# Copyright (c) ByteDance, Inc. and its affiliates. # Copyright (c) Chutong Meng # # This source code is licensed under the CC BY-NC license found in the # LICENSE file in the root directory of this source tree. # Based on AudioDec (https://github.com/facebookresearch/AudioDec) import torch import torch.nn as nn from RepCodec.repcodec.layers.conv_layer import Conv1d, ConvTranspose1d from RepCodec.repcodec.modules.residual_unit import ResidualUnit class DecoderBlock(nn.Module): """ Decoder block (no up-sampling) """ def __init__( self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True ): super().__init__() if stride == 1: self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape stride=stride, bias=bias, ) else: self.conv = ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=(2 * stride), stride=stride, bias=bias, ) self.res_units = torch.nn.ModuleList() for idx, dilation in enumerate(dilations): self.res_units += [ ResidualUnit(out_channels, out_channels, kernel_size=unit_kernel_size, dilation=dilation) ] self.num_res = len(self.res_units) def forward(self, x): x = self.conv(x) for idx in range(self.num_res): x = self.res_units[idx](x) return x class Decoder(nn.Module): def __init__( self, code_dim: int, output_channels: int, decode_channels: int, channel_ratios=(1, 1), strides=(1, 1), kernel_size=3, bias=True, block_dilations=(1, 1), unit_kernel_size=3, ): super().__init__() assert len(channel_ratios) == len(strides) self.conv1 = Conv1d( in_channels=code_dim, out_channels=int(decode_channels * channel_ratios[0]), kernel_size=kernel_size, stride=1, bias=False ) self.conv_blocks = torch.nn.ModuleList() for idx, stride in enumerate(strides): in_channels = int(decode_channels * channel_ratios[idx]) if idx < (len(channel_ratios) - 1): out_channels = int(decode_channels * channel_ratios[idx + 1]) else: out_channels = decode_channels self.conv_blocks += [ DecoderBlock( in_channels, out_channels, stride, dilations=block_dilations, unit_kernel_size=unit_kernel_size, bias=bias ) ] self.num_blocks = len(self.conv_blocks) self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False) def forward(self, z): x = self.conv1(z) for i in range(self.num_blocks): x = self.conv_blocks[i](x) x = self.conv2(x) return x