mamba-encoder-swarm_app / utils /selective_scan.py
Debito's picture
Upload 3 files
3fb2fb4 verified
# =============================================================================
# utils/selective_scan.py
# =============================================================================
import torch
import torch.nn.functional as F
from typing import Tuple
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False):
"""
Selective scan function - core of Mamba's state space model
Args:
u: input sequence [batch, seq_len, d_inner]
delta: time step [batch, seq_len, d_inner]
A: state matrix [d_inner, d_state]
B: input matrix [batch, seq_len, d_state]
C: output matrix [batch, seq_len, d_state]
D: skip connection [d_inner]
z: gating [batch, seq_len, d_inner] (optional)
delta_bias: bias for delta (optional)
delta_softplus: whether to apply softplus to delta
Returns:
y: output [batch, seq_len, d_inner]
"""
batch_size, seq_len, d_inner = u.shape
d_state = A.shape[1]
if delta_bias is not None:
delta = delta + delta_bias[None, None, :]
if delta_softplus:
delta = F.softplus(delta)
# Discretization
deltaA = torch.exp(delta.unsqueeze(-1) * A) # [batch, seq_len, d_inner, d_state]
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # [batch, seq_len, d_inner, d_state]
# Initialize hidden state
h = torch.zeros(batch_size, d_inner, d_state, device=u.device, dtype=u.dtype)
outputs = []
for i in range(seq_len):
h = deltaA[:, i] * h + deltaB_u[:, i] # State update
y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1) # Output projection
if D is not None:
y = y + D * u[:, i]
outputs.append(y)
y = torch.stack(outputs, dim=1) # [batch, seq_len, d_inner]
if z is not None:
y = y * F.silu(z)
return y