TxAgent-Api / app.py
Ali2206's picture
Update app.py
bdcc052 verified
raw
history blame
3.1 kB
import os
import sys
import json
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
from datetime import datetime
from pydantic import BaseModel
# Configure logging for Hugging Face Spaces
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")
# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent after setting up path
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Initialize FastAPI app
app = FastAPI(
title="TxAgent API",
description="API for TxAgent medical document analysis",
version="1.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize agent at startup
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
logger.info("Initializing TxAgent...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={}, # No tool files in this example
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize agent: {str(e)}")
raise RuntimeError(f"Failed to initialize agent: {str(e)}")
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Handle chat conversations"""
try:
logger.info(f"Chat request received: {request.message[:50]}...")
response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
return JSONResponse({
"status": "success",
"response": response,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Chat error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
async def service_status():
"""Check service status"""
return {
"status": "running",
"version": "1.0.0",
"model": agent.model_name if agent else "not loaded",
"timestamp": datetime.now().isoformat()
}