File size: 8,542 Bytes
b1f3a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
from typing import Any, Optional, Type
import numpy as np
import torch
from diffusers.models.transformers import LTXVideoTransformer3DModel
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_version,
scale_lora_layers,
unscale_lora_layers
)
class TeaCacheConfig:
"""Configuration for TeaCache optimization"""
def __init__(
self,
enabled: bool = True,
rel_l1_thresh: float = 0.05, # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
num_inference_steps: int = 50
):
self.enabled = enabled
self.rel_l1_thresh = rel_l1_thresh
self.num_inference_steps = num_inference_steps
# Internal state
self.cnt = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
def create_teacache_forward(original_forward: Any):
"""Factory function to create a TeaCache-enabled forward pass"""
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[tuple[float, float, float]] = None,
attention_kwargs: Optional[dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
# Handle LoRA scaling
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
# Initial processing
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
hidden_states = self.proj_in(hidden_states)
# Time embedding
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
# Caption projection
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
# TeaCache optimization logic
should_calc = True
if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
inp = hidden_states.clone()
temb_ = temb.clone()
inp = self.transformer_blocks[0].norm1(inp)
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
ada_values = (
self.transformer_blocks[0].scale_shift_table[None, None] +
temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
)
shift_msa, scale_msa, *_ = ada_values.unbind(dim=2)
modulated_inp = inp * (1 + scale_msa) + shift_msa
# Determine if we should calculate or reuse
if self.teacache_config.cnt == 0 or self.teacache_config.cnt == self.teacache_config.num_inference_steps - 1:
should_calc = True
self.teacache_config.accumulated_rel_l1_distance = 0
else:
# Polynomial coefficients for rescaling
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
rescale_func = np.poly1d(coefficients)
rel_diff = (
(modulated_inp - self.teacache_config.previous_modulated_input).abs().mean() /
self.teacache_config.previous_modulated_input.abs().mean()
).cpu().item()
self.teacache_config.accumulated_rel_l1_distance += rescale_func(rel_diff)
if self.teacache_config.accumulated_rel_l1_distance < self.teacache_config.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.teacache_config.accumulated_rel_l1_distance = 0
self.teacache_config.previous_modulated_input = modulated_inp
self.teacache_config.cnt += 1
if self.teacache_config.cnt == self.teacache_config.num_inference_steps:
self.teacache_config.cnt = 0
# Process hidden states
if hasattr(self, 'teacache_config') and self.teacache_config.enabled and not should_calc:
hidden_states += self.teacache_config.previous_residual
else:
ori_hidden_states = hidden_states.clone() if hasattr(self, 'teacache_config') and self.teacache_config.enabled else None
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
self.teacache_config.previous_residual = hidden_states - ori_hidden_states
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return {"sample": output}
return teacache_forward
def enable_teacache(model_class: Type[LTXVideoTransformer3DModel], config: TeaCacheConfig) -> None:
"""Enable TeaCache optimization for a model class
Args:
model_class: The model class to patch
config: TeaCache configuration
"""
# Store original forward method if needed
if not hasattr(model_class, '_original_forward'):
model_class._original_forward = model_class.forward
# Create new forward method with TeaCache
model_class.forward = create_teacache_forward(model_class._original_forward)
# Add config attribute to class
model_class.teacache_config = config
def disable_teacache(model_class: Type[LTXVideoTransformer3DModel]) -> None:
"""Disable TeaCache optimization for a model class
Args:
model_class: The model class to unpatch
"""
if hasattr(model_class, '_original_forward'):
model_class.forward = model_class._original_forward
delattr(model_class, '_original_forward')
if hasattr(model_class, 'teacache_config'):
delattr(model_class, 'teacache_config')
|