Spaces:
Running
on
A100
Running
on
A100
from fastapi import Depends, HTTPException, UploadFile, File, Query, Form | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from fastapi.encoders import jsonable_encoder | |
from config import app, agent, logger | |
from models import ChatRequest, VoiceOutputRequest, RiskLevel | |
from auth import get_current_user | |
from utils import clean_text_response | |
from analysis import analyze_patient_report, analyze_all_patients | |
from voice import recognize_speech, text_to_speech, extract_text_from_pdf | |
from docx import Document | |
import re | |
import mimetypes | |
from bson import ObjectId | |
from datetime import datetime | |
import asyncio | |
async def status(current_user: dict = Depends(get_current_user)): | |
logger.info(f"Status endpoint accessed by {current_user['email']}") | |
return { | |
"status": "running", | |
"timestamp": datetime.utcnow().isoformat(), | |
"version": "2.6.0", | |
"features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"] | |
} | |
async def get_patient_analysis_results( | |
name: Optional[str] = Query(None), | |
current_user: dict = Depends(get_current_user) | |
): | |
logger.info(f"Fetching analysis results by {current_user['email']}") | |
try: | |
query = {} | |
if name: | |
name_regex = re.compile(name, re.IGNORECASE) | |
matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None) | |
patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p] | |
if not patient_ids: | |
return [] | |
query = {"patient_id": {"$in": patient_ids}} | |
analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100) | |
enriched_results = [] | |
for analysis in analyses: | |
patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")}) | |
if patient: | |
analysis["full_name"] = patient.get("full_name", "Unknown") | |
analysis["_id"] = str(analysis["_id"]) | |
enriched_results.append(analysis) | |
return enriched_results | |
except Exception as e: | |
logger.error(f"Error fetching analysis results: {e}") | |
raise HTTPException(status_code=500, detail="Failed to retrieve analysis results") | |
async def chat_stream_endpoint( | |
request: ChatRequest, | |
current_user: dict = Depends(get_current_user) | |
): | |
logger.info(f"Chat stream initiated by {current_user['email']}") | |
async def token_stream(): | |
try: | |
conversation = [{"role": "system", "content": agent.chat_prompt}] | |
if request.history: | |
conversation.extend(request.history) | |
conversation.append({"role": "user", "content": request.message}) | |
input_ids = agent.tokenizer.apply_chat_template( | |
conversation, add_generation_prompt=True, return_tensors="pt" | |
).to(agent.device) | |
output = agent.model.generate( | |
input_ids, | |
do_sample=True, | |
temperature=request.temperature, | |
max_new_tokens=request.max_new_tokens, | |
pad_token_id=agent.tokenizer.eos_token_id, | |
return_dict_in_generate=True | |
) | |
text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) | |
for chunk in text.split(): | |
yield chunk + " " | |
await asyncio.sleep(0.05) | |
except Exception as e: | |
logger.error(f"Streaming error: {e}") | |
yield f"⚠️ Error: {e}" | |
return StreamingResponse(token_stream(), media_type="text/plain") | |
async def transcribe_voice( | |
audio: UploadFile = File(...), | |
language: str = Query("en-US", description="Language code for speech recognition"), | |
current_user: dict = Depends(get_current_user) | |
): | |
logger.info(f"Voice transcription initiated by {current_user['email']}") | |
try: | |
audio_data = await audio.read() | |
if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')): | |
raise HTTPException(status_code=400, detail="Unsupported audio format") | |
text = recognize_speech(audio_data, language) | |
return {"text": text} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in voice transcription: {e}") | |
raise HTTPException(status_code=500, detail="Error processing voice input") | |
async def synthesize_voice( | |
request: VoiceOutputRequest, | |
current_user: dict = Depends(get_current_user) | |
): | |
logger.info(f"Voice synthesis initiated by {current_user['email']}") | |
try: | |
audio_data = text_to_speech(request.text, request.language, request.slow) | |
if request.return_format == "base64": | |
return {"audio": base64.b64encode(audio_data).decode('utf-8')} | |
else: | |
return StreamingResponse( | |
io.BytesIO(audio_data), | |
media_type="audio/mpeg", | |
headers={"Content-Disposition": "attachment; filename=speech.mp3"} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in voice synthesis: {e}") | |
raise HTTPException(status_code=500, detail="Error generating voice output") | |
async def voice_chat_endpoint( | |
audio: UploadFile = File(...), | |
language: str = Query("en-US", description="Language code for speech recognition"), | |
temperature: float = Query(0.7, ge=0.1, le=1.0), | |
max_new_tokens: int = Query(512, ge=50, le=1024), | |
current_user: dict = Depends(get_current_user) | |
): | |
logger.info(f"Voice chat initiated by {current_user['email']}") | |
try: | |
audio_data = await audio.read() | |
user_message = recognize_speech(audio_data, language) | |
chat_response = agent.chat( | |
message=user_message, | |
history=[], | |
temperature=temperature, | |
max_new_tokens=max_new_tokens | |
) | |
audio_data = text_to_speech(chat_response, language.split('-')[0]) | |
return StreamingResponse( | |
io.BytesIO(audio_data), | |
media_type="audio/mpeg", | |
headers={"Content-Disposition": "attachment; filename=response.mp3"} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in voice chat: {e}") | |
raise HTTPException(status_code=500, detail="Error processing voice chat") | |
async def analyze_clinical_report( | |
file: UploadFile = File(...), | |
patient_id: Optional[str] = Form(None), | |
temperature: float = Form(0.5), | |
max_new_tokens: int = Form(1024), | |
current_user: dict = Depends(get_current_user) | |
): | |
logger.info(f"Report analysis initiated by {current_user['email']}") | |
try: | |
content_type = file.content_type | |
allowed_types = [ | |
'application/pdf', | |
'text/plain', | |
'application/vnd.openxmlformats-officedocument.wordprocessingml.document' | |
] | |
if content_type not in allowed_types: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX" | |
) | |
file_content = await file.read() | |
if content_type == 'application/pdf': | |
text = extract_text_from_pdf(file_content) | |
elif content_type == 'text/plain': | |
text = file_content.decode('utf-8') | |
elif content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': | |
doc = Document(io.BytesIO(file_content)) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
else: | |
raise HTTPException(status_code=400, detail="Unsupported file type") | |
text = clean_text_response(text) | |
if len(text.strip()) < 50: | |
raise HTTPException( | |
status_code=400, | |
detail="Extracted text is too short (minimum 50 characters required)" | |
) | |
analysis = await analyze_patient_report( | |
patient_id=patient_id, | |
report_content=text, | |
file_type=content_type, | |
file_content=file_content | |
) | |
if "_id" in analysis and isinstance(analysis["_id"], ObjectId): | |
analysis["_id"] = str(analysis["_id"]) | |
if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime): | |
analysis["timestamp"] = analysis["timestamp"].isoformat() | |
return JSONResponse(content=jsonable_encoder({ | |
"status": "success", | |
"analysis": analysis, | |
"patient_id": patient_id, | |
"file_type": content_type, | |
"file_size": len(file_content) | |
})) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in report analysis: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Failed to analyze report: {str(e)}" | |
) |