File size: 2,182 Bytes
cee9c20
a59d348
 
 
 
2a02fdb
a59d348
 
cee9c20
a59d348
 
d899239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a59d348
2a02fdb
 
 
 
 
8766e92
2a02fdb
8766e92
cee9c20
f6aff73
cee9c20
 
d899239
a59d348
392b971
 
 
 
 
 
 
 
d899239
a59d348
 
cee9c20
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import logging
from fastapi import FastAPI
from txagent.txagent import TxAgent
from db.mongo import get_mongo_client
import asyncio

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("TxAgentAPI")

# Initialize agent synchronously with error handling
try:
    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        enable_finish=True,
        enable_rag=False,
        force_finish=True,
        enable_checker=True,
        step_rag_num=4,
        seed=42
    )
    agent.init_model()
    agent.chat_prompt = (
        "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
    )
    logger.info("✅ TxAgent initialized synchronously with chat_prompt: %s", agent.chat_prompt)
except Exception as e:
    logger.error("❌ Failed to initialize TxAgent: %s", str(e))
    raise

# Initialize collections synchronously
db = get_mongo_client()["cps_db"]
users_collection = db["users"]
patients_collection = db["patients"]
analysis_collection = db["patient_analysis_results"]
chats_collection = db["chats"]  # New collection for chats
alerts_collection = db["clinical_alerts"]
notifications_collection = db["notifications"]
logger.info("📡 Connected to MongoDB synchronously")

# JWT settings
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key")
ALGORITHM = "HS256"

# Determine WebSocket URL based on environment
if os.getenv('SPACES') == '1':
    # Hugging Face Spaces production
    WS_URL = "wss://rocketfarmstudios-txagent-api.hf.space/queue/join"
else:
    # Local development
    WS_URL = "ws://localhost:8000/ws/notifications"

def setup_app(app: FastAPI):
    @app.on_event("startup")
    async def startup_event():
        asyncio.create_task(analyze_all_patients())

    async def analyze_all_patients():
        from analysis import analyze_patient
        patients = await patients_collection.find({}).to_list(length=None)
        for patient in patients:
            await analyze_patient(patient)
            await asyncio.sleep(0.1)