Debito's picture
Upload 3 files
3fb2fb4 verified
# =============================================================================
# utils/conv_layer.py
# =============================================================================
import torch
import torch.nn as nn
class Mamba1DConv(nn.Module):
def __init__(self, d_inner: int, d_conv: int = 4, bias: bool = True):
super().__init__()
self.d_conv = d_conv
self.conv1d = nn.Conv1d(
in_channels=d_inner,
out_channels=d_inner,
kernel_size=d_conv,
bias=bias,
groups=d_inner, # Depthwise convolution
padding=d_conv - 1
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, seq_len, d_inner]
Returns:
x: [batch, seq_len, d_inner]
"""
# Conv1d expects [batch, channels, seq_len]
x = x.transpose(1, 2) # [batch, d_inner, seq_len]
x = self.conv1d(x)
x = x[:, :, :-(self.d_conv-1)] # Remove padding
x = x.transpose(1, 2) # [batch, seq_len, d_inner]
return x