import torch import torch.nn as nn from typing import List from dataclasses import asdict from utils.audio import LogMelSpectrogram from config import MelConfig # Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license. class MultiScaleMelSpectrogramLoss(nn.Module): def __init__(self, n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048]): super().__init__() assert len(n_mels) == len(window_lengths), "n_mels and window_lengths must have the same length" self.mel_transforms = nn.ModuleList(self._get_transforms(n_mels, window_lengths)) self.loss_fn = nn.L1Loss() def _get_transforms(self, n_mels, window_lengths): transforms = [] for n_mel, win_length in zip(n_mels, window_lengths): transform = LogMelSpectrogram(**asdict(MelConfig(n_mels=n_mel, n_fft=win_length, win_length=win_length, hop_length=win_length//4))) transforms.append(transform) return transforms def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return sum(self.loss_fn(mel_transform(x), mel_transform(y)) for mel_transform in self.mel_transforms) class SingleScaleMelSpectrogramLoss(nn.Module): def __init__(self): super().__init__() self.mel_transform = LogMelSpectrogram(**asdict(MelConfig())) self.loss_fn = nn.L1Loss() print('using single mel loss') def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return self.loss_fn(self.mel_transform(x), self.mel_transform(y)) def feature_loss(fmap_r, fmap_g): loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): loss += torch.mean(torch.abs(rl - gl)) return loss*2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): loss = 0 r_losses = [] g_losses = [] for dr, dg in zip(disc_real_outputs, disc_generated_outputs): r_loss = torch.mean((1-dr)**2) g_loss = torch.mean(dg**2) loss += (r_loss + g_loss) r_losses.append(r_loss.item()) g_losses.append(g_loss.item()) return loss, r_losses, g_losses def generator_loss(disc_outputs): loss = 0 gen_losses = [] for dg in disc_outputs: l = torch.mean((1-dg)**2) gen_losses.append(l) loss += l return loss, gen_losses