import sys import os import pandas as pd import pdfplumber import json import gradio as gr from typing import List, Dict, Optional, Generator from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import psutil import subprocess import logging import torch import gc from diskcache import Cache import time from transformers import AutoTokenizer from functools import lru_cache import numpy as np from difflib import SequenceMatcher # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Constants MAX_TOKENS = 1800 BATCH_SIZE = 1 MAX_WORKERS = 2 CHUNK_SIZE = 5 MODEL_MAX_TOKENS = 131072 MAX_TEXT_LENGTH = 500000 # Persistent directory setup persistent_dir = "/data/hf_cache" os.makedirs(persistent_dir, exist_ok=True) model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache") for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]: os.makedirs(directory, exist_ok=True) os.environ.update({ "HF_HOME": model_cache_dir, "TRANSFORMERS_CACHE": model_cache_dir, "VLLM_CACHE_DIR": vllm_cache_dir, "TOKENIZERS_PARALLELISM": "false", "CUDA_LAUNCH_BLOCKING": "1" }) current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent # Initialize cache with 10GB limit cache = Cache(file_cache_dir, size_limit=10 * 1024**3) @lru_cache(maxsize=1) def get_tokenizer(): return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") def sanitize_utf8(text: str) -> str: return text.encode("utf-8", "ignore").decode("utf-8") def file_hash(path: str) -> str: hash_md5 = hashlib.md5() with open(path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() def extract_pdf_page(page, tokenizer, max_tokens=MAX_TOKENS) -> List[str]: try: text = page.extract_text() or "" text = sanitize_utf8(text) if len(text) > MAX_TEXT_LENGTH // 10: text = text[:MAX_TEXT_LENGTH // 10] tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) > max_tokens: chunks = [] current_chunk = [] current_length = 0 for token in tokens: if current_length + 1 > max_tokens: chunks.append(tokenizer.decode(current_chunk)) current_chunk = [token] current_length = 1 else: current_chunk.append(token) current_length += 1 if current_chunk: chunks.append(tokenizer.decode(current_chunk)) return [f"=== Page {page.page_number} ===\n{c}" for c in chunks] return [f"=== Page {page.page_number} ===\n{text}"] except Exception as e: logger.warning(f"Error extracting page {page.page_number}: {str(e)}") return [] def extract_all_pages(file_path: str, progress_callback=None) -> List[str]: try: tokenizer = get_tokenizer() with pdfplumber.open(file_path) as pdf: total_pages = len(pdf.pages) if total_pages == 0: return [] results = [] total_tokens = 0 for chunk_start in range(0, total_pages, CHUNK_SIZE): chunk_end = min(chunk_start + CHUNK_SIZE, total_pages) with pdfplumber.open(file_path) as pdf: with ThreadPoolExecutor(max_workers=min(CHUNK_SIZE, 2)) as executor: futures = [executor.submit(extract_pdf_page, pdf.pages[i], tokenizer) for i in range(chunk_start, chunk_end)] for future in as_completed(futures): page_chunks = future.result() for chunk in page_chunks: chunk_tokens = len(tokenizer.encode(chunk, add_special_tokens=False)) if total_tokens + chunk_tokens > MODEL_MAX_TOKENS: logger.warning(f"Total tokens exceed model limit. Stopping.") return results results.append(chunk) total_tokens += chunk_tokens if progress_callback: progress_callback(min(chunk_end, total_pages), total_pages) del pdf gc.collect() return results except Exception as e: logger.error(f"PDF processing error: {e}") return [f"PDF processing error: {str(e)}"] def excel_to_json(file_path: str) -> List[Dict]: try: # Try with openpyxl first try: with pd.ExcelFile(file_path, engine='openpyxl') as excel_file: sheets = excel_file.sheet_names results = [] for sheet_name in sheets: df = pd.read_excel( excel_file, sheet_name=sheet_name, header=None, dtype=str, na_filter=False ) if not df.empty: results.append({ "filename": f"{os.path.basename(file_path)} - {sheet_name}", "rows": df.values.tolist(), "type": "excel" }) return results if results else [{"error": "No data found in any sheet"}] except Exception as openpyxl_error: # Fallback to xlrd try: with pd.ExcelFile(file_path, engine='xlrd') as excel_file: sheets = excel_file.sheet_names results = [] for sheet_name in sheets: df = pd.read_excel( excel_file, sheet_name=sheet_name, header=None, dtype=str, na_filter=False ) if not df.empty: results.append({ "filename": f"{os.path.basename(file_path)} - {sheet_name}", "rows": df.values.tolist(), "type": "excel" }) return results if results else [{"error": "No data found in any sheet"}] except Exception as xlrd_error: logger.error(f"Excel processing failed: {xlrd_error}") return [{"error": f"Excel processing failed: {str(xlrd_error)}"}] except Exception as e: logger.error(f"Excel file opening error: {e}") return [{"error": f"Excel file opening error: {str(e)}"}] def csv_to_json(file_path: str) -> List[Dict]: try: chunks = [] for chunk in pd.read_csv( file_path, header=None, dtype=str, encoding_errors='replace', on_bad_lines='skip', chunksize=10000, na_filter=False ): chunks.append(chunk) df = pd.concat(chunks) if chunks else pd.DataFrame() return [{ "filename": os.path.basename(file_path), "rows": df.values.tolist(), "type": "csv" }] except Exception as e: logger.error(f"CSV processing error: {e}") return [{"error": f"CSV processing error: {str(e)}"}] @lru_cache(maxsize=100) def process_file_cached(file_path: str, file_type: str) -> List[Dict]: try: if file_type == "pdf": chunks = extract_all_pages(file_path) return [{ "filename": os.path.basename(file_path), "content": chunk, "status": "initial", "type": "pdf" } for chunk in chunks] elif file_type in ["xls", "xlsx"]: return excel_to_json(file_path) elif file_type == "csv": return csv_to_json(file_path) else: return [{"error": f"Unsupported file type: {file_type}"}] except Exception as e: logger.error(f"Error processing file: {e}") return [{"error": f"Error processing file: {str(e)}"}] def clean_response(text: str) -> str: if not text: return "" patterns = [ (re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""), (re.compile(r"\s+"), " "), (re.compile(r"[^\w\s\.\,\(\)\-]"), ""), ] for pattern, repl in patterns: text = pattern.sub(repl, text) sentences = text.split(". ") unique_sentences = [] seen = set() for s in sentences: if not s: continue is_unique = True for seen_s in seen: if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9: is_unique = False break if is_unique: unique_sentences.append(s) seen.add(s) text = ". ".join(unique_sentences).strip() return text if text else "No missed diagnoses identified." @lru_cache(maxsize=1) def init_agent(): logger.info("Initializing model...") default_tool_path = os.path.abspath("data/new_tool.json") target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") if not os.path.exists(target_tool_path): shutil.copy(default_tool_path, target_tool_path) agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": target_tool_path}, force_finish=True, enable_checker=False, step_rag_num=4, seed=100, additional_default_tools=[], ) agent.init_model() logger.info("Agent Ready") return agent def create_ui(agent): PROMPT_TEMPLATE = """ Analyze the patient record excerpt for missed diagnoses. Provide detailed, evidence-based analysis. Patient Record Excerpt (Chunk {0} of {1}): {chunk} """ with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("