Spaces:
Running
on
Zero
Running
on
Zero
from einops import rearrange | |
import torch | |
from modules.Utilities import util | |
import torch.nn as nn | |
from modules.Attention import Attention | |
from modules.Device import Device | |
from modules.cond import Activation | |
from modules.cond import cast | |
from modules.sample import sampling_util | |
if Device.xformers_enabled(): | |
pass | |
ops = cast.disable_weight_init | |
_ATTN_PRECISION = "fp32" | |
class FeedForward(nn.Module): | |
"""#### FeedForward neural network module. | |
#### Args: | |
- `dim` (int): The input dimension. | |
- `dim_out` (int, optional): The output dimension. Defaults to None. | |
- `mult` (int, optional): The multiplier for the inner dimension. Defaults to 4. | |
- `glu` (bool, optional): Whether to use Gated Linear Units. Defaults to False. | |
- `dropout` (float, optional): The dropout rate. Defaults to 0.0. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `device` (torch.device, optional): The device. Defaults to None. | |
- `operations` (object, optional): The operations module. Defaults to `ops`. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
dim_out: int = None, | |
mult: int = 4, | |
glu: bool = False, | |
dropout: float = 0.0, | |
dtype: torch.dtype = None, | |
device: torch.device = None, | |
operations: object = ops, | |
): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
dim_out = util.default(dim_out, dim) | |
project_in = ( | |
nn.Sequential( | |
operations.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU() | |
) | |
if not glu | |
else Activation.GEGLU(dim, inner_dim) | |
) | |
self.net = nn.Sequential( | |
project_in, | |
nn.Dropout(dropout), | |
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass of the FeedForward network. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
return self.net(x) | |
class BasicTransformerBlock(nn.Module): | |
"""#### Basic Transformer block. | |
#### Args: | |
- `dim` (int): The input dimension. | |
- `n_heads` (int): The number of attention heads. | |
- `d_head` (int): The dimension of each attention head. | |
- `dropout` (float, optional): The dropout rate. Defaults to 0.0. | |
- `context_dim` (int, optional): The context dimension. Defaults to None. | |
- `gated_ff` (bool, optional): Whether to use Gated FeedForward. Defaults to True. | |
- `checkpoint` (bool, optional): Whether to use checkpointing. Defaults to True. | |
- `ff_in` (bool, optional): Whether to use FeedForward input. Defaults to False. | |
- `inner_dim` (int, optional): The inner dimension. Defaults to None. | |
- `disable_self_attn` (bool, optional): Whether to disable self-attention. Defaults to False. | |
- `disable_temporal_crossattention` (bool, optional): Whether to disable temporal cross-attention. Defaults to False. | |
- `switch_temporal_ca_to_sa` (bool, optional): Whether to switch temporal cross-attention to self-attention. Defaults to False. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `device` (torch.device, optional): The device. Defaults to None. | |
- `operations` (object, optional): The operations module. Defaults to `ops`. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
n_heads: int, | |
d_head: int, | |
dropout: float = 0.0, | |
context_dim: int = None, | |
gated_ff: bool = True, | |
checkpoint: bool = True, | |
ff_in: bool = False, | |
inner_dim: int = None, | |
disable_self_attn: bool = False, | |
disable_temporal_crossattention: bool = False, | |
switch_temporal_ca_to_sa: bool = False, | |
dtype: torch.dtype = None, | |
device: torch.device = None, | |
operations: object = ops, | |
): | |
super().__init__() | |
self.ff_in = ff_in or inner_dim is not None | |
if inner_dim is None: | |
inner_dim = dim | |
self.is_res = inner_dim == dim | |
self.disable_self_attn = disable_self_attn | |
self.attn1 = Attention.CrossAttention( | |
query_dim=inner_dim, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout, | |
context_dim=context_dim if self.disable_self_attn else None, | |
dtype=dtype, | |
device=device, | |
operations=operations, | |
) # is a self-attention if not self.disable_self_attn | |
self.ff = FeedForward( | |
inner_dim, | |
dim_out=dim, | |
dropout=dropout, | |
glu=gated_ff, | |
dtype=dtype, | |
device=device, | |
operations=operations, | |
) | |
context_dim_attn2 = None | |
if not switch_temporal_ca_to_sa: | |
context_dim_attn2 = context_dim | |
self.attn2 = Attention.CrossAttention( | |
query_dim=inner_dim, | |
context_dim=context_dim_attn2, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout, | |
dtype=dtype, | |
device=device, | |
operations=operations, | |
) # is self-attn if context is none | |
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) | |
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) | |
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) | |
self.checkpoint = checkpoint | |
self.n_heads = n_heads | |
self.d_head = d_head | |
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa | |
def forward( | |
self, | |
x: torch.Tensor, | |
context: torch.Tensor = None, | |
transformer_options: dict = {}, | |
) -> torch.Tensor: | |
"""#### Forward pass of the Basic Transformer block. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `context` (torch.Tensor, optional): The context tensor. Defaults to None. | |
- `transformer_options` (dict, optional): Additional transformer options. Defaults to {}. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
return sampling_util.checkpoint( | |
self._forward, | |
(x, context, transformer_options), | |
self.parameters(), | |
self.checkpoint, | |
) | |
def _forward( | |
self, | |
x: torch.Tensor, | |
context: torch.Tensor = None, | |
transformer_options: dict = {}, | |
) -> torch.Tensor: | |
"""#### Internal forward pass of the Basic Transformer block. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `context` (torch.Tensor, optional): The context tensor. Defaults to None. | |
- `transformer_options` (dict, optional): Additional transformer options. Defaults to {}. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
extra_options = {} | |
block = transformer_options.get("block", None) | |
block_index = transformer_options.get("block_index", 0) | |
transformer_patches_replace = {} | |
for k in transformer_options: | |
extra_options[k] = transformer_options[k] | |
extra_options["n_heads"] = self.n_heads | |
extra_options["dim_head"] = self.d_head | |
n = self.norm1(x) | |
context_attn1 = None | |
value_attn1 = None | |
transformer_block = (block[0], block[1], block_index) | |
attn1_replace_patch = transformer_patches_replace.get("attn1", {}) | |
block_attn1 = transformer_block | |
if block_attn1 not in attn1_replace_patch: | |
block_attn1 = block | |
n = self.attn1(n, context=context_attn1, value=value_attn1) | |
x += n | |
if self.attn2 is not None: | |
n = self.norm2(x) | |
context_attn2 = context | |
value_attn2 = None | |
attn2_replace_patch = transformer_patches_replace.get("attn2", {}) | |
block_attn2 = transformer_block | |
if block_attn2 not in attn2_replace_patch: | |
block_attn2 = block | |
n = self.attn2(n, context=context_attn2, value=value_attn2) | |
x += n | |
if self.is_res: | |
x_skip = x | |
x = self.ff(self.norm3(x)) | |
if self.is_res: | |
x += x_skip | |
return x | |
class SpatialTransformer(nn.Module): | |
"""#### Spatial Transformer module. | |
#### Args: | |
- `in_channels` (int): The number of input channels. | |
- `n_heads` (int): The number of attention heads. | |
- `d_head` (int): The dimension of each attention head. | |
- `depth` (int, optional): The depth of the transformer. Defaults to 1. | |
- `dropout` (float, optional): The dropout rate. Defaults to 0.0. | |
- `context_dim` (int, optional): The context dimension. Defaults to None. | |
- `disable_self_attn` (bool, optional): Whether to disable self-attention. Defaults to False. | |
- `use_linear` (bool, optional): Whether to use linear projections. Defaults to False. | |
- `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to True. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `device` (torch.device, optional): The device. Defaults to None. | |
- `operations` (object, optional): The operations module. Defaults to `ops`. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
n_heads: int, | |
d_head: int, | |
depth: int = 1, | |
dropout: float = 0.0, | |
context_dim: int = None, | |
disable_self_attn: bool = False, | |
use_linear: bool = False, | |
use_checkpoint: bool = True, | |
dtype: torch.dtype = None, | |
device: torch.device = None, | |
operations: object = ops, | |
): | |
super().__init__() | |
if util.exists(context_dim) and not isinstance(context_dim, list): | |
context_dim = [context_dim] * depth | |
self.in_channels = in_channels | |
inner_dim = n_heads * d_head | |
self.norm = operations.GroupNorm( | |
num_groups=32, | |
num_channels=in_channels, | |
eps=1e-6, | |
affine=True, | |
dtype=dtype, | |
device=device, | |
) | |
if not use_linear: | |
self.proj_in = operations.Conv2d( | |
in_channels, | |
inner_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
dtype=dtype, | |
device=device, | |
) | |
else: | |
self.proj_in = operations.Linear( | |
in_channels, inner_dim, dtype=dtype, device=device | |
) | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
inner_dim, | |
n_heads, | |
d_head, | |
dropout=dropout, | |
context_dim=context_dim[d], | |
disable_self_attn=disable_self_attn, | |
checkpoint=use_checkpoint, | |
dtype=dtype, | |
device=device, | |
operations=operations, | |
) | |
for d in range(depth) | |
] | |
) | |
if not use_linear: | |
self.proj_out = operations.Conv2d( | |
inner_dim, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
dtype=dtype, | |
device=device, | |
) | |
else: | |
self.proj_out = operations.Linear( | |
in_channels, inner_dim, dtype=dtype, device=device | |
) | |
self.use_linear = use_linear | |
def forward( | |
self, | |
x: torch.Tensor, | |
context: torch.Tensor = None, | |
transformer_options: dict = {}, | |
) -> torch.Tensor: | |
"""#### Forward pass of the Spatial Transformer. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `context` (torch.Tensor, optional): The context tensor. Defaults to None. | |
- `transformer_options` (dict, optional): Additional transformer options. Defaults to {}. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
# note: if no context is given, cross-attention defaults to self-attention | |
if not isinstance(context, list): | |
context = [context] * len(self.transformer_blocks) | |
b, c, h, w = x.shape | |
x_in = x | |
x = self.norm(x) | |
if not self.use_linear: | |
x = self.proj_in(x) | |
x = rearrange(x, "b c h w -> b (h w) c").contiguous() | |
if self.use_linear: | |
x = self.proj_in(x) | |
for i, block in enumerate(self.transformer_blocks): | |
transformer_options["block_index"] = i | |
x = block(x, context=context[i], transformer_options=transformer_options) | |
if self.use_linear: | |
x = self.proj_out(x) | |
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() | |
if not self.use_linear: | |
x = self.proj_out(x) | |
return x + x_in | |
def count_blocks(state_dict_keys: list, prefix_string: str) -> int: | |
"""#### Count the number of blocks in a state dictionary. | |
#### Args: | |
- `state_dict_keys` (list): The list of state dictionary keys. | |
- `prefix_string` (str): The prefix string to match. | |
#### Returns: | |
- `int`: The number of blocks. | |
""" | |
count = 0 | |
while True: | |
c = False | |
for k in state_dict_keys: | |
if k.startswith(prefix_string.format(count)): | |
c = True | |
break | |
if c is False: | |
break | |
count += 1 | |
return count | |
def calculate_transformer_depth( | |
prefix: str, state_dict_keys: list, state_dict: dict | |
) -> tuple: | |
"""#### Calculate the depth of a transformer. | |
#### Args: | |
- `prefix` (str): The prefix string. | |
- `state_dict_keys` (list): The list of state dictionary keys. | |
- `state_dict` (dict): The state dictionary. | |
#### Returns: | |
- `tuple`: The transformer depth, context dimension, use of linear in transformer, and time stack. | |
""" | |
context_dim = None | |
use_linear_in_transformer = False | |
transformer_prefix = prefix + "1.transformer_blocks." | |
transformer_keys = sorted( | |
list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)) | |
) | |
if len(transformer_keys) > 0: | |
last_transformer_depth = count_blocks( | |
state_dict_keys, transformer_prefix + "{}" | |
) | |
context_dim = state_dict[ | |
"{}0.attn2.to_k.weight".format(transformer_prefix) | |
].shape[1] | |
use_linear_in_transformer = ( | |
len(state_dict["{}1.proj_in.weight".format(prefix)].shape) == 2 | |
) | |
time_stack = ( | |
"{}1.time_stack.0.attn1.to_q.weight".format(prefix) in state_dict | |
or "{}1.time_mix_blocks.0.attn1.to_q.weight".format(prefix) in state_dict | |
) | |
return ( | |
last_transformer_depth, | |
context_dim, | |
use_linear_in_transformer, | |
time_stack, | |
) | |
return None | |