Spaces:
Runtime error
Runtime error
import uvicorn | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from config import setup_app, agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection | |
from endpoints import create_router | |
from fastapi import WebSocket, WebSocketDisconnect | |
# Create the FastAPI app | |
app = FastAPI(title="TxAgent API", version="2.6.0") | |
# Apply CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
try: | |
while True: | |
# Keep connection alive | |
await websocket.receive_text() | |
except WebSocketDisconnect: | |
logger.info("Client disconnected") | |
# Setup the app (e.g., initialize globals, startup event) | |
setup_app(app) | |
# Create and include the router with dependencies | |
router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
app.include_router(router, prefix="/txagent", tags=["txagent"]) | |
# Also include some endpoints at root level for frontend compatibility | |
from endpoints import ChatRequest, VoiceOutputRequest | |
from fastapi import Depends, HTTPException, UploadFile, File, Form | |
from typing import Optional | |
from auth import get_current_user | |
async def chat_stream_root( | |
request: ChatRequest, | |
current_user: dict = Depends(get_current_user) | |
): | |
"""Chat stream endpoint at root level for frontend compatibility""" | |
# Import the chat stream function from endpoints | |
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
# Get the chat stream endpoint function | |
for route in temp_router.routes: | |
if hasattr(route, 'path') and route.path == "/chat-stream": | |
return await route.endpoint(request, current_user) | |
raise HTTPException(status_code=404, detail="Chat stream endpoint not found") | |
async def voice_synthesize_root( | |
request: dict, | |
current_user: dict = Depends(get_current_user) | |
): | |
"""Voice synthesis endpoint at root level for frontend compatibility""" | |
# Convert dict to VoiceOutputRequest | |
voice_request = VoiceOutputRequest( | |
text=request.get('text', ''), | |
language=request.get('language', 'en-US'), | |
slow=request.get('slow', False), | |
return_format=request.get('return_format', 'mp3') | |
) | |
# Get the voice synthesis endpoint function | |
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
for route in temp_router.routes: | |
if hasattr(route, 'path') and route.path == "/voice/synthesize": | |
return await route.endpoint(voice_request, current_user) | |
raise HTTPException(status_code=404, detail="Voice synthesis endpoint not found") | |
async def analyze_report_root( | |
file: UploadFile = File(...), | |
patient_id: Optional[str] = Form(None), | |
temperature: float = Form(0.5), | |
max_new_tokens: int = Form(1024), | |
current_user: dict = Depends(get_current_user) | |
): | |
"""Report analysis endpoint at root level for frontend compatibility""" | |
# Get the analyze report endpoint function | |
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
for route in temp_router.routes: | |
if hasattr(route, 'path') and route.path == "/analyze-report": | |
return await route.endpoint(file, patient_id, temperature, max_new_tokens, current_user) | |
raise HTTPException(status_code=404, detail="Analyze report endpoint not found") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |