TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
raw
history blame
14.2 kB
import math
import pdb
import random
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput, logging
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
from torch import nn
from animatediff.utils.util import zero_rank_print
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
def get_motion_module(
in_channels,
motion_module_type: str,
motion_module_kwargs: dict
):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs)
elif motion_module_type == "Conv":
return ConvTemporalModule(in_channels=in_channels, **motion_module_kwargs)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads = 8,
num_transformer_block = 2,
attention_block_types =( "Temporal_Self", ),
spatial_position_encoding = False,
temporal_position_encoding = True,
temporal_position_encoding_max_len = 32,
temporal_attention_dim_div = 1,
zero_initialize = True,
causal_temporal_attention = False,
causal_temporal_attention_mask_type = "",
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
spatial_position_encoding = spatial_position_encoding,
causal_temporal_attention=causal_temporal_attention,
causal_temporal_attention_mask_type=causal_temporal_attention_mask_type,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb=None, encoder_hidden_states=None, attention_mask=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 32,
spatial_position_encoding = False,
causal_temporal_attention = None,
causal_temporal_attention_mask_type = "",
):
super().__init__()
assert causal_temporal_attention is not None
self.causal_temporal_attention = causal_temporal_attention
assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "")
self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type
self.causal_temporal_attention_mask = None
self.spatial_position_encoding = spatial_position_encoding
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
if spatial_position_encoding:
self.pos_encoder_2d = PositionalEncoding2D(inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def get_causal_temporal_attention_mask(self, hidden_states):
batch_size, sequence_length, dim = hidden_states.shape
if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (batch_size, sequence_length, sequence_length):
zero_rank_print(f"build attn mask of type {self.causal_temporal_attention_mask_type}")
if self.causal_temporal_attention_mask_type == "causal":
# 1. vanilla causal mask
mask = torch.tril(torch.ones(sequence_length, sequence_length))
elif self.causal_temporal_attention_mask_type == "2-seq":
# 2. 2-seq
mask = torch.zeros(sequence_length, sequence_length)
mask[:sequence_length // 2, :sequence_length // 2] = 1
mask[-sequence_length // 2:, -sequence_length // 2:] = 1
elif self.causal_temporal_attention_mask_type == "0-prev":
# attn to the previous frame
indices = torch.arange(sequence_length)
indices_prev = indices - 1
indices_prev[0] = 0
mask = torch.zeros(sequence_length, sequence_length)
mask[:, 0] = 1.
mask[indices, indices_prev] = 1.
elif self.causal_temporal_attention_mask_type == "0":
# only attn to first frame
mask = torch.zeros(sequence_length, sequence_length)
mask[:,0] = 1
elif self.causal_temporal_attention_mask_type == "wo-self":
indices = torch.arange(sequence_length)
mask = torch.ones(sequence_length, sequence_length)
mask[indices, indices] = 0
elif self.causal_temporal_attention_mask_type == "circle":
indices = torch.arange(sequence_length)
indices_prev = indices - 1
indices_prev[0] = 0
mask = torch.eye(sequence_length)
mask[indices, indices_prev] = 1
mask[0,-1] = 1
else: raise ValueError
# for sanity check
if dim == 320: zero_rank_print(mask)
# generate attention mask fron binary values
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
mask = mask.unsqueeze(0)
mask = mask.repeat(batch_size, 1, 1)
self.causal_temporal_attention_mask = mask.to(hidden_states.device)
return self.causal_temporal_attention_mask
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
height, width = hidden_states.shape[-2:]
hidden_states = self.norm(hidden_states)
hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c")
hidden_states = self.proj_in(hidden_states)
if self.spatial_position_encoding:
video_length = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b h w) f c -> (b f) h w c", h=height, w=width)
pos_encoding = self.pos_encoder_2d(hidden_states)
pos_encoding = rearrange(pos_encoding, "(b f) h w c -> (b h w) f c", f = video_length)
hidden_states = rearrange(hidden_states, "(b f) h w c -> (b h w) f c", f=video_length)
attention_mask = self.get_causal_temporal_attention_mask(hidden_states) if self.causal_temporal_attention else attention_mask
# Transformer Blocks
for block in self.transformer_blocks:
if not self.spatial_position_encoding :
pos_encoding = None
hidden_states = block(hidden_states, pos_encoding=pos_encoding, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask)
hidden_states = self.proj_out(hidden_states)
hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width)
output = hidden_states + residual
# output = hidden_states
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 32,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
TemporalSelfAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, pos_encoding=None, encoder_hidden_states=None, attention_mask=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
if pos_encoding is not None:
hidden_states += pos_encoding
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
def get_emb(sin_inp):
"""
Gets a base embedding for one dimension with sin and cos intertwined
"""
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
class PositionalEncoding2D(nn.Module):
def __init__(self, channels):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super(PositionalEncoding2D, self).__init__()
self.org_channels = channels
channels = int(np.ceil(channels / 4) * 2)
self.channels = channels
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("cached_penc", None)
def forward(self, tensor):
"""
:param tensor: A 4d tensor of size (batch_size, x, y, ch)
:return: Positional Encoding Matrix of size (batch_size, x, y, ch)
"""
if len(tensor.shape) != 4:
raise RuntimeError("The input tensor has to be 4d!")
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
return self.cached_penc
self.cached_penc = None
batch_size, x, y, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
emb_x = get_emb(sin_inp_x).unsqueeze(1)
emb_y = get_emb(sin_inp_y)
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
tensor.type()
)
emb[:, :, : self.channels] = emb_x
emb[:, :, self.channels : 2 * self.channels] = emb_y
self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
return self.cached_penc
class PositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 32,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
# if x.size(1) < 16:
# start_idx = random.randint(0, 12)
# else:
# start_idx = 0
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TemporalSelfAttention(Attention):
def __init__(
self,
attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 32,
*args, **kwargs
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
max_len=temporal_position_encoding_max_len
) if temporal_position_encoding else None
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
# disable motion module efficient xformers to avoid bad results, don't know why
# TODO: fix this bug
pass
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
# add position encoding
hidden_states = self.pos_encoder(hidden_states)
if hasattr(self.processor, "__call__"):
return self.processor.__call__(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
else:
return self.processor(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)