|
import os |
|
import sys |
|
import json |
|
import re |
|
import logging |
|
from datetime import datetime |
|
from typing import List, Dict, Optional |
|
from fastapi import FastAPI, HTTPException, UploadFile, File |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
import markdown |
|
import PyPDF2 |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger("TxAgentAPI") |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
|
|
try: |
|
from txagent.txagent import TxAgent |
|
except ImportError as e: |
|
logger.error(f"Failed to import TxAgent: {str(e)}") |
|
raise |
|
|
|
|
|
app = FastAPI(title="TxAgent API", version="2.1.0") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
message: str |
|
temperature: float = 0.7 |
|
max_new_tokens: int = 512 |
|
history: Optional[List[Dict]] = None |
|
format: Optional[str] = "clean" |
|
|
|
|
|
def clean_text_response(text: str) -> str: |
|
text = re.sub(r'\n\s*\n', '\n\n', text) |
|
text = re.sub(r'[ ]+', ' ', text) |
|
text = text.replace("**", "").replace("__", "") |
|
return text.strip() |
|
|
|
def structure_medical_response(text: str) -> Dict: |
|
return { |
|
"summary": extract_section(text, "Summary"), |
|
"risks": extract_section(text, "Risks or Red Flags"), |
|
"missed_issues": extract_section(text, "What the doctor might have missed"), |
|
"recommendations": extract_section(text, "Suggested Clinical Actions") |
|
} |
|
|
|
def extract_section(text: str, heading: str) -> str: |
|
try: |
|
pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)" |
|
match = re.search(pattern, text, re.DOTALL) |
|
return clean_text_response(match.group(1)) if match else "" |
|
except Exception as e: |
|
logger.error(f"Section extraction failed: {e}") |
|
return "" |
|
|
|
|
|
agent = None |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global agent |
|
try: |
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
enable_finish=True, |
|
enable_rag=False, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=100 |
|
) |
|
agent.chat_prompt = ( |
|
"You are a clinical decision support assistant for doctors. " |
|
"You analyze patient documents, detect medical issues, identify missed diagnoses, " |
|
"and provide treatment suggestions with rationale in concise, readable language." |
|
) |
|
agent.init_model() |
|
logger.info("TxAgent initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Startup error: {str(e)}") |
|
|
|
@app.post("/chat") |
|
async def chat_endpoint(request: ChatRequest): |
|
try: |
|
raw_response = agent.chat( |
|
message=request.message, |
|
history=request.history, |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens |
|
) |
|
formatted_response = { |
|
"raw": raw_response, |
|
"clean": clean_text_response(raw_response), |
|
"structured": structure_medical_response(raw_response), |
|
"html": markdown.markdown(raw_response) |
|
} |
|
return JSONResponse({ |
|
"status": "success", |
|
"format": request.format, |
|
"response": formatted_response.get(request.format, formatted_response["clean"]), |
|
"timestamp": datetime.now().isoformat(), |
|
"available_formats": list(formatted_response.keys()) |
|
}) |
|
except Exception as e: |
|
logger.error(f"Chat error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/upload") |
|
async def upload_file(file: UploadFile = File(...)): |
|
try: |
|
logger.info(f"File upload received: {file.filename}") |
|
content = "" |
|
if file.filename.endswith(".pdf"): |
|
pdf_reader = PyPDF2.PdfReader(file.file) |
|
for page in pdf_reader.pages: |
|
content += page.extract_text() or "" |
|
else: |
|
content = await file.read() |
|
content = content.decode("utf-8", errors="ignore") |
|
|
|
message = f""" |
|
You are a clinical decision support AI assisting physicians. |
|
|
|
Given the following patient report, do the following: |
|
1. Summarize the patient's main conditions and history. |
|
2. Identify any potential clinical risks or red flags. |
|
3. Highlight any important diagnoses or treatments the doctor might have missed. |
|
4. Suggest next clinical steps, treatments, or referrals (if applicable). |
|
5. Flag anything that could pose an urgent risk (e.g., suicide risk, untreated critical conditions). |
|
|
|
Patient Document: |
|
----------------- |
|
{content[:10000]} |
|
""" |
|
|
|
raw_response = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024) |
|
formatted_response = { |
|
"raw": raw_response, |
|
"clean": clean_text_response(raw_response), |
|
"structured": structure_medical_response(raw_response), |
|
"html": markdown.markdown(raw_response) |
|
} |
|
return JSONResponse({ |
|
"status": "success", |
|
"format": "clean", |
|
"response": formatted_response["clean"], |
|
"timestamp": datetime.now().isoformat(), |
|
"available_formats": list(formatted_response.keys()) |
|
}) |
|
except Exception as e: |
|
logger.error(f"File upload error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
finally: |
|
file.file.close() |
|
|
|
@app.get("/status") |
|
async def status(): |
|
return { |
|
"status": "running", |
|
"version": "2.1.0", |
|
"model": agent.model_name if agent else "not loaded", |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|