|
import logging |
|
from fastapi import FastAPI |
|
from txagent.txagent import TxAgent |
|
from db.mongo import get_mongo_client |
|
import asyncio |
|
from pyfcm import FCMNotification |
|
from huggingface_hub import get_secret |
|
from datetime import datetime |
|
import httpx |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s") |
|
logger = logging.getLogger("TxAgentAPI") |
|
|
|
|
|
try: |
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
enable_finish=True, |
|
enable_rag=False, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=42 |
|
) |
|
agent.init_model() |
|
agent.chat_prompt = ( |
|
"You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." |
|
) |
|
logger.info("β
TxAgent initialized synchronously with chat_prompt: %s", agent.chat_prompt) |
|
except Exception as e: |
|
logger.error("β Failed to initialize TxAgent: %s", str(e)) |
|
raise |
|
|
|
|
|
db = get_mongo_client()["cps_db"] |
|
users_collection = db["users"] |
|
patients_collection = db["patients"] |
|
analysis_collection = db["patient_analysis_results"] |
|
alerts_collection = db["clinical_alerts"] |
|
notifications_collection = db["notifications"] |
|
logger.info("π‘ Connected to MongoDB synchronously at %s", datetime.utcnow().isoformat()) |
|
|
|
|
|
FCM_API_KEY = get_secret("FCM_API_KEY") |
|
SECRET_KEY = get_secret("SECRET_KEY") |
|
|
|
if not FCM_API_KEY or not SECRET_KEY: |
|
logger.error("β Missing FCM_API_KEY or SECRET_KEY in Hugging Face Secrets") |
|
raise ValueError("FCM_API_KEY and SECRET_KEY must be set in Hugging Face Secrets") |
|
|
|
|
|
push_service = FCMNotification(api_key=FCM_API_KEY) |
|
|
|
|
|
AUTH_SPACE_URL = "https://rocketfarmstudios-cps-api.hf.space/auth" |
|
|
|
async def send_push_notification(recipient_email, message): |
|
try: |
|
|
|
async with httpx.AsyncClient() as client: |
|
headers = {"Authorization": f"Bearer {create_access_token(recipient_email)}"} |
|
response = await client.get(f"{AUTH_SPACE_URL}/me", headers=headers) |
|
if response.status_code != 200: |
|
logger.warning(f"Failed to fetch user {recipient_email} from auth Space: {response.text}") |
|
return |
|
user_data = response.json() |
|
device_token = user_data.get("device_token") |
|
|
|
if not device_token: |
|
logger.warning(f"No device token found for {recipient_email} at {datetime.utcnow().isoformat()}") |
|
return |
|
|
|
|
|
result = push_service.notify_single_device( |
|
registration_id=device_token, |
|
message_title="Risk Alert", |
|
message_body=message, |
|
sound="default", |
|
priority="high" |
|
) |
|
logger.info(f"Push notification sent to {recipient_email} at {datetime.utcnow().isoformat()}: {result}") |
|
except Exception as e: |
|
logger.error(f"Failed to send push notification to {recipient_email} at {datetime.utcnow().isoformat()}: {str(e)}") |
|
|
|
|
|
from datetime import timedelta |
|
import jwt |
|
|
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
|
def create_access_token(email: str): |
|
to_encode = {"sub": email, "exp": datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)} |
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
|
|
|
def decode_access_token(token: str): |
|
try: |
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
return payload |
|
except jwt.PyJWTError as e: |
|
logger.error(f"Token decode error: {str(e)}") |
|
return None |
|
|
|
def setup_app(app: FastAPI): |
|
@app.on_event("startup") |
|
async def startup_event(): |
|
asyncio.create_task(analyze_all_patients()) |
|
|
|
async def analyze_all_patients(): |
|
from analysis import analyze_patient |
|
patients = await patients_collection.find({}).to_list(length=None) |
|
for patient in patients: |
|
await analyze_patient(patient) |
|
await asyncio.sleep(0.1) |