Spaces:
Running
Running
# Copyright (c) 2024, Tri Dao, Albert Gu. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
try: | |
from causal_conv1d import causal_conv1d_fn | |
except ImportError: | |
causal_conv1d_fn = None | |
try: | |
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm | |
except ImportError: | |
RMSNormGated, LayerNorm = None, None | |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined | |
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined | |
class Mamba2Simple(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
d_state=64, | |
d_conv=4, | |
conv_init=None, | |
expand=2, | |
headdim=128, | |
ngroups=1, | |
A_init_range=(1, 16), | |
dt_min=0.001, | |
dt_max=0.1, | |
dt_init_floor=1e-4, | |
dt_limit=(0.0, float("inf")), | |
learnable_init_states=False, | |
activation="swish", | |
bias=False, | |
conv_bias=True, | |
# Fused kernel and sharding options | |
chunk_size=256, | |
use_mem_eff_path=True, | |
layer_idx=None, # Absorb kwarg for general module | |
device=None, | |
dtype=None, | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.d_model = d_model | |
self.d_state = d_state | |
self.d_conv = d_conv | |
self.conv_init = conv_init | |
self.expand = expand | |
self.d_inner = self.expand * self.d_model | |
self.headdim = headdim | |
self.ngroups = ngroups | |
assert self.d_inner % self.headdim == 0 | |
self.nheads = self.d_inner // self.headdim | |
self.dt_limit = dt_limit | |
self.learnable_init_states = learnable_init_states | |
self.activation = activation | |
self.chunk_size = chunk_size | |
self.use_mem_eff_path = use_mem_eff_path | |
self.layer_idx = layer_idx | |
# Order: [z, x, B, C, dt] | |
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads | |
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) | |
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state | |
self.conv1d = nn.Conv1d( | |
in_channels=conv_dim, | |
out_channels=conv_dim, | |
bias=conv_bias, | |
kernel_size=d_conv, | |
groups=conv_dim, | |
padding=d_conv - 1, | |
**factory_kwargs, | |
) | |
if self.conv_init is not None: | |
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) | |
# self.conv1d.weight._no_weight_decay = True | |
if self.learnable_init_states: | |
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)) | |
self.init_states._no_weight_decay = True | |
self.act = nn.SiLU() | |
# Initialize log dt bias | |
dt = torch.exp( | |
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) | |
+ math.log(dt_min) | |
) | |
dt = torch.clamp(dt, min=dt_init_floor) | |
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 | |
inv_dt = dt + torch.log(-torch.expm1(-dt)) | |
self.dt_bias = nn.Parameter(inv_dt) | |
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check | |
# name.endswith("bias") in param_grouping.py | |
self.dt_bias._no_weight_decay = True | |
# A parameter | |
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] | |
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) | |
A_log = torch.log(A).to(dtype=dtype) | |
self.A_log = nn.Parameter(A_log) | |
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True) | |
self.A_log._no_weight_decay = True | |
# D "skip" parameter | |
self.D = nn.Parameter(torch.ones(self.nheads, device=device)) | |
self.D._no_weight_decay = True | |
# Extra normalization layer right before output projection | |
assert RMSNormGated is not None | |
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs) | |
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) | |
def forward(self, u, seq_idx=None): | |
""" | |
u: (B, L, D) | |
Returns: same shape as u | |
""" | |
batch, seqlen, dim = u.shape | |
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) | |
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state) | |
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None | |
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) | |
if self.use_mem_eff_path: | |
# Fully fused path | |
out = mamba_split_conv1d_scan_combined( | |
zxbcdt, | |
rearrange(self.conv1d.weight, "d 1 w -> d w"), | |
self.conv1d.bias, | |
self.dt_bias, | |
A, | |
D=self.D, | |
chunk_size=self.chunk_size, | |
seq_idx=seq_idx, | |
activation=self.activation, | |
rmsnorm_weight=self.norm.weight, | |
rmsnorm_eps=self.norm.eps, | |
outproj_weight=self.out_proj.weight, | |
outproj_bias=self.out_proj.bias, | |
headdim=self.headdim, | |
ngroups=self.ngroups, | |
norm_before_gate=False, | |
initial_states=initial_states, | |
**dt_limit_kwargs, | |
) | |
else: | |
z, xBC, dt = torch.split( | |
zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1 | |
) | |
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads) | |
assert self.activation in ["silu", "swish"] | |
# 1D Convolution | |
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: | |
xBC = self.act( | |
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2) | |
) # (B, L, self.d_inner + 2 * ngroups * d_state) | |
xBC = xBC[:, :seqlen, :] | |
else: | |
xBC = causal_conv1d_fn( | |
x=xBC.transpose(1, 2), | |
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), | |
bias=self.conv1d.bias, | |
activation=self.activation, | |
).transpose(1, 2) | |
# Split into 3 main branches: X, B, C | |
# These correspond to V, K, Q respectively in the SSM/attention duality | |
x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) | |
y = mamba_chunk_scan_combined( | |
rearrange(x, "b l (h p) -> b l h p", p=self.headdim), | |
dt, | |
A, | |
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), | |
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), | |
chunk_size=self.chunk_size, | |
D=self.D, | |
z=None, | |
seq_idx=seq_idx, | |
initial_states=initial_states, | |
**dt_limit_kwargs, | |
) | |
y = rearrange(y, "b l h p -> b l (h p)") | |
# Multiply "gate" branch and apply extra normalization layer | |
y = self.norm(y, z) | |
out = self.out_proj(y) | |
return out | |