File size: 5,503 Bytes
f8498f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# teacache.py
import torch
import numpy as np
from typing import Optional, Dict, Union, Any
from functools import wraps
class TeaCacheConfig:
"""Configuration for TeaCache acceleration"""
def __init__(
self,
rel_l1_thresh: float = 0.15,
enable: bool = True
):
self.rel_l1_thresh = rel_l1_thresh
self.enable = enable
self._reset_state()
def _reset_state(self):
"""Reset internal state"""
self.cnt = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
def create_teacache_forward(original_forward):
"""Factory function to create a TeaCache-enabled forward pass"""
@wraps(original_forward)
def teacache_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
pooled_projections: Optional[torch.Tensor] = None,
guidance: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
# Skip TeaCache if not enabled
if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
return original_forward(
self,
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
pooled_projections=pooled_projections,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=return_dict
)
config = self.teacache_config
# Prepare modulation vectors similar to HunyuanVideo implementation
if pooled_projections is not None:
vec = self.vector_in(pooled_projections)
if guidance is not None:
if vec is None:
vec = self.guidance_in(guidance)
else:
vec = vec + self.guidance_in(guidance)
# TeaCache optimization logic
inp = hidden_states.clone()
if hasattr(self.double_blocks[0], 'img_norm1'):
# HunyuanVideo specific modulation
img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1)
normed_inp = self.double_blocks[0].img_norm1(inp)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
else:
# Fallback modulation
normed_inp = self.transformer_blocks[0].norm1(inp)
modulated_inp = normed_inp
# Determine if we should calculate or use cache
should_calc = True
if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
should_calc = True
config.accumulated_rel_l1_distance = 0
elif config.previous_modulated_input is not None:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01,
-3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() /
config.previous_modulated_input.abs().mean()).cpu().item()
config.accumulated_rel_l1_distance += rescale_func(rel_l1)
should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
if should_calc:
config.accumulated_rel_l1_distance = 0
config.previous_modulated_input = modulated_inp
config.cnt += 1
if config.cnt >= self.num_inference_steps:
config.cnt = 0
# Use cache or calculate new result
if not should_calc and config.previous_residual is not None:
hidden_states += config.previous_residual
else:
ori_hidden_states = hidden_states.clone()
# Use original forward pass
out = original_forward(
self,
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
pooled_projections=pooled_projections,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=True
)
hidden_states = out["sample"]
# Store residual for future use
config.previous_residual = hidden_states - ori_hidden_states
if not return_dict:
return (hidden_states,)
return {"sample": hidden_states}
return teacache_forward
def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
"""Enable TeaCache acceleration for a model"""
if not hasattr(model, '_original_forward'):
model._original_forward = model.forward
model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
model.num_inference_steps = num_inference_steps
model.forward = create_teacache_forward(model._original_forward).__get__(model)
def disable_teacache(model: Any):
"""Disable TeaCache acceleration for a model"""
if hasattr(model, '_original_forward'):
model.forward = model._original_forward
del model._original_forward
if hasattr(model, 'teacache_config'):
del model.teacache_config |