Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import logging | |
from modules.Utilities import util | |
from modules.Attention import AttentionMethods | |
from modules.Device import Device | |
from modules.cond import cast | |
def Normalize( | |
in_channels: int, dtype: torch.dtype = None, device: torch.device = None | |
) -> torch.nn.GroupNorm: | |
"""#### Normalize the input channels. | |
#### Args: | |
- `in_channels` (int): The input channels. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to `None`. | |
- `device` (torch.device, optional): The device. Defaults to `None`. | |
#### Returns: | |
- `torch.nn.GroupNorm`: The normalized input channels | |
""" | |
return torch.nn.GroupNorm( | |
num_groups=32, | |
num_channels=in_channels, | |
eps=1e-6, | |
affine=True, | |
dtype=dtype, | |
device=device, | |
) | |
if Device.xformers_enabled(): | |
logging.info("Using xformers cross attention") | |
optimized_attention = AttentionMethods.attention_xformers | |
else: | |
logging.info("Using pytorch cross attention") | |
optimized_attention = AttentionMethods.attention_pytorch | |
optimized_attention_masked = optimized_attention | |
def optimized_attention_for_device() -> AttentionMethods.attention_pytorch: | |
"""#### Get the optimized attention for a device. | |
#### Returns: | |
- `function`: The optimized attention function. | |
""" | |
return AttentionMethods.attention_pytorch | |
class CrossAttention(nn.Module): | |
"""#### Cross attention module, which applies attention across the query and context. | |
#### Args: | |
- `query_dim` (int): The query dimension. | |
- `context_dim` (int, optional): The context dimension. Defaults to `None`. | |
- `heads` (int, optional): The number of heads. Defaults to `8`. | |
- `dim_head` (int, optional): The head dimension. Defaults to `64`. | |
- `dropout` (float, optional): The dropout rate. Defaults to `0.0`. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to `None`. | |
- `device` (torch.device, optional): The device. Defaults to `None`. | |
- `operations` (cast.disable_weight_init, optional): The operations. Defaults to `cast.disable_weight_init`. | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
context_dim: int = None, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
dtype: torch.dtype = None, | |
device: torch.device = None, | |
operations: cast.disable_weight_init = cast.disable_weight_init, | |
): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = util.default(context_dim, query_dim) | |
self.heads = heads | |
self.dim_head = dim_head | |
self.to_q = operations.Linear( | |
query_dim, inner_dim, bias=False, dtype=dtype, device=device | |
) | |
self.to_k = operations.Linear( | |
context_dim, inner_dim, bias=False, dtype=dtype, device=device | |
) | |
self.to_v = operations.Linear( | |
context_dim, inner_dim, bias=False, dtype=dtype, device=device | |
) | |
self.to_out = nn.Sequential( | |
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), | |
nn.Dropout(dropout), | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
context: torch.Tensor = None, | |
value: torch.Tensor = None, | |
mask: torch.Tensor = None, | |
) -> torch.Tensor: | |
"""#### Forward pass of the cross attention module. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `context` (torch.Tensor, optional): The context tensor. Defaults to `None`. | |
- `value` (torch.Tensor, optional): The value tensor. Defaults to `None`. | |
- `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
q = self.to_q(x) | |
context = util.default(context, x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
out = optimized_attention(q, k, v, self.heads) | |
return self.to_out(out) | |
class AttnBlock(nn.Module): | |
"""#### Attention block, which applies attention to the input tensor. | |
#### Args: | |
- `in_channels` (int): The input channels. | |
""" | |
def __init__(self, in_channels: int): | |
super().__init__() | |
self.in_channels = in_channels | |
self.norm = Normalize(in_channels) | |
self.q = cast.disable_weight_init.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.k = cast.disable_weight_init.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.v = cast.disable_weight_init.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.proj_out = cast.disable_weight_init.Conv2d( | |
in_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
) | |
if Device.xformers_enabled_vae(): | |
logging.info("Using xformers attention in VAE") | |
self.optimized_attention = AttentionMethods.xformers_attention | |
else: | |
logging.info("Using pytorch attention in VAE") | |
self.optimized_attention = AttentionMethods.pytorch_attention | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass of the attention block. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
h_ = x | |
h_ = self.norm(h_) | |
q = self.q(h_) | |
k = self.k(h_) | |
v = self.v(h_) | |
h_ = self.optimized_attention(q, k, v) | |
h_ = self.proj_out(h_) | |
return x + h_ | |
def make_attn(in_channels: int, attn_type: str = "vanilla") -> AttnBlock: | |
"""#### Make an attention block. | |
#### Args: | |
- `in_channels` (int): The input channels. | |
- `attn_type` (str, optional): The attention type. Defaults to "vanilla". | |
#### Returns: | |
- `AttnBlock`: A class instance of the attention block. | |
""" | |
return AttnBlock(in_channels) | |