LTX-Video-0.9.1-HFIE / teacache.py
jbilcke-hf's picture
jbilcke-hf HF staff
Update teacache.py
fcbe761 verified
raw
history blame
8.58 kB
from typing import Any, Optional, Type
import numpy as np
import torch
from diffusers.models.transformers import LTXVideoTransformer3DModel
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_version,
scale_lora_layers,
unscale_lora_layers
)
class TeaCacheConfig:
"""Configuration for TeaCache optimization"""
def __init__(
self,
enabled: bool = True,
rel_l1_thresh: float = 0.05, # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
num_inference_steps: int = 50
):
self.enabled = enabled
self.rel_l1_thresh = rel_l1_thresh
self.num_inference_steps = num_inference_steps
# Internal state
self.cnt = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
def create_teacache_forward(original_forward: Any):
"""Factory function to create a TeaCache-enabled forward pass"""
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[tuple[float, float, float]] = None,
attention_kwargs: Optional[dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
# Handle LoRA scaling
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
attention_kwargs = {}
lora_scale = 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
# Initial processing
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
hidden_states = self.proj_in(hidden_states)
# Time embedding
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
# Caption projection
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
# TeaCache optimization logic
should_calc = True
if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
inp = hidden_states.clone()
temb_ = temb.clone()
inp = self.transformer_blocks[0].norm1(inp)
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
ada_values = (
self.transformer_blocks[0].scale_shift_table[None, None] +
temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
)
shift_msa, scale_msa, *_ = ada_values.unbind(dim=2)
modulated_inp = inp * (1 + scale_msa) + shift_msa
# Determine if we should calculate or reuse
if self.teacache_config.cnt == 0 or self.teacache_config.cnt == self.teacache_config.num_inference_steps - 1:
should_calc = True
self.teacache_config.accumulated_rel_l1_distance = 0
else:
# Polynomial coefficients for rescaling
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
rescale_func = np.poly1d(coefficients)
rel_diff = (
(modulated_inp - self.teacache_config.previous_modulated_input).abs().mean() /
self.teacache_config.previous_modulated_input.abs().mean()
).cpu().item()
self.teacache_config.accumulated_rel_l1_distance += rescale_func(rel_diff)
if self.teacache_config.accumulated_rel_l1_distance < self.teacache_config.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.teacache_config.accumulated_rel_l1_distance = 0
self.teacache_config.previous_modulated_input = modulated_inp
self.teacache_config.cnt += 1
if self.teacache_config.cnt == self.teacache_config.num_inference_steps:
self.teacache_config.cnt = 0
# Process hidden states
if hasattr(self, 'teacache_config') and self.teacache_config.enabled and not should_calc:
hidden_states += self.teacache_config.previous_residual
else:
ori_hidden_states = hidden_states.clone() if hasattr(self, 'teacache_config') and self.teacache_config.enabled else None
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
self.teacache_config.previous_residual = hidden_states - ori_hidden_states
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return {"sample": output}
return teacache_forward
def enable_teacache(model_class: Type[LTXVideoTransformer3DModel], config: TeaCacheConfig) -> None:
"""Enable TeaCache optimization for a model class
Args:
model_class: The model class to patch
config: TeaCache configuration
"""
# Store original forward method if needed
if not hasattr(model_class, '_original_forward'):
model_class._original_forward = model_class.forward
# Create new forward method with TeaCache
model_class.forward = create_teacache_forward(model_class._original_forward)
# Add config attribute to class
model_class.teacache_config = config
def disable_teacache(model_class: Type[LTXVideoTransformer3DModel]) -> None:
"""Disable TeaCache optimization for a model class
Args:
model_class: The model class to unpatch
"""
if hasattr(model_class, '_original_forward'):
model_class.forward = model_class._original_forward
delattr(model_class, '_original_forward')
if hasattr(model_class, 'teacache_config'):
delattr(model_class, 'teacache_config')