File size: 1,977 Bytes
3fb2fb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# =============================================================================
# utils/selective_scan.py
# =============================================================================
import torch
import torch.nn.functional as F
from typing import Tuple

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False):
    """

    Selective scan function - core of Mamba's state space model

    

    Args:

        u: input sequence [batch, seq_len, d_inner]

        delta: time step [batch, seq_len, d_inner]

        A: state matrix [d_inner, d_state]

        B: input matrix [batch, seq_len, d_state]

        C: output matrix [batch, seq_len, d_state]

        D: skip connection [d_inner]

        z: gating [batch, seq_len, d_inner] (optional)

        delta_bias: bias for delta (optional)

        delta_softplus: whether to apply softplus to delta

    

    Returns:

        y: output [batch, seq_len, d_inner]

    """
    batch_size, seq_len, d_inner = u.shape
    d_state = A.shape[1]
    
    if delta_bias is not None:
        delta = delta + delta_bias[None, None, :]
    
    if delta_softplus:
        delta = F.softplus(delta)
    
    # Discretization
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # [batch, seq_len, d_inner, d_state]
    deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1)  # [batch, seq_len, d_inner, d_state]
    
    # Initialize hidden state
    h = torch.zeros(batch_size, d_inner, d_state, device=u.device, dtype=u.dtype)
    
    outputs = []
    for i in range(seq_len):
        h = deltaA[:, i] * h + deltaB_u[:, i]  # State update
        y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1)  # Output projection
        if D is not None:
            y = y + D * u[:, i]
        outputs.append(y)
    
    y = torch.stack(outputs, dim=1)  # [batch, seq_len, d_inner]
    
    if z is not None:
        y = y * F.silu(z)
    
    return y