TxAgent-Api / app.py
Ali2206's picture
Update app.py
0ccca39 verified
raw
history blame
4.07 kB
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from config import setup_app, agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection
from endpoints import create_router
from fastapi import WebSocket, WebSocketDisconnect
# Create the FastAPI app
app = FastAPI(title="TxAgent API", version="2.6.0")
# Apply CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.websocket("/ws/notifications")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Keep connection alive
await websocket.receive_text()
except WebSocketDisconnect:
logger.info("Client disconnected")
# Setup the app (e.g., initialize globals, startup event)
setup_app(app)
# Create and include the router with dependencies
router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
app.include_router(router, prefix="/txagent", tags=["txagent"])
# Also include some endpoints at root level for frontend compatibility
from endpoints import ChatRequest, VoiceOutputRequest
from fastapi import Depends, HTTPException, UploadFile, File, Form
from typing import Optional
from auth import get_current_user
@app.post("/chat-stream")
async def chat_stream_root(
request: ChatRequest,
current_user: dict = Depends(get_current_user)
):
"""Chat stream endpoint at root level for frontend compatibility"""
# Import the chat stream function from endpoints
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
# Get the chat stream endpoint function
for route in temp_router.routes:
if hasattr(route, 'path') and route.path == "/chat-stream":
return await route.endpoint(request, current_user)
raise HTTPException(status_code=404, detail="Chat stream endpoint not found")
@app.post("/voice/synthesize")
async def voice_synthesize_root(
request: dict,
current_user: dict = Depends(get_current_user)
):
"""Voice synthesis endpoint at root level for frontend compatibility"""
# Convert dict to VoiceOutputRequest
voice_request = VoiceOutputRequest(
text=request.get('text', ''),
language=request.get('language', 'en-US'),
slow=request.get('slow', False),
return_format=request.get('return_format', 'mp3')
)
# Get the voice synthesis endpoint function
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
for route in temp_router.routes:
if hasattr(route, 'path') and route.path == "/voice/synthesize":
return await route.endpoint(voice_request, current_user)
raise HTTPException(status_code=404, detail="Voice synthesis endpoint not found")
@app.post("/analyze-report")
async def analyze_report_root(
file: UploadFile = File(...),
patient_id: Optional[str] = Form(None),
temperature: float = Form(0.5),
max_new_tokens: int = Form(1024),
current_user: dict = Depends(get_current_user)
):
"""Report analysis endpoint at root level for frontend compatibility"""
# Get the analyze report endpoint function
temp_router = create_router(agent, logger, patients_collection, analysis_collection, users_collection, chats_collection, notifications_collection)
for route in temp_router.routes:
if hasattr(route, 'path') and route.path == "/analyze-report":
return await route.endpoint(file, patient_id, temperature, max_new_tokens, current_user)
raise HTTPException(status_code=404, detail="Analyze report endpoint not found")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)