TxAgent-Api / app.py
Ali2206's picture
Update app.py
d377221 verified
raw
history blame
3.49 kB
import os
import sys
import json
import shutil
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
import torch
from datetime import datetime
# Configuration
persistent_dir = "/data/hf_cache"
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
# Create directories if they don't exist
os.makedirs(model_cache_dir, exist_ok=True)
os.makedirs(tool_cache_dir, exist_ok=True)
os.makedirs(file_cache_dir, exist_ok=True)
os.makedirs(report_dir, exist_ok=True)
# Set environment variables
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
# Set up Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Initialize FastAPI app
app = FastAPI(
title="Clinical Patient Support System API",
description="API for analyzing medical documents",
version="1.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize agent at startup
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
agent = init_agent()
except Exception as e:
raise RuntimeError(f"Failed to initialize agent: {str(e)}")
def init_agent():
"""Initialize and return the TxAgent instance."""
tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(tool_path):
shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
from txagent.txagent import TxAgent
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100,
use_vllm=False # Disable vLLM for Hugging Face Spaces
)
agent.init_model()
return agent
@app.post("/analyze")
async def analyze_document(file: UploadFile = File(...)):
"""Analyze a medical document and return results."""
try:
# Save the uploaded file temporarily
temp_path = os.path.join(file_cache_dir, file.filename)
with open(temp_path, "wb") as f:
f.write(await file.read())
# Process the file and generate response
result = agent.process_document(temp_path)
# Clean up
os.remove(temp_path)
return JSONResponse({
"status": "success",
"result": result,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
async def service_status():
"""Check service status."""
return {
"status": "running",
"version": "1.0.0",
"model": agent.model_name if agent else "not loaded",
"device": str(agent.device) if agent else "unknown"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)