jbilcke-hf HF staff commited on
Commit
ab3b889
·
verified ·
1 Parent(s): 61e34be

Update teacache.py

Browse files
Files changed (1) hide show
  1. teacache.py +123 -116
teacache.py CHANGED
@@ -5,135 +5,142 @@ from typing import Optional, Dict, Union, Any
5
  from functools import wraps
6
 
7
  class TeaCacheConfig:
8
- """Configuration for TeaCache acceleration"""
9
- def __init__(
10
- self,
11
- rel_l1_thresh: float = 0.15,
12
- enable: bool = True
13
- ):
14
- self.rel_l1_thresh = rel_l1_thresh
15
- self.enable = enable
16
- self._reset_state()
17
-
18
- def _reset_state(self):
19
- """Reset internal state"""
20
- self.cnt = 0
21
- self.accumulated_rel_l1_distance = 0
22
- self.previous_modulated_input = None
23
- self.previous_residual = None
24
 
25
  def create_teacache_forward(original_forward):
26
- """Factory function to create a TeaCache-enabled forward pass"""
27
- @wraps(original_forward)
28
- def teacache_forward(
29
- self,
30
- hidden_states: torch.Tensor,
31
- timestep: torch.Tensor,
32
- encoder_hidden_states: Optional[torch.Tensor] = None,
33
- encoder_attention_mask: Optional[torch.Tensor] = None,
34
- pooled_projections: Optional[torch.Tensor] = None,
35
- guidance: Optional[torch.Tensor] = None,
36
- attention_kwargs: Optional[Dict[str, Any]] = None,
37
- return_dict: bool = True,
38
- ):
39
- # Skip TeaCache if not enabled
40
- if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
41
- return original_forward(
42
- self,
43
- hidden_states=hidden_states,
44
- timestep=timestep,
45
- encoder_hidden_states=encoder_hidden_states,
46
- encoder_attention_mask=encoder_attention_mask,
47
- pooled_projections=pooled_projections,
48
- guidance=guidance,
49
- attention_kwargs=attention_kwargs,
50
- return_dict=return_dict
51
- )
52
 
53
- config = self.teacache_config
54
-
55
- # Initial embeddings
56
- # Changed from time_proj to time_in to match HunyuanVideo implementation
57
- t_emb = self.time_in(timestep)
58
 
59
- if pooled_projections is not None:
60
- t_emb = t_emb + self.vector_in(pooled_projections)
 
61
 
62
- if guidance is not None:
63
- t_emb = t_emb + self.guidance_in(guidance)
 
 
 
64
 
65
- # TeaCache optimization logic
66
- inp = hidden_states.clone()
67
- normed_inp = self.transformer_blocks[0].norm1(inp)
68
- modulated_inp = self.transformer_blocks[0].attn1.to_q(normed_inp)
 
 
 
 
 
 
 
69
 
70
- # Determine if we should calculate or use cache
71
- should_calc = True
72
- if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
73
- should_calc = True
74
- config.accumulated_rel_l1_distance = 0
75
- elif config.previous_modulated_input is not None:
76
- coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01,
77
- -3.14987800e+00, 9.61237896e-02]
78
- rescale_func = np.poly1d(coefficients)
79
-
80
- rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() /
81
- config.previous_modulated_input.abs().mean()).cpu().item()
82
- config.accumulated_rel_l1_distance += rescale_func(rel_l1)
83
-
84
- should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
85
- if should_calc:
86
- config.accumulated_rel_l1_distance = 0
87
 
88
- config.previous_modulated_input = modulated_inp
89
- config.cnt += 1
90
- if config.cnt >= self.num_inference_steps:
91
- config.cnt = 0
92
 
93
- # Use cache or calculate new result
94
- if not should_calc and config.previous_residual is not None:
95
- hidden_states += config.previous_residual
96
- else:
97
- ori_hidden_states = hidden_states.clone()
98
-
99
- # Process through transformer blocks
100
- for block in self.transformer_blocks:
101
- hidden_states = block(
102
- hidden_states=hidden_states,
103
- temb=t_emb,
104
- encoder_hidden_states=encoder_hidden_states,
105
- attention_mask=encoder_attention_mask,
106
- **attention_kwargs if attention_kwargs else {}
107
- )
 
 
 
 
108
 
109
- # Store residual for future use
110
- config.previous_residual = hidden_states - ori_hidden_states
111
 
112
- # Final layer normalization and projection
113
- hidden_states = self.norm_out(hidden_states)
114
- output = self.proj_out(hidden_states)
 
115
 
116
- if not return_dict:
117
- return (output,)
118
-
119
- return {"sample": output}
120
-
121
- return teacache_forward
122
 
123
  def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
124
- """Enable TeaCache acceleration for a model"""
125
- if not hasattr(model, '_original_forward'):
126
- model._original_forward = model.forward
127
-
128
- model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
129
- model.num_inference_steps = num_inference_steps
130
- model.forward = create_teacache_forward(model._original_forward).__get__(model)
131
 
132
  def disable_teacache(model: Any):
133
- """Disable TeaCache acceleration for a model"""
134
- if hasattr(model, '_original_forward'):
135
- model.forward = model._original_forward
136
- del model._original_forward
137
-
138
- if hasattr(model, 'teacache_config'):
139
- del model.teacache_config
 
5
  from functools import wraps
6
 
7
  class TeaCacheConfig:
8
+ """Configuration for TeaCache acceleration"""
9
+ def __init__(
10
+ self,
11
+ rel_l1_thresh: float = 0.15,
12
+ enable: bool = True
13
+ ):
14
+ self.rel_l1_thresh = rel_l1_thresh
15
+ self.enable = enable
16
+ self._reset_state()
17
+
18
+ def _reset_state(self):
19
+ """Reset internal state"""
20
+ self.cnt = 0
21
+ self.accumulated_rel_l1_distance = 0
22
+ self.previous_modulated_input = None
23
+ self.previous_residual = None
24
 
25
  def create_teacache_forward(original_forward):
26
+ """Factory function to create a TeaCache-enabled forward pass"""
27
+ @wraps(original_forward)
28
+ def teacache_forward(
29
+ self,
30
+ hidden_states: torch.Tensor,
31
+ timestep: torch.Tensor,
32
+ encoder_hidden_states: Optional[torch.Tensor] = None,
33
+ encoder_attention_mask: Optional[torch.Tensor] = None,
34
+ pooled_projections: Optional[torch.Tensor] = None,
35
+ guidance: Optional[torch.Tensor] = None,
36
+ attention_kwargs: Optional[Dict[str, Any]] = None,
37
+ return_dict: bool = True,
38
+ ):
39
+ # Skip TeaCache if not enabled
40
+ if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
41
+ return original_forward(
42
+ self,
43
+ hidden_states=hidden_states,
44
+ timestep=timestep,
45
+ encoder_hidden_states=encoder_hidden_states,
46
+ encoder_attention_mask=encoder_attention_mask,
47
+ pooled_projections=pooled_projections,
48
+ guidance=guidance,
49
+ attention_kwargs=attention_kwargs,
50
+ return_dict=return_dict
51
+ )
52
 
53
+ config = self.teacache_config
 
 
 
 
54
 
55
+ # Prepare modulation vectors similar to HunyuanVideo implementation
56
+ if pooled_projections is not None:
57
+ vec = self.vector_in(pooled_projections)
58
 
59
+ if guidance is not None:
60
+ if vec is None:
61
+ vec = self.guidance_in(guidance)
62
+ else:
63
+ vec = vec + self.guidance_in(guidance)
64
 
65
+ # TeaCache optimization logic
66
+ inp = hidden_states.clone()
67
+ if hasattr(self.double_blocks[0], 'img_norm1'):
68
+ # HunyuanVideo specific modulation
69
+ img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1)
70
+ normed_inp = self.double_blocks[0].img_norm1(inp)
71
+ modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
72
+ else:
73
+ # Fallback modulation
74
+ normed_inp = self.transformer_blocks[0].norm1(inp)
75
+ modulated_inp = normed_inp
76
 
77
+ # Determine if we should calculate or use cache
78
+ should_calc = True
79
+ if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
80
+ should_calc = True
81
+ config.accumulated_rel_l1_distance = 0
82
+ elif config.previous_modulated_input is not None:
83
+ coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01,
84
+ -3.14987800e+00, 9.61237896e-02]
85
+ rescale_func = np.poly1d(coefficients)
86
+
87
+ rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() /
88
+ config.previous_modulated_input.abs().mean()).cpu().item()
89
+ config.accumulated_rel_l1_distance += rescale_func(rel_l1)
90
+
91
+ should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
92
+ if should_calc:
93
+ config.accumulated_rel_l1_distance = 0
94
 
95
+ config.previous_modulated_input = modulated_inp
96
+ config.cnt += 1
97
+ if config.cnt >= self.num_inference_steps:
98
+ config.cnt = 0
99
 
100
+ # Use cache or calculate new result
101
+ if not should_calc and config.previous_residual is not None:
102
+ hidden_states += config.previous_residual
103
+ else:
104
+ ori_hidden_states = hidden_states.clone()
105
+
106
+ # Use original forward pass
107
+ out = original_forward(
108
+ self,
109
+ hidden_states=hidden_states,
110
+ timestep=timestep,
111
+ encoder_hidden_states=encoder_hidden_states,
112
+ encoder_attention_mask=encoder_attention_mask,
113
+ pooled_projections=pooled_projections,
114
+ guidance=guidance,
115
+ attention_kwargs=attention_kwargs,
116
+ return_dict=True
117
+ )
118
+ hidden_states = out["sample"]
119
 
120
+ # Store residual for future use
121
+ config.previous_residual = hidden_states - ori_hidden_states
122
 
123
+ if not return_dict:
124
+ return (hidden_states,)
125
+
126
+ return {"sample": hidden_states}
127
 
128
+ return teacache_forward
 
 
 
 
 
129
 
130
  def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
131
+ """Enable TeaCache acceleration for a model"""
132
+ if not hasattr(model, '_original_forward'):
133
+ model._original_forward = model.forward
134
+
135
+ model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
136
+ model.num_inference_steps = num_inference_steps
137
+ model.forward = create_teacache_forward(model._original_forward).__get__(model)
138
 
139
  def disable_teacache(model: Any):
140
+ """Disable TeaCache acceleration for a model"""
141
+ if hasattr(model, '_original_forward'):
142
+ model.forward = model._original_forward
143
+ del model._original_forward
144
+
145
+ if hasattr(model, 'teacache_config'):
146
+ del model.teacache_config