TxAgent-Api / endpoints.py
Ali2206's picture
Update endpoints.py
534b887 verified
raw
history blame
15 kB
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Form
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.encoders import jsonable_encoder
from typing import Optional, List
from pydantic import BaseModel
from auth import get_current_user
from utils import clean_text_response
from analysis import analyze_patient_report
from voice import recognize_speech, text_to_speech, extract_text_from_pdf
from docx import Document
import re
import io
from datetime import datetime
from bson import ObjectId
import asyncio
from bson.errors import InvalidId
# Define the ChatRequest model with an optional patient_id
class ChatRequest(BaseModel):
message: str
history: Optional[List[dict]] = None
format: Optional[str] = "clean"
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = 512
patient_id: Optional[str] = None # Added optional patient_id field
class VoiceOutputRequest(BaseModel):
text: str
language: str = "en-US"
slow: bool = False
return_format: str = "mp3"
class RiskLevel(BaseModel):
level: str
score: float
factors: Optional[List[str]] = None
def create_router(agent, logger, patients_collection, analysis_collection, users_collection):
router = APIRouter()
@router.get("/status")
async def status(current_user: dict = Depends(get_current_user)):
logger.info(f"Status endpoint accessed by {current_user['email']}")
return {
"status": "running",
"timestamp": datetime.utcnow().isoformat(),
"version": "2.6.0",
"features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"]
}
@router.get("/patients/analysis-results")
async def get_patient_analysis_results(
name: Optional[str] = Query(None),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Fetching analysis results by {current_user['email']}")
try:
query = {}
if name:
name_regex = re.compile(name, re.IGNORECASE)
matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None)
patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p]
if not patient_ids:
return []
query = {"patient_id": {"$in": patient_ids}}
analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100)
enriched_results = []
for analysis in analyses:
patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")})
if not patient:
continue # Skip if patient no longer exists
analysis["full_name"] = patient.get("full_name", "Unknown")
analysis["_id"] = str(analysis["_id"])
enriched_results.append(analysis)
return enriched_results
except Exception as e:
logger.error(f"Error fetching analysis results: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve analysis results")
@router.post("/chat-stream")
async def chat_stream_endpoint(
request: ChatRequest,
current_user: dict = Depends(get_current_user)
):
logger.info(f"Chat stream initiated by {current_user['email']}")
async def token_stream():
try:
conversation = [{"role": "system", "content": agent.chat_prompt}]
if request.history:
conversation.extend(request.history)
conversation.append({"role": "user", "content": request.message})
input_ids = agent.tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt"
).to(agent.device)
output = agent.model.generate(
input_ids,
do_sample=True,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
pad_token_id=agent.tokenizer.eos_token_id,
return_dict_in_generate=True
)
text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True)
cleaned_text = clean_text_response(text)
full_response = ""
# Store chat session in database
chat_entry = {
"user_id": current_user["email"],
"patient_id": request.patient_id, # Now safely optional, defaults to None
"message": request.message,
"response": cleaned_text,
"chat_type": "chat",
"timestamp": datetime.utcnow(),
"temperature": request.temperature,
"max_new_tokens": request.max_new_tokens
}
result = await analysis_collection.insert_one(chat_entry)
chat_entry["_id"] = str(result.inserted_id)
for chunk in cleaned_text.split():
full_response += chunk + " "
yield chunk + " "
await asyncio.sleep(0.05)
# Update chat entry with full response
await analysis_collection.update_one(
{"_id": result.inserted_id},
{"$set": {"response": full_response}}
)
except Exception as e:
logger.error(f"Streaming error: {e}")
yield f"⚠️ Error: {e}"
return StreamingResponse(token_stream(), media_type="text/plain")
@router.get("/chats")
async def get_chats(
current_user: dict = Depends(get_current_user)
):
logger.info(f"Fetching chats for {current_user['email']}")
try:
chats = await analysis_collection.find({"user_id": current_user["email"], "chat_type": "chat"}).sort("timestamp", -1).to_list(length=100)
return [
{
"id": str(chat["_id"]),
"title": chat.get("message", "Untitled Chat")[:30], # First 30 chars of message as title
"timestamp": chat["timestamp"].isoformat(),
"message": chat["message"],
"response": chat["response"]
}
for chat in chats
]
except Exception as e:
logger.error(f"Error fetching chats: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve chats")
@router.post("/voice/transcribe")
async def transcribe_voice(
audio: UploadFile = File(...),
language: str = Query("en-US", description="Language code for speech recognition"),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Voice transcription initiated by {current_user['email']}")
try:
audio_data = await audio.read()
if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
raise HTTPException(status_code=400, detail="Unsupported audio format")
text = recognize_speech(audio_data, language)
return {"text": text}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in voice transcription: {e}")
raise HTTPException(status_code=500, detail="Error processing voice input")
@router.post("/voice/synthesize")
async def synthesize_voice(
request: VoiceOutputRequest,
current_user: dict = Depends(get_current_user)
):
logger.info(f"Voice synthesis initiated by {current_user['email']}")
try:
audio_data = text_to_speech(request.text, request.language, request.slow)
if request.return_format == "base64":
return {"audio": base64.b64encode(audio_data).decode('utf-8')}
else:
return StreamingResponse(
io.BytesIO(audio_data),
media_type="audio/mpeg",
headers={"Content-Disposition": "attachment; filename=speech.mp3"}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in voice synthesis: {e}")
raise HTTPException(status_code=500, detail="Error generating voice output")
@router.post("/voice/chat")
async def voice_chat_endpoint(
audio: UploadFile = File(...),
language: str = Query("en-US", description="Language code for speech recognition"),
temperature: float = Query(0.7, ge=0.1, le=1.0),
max_new_tokens: int = Query(512, ge=50, le=1024),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Voice chat initiated by {current_user['email']}")
try:
audio_data = await audio.read()
user_message = recognize_speech(audio_data, language)
chat_response = agent.chat(
message=user_message,
history=[],
temperature=temperature,
max_new_tokens=max_new_tokens
)
audio_data = text_to_speech(chat_response, language.split('-')[0])
# Store voice chat in database
chat_entry = {
"user_id": current_user["email"],
"patient_id": None,
"message": user_message,
"response": chat_response,
"chat_type": "voice_chat",
"timestamp": datetime.utcnow(),
"temperature": temperature,
"max_new_tokens": max_new_tokens
}
result = await analysis_collection.insert_one(chat_entry)
chat_entry["_id"] = str(result.inserted_id)
return StreamingResponse(
io.BytesIO(audio_data),
media_type="audio/mpeg",
headers={"Content-Disposition": "attachment; filename=response.mp3"}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in voice chat: {e}")
raise HTTPException(status_code=500, detail="Error processing voice chat")
@router.post("/analyze-report")
async def analyze_clinical_report(
file: UploadFile = File(...),
patient_id: Optional[str] = Form(None),
temperature: float = Form(0.5),
max_new_tokens: int = Form(1024),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Report analysis initiated by {current_user['email']}")
try:
content_type = file.content_type
allowed_types = [
'application/pdf',
'text/plain',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
]
if content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX"
)
file_content = await file.read()
if content_type == 'application/pdf':
text = extract_text_from_pdf(file_content)
elif content_type == 'text/plain':
text = file_content.decode('utf-8')
elif content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
doc = Document(io.BytesIO(file_content))
text = "\n".join([para.text for para in doc.paragraphs])
else:
raise HTTPException(status_code=400, detail="Unsupported file type")
text = clean_text_response(text)
if len(text.strip()) < 50:
raise HTTPException(
status_code=400,
detail="Extracted text is too short (minimum 50 characters required)"
)
analysis = await analyze_patient_report(
patient_id=patient_id,
report_content=text,
file_type=content_type,
file_content=file_content
)
if "_id" in analysis and isinstance(analysis["_id"], ObjectId):
analysis["_id"] = str(analysis["_id"])
if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime):
analysis["timestamp"] = analysis["timestamp"].isoformat()
return JSONResponse(content=jsonable_encoder({
"status": "success",
"analysis": analysis,
"patient_id": patient_id,
"file_type": content_type,
"file_size": len(file_content)
}))
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in report analysis: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to analyze report: {str(e)}"
)
@router.delete("/patients/{patient_id}")
async def delete_patient(
patient_id: str,
current_user: dict = Depends(get_current_user)
):
logger.info(f"Patient deletion initiated by {current_user['email']} for patient {patient_id}")
try:
# Check if the patient exists
patient = await patients_collection.find_one({"fhir_id": patient_id})
if not patient:
raise HTTPException(status_code=404, detail="Patient not found")
# Check if the current user is authorized (e.g., created_by matches or is admin)
if patient.get("created_by") != current_user["email"] and not current_user.get("is_admin", False):
raise HTTPException(status_code=403, detail="Not authorized to delete this patient")
# Delete all analyses and chats associated with this patient
await analysis_collection.delete_many({"patient_id": patient_id})
logger.info(f"Deleted analyses and chats for patient {patient_id}")
# Delete the patient
await patients_collection.delete_one({"fhir_id": patient_id})
logger.info(f"Patient {patient_id} deleted successfully")
return {"status": "success", "message": f"Patient {patient_id} and associated analyses/chats deleted"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting patient {patient_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete patient: {str(e)}")
return router