Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| class MossFormerDecoder(nn.ConvTranspose1d): | |
| """A decoder layer that consists of ConvTranspose1d. | |
| Arguments | |
| --------- | |
| kernel_size : int | |
| Length of filters. | |
| in_channels : int | |
| Number of input channels. | |
| out_channels : int | |
| Number of output channels. | |
| Example | |
| --------- | |
| >>> x = torch.randn(2, 100, 1000) | |
| >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) | |
| >>> h = decoder(x) | |
| >>> h.shape | |
| torch.Size([2, 1003]) | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(MossFormerDecoder, self).__init__(*args, **kwargs) | |
| def forward(self, x): | |
| """Return the decoded output. | |
| Arguments | |
| --------- | |
| x : torch.Tensor | |
| Input tensor with dimensionality [B, N, L]. | |
| where, B = Batchsize, | |
| N = number of filters | |
| L = time points | |
| """ | |
| if x.dim() not in [2, 3]: | |
| raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__)) | |
| x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) | |
| if torch.squeeze(x).dim() == 1: | |
| x = torch.squeeze(x, dim=1) | |
| else: | |
| x = torch.squeeze(x) | |
| return x | |