Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Union | |
| import torch | |
| from coqpit import Coqpit | |
| from torch import nn | |
| from TTS.tts.layers.align_tts.mdn import MDNBlock | |
| from TTS.tts.layers.feed_forward.decoder import Decoder | |
| from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor | |
| from TTS.tts.layers.feed_forward.encoder import Encoder | |
| from TTS.tts.layers.generic.pos_encoding import PositionalEncoding | |
| from TTS.tts.models.base_tts import BaseTTS | |
| from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask | |
| from TTS.tts.utils.speakers import SpeakerManager | |
| from TTS.tts.utils.text.tokenizer import TTSTokenizer | |
| from TTS.tts.utils.visual import plot_alignment, plot_spectrogram | |
| from TTS.utils.io import load_fsspec | |
| class AlignTTSArgs(Coqpit): | |
| """ | |
| Args: | |
| num_chars (int): | |
| number of unique input to characters | |
| out_channels (int): | |
| number of output tensor channels. It is equal to the expected spectrogram size. | |
| hidden_channels (int): | |
| number of channels in all the model layers. | |
| hidden_channels_ffn (int): | |
| number of channels in transformer's conv layers. | |
| hidden_channels_dp (int): | |
| number of channels in duration predictor network. | |
| num_heads (int): | |
| number of attention heads in transformer networks. | |
| num_transformer_layers (int): | |
| number of layers in encoder and decoder transformer blocks. | |
| dropout_p (int): | |
| dropout rate in transformer layers. | |
| length_scale (int, optional): | |
| coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1. | |
| num_speakers (int, optional): | |
| number of speakers for multi-speaker training. Defaults to 0. | |
| external_c (bool, optional): | |
| enable external speaker embeddings. Defaults to False. | |
| c_in_channels (int, optional): | |
| number of channels in speaker embedding vectors. Defaults to 0. | |
| """ | |
| num_chars: int = None | |
| out_channels: int = 80 | |
| hidden_channels: int = 256 | |
| hidden_channels_dp: int = 256 | |
| encoder_type: str = "fftransformer" | |
| encoder_params: dict = field( | |
| default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} | |
| ) | |
| decoder_type: str = "fftransformer" | |
| decoder_params: dict = field( | |
| default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} | |
| ) | |
| length_scale: float = 1.0 | |
| num_speakers: int = 0 | |
| use_speaker_embedding: bool = False | |
| use_d_vector_file: bool = False | |
| d_vector_dim: int = 0 | |
| class AlignTTS(BaseTTS): | |
| """AlignTTS with modified duration predictor. | |
| https://arxiv.org/pdf/2003.01950.pdf | |
| Encoder -> DurationPredictor -> Decoder | |
| Check :class:`AlignTTSArgs` for the class arguments. | |
| Paper Abstract: | |
| Targeting at both high efficiency and performance, we propose AlignTTS to predict the | |
| mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a | |
| sequence of characters, and the duration of each character is determined by a duration predictor.Instead of | |
| adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented | |
| to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s | |
| how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean | |
| option score (MOS), but also a high efficiency which is more than 50 times faster than real-time. | |
| Note: | |
| Original model uses a separate character embedding layer for duration predictor. However, it causes the | |
| duration predictor to overfit and prevents learning higher level interactions among characters. Therefore, | |
| we predict durations based on encoder outputs which has higher level information about input characters. This | |
| enables training without phases as in the original paper. | |
| Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture | |
| differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. | |
| Examples: | |
| >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig | |
| >>> config = AlignTTSConfig() | |
| >>> model = AlignTTS(config) | |
| """ | |
| # pylint: disable=dangerous-default-value | |
| def __init__( | |
| self, | |
| config: "AlignTTSConfig", | |
| ap: "AudioProcessor" = None, | |
| tokenizer: "TTSTokenizer" = None, | |
| speaker_manager: SpeakerManager = None, | |
| ): | |
| super().__init__(config, ap, tokenizer, speaker_manager) | |
| self.speaker_manager = speaker_manager | |
| self.phase = -1 | |
| self.length_scale = ( | |
| float(config.model_args.length_scale) | |
| if isinstance(config.model_args.length_scale, int) | |
| else config.model_args.length_scale | |
| ) | |
| self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) | |
| self.embedded_speaker_dim = 0 | |
| self.init_multispeaker(config) | |
| self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) | |
| self.encoder = Encoder( | |
| config.model_args.hidden_channels, | |
| config.model_args.hidden_channels, | |
| config.model_args.encoder_type, | |
| config.model_args.encoder_params, | |
| self.embedded_speaker_dim, | |
| ) | |
| self.decoder = Decoder( | |
| config.model_args.out_channels, | |
| config.model_args.hidden_channels, | |
| config.model_args.decoder_type, | |
| config.model_args.decoder_params, | |
| ) | |
| self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp) | |
| self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1) | |
| self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels) | |
| if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels: | |
| self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1) | |
| def compute_log_probs(mu, log_sigma, y): | |
| # pylint: disable=protected-access, c-extension-no-member | |
| y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] | |
| mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] | |
| log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] | |
| expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) | |
| exponential = -0.5 * torch.mean( | |
| torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1 | |
| ) # B, L, T | |
| logp = exponential - 0.5 * log_sigma.mean(dim=-1) | |
| return logp | |
| def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): | |
| # find the max alignment path | |
| attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) | |
| log_p = self.compute_log_probs(mu, log_sigma, y) | |
| # [B, T_en, T_dec] | |
| attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1) | |
| dr_mas = torch.sum(attn, -1) | |
| return dr_mas.squeeze(1), log_p | |
| def generate_attn(dr, x_mask, y_mask=None): | |
| # compute decode mask from the durations | |
| if y_mask is None: | |
| y_lengths = dr.sum(1).long() | |
| y_lengths[y_lengths < 1] = 1 | |
| y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) | |
| attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) | |
| attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) | |
| return attn | |
| def expand_encoder_outputs(self, en, dr, x_mask, y_mask): | |
| """Generate attention alignment map from durations and | |
| expand encoder outputs | |
| Examples:: | |
| - encoder output: [a,b,c,d] | |
| - durations: [1, 3, 2, 1] | |
| - expanded: [a, b, b, b, c, c, d] | |
| - attention map: [[0, 0, 0, 0, 0, 0, 1], | |
| [0, 0, 0, 0, 1, 1, 0], | |
| [0, 1, 1, 1, 0, 0, 0], | |
| [1, 0, 0, 0, 0, 0, 0]] | |
| """ | |
| attn = self.generate_attn(dr, x_mask, y_mask) | |
| o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) | |
| return o_en_ex, attn | |
| def format_durations(self, o_dr_log, x_mask): | |
| o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale | |
| o_dr[o_dr < 1] = 1.0 | |
| o_dr = torch.round(o_dr) | |
| return o_dr | |
| def _concat_speaker_embedding(o_en, g): | |
| g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] | |
| o_en = torch.cat([o_en, g_exp], 1) | |
| return o_en | |
| def _sum_speaker_embedding(self, x, g): | |
| # project g to decoder dim. | |
| if hasattr(self, "proj_g"): | |
| g = self.proj_g(g) | |
| return x + g | |
| def _forward_encoder(self, x, x_lengths, g=None): | |
| if hasattr(self, "emb_g"): | |
| g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1] | |
| if g is not None: | |
| g = g.unsqueeze(-1) | |
| # [B, T, C] | |
| x_emb = self.emb(x) | |
| # [B, C, T] | |
| x_emb = torch.transpose(x_emb, 1, -1) | |
| # compute sequence masks | |
| x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) | |
| # encoder pass | |
| o_en = self.encoder(x_emb, x_mask) | |
| # speaker conditioning for duration predictor | |
| if g is not None: | |
| o_en_dp = self._concat_speaker_embedding(o_en, g) | |
| else: | |
| o_en_dp = o_en | |
| return o_en, o_en_dp, x_mask, g | |
| def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): | |
| y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) | |
| # expand o_en with durations | |
| o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) | |
| # positional encoding | |
| if hasattr(self, "pos_encoder"): | |
| o_en_ex = self.pos_encoder(o_en_ex, y_mask) | |
| # speaker embedding | |
| if g is not None: | |
| o_en_ex = self._sum_speaker_embedding(o_en_ex, g) | |
| # decoder pass | |
| o_de = self.decoder(o_en_ex, y_mask, g=g) | |
| return o_de, attn.transpose(1, 2) | |
| def _forward_mdn(self, o_en, y, y_lengths, x_mask): | |
| # MAS potentials and alignment | |
| mu, log_sigma = self.mdn_block(o_en) | |
| y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) | |
| dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) | |
| return dr_mas, mu, log_sigma, logp | |
| def forward( | |
| self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None | |
| ): # pylint: disable=unused-argument | |
| """ | |
| Shapes: | |
| - x: :math:`[B, T_max]` | |
| - x_lengths: :math:`[B]` | |
| - y_lengths: :math:`[B]` | |
| - dr: :math:`[B, T_max]` | |
| - g: :math:`[B, C]` | |
| """ | |
| y = y.transpose(1, 2) | |
| g = aux_input["d_vectors"] if "d_vectors" in aux_input else None | |
| o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None | |
| if phase == 0: | |
| # train encoder and MDN | |
| o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) | |
| dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) | |
| y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) | |
| attn = self.generate_attn(dr_mas, x_mask, y_mask) | |
| elif phase == 1: | |
| # train decoder | |
| o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) | |
| dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) | |
| o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) | |
| elif phase == 2: | |
| # train the whole except duration predictor | |
| o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) | |
| dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) | |
| o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) | |
| elif phase == 3: | |
| # train duration predictor | |
| o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) | |
| o_dr_log = self.duration_predictor(x, x_mask) | |
| dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) | |
| o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) | |
| o_dr_log = o_dr_log.squeeze(1) | |
| else: | |
| o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) | |
| o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) | |
| dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) | |
| o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) | |
| o_dr_log = o_dr_log.squeeze(1) | |
| dr_mas_log = torch.log(dr_mas + 1).squeeze(1) | |
| outputs = { | |
| "model_outputs": o_de.transpose(1, 2), | |
| "alignments": attn, | |
| "durations_log": o_dr_log, | |
| "durations_mas_log": dr_mas_log, | |
| "mu": mu, | |
| "log_sigma": log_sigma, | |
| "logp": logp, | |
| } | |
| return outputs | |
| def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument | |
| """ | |
| Shapes: | |
| - x: :math:`[B, T_max]` | |
| - x_lengths: :math:`[B]` | |
| - g: :math:`[B, C]` | |
| """ | |
| g = aux_input["d_vectors"] if "d_vectors" in aux_input else None | |
| x_lengths = torch.tensor(x.shape[1:2]).to(x.device) | |
| # pad input to prevent dropping the last word | |
| # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) | |
| o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) | |
| # o_dr_log = self.duration_predictor(x, x_mask) | |
| o_dr_log = self.duration_predictor(o_en_dp, x_mask) | |
| # duration predictor pass | |
| o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) | |
| y_lengths = o_dr.sum(1) | |
| o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) | |
| outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn} | |
| return outputs | |
| def train_step(self, batch: dict, criterion: nn.Module): | |
| text_input = batch["text_input"] | |
| text_lengths = batch["text_lengths"] | |
| mel_input = batch["mel_input"] | |
| mel_lengths = batch["mel_lengths"] | |
| d_vectors = batch["d_vectors"] | |
| speaker_ids = batch["speaker_ids"] | |
| aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} | |
| outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input, self.phase) | |
| loss_dict = criterion( | |
| outputs["logp"], | |
| outputs["model_outputs"], | |
| mel_input, | |
| mel_lengths, | |
| outputs["durations_log"], | |
| outputs["durations_mas_log"], | |
| text_lengths, | |
| phase=self.phase, | |
| ) | |
| return outputs, loss_dict | |
| def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use | |
| model_outputs = outputs["model_outputs"] | |
| alignments = outputs["alignments"] | |
| mel_input = batch["mel_input"] | |
| pred_spec = model_outputs[0].data.cpu().numpy() | |
| gt_spec = mel_input[0].data.cpu().numpy() | |
| align_img = alignments[0].data.cpu().numpy() | |
| figures = { | |
| "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), | |
| "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), | |
| "alignment": plot_alignment(align_img, output_fig=False), | |
| } | |
| # Sample audio | |
| train_audio = ap.inv_melspectrogram(pred_spec.T) | |
| return figures, {"audio": train_audio} | |
| def train_log( | |
| self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int | |
| ) -> None: # pylint: disable=no-self-use | |
| figures, audios = self._create_logs(batch, outputs, self.ap) | |
| logger.train_figures(steps, figures) | |
| logger.train_audios(steps, audios, self.ap.sample_rate) | |
| def eval_step(self, batch: dict, criterion: nn.Module): | |
| return self.train_step(batch, criterion) | |
| def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: | |
| figures, audios = self._create_logs(batch, outputs, self.ap) | |
| logger.eval_figures(steps, figures) | |
| logger.eval_audios(steps, audios, self.ap.sample_rate) | |
| def load_checkpoint( | |
| self, config, checkpoint_path, eval=False, cache=False | |
| ): # pylint: disable=unused-argument, redefined-builtin | |
| state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) | |
| self.load_state_dict(state["model"]) | |
| if eval: | |
| self.eval() | |
| assert not self.training | |
| def get_criterion(self): | |
| from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel | |
| return AlignTTSLoss(self.config) | |
| def _set_phase(config, global_step): | |
| """Decide AlignTTS training phase""" | |
| if isinstance(config.phase_start_steps, list): | |
| vals = [i < global_step for i in config.phase_start_steps] | |
| if not True in vals: | |
| phase = 0 | |
| else: | |
| phase = ( | |
| len(config.phase_start_steps) | |
| - [i < global_step for i in config.phase_start_steps][::-1].index(True) | |
| - 1 | |
| ) | |
| else: | |
| phase = None | |
| return phase | |
| def on_epoch_start(self, trainer): | |
| """Set AlignTTS training phase on epoch start.""" | |
| self.phase = self._set_phase(trainer.config, trainer.total_steps_done) | |
| def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): | |
| """Initiate model from config | |
| Args: | |
| config (AlignTTSConfig): Model config. | |
| samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. | |
| Defaults to None. | |
| """ | |
| from TTS.utils.audio import AudioProcessor | |
| ap = AudioProcessor.init_from_config(config) | |
| tokenizer, new_config = TTSTokenizer.init_from_config(config) | |
| speaker_manager = SpeakerManager.init_from_config(config, samples) | |
| return AlignTTS(new_config, ap, tokenizer, speaker_manager) | |