Debito's picture
Upload 8 files
055a9c8 verified
# =============================================================================
# core/stateSpace.py
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.selective_scan import selective_scan_fn
class StateSpaceModel(nn.Module):
def __init__(self, d_inner: int, d_state: int = 16, dt_rank: int = None, bias: bool = False):
super().__init__()
self.d_inner = d_inner
self.d_state = d_state
self.dt_rank = dt_rank if dt_rank is not None else max(16, d_inner // 16)
# State space parameters
self.A_log = nn.Parameter(torch.randn(d_inner, d_state))
self.D = nn.Parameter(torch.ones(d_inner))
# Projection layers
self.x_proj = nn.Linear(d_inner, self.dt_rank + d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, d_inner, bias=True)
# Initialize parameters
self._init_parameters()
def _init_parameters(self):
# Initialize A with negative values for stability
nn.init.uniform_(self.A_log, -4.0, -1.0)
# Initialize dt_proj bias to encourage large dt values
dt_init_std = self.dt_rank**-0.5
with torch.no_grad():
self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, seq_len, d_inner]
Returns:
y: [batch, seq_len, d_inner]
"""
batch_size, seq_len, d_inner = x.shape
# Project x to get delta, B, C
x_dbl = self.x_proj(x) # [batch, seq_len, dt_rank + 2*d_state]
delta, B, C = torch.split(
x_dbl,
[self.dt_rank, self.d_state, self.d_state],
dim=-1
)
# Project delta to d_inner
delta = self.dt_proj(delta) # [batch, seq_len, d_inner]
# Get A matrix (ensure it's negative for stability)
A = -torch.exp(self.A_log) # [d_inner, d_state]
# Apply selective scan
y = selective_scan_fn(
u=x,
delta=delta,
A=A,
B=B,
C=C,
D=self.D,
delta_softplus=True
)
return y