TxAgent-Api / app.py
Ali2206's picture
Update app.py
7e095f4 verified
raw
history blame
7.72 kB
import os
import sys
import json
import shutil
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
import torch
from datetime import datetime
from pydantic import BaseModel
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('txagent_api.log')
]
)
logger = logging.getLogger("TxAgentAPI")
# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent after setting up path
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Configuration
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
# Create directories if they don't exist
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
os.makedirs(directory, exist_ok=True)
logger.info(f"Created directory: {directory}")
# Set environment variables
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
# Request models
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
class MultistepRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
max_round: int = 5
# Initialize FastAPI app
app = FastAPI(
title="TxAgent API",
description="API for TxAgent medical document analysis",
version="1.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize agent at startup
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
logger.info("Initializing TxAgent...")
agent = init_agent()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize agent: {str(e)}", exc_info=True)
raise RuntimeError(f"Failed to initialize agent: {str(e)}")
def init_agent():
"""Initialize and return the TxAgent instance"""
try:
tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(tool_path):
logger.info(f"Copying tool file to {tool_path}")
default_tool = os.path.abspath("data/new_tool.json")
if os.path.exists(default_tool):
shutil.copy(default_tool, tool_path)
else:
raise FileNotFoundError(f"Default tool file not found at {default_tool}")
logger.info("Creating TxAgent instance")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": tool_path},
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
return agent
except Exception as e:
logger.error(f"Error in init_agent: {str(e)}", exc_info=True)
raise
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Handle chat conversations"""
try:
logger.info(f"Chat request received: {request.message[:50]}...")
response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
logger.info("Chat response generated successfully")
return JSONResponse({
"status": "success",
"response": response,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Chat error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/multistep")
async def multistep_endpoint(request: MultistepRequest):
"""Run multi-step reasoning"""
try:
logger.info(f"Multistep request received: {request.message[:50]}...")
response = agent.run_multistep_agent(
message=request.message,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
max_round=request.max_round
)
logger.info("Multistep reasoning completed successfully")
return JSONResponse({
"status": "success",
"response": response,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Multistep error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze")
async def analyze_document(file: UploadFile = File(...)):
"""Analyze a medical document"""
try:
logger.info(f"Document analysis request received for: {file.filename}")
# Save the uploaded file temporarily
temp_path = os.path.join(file_cache_dir, file.filename)
with open(temp_path, "wb") as f:
f.write(await file.read())
logger.info(f"File saved temporarily at {temp_path}")
# Process the document
text = agent.extract_text_from_file(temp_path)
analysis = agent.analyze_text(text)
logger.info("Document analysis completed successfully")
# Generate report
report_filename = f"{os.path.splitext(file.filename)[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
report_path = os.path.join(report_dir, report_filename)
with open(report_path, "w") as f:
json.dump({
"filename": file.filename,
"analysis": analysis,
"timestamp": datetime.now().isoformat()
}, f, indent=2)
logger.info(f"Report generated at {report_path}")
# Clean up
os.remove(temp_path)
logger.info(f"Temporary file {temp_path} removed")
return JSONResponse({
"status": "success",
"analysis": analysis,
"report_path": report_path,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Document analysis error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
async def service_status():
"""Check service status"""
status = {
"status": "running",
"version": "1.0.0",
"model": agent.model_name if agent else "not loaded",
"device": str(agent.device) if agent else "unknown",
"timestamp": datetime.now().isoformat()
}
logger.info(f"Status check: {status}")
return status
if __name__ == "__main__":
try:
logger.info("Starting TxAgent API server")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
except Exception as e:
logger.error(f"Failed to start server: {str(e)}", exc_info=True)
raise