Spaces:
Running
Running
# Copyright (c) 2024, Tri Dao, Albert Gu. | |
from typing import Optional | |
import torch | |
from torch import nn, Tensor | |
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn | |
class Block(nn.Module): | |
def __init__( | |
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False | |
): | |
""" | |
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" | |
This Block has a slightly different structure compared to a regular | |
prenorm Transformer block. | |
The standard block is: LN -> MHA/MLP -> Add. | |
[Ref: https://arxiv.org/abs/2002.04745] | |
Here we have: Add -> LN -> Mixer, returning both | |
the hidden_states (output of the mixer) and the residual. | |
This is purely for performance reasons, as we can fuse add and LayerNorm. | |
The residual needs to be provided (except for the very first block). | |
""" | |
super().__init__() | |
self.residual_in_fp32 = residual_in_fp32 | |
self.fused_add_norm = fused_add_norm | |
self.norm = norm_cls(dim) | |
self.mixer = mixer_cls(dim) | |
if mlp_cls is not nn.Identity: | |
self.norm2 = norm_cls(dim) | |
self.mlp = mlp_cls(dim) | |
else: | |
self.mlp = None | |
if self.fused_add_norm: | |
assert RMSNorm is not None, "RMSNorm import fails" | |
assert isinstance( | |
self.norm, (nn.LayerNorm, RMSNorm) | |
), "Only LayerNorm and RMSNorm are supported for fused_add_norm" | |
def forward( | |
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs | |
): | |
r"""Pass the input through the encoder layer. | |
Args: | |
hidden_states: the sequence to the encoder layer (required). | |
residual: hidden_states = Mixer(LN(residual)) | |
""" | |
if not self.fused_add_norm: | |
residual = (hidden_states + residual) if residual is not None else hidden_states | |
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) | |
if self.residual_in_fp32: | |
residual = residual.to(torch.float32) | |
else: | |
hidden_states, residual = layer_norm_fn( | |
hidden_states, | |
self.norm.weight, | |
self.norm.bias, | |
residual=residual, | |
prenorm=True, | |
residual_in_fp32=self.residual_in_fp32, | |
eps=self.norm.eps, | |
is_rms_norm=isinstance(self.norm, RMSNorm) | |
) | |
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs) | |
if self.mlp is not None: | |
if not self.fused_add_norm: | |
residual = hidden_states + residual | |
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) | |
if self.residual_in_fp32: | |
residual = residual.to(torch.float32) | |
else: | |
hidden_states, residual = layer_norm_fn( | |
hidden_states, | |
self.norm2.weight, | |
self.norm2.bias, | |
residual=residual, | |
prenorm=True, | |
residual_in_fp32=self.residual_in_fp32, | |
eps=self.norm2.eps, | |
is_rms_norm=isinstance(self.norm2, RMSNorm) | |
) | |
hidden_states = self.mlp(hidden_states) | |
return hidden_states, residual | |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) | |