File size: 5,503 Bytes
f8498f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# teacache.py
import torch
import numpy as np
from typing import Optional, Dict, Union, Any
from functools import wraps

class TeaCacheConfig:
   """Configuration for TeaCache acceleration"""
   def __init__(
       self,
       rel_l1_thresh: float = 0.15,
       enable: bool = True
   ):
       self.rel_l1_thresh = rel_l1_thresh
       self.enable = enable
       self._reset_state()
   
   def _reset_state(self):
       """Reset internal state"""
       self.cnt = 0
       self.accumulated_rel_l1_distance = 0
       self.previous_modulated_input = None
       self.previous_residual = None

def create_teacache_forward(original_forward):
   """Factory function to create a TeaCache-enabled forward pass"""
   @wraps(original_forward)
   def teacache_forward(
       self,
       hidden_states: torch.Tensor,
       timestep: torch.Tensor,
       encoder_hidden_states: Optional[torch.Tensor] = None,
       encoder_attention_mask: Optional[torch.Tensor] = None,
       pooled_projections: Optional[torch.Tensor] = None,
       guidance: Optional[torch.Tensor] = None,
       attention_kwargs: Optional[Dict[str, Any]] = None,
       return_dict: bool = True,
   ):
       # Skip TeaCache if not enabled
       if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
           return original_forward(
               self,
               hidden_states=hidden_states,
               timestep=timestep,
               encoder_hidden_states=encoder_hidden_states,
               encoder_attention_mask=encoder_attention_mask,
               pooled_projections=pooled_projections,
               guidance=guidance,
               attention_kwargs=attention_kwargs,
               return_dict=return_dict
           )

       config = self.teacache_config

       # Prepare modulation vectors similar to HunyuanVideo implementation
       if pooled_projections is not None:
           vec = self.vector_in(pooled_projections)

       if guidance is not None:
           if vec is None:
               vec = self.guidance_in(guidance)
           else:
               vec = vec + self.guidance_in(guidance)

       # TeaCache optimization logic
       inp = hidden_states.clone()
       if hasattr(self.double_blocks[0], 'img_norm1'):
           # HunyuanVideo specific modulation
           img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1)
           normed_inp = self.double_blocks[0].img_norm1(inp)
           modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
       else:
           # Fallback modulation
           normed_inp = self.transformer_blocks[0].norm1(inp)
           modulated_inp = normed_inp

       # Determine if we should calculate or use cache
       should_calc = True
       if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
           should_calc = True
           config.accumulated_rel_l1_distance = 0
       elif config.previous_modulated_input is not None:
           coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, 
                         -3.14987800e+00, 9.61237896e-02]
           rescale_func = np.poly1d(coefficients)
           
           rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() / 
                    config.previous_modulated_input.abs().mean()).cpu().item()
           config.accumulated_rel_l1_distance += rescale_func(rel_l1)
           
           should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
           if should_calc:
               config.accumulated_rel_l1_distance = 0

       config.previous_modulated_input = modulated_inp
       config.cnt += 1
       if config.cnt >= self.num_inference_steps:
           config.cnt = 0

       # Use cache or calculate new result
       if not should_calc and config.previous_residual is not None:
           hidden_states += config.previous_residual
       else:
           ori_hidden_states = hidden_states.clone()
           
           # Use original forward pass
           out = original_forward(
               self,
               hidden_states=hidden_states,
               timestep=timestep,
               encoder_hidden_states=encoder_hidden_states,
               encoder_attention_mask=encoder_attention_mask,
               pooled_projections=pooled_projections,
               guidance=guidance,
               attention_kwargs=attention_kwargs,
               return_dict=True
           )
           hidden_states = out["sample"]

           # Store residual for future use
           config.previous_residual = hidden_states - ori_hidden_states

       if not return_dict:
           return (hidden_states,)
           
       return {"sample": hidden_states}

   return teacache_forward

def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
   """Enable TeaCache acceleration for a model"""
   if not hasattr(model, '_original_forward'):
       model._original_forward = model.forward
   
   model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
   model.num_inference_steps = num_inference_steps
   model.forward = create_teacache_forward(model._original_forward).__get__(model)

def disable_teacache(model: Any):
   """Disable TeaCache acceleration for a model"""
   if hasattr(model, '_original_forward'):
       model.forward = model._original_forward
       del model._original_forward
   
   if hasattr(model, 'teacache_config'):
       del model.teacache_config