Spaces:
Sleeping
Sleeping
# ============================================================================= | |
# core/mamba.py | |
# ============================================================================= | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from core.stateSpace import StateSpaceModel | |
from utils.conv_layer import Mamba1DConv | |
class RMSNorm(nn.Module): | |
def __init__(self, d_model: int, eps: float = 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(d_model)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
norm = x.norm(dim=-1, keepdim=True) * (x.shape[-1] ** -0.5) | |
return x / (norm + self.eps) * self.weight | |
class MambaBlock(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
# Projections | |
self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) | |
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) | |
# Convolution for local context | |
self.conv1d = Mamba1DConv(config.d_inner, config.d_conv, config.conv_bias) | |
# State space model | |
self.ssm = StateSpaceModel( | |
d_inner=config.d_inner, | |
d_state=config.d_state, | |
dt_rank=config.dt_rank, | |
bias=config.bias | |
) | |
# Activation | |
self.act = F.silu | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
x: [batch, seq_len, d_model] | |
Returns: | |
output: [batch, seq_len, d_model] | |
""" | |
batch_size, seq_len, d_model = x.shape | |
# Input projection | |
xz = self.in_proj(x) # [batch, seq_len, 2*d_inner] | |
x, z = xz.chunk(2, dim=-1) # Each [batch, seq_len, d_inner] | |
# Apply convolution | |
x = self.act(self.conv1d(x)) | |
# Apply state space model | |
y = self.ssm(x) | |
# Apply gating with z | |
y = y * self.act(z) | |
# Output projection | |
output = self.out_proj(y) | |
return output | |
class MambaLayer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.norm = RMSNorm(config.d_model) | |
self.mamba_block = MambaBlock(config) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# Pre-norm architecture | |
residual = x | |
x = self.norm(x) | |
x = self.mamba_block(x) | |
return x + residual |