HuBERT / fairseq /models /speech_to_text /convtransformer.py
aliabd
full working demo
d5175d3
#!/usr/bin/env python3
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerEncoderLayer
from torch import Tensor
logger = logging.getLogger(__name__)
@register_model("convtransformer")
class ConvTransformerModel(FairseqEncoderDecoderModel):
"""
Transformer-based Speech translation model from ESPNet-ST
https://arxiv.org/abs/2004.10234
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--decoder-output-dim",
type=int,
metavar="N",
help="decoder output dimension (extra linear layer if different from decoder embed dim)",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument(
"--load-pretrained-encoder-from",
type=str,
metavar="STR",
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--conv-out-channels",
type=int,
metavar="INT",
help="the number of output channels of conv layer",
)
@classmethod
def build_encoder(cls, args):
encoder = ConvTransformerEncoder(args)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@classmethod
def build_decoder(cls, args, task, embed_tokens):
decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from
)
return decoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
decoder_embed_tokens = build_embedding(
task.target_dictionary, args.decoder_embed_dim
)
encoder = cls.build_encoder(args)
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
return cls(encoder, decoder)
@staticmethod
@torch.jit.unused
def set_batch_first(lprobs):
lprobs.batch_first = True
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
if self.training:
self.set_batch_first(lprobs)
return lprobs
def output_layout(self):
return "BTD"
"""
The forward method inherited from the base class has a **kwargs argument in
its input, which is not supported in torchscript. This method overrites the forward
method definition without **kwargs.
"""
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
return decoder_out
class ConvTransformerEncoder(FairseqEncoder):
"""Conv + Transformer encoder"""
def __init__(self, args):
"""Construct an Encoder object."""
super().__init__(None)
self.dropout = args.dropout
self.embed_scale = (
1.0 if args.no_scale_embedding else math.sqrt(args.encoder_embed_dim)
)
self.padding_idx = 1
self.in_channels = 1
self.input_dim = args.input_feat_per_channel
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, args.conv_out_channels, 3, stride=2, padding=3 // 2),
torch.nn.ReLU(),
torch.nn.Conv2d(
args.conv_out_channels,
args.conv_out_channels,
3,
stride=2,
padding=3 // 2,
),
torch.nn.ReLU(),
)
transformer_input_dim = self.infer_conv_output_dim(
self.in_channels, self.input_dim, args.conv_out_channels
)
self.out = torch.nn.Linear(transformer_input_dim, args.encoder_embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions,
args.encoder_embed_dim,
self.padding_idx,
learned=False,
)
self.transformer_layers = nn.ModuleList([])
self.transformer_layers.extend(
[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim)
else:
self.layer_norm = None
def pooling_ratio(self):
return 4
def infer_conv_output_dim(self, in_channels, input_dim, out_channels):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
x = torch.nn.Conv2d(1, out_channels, 3, stride=2, padding=3 // 2)(x)
x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x)
x = x.transpose(1, 2)
mb, seq = x.size()[:2]
return x.contiguous().view(mb, seq, -1).size(-1)
def forward(self, src_tokens, src_lengths):
"""Encode input sequence.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
bsz, max_seq_len, _ = src_tokens.size()
x = (
src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
.transpose(1, 2)
.contiguous()
)
x = self.conv(x)
bsz, _, output_seq_len, _ = x.size()
x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1)
x = self.out(x)
x = self.embed_scale * x
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long()
input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to(
input_len_0.device
)
input_lengths = torch.min(input_len_0, input_len_1)
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
for layer in self.transformer_layers:
x = layer(x, encoder_padding_mask)
if not encoder_padding_mask.any():
maybe_encoder_padding_mask = None
else:
maybe_encoder_padding_mask = encoder_padding_mask
return {
"encoder_out": [x],
"encoder_padding_mask": [maybe_encoder_padding_mask]
if maybe_encoder_padding_mask is not None
else [],
"encoder_embedding": [],
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
(encoder_out["encoder_padding_mask"][0]).index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
(encoder_out["encoder_embedding"][0]).index_select(0, new_order)
]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out,
"encoder_padding_mask": new_encoder_padding_mask,
"encoder_embedding": new_encoder_embedding,
"encoder_states": encoder_states,
"src_tokens": [],
"src_lengths": [],
}
class TransformerDecoderNoExtra(TransformerDecoder):
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
return x, None
@register_model_architecture(model_name="convtransformer", arch_name="convtransformer")
def base_architecture(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.max_source_positions = getattr(args, "max_source_positions", 3000)
args.max_target_positions = getattr(args, "max_target_positions", 1024)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.conv_out_channels = getattr(args, "conv_out_channels", args.encoder_embed_dim)
@register_model_architecture("convtransformer", "convtransformer_espnet")
def convtransformer_espnet(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)