import sys import os import pandas as pd import pdfplumber import json import gradio as gr from typing import List, Dict, Any from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import psutil import subprocess import logging import traceback from datetime import datetime # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler('clinical_oversight.log') ] ) logger = logging.getLogger(__name__) # Persistent directory persistent_dir = "/data/hf_cache" os.makedirs(persistent_dir, exist_ok=True) model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache") for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]: os.makedirs(directory, exist_ok=True) os.environ["HF_HOME"] = model_cache_dir os.environ["TRANSFORMERS_CACHE"] = model_cache_dir os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["CUDA_LAUNCH_BLOCKING"] = "1" current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications', 'allergies', 'summary', 'impression', 'findings', 'recommendations'} def sanitize_utf8(text: str) -> str: """Ensure text is UTF-8 encoded and clean.""" try: return text.encode("utf-8", "ignore").decode("utf-8") except Exception as e: logger.error(f"UTF-8 sanitization failed: {str(e)}") return "" def file_hash(path: str) -> str: """Generate MD5 hash of file content.""" try: with open(path, "rb") as f: return hashlib.md5(f.read()).hexdigest() except Exception as e: logger.error(f"File hash generation failed for {path}: {str(e)}") return "" def extract_priority_pages(file_path: str, max_pages: int = 20) -> str: """Extract pages from PDF with priority given to pages containing medical keywords.""" try: text_chunks = [] logger.info(f"Extracting pages from {file_path}") with pdfplumber.open(file_path) as pdf: # Always extract first 3 pages for i, page in enumerate(pdf.pages[:3]): try: text = page.extract_text() or "" text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}") except Exception as page_error: logger.warning(f"Error processing page {i+1}: {str(page_error)}") text_chunks.append(f"=== Page {i+1} ===\n[Error extracting content]") # Extract remaining pages that contain medical keywords for i, page in enumerate(pdf.pages[3:max_pages], start=4): try: page_text = page.extract_text() or "" if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS): text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}") except Exception as page_error: logger.warning(f"Error processing page {i}: {str(page_error)}") return "\n\n".join(text_chunks) except Exception as e: logger.error(f"PDF processing error for {file_path}: {str(e)}") return f"PDF processing error: {str(e)}" def convert_file_to_json(file_path: str, file_type: str) -> str: """Convert different file types to JSON format with caching.""" try: h = file_hash(file_path) if not h: return json.dumps({"error": "Could not generate file hash"}) cache_path = os.path.join(file_cache_dir, f"{h}.json") # Check cache first if os.path.exists(cache_path): try: with open(cache_path, "r", encoding="utf-8") as f: return f.read() except Exception as cache_error: logger.error(f"Cache read error for {file_path}: {str(cache_error)}") result = {} try: if file_type == "pdf": text = extract_priority_pages(file_path) result = { "filename": os.path.basename(file_path), "content": text, "status": "initial", "file_type": "pdf" } elif file_type == "csv": df = pd.read_csv( file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip" ) content = df.fillna("").astype(str).values.tolist() result = { "filename": os.path.basename(file_path), "rows": content, "file_type": "csv" } elif file_type in ["xls", "xlsx"]: try: df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str) except Exception: try: df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str) except Exception as excel_error: logger.error(f"Excel read error for {file_path}: {str(excel_error)}") raise content = df.fillna("").astype(str).values.tolist() result = { "filename": os.path.basename(file_path), "rows": content, "file_type": "excel" } else: result = {"error": f"Unsupported file type: {file_type}"} json_result = json.dumps(result) # Save to cache try: with open(cache_path, "w", encoding="utf-8") as f: f.write(json_result) except Exception as cache_write_error: logger.error(f"Cache write error for {file_path}: {str(cache_write_error)}") return json_result except Exception as processing_error: logger.error(f"Error processing {file_path}: {str(processing_error)}") return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(processing_error)}"}) except Exception as e: logger.error(f"Unexpected error in convert_file_to_json: {str(e)}") return json.dumps({"error": f"Unexpected error processing file: {str(e)}"}) def log_system_usage(tag=""): """Log system resource usage including CPU, RAM, and GPU.""" try: cpu = psutil.cpu_percent(interval=1) mem = psutil.virtual_memory() logger.info(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB") try: result = subprocess.run( ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"], capture_output=True, text=True ) if result.returncode == 0: used, total, util = result.stdout.strip().split(", ") logger.info(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%") except Exception as gpu_error: logger.warning(f"[{tag}] GPU monitor failed: {gpu_error}") except Exception as e: logger.error(f"System usage logging failed: {str(e)}") def init_agent(): """Initialize the TxAgent with proper configuration.""" logger.info("Initializing model...") log_system_usage("Before Load") default_tool_path = os.path.abspath("data/new_tool.json") target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") try: if not os.path.exists(target_tool_path): shutil.copy(default_tool_path, target_tool_path) logger.info("Copied default tool configuration") except Exception as e: logger.error(f"Tool configuration copy failed: {str(e)}") raise try: agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": target_tool_path}, force_finish=True, enable_checker=True, step_rag_num=8, seed=100, additional_default_tools=[], ) agent.init_model() log_system_usage("After Load") logger.info("Agent initialization successful") return agent except Exception as e: logger.error(f"Agent initialization failed: {str(e)}") raise def save_report(content: str, file_hash_value: str = "") -> str: """Save analysis report to file and return path.""" try: if not file_hash_value: file_hash_value = hashlib.md5(content.encode()).hexdigest() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") report_filename = f"report_{timestamp}_{file_hash_value[:8]}.txt" report_path = os.path.join(report_dir, report_filename) with open(report_path, "w", encoding="utf-8") as f: f.write(content) logger.info(f"Report saved to {report_path}") return report_path except Exception as e: logger.error(f"Failed to save report: {str(e)}") return "" def clean_response(content: str) -> str: """Clean up model response by removing tool call artifacts.""" if not content: return "⚠️ No content generated." try: # Remove tool call artifacts cleaned = re.sub(r"\[TOOL_CALLS\].*?(?=(\[|\Z))", "", content, flags=re.DOTALL).strip() # Remove excessive whitespace cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) return cleaned or "⚠️ Empty response after cleaning." except Exception as e: logger.error(f"Response cleaning failed: {str(e)}") return content def process_model_response(chunk: Any, history: List[Dict[str, str]]) -> List[Dict[str, str]]: """Process model response chunk and update chat history.""" try: if chunk is None: return history if isinstance(chunk, list) and all(hasattr(m, 'role') and hasattr(m, 'content') for m in chunk): for m in chunk: cleaned_content = clean_response(m.content) history.append({"role": m.role, "content": cleaned_content}) elif isinstance(chunk, str): cleaned_chunk = clean_response(chunk) if history and history[-1]["role"] == "assistant": history[-1]["content"] += cleaned_chunk else: history.append({"role": "assistant", "content": cleaned_chunk}) else: logger.warning(f"Unexpected response type: {type(chunk)}") return history except Exception as e: logger.error(f"Error processing model response: {str(e)}") history.append({"role": "assistant", "content": f"⚠️ Error processing response: {str(e)}"}) return history def analyze(message: str, history: list, files: list): """Main analysis function that processes files and generates responses.""" try: # Initial response new_history = history.copy() new_history.append({"role": "user", "content": message}) new_history.append({"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."}) yield new_history, None # Process files extracted = "" file_hash_value = "" if files: logger.info(f"Processing {len(files)} files...") with ThreadPoolExecutor(max_workers=4) as executor: futures = [] for f in files: try: file_type = f.name.split(".")[-1].lower() futures.append(executor.submit(convert_file_to_json, f.name, file_type)) except Exception as e: logger.error(f"Error submitting file {f.name} for processing: {str(e)}") new_history.append({"role": "system", "content": f"⚠️ Error processing {f.name}: {str(e)}"}) results = [] for f in as_completed(futures): try: results.append(sanitize_utf8(f.result())) except Exception as e: logger.error(f"Error getting file processing result: {str(e)}") results.append(json.dumps({"error": "File processing failed"})) extracted = "\n".join(results) try: file_hash_value = file_hash(files[0].name) if files else "" except Exception as e: logger.error(f"Error generating file hash: {str(e)}") file_hash_value = "" # Prepare prompt prompt = f"""Review these medical records and identify EXACTLY what might have been missed: 1. List potential missed diagnoses 2. Flag any medication conflicts 3. Note incomplete assessments 4. Highlight abnormal results needing follow-up Medical Records: {extracted[:12000]} ### Potential Oversights: """ logger.info(f"Prompt length: {len(prompt)} characters") # Initialize agent response agent = init_agent() response_content = "" report_path = "" # Process agent response for chunk in agent.run_gradio_chat( message=prompt, history=[], temperature=0.2, max_new_tokens=2048, max_token=4096, call_agent=False, conversation=[], ): try: new_history = process_model_response(chunk, new_history) if isinstance(chunk, str): response_content += clean_response(chunk) yield new_history, None except Exception as chunk_error: logger.error(f"Error processing chunk: {str(chunk_error)}") new_history.append({"role": "assistant", "content": f"⚠️ Error processing response chunk: {str(chunk_error)}"}) yield new_history, None # Save final report if response_content: try: report_path = save_report(response_content, file_hash_value) except Exception as report_error: logger.error(f"Error saving report: {str(report_error)}") new_history.append({"role": "system", "content": "⚠️ Failed to save full report"}) yield new_history, report_path if report_path and os.path.exists(report_path) else None except Exception as e: logger.error(f"Analysis error: {str(e)}\n{traceback.format_exc()}") error_history = history.copy() error_history.append({"role": "assistant", "content": f"❌ Critical error occurred: {str(e)}"}) yield error_history, None def create_ui(agent): """Create Gradio UI interface.""" with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Oversight Assistant") as demo: gr.Markdown("

🩺 Clinical Oversight Assistant

") gr.Markdown("""
Upload medical records and ask about potential oversights or missed diagnoses.
""") with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( label="Analysis Conversation", height=600, bubble_full_width=False, show_copy_button=True ) msg_input = gr.Textbox( placeholder="Ask about potential oversights...", show_label=False, container=False ) with gr.Row(): send_btn = gr.Button("Analyze", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(scale=1): file_upload = gr.File( file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple", label="Upload Medical Records" ) download_output = gr.File( label="Download Full Report", interactive=False ) gr.Markdown("""
Note: The system analyzes PDFs, CSVs, and Excel files for potential clinical oversights.
""") # Event handlers send_btn.click( analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output] ) msg_input.submit( analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output] ) clear_btn.click( lambda: ([], None), inputs=[], outputs=[chatbot, download_output] ) # Add some examples gr.Examples( examples=[ ["What potential diagnoses might have been missed in these records?"], ["Are there any medication conflicts I should be aware of?"], ["What abnormal results need follow-up in these reports?"] ], inputs=msg_input, label="Example Questions" ) return demo if __name__ == "__main__": try: logger.info("🚀 Launching Clinical Oversight Assistant...") agent = init_agent() demo = create_ui(agent) demo.queue( api_open=False, concurrency_count=2 ).launch( server_name="0.0.0.0", server_port=7860, show_error=True, allowed_paths=[report_dir], share=False ) except Exception as e: logger.error(f"Application failed to start: {str(e)}\n{traceback.format_exc()}") raise