Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
import monotonic_align | |
from models.text_encoder import TextEncoder | |
from models.flow_matching import CFMDecoder | |
from models.reference_encoder import MelStyleEncoder | |
from models.duration_predictor import DurationPredictor, duration_loss | |
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor: | |
if max_length is None: | |
max_length = length.max() | |
x = torch.arange(max_length, dtype=length.dtype, device=length.device) | |
return x.unsqueeze(0) < length.unsqueeze(1) | |
def convert_pad_shape(pad_shape): | |
inverted_shape = pad_shape[::-1] | |
pad_shape = [item for sublist in inverted_shape for item in sublist] | |
return pad_shape | |
def generate_path(duration, mask): | |
device = duration.device | |
b, t_x, t_y = mask.shape | |
cum_duration = torch.cumsum(duration, 1) | |
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) | |
cum_duration_flat = cum_duration.view(b * t_x) | |
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) | |
path = path.view(b, t_x, t_y) | |
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] | |
path = path * mask | |
return path | |
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py | |
class StableTTS(nn.Module): | |
def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels): | |
super().__init__() | |
self.n_vocab = n_vocab | |
self.mel_channels = mel_channels | |
self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels) | |
self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=3) | |
self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, p_dropout, gin_channels) | |
self.decoder = CFMDecoder(mel_channels + mel_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels) | |
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0): | |
""" | |
Generates mel-spectrogram from text. Returns: | |
1. encoder outputs | |
2. decoder outputs | |
3. generated alignment | |
Args: | |
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. | |
shape: (batch_size, max_text_length) | |
x_lengths (torch.Tensor): lengths of texts in batch. | |
shape: (batch_size,) | |
n_timesteps (int): number of steps to use for reverse diffusion in decoder. | |
temperature (float, optional): controls variance of terminal distribution. | |
y (torch.Tensor): mel spectrogram of reference audio | |
shape: (batch_size, mel_channels, time) | |
length_scale (float, optional): controls speech pace. | |
Increase value to slow down generated speech and vice versa. | |
Returns: | |
dict: { | |
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), | |
# Average mel spectrogram generated by the encoder | |
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), | |
# Refined mel spectrogram improved by the CFM | |
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), | |
# Alignment map between text and mel spectrogram | |
""" | |
# Get encoder_outputs `mu_x` and log-scaled token durations `logw` | |
c = self.ref_encoder(y, None) | |
x, mu_x, x_mask = self.encoder(x, c, x_lengths) | |
logw = self.dp(x, x_mask, c) | |
w = torch.exp(logw) * x_mask | |
w_ceil = torch.ceil(w) * length_scale | |
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() | |
y_max_length = y_lengths.max() | |
# Using obtained durations `w` construct alignment map `attn` | |
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype) | |
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) | |
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) | |
# Align encoded text and get mu_y | |
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) | |
mu_y = mu_y.transpose(1, 2) | |
encoder_outputs = mu_y[:, :, :y_max_length] | |
# Generate sample tracing the probability flow | |
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c) | |
decoder_outputs = decoder_outputs[:, :, :y_max_length] | |
return { | |
"encoder_outputs": encoder_outputs, | |
"decoder_outputs": decoder_outputs, | |
"attn": attn[:, :, :y_max_length], | |
} | |
def forward(self, x, x_lengths, y, y_lengths): | |
""" | |
Computes 3 losses: | |
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). | |
2. prior loss: loss between mel-spectrogram and encoder outputs. | |
3. flow matching loss: loss between mel-spectrogram and decoder outputs. | |
Args: | |
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. | |
shape: (batch_size, max_text_length) | |
x_lengths (torch.Tensor): lengths of texts in batch. | |
shape: (batch_size,) | |
y (torch.Tensor): batch of corresponding mel-spectrograms. | |
shape: (batch_size, n_feats, max_mel_length) | |
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. | |
shape: (batch_size,) | |
""" | |
# Get encoder_outputs `mu_x` and log-scaled token durations `logw` | |
y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype) | |
c = self.ref_encoder(y, y_mask) | |
x, mu_x, x_mask = self.encoder(x, c, x_lengths) | |
logw = self.dp(x, x_mask, c) | |
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) | |
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram | |
# I'm not sure why the MAS code in Matcha TTS and Grad TTS could not align in StableTTS | |
# so I use the code from https://github.com/p0p4k/pflowtts_pytorch/blob/master/pflow/models/pflow_tts.py and it works | |
# Welcome everyone to solve this problem QAQ | |
with torch.no_grad(): | |
# const = -0.5 * math.log(2 * math.pi) * self.n_feats | |
# const = -0.5 * math.log(2 * math.pi) * self.mel_channels | |
# factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) | |
# y_square = torch.matmul(factor.transpose(1, 2), y**2) | |
# y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) | |
# mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) | |
# log_prior = y_square - y_mu_double + mu_square + const | |
s_p_sq_r = torch.ones_like(mu_x) # [b, d, t] | |
# s_p_sq_r = torch.exp(-2 * logx) | |
neg_cent1 = torch.sum( | |
-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True | |
) | |
# neg_cent1 = torch.sum( | |
# -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True | |
# ) # [b, 1, t_s] | |
neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r) | |
neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r)) | |
neg_cent4 = torch.sum( | |
-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True | |
) | |
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 | |
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) | |
attn = ( | |
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() | |
) | |
# attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) | |
# attn = attn.detach() | |
# Compute loss between predicted log-scaled durations and those obtained from MAS | |
# refered to as prior loss in the paper | |
logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask | |
# logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask | |
dur_loss = duration_loss(logw, logw_, x_lengths) | |
# Align encoded text with mel-spectrogram and get mu_y segment | |
attn = attn.squeeze(1).transpose(1,2) | |
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) | |
mu_y = mu_y.transpose(1, 2) | |
# Compute loss of the decoder | |
diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y, c) | |
# diff_loss = torch.tensor([0], device=mu_y.device) | |
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) | |
prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels) | |
return dur_loss, diff_loss, prior_loss, attn |