File size: 6,236 Bytes
f126604 d377221 f126604 5620229 7e095f4 60e4c3d 064da3f 7757822 f126604 7757822 5620229 064da3f f126604 60e4c3d 7e095f4 bdcc052 7e095f4 60e4c3d 520104a 60e4c3d 7e095f4 520104a 60e4c3d f126604 60e4c3d f126604 60e4c3d 5620229 60e4c3d 5620229 60e4c3d 5620229 60e4c3d f126604 7e095f4 f3089ba 7e095f4 f3089ba 7e095f4 60e4c3d 7e095f4 bdcc052 7e095f4 60e4c3d bdcc052 7757822 5620229 7757822 5620229 7757822 5620229 60e4c3d 5620229 7757822 bdcc052 f126604 064da3f 60e4c3d 064da3f 60e4c3d 064da3f 60e4c3d 064da3f f126604 60e4c3d bdcc052 f3089ba 60e4c3d d377221 7e095f4 60e4c3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import sys
import json
import re
import logging
from datetime import datetime
from typing import List, Dict, Optional
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import markdown
import PyPDF2
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")
# Adjust sys path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Init FastAPI
app = FastAPI(title="TxAgent API", version="2.1.0")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request schema
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
format: Optional[str] = "clean"
# Response formatting
def clean_text_response(text: str) -> str:
text = re.sub(r'\n\s*\n', '\n\n', text)
text = re.sub(r'[ ]+', ' ', text)
text = text.replace("**", "").replace("__", "")
return text.strip()
def structure_medical_response(text: str) -> Dict:
return {
"summary": extract_section(text, "Summary"),
"risks": extract_section(text, "Risks or Red Flags"),
"missed_issues": extract_section(text, "What the doctor might have missed"),
"recommendations": extract_section(text, "Suggested Clinical Actions")
}
def extract_section(text: str, heading: str) -> str:
try:
pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)"
match = re.search(pattern, text, re.DOTALL)
return clean_text_response(match.group(1)) if match else ""
except Exception as e:
logger.error(f"Section extraction failed: {e}")
return ""
# Agent init
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.chat_prompt = (
"You are a clinical decision support assistant for doctors. "
"You analyze patient documents, detect medical issues, identify missed diagnoses, "
"and provide treatment suggestions with rationale in concise, readable language."
)
agent.init_model()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Startup error: {str(e)}")
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
try:
raw_response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
return JSONResponse({
"status": "success",
"format": request.format,
"response": formatted_response.get(request.format, formatted_response["clean"]),
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"Chat error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
logger.info(f"File upload received: {file.filename}")
content = ""
if file.filename.endswith(".pdf"):
pdf_reader = PyPDF2.PdfReader(file.file)
for page in pdf_reader.pages:
content += page.extract_text() or ""
else:
content = await file.read()
content = content.decode("utf-8", errors="ignore")
message = f"""
You are a clinical decision support AI assisting physicians.
Given the following patient report, do the following:
1. Summarize the patient's main conditions and history.
2. Identify any potential clinical risks or red flags.
3. Highlight any important diagnoses or treatments the doctor might have missed.
4. Suggest next clinical steps, treatments, or referrals (if applicable).
5. Flag anything that could pose an urgent risk (e.g., suicide risk, untreated critical conditions).
Patient Document:
-----------------
{content[:10000]}
"""
raw_response = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
return JSONResponse({
"status": "success",
"format": "clean",
"response": formatted_response["clean"],
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"File upload error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
file.file.close()
@app.get("/status")
async def status():
return {
"status": "running",
"version": "2.1.0",
"model": agent.model_name if agent else "not loaded",
"timestamp": datetime.now().isoformat()
}
|