Spaces:
Runtime error
Runtime error
File size: 6,696 Bytes
5a9b731 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
test_phone_encoder.py
Desc: Check to make sure that using the Grad-TTS Encoder will work
"""
import sys
sys.path.append('./')
import torch
import numpy as np
import math
from text import text_to_sequence, cmudict
from text.symbols import symbols
from models.utils import intersperse
from models.phoneme_encoder import TextEncoder
from models.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility
from models import monotonic_align
from text import get_arpabet, _symbol_to_id
import matplotlib.pyplot as plt
def test_cmu_parser():
cmu = cmudict.CMUDict('./resources/cmu_dictionary')
text = "Here I go breaking audio models again."
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]])
arpabet_example = get_arpabet("Here", cmu)
"""
test_phone_encoder
Desc: function for ensuring that the Text Encoder works with params
"""
def test_phone_encoder():
# Load in Sample Mel Spec
mel = np.load('/data/jiachenlian/VCTK/mels_16k/p225/p225_219.npy') # Of shape (T, C) where C is channel num
# Speech Config for Values set below
add_blank = True
n_feats = 80
n_spks = 1 # 247 for Libri-TTS filelist and 1 for LJSpeech
spk_emb_dim = 64
n_feats = 80
n_fft = 1024
sample_rate = 22050
hop_length = 256
win_length = 1024
f_min = 0
f_max = 8000
# encoder parameters
n_enc_channels = 192
filter_channels = 768
filter_channels_dp = 256
n_enc_layers = 6
enc_kernel = 3
enc_dropout = 0.1
n_heads = 2
window_size = 4
length_scale = 1.0
# Format for instantiating encoder
# encoder = TextEncoder(n_vocab, n_feats, n_enc_channels,
# filter_channels, filter_channels_dp, n_heads,
# n_enc_layers, enc_kernel, enc_dropout, window_size)
# Example Declaration
encoder = TextEncoder(len(symbols) + 1, n_feats, n_enc_channels,
filter_channels, filter_channels_dp, n_heads,
n_enc_layers, enc_kernel, enc_dropout, window_size)
# Get Parsed Text
cmu = cmudict.CMUDict('./resources/cmu_dictionary')
# Example transcript from the same VCTK clip
# text = "We used to live with dignity in our country."
text = "They did not attack the themes of the book."
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]])
mu_x, logw, x_mask = encoder(x, x_lengths, None) # Pass in None for spk rn
# Inference Time Code
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = int(y_lengths.max())
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length][0].detach().numpy()
# Plotting the phoneme encodings
plt.figure(figsize=(10, 4))
plt.imshow(encoder_outputs, aspect='auto', origin='lower',
extent=[0, encoder_outputs.shape[1], 0, encoder_outputs.shape[0]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Phoneme Encoding')
plt.savefig('./assets/example_untrained_phone_encoding.png')
# Train Time Code
# Test out that duration loss works
y = torch.Tensor(mel.T).unsqueeze(0)
y_lengths = [y.shape[-1]]
y_max_length = y.shape[-1]
y_lengths = torch.LongTensor([y.shape[-1]])
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y ** 2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
attn_np = attn.numpy()
# Compute loss between predicted log-scaled durations and those obtained from MAS
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
# Align text with mel-spec to get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
mu_y_np = mu_y.detach().numpy()
# Compute loss between aligned encoder outputs and mel-spectrogram
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * n_feats)
# Plot the Aligned Text with Mel-Spec
plt.figure(figsize=(10, 4))
plt.imshow(attn_np.squeeze(0), aspect='auto', origin='lower',
extent=[0, attn_np.shape[2], 0, attn_np.shape[1]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Untrained Duration')
plt.savefig('./assets/example_duration.png')
# Plot the Aligned Text with Mel-Spec
plt.figure(figsize=(10, 4))
plt.imshow(mu_y_np.squeeze(0), aspect='auto', origin='lower',
extent=[0, mu_y_np.shape[2], 0, mu_y_np.shape[1]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Untrained Alignment')
plt.savefig('./assets/example_MAS.png')
# Plot Example Mel
plt.figure(figsize=(10, 4))
plt.imshow(mel.T, aspect='auto', origin='lower',
extent=[0, mel.shape[0], 0, mel.shape[1]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Goal Mel Spectrogram')
plt.savefig('./assets/example_mel.png')
if __name__ == "__main__":
test_cmu_parser()
test_phone_encoder() |