# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F from fairseq import utils from fairseq.iterative_refinement_generator import DecoderOut from fairseq.models import register_model, register_model_architecture from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder from fairseq.models.transformer import Embedding from fairseq.modules.transformer_sentence_encoder import init_bert_params def _mean_pooling(enc_feats, src_masks): # enc_feats: T x B x C # src_masks: B x T or None if src_masks is None: enc_feats = enc_feats.mean(0) else: src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats) enc_feats = ( (enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None] ).sum(0) return enc_feats def _argmax(x, dim): return (x == x.max(dim, keepdim=True)[0]).type_as(x) def _uniform_assignment(src_lens, trg_lens): max_trg_len = trg_lens.max() steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size # max_trg_len index_t = utils.new_arange(trg_lens, max_trg_len).float() index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len index_t = torch.round(index_t).long().detach() return index_t @register_model("nonautoregressive_transformer") class NATransformerModel(FairseqNATModel): @property def allow_length_beam(self): return True @staticmethod def add_args(parser): FairseqNATModel.add_args(parser) # length prediction parser.add_argument( "--src-embedding-copy", action="store_true", help="copy encoder word embeddings as the initial input of the decoder", ) parser.add_argument( "--pred-length-offset", action="store_true", help="predicting the length difference between the target and source sentences", ) parser.add_argument( "--sg-length-pred", action="store_true", help="stop the gradients back-propagated from the length predictor", ) parser.add_argument( "--length-loss-factor", type=float, help="weights on the length prediction loss", ) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): decoder = NATransformerDecoder(args, tgt_dict, embed_tokens) if getattr(args, "apply_bert_init", False): decoder.apply(init_bert_params) return decoder def forward( self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs ): # encoding encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) # length prediction length_out = self.decoder.forward_length( normalize=False, encoder_out=encoder_out ) length_tgt = self.decoder.forward_length_prediction( length_out, encoder_out, tgt_tokens ) # decoding word_ins_out = self.decoder( normalize=False, prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, ) return { "word_ins": { "out": word_ins_out, "tgt": tgt_tokens, "mask": tgt_tokens.ne(self.pad), "ls": self.args.label_smoothing, "nll_loss": True, }, "length": { "out": length_out, "tgt": length_tgt, "factor": self.decoder.length_loss_factor, }, } def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): step = decoder_out.step output_tokens = decoder_out.output_tokens output_scores = decoder_out.output_scores history = decoder_out.history # execute the decoder output_masks = output_tokens.ne(self.pad) _scores, _tokens = self.decoder( normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out, step=step, ).max(-1) output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) output_scores.masked_scatter_(output_masks, _scores[output_masks]) if history is not None: history.append(output_tokens.clone()) return decoder_out._replace( output_tokens=output_tokens, output_scores=output_scores, attn=None, history=history, ) def initialize_output_tokens(self, encoder_out, src_tokens): # length prediction length_tgt = self.decoder.forward_length_prediction( self.decoder.forward_length(normalize=True, encoder_out=encoder_out), encoder_out=encoder_out, ) max_length = length_tgt.clamp_(min=2).max() idx_length = utils.new_arange(src_tokens, max_length) initial_output_tokens = src_tokens.new_zeros( src_tokens.size(0), max_length ).fill_(self.pad) initial_output_tokens.masked_fill_( idx_length[None, :] < length_tgt[:, None], self.unk ) initial_output_tokens[:, 0] = self.bos initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size() ).type_as(encoder_out["encoder_out"][0]) return DecoderOut( output_tokens=initial_output_tokens, output_scores=initial_output_scores, attn=None, step=0, max_step=0, history=None, ) def regenerate_length_beam(self, decoder_out, beam_size): output_tokens = decoder_out.output_tokens length_tgt = output_tokens.ne(self.pad).sum(1) length_tgt = ( length_tgt[:, None] + utils.new_arange(length_tgt, 1, beam_size) - beam_size // 2 ) length_tgt = length_tgt.view(-1).clamp_(min=2) max_length = length_tgt.max() idx_length = utils.new_arange(length_tgt, max_length) initial_output_tokens = output_tokens.new_zeros( length_tgt.size(0), max_length ).fill_(self.pad) initial_output_tokens.masked_fill_( idx_length[None, :] < length_tgt[:, None], self.unk ) initial_output_tokens[:, 0] = self.bos initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size() ).type_as(decoder_out.output_scores) return decoder_out._replace( output_tokens=initial_output_tokens, output_scores=initial_output_scores ) class NATransformerDecoder(FairseqNATDecoder): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__( args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn ) self.dictionary = dictionary self.bos = dictionary.bos() self.unk = dictionary.unk() self.eos = dictionary.eos() self.encoder_embed_dim = args.encoder_embed_dim self.sg_length_pred = getattr(args, "sg_length_pred", False) self.pred_length_offset = getattr(args, "pred_length_offset", False) self.length_loss_factor = getattr(args, "length_loss_factor", 0.1) self.src_embedding_copy = getattr(args, "src_embedding_copy", False) self.embed_length = Embedding(256, self.encoder_embed_dim, None) @ensemble_decoder def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused): features, _ = self.extract_features( prev_output_tokens, encoder_out=encoder_out, embedding_copy=(step == 0) & self.src_embedding_copy, ) decoder_out = self.output_layer(features) return F.log_softmax(decoder_out, -1) if normalize else decoder_out @ensemble_decoder def forward_length(self, normalize, encoder_out): enc_feats = encoder_out["encoder_out"][0] # T x B x C if len(encoder_out["encoder_padding_mask"]) > 0: src_masks = encoder_out["encoder_padding_mask"][0] # B x T else: src_masks = None enc_feats = _mean_pooling(enc_feats, src_masks) if self.sg_length_pred: enc_feats = enc_feats.detach() length_out = F.linear(enc_feats, self.embed_length.weight) return F.log_softmax(length_out, -1) if normalize else length_out def extract_features( self, prev_output_tokens, encoder_out=None, early_exit=None, embedding_copy=False, **unused ): """ Similar to *forward* but only return features. Inputs: prev_output_tokens: Tensor(B, T) encoder_out: a dictionary of hidden states and masks Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs the LevenshteinTransformer decoder has full-attention to all generated tokens """ # embedding if embedding_copy: src_embd = encoder_out["encoder_embedding"][0] if len(encoder_out["encoder_padding_mask"]) > 0: src_mask = encoder_out["encoder_padding_mask"][0] else: src_mask = None src_mask = ( ~src_mask if src_mask is not None else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool() ) x, decoder_padding_mask = self.forward_embedding( prev_output_tokens, self.forward_copying_source( src_embd, src_mask, prev_output_tokens.ne(self.padding_idx) ), ) else: x, decoder_padding_mask = self.forward_embedding(prev_output_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for i, layer in enumerate(self.layers): # early exit from the decoder. if (early_exit is not None) and (i >= early_exit): break x, attn, _ = layer( x, encoder_out["encoder_out"][0] if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) else None, encoder_out["encoder_padding_mask"][0] if ( encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0 ) else None, self_attn_mask=None, self_attn_padding_mask=decoder_padding_mask, ) inner_states.append(x) if self.layer_norm: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) return x, {"attn": attn, "inner_states": inner_states} def forward_embedding(self, prev_output_tokens, states=None): # embed positions positions = ( self.embed_positions(prev_output_tokens) if self.embed_positions is not None else None ) # embed tokens and positions if states is None: x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) else: x = states if positions is not None: x += positions x = self.dropout_module(x) decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) return x, decoder_padding_mask def forward_copying_source(self, src_embeds, src_masks, tgt_masks): length_sources = src_masks.sum(1) length_targets = tgt_masks.sum(1) mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill( ~tgt_masks, 0 ) copied_embedding = torch.gather( src_embeds, 1, mapped_inputs.unsqueeze(-1).expand( *mapped_inputs.size(), src_embeds.size(-1) ), ) return copied_embedding def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None): enc_feats = encoder_out["encoder_out"][0] # T x B x C if len(encoder_out["encoder_padding_mask"]) > 0: src_masks = encoder_out["encoder_padding_mask"][0] # B x T else: src_masks = None if self.pred_length_offset: if src_masks is None: src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_( enc_feats.size(0) ) else: src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0) src_lengs = src_lengs.long() if tgt_tokens is not None: # obtain the length target tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long() if self.pred_length_offset: length_tgt = tgt_lengs - src_lengs + 128 else: length_tgt = tgt_lengs length_tgt = length_tgt.clamp(min=0, max=255) else: # predict the length target (greedy for now) # TODO: implementing length-beam pred_lengs = length_out.max(-1)[1] if self.pred_length_offset: length_tgt = pred_lengs - 128 + src_lengs else: length_tgt = pred_lengs return length_tgt @register_model_architecture( "nonautoregressive_transformer", "nonautoregressive_transformer" ) def base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.decoder_embed_path = getattr(args, "decoder_embed_path", None) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.adaptive_input = getattr(args, "adaptive_input", False) args.apply_bert_init = getattr(args, "apply_bert_init", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) # --- special arguments --- args.sg_length_pred = getattr(args, "sg_length_pred", False) args.pred_length_offset = getattr(args, "pred_length_offset", False) args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) args.src_embedding_copy = getattr(args, "src_embedding_copy", False) @register_model_architecture( "nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de" ) def nonautoregressive_transformer_wmt_en_de(args): base_architecture(args)