Spaces:
Running
on
A100
Running
on
A100
# app.py - FastAPI application | |
import os | |
import sys | |
import json | |
import shutil | |
from fastapi import FastAPI, HTTPException, UploadFile, File | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import List, Dict, Optional | |
import torch | |
from datetime import datetime | |
from pydantic import BaseModel | |
# Configuration | |
persistent_dir = "/data/hf_cache" | |
model_cache_dir = os.path.join(persistent_dir, "txagent_models") | |
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") | |
file_cache_dir = os.path.join(persistent_dir, "cache") | |
report_dir = os.path.join(persistent_dir, "reports") | |
# Create directories if they don't exist | |
os.makedirs(model_cache_dir, exist_ok=True) | |
os.makedirs(tool_cache_dir, exist_ok=True) | |
os.makedirs(file_cache_dir, exist_ok=True) | |
os.makedirs(report_dir, exist_ok=True) | |
# Set environment variables | |
os.environ["HF_HOME"] = model_cache_dir | |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir | |
# Set up Python path | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
src_path = os.path.abspath(os.path.join(current_dir, "src")) | |
sys.path.insert(0, src_path) | |
# Request models | |
class ChatRequest(BaseModel): | |
message: str | |
temperature: float = 0.7 | |
max_new_tokens: int = 512 | |
history: Optional[List[Dict]] = None | |
class MultistepRequest(BaseModel): | |
message: str | |
temperature: float = 0.7 | |
max_new_tokens: int = 512 | |
max_round: int = 5 | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="TxAgent API", | |
description="API for TxAgent medical document analysis", | |
version="1.0.0" | |
) | |
# CORS configuration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize agent at startup | |
agent = None | |
async def startup_event(): | |
global agent | |
try: | |
agent = init_agent() | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize agent: {str(e)}") | |
def init_agent(): | |
"""Initialize and return the TxAgent instance""" | |
tool_path = os.path.join(tool_cache_dir, "new_tool.json") | |
if not os.path.exists(tool_path): | |
shutil.copy(os.path.abspath("data/new_tool.json"), tool_path) | |
agent = TxAgent( | |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
tool_files_dict={"new_tool": tool_path}, | |
enable_finish=True, | |
enable_rag=False, | |
force_finish=True, | |
enable_checker=True, | |
step_rag_num=4, | |
seed=100 | |
) | |
agent.init_model() | |
return agent | |
async def chat_endpoint(request: ChatRequest): | |
"""Handle chat conversations""" | |
try: | |
response = agent.chat( | |
message=request.message, | |
history=request.history, | |
temperature=request.temperature, | |
max_new_tokens=request.max_new_tokens | |
) | |
return JSONResponse({ | |
"status": "success", | |
"response": response, | |
"timestamp": datetime.now().isoformat() | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def multistep_endpoint(request: MultistepRequest): | |
"""Run multi-step reasoning""" | |
try: | |
response = agent.run_multistep_agent( | |
message=request.message, | |
temperature=request.temperature, | |
max_new_tokens=request.max_new_tokens, | |
max_round=request.max_round | |
) | |
return JSONResponse({ | |
"status": "success", | |
"response": response, | |
"timestamp": datetime.now().isoformat() | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def analyze_document(file: UploadFile = File(...)): | |
"""Analyze a medical document""" | |
try: | |
# Save the uploaded file temporarily | |
temp_path = os.path.join(file_cache_dir, file.filename) | |
with open(temp_path, "wb") as f: | |
f.write(await file.read()) | |
# Process the document | |
text = agent.extract_text_from_file(temp_path) | |
analysis = agent.analyze_text(text) | |
# Generate report | |
report_path = os.path.join(report_dir, f"{file.filename}.json") | |
with open(report_path, "w") as f: | |
json.dump({ | |
"filename": file.filename, | |
"analysis": analysis, | |
"timestamp": datetime.now().isoformat() | |
}, f) | |
# Clean up | |
os.remove(temp_path) | |
return JSONResponse({ | |
"status": "success", | |
"analysis": analysis, | |
"report_path": report_path, | |
"timestamp": datetime.now().isoformat() | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def service_status(): | |
"""Check service status""" | |
return { | |
"status": "running", | |
"version": "1.0.0", | |
"model": agent.model_name if agent else "not loaded", | |
"device": str(agent.device) if agent else "unknown" | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |