import sys import os import json import shutil import re import gc import time from datetime import datetime from typing import List, Tuple, Dict, Union, Optional from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware import pandas as pd import pdfplumber import torch import matplotlib.pyplot as plt from fpdf import FPDF import unicodedata import uvicorn # === Configuration === persistent_dir = "/data/hf_cache" model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") # Create directories if they don't exist for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: os.makedirs(d, exist_ok=True) # Set environment variables os.environ["HF_HOME"] = model_cache_dir os.environ["TRANSFORMERS_CACHE"] = model_cache_dir os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # Fix for matplotlib permission issues # Set up Python path current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) # Import TxAgent after setting up paths from txagent.txagent import TxAgent # Constants MAX_MODEL_TOKENS = 131072 MAX_NEW_TOKENS = 4096 MAX_CHUNK_TOKENS = 8192 BATCH_SIZE = 1 PROMPT_OVERHEAD = 300 SAFE_SLEEP = 0.5 # Initialize FastAPI app app = FastAPI( title="Clinical Patient Support System API", description="API for analyzing and summarizing unstructured medical files", version="1.0.0" ) # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize agent at startup agent = None @app.on_event("startup") async def startup_event(): global agent try: agent = init_agent() except Exception as e: raise RuntimeError(f"Failed to initialize agent: {str(e)}") def init_agent() -> TxAgent: """Initialize and return the TxAgent instance.""" tool_path = os.path.join(tool_cache_dir, "new_tool.json") if not os.path.exists(tool_path): shutil.copy(os.path.abspath("data/new_tool.json"), tool_path) agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": tool_path}, force_finish=True, enable_checker=True, step_rag_num=4, seed=100 ) agent.init_model() return agent # Utility functions (keep your existing functions but add error handling) def estimate_tokens(text: str) -> int: """Estimate the number of tokens in the given text.""" return len(text) // 4 + 1 def clean_response(text: str) -> str: """Clean and format the response text.""" if not text: return "" text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) text = re.sub(r"\n{3,}", "\n\n", text) return text.strip() def extract_text_from_excel(path: str) -> str: """Extract text from Excel file.""" try: all_text = [] xls = pd.ExcelFile(path) for sheet_name in xls.sheet_names: try: df = xls.parse(sheet_name).astype(str).fillna("") except Exception: continue for _, row in df.iterrows(): non_empty = [cell.strip() for cell in row if cell.strip()] if len(non_empty) >= 2: text_line = " | ".join(non_empty) if len(text_line) > 15: all_text.append(f"[{sheet_name}] {text_line}") return "\n".join(all_text) except Exception as e: raise RuntimeError(f"Failed to extract text from Excel: {str(e)}") def extract_text(file_path: str) -> str: """Extract text from supported file types.""" try: if file_path.endswith(".xlsx"): return extract_text_from_excel(file_path) elif file_path.endswith(".csv"): df = pd.read_csv(file_path).astype(str).fillna("") return "\n".join( " | ".join(cell.strip() for cell in row if cell.strip()) for _, row in df.iterrows() if len([cell for cell in row if cell.strip()]) >= 2 ) elif file_path.endswith(".pdf"): with pdfplumber.open(file_path) as pdf: return "\n".join(page.extract_text() or "" for page in pdf.pages) else: return "" except Exception as e: raise RuntimeError(f"Failed to extract text from file: {str(e)}") # API endpoints @app.post("/analyze") async def analyze_document(file: UploadFile = File(...)): """Analyze a medical document and return results.""" start_time = time.time() try: # Save the uploaded file temporarily temp_path = os.path.join(file_cache_dir, file.filename) with open(temp_path, "wb") as f: f.write(await file.read()) extracted = extract_text(temp_path) if not extracted: raise HTTPException(status_code=400, detail="Could not extract text from the file") chunks = split_text(extracted) batches = batch_chunks(chunks, batch_size=BATCH_SIZE) batch_results = analyze_batches(agent, batches) valid_results = [res for res in batch_results if not res.startswith("❌")] if not valid_results: raise HTTPException(status_code=400, detail="No valid analysis results were generated") final_summary = generate_final_summary(agent, "\n\n".join(valid_results)) # Generate report files report_filename = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}" report_path = os.path.join(report_dir, f"{report_filename}.md") with open(report_path, 'w', encoding='utf-8') as f: f.write(f"# Final Medical Report\n\n{final_summary}") pdf_path = generate_pdf_report_with_charts(final_summary, report_path, detailed_batches=batch_results) # Clean up temp file os.remove(temp_path) return JSONResponse({ "status": "success", "summary": final_summary, "report_path": f"/reports/{os.path.basename(pdf_path)}", "processing_time": f"{time.time() - start_time:.2f} seconds", "detailed_outputs": batch_results }) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/reports/{filename}") async def download_report(filename: str): """Download a generated report.""" file_path = os.path.join(report_dir, filename) if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="Report not found") return FileResponse(file_path, media_type='application/pdf', filename=filename) @app.get("/status") async def service_status(): """Check service status.""" return { "status": "running", "version": "1.0.0", "model": "mims-harvard/TxAgent-T1-Llama-3.1-8B", "max_tokens": MAX_MODEL_TOKENS, "supported_file_types": [".pdf", ".xlsx", ".csv"] } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)