radames's picture
simplify type annotations and remove unused TypeVars
2ab3299
raw
history blame
3.95 kB
from uuid import UUID
import asyncio
from fastapi import WebSocket
from starlette.websockets import WebSocketState
import logging
from typing import Any
from util import ParamsModel
Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]]
class ServerFullException(Exception):
"""Exception raised when the server is full."""
pass
class ConnectionManager:
def __init__(self):
self.active_connections: Connections = {}
async def connect(
self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
):
await websocket.accept()
user_count = self.get_user_count()
print(f"User count: {user_count}")
if max_queue_size > 0 and user_count >= max_queue_size:
print("Server is full")
await websocket.send_json({"status": "error", "message": "Server is full"})
await websocket.close()
raise ServerFullException("Server is full")
print(f"New user connected: {user_id}")
self.active_connections[user_id] = {
"websocket": websocket,
"queue": asyncio.Queue(),
}
await websocket.send_json(
{"status": "connected", "message": "Connected"},
)
await websocket.send_json({"status": "wait"})
await websocket.send_json({"status": "send_frame"})
def check_user(self, user_id: UUID) -> bool:
return user_id in self.active_connections
async def update_data(self, user_id: UUID, new_data: ParamsModel):
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
await queue.put(new_data)
async def get_latest_data(self, user_id: UUID) -> ParamsModel | None:
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
try:
return await queue.get()
except asyncio.QueueEmpty:
return None
return None
def delete_user(self, user_id: UUID):
user_session = self.active_connections.pop(user_id, None)
if user_session:
queue = user_session["queue"]
while not queue.empty():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
continue
def get_user_count(self) -> int:
return len(self.active_connections)
def get_websocket(self, user_id: UUID) -> WebSocket:
user_session = self.active_connections.get(user_id)
if user_session:
websocket = user_session["websocket"]
if websocket.client_state == WebSocketState.CONNECTED:
return user_session["websocket"]
return None
async def disconnect(self, user_id: UUID):
websocket = self.get_websocket(user_id)
if websocket:
await websocket.close()
self.delete_user(user_id)
async def send_json(self, user_id: UUID, data: dict):
try:
websocket = self.get_websocket(user_id)
if websocket:
await websocket.send_json(data)
except Exception as e:
logging.error(f"Error: Send json: {e}")
async def receive_json(self, user_id: UUID) -> dict | None:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_json()
return None
except Exception as e:
logging.error(f"Error: Receive json: {e}")
return None
async def receive_bytes(self, user_id: UUID) -> bytes | None:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_bytes()
return None
except Exception as e:
logging.error(f"Error: Receive bytes: {e}")
return None