from dataclasses import dataclass from typing import Dict, Any, Optional import base64 import logging import random import traceback import torch from skyreelsinfer import TaskType from skyreelsinfer.offload import OffloadConfig from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer from varnish import Varnish # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class GenerationConfig: """Configuration for video generation""" # Content settings prompt: str negative_prompt: str = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" # Model settings num_frames: int = 97 # SkyReels default height: int = 544 # SkyReels default width: int = 960 # SkyReels default num_inference_steps: int = 30 guidance_scale: float = 6.0 # Reproducibility seed: int = -1 # Varnish post-processing settings fps: int = 30 double_num_frames: bool = False super_resolution: bool = False grain_amount: float = 0.0 quality: int = 18 # CRF scale (0-51, lower is better) # Audio settings enable_audio: bool = False audio_prompt: str = "" audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Model-specific settings embedded_guidance_scale: float = 1.0 quant_model: bool = True gpu_num: int = 1 offload: bool = True high_cpu_memory: bool = True parameters_level: bool = False compiler_transformer: bool = False sequence_batch: bool = False def validate_and_adjust(self) -> 'GenerationConfig': """Validate and adjust parameters""" # Set random seed if not specified if self.seed == -1: self.seed = random.randint(0, 2**32 - 1) return self class EndpointHandler: """Handles video generation requests using SkyReels and Varnish""" def __init__(self, path: str = ""): """Initialize handler with models Args: path: Path to model weights """ self.device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize SkyReelsVideoInfer self.predictor = SkyReelsVideoInfer( task_type=TaskType.T2V, model_id=path or "Skywork/SkyReels-V1", quant_model=True, # Enable quantization by default world_size=1, # Single GPU by default is_offload=True, # Enable offloading by default offload_config=OffloadConfig( high_cpu_memory=True, parameters_level=False, compiler_transformer=False, ), enable_cfg_parallel=True ) # Initialize Varnish for post-processing self.varnish = Varnish( device=self.device, model_base_dir="/repository/varnish" ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process video generation requests Args: data: Request data containing: - inputs (str): Prompt for video generation - parameters (dict): Generation parameters Returns: Dictionary containing: - video: Base64 encoded MP4 data URI - content-type: MIME type - metadata: Generation metadata """ # Extract inputs inputs = data.pop("inputs", data) if isinstance(inputs, dict): prompt = inputs.get("prompt", "") else: prompt = inputs params = data.get("parameters", {}) # Create and validate config config = GenerationConfig( prompt=prompt, negative_prompt=params.get("negative_prompt", ""), num_frames=params.get("num_frames", 97), height=params.get("height", 544), width=params.get("width", 960), num_inference_steps=params.get("num_inference_steps", 30), guidance_scale=params.get("guidance_scale", 6.0), seed=params.get("seed", -1), fps=params.get("fps", 30), double_num_frames=params.get("double_num_frames", False), super_resolution=params.get("super_resolution", False), grain_amount=params.get("grain_amount", 0.0), quality=params.get("quality", 18), enable_audio=params.get("enable_audio", False), audio_prompt=params.get("audio_prompt", ""), audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"), embedded_guidance_scale=params.get("embedded_guidance_scale", 1.0), quant_model=params.get("quant_model", True), gpu_num=params.get("gpu_num", 1), offload=params.get("offload", True), high_cpu_memory=params.get("high_cpu_memory", True), parameters_level=params.get("parameters_level", False), compiler_transformer=params.get("compiler_transformer", False), sequence_batch=params.get("sequence_batch", False) ).validate_and_adjust() try: # Set random seeds if config.seed != -1: torch.manual_seed(config.seed) random.seed(config.seed) # Prepare generation parameters generation_kwargs = { "prompt": f"FPS-{config.fps}, {config.prompt}", # SkyReels expects FPS in prompt "negative_prompt": config.negative_prompt, "height": config.height, "width": config.width, "num_frames": config.num_frames, "num_inference_steps": config.num_inference_steps, "guidance_scale": config.guidance_scale, "embedded_guidance_scale": config.embedded_guidance_scale, "seed": config.seed, "cfg_for": config.sequence_batch } # Generate video frames using SkyReels output = self.predictor.inference(generation_kwargs) # Process with Varnish import asyncio try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) result = loop.run_until_complete( self.varnish( input_data=output, fps=config.fps, double_num_frames=config.double_num_frames, super_resolution=config.super_resolution, grain_amount=config.grain_amount, enable_audio=config.enable_audio, audio_prompt=config.audio_prompt, audio_negative_prompt=config.audio_negative_prompt, ) ) # Get video data URI video_uri = loop.run_until_complete( result.write( type="data-uri", quality=config.quality ) ) return { "video": video_uri, "content-type": "video/mp4", "metadata": { "width": result.metadata.width, "height": result.metadata.height, "num_frames": result.metadata.frame_count, "fps": result.metadata.fps, "duration": result.metadata.duration, "seed": config.seed, "gpu_num": config.gpu_num, "quant_model": config.quant_model, "guidance_scale": config.guidance_scale, "embedded_guidance_scale": config.embedded_guidance_scale } } except Exception as e: message = f"Error generating video ({str(e)})\n{traceback.format_exc()}" logger.error(message) raise RuntimeError(message)