TxAgent-Api / utils.py
Ali2206's picture
Update utils.py
568ac0d verified
import re
import hashlib
import io
import json
from datetime import datetime
from typing import Dict, List, Tuple
from bson import ObjectId
import logging
from config import logger
# Add to your utils.py
from fastapi import WebSocket
import asyncio
class NotificationManager:
def __init__(self):
self.active_connections = {}
self.notification_queue = asyncio.Queue()
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
self.active_connections[user_id] = websocket
def disconnect(self, user_id: str):
if user_id in self.active_connections:
del self.active_connections[user_id]
async def broadcast_notification(self, notification: dict):
"""Broadcast to all connected clients"""
for connection in self.active_connections.values():
try:
await connection.send_json({
"type": "notification",
"data": notification
})
except Exception as e:
logger.error(f"Error sending notification: {e}")
notification_manager = NotificationManager()
async def broadcast_notification(notification: dict):
"""Broadcast notification to relevant users"""
# Determine recipients based on notification type/priority
recipients = []
if notification["priority"] == "high":
recipients = ["psychiatrist", "emergency_team", "primary_care"]
else:
recipients = ["primary_care", "case_manager"]
# Add to each recipient's notification queue
await notification_manager.notification_queue.put({
"recipients": recipients,
"notification": notification
})
def clean_text_response(text: str) -> str:
text = re.sub(r'\n\s*\n', '\n\n', text)
text = re.sub(r'[ ]+', ' ', text)
return text.replace("**", "").replace("__", "").strip()
def extract_section(text: str, heading: str) -> str:
try:
pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else ""
except Exception as e:
logger.error(f"Section extraction failed for heading '{heading}': {e}")
return ""
def structure_medical_response(text: str) -> Dict:
def extract_improved(text: str, heading: str) -> str:
patterns = [
rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)",
rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
]
for pattern in patterns:
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
content = match.group(1).strip()
content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
return content
return ""
text = text.replace('**', '').replace('__', '')
return {
"summary": extract_improved(text, "Summary of Patient's Medical History") or
extract_improved(text, "Summarize the patient's medical history"),
"risks": extract_improved(text, "Identify Risks or Red Flags") or
extract_improved(text, "Risks or Red Flags"),
"missed_issues": extract_improved(text, "Missed Diagnoses or Treatments") or
extract_improved(text, "What the doctor might have missed"),
"recommendations": extract_improved(text, "Suggest Next Clinical Steps") or
extract_improved(text, "Suggested Clinical Actions")
}
def serialize_patient(patient: dict) -> dict:
patient_copy = patient.copy()
if "_id" in patient_copy:
patient_copy["_id"] = str(patient_copy["_id"])
return patient_copy
def compute_patient_data_hash(data: dict) -> str:
serialized = json.dumps(data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
def compute_file_content_hash(file_content: bytes) -> str:
return hashlib.sha256(file_content).hexdigest()