|
from typing import Any, Optional, Type |
|
import numpy as np |
|
import torch |
|
from diffusers.models.transformers import LTXVideoTransformer3DModel |
|
from diffusers.utils import ( |
|
USE_PEFT_BACKEND, |
|
is_torch_version, |
|
scale_lora_layers, |
|
unscale_lora_layers |
|
) |
|
|
|
class TeaCacheConfig: |
|
"""Configuration for TeaCache optimization""" |
|
def __init__( |
|
self, |
|
enabled: bool = True, |
|
rel_l1_thresh: float = 0.05, |
|
num_inference_steps: int = 50 |
|
): |
|
self.enabled = enabled |
|
self.rel_l1_thresh = rel_l1_thresh |
|
self.num_inference_steps = num_inference_steps |
|
|
|
|
|
self.cnt = 0 |
|
self.accumulated_rel_l1_distance = 0 |
|
self.previous_modulated_input = None |
|
self.previous_residual = None |
|
|
|
def create_teacache_forward(original_forward: Any): |
|
"""Factory function to create a TeaCache-enabled forward pass""" |
|
|
|
def teacache_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
timestep: torch.LongTensor, |
|
encoder_attention_mask: torch.Tensor, |
|
num_frames: int, |
|
height: int, |
|
width: int, |
|
rope_interpolation_scale: Optional[tuple[float, float, float]] = None, |
|
attention_kwargs: Optional[dict[str, Any]] = None, |
|
return_dict: bool = True, |
|
) -> torch.Tensor: |
|
|
|
if attention_kwargs is not None: |
|
attention_kwargs = attention_kwargs.copy() |
|
lora_scale = attention_kwargs.pop("scale", 1.0) |
|
else: |
|
attention_kwargs = {} |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
scale_lora_layers(self, lora_scale) |
|
|
|
|
|
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) |
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
batch_size = hidden_states.size(0) |
|
hidden_states = self.proj_in(hidden_states) |
|
|
|
|
|
temb, embedded_timestep = self.time_embed( |
|
timestep.flatten(), |
|
batch_size=batch_size, |
|
hidden_dtype=hidden_states.dtype, |
|
) |
|
temb = temb.view(batch_size, -1, temb.size(-1)) |
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) |
|
|
|
|
|
encoder_hidden_states = self.caption_projection(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) |
|
|
|
|
|
should_calc = True |
|
if hasattr(self, 'teacache_config') and self.teacache_config.enabled: |
|
inp = hidden_states.clone() |
|
temb_ = temb.clone() |
|
inp = self.transformer_blocks[0].norm1(inp) |
|
|
|
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] |
|
ada_values = ( |
|
self.transformer_blocks[0].scale_shift_table[None, None] + |
|
temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1) |
|
) |
|
shift_msa, scale_msa, *_ = ada_values.unbind(dim=2) |
|
modulated_inp = inp * (1 + scale_msa) + shift_msa |
|
|
|
|
|
if self.teacache_config.cnt == 0 or self.teacache_config.cnt == self.teacache_config.num_inference_steps - 1: |
|
should_calc = True |
|
self.teacache_config.accumulated_rel_l1_distance = 0 |
|
else: |
|
|
|
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03] |
|
rescale_func = np.poly1d(coefficients) |
|
|
|
rel_diff = ( |
|
(modulated_inp - self.teacache_config.previous_modulated_input).abs().mean() / |
|
self.teacache_config.previous_modulated_input.abs().mean() |
|
).cpu().item() |
|
|
|
self.teacache_config.accumulated_rel_l1_distance += rescale_func(rel_diff) |
|
|
|
if self.teacache_config.accumulated_rel_l1_distance < self.teacache_config.rel_l1_thresh: |
|
should_calc = False |
|
else: |
|
should_calc = True |
|
self.teacache_config.accumulated_rel_l1_distance = 0 |
|
|
|
self.teacache_config.previous_modulated_input = modulated_inp |
|
self.teacache_config.cnt += 1 |
|
if self.teacache_config.cnt == self.teacache_config.num_inference_steps: |
|
self.teacache_config.cnt = 0 |
|
|
|
|
|
if hasattr(self, 'teacache_config') and self.teacache_config.enabled and not should_calc: |
|
hidden_states += self.teacache_config.previous_residual |
|
else: |
|
ori_hidden_states = hidden_states.clone() if hasattr(self, 'teacache_config') and self.teacache_config.enabled else None |
|
|
|
for block in self.transformer_blocks: |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
return custom_forward |
|
|
|
ckpt_kwargs: dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
temb, |
|
image_rotary_emb, |
|
encoder_attention_mask, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=temb, |
|
image_rotary_emb=image_rotary_emb, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
|
|
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] |
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] |
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = hidden_states * (1 + scale) + shift |
|
|
|
if hasattr(self, 'teacache_config') and self.teacache_config.enabled: |
|
self.teacache_config.previous_residual = hidden_states - ori_hidden_states |
|
|
|
output = self.proj_out(hidden_states) |
|
|
|
if USE_PEFT_BACKEND: |
|
unscale_lora_layers(self, lora_scale) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return {"sample": output} |
|
|
|
return teacache_forward |
|
|
|
def enable_teacache(model_class: Type[LTXVideoTransformer3DModel], config: TeaCacheConfig) -> None: |
|
"""Enable TeaCache optimization for a model class |
|
|
|
Args: |
|
model_class: The model class to patch |
|
config: TeaCache configuration |
|
""" |
|
|
|
if not hasattr(model_class, '_original_forward'): |
|
model_class._original_forward = model_class.forward |
|
|
|
|
|
model_class.forward = create_teacache_forward(model_class._original_forward) |
|
|
|
|
|
model_class.teacache_config = config |
|
|
|
def disable_teacache(model_class: Type[LTXVideoTransformer3DModel]) -> None: |
|
"""Disable TeaCache optimization for a model class |
|
|
|
Args: |
|
model_class: The model class to unpatch |
|
""" |
|
if hasattr(model_class, '_original_forward'): |
|
model_class.forward = model_class._original_forward |
|
delattr(model_class, '_original_forward') |
|
|
|
if hasattr(model_class, 'teacache_config'): |
|
delattr(model_class, 'teacache_config') |