Spaces:
Build error
Build error
""" | |
BSD 3-Clause License | |
Copyright (c) 2018, NVIDIA Corporation | |
All rights reserved. | |
Redistribution and use in source and binary forms, with or without | |
modification, are permitted provided that the following conditions are met: | |
* Redistributions of source code must retain the above copyright notice, this | |
list of conditions and the following disclaimer. | |
* Redistributions in binary form must reproduce the above copyright notice, | |
this list of conditions and the following disclaimer in the documentation | |
and/or other materials provided with the distribution. | |
* Neither the name of the copyright holder nor the names of its | |
contributors may be used to endorse or promote products derived from | |
this software without specific prior written permission. | |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
""" | |
from math import sqrt | |
import torch | |
from torch.autograd import Variable | |
from torch import nn | |
from torch.nn import functional as F | |
from training.tacotron2_model.layers import ConvNorm, LinearNorm | |
from training.tacotron2_model.utils import to_gpu, get_mask_from_lengths, get_x | |
class LocationLayer(nn.Module): | |
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): | |
super(LocationLayer, self).__init__() | |
padding = int((attention_kernel_size - 1) / 2) | |
self.location_conv = ConvNorm( | |
2, attention_n_filters, kernel_size=attention_kernel_size, padding=padding, bias=False, stride=1, dilation=1 | |
) | |
self.location_dense = LinearNorm(attention_n_filters, attention_dim, bias=False, w_init_gain="tanh") | |
def forward(self, attention_weights_cat): | |
processed_attention = self.location_conv(attention_weights_cat) | |
processed_attention = processed_attention.transpose(1, 2) | |
processed_attention = self.location_dense(processed_attention) | |
return processed_attention | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
attention_rnn_dim, | |
embedding_dim, | |
attention_dim, | |
attention_location_n_filters, | |
attention_location_kernel_size, | |
): | |
super(Attention, self).__init__() | |
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh") | |
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, w_init_gain="tanh") | |
self.v = LinearNorm(attention_dim, 1, bias=False) | |
self.location_layer = LocationLayer(attention_location_n_filters, attention_location_kernel_size, attention_dim) | |
self.score_mask_value = -float("inf") | |
def get_alignment_energies(self, query, processed_memory, attention_weights_cat): | |
""" | |
PARAMS | |
------ | |
query: decoder output (batch, n_mel_channels * n_frames_per_step) | |
processed_memory: processed encoder outputs (B, T_in, attention_dim) | |
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) | |
RETURNS | |
------- | |
alignment (batch, max_time) | |
""" | |
processed_query = self.query_layer(query.unsqueeze(1)) | |
processed_attention_weights = self.location_layer(attention_weights_cat) | |
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory)) | |
energies = energies.squeeze(-1) | |
return energies | |
def forward(self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask): | |
""" | |
PARAMS | |
------ | |
attention_hidden_state: attention rnn last output | |
memory: encoder outputs | |
processed_memory: processed encoder outputs | |
attention_weights_cat: previous and cummulative attention weights | |
mask: binary mask for padded data | |
""" | |
alignment = self.get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat) | |
if mask is not None: | |
alignment.data.masked_fill_(mask, self.score_mask_value) | |
attention_weights = F.softmax(alignment, dim=1) | |
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) | |
attention_context = attention_context.squeeze(1) | |
return attention_context, attention_weights | |
class Prenet(nn.Module): | |
def __init__(self, in_dim, sizes): | |
super(Prenet, self).__init__() | |
in_sizes = [in_dim] + sizes[:-1] | |
self.layers = nn.ModuleList( | |
[LinearNorm(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, sizes)] | |
) | |
def forward(self, x): | |
for linear in self.layers: | |
x = F.dropout(F.relu(linear(x)), p=0.5, training=True) | |
return x | |
class Postnet(nn.Module): | |
"""Postnet | |
- Five 1-d convolution with 512 channels and kernel size 5 | |
""" | |
def __init__(self, n_mel_channels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions): | |
super(Postnet, self).__init__() | |
self.convolutions = nn.ModuleList() | |
self.convolutions.append( | |
nn.Sequential( | |
ConvNorm( | |
n_mel_channels, | |
postnet_embedding_dim, | |
kernel_size=postnet_kernel_size, | |
stride=1, | |
padding=int((postnet_kernel_size - 1) / 2), | |
dilation=1, | |
w_init_gain="tanh", | |
), | |
nn.BatchNorm1d(postnet_embedding_dim), | |
) | |
) | |
for i in range(1, postnet_n_convolutions - 1): | |
self.convolutions.append( | |
nn.Sequential( | |
ConvNorm( | |
postnet_embedding_dim, | |
postnet_embedding_dim, | |
kernel_size=postnet_kernel_size, | |
stride=1, | |
padding=int((postnet_kernel_size - 1) / 2), | |
dilation=1, | |
w_init_gain="tanh", | |
), | |
nn.BatchNorm1d(postnet_embedding_dim), | |
) | |
) | |
self.convolutions.append( | |
nn.Sequential( | |
ConvNorm( | |
postnet_embedding_dim, | |
n_mel_channels, | |
kernel_size=postnet_kernel_size, | |
stride=1, | |
padding=int((postnet_kernel_size - 1) / 2), | |
dilation=1, | |
w_init_gain="linear", | |
), | |
nn.BatchNorm1d(n_mel_channels), | |
) | |
) | |
def forward(self, x): | |
for i in range(len(self.convolutions) - 1): | |
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) | |
x = F.dropout(self.convolutions[-1](x), 0.5, self.training) | |
return x | |
class Encoder(nn.Module): | |
"""Encoder module: | |
- Three 1-d convolution banks | |
- Bidirectional LSTM | |
""" | |
def __init__(self, encoder_kernel_size, encoder_n_convolutions, encoder_embedding_dim): | |
super(Encoder, self).__init__() | |
convolutions = [] | |
for _ in range(encoder_n_convolutions): | |
conv_layer = nn.Sequential( | |
ConvNorm( | |
encoder_embedding_dim, | |
encoder_embedding_dim, | |
kernel_size=encoder_kernel_size, | |
stride=1, | |
padding=int((encoder_kernel_size - 1) / 2), | |
dilation=1, | |
w_init_gain="relu", | |
), | |
nn.BatchNorm1d(encoder_embedding_dim), | |
) | |
convolutions.append(conv_layer) | |
self.convolutions = nn.ModuleList(convolutions) | |
self.lstm = nn.LSTM( | |
encoder_embedding_dim, int(encoder_embedding_dim / 2), 1, batch_first=True, bidirectional=True | |
) | |
def forward(self, x, input_lengths): | |
for conv in self.convolutions: | |
x = F.dropout(F.relu(conv(x)), 0.5, self.training) | |
x = x.transpose(1, 2) | |
# pytorch tensor are not reversible, hence the conversion | |
input_lengths = input_lengths.cpu().numpy() | |
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) | |
self.lstm.flatten_parameters() | |
outputs, _ = self.lstm(x) | |
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) | |
return outputs | |
def inference(self, x): | |
for conv in self.convolutions: | |
x = F.dropout(F.relu(conv(x)), 0.5, self.training) | |
x = x.transpose(1, 2) | |
self.lstm.flatten_parameters() | |
outputs, _ = self.lstm(x) | |
return outputs | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
n_mel_channels, | |
n_frames_per_step, | |
encoder_embedding_dim, | |
attention_dim, | |
attention_rnn_dim, | |
attention_location_n_filters, | |
attention_location_kernel_size, | |
decoder_rnn_dim, | |
prenet_dim, | |
max_decoder_steps, | |
gate_threshold, | |
p_attention_dropout, | |
p_decoder_dropout, | |
): | |
super(Decoder, self).__init__() | |
self.n_mel_channels = n_mel_channels | |
self.n_frames_per_step = n_frames_per_step | |
self.encoder_embedding_dim = encoder_embedding_dim | |
self.attention_rnn_dim = attention_rnn_dim | |
self.decoder_rnn_dim = decoder_rnn_dim | |
self.prenet_dim = prenet_dim | |
self.max_decoder_steps = max_decoder_steps | |
self.gate_threshold = gate_threshold | |
self.p_attention_dropout = p_attention_dropout | |
self.p_decoder_dropout = p_decoder_dropout | |
self.prenet = Prenet(n_mel_channels * n_frames_per_step, [prenet_dim, prenet_dim]) | |
self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim) | |
self.attention_layer = Attention( | |
attention_rnn_dim, | |
encoder_embedding_dim, | |
attention_dim, | |
attention_location_n_filters, | |
attention_location_kernel_size, | |
) | |
self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, 1) | |
self.linear_projection = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, n_mel_channels * n_frames_per_step) | |
self.gate_layer = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid") | |
def get_go_frame(self, memory): | |
"""Gets all zeros frames to use as first decoder input | |
PARAMS | |
------ | |
memory: decoder outputs | |
RETURNS | |
------- | |
decoder_input: all zeros frames | |
""" | |
B = memory.size(0) | |
decoder_input = Variable(memory.data.new(B, self.n_mel_channels * self.n_frames_per_step).zero_()) | |
return decoder_input | |
def initialize_decoder_states(self, memory, mask): | |
"""Initializes attention rnn states, decoder rnn states, attention | |
weights, attention cumulative weights, attention context, stores memory | |
and stores processed memory | |
PARAMS | |
------ | |
memory: Encoder outputs | |
mask: Mask for padded data if training, expects None for inference | |
""" | |
B = memory.size(0) | |
MAX_TIME = memory.size(1) | |
self.attention_hidden = Variable(memory.data.new(B, self.attention_rnn_dim).zero_()) | |
self.attention_cell = Variable(memory.data.new(B, self.attention_rnn_dim).zero_()) | |
self.decoder_hidden = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) | |
self.decoder_cell = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) | |
self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) | |
self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_()) | |
self.attention_context = Variable(memory.data.new(B, self.encoder_embedding_dim).zero_()) | |
self.memory = memory | |
self.processed_memory = self.attention_layer.memory_layer(memory) | |
self.mask = mask | |
def parse_decoder_inputs(self, decoder_inputs): | |
"""Prepares decoder inputs, i.e. mel outputs | |
PARAMS | |
------ | |
decode encoder_kernel_size=5, | |
encoder_n_convolutions=3, | |
encoder_embedding_dim=512,r_inputs: inputs used for teacher-forced training, i.e. mel-specs | |
RETURNS | |
------- | |
inputs: processed decoder inputs | |
""" | |
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) | |
decoder_inputs = decoder_inputs.transpose(1, 2) | |
decoder_inputs = decoder_inputs.view( | |
decoder_inputs.size(0), int(decoder_inputs.size(1) / self.n_frames_per_step), -1 | |
) | |
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) | |
decoder_inputs = decoder_inputs.transpose(0, 1) | |
return decoder_inputs | |
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): | |
"""Prepares decoder outputs for output | |
PARAMS | |
------ | |
mel_outputs: | |
gate_outputs: gate output energies | |
alignments: | |
RETURNS | |
------- | |
mel_outputs: | |
gate_outpust: gate output energies | |
alignments: | |
""" | |
# (T_out, B) -> (B, T_out) | |
alignments = torch.stack(alignments).transpose(0, 1) | |
# (T_out, B) -> (B, T_out) | |
gate_outputs = torch.stack(gate_outputs).transpose(0, 1) | |
gate_outputs = gate_outputs.contiguous() | |
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) | |
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() | |
# decouple frames per step | |
mel_outputs = mel_outputs.view(mel_outputs.size(0), -1, self.n_mel_channels) | |
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) | |
mel_outputs = mel_outputs.transpose(1, 2) | |
return mel_outputs, gate_outputs, alignments | |
def decode(self, decoder_input): | |
"""Decoder step using stored states, attention and memory | |
PARAMS | |
------ | |
decoder_input: previous mel output | |
RETURNS | |
------- | |
mel_output: | |
gate_output: gate output energies | |
attention_weights: | |
""" | |
cell_input = torch.cat((decoder_input, self.attention_context), -1) | |
self.attention_hidden, self.attention_cell = self.attention_rnn( | |
cell_input, (self.attention_hidden, self.attention_cell) | |
) | |
self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training) | |
attention_weights_cat = torch.cat( | |
(self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1 | |
) | |
self.attention_context, self.attention_weights = self.attention_layer( | |
self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask | |
) | |
self.attention_weights_cum += self.attention_weights | |
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) | |
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( | |
decoder_input, (self.decoder_hidden, self.decoder_cell) | |
) | |
self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) | |
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1) | |
decoder_output = self.linear_projection(decoder_hidden_attention_context) | |
gate_prediction = self.gate_layer(decoder_hidden_attention_context) | |
return decoder_output, gate_prediction, self.attention_weights | |
def forward(self, memory, decoder_inputs, memory_lengths, device): | |
"""Decoder forward pass for training | |
PARAMS | |
------ | |
memory: Encoder outputs | |
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs | |
memory_lengths: Encoder output lengths for attention masking. | |
RETURNS | |
------- | |
mel_outputs: mel outputs from the decoder | |
gate_outputs: gate outputs from the decoder | |
alignments: sequence of attention weights from the decoder | |
""" | |
decoder_input = self.get_go_frame(memory).unsqueeze(0) | |
decoder_inputs = self.parse_decoder_inputs(decoder_inputs) | |
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) | |
decoder_inputs = self.prenet(decoder_inputs) | |
self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths, device)) | |
mel_outputs, gate_outputs, alignments = [], [], [] | |
while len(mel_outputs) < decoder_inputs.size(0) - 1: | |
decoder_input = decoder_inputs[len(mel_outputs)] | |
mel_output, gate_output, attention_weights = self.decode(decoder_input) | |
mel_outputs += [mel_output.squeeze(1)] | |
gate_outputs += [gate_output.squeeze(1)] | |
alignments += [attention_weights] | |
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) | |
return mel_outputs, gate_outputs, alignments | |
def inference(self, memory, max_decoder_steps=None): | |
"""Decoder inference | |
PARAMS | |
------ | |
memory: Encoder outputs | |
RETURNS | |
------- | |
mel_outputs: mel outputs from the decoder | |
gate_outputs: gate outputs from the decoder | |
alignments: sequence of attention weights from the decoder | |
""" | |
if not max_decoder_steps: | |
# Use default max decoder steps if not given | |
max_decoder_steps = self.max_decoder_steps | |
decoder_input = self.get_go_frame(memory) | |
self.initialize_decoder_states(memory, mask=None) | |
mel_outputs, gate_outputs, alignments = [], [], [] | |
while True: | |
decoder_input = self.prenet(decoder_input) | |
mel_output, gate_output, alignment = self.decode(decoder_input) | |
mel_outputs += [mel_output.squeeze(1)] | |
gate_outputs += [gate_output] | |
alignments += [alignment] | |
if torch.sigmoid(gate_output.data) > self.gate_threshold: | |
break | |
elif len(mel_outputs) == max_decoder_steps: | |
raise Exception( | |
"Warning! Reached max decoder steps. Either the model is low quality or the given sentence is too short/long" | |
) | |
decoder_input = mel_output | |
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) | |
return mel_outputs, gate_outputs, alignments | |
class Tacotron2(nn.Module): | |
def __init__( | |
self, | |
mask_padding=True, | |
fp16_run=False, | |
n_mel_channels=80, | |
n_symbols=148, | |
symbols_embedding_dim=512, | |
encoder_kernel_size=5, | |
encoder_n_convolutions=3, | |
encoder_embedding_dim=512, | |
attention_rnn_dim=1024, | |
attention_dim=128, | |
attention_location_n_filters=32, | |
attention_location_kernel_size=31, | |
decoder_rnn_dim=1024, | |
prenet_dim=256, | |
max_decoder_steps=1000, | |
gate_threshold=0.5, | |
p_attention_dropout=0.1, | |
p_decoder_dropout=0.1, | |
postnet_embedding_dim=512, | |
postnet_kernel_size=5, | |
postnet_n_convolutions=5, | |
): | |
super(Tacotron2, self).__init__() | |
self.mask_padding = mask_padding | |
self.fp16_run = fp16_run | |
self.n_mel_channels = n_mel_channels | |
self.n_frames_per_step = 1 | |
self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim) | |
std = sqrt(2.0 / (n_symbols + symbols_embedding_dim)) | |
val = sqrt(3.0) * std # uniform bounds for std | |
self.embedding.weight.data.uniform_(-val, val) | |
self.encoder = Encoder(encoder_kernel_size, encoder_n_convolutions, encoder_embedding_dim) | |
self.decoder = Decoder( | |
n_mel_channels, | |
self.n_frames_per_step, | |
encoder_embedding_dim, | |
attention_dim, | |
attention_rnn_dim, | |
attention_location_n_filters, | |
attention_location_kernel_size, | |
decoder_rnn_dim, | |
prenet_dim, | |
max_decoder_steps, | |
gate_threshold, | |
p_attention_dropout, | |
p_decoder_dropout, | |
) | |
self.postnet = Postnet(n_mel_channels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions) | |
def parse_batch(self, batch): | |
text_padded, input_lengths, mel_padded, gate_padded, output_lengths = batch | |
text_padded = to_gpu(text_padded).long() | |
input_lengths = to_gpu(input_lengths).long() | |
max_len = torch.max(input_lengths.data).item() | |
mel_padded = to_gpu(mel_padded).float() | |
gate_padded = to_gpu(gate_padded).float() | |
output_lengths = to_gpu(output_lengths).long() | |
return ((text_padded, input_lengths, mel_padded, max_len, output_lengths), (mel_padded, gate_padded)) | |
def parse_output(self, outputs, output_lengths, mask_size, alignment_mask_size, device): | |
if self.mask_padding: | |
mask = ~get_mask_from_lengths(output_lengths, device, mask_size) | |
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) | |
mask = mask.permute(1, 0, 2) | |
outputs[0].data.masked_fill_(mask, 0.0) | |
outputs[1].data.masked_fill_(mask, 0.0) | |
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies | |
if outputs[3].size(2) != alignment_mask_size: | |
outputs[3] = nn.ConstantPad1d((0, alignment_mask_size - outputs[3].size(2)), 0)(outputs[3]) | |
return outputs | |
def forward(self, inputs, mask_size, alignment_mask_size): | |
text_inputs, text_lengths, mels, output_lengths = get_x(inputs) | |
device = text_inputs.device | |
text_lengths, output_lengths = text_lengths.data, output_lengths.data | |
embedded_inputs = self.embedding(text_inputs).transpose(1, 2) | |
encoder_outputs = self.encoder(embedded_inputs, text_lengths) | |
mel_outputs, gate_outputs, alignments = self.decoder( | |
encoder_outputs, mels, memory_lengths=text_lengths, device=device | |
) | |
mel_outputs_postnet = self.postnet(mel_outputs) | |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet | |
return self.parse_output( | |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments], | |
output_lengths, | |
mask_size, | |
alignment_mask_size, | |
device, | |
) | |
def inference(self, inputs, max_decoder_steps=None): | |
embedded_inputs = self.embedding(inputs).transpose(1, 2) | |
encoder_outputs = self.encoder.inference(embedded_inputs) | |
mel_outputs, gate_outputs, alignments = self.decoder.inference(encoder_outputs, max_decoder_steps) | |
mel_outputs_postnet = self.postnet(mel_outputs) | |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet | |
return [mel_outputs, mel_outputs_postnet, gate_outputs, alignments] | |