import os import sys import torch import torch.nn.functional as F from torch.nn.utils import remove_weight_norm from torch.nn.utils.parametrizations import weight_norm sys.path.append(os.getcwd()) from main.library.algorithm.commons import init_weights from main.library.algorithm.residuals import ResBlock, LRELU_SLOPE class HiFiGANGenerator(torch.nn.Module): def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): super(HiFiGANGenerator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) self.ups_and_resblocks = torch.nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2))) ch = upsample_initial_channel // (2 ** (i + 1)) for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.ups_and_resblocks.append(ResBlock(ch, k, d)) self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.ups_and_resblocks.apply(init_weights) if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1) def forward(self, x, g = None): x = self.conv_pre(x) if g is not None: x = x + self.cond(g) resblock_idx = 0 for _ in range(self.num_upsamples): x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE)) resblock_idx += 1 xs = 0 for _ in range(self.num_kernels): xs += self.ups_and_resblocks[resblock_idx](x) resblock_idx += 1 x = xs / self.num_kernels return torch.tanh(self.conv_post(F.leaky_relu(x))) def __prepare_scriptable__(self): for l in self.ups_and_resblocks: for hook in l._forward_pre_hooks.values(): if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l) return self def remove_weight_norm(self): for l in self.ups_and_resblocks: remove_weight_norm(l)