Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| from fastapi import FastAPI | |
| from txagent.txagent import TxAgent | |
| from db.mongo import get_mongo_client | |
| import asyncio | |
| # Logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger("TxAgentAPI") | |
| # Initialize agent synchronously with error handling | |
| try: | |
| agent = TxAgent( | |
| model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| enable_finish=True, | |
| enable_rag=False, | |
| force_finish=True, | |
| enable_checker=True, | |
| step_rag_num=4, | |
| seed=42 | |
| ) | |
| agent.init_model() | |
| agent.chat_prompt = ( | |
| "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." | |
| ) | |
| logger.info("β TxAgent initialized synchronously with chat_prompt: %s", agent.chat_prompt) | |
| except Exception as e: | |
| logger.error("β Failed to initialize TxAgent: %s", str(e)) | |
| raise | |
| # Initialize collections synchronously | |
| db = get_mongo_client()["cps_db"] | |
| users_collection = db["users"] | |
| patients_collection = db["patients"] | |
| analysis_collection = db["patient_analysis_results"] | |
| chats_collection = db["chats"] # New collection for chats | |
| alerts_collection = db["clinical_alerts"] | |
| notifications_collection = db["notifications"] | |
| logger.info("π‘ Connected to MongoDB synchronously") | |
| # JWT settings | |
| SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key") | |
| ALGORITHM = "HS256" | |
| # Determine WebSocket URL based on environment | |
| if os.getenv('SPACES') == '1': | |
| # Hugging Face Spaces production | |
| WS_URL = "wss://rocketfarmstudios-txagent-api.hf.space/queue/join" | |
| else: | |
| # Local development | |
| WS_URL = "ws://localhost:8000/ws/notifications" | |
| def setup_app(app: FastAPI): | |
| async def startup_event(): | |
| asyncio.create_task(analyze_all_patients()) | |
| async def analyze_all_patients(): | |
| from analysis import analyze_patient | |
| patients = await patients_collection.find({}).to_list(length=None) | |
| for patient in patients: | |
| await analyze_patient(patient) | |
| await asyncio.sleep(0.1) |