import os import sys import json import logging import re from datetime import datetime from typing import List, Dict, Optional from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from pymongo import MongoClient from bson import ObjectId import asyncio # Adjust sys path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))) # TxAgent from txagent.txagent import TxAgent # MongoDB from db.mongo import get_mongo_client # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger("TxAgentAPI") # FastAPI app app = FastAPI(title="TxAgent API", version="2.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) # Models class ChatRequest(BaseModel): message: str temperature: float = 0.7 max_new_tokens: int = 512 history: Optional[List[Dict]] = None format: Optional[str] = "clean" # Globals agent = None mongo_client = None patients_collection = None analysis_collection = None # Helpers def clean_text_response(text: str) -> str: text = re.sub(r'\n\s*\n', '\n\n', text) text = re.sub(r'[ ]+', ' ', text) return text.replace("**", "").replace("__", "").strip() def extract_section(text: str, heading: str) -> str: try: pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)" match = re.search(pattern, text, re.DOTALL) return clean_text_response(match.group(1)) if match else "" except Exception as e: logger.error(f"Section extraction failed: {e}") return "" def structure_medical_response(text: str) -> Dict: return { "summary": extract_section(text, "Summary"), "risks": extract_section(text, "Risks or Red Flags"), "missed_issues": extract_section(text, "What the doctor might have missed"), "recommendations": extract_section(text, "Suggested Clinical Actions") } def serialize_patient(patient: dict) -> dict: patient_copy = patient.copy() if "_id" in patient_copy: patient_copy["_id"] = str(patient_copy["_id"]) return patient_copy async def analyze_patient(patient: dict): try: doc = json.dumps(serialize_patient(patient), indent=2) message = ( "You are a clinical decision support AI.\n\n" "Given the patient document below:\n" "1. Summarize their medical history.\n" "2. Identify risks or red flags.\n" "3. Highlight missed diagnoses or treatments.\n" "4. Suggest next clinical steps.\n" f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}" ) raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024) structured = structure_medical_response(raw) analysis_doc = { "patient_id": patient.get("fhir_id"), "timestamp": datetime.utcnow(), "summary": structured, "raw": raw } await analysis_collection.update_one( {"patient_id": patient.get("fhir_id")}, {"$set": analysis_doc}, upsert=True ) logger.info(f"✔️ Analysis stored for patient {patient.get('fhir_id')}") except Exception as e: logger.error(f"Error analyzing patient: {e}") async def analyze_all_patients(): patients = await patients_collection.find({}).to_list(length=None) for patient in patients: await analyze_patient(patient) await asyncio.sleep(0.1) # Startup logic @app.on_event("startup") async def startup_event(): global agent, mongo_client, patients_collection, analysis_collection # Init agent agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", enable_finish=True, enable_rag=False, force_finish=True, enable_checker=True, step_rag_num=4, seed=42 ) agent.chat_prompt = ( "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." ) agent.init_model() logger.info("✅ TxAgent initialized") # MongoDB mongo_client = get_mongo_client() db = mongo_client.get_default_database() patients_collection = db.get_collection("patients") analysis_collection = db.get_collection("patient_analysis_results") logger.info("📡 Connected to MongoDB") asyncio.create_task(analyze_all_patients()) # Endpoints @app.get("/status") async def status(): return { "status": "running", "version": "2.1.0", "timestamp": datetime.utcnow().isoformat() } @app.post("/chat-stream") async def chat_stream_endpoint(request: ChatRequest): async def token_stream(): try: conversation = [{"role": "system", "content": agent.chat_prompt}] if request.history: conversation.extend(request.history) conversation.append({"role": "user", "content": request.message}) input_ids = agent.tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(agent.device) output = agent.model.generate( input_ids, do_sample=True, temperature=request.temperature, max_new_tokens=request.max_new_tokens, pad_token_id=agent.tokenizer.eos_token_id, return_dict_in_generate=True ) text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) for chunk in text.split(): yield chunk + " " await asyncio.sleep(0.05) except Exception as e: logger.error(f"Streaming error: {e}") yield f"⚠️ Error: {e}" return StreamingResponse(token_stream(), media_type="text/plain")