import torch import torch.nn as nn import torch.nn.functional as F from esm.layers.attention import MultiHeadAttention from esm.layers.geom_attention import ( GeometricReasoningOriginalImpl, ) from esm.utils.structure.affine3d import Affine3D def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int: # set hidden dimesion to nearest multiple of 256 after expansion ratio return int(((expansion_ratio * d_model) + 255) // 256 * 256) class SwiGLU(nn.Module): """ SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential. This module splits the input tensor along the last dimension and applies the SiLU (Swish) activation function to the first half, then multiplies it by the second half. """ def __init__(self): super(SwiGLU, self).__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return F.silu(x1) * x2 def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): return nn.Sequential( nn.LayerNorm(d_model), nn.Linear( d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias ), SwiGLU(), nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias), ) def gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool): hidden_dim = int(expansion_ratio * d_model) return nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, hidden_dim, bias=bias), nn.GELU(), nn.Linear(hidden_dim, d_model, bias=bias), ) class UnifiedTransformerBlock(nn.Module): """ A unified transformer block that can optionally incorporate geometric attention. This class defines a transformer block that can be configured to use geometric attention alongside the standard multi-head attention mechanism. It is designed to be a flexible component of transformer-based models, allowing for the integration of geometric reasoning. Parameters ---------- d_model : int The dimensionality of the input and output features of the transformer block. n_heads : int The number of attention heads in the multi-head attention mechanism. n_layers : int The number of layers in the transformer block. use_geom_attn : bool, optional Whether to use geometric attention in addition to the standard multi-head attention. Defaults to False. v_heads : int, optional The number of heads to use for the geometric attention mechanism, if enabled. Must be specified if `use_geom_attn` is True. """ def __init__( self, d_model: int, n_heads: int, use_geom_attn: bool = False, use_plain_attn: bool = True, v_heads: int | None = None, bias: bool = False, expansion_ratio: float = 4.0, residue_scaling_factor: float = 1, mask_and_zero_frameless: bool = False, qk_layernorm: bool = True, ffn_type: str = "swiglu", # swiglu | gelu ): super().__init__() self.use_plain_attn = use_plain_attn if self.use_plain_attn: self.attn = MultiHeadAttention( d_model, n_heads, bias, qk_layernorm=qk_layernorm ) self.use_geom_attn = use_geom_attn if self.use_geom_attn: if v_heads is None: raise ValueError("v_heads must be specified when use_geom_attn is True") self.geom_attn = GeometricReasoningOriginalImpl( c_s=d_model, v_heads=v_heads, bias=bias, mask_and_zero_frameless=mask_and_zero_frameless, ) if ffn_type == "swiglu": self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias) elif ffn_type == "gelu": self.ffn = gelu_ln_ffn(d_model, expansion_ratio, bias) else: raise ValueError(f"Unknown ffn_type: {ffn_type}") self.scaling_factor = residue_scaling_factor def forward( self, x: torch.Tensor, sequence_id: torch.Tensor, frames: Affine3D, frames_mask: torch.Tensor, chain_id: torch.Tensor, ) -> torch.Tensor: """ Forward pass for the UnifiedTransformerBlock. Parameters ---------- x : torch.Tensor[float] Input tensor to the transformer block, typically the output from the previous layer. sequence_id : torch.Tensor[int] Tensor containing sequence IDs for each element in the batch, used for attention masking. frames : Affine3D Affine3D containing geometric frame information for geometric attention. frames_mask : torch.Tensor[bool] Boolean mask tensor indicating valid frames for geometric attention. chain_id : torch.Tensor[int] Tensor containing chain IDs for each element, used for attention masking in geometric attention. Returns ------- torch.Tensor[float] The output tensor after applying the transformer block operations. """ if self.use_plain_attn: r1 = self.attn(x, sequence_id) x = x + r1 / self.scaling_factor if self.use_geom_attn: r2 = self.geom_attn(x, frames, frames_mask, sequence_id, chain_id) x = x + r2 / self.scaling_factor r3 = self.ffn(x) / self.scaling_factor x = x + r3 return x