HunyuanVideo-HFIE / teacache.py
jbilcke-hf's picture
jbilcke-hf HF staff
Upload teacache.py
66befae verified
raw
history blame
7.16 kB
import torch
import numpy as np
from typing import Optional, Dict, Union, Any
from functools import wraps
class TeaCacheConfig:
"""Configuration for TeaCache acceleration"""
def __init__(
self,
rel_l1_thresh: float = 0.15,
enable: bool = True
):
self.rel_l1_thresh = rel_l1_thresh
self.enable = enable
self._reset_state()
def _reset_state(self):
"""Reset internal state"""
self.cnt = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
def create_teacache_forward(original_forward):
"""
Factory function to create a TeaCache-enabled forward pass
Args:
original_forward: Original forward method to wrap
Returns:
Wrapped forward method with TeaCache acceleration
"""
@wraps(original_forward)
def teacache_forward(
self,
x: torch.Tensor,
t: torch.Tensor,
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None,
text_states_2: Optional[torch.Tensor] = None,
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# Skip TeaCache if not enabled
if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
return original_forward(self, x, t, text_states, text_mask, text_states_2,
freqs_cos, freqs_sin, guidance, return_dict)
config = self.teacache_config
out = {}
img = x
txt = text_states
# Calculate patch dimensions
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors
vec = self.time_in(t)
vec = vec + self.vector_in(text_states_2)
if self.guidance_embed:
if guidance is None:
raise ValueError("Guidance strength required for guidance distilled model")
vec = vec + self.guidance_in(guidance)
# Embed image and text
img = self.img_in(img)
if hasattr(self, 'text_projection'):
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
else:
txt = self.txt_in(txt)
# TeaCache optimization logic
inp = img.clone()
vec_ = vec.clone()
# Get modulation parameters
mod_params = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
img_mod1_shift, img_mod1_scale = mod_params[0], mod_params[1]
# Calculate modulated input
normed_inp = self.double_blocks[0].img_norm1(inp)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
# Determine if we should calculate or use cache
should_calc = True
if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
should_calc = True
config.accumulated_rel_l1_distance = 0
elif config.previous_modulated_input is not None:
# Polynomial coefficients for rescaling
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01,
-3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
# Calculate relative L1 distance
rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() /
config.previous_modulated_input.abs().mean()).cpu().item()
config.accumulated_rel_l1_distance += rescale_func(rel_l1)
should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
if should_calc:
config.accumulated_rel_l1_distance = 0
config.previous_modulated_input = modulated_inp
config.cnt += 1
if config.cnt >= self.num_inference_steps:
config.cnt = 0
# Use cache or calculate new result
if not should_calc and config.previous_residual is not None:
img += config.previous_residual
else:
ori_img = img.clone()
# Process through transformer blocks
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# Original processing logic
for block in self.double_blocks:
img, txt = block(
img, txt, vec,
None, None, img_seq_len + txt_seq_len,
img_seq_len + txt_seq_len,
(freqs_cos, freqs_sin) if freqs_cos is not None else None
)
if hasattr(self, 'single_blocks') and self.single_blocks:
x = torch.cat((img, txt), 1)
for block in self.single_blocks:
x = block(
x, vec, txt_seq_len,
None, None, img_seq_len + txt_seq_len,
img_seq_len + txt_seq_len,
(freqs_cos, freqs_sin)
)
img = x[:, :img_seq_len, ...]
# Store residual for future use
config.previous_residual = img - ori_img
# Final processing
img = self.final_layer(img, vec)
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out["x"] = img
return out
return img
return teacache_forward
def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
"""
Enable TeaCache acceleration for a model
Args:
model: The transformer model to accelerate
num_inference_steps: Number of inference steps
rel_l1_thresh: Relative L1 threshold for cache usage
"""
# Store original forward method
if not hasattr(model, '_original_forward'):
model._original_forward = model.forward
# Create and attach TeaCache config
model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
# Set inference steps
model.num_inference_steps = num_inference_steps
# Replace forward method with TeaCache version
model.forward = create_teacache_forward(model._original_forward).__get__(model)
def disable_teacache(model: Any):
"""
Disable TeaCache acceleration for a model
Args:
model: The transformer model to restore
"""
if hasattr(model, '_original_forward'):
model.forward = model._original_forward
del model._original_forward
if hasattr(model, 'teacache_config'):
del model.teacache_config