import logging import torch import torch.nn.functional as F from torch import Tensor, nn from torch.nn.utils.parametrizations import weight_norm from ..hparams import HParams from .mrstft import get_stft_cfgs logger = logging.getLogger(__name__) class PeriodNetwork(nn.Module): def __init__(self, period): super().__init__() self.period = period wn = weight_norm self.convs = nn.ModuleList( [ wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))), wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))), wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))), wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))), wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))), ] ) self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): """ Args: x: [B, 1, T] """ assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}." # 1d to 2d b, c, t = x.shape if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) for l in self.convs: x = l(x) x = F.leaky_relu(x, 0.2) x = self.conv_post(x) x = torch.flatten(x, 1, -1) return x class SpecNetwork(nn.Module): def __init__(self, stft_cfg: dict): super().__init__() wn = weight_norm self.stft_cfg = stft_cfg self.convs = nn.ModuleList( [ wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), ] ) self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) def forward(self, x): """ Args: x: [B, 1, T] """ x = self.spectrogram(x) x = x.unsqueeze(1) for l in self.convs: x = l(x) x = F.leaky_relu(x, 0.2) x = self.conv_post(x) x = x.flatten(1, -1) return x def spectrogram(self, x): """ Args: x: [B, 1, T] """ x = x.squeeze(1) dtype = x.dtype stft_cfg = dict(self.stft_cfg) x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg) mag = x.norm(p=2, dim=-1) # [B, F, TT] mag = mag.to(dtype) # [B, F, TT] return mag class MD(nn.ModuleList): def __init__(self, l: list): super().__init__([self._create_network(x) for x in l]) self._loss_type = None def loss_type_(self, loss_type): self._loss_type = loss_type def _create_network(self, _): raise NotImplementedError def _forward_each(self, d, x, y): assert self._loss_type is not None, "loss_type is not set." loss_type = self._loss_type if loss_type == "hinge": if y == 0: # d(x) should be small -> -1 loss = F.relu(1 + d(x)).mean() elif y == 1: # d(x) should be large -> 1 loss = F.relu(1 - d(x)).mean() else: raise ValueError(f"Invalid y: {y}") elif loss_type == "wgan": if y == 0: loss = d(x).mean() elif y == 1: loss = -d(x).mean() else: raise ValueError(f"Invalid y: {y}") else: raise ValueError(f"Invalid loss_type: {loss_type}") return loss def forward(self, x, y) -> Tensor: losses = [self._forward_each(d, x, y) for d in self] return torch.stack(losses).mean() class MPD(MD): def __init__(self): super().__init__([2, 3, 7, 13, 17]) def _create_network(self, period): return PeriodNetwork(period) class MRD(MD): def __init__(self, stft_cfgs): super().__init__(stft_cfgs) def _create_network(self, stft_cfg): return SpecNetwork(stft_cfg) class Discriminator(nn.Module): @property def wav_rate(self): return self.hp.wav_rate def __init__(self, hp: HParams): super().__init__() self.hp = hp self.stft_cfgs = get_stft_cfgs(hp) self.mpd = MPD() self.mrd = MRD(self.stft_cfgs) self.dummy_float: Tensor self.register_buffer("dummy_float", torch.zeros(0), persistent=False) def loss_type_(self, loss_type): self.mpd.loss_type_(loss_type) self.mrd.loss_type_(loss_type) def forward(self, fake, real=None): """ Args: fake: [B T] real: [B T] """ fake = fake.to(self.dummy_float) if real is None: self.loss_type_("wgan") else: length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1] assert length_difference < 0.05, f"length_difference should be smaller than 5%" self.loss_type_("hinge") real = real.to(self.dummy_float) fake = fake[..., : real.shape[-1]] real = real[..., : fake.shape[-1]] losses = {} assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}." assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}." fake = fake.unsqueeze(1) if real is None: losses["mpd"] = self.mpd(fake, 1) losses["mrd"] = self.mrd(fake, 1) else: real = real.unsqueeze(1) losses["mpd_fake"] = self.mpd(fake, 0) losses["mpd_real"] = self.mpd(real, 1) losses["mrd_fake"] = self.mrd(fake, 0) losses["mrd_real"] = self.mrd(real, 1) return losses