File size: 5,031 Bytes
a59d348
 
 
 
2a02fdb
d8f62d2
c1d7bb9
f6aff73
1eccd08
abc137a
 
a59d348
 
8bfd5d0
a59d348
 
d899239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a59d348
2a02fdb
 
 
 
 
 
d8f62d2
c1d7bb9
 
1eccd08
abc137a
1eccd08
c1d7bb9
abc137a
 
 
 
 
 
 
 
2a02fdb
d8f62d2
abc137a
d8f62d2
f6aff73
1caf113
f6aff73
d8f62d2
 
f6aff73
 
 
 
 
 
 
 
 
 
d8f62d2
c1d7bb9
d8f62d2
 
 
 
 
 
 
8bfd5d0
c1d7bb9
d8f62d2
c1d7bb9
d8f62d2
c1d7bb9
abc137a
 
 
 
d8f62d2
f6aff73
 
 
 
d899239
f6aff73
 
 
 
 
 
 
 
 
 
 
 
 
a59d348
d899239
a59d348
 
4878407
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import logging
from fastapi import FastAPI
from txagent.txagent import TxAgent
from db.mongo import get_mongo_client
import asyncio
from pyfcm import FCMNotification
from datetime import datetime
import httpx
import os
import json
import tempfile

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s")
logger = logging.getLogger("TxAgentAPI")

# Initialize agent synchronously with error handling
try:
    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        enable_finish=True,
        enable_rag=False,
        force_finish=True,
        enable_checker=True,
        step_rag_num=4,
        seed=42
    )
    agent.init_model()
    agent.chat_prompt = (
        "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
    )
    logger.info("βœ… TxAgent initialized synchronously with chat_prompt: %s", agent.chat_prompt)
except Exception as e:
    logger.error("❌ Failed to initialize TxAgent: %s", str(e))
    raise

# Initialize collections synchronously
db = get_mongo_client()["cps_db"]
users_collection = db["users"]
patients_collection = db["patients"]
analysis_collection = db["patient_analysis_results"]
alerts_collection = db["clinical_alerts"]
notifications_collection = db["notifications"]
logger.info("πŸ“‘ Connected to MongoDB synchronously at %s", datetime.utcnow().isoformat())

# Retrieve secrets from environment variables (set in Hugging Face Space Secrets)
FCM_SERVICE_ACCOUNT_JSON = os.getenv("FCM_SERVICE_ACCOUNT")
SECRET_KEY = os.getenv("SECRET_KEY")

if not FCM_SERVICE_ACCOUNT_JSON or not SECRET_KEY:
    logger.error("❌ Missing FCM_SERVICE_ACCOUNT or SECRET_KEY in environment variables")
    raise ValueError("FCM_SERVICE_ACCOUNT and SECRET_KEY must be set in Hugging Face Space Secrets")

# Write the FCM service account JSON to a temporary file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_file:
    json.dump(json.loads(FCM_SERVICE_ACCOUNT_JSON), temp_file)
    temp_file_path = temp_file.name

# FCM settings
push_service = FCMNotification(service_account_file=temp_file_path)

# Auth Space URL
AUTH_SPACE_URL = "https://rocketfarmstudios-cps-api.hf.space/auth"

async def send_push_notification(recipient_email, message):
    try:
        # Fetch the user's device token from the auth Space
        async with httpx.AsyncClient() as client:
            headers = {"Authorization": f"Bearer {create_access_token(recipient_email)}"}
            response = await client.get(f"{AUTH_SPACE_URL}/me", headers=headers)
            if response.status_code != 200:
                logger.warning(f"Failed to fetch user {recipient_email} from auth Space: {response.text}")
                return
            user_data = response.json()
            device_token = user_data.get("device_token")

        if not device_token:
            logger.warning(f"No device token found for {recipient_email} at {datetime.utcnow().isoformat()}")
            return

        # Send push notification
        result = push_service.notify_single_device(
            registration_id=device_token,
            message_title="Risk Alert",
            message_body=message,
            sound="default",
            priority="high"
        )
        logger.info(f"Push notification sent to {recipient_email} at {datetime.utcnow().isoformat()}: {result}")
    except Exception as e:
        logger.error(f"Failed to send push notification to {recipient_email} at {datetime.utcnow().isoformat()}: {str(e)}")
    finally:
        # Clean up the temporary file
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)

# JWT settings (minimal implementation for token creation)
from datetime import timedelta
import jwt

ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

def create_access_token(email: str):
    to_encode = {"sub": email, "exp": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)}
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

def decode_access_token(token: str):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        return payload
    except jwt.PyJWTError as e:
        logger.error(f"Token decode error: {str(e)}")
        return None

def setup_app(app: FastAPI):
    @app.on_event("startup")
    async def startup_event():
        # Temporarily disable the analyze_all_patients task due to missing analyze_patient function
        # asyncio.create_task(analyze_all_patients())
        logger.info("Startup event triggered, analyze_all_patients task disabled until implementation.")

    # async def analyze_all_patients():
    #     from analysis import analyze_patient
    #     patients = await patients_collection.find({}).to_list(length=None)
    #     for patient in patients:
    #         await analyze_patient(patient)
    #         await asyncio.sleep(0.1)