File size: 15,741 Bytes
2df6ae7 4e8d40c 2df6ae7 3fb4272 2df6ae7 b97d514 cdf6b7a 8c9921c b7f5d29 609f5cd 2df6ae7 4e8d40c 2df6ae7 4e8d40c 2df6ae7 8b4a69c b7f5d29 8b8ed09 8b4a69c 8d2995d b7f5d29 4e8d40c 8d2995d 2df6ae7 baee908 b7f5d29 4e8d40c 8c9921c bab295a 4e8d40c bab295a b7f5d29 060887f b7f5d29 060887f b7f5d29 060887f 03f0c45 b7f5d29 4e8d40c 187fe3b 2df6ae7 4e8d40c 2df6ae7 28747d9 8b8ed09 8d2995d 8b8ed09 28747d9 4e8d40c 2df6ae7 8b4a69c 80253b4 8b4a69c cdf6b7a 97a4a3d 4e8d40c 97a4a3d 4e8d40c 97a4a3d 4e8d40c 2df6ae7 4e8d40c cdf6b7a 2df6ae7 4e8d40c 2df6ae7 4e8d40c 2df6ae7 f510634 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
from dataclasses import dataclass
from typing import Dict, Any, Optional
import base64
import asyncio
import logging
import random
import traceback
import torch
import os
import gc
# note: there is no HunyuanImageToVideoPipeline yet in Diffusers
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, FasterCacheConfig
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
from varnish import Varnish
from varnish.utils import is_truthy, process_input_image
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Check environment variable for pipeline support
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
@dataclass
class GenerationConfig:
"""Configuration for video generation"""
# Content settings
prompt: str
negative_prompt: str = ""
# Model settings
num_frames: int = 49 # Should be 4k + 1 format
height: int = 320
width: int = 576
num_inference_steps: int = 50
guidance_scale: float = 7.0
# Reproducibility
seed: int = -1
# Varnish post-processing settings
fps: int = 30
double_num_frames: bool = False
super_resolution: bool = False
grain_amount: float = 0.0
quality: int = 18 # CRF scale (0-51, lower is better)
# Audio settings
enable_audio: bool = False
audio_prompt: str = ""
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
# TeaCache settings
enable_teacache: bool = False
teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)
# Enhance-A-Video settings
enable_enhance_a_video: bool = False
enhance_a_video_weight: float = 5.0
# LoRA settings
lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
def validate_and_adjust(self) -> 'GenerationConfig':
"""Validate and adjust parameters"""
# Ensure num_frames follows 4k + 1 format
k = (self.num_frames - 1) // 4
self.num_frames = (k * 4) + 1
# Set random seed if not specified
if self.seed == -1:
self.seed = random.randint(0, 2**32 - 1)
return self
class EndpointHandler:
"""Handles video generation requests using HunyuanVideo and Varnish"""
def __init__(self, path: str = ""):
"""Initialize handler with models
Args:
path: Path to model weights
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize transformer with Enhance-A-Video injection first
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
path,
subfolder="transformer",
torch_dtype=torch.bfloat16
)
if support_image_prompt:
raise Exception("Please use a version of Diffusers that supports HunyuanImageToVideoPipeline")
# # Initialize image-to-video pipeline
# self.image_to_video = HunyuanImageToVideoPipeline.from_pretrained(
# path,
# transformer=transformer,
# torch_dtype=torch.float16,
# ).to(self.device)
#
# # Initialize components in appropriate precision
# self.image_to_video.text_encoder = self.image_to_video.text_encoder.half()
# self.image_to_video.text_encoder_2 = self.image_to_video.text_encoder_2.half()
# self.image_to_video.transformer = self.image_to_video.transformer.to(torch.bfloat16)
# self.image_to_video.vae = self.image_to_video.vae.half()
# apply_enhance_a_video(self.image_to_video.transformer, EnhanceAVideoConfig(
# weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
# num_frames_callback=lambda: (config.num_frames - 1),
# _attention_type=1
# ))
else:
# Initialize text-to-video pipeline
self.text_to_video = HunyuanVideoPipeline.from_pretrained(
path,
transformer=transformer,
torch_dtype=torch.float16,
).to(self.device)
# Initialize components in appropriate precision
self.text_to_video.text_encoder = self.text_to_video.text_encoder.half()
self.text_to_video.text_encoder_2 = self.text_to_video.text_encoder_2.half()
self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16)
self.text_to_video.vae = self.text_to_video.vae.half()
# apply_enhance_a_video(self.text_to_video.transformer, EnhanceAVideoConfig(
# # weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
# weight=config.enhance_a_video_weight,
# num_frames_callback=lambda: (config.num_frames - 1),
# _attention_type=1
# ))
# enable FasterCache
# those values are coming from here:
# https://github.com/huggingface/diffusers/pull/10163/files#diff-777f4ee62cb325371233a450e0f6cc0ba357a3fade2ec2dea912260b4f8d08ceR67-R74
faster_cache_config = FasterCacheConfig(
current_timestep_callback=lambda: self.text_to_video.current_timestep,
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
# do we need to uncomment those?
#unconditional_batch_timestep_skip_range=(-1, 901),
#tensor_format="BFCHW",
)
#self.text_to_video.transformer.enable_cache(faster_cache_config)
# Initialize LoRA tracking
self._current_lora_model = None
# Initialize Varnish for post-processing
self.varnish = Varnish(
device=self.device,
model_base_dir="/repository/varnish"
)
async def process_frames(
self,
frames: torch.Tensor,
config: GenerationConfig
) -> tuple[str, dict]:
"""Post-process generated frames using Varnish
Args:
frames: Generated video frames tensor
config: Generation configuration
Returns:
Tuple of (video data URI, metadata dictionary)
"""
try:
# Process video with Varnish
result = await self.varnish(
input_data=frames,
fps=config.fps,
double_num_frames=config.double_num_frames,
super_resolution=config.super_resolution,
grain_amount=config.grain_amount,
enable_audio=config.enable_audio,
audio_prompt=config.audio_prompt,
audio_negative_prompt=config.audio_negative_prompt
)
# Convert to data URI
video_uri = await result.write(type="data-uri", quality=config.quality)
# Collect metadata
metadata = {
"width": result.metadata.width,
"height": result.metadata.height,
"num_frames": result.metadata.frame_count,
"fps": result.metadata.fps,
"duration": result.metadata.duration,
"seed": config.seed,
"enable_teacache": config.enable_teacache,
"teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
"enable_enhance_a_video": config.enable_enhance_a_video,
"enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
}
return video_uri, metadata
except Exception as e:
logger.error(f"Error in process_frames: {str(e)}")
raise RuntimeError(f"Failed to process frames: {str(e)}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process video generation requests
Args:
data: Request data containing:
- inputs (str): Prompt for video generation
- parameters (dict): Generation parameters
Returns:
Dictionary containing:
- video: Base64 encoded MP4 data URI
- content-type: MIME type
- metadata: Generation metadata
"""
# Extract inputs
inputs = data.pop("inputs", data)
if isinstance(inputs, dict):
prompt = inputs.get("prompt", "")
else:
prompt = inputs
params = data.get("parameters", {})
# Create and validate config
config = GenerationConfig(
prompt=prompt,
negative_prompt=params.get("negative_prompt", ""),
num_frames=params.get("num_frames", 49),
height=params.get("height", 320),
width=params.get("width", 576),
num_inference_steps=params.get("num_inference_steps", 50),
guidance_scale=params.get("guidance_scale", 7.0),
seed=params.get("seed", -1),
fps=params.get("fps", 30),
double_num_frames=params.get("double_num_frames", False),
super_resolution=params.get("super_resolution", False),
grain_amount=params.get("grain_amount", 0.0),
quality=params.get("quality", 18),
enable_audio=params.get("enable_audio", False),
audio_prompt=params.get("audio_prompt", ""),
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
enable_teacache=params.get("enable_teacache", False),
# values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
teacache_threshold=params.get("teacache_threshold", 0.15),
enable_enhance_a_video=params.get("enable_enhance_a_video", False),
enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
lora_model_name=params.get("lora_model_name", ""),
lora_model_weight_file=params.get("lora_model_weight_file", ""),
lora_model_trigger=params.get("lora_model_trigger", ""),
).validate_and_adjust()
try:
# Set random seeds
if config.seed != -1:
torch.manual_seed(config.seed)
random.seed(config.seed)
generator = torch.Generator(device=self.device).manual_seed(config.seed)
else:
generator = None
# Configure TeaCache
#if config.enable_teacache:
# enable_teacache(
# self.pipeline.transformer,
# num_inference_steps=config.num_inference_steps,
# rel_l1_thresh=config.teacache_threshold
# )
#else:
# disable_teacache(self.pipeline.transformer)
with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():
# Prepare generation parameters
generation_kwargs = {
"prompt": config.prompt,
# Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
#"negative_prompt": config.negative_prompt,
"num_frames": config.num_frames,
"height": config.height,
"width": config.width,
"num_inference_steps": config.num_inference_steps,
"guidance_scale": config.guidance_scale,
"generator": generator,
"output_type": "pt",
}
# Handle LoRA loading/unloading
if hasattr(self, '_current_lora_model'):
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
# Unload previous LoRA if it exists and is different
if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
self.image_to_video.unload_lora_weights()
else:
if hasattr(self.text_to_video, 'unload_lora_weights'):
self.text_to_video.unload_lora_weights()
if config.lora_model_name:
# Load new LoRA
if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
self.image_to_video.load_lora_weights(
config.lora_model_name,
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
token=hf_token,
)
else:
if hasattr(self.text_to_video, 'load_lora_weights'):
self.text_to_video.load_lora_weights(
config.lora_model_name,
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
token=hf_token,
)
self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
# Modify prompt if trigger word is provided
if config.lora_model_trigger:
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
# Check if image-to-video generation is requested
if support_image_prompt and input_image:
processed_image = process_input_image(
input_image,
config.width,
config.height,
config.input_image_quality,
)
generation_kwargs["image"] = processed_image
frames = self.image_to_video(**generation_kwargs).frames
else:
frames = self.text_to_video(**generation_kwargs).frames
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()
return {
"video": video_uri,
"content-type": "video/mp4",
"metadata": metadata
}
except Exception as e:
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
logger.error(message)
raise RuntimeError(message) |