|
from typing import Dict, Any |
|
import os |
|
import shutil |
|
import gc |
|
import time |
|
from pathlib import Path |
|
import argparse |
|
from datetime import datetime |
|
from loguru import logger |
|
import torch |
|
import base64 |
|
|
|
from hyvideo.utils.file_utils import save_videos_grid |
|
from hyvideo.inference import HunyuanVideoSampler |
|
from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH |
|
|
|
try: |
|
import triton |
|
has_triton = True |
|
except ImportError: |
|
has_triton = False |
|
|
|
try: |
|
from mmgp import offload, safetensors2, profile_type |
|
has_mmgp = True |
|
except ImportError: |
|
has_mmgp = False |
|
|
|
|
|
logger.add("handler_debug.log", rotation="500 MB") |
|
|
|
DEFAULT_RESOLUTION = "720p" |
|
DEFAULT_WIDTH = 1280 |
|
DEFAULT_HEIGHT = 720 |
|
DEFAULT_NB_FRAMES = (4 * 30) + 1 |
|
DEFAULT_NB_STEPS = 22 |
|
DEFAULT_FPS = 24 |
|
|
|
def get_attention_modes(): |
|
"""Get available attention modes - fallback if module function isn't available""" |
|
modes = ["sdpa"] |
|
|
|
try: |
|
import torch |
|
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): |
|
modes.append("sdpa") |
|
except: |
|
pass |
|
|
|
try: |
|
import flash_attn |
|
modes.append("flash") |
|
except: |
|
pass |
|
|
|
try: |
|
import sageattention |
|
modes.append("sage") |
|
if hasattr(sageattention, 'efficient_attention_v2'): |
|
modes.append("sage2") |
|
except: |
|
pass |
|
|
|
try: |
|
import xformers |
|
modes.append("xformers") |
|
except: |
|
pass |
|
|
|
return modes |
|
|
|
|
|
try: |
|
from hyvideo.modules.attenion import get_attention_modes |
|
attention_modes_supported = get_attention_modes() |
|
except: |
|
attention_modes_supported = get_attention_modes() |
|
|
|
def setup_vae_path(vae_path: Path) -> Path: |
|
"""Create a temporary directory with correctly named VAE config file""" |
|
tmp_vae_dir = Path("/tmp/vae") |
|
if tmp_vae_dir.exists(): |
|
shutil.rmtree(tmp_vae_dir) |
|
tmp_vae_dir.mkdir(parents=True) |
|
|
|
|
|
logger.info(f"Setting up VAE in temporary directory: {tmp_vae_dir}") |
|
|
|
|
|
original_config = vae_path / "hunyuan-video-t2v-720p_vae_config.json" |
|
new_config = tmp_vae_dir / "config.json" |
|
shutil.copy2(original_config, new_config) |
|
logger.info(f"Copied VAE config from {original_config} to {new_config}") |
|
|
|
|
|
original_model = vae_path / "pytorch_model.pt" |
|
new_model = tmp_vae_dir / "pytorch_model.pt" |
|
shutil.copy2(original_model, new_model) |
|
logger.info(f"Copied VAE model from {original_model} to {new_model}") |
|
|
|
return tmp_vae_dir |
|
|
|
def get_default_args(): |
|
"""Create default arguments instead of parsing from command line""" |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill") |
|
parser.add_argument("--model-resolution", type=str, default=DEFAULT_RESOLUTION, choices=["540p", "720p"]) |
|
parser.add_argument("--latent-channels", type=int, default=16) |
|
parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"]) |
|
parser.add_argument("--rope-theta", type=int, default=256) |
|
parser.add_argument("--load-key", type=str, default="module") |
|
parser.add_argument("--use-fp8", action="store_true", default=False) |
|
|
|
|
|
parser.add_argument("--vae", type=str, default="884-16c-hy") |
|
parser.add_argument("--vae-precision", type=str, default="fp16") |
|
parser.add_argument("--vae-tiling", action="store_true", default=True) |
|
|
|
|
|
parser.add_argument("--text-encoder", type=str, default="llm") |
|
parser.add_argument("--text-encoder-precision", type=str, default="fp16") |
|
parser.add_argument("--text-states-dim", type=int, default=4096) |
|
parser.add_argument("--text-len", type=int, default=256) |
|
parser.add_argument("--tokenizer", type=str, default="llm") |
|
|
|
|
|
parser.add_argument("--prompt-template", type=str, default="dit-llm-encode") |
|
parser.add_argument("--prompt-template-video", type=str, default="dit-llm-encode-video") |
|
|
|
|
|
parser.add_argument("--hidden-state-skip-layer", type=int, default=2) |
|
parser.add_argument("--apply-final-norm", action="store_true") |
|
parser.add_argument("--text-encoder-2", type=str, default="clipL") |
|
parser.add_argument("--text-encoder-precision-2", type=str, default="fp16") |
|
parser.add_argument("--text-states-dim-2", type=int, default=768) |
|
parser.add_argument("--tokenizer-2", type=str, default="clipL") |
|
parser.add_argument("--text-len-2", type=int, default=77) |
|
|
|
|
|
parser.add_argument("--hidden-size", type=int, default=1024) |
|
parser.add_argument("--heads-num", type=int, default=16) |
|
parser.add_argument("--layers-num", type=int, default=24) |
|
parser.add_argument("--mlp-ratio", type=float, default=4.0) |
|
parser.add_argument("--use-guidance-net", action="store_true", default=True) |
|
|
|
|
|
parser.add_argument("--denoise-type", type=str, default="flow") |
|
parser.add_argument("--flow-shift", type=float, default=7.0) |
|
parser.add_argument("--flow-reverse", action="store_true", default=True) |
|
parser.add_argument("--flow-solver", type=str, default="euler") |
|
parser.add_argument("--use-linear-quadratic-schedule", action="store_true") |
|
parser.add_argument("--linear-schedule-end", type=int, default=25) |
|
|
|
|
|
parser.add_argument("--use-cpu-offload", action="store_true", default=False) |
|
parser.add_argument("--batch-size", type=int, default=1) |
|
parser.add_argument("--infer-steps", type=int, default=DEFAULT_NB_STEPS) |
|
parser.add_argument("--disable-autocast", action="store_true") |
|
|
|
|
|
parser.add_argument("--save-path", type=str, default="outputs") |
|
parser.add_argument("--save-path-suffix", type=str, default="") |
|
parser.add_argument("--name-suffix", type=str, default="") |
|
|
|
|
|
parser.add_argument("--num-videos", type=int, default=1) |
|
parser.add_argument("--video-size", nargs="+", type=int, default=[DEFAULT_HEIGHT, DEFAULT_WIDTH]) |
|
parser.add_argument("--video-length", type=int, default=DEFAULT_NB_FRAMES) |
|
parser.add_argument("--prompt", type=str, default=None) |
|
parser.add_argument("--seed-type", type=str, default="auto", choices=["file", "random", "fixed", "auto"]) |
|
parser.add_argument("--seed", type=int, default=None) |
|
parser.add_argument("--neg-prompt", type=str, default="") |
|
parser.add_argument("--cfg-scale", type=float, default=1.0) |
|
parser.add_argument("--embedded-cfg-scale", type=float, default=6.0) |
|
parser.add_argument("--reproduce", action="store_true") |
|
|
|
|
|
parser.add_argument("--ulysses-degree", type=int, default=1) |
|
parser.add_argument("--ring-degree", type=int, default=1) |
|
|
|
|
|
parser.add_argument("--attention", type=str, default="auto", |
|
choices=["auto", "sdpa", "flash", "sage", "sage2", "xformers"]) |
|
parser.add_argument("--profile", type=int, default=1) |
|
parser.add_argument("--quantize-transformer", action="store_true", default=False) |
|
parser.add_argument("--tea-cache", type=float, default=0.0) |
|
parser.add_argument("--compile", action="store_true", default=False) |
|
parser.add_argument("--enable-riflex", action="store_true", default=True) |
|
parser.add_argument("--vae-config", type=int, default=0) |
|
|
|
|
|
args = parser.parse_args([]) |
|
|
|
return args |
|
|
|
def get_auto_attention(): |
|
"""Select the best available attention mode""" |
|
for attn in ["sage2", "sage", "sdpa"]: |
|
if attn in attention_modes_supported: |
|
return attn |
|
return "sdpa" |
|
|
|
def setup_vae_config(device_mem_capacity, vae, vae_config=0): |
|
"""Configure VAE tiling based on available VRAM""" |
|
if vae_config == 0: |
|
|
|
if device_mem_capacity >= 24000: |
|
use_vae_config = 1 |
|
elif device_mem_capacity >= 16000: |
|
use_vae_config = 3 |
|
elif device_mem_capacity >= 12000: |
|
use_vae_config = 4 |
|
else: |
|
use_vae_config = 5 |
|
else: |
|
use_vae_config = vae_config |
|
|
|
|
|
if use_vae_config == 1: |
|
sample_tsize = 32 |
|
sample_size = 256 |
|
elif use_vae_config == 2: |
|
sample_tsize = 64 |
|
sample_size = 192 |
|
elif use_vae_config == 3: |
|
sample_tsize = 32 |
|
sample_size = 192 |
|
elif use_vae_config == 4: |
|
sample_tsize = 16 |
|
sample_size = 256 |
|
else: |
|
sample_tsize = 16 |
|
sample_size = 192 |
|
|
|
|
|
vae.tile_sample_min_tsize = sample_tsize |
|
vae.tile_latent_min_tsize = sample_tsize // vae.time_compression_ratio |
|
vae.tile_sample_min_size = sample_size |
|
vae.tile_latent_min_size = int(sample_size / (2 ** (len(vae.config.block_out_channels) - 1))) |
|
vae.tile_overlap_factor = 0.25 |
|
|
|
return use_vae_config |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
"""Initialize the handler with model path and config.""" |
|
logger.info(f"Initializing EndpointHandler with path: {path}") |
|
|
|
|
|
self.args = get_default_args() |
|
|
|
|
|
path = str(Path(path).absolute()) |
|
logger.info(f"Absolute path: {path}") |
|
|
|
|
|
self.args.model_base = path |
|
|
|
|
|
self.init_model_paths(path) |
|
self.configure_model() |
|
|
|
|
|
self.initialize_model() |
|
|
|
def init_model_paths(self, path): |
|
"""Setup paths for model components""" |
|
|
|
self.args.use_fp8 = True |
|
|
|
|
|
dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt" |
|
original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae" |
|
|
|
|
|
logger.info(f"Model base path: {self.args.model_base}") |
|
logger.info(f"DiT weight path: {dit_weight_path}") |
|
logger.info(f"Use fp8: {self.args.use_fp8}") |
|
logger.info(f"Original VAE path: {original_vae_path}") |
|
|
|
|
|
logger.info("Checking if paths exist:") |
|
logger.info(f"DiT weight exists: {dit_weight_path.exists()}") |
|
logger.info(f"VAE path exists: {original_vae_path.exists()}") |
|
|
|
if original_vae_path.exists(): |
|
logger.info(f"VAE path contents: {list(original_vae_path.glob('*'))}") |
|
|
|
|
|
tmp_vae_path = setup_vae_path(original_vae_path) |
|
|
|
|
|
VAE_PATH["884-16c-hy"] = str(tmp_vae_path) |
|
logger.info(f"Updated VAE_PATH to: {VAE_PATH['884-16c-hy']}") |
|
|
|
|
|
text_encoder_path = str(Path(path) / "text_encoder") |
|
text_encoder_2_path = str(Path(path) / "text_encoder_2") |
|
|
|
|
|
TEXT_ENCODER_PATH.update({ |
|
"llm": text_encoder_path, |
|
"clipL": text_encoder_2_path |
|
}) |
|
|
|
TOKENIZER_PATH.update({ |
|
"llm": text_encoder_path, |
|
"clipL": text_encoder_2_path |
|
}) |
|
|
|
logger.info(f"Updated text encoder paths:") |
|
logger.info(f"TEXT_ENCODER_PATH['llm']: {TEXT_ENCODER_PATH['llm']}") |
|
logger.info(f"TEXT_ENCODER_PATH['clipL']: {TEXT_ENCODER_PATH['clipL']}") |
|
logger.info(f"TOKENIZER_PATH['llm']: {TOKENIZER_PATH['llm']}") |
|
logger.info(f"TOKENIZER_PATH['clipL']: {TOKENIZER_PATH['clipL']}") |
|
|
|
self.args.dit_weight = str(dit_weight_path) |
|
|
|
def configure_model(self): |
|
"""Configure model based on available hardware and settings""" |
|
|
|
if self.args.attention == "auto": |
|
self.attention_mode = get_auto_attention() |
|
elif self.args.attention in attention_modes_supported: |
|
self.attention_mode = self.args.attention |
|
else: |
|
logger.warning(f"Attention mode {self.args.attention} not supported. Falling back to sdpa.") |
|
self.attention_mode = "sdpa" |
|
|
|
logger.info(f"Using attention mode: {self.attention_mode}") |
|
|
|
|
|
if self.args.compile and not has_triton: |
|
logger.warning("Compilation requested but Triton not available. Compilation disabled.") |
|
self.args.compile = False |
|
|
|
|
|
|
|
if has_mmgp: |
|
self.profile = self.args.profile |
|
logger.info(f"Using memory profile: {self.profile}") |
|
else: |
|
logger.warning("MMGP not available. Memory profiles not used.") |
|
|
|
def initialize_model(self): |
|
"""Initialize the model with configured settings""" |
|
models_root_path = Path(self.args.model_base) |
|
if not models_root_path.exists(): |
|
raise ValueError(f"models_root_path does not exist: {models_root_path}") |
|
|
|
try: |
|
logger.info("Attempting to initialize HunyuanVideoSampler...") |
|
|
|
|
|
transformer_path = str(self.args.dit_weight) |
|
text_encoder_path = str(Path(self.args.model_base) / "text_encoder") |
|
|
|
logger.info(f"Transformer path: {transformer_path}") |
|
logger.info(f"Text encoder path: {text_encoder_path}") |
|
|
|
|
|
self.model = HunyuanVideoSampler.from_pretrained( |
|
transformer_path, |
|
text_encoder_path, |
|
attention_mode=self.attention_mode, |
|
args=self.args |
|
) |
|
|
|
|
|
if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'): |
|
transformer = self.model.pipeline.transformer |
|
transformer.attention_mode = self.attention_mode |
|
|
|
if hasattr(transformer, 'double_blocks'): |
|
for module in transformer.double_blocks: |
|
module.attention_mode = self.attention_mode |
|
if hasattr(transformer, 'single_blocks'): |
|
for module in transformer.single_blocks: |
|
module.attention_mode = self.attention_mode |
|
|
|
|
|
if self.args.compile: |
|
transformer.any_compilation = True |
|
logger.info("PyTorch compilation enabled for transformer") |
|
|
|
|
|
if self.args.tea_cache > 0: |
|
transformer.enable_teacache = True |
|
transformer.rel_l1_thresh = self.args.tea_cache |
|
logger.info(f"TeaCache enabled with threshold: {self.args.tea_cache}") |
|
else: |
|
transformer.enable_teacache = False |
|
|
|
|
|
if hasattr(self.model, 'vae'): |
|
if torch.cuda.is_available(): |
|
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 |
|
vae_config = setup_vae_config(device_mem_capacity, self.model.vae, self.args.vae_config) |
|
logger.info(f"Configured VAE tiling with config: {vae_config}") |
|
else: |
|
logger.warning("CUDA not available, using default VAE configuration") |
|
|
|
logger.info("Successfully initialized HunyuanVideoSampler") |
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing model: {str(e)}") |
|
raise |
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
"""Process a single request""" |
|
|
|
logger.info(f"Processing request with data: {data}") |
|
|
|
|
|
prompt = data.pop("inputs", None) |
|
if prompt is None: |
|
raise ValueError("No prompt provided in the 'inputs' field") |
|
|
|
|
|
resolution = data.pop("resolution", f"{DEFAULT_WIDTH}x{DEFAULT_HEIGHT}") |
|
width, height = map(int, resolution.split("x")) |
|
|
|
|
|
video_length = int(data.pop("video_length", DEFAULT_NB_FRAMES)) |
|
seed = data.pop("seed", -1) |
|
seed = None if seed == -1 else int(seed) |
|
num_inference_steps = int(data.pop("num_inference_steps", DEFAULT_NB_STEPS)) |
|
guidance_scale = float(data.pop("guidance_scale", 1.0)) |
|
flow_shift = float(data.pop("flow_shift", 7.0)) |
|
embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0)) |
|
enable_riflex = data.pop("enable_riflex", self.args.enable_riflex) |
|
tea_cache = float(data.pop("tea_cache", 0.0)) |
|
|
|
logger.info(f"Processing with parameters: width={width}, height={height}, " |
|
f"video_length={video_length}, seed={seed}, " |
|
f"num_inference_steps={num_inference_steps}") |
|
|
|
try: |
|
|
|
if hasattr(self.model.pipeline, 'transformer') and tea_cache > 0: |
|
transformer = self.model.pipeline.transformer |
|
transformer.enable_teacache = True |
|
transformer.num_steps = num_inference_steps |
|
transformer.cnt = 0 |
|
transformer.rel_l1_thresh = tea_cache |
|
transformer.accumulated_rel_l1_distance = 0 |
|
transformer.previous_modulated_input = None |
|
transformer.previous_residual = None |
|
|
|
|
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
outputs = self.model.predict( |
|
prompt=prompt, |
|
height=height, |
|
width=width, |
|
video_length=video_length, |
|
seed=seed, |
|
negative_prompt="", |
|
infer_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
num_videos_per_prompt=1, |
|
flow_shift=flow_shift, |
|
batch_size=1, |
|
embedded_guidance_scale=embedded_guidance_scale, |
|
enable_riflex=enable_riflex |
|
) |
|
|
|
|
|
samples = outputs['samples'] |
|
sample = samples[0].unsqueeze(0) |
|
|
|
|
|
temp_path = "/tmp/temp_video.mp4" |
|
save_videos_grid(sample, temp_path, fps=DEFAULT_FPS) |
|
|
|
|
|
with open(temp_path, "rb") as f: |
|
video_bytes = f.read() |
|
video_base64 = base64.b64encode(video_bytes).decode() |
|
|
|
|
|
video_data_uri = f"data:video/mp4;base64,{video_base64}" |
|
|
|
|
|
os.remove(temp_path) |
|
|
|
|
|
if has_mmgp and hasattr(offload, 'last_offload_obj'): |
|
offload.last_offload_obj.unload_all() |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
logger.info("Successfully generated and encoded video") |
|
|
|
|
|
return video_data_uri |
|
|
|
except Exception as e: |
|
logger.error(f"Error during video generation: {str(e)}") |
|
|
|
|
|
if has_mmgp and hasattr(offload, 'last_offload_obj'): |
|
offload.last_offload_obj.unload_all() |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
raise |