|
import os |
|
import logging |
|
from fastapi import FastAPI |
|
from txagent.txagent import TxAgent |
|
from db.mongo import get_mongo_client |
|
import asyncio |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger("TxAgentAPI") |
|
|
|
|
|
try: |
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
enable_finish=True, |
|
enable_rag=False, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=42 |
|
) |
|
agent.init_model() |
|
agent.chat_prompt = ( |
|
"You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." |
|
) |
|
logger.info("β
TxAgent initialized synchronously with chat_prompt: %s", agent.chat_prompt) |
|
except Exception as e: |
|
logger.error("β Failed to initialize TxAgent: %s", str(e)) |
|
raise |
|
|
|
|
|
db = get_mongo_client()["cps_db"] |
|
users_collection = db["users"] |
|
patients_collection = db["patients"] |
|
analysis_collection = db["patient_analysis_results"] |
|
chats_collection = db["chats"] |
|
alerts_collection = db["clinical_alerts"] |
|
notifications_collection = db["notifications"] |
|
logger.info("π‘ Connected to MongoDB synchronously") |
|
|
|
|
|
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key") |
|
ALGORITHM = "HS256" |
|
|
|
|
|
if os.getenv('SPACES') == '1': |
|
|
|
WS_URL = "wss://rocketfarmstudios-txagent-api.hf.space/queue/join" |
|
else: |
|
|
|
WS_URL = "ws://localhost:8000/ws/notifications" |
|
|
|
def setup_app(app: FastAPI): |
|
@app.on_event("startup") |
|
async def startup_event(): |
|
asyncio.create_task(analyze_all_patients()) |
|
|
|
async def analyze_all_patients(): |
|
from analysis import analyze_patient |
|
patients = await patients_collection.find({}).to_list(length=None) |
|
for patient in patients: |
|
await analyze_patient(patient) |
|
await asyncio.sleep(0.1) |