# Copyright (c) ByteDance, Inc. and its affiliates. # Copyright (c) Chutong Meng # # This source code is licensed under the CC BY-NC license found in the # LICENSE file in the root directory of this source tree. # Based on AudioDec (https://github.com/facebookresearch/AudioDec) import torch.nn as nn class Conv1d1x1(nn.Conv1d): """1x1 Conv1d.""" def __init__(self, in_channels, out_channels, bias=True): super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias) class Conv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = -1, dilation: int = 1, groups: int = 1, bias: bool = True ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if padding < 0: padding = (kernel_size - 1) // 2 * dilation self.dilation = dilation self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) def forward(self, x): """ Args: x (Tensor): Float tensor variable with the shape (B, C, T). Returns: Tensor: Float tensor variable with the shape (B, C, T). """ x = self.conv(x) return x class ConvTranspose1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding=-1, output_padding=-1, groups=1, bias=True, ): super().__init__() if padding < 0: padding = (stride + 1) // 2 if output_padding < 0: output_padding = 1 if stride % 2 else 0 self.deconv = nn.ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, ) def forward(self, x): """ Args: x (Tensor): Float tensor variable with the shape (B, C, T). Returns: Tensor: Float tensor variable with the shape (B, C', T'). """ x = self.deconv(x) return x