Spaces:
Sleeping
Sleeping
# ============================================================================= | |
# utils/conv_layer.py | |
# ============================================================================= | |
import torch | |
import torch.nn as nn | |
class Mamba1DConv(nn.Module): | |
def __init__(self, d_inner: int, d_conv: int = 4, bias: bool = True): | |
super().__init__() | |
self.d_conv = d_conv | |
self.conv1d = nn.Conv1d( | |
in_channels=d_inner, | |
out_channels=d_inner, | |
kernel_size=d_conv, | |
bias=bias, | |
groups=d_inner, # Depthwise convolution | |
padding=d_conv - 1 | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
x: [batch, seq_len, d_inner] | |
Returns: | |
x: [batch, seq_len, d_inner] | |
""" | |
# Conv1d expects [batch, channels, seq_len] | |
x = x.transpose(1, 2) # [batch, d_inner, seq_len] | |
x = self.conv1d(x) | |
x = x[:, :, :-(self.d_conv-1)] # Remove padding | |
x = x.transpose(1, 2) # [batch, seq_len, d_inner] | |
return x |