from typing import Any, Optional, Type import numpy as np import torch from diffusers.models.transformers import LTXVideoTransformer3DModel from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers ) class TeaCacheConfig: """Configuration for TeaCache optimization""" def __init__( self, enabled: bool = True, rel_l1_thresh: float = 0.05, # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup num_inference_steps: int = 50 ): self.enabled = enabled self.rel_l1_thresh = rel_l1_thresh self.num_inference_steps = num_inference_steps # Internal state self.cnt = 0 self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.previous_residual = None def create_teacache_forward(original_forward: Any): """Factory function to create a TeaCache-enabled forward pass""" def teacache_forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, num_frames: int, height: int, width: int, rope_interpolation_scale: Optional[tuple[float, float, float]] = None, attention_kwargs: Optional[dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: # Handle LoRA scaling if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: attention_kwargs = {} lora_scale = 1.0 if USE_PEFT_BACKEND: scale_lora_layers(self, lora_scale) # Initial processing image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) batch_size = hidden_states.size(0) hidden_states = self.proj_in(hidden_states) # Time embedding temb, embedded_timestep = self.time_embed( timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) temb = temb.view(batch_size, -1, temb.size(-1)) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) # Caption projection encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) # TeaCache optimization logic should_calc = True if hasattr(self, 'teacache_config') and self.teacache_config.enabled: inp = hidden_states.clone() temb_ = temb.clone() inp = self.transformer_blocks[0].norm1(inp) num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] ada_values = ( self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1) ) shift_msa, scale_msa, *_ = ada_values.unbind(dim=2) modulated_inp = inp * (1 + scale_msa) + shift_msa # Determine if we should calculate or reuse if self.teacache_config.cnt == 0 or self.teacache_config.cnt == self.teacache_config.num_inference_steps - 1: should_calc = True self.teacache_config.accumulated_rel_l1_distance = 0 else: # Polynomial coefficients for rescaling coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03] rescale_func = np.poly1d(coefficients) rel_diff = ( (modulated_inp - self.teacache_config.previous_modulated_input).abs().mean() / self.teacache_config.previous_modulated_input.abs().mean() ).cpu().item() self.teacache_config.accumulated_rel_l1_distance += rescale_func(rel_diff) if self.teacache_config.accumulated_rel_l1_distance < self.teacache_config.rel_l1_thresh: should_calc = False else: should_calc = True self.teacache_config.accumulated_rel_l1_distance = 0 self.teacache_config.previous_modulated_input = modulated_inp self.teacache_config.cnt += 1 if self.teacache_config.cnt == self.teacache_config.num_inference_steps: self.teacache_config.cnt = 0 # Process hidden states if hasattr(self, 'teacache_config') and self.teacache_config.enabled and not should_calc: hidden_states += self.teacache_config.previous_residual else: ori_hidden_states = hidden_states.clone() if hasattr(self, 'teacache_config') and self.teacache_config.enabled else None for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, encoder_attention_mask, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, encoder_attention_mask=encoder_attention_mask, ) scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] hidden_states = self.norm_out(hidden_states) hidden_states = hidden_states * (1 + scale) + shift if hasattr(self, 'teacache_config') and self.teacache_config.enabled: self.teacache_config.previous_residual = hidden_states - ori_hidden_states output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return {"sample": output} return teacache_forward def enable_teacache(model_class: Type[LTXVideoTransformer3DModel], config: TeaCacheConfig) -> None: """Enable TeaCache optimization for a model class Args: model_class: The model class to patch config: TeaCache configuration """ # Store original forward method if needed if not hasattr(model_class, '_original_forward'): model_class._original_forward = model_class.forward # Create new forward method with TeaCache model_class.forward = create_teacache_forward(model_class._original_forward) # Add config attribute to class model_class.teacache_config = config def disable_teacache(model_class: Type[LTXVideoTransformer3DModel]) -> None: """Disable TeaCache optimization for a model class Args: model_class: The model class to unpatch """ if hasattr(model_class, '_original_forward'): model_class.forward = model_class._original_forward delattr(model_class, '_original_forward') if hasattr(model_class, 'teacache_config'): delattr(model_class, 'teacache_config')