|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Form |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from fastapi.encoders import jsonable_encoder |
|
from typing import Optional, List |
|
from pydantic import BaseModel |
|
from auth import get_current_user |
|
from utils import clean_text_response |
|
from analysis import analyze_patient_report |
|
from voice import recognize_speech, text_to_speech, extract_text_from_pdf |
|
from docx import Document |
|
import re |
|
import io |
|
from datetime import datetime |
|
from bson import ObjectId |
|
import asyncio |
|
from bson.errors import InvalidId |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
message: str |
|
history: Optional[List[dict]] = None |
|
format: Optional[str] = "clean" |
|
temperature: Optional[float] = 0.7 |
|
max_new_tokens: Optional[int] = 512 |
|
patient_id: Optional[str] = None |
|
|
|
class VoiceOutputRequest(BaseModel): |
|
text: str |
|
language: str = "en-US" |
|
slow: bool = False |
|
return_format: str = "mp3" |
|
|
|
class RiskLevel(BaseModel): |
|
level: str |
|
score: float |
|
factors: Optional[List[str]] = None |
|
|
|
def create_router(agent, logger, patients_collection, analysis_collection, users_collection): |
|
router = APIRouter() |
|
|
|
@router.get("/status") |
|
async def status(current_user: dict = Depends(get_current_user)): |
|
logger.info(f"Status endpoint accessed by {current_user['email']}") |
|
return { |
|
"status": "running", |
|
"timestamp": datetime.utcnow().isoformat(), |
|
"version": "2.6.0", |
|
"features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"] |
|
} |
|
|
|
@router.get("/patients/analysis-results") |
|
async def get_patient_analysis_results( |
|
name: Optional[str] = Query(None), |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Fetching analysis results by {current_user['email']}") |
|
try: |
|
query = {} |
|
if name: |
|
name_regex = re.compile(name, re.IGNORECASE) |
|
matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None) |
|
patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p] |
|
if not patient_ids: |
|
return [] |
|
query = {"patient_id": {"$in": patient_ids}} |
|
|
|
analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100) |
|
enriched_results = [] |
|
for analysis in analyses: |
|
patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")}) |
|
if not patient: |
|
continue |
|
analysis["full_name"] = patient.get("full_name", "Unknown") |
|
analysis["_id"] = str(analysis["_id"]) |
|
enriched_results.append(analysis) |
|
|
|
return enriched_results |
|
|
|
except Exception as e: |
|
logger.error(f"Error fetching analysis results: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to retrieve analysis results") |
|
|
|
@router.post("/chat-stream") |
|
async def chat_stream_endpoint( |
|
request: ChatRequest, |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Chat stream initiated by {current_user['email']}") |
|
async def token_stream(): |
|
try: |
|
conversation = [{"role": "system", "content": agent.chat_prompt}] |
|
if request.history: |
|
conversation.extend(request.history) |
|
conversation.append({"role": "user", "content": request.message}) |
|
|
|
input_ids = agent.tokenizer.apply_chat_template( |
|
conversation, add_generation_prompt=True, return_tensors="pt" |
|
).to(agent.device) |
|
|
|
output = agent.model.generate( |
|
input_ids, |
|
do_sample=True, |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens, |
|
pad_token_id=agent.tokenizer.eos_token_id, |
|
return_dict_in_generate=True |
|
) |
|
|
|
text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) |
|
cleaned_text = clean_text_response(text) |
|
full_response = "" |
|
|
|
|
|
chat_entry = { |
|
"user_id": current_user["email"], |
|
"patient_id": request.patient_id, |
|
"message": request.message, |
|
"response": cleaned_text, |
|
"chat_type": "chat", |
|
"timestamp": datetime.utcnow(), |
|
"temperature": request.temperature, |
|
"max_new_tokens": request.max_new_tokens |
|
} |
|
result = await analysis_collection.insert_one(chat_entry) |
|
chat_entry["_id"] = str(result.inserted_id) |
|
|
|
for chunk in cleaned_text.split(): |
|
full_response += chunk + " " |
|
yield chunk + " " |
|
await asyncio.sleep(0.05) |
|
|
|
|
|
await analysis_collection.update_one( |
|
{"_id": result.inserted_id}, |
|
{"$set": {"response": full_response}} |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Streaming error: {e}") |
|
yield f"⚠️ Error: {e}" |
|
|
|
return StreamingResponse(token_stream(), media_type="text/plain") |
|
|
|
@router.get("/chats") |
|
async def get_chats( |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Fetching chats for {current_user['email']}") |
|
try: |
|
chats = await analysis_collection.find({"user_id": current_user["email"], "chat_type": "chat"}).sort("timestamp", -1).to_list(length=100) |
|
return [ |
|
{ |
|
"id": str(chat["_id"]), |
|
"title": chat.get("message", "Untitled Chat")[:30], |
|
"timestamp": chat["timestamp"].isoformat(), |
|
"message": chat["message"], |
|
"response": chat["response"] |
|
} |
|
for chat in chats |
|
] |
|
except Exception as e: |
|
logger.error(f"Error fetching chats: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to retrieve chats") |
|
|
|
@router.post("/voice/transcribe") |
|
async def transcribe_voice( |
|
audio: UploadFile = File(...), |
|
language: str = Query("en-US", description="Language code for speech recognition"), |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Voice transcription initiated by {current_user['email']}") |
|
try: |
|
audio_data = await audio.read() |
|
if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')): |
|
raise HTTPException(status_code=400, detail="Unsupported audio format") |
|
|
|
text = recognize_speech(audio_data, language) |
|
return {"text": text} |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error in voice transcription: {e}") |
|
raise HTTPException(status_code=500, detail="Error processing voice input") |
|
|
|
@router.post("/voice/synthesize") |
|
async def synthesize_voice( |
|
request: VoiceOutputRequest, |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Voice synthesis initiated by {current_user['email']}") |
|
try: |
|
audio_data = text_to_speech(request.text, request.language, request.slow) |
|
|
|
if request.return_format == "base64": |
|
return {"audio": base64.b64encode(audio_data).decode('utf-8')} |
|
else: |
|
return StreamingResponse( |
|
io.BytesIO(audio_data), |
|
media_type="audio/mpeg", |
|
headers={"Content-Disposition": "attachment; filename=speech.mp3"} |
|
) |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error in voice synthesis: {e}") |
|
raise HTTPException(status_code=500, detail="Error generating voice output") |
|
|
|
@router.post("/voice/chat") |
|
async def voice_chat_endpoint( |
|
audio: UploadFile = File(...), |
|
language: str = Query("en-US", description="Language code for speech recognition"), |
|
temperature: float = Query(0.7, ge=0.1, le=1.0), |
|
max_new_tokens: int = Query(512, ge=50, le=1024), |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Voice chat initiated by {current_user['email']}") |
|
try: |
|
audio_data = await audio.read() |
|
user_message = recognize_speech(audio_data, language) |
|
|
|
chat_response = agent.chat( |
|
message=user_message, |
|
history=[], |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens |
|
) |
|
|
|
audio_data = text_to_speech(chat_response, language.split('-')[0]) |
|
|
|
|
|
chat_entry = { |
|
"user_id": current_user["email"], |
|
"patient_id": None, |
|
"message": user_message, |
|
"response": chat_response, |
|
"chat_type": "voice_chat", |
|
"timestamp": datetime.utcnow(), |
|
"temperature": temperature, |
|
"max_new_tokens": max_new_tokens |
|
} |
|
result = await analysis_collection.insert_one(chat_entry) |
|
chat_entry["_id"] = str(result.inserted_id) |
|
|
|
return StreamingResponse( |
|
io.BytesIO(audio_data), |
|
media_type="audio/mpeg", |
|
headers={"Content-Disposition": "attachment; filename=response.mp3"} |
|
) |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error in voice chat: {e}") |
|
raise HTTPException(status_code=500, detail="Error processing voice chat") |
|
|
|
@router.post("/analyze-report") |
|
async def analyze_clinical_report( |
|
file: UploadFile = File(...), |
|
patient_id: Optional[str] = Form(None), |
|
temperature: float = Form(0.5), |
|
max_new_tokens: int = Form(1024), |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Report analysis initiated by {current_user['email']}") |
|
try: |
|
content_type = file.content_type |
|
allowed_types = [ |
|
'application/pdf', |
|
'text/plain', |
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document' |
|
] |
|
|
|
if content_type not in allowed_types: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX" |
|
) |
|
|
|
file_content = await file.read() |
|
|
|
if content_type == 'application/pdf': |
|
text = extract_text_from_pdf(file_content) |
|
elif content_type == 'text/plain': |
|
text = file_content.decode('utf-8') |
|
elif content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': |
|
doc = Document(io.BytesIO(file_content)) |
|
text = "\n".join([para.text for para in doc.paragraphs]) |
|
else: |
|
raise HTTPException(status_code=400, detail="Unsupported file type") |
|
|
|
text = clean_text_response(text) |
|
if len(text.strip()) < 50: |
|
raise HTTPException( |
|
status_code=400, |
|
detail="Extracted text is too short (minimum 50 characters required)" |
|
) |
|
|
|
analysis = await analyze_patient_report( |
|
patient_id=patient_id, |
|
report_content=text, |
|
file_type=content_type, |
|
file_content=file_content |
|
) |
|
|
|
if "_id" in analysis and isinstance(analysis["_id"], ObjectId): |
|
analysis["_id"] = str(analysis["_id"]) |
|
if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime): |
|
analysis["timestamp"] = analysis["timestamp"].isoformat() |
|
|
|
return JSONResponse(content=jsonable_encoder({ |
|
"status": "success", |
|
"analysis": analysis, |
|
"patient_id": patient_id, |
|
"file_type": content_type, |
|
"file_size": len(file_content) |
|
})) |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error in report analysis: {str(e)}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Failed to analyze report: {str(e)}" |
|
) |
|
|
|
@router.delete("/patients/{patient_id}") |
|
async def delete_patient( |
|
patient_id: str, |
|
current_user: dict = Depends(get_current_user) |
|
): |
|
logger.info(f"Patient deletion initiated by {current_user['email']} for patient {patient_id}") |
|
try: |
|
|
|
patient = await patients_collection.find_one({"fhir_id": patient_id}) |
|
if not patient: |
|
raise HTTPException(status_code=404, detail="Patient not found") |
|
|
|
|
|
if patient.get("created_by") != current_user["email"] and not current_user.get("is_admin", False): |
|
raise HTTPException(status_code=403, detail="Not authorized to delete this patient") |
|
|
|
|
|
await analysis_collection.delete_many({"patient_id": patient_id}) |
|
logger.info(f"Deleted analyses and chats for patient {patient_id}") |
|
|
|
|
|
await patients_collection.delete_one({"fhir_id": patient_id}) |
|
logger.info(f"Patient {patient_id} deleted successfully") |
|
|
|
return {"status": "success", "message": f"Patient {patient_id} and associated analyses/chats deleted"} |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error deleting patient {patient_id}: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Failed to delete patient: {str(e)}") |
|
|
|
return router |