0xrushi's picture
first commit
bcfd9f0
"""
BSD 3-Clause License
Copyright (c) 2018, NVIDIA Corporation
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from training.tacotron2_model.layers import ConvNorm, LinearNorm
from training.tacotron2_model.utils import to_gpu, get_mask_from_lengths, get_x
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
super(LocationLayer, self).__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(
2, attention_n_filters, kernel_size=attention_kernel_size, padding=padding, bias=False, stride=1, dilation=1
)
self.location_dense = LinearNorm(attention_n_filters, attention_dim, bias=False, w_init_gain="tanh")
def forward(self, attention_weights_cat):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Module):
def __init__(
self,
attention_rnn_dim,
embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
):
super(Attention, self).__init__()
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh")
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, w_init_gain="tanh")
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(attention_location_n_filters, attention_location_kernel_size, attention_dim)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class Prenet(nn.Module):
def __init__(self, in_dim, sizes):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[LinearNorm(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, sizes)]
)
def forward(self, x):
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
return x
class Postnet(nn.Module):
"""Postnet
- Five 1-d convolution with 512 channels and kernel size 5
"""
def __init__(self, n_mel_channels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
nn.Sequential(
ConvNorm(
n_mel_channels,
postnet_embedding_dim,
kernel_size=postnet_kernel_size,
stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1,
w_init_gain="tanh",
),
nn.BatchNorm1d(postnet_embedding_dim),
)
)
for i in range(1, postnet_n_convolutions - 1):
self.convolutions.append(
nn.Sequential(
ConvNorm(
postnet_embedding_dim,
postnet_embedding_dim,
kernel_size=postnet_kernel_size,
stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1,
w_init_gain="tanh",
),
nn.BatchNorm1d(postnet_embedding_dim),
)
)
self.convolutions.append(
nn.Sequential(
ConvNorm(
postnet_embedding_dim,
n_mel_channels,
kernel_size=postnet_kernel_size,
stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1,
w_init_gain="linear",
),
nn.BatchNorm1d(n_mel_channels),
)
)
def forward(self, x):
for i in range(len(self.convolutions) - 1):
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
return x
class Encoder(nn.Module):
"""Encoder module:
- Three 1-d convolution banks
- Bidirectional LSTM
"""
def __init__(self, encoder_kernel_size, encoder_n_convolutions, encoder_embedding_dim):
super(Encoder, self).__init__()
convolutions = []
for _ in range(encoder_n_convolutions):
conv_layer = nn.Sequential(
ConvNorm(
encoder_embedding_dim,
encoder_embedding_dim,
kernel_size=encoder_kernel_size,
stride=1,
padding=int((encoder_kernel_size - 1) / 2),
dilation=1,
w_init_gain="relu",
),
nn.BatchNorm1d(encoder_embedding_dim),
)
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
self.lstm = nn.LSTM(
encoder_embedding_dim, int(encoder_embedding_dim / 2), 1, batch_first=True, bidirectional=True
)
def forward(self, x, input_lengths):
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
# pytorch tensor are not reversible, hence the conversion
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
return outputs
def inference(self, x):
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
return outputs
class Decoder(nn.Module):
def __init__(
self,
n_mel_channels,
n_frames_per_step,
encoder_embedding_dim,
attention_dim,
attention_rnn_dim,
attention_location_n_filters,
attention_location_kernel_size,
decoder_rnn_dim,
prenet_dim,
max_decoder_steps,
gate_threshold,
p_attention_dropout,
p_decoder_dropout,
):
super(Decoder, self).__init__()
self.n_mel_channels = n_mel_channels
self.n_frames_per_step = n_frames_per_step
self.encoder_embedding_dim = encoder_embedding_dim
self.attention_rnn_dim = attention_rnn_dim
self.decoder_rnn_dim = decoder_rnn_dim
self.prenet_dim = prenet_dim
self.max_decoder_steps = max_decoder_steps
self.gate_threshold = gate_threshold
self.p_attention_dropout = p_attention_dropout
self.p_decoder_dropout = p_decoder_dropout
self.prenet = Prenet(n_mel_channels * n_frames_per_step, [prenet_dim, prenet_dim])
self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)
self.attention_layer = Attention(
attention_rnn_dim,
encoder_embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
)
self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, 1)
self.linear_projection = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, n_mel_channels * n_frames_per_step)
self.gate_layer = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid")
def get_go_frame(self, memory):
"""Gets all zeros frames to use as first decoder input
PARAMS
------
memory: decoder outputs
RETURNS
-------
decoder_input: all zeros frames
"""
B = memory.size(0)
decoder_input = Variable(memory.data.new(B, self.n_mel_channels * self.n_frames_per_step).zero_())
return decoder_input
def initialize_decoder_states(self, memory, mask):
"""Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
PARAMS
------
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(memory.data.new(B, self.attention_rnn_dim).zero_())
self.attention_cell = Variable(memory.data.new(B, self.attention_rnn_dim).zero_())
self.decoder_hidden = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_())
self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(B, self.encoder_embedding_dim).zero_())
self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
def parse_decoder_inputs(self, decoder_inputs):
"""Prepares decoder inputs, i.e. mel outputs
PARAMS
------
decode encoder_kernel_size=5,
encoder_n_convolutions=3,
encoder_embedding_dim=512,r_inputs: inputs used for teacher-forced training, i.e. mel-specs
RETURNS
-------
inputs: processed decoder inputs
"""
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(1, 2)
decoder_inputs = decoder_inputs.view(
decoder_inputs.size(0), int(decoder_inputs.size(1) / self.n_frames_per_step), -1
)
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(0, 1)
return decoder_inputs
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
"""Prepares decoder outputs for output
PARAMS
------
mel_outputs:
gate_outputs: gate output energies
alignments:
RETURNS
-------
mel_outputs:
gate_outpust: gate output energies
alignments:
"""
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments).transpose(0, 1)
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
gate_outputs = gate_outputs.contiguous()
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
# decouple frames per step
mel_outputs = mel_outputs.view(mel_outputs.size(0), -1, self.n_mel_channels)
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
mel_outputs = mel_outputs.transpose(1, 2)
return mel_outputs, gate_outputs, alignments
def decode(self, decoder_input):
"""Decoder step using stored states, attention and memory
PARAMS
------
decoder_input: previous mel output
RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell)
)
self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training)
attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1
)
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask
)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell)
)
self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)
decoder_output = self.linear_projection(decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
def forward(self, memory, decoder_inputs, memory_lengths, device):
"""Decoder forward pass for training
PARAMS
------
memory: Encoder outputs
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
memory_lengths: Encoder output lengths for attention masking.
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory).unsqueeze(0)
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
decoder_inputs = self.prenet(decoder_inputs)
self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths, device))
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments)
return mel_outputs, gate_outputs, alignments
def inference(self, memory, max_decoder_steps=None):
"""Decoder inference
PARAMS
------
memory: Encoder outputs
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
if not max_decoder_steps:
# Use default max decoder steps if not given
max_decoder_steps = self.max_decoder_steps
decoder_input = self.get_go_frame(memory)
self.initialize_decoder_states(memory, mask=None)
mel_outputs, gate_outputs, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
if torch.sigmoid(gate_output.data) > self.gate_threshold:
break
elif len(mel_outputs) == max_decoder_steps:
raise Exception(
"Warning! Reached max decoder steps. Either the model is low quality or the given sentence is too short/long"
)
decoder_input = mel_output
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments)
return mel_outputs, gate_outputs, alignments
class Tacotron2(nn.Module):
def __init__(
self,
mask_padding=True,
fp16_run=False,
n_mel_channels=80,
n_symbols=148,
symbols_embedding_dim=512,
encoder_kernel_size=5,
encoder_n_convolutions=3,
encoder_embedding_dim=512,
attention_rnn_dim=1024,
attention_dim=128,
attention_location_n_filters=32,
attention_location_kernel_size=31,
decoder_rnn_dim=1024,
prenet_dim=256,
max_decoder_steps=1000,
gate_threshold=0.5,
p_attention_dropout=0.1,
p_decoder_dropout=0.1,
postnet_embedding_dim=512,
postnet_kernel_size=5,
postnet_n_convolutions=5,
):
super(Tacotron2, self).__init__()
self.mask_padding = mask_padding
self.fp16_run = fp16_run
self.n_mel_channels = n_mel_channels
self.n_frames_per_step = 1
self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim)
std = sqrt(2.0 / (n_symbols + symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(encoder_kernel_size, encoder_n_convolutions, encoder_embedding_dim)
self.decoder = Decoder(
n_mel_channels,
self.n_frames_per_step,
encoder_embedding_dim,
attention_dim,
attention_rnn_dim,
attention_location_n_filters,
attention_location_kernel_size,
decoder_rnn_dim,
prenet_dim,
max_decoder_steps,
gate_threshold,
p_attention_dropout,
p_decoder_dropout,
)
self.postnet = Postnet(n_mel_channels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions)
def parse_batch(self, batch):
text_padded, input_lengths, mel_padded, gate_padded, output_lengths = batch
text_padded = to_gpu(text_padded).long()
input_lengths = to_gpu(input_lengths).long()
max_len = torch.max(input_lengths.data).item()
mel_padded = to_gpu(mel_padded).float()
gate_padded = to_gpu(gate_padded).float()
output_lengths = to_gpu(output_lengths).long()
return ((text_padded, input_lengths, mel_padded, max_len, output_lengths), (mel_padded, gate_padded))
def parse_output(self, outputs, output_lengths, mask_size, alignment_mask_size, device):
if self.mask_padding:
mask = ~get_mask_from_lengths(output_lengths, device, mask_size)
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
if outputs[3].size(2) != alignment_mask_size:
outputs[3] = nn.ConstantPad1d((0, alignment_mask_size - outputs[3].size(2)), 0)(outputs[3])
return outputs
def forward(self, inputs, mask_size, alignment_mask_size):
text_inputs, text_lengths, mels, output_lengths = get_x(inputs)
device = text_inputs.device
text_lengths, output_lengths = text_lengths.data, output_lengths.data
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
mel_outputs, gate_outputs, alignments = self.decoder(
encoder_outputs, mels, memory_lengths=text_lengths, device=device
)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
return self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
output_lengths,
mask_size,
alignment_mask_size,
device,
)
def inference(self, inputs, max_decoder_steps=None):
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
mel_outputs, gate_outputs, alignments = self.decoder.inference(encoder_outputs, max_decoder_steps)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
return [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]