File size: 6,154 Bytes
f126604
d377221
f126604
7e095f4
ab172ce
60e4c3d
 
dfff005
ab172ce
ea3d9f9
f126604
7757822
ab172ce
 
 
7e095f4
ab172ce
 
520104a
ab172ce
dfff005
520104a
ab172ce
 
 
 
 
 
dfff005
 
60e4c3d
f126604
 
 
ab172ce
 
f126604
 
ab172ce
5620229
 
 
 
 
60e4c3d
5620229
ab172ce
 
 
 
 
 
 
 
5620229
 
ab172ce
dfff005
 
 
ab172ce
dfff005
ab172ce
 
 
dfff005
5620229
 
60e4c3d
 
 
 
ab172ce
60e4c3d
 
ab172ce
 
 
 
 
dfff005
ab172ce
60e4c3d
ab172ce
dfff005
ab172ce
 
 
 
 
 
 
dfff005
 
 
 
 
ab172ce
 
 
 
 
 
 
dfff005
ab172ce
dfff005
 
ab172ce
60e4c3d
dfff005
60e4c3d
ab172ce
 
 
 
 
f126604
ab172ce
f126604
 
ab172ce
 
 
dfff005
 
 
 
ab172ce
dfff005
ab172ce
 
 
dfff005
 
ab172ce
dfff005
 
ab172ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdcc052
ea3d9f9
ab172ce
 
ea3d9f9
ab172ce
ea3d9f9
ab172ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfff005
ea3d9f9
 
 
ab172ce
 
f126604
ab172ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import sys
import json
import logging
import re
from datetime import datetime
from typing import List, Dict, Optional

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pymongo import MongoClient
from bson import ObjectId
import asyncio

# Adjust sys path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))

# TxAgent
from txagent.txagent import TxAgent

# MongoDB
from db.mongo import get_mongo_client

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("TxAgentAPI")

# FastAPI app
app = FastAPI(title="TxAgent API", version="2.1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"]
)

# Models
class ChatRequest(BaseModel):
    message: str
    temperature: float = 0.7
    max_new_tokens: int = 512
    history: Optional[List[Dict]] = None
    format: Optional[str] = "clean"

# Globals
agent = None
mongo_client = None
patients_collection = None
analysis_collection = None

# Helpers
def clean_text_response(text: str) -> str:
    text = re.sub(r'\n\s*\n', '\n\n', text)
    text = re.sub(r'[ ]+', ' ', text)
    return text.replace("**", "").replace("__", "").strip()

def extract_section(text: str, heading: str) -> str:
    try:
        pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)"
        match = re.search(pattern, text, re.DOTALL)
        return clean_text_response(match.group(1)) if match else ""
    except Exception as e:
        logger.error(f"Section extraction failed: {e}")
        return ""

def structure_medical_response(text: str) -> Dict:
    return {
        "summary": extract_section(text, "Summary"),
        "risks": extract_section(text, "Risks or Red Flags"),
        "missed_issues": extract_section(text, "What the doctor might have missed"),
        "recommendations": extract_section(text, "Suggested Clinical Actions")
    }

def serialize_patient(patient: dict) -> dict:
    patient_copy = patient.copy()
    if "_id" in patient_copy:
        patient_copy["_id"] = str(patient_copy["_id"])
    return patient_copy

async def analyze_patient(patient: dict):
    try:
        doc = json.dumps(serialize_patient(patient), indent=2)
        message = (
            "You are a clinical decision support AI.\n\n"
            "Given the patient document below:\n"
            "1. Summarize their medical history.\n"
            "2. Identify risks or red flags.\n"
            "3. Highlight missed diagnoses or treatments.\n"
            "4. Suggest next clinical steps.\n"
            f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
        )

        raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
        structured = structure_medical_response(raw)

        analysis_doc = {
            "patient_id": patient.get("fhir_id"),
            "timestamp": datetime.utcnow(),
            "summary": structured,
            "raw": raw
        }
        await analysis_collection.update_one(
            {"patient_id": patient.get("fhir_id")},
            {"$set": analysis_doc},
            upsert=True
        )
        logger.info(f"✔️ Analysis stored for patient {patient.get('fhir_id')}")
    except Exception as e:
        logger.error(f"Error analyzing patient: {e}")

async def analyze_all_patients():
    patients = await patients_collection.find({}).to_list(length=None)
    for patient in patients:
        await analyze_patient(patient)
        await asyncio.sleep(0.1)

# Startup logic
@app.on_event("startup")
async def startup_event():
    global agent, mongo_client, patients_collection, analysis_collection

    # Init agent
    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        enable_finish=True,
        enable_rag=False,
        force_finish=True,
        enable_checker=True,
        step_rag_num=4,
        seed=42
    )
    agent.chat_prompt = (
        "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
    )
    agent.init_model()
    logger.info("✅ TxAgent initialized")

    # MongoDB
    mongo_client = get_mongo_client()
    db = mongo_client.get_default_database()
    patients_collection = db.get_collection("patients")
    analysis_collection = db.get_collection("patient_analysis_results")

    logger.info("📡 Connected to MongoDB")
    asyncio.create_task(analyze_all_patients())

# Endpoints
@app.get("/status")
async def status():
    return {
        "status": "running",
        "version": "2.1.0",
        "timestamp": datetime.utcnow().isoformat()
    }

@app.post("/chat-stream")
async def chat_stream_endpoint(request: ChatRequest):
    async def token_stream():
        try:
            conversation = [{"role": "system", "content": agent.chat_prompt}]
            if request.history:
                conversation.extend(request.history)
            conversation.append({"role": "user", "content": request.message})

            input_ids = agent.tokenizer.apply_chat_template(
                conversation, add_generation_prompt=True, return_tensors="pt"
            ).to(agent.device)

            output = agent.model.generate(
                input_ids,
                do_sample=True,
                temperature=request.temperature,
                max_new_tokens=request.max_new_tokens,
                pad_token_id=agent.tokenizer.eos_token_id,
                return_dict_in_generate=True
            )

            text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True)
            for chunk in text.split():
                yield chunk + " "
                await asyncio.sleep(0.05)
        except Exception as e:
            logger.error(f"Streaming error: {e}")
            yield f"⚠️ Error: {e}"

    return StreamingResponse(token_stream(), media_type="text/plain")