File size: 3,489 Bytes
f126604
d377221
f126604
 
 
d377221
f126604
d377221
f126604
d377221
f126604
d377221
f126604
 
 
 
 
 
32e4e6a
d377221
 
 
 
f126604
32e4e6a
f126604
 
 
32e4e6a
f126604
 
 
 
32e4e6a
 
 
d377221
32e4e6a
 
f126604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32e4e6a
 
 
f126604
d377221
32e4e6a
f126604
 
 
32e4e6a
d377221
f126604
 
 
 
 
 
 
d377221
 
f126604
 
 
 
32e4e6a
f126604
32e4e6a
f126604
 
 
 
 
 
d377221
 
32e4e6a
d377221
f126604
 
 
 
d377221
 
f126604
 
 
 
 
 
 
32e4e6a
 
f126604
 
d377221
 
32e4e6a
f126604
 
d377221
f126604
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
import json
import shutil
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
import torch
from datetime import datetime

# Configuration
persistent_dir = "/data/hf_cache"
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")

# Create directories if they don't exist
os.makedirs(model_cache_dir, exist_ok=True)
os.makedirs(tool_cache_dir, exist_ok=True)
os.makedirs(file_cache_dir, exist_ok=True)
os.makedirs(report_dir, exist_ok=True)

# Set environment variables
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir

# Set up Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)

# Initialize FastAPI app
app = FastAPI(
    title="Clinical Patient Support System API",
    description="API for analyzing medical documents",
    version="1.0.0"
)

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize agent at startup
agent = None

@app.on_event("startup")
async def startup_event():
    global agent
    try:
        agent = init_agent()
    except Exception as e:
        raise RuntimeError(f"Failed to initialize agent: {str(e)}")

def init_agent():
    """Initialize and return the TxAgent instance."""
    tool_path = os.path.join(tool_cache_dir, "new_tool.json")
    if not os.path.exists(tool_path):
        shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
    
    from txagent.txagent import TxAgent
    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        tool_files_dict={"new_tool": tool_path},
        force_finish=True,
        enable_checker=True,
        step_rag_num=4,
        seed=100,
        use_vllm=False  # Disable vLLM for Hugging Face Spaces
    )
    agent.init_model()
    return agent

@app.post("/analyze")
async def analyze_document(file: UploadFile = File(...)):
    """Analyze a medical document and return results."""
    try:
        # Save the uploaded file temporarily
        temp_path = os.path.join(file_cache_dir, file.filename)
        with open(temp_path, "wb") as f:
            f.write(await file.read())
        
        # Process the file and generate response
        result = agent.process_document(temp_path)
        
        # Clean up
        os.remove(temp_path)
        
        return JSONResponse({
            "status": "success",
            "result": result,
            "timestamp": datetime.now().isoformat()
        })
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/status")
async def service_status():
    """Check service status."""
    return {
        "status": "running",
        "version": "1.0.0",
        "model": agent.model_name if agent else "not loaded",
        "device": str(agent.device) if agent else "unknown"
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)