Spaces:
Paused
Paused
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import torch | |
| from coqpit import Coqpit | |
| from torch import nn | |
| from torch.nn.utils import weight_norm | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from trainer.trainer_utils import get_optimizer, get_scheduler | |
| from TTS.utils.io import load_fsspec | |
| from TTS.vocoder.datasets import WaveGradDataset | |
| from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock | |
| from TTS.vocoder.models.base_vocoder import BaseVocoder | |
| from TTS.vocoder.utils.generic_utils import plot_results | |
| class WavegradArgs(Coqpit): | |
| in_channels: int = 80 | |
| out_channels: int = 1 | |
| use_weight_norm: bool = False | |
| y_conv_channels: int = 32 | |
| x_conv_channels: int = 768 | |
| dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512]) | |
| ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128]) | |
| upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2]) | |
| upsample_dilations: List[List[int]] = field( | |
| default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]] | |
| ) | |
| class Wavegrad(BaseVocoder): | |
| """🐸 🌊 WaveGrad 🌊 model. | |
| Paper - https://arxiv.org/abs/2009.00713 | |
| Examples: | |
| Initializing the model. | |
| >>> from TTS.vocoder.configs import WavegradConfig | |
| >>> config = WavegradConfig() | |
| >>> model = Wavegrad(config) | |
| Paper Abstract: | |
| This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the | |
| data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts | |
| from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned | |
| on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting | |
| the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in | |
| terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations. | |
| Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive | |
| baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations. | |
| Audio samples are available at this https URL. | |
| """ | |
| # pylint: disable=dangerous-default-value | |
| def __init__(self, config: Coqpit): | |
| super().__init__(config) | |
| self.config = config | |
| self.use_weight_norm = config.model_params.use_weight_norm | |
| self.hop_len = np.prod(config.model_params.upsample_factors) | |
| self.noise_level = None | |
| self.num_steps = None | |
| self.beta = None | |
| self.alpha = None | |
| self.alpha_hat = None | |
| self.c1 = None | |
| self.c2 = None | |
| self.sigma = None | |
| # dblocks | |
| self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2) | |
| self.dblocks = nn.ModuleList([]) | |
| ic = config.model_params.y_conv_channels | |
| for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)): | |
| self.dblocks.append(DBlock(ic, oc, df)) | |
| ic = oc | |
| # film | |
| self.film = nn.ModuleList([]) | |
| ic = config.model_params.y_conv_channels | |
| for oc in reversed(config.model_params.ublock_out_channels): | |
| self.film.append(FiLM(ic, oc)) | |
| ic = oc | |
| # ublocksn | |
| self.ublocks = nn.ModuleList([]) | |
| ic = config.model_params.x_conv_channels | |
| for oc, uf, ud in zip( | |
| config.model_params.ublock_out_channels, | |
| config.model_params.upsample_factors, | |
| config.model_params.upsample_dilations, | |
| ): | |
| self.ublocks.append(UBlock(ic, oc, uf, ud)) | |
| ic = oc | |
| self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1) | |
| self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1) | |
| if config.model_params.use_weight_norm: | |
| self.apply_weight_norm() | |
| def forward(self, x, spectrogram, noise_scale): | |
| shift_and_scale = [] | |
| x = self.y_conv(x) | |
| shift_and_scale.append(self.film[0](x, noise_scale)) | |
| for film, layer in zip(self.film[1:], self.dblocks): | |
| x = layer(x) | |
| shift_and_scale.append(film(x, noise_scale)) | |
| x = self.x_conv(spectrogram) | |
| for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)): | |
| x = layer(x, film_shift, film_scale) | |
| x = self.out_conv(x) | |
| return x | |
| def load_noise_schedule(self, path): | |
| beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg | |
| self.compute_noise_level(beta) | |
| def inference(self, x, y_n=None): | |
| """ | |
| Shapes: | |
| x: :math:`[B, C , T]` | |
| y_n: :math:`[B, 1, T]` | |
| """ | |
| if y_n is None: | |
| y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1]) | |
| else: | |
| y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0) | |
| y_n = y_n.type_as(x) | |
| sqrt_alpha_hat = self.noise_level.to(x) | |
| for n in range(len(self.alpha) - 1, -1, -1): | |
| y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) | |
| if n > 0: | |
| z = torch.randn_like(y_n) | |
| y_n += self.sigma[n - 1] * z | |
| y_n.clamp_(-1.0, 1.0) | |
| return y_n | |
| def compute_y_n(self, y_0): | |
| """Compute noisy audio based on noise schedule""" | |
| self.noise_level = self.noise_level.to(y_0) | |
| if len(y_0.shape) == 3: | |
| y_0 = y_0.squeeze(1) | |
| s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]]) | |
| l_a, l_b = self.noise_level[s], self.noise_level[s + 1] | |
| noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) | |
| noise_scale = noise_scale.unsqueeze(1) | |
| noise = torch.randn_like(y_0) | |
| noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise | |
| return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] | |
| def compute_noise_level(self, beta): | |
| """Compute noise schedule parameters""" | |
| self.num_steps = len(beta) | |
| alpha = 1 - beta | |
| alpha_hat = np.cumprod(alpha) | |
| noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0) | |
| noise_level = alpha_hat**0.5 | |
| # pylint: disable=not-callable | |
| self.beta = torch.tensor(beta.astype(np.float32)) | |
| self.alpha = torch.tensor(alpha.astype(np.float32)) | |
| self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) | |
| self.noise_level = torch.tensor(noise_level.astype(np.float32)) | |
| self.c1 = 1 / self.alpha**0.5 | |
| self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 | |
| self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 | |
| def remove_weight_norm(self): | |
| for _, layer in enumerate(self.dblocks): | |
| if len(layer.state_dict()) != 0: | |
| try: | |
| nn.utils.remove_weight_norm(layer) | |
| except ValueError: | |
| layer.remove_weight_norm() | |
| for _, layer in enumerate(self.film): | |
| if len(layer.state_dict()) != 0: | |
| try: | |
| nn.utils.remove_weight_norm(layer) | |
| except ValueError: | |
| layer.remove_weight_norm() | |
| for _, layer in enumerate(self.ublocks): | |
| if len(layer.state_dict()) != 0: | |
| try: | |
| nn.utils.remove_weight_norm(layer) | |
| except ValueError: | |
| layer.remove_weight_norm() | |
| nn.utils.remove_weight_norm(self.x_conv) | |
| nn.utils.remove_weight_norm(self.out_conv) | |
| nn.utils.remove_weight_norm(self.y_conv) | |
| def apply_weight_norm(self): | |
| for _, layer in enumerate(self.dblocks): | |
| if len(layer.state_dict()) != 0: | |
| layer.apply_weight_norm() | |
| for _, layer in enumerate(self.film): | |
| if len(layer.state_dict()) != 0: | |
| layer.apply_weight_norm() | |
| for _, layer in enumerate(self.ublocks): | |
| if len(layer.state_dict()) != 0: | |
| layer.apply_weight_norm() | |
| self.x_conv = weight_norm(self.x_conv) | |
| self.out_conv = weight_norm(self.out_conv) | |
| self.y_conv = weight_norm(self.y_conv) | |
| def load_checkpoint( | |
| self, config, checkpoint_path, eval=False, cache=False | |
| ): # pylint: disable=unused-argument, redefined-builtin | |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) | |
| self.load_state_dict(state["model"]) | |
| if eval: | |
| self.eval() | |
| assert not self.training | |
| if self.config.model_params.use_weight_norm: | |
| self.remove_weight_norm() | |
| betas = np.linspace( | |
| config["test_noise_schedule"]["min_val"], | |
| config["test_noise_schedule"]["max_val"], | |
| config["test_noise_schedule"]["num_steps"], | |
| ) | |
| self.compute_noise_level(betas) | |
| else: | |
| betas = np.linspace( | |
| config["train_noise_schedule"]["min_val"], | |
| config["train_noise_schedule"]["max_val"], | |
| config["train_noise_schedule"]["num_steps"], | |
| ) | |
| self.compute_noise_level(betas) | |
| def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: | |
| # format data | |
| x = batch["input"] | |
| y = batch["waveform"] | |
| # set noise scale | |
| noise, x_noisy, noise_scale = self.compute_y_n(y) | |
| # forward pass | |
| noise_hat = self.forward(x_noisy, x, noise_scale) | |
| # compute losses | |
| loss = criterion(noise, noise_hat) | |
| return {"model_output": noise_hat}, {"loss": loss} | |
| def train_log( # pylint: disable=no-self-use | |
| self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument | |
| ) -> Tuple[Dict, np.ndarray]: | |
| pass | |
| def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: | |
| return self.train_step(batch, criterion) | |
| def eval_log( # pylint: disable=no-self-use | |
| self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument | |
| ) -> None: | |
| pass | |
| def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument | |
| # setup noise schedule and inference | |
| ap = assets["audio_processor"] | |
| noise_schedule = self.config["test_noise_schedule"] | |
| betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) | |
| self.compute_noise_level(betas) | |
| samples = test_loader.dataset.load_test_samples(1) | |
| for sample in samples: | |
| x = sample[0] | |
| x = x[None, :, :].to(next(self.parameters()).device) | |
| y = sample[1] | |
| y = y[None, :] | |
| # compute voice | |
| y_pred = self.inference(x) | |
| # compute spectrograms | |
| figures = plot_results(y_pred, y, ap, "test") | |
| # Sample audio | |
| sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() | |
| return figures, {"test/audio": sample_voice} | |
| def get_optimizer(self): | |
| return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) | |
| def get_scheduler(self, optimizer): | |
| return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) | |
| def get_criterion(): | |
| return torch.nn.L1Loss() | |
| def format_batch(batch: Dict) -> Dict: | |
| # return a whole audio segment | |
| m, y = batch[0], batch[1] | |
| y = y.unsqueeze(1) | |
| return {"input": m, "waveform": y} | |
| def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): | |
| ap = assets["audio_processor"] | |
| dataset = WaveGradDataset( | |
| ap=ap, | |
| items=samples, | |
| seq_len=self.config.seq_len, | |
| hop_len=ap.hop_length, | |
| pad_short=self.config.pad_short, | |
| conv_pad=self.config.conv_pad, | |
| is_training=not is_eval, | |
| return_segments=True, | |
| use_noise_augment=False, | |
| use_cache=config.use_cache, | |
| verbose=verbose, | |
| ) | |
| sampler = DistributedSampler(dataset) if num_gpus > 1 else None | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=self.config.batch_size, | |
| shuffle=num_gpus <= 1, | |
| drop_last=False, | |
| sampler=sampler, | |
| num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers, | |
| pin_memory=False, | |
| ) | |
| return loader | |
| def on_epoch_start(self, trainer): # pylint: disable=unused-argument | |
| noise_schedule = self.config["train_noise_schedule"] | |
| betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) | |
| self.compute_noise_level(betas) | |
| def init_from_config(config: "WavegradConfig"): | |
| return Wavegrad(config) | |