import sys import os import pandas as pd import pdfplumber import json import gradio as gr from typing import List, Dict, Optional, Generator from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import psutil import subprocess import logging import torch import gc from diskcache import Cache import time from transformers import AutoTokenizer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Persistent directory persistent_dir = "/data/hf_cache" os.makedirs(persistent_dir, exist_ok=True) model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache") for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]: os.makedirs(directory, exist_ok=True) os.environ["HF_HOME"] = model_cache_dir os.environ["TRANSFORMERS_CACHE"] = model_cache_dir os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["CUDA_LAUNCH_BLOCKING"] = "1" current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent # Initialize cache with 10GB limit cache = Cache(file_cache_dir, size_limit=10 * 1024**3) # Initialize tokenizer for precise chunking tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") def sanitize_utf8(text: str) -> str: return text.encode("utf-8", "ignore").decode("utf-8") def file_hash(path: str) -> str: with open(path, "rb") as f: return hashlib.md5(f.read()).hexdigest() def extract_all_pages(file_path: str, progress_callback=None) -> str: try: with pdfplumber.open(file_path) as pdf: total_pages = len(pdf.pages) if total_pages == 0: return "" batch_size = 10 batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)] text_chunks = [""] * total_pages processed_pages = 0 def extract_batch(start: int, end: int) -> List[tuple]: results = [] with pdfplumber.open(file_path) as pdf: for page in pdf.pages[start:end]: page_num = start + pdf.pages.index(page) page_text = page.extract_text() or "" results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}")) return results with ThreadPoolExecutor(max_workers=6) as executor: futures = [executor.submit(extract_batch, start, end) for start, end in batches] for future in as_completed(futures): for page_num, text in future.result(): text_chunks[page_num] = text processed_pages += batch_size if progress_callback: progress_callback(min(processed_pages, total_pages), total_pages) return "\n\n".join(filter(None, text_chunks)) except Exception as e: logger.error("PDF processing error: %s", e) return f"PDF processing error: {str(e)}" def excel_to_json(file_path: str) -> List[Dict]: """Convert Excel file to JSON with optimized processing""" try: # First try with openpyxl (faster for xlsx) try: df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str) except Exception: # Fall back to xlrd if needed df = pd.read_excel(file_path, engine='xlrd', header=None, dtype=str) # Convert to list of lists with null handling content = df.where(pd.notnull(df), "").astype(str).values.tolist() return [{ "filename": os.path.basename(file_path), "rows": content, "type": "excel" }] except Exception as e: logger.error(f"Error processing Excel file: {e}") return [{"error": f"Error processing Excel file: {str(e)}"}] def csv_to_json(file_path: str) -> List[Dict]: """Convert CSV file to JSON with optimized processing""" try: # Read CSV in chunks if large chunks = [] for chunk in pd.read_csv( file_path, header=None, dtype=str, encoding_errors='replace', on_bad_lines='skip', chunksize=10000 ): chunks.append(chunk) df = pd.concat(chunks) if chunks else pd.DataFrame() content = df.where(pd.notnull(df), "").astype(str).values.tolist() return [{ "filename": os.path.basename(file_path), "rows": content, "type": "csv" }] except Exception as e: logger.error(f"Error processing CSV file: {e}") return [{"error": f"Error processing CSV file: {str(e)}"}] def process_file(file_path: str, file_type: str) -> List[Dict]: """Process file based on type and return JSON data""" try: if file_type == "pdf": text = extract_all_pages(file_path) return [{ "filename": os.path.basename(file_path), "content": text, "status": "initial", "type": "pdf" }] elif file_type in ["xls", "xlsx"]: return excel_to_json(file_path) elif file_type == "csv": return csv_to_json(file_path) else: return [{"error": f"Unsupported file type: {file_type}"}] except Exception as e: logger.error("Error processing %s: %s", os.path.basename(file_path), e) return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}] def tokenize_and_chunk(text: str, max_tokens: int = 1800) -> List[str]: """Split text into chunks based on token count""" tokens = tokenizer.encode(text) chunks = [] for i in range(0, len(tokens), max_tokens): chunk_tokens = tokens[i:i + max_tokens] chunks.append(tokenizer.decode(chunk_tokens)) return chunks def log_system_usage(tag=""): try: cpu = psutil.cpu_percent(interval=1) mem = psutil.virtual_memory() logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2)) result = subprocess.run( ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"], capture_output=True, text=True ) if result.returncode == 0: used, total, util = result.stdout.strip().split(", ") logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util) except Exception as e: logger.error("[%s] GPU/CPU monitor failed: %s", tag, e) def clean_response(text: str) -> str: text = sanitize_utf8(text) text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL) diagnoses = [] lines = text.splitlines() in_diagnoses_section = False for line in lines: line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_diagnoses_section = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_diagnoses_section = False continue if in_diagnoses_section and re.match(r"-\s*.+", line): diagnosis = re.sub(r"^\-\s*", "", line).strip() if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): diagnoses.append(diagnosis) text = " ".join(diagnoses) text = re.sub(r"\s+", " ", text).strip() text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text) return text if text else "" def summarize_findings(combined_response: str) -> str: chunks = combined_response.split("--- Analysis for Chunk") diagnoses = [] for chunk in chunks: chunk = chunk.strip() if not chunk or "No oversights identified" in chunk: continue lines = chunk.splitlines() in_diagnoses_section = False for line in lines: line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_diagnoses_section = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_diagnoses_section = False continue if in_diagnoses_section and re.match(r"-\s*.+", line): diagnosis = re.sub(r"^\-\s*", "", line).strip() if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): diagnoses.append(diagnosis) seen = set() unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))] if not unique_diagnoses: return "No missed diagnoses were identified in the provided records." summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1]) if len(unique_diagnoses) > 1: summary += f", and {unique_diagnoses[-1]}" elif len(unique_diagnoses) == 1: summary = "Missed diagnoses include " + unique_diagnoses[0] summary += ", all of which require urgent clinical review to prevent potential adverse outcomes." return summary.strip() def init_agent(): logger.info("Initializing model...") log_system_usage("Before Load") default_tool_path = os.path.abspath("data/new_tool.json") target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") if not os.path.exists(target_tool_path): shutil.copy(default_tool_path, target_tool_path) agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": target_tool_path}, force_finish=True, enable_checker=False, step_rag_num=4, seed=100, additional_default_tools=[], ) agent.init_model() log_system_usage("After Load") logger.info("Agent Ready") return agent def create_ui(agent): with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Oversight Assistant") as demo: gr.Markdown("

🩺 Clinical Oversight Assistant

") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Analysis Conversation", height=600, show_copy_button=True, avatar_images=( "assets/user.png", "assets/assistant.png" ), render=False # Disable auto-render for better streaming control ) with gr.Column(scale=1): final_summary = gr.Markdown( label="Summary of Findings", value="### Summary will appear here\nAfter analysis completes" ) download_output = gr.File( label="Download Full Report", visible=False ) with gr.Row(): file_upload = gr.File( file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple", label="Upload Patient Records" ) with gr.Row(): msg_input = gr.Textbox( placeholder="Ask about potential oversights...", show_label=False, container=False, scale=7, autofocus=True ) send_btn = gr.Button( "Analyze", variant="primary", scale=1, min_width=100 ) progress_text = gr.Textbox( label="Progress", visible=False, interactive=False ) def update_progress(current, total, stage=""): progress = f"{stage} - {current}/{total}" if stage else f"{current}/{total}" return { progress_text: gr.Textbox( value=progress, visible=True, label=f"Progress: {progress}" ) } prompt_template = """ Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence. Patient Record Excerpt (Chunk {0} of {1}): {chunk} """ def process_response_stream(prompt: str, history: List[dict]) -> Generator[dict, None, None]: """Process a single prompt and stream the response""" full_response = "" for chunk_output in agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []): if chunk_output is None: continue if isinstance(chunk_output, list): for m in chunk_output: if hasattr(m, 'content') and m.content: cleaned = clean_response(m.content) if cleaned: full_response += cleaned + " " yield {"role": "assistant", "content": full_response} elif isinstance(chunk_output, str) and chunk_output.strip(): cleaned = clean_response(chunk_output) if cleaned: full_response += cleaned + " " yield {"role": "assistant", "content": full_response} return full_response def analyze(message: str, history: List[dict], files: List) -> Generator[dict, None, None]: # Start with user message history.append({"role": "user", "content": message}) yield { "chatbot": history, "download_output": None, "final_summary": "", "progress_text": gr.Textbox(visible=True, value="Starting analysis...") } extracted = [] file_hash_value = "" if files: # Process files in parallel with ThreadPoolExecutor(max_workers=4) as executor: futures = [] for f in files: file_type = f.name.split(".")[-1].lower() futures.append(executor.submit(process_file, f.name, file_type)) for i, future in enumerate(as_completed(futures), 1): try: extracted.extend(future.result()) yield update_progress(i, len(files), "Processing files") except Exception as e: logger.error(f"File processing error: {e}") extracted.append({"error": f"Error processing file: {str(e)}"}) file_hash_value = file_hash(files[0].name) if files else "" history.append({"role": "assistant", "content": "✅ File processing complete"}) yield { "chatbot": history, "download_output": None, "final_summary": "", "progress_text": gr.Textbox(visible=True, value="Files processed, analyzing content...") } # Convert extracted data to JSON text text_content = "\n".join(json.dumps(item) for item in extracted) # Tokenize and chunk the content properly chunks = tokenize_and_chunk(text_content) combined_response = "" try: for chunk_idx, chunk in enumerate(chunks, 1): prompt = prompt_template.format(chunk_idx, len(chunks), chunk=chunk[:1800]) # Create a placeholder message history.append({"role": "assistant", "content": ""}) yield { "chatbot": history, "download_output": None, "final_summary": "", "progress_text": gr.Textbox( visible=True, value=f"Analyzing chunk {chunk_idx}/{len(chunks)}" ) } # Process and stream the response chunk_response = "" for update in process_response_stream(prompt, history): # Update the last message with streaming content history[-1] = update chunk_response = update["content"] yield { "chatbot": history, "download_output": None, "final_summary": "", "progress_text": gr.Textbox( visible=True, value=f"Analyzing chunk {chunk_idx}/{len(chunks)}" ) } combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n" # Clean up memory torch.cuda.empty_cache() gc.collect() # Generate final summary summary = summarize_findings(combined_response) report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None if report_path: with open(report_path, "w", encoding="utf-8") as f: f.write(combined_response + "\n\n" + summary) yield { "chatbot": history, "download_output": gr.File(report_path) if report_path and os.path.exists(report_path) else None, "final_summary": summary, "progress_text": gr.Textbox(visible=False) } except Exception as e: logger.error("Analysis error: %s", e) history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"}) yield { "chatbot": history, "download_output": None, "final_summary": f"Error occurred during analysis: {str(e)}", "progress_text": gr.Textbox(visible=False) } def clear_and_start(): return { "chatbot": [], "download_output": None, "final_summary": "", "msg_input": "", "file_upload": None, "progress_text": gr.Textbox(visible=False) } # Event handlers send_btn.click( analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output, final_summary, progress_text], show_progress="hidden" ) msg_input.submit( analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output, final_summary, progress_text], show_progress="hidden" ) demo.load( clear_and_start, outputs=[chatbot, download_output, final_summary, msg_input, file_upload, progress_text], queue=False ) return demo if __name__ == "__main__": try: logger.info("Launching app...") agent = init_agent() demo = create_ui(agent) demo.queue( api_open=False, max_size=20 ).launch( server_name="0.0.0.0", server_port=7860, show_error=True, allowed_paths=[report_dir], share=False ) except Exception as e: logger.error(f"Failed to launch app: {e}") raise finally: if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()