tsqn commited on
Commit
715f620
·
verified ·
1 Parent(s): 2a75bce

Delete xora

Browse files
xora/__init__.py DELETED
File without changes
xora/models/__init__.py DELETED
File without changes
xora/models/autoencoders/__init__.py DELETED
File without changes
xora/models/autoencoders/causal_conv3d.py DELETED
@@ -1,62 +0,0 @@
1
- from typing import Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- class CausalConv3d(nn.Module):
8
- def __init__(
9
- self,
10
- in_channels,
11
- out_channels,
12
- kernel_size: int = 3,
13
- stride: Union[int, Tuple[int]] = 1,
14
- dilation: int = 1,
15
- groups: int = 1,
16
- **kwargs,
17
- ):
18
- super().__init__()
19
-
20
- self.in_channels = in_channels
21
- self.out_channels = out_channels
22
-
23
- kernel_size = (kernel_size, kernel_size, kernel_size)
24
- self.time_kernel_size = kernel_size[0]
25
-
26
- dilation = (dilation, 1, 1)
27
-
28
- height_pad = kernel_size[1] // 2
29
- width_pad = kernel_size[2] // 2
30
- padding = (0, height_pad, width_pad)
31
-
32
- self.conv = nn.Conv3d(
33
- in_channels,
34
- out_channels,
35
- kernel_size,
36
- stride=stride,
37
- dilation=dilation,
38
- padding=padding,
39
- padding_mode="zeros",
40
- groups=groups,
41
- )
42
-
43
- def forward(self, x, causal: bool = True):
44
- if causal:
45
- first_frame_pad = x[:, :, :1, :, :].repeat(
46
- (1, 1, self.time_kernel_size - 1, 1, 1)
47
- )
48
- x = torch.concatenate((first_frame_pad, x), dim=2)
49
- else:
50
- first_frame_pad = x[:, :, :1, :, :].repeat(
51
- (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
52
- )
53
- last_frame_pad = x[:, :, -1:, :, :].repeat(
54
- (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
55
- )
56
- x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
57
- x = self.conv(x)
58
- return x
59
-
60
- @property
61
- def weight(self):
62
- return self.conv.weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/autoencoders/causal_video_autoencoder.py DELETED
@@ -1,1199 +0,0 @@
1
- import json
2
- import os
3
- from functools import partial
4
- from types import SimpleNamespace
5
- from typing import Any, Mapping, Optional, Tuple, Union, List
6
-
7
- import torch
8
- import numpy as np
9
- from einops import rearrange
10
- from torch import nn
11
- from diffusers.utils import logging
12
- import torch.nn.functional as F
13
- from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
14
-
15
-
16
- from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
17
- from xora.models.autoencoders.pixel_norm import PixelNorm
18
- from xora.models.autoencoders.vae import AutoencoderKLWrapper
19
- from xora.models.transformers.attention import Attention
20
-
21
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
-
23
-
24
- class CausalVideoAutoencoder(AutoencoderKLWrapper):
25
- @classmethod
26
- def from_pretrained(
27
- cls,
28
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
29
- *args,
30
- **kwargs,
31
- ):
32
- config_local_path = pretrained_model_name_or_path / "config.json"
33
- config = cls.load_config(config_local_path, **kwargs)
34
- video_vae = cls.from_config(config)
35
- video_vae.to(kwargs["torch_dtype"])
36
-
37
- model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
38
- ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
39
- video_vae.load_state_dict(ckpt_state_dict)
40
-
41
- statistics_local_path = (
42
- pretrained_model_name_or_path / "per_channel_statistics.json"
43
- )
44
- if statistics_local_path.exists():
45
- with open(statistics_local_path, "r") as file:
46
- data = json.load(file)
47
- transposed_data = list(zip(*data["data"]))
48
- data_dict = {
49
- col: torch.tensor(vals)
50
- for col, vals in zip(data["columns"], transposed_data)
51
- }
52
- video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
53
- video_vae.register_buffer(
54
- "mean_of_means",
55
- data_dict.get(
56
- "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
57
- ),
58
- )
59
-
60
- return video_vae
61
-
62
- @staticmethod
63
- def from_config(config):
64
- assert (
65
- config["_class_name"] == "CausalVideoAutoencoder"
66
- ), "config must have _class_name=CausalVideoAutoencoder"
67
- if isinstance(config["dims"], list):
68
- config["dims"] = tuple(config["dims"])
69
-
70
- assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
71
-
72
- double_z = config.get("double_z", True)
73
- latent_log_var = config.get(
74
- "latent_log_var", "per_channel" if double_z else "none"
75
- )
76
- use_quant_conv = config.get("use_quant_conv", True)
77
-
78
- if use_quant_conv and latent_log_var == "uniform":
79
- raise ValueError("uniform latent_log_var requires use_quant_conv=False")
80
-
81
- encoder = Encoder(
82
- dims=config["dims"],
83
- in_channels=config.get("in_channels", 3),
84
- out_channels=config["latent_channels"],
85
- blocks=config.get("encoder_blocks", config.get("blocks")),
86
- patch_size=config.get("patch_size", 1),
87
- latent_log_var=latent_log_var,
88
- norm_layer=config.get("norm_layer", "group_norm"),
89
- )
90
-
91
- decoder = Decoder(
92
- dims=config["dims"],
93
- in_channels=config["latent_channels"],
94
- out_channels=config.get("out_channels", 3),
95
- blocks=config.get("decoder_blocks", config.get("blocks")),
96
- patch_size=config.get("patch_size", 1),
97
- norm_layer=config.get("norm_layer", "group_norm"),
98
- causal=config.get("causal_decoder", False),
99
- timestep_conditioning=config.get("timestep_conditioning", False),
100
- )
101
-
102
- dims = config["dims"]
103
- return CausalVideoAutoencoder(
104
- encoder=encoder,
105
- decoder=decoder,
106
- latent_channels=config["latent_channels"],
107
- dims=dims,
108
- use_quant_conv=use_quant_conv,
109
- )
110
-
111
- @property
112
- def config(self):
113
- return SimpleNamespace(
114
- _class_name="CausalVideoAutoencoder",
115
- dims=self.dims,
116
- in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
117
- out_channels=self.decoder.conv_out.out_channels
118
- // self.decoder.patch_size**2,
119
- latent_channels=self.decoder.conv_in.in_channels,
120
- encoder_blocks=self.encoder.blocks_desc,
121
- decoder_blocks=self.decoder.blocks_desc,
122
- scaling_factor=1.0,
123
- norm_layer=self.encoder.norm_layer,
124
- patch_size=self.encoder.patch_size,
125
- latent_log_var=self.encoder.latent_log_var,
126
- use_quant_conv=self.use_quant_conv,
127
- causal_decoder=self.decoder.causal,
128
- timestep_conditioning=self.decoder.timestep_conditioning,
129
- )
130
-
131
- @property
132
- def is_video_supported(self):
133
- """
134
- Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
135
- """
136
- return self.dims != 2
137
-
138
- @property
139
- def spatial_downscale_factor(self):
140
- return (
141
- 2
142
- ** len(
143
- [
144
- block
145
- for block in self.encoder.blocks_desc
146
- if block[0] in ["compress_space", "compress_all"]
147
- ]
148
- )
149
- * self.encoder.patch_size
150
- )
151
-
152
- @property
153
- def temporal_downscale_factor(self):
154
- return 2 ** len(
155
- [
156
- block
157
- for block in self.encoder.blocks_desc
158
- if block[0] in ["compress_time", "compress_all"]
159
- ]
160
- )
161
-
162
- def to_json_string(self) -> str:
163
- import json
164
-
165
- return json.dumps(self.config.__dict__)
166
-
167
- def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
168
- per_channel_statistics_prefix = "per_channel_statistics."
169
- ckpt_state_dict = {
170
- key: value
171
- for key, value in state_dict.items()
172
- if not key.startswith(per_channel_statistics_prefix)
173
- }
174
-
175
- model_keys = set(name for name, _ in self.named_parameters())
176
-
177
- key_mapping = {
178
- ".resnets.": ".res_blocks.",
179
- "downsamplers.0": "downsample",
180
- "upsamplers.0": "upsample",
181
- }
182
- converted_state_dict = {}
183
- for key, value in ckpt_state_dict.items():
184
- for k, v in key_mapping.items():
185
- key = key.replace(k, v)
186
-
187
- if "norm" in key and key not in model_keys:
188
- logger.info(
189
- f"Removing key {key} from state_dict as it is not present in the model"
190
- )
191
- continue
192
-
193
- converted_state_dict[key] = value
194
-
195
- super().load_state_dict(converted_state_dict, strict=strict)
196
-
197
- data_dict = {
198
- key.removeprefix(per_channel_statistics_prefix): value
199
- for key, value in state_dict.items()
200
- if key.startswith(per_channel_statistics_prefix)
201
- }
202
- if len(data_dict) > 0:
203
- self.register_buffer("std_of_means", data_dict["std-of-means"])
204
- self.register_buffer(
205
- "mean_of_means",
206
- data_dict.get(
207
- "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
208
- ),
209
- )
210
-
211
- def last_layer(self):
212
- if hasattr(self.decoder, "conv_out"):
213
- if isinstance(self.decoder.conv_out, nn.Sequential):
214
- last_layer = self.decoder.conv_out[-1]
215
- else:
216
- last_layer = self.decoder.conv_out
217
- else:
218
- last_layer = self.decoder.layers[-1]
219
- return last_layer
220
-
221
- def set_use_tpu_flash_attention(self):
222
- for block in self.decoder.up_blocks:
223
- if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
224
- for attention_block in block.attention_blocks:
225
- attention_block.set_use_tpu_flash_attention()
226
-
227
-
228
- class Encoder(nn.Module):
229
- r"""
230
- The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
231
-
232
- Args:
233
- dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
234
- The number of dimensions to use in convolutions.
235
- in_channels (`int`, *optional*, defaults to 3):
236
- The number of input channels.
237
- out_channels (`int`, *optional*, defaults to 3):
238
- The number of output channels.
239
- blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
240
- The blocks to use. Each block is a tuple of the block name and the number of layers.
241
- base_channels (`int`, *optional*, defaults to 128):
242
- The number of output channels for the first convolutional layer.
243
- norm_num_groups (`int`, *optional*, defaults to 32):
244
- The number of groups for normalization.
245
- patch_size (`int`, *optional*, defaults to 1):
246
- The patch size to use. Should be a power of 2.
247
- norm_layer (`str`, *optional*, defaults to `group_norm`):
248
- The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
249
- latent_log_var (`str`, *optional*, defaults to `per_channel`):
250
- The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
251
- """
252
-
253
- def __init__(
254
- self,
255
- dims: Union[int, Tuple[int, int]] = 3,
256
- in_channels: int = 3,
257
- out_channels: int = 3,
258
- blocks: List[Tuple[str, Union[int, dict]]] = [("res_x", 1)],
259
- base_channels: int = 128,
260
- norm_num_groups: int = 32,
261
- patch_size: Union[int, Tuple[int]] = 1,
262
- norm_layer: str = "group_norm", # group_norm, pixel_norm
263
- latent_log_var: str = "per_channel",
264
- ):
265
- super().__init__()
266
- self.patch_size = patch_size
267
- self.norm_layer = norm_layer
268
- self.latent_channels = out_channels
269
- self.latent_log_var = latent_log_var
270
- self.blocks_desc = blocks
271
-
272
- in_channels = in_channels * patch_size**2
273
- output_channel = base_channels
274
-
275
- self.conv_in = make_conv_nd(
276
- dims=dims,
277
- in_channels=in_channels,
278
- out_channels=output_channel,
279
- kernel_size=3,
280
- stride=1,
281
- padding=1,
282
- causal=True,
283
- )
284
-
285
- self.down_blocks = nn.ModuleList([])
286
-
287
- for block_name, block_params in blocks:
288
- input_channel = output_channel
289
- if isinstance(block_params, int):
290
- block_params = {"num_layers": block_params}
291
-
292
- if block_name == "res_x":
293
- block = UNetMidBlock3D(
294
- dims=dims,
295
- in_channels=input_channel,
296
- num_layers=block_params["num_layers"],
297
- resnet_eps=1e-6,
298
- resnet_groups=norm_num_groups,
299
- norm_layer=norm_layer,
300
- )
301
- elif block_name == "res_x_y":
302
- output_channel = block_params.get("multiplier", 2) * output_channel
303
- block = ResnetBlock3D(
304
- dims=dims,
305
- in_channels=input_channel,
306
- out_channels=output_channel,
307
- eps=1e-6,
308
- groups=norm_num_groups,
309
- norm_layer=norm_layer,
310
- )
311
- elif block_name == "compress_time":
312
- block = make_conv_nd(
313
- dims=dims,
314
- in_channels=input_channel,
315
- out_channels=output_channel,
316
- kernel_size=3,
317
- stride=(2, 1, 1),
318
- causal=True,
319
- )
320
- elif block_name == "compress_space":
321
- block = make_conv_nd(
322
- dims=dims,
323
- in_channels=input_channel,
324
- out_channels=output_channel,
325
- kernel_size=3,
326
- stride=(1, 2, 2),
327
- causal=True,
328
- )
329
- elif block_name == "compress_all":
330
- block = make_conv_nd(
331
- dims=dims,
332
- in_channels=input_channel,
333
- out_channels=output_channel,
334
- kernel_size=3,
335
- stride=(2, 2, 2),
336
- causal=True,
337
- )
338
- elif block_name == "compress_all_x_y":
339
- output_channel = block_params.get("multiplier", 2) * output_channel
340
- block = make_conv_nd(
341
- dims=dims,
342
- in_channels=input_channel,
343
- out_channels=output_channel,
344
- kernel_size=3,
345
- stride=(2, 2, 2),
346
- causal=True,
347
- )
348
- else:
349
- raise ValueError(f"unknown block: {block_name}")
350
-
351
- self.down_blocks.append(block)
352
-
353
- # out
354
- if norm_layer == "group_norm":
355
- self.conv_norm_out = nn.GroupNorm(
356
- num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
357
- )
358
- elif norm_layer == "pixel_norm":
359
- self.conv_norm_out = PixelNorm()
360
- elif norm_layer == "layer_norm":
361
- self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
362
-
363
- self.conv_act = nn.SiLU()
364
-
365
- conv_out_channels = out_channels
366
- if latent_log_var == "per_channel":
367
- conv_out_channels *= 2
368
- elif latent_log_var == "uniform":
369
- conv_out_channels += 1
370
- elif latent_log_var != "none":
371
- raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
372
- self.conv_out = make_conv_nd(
373
- dims, output_channel, conv_out_channels, 3, padding=1, causal=True
374
- )
375
-
376
- self.gradient_checkpointing = False
377
-
378
- def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
379
- r"""The forward method of the `Encoder` class."""
380
-
381
- sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
382
- sample = self.conv_in(sample)
383
-
384
- checkpoint_fn = (
385
- partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
386
- if self.gradient_checkpointing and self.training
387
- else lambda x: x
388
- )
389
-
390
- for down_block in self.down_blocks:
391
- sample = checkpoint_fn(down_block)(sample)
392
-
393
- sample = self.conv_norm_out(sample)
394
- sample = self.conv_act(sample)
395
- sample = self.conv_out(sample)
396
-
397
- if self.latent_log_var == "uniform":
398
- last_channel = sample[:, -1:, ...]
399
- num_dims = sample.dim()
400
-
401
- if num_dims == 4:
402
- # For shape (B, C, H, W)
403
- repeated_last_channel = last_channel.repeat(
404
- 1, sample.shape[1] - 2, 1, 1
405
- )
406
- sample = torch.cat([sample, repeated_last_channel], dim=1)
407
- elif num_dims == 5:
408
- # For shape (B, C, F, H, W)
409
- repeated_last_channel = last_channel.repeat(
410
- 1, sample.shape[1] - 2, 1, 1, 1
411
- )
412
- sample = torch.cat([sample, repeated_last_channel], dim=1)
413
- else:
414
- raise ValueError(f"Invalid input shape: {sample.shape}")
415
-
416
- return sample
417
-
418
-
419
- class Decoder(nn.Module):
420
- r"""
421
- The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
422
-
423
- Args:
424
- dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
425
- The number of dimensions to use in convolutions.
426
- in_channels (`int`, *optional*, defaults to 3):
427
- The number of input channels.
428
- out_channels (`int`, *optional*, defaults to 3):
429
- The number of output channels.
430
- blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
431
- The blocks to use. Each block is a tuple of the block name and the number of layers.
432
- base_channels (`int`, *optional*, defaults to 128):
433
- The number of output channels for the first convolutional layer.
434
- norm_num_groups (`int`, *optional*, defaults to 32):
435
- The number of groups for normalization.
436
- patch_size (`int`, *optional*, defaults to 1):
437
- The patch size to use. Should be a power of 2.
438
- norm_layer (`str`, *optional*, defaults to `group_norm`):
439
- The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
440
- causal (`bool`, *optional*, defaults to `True`):
441
- Whether to use causal convolutions or not.
442
- """
443
-
444
- def __init__(
445
- self,
446
- dims,
447
- in_channels: int = 3,
448
- out_channels: int = 3,
449
- blocks: List[Tuple[str, Union[int, dict]]] = [("res_x", 1)],
450
- base_channels: int = 128,
451
- layers_per_block: int = 2,
452
- norm_num_groups: int = 32,
453
- patch_size: int = 1,
454
- norm_layer: str = "group_norm",
455
- causal: bool = True,
456
- timestep_conditioning: bool = False,
457
- ):
458
- super().__init__()
459
- self.patch_size = patch_size
460
- self.layers_per_block = layers_per_block
461
- out_channels = out_channels * patch_size**2
462
- self.causal = causal
463
- self.blocks_desc = blocks
464
-
465
- # Compute output channel to be product of all channel-multiplier blocks
466
- output_channel = base_channels
467
- for block_name, block_params in list(reversed(blocks)):
468
- block_params = block_params if isinstance(block_params, dict) else {}
469
- if block_name == "res_x_y":
470
- output_channel = output_channel * block_params.get("multiplier", 2)
471
- if block_name == "compress_all":
472
- output_channel = output_channel * block_params.get("multiplier", 1)
473
-
474
- self.conv_in = make_conv_nd(
475
- dims,
476
- in_channels,
477
- output_channel,
478
- kernel_size=3,
479
- stride=1,
480
- padding=1,
481
- causal=True,
482
- )
483
-
484
- self.up_blocks = nn.ModuleList([])
485
-
486
- for block_name, block_params in list(reversed(blocks)):
487
- input_channel = output_channel
488
- if isinstance(block_params, int):
489
- block_params = {"num_layers": block_params}
490
-
491
- if block_name == "res_x":
492
- block = UNetMidBlock3D(
493
- dims=dims,
494
- in_channels=input_channel,
495
- num_layers=block_params["num_layers"],
496
- resnet_eps=1e-6,
497
- resnet_groups=norm_num_groups,
498
- norm_layer=norm_layer,
499
- inject_noise=block_params.get("inject_noise", False),
500
- timestep_conditioning=timestep_conditioning,
501
- )
502
- elif block_name == "attn_res_x":
503
- block = UNetMidBlock3D(
504
- dims=dims,
505
- in_channels=input_channel,
506
- num_layers=block_params["num_layers"],
507
- resnet_groups=norm_num_groups,
508
- norm_layer=norm_layer,
509
- inject_noise=block_params.get("inject_noise", False),
510
- timestep_conditioning=timestep_conditioning,
511
- attention_head_dim=block_params["attention_head_dim"],
512
- )
513
- elif block_name == "res_x_y":
514
- output_channel = output_channel // block_params.get("multiplier", 2)
515
- block = ResnetBlock3D(
516
- dims=dims,
517
- in_channels=input_channel,
518
- out_channels=output_channel,
519
- eps=1e-6,
520
- groups=norm_num_groups,
521
- norm_layer=norm_layer,
522
- inject_noise=block_params.get("inject_noise", False),
523
- timestep_conditioning=False,
524
- )
525
- elif block_name == "compress_time":
526
- block = DepthToSpaceUpsample(
527
- dims=dims, in_channels=input_channel, stride=(2, 1, 1)
528
- )
529
- elif block_name == "compress_space":
530
- block = DepthToSpaceUpsample(
531
- dims=dims, in_channels=input_channel, stride=(1, 2, 2)
532
- )
533
- elif block_name == "compress_all":
534
- output_channel = output_channel // block_params.get("multiplier", 1)
535
- block = DepthToSpaceUpsample(
536
- dims=dims,
537
- in_channels=input_channel,
538
- stride=(2, 2, 2),
539
- residual=block_params.get("residual", False),
540
- out_channels_reduction_factor=block_params.get("multiplier", 1),
541
- )
542
- else:
543
- raise ValueError(f"unknown layer: {block_name}")
544
-
545
- self.up_blocks.append(block)
546
-
547
- if norm_layer == "group_norm":
548
- self.conv_norm_out = nn.GroupNorm(
549
- num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
550
- )
551
- elif norm_layer == "pixel_norm":
552
- self.conv_norm_out = PixelNorm()
553
- elif norm_layer == "layer_norm":
554
- self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
555
-
556
- self.conv_act = nn.SiLU()
557
- self.conv_out = make_conv_nd(
558
- dims, output_channel, out_channels, 3, padding=1, causal=True
559
- )
560
-
561
- self.gradient_checkpointing = False
562
-
563
- self.timestep_conditioning = timestep_conditioning
564
-
565
- if timestep_conditioning:
566
- self.timestep_scale_multiplier = nn.Parameter(
567
- torch.tensor(1000.0, dtype=torch.float32)
568
- )
569
- self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
570
- output_channel * 2, 0
571
- )
572
- self.last_scale_shift_table = nn.Parameter(
573
- torch.randn(2, output_channel) / output_channel**0.5
574
- )
575
-
576
- def forward(
577
- self,
578
- sample: torch.FloatTensor,
579
- target_shape,
580
- timesteps: Optional[torch.Tensor] = None,
581
- ) -> torch.FloatTensor:
582
- r"""The forward method of the `Decoder` class."""
583
- assert target_shape is not None, "target_shape must be provided"
584
- batch_size = sample.shape[0]
585
-
586
- sample = self.conv_in(sample, causal=self.causal)
587
-
588
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
589
-
590
- checkpoint_fn = (
591
- partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
592
- if self.gradient_checkpointing and self.training
593
- else lambda x: x
594
- )
595
-
596
- sample = sample.to(upscale_dtype)
597
-
598
- if self.timestep_conditioning:
599
- assert (
600
- timesteps is not None
601
- ), "should pass timesteps with timestep_conditioning=True"
602
- scaled_timesteps = timesteps * self.timestep_scale_multiplier
603
-
604
- for up_block in self.up_blocks:
605
- if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
606
- sample = checkpoint_fn(up_block)(
607
- sample, causal=self.causal, timesteps=scaled_timesteps
608
- )
609
- else:
610
- sample = checkpoint_fn(up_block)(sample, causal=self.causal)
611
-
612
- sample = self.conv_norm_out(sample)
613
-
614
- if self.timestep_conditioning:
615
- embedded_timesteps = self.last_time_embedder(
616
- timestep=scaled_timesteps.flatten(),
617
- resolution=None,
618
- aspect_ratio=None,
619
- batch_size=sample.shape[0],
620
- hidden_dtype=sample.dtype,
621
- )
622
- embedded_timesteps = embedded_timesteps.view(
623
- batch_size, embedded_timesteps.shape[-1], 1, 1, 1
624
- )
625
- ada_values = self.last_scale_shift_table[
626
- None, ..., None, None, None
627
- ] + embedded_timesteps.reshape(
628
- batch_size,
629
- 2,
630
- -1,
631
- embedded_timesteps.shape[-3],
632
- embedded_timesteps.shape[-2],
633
- embedded_timesteps.shape[-1],
634
- )
635
- shift, scale = ada_values.unbind(dim=1)
636
- sample = sample * (1 + scale) + shift
637
-
638
- sample = self.conv_act(sample)
639
- sample = self.conv_out(sample, causal=self.causal)
640
-
641
- sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
642
-
643
- return sample
644
-
645
-
646
- class UNetMidBlock3D(nn.Module):
647
- """
648
- A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
649
-
650
- Args:
651
- in_channels (`int`): The number of input channels.
652
- dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
653
- num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
654
- resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
655
- resnet_groups (`int`, *optional*, defaults to 32):
656
- The number of groups to use in the group normalization layers of the resnet blocks.
657
- norm_layer (`str`, *optional*, defaults to `group_norm`):
658
- The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
659
- inject_noise (`bool`, *optional*, defaults to `False`):
660
- Whether to inject noise into the hidden states.
661
- timestep_conditioning (`bool`, *optional*, defaults to `False`):
662
- Whether to condition the hidden states on the timestep.
663
- attention_head_dim (`int`, *optional*, defaults to -1):
664
- The dimension of the attention head. If -1, no attention is used.
665
-
666
- Returns:
667
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
668
- in_channels, height, width)`.
669
-
670
- """
671
-
672
- def __init__(
673
- self,
674
- dims: Union[int, Tuple[int, int]],
675
- in_channels: int,
676
- dropout: float = 0.0,
677
- num_layers: int = 1,
678
- resnet_eps: float = 1e-6,
679
- resnet_groups: int = 32,
680
- norm_layer: str = "group_norm",
681
- inject_noise: bool = False,
682
- timestep_conditioning: bool = False,
683
- attention_head_dim: int = -1,
684
- ):
685
- super().__init__()
686
- resnet_groups = (
687
- resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
688
- )
689
- self.timestep_conditioning = timestep_conditioning
690
-
691
- if timestep_conditioning:
692
- self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
693
- in_channels * 4, 0
694
- )
695
-
696
- self.res_blocks = nn.ModuleList(
697
- [
698
- ResnetBlock3D(
699
- dims=dims,
700
- in_channels=in_channels,
701
- out_channels=in_channels,
702
- eps=resnet_eps,
703
- groups=resnet_groups,
704
- dropout=dropout,
705
- norm_layer=norm_layer,
706
- inject_noise=inject_noise,
707
- timestep_conditioning=timestep_conditioning,
708
- )
709
- for _ in range(num_layers)
710
- ]
711
- )
712
-
713
- self.attention_blocks = None
714
-
715
- if attention_head_dim > 0:
716
- if attention_head_dim > in_channels:
717
- raise ValueError(
718
- "attention_head_dim must be less than or equal to in_channels"
719
- )
720
-
721
- self.attention_blocks = nn.ModuleList(
722
- [
723
- Attention(
724
- query_dim=in_channels,
725
- heads=in_channels // attention_head_dim,
726
- dim_head=attention_head_dim,
727
- bias=True,
728
- out_bias=True,
729
- qk_norm="rms_norm",
730
- residual_connection=True,
731
- )
732
- for _ in range(num_layers)
733
- ]
734
- )
735
-
736
- def forward(
737
- self,
738
- hidden_states: torch.FloatTensor,
739
- causal: bool = True,
740
- timesteps: Optional[torch.Tensor] = None,
741
- ) -> torch.FloatTensor:
742
- timestep_embed = None
743
- if self.timestep_conditioning:
744
- assert (
745
- timesteps is not None
746
- ), "should pass timesteps with timestep_conditioning=True"
747
- batch_size = hidden_states.shape[0]
748
- timestep_embed = self.time_embedder(
749
- timestep=timesteps.flatten(),
750
- resolution=None,
751
- aspect_ratio=None,
752
- batch_size=batch_size,
753
- hidden_dtype=hidden_states.dtype,
754
- )
755
- timestep_embed = timestep_embed.view(
756
- batch_size, timestep_embed.shape[-1], 1, 1, 1
757
- )
758
-
759
- if self.attention_blocks:
760
- for resnet, attention in zip(self.res_blocks, self.attention_blocks):
761
- hidden_states = resnet(
762
- hidden_states, causal=causal, timesteps=timestep_embed
763
- )
764
-
765
- # Reshape the hidden states to be (batch_size, frames * height * width, channel)
766
- batch_size, channel, frames, height, width = hidden_states.shape
767
- hidden_states = hidden_states.view(
768
- batch_size, channel, frames * height * width
769
- ).transpose(1, 2)
770
-
771
- if attention.use_tpu_flash_attention:
772
- # Pad the second dimension to be divisible by block_k_major (block in flash attention)
773
- seq_len = hidden_states.shape[1]
774
- block_k_major = 512
775
- pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
776
- if pad_len > 0:
777
- hidden_states = F.pad(
778
- hidden_states, (0, 0, 0, pad_len), "constant", 0
779
- )
780
-
781
- # Create a mask with ones for the original sequence length and zeros for the padded indexes
782
- mask = torch.ones(
783
- (hidden_states.shape[0], seq_len),
784
- device=hidden_states.device,
785
- dtype=hidden_states.dtype,
786
- )
787
- if pad_len > 0:
788
- mask = F.pad(mask, (0, pad_len), "constant", 0)
789
-
790
- hidden_states = attention(
791
- hidden_states,
792
- attention_mask=(
793
- None if not attention.use_tpu_flash_attention else mask
794
- ),
795
- )
796
-
797
- if attention.use_tpu_flash_attention:
798
- # Remove the padding
799
- if pad_len > 0:
800
- hidden_states = hidden_states[:, :-pad_len, :]
801
-
802
- # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
803
- hidden_states = hidden_states.transpose(-1, -2).reshape(
804
- batch_size, channel, frames, height, width
805
- )
806
- else:
807
- for resnet in self.res_blocks:
808
- hidden_states = resnet(
809
- hidden_states, causal=causal, timesteps=timestep_embed
810
- )
811
-
812
- return hidden_states
813
-
814
-
815
- class DepthToSpaceUpsample(nn.Module):
816
- def __init__(
817
- self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
818
- ):
819
- super().__init__()
820
- self.stride = stride
821
- self.out_channels = (
822
- np.prod(stride) * in_channels // out_channels_reduction_factor
823
- )
824
- self.conv = make_conv_nd(
825
- dims=dims,
826
- in_channels=in_channels,
827
- out_channels=self.out_channels,
828
- kernel_size=3,
829
- stride=1,
830
- causal=True,
831
- )
832
- self.residual = residual
833
- self.out_channels_reduction_factor = out_channels_reduction_factor
834
-
835
- def forward(self, x, causal: bool = True):
836
- if self.residual:
837
- # Reshape and duplicate the input to match the output shape
838
- x_in = rearrange(
839
- x,
840
- "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
841
- p1=self.stride[0],
842
- p2=self.stride[1],
843
- p3=self.stride[2],
844
- )
845
- num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
846
- x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
847
- if self.stride[0] == 2:
848
- x_in = x_in[:, :, 1:, :, :]
849
- x = self.conv(x, causal=causal)
850
- x = rearrange(
851
- x,
852
- "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
853
- p1=self.stride[0],
854
- p2=self.stride[1],
855
- p3=self.stride[2],
856
- )
857
- if self.stride[0] == 2:
858
- x = x[:, :, 1:, :, :]
859
- if self.residual:
860
- x = x + x_in
861
- return x
862
-
863
-
864
- class LayerNorm(nn.Module):
865
- def __init__(self, dim, eps, elementwise_affine=True) -> None:
866
- super().__init__()
867
- self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
868
-
869
- def forward(self, x):
870
- x = rearrange(x, "b c d h w -> b d h w c")
871
- x = self.norm(x)
872
- x = rearrange(x, "b d h w c -> b c d h w")
873
- return x
874
-
875
-
876
- class ResnetBlock3D(nn.Module):
877
- r"""
878
- A Resnet block.
879
-
880
- Parameters:
881
- in_channels (`int`): The number of channels in the input.
882
- out_channels (`int`, *optional*, default to be `None`):
883
- The number of output channels for the first conv layer. If None, same as `in_channels`.
884
- dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
885
- groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
886
- eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
887
- """
888
-
889
- def __init__(
890
- self,
891
- dims: Union[int, Tuple[int, int]],
892
- in_channels: int,
893
- out_channels: Optional[int] = None,
894
- dropout: float = 0.0,
895
- groups: int = 32,
896
- eps: float = 1e-6,
897
- norm_layer: str = "group_norm",
898
- inject_noise: bool = False,
899
- timestep_conditioning: bool = False,
900
- ):
901
- super().__init__()
902
- self.in_channels = in_channels
903
- out_channels = in_channels if out_channels is None else out_channels
904
- self.out_channels = out_channels
905
- self.inject_noise = inject_noise
906
-
907
- if norm_layer == "group_norm":
908
- self.norm1 = nn.GroupNorm(
909
- num_groups=groups, num_channels=in_channels, eps=eps, affine=True
910
- )
911
- elif norm_layer == "pixel_norm":
912
- self.norm1 = PixelNorm()
913
- elif norm_layer == "layer_norm":
914
- self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
915
-
916
- self.non_linearity = nn.SiLU()
917
-
918
- self.conv1 = make_conv_nd(
919
- dims,
920
- in_channels,
921
- out_channels,
922
- kernel_size=3,
923
- stride=1,
924
- padding=1,
925
- causal=True,
926
- )
927
-
928
- if inject_noise:
929
- self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
930
-
931
- if norm_layer == "group_norm":
932
- self.norm2 = nn.GroupNorm(
933
- num_groups=groups, num_channels=out_channels, eps=eps, affine=True
934
- )
935
- elif norm_layer == "pixel_norm":
936
- self.norm2 = PixelNorm()
937
- elif norm_layer == "layer_norm":
938
- self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
939
-
940
- self.dropout = torch.nn.Dropout(dropout)
941
-
942
- self.conv2 = make_conv_nd(
943
- dims,
944
- out_channels,
945
- out_channels,
946
- kernel_size=3,
947
- stride=1,
948
- padding=1,
949
- causal=True,
950
- )
951
-
952
- if inject_noise:
953
- self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
954
-
955
- self.conv_shortcut = (
956
- make_linear_nd(
957
- dims=dims, in_channels=in_channels, out_channels=out_channels
958
- )
959
- if in_channels != out_channels
960
- else nn.Identity()
961
- )
962
-
963
- self.norm3 = (
964
- LayerNorm(in_channels, eps=eps, elementwise_affine=True)
965
- if in_channels != out_channels
966
- else nn.Identity()
967
- )
968
-
969
- self.timestep_conditioning = timestep_conditioning
970
-
971
- if timestep_conditioning:
972
- self.scale_shift_table = nn.Parameter(
973
- torch.randn(4, in_channels) / in_channels**0.5
974
- )
975
-
976
- def _feed_spatial_noise(
977
- self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
978
- ) -> torch.FloatTensor:
979
- spatial_shape = hidden_states.shape[-2:]
980
- device = hidden_states.device
981
- dtype = hidden_states.dtype
982
-
983
- # similar to the "explicit noise inputs" method in style-gan
984
- spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
985
- scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
986
- hidden_states = hidden_states + scaled_noise
987
-
988
- return hidden_states
989
-
990
- def forward(
991
- self,
992
- input_tensor: torch.FloatTensor,
993
- causal: bool = True,
994
- timesteps: Optional[torch.Tensor] = None,
995
- ) -> torch.FloatTensor:
996
- hidden_states = input_tensor
997
- batch_size = hidden_states.shape[0]
998
-
999
- hidden_states = self.norm1(hidden_states)
1000
- if self.timestep_conditioning:
1001
- assert (
1002
- timesteps is not None
1003
- ), "should pass timesteps with timestep_conditioning=True"
1004
- ada_values = self.scale_shift_table[
1005
- None, ..., None, None, None
1006
- ] + timesteps.reshape(
1007
- batch_size,
1008
- 4,
1009
- -1,
1010
- timesteps.shape[-3],
1011
- timesteps.shape[-2],
1012
- timesteps.shape[-1],
1013
- )
1014
- shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
1015
-
1016
- hidden_states = hidden_states * (1 + scale1) + shift1
1017
-
1018
- hidden_states = self.non_linearity(hidden_states)
1019
-
1020
- hidden_states = self.conv1(hidden_states, causal=causal)
1021
-
1022
- if self.inject_noise:
1023
- hidden_states = self._feed_spatial_noise(
1024
- hidden_states, self.per_channel_scale1
1025
- )
1026
-
1027
- hidden_states = self.norm2(hidden_states)
1028
-
1029
- if self.timestep_conditioning:
1030
- hidden_states = hidden_states * (1 + scale2) + shift2
1031
-
1032
- hidden_states = self.non_linearity(hidden_states)
1033
-
1034
- hidden_states = self.dropout(hidden_states)
1035
-
1036
- hidden_states = self.conv2(hidden_states, causal=causal)
1037
-
1038
- if self.inject_noise:
1039
- hidden_states = self._feed_spatial_noise(
1040
- hidden_states, self.per_channel_scale2
1041
- )
1042
-
1043
- input_tensor = self.norm3(input_tensor)
1044
-
1045
- batch_size = input_tensor.shape[0]
1046
-
1047
- input_tensor = self.conv_shortcut(input_tensor)
1048
-
1049
- output_tensor = input_tensor + hidden_states
1050
-
1051
- return output_tensor
1052
-
1053
-
1054
- def patchify(x, patch_size_hw, patch_size_t=1):
1055
- if patch_size_hw == 1 and patch_size_t == 1:
1056
- return x
1057
- if x.dim() == 4:
1058
- x = rearrange(
1059
- x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
1060
- )
1061
- elif x.dim() == 5:
1062
- x = rearrange(
1063
- x,
1064
- "b c (f p) (h q) (w r) -> b (c p r q) f h w",
1065
- p=patch_size_t,
1066
- q=patch_size_hw,
1067
- r=patch_size_hw,
1068
- )
1069
- else:
1070
- raise ValueError(f"Invalid input shape: {x.shape}")
1071
-
1072
- return x
1073
-
1074
-
1075
- def unpatchify(x, patch_size_hw, patch_size_t=1):
1076
- if patch_size_hw == 1 and patch_size_t == 1:
1077
- return x
1078
-
1079
- if x.dim() == 4:
1080
- x = rearrange(
1081
- x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
1082
- )
1083
- elif x.dim() == 5:
1084
- x = rearrange(
1085
- x,
1086
- "b (c p r q) f h w -> b c (f p) (h q) (w r)",
1087
- p=patch_size_t,
1088
- q=patch_size_hw,
1089
- r=patch_size_hw,
1090
- )
1091
-
1092
- return x
1093
-
1094
-
1095
- def create_video_autoencoder_config(
1096
- latent_channels: int = 64,
1097
- ):
1098
- encoder_blocks = [
1099
- ("res_x", {"num_layers": 4}),
1100
- ("compress_all_x_y", {"multiplier": 3}),
1101
- ("res_x", {"num_layers": 4}),
1102
- ("compress_all_x_y", {"multiplier": 2}),
1103
- ("res_x", {"num_layers": 4}),
1104
- ("compress_all", {}),
1105
- ("res_x", {"num_layers": 3}),
1106
- ("res_x", {"num_layers": 4}),
1107
- ]
1108
- decoder_blocks = [
1109
- ("res_x", {"num_layers": 4}),
1110
- ("compress_all", {"residual": True}),
1111
- ("res_x_y", {"multiplier": 3}),
1112
- ("res_x", {"num_layers": 3}),
1113
- ("compress_all", {"residual": True}),
1114
- ("res_x_y", {"multiplier": 2}),
1115
- ("res_x", {"num_layers": 3}),
1116
- ("compress_all", {"residual": True}),
1117
- ("res_x", {"num_layers": 3}),
1118
- ("res_x", {"num_layers": 4}),
1119
- ]
1120
- return {
1121
- "_class_name": "CausalVideoAutoencoder",
1122
- "dims": 3,
1123
- "encoder_blocks": encoder_blocks,
1124
- "decoder_blocks": decoder_blocks,
1125
- "latent_channels": latent_channels,
1126
- "norm_layer": "pixel_norm",
1127
- "patch_size": 4,
1128
- "latent_log_var": "uniform",
1129
- "use_quant_conv": False,
1130
- "causal_decoder": False,
1131
- "timestep_conditioning": True,
1132
- }
1133
-
1134
-
1135
- def test_vae_patchify_unpatchify():
1136
- import torch
1137
-
1138
- x = torch.randn(2, 3, 8, 64, 64)
1139
- x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
1140
- x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
1141
- assert torch.allclose(x, x_unpatched)
1142
-
1143
-
1144
- def demo_video_autoencoder_forward_backward():
1145
- # Configuration for the VideoAutoencoder
1146
- config = create_video_autoencoder_config()
1147
-
1148
- # Instantiate the VideoAutoencoder with the specified configuration
1149
- video_autoencoder = CausalVideoAutoencoder.from_config(config)
1150
-
1151
- print(video_autoencoder)
1152
- video_autoencoder.eval()
1153
- # Print the total number of parameters in the video autoencoder
1154
- total_params = sum(p.numel() for p in video_autoencoder.parameters())
1155
- print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
1156
-
1157
- # Create a mock input tensor simulating a batch of videos
1158
- # Shape: (batch_size, channels, depth, height, width)
1159
- # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
1160
- input_videos = torch.randn(2, 3, 17, 64, 64)
1161
-
1162
- # Forward pass: encode and decode the input videos
1163
- latent = video_autoencoder.encode(input_videos).latent_dist.mode()
1164
- print(f"input shape={input_videos.shape}")
1165
- print(f"latent shape={latent.shape}")
1166
-
1167
- timesteps = torch.ones(input_videos.shape[0]) * 0.1
1168
- reconstructed_videos = video_autoencoder.decode(
1169
- latent, target_shape=input_videos.shape, timesteps=timesteps
1170
- ).sample
1171
-
1172
- print(f"reconstructed shape={reconstructed_videos.shape}")
1173
-
1174
- # Validate that single image gets treated the same way as first frame
1175
- input_image = input_videos[:, :, :1, :, :]
1176
- image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
1177
- _ = video_autoencoder.decode(
1178
- image_latent, target_shape=image_latent.shape, timesteps=timesteps
1179
- ).sample
1180
-
1181
- # first_frame_latent = latent[:, :, :1, :, :]
1182
-
1183
- # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
1184
- # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
1185
- # assert (image_latent == first_frame_latent).all()
1186
- # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
1187
-
1188
- # Calculate the loss (e.g., mean squared error)
1189
- loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
1190
-
1191
- # Perform backward pass
1192
- loss.backward()
1193
-
1194
- print(f"Demo completed with loss: {loss.item()}")
1195
-
1196
-
1197
- # Ensure to call the demo function to execute the forward and backward pass
1198
- if __name__ == "__main__":
1199
- demo_video_autoencoder_forward_backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/autoencoders/conv_nd_factory.py DELETED
@@ -1,82 +0,0 @@
1
- from typing import Tuple, Union
2
-
3
- import torch
4
-
5
- from xora.models.autoencoders.dual_conv3d import DualConv3d
6
- from xora.models.autoencoders.causal_conv3d import CausalConv3d
7
-
8
-
9
- def make_conv_nd(
10
- dims: Union[int, Tuple[int, int]],
11
- in_channels: int,
12
- out_channels: int,
13
- kernel_size: int,
14
- stride=1,
15
- padding=0,
16
- dilation=1,
17
- groups=1,
18
- bias=True,
19
- causal=False,
20
- ):
21
- if dims == 2:
22
- return torch.nn.Conv2d(
23
- in_channels=in_channels,
24
- out_channels=out_channels,
25
- kernel_size=kernel_size,
26
- stride=stride,
27
- padding=padding,
28
- dilation=dilation,
29
- groups=groups,
30
- bias=bias,
31
- )
32
- elif dims == 3:
33
- if causal:
34
- return CausalConv3d(
35
- in_channels=in_channels,
36
- out_channels=out_channels,
37
- kernel_size=kernel_size,
38
- stride=stride,
39
- padding=padding,
40
- dilation=dilation,
41
- groups=groups,
42
- bias=bias,
43
- )
44
- return torch.nn.Conv3d(
45
- in_channels=in_channels,
46
- out_channels=out_channels,
47
- kernel_size=kernel_size,
48
- stride=stride,
49
- padding=padding,
50
- dilation=dilation,
51
- groups=groups,
52
- bias=bias,
53
- )
54
- elif dims == (2, 1):
55
- return DualConv3d(
56
- in_channels=in_channels,
57
- out_channels=out_channels,
58
- kernel_size=kernel_size,
59
- stride=stride,
60
- padding=padding,
61
- bias=bias,
62
- )
63
- else:
64
- raise ValueError(f"unsupported dimensions: {dims}")
65
-
66
-
67
- def make_linear_nd(
68
- dims: int,
69
- in_channels: int,
70
- out_channels: int,
71
- bias=True,
72
- ):
73
- if dims == 2:
74
- return torch.nn.Conv2d(
75
- in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
76
- )
77
- elif dims == 3 or dims == (2, 1):
78
- return torch.nn.Conv3d(
79
- in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
80
- )
81
- else:
82
- raise ValueError(f"unsupported dimensions: {dims}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/autoencoders/dual_conv3d.py DELETED
@@ -1,195 +0,0 @@
1
- import math
2
- from typing import Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from einops import rearrange
8
-
9
-
10
- class DualConv3d(nn.Module):
11
- def __init__(
12
- self,
13
- in_channels,
14
- out_channels,
15
- kernel_size,
16
- stride: Union[int, Tuple[int, int, int]] = 1,
17
- padding: Union[int, Tuple[int, int, int]] = 0,
18
- dilation: Union[int, Tuple[int, int, int]] = 1,
19
- groups=1,
20
- bias=True,
21
- ):
22
- super(DualConv3d, self).__init__()
23
-
24
- self.in_channels = in_channels
25
- self.out_channels = out_channels
26
- # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
27
- if isinstance(kernel_size, int):
28
- kernel_size = (kernel_size, kernel_size, kernel_size)
29
- if kernel_size == (1, 1, 1):
30
- raise ValueError(
31
- "kernel_size must be greater than 1. Use make_linear_nd instead."
32
- )
33
- if isinstance(stride, int):
34
- stride = (stride, stride, stride)
35
- if isinstance(padding, int):
36
- padding = (padding, padding, padding)
37
- if isinstance(dilation, int):
38
- dilation = (dilation, dilation, dilation)
39
-
40
- # Set parameters for convolutions
41
- self.groups = groups
42
- self.bias = bias
43
-
44
- # Define the size of the channels after the first convolution
45
- intermediate_channels = (
46
- out_channels if in_channels < out_channels else in_channels
47
- )
48
-
49
- # Define parameters for the first convolution
50
- self.weight1 = nn.Parameter(
51
- torch.Tensor(
52
- intermediate_channels,
53
- in_channels // groups,
54
- 1,
55
- kernel_size[1],
56
- kernel_size[2],
57
- )
58
- )
59
- self.stride1 = (1, stride[1], stride[2])
60
- self.padding1 = (0, padding[1], padding[2])
61
- self.dilation1 = (1, dilation[1], dilation[2])
62
- if bias:
63
- self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
64
- else:
65
- self.register_parameter("bias1", None)
66
-
67
- # Define parameters for the second convolution
68
- self.weight2 = nn.Parameter(
69
- torch.Tensor(
70
- out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
71
- )
72
- )
73
- self.stride2 = (stride[0], 1, 1)
74
- self.padding2 = (padding[0], 0, 0)
75
- self.dilation2 = (dilation[0], 1, 1)
76
- if bias:
77
- self.bias2 = nn.Parameter(torch.Tensor(out_channels))
78
- else:
79
- self.register_parameter("bias2", None)
80
-
81
- # Initialize weights and biases
82
- self.reset_parameters()
83
-
84
- def reset_parameters(self):
85
- nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
86
- nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
87
- if self.bias:
88
- fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
89
- bound1 = 1 / math.sqrt(fan_in1)
90
- nn.init.uniform_(self.bias1, -bound1, bound1)
91
- fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
92
- bound2 = 1 / math.sqrt(fan_in2)
93
- nn.init.uniform_(self.bias2, -bound2, bound2)
94
-
95
- def forward(self, x, use_conv3d=False, skip_time_conv=False):
96
- if use_conv3d:
97
- return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
98
- else:
99
- return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
100
-
101
- def forward_with_3d(self, x, skip_time_conv):
102
- # First convolution
103
- x = F.conv3d(
104
- x,
105
- self.weight1,
106
- self.bias1,
107
- self.stride1,
108
- self.padding1,
109
- self.dilation1,
110
- self.groups,
111
- )
112
-
113
- if skip_time_conv:
114
- return x
115
-
116
- # Second convolution
117
- x = F.conv3d(
118
- x,
119
- self.weight2,
120
- self.bias2,
121
- self.stride2,
122
- self.padding2,
123
- self.dilation2,
124
- self.groups,
125
- )
126
-
127
- return x
128
-
129
- def forward_with_2d(self, x, skip_time_conv):
130
- b, c, d, h, w = x.shape
131
-
132
- # First 2D convolution
133
- x = rearrange(x, "b c d h w -> (b d) c h w")
134
- # Squeeze the depth dimension out of weight1 since it's 1
135
- weight1 = self.weight1.squeeze(2)
136
- # Select stride, padding, and dilation for the 2D convolution
137
- stride1 = (self.stride1[1], self.stride1[2])
138
- padding1 = (self.padding1[1], self.padding1[2])
139
- dilation1 = (self.dilation1[1], self.dilation1[2])
140
- x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
141
-
142
- _, _, h, w = x.shape
143
-
144
- if skip_time_conv:
145
- x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
146
- return x
147
-
148
- # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
149
- x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
150
-
151
- # Reshape weight2 to match the expected dimensions for conv1d
152
- weight2 = self.weight2.squeeze(-1).squeeze(-1)
153
- # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
154
- stride2 = self.stride2[0]
155
- padding2 = self.padding2[0]
156
- dilation2 = self.dilation2[0]
157
- x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
158
- x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
159
-
160
- return x
161
-
162
- @property
163
- def weight(self):
164
- return self.weight2
165
-
166
-
167
- def test_dual_conv3d_consistency():
168
- # Initialize parameters
169
- in_channels = 3
170
- out_channels = 5
171
- kernel_size = (3, 3, 3)
172
- stride = (2, 2, 2)
173
- padding = (1, 1, 1)
174
-
175
- # Create an instance of the DualConv3d class
176
- dual_conv3d = DualConv3d(
177
- in_channels=in_channels,
178
- out_channels=out_channels,
179
- kernel_size=kernel_size,
180
- stride=stride,
181
- padding=padding,
182
- bias=True,
183
- )
184
-
185
- # Example input tensor
186
- test_input = torch.randn(1, 3, 10, 10, 10)
187
-
188
- # Perform forward passes with both 3D and 2D settings
189
- output_conv3d = dual_conv3d(test_input, use_conv3d=True)
190
- output_2d = dual_conv3d(test_input, use_conv3d=False)
191
-
192
- # Assert that the outputs from both methods are sufficiently close
193
- assert torch.allclose(
194
- output_conv3d, output_2d, atol=1e-6
195
- ), "Outputs are not consistent between 3D and 2D convolutions."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/autoencoders/pixel_norm.py DELETED
@@ -1,12 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class PixelNorm(nn.Module):
6
- def __init__(self, dim=1, eps=1e-8):
7
- super(PixelNorm, self).__init__()
8
- self.dim = dim
9
- self.eps = eps
10
-
11
- def forward(self, x):
12
- return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/autoencoders/vae.py DELETED
@@ -1,343 +0,0 @@
1
- from typing import Optional, Union
2
-
3
- import torch
4
- import inspect
5
- import math
6
- import torch.nn as nn
7
- from diffusers import ConfigMixin, ModelMixin
8
- from diffusers.models.autoencoders.vae import (
9
- DecoderOutput,
10
- DiagonalGaussianDistribution,
11
- )
12
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
13
- from xora.models.autoencoders.conv_nd_factory import make_conv_nd
14
-
15
-
16
- class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
17
- """Variational Autoencoder (VAE) model with KL loss.
18
-
19
- VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
20
- This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
21
-
22
- Args:
23
- encoder (`nn.Module`):
24
- Encoder module.
25
- decoder (`nn.Module`):
26
- Decoder module.
27
- latent_channels (`int`, *optional*, defaults to 4):
28
- Number of latent channels.
29
- """
30
-
31
- def __init__(
32
- self,
33
- encoder: nn.Module,
34
- decoder: nn.Module,
35
- latent_channels: int = 4,
36
- dims: int = 2,
37
- sample_size=512,
38
- use_quant_conv: bool = True,
39
- ):
40
- super().__init__()
41
-
42
- # pass init params to Encoder
43
- self.encoder = encoder
44
- self.use_quant_conv = use_quant_conv
45
-
46
- # pass init params to Decoder
47
- quant_dims = 2 if dims == 2 else 3
48
- self.decoder = decoder
49
- if use_quant_conv:
50
- self.quant_conv = make_conv_nd(
51
- quant_dims, 2 * latent_channels, 2 * latent_channels, 1
52
- )
53
- self.post_quant_conv = make_conv_nd(
54
- quant_dims, latent_channels, latent_channels, 1
55
- )
56
- else:
57
- self.quant_conv = nn.Identity()
58
- self.post_quant_conv = nn.Identity()
59
- self.use_z_tiling = False
60
- self.use_hw_tiling = False
61
- self.dims = dims
62
- self.z_sample_size = 1
63
-
64
- self.decoder_params = inspect.signature(self.decoder.forward).parameters
65
-
66
- # only relevant if vae tiling is enabled
67
- self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
68
-
69
- def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
70
- self.tile_sample_min_size = sample_size
71
- num_blocks = len(self.encoder.down_blocks)
72
- self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
73
- self.tile_overlap_factor = overlap_factor
74
-
75
- def enable_z_tiling(self, z_sample_size: int = 8):
76
- r"""
77
- Enable tiling during VAE decoding.
78
-
79
- When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
80
- steps. This is useful to save some memory and allow larger batch sizes.
81
- """
82
- self.use_z_tiling = z_sample_size > 1
83
- self.z_sample_size = z_sample_size
84
- assert (
85
- z_sample_size % 8 == 0 or z_sample_size == 1
86
- ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
87
-
88
- def disable_z_tiling(self):
89
- r"""
90
- Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
91
- decoding in one step.
92
- """
93
- self.use_z_tiling = False
94
-
95
- def enable_hw_tiling(self):
96
- r"""
97
- Enable tiling during VAE decoding along the height and width dimension.
98
- """
99
- self.use_hw_tiling = True
100
-
101
- def disable_hw_tiling(self):
102
- r"""
103
- Disable tiling during VAE decoding along the height and width dimension.
104
- """
105
- self.use_hw_tiling = False
106
-
107
- def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
108
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
109
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
110
- row_limit = self.tile_latent_min_size - blend_extent
111
-
112
- # Split the image into 512x512 tiles and encode them separately.
113
- rows = []
114
- for i in range(0, x.shape[3], overlap_size):
115
- row = []
116
- for j in range(0, x.shape[4], overlap_size):
117
- tile = x[
118
- :,
119
- :,
120
- :,
121
- i : i + self.tile_sample_min_size,
122
- j : j + self.tile_sample_min_size,
123
- ]
124
- tile = self.encoder(tile)
125
- tile = self.quant_conv(tile)
126
- row.append(tile)
127
- rows.append(row)
128
- result_rows = []
129
- for i, row in enumerate(rows):
130
- result_row = []
131
- for j, tile in enumerate(row):
132
- # blend the above tile and the left tile
133
- # to the current tile and add the current tile to the result row
134
- if i > 0:
135
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
136
- if j > 0:
137
- tile = self.blend_h(row[j - 1], tile, blend_extent)
138
- result_row.append(tile[:, :, :, :row_limit, :row_limit])
139
- result_rows.append(torch.cat(result_row, dim=4))
140
-
141
- moments = torch.cat(result_rows, dim=3)
142
- return moments
143
-
144
- def blend_z(
145
- self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
146
- ) -> torch.Tensor:
147
- blend_extent = min(a.shape[2], b.shape[2], blend_extent)
148
- for z in range(blend_extent):
149
- b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
150
- 1 - z / blend_extent
151
- ) + b[:, :, z, :, :] * (z / blend_extent)
152
- return b
153
-
154
- def blend_v(
155
- self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
156
- ) -> torch.Tensor:
157
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
158
- for y in range(blend_extent):
159
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
160
- 1 - y / blend_extent
161
- ) + b[:, :, :, y, :] * (y / blend_extent)
162
- return b
163
-
164
- def blend_h(
165
- self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
166
- ) -> torch.Tensor:
167
- blend_extent = min(a.shape[4], b.shape[4], blend_extent)
168
- for x in range(blend_extent):
169
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
170
- 1 - x / blend_extent
171
- ) + b[:, :, :, :, x] * (x / blend_extent)
172
- return b
173
-
174
- def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
175
- overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
176
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
177
- row_limit = self.tile_sample_min_size - blend_extent
178
- tile_target_shape = (
179
- *target_shape[:3],
180
- self.tile_sample_min_size,
181
- self.tile_sample_min_size,
182
- )
183
- # Split z into overlapping 64x64 tiles and decode them separately.
184
- # The tiles have an overlap to avoid seams between tiles.
185
- rows = []
186
- for i in range(0, z.shape[3], overlap_size):
187
- row = []
188
- for j in range(0, z.shape[4], overlap_size):
189
- tile = z[
190
- :,
191
- :,
192
- :,
193
- i : i + self.tile_latent_min_size,
194
- j : j + self.tile_latent_min_size,
195
- ]
196
- tile = self.post_quant_conv(tile)
197
- decoded = self.decoder(tile, target_shape=tile_target_shape)
198
- row.append(decoded)
199
- rows.append(row)
200
- result_rows = []
201
- for i, row in enumerate(rows):
202
- result_row = []
203
- for j, tile in enumerate(row):
204
- # blend the above tile and the left tile
205
- # to the current tile and add the current tile to the result row
206
- if i > 0:
207
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
208
- if j > 0:
209
- tile = self.blend_h(row[j - 1], tile, blend_extent)
210
- result_row.append(tile[:, :, :, :row_limit, :row_limit])
211
- result_rows.append(torch.cat(result_row, dim=4))
212
-
213
- dec = torch.cat(result_rows, dim=3)
214
- return dec
215
-
216
- def encode(
217
- self, z: torch.FloatTensor, return_dict: bool = True
218
- ) -> Union[DecoderOutput, torch.FloatTensor]:
219
- if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
220
- num_splits = z.shape[2] // self.z_sample_size
221
- sizes = [self.z_sample_size] * num_splits
222
- sizes = (
223
- sizes + [z.shape[2] - sum(sizes)]
224
- if z.shape[2] - sum(sizes) > 0
225
- else sizes
226
- )
227
- tiles = z.split(sizes, dim=2)
228
- moments_tiles = [
229
- (
230
- self._hw_tiled_encode(z_tile, return_dict)
231
- if self.use_hw_tiling
232
- else self._encode(z_tile)
233
- )
234
- for z_tile in tiles
235
- ]
236
- moments = torch.cat(moments_tiles, dim=2)
237
-
238
- else:
239
- moments = (
240
- self._hw_tiled_encode(z, return_dict)
241
- if self.use_hw_tiling
242
- else self._encode(z)
243
- )
244
-
245
- posterior = DiagonalGaussianDistribution(moments)
246
- if not return_dict:
247
- return (posterior,)
248
-
249
- return AutoencoderKLOutput(latent_dist=posterior)
250
-
251
- def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
252
- h = self.encoder(x)
253
- moments = self.quant_conv(h)
254
- return moments
255
-
256
- def _decode(
257
- self,
258
- z: torch.FloatTensor,
259
- target_shape=None,
260
- timesteps: Optional[torch.Tensor] = None,
261
- ) -> Union[DecoderOutput, torch.FloatTensor]:
262
- z = self.post_quant_conv(z)
263
- if "timesteps" in self.decoder_params:
264
- dec = self.decoder(z, target_shape=target_shape, timesteps=timesteps)
265
- else:
266
- dec = self.decoder(z, target_shape=target_shape)
267
- return dec
268
-
269
- def decode(
270
- self,
271
- z: torch.FloatTensor,
272
- return_dict: bool = True,
273
- target_shape=None,
274
- timesteps: Optional[torch.Tensor] = None,
275
- ) -> Union[DecoderOutput, torch.FloatTensor]:
276
- assert target_shape is not None, "target_shape must be provided for decoding"
277
- if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
278
- reduction_factor = int(
279
- self.encoder.patch_size_t
280
- * 2
281
- ** (
282
- len(self.encoder.down_blocks)
283
- - 1
284
- - math.sqrt(self.encoder.patch_size)
285
- )
286
- )
287
- split_size = self.z_sample_size // reduction_factor
288
- num_splits = z.shape[2] // split_size
289
-
290
- # copy target shape, and divide frame dimension (=2) by the context size
291
- target_shape_split = list(target_shape)
292
- target_shape_split[2] = target_shape[2] // num_splits
293
-
294
- decoded_tiles = [
295
- (
296
- self._hw_tiled_decode(z_tile, target_shape_split)
297
- if self.use_hw_tiling
298
- else self._decode(z_tile, target_shape=target_shape_split)
299
- )
300
- for z_tile in torch.tensor_split(z, num_splits, dim=2)
301
- ]
302
- decoded = torch.cat(decoded_tiles, dim=2)
303
- else:
304
- decoded = (
305
- self._hw_tiled_decode(z, target_shape)
306
- if self.use_hw_tiling
307
- else self._decode(z, target_shape=target_shape, timesteps=timesteps)
308
- )
309
-
310
- if not return_dict:
311
- return (decoded,)
312
-
313
- return DecoderOutput(sample=decoded)
314
-
315
- def forward(
316
- self,
317
- sample: torch.FloatTensor,
318
- sample_posterior: bool = False,
319
- return_dict: bool = True,
320
- generator: Optional[torch.Generator] = None,
321
- ) -> Union[DecoderOutput, torch.FloatTensor]:
322
- r"""
323
- Args:
324
- sample (`torch.FloatTensor`): Input sample.
325
- sample_posterior (`bool`, *optional*, defaults to `False`):
326
- Whether to sample from the posterior.
327
- return_dict (`bool`, *optional*, defaults to `True`):
328
- Whether to return a [`DecoderOutput`] instead of a plain tuple.
329
- generator (`torch.Generator`, *optional*):
330
- Generator used to sample from the posterior.
331
- """
332
- x = sample
333
- posterior = self.encode(x).latent_dist
334
- if sample_posterior:
335
- z = posterior.sample(generator=generator)
336
- else:
337
- z = posterior.mode()
338
- dec = self.decode(z, target_shape=sample.shape).sample
339
-
340
- if not return_dict:
341
- return (dec,)
342
-
343
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/autoencoders/vae_encode.py DELETED
@@ -1,190 +0,0 @@
1
- import torch
2
- from diffusers import AutoencoderKL
3
- from einops import rearrange
4
- from torch import Tensor
5
-
6
-
7
- from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
8
- from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
9
-
10
- try:
11
- import torch_xla.core.xla_model as xm
12
- except ImportError:
13
- xm = None
14
-
15
-
16
- def vae_encode(
17
- media_items: Tensor,
18
- vae: AutoencoderKL,
19
- split_size: int = 1,
20
- vae_per_channel_normalize=False,
21
- ) -> Tensor:
22
- """
23
- Encodes media items (images or videos) into latent representations using a specified VAE model.
24
- The function supports processing batches of images or video frames and can handle the processing
25
- in smaller sub-batches if needed.
26
-
27
- Args:
28
- media_items (Tensor): A torch Tensor containing the media items to encode. The expected
29
- shape is (batch_size, channels, height, width) for images or (batch_size, channels,
30
- frames, height, width) for videos.
31
- vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
32
- pre-configured and loaded with the appropriate model weights.
33
- split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
34
- If set to more than 1, the input media items are processed in smaller batches according to
35
- this value. Defaults to 1, which processes all items in a single batch.
36
-
37
- Returns:
38
- Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
39
- to match the input shape, scaled by the model's configuration.
40
-
41
- Examples:
42
- >>> import torch
43
- >>> from diffusers import AutoencoderKL
44
- >>> vae = AutoencoderKL.from_pretrained('your-model-name')
45
- >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
46
- >>> latents = vae_encode(images, vae)
47
- >>> print(latents.shape) # Output shape will depend on the model's latent configuration.
48
-
49
- Note:
50
- In case of a video, the function encodes the media item frame-by frame.
51
- """
52
- is_video_shaped = media_items.dim() == 5
53
- batch_size, channels = media_items.shape[0:2]
54
-
55
- if channels != 3:
56
- raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
57
-
58
- if is_video_shaped and not isinstance(
59
- vae, (VideoAutoencoder, CausalVideoAutoencoder)
60
- ):
61
- media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
62
- if split_size > 1:
63
- if len(media_items) % split_size != 0:
64
- raise ValueError(
65
- "Error: The batch size must be divisible by 'train.vae_bs_split"
66
- )
67
- encode_bs = len(media_items) // split_size
68
- # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
69
- latents = []
70
- if media_items.device.type == "xla":
71
- xm.mark_step()
72
- for image_batch in media_items.split(encode_bs):
73
- latents.append(vae.encode(image_batch).latent_dist.sample())
74
- if media_items.device.type == "xla":
75
- xm.mark_step()
76
- latents = torch.cat(latents, dim=0)
77
- else:
78
- latents = vae.encode(media_items).latent_dist.sample()
79
-
80
- latents = normalize_latents(latents, vae, vae_per_channel_normalize)
81
- if is_video_shaped and not isinstance(
82
- vae, (VideoAutoencoder, CausalVideoAutoencoder)
83
- ):
84
- latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
85
- return latents
86
-
87
-
88
- def vae_decode(
89
- latents: Tensor,
90
- vae: AutoencoderKL,
91
- is_video: bool = True,
92
- split_size: int = 1,
93
- vae_per_channel_normalize=False,
94
- ) -> Tensor:
95
- is_video_shaped = latents.dim() == 5
96
- batch_size = latents.shape[0]
97
-
98
- if is_video_shaped and not isinstance(
99
- vae, (VideoAutoencoder, CausalVideoAutoencoder)
100
- ):
101
- latents = rearrange(latents, "b c n h w -> (b n) c h w")
102
- if split_size > 1:
103
- if len(latents) % split_size != 0:
104
- raise ValueError(
105
- "Error: The batch size must be divisible by 'train.vae_bs_split"
106
- )
107
- encode_bs = len(latents) // split_size
108
- image_batch = [
109
- _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
110
- for latent_batch in latents.split(encode_bs)
111
- ]
112
- images = torch.cat(image_batch, dim=0)
113
- else:
114
- images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
115
-
116
- if is_video_shaped and not isinstance(
117
- vae, (VideoAutoencoder, CausalVideoAutoencoder)
118
- ):
119
- images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
120
- return images
121
-
122
-
123
- def _run_decoder(
124
- latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False
125
- ) -> Tensor:
126
- if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
127
- *_, fl, hl, wl = latents.shape
128
- temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
129
- latents = latents.to(vae.dtype)
130
- image = vae.decode(
131
- un_normalize_latents(latents, vae, vae_per_channel_normalize),
132
- return_dict=False,
133
- target_shape=(
134
- 1,
135
- 3,
136
- fl * temporal_scale if is_video else 1,
137
- hl * spatial_scale,
138
- wl * spatial_scale,
139
- ),
140
- )[0]
141
- else:
142
- image = vae.decode(
143
- un_normalize_latents(latents, vae, vae_per_channel_normalize),
144
- return_dict=False,
145
- )[0]
146
- return image
147
-
148
-
149
- def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
150
- if isinstance(vae, CausalVideoAutoencoder):
151
- spatial = vae.spatial_downscale_factor
152
- temporal = vae.temporal_downscale_factor
153
- else:
154
- down_blocks = len(
155
- [
156
- block
157
- for block in vae.encoder.down_blocks
158
- if isinstance(block.downsample, Downsample3D)
159
- ]
160
- )
161
- spatial = vae.config.patch_size * 2**down_blocks
162
- temporal = (
163
- vae.config.patch_size_t * 2**down_blocks
164
- if isinstance(vae, VideoAutoencoder)
165
- else 1
166
- )
167
-
168
- return (temporal, spatial, spatial)
169
-
170
-
171
- def normalize_latents(
172
- latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
173
- ) -> Tensor:
174
- return (
175
- (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
176
- / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
177
- if vae_per_channel_normalize
178
- else latents * vae.config.scaling_factor
179
- )
180
-
181
-
182
- def un_normalize_latents(
183
- latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
184
- ) -> Tensor:
185
- return (
186
- latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
187
- + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
188
- if vae_per_channel_normalize
189
- else latents / vae.config.scaling_factor
190
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/autoencoders/video_autoencoder.py DELETED
@@ -1,1045 +0,0 @@
1
- import json
2
- import os
3
- from functools import partial
4
- from types import SimpleNamespace
5
- from typing import Any, Mapping, Optional, Tuple, Union
6
-
7
- import torch
8
- from einops import rearrange
9
- from torch import nn
10
- from torch.nn import functional
11
-
12
- from diffusers.utils import logging
13
-
14
- from xora.utils.torch_utils import Identity
15
- from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
16
- from xora.models.autoencoders.pixel_norm import PixelNorm
17
- from xora.models.autoencoders.vae import AutoencoderKLWrapper
18
-
19
- logger = logging.get_logger(__name__)
20
-
21
-
22
- class VideoAutoencoder(AutoencoderKLWrapper):
23
- @classmethod
24
- def from_pretrained(
25
- cls,
26
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
27
- *args,
28
- **kwargs,
29
- ):
30
- config_local_path = pretrained_model_name_or_path / "config.json"
31
- config = cls.load_config(config_local_path, **kwargs)
32
- video_vae = cls.from_config(config)
33
- video_vae.to(kwargs["torch_dtype"])
34
-
35
- model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
36
- ckpt_state_dict = torch.load(model_local_path)
37
- video_vae.load_state_dict(ckpt_state_dict)
38
-
39
- statistics_local_path = (
40
- pretrained_model_name_or_path / "per_channel_statistics.json"
41
- )
42
- if statistics_local_path.exists():
43
- with open(statistics_local_path, "r") as file:
44
- data = json.load(file)
45
- transposed_data = list(zip(*data["data"]))
46
- data_dict = {
47
- col: torch.tensor(vals)
48
- for col, vals in zip(data["columns"], transposed_data)
49
- }
50
- video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
51
- video_vae.register_buffer(
52
- "mean_of_means",
53
- data_dict.get(
54
- "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
55
- ),
56
- )
57
-
58
- return video_vae
59
-
60
- @staticmethod
61
- def from_config(config):
62
- assert (
63
- config["_class_name"] == "VideoAutoencoder"
64
- ), "config must have _class_name=VideoAutoencoder"
65
- if isinstance(config["dims"], list):
66
- config["dims"] = tuple(config["dims"])
67
-
68
- assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
69
-
70
- double_z = config.get("double_z", True)
71
- latent_log_var = config.get(
72
- "latent_log_var", "per_channel" if double_z else "none"
73
- )
74
- use_quant_conv = config.get("use_quant_conv", True)
75
-
76
- if use_quant_conv and latent_log_var == "uniform":
77
- raise ValueError("uniform latent_log_var requires use_quant_conv=False")
78
-
79
- encoder = Encoder(
80
- dims=config["dims"],
81
- in_channels=config.get("in_channels", 3),
82
- out_channels=config["latent_channels"],
83
- block_out_channels=config["block_out_channels"],
84
- patch_size=config.get("patch_size", 1),
85
- latent_log_var=latent_log_var,
86
- norm_layer=config.get("norm_layer", "group_norm"),
87
- patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
88
- add_channel_padding=config.get("add_channel_padding", False),
89
- )
90
-
91
- decoder = Decoder(
92
- dims=config["dims"],
93
- in_channels=config["latent_channels"],
94
- out_channels=config.get("out_channels", 3),
95
- block_out_channels=config["block_out_channels"],
96
- patch_size=config.get("patch_size", 1),
97
- norm_layer=config.get("norm_layer", "group_norm"),
98
- patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
99
- add_channel_padding=config.get("add_channel_padding", False),
100
- )
101
-
102
- dims = config["dims"]
103
- return VideoAutoencoder(
104
- encoder=encoder,
105
- decoder=decoder,
106
- latent_channels=config["latent_channels"],
107
- dims=dims,
108
- use_quant_conv=use_quant_conv,
109
- )
110
-
111
- @property
112
- def config(self):
113
- return SimpleNamespace(
114
- _class_name="VideoAutoencoder",
115
- dims=self.dims,
116
- in_channels=self.encoder.conv_in.in_channels
117
- // (self.encoder.patch_size_t * self.encoder.patch_size**2),
118
- out_channels=self.decoder.conv_out.out_channels
119
- // (self.decoder.patch_size_t * self.decoder.patch_size**2),
120
- latent_channels=self.decoder.conv_in.in_channels,
121
- block_out_channels=[
122
- self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
123
- for i in range(len(self.encoder.down_blocks))
124
- ],
125
- scaling_factor=1.0,
126
- norm_layer=self.encoder.norm_layer,
127
- patch_size=self.encoder.patch_size,
128
- latent_log_var=self.encoder.latent_log_var,
129
- use_quant_conv=self.use_quant_conv,
130
- patch_size_t=self.encoder.patch_size_t,
131
- add_channel_padding=self.encoder.add_channel_padding,
132
- )
133
-
134
- @property
135
- def is_video_supported(self):
136
- """
137
- Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
138
- """
139
- return self.dims != 2
140
-
141
- @property
142
- def downscale_factor(self):
143
- return self.encoder.downsample_factor
144
-
145
- def to_json_string(self) -> str:
146
- import json
147
-
148
- return json.dumps(self.config.__dict__)
149
-
150
- def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
151
- model_keys = set(name for name, _ in self.named_parameters())
152
-
153
- key_mapping = {
154
- ".resnets.": ".res_blocks.",
155
- "downsamplers.0": "downsample",
156
- "upsamplers.0": "upsample",
157
- }
158
-
159
- converted_state_dict = {}
160
- for key, value in state_dict.items():
161
- for k, v in key_mapping.items():
162
- key = key.replace(k, v)
163
-
164
- if "norm" in key and key not in model_keys:
165
- logger.info(
166
- f"Removing key {key} from state_dict as it is not present in the model"
167
- )
168
- continue
169
-
170
- converted_state_dict[key] = value
171
-
172
- super().load_state_dict(converted_state_dict, strict=strict)
173
-
174
- def last_layer(self):
175
- if hasattr(self.decoder, "conv_out"):
176
- if isinstance(self.decoder.conv_out, nn.Sequential):
177
- last_layer = self.decoder.conv_out[-1]
178
- else:
179
- last_layer = self.decoder.conv_out
180
- else:
181
- last_layer = self.decoder.layers[-1]
182
- return last_layer
183
-
184
-
185
- class Encoder(nn.Module):
186
- r"""
187
- The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
188
-
189
- Args:
190
- in_channels (`int`, *optional*, defaults to 3):
191
- The number of input channels.
192
- out_channels (`int`, *optional*, defaults to 3):
193
- The number of output channels.
194
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
195
- The number of output channels for each block.
196
- layers_per_block (`int`, *optional*, defaults to 2):
197
- The number of layers per block.
198
- norm_num_groups (`int`, *optional*, defaults to 32):
199
- The number of groups for normalization.
200
- patch_size (`int`, *optional*, defaults to 1):
201
- The patch size to use. Should be a power of 2.
202
- norm_layer (`str`, *optional*, defaults to `group_norm`):
203
- The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
204
- latent_log_var (`str`, *optional*, defaults to `per_channel`):
205
- The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
206
- """
207
-
208
- def __init__(
209
- self,
210
- dims: Union[int, Tuple[int, int]] = 3,
211
- in_channels: int = 3,
212
- out_channels: int = 3,
213
- block_out_channels: Tuple[int, ...] = (64,),
214
- layers_per_block: int = 2,
215
- norm_num_groups: int = 32,
216
- patch_size: Union[int, Tuple[int]] = 1,
217
- norm_layer: str = "group_norm", # group_norm, pixel_norm
218
- latent_log_var: str = "per_channel",
219
- patch_size_t: Optional[int] = None,
220
- add_channel_padding: Optional[bool] = False,
221
- ):
222
- super().__init__()
223
- self.patch_size = patch_size
224
- self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
225
- self.add_channel_padding = add_channel_padding
226
- self.layers_per_block = layers_per_block
227
- self.norm_layer = norm_layer
228
- self.latent_channels = out_channels
229
- self.latent_log_var = latent_log_var
230
- if add_channel_padding:
231
- in_channels = in_channels * self.patch_size**3
232
- else:
233
- in_channels = in_channels * self.patch_size_t * self.patch_size**2
234
- self.in_channels = in_channels
235
- output_channel = block_out_channels[0]
236
-
237
- self.conv_in = make_conv_nd(
238
- dims=dims,
239
- in_channels=in_channels,
240
- out_channels=output_channel,
241
- kernel_size=3,
242
- stride=1,
243
- padding=1,
244
- )
245
-
246
- self.down_blocks = nn.ModuleList([])
247
-
248
- for i in range(len(block_out_channels)):
249
- input_channel = output_channel
250
- output_channel = block_out_channels[i]
251
- is_final_block = i == len(block_out_channels) - 1
252
-
253
- down_block = DownEncoderBlock3D(
254
- dims=dims,
255
- in_channels=input_channel,
256
- out_channels=output_channel,
257
- num_layers=self.layers_per_block,
258
- add_downsample=not is_final_block and 2**i >= patch_size,
259
- resnet_eps=1e-6,
260
- downsample_padding=0,
261
- resnet_groups=norm_num_groups,
262
- norm_layer=norm_layer,
263
- )
264
- self.down_blocks.append(down_block)
265
-
266
- self.mid_block = UNetMidBlock3D(
267
- dims=dims,
268
- in_channels=block_out_channels[-1],
269
- num_layers=self.layers_per_block,
270
- resnet_eps=1e-6,
271
- resnet_groups=norm_num_groups,
272
- norm_layer=norm_layer,
273
- )
274
-
275
- # out
276
- if norm_layer == "group_norm":
277
- self.conv_norm_out = nn.GroupNorm(
278
- num_channels=block_out_channels[-1],
279
- num_groups=norm_num_groups,
280
- eps=1e-6,
281
- )
282
- elif norm_layer == "pixel_norm":
283
- self.conv_norm_out = PixelNorm()
284
- self.conv_act = nn.SiLU()
285
-
286
- conv_out_channels = out_channels
287
- if latent_log_var == "per_channel":
288
- conv_out_channels *= 2
289
- elif latent_log_var == "uniform":
290
- conv_out_channels += 1
291
- elif latent_log_var != "none":
292
- raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
293
- self.conv_out = make_conv_nd(
294
- dims, block_out_channels[-1], conv_out_channels, 3, padding=1
295
- )
296
-
297
- self.gradient_checkpointing = False
298
-
299
- @property
300
- def downscale_factor(self):
301
- return (
302
- 2
303
- ** len(
304
- [
305
- block
306
- for block in self.down_blocks
307
- if isinstance(block.downsample, Downsample3D)
308
- ]
309
- )
310
- * self.patch_size
311
- )
312
-
313
- def forward(
314
- self, sample: torch.FloatTensor, return_features=False
315
- ) -> torch.FloatTensor:
316
- r"""The forward method of the `Encoder` class."""
317
-
318
- downsample_in_time = sample.shape[2] != 1
319
-
320
- # patchify
321
- patch_size_t = self.patch_size_t if downsample_in_time else 1
322
- sample = patchify(
323
- sample,
324
- patch_size_hw=self.patch_size,
325
- patch_size_t=patch_size_t,
326
- add_channel_padding=self.add_channel_padding,
327
- )
328
-
329
- sample = self.conv_in(sample)
330
-
331
- checkpoint_fn = (
332
- partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
333
- if self.gradient_checkpointing and self.training
334
- else lambda x: x
335
- )
336
-
337
- if return_features:
338
- features = []
339
- for down_block in self.down_blocks:
340
- sample = checkpoint_fn(down_block)(
341
- sample, downsample_in_time=downsample_in_time
342
- )
343
- if return_features:
344
- features.append(sample)
345
-
346
- sample = checkpoint_fn(self.mid_block)(sample)
347
-
348
- # post-process
349
- sample = self.conv_norm_out(sample)
350
- sample = self.conv_act(sample)
351
- sample = self.conv_out(sample)
352
-
353
- if self.latent_log_var == "uniform":
354
- last_channel = sample[:, -1:, ...]
355
- num_dims = sample.dim()
356
-
357
- if num_dims == 4:
358
- # For shape (B, C, H, W)
359
- repeated_last_channel = last_channel.repeat(
360
- 1, sample.shape[1] - 2, 1, 1
361
- )
362
- sample = torch.cat([sample, repeated_last_channel], dim=1)
363
- elif num_dims == 5:
364
- # For shape (B, C, F, H, W)
365
- repeated_last_channel = last_channel.repeat(
366
- 1, sample.shape[1] - 2, 1, 1, 1
367
- )
368
- sample = torch.cat([sample, repeated_last_channel], dim=1)
369
- else:
370
- raise ValueError(f"Invalid input shape: {sample.shape}")
371
-
372
- if return_features:
373
- features.append(sample[:, : self.latent_channels, ...])
374
- return sample, features
375
- return sample
376
-
377
-
378
- class Decoder(nn.Module):
379
- r"""
380
- The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
381
-
382
- Args:
383
- in_channels (`int`, *optional*, defaults to 3):
384
- The number of input channels.
385
- out_channels (`int`, *optional*, defaults to 3):
386
- The number of output channels.
387
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
388
- The number of output channels for each block.
389
- layers_per_block (`int`, *optional*, defaults to 2):
390
- The number of layers per block.
391
- norm_num_groups (`int`, *optional*, defaults to 32):
392
- The number of groups for normalization.
393
- patch_size (`int`, *optional*, defaults to 1):
394
- The patch size to use. Should be a power of 2.
395
- norm_layer (`str`, *optional*, defaults to `group_norm`):
396
- The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
397
- """
398
-
399
- def __init__(
400
- self,
401
- dims,
402
- in_channels: int = 3,
403
- out_channels: int = 3,
404
- block_out_channels: Tuple[int, ...] = (64,),
405
- layers_per_block: int = 2,
406
- norm_num_groups: int = 32,
407
- patch_size: int = 1,
408
- norm_layer: str = "group_norm",
409
- patch_size_t: Optional[int] = None,
410
- add_channel_padding: Optional[bool] = False,
411
- ):
412
- super().__init__()
413
- self.patch_size = patch_size
414
- self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
415
- self.add_channel_padding = add_channel_padding
416
- self.layers_per_block = layers_per_block
417
- if add_channel_padding:
418
- out_channels = out_channels * self.patch_size**3
419
- else:
420
- out_channels = out_channels * self.patch_size_t * self.patch_size**2
421
- self.out_channels = out_channels
422
-
423
- self.conv_in = make_conv_nd(
424
- dims,
425
- in_channels,
426
- block_out_channels[-1],
427
- kernel_size=3,
428
- stride=1,
429
- padding=1,
430
- )
431
-
432
- self.mid_block = None
433
- self.up_blocks = nn.ModuleList([])
434
-
435
- self.mid_block = UNetMidBlock3D(
436
- dims=dims,
437
- in_channels=block_out_channels[-1],
438
- num_layers=self.layers_per_block,
439
- resnet_eps=1e-6,
440
- resnet_groups=norm_num_groups,
441
- norm_layer=norm_layer,
442
- )
443
-
444
- reversed_block_out_channels = list(reversed(block_out_channels))
445
- output_channel = reversed_block_out_channels[0]
446
- for i in range(len(reversed_block_out_channels)):
447
- prev_output_channel = output_channel
448
- output_channel = reversed_block_out_channels[i]
449
-
450
- is_final_block = i == len(block_out_channels) - 1
451
-
452
- up_block = UpDecoderBlock3D(
453
- dims=dims,
454
- num_layers=self.layers_per_block + 1,
455
- in_channels=prev_output_channel,
456
- out_channels=output_channel,
457
- add_upsample=not is_final_block
458
- and 2 ** (len(block_out_channels) - i - 1) > patch_size,
459
- resnet_eps=1e-6,
460
- resnet_groups=norm_num_groups,
461
- norm_layer=norm_layer,
462
- )
463
- self.up_blocks.append(up_block)
464
-
465
- if norm_layer == "group_norm":
466
- self.conv_norm_out = nn.GroupNorm(
467
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
468
- )
469
- elif norm_layer == "pixel_norm":
470
- self.conv_norm_out = PixelNorm()
471
-
472
- self.conv_act = nn.SiLU()
473
- self.conv_out = make_conv_nd(
474
- dims, block_out_channels[0], out_channels, 3, padding=1
475
- )
476
-
477
- self.gradient_checkpointing = False
478
-
479
- def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
480
- r"""The forward method of the `Decoder` class."""
481
- assert target_shape is not None, "target_shape must be provided"
482
- upsample_in_time = sample.shape[2] < target_shape[2]
483
-
484
- sample = self.conv_in(sample)
485
-
486
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
487
-
488
- checkpoint_fn = (
489
- partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
490
- if self.gradient_checkpointing and self.training
491
- else lambda x: x
492
- )
493
-
494
- sample = checkpoint_fn(self.mid_block)(sample)
495
- sample = sample.to(upscale_dtype)
496
-
497
- for up_block in self.up_blocks:
498
- sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
499
-
500
- # post-process
501
- sample = self.conv_norm_out(sample)
502
- sample = self.conv_act(sample)
503
- sample = self.conv_out(sample)
504
-
505
- # un-patchify
506
- patch_size_t = self.patch_size_t if upsample_in_time else 1
507
- sample = unpatchify(
508
- sample,
509
- patch_size_hw=self.patch_size,
510
- patch_size_t=patch_size_t,
511
- add_channel_padding=self.add_channel_padding,
512
- )
513
-
514
- return sample
515
-
516
-
517
- class DownEncoderBlock3D(nn.Module):
518
- def __init__(
519
- self,
520
- dims: Union[int, Tuple[int, int]],
521
- in_channels: int,
522
- out_channels: int,
523
- dropout: float = 0.0,
524
- num_layers: int = 1,
525
- resnet_eps: float = 1e-6,
526
- resnet_groups: int = 32,
527
- add_downsample: bool = True,
528
- downsample_padding: int = 1,
529
- norm_layer: str = "group_norm",
530
- ):
531
- super().__init__()
532
- res_blocks = []
533
-
534
- for i in range(num_layers):
535
- in_channels = in_channels if i == 0 else out_channels
536
- res_blocks.append(
537
- ResnetBlock3D(
538
- dims=dims,
539
- in_channels=in_channels,
540
- out_channels=out_channels,
541
- eps=resnet_eps,
542
- groups=resnet_groups,
543
- dropout=dropout,
544
- norm_layer=norm_layer,
545
- )
546
- )
547
-
548
- self.res_blocks = nn.ModuleList(res_blocks)
549
-
550
- if add_downsample:
551
- self.downsample = Downsample3D(
552
- dims,
553
- out_channels,
554
- out_channels=out_channels,
555
- padding=downsample_padding,
556
- )
557
- else:
558
- self.downsample = Identity()
559
-
560
- def forward(
561
- self, hidden_states: torch.FloatTensor, downsample_in_time
562
- ) -> torch.FloatTensor:
563
- for resnet in self.res_blocks:
564
- hidden_states = resnet(hidden_states)
565
-
566
- hidden_states = self.downsample(
567
- hidden_states, downsample_in_time=downsample_in_time
568
- )
569
-
570
- return hidden_states
571
-
572
-
573
- class UNetMidBlock3D(nn.Module):
574
- """
575
- A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
576
-
577
- Args:
578
- in_channels (`int`): The number of input channels.
579
- dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
580
- num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
581
- resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
582
- resnet_groups (`int`, *optional*, defaults to 32):
583
- The number of groups to use in the group normalization layers of the resnet blocks.
584
-
585
- Returns:
586
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
587
- in_channels, height, width)`.
588
-
589
- """
590
-
591
- def __init__(
592
- self,
593
- dims: Union[int, Tuple[int, int]],
594
- in_channels: int,
595
- dropout: float = 0.0,
596
- num_layers: int = 1,
597
- resnet_eps: float = 1e-6,
598
- resnet_groups: int = 32,
599
- norm_layer: str = "group_norm",
600
- ):
601
- super().__init__()
602
- resnet_groups = (
603
- resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
604
- )
605
-
606
- self.res_blocks = nn.ModuleList(
607
- [
608
- ResnetBlock3D(
609
- dims=dims,
610
- in_channels=in_channels,
611
- out_channels=in_channels,
612
- eps=resnet_eps,
613
- groups=resnet_groups,
614
- dropout=dropout,
615
- norm_layer=norm_layer,
616
- )
617
- for _ in range(num_layers)
618
- ]
619
- )
620
-
621
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
622
- for resnet in self.res_blocks:
623
- hidden_states = resnet(hidden_states)
624
-
625
- return hidden_states
626
-
627
-
628
- class UpDecoderBlock3D(nn.Module):
629
- def __init__(
630
- self,
631
- dims: Union[int, Tuple[int, int]],
632
- in_channels: int,
633
- out_channels: int,
634
- resolution_idx: Optional[int] = None,
635
- dropout: float = 0.0,
636
- num_layers: int = 1,
637
- resnet_eps: float = 1e-6,
638
- resnet_groups: int = 32,
639
- add_upsample: bool = True,
640
- norm_layer: str = "group_norm",
641
- ):
642
- super().__init__()
643
- res_blocks = []
644
-
645
- for i in range(num_layers):
646
- input_channels = in_channels if i == 0 else out_channels
647
-
648
- res_blocks.append(
649
- ResnetBlock3D(
650
- dims=dims,
651
- in_channels=input_channels,
652
- out_channels=out_channels,
653
- eps=resnet_eps,
654
- groups=resnet_groups,
655
- dropout=dropout,
656
- norm_layer=norm_layer,
657
- )
658
- )
659
-
660
- self.res_blocks = nn.ModuleList(res_blocks)
661
-
662
- if add_upsample:
663
- self.upsample = Upsample3D(
664
- dims=dims, channels=out_channels, out_channels=out_channels
665
- )
666
- else:
667
- self.upsample = Identity()
668
-
669
- self.resolution_idx = resolution_idx
670
-
671
- def forward(
672
- self, hidden_states: torch.FloatTensor, upsample_in_time=True
673
- ) -> torch.FloatTensor:
674
- for resnet in self.res_blocks:
675
- hidden_states = resnet(hidden_states)
676
-
677
- hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
678
-
679
- return hidden_states
680
-
681
-
682
- class ResnetBlock3D(nn.Module):
683
- r"""
684
- A Resnet block.
685
-
686
- Parameters:
687
- in_channels (`int`): The number of channels in the input.
688
- out_channels (`int`, *optional*, default to be `None`):
689
- The number of output channels for the first conv layer. If None, same as `in_channels`.
690
- dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
691
- groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
692
- eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
693
- """
694
-
695
- def __init__(
696
- self,
697
- dims: Union[int, Tuple[int, int]],
698
- in_channels: int,
699
- out_channels: Optional[int] = None,
700
- conv_shortcut: bool = False,
701
- dropout: float = 0.0,
702
- groups: int = 32,
703
- eps: float = 1e-6,
704
- norm_layer: str = "group_norm",
705
- ):
706
- super().__init__()
707
- self.in_channels = in_channels
708
- out_channels = in_channels if out_channels is None else out_channels
709
- self.out_channels = out_channels
710
- self.use_conv_shortcut = conv_shortcut
711
-
712
- if norm_layer == "group_norm":
713
- self.norm1 = torch.nn.GroupNorm(
714
- num_groups=groups, num_channels=in_channels, eps=eps, affine=True
715
- )
716
- elif norm_layer == "pixel_norm":
717
- self.norm1 = PixelNorm()
718
-
719
- self.non_linearity = nn.SiLU()
720
-
721
- self.conv1 = make_conv_nd(
722
- dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
723
- )
724
-
725
- if norm_layer == "group_norm":
726
- self.norm2 = torch.nn.GroupNorm(
727
- num_groups=groups, num_channels=out_channels, eps=eps, affine=True
728
- )
729
- elif norm_layer == "pixel_norm":
730
- self.norm2 = PixelNorm()
731
-
732
- self.dropout = torch.nn.Dropout(dropout)
733
-
734
- self.conv2 = make_conv_nd(
735
- dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
736
- )
737
-
738
- self.conv_shortcut = (
739
- make_linear_nd(
740
- dims=dims, in_channels=in_channels, out_channels=out_channels
741
- )
742
- if in_channels != out_channels
743
- else nn.Identity()
744
- )
745
-
746
- def forward(
747
- self,
748
- input_tensor: torch.FloatTensor,
749
- ) -> torch.FloatTensor:
750
- hidden_states = input_tensor
751
-
752
- hidden_states = self.norm1(hidden_states)
753
-
754
- hidden_states = self.non_linearity(hidden_states)
755
-
756
- hidden_states = self.conv1(hidden_states)
757
-
758
- hidden_states = self.norm2(hidden_states)
759
-
760
- hidden_states = self.non_linearity(hidden_states)
761
-
762
- hidden_states = self.dropout(hidden_states)
763
-
764
- hidden_states = self.conv2(hidden_states)
765
-
766
- input_tensor = self.conv_shortcut(input_tensor)
767
-
768
- output_tensor = input_tensor + hidden_states
769
-
770
- return output_tensor
771
-
772
-
773
- class Downsample3D(nn.Module):
774
- def __init__(
775
- self,
776
- dims,
777
- in_channels: int,
778
- out_channels: int,
779
- kernel_size: int = 3,
780
- padding: int = 1,
781
- ):
782
- super().__init__()
783
- stride: int = 2
784
- self.padding = padding
785
- self.in_channels = in_channels
786
- self.dims = dims
787
- self.conv = make_conv_nd(
788
- dims=dims,
789
- in_channels=in_channels,
790
- out_channels=out_channels,
791
- kernel_size=kernel_size,
792
- stride=stride,
793
- padding=padding,
794
- )
795
-
796
- def forward(self, x, downsample_in_time=True):
797
- conv = self.conv
798
- if self.padding == 0:
799
- if self.dims == 2:
800
- padding = (0, 1, 0, 1)
801
- else:
802
- padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
803
-
804
- x = functional.pad(x, padding, mode="constant", value=0)
805
-
806
- if self.dims == (2, 1) and not downsample_in_time:
807
- return conv(x, skip_time_conv=True)
808
-
809
- return conv(x)
810
-
811
-
812
- class Upsample3D(nn.Module):
813
- """
814
- An upsampling layer for 3D tensors of shape (B, C, D, H, W).
815
-
816
- :param channels: channels in the inputs and outputs.
817
- """
818
-
819
- def __init__(self, dims, channels, out_channels=None):
820
- super().__init__()
821
- self.dims = dims
822
- self.channels = channels
823
- self.out_channels = out_channels or channels
824
- self.conv = make_conv_nd(
825
- dims, channels, out_channels, kernel_size=3, padding=1, bias=True
826
- )
827
-
828
- def forward(self, x, upsample_in_time):
829
- if self.dims == 2:
830
- x = functional.interpolate(
831
- x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
832
- )
833
- else:
834
- time_scale_factor = 2 if upsample_in_time else 1
835
- # print("before:", x.shape)
836
- b, c, d, h, w = x.shape
837
- x = rearrange(x, "b c d h w -> (b d) c h w")
838
- # height and width interpolate
839
- x = functional.interpolate(
840
- x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
841
- )
842
- _, _, h, w = x.shape
843
-
844
- if not upsample_in_time and self.dims == (2, 1):
845
- x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
846
- return self.conv(x, skip_time_conv=True)
847
-
848
- # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
849
- x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
850
-
851
- # (b h w) c 1 d
852
- new_d = x.shape[-1] * time_scale_factor
853
- x = functional.interpolate(x, (1, new_d), mode="nearest")
854
- # (b h w) c 1 new_d
855
- x = rearrange(
856
- x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
857
- )
858
- # b c d h w
859
-
860
- # x = functional.interpolate(
861
- # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
862
- # )
863
- # print("after:", x.shape)
864
-
865
- return self.conv(x)
866
-
867
-
868
- def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
869
- if patch_size_hw == 1 and patch_size_t == 1:
870
- return x
871
- if x.dim() == 4:
872
- x = rearrange(
873
- x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
874
- )
875
- elif x.dim() == 5:
876
- x = rearrange(
877
- x,
878
- "b c (f p) (h q) (w r) -> b (c p r q) f h w",
879
- p=patch_size_t,
880
- q=patch_size_hw,
881
- r=patch_size_hw,
882
- )
883
- else:
884
- raise ValueError(f"Invalid input shape: {x.shape}")
885
-
886
- if (
887
- (x.dim() == 5)
888
- and (patch_size_hw > patch_size_t)
889
- and (patch_size_t > 1 or add_channel_padding)
890
- ):
891
- channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
892
- padding_zeros = torch.zeros(
893
- x.shape[0],
894
- channels_to_pad,
895
- x.shape[2],
896
- x.shape[3],
897
- x.shape[4],
898
- device=x.device,
899
- dtype=x.dtype,
900
- )
901
- x = torch.cat([padding_zeros, x], dim=1)
902
-
903
- return x
904
-
905
-
906
- def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
907
- if patch_size_hw == 1 and patch_size_t == 1:
908
- return x
909
-
910
- if (
911
- (x.dim() == 5)
912
- and (patch_size_hw > patch_size_t)
913
- and (patch_size_t > 1 or add_channel_padding)
914
- ):
915
- channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
916
- x = x[:, :channels_to_keep, :, :, :]
917
-
918
- if x.dim() == 4:
919
- x = rearrange(
920
- x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
921
- )
922
- elif x.dim() == 5:
923
- x = rearrange(
924
- x,
925
- "b (c p r q) f h w -> b c (f p) (h q) (w r)",
926
- p=patch_size_t,
927
- q=patch_size_hw,
928
- r=patch_size_hw,
929
- )
930
-
931
- return x
932
-
933
-
934
- def create_video_autoencoder_config(
935
- latent_channels: int = 4,
936
- ):
937
- config = {
938
- "_class_name": "VideoAutoencoder",
939
- "dims": (
940
- 2,
941
- 1,
942
- ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
943
- "in_channels": 3, # Number of input color channels (e.g., RGB)
944
- "out_channels": 3, # Number of output color channels
945
- "latent_channels": latent_channels, # Number of channels in the latent space representation
946
- "block_out_channels": [
947
- 128,
948
- 256,
949
- 512,
950
- 512,
951
- ], # Number of output channels of each encoder / decoder inner block
952
- "patch_size": 1,
953
- }
954
-
955
- return config
956
-
957
-
958
- def create_video_autoencoder_pathify4x4x4_config(
959
- latent_channels: int = 4,
960
- ):
961
- config = {
962
- "_class_name": "VideoAutoencoder",
963
- "dims": (
964
- 2,
965
- 1,
966
- ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
967
- "in_channels": 3, # Number of input color channels (e.g., RGB)
968
- "out_channels": 3, # Number of output color channels
969
- "latent_channels": latent_channels, # Number of channels in the latent space representation
970
- "block_out_channels": [512]
971
- * 4, # Number of output channels of each encoder / decoder inner block
972
- "patch_size": 4,
973
- "latent_log_var": "uniform",
974
- }
975
-
976
- return config
977
-
978
-
979
- def create_video_autoencoder_pathify4x4_config(
980
- latent_channels: int = 4,
981
- ):
982
- config = {
983
- "_class_name": "VideoAutoencoder",
984
- "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
985
- "in_channels": 3, # Number of input color channels (e.g., RGB)
986
- "out_channels": 3, # Number of output color channels
987
- "latent_channels": latent_channels, # Number of channels in the latent space representation
988
- "block_out_channels": [512]
989
- * 4, # Number of output channels of each encoder / decoder inner block
990
- "patch_size": 4,
991
- "norm_layer": "pixel_norm",
992
- }
993
-
994
- return config
995
-
996
-
997
- def test_vae_patchify_unpatchify():
998
- import torch
999
-
1000
- x = torch.randn(2, 3, 8, 64, 64)
1001
- x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
1002
- x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
1003
- assert torch.allclose(x, x_unpatched)
1004
-
1005
-
1006
- def demo_video_autoencoder_forward_backward():
1007
- # Configuration for the VideoAutoencoder
1008
- config = create_video_autoencoder_pathify4x4x4_config()
1009
-
1010
- # Instantiate the VideoAutoencoder with the specified configuration
1011
- video_autoencoder = VideoAutoencoder.from_config(config)
1012
-
1013
- print(video_autoencoder)
1014
-
1015
- # Print the total number of parameters in the video autoencoder
1016
- total_params = sum(p.numel() for p in video_autoencoder.parameters())
1017
- print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
1018
-
1019
- # Create a mock input tensor simulating a batch of videos
1020
- # Shape: (batch_size, channels, depth, height, width)
1021
- # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
1022
- input_videos = torch.randn(2, 3, 8, 64, 64)
1023
-
1024
- # Forward pass: encode and decode the input videos
1025
- latent = video_autoencoder.encode(input_videos).latent_dist.mode()
1026
- print(f"input shape={input_videos.shape}")
1027
- print(f"latent shape={latent.shape}")
1028
- reconstructed_videos = video_autoencoder.decode(
1029
- latent, target_shape=input_videos.shape
1030
- ).sample
1031
-
1032
- print(f"reconstructed shape={reconstructed_videos.shape}")
1033
-
1034
- # Calculate the loss (e.g., mean squared error)
1035
- loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
1036
-
1037
- # Perform backward pass
1038
- loss.backward()
1039
-
1040
- print(f"Demo completed with loss: {loss.item()}")
1041
-
1042
-
1043
- # Ensure to call the demo function to execute the forward and backward pass
1044
- if __name__ == "__main__":
1045
- demo_video_autoencoder_forward_backward()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/transformers/__init__.py DELETED
File without changes
xora/models/transformers/attention.py DELETED
@@ -1,1206 +0,0 @@
1
- import inspect
2
- from importlib import import_module
3
- from typing import Any, Dict, Optional, Tuple
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
8
- from diffusers.models.attention import _chunked_feed_forward
9
- from diffusers.models.attention_processor import (
10
- LoRAAttnAddedKVProcessor,
11
- LoRAAttnProcessor,
12
- LoRAAttnProcessor2_0,
13
- LoRAXFormersAttnProcessor,
14
- SpatialNorm,
15
- )
16
- from diffusers.models.lora import LoRACompatibleLinear
17
- from diffusers.models.normalization import RMSNorm
18
- from diffusers.utils import deprecate, logging
19
- from diffusers.utils.torch_utils import maybe_allow_in_graph
20
- from einops import rearrange
21
- from torch import nn
22
-
23
- try:
24
- from torch_xla.experimental.custom_kernel import flash_attention
25
- except ImportError:
26
- # workaround for automatic tests. Currently this function is manually patched
27
- # to the torch_xla lib on setup of container
28
- pass
29
-
30
- # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
31
-
32
- logger = logging.get_logger(__name__)
33
-
34
-
35
- @maybe_allow_in_graph
36
- class BasicTransformerBlock(nn.Module):
37
- r"""
38
- A basic Transformer block.
39
-
40
- Parameters:
41
- dim (`int`): The number of channels in the input and output.
42
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
43
- attention_head_dim (`int`): The number of channels in each head.
44
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
45
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
46
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
47
- num_embeds_ada_norm (:
48
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
49
- attention_bias (:
50
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
51
- only_cross_attention (`bool`, *optional*):
52
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
53
- double_self_attention (`bool`, *optional*):
54
- Whether to use two self-attention layers. In this case no cross attention layers are used.
55
- upcast_attention (`bool`, *optional*):
56
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
57
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
58
- Whether to use learnable elementwise affine parameters for normalization.
59
- qk_norm (`str`, *optional*, defaults to None):
60
- Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
61
- adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`):
62
- The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none".
63
- standardization_norm (`str`, *optional*, defaults to `"layer_norm"`):
64
- The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
65
- final_dropout (`bool` *optional*, defaults to False):
66
- Whether to apply a final dropout after the last feed-forward layer.
67
- attention_type (`str`, *optional*, defaults to `"default"`):
68
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
69
- positional_embeddings (`str`, *optional*, defaults to `None`):
70
- The type of positional embeddings to apply to.
71
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
72
- The maximum number of positional embeddings to apply.
73
- """
74
-
75
- def __init__(
76
- self,
77
- dim: int,
78
- num_attention_heads: int,
79
- attention_head_dim: int,
80
- dropout=0.0,
81
- cross_attention_dim: Optional[int] = None,
82
- activation_fn: str = "geglu",
83
- num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument
84
- attention_bias: bool = False,
85
- only_cross_attention: bool = False,
86
- double_self_attention: bool = False,
87
- upcast_attention: bool = False,
88
- norm_elementwise_affine: bool = True,
89
- adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none'
90
- standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
91
- norm_eps: float = 1e-5,
92
- qk_norm: Optional[str] = None,
93
- final_dropout: bool = False,
94
- attention_type: str = "default", # pylint: disable=unused-argument
95
- ff_inner_dim: Optional[int] = None,
96
- ff_bias: bool = True,
97
- attention_out_bias: bool = True,
98
- use_tpu_flash_attention: bool = False,
99
- use_rope: bool = False,
100
- ):
101
- super().__init__()
102
- self.only_cross_attention = only_cross_attention
103
- self.use_tpu_flash_attention = use_tpu_flash_attention
104
- self.adaptive_norm = adaptive_norm
105
-
106
- assert standardization_norm in ["layer_norm", "rms_norm"]
107
- assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
108
-
109
- make_norm_layer = (
110
- nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
111
- )
112
-
113
- # Define 3 blocks. Each block has its own normalization layer.
114
- # 1. Self-Attn
115
- self.norm1 = make_norm_layer(
116
- dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
117
- )
118
-
119
- self.attn1 = Attention(
120
- query_dim=dim,
121
- heads=num_attention_heads,
122
- dim_head=attention_head_dim,
123
- dropout=dropout,
124
- bias=attention_bias,
125
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
126
- upcast_attention=upcast_attention,
127
- out_bias=attention_out_bias,
128
- use_tpu_flash_attention=use_tpu_flash_attention,
129
- qk_norm=qk_norm,
130
- use_rope=use_rope,
131
- )
132
-
133
- # 2. Cross-Attn
134
- if cross_attention_dim is not None or double_self_attention:
135
- self.attn2 = Attention(
136
- query_dim=dim,
137
- cross_attention_dim=(
138
- cross_attention_dim if not double_self_attention else None
139
- ),
140
- heads=num_attention_heads,
141
- dim_head=attention_head_dim,
142
- dropout=dropout,
143
- bias=attention_bias,
144
- upcast_attention=upcast_attention,
145
- out_bias=attention_out_bias,
146
- use_tpu_flash_attention=use_tpu_flash_attention,
147
- qk_norm=qk_norm,
148
- use_rope=use_rope,
149
- ) # is self-attn if encoder_hidden_states is none
150
-
151
- if adaptive_norm == "none":
152
- self.attn2_norm = make_norm_layer(
153
- dim, norm_eps, norm_elementwise_affine
154
- )
155
- else:
156
- self.attn2 = None
157
- self.attn2_norm = None
158
-
159
- self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
160
-
161
- # 3. Feed-forward
162
- self.ff = FeedForward(
163
- dim,
164
- dropout=dropout,
165
- activation_fn=activation_fn,
166
- final_dropout=final_dropout,
167
- inner_dim=ff_inner_dim,
168
- bias=ff_bias,
169
- )
170
-
171
- # 5. Scale-shift for PixArt-Alpha.
172
- if adaptive_norm != "none":
173
- num_ada_params = 4 if adaptive_norm == "single_scale" else 6
174
- self.scale_shift_table = nn.Parameter(
175
- torch.randn(num_ada_params, dim) / dim**0.5
176
- )
177
-
178
- # let chunk size default to None
179
- self._chunk_size = None
180
- self._chunk_dim = 0
181
-
182
- def set_use_tpu_flash_attention(self):
183
- r"""
184
- Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
185
- attention kernel.
186
- """
187
- self.use_tpu_flash_attention = True
188
- self.attn1.set_use_tpu_flash_attention()
189
- self.attn2.set_use_tpu_flash_attention()
190
-
191
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
192
- # Sets chunk feed-forward
193
- self._chunk_size = chunk_size
194
- self._chunk_dim = dim
195
-
196
- def forward(
197
- self,
198
- hidden_states: torch.FloatTensor,
199
- freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
200
- attention_mask: Optional[torch.FloatTensor] = None,
201
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
202
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
203
- timestep: Optional[torch.LongTensor] = None,
204
- cross_attention_kwargs: Dict[str, Any] = None,
205
- class_labels: Optional[torch.LongTensor] = None,
206
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
207
- ) -> torch.FloatTensor:
208
- if cross_attention_kwargs is not None:
209
- if cross_attention_kwargs.get("scale", None) is not None:
210
- logger.warning(
211
- "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
212
- )
213
-
214
- # Notice that normalization is always applied before the real computation in the following blocks.
215
- # 0. Self-Attention
216
- batch_size = hidden_states.shape[0]
217
-
218
- norm_hidden_states = self.norm1(hidden_states)
219
-
220
- # Apply ada_norm_single
221
- if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
222
- assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
223
- num_ada_params = self.scale_shift_table.shape[0]
224
- ada_values = self.scale_shift_table[None, None] + timestep.reshape(
225
- batch_size, timestep.shape[1], num_ada_params, -1
226
- )
227
- if self.adaptive_norm == "single_scale_shift":
228
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
229
- ada_values.unbind(dim=2)
230
- )
231
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
232
- else:
233
- scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
234
- norm_hidden_states = norm_hidden_states * (1 + scale_msa)
235
- elif self.adaptive_norm == "none":
236
- scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None
237
- else:
238
- raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
239
-
240
- norm_hidden_states = norm_hidden_states.squeeze(
241
- 1
242
- ) # TODO: Check if this is needed
243
-
244
- # 1. Prepare GLIGEN inputs
245
- cross_attention_kwargs = (
246
- cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
247
- )
248
-
249
- attn_output = self.attn1(
250
- norm_hidden_states,
251
- freqs_cis=freqs_cis,
252
- encoder_hidden_states=(
253
- encoder_hidden_states if self.only_cross_attention else None
254
- ),
255
- attention_mask=attention_mask,
256
- **cross_attention_kwargs,
257
- )
258
- if gate_msa is not None:
259
- attn_output = gate_msa * attn_output
260
-
261
- hidden_states = attn_output + hidden_states
262
- if hidden_states.ndim == 4:
263
- hidden_states = hidden_states.squeeze(1)
264
-
265
- # 3. Cross-Attention
266
- if self.attn2 is not None:
267
- if self.adaptive_norm == "none":
268
- attn_input = self.attn2_norm(hidden_states)
269
- else:
270
- attn_input = hidden_states
271
- attn_output = self.attn2(
272
- attn_input,
273
- freqs_cis=freqs_cis,
274
- encoder_hidden_states=encoder_hidden_states,
275
- attention_mask=encoder_attention_mask,
276
- **cross_attention_kwargs,
277
- )
278
- hidden_states = attn_output + hidden_states
279
-
280
- # 4. Feed-forward
281
- norm_hidden_states = self.norm2(hidden_states)
282
- if self.adaptive_norm == "single_scale_shift":
283
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
284
- elif self.adaptive_norm == "single_scale":
285
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp)
286
- elif self.adaptive_norm == "none":
287
- pass
288
- else:
289
- raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
290
-
291
- if self._chunk_size is not None:
292
- # "feed_forward_chunk_size" can be used to save memory
293
- ff_output = _chunked_feed_forward(
294
- self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
295
- )
296
- else:
297
- ff_output = self.ff(norm_hidden_states)
298
- if gate_mlp is not None:
299
- ff_output = gate_mlp * ff_output
300
-
301
- hidden_states = ff_output + hidden_states
302
- if hidden_states.ndim == 4:
303
- hidden_states = hidden_states.squeeze(1)
304
-
305
- return hidden_states
306
-
307
-
308
- @maybe_allow_in_graph
309
- class Attention(nn.Module):
310
- r"""
311
- A cross attention layer.
312
-
313
- Parameters:
314
- query_dim (`int`):
315
- The number of channels in the query.
316
- cross_attention_dim (`int`, *optional*):
317
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
318
- heads (`int`, *optional*, defaults to 8):
319
- The number of heads to use for multi-head attention.
320
- dim_head (`int`, *optional*, defaults to 64):
321
- The number of channels in each head.
322
- dropout (`float`, *optional*, defaults to 0.0):
323
- The dropout probability to use.
324
- bias (`bool`, *optional*, defaults to False):
325
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
326
- upcast_attention (`bool`, *optional*, defaults to False):
327
- Set to `True` to upcast the attention computation to `float32`.
328
- upcast_softmax (`bool`, *optional*, defaults to False):
329
- Set to `True` to upcast the softmax computation to `float32`.
330
- cross_attention_norm (`str`, *optional*, defaults to `None`):
331
- The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
332
- cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
333
- The number of groups to use for the group norm in the cross attention.
334
- added_kv_proj_dim (`int`, *optional*, defaults to `None`):
335
- The number of channels to use for the added key and value projections. If `None`, no projection is used.
336
- norm_num_groups (`int`, *optional*, defaults to `None`):
337
- The number of groups to use for the group norm in the attention.
338
- spatial_norm_dim (`int`, *optional*, defaults to `None`):
339
- The number of channels to use for the spatial normalization.
340
- out_bias (`bool`, *optional*, defaults to `True`):
341
- Set to `True` to use a bias in the output linear layer.
342
- scale_qk (`bool`, *optional*, defaults to `True`):
343
- Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
344
- qk_norm (`str`, *optional*, defaults to None):
345
- Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
346
- only_cross_attention (`bool`, *optional*, defaults to `False`):
347
- Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
348
- `added_kv_proj_dim` is not `None`.
349
- eps (`float`, *optional*, defaults to 1e-5):
350
- An additional value added to the denominator in group normalization that is used for numerical stability.
351
- rescale_output_factor (`float`, *optional*, defaults to 1.0):
352
- A factor to rescale the output by dividing it with this value.
353
- residual_connection (`bool`, *optional*, defaults to `False`):
354
- Set to `True` to add the residual connection to the output.
355
- _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
356
- Set to `True` if the attention block is loaded from a deprecated state dict.
357
- processor (`AttnProcessor`, *optional*, defaults to `None`):
358
- The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
359
- `AttnProcessor` otherwise.
360
- """
361
-
362
- def __init__(
363
- self,
364
- query_dim: int,
365
- cross_attention_dim: Optional[int] = None,
366
- heads: int = 8,
367
- dim_head: int = 64,
368
- dropout: float = 0.0,
369
- bias: bool = False,
370
- upcast_attention: bool = False,
371
- upcast_softmax: bool = False,
372
- cross_attention_norm: Optional[str] = None,
373
- cross_attention_norm_num_groups: int = 32,
374
- added_kv_proj_dim: Optional[int] = None,
375
- norm_num_groups: Optional[int] = None,
376
- spatial_norm_dim: Optional[int] = None,
377
- out_bias: bool = True,
378
- scale_qk: bool = True,
379
- qk_norm: Optional[str] = None,
380
- only_cross_attention: bool = False,
381
- eps: float = 1e-5,
382
- rescale_output_factor: float = 1.0,
383
- residual_connection: bool = False,
384
- _from_deprecated_attn_block: bool = False,
385
- processor: Optional["AttnProcessor"] = None,
386
- out_dim: int = None,
387
- use_tpu_flash_attention: bool = False,
388
- use_rope: bool = False,
389
- ):
390
- super().__init__()
391
- self.inner_dim = out_dim if out_dim is not None else dim_head * heads
392
- self.query_dim = query_dim
393
- self.use_bias = bias
394
- self.is_cross_attention = cross_attention_dim is not None
395
- self.cross_attention_dim = (
396
- cross_attention_dim if cross_attention_dim is not None else query_dim
397
- )
398
- self.upcast_attention = upcast_attention
399
- self.upcast_softmax = upcast_softmax
400
- self.rescale_output_factor = rescale_output_factor
401
- self.residual_connection = residual_connection
402
- self.dropout = dropout
403
- self.fused_projections = False
404
- self.out_dim = out_dim if out_dim is not None else query_dim
405
- self.use_tpu_flash_attention = use_tpu_flash_attention
406
- self.use_rope = use_rope
407
-
408
- # we make use of this private variable to know whether this class is loaded
409
- # with an deprecated state dict so that we can convert it on the fly
410
- self._from_deprecated_attn_block = _from_deprecated_attn_block
411
-
412
- self.scale_qk = scale_qk
413
- self.scale = dim_head**-0.5 if self.scale_qk else 1.0
414
-
415
- if qk_norm is None:
416
- self.q_norm = nn.Identity()
417
- self.k_norm = nn.Identity()
418
- elif qk_norm == "rms_norm":
419
- self.q_norm = RMSNorm(dim_head * heads, eps=1e-5)
420
- self.k_norm = RMSNorm(dim_head * heads, eps=1e-5)
421
- elif qk_norm == "layer_norm":
422
- self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
423
- self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
424
- else:
425
- raise ValueError(f"Unsupported qk_norm method: {qk_norm}")
426
-
427
- self.heads = out_dim // dim_head if out_dim is not None else heads
428
- # for slice_size > 0 the attention score computation
429
- # is split across the batch axis to save memory
430
- # You can set slice_size with `set_attention_slice`
431
- self.sliceable_head_dim = heads
432
-
433
- self.added_kv_proj_dim = added_kv_proj_dim
434
- self.only_cross_attention = only_cross_attention
435
-
436
- if self.added_kv_proj_dim is None and self.only_cross_attention:
437
- raise ValueError(
438
- "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
439
- )
440
-
441
- if norm_num_groups is not None:
442
- self.group_norm = nn.GroupNorm(
443
- num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
444
- )
445
- else:
446
- self.group_norm = None
447
-
448
- if spatial_norm_dim is not None:
449
- self.spatial_norm = SpatialNorm(
450
- f_channels=query_dim, zq_channels=spatial_norm_dim
451
- )
452
- else:
453
- self.spatial_norm = None
454
-
455
- if cross_attention_norm is None:
456
- self.norm_cross = None
457
- elif cross_attention_norm == "layer_norm":
458
- self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
459
- elif cross_attention_norm == "group_norm":
460
- if self.added_kv_proj_dim is not None:
461
- # The given `encoder_hidden_states` are initially of shape
462
- # (batch_size, seq_len, added_kv_proj_dim) before being projected
463
- # to (batch_size, seq_len, cross_attention_dim). The norm is applied
464
- # before the projection, so we need to use `added_kv_proj_dim` as
465
- # the number of channels for the group norm.
466
- norm_cross_num_channels = added_kv_proj_dim
467
- else:
468
- norm_cross_num_channels = self.cross_attention_dim
469
-
470
- self.norm_cross = nn.GroupNorm(
471
- num_channels=norm_cross_num_channels,
472
- num_groups=cross_attention_norm_num_groups,
473
- eps=1e-5,
474
- affine=True,
475
- )
476
- else:
477
- raise ValueError(
478
- f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
479
- )
480
-
481
- linear_cls = nn.Linear
482
-
483
- self.linear_cls = linear_cls
484
- self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
485
-
486
- if not self.only_cross_attention:
487
- # only relevant for the `AddedKVProcessor` classes
488
- self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
489
- self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
490
- else:
491
- self.to_k = None
492
- self.to_v = None
493
-
494
- if self.added_kv_proj_dim is not None:
495
- self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
496
- self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
497
-
498
- self.to_out = nn.ModuleList([])
499
- self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
500
- self.to_out.append(nn.Dropout(dropout))
501
-
502
- # set attention processor
503
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
504
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
505
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
506
- if processor is None:
507
- processor = AttnProcessor2_0()
508
- self.set_processor(processor)
509
-
510
- def set_use_tpu_flash_attention(self):
511
- r"""
512
- Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
513
- """
514
- self.use_tpu_flash_attention = True
515
-
516
- def set_processor(self, processor: "AttnProcessor") -> None:
517
- r"""
518
- Set the attention processor to use.
519
-
520
- Args:
521
- processor (`AttnProcessor`):
522
- The attention processor to use.
523
- """
524
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
525
- # pop `processor` from `self._modules`
526
- if (
527
- hasattr(self, "processor")
528
- and isinstance(self.processor, torch.nn.Module)
529
- and not isinstance(processor, torch.nn.Module)
530
- ):
531
- logger.info(
532
- f"You are removing possibly trained weights of {self.processor} with {processor}"
533
- )
534
- self._modules.pop("processor")
535
-
536
- self.processor = processor
537
-
538
- def get_processor(
539
- self, return_deprecated_lora: bool = False
540
- ) -> "AttentionProcessor": # noqa: F821
541
- r"""
542
- Get the attention processor in use.
543
-
544
- Args:
545
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
546
- Set to `True` to return the deprecated LoRA attention processor.
547
-
548
- Returns:
549
- "AttentionProcessor": The attention processor in use.
550
- """
551
- if not return_deprecated_lora:
552
- return self.processor
553
-
554
- # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
555
- # serialization format for LoRA Attention Processors. It should be deleted once the integration
556
- # with PEFT is completed.
557
- is_lora_activated = {
558
- name: module.lora_layer is not None
559
- for name, module in self.named_modules()
560
- if hasattr(module, "lora_layer")
561
- }
562
-
563
- # 1. if no layer has a LoRA activated we can return the processor as usual
564
- if not any(is_lora_activated.values()):
565
- return self.processor
566
-
567
- # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
568
- is_lora_activated.pop("add_k_proj", None)
569
- is_lora_activated.pop("add_v_proj", None)
570
- # 2. else it is not posssible that only some layers have LoRA activated
571
- if not all(is_lora_activated.values()):
572
- raise ValueError(
573
- f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
574
- )
575
-
576
- # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
577
- non_lora_processor_cls_name = self.processor.__class__.__name__
578
- lora_processor_cls = getattr(
579
- import_module(__name__), "LoRA" + non_lora_processor_cls_name
580
- )
581
-
582
- hidden_size = self.inner_dim
583
-
584
- # now create a LoRA attention processor from the LoRA layers
585
- if lora_processor_cls in [
586
- LoRAAttnProcessor,
587
- LoRAAttnProcessor2_0,
588
- LoRAXFormersAttnProcessor,
589
- ]:
590
- kwargs = {
591
- "cross_attention_dim": self.cross_attention_dim,
592
- "rank": self.to_q.lora_layer.rank,
593
- "network_alpha": self.to_q.lora_layer.network_alpha,
594
- "q_rank": self.to_q.lora_layer.rank,
595
- "q_hidden_size": self.to_q.lora_layer.out_features,
596
- "k_rank": self.to_k.lora_layer.rank,
597
- "k_hidden_size": self.to_k.lora_layer.out_features,
598
- "v_rank": self.to_v.lora_layer.rank,
599
- "v_hidden_size": self.to_v.lora_layer.out_features,
600
- "out_rank": self.to_out[0].lora_layer.rank,
601
- "out_hidden_size": self.to_out[0].lora_layer.out_features,
602
- }
603
-
604
- if hasattr(self.processor, "attention_op"):
605
- kwargs["attention_op"] = self.processor.attention_op
606
-
607
- lora_processor = lora_processor_cls(hidden_size, **kwargs)
608
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
609
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
610
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
611
- lora_processor.to_out_lora.load_state_dict(
612
- self.to_out[0].lora_layer.state_dict()
613
- )
614
- elif lora_processor_cls == LoRAAttnAddedKVProcessor:
615
- lora_processor = lora_processor_cls(
616
- hidden_size,
617
- cross_attention_dim=self.add_k_proj.weight.shape[0],
618
- rank=self.to_q.lora_layer.rank,
619
- network_alpha=self.to_q.lora_layer.network_alpha,
620
- )
621
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
622
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
623
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
624
- lora_processor.to_out_lora.load_state_dict(
625
- self.to_out[0].lora_layer.state_dict()
626
- )
627
-
628
- # only save if used
629
- if self.add_k_proj.lora_layer is not None:
630
- lora_processor.add_k_proj_lora.load_state_dict(
631
- self.add_k_proj.lora_layer.state_dict()
632
- )
633
- lora_processor.add_v_proj_lora.load_state_dict(
634
- self.add_v_proj.lora_layer.state_dict()
635
- )
636
- else:
637
- lora_processor.add_k_proj_lora = None
638
- lora_processor.add_v_proj_lora = None
639
- else:
640
- raise ValueError(f"{lora_processor_cls} does not exist.")
641
-
642
- return lora_processor
643
-
644
- def forward(
645
- self,
646
- hidden_states: torch.FloatTensor,
647
- freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
648
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
649
- attention_mask: Optional[torch.FloatTensor] = None,
650
- **cross_attention_kwargs,
651
- ) -> torch.Tensor:
652
- r"""
653
- The forward method of the `Attention` class.
654
-
655
- Args:
656
- hidden_states (`torch.Tensor`):
657
- The hidden states of the query.
658
- encoder_hidden_states (`torch.Tensor`, *optional*):
659
- The hidden states of the encoder.
660
- attention_mask (`torch.Tensor`, *optional*):
661
- The attention mask to use. If `None`, no mask is applied.
662
- **cross_attention_kwargs:
663
- Additional keyword arguments to pass along to the cross attention.
664
-
665
- Returns:
666
- `torch.Tensor`: The output of the attention layer.
667
- """
668
- # The `Attention` class can call different attention processors / attention functions
669
- # here we simply pass along all tensors to the selected processor class
670
- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
671
-
672
- attn_parameters = set(
673
- inspect.signature(self.processor.__call__).parameters.keys()
674
- )
675
- unused_kwargs = [
676
- k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
677
- ]
678
- if len(unused_kwargs) > 0:
679
- logger.warning(
680
- f"cross_attention_kwargs {unused_kwargs} are not expected by"
681
- f" {self.processor.__class__.__name__} and will be ignored."
682
- )
683
- cross_attention_kwargs = {
684
- k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
685
- }
686
-
687
- return self.processor(
688
- self,
689
- hidden_states,
690
- freqs_cis=freqs_cis,
691
- encoder_hidden_states=encoder_hidden_states,
692
- attention_mask=attention_mask,
693
- **cross_attention_kwargs,
694
- )
695
-
696
- def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
697
- r"""
698
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
699
- is the number of heads initialized while constructing the `Attention` class.
700
-
701
- Args:
702
- tensor (`torch.Tensor`): The tensor to reshape.
703
-
704
- Returns:
705
- `torch.Tensor`: The reshaped tensor.
706
- """
707
- head_size = self.heads
708
- batch_size, seq_len, dim = tensor.shape
709
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
710
- tensor = tensor.permute(0, 2, 1, 3).reshape(
711
- batch_size // head_size, seq_len, dim * head_size
712
- )
713
- return tensor
714
-
715
- def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
716
- r"""
717
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
718
- the number of heads initialized while constructing the `Attention` class.
719
-
720
- Args:
721
- tensor (`torch.Tensor`): The tensor to reshape.
722
- out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
723
- reshaped to `[batch_size * heads, seq_len, dim // heads]`.
724
-
725
- Returns:
726
- `torch.Tensor`: The reshaped tensor.
727
- """
728
-
729
- head_size = self.heads
730
- if tensor.ndim == 3:
731
- batch_size, seq_len, dim = tensor.shape
732
- extra_dim = 1
733
- else:
734
- batch_size, extra_dim, seq_len, dim = tensor.shape
735
- tensor = tensor.reshape(
736
- batch_size, seq_len * extra_dim, head_size, dim // head_size
737
- )
738
- tensor = tensor.permute(0, 2, 1, 3)
739
-
740
- if out_dim == 3:
741
- tensor = tensor.reshape(
742
- batch_size * head_size, seq_len * extra_dim, dim // head_size
743
- )
744
-
745
- return tensor
746
-
747
- def get_attention_scores(
748
- self,
749
- query: torch.Tensor,
750
- key: torch.Tensor,
751
- attention_mask: torch.Tensor = None,
752
- ) -> torch.Tensor:
753
- r"""
754
- Compute the attention scores.
755
-
756
- Args:
757
- query (`torch.Tensor`): The query tensor.
758
- key (`torch.Tensor`): The key tensor.
759
- attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
760
-
761
- Returns:
762
- `torch.Tensor`: The attention probabilities/scores.
763
- """
764
- dtype = query.dtype
765
- if self.upcast_attention:
766
- query = query.float()
767
- key = key.float()
768
-
769
- if attention_mask is None:
770
- baddbmm_input = torch.empty(
771
- query.shape[0],
772
- query.shape[1],
773
- key.shape[1],
774
- dtype=query.dtype,
775
- device=query.device,
776
- )
777
- beta = 0
778
- else:
779
- baddbmm_input = attention_mask
780
- beta = 1
781
-
782
- attention_scores = torch.baddbmm(
783
- baddbmm_input,
784
- query,
785
- key.transpose(-1, -2),
786
- beta=beta,
787
- alpha=self.scale,
788
- )
789
- del baddbmm_input
790
-
791
- if self.upcast_softmax:
792
- attention_scores = attention_scores.float()
793
-
794
- attention_probs = attention_scores.softmax(dim=-1)
795
- del attention_scores
796
-
797
- attention_probs = attention_probs.to(dtype)
798
-
799
- return attention_probs
800
-
801
- def prepare_attention_mask(
802
- self,
803
- attention_mask: torch.Tensor,
804
- target_length: int,
805
- batch_size: int,
806
- out_dim: int = 3,
807
- ) -> torch.Tensor:
808
- r"""
809
- Prepare the attention mask for the attention computation.
810
-
811
- Args:
812
- attention_mask (`torch.Tensor`):
813
- The attention mask to prepare.
814
- target_length (`int`):
815
- The target length of the attention mask. This is the length of the attention mask after padding.
816
- batch_size (`int`):
817
- The batch size, which is used to repeat the attention mask.
818
- out_dim (`int`, *optional*, defaults to `3`):
819
- The output dimension of the attention mask. Can be either `3` or `4`.
820
-
821
- Returns:
822
- `torch.Tensor`: The prepared attention mask.
823
- """
824
- head_size = self.heads
825
- if attention_mask is None:
826
- return attention_mask
827
-
828
- current_length: int = attention_mask.shape[-1]
829
- if current_length != target_length:
830
- if attention_mask.device.type == "mps":
831
- # HACK: MPS: Does not support padding by greater than dimension of input tensor.
832
- # Instead, we can manually construct the padding tensor.
833
- padding_shape = (
834
- attention_mask.shape[0],
835
- attention_mask.shape[1],
836
- target_length,
837
- )
838
- padding = torch.zeros(
839
- padding_shape,
840
- dtype=attention_mask.dtype,
841
- device=attention_mask.device,
842
- )
843
- attention_mask = torch.cat([attention_mask, padding], dim=2)
844
- else:
845
- # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
846
- # we want to instead pad by (0, remaining_length), where remaining_length is:
847
- # remaining_length: int = target_length - current_length
848
- # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
849
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
850
-
851
- if out_dim == 3:
852
- if attention_mask.shape[0] < batch_size * head_size:
853
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
854
- elif out_dim == 4:
855
- attention_mask = attention_mask.unsqueeze(1)
856
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
857
-
858
- return attention_mask
859
-
860
- def norm_encoder_hidden_states(
861
- self, encoder_hidden_states: torch.Tensor
862
- ) -> torch.Tensor:
863
- r"""
864
- Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
865
- `Attention` class.
866
-
867
- Args:
868
- encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
869
-
870
- Returns:
871
- `torch.Tensor`: The normalized encoder hidden states.
872
- """
873
- assert (
874
- self.norm_cross is not None
875
- ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
876
-
877
- if isinstance(self.norm_cross, nn.LayerNorm):
878
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
879
- elif isinstance(self.norm_cross, nn.GroupNorm):
880
- # Group norm norms along the channels dimension and expects
881
- # input to be in the shape of (N, C, *). In this case, we want
882
- # to norm along the hidden dimension, so we need to move
883
- # (batch_size, sequence_length, hidden_size) ->
884
- # (batch_size, hidden_size, sequence_length)
885
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
886
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
887
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
888
- else:
889
- assert False
890
-
891
- return encoder_hidden_states
892
-
893
- @staticmethod
894
- def apply_rotary_emb(
895
- input_tensor: torch.Tensor,
896
- freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
897
- ) -> Tuple[torch.Tensor, torch.Tensor]:
898
- cos_freqs = freqs_cis[0]
899
- sin_freqs = freqs_cis[1]
900
-
901
- t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
902
- t1, t2 = t_dup.unbind(dim=-1)
903
- t_dup = torch.stack((-t2, t1), dim=-1)
904
- input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
905
-
906
- out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
907
-
908
- return out
909
-
910
-
911
- class AttnProcessor2_0:
912
- r"""
913
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
914
- """
915
-
916
- def __init__(self):
917
- pass
918
-
919
- def __call__(
920
- self,
921
- attn: Attention,
922
- hidden_states: torch.FloatTensor,
923
- freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
924
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
925
- attention_mask: Optional[torch.FloatTensor] = None,
926
- temb: Optional[torch.FloatTensor] = None,
927
- *args,
928
- **kwargs,
929
- ) -> torch.FloatTensor:
930
- if len(args) > 0 or kwargs.get("scale", None) is not None:
931
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
932
- deprecate("scale", "1.0.0", deprecation_message)
933
-
934
- residual = hidden_states
935
- if attn.spatial_norm is not None:
936
- hidden_states = attn.spatial_norm(hidden_states, temb)
937
-
938
- input_ndim = hidden_states.ndim
939
-
940
- if input_ndim == 4:
941
- batch_size, channel, height, width = hidden_states.shape
942
- hidden_states = hidden_states.view(
943
- batch_size, channel, height * width
944
- ).transpose(1, 2)
945
-
946
- batch_size, sequence_length, _ = (
947
- hidden_states.shape
948
- if encoder_hidden_states is None
949
- else encoder_hidden_states.shape
950
- )
951
-
952
- if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
953
- attention_mask = attn.prepare_attention_mask(
954
- attention_mask, sequence_length, batch_size
955
- )
956
- # scaled_dot_product_attention expects attention_mask shape to be
957
- # (batch, heads, source_length, target_length)
958
- attention_mask = attention_mask.view(
959
- batch_size, attn.heads, -1, attention_mask.shape[-1]
960
- )
961
-
962
- if attn.group_norm is not None:
963
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
964
- 1, 2
965
- )
966
-
967
- query = attn.to_q(hidden_states)
968
- query = attn.q_norm(query)
969
-
970
- if encoder_hidden_states is not None:
971
- if attn.norm_cross:
972
- encoder_hidden_states = attn.norm_encoder_hidden_states(
973
- encoder_hidden_states
974
- )
975
- key = attn.to_k(encoder_hidden_states)
976
- key = attn.k_norm(key)
977
- else: # if no context provided do self-attention
978
- encoder_hidden_states = hidden_states
979
- key = attn.to_k(hidden_states)
980
- key = attn.k_norm(key)
981
- if attn.use_rope:
982
- key = attn.apply_rotary_emb(key, freqs_cis)
983
- query = attn.apply_rotary_emb(query, freqs_cis)
984
-
985
- value = attn.to_v(encoder_hidden_states)
986
-
987
- inner_dim = key.shape[-1]
988
- head_dim = inner_dim // attn.heads
989
-
990
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
991
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
992
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
993
-
994
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
995
-
996
- if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
997
- q_segment_indexes = None
998
- if (
999
- attention_mask is not None
1000
- ): # if mask is required need to tune both segmenIds fields
1001
- # attention_mask = torch.squeeze(attention_mask).to(torch.float32)
1002
- attention_mask = attention_mask.to(torch.float32)
1003
- q_segment_indexes = torch.ones(
1004
- batch_size, query.shape[2], device=query.device, dtype=torch.float32
1005
- )
1006
- assert (
1007
- attention_mask.shape[1] == key.shape[2]
1008
- ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
1009
-
1010
- assert (
1011
- query.shape[2] % 128 == 0
1012
- ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]"
1013
- assert (
1014
- key.shape[2] % 128 == 0
1015
- ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]"
1016
-
1017
- # run the TPU kernel implemented in jax with pallas
1018
- hidden_states = flash_attention(
1019
- q=query,
1020
- k=key,
1021
- v=value,
1022
- q_segment_ids=q_segment_indexes,
1023
- kv_segment_ids=attention_mask,
1024
- sm_scale=attn.scale,
1025
- )
1026
- else:
1027
- hidden_states = F.scaled_dot_product_attention(
1028
- query,
1029
- key,
1030
- value,
1031
- attn_mask=attention_mask,
1032
- dropout_p=0.0,
1033
- is_causal=False,
1034
- )
1035
-
1036
- hidden_states = hidden_states.transpose(1, 2).reshape(
1037
- batch_size, -1, attn.heads * head_dim
1038
- )
1039
- hidden_states = hidden_states.to(query.dtype)
1040
-
1041
- # linear proj
1042
- hidden_states = attn.to_out[0](hidden_states)
1043
- # dropout
1044
- hidden_states = attn.to_out[1](hidden_states)
1045
-
1046
- if input_ndim == 4:
1047
- hidden_states = hidden_states.transpose(-1, -2).reshape(
1048
- batch_size, channel, height, width
1049
- )
1050
-
1051
- if attn.residual_connection:
1052
- hidden_states = hidden_states + residual
1053
-
1054
- hidden_states = hidden_states / attn.rescale_output_factor
1055
-
1056
- return hidden_states
1057
-
1058
-
1059
- class AttnProcessor:
1060
- r"""
1061
- Default processor for performing attention-related computations.
1062
- """
1063
-
1064
- def __call__(
1065
- self,
1066
- attn: Attention,
1067
- hidden_states: torch.FloatTensor,
1068
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1069
- attention_mask: Optional[torch.FloatTensor] = None,
1070
- temb: Optional[torch.FloatTensor] = None,
1071
- *args,
1072
- **kwargs,
1073
- ) -> torch.Tensor:
1074
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1075
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1076
- deprecate("scale", "1.0.0", deprecation_message)
1077
-
1078
- residual = hidden_states
1079
-
1080
- if attn.spatial_norm is not None:
1081
- hidden_states = attn.spatial_norm(hidden_states, temb)
1082
-
1083
- input_ndim = hidden_states.ndim
1084
-
1085
- if input_ndim == 4:
1086
- batch_size, channel, height, width = hidden_states.shape
1087
- hidden_states = hidden_states.view(
1088
- batch_size, channel, height * width
1089
- ).transpose(1, 2)
1090
-
1091
- batch_size, sequence_length, _ = (
1092
- hidden_states.shape
1093
- if encoder_hidden_states is None
1094
- else encoder_hidden_states.shape
1095
- )
1096
- attention_mask = attn.prepare_attention_mask(
1097
- attention_mask, sequence_length, batch_size
1098
- )
1099
-
1100
- if attn.group_norm is not None:
1101
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1102
- 1, 2
1103
- )
1104
-
1105
- query = attn.to_q(hidden_states)
1106
-
1107
- if encoder_hidden_states is None:
1108
- encoder_hidden_states = hidden_states
1109
- elif attn.norm_cross:
1110
- encoder_hidden_states = attn.norm_encoder_hidden_states(
1111
- encoder_hidden_states
1112
- )
1113
-
1114
- key = attn.to_k(encoder_hidden_states)
1115
- value = attn.to_v(encoder_hidden_states)
1116
-
1117
- query = attn.head_to_batch_dim(query)
1118
- key = attn.head_to_batch_dim(key)
1119
- value = attn.head_to_batch_dim(value)
1120
-
1121
- query = attn.q_norm(query)
1122
- key = attn.k_norm(key)
1123
-
1124
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
1125
- hidden_states = torch.bmm(attention_probs, value)
1126
- hidden_states = attn.batch_to_head_dim(hidden_states)
1127
-
1128
- # linear proj
1129
- hidden_states = attn.to_out[0](hidden_states)
1130
- # dropout
1131
- hidden_states = attn.to_out[1](hidden_states)
1132
-
1133
- if input_ndim == 4:
1134
- hidden_states = hidden_states.transpose(-1, -2).reshape(
1135
- batch_size, channel, height, width
1136
- )
1137
-
1138
- if attn.residual_connection:
1139
- hidden_states = hidden_states + residual
1140
-
1141
- hidden_states = hidden_states / attn.rescale_output_factor
1142
-
1143
- return hidden_states
1144
-
1145
-
1146
- class FeedForward(nn.Module):
1147
- r"""
1148
- A feed-forward layer.
1149
-
1150
- Parameters:
1151
- dim (`int`): The number of channels in the input.
1152
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1153
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1154
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1155
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1156
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1157
- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1158
- """
1159
-
1160
- def __init__(
1161
- self,
1162
- dim: int,
1163
- dim_out: Optional[int] = None,
1164
- mult: int = 4,
1165
- dropout: float = 0.0,
1166
- activation_fn: str = "geglu",
1167
- final_dropout: bool = False,
1168
- inner_dim=None,
1169
- bias: bool = True,
1170
- ):
1171
- super().__init__()
1172
- if inner_dim is None:
1173
- inner_dim = int(dim * mult)
1174
- dim_out = dim_out if dim_out is not None else dim
1175
- linear_cls = nn.Linear
1176
-
1177
- if activation_fn == "gelu":
1178
- act_fn = GELU(dim, inner_dim, bias=bias)
1179
- elif activation_fn == "gelu-approximate":
1180
- act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1181
- elif activation_fn == "geglu":
1182
- act_fn = GEGLU(dim, inner_dim, bias=bias)
1183
- elif activation_fn == "geglu-approximate":
1184
- act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1185
- else:
1186
- raise ValueError(f"Unsupported activation function: {activation_fn}")
1187
-
1188
- self.net = nn.ModuleList([])
1189
- # project in
1190
- self.net.append(act_fn)
1191
- # project dropout
1192
- self.net.append(nn.Dropout(dropout))
1193
- # project out
1194
- self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
1195
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1196
- if final_dropout:
1197
- self.net.append(nn.Dropout(dropout))
1198
-
1199
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
1200
- compatible_cls = (GEGLU, LoRACompatibleLinear)
1201
- for module in self.net:
1202
- if isinstance(module, compatible_cls):
1203
- hidden_states = module(hidden_states, scale)
1204
- else:
1205
- hidden_states = module(hidden_states)
1206
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/transformers/embeddings.py DELETED
@@ -1,129 +0,0 @@
1
- # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
2
- import math
3
-
4
- import numpy as np
5
- import torch
6
- from einops import rearrange
7
- from torch import nn
8
-
9
-
10
- def get_timestep_embedding(
11
- timesteps: torch.Tensor,
12
- embedding_dim: int,
13
- flip_sin_to_cos: bool = False,
14
- downscale_freq_shift: float = 1,
15
- scale: float = 1,
16
- max_period: int = 10000,
17
- ):
18
- """
19
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
20
-
21
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
22
- These may be fractional.
23
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
24
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
25
- """
26
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27
-
28
- half_dim = embedding_dim // 2
29
- exponent = -math.log(max_period) * torch.arange(
30
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
31
- )
32
- exponent = exponent / (half_dim - downscale_freq_shift)
33
-
34
- emb = torch.exp(exponent)
35
- emb = timesteps[:, None].float() * emb[None, :]
36
-
37
- # scale embeddings
38
- emb = scale * emb
39
-
40
- # concat sine and cosine embeddings
41
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
42
-
43
- # flip sine and cosine embeddings
44
- if flip_sin_to_cos:
45
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
46
-
47
- # zero pad
48
- if embedding_dim % 2 == 1:
49
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
50
- return emb
51
-
52
-
53
- def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
54
- """
55
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
56
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
57
- """
58
- grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
59
- grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
60
- grid = grid.reshape([3, 1, w, h, f])
61
- pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
62
- pos_embed = pos_embed.transpose(1, 0, 2, 3)
63
- return rearrange(pos_embed, "h w f c -> (f h w) c")
64
-
65
-
66
- def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
67
- if embed_dim % 3 != 0:
68
- raise ValueError("embed_dim must be divisible by 3")
69
-
70
- # use half of dimensions to encode grid_h
71
- emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
72
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
73
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
74
-
75
- emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
76
- return emb
77
-
78
-
79
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
80
- """
81
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
82
- """
83
- if embed_dim % 2 != 0:
84
- raise ValueError("embed_dim must be divisible by 2")
85
-
86
- omega = np.arange(embed_dim // 2, dtype=np.float64)
87
- omega /= embed_dim / 2.0
88
- omega = 1.0 / 10000**omega # (D/2,)
89
-
90
- pos_shape = pos.shape
91
-
92
- pos = pos.reshape(-1)
93
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
94
- out = out.reshape([*pos_shape, -1])[0]
95
-
96
- emb_sin = np.sin(out) # (M, D/2)
97
- emb_cos = np.cos(out) # (M, D/2)
98
-
99
- emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
100
- return emb
101
-
102
-
103
- class SinusoidalPositionalEmbedding(nn.Module):
104
- """Apply positional information to a sequence of embeddings.
105
-
106
- Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
107
- them
108
-
109
- Args:
110
- embed_dim: (int): Dimension of the positional embedding.
111
- max_seq_length: Maximum sequence length to apply positional embeddings
112
-
113
- """
114
-
115
- def __init__(self, embed_dim: int, max_seq_length: int = 32):
116
- super().__init__()
117
- position = torch.arange(max_seq_length).unsqueeze(1)
118
- div_term = torch.exp(
119
- torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
120
- )
121
- pe = torch.zeros(1, max_seq_length, embed_dim)
122
- pe[0, :, 0::2] = torch.sin(position * div_term)
123
- pe[0, :, 1::2] = torch.cos(position * div_term)
124
- self.register_buffer("pe", pe)
125
-
126
- def forward(self, x):
127
- _, seq_length, _ = x.shape
128
- x = x + self.pe[:, :seq_length]
129
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/transformers/symmetric_patchifier.py DELETED
@@ -1,96 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Tuple
3
-
4
- import torch
5
- from diffusers.configuration_utils import ConfigMixin
6
- from einops import rearrange
7
- from torch import Tensor
8
-
9
- from xora.utils.torch_utils import append_dims
10
-
11
-
12
- class Patchifier(ConfigMixin, ABC):
13
- def __init__(self, patch_size: int):
14
- super().__init__()
15
- self._patch_size = (1, patch_size, patch_size)
16
-
17
- @abstractmethod
18
- def patchify(
19
- self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
20
- ) -> Tuple[Tensor, Tensor]:
21
- pass
22
-
23
- @abstractmethod
24
- def unpatchify(
25
- self,
26
- latents: Tensor,
27
- output_height: int,
28
- output_width: int,
29
- output_num_frames: int,
30
- out_channels: int,
31
- ) -> Tuple[Tensor, Tensor]:
32
- pass
33
-
34
- @property
35
- def patch_size(self):
36
- return self._patch_size
37
-
38
- def get_grid(
39
- self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
40
- ):
41
- f = orig_num_frames // self._patch_size[0]
42
- h = orig_height // self._patch_size[1]
43
- w = orig_width // self._patch_size[2]
44
- grid_h = torch.arange(h, dtype=torch.float32, device=device)
45
- grid_w = torch.arange(w, dtype=torch.float32, device=device)
46
- grid_f = torch.arange(f, dtype=torch.float32, device=device)
47
- grid = torch.meshgrid(grid_f, grid_h, grid_w)
48
- grid = torch.stack(grid, dim=0)
49
- grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
50
-
51
- if scale_grid is not None:
52
- for i in range(3):
53
- if isinstance(scale_grid[i], Tensor):
54
- scale = append_dims(scale_grid[i], grid.ndim - 1)
55
- else:
56
- scale = scale_grid[i]
57
- grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
58
-
59
- grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
60
- return grid
61
-
62
-
63
- class SymmetricPatchifier(Patchifier):
64
- def patchify(
65
- self,
66
- latents: Tensor,
67
- ) -> Tuple[Tensor, Tensor]:
68
- latents = rearrange(
69
- latents,
70
- "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
71
- p1=self._patch_size[0],
72
- p2=self._patch_size[1],
73
- p3=self._patch_size[2],
74
- )
75
- return latents
76
-
77
- def unpatchify(
78
- self,
79
- latents: Tensor,
80
- output_height: int,
81
- output_width: int,
82
- output_num_frames: int,
83
- out_channels: int,
84
- ) -> Tuple[Tensor, Tensor]:
85
- output_height = output_height // self._patch_size[1]
86
- output_width = output_width // self._patch_size[2]
87
- latents = rearrange(
88
- latents,
89
- "b (f h w) (c p q) -> b c f (h p) (w q) ",
90
- f=output_num_frames,
91
- h=output_height,
92
- w=output_width,
93
- p=self._patch_size[1],
94
- q=self._patch_size[2],
95
- )
96
- return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/models/transformers/transformer3d.py DELETED
@@ -1,491 +0,0 @@
1
- # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
2
- import math
3
- from dataclasses import dataclass
4
- from typing import Any, Dict, List, Optional, Literal
5
-
6
- import torch
7
- from diffusers.configuration_utils import ConfigMixin, register_to_config
8
- from diffusers.models.embeddings import PixArtAlphaTextProjection
9
- from diffusers.models.modeling_utils import ModelMixin
10
- from diffusers.models.normalization import AdaLayerNormSingle
11
- from diffusers.utils import BaseOutput, is_torch_version
12
- from diffusers.utils import logging
13
- from torch import nn
14
-
15
- from xora.models.transformers.attention import BasicTransformerBlock
16
- from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
17
-
18
- logger = logging.get_logger(__name__)
19
-
20
-
21
- @dataclass
22
- class Transformer3DModelOutput(BaseOutput):
23
- """
24
- The output of [`Transformer2DModel`].
25
-
26
- Args:
27
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
28
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
29
- distributions for the unnoised latent pixels.
30
- """
31
-
32
- sample: torch.FloatTensor
33
-
34
-
35
- class Transformer3DModel(ModelMixin, ConfigMixin):
36
- _supports_gradient_checkpointing = True
37
-
38
- @register_to_config
39
- def __init__(
40
- self,
41
- num_attention_heads: int = 16,
42
- attention_head_dim: int = 88,
43
- in_channels: Optional[int] = None,
44
- out_channels: Optional[int] = None,
45
- num_layers: int = 1,
46
- dropout: float = 0.0,
47
- norm_num_groups: int = 32,
48
- cross_attention_dim: Optional[int] = None,
49
- attention_bias: bool = False,
50
- num_vector_embeds: Optional[int] = None,
51
- activation_fn: str = "geglu",
52
- num_embeds_ada_norm: Optional[int] = None,
53
- use_linear_projection: bool = False,
54
- only_cross_attention: bool = False,
55
- double_self_attention: bool = False,
56
- upcast_attention: bool = False,
57
- adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
58
- standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
59
- norm_elementwise_affine: bool = True,
60
- norm_eps: float = 1e-5,
61
- attention_type: str = "default",
62
- caption_channels: int = None,
63
- project_to_2d_pos: bool = False,
64
- use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
65
- qk_norm: Optional[str] = None,
66
- positional_embedding_type: str = "absolute",
67
- positional_embedding_theta: Optional[float] = None,
68
- positional_embedding_max_pos: Optional[List[int]] = None,
69
- timestep_scale_multiplier: Optional[float] = None,
70
- ):
71
- super().__init__()
72
- self.use_tpu_flash_attention = (
73
- use_tpu_flash_attention # FIXME: push config down to the attention modules
74
- )
75
- self.use_linear_projection = use_linear_projection
76
- self.num_attention_heads = num_attention_heads
77
- self.attention_head_dim = attention_head_dim
78
- inner_dim = num_attention_heads * attention_head_dim
79
- self.inner_dim = inner_dim
80
-
81
- self.project_to_2d_pos = project_to_2d_pos
82
-
83
- self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
84
-
85
- self.positional_embedding_type = positional_embedding_type
86
- self.positional_embedding_theta = positional_embedding_theta
87
- self.positional_embedding_max_pos = positional_embedding_max_pos
88
- self.use_rope = self.positional_embedding_type == "rope"
89
- self.timestep_scale_multiplier = timestep_scale_multiplier
90
-
91
- if self.positional_embedding_type == "absolute":
92
- embed_dim_3d = (
93
- math.ceil((inner_dim / 2) * 3) if project_to_2d_pos else inner_dim
94
- )
95
- if self.project_to_2d_pos:
96
- self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False)
97
- self._init_to_2d_proj_weights(self.to_2d_proj)
98
- elif self.positional_embedding_type == "rope":
99
- if positional_embedding_theta is None:
100
- raise ValueError(
101
- "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
102
- )
103
- if positional_embedding_max_pos is None:
104
- raise ValueError(
105
- "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
106
- )
107
-
108
- # 3. Define transformers blocks
109
- self.transformer_blocks = nn.ModuleList(
110
- [
111
- BasicTransformerBlock(
112
- inner_dim,
113
- num_attention_heads,
114
- attention_head_dim,
115
- dropout=dropout,
116
- cross_attention_dim=cross_attention_dim,
117
- activation_fn=activation_fn,
118
- num_embeds_ada_norm=num_embeds_ada_norm,
119
- attention_bias=attention_bias,
120
- only_cross_attention=only_cross_attention,
121
- double_self_attention=double_self_attention,
122
- upcast_attention=upcast_attention,
123
- adaptive_norm=adaptive_norm,
124
- standardization_norm=standardization_norm,
125
- norm_elementwise_affine=norm_elementwise_affine,
126
- norm_eps=norm_eps,
127
- attention_type=attention_type,
128
- use_tpu_flash_attention=use_tpu_flash_attention,
129
- qk_norm=qk_norm,
130
- use_rope=self.use_rope,
131
- )
132
- for d in range(num_layers)
133
- ]
134
- )
135
-
136
- # 4. Define output layers
137
- self.out_channels = in_channels if out_channels is None else out_channels
138
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
139
- self.scale_shift_table = nn.Parameter(
140
- torch.randn(2, inner_dim) / inner_dim**0.5
141
- )
142
- self.proj_out = nn.Linear(inner_dim, self.out_channels)
143
-
144
- self.adaln_single = AdaLayerNormSingle(
145
- inner_dim, use_additional_conditions=False
146
- )
147
- if adaptive_norm == "single_scale":
148
- self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
149
-
150
- self.caption_projection = None
151
- if caption_channels is not None:
152
- self.caption_projection = PixArtAlphaTextProjection(
153
- in_features=caption_channels, hidden_size=inner_dim
154
- )
155
-
156
- self.gradient_checkpointing = False
157
-
158
- def set_use_tpu_flash_attention(self):
159
- r"""
160
- Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
161
- attention kernel.
162
- """
163
- logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
164
- self.use_tpu_flash_attention = True
165
- # push config down to the attention modules
166
- for block in self.transformer_blocks:
167
- block.set_use_tpu_flash_attention()
168
-
169
- def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
170
- def _basic_init(module):
171
- if isinstance(module, nn.Linear):
172
- torch.nn.init.xavier_uniform_(module.weight)
173
- if module.bias is not None:
174
- nn.init.constant_(module.bias, 0)
175
-
176
- self.apply(_basic_init)
177
-
178
- # Initialize timestep embedding MLP:
179
- nn.init.normal_(
180
- self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std
181
- )
182
- nn.init.normal_(
183
- self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std
184
- )
185
- nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
186
-
187
- if hasattr(self.adaln_single.emb, "resolution_embedder"):
188
- nn.init.normal_(
189
- self.adaln_single.emb.resolution_embedder.linear_1.weight,
190
- std=embedding_std,
191
- )
192
- nn.init.normal_(
193
- self.adaln_single.emb.resolution_embedder.linear_2.weight,
194
- std=embedding_std,
195
- )
196
- if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
197
- nn.init.normal_(
198
- self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight,
199
- std=embedding_std,
200
- )
201
- nn.init.normal_(
202
- self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight,
203
- std=embedding_std,
204
- )
205
-
206
- # Initialize caption embedding MLP:
207
- nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
208
- nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
209
-
210
- for block in self.transformer_blocks:
211
- if mode.lower() == "xora":
212
- nn.init.constant_(block.attn1.to_out[0].weight, 0)
213
- nn.init.constant_(block.attn1.to_out[0].bias, 0)
214
-
215
- nn.init.constant_(block.attn2.to_out[0].weight, 0)
216
- nn.init.constant_(block.attn2.to_out[0].bias, 0)
217
-
218
- if mode.lower() == "xora":
219
- nn.init.constant_(block.ff.net[2].weight, 0)
220
- nn.init.constant_(block.ff.net[2].bias, 0)
221
-
222
- # Zero-out output layers:
223
- nn.init.constant_(self.proj_out.weight, 0)
224
- nn.init.constant_(self.proj_out.bias, 0)
225
-
226
- def _set_gradient_checkpointing(self, module, value=False):
227
- if hasattr(module, "gradient_checkpointing"):
228
- module.gradient_checkpointing = value
229
-
230
- @staticmethod
231
- def _init_to_2d_proj_weights(linear_layer):
232
- input_features = linear_layer.weight.data.size(1)
233
- output_features = linear_layer.weight.data.size(0)
234
-
235
- # Start with a zero matrix
236
- identity_like = torch.zeros((output_features, input_features))
237
-
238
- # Fill the diagonal with 1's as much as possible
239
- min_features = min(output_features, input_features)
240
- identity_like[:min_features, :min_features] = torch.eye(min_features)
241
- linear_layer.weight.data = identity_like.to(linear_layer.weight.data.device)
242
-
243
- def get_fractional_positions(self, indices_grid):
244
- fractional_positions = torch.stack(
245
- [
246
- indices_grid[:, i] / self.positional_embedding_max_pos[i]
247
- for i in range(3)
248
- ],
249
- dim=-1,
250
- )
251
- return fractional_positions
252
-
253
- def precompute_freqs_cis(self, indices_grid, spacing="exp"):
254
- dtype = torch.float32 # We need full precision in the freqs_cis computation.
255
- dim = self.inner_dim
256
- theta = self.positional_embedding_theta
257
-
258
- fractional_positions = self.get_fractional_positions(indices_grid)
259
-
260
- start = 1
261
- end = theta
262
- device = fractional_positions.device
263
- if spacing == "exp":
264
- indices = theta ** (
265
- torch.linspace(
266
- math.log(start, theta),
267
- math.log(end, theta),
268
- dim // 6,
269
- device=device,
270
- dtype=dtype,
271
- )
272
- )
273
- indices = indices.to(dtype=dtype)
274
- elif spacing == "exp_2":
275
- indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
276
- indices = indices.to(dtype=dtype)
277
- elif spacing == "linear":
278
- indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
279
- elif spacing == "sqrt":
280
- indices = torch.linspace(
281
- start**2, end**2, dim // 6, device=device, dtype=dtype
282
- ).sqrt()
283
-
284
- indices = indices * math.pi / 2
285
-
286
- if spacing == "exp_2":
287
- freqs = (
288
- (indices * fractional_positions.unsqueeze(-1))
289
- .transpose(-1, -2)
290
- .flatten(2)
291
- )
292
- else:
293
- freqs = (
294
- (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
295
- .transpose(-1, -2)
296
- .flatten(2)
297
- )
298
-
299
- cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
300
- sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
301
- if dim % 6 != 0:
302
- cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
303
- sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
304
- cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
305
- sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
306
- return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
307
-
308
- def forward(
309
- self,
310
- hidden_states: torch.Tensor,
311
- indices_grid: torch.Tensor,
312
- encoder_hidden_states: Optional[torch.Tensor] = None,
313
- timestep: Optional[torch.LongTensor] = None,
314
- class_labels: Optional[torch.LongTensor] = None,
315
- cross_attention_kwargs: Dict[str, Any] = None,
316
- attention_mask: Optional[torch.Tensor] = None,
317
- encoder_attention_mask: Optional[torch.Tensor] = None,
318
- return_dict: bool = True,
319
- ):
320
- """
321
- The [`Transformer2DModel`] forward method.
322
-
323
- Args:
324
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
325
- Input `hidden_states`.
326
- indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
327
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
328
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
329
- self-attention.
330
- timestep ( `torch.LongTensor`, *optional*):
331
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
332
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
333
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
334
- `AdaLayerZeroNorm`.
335
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
336
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
337
- `self.processor` in
338
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
339
- attention_mask ( `torch.Tensor`, *optional*):
340
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
341
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
342
- negative values to the attention scores corresponding to "discard" tokens.
343
- encoder_attention_mask ( `torch.Tensor`, *optional*):
344
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
345
-
346
- * Mask `(batch, sequence_length)` True = keep, False = discard.
347
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
348
-
349
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
350
- above. This bias will be added to the cross-attention scores.
351
- return_dict (`bool`, *optional*, defaults to `True`):
352
- Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
353
- tuple.
354
-
355
- Returns:
356
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
357
- `tuple` where the first element is the sample tensor.
358
- """
359
- # for tpu attention offload 2d token masks are used. No need to transform.
360
- if not self.use_tpu_flash_attention:
361
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
362
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
363
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
364
- # expects mask of shape:
365
- # [batch, key_tokens]
366
- # adds singleton query_tokens dimension:
367
- # [batch, 1, key_tokens]
368
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
369
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
370
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
371
- if attention_mask is not None and attention_mask.ndim == 2:
372
- # assume that mask is expressed as:
373
- # (1 = keep, 0 = discard)
374
- # convert mask into a bias that can be added to attention scores:
375
- # (keep = +0, discard = -10000.0)
376
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
377
- attention_mask = attention_mask.unsqueeze(1)
378
-
379
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
380
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
381
- encoder_attention_mask = (
382
- 1 - encoder_attention_mask.to(hidden_states.dtype)
383
- ) * -10000.0
384
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
385
-
386
- # 1. Input
387
- hidden_states = self.patchify_proj(hidden_states)
388
-
389
- if self.timestep_scale_multiplier:
390
- timestep = self.timestep_scale_multiplier * timestep
391
-
392
- if self.positional_embedding_type == "absolute":
393
- pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(
394
- hidden_states.device
395
- )
396
- if self.project_to_2d_pos:
397
- pos_embed = self.to_2d_proj(pos_embed_3d)
398
- hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
399
- freqs_cis = None
400
- elif self.positional_embedding_type == "rope":
401
- freqs_cis = self.precompute_freqs_cis(indices_grid)
402
-
403
- batch_size = hidden_states.shape[0]
404
- timestep, embedded_timestep = self.adaln_single(
405
- timestep.flatten(),
406
- {"resolution": None, "aspect_ratio": None},
407
- batch_size=batch_size,
408
- hidden_dtype=hidden_states.dtype,
409
- )
410
- # Second dimension is 1 or number of tokens (if timestep_per_token)
411
- timestep = timestep.view(batch_size, -1, timestep.shape[-1])
412
- embedded_timestep = embedded_timestep.view(
413
- batch_size, -1, embedded_timestep.shape[-1]
414
- )
415
-
416
- # 2. Blocks
417
- if self.caption_projection is not None:
418
- batch_size = hidden_states.shape[0]
419
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
420
- encoder_hidden_states = encoder_hidden_states.view(
421
- batch_size, -1, hidden_states.shape[-1]
422
- )
423
-
424
- for block in self.transformer_blocks:
425
- if self.training and self.gradient_checkpointing:
426
-
427
- def create_custom_forward(module, return_dict=None):
428
- def custom_forward(*inputs):
429
- if return_dict is not None:
430
- return module(*inputs, return_dict=return_dict)
431
- else:
432
- return module(*inputs)
433
-
434
- return custom_forward
435
-
436
- ckpt_kwargs: Dict[str, Any] = (
437
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
438
- )
439
- hidden_states = torch.utils.checkpoint.checkpoint(
440
- create_custom_forward(block),
441
- hidden_states,
442
- freqs_cis,
443
- attention_mask,
444
- encoder_hidden_states,
445
- encoder_attention_mask,
446
- timestep,
447
- cross_attention_kwargs,
448
- class_labels,
449
- **ckpt_kwargs,
450
- )
451
- else:
452
- hidden_states = block(
453
- hidden_states,
454
- freqs_cis=freqs_cis,
455
- attention_mask=attention_mask,
456
- encoder_hidden_states=encoder_hidden_states,
457
- encoder_attention_mask=encoder_attention_mask,
458
- timestep=timestep,
459
- cross_attention_kwargs=cross_attention_kwargs,
460
- class_labels=class_labels,
461
- )
462
-
463
- # 3. Output
464
- scale_shift_values = (
465
- self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
466
- )
467
- shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
468
- hidden_states = self.norm_out(hidden_states)
469
- # Modulation
470
- hidden_states = hidden_states * (1 + scale) + shift
471
- hidden_states = self.proj_out(hidden_states)
472
- if not return_dict:
473
- return (hidden_states,)
474
-
475
- return Transformer3DModelOutput(sample=hidden_states)
476
-
477
- def get_absolute_pos_embed(self, grid):
478
- grid_np = grid[0].cpu().numpy()
479
- embed_dim_3d = (
480
- math.ceil((self.inner_dim / 2) * 3)
481
- if self.project_to_2d_pos
482
- else self.inner_dim
483
- )
484
- pos_embed = get_3d_sincos_pos_embed( # (f h w)
485
- embed_dim_3d,
486
- grid_np,
487
- h=int(max(grid_np[1]) + 1),
488
- w=int(max(grid_np[2]) + 1),
489
- f=int(max(grid_np[0] + 1)),
490
- )
491
- return torch.from_numpy(pos_embed).float().unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/pipelines/__init__.py DELETED
File without changes
xora/pipelines/pipeline_xora_video.py DELETED
@@ -1,1162 +0,0 @@
1
- # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
2
- import html
3
- import inspect
4
- import math
5
- import re
6
- import urllib.parse as ul
7
- from typing import Callable, Dict, List, Optional, Tuple, Union
8
-
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- from contextlib import nullcontext
13
- from diffusers.image_processor import VaeImageProcessor
14
- from diffusers.models import AutoencoderKL
15
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
16
- from diffusers.schedulers import DPMSolverMultistepScheduler
17
- from diffusers.utils import (
18
- BACKENDS_MAPPING,
19
- deprecate,
20
- is_bs4_available,
21
- is_ftfy_available,
22
- logging,
23
- )
24
- from diffusers.utils.torch_utils import randn_tensor
25
- from einops import rearrange
26
- from transformers import T5EncoderModel, T5Tokenizer
27
-
28
- from xora.models.transformers.transformer3d import Transformer3DModel
29
- from xora.models.transformers.symmetric_patchifier import Patchifier
30
- from xora.models.autoencoders.vae_encode import (
31
- get_vae_size_scale_factor,
32
- vae_decode,
33
- vae_encode,
34
- )
35
- from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
36
- from xora.schedulers.rf import TimestepShifter
37
- from xora.utils.conditioning_method import ConditioningMethod
38
-
39
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
-
41
- if is_bs4_available():
42
- from bs4 import BeautifulSoup
43
-
44
- if is_ftfy_available():
45
- import ftfy
46
-
47
- ASPECT_RATIO_1024_BIN = {
48
- "0.25": [512.0, 2048.0],
49
- "0.28": [512.0, 1856.0],
50
- "0.32": [576.0, 1792.0],
51
- "0.33": [576.0, 1728.0],
52
- "0.35": [576.0, 1664.0],
53
- "0.4": [640.0, 1600.0],
54
- "0.42": [640.0, 1536.0],
55
- "0.48": [704.0, 1472.0],
56
- "0.5": [704.0, 1408.0],
57
- "0.52": [704.0, 1344.0],
58
- "0.57": [768.0, 1344.0],
59
- "0.6": [768.0, 1280.0],
60
- "0.68": [832.0, 1216.0],
61
- "0.72": [832.0, 1152.0],
62
- "0.78": [896.0, 1152.0],
63
- "0.82": [896.0, 1088.0],
64
- "0.88": [960.0, 1088.0],
65
- "0.94": [960.0, 1024.0],
66
- "1.0": [1024.0, 1024.0],
67
- "1.07": [1024.0, 960.0],
68
- "1.13": [1088.0, 960.0],
69
- "1.21": [1088.0, 896.0],
70
- "1.29": [1152.0, 896.0],
71
- "1.38": [1152.0, 832.0],
72
- "1.46": [1216.0, 832.0],
73
- "1.67": [1280.0, 768.0],
74
- "1.75": [1344.0, 768.0],
75
- "2.0": [1408.0, 704.0],
76
- "2.09": [1472.0, 704.0],
77
- "2.4": [1536.0, 640.0],
78
- "2.5": [1600.0, 640.0],
79
- "3.0": [1728.0, 576.0],
80
- "4.0": [2048.0, 512.0],
81
- }
82
-
83
- ASPECT_RATIO_512_BIN = {
84
- "0.25": [256.0, 1024.0],
85
- "0.28": [256.0, 928.0],
86
- "0.32": [288.0, 896.0],
87
- "0.33": [288.0, 864.0],
88
- "0.35": [288.0, 832.0],
89
- "0.4": [320.0, 800.0],
90
- "0.42": [320.0, 768.0],
91
- "0.48": [352.0, 736.0],
92
- "0.5": [352.0, 704.0],
93
- "0.52": [352.0, 672.0],
94
- "0.57": [384.0, 672.0],
95
- "0.6": [384.0, 640.0],
96
- "0.68": [416.0, 608.0],
97
- "0.72": [416.0, 576.0],
98
- "0.78": [448.0, 576.0],
99
- "0.82": [448.0, 544.0],
100
- "0.88": [480.0, 544.0],
101
- "0.94": [480.0, 512.0],
102
- "1.0": [512.0, 512.0],
103
- "1.07": [512.0, 480.0],
104
- "1.13": [544.0, 480.0],
105
- "1.21": [544.0, 448.0],
106
- "1.29": [576.0, 448.0],
107
- "1.38": [576.0, 416.0],
108
- "1.46": [608.0, 416.0],
109
- "1.67": [640.0, 384.0],
110
- "1.75": [672.0, 384.0],
111
- "2.0": [704.0, 352.0],
112
- "2.09": [736.0, 352.0],
113
- "2.4": [768.0, 320.0],
114
- "2.5": [800.0, 320.0],
115
- "3.0": [864.0, 288.0],
116
- "4.0": [1024.0, 256.0],
117
- }
118
-
119
-
120
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
121
- def retrieve_timesteps(
122
- scheduler,
123
- num_inference_steps: Optional[int] = None,
124
- device: Optional[Union[str, torch.device]] = None,
125
- timesteps: Optional[List[int]] = None,
126
- **kwargs,
127
- ):
128
- """
129
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
130
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
131
-
132
- Args:
133
- scheduler (`SchedulerMixin`):
134
- The scheduler to get timesteps from.
135
- num_inference_steps (`int`):
136
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
137
- `timesteps` must be `None`.
138
- device (`str` or `torch.device`, *optional*):
139
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
140
- timesteps (`List[int]`, *optional*):
141
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
142
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
143
- must be `None`.
144
-
145
- Returns:
146
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
147
- second element is the number of inference steps.
148
- """
149
- if timesteps is not None:
150
- accepts_timesteps = "timesteps" in set(
151
- inspect.signature(scheduler.set_timesteps).parameters.keys()
152
- )
153
- if not accepts_timesteps:
154
- raise ValueError(
155
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
156
- f" timestep schedules. Please check whether you are using the correct scheduler."
157
- )
158
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
159
- timesteps = scheduler.timesteps
160
- num_inference_steps = len(timesteps)
161
- else:
162
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
163
- timesteps = scheduler.timesteps
164
- return timesteps, num_inference_steps
165
-
166
-
167
- class XoraVideoPipeline(DiffusionPipeline):
168
- r"""
169
- Pipeline for text-to-image generation using Xora.
170
-
171
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
172
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
173
-
174
- Args:
175
- vae ([`AutoencoderKL`]):
176
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
177
- text_encoder ([`T5EncoderModel`]):
178
- Frozen text-encoder. This uses
179
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
180
- [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
181
- tokenizer (`T5Tokenizer`):
182
- Tokenizer of class
183
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
184
- transformer ([`Transformer2DModel`]):
185
- A text conditioned `Transformer2DModel` to denoise the encoded image latents.
186
- scheduler ([`SchedulerMixin`]):
187
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
188
- """
189
-
190
- bad_punct_regex = re.compile(
191
- r"["
192
- + "#®•©™&@·º½¾¿¡§~"
193
- + r"\)"
194
- + r"\("
195
- + r"\]"
196
- + r"\["
197
- + r"\}"
198
- + r"\{"
199
- + r"\|"
200
- + "\\"
201
- + r"\/"
202
- + r"\*"
203
- + r"]{1,}"
204
- ) # noqa
205
-
206
- _optional_components = ["tokenizer", "text_encoder"]
207
- model_cpu_offload_seq = "text_encoder->transformer->vae"
208
-
209
- def __init__(
210
- self,
211
- tokenizer: T5Tokenizer,
212
- text_encoder: T5EncoderModel,
213
- vae: AutoencoderKL,
214
- transformer: Transformer3DModel,
215
- scheduler: DPMSolverMultistepScheduler,
216
- patchifier: Patchifier,
217
- ):
218
- super().__init__()
219
-
220
- self.register_modules(
221
- tokenizer=tokenizer,
222
- text_encoder=text_encoder,
223
- vae=vae,
224
- transformer=transformer,
225
- scheduler=scheduler,
226
- patchifier=patchifier,
227
- )
228
-
229
- self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
230
- self.vae
231
- )
232
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
233
-
234
- def mask_text_embeddings(self, emb, mask):
235
- if emb.shape[0] == 1:
236
- keep_index = mask.sum().item()
237
- return emb[:, :, :keep_index, :], keep_index
238
- else:
239
- masked_feature = emb * mask[:, None, :, None]
240
- return masked_feature, emb.shape[2]
241
-
242
- # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
243
- def encode_prompt(
244
- self,
245
- prompt: Union[str, List[str]],
246
- do_classifier_free_guidance: bool = True,
247
- negative_prompt: str = "",
248
- num_images_per_prompt: int = 1,
249
- device: Optional[torch.device] = None,
250
- prompt_embeds: Optional[torch.FloatTensor] = None,
251
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
252
- prompt_attention_mask: Optional[torch.FloatTensor] = None,
253
- negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
254
- clean_caption: bool = False,
255
- **kwargs,
256
- ):
257
- r"""
258
- Encodes the prompt into text encoder hidden states.
259
-
260
- Args:
261
- prompt (`str` or `List[str]`, *optional*):
262
- prompt to be encoded
263
- negative_prompt (`str` or `List[str]`, *optional*):
264
- The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
265
- instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
266
- This should be "".
267
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
268
- whether to use classifier free guidance or not
269
- num_images_per_prompt (`int`, *optional*, defaults to 1):
270
- number of images that should be generated per prompt
271
- device: (`torch.device`, *optional*):
272
- torch device to place the resulting embeddings on
273
- prompt_embeds (`torch.FloatTensor`, *optional*):
274
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
275
- provided, text embeddings will be generated from `prompt` input argument.
276
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
277
- Pre-generated negative text embeddings.
278
- clean_caption (bool, defaults to `False`):
279
- If `True`, the function will preprocess and clean the provided caption before encoding.
280
- """
281
-
282
- if "mask_feature" in kwargs:
283
- deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
284
- deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
285
-
286
- if device is None:
287
- device = self._execution_device
288
-
289
- if prompt is not None and isinstance(prompt, str):
290
- batch_size = 1
291
- elif prompt is not None and isinstance(prompt, list):
292
- batch_size = len(prompt)
293
- else:
294
- batch_size = prompt_embeds.shape[0]
295
-
296
- # See Section 3.1. of the paper.
297
- # FIXME: to be configured in config not hardecoded. Fix in separate PR with rest of config
298
- max_length = 128 # TPU supports only lengths multiple of 128
299
-
300
- if prompt_embeds is None:
301
- prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
302
- text_inputs = self.tokenizer(
303
- prompt,
304
- padding="max_length",
305
- max_length=max_length,
306
- truncation=True,
307
- add_special_tokens=True,
308
- return_tensors="pt",
309
- )
310
- text_input_ids = text_inputs.input_ids
311
- untruncated_ids = self.tokenizer(
312
- prompt, padding="longest", return_tensors="pt"
313
- ).input_ids
314
-
315
- if untruncated_ids.shape[-1] >= text_input_ids.shape[
316
- -1
317
- ] and not torch.equal(text_input_ids, untruncated_ids):
318
- removed_text = self.tokenizer.batch_decode(
319
- untruncated_ids[:, max_length - 1 : -1]
320
- )
321
- logger.warning(
322
- "The following part of your input was truncated because CLIP can only handle sequences up to"
323
- f" {max_length} tokens: {removed_text}"
324
- )
325
-
326
- prompt_attention_mask = text_inputs.attention_mask
327
- prompt_attention_mask = prompt_attention_mask.to(device)
328
-
329
- prompt_embeds = self.text_encoder(
330
- text_input_ids.to(device), attention_mask=prompt_attention_mask
331
- )
332
- prompt_embeds = prompt_embeds[0]
333
-
334
- if self.text_encoder is not None:
335
- dtype = self.text_encoder.dtype
336
- elif self.transformer is not None:
337
- dtype = self.transformer.dtype
338
- else:
339
- dtype = None
340
-
341
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
342
-
343
- bs_embed, seq_len, _ = prompt_embeds.shape
344
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
345
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
346
- prompt_embeds = prompt_embeds.view(
347
- bs_embed * num_images_per_prompt, seq_len, -1
348
- )
349
- prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
350
- prompt_attention_mask = prompt_attention_mask.view(
351
- bs_embed * num_images_per_prompt, -1
352
- )
353
-
354
- # get unconditional embeddings for classifier free guidance
355
- if do_classifier_free_guidance and negative_prompt_embeds is None:
356
- uncond_tokens = [negative_prompt] * batch_size
357
- uncond_tokens = self._text_preprocessing(
358
- uncond_tokens, clean_caption=clean_caption
359
- )
360
- max_length = prompt_embeds.shape[1]
361
- uncond_input = self.tokenizer(
362
- uncond_tokens,
363
- padding="max_length",
364
- max_length=max_length,
365
- truncation=True,
366
- return_attention_mask=True,
367
- add_special_tokens=True,
368
- return_tensors="pt",
369
- )
370
- negative_prompt_attention_mask = uncond_input.attention_mask
371
- negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
372
-
373
- negative_prompt_embeds = self.text_encoder(
374
- uncond_input.input_ids.to(device),
375
- attention_mask=negative_prompt_attention_mask,
376
- )
377
- negative_prompt_embeds = negative_prompt_embeds[0]
378
-
379
- if do_classifier_free_guidance:
380
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
381
- seq_len = negative_prompt_embeds.shape[1]
382
-
383
- negative_prompt_embeds = negative_prompt_embeds.to(
384
- dtype=dtype, device=device
385
- )
386
-
387
- negative_prompt_embeds = negative_prompt_embeds.repeat(
388
- 1, num_images_per_prompt, 1
389
- )
390
- negative_prompt_embeds = negative_prompt_embeds.view(
391
- batch_size * num_images_per_prompt, seq_len, -1
392
- )
393
-
394
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
395
- 1, num_images_per_prompt
396
- )
397
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(
398
- bs_embed * num_images_per_prompt, -1
399
- )
400
- else:
401
- negative_prompt_embeds = None
402
- negative_prompt_attention_mask = None
403
-
404
- return (
405
- prompt_embeds,
406
- prompt_attention_mask,
407
- negative_prompt_embeds,
408
- negative_prompt_attention_mask,
409
- )
410
-
411
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
412
- def prepare_extra_step_kwargs(self, generator, eta):
413
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
414
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
415
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
416
- # and should be between [0, 1]
417
-
418
- accepts_eta = "eta" in set(
419
- inspect.signature(self.scheduler.step).parameters.keys()
420
- )
421
- extra_step_kwargs = {}
422
- if accepts_eta:
423
- extra_step_kwargs["eta"] = eta
424
-
425
- # check if the scheduler accepts generator
426
- accepts_generator = "generator" in set(
427
- inspect.signature(self.scheduler.step).parameters.keys()
428
- )
429
- if accepts_generator:
430
- extra_step_kwargs["generator"] = generator
431
- return extra_step_kwargs
432
-
433
- def check_inputs(
434
- self,
435
- prompt,
436
- height,
437
- width,
438
- negative_prompt,
439
- prompt_embeds=None,
440
- negative_prompt_embeds=None,
441
- prompt_attention_mask=None,
442
- negative_prompt_attention_mask=None,
443
- ):
444
- if height % 8 != 0 or width % 8 != 0:
445
- raise ValueError(
446
- f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
447
- )
448
-
449
- if prompt is not None and prompt_embeds is not None:
450
- raise ValueError(
451
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
452
- " only forward one of the two."
453
- )
454
- elif prompt is None and prompt_embeds is None:
455
- raise ValueError(
456
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
457
- )
458
- elif prompt is not None and (
459
- not isinstance(prompt, str) and not isinstance(prompt, list)
460
- ):
461
- raise ValueError(
462
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
463
- )
464
-
465
- if prompt is not None and negative_prompt_embeds is not None:
466
- raise ValueError(
467
- f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
468
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
469
- )
470
-
471
- if negative_prompt is not None and negative_prompt_embeds is not None:
472
- raise ValueError(
473
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
474
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
475
- )
476
-
477
- if prompt_embeds is not None and prompt_attention_mask is None:
478
- raise ValueError(
479
- "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
480
- )
481
-
482
- if (
483
- negative_prompt_embeds is not None
484
- and negative_prompt_attention_mask is None
485
- ):
486
- raise ValueError(
487
- "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
488
- )
489
-
490
- if prompt_embeds is not None and negative_prompt_embeds is not None:
491
- if prompt_embeds.shape != negative_prompt_embeds.shape:
492
- raise ValueError(
493
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
494
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
495
- f" {negative_prompt_embeds.shape}."
496
- )
497
- if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
498
- raise ValueError(
499
- "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
500
- f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
501
- f" {negative_prompt_attention_mask.shape}."
502
- )
503
-
504
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
505
- def _text_preprocessing(self, text, clean_caption=False):
506
- if clean_caption and not is_bs4_available():
507
- logger.warn(
508
- BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")
509
- )
510
- logger.warn("Setting `clean_caption` to False...")
511
- clean_caption = False
512
-
513
- if clean_caption and not is_ftfy_available():
514
- logger.warn(
515
- BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")
516
- )
517
- logger.warn("Setting `clean_caption` to False...")
518
- clean_caption = False
519
-
520
- if not isinstance(text, (tuple, list)):
521
- text = [text]
522
-
523
- def process(text: str):
524
- if clean_caption:
525
- text = self._clean_caption(text)
526
- text = self._clean_caption(text)
527
- else:
528
- text = text.lower().strip()
529
- return text
530
-
531
- return [process(t) for t in text]
532
-
533
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
534
- def _clean_caption(self, caption):
535
- caption = str(caption)
536
- caption = ul.unquote_plus(caption)
537
- caption = caption.strip().lower()
538
- caption = re.sub("<person>", "person", caption)
539
- # urls:
540
- caption = re.sub(
541
- r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
542
- "",
543
- caption,
544
- ) # regex for urls
545
- caption = re.sub(
546
- r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
547
- "",
548
- caption,
549
- ) # regex for urls
550
- # html:
551
- caption = BeautifulSoup(caption, features="html.parser").text
552
-
553
- # @<nickname>
554
- caption = re.sub(r"@[\w\d]+\b", "", caption)
555
-
556
- # 31C0—31EF CJK Strokes
557
- # 31F0—31FF Katakana Phonetic Extensions
558
- # 3200—32FF Enclosed CJK Letters and Months
559
- # 3300—33FF CJK Compatibility
560
- # 3400—4DBF CJK Unified Ideographs Extension A
561
- # 4DC0—4DFF Yijing Hexagram Symbols
562
- # 4E00—9FFF CJK Unified Ideographs
563
- caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
564
- caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
565
- caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
566
- caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
567
- caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
568
- caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
569
- caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
570
- #######################################################
571
-
572
- # все виды тире / all types of dash --> "-"
573
- caption = re.sub(
574
- r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
575
- "-",
576
- caption,
577
- )
578
-
579
- # кавычки к одному стандарту
580
- caption = re.sub(r"[`´«»“”¨]", '"', caption)
581
- caption = re.sub(r"[‘’]", "'", caption)
582
-
583
- # &quot;
584
- caption = re.sub(r"&quot;?", "", caption)
585
- # &amp
586
- caption = re.sub(r"&amp", "", caption)
587
-
588
- # ip adresses:
589
- caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
590
-
591
- # article ids:
592
- caption = re.sub(r"\d:\d\d\s+$", "", caption)
593
-
594
- # \n
595
- caption = re.sub(r"\\n", " ", caption)
596
-
597
- # "#123"
598
- caption = re.sub(r"#\d{1,3}\b", "", caption)
599
- # "#12345.."
600
- caption = re.sub(r"#\d{5,}\b", "", caption)
601
- # "123456.."
602
- caption = re.sub(r"\b\d{6,}\b", "", caption)
603
- # filenames:
604
- caption = re.sub(
605
- r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
606
- )
607
-
608
- #
609
- caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
610
- caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
611
-
612
- caption = re.sub(
613
- self.bad_punct_regex, r" ", caption
614
- ) # ***AUSVERKAUFT***, #AUSVERKAUFT
615
- caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
616
-
617
- # this-is-my-cute-cat / this_is_my_cute_cat
618
- regex2 = re.compile(r"(?:\-|\_)")
619
- if len(re.findall(regex2, caption)) > 3:
620
- caption = re.sub(regex2, " ", caption)
621
-
622
- caption = ftfy.fix_text(caption)
623
- caption = html.unescape(html.unescape(caption))
624
-
625
- caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
626
- caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
627
- caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
628
-
629
- caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
630
- caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
631
- caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
632
- caption = re.sub(
633
- r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
634
- )
635
- caption = re.sub(r"\bpage\s+\d+\b", "", caption)
636
-
637
- caption = re.sub(
638
- r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
639
- ) # j2d1a2a...
640
-
641
- caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
642
-
643
- caption = re.sub(r"\b\s+\:\s+", r": ", caption)
644
- caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
645
- caption = re.sub(r"\s+", " ", caption)
646
-
647
- caption.strip()
648
-
649
- caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
650
- caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
651
- caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
652
- caption = re.sub(r"^\.\S+$", "", caption)
653
-
654
- return caption.strip()
655
-
656
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
657
- def prepare_latents(
658
- self,
659
- batch_size,
660
- num_latent_channels,
661
- num_patches,
662
- dtype,
663
- device,
664
- generator,
665
- latents=None,
666
- latents_mask=None,
667
- ):
668
- shape = (
669
- batch_size,
670
- num_patches // math.prod(self.patchifier.patch_size),
671
- num_latent_channels,
672
- )
673
-
674
- if isinstance(generator, list) and len(generator) != batch_size:
675
- raise ValueError(
676
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
677
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
678
- )
679
-
680
- if latents is None:
681
- latents = randn_tensor(
682
- shape, generator=generator, device=device, dtype=dtype
683
- )
684
- elif latents_mask is not None:
685
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
686
- latents = latents * latents_mask[..., None] + noise * (
687
- 1 - latents_mask[..., None]
688
- )
689
- else:
690
- latents = latents.to(device)
691
-
692
- # scale the initial noise by the standard deviation required by the scheduler
693
- latents = latents * self.scheduler.init_noise_sigma
694
- return latents
695
-
696
- @staticmethod
697
- def classify_height_width_bin(
698
- height: int, width: int, ratios: dict
699
- ) -> Tuple[int, int]:
700
- """Returns binned height and width."""
701
- ar = float(height / width)
702
- closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
703
- default_hw = ratios[closest_ratio]
704
- return int(default_hw[0]), int(default_hw[1])
705
-
706
- @staticmethod
707
- def resize_and_crop_tensor(
708
- samples: torch.Tensor, new_width: int, new_height: int
709
- ) -> torch.Tensor:
710
- n_frames, orig_height, orig_width = samples.shape[-3:]
711
-
712
- # Check if resizing is needed
713
- if orig_height != new_height or orig_width != new_width:
714
- ratio = max(new_height / orig_height, new_width / orig_width)
715
- resized_width = int(orig_width * ratio)
716
- resized_height = int(orig_height * ratio)
717
-
718
- # Resize
719
- samples = rearrange(samples, "b c n h w -> (b n) c h w")
720
- samples = F.interpolate(
721
- samples,
722
- size=(resized_height, resized_width),
723
- mode="bilinear",
724
- align_corners=False,
725
- )
726
- samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames)
727
-
728
- # Center Crop
729
- start_x = (resized_width - new_width) // 2
730
- end_x = start_x + new_width
731
- start_y = (resized_height - new_height) // 2
732
- end_y = start_y + new_height
733
- samples = samples[..., start_y:end_y, start_x:end_x]
734
-
735
- return samples
736
-
737
- @torch.no_grad()
738
- def __call__(
739
- self,
740
- height: int,
741
- width: int,
742
- num_frames: int,
743
- frame_rate: float,
744
- prompt: Union[str, List[str]] = None,
745
- negative_prompt: str = "",
746
- num_inference_steps: int = 20,
747
- timesteps: List[int] = None,
748
- guidance_scale: float = 4.5,
749
- num_images_per_prompt: Optional[int] = 1,
750
- eta: float = 0.0,
751
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
752
- latents: Optional[torch.FloatTensor] = None,
753
- prompt_embeds: Optional[torch.FloatTensor] = None,
754
- prompt_attention_mask: Optional[torch.FloatTensor] = None,
755
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
756
- negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
757
- output_type: Optional[str] = "pil",
758
- return_dict: bool = True,
759
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
760
- clean_caption: bool = True,
761
- media_items: Optional[torch.FloatTensor] = None,
762
- mixed_precision: bool = False,
763
- **kwargs,
764
- ) -> Union[ImagePipelineOutput, Tuple]:
765
- """
766
- Function invoked when calling the pipeline for generation.
767
-
768
- Args:
769
- prompt (`str` or `List[str]`, *optional*):
770
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
771
- instead.
772
- negative_prompt (`str` or `List[str]`, *optional*):
773
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
774
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
775
- less than `1`).
776
- num_inference_steps (`int`, *optional*, defaults to 100):
777
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
778
- expense of slower inference.
779
- timesteps (`List[int]`, *optional*):
780
- Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
781
- timesteps are used. Must be in descending order.
782
- guidance_scale (`float`, *optional*, defaults to 4.5):
783
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
784
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
785
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
786
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
787
- usually at the expense of lower image quality.
788
- num_images_per_prompt (`int`, *optional*, defaults to 1):
789
- The number of images to generate per prompt.
790
- height (`int`, *optional*, defaults to self.unet.config.sample_size):
791
- The height in pixels of the generated image.
792
- width (`int`, *optional*, defaults to self.unet.config.sample_size):
793
- The width in pixels of the generated image.
794
- eta (`float`, *optional*, defaults to 0.0):
795
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
796
- [`schedulers.DDIMScheduler`], will be ignored for others.
797
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
798
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
799
- to make generation deterministic.
800
- latents (`torch.FloatTensor`, *optional*):
801
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
802
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
803
- tensor will ge generated by sampling using the supplied random `generator`.
804
- prompt_embeds (`torch.FloatTensor`, *optional*):
805
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
806
- provided, text embeddings will be generated from `prompt` input argument.
807
- prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
808
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
809
- Pre-generated negative text embeddings. This negative prompt should be "". If not
810
- provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
811
- negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
812
- Pre-generated attention mask for negative text embeddings.
813
- output_type (`str`, *optional*, defaults to `"pil"`):
814
- The output format of the generate image. Choose between
815
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
816
- return_dict (`bool`, *optional*, defaults to `True`):
817
- Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
818
- callback_on_step_end (`Callable`, *optional*):
819
- A function that calls at the end of each denoising steps during the inference. The function is called
820
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
821
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
822
- `callback_on_step_end_tensor_inputs`.
823
- clean_caption (`bool`, *optional*, defaults to `True`):
824
- Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
825
- be installed. If the dependencies are not installed, the embeddings will be created from the raw
826
- prompt.
827
- use_resolution_binning (`bool` defaults to `True`):
828
- If set to `True`, the requested height and width are first mapped to the closest resolutions using
829
- `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
830
- the requested resolution. Useful for generating non-square images.
831
-
832
- Examples:
833
-
834
- Returns:
835
- [`~pipelines.ImagePipelineOutput`] or `tuple`:
836
- If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
837
- returned where the first element is a list with the generated images
838
- """
839
- if "mask_feature" in kwargs:
840
- deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
841
- deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
842
-
843
- is_video = kwargs.get("is_video", False)
844
- self.check_inputs(
845
- prompt,
846
- height,
847
- width,
848
- negative_prompt,
849
- prompt_embeds,
850
- negative_prompt_embeds,
851
- prompt_attention_mask,
852
- negative_prompt_attention_mask,
853
- )
854
-
855
- # 2. Default height and width to transformer
856
- if prompt is not None and isinstance(prompt, str):
857
- batch_size = 1
858
- elif prompt is not None and isinstance(prompt, list):
859
- batch_size = len(prompt)
860
- else:
861
- batch_size = prompt_embeds.shape[0]
862
-
863
- device = self._execution_device
864
-
865
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
866
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
867
- # corresponds to doing no classifier free guidance.
868
- do_classifier_free_guidance = guidance_scale > 1.0
869
-
870
- # 3. Encode input prompt
871
- (
872
- prompt_embeds,
873
- prompt_attention_mask,
874
- negative_prompt_embeds,
875
- negative_prompt_attention_mask,
876
- ) = self.encode_prompt(
877
- prompt,
878
- do_classifier_free_guidance,
879
- negative_prompt=negative_prompt,
880
- num_images_per_prompt=num_images_per_prompt,
881
- device=device,
882
- prompt_embeds=prompt_embeds,
883
- negative_prompt_embeds=negative_prompt_embeds,
884
- prompt_attention_mask=prompt_attention_mask,
885
- negative_prompt_attention_mask=negative_prompt_attention_mask,
886
- clean_caption=clean_caption,
887
- )
888
- if do_classifier_free_guidance:
889
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
890
- prompt_attention_mask = torch.cat(
891
- [negative_prompt_attention_mask, prompt_attention_mask], dim=0
892
- )
893
-
894
- # 3b. Encode and prepare conditioning data
895
- self.video_scale_factor = self.video_scale_factor if is_video else 1
896
- conditioning_method = kwargs.get("conditioning_method", None)
897
- vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
898
- init_latents, conditioning_mask = self.prepare_conditioning(
899
- media_items,
900
- num_frames,
901
- height,
902
- width,
903
- conditioning_method,
904
- vae_per_channel_normalize,
905
- )
906
-
907
- # 4. Prepare latents.
908
- latent_height = height // self.vae_scale_factor
909
- latent_width = width // self.vae_scale_factor
910
- latent_num_frames = num_frames // self.video_scale_factor
911
- if isinstance(self.vae, CausalVideoAutoencoder) and is_video:
912
- latent_num_frames += 1
913
- latent_frame_rate = frame_rate / self.video_scale_factor
914
- num_latent_patches = latent_height * latent_width * latent_num_frames
915
- latents = self.prepare_latents(
916
- batch_size=batch_size * num_images_per_prompt,
917
- num_latent_channels=self.transformer.config.in_channels,
918
- num_patches=num_latent_patches,
919
- dtype=prompt_embeds.dtype,
920
- device=device,
921
- generator=generator,
922
- latents=init_latents,
923
- latents_mask=conditioning_mask,
924
- )
925
- if conditioning_mask is not None and is_video:
926
- assert num_images_per_prompt == 1
927
- conditioning_mask = (
928
- torch.cat([conditioning_mask] * 2)
929
- if do_classifier_free_guidance
930
- else conditioning_mask
931
- )
932
-
933
- # 5. Prepare timesteps
934
- retrieve_timesteps_kwargs = {}
935
- if isinstance(self.scheduler, TimestepShifter):
936
- retrieve_timesteps_kwargs["samples"] = latents
937
- timesteps, num_inference_steps = retrieve_timesteps(
938
- self.scheduler,
939
- num_inference_steps,
940
- device,
941
- timesteps,
942
- **retrieve_timesteps_kwargs,
943
- )
944
-
945
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
946
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
947
-
948
- # 7. Denoising loop
949
- num_warmup_steps = max(
950
- len(timesteps) - num_inference_steps * self.scheduler.order, 0
951
- )
952
-
953
- with self.progress_bar(total=num_inference_steps) as progress_bar:
954
- for i, t in enumerate(timesteps):
955
- latent_model_input = (
956
- torch.cat([latents] * 2) if do_classifier_free_guidance else latents
957
- )
958
- latent_model_input = self.scheduler.scale_model_input(
959
- latent_model_input, t
960
- )
961
-
962
- latent_frame_rates = (
963
- torch.ones(
964
- latent_model_input.shape[0], 1, device=latent_model_input.device
965
- )
966
- * latent_frame_rate
967
- )
968
-
969
- current_timestep = t
970
- if not torch.is_tensor(current_timestep):
971
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
972
- # This would be a good case for the `match` statement (Python 3.10+)
973
- is_mps = latent_model_input.device.type == "mps"
974
- if isinstance(current_timestep, float):
975
- dtype = torch.float32 if is_mps else torch.float64
976
- else:
977
- dtype = torch.int32 if is_mps else torch.int64
978
- current_timestep = torch.tensor(
979
- [current_timestep],
980
- dtype=dtype,
981
- device=latent_model_input.device,
982
- )
983
- elif len(current_timestep.shape) == 0:
984
- current_timestep = current_timestep[None].to(
985
- latent_model_input.device
986
- )
987
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
988
- current_timestep = current_timestep.expand(
989
- latent_model_input.shape[0]
990
- ).unsqueeze(-1)
991
- scale_grid = (
992
- (
993
- 1 / latent_frame_rates,
994
- self.vae_scale_factor,
995
- self.vae_scale_factor,
996
- )
997
- if self.transformer.use_rope
998
- else None
999
- )
1000
- indices_grid = self.patchifier.get_grid(
1001
- orig_num_frames=latent_num_frames,
1002
- orig_height=latent_height,
1003
- orig_width=latent_width,
1004
- batch_size=latent_model_input.shape[0],
1005
- scale_grid=scale_grid,
1006
- device=latents.device,
1007
- )
1008
-
1009
- if conditioning_mask is not None:
1010
- current_timestep = current_timestep * (1 - conditioning_mask)
1011
- # Choose the appropriate context manager based on `mixed_precision`
1012
- if mixed_precision:
1013
- if "xla" in device.type:
1014
- raise NotImplementedError(
1015
- "Mixed precision is not supported yet on XLA devices."
1016
- )
1017
-
1018
- context_manager = torch.autocast(device.type, dtype=torch.bfloat16)
1019
- else:
1020
- context_manager = nullcontext() # Dummy context manager
1021
-
1022
- # predict noise model_output
1023
- with context_manager:
1024
- noise_pred = self.transformer(
1025
- latent_model_input.to(self.transformer.dtype),
1026
- indices_grid,
1027
- encoder_hidden_states=prompt_embeds.to(self.transformer.dtype),
1028
- encoder_attention_mask=prompt_attention_mask,
1029
- timestep=current_timestep,
1030
- return_dict=False,
1031
- )[0]
1032
-
1033
- # perform guidance
1034
- if do_classifier_free_guidance:
1035
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1036
- noise_pred = noise_pred_uncond + guidance_scale * (
1037
- noise_pred_text - noise_pred_uncond
1038
- )
1039
- current_timestep, _ = current_timestep.chunk(2)
1040
-
1041
- # learned sigma
1042
- if (
1043
- self.transformer.config.out_channels // 2
1044
- == self.transformer.config.in_channels
1045
- ):
1046
- noise_pred = noise_pred.chunk(2, dim=1)[0]
1047
-
1048
- # compute previous image: x_t -> x_t-1
1049
- latents = self.scheduler.step(
1050
- noise_pred,
1051
- t if current_timestep is None else current_timestep,
1052
- latents,
1053
- **extra_step_kwargs,
1054
- return_dict=False,
1055
- )[0]
1056
-
1057
- # call the callback, if provided
1058
- if i == len(timesteps) - 1 or (
1059
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1060
- ):
1061
- progress_bar.update()
1062
-
1063
- if callback_on_step_end is not None:
1064
- callback_on_step_end(self, i, t, {})
1065
-
1066
- latents = self.patchifier.unpatchify(
1067
- latents=latents,
1068
- output_height=latent_height,
1069
- output_width=latent_width,
1070
- output_num_frames=latent_num_frames,
1071
- out_channels=self.transformer.in_channels
1072
- // math.prod(self.patchifier.patch_size),
1073
- )
1074
- if output_type != "latent":
1075
- image = vae_decode(
1076
- latents,
1077
- self.vae,
1078
- is_video,
1079
- vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
1080
- )
1081
- image = self.image_processor.postprocess(image, output_type=output_type)
1082
-
1083
- else:
1084
- image = latents
1085
-
1086
- # Offload all models
1087
- self.maybe_free_model_hooks()
1088
-
1089
- if not return_dict:
1090
- return (image,)
1091
-
1092
- return ImagePipelineOutput(images=image)
1093
-
1094
- def prepare_conditioning(
1095
- self,
1096
- media_items: torch.Tensor,
1097
- num_frames: int,
1098
- height: int,
1099
- width: int,
1100
- method: ConditioningMethod = ConditioningMethod.UNCONDITIONAL,
1101
- vae_per_channel_normalize: bool = False,
1102
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1103
- """
1104
- Prepare the conditioning data for the video generation. If an input media item is provided, encode it
1105
- and set the conditioning_mask to indicate which tokens to condition on. Input media item should have
1106
- the same height and width as the generated video.
1107
-
1108
- Args:
1109
- media_items (torch.Tensor): media items to condition on (images or videos)
1110
- num_frames (int): number of frames to generate
1111
- height (int): height of the generated video
1112
- width (int): width of the generated video
1113
- method (ConditioningMethod, optional): conditioning method to use. Defaults to ConditioningMethod.UNCONDITIONAL.
1114
- vae_per_channel_normalize (bool, optional): whether to normalize the input to the VAE per channel. Defaults to False.
1115
-
1116
- Returns:
1117
- Tuple[torch.Tensor, torch.Tensor]: the conditioning latents and the conditioning mask
1118
- """
1119
- if media_items is None or method == ConditioningMethod.UNCONDITIONAL:
1120
- return None, None
1121
-
1122
- assert media_items.ndim == 5
1123
- assert height == media_items.shape[-2] and width == media_items.shape[-1]
1124
-
1125
- # Encode the input video and repeat to the required number of frame-tokens
1126
- init_latents = vae_encode(
1127
- media_items.to(dtype=self.vae.dtype, device=self.vae.device),
1128
- self.vae,
1129
- vae_per_channel_normalize=vae_per_channel_normalize,
1130
- ).float()
1131
-
1132
- init_len, target_len = (
1133
- init_latents.shape[2],
1134
- num_frames // self.video_scale_factor,
1135
- )
1136
- if isinstance(self.vae, CausalVideoAutoencoder):
1137
- target_len += 1
1138
- init_latents = init_latents[:, :, :target_len]
1139
- if target_len > init_len:
1140
- repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
1141
- init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[
1142
- :, :, :target_len
1143
- ]
1144
-
1145
- # Prepare the conditioning mask (1.0 = condition on this token)
1146
- b, n, f, h, w = init_latents.shape
1147
- conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
1148
- if method in [
1149
- ConditioningMethod.FIRST_FRAME,
1150
- ConditioningMethod.FIRST_AND_LAST_FRAME,
1151
- ]:
1152
- conditioning_mask[:, :, 0] = 1.0
1153
- if method in [
1154
- ConditioningMethod.LAST_FRAME,
1155
- ConditioningMethod.FIRST_AND_LAST_FRAME,
1156
- ]:
1157
- conditioning_mask[:, :, -1] = 1.0
1158
-
1159
- # Patchify the init latents and the mask
1160
- conditioning_mask = self.patchifier.patchify(conditioning_mask).squeeze(-1)
1161
- init_latents = self.patchifier.patchify(latents=init_latents)
1162
- return init_latents, conditioning_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/schedulers/__init__.py DELETED
File without changes
xora/schedulers/rf.py DELETED
@@ -1,261 +0,0 @@
1
- import math
2
- from abc import ABC, abstractmethod
3
- from dataclasses import dataclass
4
- from typing import Callable, Optional, Tuple, Union
5
-
6
- import torch
7
- from diffusers.configuration_utils import ConfigMixin, register_to_config
8
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
- from diffusers.utils import BaseOutput
10
- from torch import Tensor
11
-
12
- from xora.utils.torch_utils import append_dims
13
-
14
-
15
- def simple_diffusion_resolution_dependent_timestep_shift(
16
- samples: Tensor,
17
- timesteps: Tensor,
18
- n: int = 32 * 32,
19
- ) -> Tensor:
20
- if len(samples.shape) == 3:
21
- _, m, _ = samples.shape
22
- elif len(samples.shape) in [4, 5]:
23
- m = math.prod(samples.shape[2:])
24
- else:
25
- raise ValueError(
26
- "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
27
- )
28
- snr = (timesteps / (1 - timesteps)) ** 2
29
- shift_snr = torch.log(snr) + 2 * math.log(m / n)
30
- shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
31
-
32
- return shifted_timesteps
33
-
34
-
35
- def time_shift(mu: float, sigma: float, t: Tensor):
36
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
37
-
38
-
39
- def get_normal_shift(
40
- n_tokens: int,
41
- min_tokens: int = 1024,
42
- max_tokens: int = 4096,
43
- min_shift: float = 0.95,
44
- max_shift: float = 2.05,
45
- ) -> Callable[[float], float]:
46
- m = (max_shift - min_shift) / (max_tokens - min_tokens)
47
- b = min_shift - m * min_tokens
48
- return m * n_tokens + b
49
-
50
-
51
- def sd3_resolution_dependent_timestep_shift(
52
- samples: Tensor, timesteps: Tensor
53
- ) -> Tensor:
54
- """
55
- Shifts the timestep schedule as a function of the generated resolution.
56
-
57
- In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images.
58
- For more details: https://arxiv.org/pdf/2403.03206
59
-
60
- In Flux they later propose a more dynamic resolution dependent timestep shift, see:
61
- https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66
62
-
63
-
64
- Args:
65
- samples (Tensor): A batch of samples with shape (batch_size, channels, height, width) or
66
- (batch_size, channels, frame, height, width).
67
- timesteps (Tensor): A batch of timesteps with shape (batch_size,).
68
-
69
- Returns:
70
- Tensor: The shifted timesteps.
71
- """
72
- if len(samples.shape) == 3:
73
- _, m, _ = samples.shape
74
- elif len(samples.shape) in [4, 5]:
75
- m = math.prod(samples.shape[2:])
76
- else:
77
- raise ValueError(
78
- "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
79
- )
80
-
81
- shift = get_normal_shift(m)
82
- return time_shift(shift, 1, timesteps)
83
-
84
-
85
- class TimestepShifter(ABC):
86
- @abstractmethod
87
- def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor:
88
- pass
89
-
90
-
91
- @dataclass
92
- class RectifiedFlowSchedulerOutput(BaseOutput):
93
- """
94
- Output class for the scheduler's step function output.
95
-
96
- Args:
97
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
98
- Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
99
- denoising loop.
100
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
101
- The predicted denoised sample (x_{0}) based on the model output from the current timestep.
102
- `pred_original_sample` can be used to preview progress or for guidance.
103
- """
104
-
105
- prev_sample: torch.FloatTensor
106
- pred_original_sample: Optional[torch.FloatTensor] = None
107
-
108
-
109
- class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
110
- order = 1
111
-
112
- @register_to_config
113
- def __init__(
114
- self,
115
- num_train_timesteps=1000,
116
- shifting: Optional[str] = None,
117
- base_resolution: int = 32**2,
118
- ):
119
- super().__init__()
120
- self.init_noise_sigma = 1.0
121
- self.num_inference_steps = None
122
- self.timesteps = self.sigmas = torch.linspace(
123
- 1, 1 / num_train_timesteps, num_train_timesteps
124
- )
125
- self.delta_timesteps = self.timesteps - torch.cat(
126
- [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
127
- )
128
- self.shifting = shifting
129
- self.base_resolution = base_resolution
130
-
131
- def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor:
132
- if self.shifting == "SD3":
133
- return sd3_resolution_dependent_timestep_shift(samples, timesteps)
134
- elif self.shifting == "SimpleDiffusion":
135
- return simple_diffusion_resolution_dependent_timestep_shift(
136
- samples, timesteps, self.base_resolution
137
- )
138
- return timesteps
139
-
140
- def set_timesteps(
141
- self,
142
- num_inference_steps: int,
143
- samples: Tensor,
144
- device: Union[str, torch.device] = None,
145
- ):
146
- """
147
- Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
148
-
149
- Args:
150
- num_inference_steps (`int`): The number of diffusion steps used when generating samples.
151
- samples (`Tensor`): A batch of samples with shape.
152
- device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
153
- """
154
- num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
155
- timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(
156
- device
157
- )
158
- self.timesteps = self.shift_timesteps(samples, timesteps)
159
- self.delta_timesteps = self.timesteps - torch.cat(
160
- [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
161
- )
162
- self.num_inference_steps = num_inference_steps
163
- self.sigmas = self.timesteps
164
-
165
- def scale_model_input(
166
- self, sample: torch.FloatTensor, timestep: Optional[int] = None
167
- ) -> torch.FloatTensor:
168
- # pylint: disable=unused-argument
169
- """
170
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
171
- current timestep.
172
-
173
- Args:
174
- sample (`torch.FloatTensor`): input sample
175
- timestep (`int`, optional): current timestep
176
-
177
- Returns:
178
- `torch.FloatTensor`: scaled input sample
179
- """
180
- return sample
181
-
182
- def step(
183
- self,
184
- model_output: torch.FloatTensor,
185
- timestep: torch.FloatTensor,
186
- sample: torch.FloatTensor,
187
- eta: float = 0.0,
188
- use_clipped_model_output: bool = False,
189
- generator=None,
190
- variance_noise: Optional[torch.FloatTensor] = None,
191
- return_dict: bool = True,
192
- ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
193
- # pylint: disable=unused-argument
194
- """
195
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
196
- process from the learned model outputs (most often the predicted noise).
197
-
198
- Args:
199
- model_output (`torch.FloatTensor`):
200
- The direct output from learned diffusion model.
201
- timestep (`float`):
202
- The current discrete timestep in the diffusion chain.
203
- sample (`torch.FloatTensor`):
204
- A current instance of a sample created by the diffusion process.
205
- eta (`float`):
206
- The weight of noise for added noise in diffusion step.
207
- use_clipped_model_output (`bool`, defaults to `False`):
208
- If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
209
- because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
210
- clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
211
- `use_clipped_model_output` has no effect.
212
- generator (`torch.Generator`, *optional*):
213
- A random number generator.
214
- variance_noise (`torch.FloatTensor`):
215
- Alternative to generating noise with `generator` by directly providing the noise for the variance
216
- itself. Useful for methods such as [`CycleDiffusion`].
217
- return_dict (`bool`, *optional*, defaults to `True`):
218
- Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
219
-
220
- Returns:
221
- [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`:
222
- If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned,
223
- otherwise a tuple is returned where the first element is the sample tensor.
224
- """
225
- if self.num_inference_steps is None:
226
- raise ValueError(
227
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
228
- )
229
-
230
- if timestep.ndim == 0:
231
- # Global timestep
232
- current_index = (self.timesteps - timestep).abs().argmin()
233
- dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0))
234
- else:
235
- # Timestep per token
236
- assert timestep.ndim == 2
237
- current_index = (
238
- (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
239
- )
240
- dt = self.delta_timesteps[current_index]
241
- # Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
242
- dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
243
-
244
- prev_sample = sample - dt * model_output
245
-
246
- if not return_dict:
247
- return (prev_sample,)
248
-
249
- return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
250
-
251
- def add_noise(
252
- self,
253
- original_samples: torch.FloatTensor,
254
- noise: torch.FloatTensor,
255
- timesteps: torch.FloatTensor,
256
- ) -> torch.FloatTensor:
257
- sigmas = timesteps
258
- sigmas = append_dims(sigmas, original_samples.ndim)
259
- alphas = 1 - sigmas
260
- noisy_samples = alphas * original_samples + sigmas * noise
261
- return noisy_samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
xora/utils/__init__.py DELETED
File without changes
xora/utils/conditioning_method.py DELETED
@@ -1,8 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class ConditioningMethod(Enum):
5
- UNCONDITIONAL = "unconditional"
6
- FIRST_FRAME = "first_frame"
7
- LAST_FRAME = "last_frame"
8
- FIRST_AND_LAST_FRAME = "first_and_last_frame"
 
 
 
 
 
 
 
 
 
xora/utils/torch_utils.py DELETED
@@ -1,25 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
6
- """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
7
- dims_to_append = target_dims - x.ndim
8
- if dims_to_append < 0:
9
- raise ValueError(
10
- f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
11
- )
12
- elif dims_to_append == 0:
13
- return x
14
- return x[(...,) + (None,) * dims_to_append]
15
-
16
-
17
- class Identity(nn.Module):
18
- """A placeholder identity operator that is argument-insensitive."""
19
-
20
- def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
21
- super().__init__()
22
-
23
- # pylint: disable=unused-argument
24
- def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
25
- return x