Spaces:
Build error
Build error
from typing import Dict, Union | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
import librosa | |
class TorchSTFT(nn.Module): # pylint: disable=abstract-method | |
"""Some of the audio processing funtions using Torch for faster batch processing. | |
Args: | |
n_fft (int): | |
FFT window size for STFT. | |
hop_length (int): | |
number of frames between STFT columns. | |
win_length (int, optional): | |
STFT window length. | |
pad_wav (bool, optional): | |
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. | |
window (str, optional): | |
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" | |
sample_rate (int, optional): | |
target audio sampling rate. Defaults to None. | |
mel_fmin (int, optional): | |
minimum filter frequency for computing melspectrograms. Defaults to None. | |
mel_fmax (int, optional): | |
maximum filter frequency for computing melspectrograms. Defaults to None. | |
n_mels (int, optional): | |
number of melspectrogram dimensions. Defaults to None. | |
use_mel (bool, optional): | |
If True compute the melspectrograms otherwise. Defaults to False. | |
do_amp_to_db_linear (bool, optional): | |
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. | |
spec_gain (float, optional): | |
gain applied when converting amplitude to DB. Defaults to 1.0. | |
power (float, optional): | |
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. | |
use_htk (bool, optional): | |
Use HTK formula in mel filter instead of Slaney. | |
mel_norm (None, 'slaney', or number, optional): | |
If 'slaney', divide the triangular mel weights by the width of the mel band | |
(area normalization). | |
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. | |
See `librosa.util.normalize` for a full description of supported norm values | |
(including `+-np.inf`). | |
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". | |
""" | |
def __init__( | |
self, | |
n_fft, | |
hop_length, | |
win_length, | |
pad_wav=False, | |
window="hann_window", | |
sample_rate=None, | |
mel_fmin=0, | |
mel_fmax=None, | |
n_mels=80, | |
use_mel=False, | |
do_amp_to_db=False, | |
spec_gain=1.0, | |
power=None, | |
use_htk=False, | |
mel_norm="slaney", | |
normalized=False, | |
): | |
super().__init__() | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.pad_wav = pad_wav | |
self.sample_rate = sample_rate | |
self.mel_fmin = mel_fmin | |
self.mel_fmax = mel_fmax | |
self.n_mels = n_mels | |
self.use_mel = use_mel | |
self.do_amp_to_db = do_amp_to_db | |
self.spec_gain = spec_gain | |
self.power = power | |
self.use_htk = use_htk | |
self.mel_norm = mel_norm | |
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) | |
self.mel_basis = None | |
self.normalized = normalized | |
if use_mel: | |
self._build_mel_basis() | |
def __call__(self, x): | |
"""Compute spectrogram frames by torch based stft. | |
Args: | |
x (Tensor): input waveform | |
Returns: | |
Tensor: spectrogram frames. | |
Shapes: | |
x: [B x T] or [:math:`[B, 1, T]`] | |
""" | |
if x.ndim == 2: | |
x = x.unsqueeze(1) | |
if self.pad_wav: | |
padding = int((self.n_fft - self.hop_length) / 2) | |
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") | |
# B x D x T x 2 | |
o = torch.stft( | |
x.squeeze(1), | |
self.n_fft, | |
self.hop_length, | |
self.win_length, | |
self.window.to(x.device), | |
center=True, | |
pad_mode="reflect", # compatible with audio.py | |
normalized=self.normalized, | |
onesided=True, | |
return_complex=False, | |
) | |
M = o[:, :, :, 0] | |
P = o[:, :, :, 1] | |
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) | |
if self.power is not None: | |
S = S**self.power | |
if self.use_mel: | |
S = torch.matmul(self.mel_basis.to(x), S) | |
if self.do_amp_to_db: | |
S = self._amp_to_db(S, spec_gain=self.spec_gain) | |
return S | |
def _build_mel_basis(self): | |
mel_basis = librosa.filters.mel( | |
sr=self.sample_rate, | |
n_fft=self.n_fft, | |
n_mels=self.n_mels, | |
fmin=self.mel_fmin, | |
fmax=self.mel_fmax, | |
htk=self.use_htk, | |
norm=self.mel_norm, | |
) | |
self.mel_basis = torch.from_numpy(mel_basis).float() | |
def _amp_to_db(x, spec_gain=1.0): | |
return torch.log(torch.clamp(x, min=1e-5) * spec_gain) | |
def _db_to_amp(x, spec_gain=1.0): | |
return torch.exp(x) / spec_gain | |
################################# | |
# GENERATOR LOSSES | |
################################# | |
class STFTLoss(nn.Module): | |
"""STFT loss. Input generate and real waveforms are converted | |
to spectrograms compared with L1 and Spectral convergence losses. | |
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" | |
def __init__(self, n_fft, hop_length, win_length): | |
super().__init__() | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.stft = TorchSTFT(n_fft, hop_length, win_length) | |
def forward(self, y_hat, y): | |
y_hat_M = self.stft(y_hat) | |
y_M = self.stft(y) | |
# magnitude loss | |
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) | |
# spectral convergence loss | |
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") | |
return loss_mag, loss_sc | |
class MultiScaleSTFTLoss(torch.nn.Module): | |
"""Multi-scale STFT loss. Input generate and real waveforms are converted | |
to spectrograms compared with L1 and Spectral convergence losses. | |
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" | |
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)): | |
super().__init__() | |
self.loss_funcs = torch.nn.ModuleList() | |
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): | |
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length)) | |
def forward(self, y_hat, y): | |
N = len(self.loss_funcs) | |
loss_sc = 0 | |
loss_mag = 0 | |
for f in self.loss_funcs: | |
lm, lsc = f(y_hat, y) | |
loss_mag += lm | |
loss_sc += lsc | |
loss_sc /= N | |
loss_mag /= N | |
return loss_mag, loss_sc | |
class L1SpecLoss(nn.Module): | |
"""L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf""" | |
def __init__( | |
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True | |
): | |
super().__init__() | |
self.use_mel = use_mel | |
self.stft = TorchSTFT( | |
n_fft, | |
hop_length, | |
win_length, | |
sample_rate=sample_rate, | |
mel_fmin=mel_fmin, | |
mel_fmax=mel_fmax, | |
n_mels=n_mels, | |
use_mel=use_mel, | |
) | |
def forward(self, y_hat, y): | |
y_hat_M = self.stft(y_hat) | |
y_M = self.stft(y) | |
# magnitude loss | |
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) | |
return loss_mag | |
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): | |
"""Multiscale STFT loss for multi band model outputs. | |
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106""" | |
# pylint: disable=no-self-use | |
def forward(self, y_hat, y): | |
y_hat = y_hat.view(-1, 1, y_hat.shape[2]) | |
y = y.view(-1, 1, y.shape[2]) | |
return super().forward(y_hat.squeeze(1), y.squeeze(1)) | |
class MSEGLoss(nn.Module): | |
"""Mean Squared Generator Loss""" | |
# pylint: disable=no-self-use | |
def forward(self, score_real): | |
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) | |
return loss_fake | |
class HingeGLoss(nn.Module): | |
"""Hinge Discriminator Loss""" | |
# pylint: disable=no-self-use | |
def forward(self, score_real): | |
# TODO: this might be wrong | |
loss_fake = torch.mean(F.relu(1.0 - score_real)) | |
return loss_fake | |
################################## | |
# DISCRIMINATOR LOSSES | |
################################## | |
class MSEDLoss(nn.Module): | |
"""Mean Squared Discriminator Loss""" | |
def __init__( | |
self, | |
): | |
super().__init__() | |
self.loss_func = nn.MSELoss() | |
# pylint: disable=no-self-use | |
def forward(self, score_fake, score_real): | |
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape)) | |
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape)) | |
loss_d = loss_real + loss_fake | |
return loss_d, loss_real, loss_fake | |
class HingeDLoss(nn.Module): | |
"""Hinge Discriminator Loss""" | |
# pylint: disable=no-self-use | |
def forward(self, score_fake, score_real): | |
loss_real = torch.mean(F.relu(1.0 - score_real)) | |
loss_fake = torch.mean(F.relu(1.0 + score_fake)) | |
loss_d = loss_real + loss_fake | |
return loss_d, loss_real, loss_fake | |
class MelganFeatureLoss(nn.Module): | |
def __init__( | |
self, | |
): | |
super().__init__() | |
self.loss_func = nn.L1Loss() | |
# pylint: disable=no-self-use | |
def forward(self, fake_feats, real_feats): | |
loss_feats = 0 | |
num_feats = 0 | |
for idx, _ in enumerate(fake_feats): | |
for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]): | |
loss_feats += self.loss_func(fake_feat, real_feat) | |
num_feats += 1 | |
loss_feats = loss_feats / num_feats | |
return loss_feats | |
##################################### | |
# LOSS WRAPPERS | |
##################################### | |
def _apply_G_adv_loss(scores_fake, loss_func): | |
"""Compute G adversarial loss function | |
and normalize values""" | |
adv_loss = 0 | |
if isinstance(scores_fake, list): | |
for score_fake in scores_fake: | |
fake_loss = loss_func(score_fake) | |
adv_loss += fake_loss | |
adv_loss /= len(scores_fake) | |
else: | |
fake_loss = loss_func(scores_fake) | |
adv_loss = fake_loss | |
return adv_loss | |
def _apply_D_loss(scores_fake, scores_real, loss_func): | |
"""Compute D loss func and normalize loss values""" | |
loss = 0 | |
real_loss = 0 | |
fake_loss = 0 | |
if isinstance(scores_fake, list): | |
# multi-scale loss | |
for score_fake, score_real in zip(scores_fake, scores_real): | |
total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real) | |
loss += total_loss | |
real_loss += real_loss_ | |
fake_loss += fake_loss_ | |
# normalize loss values with number of scales (discriminators) | |
loss /= len(scores_fake) | |
real_loss /= len(scores_real) | |
fake_loss /= len(scores_fake) | |
else: | |
# single scale loss | |
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real) | |
loss = total_loss | |
return loss, real_loss, fake_loss | |
################################## | |
# MODEL LOSSES | |
################################## | |
class GeneratorLoss(nn.Module): | |
"""Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes | |
losses. It allows to experiment with different combinations of loss functions with different models by just | |
changing configurations. | |
Args: | |
C (AttrDict): model configuration. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.use_stft_loss = False | |
self.use_subband_stft_loss = False | |
self.use_mse_gan_loss = True | |
self.use_hinge_gan_loss = False | |
self.use_feat_match_loss = True | |
self.use_l1_spec_loss = True | |
self.stft_loss_weight = 0 | |
self.subband_stft_loss_weight = 0 | |
self.mse_gan_loss_weight = 1 | |
self.hinge_gan_loss_weight = 0 | |
self.feat_match_loss_weight = 108 | |
self.l1_spec_loss_weight = 45 | |
self.mse_loss = MSEGLoss() | |
self.feat_match_loss = MelganFeatureLoss() | |
self.l1_spec_loss = L1SpecLoss(**{ | |
"use_mel": True, | |
"sample_rate": 24000, | |
"n_fft": 1024, | |
"hop_length": 256, | |
"win_length": 1024, | |
"n_mels": 100, | |
"mel_fmin": 0.0, | |
"mel_fmax": None, | |
}) | |
def forward( | |
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None | |
): | |
gen_loss = 0 | |
adv_loss = 0 | |
return_dict = {} | |
# STFT Loss | |
if self.use_stft_loss: | |
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1)) | |
return_dict["G_stft_loss_mg"] = stft_loss_mg | |
return_dict["G_stft_loss_sc"] = stft_loss_sc | |
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) | |
# L1 Spec loss | |
if self.use_l1_spec_loss: | |
l1_spec_loss = self.l1_spec_loss(y_hat, y) | |
return_dict["G_l1_spec_loss"] = l1_spec_loss | |
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss | |
# subband STFT Loss | |
if self.use_subband_stft_loss: | |
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) | |
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg | |
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc | |
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) | |
# multiscale MSE adversarial loss | |
if self.use_mse_gan_loss and scores_fake is not None: | |
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) | |
return_dict["G_mse_fake_loss"] = mse_fake_loss | |
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss | |
# multiscale Hinge adversarial loss | |
if self.use_hinge_gan_loss and not scores_fake is not None: | |
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) | |
return_dict["G_hinge_fake_loss"] = hinge_fake_loss | |
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss | |
# Feature Matching Loss | |
if self.use_feat_match_loss and not feats_fake is None: | |
feat_match_loss = self.feat_match_loss(feats_fake, feats_real) | |
return_dict["G_feat_match_loss"] = feat_match_loss | |
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss | |
return_dict["loss"] = gen_loss + adv_loss | |
return_dict["G_gen_loss"] = gen_loss | |
return_dict["G_adv_loss"] = adv_loss | |
return return_dict | |
class DiscriminatorLoss(nn.Module): | |
"""Like ```GeneratorLoss```""" | |
def __init__(self): | |
super().__init__() | |
self.use_mse_gan_loss = True | |
self.mse_loss = MSEDLoss() | |
def forward(self, scores_fake, scores_real): | |
loss = 0 | |
return_dict = {} | |
if self.use_mse_gan_loss: | |
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( | |
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss | |
) | |
return_dict["D_mse_gan_loss"] = mse_D_loss | |
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss | |
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss | |
loss += mse_D_loss | |
return_dict["loss"] = loss | |
return return_dict | |