import torch import torch.nn as nn import logging from modules.Utilities import util from modules.Attention import AttentionMethods from modules.Device import Device from modules.cond import cast def Normalize( in_channels: int, dtype: torch.dtype = None, device: torch.device = None ) -> torch.nn.GroupNorm: """#### Normalize the input channels. #### Args: - `in_channels` (int): The input channels. - `dtype` (torch.dtype, optional): The data type. Defaults to `None`. - `device` (torch.device, optional): The device. Defaults to `None`. #### Returns: - `torch.nn.GroupNorm`: The normalized input channels """ return torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device, ) if Device.xformers_enabled(): logging.info("Using xformers cross attention") optimized_attention = AttentionMethods.attention_xformers else: logging.info("Using pytorch cross attention") optimized_attention = AttentionMethods.attention_pytorch optimized_attention_masked = optimized_attention def optimized_attention_for_device() -> AttentionMethods.attention_pytorch: """#### Get the optimized attention for a device. #### Returns: - `function`: The optimized attention function. """ return AttentionMethods.attention_pytorch class CrossAttention(nn.Module): """#### Cross attention module, which applies attention across the query and context. #### Args: - `query_dim` (int): The query dimension. - `context_dim` (int, optional): The context dimension. Defaults to `None`. - `heads` (int, optional): The number of heads. Defaults to `8`. - `dim_head` (int, optional): The head dimension. Defaults to `64`. - `dropout` (float, optional): The dropout rate. Defaults to `0.0`. - `dtype` (torch.dtype, optional): The data type. Defaults to `None`. - `device` (torch.device, optional): The device. Defaults to `None`. - `operations` (cast.disable_weight_init, optional): The operations. Defaults to `cast.disable_weight_init`. """ def __init__( self, query_dim: int, context_dim: int = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, dtype: torch.dtype = None, device: torch.device = None, operations: cast.disable_weight_init = cast.disable_weight_init, ): super().__init__() inner_dim = dim_head * heads context_dim = util.default(context_dim, query_dim) self.heads = heads self.dim_head = dim_head self.to_q = operations.Linear( query_dim, inner_dim, bias=False, dtype=dtype, device=device ) self.to_k = operations.Linear( context_dim, inner_dim, bias=False, dtype=dtype, device=device ) self.to_v = operations.Linear( context_dim, inner_dim, bias=False, dtype=dtype, device=device ) self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout), ) def forward( self, x: torch.Tensor, context: torch.Tensor = None, value: torch.Tensor = None, mask: torch.Tensor = None, ) -> torch.Tensor: """#### Forward pass of the cross attention module. #### Args: - `x` (torch.Tensor): The input tensor. - `context` (torch.Tensor, optional): The context tensor. Defaults to `None`. - `value` (torch.Tensor, optional): The value tensor. Defaults to `None`. - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`. #### Returns: - `torch.Tensor`: The output tensor. """ q = self.to_q(x) context = util.default(context, x) k = self.to_k(context) v = self.to_v(context) out = optimized_attention(q, k, v, self.heads) return self.to_out(out) class AttnBlock(nn.Module): """#### Attention block, which applies attention to the input tensor. #### Args: - `in_channels` (int): The input channels. """ def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = cast.disable_weight_init.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = cast.disable_weight_init.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = cast.disable_weight_init.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = cast.disable_weight_init.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) if Device.xformers_enabled_vae(): logging.info("Using xformers attention in VAE") self.optimized_attention = AttentionMethods.xformers_attention else: logging.info("Using pytorch attention in VAE") self.optimized_attention = AttentionMethods.pytorch_attention def forward(self, x: torch.Tensor) -> torch.Tensor: """#### Forward pass of the attention block. #### Args: - `x` (torch.Tensor): The input tensor. #### Returns: - `torch.Tensor`: The output tensor. """ h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) h_ = self.optimized_attention(q, k, v) h_ = self.proj_out(h_) return x + h_ def make_attn(in_channels: int, attn_type: str = "vanilla") -> AttnBlock: """#### Make an attention block. #### Args: - `in_channels` (int): The input channels. - `attn_type` (str, optional): The attention type. Defaults to "vanilla". #### Returns: - `AttnBlock`: A class instance of the attention block. """ return AttnBlock(in_channels)