Spaces:
Restarting
on
A100
Restarting
on
A100
import os | |
import sys | |
import json | |
import shutil | |
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import List, Dict, Optional | |
import torch | |
from datetime import datetime | |
# Configuration | |
persistent_dir = "/data/hf_cache" | |
model_cache_dir = os.path.join(persistent_dir, "txagent_models") | |
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") | |
file_cache_dir = os.path.join(persistent_dir, "cache") | |
report_dir = os.path.join(persistent_dir, "reports") | |
# Create directories if they don't exist | |
os.makedirs(model_cache_dir, exist_ok=True) | |
os.makedirs(tool_cache_dir, exist_ok=True) | |
os.makedirs(file_cache_dir, exist_ok=True) | |
os.makedirs(report_dir, exist_ok=True) | |
# Set environment variables | |
os.environ["HF_HOME"] = model_cache_dir | |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir | |
# Set up Python path | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
src_path = os.path.abspath(os.path.join(current_dir, "src")) | |
sys.path.insert(0, src_path) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Clinical Patient Support System API", | |
description="API for analyzing medical documents", | |
version="1.0.0" | |
) | |
# CORS configuration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize agent at startup | |
agent = None | |
async def startup_event(): | |
global agent | |
try: | |
agent = init_agent() | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize agent: {str(e)}") | |
def init_agent(): | |
"""Initialize and return the TxAgent instance.""" | |
tool_path = os.path.join(tool_cache_dir, "new_tool.json") | |
if not os.path.exists(tool_path): | |
shutil.copy(os.path.abspath("data/new_tool.json"), tool_path) | |
from txagent.txagent import TxAgent | |
agent = TxAgent( | |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
tool_files_dict={"new_tool": tool_path}, | |
force_finish=True, | |
enable_checker=True, | |
step_rag_num=4, | |
seed=100, | |
use_vllm=False # Disable vLLM for Hugging Face Spaces | |
) | |
agent.init_model() | |
return agent | |
async def analyze_document(file: UploadFile = File(...)): | |
"""Analyze a medical document and return results.""" | |
try: | |
# Save the uploaded file temporarily | |
temp_path = os.path.join(file_cache_dir, file.filename) | |
with open(temp_path, "wb") as f: | |
f.write(await file.read()) | |
# Process the file and generate response | |
result = agent.process_document(temp_path) | |
# Clean up | |
os.remove(temp_path) | |
return JSONResponse({ | |
"status": "success", | |
"result": result, | |
"timestamp": datetime.now().isoformat() | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def service_status(): | |
"""Check service status.""" | |
return { | |
"status": "running", | |
"version": "1.0.0", | |
"model": agent.model_name if agent else "not loaded", | |
"device": str(agent.device) if agent else "unknown" | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |