Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import json | |
import logging | |
import re | |
import hashlib | |
from datetime import datetime | |
from typing import List, Dict, Optional, Tuple | |
from enum import Enum | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import asyncio | |
from fastapi import Query | |
from bson import ObjectId | |
from txagent.txagent import TxAgent | |
from db.mongo import get_mongo_client | |
# Logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger("TxAgentAPI") | |
# App | |
app = FastAPI(title="TxAgent API", version="2.2.1") # Version for hash-based analysis | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], allow_credentials=True, | |
allow_methods=["*"], allow_headers=["*"] | |
) | |
# Pydantic | |
class ChatRequest(BaseModel): | |
message: str | |
temperature: float = 0.7 | |
max_new_tokens: int = 512 | |
history: Optional[List[Dict]] = None | |
format: Optional[str] = "clean" | |
# Enums | |
class RiskLevel(str, Enum): | |
NONE = "none" | |
LOW = "low" | |
MODERATE = "moderate" | |
HIGH = "high" | |
SEVERE = "severe" | |
# Globals | |
agent = None | |
patients_collection = None | |
analysis_collection = None | |
alerts_collection = None | |
# Helpers | |
def clean_text_response(text: str) -> str: | |
text = re.sub(r'\n\s*\n', '\n\n', text) | |
text = re.sub(r'[ ]+', ' ', text) | |
return text.replace("**", "").replace("__", "").strip() | |
def extract_section(text: str, heading: str) -> str: | |
try: | |
pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)" | |
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
return match.group(1).strip() if match else "" | |
except Exception as e: | |
logger.error(f"Section extraction failed for heading '{heading}': {e}") | |
return "" | |
def structure_medical_response(text: str) -> Dict: | |
"""Improved version that handles both markdown and plain text formats""" | |
def extract_improved(text: str, heading: str) -> str: | |
patterns = [ | |
rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)", | |
rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)", | |
rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)", | |
rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)" | |
] | |
for pattern in patterns: | |
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
if match: | |
content = match.group(1).strip() | |
content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE) | |
return content | |
return "" | |
text = text.replace('**', '').replace('__', '') | |
return { | |
"summary": extract_improved(text, "Summary of Patient's Medical History") or | |
extract_improved(text, "Summarize the patient's medical history"), | |
"risks": extract_improved(text, "Identify Risks or Red Flags") or | |
extract_improved(text, "Risks or Red Flags"), | |
"missed_issues": extract_improved(text, "Missed Diagnoses or Treatments") or | |
extract_improved(text, "What the doctor might have missed"), | |
"recommendations": extract_improved(text, "Suggest Next Clinical Steps") or | |
extract_improved(text, "Suggested Clinical Actions") | |
} | |
def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]: | |
"""Analyze text for suicide risk factors and return assessment""" | |
suicide_keywords = [ | |
'suicide', 'suicidal', 'kill myself', 'end my life', | |
'want to die', 'self-harm', 'self harm', 'hopeless', | |
'no reason to live', 'plan to die' | |
] | |
# Check for explicit mentions | |
explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()] | |
if not explicit_mentions: | |
return RiskLevel.NONE, 0.0, [] | |
# If found, ask AI for detailed assessment | |
assessment_prompt = ( | |
"Assess the suicide risk level based on this text. " | |
"Consider frequency, specificity, and severity of statements. " | |
"Respond with JSON format: {\"risk_level\": \"low/moderate/high/severe\", " | |
"\"risk_score\": 0-1, \"factors\": [\"list of risk factors\"]}\n\n" | |
f"Text to assess:\n{text}" | |
) | |
try: | |
response = agent.chat( | |
message=assessment_prompt, | |
history=[], | |
temperature=0.2, # Lower temp for more deterministic responses | |
max_new_tokens=256 | |
) | |
# Extract JSON from response | |
json_match = re.search(r'\{.*\}', response, re.DOTALL) | |
if json_match: | |
assessment = json.loads(json_match.group()) | |
return ( | |
RiskLevel(assessment.get("risk_level", "none").lower()), | |
float(assessment.get("risk_score", 0)), | |
assessment.get("factors", []) | |
) | |
except Exception as e: | |
logger.error(f"Error in suicide risk assessment: {e}") | |
# Fallback if JSON parsing fails | |
risk_score = min(0.1 * len(explicit_mentions), 0.9) # Cap at 0.9 for fallback | |
if risk_score > 0.7: | |
return RiskLevel.HIGH, risk_score, explicit_mentions | |
elif risk_score > 0.4: | |
return RiskLevel.MODERATE, risk_score, explicit_mentions | |
return RiskLevel.LOW, risk_score, explicit_mentions | |
async def create_alert(patient_id: str, risk_data: dict): | |
"""Create an alert document in the database""" | |
alert_doc = { | |
"patient_id": patient_id, | |
"type": "suicide_risk", | |
"level": risk_data["level"], | |
"score": risk_data["score"], | |
"factors": risk_data["factors"], | |
"timestamp": datetime.utcnow(), | |
"acknowledged": False | |
} | |
await alerts_collection.insert_one(alert_doc) | |
logger.warning(f"⚠️ Created suicide risk alert for patient {patient_id}") | |
def serialize_patient(patient: dict) -> dict: | |
patient_copy = patient.copy() | |
if "_id" in patient_copy: | |
patient_copy["_id"] = str(patient_copy["_id"]) | |
return patient_copy | |
def compute_patient_data_hash(patient: dict) -> str: | |
"""Compute SHA-256 hash of patient data.""" | |
serialized = json.dumps(patient, sort_keys=True) # Sort keys for consistent hashing | |
return hashlib.sha256(serialized.encode()).hexdigest() | |
async def analyze_patient(patient: dict): | |
try: | |
serialized = serialize_patient(patient) | |
patient_id = serialized.get("fhir_id") | |
patient_hash = compute_patient_data_hash(serialized) | |
logger.info(f"🧾 Analyzing patient: {patient_id}") | |
# Check if analysis exists and hash matches | |
existing_analysis = await analysis_collection.find_one({"patient_id": patient_id}) | |
if existing_analysis and existing_analysis.get("data_hash") == patient_hash: | |
logger.info(f"✅ No changes in patient data for {patient_id}, skipping analysis") | |
return # Skip analysis if data hasn't changed | |
# Main clinical analysis | |
doc = json.dumps(serialized, indent=2) | |
message = ( | |
"You are a clinical decision support AI.\n\n" | |
"Given the patient document below:\n" | |
"1. Summarize the patient's medical history.\n" | |
"2. Identify risks or red flags (including mental health and suicide risk).\n" | |
"3. Highlight missed diagnoses or treatments.\n" | |
"4. Suggest next clinical steps.\n" | |
f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}" | |
) | |
raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024) | |
structured = structure_medical_response(raw) | |
# Suicide risk assessment | |
risk_level, risk_score, risk_factors = detect_suicide_risk(raw) | |
suicide_risk = { | |
"level": risk_level.value, | |
"score": risk_score, | |
"factors": risk_factors | |
} | |
# Store analysis with data hash | |
analysis_doc = { | |
"patient_id": patient_id, | |
"timestamp": datetime.utcnow(), | |
"summary": structured, | |
"suicide_risk": suicide_risk, | |
"raw": raw, | |
"data_hash": patient_hash # Store the hash | |
} | |
await analysis_collection.update_one( | |
{"patient_id": patient_id}, | |
{"$set": analysis_doc}, | |
upsert=True | |
) | |
# Create alert if risk is above threshold | |
if risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]: | |
await create_alert(patient_id, suicide_risk) | |
logger.info(f"✅ Stored analysis for patient {patient_id}") | |
except Exception as e: | |
logger.error(f"Error analyzing patient: {e}") | |
async def analyze_all_patients(): | |
patients = await patients_collection.find({}).to_list(length=None) | |
for patient in patients: | |
await analyze_patient(patient) | |
await asyncio.sleep(0.1) | |
async def startup_event(): | |
global agent, patients_collection, analysis_collection, alerts_collection | |
agent = TxAgent( | |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
enable_finish=True, | |
enable_rag=False, | |
force_finish=True, | |
enable_checker=True, | |
step_rag_num=4, | |
seed=42 | |
) | |
agent.chat_prompt = ( | |
"You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." | |
) | |
agent.init_model() | |
logger.info("✅ TxAgent initialized") | |
db = get_mongo_client()["cps_db"] | |
patients_collection = db["patients"] | |
analysis_collection = db["patient_analysis_results"] | |
alerts_collection = db["clinical_alerts"] | |
logger.info("📡 Connected to MongoDB") | |
asyncio.create_task(analyze_all_patients()) | |
async def status(): | |
return { | |
"status": "running", | |
"timestamp": datetime.utcnow().isoformat(), | |
"version": "2.2.1" | |
} | |
async def get_patient_analysis_results(name: Optional[str] = Query(None)): | |
try: | |
query = {} | |
# If a name filter is provided, we search the patients collection first | |
if name: | |
name_regex = re.compile(name, re.IGNORECASE) | |
matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None) | |
patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p] | |
if not patient_ids: | |
return [] | |
query = {"patient_id": {"$in": patient_ids}} | |
# Find analysis results based on patient_ids (or all if no filter) | |
analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100) | |
# Attach full_name to each analysis result | |
enriched_results = [] | |
for analysis in analyses: | |
patient = await patients_collection.find_one({"fhir_id": analysis["patient_id"]}) | |
if patient: | |
analysis["full_name"] = patient.get("full_name", "Unknown") | |
analysis["_id"] = str(analysis["_id"]) | |
enriched_results.append(analysis) | |
return enriched_results | |
except Exception as e: | |
logger.error(f"Error fetching analysis results: {e}") | |
raise HTTPException(status_code=500, detail="Failed to retrieve analysis results") | |
async def chat_stream_endpoint(request: ChatRequest): | |
async def token_stream(): | |
try: | |
conversation = [{"role": "system", "content": agent.chat_prompt}] | |
if request.history: | |
conversation.extend(request.history) | |
conversation.append({"role": "user", "content": request.message}) | |
input_ids = agent.tokenizer.apply_chat_template( | |
conversation, add_generation_prompt=True, return_tensors="pt" | |
).to(agent.device) | |
output = agent.model.generate( | |
input_ids, | |
do_sample=True, | |
temperature=request.temperature, | |
max_new_tokens=request.max_new_tokens, | |
pad_token_id=agent.tokenizer.eos_token_id, | |
return_dict_in_generate=True | |
) | |
text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) | |
for chunk in text.split(): | |
yield chunk + " " | |
await asyncio.sleep(0.05) | |
except Exception as e: | |
logger.error(f"Streaming error: {e}") | |
yield f"⚠️ Error: {e}" | |
return StreamingResponse(token_stream(), media_type="text/plain") | |