""" test_phone_encoder.py Desc: Check to make sure that using the Grad-TTS Encoder will work """ import sys sys.path.append('./') import torch import numpy as np import math from text import text_to_sequence, cmudict from text.symbols import symbols from models.utils import intersperse from models.phoneme_encoder import TextEncoder from models.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility from models import monotonic_align from text import get_arpabet, _symbol_to_id import matplotlib.pyplot as plt def test_cmu_parser(): cmu = cmudict.CMUDict('./resources/cmu_dictionary') text = "Here I go breaking audio models again." x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None] x_lengths = torch.LongTensor([x.shape[-1]]) arpabet_example = get_arpabet("Here", cmu) """ test_phone_encoder Desc: function for ensuring that the Text Encoder works with params """ def test_phone_encoder(): # Load in Sample Mel Spec mel = np.load('/data/jiachenlian/VCTK/mels_16k/p225/p225_219.npy') # Of shape (T, C) where C is channel num # Speech Config for Values set below add_blank = True n_feats = 80 n_spks = 1 # 247 for Libri-TTS filelist and 1 for LJSpeech spk_emb_dim = 64 n_feats = 80 n_fft = 1024 sample_rate = 22050 hop_length = 256 win_length = 1024 f_min = 0 f_max = 8000 # encoder parameters n_enc_channels = 192 filter_channels = 768 filter_channels_dp = 256 n_enc_layers = 6 enc_kernel = 3 enc_dropout = 0.1 n_heads = 2 window_size = 4 length_scale = 1.0 # Format for instantiating encoder # encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, # filter_channels, filter_channels_dp, n_heads, # n_enc_layers, enc_kernel, enc_dropout, window_size) # Example Declaration encoder = TextEncoder(len(symbols) + 1, n_feats, n_enc_channels, filter_channels, filter_channels_dp, n_heads, n_enc_layers, enc_kernel, enc_dropout, window_size) # Get Parsed Text cmu = cmudict.CMUDict('./resources/cmu_dictionary') # Example transcript from the same VCTK clip # text = "We used to live with dignity in our country." text = "They did not attack the themes of the book." x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None] x_lengths = torch.LongTensor([x.shape[-1]]) mu_x, logw, x_mask = encoder(x, x_lengths, None) # Pass in None for spk rn # Inference Time Code w = torch.exp(logw) * x_mask w_ceil = torch.ceil(w) * length_scale y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = int(y_lengths.max()) y_max_length_ = fix_len_compatibility(y_max_length) # Using obtained durations `w` construct alignment map `attn` y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) # Align encoded text and get mu_y mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) mu_y = mu_y.transpose(1, 2) encoder_outputs = mu_y[:, :, :y_max_length][0].detach().numpy() # Plotting the phoneme encodings plt.figure(figsize=(10, 4)) plt.imshow(encoder_outputs, aspect='auto', origin='lower', extent=[0, encoder_outputs.shape[1], 0, encoder_outputs.shape[0]]) plt.colorbar(label='Intensity') plt.xlabel('Time') plt.ylabel('Mel Frequency Bands') plt.title('Phoneme Encoding') plt.savefig('./assets/example_untrained_phone_encoding.png') # Train Time Code # Test out that duration loss works y = torch.Tensor(mel.T).unsqueeze(0) y_lengths = [y.shape[-1]] y_max_length = y.shape[-1] y_lengths = torch.LongTensor([y.shape[-1]]) y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) # Use MAS to find most likely alignment `attn` between text and mel-spectrogram with torch.no_grad(): const = -0.5 * math.log(2 * math.pi) * n_feats factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) y_square = torch.matmul(factor.transpose(1, 2), y ** 2) y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) log_prior = y_square - y_mu_double + mu_square + const attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) attn = attn.detach() attn_np = attn.numpy() # Compute loss between predicted log-scaled durations and those obtained from MAS logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask dur_loss = duration_loss(logw, logw_, x_lengths) # Align text with mel-spec to get mu_y mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) mu_y = mu_y.transpose(1, 2) mu_y_np = mu_y.detach().numpy() # Compute loss between aligned encoder outputs and mel-spectrogram prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) prior_loss = prior_loss / (torch.sum(y_mask) * n_feats) # Plot the Aligned Text with Mel-Spec plt.figure(figsize=(10, 4)) plt.imshow(attn_np.squeeze(0), aspect='auto', origin='lower', extent=[0, attn_np.shape[2], 0, attn_np.shape[1]]) plt.colorbar(label='Intensity') plt.xlabel('Time') plt.ylabel('Mel Frequency Bands') plt.title('Untrained Duration') plt.savefig('./assets/example_duration.png') # Plot the Aligned Text with Mel-Spec plt.figure(figsize=(10, 4)) plt.imshow(mu_y_np.squeeze(0), aspect='auto', origin='lower', extent=[0, mu_y_np.shape[2], 0, mu_y_np.shape[1]]) plt.colorbar(label='Intensity') plt.xlabel('Time') plt.ylabel('Mel Frequency Bands') plt.title('Untrained Alignment') plt.savefig('./assets/example_MAS.png') # Plot Example Mel plt.figure(figsize=(10, 4)) plt.imshow(mel.T, aspect='auto', origin='lower', extent=[0, mel.shape[0], 0, mel.shape[1]]) plt.colorbar(label='Intensity') plt.xlabel('Time') plt.ylabel('Mel Frequency Bands') plt.title('Goal Mel Spectrogram') plt.savefig('./assets/example_mel.png') if __name__ == "__main__": test_cmu_parser() test_phone_encoder()