File size: 8,542 Bytes
b1f3a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from typing import Any, Optional, Type
import numpy as np
import torch
from diffusers.models.transformers import LTXVideoTransformer3DModel
from diffusers.utils import (
    USE_PEFT_BACKEND,
    is_torch_version,
    scale_lora_layers,
    unscale_lora_layers
)

class TeaCacheConfig:
    """Configuration for TeaCache optimization"""
    def __init__(
        self,
        enabled: bool = True,
        rel_l1_thresh: float = 0.05,  # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
        num_inference_steps: int = 50
    ):
        self.enabled = enabled
        self.rel_l1_thresh = rel_l1_thresh
        self.num_inference_steps = num_inference_steps
        
        # Internal state
        self.cnt = 0
        self.accumulated_rel_l1_distance = 0
        self.previous_modulated_input = None
        self.previous_residual = None

def create_teacache_forward(original_forward: Any):
    """Factory function to create a TeaCache-enabled forward pass"""
    
    def teacache_forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        timestep: torch.LongTensor,
        encoder_attention_mask: torch.Tensor,
        num_frames: int,
        height: int,
        width: int,
        rope_interpolation_scale: Optional[tuple[float, float, float]] = None,
        attention_kwargs: Optional[dict[str, Any]] = None,
        return_dict: bool = True,
    ) -> torch.Tensor:
        # Handle LoRA scaling
        if attention_kwargs is not None:
            attention_kwargs = attention_kwargs.copy()
            lora_scale = attention_kwargs.pop("scale", 1.0)
        else:
            lora_scale = 1.0

        if USE_PEFT_BACKEND:
            scale_lora_layers(self, lora_scale)

        # Initial processing
        image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
        
        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
            encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

        batch_size = hidden_states.size(0)
        hidden_states = self.proj_in(hidden_states)

        # Time embedding
        temb, embedded_timestep = self.time_embed(
            timestep.flatten(),
            batch_size=batch_size,
            hidden_dtype=hidden_states.dtype,
        )
        temb = temb.view(batch_size, -1, temb.size(-1))
        embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))

        # Caption projection
        encoder_hidden_states = self.caption_projection(encoder_hidden_states)
        encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))

        # TeaCache optimization logic
        should_calc = True
        if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
            inp = hidden_states.clone()
            temb_ = temb.clone()
            inp = self.transformer_blocks[0].norm1(inp)
            
            num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
            ada_values = (
                self.transformer_blocks[0].scale_shift_table[None, None] + 
                temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
            )
            shift_msa, scale_msa, *_ = ada_values.unbind(dim=2)
            modulated_inp = inp * (1 + scale_msa) + shift_msa

            # Determine if we should calculate or reuse
            if self.teacache_config.cnt == 0 or self.teacache_config.cnt == self.teacache_config.num_inference_steps - 1:
                should_calc = True
                self.teacache_config.accumulated_rel_l1_distance = 0
            else:
                # Polynomial coefficients for rescaling
                coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
                rescale_func = np.poly1d(coefficients)
                
                rel_diff = (
                    (modulated_inp - self.teacache_config.previous_modulated_input).abs().mean() / 
                    self.teacache_config.previous_modulated_input.abs().mean()
                ).cpu().item()
                
                self.teacache_config.accumulated_rel_l1_distance += rescale_func(rel_diff)
                
                if self.teacache_config.accumulated_rel_l1_distance < self.teacache_config.rel_l1_thresh:
                    should_calc = False
                else:
                    should_calc = True
                    self.teacache_config.accumulated_rel_l1_distance = 0

            self.teacache_config.previous_modulated_input = modulated_inp
            self.teacache_config.cnt += 1
            if self.teacache_config.cnt == self.teacache_config.num_inference_steps:
                self.teacache_config.cnt = 0

        # Process hidden states
        if hasattr(self, 'teacache_config') and self.teacache_config.enabled and not should_calc:
            hidden_states += self.teacache_config.previous_residual
        else:
            ori_hidden_states = hidden_states.clone() if hasattr(self, 'teacache_config') and self.teacache_config.enabled else None
            
            for block in self.transformer_blocks:
                if torch.is_grad_enabled() and self.gradient_checkpointing:
                    def create_custom_forward(module, return_dict=None):
                        def custom_forward(*inputs):
                            if return_dict is not None:
                                return module(*inputs, return_dict=return_dict)
                            else:
                                return module(*inputs)
                        return custom_forward

                    ckpt_kwargs: dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(block),
                        hidden_states,
                        encoder_hidden_states,
                        temb,
                        image_rotary_emb,
                        encoder_attention_mask,
                        **ckpt_kwargs,
                    )
                else:
                    hidden_states = block(
                        hidden_states=hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        temb=temb,
                        image_rotary_emb=image_rotary_emb,
                        encoder_attention_mask=encoder_attention_mask,
                    )

            scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
            shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

            hidden_states = self.norm_out(hidden_states)
            hidden_states = hidden_states * (1 + scale) + shift
            
            if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
                self.teacache_config.previous_residual = hidden_states - ori_hidden_states

        output = self.proj_out(hidden_states)

        if USE_PEFT_BACKEND:
            unscale_lora_layers(self, lora_scale)

        if not return_dict:
            return (output,)
            
        return {"sample": output}

    return teacache_forward

def enable_teacache(model_class: Type[LTXVideoTransformer3DModel], config: TeaCacheConfig) -> None:
    """Enable TeaCache optimization for a model class
    
    Args:
        model_class: The model class to patch
        config: TeaCache configuration
    """
    # Store original forward method if needed
    if not hasattr(model_class, '_original_forward'):
        model_class._original_forward = model_class.forward
    
    # Create new forward method with TeaCache
    model_class.forward = create_teacache_forward(model_class._original_forward)
    
    # Add config attribute to class
    model_class.teacache_config = config

def disable_teacache(model_class: Type[LTXVideoTransformer3DModel]) -> None:
    """Disable TeaCache optimization for a model class
    
    Args:
        model_class: The model class to unpatch
    """
    if hasattr(model_class, '_original_forward'):
        model_class.forward = model_class._original_forward
        delattr(model_class, '_original_forward')
    
    if hasattr(model_class, 'teacache_config'):
        delattr(model_class, 'teacache_config')