TxAgent-Api / app.py
Ali2206's picture
Update app.py
ab0089d verified
raw
history blame
6.88 kB
import os
import sys
import json
import re
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
from datetime import datetime
from pydantic import BaseModel
import markdown
import PyPDF2
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")
# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent after setting up path
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Initialize FastAPI app
app = FastAPI(
title="TxAgent API",
description="API for TxAgent medical document analysis",
version="2.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request models
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
format: Optional[str] = "clean" # Options: raw, clean, structured, html
# Response cleaning functions
def clean_text_response(text: str) -> str:
"""Basic text cleaning"""
text = re.sub(r'\n\s*\n', '\n\n', text)
text = re.sub(r'[ ]+', ' ', text)
text = text.replace("**", "").replace("__", "")
return text.strip()
def structure_medical_response(text: str) -> Dict:
"""Structure medical content into categories"""
result = {"overview": "", "symptoms": [], "types": {}, "notes": ""}
overview_end = text.find("Type 1 Diabetes:")
result["overview"] = clean_text_response(text[:overview_end])
type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:])
current_type = None
for section in type_sections:
if section.startswith("Type "):
current_type = section.replace(":", "").strip().lower()
result["types"][current_type] = []
elif current_type:
points = [p.strip() for p in section.split('\n') if p.strip()]
result["types"][current_type] = points
return result
# Initialize agent at startup
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
logger.info("Initializing TxAgent...")
# Get Hugging Face token from environment variable
hf_token = os.getenv("HUGGINGFACE_TOKEN")
model_kwargs = {"token": hf_token} if hf_token else {}
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", # Corrected model name
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={},
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100,
**model_kwargs
)
agent.init_model()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize agent: {str(e)}")
# Allow app to start, but endpoints will return errors if agent is None
agent = None
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Handle chat conversations with formatting options"""
if agent is None:
logger.error("TxAgent not initialized")
raise HTTPException(status_code=503, detail="TxAgent service unavailable")
try:
logger.info(f"Chat request received (format: {request.format})")
raw_response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
response_content = formatted_response.get(request.format, formatted_response["clean"])
return JSONResponse({
"status": "success",
"format": request.format,
"response": response_content,
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"Chat error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
"""Handle file uploads and process with TxAgent"""
if agent is None:
logger.error("TxAgent not initialized")
raise HTTPException(status_code=503, detail="TxAgent service unavailable")
try:
logger.info(f"File upload received: {file.filename}")
content = ""
if file.filename.endswith('.pdf'):
pdf_reader = PyPDF2.PdfReader(file.file)
for page in pdf_reader.pages:
content += page.extract_text() or ""
else:
content = await file.read()
content = content.decode('utf-8', errors='ignore')
message = f"Analyze the following medical document content:\n\n{content[:10000]}"
raw_response = agent.chat(
message=message,
history=[],
temperature=0.7,
max_new_tokens=512
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
response_content = formatted_response["clean"]
return JSONResponse({
"status": "success",
"format": "clean",
"response": response_content,
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"File upload error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
file.file.close()
@app.get("/status")
async def service_status():
"""Check service status"""
return {
"status": "running" if agent else "failed",
"version": "2.0.0",
"model": agent.model_name if agent else "not loaded",
"formats_available": ["raw", "clean", "structured", "html"],
"timestamp": datetime.now().isoformat()
}