from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Form from fastapi.responses import StreamingResponse, JSONResponse from fastapi.encoders import jsonable_encoder from typing import Optional, List from pydantic import BaseModel from auth import get_current_user from utils import clean_text_response from analysis import analyze_patient_report from voice import recognize_speech, text_to_speech, extract_text_from_pdf from docx import Document import re import io from datetime import datetime from bson import ObjectId import asyncio from bson.errors import InvalidId # Define the ChatRequest model with an optional patient_id class ChatRequest(BaseModel): message: str history: Optional[List[dict]] = None format: Optional[str] = "clean" temperature: Optional[float] = 0.7 max_new_tokens: Optional[int] = 512 patient_id: Optional[str] = None # Added optional patient_id field class VoiceOutputRequest(BaseModel): text: str language: str = "en-US" slow: bool = False return_format: str = "mp3" class RiskLevel(BaseModel): level: str score: float factors: Optional[List[str]] = None def create_router(agent, logger, patients_collection, analysis_collection, users_collection): router = APIRouter() @router.get("/status") async def status(current_user: dict = Depends(get_current_user)): logger.info(f"Status endpoint accessed by {current_user['email']}") return { "status": "running", "timestamp": datetime.utcnow().isoformat(), "version": "2.6.0", "features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"] } @router.get("/patients/analysis-results") async def get_patient_analysis_results( name: Optional[str] = Query(None), current_user: dict = Depends(get_current_user) ): logger.info(f"Fetching analysis results by {current_user['email']}") try: query = {} if name: name_regex = re.compile(name, re.IGNORECASE) matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None) patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p] if not patient_ids: return [] query = {"patient_id": {"$in": patient_ids}} analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100) enriched_results = [] for analysis in analyses: patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")}) if not patient: continue # Skip if patient no longer exists analysis["full_name"] = patient.get("full_name", "Unknown") analysis["_id"] = str(analysis["_id"]) enriched_results.append(analysis) return enriched_results except Exception as e: logger.error(f"Error fetching analysis results: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve analysis results") @router.post("/chat-stream") async def chat_stream_endpoint( request: ChatRequest, current_user: dict = Depends(get_current_user) ): logger.info(f"Chat stream initiated by {current_user['email']}") async def token_stream(): try: conversation = [{"role": "system", "content": agent.chat_prompt}] if request.history: conversation.extend(request.history) conversation.append({"role": "user", "content": request.message}) input_ids = agent.tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(agent.device) output = agent.model.generate( input_ids, do_sample=True, temperature=request.temperature, max_new_tokens=request.max_new_tokens, pad_token_id=agent.tokenizer.eos_token_id, return_dict_in_generate=True ) text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) cleaned_text = clean_text_response(text) full_response = "" # Store chat session in database chat_entry = { "user_id": current_user["email"], "patient_id": request.patient_id, # Now safely optional, defaults to None "message": request.message, "response": cleaned_text, "chat_type": "chat", "timestamp": datetime.utcnow(), "temperature": request.temperature, "max_new_tokens": request.max_new_tokens } result = await analysis_collection.insert_one(chat_entry) chat_entry["_id"] = str(result.inserted_id) for chunk in cleaned_text.split(): full_response += chunk + " " yield chunk + " " await asyncio.sleep(0.05) # Update chat entry with full response await analysis_collection.update_one( {"_id": result.inserted_id}, {"$set": {"response": full_response}} ) except Exception as e: logger.error(f"Streaming error: {e}") yield f"⚠️ Error: {e}" return StreamingResponse(token_stream(), media_type="text/plain") @router.get("/chats") async def get_chats( current_user: dict = Depends(get_current_user) ): logger.info(f"Fetching chats for {current_user['email']}") try: chats = await analysis_collection.find({"user_id": current_user["email"], "chat_type": "chat"}).sort("timestamp", -1).to_list(length=100) return [ { "id": str(chat["_id"]), "title": chat.get("message", "Untitled Chat")[:30], # First 30 chars of message as title "timestamp": chat["timestamp"].isoformat(), "message": chat["message"], "response": chat["response"] } for chat in chats ] except Exception as e: logger.error(f"Error fetching chats: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve chats") @router.post("/voice/transcribe") async def transcribe_voice( audio: UploadFile = File(...), language: str = Query("en-US", description="Language code for speech recognition"), current_user: dict = Depends(get_current_user) ): logger.info(f"Voice transcription initiated by {current_user['email']}") try: audio_data = await audio.read() if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')): raise HTTPException(status_code=400, detail="Unsupported audio format") text = recognize_speech(audio_data, language) return {"text": text} except HTTPException: raise except Exception as e: logger.error(f"Error in voice transcription: {e}") raise HTTPException(status_code=500, detail="Error processing voice input") @router.post("/voice/synthesize") async def synthesize_voice( request: VoiceOutputRequest, current_user: dict = Depends(get_current_user) ): logger.info(f"Voice synthesis initiated by {current_user['email']}") try: audio_data = text_to_speech(request.text, request.language, request.slow) if request.return_format == "base64": return {"audio": base64.b64encode(audio_data).decode('utf-8')} else: return StreamingResponse( io.BytesIO(audio_data), media_type="audio/mpeg", headers={"Content-Disposition": "attachment; filename=speech.mp3"} ) except HTTPException: raise except Exception as e: logger.error(f"Error in voice synthesis: {e}") raise HTTPException(status_code=500, detail="Error generating voice output") @router.post("/voice/chat") async def voice_chat_endpoint( audio: UploadFile = File(...), language: str = Query("en-US", description="Language code for speech recognition"), temperature: float = Query(0.7, ge=0.1, le=1.0), max_new_tokens: int = Query(512, ge=50, le=1024), current_user: dict = Depends(get_current_user) ): logger.info(f"Voice chat initiated by {current_user['email']}") try: audio_data = await audio.read() user_message = recognize_speech(audio_data, language) chat_response = agent.chat( message=user_message, history=[], temperature=temperature, max_new_tokens=max_new_tokens ) audio_data = text_to_speech(chat_response, language.split('-')[0]) # Store voice chat in database chat_entry = { "user_id": current_user["email"], "patient_id": None, "message": user_message, "response": chat_response, "chat_type": "voice_chat", "timestamp": datetime.utcnow(), "temperature": temperature, "max_new_tokens": max_new_tokens } result = await analysis_collection.insert_one(chat_entry) chat_entry["_id"] = str(result.inserted_id) return StreamingResponse( io.BytesIO(audio_data), media_type="audio/mpeg", headers={"Content-Disposition": "attachment; filename=response.mp3"} ) except HTTPException: raise except Exception as e: logger.error(f"Error in voice chat: {e}") raise HTTPException(status_code=500, detail="Error processing voice chat") @router.post("/analyze-report") async def analyze_clinical_report( file: UploadFile = File(...), patient_id: Optional[str] = Form(None), temperature: float = Form(0.5), max_new_tokens: int = Form(1024), current_user: dict = Depends(get_current_user) ): logger.info(f"Report analysis initiated by {current_user['email']}") try: content_type = file.content_type allowed_types = [ 'application/pdf', 'text/plain', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' ] if content_type not in allowed_types: raise HTTPException( status_code=400, detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX" ) file_content = await file.read() if content_type == 'application/pdf': text = extract_text_from_pdf(file_content) elif content_type == 'text/plain': text = file_content.decode('utf-8') elif content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': doc = Document(io.BytesIO(file_content)) text = "\n".join([para.text for para in doc.paragraphs]) else: raise HTTPException(status_code=400, detail="Unsupported file type") text = clean_text_response(text) if len(text.strip()) < 50: raise HTTPException( status_code=400, detail="Extracted text is too short (minimum 50 characters required)" ) analysis = await analyze_patient_report( patient_id=patient_id, report_content=text, file_type=content_type, file_content=file_content ) if "_id" in analysis and isinstance(analysis["_id"], ObjectId): analysis["_id"] = str(analysis["_id"]) if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime): analysis["timestamp"] = analysis["timestamp"].isoformat() return JSONResponse(content=jsonable_encoder({ "status": "success", "analysis": analysis, "patient_id": patient_id, "file_type": content_type, "file_size": len(file_content) })) except HTTPException: raise except Exception as e: logger.error(f"Error in report analysis: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to analyze report: {str(e)}" ) @router.delete("/patients/{patient_id}") async def delete_patient( patient_id: str, current_user: dict = Depends(get_current_user) ): logger.info(f"Patient deletion initiated by {current_user['email']} for patient {patient_id}") try: # Check if the patient exists patient = await patients_collection.find_one({"fhir_id": patient_id}) if not patient: raise HTTPException(status_code=404, detail="Patient not found") # Check if the current user is authorized (e.g., created_by matches or is admin) if patient.get("created_by") != current_user["email"] and not current_user.get("is_admin", False): raise HTTPException(status_code=403, detail="Not authorized to delete this patient") # Delete all analyses and chats associated with this patient await analysis_collection.delete_many({"patient_id": patient_id}) logger.info(f"Deleted analyses and chats for patient {patient_id}") # Delete the patient await patients_collection.delete_one({"fhir_id": patient_id}) logger.info(f"Patient {patient_id} deleted successfully") return {"status": "success", "message": f"Patient {patient_id} and associated analyses/chats deleted"} except HTTPException: raise except Exception as e: logger.error(f"Error deleting patient {patient_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to delete patient: {str(e)}") return router