File size: 1,953 Bytes
a59d348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import logging
import asyncio
from fastapi import FastAPI
from txagent.txagent import TxAgent
from db.mongo import get_mongo_client

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("TxAgentAPI")

# Globals
agent = None
patients_collection = None
analysis_collection = None
alerts_collection = None
users_collection = None

# JWT settings (must match CPS-API)
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key")
ALGORITHM = "HS256"

def setup_app(app: FastAPI):
    global agent, patients_collection, analysis_collection, alerts_collection, users_collection

    @app.on_event("startup")
    async def startup_event():
        global agent, patients_collection, analysis_collection, alerts_collection, users_collection

        agent = TxAgent(
            model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
            rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
            enable_finish=True,
            enable_rag=False,
            force_finish=True,
            enable_checker=True,
            step_rag_num=4,
            seed=42
        )
        agent.chat_prompt = (
            "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
        )
        agent.init_model()
        logger.info("✅ TxAgent initialized")

        db = get_mongo_client()["cps_db"]
        users_collection = db["users"]
        patients_collection = db["patients"]
        analysis_collection = db["patient_analysis_results"]
        alerts_collection = db["clinical_alerts"]
        logger.info("📡 Connected to MongoDB")

        asyncio.create_task(analyze_all_patients())

    async def analyze_all_patients():
        patients = await patients_collection.find({}).to_list(length=None)
        for patient in patients:
            await analyze_patient(patient)
            await asyncio.sleep(0.1)