Spaces:
Sleeping
Sleeping
# ============================================================================= | |
# core/stateSpace.py | |
# ============================================================================= | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from utils.selective_scan import selective_scan_fn | |
class StateSpaceModel(nn.Module): | |
def __init__(self, d_inner: int, d_state: int = 16, dt_rank: int = None, bias: bool = False): | |
super().__init__() | |
self.d_inner = d_inner | |
self.d_state = d_state | |
self.dt_rank = dt_rank if dt_rank is not None else max(16, d_inner // 16) | |
# State space parameters | |
self.A_log = nn.Parameter(torch.randn(d_inner, d_state)) | |
self.D = nn.Parameter(torch.ones(d_inner)) | |
# Projection layers | |
self.x_proj = nn.Linear(d_inner, self.dt_rank + d_state * 2, bias=False) | |
self.dt_proj = nn.Linear(self.dt_rank, d_inner, bias=True) | |
# Initialize parameters | |
self._init_parameters() | |
def _init_parameters(self): | |
# Initialize A with negative values for stability | |
nn.init.uniform_(self.A_log, -4.0, -1.0) | |
# Initialize dt_proj bias to encourage large dt values | |
dt_init_std = self.dt_rank**-0.5 | |
with torch.no_grad(): | |
self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
x: [batch, seq_len, d_inner] | |
Returns: | |
y: [batch, seq_len, d_inner] | |
""" | |
batch_size, seq_len, d_inner = x.shape | |
# Project x to get delta, B, C | |
x_dbl = self.x_proj(x) # [batch, seq_len, dt_rank + 2*d_state] | |
delta, B, C = torch.split( | |
x_dbl, | |
[self.dt_rank, self.d_state, self.d_state], | |
dim=-1 | |
) | |
# Project delta to d_inner | |
delta = self.dt_proj(delta) # [batch, seq_len, d_inner] | |
# Get A matrix (ensure it's negative for stability) | |
A = -torch.exp(self.A_log) # [d_inner, d_state] | |
# Apply selective scan | |
y = selective_scan_fn( | |
u=x, | |
delta=delta, | |
A=A, | |
B=B, | |
C=C, | |
D=self.D, | |
delta_softplus=True | |
) | |
return y |