|
import os
|
|
import sys
|
|
import math
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
|
|
from torch.nn.utils import remove_weight_norm
|
|
from torch.utils.checkpoint import checkpoint
|
|
from torch.nn.utils.parametrizations import weight_norm
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library.algorithm.commons import init_weights
|
|
from main.library.algorithm.residuals import ResBlock, LRELU_SLOPE
|
|
|
|
class SineGen(torch.nn.Module):
|
|
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
|
|
super(SineGen, self).__init__()
|
|
self.sine_amp = sine_amp
|
|
self.noise_std = noise_std
|
|
self.harmonic_num = harmonic_num
|
|
self.dim = self.harmonic_num + 1
|
|
self.sampling_rate = samp_rate
|
|
self.voiced_threshold = voiced_threshold
|
|
|
|
def _f02uv(self, f0):
|
|
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
|
|
|
def _f02sine(self, f0, upp):
|
|
rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
|
|
rad += F.pad((torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5).cumsum(dim=1).fmod(1.0).to(f0), (0, 0, 1, 0), mode='constant')
|
|
rad = rad.reshape(f0.shape[0], -1, 1)
|
|
rad *= torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
|
|
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
|
|
rand_ini[..., 0] = 0
|
|
rad += rand_ini
|
|
|
|
return torch.sin(2 * np.pi * rad)
|
|
|
|
def forward(self, f0, upp):
|
|
with torch.no_grad():
|
|
f0 = f0.unsqueeze(-1)
|
|
sine_waves = self._f02sine(f0, upp) * self.sine_amp
|
|
uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
|
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
|
|
|
return sine_waves
|
|
|
|
class SourceModuleHnNSF(torch.nn.Module):
|
|
def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
|
|
super(SourceModuleHnNSF, self).__init__()
|
|
self.sine_amp = sine_amp
|
|
self.noise_std = add_noise_std
|
|
self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
|
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
|
self.l_tanh = torch.nn.Tanh()
|
|
|
|
def forward(self, x, upsample_factor = 1):
|
|
return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
|
|
|
|
class HiFiGANNRFGenerator(torch.nn.Module):
|
|
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
|
|
super(HiFiGANNRFGenerator, self).__init__()
|
|
self.num_kernels = len(resblock_kernel_sizes)
|
|
self.num_upsamples = len(upsample_rates)
|
|
self.upp = math.prod(upsample_rates)
|
|
self.f0_upsamp = torch.nn.Upsample(scale_factor=self.upp)
|
|
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
|
|
|
|
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
|
self.checkpointing = checkpointing
|
|
|
|
self.ups = torch.nn.ModuleList()
|
|
self.noise_convs = torch.nn.ModuleList()
|
|
|
|
channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(self.num_upsamples)]
|
|
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < self.num_upsamples else 1 for i in range(self.num_upsamples)]
|
|
|
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
|
|
stride = stride_f0s[i]
|
|
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
|
self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
|
|
|
self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
|
|
self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
|
|
|
|
self.ups.apply(init_weights)
|
|
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
|
|
|
def forward(self, x, f0, g = None):
|
|
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
|
x = self.conv_pre(x)
|
|
if g is not None: x += self.cond(g)
|
|
|
|
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
|
|
|
if self.training and self.checkpointing:
|
|
x = checkpoint(ups, x, use_reentrant=False) + noise_convs(har_source)
|
|
xs = sum([checkpoint(resblock, x, use_reentrant=False) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
|
else:
|
|
x = ups(x) + noise_convs(har_source)
|
|
xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
|
|
|
x = xs / self.num_kernels
|
|
|
|
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
|
|
|
def remove_weight_norm(self):
|
|
for l in self.ups:
|
|
remove_weight_norm(l)
|
|
|
|
for l in self.resblocks:
|
|
l.remove_weight_norm() |