TxAgent-Api / app.py
Ali2206's picture
Update app.py
f3089ba verified
raw
history blame
6.27 kB
import os
import sys
import json
import re
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
from datetime import datetime
from pydantic import BaseModel
import markdown
import PyPDF2
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")
# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent after setting up path
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Initialize FastAPI app
app = FastAPI(
title="TxAgent API",
description="API for TxAgent medical document analysis",
version="2.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request models
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
format: Optional[str] = "clean" # Options: raw, clean, structured, html
# Response cleaning functions
def clean_text_response(text: str) -> str:
"""Basic text cleaning"""
text = re.sub(r'\n\s*\n', '\n\n', text)
text = re.sub(r'[ ]+', ' ', text)
text = text.replace("**", "").replace("__", "")
return text.strip()
def structure_medical_response(text: str) -> Dict:
"""Structure medical content into categories"""
result = {"overview": "", "symptoms": [], "types": {}, "notes": ""}
overview_end = text.find("Type 1 Diabetes:")
result["overview"] = clean_text_response(text[:overview_end])
type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:])
current_type = None
for section in type_sections:
if section.startswith("Type "):
current_type = section.replace(":", "").strip().lower()
result["types"][current_type] = []
elif current_type:
points = [p.strip() for p in section.split('\n') if p.strip()]
result["types"][current_type] = points
return result
# Initialize agent at startup
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
logger.info("Initializing TxAgent...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={},
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize agent: {str(e)}")
raise RuntimeError(f"Failed to initialize agent: {str(e)}")
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Handle chat conversations with formatting options"""
try:
logger.info(f"Chat request received (format: {request.format})")
raw_response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
response_content = formatted_response.get(request.format, formatted_response["clean"])
return JSONResponse({
"status": "success",
"format": request.format,
"response": response_content,
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"Chat error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
"""Handle file uploads and process with TxAgent"""
try:
logger.info(f"File upload received: {file.filename}")
content = ""
if file.filename.endswith('.pdf'):
pdf_reader = PyPDF2.PdfReader(file.file)
for page in pdf_reader.pages:
content += page.extract_text() or ""
else:
content = await file.read()
content = content.decode('utf-8', errors='ignore')
message = f"Analyze the following medical document content:\n\n{content[:10000]}"
raw_response = agent.chat(
message=message,
history=[],
temperature=0.7,
max_new_tokens=512
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
response_content = formatted_response["clean"]
return JSONResponse({
"status": "success",
"format": "clean",
"response": response_content,
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"File upload error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
file.file.close()
@app.get("/status")
async def service_status():
"""Check service status"""
return {
"status": "running",
"version": "2.0.0",
"model": agent.model_name if agent else "not loaded",
"formats_available": ["raw", "clean", "structured", "html"],
"timestamp": datetime.now().isoformat()
}