import sys import os import pandas as pd import pdfplumber import json import gradio as gr from typing import List, Dict, Generator, Any, Optional from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import psutil import subprocess import logging import torch import gc from diskcache import Cache from transformers import AutoTokenizer from pathlib import Path # ==================== CONFIGURATION ==================== logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Directory Setup BASE_DIR = Path("/data/hf_cache") DIRECTORIES = { "models": BASE_DIR / "txagent_models", "tools": BASE_DIR / "tool_cache", "cache": BASE_DIR / "cache", "reports": BASE_DIR / "reports", "vllm": BASE_DIR / "vllm_cache" } for dir_path in DIRECTORIES.values(): dir_path.mkdir(parents=True, exist_ok=True) # Environment Configuration os.environ.update({ "HF_HOME": str(DIRECTORIES["models"]), "TRANSFORMERS_CACHE": str(DIRECTORIES["models"]), "VLLM_CACHE_DIR": str(DIRECTORIES["vllm"]), "TOKENIZERS_PARALLELISM": "false", "CUDA_LAUNCH_BLOCKING": "1" }) # Add src path for txagent current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent # ==================== CORE COMPONENTS ==================== class FileProcessor: """Handles all file processing operations""" @staticmethod def extract_pdf_content(file_path: str) -> str: """Extract text from PDF with parallel processing""" try: with pdfplumber.open(file_path) as pdf: total_pages = len(pdf.pages) if not total_pages: return "" def process_batch(start: int, end: int) -> List[tuple]: results = [] with pdfplumber.open(file_path) as pdf: for page in pdf.pages[start:end]: page_num = start + pdf.pages.index(page) text = page.extract_text() or "" results.append((page_num, f"=== Page {page_num + 1} ===\n{text.strip()}")) return results batch_size = min(10, total_pages) batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)] text_chunks = [""] * total_pages with ThreadPoolExecutor(max_workers=min(6, os.cpu_count() or 4)) as executor: futures = [executor.submit(process_batch, start, end) for start, end in batches] for future in as_completed(futures): for page_num, text in future.result(): text_chunks[page_num] = text return "\n\n".join(filter(None, text_chunks)) except Exception as e: logger.error(f"PDF extraction failed: {e}") return f"PDF processing error: {str(e)}" @staticmethod def process_tabular_data(file_path: str, file_type: str) -> List[Dict]: """Process Excel or CSV files""" try: if file_type == "csv": chunks = pd.read_csv( file_path, header=None, dtype=str, encoding_errors='replace', on_bad_lines='skip', chunksize=10000 ) df = pd.concat(chunks) if chunks else pd.DataFrame() else: # Excel try: df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str) except: df = pd.read_excel(file_path, engine='xlrd', header=None, dtype=str) return [{ "filename": os.path.basename(file_path), "rows": df.where(pd.notnull(df), "").astype(str).values.tolist(), "type": file_type }] except Exception as e: logger.error(f"{file_type.upper()} processing failed: {e}") return [{"error": f"{file_type.upper()} processing error: {str(e)}"}] @classmethod def handle_upload(cls, file_path: str, file_type: str) -> List[Dict]: """Route file processing based on type""" processor_map = { "pdf": cls.extract_pdf_content, "xls": lambda x: cls.process_tabular_data(x, "excel"), "xlsx": lambda x: cls.process_tabular_data(x, "excel"), "csv": lambda x: cls.process_tabular_data(x, "csv") } if file_type not in processor_map: return [{"error": f"Unsupported file type: {file_type}"}] try: result = processor_map[file_type](file_path) if file_type == "pdf": return [{ "filename": os.path.basename(file_path), "content": result, "type": "pdf" }] return result except Exception as e: logger.error(f"File processing failed: {e}") return [{"error": f"File processing error: {str(e)}"}] class TextAnalyzer: """Handles text processing and analysis""" def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") self.cache = Cache(DIRECTORIES["cache"], size_limit=10*1024**3) def chunk_content(self, text: str, max_tokens: int = 1800) -> List[str]: """Split text into token-limited chunks""" tokens = self.tokenizer.encode(text) return [ self.tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens) ] def clean_output(self, text: str) -> str: """Clean and format model response""" text = text.encode("utf-8", "ignore").decode("utf-8") text = re.sub( r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\." r"|Since the previous attempts.*?\.|I need to.*?medications\." r"|Retrieving tools.*?\.", "", text, flags=re.DOTALL ) diagnoses = [] in_section = False for line in text.splitlines(): line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_section = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_section = False continue if in_section and re.match(r"-\s*.+", line): diagnosis = re.sub(r"^\-\s*", "", line).strip() if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): diagnoses.append(diagnosis) return " ".join(diagnoses) if diagnoses else "" def generate_summary(self, analysis: str) -> str: """Create concise clinical summary""" findings = [] for chunk in analysis.split("--- Analysis for Chunk"): chunk = chunk.strip() if not chunk or "No oversights identified" in chunk: continue in_section = False for line in chunk.splitlines(): line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_section = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_section = False continue if in_section and re.match(r"-\s*.+", line): finding = re.sub(r"^\-\s*", "", line).strip() if finding and not re.match(r"No issues identified", finding, re.IGNORECASE): findings.append(finding) unique_findings = list(dict.fromkeys(findings)) if not unique_findings: return "No clinical concerns identified in the provided records." if len(unique_findings) > 1: summary = "Potential concerns include: " + ", ".join(unique_findings[:-1]) summary += f", and {unique_findings[-1]}" else: summary = "Potential concern identified: " + unique_findings[0] return summary + ". Recommend urgent clinical review." class ClinicalAgent: """Main application controller""" def __init__(self): self.agent = self._init_agent() self.file_processor = FileProcessor() self.text_analyzer = TextAnalyzer() def _init_agent(self) -> Any: """Initialize the AI agent""" logger.info("Initializing clinical agent...") self._log_system_status("pre-init") tool_path = DIRECTORIES["tools"] / "new_tool.json" if not tool_path.exists(): default_tools = Path("data/new_tool.json") if default_tools.exists(): shutil.copy(default_tools, tool_path) agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": str(tool_path)}, force_finish=True, enable_checker=False, step_rag_num=4, seed=100, additional_default_tools=[], ) agent.init_model() self._log_system_status("post-init") logger.info("Clinical agent ready") return agent def _log_system_status(self, phase: str) -> None: """Log system resource utilization""" try: cpu = psutil.cpu_percent(interval=1) mem = psutil.virtual_memory() logger.info(f"[{phase}] CPU: {cpu:.1f}% | RAM: {mem.used//(1024**2)}MB/{mem.total//(1024**2)}MB") gpu_info = subprocess.run( ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"], capture_output=True, text=True ) if gpu_info.returncode == 0: used, total, util = gpu_info.stdout.strip().split(", ") logger.info(f"[{phase}] GPU: {used}MB/{total}MB | Util: {util}%") except Exception as e: logger.error(f"Resource monitoring failed: {e}") def process_stream(self, prompt: str, history: List[Dict]) -> Generator[Dict, None, None]: """Stream the agent's responses""" full_response = "" for chunk in self.agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []): if not chunk: continue if isinstance(chunk, list): for msg in chunk: if hasattr(msg, 'content') and msg.content: cleaned = self.text_analyzer.clean_output(msg.content) if cleaned: full_response += cleaned + " " yield {"role": "assistant", "content": full_response} elif isinstance(chunk, str) and chunk.strip(): cleaned = self.text_analyzer.clean_output(chunk) if cleaned: full_response += cleaned + " " yield {"role": "assistant", "content": full_response} def analyze_records(self, message: str, history: List[Dict], files: List) -> Generator[tuple, None, None]: """Main analysis workflow""" outputs = { "chatbot": history.copy(), "download_output": None, "final_summary": "", "progress": {"value": "Initializing...", "visible": True} } yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) try: # Add user message history.append({"role": "user", "content": message}) outputs["chatbot"] = history yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) # Process files extracted = [] file_hash = "" if files: with ThreadPoolExecutor(max_workers=4) as executor: futures = [] for f in files: file_type = Path(f.name).suffix[1:].lower() futures.append(executor.submit( self.file_processor.handle_upload, f.name, file_type )) for i, future in enumerate(as_completed(futures), 1): try: extracted.extend(future.result()) outputs["progress"] = self._format_progress(i, len(files), "Processing files") yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) except Exception as e: logger.error(f"File processing failed: {e}") extracted.append({"error": str(e)}) if files and os.path.exists(files[0].name): file_hash = hashlib.md5(open(files[0].name, "rb").read()).hexdigest() history.append({"role": "assistant", "content": "✅ Files processed successfully"}) outputs.update({ "chatbot": history, "progress": self._format_progress(len(files), len(files), "Files processed") }) yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) # Analyze content text_content = "\n".join(json.dumps(item) for item in extracted) chunks = self.text_analyzer.chunk_content(text_content) full_analysis = "" for idx, chunk in enumerate(chunks, 1): prompt = f""" Analyze this clinical documentation for potential missed diagnoses. Provide: 1. Specific clinical findings with references (e.g., "Elevated BP (160/95) on page 3") 2. Their clinical significance 3. Urgency of review Use concise, continuous prose without bullet points. If no concerns, state "No missed diagnoses identified." Document Excerpt (Part {idx}/{len(chunks)}): {chunk[:1750]} """ history.append({"role": "assistant", "content": ""}) outputs.update({ "chatbot": history, "progress": self._format_progress(idx, len(chunks), "Analyzing") }) yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) # Stream analysis chunk_response = "" for update in self.process_stream(prompt, history): history[-1] = update chunk_response = update["content"] outputs.update({ "chatbot": history, "progress": self._format_progress(idx, len(chunks), "Analyzing") }) yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) full_analysis += f"--- Analysis Part {idx} ---\n{chunk_response}\n" torch.cuda.empty_cache() gc.collect() # Final outputs summary = self.text_analyzer.generate_summary(full_analysis) report_path = DIRECTORIES["reports"] / f"{file_hash}_report.txt" if file_hash else None if report_path: with open(report_path, "w", encoding="utf-8") as f: f.write(full_analysis + "\n\nSUMMARY:\n" + summary) outputs.update({ "download_output": str(report_path) if report_path and report_path.exists() else None, "final_summary": summary, "progress": {"visible": False} }) yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) except Exception as e: logger.error(f"Analysis failed: {e}") history.append({"role": "assistant", "content": f"❌ Analysis error: {str(e)}"}) outputs.update({ "chatbot": history, "final_summary": f"Error: {str(e)}", "progress": {"visible": False} }) yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) def _format_progress(self, current: int, total: int, stage: str = "") -> Dict[str, Any]: """Format progress update for UI""" status = f"{stage} - {current}/{total}" if stage else f"{current}/{total}" return {"value": status, "visible": True, "label": f"Progress: {status}"} def create_interface(self) -> gr.Blocks: """Build the Gradio interface""" css = """ /* ==================== BASE STYLES ==================== */ :root { --primary-color: #4f46e5; --primary-dark: #4338ca; --border-radius: 8px; --transition: all 0.3s ease; --shadow: 0 4px 12px rgba(0,0,0,0.1); --font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; --background: #ffffff; --text-color: #1e293b; --chat-bg: #f8fafc; --message-bg: #e2e8f0; --panel-bg: rgba(248, 250, 252, 0.9); --panel-dark-bg: rgba(30, 41, 59, 0.9); } [data-theme="dark"] { --background: #1e2a44; --text-color: #f1f5f9; --chat-bg: #2d3b55; --message-bg: #475569; --panel-bg: var(--panel-dark-bg); } body, .gradio-container { font-family: var(--font-family); background: var(--background); color: var(--text-color); margin: 0; padding: 0; transition: var(--transition); } /* ==================== LAYOUT ==================== */ .gradio-container { max-width: 1200px; margin: 0 auto; padding: 1.5rem; display: flex; flex-direction: column; gap: 1.5rem; } .chat-container { background: var(--chat-bg); border-radius: var(--border-radius); border: 1px solid #e2e8f0; padding: 1.5rem; min-height: 50vh; max-height: 80vh; overflow-y: auto; box-shadow: var(--shadow); margin-bottom: 4rem; } .summary-panel { background: var(--panel-bg); border-left: 4px solid var(--primary-color); padding: 1rem; border-radius: var(--border-radius); margin-bottom: 1rem; box-shadow: var(--shadow); backdrop-filter: blur(8px); } .upload-area { border: 2px dashed #cbd5e1; border-radius: var(--border-radius); padding: 1.5rem; margin: 0.75rem 0; transition: var(--transition); } .upload-area:hover { border-color: var(--primary-color); background: rgba(79, 70, 229, 0.05); } /* ==================== COMPONENTS ==================== */ .chat__message { margin: 0.75rem 0; padding: 0.75rem 1rem; border-radius: var(--border-radius); max-width: 85%; transition: var(--transition); background: var(--message-bg); border: 1px solid rgba(0,0,0,0.05); animation: messageFade 0.3s ease; } .chat__message:hover { transform: translateY(-2px); box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .chat__message.user { background: linear-gradient(135deg, var(--primary-color), var(--primary-dark)); color: white; margin-left: auto; } .chat__message.assistant { background: var(--message-bg); color: var(--text-color); } .input-container { display: flex; align-items: center; gap: 0.75rem; background: var(--chat-bg); padding: 0.75rem 1rem; border-radius: 1.5rem; box-shadow: var(--shadow); position: sticky; bottom: 1rem; z-index: 10; } .input__textbox { flex-grow: 1; border: none; background: transparent; color: var(--text-color); outline: none; font-size: 1rem; } .input__textbox:focus { border-bottom: 2px solid var(--primary-color); } .submit-btn { background: linear-gradient(135deg, var(--primary-color), var(--primary-dark)); color: white; border: none; border-radius: 1rem; padding: 0.5rem 1.25rem; font-size: 0.9rem; transition: var(--transition); } .submit-btn:hover { transform: scale(1.05); } .submit-btn:active { animation: glow 0.3s ease; } .tooltip { position: relative; } .tooltip:hover::after { content: attr(data-tip); position: absolute; top: -2.5rem; left: 50%; transform: translateX(-50%); background: #1e293b; color: white; padding: 0.4rem 0.8rem; border-radius: 0.4rem; font-size: 0.85rem; max-width: 200px; white-space: normal; text-align: center; z-index: 1000; animation: fadeIn 0.3s ease; } .progress-tracker { position: relative; padding: 0.5rem; background: var(--message-bg); border-radius: var(--border-radius); margin-top: 0.75rem; overflow: hidden; } .progress-tracker::before { content: ''; position: absolute; top: 0; left: 0; height: 100%; width: 0; background: linear-gradient(to right, var(--primary-color), var(--primary-dark)); opacity: 0.3; animation: progress 2s ease-in-out infinite; } /* ==================== ANIMATIONS ==================== */ @keyframes glow { 0%, 100% { transform: scale(1); opacity: 1; } 50% { transform: scale(1.1); opacity: 0.8; } } @keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } } @keyframes messageFade { from { opacity: 0; transform: translateY(10px) scale(0.95); } to { opacity: 1; transform: translateY(0) scale(1); } } @keyframes progress { 0% { width: 0; } 50% { width: 60%; } 100% { width: 0; } } /* ==================== THEMES ==================== */ [data-theme="dark"] .chat-container { border-color: #475569; } [data-theme="dark"] .upload-area { border-color: #64748b; } [data-theme="dark"] .upload-area:hover { background: rgba(79, 70, 229, 0.1); } [data-theme="dark"] .summary-panel { border-left-color: #818cf8; } /* ==================== MEDIA QUERIES ==================== */ @media (max-width: 768px) { .gradio-container { padding: 1rem; } .chat-container { min-height: 40vh; max-height: 70vh; margin-bottom: 3.5rem; } .summary-panel { padding: 0.75rem; } .upload-area { padding: 1rem; } .input-container { gap: 0.5rem; padding: 0.5rem; } .submit-btn { padding: 0.4rem 1rem; } } @media (max-width: 480px) { .chat-container { padding: 1rem; margin-bottom: 3rem; } .input-container { flex-direction: column; padding: 0.5rem; } .input__textbox { font-size: 0.9rem; } .submit-btn { width: 100%; padding: 0.5rem; font-size: 0.85rem; } .chat__message { max-width: 90%; padding: 0.5rem 0.75rem; } .tooltip:hover::after { top: auto; bottom: -2.5rem; max-width: 80vw; } } """ js = """ function applyTheme(theme) { document.documentElement.setAttribute('data-theme', theme); localStorage.setItem('theme', theme); } document.addEventListener('DOMContentLoaded', () => { const savedTheme = localStorage.getItem('theme') || 'light'; applyTheme(savedTheme); }); """ with gr.Blocks( theme=gr.themes.Soft( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate" ), title="Clinical Oversight Assistant", css=css, js=js ) as app: # Header gr.Markdown("""
AI-powered analysis for identifying potential missed diagnoses in patient records