|
import torch |
|
import numpy as np |
|
from typing import Optional, Dict, Union, Any |
|
from functools import wraps |
|
|
|
class TeaCacheConfig: |
|
"""Configuration for TeaCache acceleration""" |
|
def __init__( |
|
self, |
|
rel_l1_thresh: float = 0.15, |
|
enable: bool = True |
|
): |
|
self.rel_l1_thresh = rel_l1_thresh |
|
self.enable = enable |
|
self._reset_state() |
|
|
|
def _reset_state(self): |
|
"""Reset internal state""" |
|
self.cnt = 0 |
|
self.accumulated_rel_l1_distance = 0 |
|
self.previous_modulated_input = None |
|
self.previous_residual = None |
|
|
|
def create_teacache_forward(original_forward): |
|
""" |
|
Factory function to create a TeaCache-enabled forward pass |
|
|
|
Args: |
|
original_forward: Original forward method to wrap |
|
|
|
Returns: |
|
Wrapped forward method with TeaCache acceleration |
|
""" |
|
@wraps(original_forward) |
|
def teacache_forward( |
|
self, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
text_states: torch.Tensor = None, |
|
text_mask: torch.Tensor = None, |
|
text_states_2: Optional[torch.Tensor] = None, |
|
freqs_cos: Optional[torch.Tensor] = None, |
|
freqs_sin: Optional[torch.Tensor] = None, |
|
guidance: torch.Tensor = None, |
|
return_dict: bool = True, |
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
if not hasattr(self, 'teacache_config') or not self.teacache_config.enable: |
|
return original_forward(self, x, t, text_states, text_mask, text_states_2, |
|
freqs_cos, freqs_sin, guidance, return_dict) |
|
|
|
config = self.teacache_config |
|
out = {} |
|
img = x |
|
txt = text_states |
|
|
|
|
|
_, _, ot, oh, ow = x.shape |
|
tt, th, tw = ( |
|
ot // self.patch_size[0], |
|
oh // self.patch_size[1], |
|
ow // self.patch_size[2], |
|
) |
|
|
|
|
|
vec = self.time_in(t) |
|
vec = vec + self.vector_in(text_states_2) |
|
|
|
if self.guidance_embed: |
|
if guidance is None: |
|
raise ValueError("Guidance strength required for guidance distilled model") |
|
vec = vec + self.guidance_in(guidance) |
|
|
|
|
|
img = self.img_in(img) |
|
if hasattr(self, 'text_projection'): |
|
if self.text_projection == "linear": |
|
txt = self.txt_in(txt) |
|
elif self.text_projection == "single_refiner": |
|
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) |
|
else: |
|
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") |
|
else: |
|
txt = self.txt_in(txt) |
|
|
|
|
|
inp = img.clone() |
|
vec_ = vec.clone() |
|
|
|
|
|
mod_params = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) |
|
img_mod1_shift, img_mod1_scale = mod_params[0], mod_params[1] |
|
|
|
|
|
normed_inp = self.double_blocks[0].img_norm1(inp) |
|
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift |
|
|
|
|
|
should_calc = True |
|
if config.cnt == 0 or config.cnt == self.num_inference_steps - 1: |
|
should_calc = True |
|
config.accumulated_rel_l1_distance = 0 |
|
elif config.previous_modulated_input is not None: |
|
|
|
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, |
|
-3.14987800e+00, 9.61237896e-02] |
|
rescale_func = np.poly1d(coefficients) |
|
|
|
|
|
rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() / |
|
config.previous_modulated_input.abs().mean()).cpu().item() |
|
config.accumulated_rel_l1_distance += rescale_func(rel_l1) |
|
|
|
should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh |
|
if should_calc: |
|
config.accumulated_rel_l1_distance = 0 |
|
|
|
config.previous_modulated_input = modulated_inp |
|
config.cnt += 1 |
|
if config.cnt >= self.num_inference_steps: |
|
config.cnt = 0 |
|
|
|
|
|
if not should_calc and config.previous_residual is not None: |
|
img += config.previous_residual |
|
else: |
|
ori_img = img.clone() |
|
|
|
|
|
txt_seq_len = txt.shape[1] |
|
img_seq_len = img.shape[1] |
|
|
|
|
|
for block in self.double_blocks: |
|
img, txt = block( |
|
img, txt, vec, |
|
None, None, img_seq_len + txt_seq_len, |
|
img_seq_len + txt_seq_len, |
|
(freqs_cos, freqs_sin) if freqs_cos is not None else None |
|
) |
|
|
|
if hasattr(self, 'single_blocks') and self.single_blocks: |
|
x = torch.cat((img, txt), 1) |
|
for block in self.single_blocks: |
|
x = block( |
|
x, vec, txt_seq_len, |
|
None, None, img_seq_len + txt_seq_len, |
|
img_seq_len + txt_seq_len, |
|
(freqs_cos, freqs_sin) |
|
) |
|
img = x[:, :img_seq_len, ...] |
|
|
|
|
|
config.previous_residual = img - ori_img |
|
|
|
|
|
img = self.final_layer(img, vec) |
|
img = self.unpatchify(img, tt, th, tw) |
|
|
|
if return_dict: |
|
out["x"] = img |
|
return out |
|
return img |
|
|
|
return teacache_forward |
|
|
|
def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15): |
|
""" |
|
Enable TeaCache acceleration for a model |
|
|
|
Args: |
|
model: The transformer model to accelerate |
|
num_inference_steps: Number of inference steps |
|
rel_l1_thresh: Relative L1 threshold for cache usage |
|
""" |
|
|
|
if not hasattr(model, '_original_forward'): |
|
model._original_forward = model.forward |
|
|
|
|
|
model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh) |
|
|
|
|
|
model.num_inference_steps = num_inference_steps |
|
|
|
|
|
model.forward = create_teacache_forward(model._original_forward).__get__(model) |
|
|
|
def disable_teacache(model: Any): |
|
""" |
|
Disable TeaCache acceleration for a model |
|
|
|
Args: |
|
model: The transformer model to restore |
|
""" |
|
if hasattr(model, '_original_forward'): |
|
model.forward = model._original_forward |
|
del model._original_forward |
|
|
|
if hasattr(model, 'teacache_config'): |
|
del model.teacache_config |
|
|