import logging import mimetypes import os import time import uuid from types import SimpleNamespace import markdown2 import torch from config import Args, config from connection_manager import ConnectionManager, ServerFullException from fastapi import FastAPI, HTTPException, Request, WebSocket from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from util import bytes_to_pil, pil_to_frame from vid2vid import Pipeline # fix mime error on windows mimetypes.add_type("application/javascript", ".js") THROTTLE = 1.0 / 120 # logging.basicConfig(level=logging.DEBUG) class App: def __init__(self, config: Args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_dtype = torch.float16 pipeline = Pipeline(config, device, torch_dtype) self.args = config self.pipeline = pipeline self.app = FastAPI() self.conn_manager = ConnectionManager() self.init_app() def init_app(self): self.app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @self.app.websocket("/api/ws/{user_id}") async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket): try: await self.conn_manager.connect(user_id, websocket, self.args.max_queue_size) await handle_websocket_data(user_id) except ServerFullException as e: logging.error(f"Server Full: {e}") finally: await self.conn_manager.disconnect(user_id) logging.info(f"User disconnected: {user_id}") async def handle_websocket_data(user_id: uuid.UUID): if not self.conn_manager.check_user(user_id): return HTTPException(status_code=404, detail="User not found") last_time = time.time() try: while True: if self.args.timeout > 0 and time.time() - last_time > self.args.timeout: await self.conn_manager.send_json( user_id, { "status": "timeout", "message": "Your session has ended", }, ) await self.conn_manager.disconnect(user_id) return data = await self.conn_manager.receive_json(user_id) if data["status"] == "next_frame": info = self.pipeline.Info() params = await self.conn_manager.receive_json(user_id) params = self.pipeline.InputParams(**params) params = SimpleNamespace(**params.model_dump()) if info.input_mode == "image": image_data = await self.conn_manager.receive_bytes(user_id) if len(image_data) == 0: await self.conn_manager.send_json(user_id, {"status": "send_frame"}) continue params.image = bytes_to_pil(image_data) await self.conn_manager.update_data(user_id, params) except Exception as e: logging.error(f"Websocket Error: {e}, {user_id} ") await self.conn_manager.disconnect(user_id) @self.app.get("/api/queue") async def get_queue_size(): queue_size = self.conn_manager.get_user_count() return JSONResponse({"queue_size": queue_size}) @self.app.get("/api/stream/{user_id}") async def stream(user_id: uuid.UUID, request: Request): try: async def generate(): while True: last_time = time.time() await self.conn_manager.send_json(user_id, {"status": "send_frame"}) params = await self.conn_manager.get_latest_data(user_id) if params is None: continue image = self.pipeline.predict(params) if image is None: continue frame = pil_to_frame(image) yield frame if self.args.debug: print(f"Time taken: {time.time() - last_time}") return StreamingResponse( generate(), media_type="multipart/x-mixed-replace;boundary=frame", headers={"Cache-Control": "no-cache"}, ) except Exception as e: logging.error(f"Streaming Error: {e}, {user_id} ") return HTTPException(status_code=404, detail="User not found") # route to setup frontend @self.app.get("/api/settings") async def settings(): info_schema = self.pipeline.Info.model_json_schema() info = self.pipeline.Info() if info.page_content: page_content = markdown2.markdown(info.page_content) input_params = self.pipeline.InputParams.model_json_schema() return JSONResponse( { "info": info_schema, "input_params": input_params, "max_queue_size": self.args.max_queue_size, "page_content": page_content if info.page_content else "", } ) if not os.path.exists("public"): os.makedirs("public") self.app.mount("/", StaticFiles(directory="./frontend/public", html=True), name="public") app = App(config).app