Spaces:
Runtime error
Runtime error
Merge pull request #34 from LightricksResearch/add_atten_to_decoder
Browse files
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
|
@@ -9,10 +9,12 @@ import numpy as np
|
|
| 9 |
from einops import rearrange
|
| 10 |
from torch import nn
|
| 11 |
from diffusers.utils import logging
|
|
|
|
| 12 |
|
| 13 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 14 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
| 15 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
|
|
|
| 16 |
|
| 17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 18 |
|
|
@@ -212,6 +214,12 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
| 212 |
last_layer = self.decoder.layers[-1]
|
| 213 |
return last_layer
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
class Encoder(nn.Module):
|
| 217 |
r"""
|
|
@@ -485,6 +493,16 @@ class Decoder(nn.Module):
|
|
| 485 |
norm_layer=norm_layer,
|
| 486 |
inject_noise=block_params.get("inject_noise", False),
|
| 487 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
elif block_name == "res_x_y":
|
| 489 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
| 490 |
block = ResnetBlock3D(
|
|
@@ -562,6 +580,129 @@ class Decoder(nn.Module):
|
|
| 562 |
return sample
|
| 563 |
|
| 564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
class UNetMidBlock3D(nn.Module):
|
| 566 |
"""
|
| 567 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
|
|
|
| 9 |
from einops import rearrange
|
| 10 |
from torch import nn
|
| 11 |
from diffusers.utils import logging
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
|
| 14 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 15 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
| 16 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
| 17 |
+
from xora.models.transformers.attention import Attention
|
| 18 |
|
| 19 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 20 |
|
|
|
|
| 214 |
last_layer = self.decoder.layers[-1]
|
| 215 |
return last_layer
|
| 216 |
|
| 217 |
+
def set_use_tpu_flash_attention(self):
|
| 218 |
+
for block in self.decoder.up_blocks:
|
| 219 |
+
if isinstance(block, AttentionResBlocks):
|
| 220 |
+
for attention_block in block.attention_blocks:
|
| 221 |
+
attention_block.set_use_tpu_flash_attention()
|
| 222 |
+
|
| 223 |
|
| 224 |
class Encoder(nn.Module):
|
| 225 |
r"""
|
|
|
|
| 493 |
norm_layer=norm_layer,
|
| 494 |
inject_noise=block_params.get("inject_noise", False),
|
| 495 |
)
|
| 496 |
+
elif block_name == "attn_res_x":
|
| 497 |
+
block = AttentionResBlocks(
|
| 498 |
+
dims=dims,
|
| 499 |
+
in_channels=input_channel,
|
| 500 |
+
num_layers=block_params["num_layers"],
|
| 501 |
+
resnet_groups=norm_num_groups,
|
| 502 |
+
norm_layer=norm_layer,
|
| 503 |
+
attention_head_dim=block_params["attention_head_dim"],
|
| 504 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 505 |
+
)
|
| 506 |
elif block_name == "res_x_y":
|
| 507 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
| 508 |
block = ResnetBlock3D(
|
|
|
|
| 580 |
return sample
|
| 581 |
|
| 582 |
|
| 583 |
+
class AttentionResBlocks(nn.Module):
|
| 584 |
+
"""
|
| 585 |
+
A 3D convolution residual block followed by self attention residual block
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
dims (`int` or `Tuple[int, int]`): The number of dimensions to use in convolutions.
|
| 589 |
+
in_channels (`int`): The number of input channels.
|
| 590 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 591 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 592 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 593 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
| 594 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 595 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use.
|
| 596 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
|
| 597 |
+
inject_noise (`bool`, *optional*, defaults to `False`): Whether to inject noise or not between convolution layers.
|
| 598 |
+
|
| 599 |
+
Returns:
|
| 600 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 601 |
+
in_channels, height, width)`.
|
| 602 |
+
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
def __init__(
|
| 606 |
+
self,
|
| 607 |
+
dims: Union[int, Tuple[int, int]],
|
| 608 |
+
in_channels: int,
|
| 609 |
+
dropout: float = 0.0,
|
| 610 |
+
num_layers: int = 1,
|
| 611 |
+
resnet_eps: float = 1e-6,
|
| 612 |
+
resnet_groups: int = 32,
|
| 613 |
+
norm_layer: str = "group_norm",
|
| 614 |
+
attention_head_dim: int = 64,
|
| 615 |
+
inject_noise: bool = False,
|
| 616 |
+
):
|
| 617 |
+
super().__init__()
|
| 618 |
+
|
| 619 |
+
if attention_head_dim > in_channels:
|
| 620 |
+
raise ValueError(
|
| 621 |
+
"attention_head_dim must be less than or equal to in_channels"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
resnet_groups = (
|
| 625 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
self.res_blocks = []
|
| 629 |
+
self.attention_blocks = []
|
| 630 |
+
for i in range(num_layers):
|
| 631 |
+
self.res_blocks.append(
|
| 632 |
+
ResnetBlock3D(
|
| 633 |
+
dims=dims,
|
| 634 |
+
in_channels=in_channels,
|
| 635 |
+
out_channels=in_channels,
|
| 636 |
+
eps=resnet_eps,
|
| 637 |
+
groups=resnet_groups,
|
| 638 |
+
dropout=dropout,
|
| 639 |
+
norm_layer=norm_layer,
|
| 640 |
+
inject_noise=inject_noise,
|
| 641 |
+
)
|
| 642 |
+
)
|
| 643 |
+
self.attention_blocks.append(
|
| 644 |
+
Attention(
|
| 645 |
+
query_dim=in_channels,
|
| 646 |
+
heads=in_channels // attention_head_dim,
|
| 647 |
+
dim_head=attention_head_dim,
|
| 648 |
+
bias=True,
|
| 649 |
+
out_bias=True,
|
| 650 |
+
qk_norm="rms_norm",
|
| 651 |
+
residual_connection=True,
|
| 652 |
+
)
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
self.res_blocks = nn.ModuleList(self.res_blocks)
|
| 656 |
+
self.attention_blocks = nn.ModuleList(self.attention_blocks)
|
| 657 |
+
|
| 658 |
+
def forward(
|
| 659 |
+
self, hidden_states: torch.FloatTensor, causal: bool = True
|
| 660 |
+
) -> torch.FloatTensor:
|
| 661 |
+
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
| 662 |
+
hidden_states = resnet(hidden_states, causal=causal)
|
| 663 |
+
|
| 664 |
+
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
| 665 |
+
batch_size, channel, frames, height, width = hidden_states.shape
|
| 666 |
+
hidden_states = hidden_states.view(
|
| 667 |
+
batch_size, channel, frames * height * width
|
| 668 |
+
).transpose(1, 2)
|
| 669 |
+
|
| 670 |
+
if attention.use_tpu_flash_attention:
|
| 671 |
+
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
| 672 |
+
seq_len = hidden_states.shape[1]
|
| 673 |
+
block_k_major = 512
|
| 674 |
+
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
| 675 |
+
if pad_len > 0:
|
| 676 |
+
hidden_states = F.pad(
|
| 677 |
+
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
| 681 |
+
mask = torch.ones(
|
| 682 |
+
(hidden_states.shape[0], seq_len),
|
| 683 |
+
device=hidden_states.device,
|
| 684 |
+
dtype=hidden_states.dtype,
|
| 685 |
+
)
|
| 686 |
+
if pad_len > 0:
|
| 687 |
+
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
| 688 |
+
|
| 689 |
+
hidden_states = attention(
|
| 690 |
+
hidden_states,
|
| 691 |
+
attention_mask=None if not attention.use_tpu_flash_attention else mask,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
if attention.use_tpu_flash_attention:
|
| 695 |
+
# Remove the padding
|
| 696 |
+
if pad_len > 0:
|
| 697 |
+
hidden_states = hidden_states[:, :-pad_len, :]
|
| 698 |
+
|
| 699 |
+
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
| 700 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 701 |
+
batch_size, channel, frames, height, width
|
| 702 |
+
)
|
| 703 |
+
return hidden_states
|
| 704 |
+
|
| 705 |
+
|
| 706 |
class UNetMidBlock3D(nn.Module):
|
| 707 |
"""
|
| 708 |
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|