Spaces:
Build error
Build error
| import enum | |
| import json | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| from fastapi import WebSocket, websockets | |
| from fastapi.encoders import jsonable_encoder | |
| from pydantic import BaseModel | |
| from core.db.redis_session import redis_chat_client | |
| from core.security import get_uid_hash | |
| from models import User | |
| class ChatMessageTypes(enum.Enum): | |
| MESSAGE_HISTORY: int = 1 | |
| PUBLIC_MESSAGE: int = 2 | |
| ANON_MESSAGE: int = 3 | |
| USER_JOINED: int = 4 | |
| USER_LEFT: int = 5 | |
| ACTIVE_USER_LIST: int = 6 | |
| class Message(BaseModel): | |
| msg_type: int | |
| data: Optional[str] | |
| user: Optional[str] | |
| time: datetime | |
| class WebSocketManager: | |
| def __init__(self): | |
| self.connections: Dict = {} | |
| async def update(self, data, key): | |
| msg = await redis_chat_client.client.get(key) | |
| if msg: | |
| msg = json.loads(msg) | |
| else: | |
| msg = [] | |
| msg.append(data) | |
| await redis_chat_client.client.set( | |
| key, json.dumps(msg, separators=(",", ":")), expire=60 * 60 * 1000 | |
| ) | |
| async def send_history(self, websocket: WebSocket, class_session_id: int): | |
| chat_history = await redis_chat_client.client.get( | |
| f"chat_class_sess_{class_session_id}", encoding="UTF-8" | |
| ) | |
| msg_history_instance = Message( | |
| msg_type=ChatMessageTypes.MESSAGE_HISTORY.value, | |
| data=chat_history, | |
| time=datetime.utcnow(), | |
| ) | |
| await websocket.send_json( | |
| jsonable_encoder(msg_history_instance.dict(exclude_none=True)) | |
| ) | |
| async def connect(self, websocket: WebSocket, user_id: int, class_session_id: int): | |
| await websocket.accept() | |
| try: | |
| self.connections[class_session_id].append(websocket) | |
| except: | |
| self.connections.update({class_session_id: [websocket]}) | |
| msg_instance = Message( | |
| msg_type=ChatMessageTypes.USER_JOINED.value, | |
| time=datetime.utcnow(), | |
| user=user_id, | |
| ) | |
| # self.send_history(websocket=websocket, class_session_id=class_session_id) | |
| await self.broadcast( | |
| msg_instance.dict(exclude_none=True), user_id, class_session_id, save=False | |
| ) | |
| pre_status = await redis_chat_client.client.get( | |
| f"active_status_{class_session_id}", encoding="UTF-8" | |
| ) | |
| active_user_instance = Message( | |
| msg_type=ChatMessageTypes.ACTIVE_USER_LIST.value, | |
| data=pre_status, | |
| time=datetime.utcnow(), | |
| ) | |
| # print(active_user_instance.dict(exclude_none=True)) | |
| await websocket.send_json( | |
| jsonable_encoder(active_user_instance.dict(exclude_none=True)) | |
| ) | |
| if not pre_status: | |
| pre_status_obj = [] | |
| else: | |
| pre_status_obj = json.loads(pre_status) | |
| pre_status_obj.append(user_id) | |
| pre_status_obj = list(set(pre_status_obj)) | |
| await redis_chat_client.client.set( | |
| f"active_status_{class_session_id}", | |
| json.dumps(pre_status_obj, separators=(",", ":")), | |
| ) | |
| await redis_chat_client.client.expire( | |
| f"active_status_{class_session_id}", | |
| 60 * 60 * 1000, | |
| ) | |
| async def disconnect( | |
| self, websocket: WebSocket, user_id: int, class_session_id: int | |
| ): | |
| self.connections[class_session_id].remove(websocket) | |
| msg_instance = Message( | |
| msg_type=ChatMessageTypes.USER_LEFT.value, | |
| time=datetime.utcnow(), | |
| user=user_id, | |
| ) | |
| await self.broadcast(msg_instance, user_id, class_session_id, save=False) | |
| pre_status = json.loads( | |
| await redis_chat_client.client.get(f"active_status_{class_session_id}") | |
| ) | |
| pre_status.remove(user_id) | |
| await redis_chat_client.client.set( | |
| f"active_status_{class_session_id}", | |
| json.dumps(pre_status, separators=(",", ":")), | |
| ) | |
| await redis_chat_client.client.expire( | |
| f"active_status_{class_session_id}", | |
| 60 * 60 * 1000, | |
| ) | |
| async def broadcast( | |
| self, data: any, user_id: int, class_session_id: int, save: bool = True | |
| ): | |
| encoded_data = jsonable_encoder(data) | |
| for connection in self.connections.get(class_session_id): | |
| try: | |
| await connection.send_json(encoded_data) | |
| except Exception as e: | |
| pass | |
| if save: | |
| await self.update(encoded_data, f"chat_class_sess_{class_session_id}") | |
| async def message( | |
| self, | |
| websocket: WebSocket, | |
| message: str, | |
| user_id: int, | |
| class_session_id: int, | |
| anon: bool = False, | |
| ): | |
| msg_type = ChatMessageTypes.PUBLIC_MESSAGE.value | |
| user = user_id | |
| if anon: | |
| msg_type = ChatMessageTypes.ANON_MESSAGE.value | |
| user = get_uid_hash(user_id) | |
| msg_instance = Message( | |
| msg_type=msg_type, | |
| data=message, | |
| user=user, | |
| time=datetime.utcnow(), | |
| ) | |
| await self.broadcast( | |
| msg_instance.dict(exclude_none=True), user_id, class_session_id | |
| ) | |
| ws = WebSocketManager() | |