Spaces:
Runtime error
Runtime error
| import uvicorn | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from config import setup_app, agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection | |
| from endpoints import create_router | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| # Create the FastAPI app | |
| app = FastAPI(title="TxAgent API", version="2.6.0") | |
| # Apply CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| while True: | |
| # Keep connection alive | |
| await websocket.receive_text() | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected") | |
| # Setup the app (e.g., initialize globals, startup event) | |
| setup_app(app) | |
| # Create and include the router with dependencies | |
| router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
| app.include_router(router, prefix="/txagent", tags=["txagent"]) | |
| # Also include some endpoints at root level for frontend compatibility | |
| from endpoints import ChatRequest, VoiceOutputRequest | |
| from fastapi import Depends, HTTPException, UploadFile, File, Form | |
| from typing import Optional | |
| from auth import get_current_user | |
| async def chat_stream_root( | |
| request: ChatRequest, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """Chat stream endpoint at root level for frontend compatibility""" | |
| # Import the chat stream function from endpoints | |
| temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
| # Get the chat stream endpoint function | |
| for route in temp_router.routes: | |
| if hasattr(route, 'path') and route.path == "/chat-stream": | |
| return await route.endpoint(request, current_user) | |
| raise HTTPException(status_code=404, detail="Chat stream endpoint not found") | |
| async def voice_synthesize_root( | |
| request: dict, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """Voice synthesis endpoint at root level for frontend compatibility""" | |
| # Convert dict to VoiceOutputRequest | |
| voice_request = VoiceOutputRequest( | |
| text=request.get('text', ''), | |
| language=request.get('language', 'en-US'), | |
| slow=request.get('slow', False), | |
| return_format=request.get('return_format', 'mp3') | |
| ) | |
| # Get the voice synthesis endpoint function | |
| temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
| for route in temp_router.routes: | |
| if hasattr(route, 'path') and route.path == "/voice/synthesize": | |
| return await route.endpoint(voice_request, current_user) | |
| raise HTTPException(status_code=404, detail="Voice synthesis endpoint not found") | |
| async def analyze_report_root( | |
| file: UploadFile = File(...), | |
| patient_id: Optional[str] = Form(None), | |
| temperature: float = Form(0.5), | |
| max_new_tokens: int = Form(1024), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| """Report analysis endpoint at root level for frontend compatibility""" | |
| # Get the analyze report endpoint function | |
| temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection) | |
| for route in temp_router.routes: | |
| if hasattr(route, 'path') and route.path == "/analyze-report": | |
| return await route.endpoint(file, patient_id, temperature, max_new_tokens, current_user) | |
| raise HTTPException(status_code=404, detail="Analyze report endpoint not found") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |