File size: 6,236 Bytes
f126604
d377221
f126604
5620229
7e095f4
60e4c3d
 
064da3f
7757822
f126604
7757822
5620229
064da3f
f126604
60e4c3d
7e095f4
 
bdcc052
7e095f4
 
 
60e4c3d
520104a
 
 
 
60e4c3d
7e095f4
 
 
 
 
520104a
60e4c3d
 
f126604
60e4c3d
f126604
 
 
 
 
 
 
 
60e4c3d
5620229
 
 
 
 
60e4c3d
5620229
60e4c3d
5620229
 
 
 
 
 
 
60e4c3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f126604
 
 
 
 
 
7e095f4
f3089ba
7e095f4
 
 
 
 
 
f3089ba
7e095f4
60e4c3d
 
 
 
 
7e095f4
bdcc052
7e095f4
60e4c3d
bdcc052
7757822
 
 
5620229
7757822
 
 
 
 
5620229
 
 
 
 
 
7757822
 
5620229
60e4c3d
5620229
 
7757822
 
bdcc052
f126604
 
064da3f
 
 
 
 
60e4c3d
064da3f
 
 
 
 
60e4c3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
064da3f
 
 
 
 
 
 
 
 
60e4c3d
064da3f
 
 
 
 
 
 
 
 
f126604
60e4c3d
bdcc052
f3089ba
60e4c3d
d377221
7e095f4
60e4c3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import sys
import json
import re
import logging
from datetime import datetime
from typing import List, Dict, Optional
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import markdown
import PyPDF2

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")

# Adjust sys path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)

# Import TxAgent
try:
    from txagent.txagent import TxAgent
except ImportError as e:
    logger.error(f"Failed to import TxAgent: {str(e)}")
    raise

# Init FastAPI
app = FastAPI(title="TxAgent API", version="2.1.0")

# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request schema
class ChatRequest(BaseModel):
    message: str
    temperature: float = 0.7
    max_new_tokens: int = 512
    history: Optional[List[Dict]] = None
    format: Optional[str] = "clean"

# Response formatting
def clean_text_response(text: str) -> str:
    text = re.sub(r'\n\s*\n', '\n\n', text)
    text = re.sub(r'[ ]+', ' ', text)
    text = text.replace("**", "").replace("__", "")
    return text.strip()

def structure_medical_response(text: str) -> Dict:
    return {
        "summary": extract_section(text, "Summary"),
        "risks": extract_section(text, "Risks or Red Flags"),
        "missed_issues": extract_section(text, "What the doctor might have missed"),
        "recommendations": extract_section(text, "Suggested Clinical Actions")
    }

def extract_section(text: str, heading: str) -> str:
    try:
        pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)"
        match = re.search(pattern, text, re.DOTALL)
        return clean_text_response(match.group(1)) if match else ""
    except Exception as e:
        logger.error(f"Section extraction failed: {e}")
        return ""

# Agent init
agent = None

@app.on_event("startup")
async def startup_event():
    global agent
    try:
        agent = TxAgent(
            model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
            rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
            enable_finish=True,
            enable_rag=False,
            force_finish=True,
            enable_checker=True,
            step_rag_num=4,
            seed=100
        )
        agent.chat_prompt = (
            "You are a clinical decision support assistant for doctors. "
            "You analyze patient documents, detect medical issues, identify missed diagnoses, "
            "and provide treatment suggestions with rationale in concise, readable language."
        )
        agent.init_model()
        logger.info("TxAgent initialized successfully")
    except Exception as e:
        logger.error(f"Startup error: {str(e)}")

@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    try:
        raw_response = agent.chat(
            message=request.message,
            history=request.history,
            temperature=request.temperature,
            max_new_tokens=request.max_new_tokens
        )
        formatted_response = {
            "raw": raw_response,
            "clean": clean_text_response(raw_response),
            "structured": structure_medical_response(raw_response),
            "html": markdown.markdown(raw_response)
        }
        return JSONResponse({
            "status": "success",
            "format": request.format,
            "response": formatted_response.get(request.format, formatted_response["clean"]),
            "timestamp": datetime.now().isoformat(),
            "available_formats": list(formatted_response.keys())
        })
    except Exception as e:
        logger.error(f"Chat error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    try:
        logger.info(f"File upload received: {file.filename}")
        content = ""
        if file.filename.endswith(".pdf"):
            pdf_reader = PyPDF2.PdfReader(file.file)
            for page in pdf_reader.pages:
                content += page.extract_text() or ""
        else:
            content = await file.read()
            content = content.decode("utf-8", errors="ignore")

        message = f"""
        You are a clinical decision support AI assisting physicians.

        Given the following patient report, do the following:
        1. Summarize the patient's main conditions and history.
        2. Identify any potential clinical risks or red flags.
        3. Highlight any important diagnoses or treatments the doctor might have missed.
        4. Suggest next clinical steps, treatments, or referrals (if applicable).
        5. Flag anything that could pose an urgent risk (e.g., suicide risk, untreated critical conditions).

        Patient Document:
        -----------------
        {content[:10000]}
        """

        raw_response = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
        formatted_response = {
            "raw": raw_response,
            "clean": clean_text_response(raw_response),
            "structured": structure_medical_response(raw_response),
            "html": markdown.markdown(raw_response)
        }
        return JSONResponse({
            "status": "success",
            "format": "clean",
            "response": formatted_response["clean"],
            "timestamp": datetime.now().isoformat(),
            "available_formats": list(formatted_response.keys())
        })
    except Exception as e:
        logger.error(f"File upload error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        file.file.close()

@app.get("/status")
async def status():
    return {
        "status": "running",
        "version": "2.1.0",
        "model": agent.model_name if agent else "not loaded",
        "timestamp": datetime.now().isoformat()
    }