import sys import os import pandas as pd import pdfplumber import json import gradio as gr from typing import List, Dict, Generator, Any from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import psutil import subprocess import logging import torch import gc from diskcache import Cache from transformers import AutoTokenizer from datetime import datetime # ==================== CONFIGURATION ==================== # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup directories PERSISTENT_DIR = "/data/hf_cache" DIRECTORIES = { "models": os.path.join(PERSISTENT_DIR, "txagent_models"), "tools": os.path.join(PERSISTENT_DIR, "tool_cache"), "cache": os.path.join(PERSISTENT_DIR, "cache"), "reports": os.path.join(PERSISTENT_DIR, "reports"), "vllm": os.path.join(PERSISTENT_DIR, "vllm_cache") } # Create directories for dir_path in DIRECTORIES.values(): os.makedirs(dir_path, exist_ok=True) # Environment variables os.environ.update({ "HF_HOME": DIRECTORIES["models"], "TRANSFORMERS_CACHE": DIRECTORIES["models"], "VLLM_CACHE_DIR": DIRECTORIES["vllm"], "TOKENIZERS_PARALLELISM": "false", "CUDA_LAUNCH_BLOCKING": "1" }) # Add src path for txagent current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent # Log Gradio version for debugging logger.info(f"Gradio version: {gr.__version__}") # ==================== UTILITY FUNCTIONS ==================== def sanitize_text(text: str) -> str: """Clean and sanitize text input""" return text.encode("utf-8", "ignore").decode("utf-8") def get_file_hash(file_path: str) -> str: """Generate MD5 hash of file content""" with open(file_path, "rb") as f: return hashlib.md5(f.read()).hexdigest() def log_system_resources(tag: str = "") -> None: """Log system resource usage""" try: cpu = psutil.cpu_percent(interval=1) mem = psutil.virtual_memory() logger.info(f"[{tag}] CPU: {cpu:.1f}% | RAM: {mem.used//(1024**2)}MB/{mem.total//(1024**2)}MB") gpu_info = subprocess.run( ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"], capture_output=True, text=True ) if gpu_info.returncode == 0: used, total, util = gpu_info.stdout.strip().split(", ") logger.info(f"[{tag}] GPU: {used}MB/{total}MB | Util: {util}%") except Exception as e: logger.error(f"[{tag}] Resource monitoring failed: {e}") # ==================== FILE PROCESSING ==================== class FileProcessor: @staticmethod def extract_pdf_text(file_path: str, cache: Cache) -> str: """Extract text from PDF with caching""" cache_key = f"pdf_{get_file_hash(file_path)}" if cache_key in cache: return cache[cache_key] try: with pdfplumber.open(file_path) as pdf: total_pages = len(pdf.pages) if not total_pages: return "" def process_page_range(start: int, end: int) -> List[tuple]: results = [] with pdfplumber.open(file_path) as pdf: for page in pdf.pages[start:end]: page_num = start + pdf.pages.index(page) text = page.extract_text() or "" results.append((page_num, f"=== Page {page_num + 1} ===\n{text.strip()}")) return results batch_size = 10 batches = [(i, min(i+batch_size, total_pages)) for i in range(0, total_pages, batch_size)] text_chunks = [""] * total_pages with ThreadPoolExecutor(max_workers=2) as executor: futures = [executor.submit(process_page_range, start, end) for start, end in batches] for future in as_completed(futures): for page_num, text in future.result(): text_chunks[page_num] = text result = "\n\n".join(filter(None, text_chunks)) cache[cache_key] = result return result except Exception as e: logger.error(f"PDF processing error: {e}") return f"PDF processing error: {str(e)}" @staticmethod def excel_to_data(file_path: str, cache: Cache) -> List[Dict]: """Convert Excel file to structured data with caching""" cache_key = f"excel_{get_file_hash(file_path)}" if cache_key in cache: return cache[cache_key] try: df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str) content = df.where(pd.notnull(df), "").astype(str).values.tolist() result = [{"filename": os.path.basename(file_path), "rows": content, "type": "excel"}] cache[cache_key] = result return result except Exception as e: logger.error(f"Excel processing error: {e}") return [{"error": f"Excel processing error: {str(e)}"}] @staticmethod def csv_to_data(file_path: str, cache: Cache) -> List[Dict]: """Convert CSV file to structured data with caching""" cache_key = f"csv_{get_file_hash(file_path)}" if cache_key in cache: return cache[cache_key] try: chunks = [] for chunk in pd.read_csv( file_path, header=None, dtype=str, encoding_errors='replace', on_bad_lines='skip', chunksize=10000 ): chunks.append(chunk) df = pd.concat(chunks) if chunks else pd.DataFrame() content = df.where(pd.notnull(df), "").astype(str).values.tolist() result = [{"filename": os.path.basename(file_path), "rows": content, "type": "csv"}] cache[cache_key] = result return result except Exception as e: logger.error(f"CSV processing error: {e}") return [{"error": f"CSV processing error: {str(e)}"}] @classmethod def process_file(cls, file_path: str, file_type: str, cache: Cache) -> List[Dict]: """Route file processing based on type""" processors = { "pdf": cls.extract_pdf_text, "xls": cls.excel_to_data, "xlsx": cls.excel_to_data, "csv": cls.csv_to_data } if file_type not in processors: return [{"error": f"Unsupported file type: {file_type}"}] try: result = processors[file_type](file_path, cache) if file_type == "pdf": return [{ "filename": os.path.basename(file_path), "content": result, "status": "initial", "type": "pdf" }] return result except Exception as e: logger.error(f"Error processing {file_type} file: {e}") return [{"error": f"Error processing file: {str(e)}"}] # ==================== TEXT PROCESSING ==================== class TextProcessor: def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") self.cache = Cache(DIRECTORIES["cache"], size_limit=10*1024**3) def chunk_text(self, text: str, max_tokens: int = 1200) -> List[str]: """Split text into token-limited chunks""" tokens = self.tokenizer.encode(text) return [ self.tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens) ] def clean_response(self, text: str) -> str: """Clean and format model response""" text = sanitize_text(text) text = re.sub(r"\[.*?\]|\bNone\b", "", text) diagnoses = [] in_diagnoses = False for line in text.splitlines(): line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_diagnoses = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_diagnoses = False continue if in_diagnoses and re.match(r"-\s*.+", line): diagnosis = re.sub(r"^\-\s*", "", line).strip() if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): diagnoses.append(diagnosis) return " ".join(diagnoses) if diagnoses else "" def summarize_results(self, analysis: str) -> str: """Generate concise summary from full analysis""" chunks = analysis.split("--- Analysis for Chunk") diagnoses = [] for chunk in chunks: chunk = chunk.strip() if not chunk or "No oversights identified" in chunk: continue in_diagnoses = False for line in chunk.splitlines(): line = line.strip() if not line: continue if re.match(r"###\s*Missed Diagnoses", line): in_diagnoses = True continue if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): in_diagnoses = False continue if in_diagnoses and re.match(r"-\s*.+", line): diagnosis = re.sub(r"^\-\s*", "", line).strip() if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): diagnoses.append(diagnosis) unique_diagnoses = list(dict.fromkeys(diagnoses)) if not unique_diagnoses: return "No missed diagnoses were identified in the provided records." if len(unique_diagnoses) > 1: summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1]) summary += f", and {unique_diagnoses[-1]}" else: summary = "Missed diagnoses include " + unique_diagnoses[0] return summary + ", all requiring urgent clinical review." # ==================== CORE APPLICATION ==================== class ClinicalOversightApp: def __init__(self): self.agent = self._initialize_agent() self.text_processor = TextProcessor() self.file_processor = FileProcessor() def _initialize_agent(self): """Initialize the TxAgent with proper configuration""" logger.info("Initializing AI model...") log_system_resources("Before Load") tool_path = os.path.join(DIRECTORIES["tools"], "new_tool.json") if not os.path.exists(tool_path): default_tools = os.path.abspath("data/new_tool.json") shutil.copy(default_tools, tool_path) agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": tool_path}, force_finish=True, enable_checker=False, step_rag_num=4, seed=100, additional_default_tools=[], ) agent.init_model() log_system_resources("After Load") logger.info("AI Agent Ready") return agent def cleanup_resources(self): """Clean up GPU memory and collect garbage""" logger.info("Cleaning up resources...") torch.cuda.empty_cache() gc.collect() if torch.distributed.is_initialized(): logger.info("Destroying PyTorch distributed process group...") torch.distributed.destroy_process_group() def process_response_stream(self, prompt: str, history: List[dict]) -> Generator[dict, None, None]: """Stream the agent's response with proper formatting""" full_response = "" for chunk in self.agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []): if not chunk: continue if isinstance(chunk, list): for message in chunk: if hasattr(message, 'content') and message.content: cleaned = self.text_processor.clean_response(message.content) if cleaned: full_response += cleaned + " " yield { "role": "assistant", "content": f"✅ {cleaned} [{datetime.now().strftime('%H:%M:%S')}]" } elif isinstance(chunk, str) and chunk.strip(): cleaned = self.text_processor.clean_response(chunk) if cleaned: full_response += cleaned + " " yield { "role": "assistant", "content": f"✅ {cleaned} [{datetime.now().strftime('%H:%M:%S')}]" } def analyze(self, message: str, history: List[dict], files: List) -> Generator[tuple, None, None]: """Main analysis pipeline with proper output formatting""" chatbot_output = history.copy() download_output = None final_summary = "" progress_text = {"value": "Starting analysis...", "visible": True} try: # Add user message to history chatbot_output.append({ "role": "user", "content": f"{message} [{datetime.now().strftime('%H:%M:%S')}]" }) yield (chatbot_output, download_output, final_summary, progress_text) # Process uploaded files extracted = [] file_hash_value = "" if files: with ThreadPoolExecutor(max_workers=2) as executor: futures = [] for f in files: file_type = f.name.split(".")[-1].lower() futures.append(executor.submit(self.file_processor.process_file, f.name, file_type, self.text_processor.cache)) for i, future in enumerate(as_completed(futures), 1): try: extracted.extend(future.result()) progress_text = self._update_progress(i, len(files), "Processing files") yield (chatbot_output, download_output, final_summary, progress_text) except Exception as e: logger.error(f"File processing error: {e}") extracted.append({"error": f"Error processing file: {str(e)}"}) file_hash_value = get_file_hash(files[0].name) if files else "" chatbot_output.append({ "role": "assistant", "content": f"✅ File processing complete [{datetime.now().strftime('%H:%M:%S')}]" }) progress_text = self._update_progress(len(files), len(files), "Files processed") yield (chatbot_output, download_output, final_summary, progress_text) # Analyze content text_content = "\n".join(json.dumps(item) for item in extracted) chunks = self.text_processor.chunk_text(text_content) combined_response = "" for chunk_idx, chunk in enumerate(chunks, 1): prompt = f""" Analyze this patient record for missed diagnoses. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings with their potential implications and urgent review recommendations. If no missed diagnoses are found, state 'No missed diagnoses identified'. Patient Record (Chunk {chunk_idx}/{len(chunks)}): {chunk[:1200]} """ chatbot_output.append({"role": "assistant", "content": "⏳ Analyzing..."}) progress_text = self._update_progress(chunk_idx, len(chunks), "Analyzing") yield (chatbot_output, download_output, final_summary, progress_text) # Stream response chunk_response = "" for update in self.process_response_stream(prompt, chatbot_output): chatbot_output[-1] = update chunk_response = update["content"] progress_text = self._update_progress(chunk_idx, len(chunks), "Analyzing") yield (chatbot_output, download_output, final_summary, progress_text) combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n" self.cleanup_resources() # Generate final outputs final_summary = self.text_processor.summarize_results(combined_response) report_path = os.path.join(DIRECTORIES["reports"], f"{file_hash_value}_report.txt") if file_hash_value else None if report_path: with open(report_path, "w", encoding="utf-8") as f: f.write(combined_response + "\n\n" + final_summary) download_output = report_path if report_path and os.path.exists(report_path) else None progress_text = {"visible": False} yield (chatbot_output, download_output, final_summary, progress_text) except Exception as e: logger.error(f"Analysis error: {e}") chatbot_output.append({ "role": "assistant", "content": f"❌ Error: {str(e)} [{datetime.now().strftime('%H:%M:%S')}]" }) final_summary = f"Error occurred: {str(e)}" progress_text = {"visible": False} yield (chatbot_output, download_output, final_summary, progress_text) finally: self.cleanup_resources() def _update_progress(self, current: int, total: int, stage: str = "") -> Dict[str, Any]: """Format progress update for UI""" progress = f"{stage} - {current}/{total}" if stage else f"{current}/{total}" return {"value": progress, "visible": True} def toggle_theme(self, theme_state: str) -> tuple[str, str]: """Toggle between light and dark themes""" new_theme = "dark" if theme_state == "light" else "light" button_text = "☀️ Light Mode" if new_theme == "dark" else "🌙 Dark Mode" return new_theme, button_text def toggle_sidebar(self, sidebar_state: bool) -> bool: """Toggle sidebar visibility""" return not sidebar_state def create_interface(self): """Create Gradio interface with refined ChatGPT-like design""" css = """ body, .gradio-container { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: var(--background); color: var(--text-color); transition: all 0.4s ease; } .gradio-container { max-width: 800px; margin: 0 auto; padding: 24px; } .chat-container { background: var(--chat-bg); border-radius: 16px; padding: 24px; height: 80vh; overflow-y: auto; box-shadow: 0 4px 12px rgba(0,0,0,0.15); position: relative; } .message { margin: 12px 0; padding: 12px 16px; border-radius: 12px; max-width: 80%; transition: all 0.3s ease; background: var(--message-bg); position: relative; } .message:hover { transform: translateY(-2px); box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .message.user { background: linear-gradient(135deg, #007bff, #0056b3); color: white; margin-left: auto; } .message.assistant { background: var(--message-bg); color: var(--text-color); } .message-timestamp { font-size: 0.75em; opacity: 0.7; margin-top: 4px; text-align: right; } .input-container { display: flex; align-items: center; margin-top: 24px; background: var(--chat-bg); padding: 12px 24px; border-radius: 30px; box-shadow: 0 4px 8px rgba(0,0,0,0.15); position: sticky; bottom: 0; } .input-textbox { flex-grow: 1; border: none; background: transparent; color: var(--text-color); outline: none; font-size: 1em; } .input-textbox:focus { border-bottom: 2px solid #007bff; } .send-btn { background: linear-gradient(135deg, #007bff, #0056b3); color: white; border: none; border-radius: 20px; padding: 10px 20px; margin-left: 12px; transition: transform 0.2s ease; } .send-btn:hover { transform: scale(1.05); } .send-btn:active { animation: glow 0.3s ease; } .sidebar { background: var(--sidebar-bg); padding: 24px; border-radius: 16px; margin-top: 24px; box-shadow: 0 4px 12px rgba(0,0,0,0.15); transition: transform 0.4s ease; transform: translateX(0); position: fixed; right: 0; top: 100px; width: 300px; z-index: 1000; backdrop-filter: blur(10px); background: rgba(241, 243, 245, 0.8); } .sidebar-hidden { transform: translateX(100%); } .sidebar-backdrop { position: fixed; top: 0; left: 0; width: 100%; height: 100%; background: rgba(0,0,0,0.3); z-index: 999; display: none; } .sidebar:not(.sidebar-hidden) ~ .sidebar-backdrop { display: block; } .header { text-align: center; margin-bottom: 24px; } .theme-toggle { position: absolute; top: 24px; right: 24px; background: linear-gradient(135deg, #007bff, #0056b3); color: white; border: none; border-radius: 20px; padding: 8px 16px; display: flex; align-items: center; gap: 8px; } .tooltip { position: relative; } .tooltip:hover::after { content: attr(data-tooltip); position: absolute; bottom: 100%; left: 50%; transform: translateX(-50%); background: #333; color: white; padding: 6px 12px; border-radius: 6px; font-size: 0.85em; white-space: nowrap; z-index: 1000; animation: fadeIn 0.3s ease; } .loading-spinner { position: absolute; bottom: 80px; left: 50%; transform: translateX(-50%); font-size: 1.2em; animation: glow 1.5s ease infinite; } .typing-indicator { display: none; font-size: 0.9em; color: var(--text-color); opacity: 0.7; margin: 12px; } .typing-indicator.active { display: block; animation: blink 1s step-end infinite; } .progress-text { position: relative; padding: 8px; background: var(--message-bg); border-radius: 8px; margin-top: 12px; } .progress-text::before { content: ''; position: absolute; top: 0; left: 0; height: 100%; width: 0; background: #007bff; opacity: 0.2; animation: progress 2s linear infinite; } @keyframes glow { 0%, 100% { transform: translateX(-50%) scale(1); opacity: 1; color: #007bff; } 50% { transform: translateX(-50%) scale(1.2); opacity: 0.7; color: #0056b3; } } @keyframes blink { 50% { opacity: 0.3; } } @keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } } @keyframes progress { 0% { width: 0; } 50% { width: 50%; } 100% { width: 0; } } :root { --background: #ffffff; --text-color: #333333; --chat-bg: #f9fafb; --message-bg: #e5e5ea; --sidebar-bg: #f1f3f5; } [data-theme="dark"] { --background: #1e2a44; --text-color: #ffffff; --chat-bg: #2d3b55; --message-bg: #3e4c6a; --sidebar-bg: #2a3650; } @media (max-width: 600px) { .gradio-container { padding: 12px; } .chat-container { height: 70vh; } .input-container { flex-direction: column; gap: 12px; padding: 12px; } .send-btn { width: 100%; margin-left: 0; } .sidebar { width: 100%; top: 80px; } .sidebar-hidden { transform: translateX(100%); } } """ js = """ function applyTheme(theme) { document.documentElement.setAttribute('data-theme', theme); localStorage.setItem('theme', theme); document.querySelector('.theme-toggle').innerHTML = theme === 'dark' ? '☀️ Light Mode' : '🌙 Dark Mode'; } function toggleSidebar() { const sidebar = document.querySelector('.sidebar'); sidebar.classList.toggle('sidebar-hidden'); if (!sidebar.classList.contains('sidebar-hidden')) { setTimeout(() => { if (window.innerWidth <= 600) { sidebar.classList.add('sidebar-hidden'); } }, 5000); } } document.addEventListener('DOMContentLoaded', () => { const savedTheme = localStorage.getItem('theme') || 'light'; applyTheme(savedTheme); document.querySelector('.sidebar').classList.add('sidebar-hidden'); }); """ with gr.Blocks(theme=gr.themes.Default(), css=css, js=js, title="Clinical Oversight Assistant") as app: try: theme_state = gr.State(value="light") sidebar_state = gr.State(value=False) gr.HTML("""

🩺 Clinical Oversight Assistant

AI-powered analysis of patient records for missed diagnoses

""") theme_button = gr.Button("🌙 Dark Mode", elem_classes="theme-toggle") with gr.Column(elem_classes="chat-container"): chatbot = gr.Chatbot( label="Clinical Analysis", height="100%", show_copy_button=True, type="messages", elem_classes="chatbot", render_markdown=True ) gr.HTML("") gr.HTML("
Typing...
") with gr.Row(): tools_button = gr.Button("📂 Tools", variant="secondary") with gr.Column(elem_classes="sidebar"): gr.Markdown("### 📎 Upload Records", elem_classes="tooltip", data_tooltip="Upload patient records") file_upload = gr.File( file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple", label="Patient Records", elem_classes="tooltip", data_tooltip="Select PDF, CSV, or Excel files" ) gr.Markdown("### 📝 Analysis Summary", elem_classes="tooltip", data_tooltip="Summary of findings") final_summary = gr.Markdown( "Analysis results will appear here...", elem_classes="tooltip", data_tooltip="View analysis results" ) gr.Markdown("### 📄 Full Report", elem_classes="tooltip", data_tooltip="Download full report") download_output = gr.File( label="Download Report", visible=False, interactive=False, elem_classes="tooltip", data_tooltip="Download analysis report" ) with gr.Row(elem_classes="input-container"): msg_input = gr.Textbox( placeholder="Ask about potential oversights or upload files...", show_label=False, container=False, elem_classes="input-textbox", autofocus=True ) send_btn = gr.Button( "Analyze", variant="primary", elem_classes="send-btn" ) progress_text = gr.Textbox( label="Progress Status", visible=False, interactive=False, elem_classes="progress-text" ) def show_loading(state: bool) -> dict: return { "value": "
" if state else "", "visible": state } def show_typing(state: bool) -> dict: return { "value": f"
Typing...
", "visible": state } # Theme toggle handler theme_button.click( fn=self.toggle_theme, inputs=[theme_state], outputs=[theme_state, theme_button], _js="function(theme) { applyTheme(theme); }" ) # Sidebar toggle handler tools_button.click( fn=self.toggle_sidebar, inputs=[sidebar_state], outputs=[sidebar_state], _js="toggleSidebar" ) # Analysis handlers send_btn.click( fn=show_loading, inputs=[gr.State(value=True)], outputs=[chatbot] ).then( fn=show_typing, inputs=[gr.State(value=True)], outputs=[chatbot] ).then( fn=self.analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output, final_summary, progress_text], show_progress="hidden" ).then( fn=show_loading, inputs=[gr.State(value=False)], outputs=[chatbot] ).then( fn=show_typing, inputs=[gr.State(value=False)], outputs=[chatbot] ) msg_input.submit( fn=show_loading, inputs=[gr.State(value=True)], outputs=[chatbot] ).then( fn=show_typing, inputs=[gr.State(value=True)], outputs=[chatbot] ).then( fn=self.analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output, final_summary, progress_text], show_progress="hidden" ).then( fn=show_loading, inputs=[gr.State(value=False)], outputs=[chatbot] ).then( fn=show_typing, inputs=[gr.State(value=False)], outputs=[chatbot] ) app.load( fn=lambda: [ [], None, "", "", None, {"visible": False}, "light", False, "🌙 Dark Mode" ], outputs=[chatbot, download_output, final_summary, msg_input, file_upload, progress_text, theme_state, sidebar_state, theme_button], queue=False ) except Exception as e: logger.error(f"Interface creation failed: {e}") self.cleanup_resources() raise return app # ==================== APPLICATION ENTRY POINT ==================== if __name__ == "__main__": app = None try: logger.info("Starting Clinical Oversight Assistant...") app = ClinicalOversightApp() interface = app.create_interface() interface.queue( api_open=False, max_size=20 ).launch( server_name="0.0.0.0", server_port=7860, show_error=True, allowed_paths=[DIRECTORIES["reports"]], share=False ) except Exception as e: logger.error(f"Application failed to start: {e}") raise finally: if app: app.cleanup_resources()