TxAgent-Api / app.py
Ali2206's picture
Update app.py
60e4c3d verified
raw
history blame
6.24 kB
import os
import sys
import json
import re
import logging
from datetime import datetime
from typing import List, Dict, Optional
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import markdown
import PyPDF2
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")
# Adjust sys path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Init FastAPI
app = FastAPI(title="TxAgent API", version="2.1.0")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request schema
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
format: Optional[str] = "clean"
# Response formatting
def clean_text_response(text: str) -> str:
text = re.sub(r'\n\s*\n', '\n\n', text)
text = re.sub(r'[ ]+', ' ', text)
text = text.replace("**", "").replace("__", "")
return text.strip()
def structure_medical_response(text: str) -> Dict:
return {
"summary": extract_section(text, "Summary"),
"risks": extract_section(text, "Risks or Red Flags"),
"missed_issues": extract_section(text, "What the doctor might have missed"),
"recommendations": extract_section(text, "Suggested Clinical Actions")
}
def extract_section(text: str, heading: str) -> str:
try:
pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)"
match = re.search(pattern, text, re.DOTALL)
return clean_text_response(match.group(1)) if match else ""
except Exception as e:
logger.error(f"Section extraction failed: {e}")
return ""
# Agent init
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.chat_prompt = (
"You are a clinical decision support assistant for doctors. "
"You analyze patient documents, detect medical issues, identify missed diagnoses, "
"and provide treatment suggestions with rationale in concise, readable language."
)
agent.init_model()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Startup error: {str(e)}")
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
try:
raw_response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
return JSONResponse({
"status": "success",
"format": request.format,
"response": formatted_response.get(request.format, formatted_response["clean"]),
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"Chat error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
logger.info(f"File upload received: {file.filename}")
content = ""
if file.filename.endswith(".pdf"):
pdf_reader = PyPDF2.PdfReader(file.file)
for page in pdf_reader.pages:
content += page.extract_text() or ""
else:
content = await file.read()
content = content.decode("utf-8", errors="ignore")
message = f"""
You are a clinical decision support AI assisting physicians.
Given the following patient report, do the following:
1. Summarize the patient's main conditions and history.
2. Identify any potential clinical risks or red flags.
3. Highlight any important diagnoses or treatments the doctor might have missed.
4. Suggest next clinical steps, treatments, or referrals (if applicable).
5. Flag anything that could pose an urgent risk (e.g., suicide risk, untreated critical conditions).
Patient Document:
-----------------
{content[:10000]}
"""
raw_response = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
return JSONResponse({
"status": "success",
"format": "clean",
"response": formatted_response["clean"],
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"File upload error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
file.file.close()
@app.get("/status")
async def status():
return {
"status": "running",
"version": "2.1.0",
"model": agent.model_name if agent else "not loaded",
"timestamp": datetime.now().isoformat()
}