Spaces:
Running
on
A100
Running
on
A100
import os | |
import sys | |
import json | |
import logging | |
import re | |
import hashlib | |
import io | |
import base64 | |
from datetime import datetime | |
from typing import List, Dict, Optional, Tuple | |
from enum import Enum | |
from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Form | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import asyncio | |
from bson import ObjectId | |
import speech_recognition as sr | |
from gtts import gTTS | |
from pydub import AudioSegment | |
import PyPDF2 | |
import mimetypes | |
from txagent.txagent import TxAgent | |
from db.mongo import get_mongo_client | |
# Logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger("TxAgentAPI") | |
# App | |
app = FastAPI(title="TxAgent API", version="2.6.0") # Updated version for optional patient_id | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
# Pydantic Models | |
class ChatRequest(BaseModel): | |
message: str | |
temperature: float = 0.7 | |
max_new_tokens: int = 512 | |
history: Optional[List[Dict]] = None | |
format: Optional[str] = "clean" | |
class VoiceInputRequest(BaseModel): | |
audio_format: str = "wav" | |
language: str = "en-US" | |
class VoiceOutputRequest(BaseModel): | |
text: str | |
language: str = "en" | |
slow: bool = False | |
return_format: str = "mp3" # mp3 or base64 | |
# Enums | |
class RiskLevel(str, Enum): | |
NONE = "none" | |
LOW = "low" | |
MODERATE = "moderate" | |
HIGH = "high" | |
SEVERE = "severe" | |
# Globals | |
agent = None | |
patients_collection = None | |
analysis_collection = None | |
alerts_collection = None | |
# Helpers | |
def clean_text_response(text: str) -> str: | |
text = re.sub(r'\n\s*\n', '\n\n', text) | |
text = re.sub(r'[ ]+', ' ', text) | |
return text.replace("**", "").replace("__", "").strip() | |
def extract_section(text: str, heading: str) -> str: | |
try: | |
pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)" | |
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
return match.group(1).strip() if match else "" | |
except Exception as e: | |
logger.error(f"Section extraction failed for heading '{heading}': {e}") | |
return "" | |
def structure_medical_response(text: str) -> Dict: | |
"""Improved version that handles both markdown and plain text formats""" | |
def extract_improved(text: str, heading: str) -> str: | |
patterns = [ | |
rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)", | |
rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)", | |
rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)", | |
rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)" | |
] | |
for pattern in patterns: | |
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
if match: | |
content = match.group(1).strip() | |
content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE) | |
return content | |
return "" | |
text = text.replace('**', '').replace('__', '') | |
return { | |
"summary": extract_improved(text, "Summary of Patient's Medical History") or | |
extract_improved(text, "Summarize the patient's medical history"), | |
"risks": extract_improved(text, "Identify Risks or Red Flags") or | |
extract_improved(text, "Risks or Red Flags"), | |
"missed_issues": extract_improved(text, "Missed Diagnoses or Treatments") or | |
extract_improved(text, "What the doctor might have missed"), | |
"recommendations": extract_improved(text, "Suggest Next Clinical Steps") or | |
extract_improved(text, "Suggested Clinical Actions") | |
} | |
def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]: | |
"""Analyze text for suicide risk factors and return assessment""" | |
suicide_keywords = [ | |
'suicide', 'suicidal', 'kill myself', 'end my life', | |
'want to die', 'self-harm', 'self harm', 'hopeless', | |
'no reason to live', 'plan to die' | |
] | |
explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()] | |
if not explicit_mentions: | |
return RiskLevel.NONE, 0.0, [] | |
assessment_prompt = ( | |
"Assess the suicide risk level based on this text. " | |
"Consider frequency, specificity, and severity of statements. " | |
"Respond with JSON format: {\"risk_level\": \"low/moderate/high/severe\", " | |
"\"risk_score\": 0-1, \"factors\": [\"list of risk factors\"]}\n\n" | |
f"Text to assess:\n{text}" | |
) | |
try: | |
response = agent.chat( | |
message=assessment_prompt, | |
history=[], | |
temperature=0.2, | |
max_new_tokens=256 | |
) | |
json_match = re.search(r'\{.*\}', response, re.DOTALL) | |
if json_match: | |
assessment = json.loads(json_match.group()) | |
return ( | |
RiskLevel(assessment.get("risk_level", "none").lower()), | |
float(assessment.get("risk_score", 0)), | |
assessment.get("factors", []) | |
) | |
except Exception as e: | |
logger.error(f"Error in suicide risk assessment: {e}") | |
risk_score = min(0.1 * len(explicit_mentions), 0.9) | |
if risk_score > 0.7: | |
return RiskLevel.HIGH, risk_score, explicit_mentions | |
elif risk_score > 0.4: | |
return RiskLevel.MODERATE, risk_score, explicit_mentions | |
return RiskLevel.LOW, risk_score, explicit_mentions | |
async def create_alert(patient_id: str, risk_data: dict): | |
"""Create an alert document in the database""" | |
alert_doc = { | |
"patient_id": patient_id, | |
"type": "suicide_risk", | |
"level": risk_data["level"], | |
"score": risk_data["score"], | |
"factors": risk_data["factors"], | |
"timestamp": datetime.utcnow(), | |
"acknowledged": False | |
} | |
await alerts_collection.insert_one(alert_doc) | |
logger.warning(f"⚠️ Created suicide risk alert for patient {patient_id}") | |
def serialize_patient(patient: dict) -> dict: | |
patient_copy = patient.copy() | |
if "_id" in patient_copy: | |
patient_copy["_id"] = str(patient_copy["_id"]) | |
return patient_copy | |
def compute_patient_data_hash(data: dict) -> str: | |
"""Compute SHA-256 hash of patient data or report.""" | |
serialized = json.dumps(data, sort_keys=True) | |
return hashlib.sha256(serialized.encode()).hexdigest() | |
def compute_file_content_hash(file_content: bytes) -> str: | |
"""Compute SHA-256 hash of file content.""" | |
return hashlib.sha256(file_content).hexdigest() | |
def extract_text_from_pdf(pdf_data: bytes) -> str: | |
"""Extract text from a PDF file.""" | |
try: | |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data)) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() or "" | |
return clean_text_response(text) | |
except Exception as e: | |
logger.error(f"Error extracting text from PDF: {e}") | |
raise HTTPException(status_code=400, detail="Failed to extract text from PDF") | |
async def analyze_patient_report(patient_id: Optional[str], report_content: str, file_type: str, file_content: bytes): | |
"""Analyze a patient report and store results.""" | |
try: | |
# Use file content hash as identifier if no patient_id is provided | |
identifier = patient_id if patient_id else compute_file_content_hash(file_content) | |
report_data = {"identifier": identifier, "content": report_content, "file_type": file_type} | |
report_hash = compute_patient_data_hash(report_data) | |
logger.info(f"🧾 Analyzing report for identifier: {identifier}") | |
# Check if analysis exists and hash matches | |
existing_analysis = await analysis_collection.find_one({"identifier": identifier, "report_hash": report_hash}) | |
if existing_analysis: | |
logger.info(f"✅ No changes in report data for {identifier}, skipping analysis") | |
return existing_analysis | |
# Construct analysis prompt | |
prompt = ( | |
"You are a clinical decision support AI. Analyze the following patient report:\n" | |
"1. Summarize the patient's medical history.\n" | |
"2. Identify risks or red flags (including mental health and suicide risk).\n" | |
"3. Highlight missed diagnoses or treatments.\n" | |
"4. Suggest next clinical steps.\n" | |
f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}" | |
) | |
# Perform analysis | |
raw_response = agent.chat( | |
message=prompt, | |
history=[], | |
temperature=0.7, | |
max_new_tokens=1024 | |
) | |
structured_response = structure_medical_response(raw_response) | |
# Suicide risk assessment | |
risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response) | |
suicide_risk = { | |
"level": risk_level.value, | |
"score": risk_score, | |
"factors": risk_factors | |
} | |
# Store analysis | |
analysis_doc = { | |
"identifier": identifier, | |
"patient_id": patient_id, # May be None | |
"timestamp": datetime.utcnow(), | |
"summary": structured_response, | |
"suicide_risk": suicide_risk, | |
"raw": raw_response, | |
"report_hash": report_hash, | |
"file_type": file_type | |
} | |
await analysis_collection.update_one( | |
{"identifier": identifier, "report_hash": report_hash}, | |
{"$set": analysis_doc}, | |
upsert=True | |
) | |
# Create alert for high-risk cases only if patient_id is provided | |
if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]: | |
await create_alert(patient_id, suicide_risk) | |
logger.info(f"✅ Stored analysis for identifier {identifier}") | |
return analysis_doc | |
except Exception as e: | |
logger.error(f"Error analyzing patient report: {e}") | |
raise HTTPException(status_code=500, detail="Failed to analyze patient report") | |
async def analyze_all_patients(): | |
"""Analyze all patients in the database.""" | |
patients = await patients_collection.find({}).to_list(length=None) | |
for patient in patients: | |
await analyze_patient(patient) | |
await asyncio.sleep(0.1) | |
async def analyze_patient(patient: dict): | |
"""Analyze patient data (existing logic for patient records).""" | |
try: | |
serialized = serialize_patient(patient) | |
patient_id = serialized.get("fhir_id") | |
patient_hash = compute_patient_data_hash(serialized) | |
logger.info(f"🧾 Analyzing patient: {patient_id}") | |
existing_analysis = await analysis_collection.find_one({"patient_id": patient_id}) | |
if existing_analysis and existing_analysis.get("data_hash") == patient_hash: | |
logger.info(f"✅ No changes in patient data for {patient_id}, skipping analysis") | |
return | |
doc = json.dumps(serialized, indent=2) | |
message = ( | |
"You are a clinical decision support AI.\n\n" | |
"Given the patient document below:\n" | |
"1. Summarize the patient's medical history.\n" | |
"2. Identify risks or red flags (including mental health and suicide risk).\n" | |
"3. Highlight missed diagnoses or treatments.\n" | |
"4. Suggest next clinical steps.\n" | |
f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}" | |
) | |
raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024) | |
structured = structure_medical_response(raw) | |
risk_level, risk_score, risk_factors = detect_suicide_risk(raw) | |
suicide_risk = { | |
"level": risk_level.value, | |
"score": risk_score, | |
"factors": risk_factors | |
} | |
analysis_doc = { | |
"identifier": patient_id, | |
"patient_id": patient_id, | |
"timestamp": datetime.utcnow(), | |
"summary": structured, | |
"suicide_risk": suicide_risk, | |
"raw": raw, | |
"data_hash": patient_hash | |
} | |
await analysis_collection.update_one( | |
{"identifier": patient_id}, | |
{"$set": analysis_doc}, | |
upsert=True | |
) | |
if risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]: | |
await create_alert(patient_id, suicide_risk) | |
logger.info(f"✅ Stored analysis for patient {patient_id}") | |
except Exception as e: | |
logger.error(f"Error analyzing patient: {e}") | |
def recognize_speech(audio_data: bytes, language: str = "en-US") -> str: | |
"""Convert speech to text using Google's speech recognition.""" | |
recognizer = sr.Recognizer() | |
try: | |
with io.BytesIO(audio_data) as audio_file: | |
with sr.AudioFile(audio_file) as source: | |
audio = recognizer.record(source) | |
text = recognizer.recognize_google(audio, language=language) | |
return text | |
except sr.UnknownValueError: | |
logger.error("Google Speech Recognition could not understand audio") | |
raise HTTPException(status_code=400, detail="Could not understand audio") | |
except sr.RequestError as e: | |
logger.error(f"Could not request results from Google Speech Recognition service; {e}") | |
raise HTTPException(status_code=503, detail="Speech recognition service unavailable") | |
except Exception as e: | |
logger.error(f"Error in speech recognition: {e}") | |
raise HTTPException(status_code=500, detail="Error processing speech") | |
def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes: | |
"""Convert text to speech using gTTS and return as MP3 bytes.""" | |
try: | |
tts = gTTS(text=text, lang=language, slow=slow) | |
mp3_fp = io.BytesIO() | |
tts.write_to_fp(mp3_fp) | |
mp3_fp.seek(0) | |
return mp3_fp.read() | |
except Exception as e: | |
logger.error(f"Error in text-to-speech conversion: {e}") | |
raise HTTPException(status_code=500, detail="Error generating speech") | |
async def startup_event(): | |
global agent, patients_collection, analysis_collection, alerts_collection | |
agent = TxAgent( | |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
enable_finish=True, | |
enable_rag=False, | |
force_finish=True, | |
enable_checker=True, | |
step_rag_num=4, | |
seed=42 | |
) | |
agent.chat_prompt = ( | |
"You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." | |
) | |
agent.init_model() | |
logger.info("✅ TxAgent initialized") | |
db = get_mongo_client()["cps_db"] | |
patients_collection = db["patients"] | |
analysis_collection = db["patient_analysis_results"] | |
alerts_collection = db["clinical_alerts"] | |
logger.info("📡 Connected to MongoDB") | |
asyncio.create_task(analyze_all_patients()) | |
async def status(): | |
return { | |
"status": "running", | |
"timestamp": datetime.utcnow().isoformat(), | |
"version": "2.6.0", | |
"features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"] | |
} | |
async def get_patient_analysis_results(name: Optional[str] = Query(None)): | |
try: | |
query = {} | |
if name: | |
name_regex = re.compile(name, re.IGNORECASE) | |
matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None) | |
patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p] | |
if not patient_ids: | |
return [] | |
query = {"patient_id": {"$in": patient_ids}} | |
analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100) | |
enriched_results = [] | |
for analysis in analyses: | |
patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")}) | |
if patient: | |
analysis["full_name"] = patient.get("full_name", "Unknown") | |
analysis["_id"] = str(analysis["_id"]) | |
enriched_results.append(analysis) | |
return enriched_results | |
except Exception as e: | |
logger.error(f"Error fetching analysis results: {e}") | |
raise HTTPException(status_code=500, detail="Failed to retrieve analysis results") | |
async def chat_stream_endpoint(request: ChatRequest): | |
async def token_stream(): | |
try: | |
conversation = [{"role": "system", "content": agent.chat_prompt}] | |
if request.history: | |
conversation.extend(request.history) | |
conversation.append({"role": "user", "content": request.message}) | |
input_ids = agent.tokenizer.apply_chat_template( | |
conversation, add_generation_prompt=True, return_tensors="pt" | |
).to(agent.device) | |
output = agent.model.generate( | |
input_ids, | |
do_sample=True, | |
temperature=request.temperature, | |
max_new_tokens=request.max_new_tokens, | |
pad_token_id=agent.tokenizer.eos_token_id, | |
return_dict_in_generate=True | |
) | |
text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) | |
for chunk in text.split(): | |
yield chunk + " " | |
await asyncio.sleep(0.05) | |
except Exception as e: | |
logger.error(f"Streaming error: {e}") | |
yield f"⚠️ Error: {e}" | |
return StreamingResponse(token_stream(), media_type="text/plain") | |
async def transcribe_voice( | |
audio: UploadFile = File(...), | |
language: str = Query("en-US", description="Language code for speech recognition") | |
): | |
"""Convert speech to text.""" | |
try: | |
audio_data = await audio.read() | |
if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')): | |
raise HTTPException(status_code=400, detail="Unsupported audio format") | |
text = recognize_speech(audio_data, language) | |
return {"text": text} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in voice transcription: {e}") | |
raise HTTPException(status_code=500, detail="Error processing voice input") | |
async def synthesize_voice(request: VoiceOutputRequest): | |
"""Convert text to speech.""" | |
try: | |
audio_data = text_to_speech(request.text, request.language, request.slow) | |
if request.return_format == "base64": | |
return {"audio": base64.b64encode(audio_data).decode('utf-8')} | |
else: | |
return StreamingResponse( | |
io.BytesIO(audio_data), | |
media_type="audio/mpeg", | |
headers={"Content-Disposition": "attachment; filename=speech.mp3"} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in voice synthesis: {e}") | |
raise HTTPException(status_code=500, detail="Error generating voice output") | |
async def voice_chat_endpoint( | |
audio: UploadFile = File(...), | |
language: str = Query("en-US", description="Language code for speech recognition"), | |
temperature: float = Query(0.7, ge=0.1, le=1.0), | |
max_new_tokens: int = Query(512, ge=50, le=1024) | |
): | |
"""Complete voice chat interaction (speech-to-text -> AI -> text-to-speech).""" | |
try: | |
audio_data = await audio.read() | |
user_message = recognize_speech(audio_data, language) | |
chat_response = agent.chat( | |
message=user_message, | |
history=[], | |
temperature=temperature, | |
max_new_tokens=max_new_tokens | |
) | |
audio_data = text_to_speech(chat_response, language.split('-')[0]) | |
return StreamingResponse( | |
io.BytesIO(audio_data), | |
media_type="audio/mpeg", | |
headers={"Content-Disposition": "attachment; filename=response.mp3"} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in voice chat: {e}") | |
raise HTTPException(status_code=500, detail="Error processing voice chat") | |
async def analyze_clinical_report( | |
file: Optional[UploadFile] = File(None), | |
text: Optional[str] = Form(None), | |
temperature: float = Form(0.5), | |
max_new_tokens: int = Form(1024) | |
): | |
""" | |
Analyze a clinical patient report either from uploaded file or direct text input. | |
Parameters: | |
- file: Uploaded clinical report file (PDF, TXT, DOCX) | |
- text: Direct text input of the clinical report | |
- temperature: Controls randomness of response (0.1-1.0) | |
- max_new_tokens: Maximum length of response | |
Returns structured analysis of the patient report. | |
""" | |
try: | |
# Validate input | |
if not (file or text): | |
raise HTTPException(status_code=400, detail="Either file or text input is required") | |
# Extract text from file if provided | |
report_text = text | |
if file: | |
if file.filename.lower().endswith('.pdf'): | |
# PDF processing | |
import PyPDF2 | |
pdf_reader = PyPDF2.PdfReader(file.file) | |
report_text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
elif file.filename.lower().endswith(('.txt', '.md')): | |
# Plain text | |
report_text = (await file.read()).decode('utf-8') | |
elif file.filename.lower().endswith(('.docx', '.doc')): | |
# Word document | |
from docx import Document | |
doc = Document(io.BytesIO(await file.read())) | |
report_text = "\n".join([para.text for para in doc.paragraphs]) | |
else: | |
raise HTTPException(status_code=400, detail="Unsupported file format") | |
# Clean and validate the extracted text | |
if not report_text or len(report_text.strip()) < 50: | |
raise HTTPException(status_code=400, detail="Report text is too short or empty") | |
# Create analysis prompt | |
prompt = ( | |
"You are a senior clinical analyst. Analyze this patient report and provide:\n" | |
"1. SUMMARY: Concise summary of key findings\n" | |
"2. DIAGNOSES: List of confirmed or suspected diagnoses\n" | |
"3. RISK FACTORS: Important risk factors identified\n" | |
"4. RED FLAGS: Any urgent concerns that need attention\n" | |
"5. RECOMMENDATIONS: Suggested next steps for care\n\n" | |
f"PATIENT REPORT:\n{report_text[:15000]}" # Limit input size | |
) | |
# Get AI analysis | |
raw_response = agent.chat( | |
message=prompt, | |
history=[], | |
temperature=temperature, | |
max_new_tokens=max_new_tokens | |
) | |
# Structure the response | |
structured_response = structure_medical_response(raw_response) | |
# Add suicide risk assessment | |
risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response) | |
structured_response["suicide_risk"] = { | |
"level": risk_level.value, | |
"score": risk_score, | |
"factors": risk_factors | |
} | |
return JSONResponse(content=structured_response) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error analyzing clinical report: {e}") | |
raise HTTPException(status_code=500, detail="Error processing clinical report") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |