ALeLacheur's picture
uploading audio diffusion attacks
5a9b731
raw
history blame
6.7 kB
"""
test_phone_encoder.py
Desc: Check to make sure that using the Grad-TTS Encoder will work
"""
import sys
sys.path.append('./')
import torch
import numpy as np
import math
from text import text_to_sequence, cmudict
from text.symbols import symbols
from models.utils import intersperse
from models.phoneme_encoder import TextEncoder
from models.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility
from models import monotonic_align
from text import get_arpabet, _symbol_to_id
import matplotlib.pyplot as plt
def test_cmu_parser():
cmu = cmudict.CMUDict('./resources/cmu_dictionary')
text = "Here I go breaking audio models again."
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]])
arpabet_example = get_arpabet("Here", cmu)
"""
test_phone_encoder
Desc: function for ensuring that the Text Encoder works with params
"""
def test_phone_encoder():
# Load in Sample Mel Spec
mel = np.load('/data/jiachenlian/VCTK/mels_16k/p225/p225_219.npy') # Of shape (T, C) where C is channel num
# Speech Config for Values set below
add_blank = True
n_feats = 80
n_spks = 1 # 247 for Libri-TTS filelist and 1 for LJSpeech
spk_emb_dim = 64
n_feats = 80
n_fft = 1024
sample_rate = 22050
hop_length = 256
win_length = 1024
f_min = 0
f_max = 8000
# encoder parameters
n_enc_channels = 192
filter_channels = 768
filter_channels_dp = 256
n_enc_layers = 6
enc_kernel = 3
enc_dropout = 0.1
n_heads = 2
window_size = 4
length_scale = 1.0
# Format for instantiating encoder
# encoder = TextEncoder(n_vocab, n_feats, n_enc_channels,
# filter_channels, filter_channels_dp, n_heads,
# n_enc_layers, enc_kernel, enc_dropout, window_size)
# Example Declaration
encoder = TextEncoder(len(symbols) + 1, n_feats, n_enc_channels,
filter_channels, filter_channels_dp, n_heads,
n_enc_layers, enc_kernel, enc_dropout, window_size)
# Get Parsed Text
cmu = cmudict.CMUDict('./resources/cmu_dictionary')
# Example transcript from the same VCTK clip
# text = "We used to live with dignity in our country."
text = "They did not attack the themes of the book."
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]])
mu_x, logw, x_mask = encoder(x, x_lengths, None) # Pass in None for spk rn
# Inference Time Code
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = int(y_lengths.max())
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length][0].detach().numpy()
# Plotting the phoneme encodings
plt.figure(figsize=(10, 4))
plt.imshow(encoder_outputs, aspect='auto', origin='lower',
extent=[0, encoder_outputs.shape[1], 0, encoder_outputs.shape[0]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Phoneme Encoding')
plt.savefig('./assets/example_untrained_phone_encoding.png')
# Train Time Code
# Test out that duration loss works
y = torch.Tensor(mel.T).unsqueeze(0)
y_lengths = [y.shape[-1]]
y_max_length = y.shape[-1]
y_lengths = torch.LongTensor([y.shape[-1]])
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y ** 2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
attn_np = attn.numpy()
# Compute loss between predicted log-scaled durations and those obtained from MAS
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
# Align text with mel-spec to get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
mu_y_np = mu_y.detach().numpy()
# Compute loss between aligned encoder outputs and mel-spectrogram
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * n_feats)
# Plot the Aligned Text with Mel-Spec
plt.figure(figsize=(10, 4))
plt.imshow(attn_np.squeeze(0), aspect='auto', origin='lower',
extent=[0, attn_np.shape[2], 0, attn_np.shape[1]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Untrained Duration')
plt.savefig('./assets/example_duration.png')
# Plot the Aligned Text with Mel-Spec
plt.figure(figsize=(10, 4))
plt.imshow(mu_y_np.squeeze(0), aspect='auto', origin='lower',
extent=[0, mu_y_np.shape[2], 0, mu_y_np.shape[1]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Untrained Alignment')
plt.savefig('./assets/example_MAS.png')
# Plot Example Mel
plt.figure(figsize=(10, 4))
plt.imshow(mel.T, aspect='auto', origin='lower',
extent=[0, mel.shape[0], 0, mel.shape[1]])
plt.colorbar(label='Intensity')
plt.xlabel('Time')
plt.ylabel('Mel Frequency Bands')
plt.title('Goal Mel Spectrogram')
plt.savefig('./assets/example_mel.png')
if __name__ == "__main__":
test_cmu_parser()
test_phone_encoder()