File size: 6,024 Bytes
f126604
d377221
f126604
7e095f4
ab172ce
60e4c3d
 
dfff005
ab172ce
f0898a3
f126604
7757822
ab172ce
7e095f4
dfff005
ab172ce
 
f0898a3
ab172ce
 
dfff005
f0898a3
60e4c3d
f126604
 
 
ab172ce
 
f126604
 
f0898a3
5620229
 
 
 
 
60e4c3d
5620229
ab172ce
 
 
 
 
 
 
5620229
 
ab172ce
dfff005
 
 
f0898a3
6c1d81c
f0898a3
ab172ce
6c1d81c
dfff005
5620229
 
60e4c3d
4edd370
60e4c3d
 
ab172ce
60e4c3d
 
ab172ce
 
 
 
 
dfff005
ab172ce
60e4c3d
f0898a3
 
 
 
 
dfff005
ab172ce
 
f0898a3
ab172ce
 
 
 
dfff005
 
 
 
 
ab172ce
f0898a3
ab172ce
 
 
 
 
f0898a3
ab172ce
dfff005
 
f0898a3
60e4c3d
dfff005
60e4c3d
ab172ce
 
 
 
 
f126604
 
 
f0898a3
ab172ce
dfff005
 
 
 
ab172ce
dfff005
ab172ce
 
 
dfff005
 
ab172ce
dfff005
 
ab172ce
 
f0898a3
 
 
ab172ce
f0898a3
ab172ce
 
 
 
 
 
 
 
bdcc052
ea3d9f9
ab172ce
 
ea3d9f9
ab172ce
ea3d9f9
ab172ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfff005
ea3d9f9
 
 
ab172ce
 
f126604
ab172ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import sys
import json
import logging
import re
from datetime import datetime
from typing import List, Dict, Optional

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio

from txagent.txagent import TxAgent
from db.mongo import get_mongo_client

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("TxAgentAPI")

# App
app = FastAPI(title="TxAgent API", version="2.1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"]
)

# Pydantic
class ChatRequest(BaseModel):
    message: str
    temperature: float = 0.7
    max_new_tokens: int = 512
    history: Optional[List[Dict]] = None
    format: Optional[str] = "clean"

# Globals
agent = None
patients_collection = None
analysis_collection = None

# Helpers
def clean_text_response(text: str) -> str:
    text = re.sub(r'\n\s*\n', '\n\n', text)
    text = re.sub(r'[ ]+', ' ', text)
    return text.replace("**", "").replace("__", "").strip()

def extract_section(text: str, heading: str) -> str:
    try:
        pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)"
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        return match.group(1).strip() if match else ""
    except Exception as e:
        logger.error(f"Section extraction failed for heading '{heading}': {e}")
        return ""

def structure_medical_response(text: str) -> Dict:
    return {
        "summary": extract_section(text, "Summarize the patient's medical history"),
        "risks": extract_section(text, "Risks or Red Flags"),
        "missed_issues": extract_section(text, "What the doctor might have missed"),
        "recommendations": extract_section(text, "Suggested Clinical Actions")
    }

def serialize_patient(patient: dict) -> dict:
    patient_copy = patient.copy()
    if "_id" in patient_copy:
        patient_copy["_id"] = str(patient_copy["_id"])
    return patient_copy

async def analyze_patient(patient: dict):
    try:
        serialized = serialize_patient(patient)
        doc = json.dumps(serialized, indent=2)
        logger.info(f"🧾 Analyzing patient: {serialized.get('fhir_id')}")
        logger.debug(f"🧠 Data passed to TxAgent:\n{doc[:1000]}")

        message = (
            "You are a clinical decision support AI.\n\n"
            "Given the patient document below:\n"
            "1. Summarize the patient's medical history.\n"
            "2. Identify risks or red flags.\n"
            "3. Highlight missed diagnoses or treatments.\n"
            "4. Suggest next clinical steps.\n"
            f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
        )

        raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
        structured = structure_medical_response(raw)

        analysis_doc = {
            "patient_id": serialized.get("fhir_id"),
            "timestamp": datetime.utcnow(),
            "summary": structured,
            "raw": raw
        }
        await analysis_collection.update_one(
            {"patient_id": serialized.get("fhir_id")},
            {"$set": analysis_doc},
            upsert=True
        )
        logger.info(f"✅ Stored analysis for patient {serialized.get('fhir_id')}")
    except Exception as e:
        logger.error(f"Error analyzing patient: {e}")

async def analyze_all_patients():
    patients = await patients_collection.find({}).to_list(length=None)
    for patient in patients:
        await analyze_patient(patient)
        await asyncio.sleep(0.1)

@app.on_event("startup")
async def startup_event():
    global agent, patients_collection, analysis_collection

    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        enable_finish=True,
        enable_rag=False,
        force_finish=True,
        enable_checker=True,
        step_rag_num=4,
        seed=42
    )
    agent.chat_prompt = (
        "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
    )
    agent.init_model()
    logger.info("✅ TxAgent initialized")

    db = get_mongo_client()["cps_db"]
    patients_collection = db["patients"]
    analysis_collection = db["patient_analysis_results"]
    logger.info("📡 Connected to MongoDB")

    asyncio.create_task(analyze_all_patients())

@app.get("/status")
async def status():
    return {
        "status": "running",
        "timestamp": datetime.utcnow().isoformat()
    }

@app.post("/chat-stream")
async def chat_stream_endpoint(request: ChatRequest):
    async def token_stream():
        try:
            conversation = [{"role": "system", "content": agent.chat_prompt}]
            if request.history:
                conversation.extend(request.history)
            conversation.append({"role": "user", "content": request.message})

            input_ids = agent.tokenizer.apply_chat_template(
                conversation, add_generation_prompt=True, return_tensors="pt"
            ).to(agent.device)

            output = agent.model.generate(
                input_ids,
                do_sample=True,
                temperature=request.temperature,
                max_new_tokens=request.max_new_tokens,
                pad_token_id=agent.tokenizer.eos_token_id,
                return_dict_in_generate=True
            )

            text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True)
            for chunk in text.split():
                yield chunk + " "
                await asyncio.sleep(0.05)
        except Exception as e:
            logger.error(f"Streaming error: {e}")
            yield f"⚠️ Error: {e}"

    return StreamingResponse(token_stream(), media_type="text/plain")