|
import os
|
|
import sys
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from torch.nn.utils import remove_weight_norm
|
|
from torch.nn.utils.parametrizations import weight_norm
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library.algorithm.commons import init_weights
|
|
from main.library.algorithm.residuals import ResBlock, LRELU_SLOPE
|
|
|
|
class HiFiGANGenerator(torch.nn.Module):
|
|
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
|
super(HiFiGANGenerator, self).__init__()
|
|
self.num_kernels = len(resblock_kernel_sizes)
|
|
self.num_upsamples = len(upsample_rates)
|
|
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
|
self.ups_and_resblocks = torch.nn.ModuleList()
|
|
|
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
|
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
|
self.ups_and_resblocks.append(ResBlock(ch, k, d))
|
|
|
|
self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
|
self.ups_and_resblocks.apply(init_weights)
|
|
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
|
|
|
def forward(self, x, g = None):
|
|
x = self.conv_pre(x)
|
|
if g is not None: x = x + self.cond(g)
|
|
|
|
resblock_idx = 0
|
|
|
|
for _ in range(self.num_upsamples):
|
|
x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
|
|
resblock_idx += 1
|
|
xs = 0
|
|
|
|
for _ in range(self.num_kernels):
|
|
xs += self.ups_and_resblocks[resblock_idx](x)
|
|
resblock_idx += 1
|
|
|
|
x = xs / self.num_kernels
|
|
|
|
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
|
|
|
def __prepare_scriptable__(self):
|
|
for l in self.ups_and_resblocks:
|
|
for hook in l._forward_pre_hooks.values():
|
|
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
|
|
|
|
return self
|
|
|
|
def remove_weight_norm(self):
|
|
for l in self.ups_and_resblocks:
|
|
remove_weight_norm(l) |