File size: 4,321 Bytes
a59d348
 
 
 
2a02fdb
d8f62d2
8bfd5d0
c1d7bb9
f6aff73
a59d348
 
8bfd5d0
a59d348
 
d899239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a59d348
2a02fdb
 
 
 
 
 
d8f62d2
c1d7bb9
 
 
 
 
 
 
 
 
2a02fdb
d8f62d2
 
 
f6aff73
48c0fb6
f6aff73
d8f62d2
 
f6aff73
 
 
 
 
 
 
 
 
 
d8f62d2
c1d7bb9
d8f62d2
 
 
 
 
 
 
8bfd5d0
c1d7bb9
d8f62d2
c1d7bb9
d8f62d2
c1d7bb9
d8f62d2
f6aff73
 
 
 
d899239
f6aff73
 
 
 
 
 
 
 
 
 
 
 
 
a59d348
d899239
a59d348
 
 
 
 
2a02fdb
a59d348
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
from fastapi import FastAPI
from txagent.txagent import TxAgent
from db.mongo import get_mongo_client
import asyncio
from pyfcm import FCMNotification
from huggingface_hub import get_secret
from datetime import datetime
import httpx

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s")
logger = logging.getLogger("TxAgentAPI")

# Initialize agent synchronously with error handling
try:
    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        enable_finish=True,
        enable_rag=False,
        force_finish=True,
        enable_checker=True,
        step_rag_num=4,
        seed=42
    )
    agent.init_model()
    agent.chat_prompt = (
        "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
    )
    logger.info("βœ… TxAgent initialized synchronously with chat_prompt: %s", agent.chat_prompt)
except Exception as e:
    logger.error("❌ Failed to initialize TxAgent: %s", str(e))
    raise

# Initialize collections synchronously
db = get_mongo_client()["cps_db"]
users_collection = db["users"]
patients_collection = db["patients"]
analysis_collection = db["patient_analysis_results"]
alerts_collection = db["clinical_alerts"]
notifications_collection = db["notifications"]
logger.info("πŸ“‘ Connected to MongoDB synchronously at %s", datetime.utcnow().isoformat())

# Retrieve secrets from Hugging Face
FCM_API_KEY = get_secret("FCM_API_KEY")
SECRET_KEY = get_secret("SECRET_KEY")

if not FCM_API_KEY or not SECRET_KEY:
    logger.error("❌ Missing FCM_API_KEY or SECRET_KEY in Hugging Face Secrets")
    raise ValueError("FCM_API_KEY and SECRET_KEY must be set in Hugging Face Secrets")

# FCM settings
push_service = FCMNotification(api_key=FCM_API_KEY)

# Auth Space URL
AUTH_SPACE_URL = "https://rocketfarmstudios-cps-api.hf.space/auth"

async def send_push_notification(recipient_email, message):
    try:
        # Fetch the user's device token from the auth Space
        async with httpx.AsyncClient() as client:
            headers = {"Authorization": f"Bearer {create_access_token(recipient_email)}"}
            response = await client.get(f"{AUTH_SPACE_URL}/me", headers=headers)
            if response.status_code != 200:
                logger.warning(f"Failed to fetch user {recipient_email} from auth Space: {response.text}")
                return
            user_data = response.json()
            device_token = user_data.get("device_token")

        if not device_token:
            logger.warning(f"No device token found for {recipient_email} at {datetime.utcnow().isoformat()}")
            return

        # Send push notification
        result = push_service.notify_single_device(
            registration_id=device_token,
            message_title="Risk Alert",
            message_body=message,
            sound="default",
            priority="high"
        )
        logger.info(f"Push notification sent to {recipient_email} at {datetime.utcnow().isoformat()}: {result}")
    except Exception as e:
        logger.error(f"Failed to send push notification to {recipient_email} at {datetime.utcnow().isoformat()}: {str(e)}")

# JWT settings (minimal implementation for token creation)
from datetime import timedelta
import jwt

ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

def create_access_token(email: str):
    to_encode = {"sub": email, "exp": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)}
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

def decode_access_token(token: str):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        return payload
    except jwt.PyJWTError as e:
        logger.error(f"Token decode error: {str(e)}")
        return None

def setup_app(app: FastAPI):
    @app.on_event("startup")
    async def startup_event():
        asyncio.create_task(analyze_all_patients())

    async def analyze_all_patients():
        from analysis import analyze_patient
        patients = await patients_collection.find({}).to_list(length=None)
        for patient in patients:
            await analyze_patient(patient)
            await asyncio.sleep(0.1)