#!/usr/bin/env python3 import logging import math from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from fairseq import checkpoint_utils, utils from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, register_model, register_model_architecture, ) from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerEncoderLayer from torch import Tensor logger = logging.getLogger(__name__) @register_model("convtransformer") class ConvTransformerModel(FairseqEncoderDecoderModel): """ Transformer-based Speech translation model from ESPNet-ST https://arxiv.org/abs/2004.10234 """ def __init__(self, encoder, decoder): super().__init__(encoder, decoder) @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" parser.add_argument( "--input-feat-per-channel", type=int, metavar="N", help="encoder input dimension per input channel", ) parser.add_argument( "--activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use", ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability" ) parser.add_argument( "--attention-dropout", type=float, metavar="D", help="dropout probability for attention weights", ) parser.add_argument( "--activation-dropout", "--relu-dropout", type=float, metavar="D", help="dropout probability after activation in FFN.", ) parser.add_argument( "--encoder-embed-dim", type=int, metavar="N", help="encoder embedding dimension", ) parser.add_argument( "--encoder-ffn-embed-dim", type=int, metavar="N", help="encoder embedding dimension for FFN", ) parser.add_argument( "--encoder-layers", type=int, metavar="N", help="num encoder layers" ) parser.add_argument( "--encoder-attention-heads", type=int, metavar="N", help="num encoder attention heads", ) parser.add_argument( "--encoder-normalize-before", action="store_true", help="apply layernorm before each encoder block", ) parser.add_argument( "--decoder-embed-dim", type=int, metavar="N", help="decoder embedding dimension", ) parser.add_argument( "--decoder-ffn-embed-dim", type=int, metavar="N", help="decoder embedding dimension for FFN", ) parser.add_argument( "--decoder-layers", type=int, metavar="N", help="num decoder layers" ) parser.add_argument( "--decoder-attention-heads", type=int, metavar="N", help="num decoder attention heads", ) parser.add_argument( "--decoder-normalize-before", action="store_true", help="apply layernorm before each decoder block", ) parser.add_argument( "--decoder-output-dim", type=int, metavar="N", help="decoder output dimension (extra linear layer if different from decoder embed dim)", ) parser.add_argument( "--share-decoder-input-output-embed", action="store_true", help="share decoder input and output embeddings", ) parser.add_argument( "--layernorm-embedding", action="store_true", help="add layernorm to embedding", ) parser.add_argument( "--no-scale-embedding", action="store_true", help="if True, dont scale embeddings", ) parser.add_argument( "--load-pretrained-encoder-from", type=str, metavar="STR", help="model to take encoder weights from (for initialization)", ) parser.add_argument( "--load-pretrained-decoder-from", type=str, metavar="STR", help="model to take decoder weights from (for initialization)", ) parser.add_argument( "--conv-out-channels", type=int, metavar="INT", help="the number of output channels of conv layer", ) @classmethod def build_encoder(cls, args): encoder = ConvTransformerEncoder(args) if getattr(args, "load_pretrained_encoder_from", None): encoder = checkpoint_utils.load_pretrained_component_from_model( component=encoder, checkpoint=args.load_pretrained_encoder_from ) return encoder @classmethod def build_decoder(cls, args, task, embed_tokens): decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from ) return decoder @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) def build_embedding(dictionary, embed_dim): num_embeddings = len(dictionary) padding_idx = dictionary.pad() return Embedding(num_embeddings, embed_dim, padding_idx) decoder_embed_tokens = build_embedding( task.target_dictionary, args.decoder_embed_dim ) encoder = cls.build_encoder(args) decoder = cls.build_decoder(args, task, decoder_embed_tokens) return cls(encoder, decoder) @staticmethod @torch.jit.unused def set_batch_first(lprobs): lprobs.batch_first = True def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): # net_output['encoder_out'] is a (B, T, D) tensor lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) if self.training: self.set_batch_first(lprobs) return lprobs def output_layout(self): return "BTD" """ The forward method inherited from the base class has a **kwargs argument in its input, which is not supported in torchscript. This method overrites the forward method definition without **kwargs. """ def forward(self, src_tokens, src_lengths, prev_output_tokens): encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) decoder_out = self.decoder( prev_output_tokens=prev_output_tokens, encoder_out=encoder_out ) return decoder_out class ConvTransformerEncoder(FairseqEncoder): """Conv + Transformer encoder""" def __init__(self, args): """Construct an Encoder object.""" super().__init__(None) self.dropout = args.dropout self.embed_scale = ( 1.0 if args.no_scale_embedding else math.sqrt(args.encoder_embed_dim) ) self.padding_idx = 1 self.in_channels = 1 self.input_dim = args.input_feat_per_channel self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, args.conv_out_channels, 3, stride=2, padding=3 // 2), torch.nn.ReLU(), torch.nn.Conv2d( args.conv_out_channels, args.conv_out_channels, 3, stride=2, padding=3 // 2, ), torch.nn.ReLU(), ) transformer_input_dim = self.infer_conv_output_dim( self.in_channels, self.input_dim, args.conv_out_channels ) self.out = torch.nn.Linear(transformer_input_dim, args.encoder_embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, args.encoder_embed_dim, self.padding_idx, learned=False, ) self.transformer_layers = nn.ModuleList([]) self.transformer_layers.extend( [TransformerEncoderLayer(args) for i in range(args.encoder_layers)] ) if args.encoder_normalize_before: self.layer_norm = LayerNorm(args.encoder_embed_dim) else: self.layer_norm = None def pooling_ratio(self): return 4 def infer_conv_output_dim(self, in_channels, input_dim, out_channels): sample_seq_len = 200 sample_bsz = 10 x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) x = torch.nn.Conv2d(1, out_channels, 3, stride=2, padding=3 // 2)(x) x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) x = x.transpose(1, 2) mb, seq = x.size()[:2] return x.contiguous().view(mb, seq, -1).size(-1) def forward(self, src_tokens, src_lengths): """Encode input sequence. :param torch.Tensor xs: input tensor :param torch.Tensor masks: input mask :return: position embedded tensor and mask :rtype Tuple[torch.Tensor, torch.Tensor]: """ bsz, max_seq_len, _ = src_tokens.size() x = ( src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) .transpose(1, 2) .contiguous() ) x = self.conv(x) bsz, _, output_seq_len, _ = x.size() x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) x = self.out(x) x = self.embed_scale * x subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to( input_len_0.device ) input_lengths = torch.min(input_len_0, input_len_1) encoder_padding_mask = lengths_to_padding_mask(input_lengths) positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) x += positions x = F.dropout(x, p=self.dropout, training=self.training) for layer in self.transformer_layers: x = layer(x, encoder_padding_mask) if not encoder_padding_mask.any(): maybe_encoder_padding_mask = None else: maybe_encoder_padding_mask = encoder_padding_mask return { "encoder_out": [x], "encoder_padding_mask": [maybe_encoder_padding_mask] if maybe_encoder_padding_mask is not None else [], "encoder_embedding": [], "encoder_states": [], "src_tokens": [], "src_lengths": [], } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ (encoder_out["encoder_padding_mask"][0]).index_select(0, new_order) ] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: new_encoder_embedding = [ (encoder_out["encoder_embedding"][0]).index_select(0, new_order) ] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, "encoder_padding_mask": new_encoder_padding_mask, "encoder_embedding": new_encoder_embedding, "encoder_states": encoder_states, "src_tokens": [], "src_lengths": [], } class TransformerDecoderNoExtra(TransformerDecoder): def extract_features( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): # call scriptable method from parent class x, _ = self.extract_features_scriptable( prev_output_tokens, encoder_out, incremental_state, full_context_alignment, alignment_layer, alignment_heads, ) return x, None @register_model_architecture(model_name="convtransformer", arch_name="convtransformer") def base_architecture(args): args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.adaptive_input = getattr(args, "adaptive_input", False) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.max_source_positions = getattr(args, "max_source_positions", 3000) args.max_target_positions = getattr(args, "max_target_positions", 1024) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.conv_out_channels = getattr(args, "conv_out_channels", args.encoder_embed_dim) @register_model_architecture("convtransformer", "convtransformer_espnet") def convtransformer_espnet(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_layers = getattr(args, "encoder_layers", 12) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)