import sys import os import pandas as pd import pdfplumber import json import gradio as gr from typing import List, Dict, Generator, Any from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import psutil import subprocess import logging import torch import gc from diskcache import Cache from transformers import AutoTokenizer # ==================== CONFIGURATION ==================== # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup directories PERSISTENT_DIR = "/data/hf_cache" DIRECTORIES = { "models": os.path.join(PERSISTENT_DIR, "txagent_models"), "tools": os.path.join(PERSISTENT_DIR, "tool_cache"), "cache": os.path.join(PERSISTENT_DIR, "cache"), "reports": os.path.join(PERSISTENT_DIR, "reports"), "vllm": os.path.join(PERSISTENT_DIR, "vllm_cache") } # Create directories for dir_path in DIRECTORIES.values(): os.makedirs(dir_path, exist_ok=True) # Environment variables os.environ.update({ "HF_HOME": DIRECTORIES["models"], "TRANSFORMERS_CACHE": DIRECTORIES["models"], "VLLM_CACHE_DIR": DIRECTORIES["vllm"], "TOKENIZERS_PARALLELISM": "false", "CUDA_LAUNCH_BLOCKING": "1" }) current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent # ==================== UTILITY FUNCTIONS ==================== def sanitize_text(text: str) -> str: """Clean and sanitize text input""" return text.encode("utf-8", "ignore").decode("utf-8") def get_file_hash(file_path: str) -> str: """Generate MD5 hash of file content""" with open(file_path, "rb") as f: return hashlib.md5(f.read()).hexdigest() def log_system_resources(tag: str = "") -> None: """Log system resource usage""" try: cpu = psutil.cpu_percent(interval=1) mem = psutil.virtual_memory() logger.info(f"[{tag}] CPU: {cpu:.1f}% | RAM: {mem.used//(1024**2)}MB/{mem.total//(1024**2)}MB") gpu_info = subprocess.run( ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"], capture_output=True, text=True ) if gpu_info.returncode == 0: used, total, util = gpu_info.stdout.strip().split(", ") logger.info(f"[{tag}] GPU: {used}MB/{total}MB | Util: {util}%") except Exception as e: logger.error(f"[{tag}] Resource monitoring failed: {e}") # ==================== FILE PROCESSING ==================== class FileProcessor: @staticmethod def extract_pdf_text(file_path: str) -> str: """Extract text from PDF with parallel processing""" try: with pdfplumber.open(file_path) as pdf: total_pages = len(pdf.pages) if not total_pages: return "" def process_page_range(start: int, end: int) -> List[tuple]: results = [] with pdfplumber.open(file_path) as pdf: for page in pdf.pages[start:end]: page_num = start + pdf.pages.index(page) text = page.extract_text() or "" results.append((page_num, f"=== Page {page_num + 1} ===\n{text.strip()}")) return results batch_size = 10 batches = [(i, min(i+batch_size, total_pages)) for i in range(0, total_pages, batch_size)] text_chunks = [""] * total_pages with ThreadPoolExecutor(max_workers=6) as executor: futures = [executor.submit(process_page_range, start, end) for start, end in batches] for future in as_completed(futures): for page_num, text in future.result(): text_chunks[page_num] = text return "\n\n".join(filter(None, text_chunks)) except Exception as e: logger.error(f"PDF processing error: {e}") return f"PDF processing error: {str(e)}" @staticmethod def excel_to_data(file_path: str) -> List[Dict]: """Convert Excel file to structured data""" try: df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str) content = df.where(pd.notnull(df), "").astype(str).values.tolist() return [{"filename": os.path.basename(file_path), "rows": content, "type": "excel"}] except Exception as e: logger.error(f"Excel processing error: {e}") return [{"error": f"Excel processing error: {str(e)}"}] @staticmethod def csv_to_data(file_path: str) -> List[Dict]: """Convert CSV file to structured data""" try: chunks = [] for chunk in pd.read_csv( file_path, header=None, dtype=str, encoding_errors='replace', on_bad_lines='skip', chunksize=10000 ): chunks.append(chunk) df = pd.concat(chunks) if chunks else pd.DataFrame() content = df.where(pd.notnull(df), "").astype(str).values.tolist() return [{"filename": os.path.basename(file_path), "rows": content, "type": "csv"}] except Exception as e: logger.error(f"CSV processing error: {e}") return [{"error": f"CSV processing error: {str(e)}"}] @classmethod def process_file(cls, file_path: str, file_type: str) -> List[Dict]: """Route file processing based on type""" processors = { "pdf": cls.extract_pdf_text, "xls": cls.excel_to_data, "xlsx": cls.excel_to_data, "csv": cls.csv_to_data } if file_type not in processors: return [{"error": f"Unsupported file type: {file_type}"}] try: result = processors[file_type](file_path) if file_type == "pdf": return [{ "filename": os.path.basename(file_path), "content": result, "status": "initial", "type": "pdf" }] return result except Exception as e: logger.error(f"Error processing {file_type} file: {e}") return [{"error": f"Error processing file: {str(e)}"}] # ==================== TEXT PROCESSING ==================== class TextProcessor: def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") self.cache = Cache(DIRECTORIES["cache"], size_limit=10*1024**3) def chunk_text(self, text: str, max_tokens: int = 1800) -> List[str]: """Split text into token-limited chunks""" tokens = self.tokenizer.encode(text) return [ self.tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens) ] def clean_response(self, text: str) -> str: """Clean and format model response""" text = sanitize_text(text) text = re.sub( r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\." r"|Since the previous attempts.*?\.|I need to.*?medications\." r"|Retrieving tools.*?\.", "", text, flags=re.DOTALL ) diagnoses = [] in_diagnoses = False for line in text.splitlines(): line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_diagnoses = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_diagnoses = False continue if in_diagnoses and re.match(r"-\s*.+", line): diagnosis = re.sub(r"^\-\s*", "", line).strip() if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): diagnoses.append(diagnosis) return " ".join(diagnoses) if diagnoses else "" def summarize_results(self, analysis: str) -> str: """Generate concise summary from full analysis""" chunks = analysis.split("--- Analysis for Chunk") diagnoses = [] for chunk in chunks: chunk = chunk.strip() if not chunk or "No oversights identified" in chunk: continue in_diagnoses = False for line in chunk.splitlines(): line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_diagnoses = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_diagnoses = False continue if in_diagnoses and re.match(r"-\s*.+", line): diagnosis = re.sub(r"^\-\s*", "", line).strip() if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): diagnoses.append(diagnosis) unique_diagnoses = list(dict.fromkeys(diagnoses)) # Remove duplicates if not unique_diagnoses: return "No missed diagnoses were identified in the provided records." if len(unique_diagnoses) > 1: summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1]) summary += f", and {unique_diagnoses[-1]}" else: summary = "Missed diagnoses include " + unique_diagnoses[0] return summary + ", all requiring urgent clinical review." # ==================== CORE APPLICATION ==================== class ClinicalOversightApp: def __init__(self): self.agent = self._initialize_agent() self.text_processor = TextProcessor() self.file_processor = FileProcessor() def _initialize_agent(self): """Initialize the TxAgent with proper configuration""" logger.info("Initializing AI model...") log_system_resources("Before Load") tool_path = os.path.join(DIRECTORIES["tools"], "new_tool.json") if not os.path.exists(tool_path): default_tools = os.path.abspath("data/new_tool.json") shutil.copy(default_tools, tool_path) agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": tool_path}, force_finish=True, enable_checker=False, step_rag_num=4, seed=100, additional_default_tools=[], ) agent.init_model() log_system_resources("After Load") logger.info("AI Agent Ready") return agent def process_response_stream(self, prompt: str, history: List[dict]) -> Generator[dict, None, None]: """Stream the agent's response with proper formatting""" full_response = "" for chunk in self.agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []): if not chunk: continue if isinstance(chunk, list): for message in chunk: if hasattr(message, 'content') and message.content: cleaned = self.text_processor.clean_response(message.content) if cleaned: full_response += cleaned + " " yield {"role": "assistant", "content": full_response} elif isinstance(chunk, str) and chunk.strip(): cleaned = self.text_processor.clean_response(chunk) if cleaned: full_response += cleaned + " " yield {"role": "assistant", "content": full_response} def analyze(self, message: str, history: List[dict], files: List) -> Generator[Dict[str, Any], None, None]: """Main analysis pipeline with proper output formatting""" # Initialize all output components outputs = { "chatbot": history.copy(), "download_output": None, "final_summary": "", "progress_text": {"value": "Starting analysis...", "visible": True} } yield outputs try: # Add user message to history history.append({"role": "user", "content": message}) outputs["chatbot"] = history yield outputs # Process uploaded files extracted = [] file_hash_value = "" if files: with ThreadPoolExecutor(max_workers=4) as executor: futures = [] for f in files: file_type = f.name.split(".")[-1].lower() futures.append(executor.submit(self.file_processor.process_file, f.name, file_type)) for i, future in enumerate(as_completed(futures), 1): try: extracted.extend(future.result()) outputs["progress_text"] = self._update_progress(i, len(files), "Processing files") yield outputs except Exception as e: logger.error(f"File processing error: {e}") extracted.append({"error": f"Error processing file: {str(e)}"}) file_hash_value = get_file_hash(files[0].name) if files else "" history.append({"role": "assistant", "content": "✅ File processing complete"}) outputs.update({ "chatbot": history, "progress_text": self._update_progress(len(files), len(files), "Files processed") }) yield outputs # Analyze content text_content = "\n".join(json.dumps(item) for item in extracted) chunks = self.text_processor.chunk_text(text_content) combined_response = "" for chunk_idx, chunk in enumerate(chunks, 1): prompt = f""" Analyze this patient record for missed diagnoses. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings with their potential implications and urgent review recommendations. If no missed diagnoses are found, state 'No missed diagnoses identified'. Patient Record (Chunk {chunk_idx}/{len(chunks)}): {chunk[:1800]} """ history.append({"role": "assistant", "content": ""}) outputs.update({ "chatbot": history, "progress_text": self._update_progress(chunk_idx, len(chunks), "Analyzing") }) yield outputs # Stream response chunk_response = "" for update in self.process_response_stream(prompt, history): history[-1] = update chunk_response = update["content"] outputs.update({ "chatbot": history, "progress_text": self._update_progress(chunk_idx, len(chunks), "Analyzing") }) yield outputs combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n" torch.cuda.empty_cache() gc.collect() # Generate final outputs summary = self.text_processor.summarize_results(combined_response) report_path = os.path.join(DIRECTORIES["reports"], f"{file_hash_value}_report.txt") if file_hash_value else None if report_path: with open(report_path, "w", encoding="utf-8") as f: f.write(combined_response + "\n\n" + summary) outputs.update({ "download_output": report_path if report_path else None, "final_summary": summary, "progress_text": {"visible": False} }) yield outputs except Exception as e: logger.error(f"Analysis error: {e}") history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"}) outputs.update({ "chatbot": history, "final_summary": f"Error occurred: {str(e)}", "progress_text": {"visible": False} }) yield outputs def _update_progress(self, current: int, total: int, stage: str = "") -> Dict[str, Any]: """Format progress update for UI""" progress = f"{stage} - {current}/{total}" if stage else f"{current}/{total}" return {"value": progress, "visible": True, "label": f"Progress: {progress}"} def create_interface(self): """Create Gradio interface with improved layout""" with gr.Blocks( theme=gr.themes.Soft( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate" ), title="Clinical Oversight Assistant", css=""" .diagnosis-summary { border-left: 4px solid #4f46e5; padding: 12px; background: #f8fafc; border-radius: 4px; } .file-upload { border: 2px dashed #cbd5e1; border-radius: 8px; padding: 20px; } """ ) as app: # Header Section gr.Markdown("""

🩺 Clinical Oversight Assistant

AI-powered analysis of patient records for potential missed diagnoses

""") with gr.Row(equal_height=False): # Main Chat Column with gr.Column(scale=3): chatbot = gr.Chatbot( label="Clinical Analysis", height=600, show_copy_button=True, avatar_images=( "assets/user.png", "assets/assistant.png" ) if os.path.exists("assets/user.png") else None, bubble_full_width=False, type="messages", elem_classes=["chat-container"] ) # Results Column with gr.Column(scale=1): with gr.Group(): gr.Markdown("### 📝 Summary of Findings") final_summary = gr.Markdown( "Analysis results will appear here...", elem_classes=["diagnosis-summary"] ) with gr.Group(): gr.Markdown("### 📂 Report Download") download_output = gr.File( label="Full Report", visible=False, interactive=False ) # Input Section with gr.Row(): file_upload = gr.File( file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple", label="Upload Patient Records", elem_classes=["file-upload"] ) # Interaction Section with gr.Row(): msg_input = gr.Textbox( placeholder="Ask about potential oversights or upload files...", show_label=False, container=False, scale=7, autofocus=True ) send_btn = gr.Button( "Analyze", variant="primary", scale=1, min_width=100 ) # Progress Indicator progress_text = gr.Textbox( label="Progress Status", visible=False, interactive=False ) # Event Handlers send_btn.click( self.analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output, final_summary, progress_text], show_progress="hidden" ) msg_input.submit( self.analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output, final_summary, progress_text], show_progress="hidden" ) app.load( lambda: [ [], None, "", "", None, {"visible": False} ], outputs=[chatbot, download_output, final_summary, msg_input, file_upload, progress_text], queue=False ) return app # ==================== APPLICATION ENTRY POINT ==================== if __name__ == "__main__": try: logger.info("Starting Clinical Oversight Assistant...") app = ClinicalOversightApp() interface = app.create_interface() interface.queue( api_open=False, max_size=20 ).launch( server_name="0.0.0.0", server_port=7860, show_error=True, allowed_paths=[DIRECTORIES["reports"]], share=False ) except Exception as e: logger.error(f"Application failed to start: {e}") raise finally: if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()