Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            import sys
         | 
| 3 | 
             
            import json
         | 
| @@ -9,9 +10,11 @@ import base64 | |
| 9 | 
             
            from datetime import datetime
         | 
| 10 | 
             
            from typing import List, Dict, Optional, Tuple
         | 
| 11 | 
             
            from enum import Enum
         | 
| 12 | 
            -
            from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Form
         | 
| 13 | 
             
            from fastapi.responses import StreamingResponse, JSONResponse
         | 
| 14 | 
             
            from fastapi.middleware.cors import CORSMiddleware
         | 
|  | |
|  | |
| 15 | 
             
            from pydantic import BaseModel
         | 
| 16 | 
             
            import asyncio
         | 
| 17 | 
             
            from bson import ObjectId
         | 
| @@ -20,17 +23,19 @@ from gtts import gTTS | |
| 20 | 
             
            from pydub import AudioSegment
         | 
| 21 | 
             
            import PyPDF2
         | 
| 22 | 
             
            import mimetypes
         | 
|  | |
|  | |
| 23 | 
             
            from txagent.txagent import TxAgent
         | 
| 24 | 
             
            from db.mongo import get_mongo_client
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            from docx import Document 
         | 
| 27 | 
             
            # Logging
         | 
| 28 | 
             
            logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
         | 
| 29 | 
             
            logger = logging.getLogger("TxAgentAPI")
         | 
| 30 |  | 
| 31 | 
             
            # App
         | 
| 32 | 
            -
            app = FastAPI(title="TxAgent API", version="2.6.0") | 
| 33 |  | 
|  | |
| 34 | 
             
            app.add_middleware(
         | 
| 35 | 
             
                CORSMiddleware,
         | 
| 36 | 
             
                allow_origins=["*"],
         | 
| @@ -39,6 +44,13 @@ app.add_middleware( | |
| 39 | 
             
                allow_headers=["*"]
         | 
| 40 | 
             
            )
         | 
| 41 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 42 | 
             
            # Pydantic Models
         | 
| 43 | 
             
            class ChatRequest(BaseModel):
         | 
| 44 | 
             
                message: str
         | 
| @@ -55,7 +67,7 @@ class VoiceOutputRequest(BaseModel): | |
| 55 | 
             
                text: str
         | 
| 56 | 
             
                language: str = "en"
         | 
| 57 | 
             
                slow: bool = False
         | 
| 58 | 
            -
                return_format: str = "mp3" | 
| 59 |  | 
| 60 | 
             
            # Enums
         | 
| 61 | 
             
            class RiskLevel(str, Enum):
         | 
| @@ -71,7 +83,26 @@ patients_collection = None | |
| 71 | 
             
            analysis_collection = None
         | 
| 72 | 
             
            alerts_collection = None
         | 
| 73 |  | 
| 74 | 
            -
            #  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 75 | 
             
            def clean_text_response(text: str) -> str:
         | 
| 76 | 
             
                text = re.sub(r'\n\s*\n', '\n\n', text)
         | 
| 77 | 
             
                text = re.sub(r'[ ]+', ' ', text)
         | 
| @@ -87,7 +118,6 @@ def extract_section(text: str, heading: str) -> str: | |
| 87 | 
             
                    return ""
         | 
| 88 |  | 
| 89 | 
             
            def structure_medical_response(text: str) -> Dict:
         | 
| 90 | 
            -
                """Improved version that handles both markdown and plain text formats"""
         | 
| 91 | 
             
                def extract_improved(text: str, heading: str) -> str:
         | 
| 92 | 
             
                    patterns = [
         | 
| 93 | 
             
                        rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
         | 
| @@ -95,7 +125,6 @@ def structure_medical_response(text: str) -> Dict: | |
| 95 | 
             
                        rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
         | 
| 96 | 
             
                        rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
         | 
| 97 | 
             
                    ]
         | 
| 98 | 
            -
                    
         | 
| 99 | 
             
                    for pattern in patterns:
         | 
| 100 | 
             
                        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
         | 
| 101 | 
             
                        if match:
         | 
| @@ -103,9 +132,8 @@ def structure_medical_response(text: str) -> Dict: | |
| 103 | 
             
                            content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
         | 
| 104 | 
             
                            return content
         | 
| 105 | 
             
                    return ""
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                text = text.replace('**', '').replace('__', '')
         | 
| 108 |  | 
|  | |
| 109 | 
             
                return {
         | 
| 110 | 
             
                    "summary": extract_improved(text, "Summary of Patient's Medical History") or 
         | 
| 111 | 
             
                              extract_improved(text, "Summarize the patient's medical history"),
         | 
| @@ -118,15 +146,12 @@ def structure_medical_response(text: str) -> Dict: | |
| 118 | 
             
                }
         | 
| 119 |  | 
| 120 | 
             
            def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
         | 
| 121 | 
            -
                """Analyze text for suicide risk factors and return assessment"""
         | 
| 122 | 
             
                suicide_keywords = [
         | 
| 123 | 
             
                    'suicide', 'suicidal', 'kill myself', 'end my life', 
         | 
| 124 | 
             
                    'want to die', 'self-harm', 'self harm', 'hopeless',
         | 
| 125 | 
             
                    'no reason to live', 'plan to die'
         | 
| 126 | 
             
                ]
         | 
| 127 | 
            -
                
         | 
| 128 | 
             
                explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
         | 
| 129 | 
            -
                
         | 
| 130 | 
             
                if not explicit_mentions:
         | 
| 131 | 
             
                    return RiskLevel.NONE, 0.0, []
         | 
| 132 |  | 
| @@ -145,7 +170,6 @@ def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]: | |
| 145 | 
             
                        temperature=0.2,
         | 
| 146 | 
             
                        max_new_tokens=256
         | 
| 147 | 
             
                    )
         | 
| 148 | 
            -
                    
         | 
| 149 | 
             
                    json_match = re.search(r'\{.*\}', response, re.DOTALL)
         | 
| 150 | 
             
                    if json_match:
         | 
| 151 | 
             
                        assessment = json.loads(json_match.group())
         | 
| @@ -165,7 +189,6 @@ def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]: | |
| 165 | 
             
                return RiskLevel.LOW, risk_score, explicit_mentions
         | 
| 166 |  | 
| 167 | 
             
            async def create_alert(patient_id: str, risk_data: dict):
         | 
| 168 | 
            -
                """Create an alert document in the database"""
         | 
| 169 | 
             
                alert_doc = {
         | 
| 170 | 
             
                    "patient_id": patient_id,
         | 
| 171 | 
             
                    "type": "suicide_risk",
         | 
| @@ -185,16 +208,13 @@ def serialize_patient(patient: dict) -> dict: | |
| 185 | 
             
                return patient_copy
         | 
| 186 |  | 
| 187 | 
             
            def compute_patient_data_hash(data: dict) -> str:
         | 
| 188 | 
            -
                """Compute SHA-256 hash of patient data or report."""
         | 
| 189 | 
             
                serialized = json.dumps(data, sort_keys=True)
         | 
| 190 | 
             
                return hashlib.sha256(serialized.encode()).hexdigest()
         | 
| 191 |  | 
| 192 | 
             
            def compute_file_content_hash(file_content: bytes) -> str:
         | 
| 193 | 
            -
                """Compute SHA-256 hash of file content."""
         | 
| 194 | 
             
                return hashlib.sha256(file_content).hexdigest()
         | 
| 195 |  | 
| 196 | 
             
            def extract_text_from_pdf(pdf_data: bytes) -> str:
         | 
| 197 | 
            -
                """Extract text from a PDF file."""
         | 
| 198 | 
             
                try:
         | 
| 199 | 
             
                    pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
         | 
| 200 | 
             
                    text = ""
         | 
| @@ -206,85 +226,70 @@ def extract_text_from_pdf(pdf_data: bytes) -> str: | |
| 206 | 
             
                    raise HTTPException(status_code=400, detail="Failed to extract text from PDF")
         | 
| 207 |  | 
| 208 | 
             
            async def analyze_patient_report(patient_id: Optional[str], report_content: str, file_type: str, file_content: bytes):
         | 
| 209 | 
            -
                 | 
| 210 | 
            -
                 | 
| 211 | 
            -
             | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
| 217 | 
            -
                     | 
| 218 | 
            -
             | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
             | 
| 223 | 
            -
                     | 
| 224 | 
            -
                     | 
| 225 | 
            -
             | 
| 226 | 
            -
             | 
| 227 | 
            -
                        "2. Identify risks or red flags (including mental health and suicide risk).\n"
         | 
| 228 | 
            -
                        "3. Highlight missed diagnoses or treatments.\n"
         | 
| 229 | 
            -
                        "4. Suggest next clinical steps.\n"
         | 
| 230 | 
            -
                        f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}"
         | 
| 231 | 
            -
                    )
         | 
| 232 | 
            -
             | 
| 233 | 
            -
                    # Perform analysis
         | 
| 234 | 
            -
                    raw_response = agent.chat(
         | 
| 235 | 
            -
                        message=prompt,
         | 
| 236 | 
            -
                        history=[],
         | 
| 237 | 
            -
                        temperature=0.7,
         | 
| 238 | 
            -
                        max_new_tokens=1024
         | 
| 239 | 
            -
                    )
         | 
| 240 | 
            -
                    structured_response = structure_medical_response(raw_response)
         | 
| 241 |  | 
| 242 | 
            -
             | 
| 243 | 
            -
                     | 
| 244 | 
            -
                     | 
| 245 | 
            -
             | 
| 246 | 
            -
             | 
| 247 | 
            -
             | 
| 248 | 
            -
             | 
| 249 |  | 
| 250 | 
            -
             | 
| 251 | 
            -
             | 
| 252 | 
            -
             | 
| 253 | 
            -
             | 
| 254 | 
            -
             | 
| 255 | 
            -
             | 
| 256 | 
            -
                        "suicide_risk": suicide_risk,
         | 
| 257 | 
            -
                        "raw": raw_response,
         | 
| 258 | 
            -
                        "report_hash": report_hash,
         | 
| 259 | 
            -
                        "file_type": file_type
         | 
| 260 | 
            -
                    }
         | 
| 261 |  | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 267 |  | 
| 268 | 
            -
             | 
| 269 | 
            -
                     | 
| 270 | 
            -
             | 
|  | |
|  | |
| 271 |  | 
| 272 | 
            -
             | 
| 273 | 
            -
                     | 
| 274 |  | 
| 275 | 
            -
                 | 
| 276 | 
            -
             | 
| 277 | 
            -
                    raise HTTPException(status_code=500, detail="Failed to analyze patient report")
         | 
| 278 |  | 
| 279 | 
             
            async def analyze_all_patients():
         | 
| 280 | 
            -
                """Analyze all patients in the database."""
         | 
| 281 | 
             
                patients = await patients_collection.find({}).to_list(length=None)
         | 
| 282 | 
             
                for patient in patients:
         | 
| 283 | 
             
                    await analyze_patient(patient)
         | 
| 284 | 
             
                    await asyncio.sleep(0.1)
         | 
| 285 |  | 
| 286 | 
             
            async def analyze_patient(patient: dict):
         | 
| 287 | 
            -
                """Analyze patient data (existing logic for patient records)."""
         | 
| 288 | 
             
                try:
         | 
| 289 | 
             
                    serialized = serialize_patient(patient)
         | 
| 290 | 
             
                    patient_id = serialized.get("fhir_id")
         | 
| @@ -342,9 +347,7 @@ async def analyze_patient(patient: dict): | |
| 342 | 
             
                    logger.error(f"Error analyzing patient: {e}")
         | 
| 343 |  | 
| 344 | 
             
            def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
         | 
| 345 | 
            -
                """Convert speech to text using Google's speech recognition."""
         | 
| 346 | 
             
                recognizer = sr.Recognizer()
         | 
| 347 | 
            -
                
         | 
| 348 | 
             
                try:
         | 
| 349 | 
             
                    with io.BytesIO(audio_data) as audio_file:
         | 
| 350 | 
             
                        with sr.AudioFile(audio_file) as source:
         | 
| @@ -362,7 +365,6 @@ def recognize_speech(audio_data: bytes, language: str = "en-US") -> str: | |
| 362 | 
             
                    raise HTTPException(status_code=500, detail="Error processing speech")
         | 
| 363 |  | 
| 364 | 
             
            def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
         | 
| 365 | 
            -
                """Convert text to speech using gTTS and return as MP3 bytes."""
         | 
| 366 | 
             
                try:
         | 
| 367 | 
             
                    tts = gTTS(text=text, lang=language, slow=slow)
         | 
| 368 | 
             
                    mp3_fp = io.BytesIO()
         | 
| @@ -394,6 +396,8 @@ async def startup_event(): | |
| 394 | 
             
                logger.info("✅ TxAgent initialized")
         | 
| 395 |  | 
| 396 | 
             
                db = get_mongo_client()["cps_db"]
         | 
|  | |
|  | |
| 397 | 
             
                patients_collection = db["patients"]
         | 
| 398 | 
             
                analysis_collection = db["patient_analysis_results"]
         | 
| 399 | 
             
                alerts_collection = db["clinical_alerts"]
         | 
| @@ -401,8 +405,10 @@ async def startup_event(): | |
| 401 |  | 
| 402 | 
             
                asyncio.create_task(analyze_all_patients())
         | 
| 403 |  | 
|  | |
| 404 | 
             
            @app.get("/status")
         | 
| 405 | 
            -
            async def status():
         | 
|  | |
| 406 | 
             
                return {
         | 
| 407 | 
             
                    "status": "running",
         | 
| 408 | 
             
                    "timestamp": datetime.utcnow().isoformat(),
         | 
| @@ -411,7 +417,11 @@ async def status(): | |
| 411 | 
             
                }
         | 
| 412 |  | 
| 413 | 
             
            @app.get("/patients/analysis-results")
         | 
| 414 | 
            -
            async def get_patient_analysis_results( | 
|  | |
|  | |
|  | |
|  | |
| 415 | 
             
                try:
         | 
| 416 | 
             
                    query = {}
         | 
| 417 | 
             
                    if name:
         | 
| @@ -438,7 +448,11 @@ async def get_patient_analysis_results(name: Optional[str] = Query(None)): | |
| 438 | 
             
                    raise HTTPException(status_code=500, detail="Failed to retrieve analysis results")
         | 
| 439 |  | 
| 440 | 
             
            @app.post("/chat-stream")
         | 
| 441 | 
            -
            async def chat_stream_endpoint( | 
|  | |
|  | |
|  | |
|  | |
| 442 | 
             
                async def token_stream():
         | 
| 443 | 
             
                    try:
         | 
| 444 | 
             
                        conversation = [{"role": "system", "content": agent.chat_prompt}]
         | 
| @@ -472,9 +486,10 @@ async def chat_stream_endpoint(request: ChatRequest): | |
| 472 | 
             
            @app.post("/voice/transcribe")
         | 
| 473 | 
             
            async def transcribe_voice(
         | 
| 474 | 
             
                audio: UploadFile = File(...),
         | 
| 475 | 
            -
                language: str = Query("en-US", description="Language code for speech recognition")
         | 
|  | |
| 476 | 
             
            ):
         | 
| 477 | 
            -
                " | 
| 478 | 
             
                try:
         | 
| 479 | 
             
                    audio_data = await audio.read()
         | 
| 480 | 
             
                    if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
         | 
| @@ -490,8 +505,11 @@ async def transcribe_voice( | |
| 490 | 
             
                    raise HTTPException(status_code=500, detail="Error processing voice input")
         | 
| 491 |  | 
| 492 | 
             
            @app.post("/voice/synthesize")
         | 
| 493 | 
            -
            async def synthesize_voice( | 
| 494 | 
            -
                 | 
|  | |
|  | |
|  | |
| 495 | 
             
                try:
         | 
| 496 | 
             
                    audio_data = text_to_speech(request.text, request.language, request.slow)
         | 
| 497 |  | 
| @@ -515,9 +533,10 @@ async def voice_chat_endpoint( | |
| 515 | 
             
                audio: UploadFile = File(...),
         | 
| 516 | 
             
                language: str = Query("en-US", description="Language code for speech recognition"),
         | 
| 517 | 
             
                temperature: float = Query(0.7, ge=0.1, le=1.0),
         | 
| 518 | 
            -
                max_new_tokens: int = Query(512, ge=50, le=1024)
         | 
|  | |
| 519 | 
             
            ):
         | 
| 520 | 
            -
                " | 
| 521 | 
             
                try:
         | 
| 522 | 
             
                    audio_data = await audio.read()
         | 
| 523 | 
             
                    user_message = recognize_speech(audio_data, language)
         | 
| @@ -548,18 +567,11 @@ async def analyze_clinical_report( | |
| 548 | 
             
                file: UploadFile = File(...),
         | 
| 549 | 
             
                patient_id: Optional[str] = Form(None),
         | 
| 550 | 
             
                temperature: float = Form(0.5),
         | 
| 551 | 
            -
                max_new_tokens: int = Form(1024)
         | 
|  | |
| 552 | 
             
            ):
         | 
| 553 | 
            -
                "" | 
| 554 | 
            -
                Analyze a clinical patient report from an uploaded file.
         | 
| 555 | 
            -
                Parameters:
         | 
| 556 | 
            -
                - file: Uploaded clinical report file (PDF, TXT, DOCX)
         | 
| 557 | 
            -
                - patient_id: Optional patient ID to associate with this report
         | 
| 558 | 
            -
                - temperature: Controls randomness of response (0.1-1.0)
         | 
| 559 | 
            -
                - max_new_tokens: Maximum length of response
         | 
| 560 | 
            -
                """
         | 
| 561 | 
             
                try:
         | 
| 562 | 
            -
                    # Validate file type
         | 
| 563 | 
             
                    content_type = file.content_type
         | 
| 564 | 
             
                    allowed_types = [
         | 
| 565 | 
             
                        'application/pdf',
         | 
| @@ -573,10 +585,8 @@ async def analyze_clinical_report( | |
| 573 | 
             
                            detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX"
         | 
| 574 | 
             
                        )
         | 
| 575 |  | 
| 576 | 
            -
                    # Read file content
         | 
| 577 | 
             
                    file_content = await file.read()
         | 
| 578 |  | 
| 579 | 
            -
                    # Extract text from file
         | 
| 580 | 
             
                    if content_type == 'application/pdf':
         | 
| 581 | 
             
                        text = extract_text_from_pdf(file_content)
         | 
| 582 | 
             
                    elif content_type == 'text/plain':
         | 
| @@ -587,7 +597,6 @@ async def analyze_clinical_report( | |
| 587 | 
             
                    else:
         | 
| 588 | 
             
                        raise HTTPException(status_code=400, detail="Unsupported file type")
         | 
| 589 |  | 
| 590 | 
            -
                    # Clean and validate text
         | 
| 591 | 
             
                    text = clean_text_response(text)
         | 
| 592 | 
             
                    if len(text.strip()) < 50:
         | 
| 593 | 
             
                        raise HTTPException(
         | 
| @@ -595,7 +604,6 @@ async def analyze_clinical_report( | |
| 595 | 
             
                            detail="Extracted text is too short (minimum 50 characters required)"
         | 
| 596 | 
             
                        )
         | 
| 597 |  | 
| 598 | 
            -
                    # Analyze the report
         | 
| 599 | 
             
                    analysis = await analyze_patient_report(
         | 
| 600 | 
             
                        patient_id=patient_id,
         | 
| 601 | 
             
                        report_content=text,
         | 
| @@ -603,13 +611,11 @@ async def analyze_clinical_report( | |
| 603 | 
             
                        file_content=file_content
         | 
| 604 | 
             
                    )
         | 
| 605 |  | 
| 606 | 
            -
                    # Manually convert ObjectId and timestamp if needed
         | 
| 607 | 
             
                    if "_id" in analysis and isinstance(analysis["_id"], ObjectId):
         | 
| 608 | 
             
                        analysis["_id"] = str(analysis["_id"])
         | 
| 609 | 
             
                    if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime):
         | 
| 610 | 
             
                        analysis["timestamp"] = analysis["timestamp"].isoformat()
         | 
| 611 |  | 
| 612 | 
            -
                    # Return response using jsonable_encoder
         | 
| 613 | 
             
                    return JSONResponse(content=jsonable_encoder({
         | 
| 614 | 
             
                        "status": "success",
         | 
| 615 | 
             
                        "analysis": analysis,
         | 
| @@ -627,7 +633,6 @@ async def analyze_clinical_report( | |
| 627 | 
             
                        detail=f"Failed to analyze report: {str(e)}"
         | 
| 628 | 
             
                    )
         | 
| 629 |  | 
| 630 | 
            -
                    
         | 
| 631 | 
             
            if __name__ == "__main__":
         | 
| 632 | 
             
                import uvicorn
         | 
| 633 | 
             
                uvicorn.run(app, host="0.0.0.0", port=8000)
         | 
|  | |
| 1 | 
            +
            # app.py (in TxAgent-API)
         | 
| 2 | 
             
            import os
         | 
| 3 | 
             
            import sys
         | 
| 4 | 
             
            import json
         | 
|  | |
| 10 | 
             
            from datetime import datetime
         | 
| 11 | 
             
            from typing import List, Dict, Optional, Tuple
         | 
| 12 | 
             
            from enum import Enum
         | 
| 13 | 
            +
            from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Form, Depends
         | 
| 14 | 
             
            from fastapi.responses import StreamingResponse, JSONResponse
         | 
| 15 | 
             
            from fastapi.middleware.cors import CORSMiddleware
         | 
| 16 | 
            +
            from fastapi.security import OAuth2PasswordBearer
         | 
| 17 | 
            +
            from fastapi.encoders import jsonable_encoder
         | 
| 18 | 
             
            from pydantic import BaseModel
         | 
| 19 | 
             
            import asyncio
         | 
| 20 | 
             
            from bson import ObjectId
         | 
|  | |
| 23 | 
             
            from pydub import AudioSegment
         | 
| 24 | 
             
            import PyPDF2
         | 
| 25 | 
             
            import mimetypes
         | 
| 26 | 
            +
            from docx import Document
         | 
| 27 | 
            +
            from jose import JWTError, jwt
         | 
| 28 | 
             
            from txagent.txagent import TxAgent
         | 
| 29 | 
             
            from db.mongo import get_mongo_client
         | 
| 30 | 
            +
             | 
|  | |
| 31 | 
             
            # Logging
         | 
| 32 | 
             
            logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
         | 
| 33 | 
             
            logger = logging.getLogger("TxAgentAPI")
         | 
| 34 |  | 
| 35 | 
             
            # App
         | 
| 36 | 
            +
            app = FastAPI(title="TxAgent API", version="2.6.0")
         | 
| 37 |  | 
| 38 | 
            +
            # CORS
         | 
| 39 | 
             
            app.add_middleware(
         | 
| 40 | 
             
                CORSMiddleware,
         | 
| 41 | 
             
                allow_origins=["*"],
         | 
|  | |
| 44 | 
             
                allow_headers=["*"]
         | 
| 45 | 
             
            )
         | 
| 46 |  | 
| 47 | 
            +
            # JWT settings (must match CPS-API)
         | 
| 48 | 
            +
            SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key")  # Same as CPS-API
         | 
| 49 | 
            +
            ALGORITHM = "HS256"
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            # OAuth2 scheme (point to CPS-API's login endpoint)
         | 
| 52 | 
            +
            oauth2_scheme = OAuth2PasswordBearer(tokenUrl="https://rocketfarmstudios-cps-api.hf.space/auth/login")
         | 
| 53 | 
            +
             | 
| 54 | 
             
            # Pydantic Models
         | 
| 55 | 
             
            class ChatRequest(BaseModel):
         | 
| 56 | 
             
                message: str
         | 
|  | |
| 67 | 
             
                text: str
         | 
| 68 | 
             
                language: str = "en"
         | 
| 69 | 
             
                slow: bool = False
         | 
| 70 | 
            +
                return_format: str = "mp3"
         | 
| 71 |  | 
| 72 | 
             
            # Enums
         | 
| 73 | 
             
            class RiskLevel(str, Enum):
         | 
|  | |
| 83 | 
             
            analysis_collection = None
         | 
| 84 | 
             
            alerts_collection = None
         | 
| 85 |  | 
| 86 | 
            +
            # JWT validation
         | 
| 87 | 
            +
            async def get_current_user(token: str = Depends(oauth2_scheme)):
         | 
| 88 | 
            +
                credentials_exception = HTTPException(
         | 
| 89 | 
            +
                    status_code=401,
         | 
| 90 | 
            +
                    detail="Could not validate credentials",
         | 
| 91 | 
            +
                    headers={"WWW-Authenticate": "Bearer"},
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
                try:
         | 
| 94 | 
            +
                    payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
         | 
| 95 | 
            +
                    email: str = payload.get("sub")
         | 
| 96 | 
            +
                    if email is None:
         | 
| 97 | 
            +
                        raise credentials_exception
         | 
| 98 | 
            +
                except JWTError:
         | 
| 99 | 
            +
                    raise credentials_exception
         | 
| 100 | 
            +
                user = await users_collection.find_one({"email": email})
         | 
| 101 | 
            +
                if user is None:
         | 
| 102 | 
            +
                    raise credentials_exception
         | 
| 103 | 
            +
                return user
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            # Helper functions (unchanged from your original code)
         | 
| 106 | 
             
            def clean_text_response(text: str) -> str:
         | 
| 107 | 
             
                text = re.sub(r'\n\s*\n', '\n\n', text)
         | 
| 108 | 
             
                text = re.sub(r'[ ]+', ' ', text)
         | 
|  | |
| 118 | 
             
                    return ""
         | 
| 119 |  | 
| 120 | 
             
            def structure_medical_response(text: str) -> Dict:
         | 
|  | |
| 121 | 
             
                def extract_improved(text: str, heading: str) -> str:
         | 
| 122 | 
             
                    patterns = [
         | 
| 123 | 
             
                        rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
         | 
|  | |
| 125 | 
             
                        rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
         | 
| 126 | 
             
                        rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
         | 
| 127 | 
             
                    ]
         | 
|  | |
| 128 | 
             
                    for pattern in patterns:
         | 
| 129 | 
             
                        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
         | 
| 130 | 
             
                        if match:
         | 
|  | |
| 132 | 
             
                            content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
         | 
| 133 | 
             
                            return content
         | 
| 134 | 
             
                    return ""
         | 
|  | |
|  | |
| 135 |  | 
| 136 | 
            +
                text = text.replace('**', '').replace('__', '')
         | 
| 137 | 
             
                return {
         | 
| 138 | 
             
                    "summary": extract_improved(text, "Summary of Patient's Medical History") or 
         | 
| 139 | 
             
                              extract_improved(text, "Summarize the patient's medical history"),
         | 
|  | |
| 146 | 
             
                }
         | 
| 147 |  | 
| 148 | 
             
            def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
         | 
|  | |
| 149 | 
             
                suicide_keywords = [
         | 
| 150 | 
             
                    'suicide', 'suicidal', 'kill myself', 'end my life', 
         | 
| 151 | 
             
                    'want to die', 'self-harm', 'self harm', 'hopeless',
         | 
| 152 | 
             
                    'no reason to live', 'plan to die'
         | 
| 153 | 
             
                ]
         | 
|  | |
| 154 | 
             
                explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
         | 
|  | |
| 155 | 
             
                if not explicit_mentions:
         | 
| 156 | 
             
                    return RiskLevel.NONE, 0.0, []
         | 
| 157 |  | 
|  | |
| 170 | 
             
                        temperature=0.2,
         | 
| 171 | 
             
                        max_new_tokens=256
         | 
| 172 | 
             
                    )
         | 
|  | |
| 173 | 
             
                    json_match = re.search(r'\{.*\}', response, re.DOTALL)
         | 
| 174 | 
             
                    if json_match:
         | 
| 175 | 
             
                        assessment = json.loads(json_match.group())
         | 
|  | |
| 189 | 
             
                return RiskLevel.LOW, risk_score, explicit_mentions
         | 
| 190 |  | 
| 191 | 
             
            async def create_alert(patient_id: str, risk_data: dict):
         | 
|  | |
| 192 | 
             
                alert_doc = {
         | 
| 193 | 
             
                    "patient_id": patient_id,
         | 
| 194 | 
             
                    "type": "suicide_risk",
         | 
|  | |
| 208 | 
             
                return patient_copy
         | 
| 209 |  | 
| 210 | 
             
            def compute_patient_data_hash(data: dict) -> str:
         | 
|  | |
| 211 | 
             
                serialized = json.dumps(data, sort_keys=True)
         | 
| 212 | 
             
                return hashlib.sha256(serialized.encode()).hexdigest()
         | 
| 213 |  | 
| 214 | 
             
            def compute_file_content_hash(file_content: bytes) -> str:
         | 
|  | |
| 215 | 
             
                return hashlib.sha256(file_content).hexdigest()
         | 
| 216 |  | 
| 217 | 
             
            def extract_text_from_pdf(pdf_data: bytes) -> str:
         | 
|  | |
| 218 | 
             
                try:
         | 
| 219 | 
             
                    pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
         | 
| 220 | 
             
                    text = ""
         | 
|  | |
| 226 | 
             
                    raise HTTPException(status_code=400, detail="Failed to extract text from PDF")
         | 
| 227 |  | 
| 228 | 
             
            async def analyze_patient_report(patient_id: Optional[str], report_content: str, file_type: str, file_content: bytes):
         | 
| 229 | 
            +
                identifier = patient_id if patient_id else compute_file_content_hash(file_content)
         | 
| 230 | 
            +
                report_data = {"identifier": identifier, "content": report_content, "file_type": file_type}
         | 
| 231 | 
            +
                report_hash = compute_patient_data_hash(report_data)
         | 
| 232 | 
            +
                logger.info(f"🧾 Analyzing report for identifier: {identifier}")
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                existing_analysis = await analysis_collection.find_one({"identifier": identifier, "report_hash": report_hash})
         | 
| 235 | 
            +
                if existing_analysis:
         | 
| 236 | 
            +
                    logger.info(f"✅ No changes in report data for {identifier}, skipping analysis")
         | 
| 237 | 
            +
                    return existing_analysis
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                prompt = (
         | 
| 240 | 
            +
                    "You are a clinical decision support AI. Analyze the following patient report:\n"
         | 
| 241 | 
            +
                    "1. Summarize the patient's medical history.\n"
         | 
| 242 | 
            +
                    "2. Identify risks or red flags (including mental health and suicide risk).\n"
         | 
| 243 | 
            +
                    "3. Highlight missed diagnoses or treatments.\n"
         | 
| 244 | 
            +
                    "4. Suggest next clinical steps.\n"
         | 
| 245 | 
            +
                    f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}"
         | 
| 246 | 
            +
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 247 |  | 
| 248 | 
            +
                raw_response = agent.chat(
         | 
| 249 | 
            +
                    message=prompt,
         | 
| 250 | 
            +
                    history=[],
         | 
| 251 | 
            +
                    temperature=0.7,
         | 
| 252 | 
            +
                    max_new_tokens=1024
         | 
| 253 | 
            +
                )
         | 
| 254 | 
            +
                structured_response = structure_medical_response(raw_response)
         | 
| 255 |  | 
| 256 | 
            +
                risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response)
         | 
| 257 | 
            +
                suicide_risk = {
         | 
| 258 | 
            +
                    "level": risk_level.value,
         | 
| 259 | 
            +
                    "score": risk_score,
         | 
| 260 | 
            +
                    "factors": risk_factors
         | 
| 261 | 
            +
                }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 262 |  | 
| 263 | 
            +
                analysis_doc = {
         | 
| 264 | 
            +
                    "identifier": identifier,
         | 
| 265 | 
            +
                    "patient_id": patient_id,
         | 
| 266 | 
            +
                    "timestamp": datetime.utcnow(),
         | 
| 267 | 
            +
                    "summary": structured_response,
         | 
| 268 | 
            +
                    "suicide_risk": suicide_risk,
         | 
| 269 | 
            +
                    "raw": raw_response,
         | 
| 270 | 
            +
                    "report_hash": report_hash,
         | 
| 271 | 
            +
                    "file_type": file_type
         | 
| 272 | 
            +
                }
         | 
| 273 |  | 
| 274 | 
            +
                await analysis_collection.update_one(
         | 
| 275 | 
            +
                    {"identifier": identifier, "report_hash": report_hash},
         | 
| 276 | 
            +
                    {"$set": analysis_doc},
         | 
| 277 | 
            +
                    upsert=True
         | 
| 278 | 
            +
                )
         | 
| 279 |  | 
| 280 | 
            +
                if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
         | 
| 281 | 
            +
                    await create_alert(patient_id, suicide_risk)
         | 
| 282 |  | 
| 283 | 
            +
                logger.info(f"✅ Stored analysis for identifier {identifier}")
         | 
| 284 | 
            +
                return analysis_doc
         | 
|  | |
| 285 |  | 
| 286 | 
             
            async def analyze_all_patients():
         | 
|  | |
| 287 | 
             
                patients = await patients_collection.find({}).to_list(length=None)
         | 
| 288 | 
             
                for patient in patients:
         | 
| 289 | 
             
                    await analyze_patient(patient)
         | 
| 290 | 
             
                    await asyncio.sleep(0.1)
         | 
| 291 |  | 
| 292 | 
             
            async def analyze_patient(patient: dict):
         | 
|  | |
| 293 | 
             
                try:
         | 
| 294 | 
             
                    serialized = serialize_patient(patient)
         | 
| 295 | 
             
                    patient_id = serialized.get("fhir_id")
         | 
|  | |
| 347 | 
             
                    logger.error(f"Error analyzing patient: {e}")
         | 
| 348 |  | 
| 349 | 
             
            def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
         | 
|  | |
| 350 | 
             
                recognizer = sr.Recognizer()
         | 
|  | |
| 351 | 
             
                try:
         | 
| 352 | 
             
                    with io.BytesIO(audio_data) as audio_file:
         | 
| 353 | 
             
                        with sr.AudioFile(audio_file) as source:
         | 
|  | |
| 365 | 
             
                    raise HTTPException(status_code=500, detail="Error processing speech")
         | 
| 366 |  | 
| 367 | 
             
            def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
         | 
|  | |
| 368 | 
             
                try:
         | 
| 369 | 
             
                    tts = gTTS(text=text, lang=language, slow=slow)
         | 
| 370 | 
             
                    mp3_fp = io.BytesIO()
         | 
|  | |
| 396 | 
             
                logger.info("✅ TxAgent initialized")
         | 
| 397 |  | 
| 398 | 
             
                db = get_mongo_client()["cps_db"]
         | 
| 399 | 
            +
                global users_collection  # Add this to access users_collection for authentication
         | 
| 400 | 
            +
                users_collection = db["users"]
         | 
| 401 | 
             
                patients_collection = db["patients"]
         | 
| 402 | 
             
                analysis_collection = db["patient_analysis_results"]
         | 
| 403 | 
             
                alerts_collection = db["clinical_alerts"]
         | 
|  | |
| 405 |  | 
| 406 | 
             
                asyncio.create_task(analyze_all_patients())
         | 
| 407 |  | 
| 408 | 
            +
            # Protected Endpoints (add Depends(get_current_user) to all endpoints)
         | 
| 409 | 
             
            @app.get("/status")
         | 
| 410 | 
            +
            async def status(current_user: dict = Depends(get_current_user)):
         | 
| 411 | 
            +
                logger.info(f"Status endpoint accessed by {current_user['email']}")
         | 
| 412 | 
             
                return {
         | 
| 413 | 
             
                    "status": "running",
         | 
| 414 | 
             
                    "timestamp": datetime.utcnow().isoformat(),
         | 
|  | |
| 417 | 
             
                }
         | 
| 418 |  | 
| 419 | 
             
            @app.get("/patients/analysis-results")
         | 
| 420 | 
            +
            async def get_patient_analysis_results(
         | 
| 421 | 
            +
                name: Optional[str] = Query(None),
         | 
| 422 | 
            +
                current_user: dict = Depends(get_current_user)
         | 
| 423 | 
            +
            ):
         | 
| 424 | 
            +
                logger.info(f"Fetching analysis results by {current_user['email']}")
         | 
| 425 | 
             
                try:
         | 
| 426 | 
             
                    query = {}
         | 
| 427 | 
             
                    if name:
         | 
|  | |
| 448 | 
             
                    raise HTTPException(status_code=500, detail="Failed to retrieve analysis results")
         | 
| 449 |  | 
| 450 | 
             
            @app.post("/chat-stream")
         | 
| 451 | 
            +
            async def chat_stream_endpoint(
         | 
| 452 | 
            +
                request: ChatRequest,
         | 
| 453 | 
            +
                current_user: dict = Depends(get_current_user)
         | 
| 454 | 
            +
            ):
         | 
| 455 | 
            +
                logger.info(f"Chat stream initiated by {current_user['email']}")
         | 
| 456 | 
             
                async def token_stream():
         | 
| 457 | 
             
                    try:
         | 
| 458 | 
             
                        conversation = [{"role": "system", "content": agent.chat_prompt}]
         | 
|  | |
| 486 | 
             
            @app.post("/voice/transcribe")
         | 
| 487 | 
             
            async def transcribe_voice(
         | 
| 488 | 
             
                audio: UploadFile = File(...),
         | 
| 489 | 
            +
                language: str = Query("en-US", description="Language code for speech recognition"),
         | 
| 490 | 
            +
                current_user: dict = Depends(get_current_user)
         | 
| 491 | 
             
            ):
         | 
| 492 | 
            +
                logger.info(f"Voice transcription initiated by {current_user['email']}")
         | 
| 493 | 
             
                try:
         | 
| 494 | 
             
                    audio_data = await audio.read()
         | 
| 495 | 
             
                    if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
         | 
|  | |
| 505 | 
             
                    raise HTTPException(status_code=500, detail="Error processing voice input")
         | 
| 506 |  | 
| 507 | 
             
            @app.post("/voice/synthesize")
         | 
| 508 | 
            +
            async def synthesize_voice(
         | 
| 509 | 
            +
                request: VoiceOutputRequest,
         | 
| 510 | 
            +
                current_user: dict = Depends(get_current_user)
         | 
| 511 | 
            +
            ):
         | 
| 512 | 
            +
                logger.info(f"Voice synthesis initiated by {current_user['email']}")
         | 
| 513 | 
             
                try:
         | 
| 514 | 
             
                    audio_data = text_to_speech(request.text, request.language, request.slow)
         | 
| 515 |  | 
|  | |
| 533 | 
             
                audio: UploadFile = File(...),
         | 
| 534 | 
             
                language: str = Query("en-US", description="Language code for speech recognition"),
         | 
| 535 | 
             
                temperature: float = Query(0.7, ge=0.1, le=1.0),
         | 
| 536 | 
            +
                max_new_tokens: int = Query(512, ge=50, le=1024),
         | 
| 537 | 
            +
                current_user: dict = Depends(get_current_user)
         | 
| 538 | 
             
            ):
         | 
| 539 | 
            +
                logger.info(f"Voice chat initiated by {current_user['email']}")
         | 
| 540 | 
             
                try:
         | 
| 541 | 
             
                    audio_data = await audio.read()
         | 
| 542 | 
             
                    user_message = recognize_speech(audio_data, language)
         | 
|  | |
| 567 | 
             
                file: UploadFile = File(...),
         | 
| 568 | 
             
                patient_id: Optional[str] = Form(None),
         | 
| 569 | 
             
                temperature: float = Form(0.5),
         | 
| 570 | 
            +
                max_new_tokens: int = Form(1024),
         | 
| 571 | 
            +
                current_user: dict = Depends(get_current_user)
         | 
| 572 | 
             
            ):
         | 
| 573 | 
            +
                logger.info(f"Report analysis initiated by {current_user['email']}")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 574 | 
             
                try:
         | 
|  | |
| 575 | 
             
                    content_type = file.content_type
         | 
| 576 | 
             
                    allowed_types = [
         | 
| 577 | 
             
                        'application/pdf',
         | 
|  | |
| 585 | 
             
                            detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX"
         | 
| 586 | 
             
                        )
         | 
| 587 |  | 
|  | |
| 588 | 
             
                    file_content = await file.read()
         | 
| 589 |  | 
|  | |
| 590 | 
             
                    if content_type == 'application/pdf':
         | 
| 591 | 
             
                        text = extract_text_from_pdf(file_content)
         | 
| 592 | 
             
                    elif content_type == 'text/plain':
         | 
|  | |
| 597 | 
             
                    else:
         | 
| 598 | 
             
                        raise HTTPException(status_code=400, detail="Unsupported file type")
         | 
| 599 |  | 
|  | |
| 600 | 
             
                    text = clean_text_response(text)
         | 
| 601 | 
             
                    if len(text.strip()) < 50:
         | 
| 602 | 
             
                        raise HTTPException(
         | 
|  | |
| 604 | 
             
                            detail="Extracted text is too short (minimum 50 characters required)"
         | 
| 605 | 
             
                        )
         | 
| 606 |  | 
|  | |
| 607 | 
             
                    analysis = await analyze_patient_report(
         | 
| 608 | 
             
                        patient_id=patient_id,
         | 
| 609 | 
             
                        report_content=text,
         | 
|  | |
| 611 | 
             
                        file_content=file_content
         | 
| 612 | 
             
                    )
         | 
| 613 |  | 
|  | |
| 614 | 
             
                    if "_id" in analysis and isinstance(analysis["_id"], ObjectId):
         | 
| 615 | 
             
                        analysis["_id"] = str(analysis["_id"])
         | 
| 616 | 
             
                    if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime):
         | 
| 617 | 
             
                        analysis["timestamp"] = analysis["timestamp"].isoformat()
         | 
| 618 |  | 
|  | |
| 619 | 
             
                    return JSONResponse(content=jsonable_encoder({
         | 
| 620 | 
             
                        "status": "success",
         | 
| 621 | 
             
                        "analysis": analysis,
         | 
|  | |
| 633 | 
             
                        detail=f"Failed to analyze report: {str(e)}"
         | 
| 634 | 
             
                    )
         | 
| 635 |  | 
|  | |
| 636 | 
             
            if __name__ == "__main__":
         | 
| 637 | 
             
                import uvicorn
         | 
| 638 | 
             
                uvicorn.run(app, host="0.0.0.0", port=8000)
         | 
