Spaces:
Runtime error
Runtime error
| import logging | |
| import mimetypes | |
| import os | |
| import time | |
| import uuid | |
| from types import SimpleNamespace | |
| import markdown2 | |
| import torch | |
| from config import Args, config | |
| from connection_manager import ConnectionManager, ServerFullException | |
| from fastapi import FastAPI, HTTPException, Request, WebSocket | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from util import bytes_to_pil, pil_to_frame | |
| from vid2vid import Pipeline | |
| # fix mime error on windows | |
| mimetypes.add_type("application/javascript", ".js") | |
| THROTTLE = 1.0 / 120 | |
| # logging.basicConfig(level=logging.DEBUG) | |
| class App: | |
| def __init__(self, config: Args): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch_dtype = torch.float16 | |
| pipeline = Pipeline(config, device, torch_dtype) | |
| self.args = config | |
| self.pipeline = pipeline | |
| self.app = FastAPI() | |
| self.conn_manager = ConnectionManager() | |
| self.init_app() | |
| def init_app(self): | |
| self.app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket): | |
| try: | |
| await self.conn_manager.connect(user_id, websocket, self.args.max_queue_size) | |
| await handle_websocket_data(user_id) | |
| except ServerFullException as e: | |
| logging.error(f"Server Full: {e}") | |
| finally: | |
| await self.conn_manager.disconnect(user_id) | |
| logging.info(f"User disconnected: {user_id}") | |
| async def handle_websocket_data(user_id: uuid.UUID): | |
| if not self.conn_manager.check_user(user_id): | |
| return HTTPException(status_code=404, detail="User not found") | |
| last_time = time.time() | |
| try: | |
| while True: | |
| if self.args.timeout > 0 and time.time() - last_time > self.args.timeout: | |
| await self.conn_manager.send_json( | |
| user_id, | |
| { | |
| "status": "timeout", | |
| "message": "Your session has ended", | |
| }, | |
| ) | |
| await self.conn_manager.disconnect(user_id) | |
| return | |
| data = await self.conn_manager.receive_json(user_id) | |
| if data["status"] == "next_frame": | |
| info = self.pipeline.Info() | |
| params = await self.conn_manager.receive_json(user_id) | |
| params = self.pipeline.InputParams(**params) | |
| params = SimpleNamespace(**params.model_dump()) | |
| if info.input_mode == "image": | |
| image_data = await self.conn_manager.receive_bytes(user_id) | |
| if len(image_data) == 0: | |
| await self.conn_manager.send_json(user_id, {"status": "send_frame"}) | |
| continue | |
| params.image = bytes_to_pil(image_data) | |
| await self.conn_manager.update_data(user_id, params) | |
| except Exception as e: | |
| logging.error(f"Websocket Error: {e}, {user_id} ") | |
| await self.conn_manager.disconnect(user_id) | |
| async def get_queue_size(): | |
| queue_size = self.conn_manager.get_user_count() | |
| return JSONResponse({"queue_size": queue_size}) | |
| async def stream(user_id: uuid.UUID, request: Request): | |
| try: | |
| async def generate(): | |
| while True: | |
| last_time = time.time() | |
| await self.conn_manager.send_json(user_id, {"status": "send_frame"}) | |
| params = await self.conn_manager.get_latest_data(user_id) | |
| if params is None: | |
| continue | |
| image = self.pipeline.predict(params) | |
| if image is None: | |
| continue | |
| frame = pil_to_frame(image) | |
| yield frame | |
| if self.args.debug: | |
| print(f"Time taken: {time.time() - last_time}") | |
| return StreamingResponse( | |
| generate(), | |
| media_type="multipart/x-mixed-replace;boundary=frame", | |
| headers={"Cache-Control": "no-cache"}, | |
| ) | |
| except Exception as e: | |
| logging.error(f"Streaming Error: {e}, {user_id} ") | |
| return HTTPException(status_code=404, detail="User not found") | |
| # route to setup frontend | |
| async def settings(): | |
| info_schema = self.pipeline.Info.model_json_schema() | |
| info = self.pipeline.Info() | |
| if info.page_content: | |
| page_content = markdown2.markdown(info.page_content) | |
| input_params = self.pipeline.InputParams.model_json_schema() | |
| return JSONResponse( | |
| { | |
| "info": info_schema, | |
| "input_params": input_params, | |
| "max_queue_size": self.args.max_queue_size, | |
| "page_content": page_content if info.page_content else "", | |
| } | |
| ) | |
| if not os.path.exists("public"): | |
| os.makedirs("public") | |
| self.app.mount("/", StaticFiles(directory="./frontend/public", html=True), name="public") | |
| app = App(config).app | |