from typing import Dict, Any import os import shutil import gc import time from pathlib import Path import argparse from datetime import datetime from loguru import logger import torch import base64 from hyvideo.utils.file_utils import save_videos_grid from hyvideo.inference import HunyuanVideoSampler from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH try: import triton has_triton = True except ImportError: has_triton = False try: from mmgp import offload, safetensors2, profile_type has_mmgp = True except ImportError: has_mmgp = False # Configure logger logger.add("handler_debug.log", rotation="500 MB") DEFAULT_RESOLUTION = "720p" DEFAULT_WIDTH = 1280 DEFAULT_HEIGHT = 720 DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 frame) DEFAULT_NB_STEPS = 22 # Default for standard model DEFAULT_FPS = 24 def get_attention_modes(): """Get available attention modes - fallback if module function isn't available""" modes = ["sdpa"] # Always available try: import torch if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): modes.append("sdpa") except: pass try: import flash_attn modes.append("flash") except: pass try: import sageattention modes.append("sage") if hasattr(sageattention, 'efficient_attention_v2'): modes.append("sage2") except: pass try: import xformers modes.append("xformers") except: pass return modes # Get supported attention modes try: from hyvideo.modules.attenion import get_attention_modes attention_modes_supported = get_attention_modes() except: attention_modes_supported = get_attention_modes() def setup_vae_path(vae_path: Path) -> Path: """Create a temporary directory with correctly named VAE config file""" tmp_vae_dir = Path("/tmp/vae") if tmp_vae_dir.exists(): shutil.rmtree(tmp_vae_dir) tmp_vae_dir.mkdir(parents=True) # Copy files to temp directory logger.info(f"Setting up VAE in temporary directory: {tmp_vae_dir}") # Copy and rename config file original_config = vae_path / "hunyuan-video-t2v-720p_vae_config.json" new_config = tmp_vae_dir / "config.json" shutil.copy2(original_config, new_config) logger.info(f"Copied VAE config from {original_config} to {new_config}") # Copy model file original_model = vae_path / "pytorch_model.pt" new_model = tmp_vae_dir / "pytorch_model.pt" shutil.copy2(original_model, new_model) logger.info(f"Copied VAE model from {original_model} to {new_model}") return tmp_vae_dir def get_default_args(): """Create default arguments instead of parsing from command line""" parser = argparse.ArgumentParser() # Model configuration parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill") parser.add_argument("--model-resolution", type=str, default=DEFAULT_RESOLUTION, choices=["540p", "720p"]) parser.add_argument("--latent-channels", type=int, default=16) parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"]) parser.add_argument("--rope-theta", type=int, default=256) parser.add_argument("--load-key", type=str, default="module") parser.add_argument("--use-fp8", action="store_true", default=False) # VAE settings parser.add_argument("--vae", type=str, default="884-16c-hy") parser.add_argument("--vae-precision", type=str, default="fp16") parser.add_argument("--vae-tiling", action="store_true", default=True) # Text encoder settings parser.add_argument("--text-encoder", type=str, default="llm") parser.add_argument("--text-encoder-precision", type=str, default="fp16") parser.add_argument("--text-states-dim", type=int, default=4096) parser.add_argument("--text-len", type=int, default=256) parser.add_argument("--tokenizer", type=str, default="llm") # Prompt template settings parser.add_argument("--prompt-template", type=str, default="dit-llm-encode") parser.add_argument("--prompt-template-video", type=str, default="dit-llm-encode-video") # Additional text encoder settings parser.add_argument("--hidden-state-skip-layer", type=int, default=2) parser.add_argument("--apply-final-norm", action="store_true") parser.add_argument("--text-encoder-2", type=str, default="clipL") parser.add_argument("--text-encoder-precision-2", type=str, default="fp16") parser.add_argument("--text-states-dim-2", type=int, default=768) parser.add_argument("--tokenizer-2", type=str, default="clipL") parser.add_argument("--text-len-2", type=int, default=77) # Model architecture settings parser.add_argument("--hidden-size", type=int, default=1024) parser.add_argument("--heads-num", type=int, default=16) parser.add_argument("--layers-num", type=int, default=24) parser.add_argument("--mlp-ratio", type=float, default=4.0) parser.add_argument("--use-guidance-net", action="store_true", default=True) # Inference settings parser.add_argument("--denoise-type", type=str, default="flow") parser.add_argument("--flow-shift", type=float, default=7.0) parser.add_argument("--flow-reverse", action="store_true", default=True) parser.add_argument("--flow-solver", type=str, default="euler") parser.add_argument("--use-linear-quadratic-schedule", action="store_true") parser.add_argument("--linear-schedule-end", type=int, default=25) # Hardware settings parser.add_argument("--use-cpu-offload", action="store_true", default=False) parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--infer-steps", type=int, default=DEFAULT_NB_STEPS) parser.add_argument("--disable-autocast", action="store_true") # Output settings parser.add_argument("--save-path", type=str, default="outputs") parser.add_argument("--save-path-suffix", type=str, default="") parser.add_argument("--name-suffix", type=str, default="") # Generation settings parser.add_argument("--num-videos", type=int, default=1) parser.add_argument("--video-size", nargs="+", type=int, default=[DEFAULT_HEIGHT, DEFAULT_WIDTH]) parser.add_argument("--video-length", type=int, default=DEFAULT_NB_FRAMES) parser.add_argument("--prompt", type=str, default=None) parser.add_argument("--seed-type", type=str, default="auto", choices=["file", "random", "fixed", "auto"]) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--neg-prompt", type=str, default="") parser.add_argument("--cfg-scale", type=float, default=1.0) parser.add_argument("--embedded-cfg-scale", type=float, default=6.0) parser.add_argument("--reproduce", action="store_true") # Parallel settings parser.add_argument("--ulysses-degree", type=int, default=1) parser.add_argument("--ring-degree", type=int, default=1) # Added from gradio server parser.add_argument("--attention", type=str, default="auto", choices=["auto", "sdpa", "flash", "sage", "sage2", "xformers"]) parser.add_argument("--profile", type=int, default=1) # HighRAM_HighVRAM parser.add_argument("--quantize-transformer", action="store_true", default=False) parser.add_argument("--tea-cache", type=float, default=0.0) parser.add_argument("--compile", action="store_true", default=False) parser.add_argument("--enable-riflex", action="store_true", default=True) parser.add_argument("--vae-config", type=int, default=0) # Parse with empty args list to avoid reading sys.argv args = parser.parse_args([]) return args def get_auto_attention(): """Select the best available attention mode""" for attn in ["sage2", "sage", "sdpa"]: if attn in attention_modes_supported: return attn return "sdpa" def setup_vae_config(device_mem_capacity, vae, vae_config=0): """Configure VAE tiling based on available VRAM""" if vae_config == 0: # Auto-select based on VRAM if device_mem_capacity >= 24000: use_vae_config = 1 elif device_mem_capacity >= 16000: use_vae_config = 3 elif device_mem_capacity >= 12000: use_vae_config = 4 else: use_vae_config = 5 else: use_vae_config = vae_config # VAE tiling configuration options if use_vae_config == 1: sample_tsize = 32 sample_size = 256 elif use_vae_config == 2: sample_tsize = 64 sample_size = 192 elif use_vae_config == 3: sample_tsize = 32 sample_size = 192 elif use_vae_config == 4: sample_tsize = 16 sample_size = 256 else: sample_tsize = 16 sample_size = 192 # Apply settings vae.tile_sample_min_tsize = sample_tsize vae.tile_latent_min_tsize = sample_tsize // vae.time_compression_ratio vae.tile_sample_min_size = sample_size vae.tile_latent_min_size = int(sample_size / (2 ** (len(vae.config.block_out_channels) - 1))) vae.tile_overlap_factor = 0.25 return use_vae_config class EndpointHandler: def __init__(self, path: str = ""): """Initialize the handler with model path and config.""" logger.info(f"Initializing EndpointHandler with path: {path}") # Use default args instead of parsing from command line self.args = get_default_args() # Convert path to absolute path if not already path = str(Path(path).absolute()) logger.info(f"Absolute path: {path}") # Set up model paths self.args.model_base = path # Model configurations self.init_model_paths(path) self.configure_model() # Initialize model self.initialize_model() def init_model_paths(self, path): """Setup paths for model components""" # We'll use the FP8 model for memory efficiency self.args.use_fp8 = True # Model component paths dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt" original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae" # Log all critical paths logger.info(f"Model base path: {self.args.model_base}") logger.info(f"DiT weight path: {dit_weight_path}") logger.info(f"Use fp8: {self.args.use_fp8}") logger.info(f"Original VAE path: {original_vae_path}") # Verify paths exist logger.info("Checking if paths exist:") logger.info(f"DiT weight exists: {dit_weight_path.exists()}") logger.info(f"VAE path exists: {original_vae_path.exists()}") if original_vae_path.exists(): logger.info(f"VAE path contents: {list(original_vae_path.glob('*'))}") # Set up VAE in temporary directory with correct file names tmp_vae_path = setup_vae_path(original_vae_path) # Override the VAE path in constants to use our temporary directory VAE_PATH["884-16c-hy"] = str(tmp_vae_path) logger.info(f"Updated VAE_PATH to: {VAE_PATH['884-16c-hy']}") # Update text encoder paths to use absolute paths text_encoder_path = str(Path(path) / "text_encoder") text_encoder_2_path = str(Path(path) / "text_encoder_2") # Update both text encoder and tokenizer paths TEXT_ENCODER_PATH.update({ "llm": text_encoder_path, "clipL": text_encoder_2_path }) TOKENIZER_PATH.update({ "llm": text_encoder_path, "clipL": text_encoder_2_path }) logger.info(f"Updated text encoder paths:") logger.info(f"TEXT_ENCODER_PATH['llm']: {TEXT_ENCODER_PATH['llm']}") logger.info(f"TEXT_ENCODER_PATH['clipL']: {TEXT_ENCODER_PATH['clipL']}") logger.info(f"TOKENIZER_PATH['llm']: {TOKENIZER_PATH['llm']}") logger.info(f"TOKENIZER_PATH['clipL']: {TOKENIZER_PATH['clipL']}") self.args.dit_weight = str(dit_weight_path) def configure_model(self): """Configure model based on available hardware and settings""" # Set attention mode (auto-select best available if set to 'auto') if self.args.attention == "auto": self.attention_mode = get_auto_attention() elif self.args.attention in attention_modes_supported: self.attention_mode = self.args.attention else: logger.warning(f"Attention mode {self.args.attention} not supported. Falling back to sdpa.") self.attention_mode = "sdpa" logger.info(f"Using attention mode: {self.attention_mode}") # Set compilation flag based on Triton availability if self.args.compile and not has_triton: logger.warning("Compilation requested but Triton not available. Compilation disabled.") self.args.compile = False # Set profile based on memory configuration # We default to HighRAM_HighVRAM (1) as specified if has_mmgp: self.profile = self.args.profile logger.info(f"Using memory profile: {self.profile}") else: logger.warning("MMGP not available. Memory profiles not used.") def initialize_model(self): """Initialize the model with configured settings""" models_root_path = Path(self.args.model_base) if not models_root_path.exists(): raise ValueError(f"models_root_path does not exist: {models_root_path}") try: logger.info("Attempting to initialize HunyuanVideoSampler...") # Extract necessary paths transformer_path = str(self.args.dit_weight) text_encoder_path = str(Path(self.args.model_base) / "text_encoder") logger.info(f"Transformer path: {transformer_path}") logger.info(f"Text encoder path: {text_encoder_path}") # Initialize the model using the exact signature from gradio_server.py self.model = HunyuanVideoSampler.from_pretrained( transformer_path, text_encoder_path, attention_mode=self.attention_mode, args=self.args ) # Set attention mode for transformer blocks if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'): transformer = self.model.pipeline.transformer transformer.attention_mode = self.attention_mode # Apply to all blocks if hasattr(transformer, 'double_blocks'): for module in transformer.double_blocks: module.attention_mode = self.attention_mode if hasattr(transformer, 'single_blocks'): for module in transformer.single_blocks: module.attention_mode = self.attention_mode # Enable compilation if requested if self.args.compile: transformer.any_compilation = True logger.info("PyTorch compilation enabled for transformer") # Enable TeaCache if requested if self.args.tea_cache > 0: transformer.enable_teacache = True transformer.rel_l1_thresh = self.args.tea_cache logger.info(f"TeaCache enabled with threshold: {self.args.tea_cache}") else: transformer.enable_teacache = False # Apply VAE tiling configuration if supported if hasattr(self.model, 'vae'): if torch.cuda.is_available(): device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 vae_config = setup_vae_config(device_mem_capacity, self.model.vae, self.args.vae_config) logger.info(f"Configured VAE tiling with config: {vae_config}") else: logger.warning("CUDA not available, using default VAE configuration") logger.info("Successfully initialized HunyuanVideoSampler") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def __call__(self, data: Dict[str, Any]) -> str: """Process a single request""" # Log incoming request logger.info(f"Processing request with data: {data}") # Get inputs from request data prompt = data.pop("inputs", None) if prompt is None: raise ValueError("No prompt provided in the 'inputs' field") # Parse resolution resolution = data.pop("resolution", f"{DEFAULT_WIDTH}x{DEFAULT_HEIGHT}") width, height = map(int, resolution.split("x")) # Get other parameters with defaults video_length = int(data.pop("video_length", DEFAULT_NB_FRAMES)) seed = data.pop("seed", -1) seed = None if seed == -1 else int(seed) num_inference_steps = int(data.pop("num_inference_steps", DEFAULT_NB_STEPS)) guidance_scale = float(data.pop("guidance_scale", 1.0)) flow_shift = float(data.pop("flow_shift", 7.0)) embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0)) enable_riflex = data.pop("enable_riflex", self.args.enable_riflex) tea_cache = float(data.pop("tea_cache", 0.0)) logger.info(f"Processing with parameters: width={width}, height={height}, " f"video_length={video_length}, seed={seed}, " f"num_inference_steps={num_inference_steps}") try: # Set up TeaCache for this generation if enabled if hasattr(self.model.pipeline, 'transformer') and tea_cache > 0: transformer = self.model.pipeline.transformer transformer.enable_teacache = True transformer.num_steps = num_inference_steps transformer.cnt = 0 transformer.rel_l1_thresh = tea_cache transformer.accumulated_rel_l1_distance = 0 transformer.previous_modulated_input = None transformer.previous_residual = None # Clean up memory before generation gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Run inference outputs = self.model.predict( prompt=prompt, height=height, width=width, video_length=video_length, seed=seed, negative_prompt="", infer_steps=num_inference_steps, guidance_scale=guidance_scale, num_videos_per_prompt=1, flow_shift=flow_shift, batch_size=1, embedded_guidance_scale=embedded_guidance_scale, enable_riflex=enable_riflex ) # Get the video tensor samples = outputs['samples'] sample = samples[0].unsqueeze(0) # Save to temporary file temp_path = "/tmp/temp_video.mp4" save_videos_grid(sample, temp_path, fps=DEFAULT_FPS) # Read video file and convert to base64 with open(temp_path, "rb") as f: video_bytes = f.read() video_base64 = base64.b64encode(video_bytes).decode() # Add MP4 data URI prefix video_data_uri = f"data:video/mp4;base64,{video_base64}" # Cleanup os.remove(temp_path) # Clean up memory after generation if has_mmgp and hasattr(offload, 'last_offload_obj'): offload.last_offload_obj.unload_all() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Successfully generated and encoded video") # Return exactly what the demo.py expects return video_data_uri except Exception as e: logger.error(f"Error during video generation: {str(e)}") # Clean up memory after error if has_mmgp and hasattr(offload, 'last_offload_obj'): offload.last_offload_obj.unload_all() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() raise