Spaces:
Runtime error
Runtime error
""" | |
test_phone_encoder.py | |
Desc: Check to make sure that using the Grad-TTS Encoder will work | |
""" | |
import sys | |
sys.path.append('./') | |
import torch | |
import numpy as np | |
import math | |
from text import text_to_sequence, cmudict | |
from text.symbols import symbols | |
from models.utils import intersperse | |
from models.phoneme_encoder import TextEncoder | |
from models.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility | |
from models import monotonic_align | |
from text import get_arpabet, _symbol_to_id | |
import matplotlib.pyplot as plt | |
def test_cmu_parser(): | |
cmu = cmudict.CMUDict('./resources/cmu_dictionary') | |
text = "Here I go breaking audio models again." | |
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None] | |
x_lengths = torch.LongTensor([x.shape[-1]]) | |
arpabet_example = get_arpabet("Here", cmu) | |
""" | |
test_phone_encoder | |
Desc: function for ensuring that the Text Encoder works with params | |
""" | |
def test_phone_encoder(): | |
# Load in Sample Mel Spec | |
mel = np.load('/data/jiachenlian/VCTK/mels_16k/p225/p225_219.npy') # Of shape (T, C) where C is channel num | |
# Speech Config for Values set below | |
add_blank = True | |
n_feats = 80 | |
n_spks = 1 # 247 for Libri-TTS filelist and 1 for LJSpeech | |
spk_emb_dim = 64 | |
n_feats = 80 | |
n_fft = 1024 | |
sample_rate = 22050 | |
hop_length = 256 | |
win_length = 1024 | |
f_min = 0 | |
f_max = 8000 | |
# encoder parameters | |
n_enc_channels = 192 | |
filter_channels = 768 | |
filter_channels_dp = 256 | |
n_enc_layers = 6 | |
enc_kernel = 3 | |
enc_dropout = 0.1 | |
n_heads = 2 | |
window_size = 4 | |
length_scale = 1.0 | |
# Format for instantiating encoder | |
# encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, | |
# filter_channels, filter_channels_dp, n_heads, | |
# n_enc_layers, enc_kernel, enc_dropout, window_size) | |
# Example Declaration | |
encoder = TextEncoder(len(symbols) + 1, n_feats, n_enc_channels, | |
filter_channels, filter_channels_dp, n_heads, | |
n_enc_layers, enc_kernel, enc_dropout, window_size) | |
# Get Parsed Text | |
cmu = cmudict.CMUDict('./resources/cmu_dictionary') | |
# Example transcript from the same VCTK clip | |
# text = "We used to live with dignity in our country." | |
text = "They did not attack the themes of the book." | |
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None] | |
x_lengths = torch.LongTensor([x.shape[-1]]) | |
mu_x, logw, x_mask = encoder(x, x_lengths, None) # Pass in None for spk rn | |
# Inference Time Code | |
w = torch.exp(logw) * x_mask | |
w_ceil = torch.ceil(w) * length_scale | |
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() | |
y_max_length = int(y_lengths.max()) | |
y_max_length_ = fix_len_compatibility(y_max_length) | |
# Using obtained durations `w` construct alignment map `attn` | |
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) | |
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) | |
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) | |
# Align encoded text and get mu_y | |
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) | |
mu_y = mu_y.transpose(1, 2) | |
encoder_outputs = mu_y[:, :, :y_max_length][0].detach().numpy() | |
# Plotting the phoneme encodings | |
plt.figure(figsize=(10, 4)) | |
plt.imshow(encoder_outputs, aspect='auto', origin='lower', | |
extent=[0, encoder_outputs.shape[1], 0, encoder_outputs.shape[0]]) | |
plt.colorbar(label='Intensity') | |
plt.xlabel('Time') | |
plt.ylabel('Mel Frequency Bands') | |
plt.title('Phoneme Encoding') | |
plt.savefig('./assets/example_untrained_phone_encoding.png') | |
# Train Time Code | |
# Test out that duration loss works | |
y = torch.Tensor(mel.T).unsqueeze(0) | |
y_lengths = [y.shape[-1]] | |
y_max_length = y.shape[-1] | |
y_lengths = torch.LongTensor([y.shape[-1]]) | |
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) | |
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) | |
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram | |
with torch.no_grad(): | |
const = -0.5 * math.log(2 * math.pi) * n_feats | |
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) | |
y_square = torch.matmul(factor.transpose(1, 2), y ** 2) | |
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) | |
mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) | |
log_prior = y_square - y_mu_double + mu_square + const | |
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) | |
attn = attn.detach() | |
attn_np = attn.numpy() | |
# Compute loss between predicted log-scaled durations and those obtained from MAS | |
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask | |
dur_loss = duration_loss(logw, logw_, x_lengths) | |
# Align text with mel-spec to get mu_y | |
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) | |
mu_y = mu_y.transpose(1, 2) | |
mu_y_np = mu_y.detach().numpy() | |
# Compute loss between aligned encoder outputs and mel-spectrogram | |
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) | |
prior_loss = prior_loss / (torch.sum(y_mask) * n_feats) | |
# Plot the Aligned Text with Mel-Spec | |
plt.figure(figsize=(10, 4)) | |
plt.imshow(attn_np.squeeze(0), aspect='auto', origin='lower', | |
extent=[0, attn_np.shape[2], 0, attn_np.shape[1]]) | |
plt.colorbar(label='Intensity') | |
plt.xlabel('Time') | |
plt.ylabel('Mel Frequency Bands') | |
plt.title('Untrained Duration') | |
plt.savefig('./assets/example_duration.png') | |
# Plot the Aligned Text with Mel-Spec | |
plt.figure(figsize=(10, 4)) | |
plt.imshow(mu_y_np.squeeze(0), aspect='auto', origin='lower', | |
extent=[0, mu_y_np.shape[2], 0, mu_y_np.shape[1]]) | |
plt.colorbar(label='Intensity') | |
plt.xlabel('Time') | |
plt.ylabel('Mel Frequency Bands') | |
plt.title('Untrained Alignment') | |
plt.savefig('./assets/example_MAS.png') | |
# Plot Example Mel | |
plt.figure(figsize=(10, 4)) | |
plt.imshow(mel.T, aspect='auto', origin='lower', | |
extent=[0, mel.shape[0], 0, mel.shape[1]]) | |
plt.colorbar(label='Intensity') | |
plt.xlabel('Time') | |
plt.ylabel('Mel Frequency Bands') | |
plt.title('Goal Mel Spectrogram') | |
plt.savefig('./assets/example_mel.png') | |
if __name__ == "__main__": | |
test_cmu_parser() | |
test_phone_encoder() |