Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import logging | |
| import re | |
| import hashlib | |
| import io | |
| import base64 | |
| from datetime import datetime | |
| from typing import List, Dict, Optional, Tuple | |
| from enum import Enum | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Form | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import asyncio | |
| from bson import ObjectId | |
| import speech_recognition as sr | |
| from gtts import gTTS | |
| from pydub import AudioSegment | |
| import PyPDF2 | |
| import mimetypes | |
| from txagent.txagent import TxAgent | |
| from db.mongo import get_mongo_client | |
| from fastapi.encoders import jsonable_encoder | |
| from docx import Document | |
| # Logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger("TxAgentAPI") | |
| # App | |
| app = FastAPI(title="TxAgent API", version="2.6.0") # Updated version for optional patient_id | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| # Pydantic Models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| temperature: float = 0.7 | |
| max_new_tokens: int = 512 | |
| history: Optional[List[Dict]] = None | |
| format: Optional[str] = "clean" | |
| class VoiceInputRequest(BaseModel): | |
| audio_format: str = "wav" | |
| language: str = "en-US" | |
| class VoiceOutputRequest(BaseModel): | |
| text: str | |
| language: str = "en" | |
| slow: bool = False | |
| return_format: str = "mp3" # mp3 or base64 | |
| # Enums | |
| class RiskLevel(str, Enum): | |
| NONE = "none" | |
| LOW = "low" | |
| MODERATE = "moderate" | |
| HIGH = "high" | |
| SEVERE = "severe" | |
| # Globals | |
| agent = None | |
| patients_collection = None | |
| analysis_collection = None | |
| alerts_collection = None | |
| # Helpers | |
| def clean_text_response(text: str) -> str: | |
| text = re.sub(r'\n\s*\n', '\n\n', text) | |
| text = re.sub(r'[ ]+', ' ', text) | |
| return text.replace("**", "").replace("__", "").strip() | |
| def extract_section(text: str, heading: str) -> str: | |
| try: | |
| pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)" | |
| match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
| return match.group(1).strip() if match else "" | |
| except Exception as e: | |
| logger.error(f"Section extraction failed for heading '{heading}': {e}") | |
| return "" | |
| def structure_medical_response(text: str) -> Dict: | |
| """Improved version that handles both markdown and plain text formats""" | |
| def extract_improved(text: str, heading: str) -> str: | |
| patterns = [ | |
| rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)", | |
| rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)", | |
| rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)", | |
| rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)" | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| content = match.group(1).strip() | |
| content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE) | |
| return content | |
| return "" | |
| text = text.replace('**', '').replace('__', '') | |
| return { | |
| "summary": extract_improved(text, "Summary of Patient's Medical History") or | |
| extract_improved(text, "Summarize the patient's medical history"), | |
| "risks": extract_improved(text, "Identify Risks or Red Flags") or | |
| extract_improved(text, "Risks or Red Flags"), | |
| "missed_issues": extract_improved(text, "Missed Diagnoses or Treatments") or | |
| extract_improved(text, "What the doctor might have missed"), | |
| "recommendations": extract_improved(text, "Suggest Next Clinical Steps") or | |
| extract_improved(text, "Suggested Clinical Actions") | |
| } | |
| def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]: | |
| """Analyze text for suicide risk factors and return assessment""" | |
| suicide_keywords = [ | |
| 'suicide', 'suicidal', 'kill myself', 'end my life', | |
| 'want to die', 'self-harm', 'self harm', 'hopeless', | |
| 'no reason to live', 'plan to die' | |
| ] | |
| explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()] | |
| if not explicit_mentions: | |
| return RiskLevel.NONE, 0.0, [] | |
| assessment_prompt = ( | |
| "Assess the suicide risk level based on this text. " | |
| "Consider frequency, specificity, and severity of statements. " | |
| "Respond with JSON format: {\"risk_level\": \"low/moderate/high/severe\", " | |
| "\"risk_score\": 0-1, \"factors\": [\"list of risk factors\"]}\n\n" | |
| f"Text to assess:\n{text}" | |
| ) | |
| try: | |
| response = agent.chat( | |
| message=assessment_prompt, | |
| history=[], | |
| temperature=0.2, | |
| max_new_tokens=256 | |
| ) | |
| json_match = re.search(r'\{.*\}', response, re.DOTALL) | |
| if json_match: | |
| assessment = json.loads(json_match.group()) | |
| return ( | |
| RiskLevel(assessment.get("risk_level", "none").lower()), | |
| float(assessment.get("risk_score", 0)), | |
| assessment.get("factors", []) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in suicide risk assessment: {e}") | |
| risk_score = min(0.1 * len(explicit_mentions), 0.9) | |
| if risk_score > 0.7: | |
| return RiskLevel.HIGH, risk_score, explicit_mentions | |
| elif risk_score > 0.4: | |
| return RiskLevel.MODERATE, risk_score, explicit_mentions | |
| return RiskLevel.LOW, risk_score, explicit_mentions | |
| async def create_alert(patient_id: str, risk_data: dict): | |
| """Create an alert document in the database""" | |
| alert_doc = { | |
| "patient_id": patient_id, | |
| "type": "suicide_risk", | |
| "level": risk_data["level"], | |
| "score": risk_data["score"], | |
| "factors": risk_data["factors"], | |
| "timestamp": datetime.utcnow(), | |
| "acknowledged": False | |
| } | |
| await alerts_collection.insert_one(alert_doc) | |
| logger.warning(f"⚠️ Created suicide risk alert for patient {patient_id}") | |
| def serialize_patient(patient: dict) -> dict: | |
| patient_copy = patient.copy() | |
| if "_id" in patient_copy: | |
| patient_copy["_id"] = str(patient_copy["_id"]) | |
| return patient_copy | |
| def compute_patient_data_hash(data: dict) -> str: | |
| """Compute SHA-256 hash of patient data or report.""" | |
| serialized = json.dumps(data, sort_keys=True) | |
| return hashlib.sha256(serialized.encode()).hexdigest() | |
| def compute_file_content_hash(file_content: bytes) -> str: | |
| """Compute SHA-256 hash of file content.""" | |
| return hashlib.sha256(file_content).hexdigest() | |
| def extract_text_from_pdf(pdf_data: bytes) -> str: | |
| """Extract text from a PDF file.""" | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data)) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() or "" | |
| return clean_text_response(text) | |
| except Exception as e: | |
| logger.error(f"Error extracting text from PDF: {e}") | |
| raise HTTPException(status_code=400, detail="Failed to extract text from PDF") | |
| async def analyze_patient_report(patient_id: Optional[str], report_content: str, file_type: str, file_content: bytes): | |
| """Analyze a patient report and store results.""" | |
| try: | |
| # Use file content hash as identifier if no patient_id is provided | |
| identifier = patient_id if patient_id else compute_file_content_hash(file_content) | |
| report_data = {"identifier": identifier, "content": report_content, "file_type": file_type} | |
| report_hash = compute_patient_data_hash(report_data) | |
| logger.info(f"🧾 Analyzing report for identifier: {identifier}") | |
| # Check if analysis exists and hash matches | |
| existing_analysis = await analysis_collection.find_one({"identifier": identifier, "report_hash": report_hash}) | |
| if existing_analysis: | |
| logger.info(f"✅ No changes in report data for {identifier}, skipping analysis") | |
| return existing_analysis | |
| # Construct analysis prompt | |
| prompt = ( | |
| "You are a clinical decision support AI. Analyze the following patient report:\n" | |
| "1. Summarize the patient's medical history.\n" | |
| "2. Identify risks or red flags (including mental health and suicide risk).\n" | |
| "3. Highlight missed diagnoses or treatments.\n" | |
| "4. Suggest next clinical steps.\n" | |
| f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}" | |
| ) | |
| # Perform analysis | |
| raw_response = agent.chat( | |
| message=prompt, | |
| history=[], | |
| temperature=0.7, | |
| max_new_tokens=1024 | |
| ) | |
| structured_response = structure_medical_response(raw_response) | |
| # Suicide risk assessment | |
| risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response) | |
| suicide_risk = { | |
| "level": risk_level.value, | |
| "score": risk_score, | |
| "factors": risk_factors | |
| } | |
| # Store analysis | |
| analysis_doc = { | |
| "identifier": identifier, | |
| "patient_id": patient_id, # May be None | |
| "timestamp": datetime.utcnow(), | |
| "summary": structured_response, | |
| "suicide_risk": suicide_risk, | |
| "raw": raw_response, | |
| "report_hash": report_hash, | |
| "file_type": file_type | |
| } | |
| await analysis_collection.update_one( | |
| {"identifier": identifier, "report_hash": report_hash}, | |
| {"$set": analysis_doc}, | |
| upsert=True | |
| ) | |
| # Create alert for high-risk cases only if patient_id is provided | |
| if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]: | |
| await create_alert(patient_id, suicide_risk) | |
| logger.info(f"✅ Stored analysis for identifier {identifier}") | |
| return analysis_doc | |
| except Exception as e: | |
| logger.error(f"Error analyzing patient report: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to analyze patient report") | |
| async def analyze_all_patients(): | |
| """Analyze all patients in the database.""" | |
| patients = await patients_collection.find({}).to_list(length=None) | |
| for patient in patients: | |
| await analyze_patient(patient) | |
| await asyncio.sleep(0.1) | |
| async def analyze_patient(patient: dict): | |
| """Analyze patient data (existing logic for patient records).""" | |
| try: | |
| serialized = serialize_patient(patient) | |
| patient_id = serialized.get("fhir_id") | |
| patient_hash = compute_patient_data_hash(serialized) | |
| logger.info(f"🧾 Analyzing patient: {patient_id}") | |
| existing_analysis = await analysis_collection.find_one({"patient_id": patient_id}) | |
| if existing_analysis and existing_analysis.get("data_hash") == patient_hash: | |
| logger.info(f"✅ No changes in patient data for {patient_id}, skipping analysis") | |
| return | |
| doc = json.dumps(serialized, indent=2) | |
| message = ( | |
| "You are a clinical decision support AI.\n\n" | |
| "Given the patient document below:\n" | |
| "1. Summarize the patient's medical history.\n" | |
| "2. Identify risks or red flags (including mental health and suicide risk).\n" | |
| "3. Highlight missed diagnoses or treatments.\n" | |
| "4. Suggest next clinical steps.\n" | |
| f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}" | |
| ) | |
| raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024) | |
| structured = structure_medical_response(raw) | |
| risk_level, risk_score, risk_factors = detect_suicide_risk(raw) | |
| suicide_risk = { | |
| "level": risk_level.value, | |
| "score": risk_score, | |
| "factors": risk_factors | |
| } | |
| analysis_doc = { | |
| "identifier": patient_id, | |
| "patient_id": patient_id, | |
| "timestamp": datetime.utcnow(), | |
| "summary": structured, | |
| "suicide_risk": suicide_risk, | |
| "raw": raw, | |
| "data_hash": patient_hash | |
| } | |
| await analysis_collection.update_one( | |
| {"identifier": patient_id}, | |
| {"$set": analysis_doc}, | |
| upsert=True | |
| ) | |
| if risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]: | |
| await create_alert(patient_id, suicide_risk) | |
| logger.info(f"✅ Stored analysis for patient {patient_id}") | |
| except Exception as e: | |
| logger.error(f"Error analyzing patient: {e}") | |
| def recognize_speech(audio_data: bytes, language: str = "en-US") -> str: | |
| """Convert speech to text using Google's speech recognition.""" | |
| recognizer = sr.Recognizer() | |
| try: | |
| with io.BytesIO(audio_data) as audio_file: | |
| with sr.AudioFile(audio_file) as source: | |
| audio = recognizer.record(source) | |
| text = recognizer.recognize_google(audio, language=language) | |
| return text | |
| except sr.UnknownValueError: | |
| logger.error("Google Speech Recognition could not understand audio") | |
| raise HTTPException(status_code=400, detail="Could not understand audio") | |
| except sr.RequestError as e: | |
| logger.error(f"Could not request results from Google Speech Recognition service; {e}") | |
| raise HTTPException(status_code=503, detail="Speech recognition service unavailable") | |
| except Exception as e: | |
| logger.error(f"Error in speech recognition: {e}") | |
| raise HTTPException(status_code=500, detail="Error processing speech") | |
| def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes: | |
| """Convert text to speech using gTTS and return as MP3 bytes.""" | |
| try: | |
| tts = gTTS(text=text, lang=language, slow=slow) | |
| mp3_fp = io.BytesIO() | |
| tts.write_to_fp(mp3_fp) | |
| mp3_fp.seek(0) | |
| return mp3_fp.read() | |
| except Exception as e: | |
| logger.error(f"Error in text-to-speech conversion: {e}") | |
| raise HTTPException(status_code=500, detail="Error generating speech") | |
| async def startup_event(): | |
| global agent, patients_collection, analysis_collection, alerts_collection | |
| agent = TxAgent( | |
| model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| enable_finish=True, | |
| enable_rag=False, | |
| force_finish=True, | |
| enable_checker=True, | |
| step_rag_num=4, | |
| seed=42 | |
| ) | |
| agent.chat_prompt = ( | |
| "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." | |
| ) | |
| agent.init_model() | |
| logger.info("✅ TxAgent initialized") | |
| db = get_mongo_client()["cps_db"] | |
| patients_collection = db["patients"] | |
| analysis_collection = db["patient_analysis_results"] | |
| alerts_collection = db["clinical_alerts"] | |
| logger.info("📡 Connected to MongoDB") | |
| asyncio.create_task(analyze_all_patients()) | |
| async def status(): | |
| return { | |
| "status": "running", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "version": "2.6.0", | |
| "features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"] | |
| } | |
| async def get_patient_analysis_results(name: Optional[str] = Query(None)): | |
| try: | |
| query = {} | |
| if name: | |
| name_regex = re.compile(name, re.IGNORECASE) | |
| matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None) | |
| patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p] | |
| if not patient_ids: | |
| return [] | |
| query = {"patient_id": {"$in": patient_ids}} | |
| analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100) | |
| enriched_results = [] | |
| for analysis in analyses: | |
| patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")}) | |
| if patient: | |
| analysis["full_name"] = patient.get("full_name", "Unknown") | |
| analysis["_id"] = str(analysis["_id"]) | |
| enriched_results.append(analysis) | |
| return enriched_results | |
| except Exception as e: | |
| logger.error(f"Error fetching analysis results: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to retrieve analysis results") | |
| async def chat_stream_endpoint(request: ChatRequest): | |
| async def token_stream(): | |
| try: | |
| conversation = [{"role": "system", "content": agent.chat_prompt}] | |
| if request.history: | |
| conversation.extend(request.history) | |
| conversation.append({"role": "user", "content": request.message}) | |
| input_ids = agent.tokenizer.apply_chat_template( | |
| conversation, add_generation_prompt=True, return_tensors="pt" | |
| ).to(agent.device) | |
| output = agent.model.generate( | |
| input_ids, | |
| do_sample=True, | |
| temperature=request.temperature, | |
| max_new_tokens=request.max_new_tokens, | |
| pad_token_id=agent.tokenizer.eos_token_id, | |
| return_dict_in_generate=True | |
| ) | |
| text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) | |
| for chunk in text.split(): | |
| yield chunk + " " | |
| await asyncio.sleep(0.05) | |
| except Exception as e: | |
| logger.error(f"Streaming error: {e}") | |
| yield f"⚠️ Error: {e}" | |
| return StreamingResponse(token_stream(), media_type="text/plain") | |
| async def transcribe_voice( | |
| audio: UploadFile = File(...), | |
| language: str = Query("en-US", description="Language code for speech recognition") | |
| ): | |
| """Convert speech to text.""" | |
| try: | |
| audio_data = await audio.read() | |
| if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')): | |
| raise HTTPException(status_code=400, detail="Unsupported audio format") | |
| text = recognize_speech(audio_data, language) | |
| return {"text": text} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in voice transcription: {e}") | |
| raise HTTPException(status_code=500, detail="Error processing voice input") | |
| async def synthesize_voice(request: VoiceOutputRequest): | |
| """Convert text to speech.""" | |
| try: | |
| audio_data = text_to_speech(request.text, request.language, request.slow) | |
| if request.return_format == "base64": | |
| return {"audio": base64.b64encode(audio_data).decode('utf-8')} | |
| else: | |
| return StreamingResponse( | |
| io.BytesIO(audio_data), | |
| media_type="audio/mpeg", | |
| headers={"Content-Disposition": "attachment; filename=speech.mp3"} | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in voice synthesis: {e}") | |
| raise HTTPException(status_code=500, detail="Error generating voice output") | |
| async def voice_chat_endpoint( | |
| audio: UploadFile = File(...), | |
| language: str = Query("en-US", description="Language code for speech recognition"), | |
| temperature: float = Query(0.7, ge=0.1, le=1.0), | |
| max_new_tokens: int = Query(512, ge=50, le=1024) | |
| ): | |
| """Complete voice chat interaction (speech-to-text -> AI -> text-to-speech).""" | |
| try: | |
| audio_data = await audio.read() | |
| user_message = recognize_speech(audio_data, language) | |
| chat_response = agent.chat( | |
| message=user_message, | |
| history=[], | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens | |
| ) | |
| audio_data = text_to_speech(chat_response, language.split('-')[0]) | |
| return StreamingResponse( | |
| io.BytesIO(audio_data), | |
| media_type="audio/mpeg", | |
| headers={"Content-Disposition": "attachment; filename=response.mp3"} | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in voice chat: {e}") | |
| raise HTTPException(status_code=500, detail="Error processing voice chat") | |
| async def analyze_clinical_report( | |
| file: UploadFile = File(...), | |
| patient_id: Optional[str] = Form(None), | |
| temperature: float = Form(0.5), | |
| max_new_tokens: int = Form(1024) | |
| ): | |
| """ | |
| Analyze a clinical patient report from an uploaded file. | |
| Parameters: | |
| - file: Uploaded clinical report file (PDF, TXT, DOCX) | |
| - patient_id: Optional patient ID to associate with this report | |
| - temperature: Controls randomness of response (0.1-1.0) | |
| - max_new_tokens: Maximum length of response | |
| """ | |
| try: | |
| # Validate file type | |
| content_type = file.content_type | |
| allowed_types = [ | |
| 'application/pdf', | |
| 'text/plain', | |
| 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' | |
| ] | |
| if content_type not in allowed_types: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX" | |
| ) | |
| # Read file content | |
| file_content = await file.read() | |
| # Extract text from file | |
| if content_type == 'application/pdf': | |
| text = extract_text_from_pdf(file_content) | |
| elif content_type == 'text/plain': | |
| text = file_content.decode('utf-8') | |
| elif content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': | |
| doc = Document(io.BytesIO(file_content)) | |
| text = "\n".join([para.text for para in doc.paragraphs]) | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported file type") | |
| # Clean and validate text | |
| text = clean_text_response(text) | |
| if len(text.strip()) < 50: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Extracted text is too short (minimum 50 characters required)" | |
| ) | |
| # Analyze the report | |
| analysis = await analyze_patient_report( | |
| patient_id=patient_id, | |
| report_content=text, | |
| file_type=content_type, | |
| file_content=file_content | |
| ) | |
| # Manually convert ObjectId and timestamp if needed | |
| if "_id" in analysis and isinstance(analysis["_id"], ObjectId): | |
| analysis["_id"] = str(analysis["_id"]) | |
| if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime): | |
| analysis["timestamp"] = analysis["timestamp"].isoformat() | |
| # Return response using jsonable_encoder | |
| return JSONResponse(content=jsonable_encoder({ | |
| "status": "success", | |
| "analysis": analysis, | |
| "patient_id": patient_id, | |
| "file_type": content_type, | |
| "file_size": len(file_content) | |
| })) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in report analysis: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to analyze report: {str(e)}" | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |