Tzktz's picture
Upload 7664 files
6fc683c verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import math
from collections.abc import Iterable
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import LinearizedConvolution
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock
@register_model("asr_vggtransformer")
class VGGTransformerModel(FairseqEncoderDecoderModel):
"""
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock:
[(out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
use_layer_norm), ...])
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help=""""
a tuple containing the configuration of the encoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]')
""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="""
encoder output dimension, can be None. If specified, projecting the
transformer output to the specified dimension""",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--tgt-embed-dim",
type=int,
metavar="N",
help="embedding dimension of the decoder target tokens",
)
parser.add_argument(
"--transformer-dec-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the decoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]
""",
)
parser.add_argument(
"--conv-dec-config",
type=str,
metavar="EXPR",
help="""
an array of tuples for the decoder 1-D convolution config
[(out_channels, conv_kernel_size, use_layer_norm), ...]""",
)
@classmethod
def build_encoder(cls, args, task):
return VGGTransformerEncoder(
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
)
@classmethod
def build_decoder(cls, args, task):
return TransformerDecoder(
dictionary=task.target_dictionary,
embed_dim=args.tgt_embed_dim,
transformer_config=eval(args.transformer_dec_config),
conv_config=eval(args.conv_dec_config),
encoder_output_dim=args.enc_output_dim,
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted
# (in case there are any new ones)
base_architecture(args)
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
# 256: embedding dimension
# 4: number of heads
# 1024: FFN
# True: apply layerNorm before (dropout + resiaul) instead of after
# 0.2 (dropout): dropout after MultiheadAttention and second FC
# 0.2 (attention_dropout): dropout in MultiheadAttention
# 0.2 (relu_dropout): dropout after ReLu
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
# TODO: repace transformer encoder config from one liner
# to explicit args to get rid of this transformation
def prepare_transformer_encoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.encoder_embed_dim = input_dim
args.encoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.encoder_normalize_before = normalize_before
args.encoder_ffn_embed_dim = ffn_dim
return args
def prepare_transformer_decoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.decoder_embed_dim = input_dim
args.decoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.decoder_normalize_before = normalize_before
args.decoder_ffn_embed_dim = ffn_dim
return args
class VGGTransformerEncoder(FairseqEncoder):
"""VGG + Transformer encoder"""
def __init__(
self,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
"""constructor for VGGTransformerEncoder
Args:
- input_feat_per_channel: feature dim (not including stacked,
just base feature)
- in_channel: # input channels (e.g., if stack 8 feature vector
together, this is 8)
- vggblock_config: configuration of vggblock, see comments on
DEFAULT_ENC_VGGBLOCK_CONFIG
- transformer_config: configuration of transformer layer, see comments
on DEFAULT_ENC_TRANSFORMER_CONFIG
- encoder_output_dim: final transformer output embedding dimension
- transformer_context: (left, right) if set, self-attention will be focused
on (t-left, t+right)
- transformer_sampling: an iterable of int, must match with
len(transformer_config), transformer_sampling[i] indicates sampling
factor for i-th transformer layer, after multihead att and feedfoward
part
"""
super().__init__(None)
self.num_vggblocks = 0
if vggblock_config is not None:
if not isinstance(vggblock_config, Iterable):
raise ValueError("vggblock_config is not iterable")
self.num_vggblocks = len(vggblock_config)
self.conv_layers = nn.ModuleList()
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
if vggblock_config is not None:
for _, config in enumerate(vggblock_config):
(
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
layer_norm,
) = config
self.conv_layers.append(
VGGBlock(
in_channels,
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
input_dim=input_feat_per_channel,
layer_norm=layer_norm,
)
)
in_channels = out_channels
input_feat_per_channel = self.conv_layers[-1].output_dim
transformer_input_dim = self.infer_conv_output_dim(
self.in_channels, self.input_dim
)
# transformer_input_dim is the output dimension of VGG part
self.validate_transformer_config(transformer_config)
self.transformer_context = self.parse_transformer_context(transformer_context)
self.transformer_sampling = self.parse_transformer_sampling(
transformer_sampling, len(transformer_config)
)
self.transformer_layers = nn.ModuleList()
if transformer_input_dim != transformer_config[0][0]:
self.transformer_layers.append(
Linear(transformer_input_dim, transformer_config[0][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.transformer_layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[i])
)
)
self.encoder_output_dim = encoder_output_dim
self.transformer_layers.extend(
[
Linear(transformer_config[-1][0], encoder_output_dim),
LayerNorm(encoder_output_dim),
]
)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz, max_seq_len, _ = src_tokens.size()
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
x = x.transpose(1, 2).contiguous()
# (B, C, T, feat)
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
bsz, _, output_seq_len, _ = x.size()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
x = x.transpose(1, 2).transpose(0, 1)
x = x.contiguous().view(output_seq_len, bsz, -1)
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
# TODO: shouldn't subsampling_factor determined in advance ?
input_lengths = (src_lengths.float() / subsampling_factor).ceil().long()
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
input_lengths, batch_first=True
)
if not encoder_padding_mask.any():
encoder_padding_mask = None
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
transformer_layer_idx = 0
for layer_idx in range(len(self.transformer_layers)):
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
x = self.transformer_layers[layer_idx](
x, encoder_padding_mask, attn_mask
)
if self.transformer_sampling[transformer_layer_idx] != 1:
sampling_factor = self.transformer_sampling[transformer_layer_idx]
x, encoder_padding_mask, attn_mask = self.slice(
x, encoder_padding_mask, attn_mask, sampling_factor
)
transformer_layer_idx += 1
else:
x = self.transformer_layers[layer_idx](x)
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": encoder_padding_mask.t()
if encoder_padding_mask is not None
else None,
# (B, T) --> (T, B)
}
def infer_conv_output_dim(self, in_channels, input_dim):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
for i, _ in enumerate(self.conv_layers):
x = self.conv_layers[i](x)
x = x.transpose(1, 2)
mb, seq = x.size()[:2]
return x.contiguous().view(mb, seq, -1).size(-1)
def validate_transformer_config(self, transformer_config):
for config in transformer_config:
input_dim, num_heads = config[:2]
if input_dim % num_heads != 0:
msg = (
"ERROR in transformer config {}:".format(config)
+ "input dimension {} ".format(input_dim)
+ "not dividable by number of heads".format(num_heads)
)
raise ValueError(msg)
def parse_transformer_context(self, transformer_context):
"""
transformer_context can be the following:
- None; indicates no context is used, i.e.,
transformer can access full context
- a tuple/list of two int; indicates left and right context,
any number <0 indicates infinite context
* e.g., (5, 6) indicates that for query at x_t, transformer can
access [t-5, t+6] (inclusive)
* e.g., (-1, 6) indicates that for query at x_t, transformer can
access [0, t+6] (inclusive)
"""
if transformer_context is None:
return None
if not isinstance(transformer_context, Iterable):
raise ValueError("transformer context must be Iterable if it is not None")
if len(transformer_context) != 2:
raise ValueError("transformer context must have length 2")
left_context = transformer_context[0]
if left_context < 0:
left_context = None
right_context = transformer_context[1]
if right_context < 0:
right_context = None
if left_context is None and right_context is None:
return None
return (left_context, right_context)
def parse_transformer_sampling(self, transformer_sampling, num_layers):
"""
parsing transformer sampling configuration
Args:
- transformer_sampling, accepted input:
* None, indicating no sampling
* an Iterable with int (>0) as element
- num_layers, expected number of transformer layers, must match with
the length of transformer_sampling if it is not None
Returns:
- A tuple with length num_layers
"""
if transformer_sampling is None:
return (1,) * num_layers
if not isinstance(transformer_sampling, Iterable):
raise ValueError(
"transformer_sampling must be an iterable if it is not None"
)
if len(transformer_sampling) != num_layers:
raise ValueError(
"transformer_sampling {} does not match with the number "
+ "of layers {}".format(transformer_sampling, num_layers)
)
for layer, value in enumerate(transformer_sampling):
if not isinstance(value, int):
raise ValueError("Invalid value in transformer_sampling: ")
if value < 1:
raise ValueError(
"{} layer's subsampling is {}.".format(layer, value)
+ " This is not allowed! "
)
return transformer_sampling
def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
"""
embedding is a (T, B, D) tensor
padding_mask is a (B, T) tensor or None
attn_mask is a (T, T) tensor or None
"""
embedding = embedding[::sampling_factor, :, :]
if padding_mask is not None:
padding_mask = padding_mask[:, ::sampling_factor]
if attn_mask is not None:
attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
return embedding, padding_mask, attn_mask
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
"""
create attention mask according to sequence lengths and transformer
context
Args:
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
the length of b-th sequence
- subsampling_factor: int
* Note that the left_context and right_context is specified in
the input frame-level while input to transformer may already
go through subsampling (e.g., the use of striding in vggblock)
we use subsampling_factor to scale the left/right context
Return:
- a (T, T) binary tensor or None, where T is max(input_lengths)
* if self.transformer_context is None, None
* if left_context is None,
* attn_mask[t, t + right_context + 1:] = 1
* others = 0
* if right_context is None,
* attn_mask[t, 0:t - left_context] = 1
* others = 0
* elsif
* attn_mask[t, t - left_context: t + right_context + 1] = 0
* others = 1
"""
if self.transformer_context is None:
return None
maxT = torch.max(input_lengths).item()
attn_mask = torch.zeros(maxT, maxT)
left_context = self.transformer_context[0]
right_context = self.transformer_context[1]
if left_context is not None:
left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
if right_context is not None:
right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
for t in range(maxT):
if left_context is not None:
st = 0
en = max(st, t - left_context)
attn_mask[t, st:en] = 1
if right_context is not None:
st = t + right_context + 1
st = min(st, maxT - 1)
attn_mask[t, st:] = 1
return attn_mask.to(input_lengths.device)
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(
self,
dictionary,
embed_dim=512,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
conv_config=DEFAULT_DEC_CONV_CONFIG,
encoder_output_dim=512,
):
super().__init__(dictionary)
vocab_size = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
self.conv_layers = nn.ModuleList()
for i in range(len(conv_config)):
out_channels, kernel_size, layer_norm = conv_config[i]
if i == 0:
conv_layer = LinearizedConv1d(
embed_dim, out_channels, kernel_size, padding=kernel_size - 1
)
else:
conv_layer = LinearizedConv1d(
conv_config[i - 1][0],
out_channels,
kernel_size,
padding=kernel_size - 1,
)
self.conv_layers.append(conv_layer)
if layer_norm:
self.conv_layers.append(nn.LayerNorm(out_channels))
self.conv_layers.append(nn.ReLU())
self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
self.layers.append(TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[0])
))
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.layers.append(TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[i])
))
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
target_padding_mask = (
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
if incremental_state is None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens
x = self.embed_tokens(prev_output_tokens)
# B x T x C -> T x B x C
x = self._transpose_if_training(x, incremental_state)
for layer in self.conv_layers:
if isinstance(layer, LinearizedConvolution):
x = layer(x, incremental_state)
else:
x = layer(x)
# B x T x C -> T x B x C
x = self._transpose_if_inference(x, incremental_state)
# decoder layers
for layer in self.layers:
if isinstance(layer, TransformerDecoderLayer):
x, _ = layer(
x,
(encoder_out["encoder_out"] if encoder_out is not None else None),
(
encoder_out["encoder_padding_mask"].t()
if encoder_out["encoder_padding_mask"] is not None
else None
),
incremental_state,
self_attn_mask=(
self.buffered_future_mask(x)
if incremental_state is None
else None
),
self_attn_padding_mask=(
target_padding_mask if incremental_state is None else None
),
)
else:
x = layer(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.fc_out(x)
return x, None
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def _transpose_if_training(self, x, incremental_state):
if incremental_state is None:
x = x.transpose(0, 1)
return x
def _transpose_if_inference(self, x, incremental_state):
if incremental_state:
x = x.transpose(0, 1)
return x
@register_model("asr_vggtransformer_encoder")
class VGGTransformerEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...]
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the Transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ]""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="encoder output dimension, projecting the LSTM output",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--transformer-context",
type=str,
metavar="EXPR",
help="""
either None or a tuple of two ints, indicating left/right context a
transformer can have access to""",
)
parser.add_argument(
"--transformer-sampling",
type=str,
metavar="EXPR",
help="""
either None or a tuple of ints, indicating sampling factor in each layer""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
base_architecture_enconly(args)
encoder = VGGTransformerEncoderOnly(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
transformer_context=eval(args.transformer_context),
transformer_sampling=eval(args.transformer_sampling),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (T, B, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
# lprobs is a (T, B, D) tensor
# we need to transoose to get (B, T, D) tensor
lprobs = lprobs.transpose(0, 1).contiguous()
lprobs.batch_first = True
return lprobs
class VGGTransformerEncoderOnly(VGGTransformerEncoder):
def __init__(
self,
vocab_size,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
super().__init__(
input_feat_per_channel=input_feat_per_channel,
vggblock_config=vggblock_config,
transformer_config=transformer_config,
encoder_output_dim=encoder_output_dim,
in_channels=in_channels,
transformer_context=transformer_context,
transformer_sampling=transformer_sampling,
)
self.fc_out = Linear(self.encoder_output_dim, vocab_size)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
enc_out = super().forward(src_tokens, src_lengths)
x = self.fc_out(enc_out["encoder_out"])
# x = F.log_softmax(x, dim=-1)
# Note: no need this line, because model.get_normalized_prob will call
# log_softmax
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B)
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
# nn.init.uniform_(m.weight, -0.1, 0.1)
# nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
# m.weight.data.uniform_(-0.1, 0.1)
# if bias:
# m.bias.data.uniform_(-0.1, 0.1)
return m
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
# seq2seq models
def base_architecture(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.transformer_dec_config = getattr(
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
args.transformer_context = getattr(args, "transformer_context", "None")
@register_model_architecture("asr_vggtransformer", "vggtransformer_1")
def vggtransformer_1(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_2")
def vggtransformer_2(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_base")
def vggtransformer_base(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
)
# Size estimations:
# Encoder:
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
# Transformer:
# - input dimension adapter: 2560 x 512 -> 1.31M
# - transformer_layers (x12) --> 37.74M
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
# * FFN weight: 512*2048*2 = 2.097M
# - output dimension adapter: 512 x 512 -> 0.26 M
# Decoder:
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
# - transformer_layer: (x6) --> 25.16M
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
# * FFN: 512*2048*2 = 2.097M
# Final FC:
# - FC: 512*5000 = 256K (assuming vocab size 5K)
# In total:
# ~65 M
# CTC models
def base_architecture_enconly(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.transformer_context = getattr(args, "transformer_context", "None")
args.transformer_sampling = getattr(args, "transformer_sampling", "None")
@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1")
def vggtransformer_enc_1(args):
# vggtransformer_1 is the same as vggtransformer_enc_big, except the number
# of layers is increased to 16
# keep it here for backward compatiablity purpose
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)