# ============================================================================= # utils/selective_scan.py # ============================================================================= import torch import torch.nn.functional as F from typing import Tuple def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False): """ Selective scan function - core of Mamba's state space model Args: u: input sequence [batch, seq_len, d_inner] delta: time step [batch, seq_len, d_inner] A: state matrix [d_inner, d_state] B: input matrix [batch, seq_len, d_state] C: output matrix [batch, seq_len, d_state] D: skip connection [d_inner] z: gating [batch, seq_len, d_inner] (optional) delta_bias: bias for delta (optional) delta_softplus: whether to apply softplus to delta Returns: y: output [batch, seq_len, d_inner] """ batch_size, seq_len, d_inner = u.shape d_state = A.shape[1] if delta_bias is not None: delta = delta + delta_bias[None, None, :] if delta_softplus: delta = F.softplus(delta) # Discretization deltaA = torch.exp(delta.unsqueeze(-1) * A) # [batch, seq_len, d_inner, d_state] deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # [batch, seq_len, d_inner, d_state] # Initialize hidden state h = torch.zeros(batch_size, d_inner, d_state, device=u.device, dtype=u.dtype) outputs = [] for i in range(seq_len): h = deltaA[:, i] * h + deltaB_u[:, i] # State update y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1) # Output projection if D is not None: y = y + D * u[:, i] outputs.append(y) y = torch.stack(outputs, dim=1) # [batch, seq_len, d_inner] if z is not None: y = y * F.silu(z) return y