File size: 23,037 Bytes
067283f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""

A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.

"""

from typing import Optional, Tuple

import torch
from einops import rearrange
from torch import nn
from torchvision import transforms

from enum import Enum
import logging

from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm

from .blocks import (
    FinalLayer,
    GeneralDITTransformerBlock,
    PatchEmbed,
    TimestepEmbedding,
    Timesteps,
)

from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb


class DataType(Enum):
    IMAGE = "image"
    VIDEO = "video"


class GeneralDIT(nn.Module):
    """

    A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.



    Args:

        max_img_h (int): Maximum height of the input images.

        max_img_w (int): Maximum width of the input images.

        max_frames (int): Maximum number of frames in the video sequence.

        in_channels (int): Number of input channels (e.g., RGB channels for color images).

        out_channels (int): Number of output channels.

        patch_spatial (tuple): Spatial resolution of patches for input processing.

        patch_temporal (int): Temporal resolution of patches for input processing.

        concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.

        block_config (str): Configuration of the transformer block. See Notes for supported block types.

        model_channels (int): Base number of channels used throughout the model.

        num_blocks (int): Number of transformer blocks.

        num_heads (int): Number of heads in the multi-head attention layers.

        mlp_ratio (float): Expansion ratio for MLP blocks.

        block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').

        crossattn_emb_channels (int): Number of embedding channels for cross-attention.

        use_cross_attn_mask (bool): Whether to use mask in cross-attention.

        pos_emb_cls (str): Type of positional embeddings.

        pos_emb_learnable (bool): Whether positional embeddings are learnable.

        pos_emb_interpolation (str): Method for interpolating positional embeddings.

        affline_emb_norm (bool): Whether to normalize affine embeddings.

        use_adaln_lora (bool): Whether to use AdaLN-LoRA.

        adaln_lora_dim (int): Dimension for AdaLN-LoRA.

        rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.

        rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.

        rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.

        extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.

        extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.

        extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.

        extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.

        extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.



    Notes:

        Supported block types in block_config:

        * cross_attn, ca: Cross attention

        * full_attn: Full attention on all flattened tokens

        * mlp, ff: Feed forward block

    """

    def __init__(

        self,

        max_img_h: int,

        max_img_w: int,

        max_frames: int,

        in_channels: int,

        out_channels: int,

        patch_spatial: tuple,

        patch_temporal: int,

        concat_padding_mask: bool = True,

        # attention settings

        block_config: str = "FA-CA-MLP",

        model_channels: int = 768,

        num_blocks: int = 10,

        num_heads: int = 16,

        mlp_ratio: float = 4.0,

        block_x_format: str = "BTHWD",

        # cross attention settings

        crossattn_emb_channels: int = 1024,

        use_cross_attn_mask: bool = False,

        # positional embedding settings

        pos_emb_cls: str = "sincos",

        pos_emb_learnable: bool = False,

        pos_emb_interpolation: str = "crop",

        affline_emb_norm: bool = False,  # whether or not to normalize the affine embedding

        use_adaln_lora: bool = False,

        adaln_lora_dim: int = 256,

        rope_h_extrapolation_ratio: float = 1.0,

        rope_w_extrapolation_ratio: float = 1.0,

        rope_t_extrapolation_ratio: float = 1.0,

        extra_per_block_abs_pos_emb: bool = False,

        extra_per_block_abs_pos_emb_type: str = "sincos",

        extra_h_extrapolation_ratio: float = 1.0,

        extra_w_extrapolation_ratio: float = 1.0,

        extra_t_extrapolation_ratio: float = 1.0,

        image_model=None,

        device=None,

        dtype=None,

        operations=None,

    ) -> None:
        super().__init__()
        self.max_img_h = max_img_h
        self.max_img_w = max_img_w
        self.max_frames = max_frames
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.patch_spatial = patch_spatial
        self.patch_temporal = patch_temporal
        self.num_heads = num_heads
        self.num_blocks = num_blocks
        self.model_channels = model_channels
        self.use_cross_attn_mask = use_cross_attn_mask
        self.concat_padding_mask = concat_padding_mask
        # positional embedding settings
        self.pos_emb_cls = pos_emb_cls
        self.pos_emb_learnable = pos_emb_learnable
        self.pos_emb_interpolation = pos_emb_interpolation
        self.affline_emb_norm = affline_emb_norm
        self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
        self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
        self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
        self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
        self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
        self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
        self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
        self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
        self.dtype = dtype
        weight_args = {"device": device, "dtype": dtype}

        in_channels = in_channels + 1 if concat_padding_mask else in_channels
        self.x_embedder = PatchEmbed(
            spatial_patch_size=patch_spatial,
            temporal_patch_size=patch_temporal,
            in_channels=in_channels,
            out_channels=model_channels,
            bias=False,
            weight_args=weight_args,
            operations=operations,
        )

        self.build_pos_embed(device=device, dtype=dtype)
        self.block_x_format = block_x_format
        self.use_adaln_lora = use_adaln_lora
        self.adaln_lora_dim = adaln_lora_dim
        self.t_embedder = nn.ModuleList(
            [Timesteps(model_channels),
             TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
        )

        self.blocks = nn.ModuleDict()

        for idx in range(num_blocks):
            self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
                x_dim=model_channels,
                context_dim=crossattn_emb_channels,
                num_heads=num_heads,
                block_config=block_config,
                mlp_ratio=mlp_ratio,
                x_format=self.block_x_format,
                use_adaln_lora=use_adaln_lora,
                adaln_lora_dim=adaln_lora_dim,
                weight_args=weight_args,
                operations=operations,
            )

        if self.affline_emb_norm:
            logging.debug("Building affine embedding normalization layer")
            self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
        else:
            self.affline_norm = nn.Identity()

        self.final_layer = FinalLayer(
            hidden_size=self.model_channels,
            spatial_patch_size=self.patch_spatial,
            temporal_patch_size=self.patch_temporal,
            out_channels=self.out_channels,
            use_adaln_lora=self.use_adaln_lora,
            adaln_lora_dim=self.adaln_lora_dim,
            weight_args=weight_args,
            operations=operations,
        )

    def build_pos_embed(self, device=None, dtype=None):
        if self.pos_emb_cls == "rope3d":
            cls_type = VideoRopePosition3DEmb
        else:
            raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")

        logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
        kwargs = dict(
            model_channels=self.model_channels,
            len_h=self.max_img_h // self.patch_spatial,
            len_w=self.max_img_w // self.patch_spatial,
            len_t=self.max_frames // self.patch_temporal,
            is_learnable=self.pos_emb_learnable,
            interpolation=self.pos_emb_interpolation,
            head_dim=self.model_channels // self.num_heads,
            h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
            w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
            t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
            device=device,
        )
        self.pos_embedder = cls_type(
            **kwargs,
        )

        if self.extra_per_block_abs_pos_emb:
            assert self.extra_per_block_abs_pos_emb_type in [
                "learnable",
            ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
            kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
            kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
            kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
            kwargs["device"] = device
            kwargs["dtype"] = dtype
            self.extra_pos_embedder = LearnablePosEmbAxis(
                **kwargs,
            )

    def prepare_embedded_sequence(

        self,

        x_B_C_T_H_W: torch.Tensor,

        fps: Optional[torch.Tensor] = None,

        padding_mask: Optional[torch.Tensor] = None,

        latent_condition: Optional[torch.Tensor] = None,

        latent_condition_sigma: Optional[torch.Tensor] = None,

    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """

        Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.



        Args:

            x_B_C_T_H_W (torch.Tensor): video

            fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.

                                    If None, a default value (`self.base_fps`) will be used.

            padding_mask (Optional[torch.Tensor]): current it is not used



        Returns:

            Tuple[torch.Tensor, Optional[torch.Tensor]]:

                - A tensor of shape (B, T, H, W, D) with the embedded sequence.

                - An optional positional embedding tensor, returned only if the positional embedding class

                (`self.pos_emb_cls`) includes 'rope'. Otherwise, None.



        Notes:

            - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.

            - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.

            - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using

                the `self.pos_embedder` with the shape [T, H, W].

            - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the

            `self.pos_embedder` with the fps tensor.

            - Otherwise, the positional embeddings are generated without considering fps.

        """
        if self.concat_padding_mask:
            if padding_mask is not None:
                padding_mask = transforms.functional.resize(
                    padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
                )
            else:
                padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)

            x_B_C_T_H_W = torch.cat(
                [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
            )
        x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)

        if self.extra_per_block_abs_pos_emb:
            extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
        else:
            extra_pos_emb = None

        if "rope" in self.pos_emb_cls.lower():
            return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb

        if "fps_aware" in self.pos_emb_cls:
            x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)  # [B, T, H, W, D]
        else:
            x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device)  # [B, T, H, W, D]

        return x_B_T_H_W_D, None, extra_pos_emb

    def decoder_head(

        self,

        x_B_T_H_W_D: torch.Tensor,

        emb_B_D: torch.Tensor,

        crossattn_emb: torch.Tensor,

        origin_shape: Tuple[int, int, int, int, int],  # [B, C, T, H, W]

        crossattn_mask: Optional[torch.Tensor] = None,

        adaln_lora_B_3D: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        del crossattn_emb, crossattn_mask
        B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
        x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
        x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
        # This is to ensure x_BT_HW_D has the correct shape because
        # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
        x_BT_HW_D = x_BT_HW_D.view(
            B * T_before_patchify // self.patch_temporal,
            H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
            -1,
        )
        x_B_D_T_H_W = rearrange(
            x_BT_HW_D,
            "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
            p1=self.patch_spatial,
            p2=self.patch_spatial,
            H=H_before_patchify // self.patch_spatial,
            W=W_before_patchify // self.patch_spatial,
            t=self.patch_temporal,
            B=B,
        )
        return x_B_D_T_H_W

    def forward_before_blocks(

        self,

        x: torch.Tensor,

        timesteps: torch.Tensor,

        crossattn_emb: torch.Tensor,

        crossattn_mask: Optional[torch.Tensor] = None,

        fps: Optional[torch.Tensor] = None,

        image_size: Optional[torch.Tensor] = None,

        padding_mask: Optional[torch.Tensor] = None,

        scalar_feature: Optional[torch.Tensor] = None,

        data_type: Optional[DataType] = DataType.VIDEO,

        latent_condition: Optional[torch.Tensor] = None,

        latent_condition_sigma: Optional[torch.Tensor] = None,

        **kwargs,

    ) -> torch.Tensor:
        """

        Args:

            x: (B, C, T, H, W) tensor of spatial-temp inputs

            timesteps: (B, ) tensor of timesteps

            crossattn_emb: (B, N, D) tensor of cross-attention embeddings

            crossattn_mask: (B, N) tensor of cross-attention masks

        """
        del kwargs
        assert isinstance(
            data_type, DataType
        ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
        original_shape = x.shape
        x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
            x,
            fps=fps,
            padding_mask=padding_mask,
            latent_condition=latent_condition,
            latent_condition_sigma=latent_condition_sigma,
        )
        # logging affline scale information
        affline_scale_log_info = {}

        timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
        affline_emb_B_D = timesteps_B_D
        affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()

        if scalar_feature is not None:
            raise NotImplementedError("Scalar feature is not implemented yet.")

        affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
        affline_emb_B_D = self.affline_norm(affline_emb_B_D)

        if self.use_cross_attn_mask:
            if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
                crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
            crossattn_mask = crossattn_mask[:, None, None, :]  # .to(dtype=torch.bool)  # [B, 1, 1, length]
        else:
            crossattn_mask = None

        if self.blocks["block0"].x_format == "THWBD":
            x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
            if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
                extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
                    extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
                )
            crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")

            if crossattn_mask:
                crossattn_mask = rearrange(crossattn_mask, "B M -> M B")

        elif self.blocks["block0"].x_format == "BTHWD":
            x = x_B_T_H_W_D
        else:
            raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
        output = {
            "x": x,
            "affline_emb_B_D": affline_emb_B_D,
            "crossattn_emb": crossattn_emb,
            "crossattn_mask": crossattn_mask,
            "rope_emb_L_1_1_D": rope_emb_L_1_1_D,
            "adaln_lora_B_3D": adaln_lora_B_3D,
            "original_shape": original_shape,
            "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
        }
        return output

    def forward(

        self,

        x: torch.Tensor,

        timesteps: torch.Tensor,

        context: torch.Tensor,

        attention_mask: Optional[torch.Tensor] = None,

        # crossattn_emb: torch.Tensor,

        # crossattn_mask: Optional[torch.Tensor] = None,

        fps: Optional[torch.Tensor] = None,

        image_size: Optional[torch.Tensor] = None,

        padding_mask: Optional[torch.Tensor] = None,

        scalar_feature: Optional[torch.Tensor] = None,

        data_type: Optional[DataType] = DataType.VIDEO,

        latent_condition: Optional[torch.Tensor] = None,

        latent_condition_sigma: Optional[torch.Tensor] = None,

        condition_video_augment_sigma: Optional[torch.Tensor] = None,

        **kwargs,

    ):
        """

        Args:

            x: (B, C, T, H, W) tensor of spatial-temp inputs

            timesteps: (B, ) tensor of timesteps

            crossattn_emb: (B, N, D) tensor of cross-attention embeddings

            crossattn_mask: (B, N) tensor of cross-attention masks

            condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to

                augment condition input, the lvg model will condition on the condition_video_augment_sigma value;

                we need forward_before_blocks pass to the forward_before_blocks function.

        """

        crossattn_emb = context
        crossattn_mask = attention_mask

        inputs = self.forward_before_blocks(
            x=x,
            timesteps=timesteps,
            crossattn_emb=crossattn_emb,
            crossattn_mask=crossattn_mask,
            fps=fps,
            image_size=image_size,
            padding_mask=padding_mask,
            scalar_feature=scalar_feature,
            data_type=data_type,
            latent_condition=latent_condition,
            latent_condition_sigma=latent_condition_sigma,
            condition_video_augment_sigma=condition_video_augment_sigma,
            **kwargs,
        )
        x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
            inputs["x"],
            inputs["affline_emb_B_D"],
            inputs["crossattn_emb"],
            inputs["crossattn_mask"],
            inputs["rope_emb_L_1_1_D"],
            inputs["adaln_lora_B_3D"],
            inputs["original_shape"],
        )
        extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
        del inputs

        if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
            assert (
                x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
            ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"

        for _, block in self.blocks.items():
            assert (
                self.blocks["block0"].x_format == block.x_format
            ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"

            if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
                x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
            x = block(
                x,
                affline_emb_B_D,
                crossattn_emb,
                crossattn_mask,
                rope_emb_L_1_1_D=rope_emb_L_1_1_D,
                adaln_lora_B_3D=adaln_lora_B_3D,
            )

        x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

        x_B_D_T_H_W = self.decoder_head(
            x_B_T_H_W_D=x_B_T_H_W_D,
            emb_B_D=affline_emb_B_D,
            crossattn_emb=None,
            origin_shape=original_shape,
            crossattn_mask=None,
            adaln_lora_B_3D=adaln_lora_B_3D,
        )

        return x_B_D_T_H_W