import os import sys import json import shutil from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from typing import List, Dict, Optional import torch from datetime import datetime # Configuration persistent_dir = "/data/hf_cache" model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") # Create directories if they don't exist os.makedirs(model_cache_dir, exist_ok=True) os.makedirs(tool_cache_dir, exist_ok=True) os.makedirs(file_cache_dir, exist_ok=True) os.makedirs(report_dir, exist_ok=True) # Set environment variables os.environ["HF_HOME"] = model_cache_dir os.environ["TRANSFORMERS_CACHE"] = model_cache_dir # Set up Python path current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) # Initialize FastAPI app app = FastAPI( title="Clinical Patient Support System API", description="API for analyzing medical documents", version="1.0.0" ) # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize agent at startup agent = None @app.on_event("startup") async def startup_event(): global agent try: agent = init_agent() except Exception as e: raise RuntimeError(f"Failed to initialize agent: {str(e)}") def init_agent(): """Initialize and return the TxAgent instance.""" tool_path = os.path.join(tool_cache_dir, "new_tool.json") if not os.path.exists(tool_path): shutil.copy(os.path.abspath("data/new_tool.json"), tool_path) from txagent.txagent import TxAgent agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": tool_path}, force_finish=True, enable_checker=True, step_rag_num=4, seed=100, use_vllm=False # Disable vLLM for Hugging Face Spaces ) agent.init_model() return agent @app.post("/analyze") async def analyze_document(file: UploadFile = File(...)): """Analyze a medical document and return results.""" try: # Save the uploaded file temporarily temp_path = os.path.join(file_cache_dir, file.filename) with open(temp_path, "wb") as f: f.write(await file.read()) # Process the file and generate response result = agent.process_document(temp_path) # Clean up os.remove(temp_path) return JSONResponse({ "status": "success", "result": result, "timestamp": datetime.now().isoformat() }) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/status") async def service_status(): """Check service status.""" return { "status": "running", "version": "1.0.0", "model": agent.model_name if agent else "not loaded", "device": str(agent.device) if agent else "unknown" } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)