Spaces:
Runtime error
Runtime error
File size: 7,721 Bytes
f126604 d377221 f126604 7e095f4 7757822 f126604 d377221 f126604 d377221 7757822 f126604 7e095f4 520104a 7e095f4 520104a d377221 f126604 7e095f4 f126604 32e4e6a 7e095f4 f126604 32e4e6a f126604 7757822 32e4e6a 7757822 32e4e6a f126604 7e095f4 32e4e6a 520104a 32e4e6a 7e095f4 32e4e6a f126604 d377221 7757822 7e095f4 f126604 7757822 7e095f4 7757822 7e095f4 7757822 7e095f4 7757822 7e095f4 7757822 7e095f4 7757822 7e095f4 7757822 32e4e6a f126604 7757822 f126604 7e095f4 f126604 7e095f4 f126604 7757822 7e095f4 7757822 7e095f4 7757822 7e095f4 32e4e6a d377221 f126604 7e095f4 f126604 7757822 d377221 f126604 7e095f4 f126604 7757822 7e095f4 f126604 d377221 7e095f4 32e4e6a 7e095f4 f126604 7e095f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import os
import sys
import json
import shutil
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
import torch
from datetime import datetime
from pydantic import BaseModel
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('txagent_api.log')
]
)
logger = logging.getLogger("TxAgentAPI")
# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent after setting up path
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Configuration
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
# Create directories if they don't exist
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
os.makedirs(directory, exist_ok=True)
logger.info(f"Created directory: {directory}")
# Set environment variables
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
# Request models
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
class MultistepRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
max_round: int = 5
# Initialize FastAPI app
app = FastAPI(
title="TxAgent API",
description="API for TxAgent medical document analysis",
version="1.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize agent at startup
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
logger.info("Initializing TxAgent...")
agent = init_agent()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize agent: {str(e)}", exc_info=True)
raise RuntimeError(f"Failed to initialize agent: {str(e)}")
def init_agent():
"""Initialize and return the TxAgent instance"""
try:
tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(tool_path):
logger.info(f"Copying tool file to {tool_path}")
default_tool = os.path.abspath("data/new_tool.json")
if os.path.exists(default_tool):
shutil.copy(default_tool, tool_path)
else:
raise FileNotFoundError(f"Default tool file not found at {default_tool}")
logger.info("Creating TxAgent instance")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": tool_path},
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
return agent
except Exception as e:
logger.error(f"Error in init_agent: {str(e)}", exc_info=True)
raise
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Handle chat conversations"""
try:
logger.info(f"Chat request received: {request.message[:50]}...")
response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
logger.info("Chat response generated successfully")
return JSONResponse({
"status": "success",
"response": response,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Chat error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/multistep")
async def multistep_endpoint(request: MultistepRequest):
"""Run multi-step reasoning"""
try:
logger.info(f"Multistep request received: {request.message[:50]}...")
response = agent.run_multistep_agent(
message=request.message,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
max_round=request.max_round
)
logger.info("Multistep reasoning completed successfully")
return JSONResponse({
"status": "success",
"response": response,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Multistep error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze")
async def analyze_document(file: UploadFile = File(...)):
"""Analyze a medical document"""
try:
logger.info(f"Document analysis request received for: {file.filename}")
# Save the uploaded file temporarily
temp_path = os.path.join(file_cache_dir, file.filename)
with open(temp_path, "wb") as f:
f.write(await file.read())
logger.info(f"File saved temporarily at {temp_path}")
# Process the document
text = agent.extract_text_from_file(temp_path)
analysis = agent.analyze_text(text)
logger.info("Document analysis completed successfully")
# Generate report
report_filename = f"{os.path.splitext(file.filename)[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
report_path = os.path.join(report_dir, report_filename)
with open(report_path, "w") as f:
json.dump({
"filename": file.filename,
"analysis": analysis,
"timestamp": datetime.now().isoformat()
}, f, indent=2)
logger.info(f"Report generated at {report_path}")
# Clean up
os.remove(temp_path)
logger.info(f"Temporary file {temp_path} removed")
return JSONResponse({
"status": "success",
"analysis": analysis,
"report_path": report_path,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Document analysis error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
async def service_status():
"""Check service status"""
status = {
"status": "running",
"version": "1.0.0",
"model": agent.model_name if agent else "not loaded",
"device": str(agent.device) if agent else "unknown",
"timestamp": datetime.now().isoformat()
}
logger.info(f"Status check: {status}")
return status
if __name__ == "__main__":
try:
logger.info("Starting TxAgent API server")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
except Exception as e:
logger.error(f"Failed to start server: {str(e)}", exc_info=True)
raise |