Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from fairseq import utils, options | |
from fairseq.dataclass.utils import gen_parser_from_dataclass | |
from fairseq.distributed import fsdp_wrap | |
from fairseq.models import FairseqEncoderDecoderModel | |
from fairseq.models.transformer import ( | |
TransformerEncoderBase, | |
TransformerDecoderBase, | |
TransformerConfig, | |
) | |
from torch import Tensor | |
from fairseq.modules import AdaptiveInput | |
class TransformerModelBase(FairseqEncoderDecoderModel): | |
""" | |
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) | |
<https://arxiv.org/abs/1706.03762>`_. | |
Args: | |
encoder (TransformerEncoder): the encoder | |
decoder (TransformerDecoder): the decoder | |
The Transformer model provides the following named architectures and | |
command-line arguments: | |
.. argparse:: | |
:ref: fairseq.models.transformer_parser | |
:prog: | |
""" | |
def __init__(self, cfg, encoder, decoder): | |
super().__init__(encoder, decoder) | |
self.cfg = cfg | |
self.supports_align_args = True | |
def add_args(cls, parser): | |
"""Add model-specific arguments to the parser.""" | |
# we want to build the args recursively in this case. | |
gen_parser_from_dataclass( | |
parser, TransformerConfig(), delete_default=False, with_prefix="" | |
) | |
def build_model(cls, cfg, task): | |
"""Build a new model instance.""" | |
# -- TODO T96535332 | |
# bug caused by interaction between OmegaConf II and argparsing | |
cfg.decoder.input_dim = int(cfg.decoder.input_dim) | |
cfg.decoder.output_dim = int(cfg.decoder.output_dim) | |
# -- | |
if cfg.encoder.layers_to_keep: | |
cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(",")) | |
if cfg.decoder.layers_to_keep: | |
cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(",")) | |
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary | |
if cfg.adaptive_input is False: | |
if cfg.share_all_embeddings: | |
if src_dict != tgt_dict: | |
raise ValueError("--share-all-embeddings requires a joined dictionary") | |
if cfg.encoder.embed_dim != cfg.decoder.embed_dim: | |
raise ValueError( | |
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" | |
) | |
if cfg.decoder.embed_path and ( | |
cfg.decoder.embed_path != cfg.encoder.embed_path | |
): | |
raise ValueError( | |
"--share-all-embeddings not compatible with --decoder-embed-path" | |
) | |
encoder_embed_tokens = cls.build_embedding( | |
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path | |
) | |
decoder_embed_tokens = encoder_embed_tokens | |
cfg.share_decoder_input_output_embed = True | |
else: | |
encoder_embed_tokens = cls.build_embedding( | |
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path | |
) | |
decoder_embed_tokens = cls.build_embedding( | |
cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path | |
) | |
else: | |
encoder_embed_tokens = AdaptiveInput( | |
len(task.source_dictionary), | |
task.source_dictionary.pad(), | |
cfg.encoder.embed_dim, | |
cfg.adaptive_input_factor, | |
cfg.encoder.embed_dim, | |
options.eval_str_list(cfg.adaptive_input_cutoff, type=int), | |
cfg.quant_noise_pq, | |
cfg.quant_noise_pq_block_size, | |
) | |
decoder_embed_tokens = encoder_embed_tokens | |
if cfg.tie_adaptive_weights: | |
assert cfg.adaptive_input | |
assert cfg.adaptive_input_factor == cfg.adaptive_softmax_factor | |
assert ( | |
cfg.adaptive_softmax_cutoff == cfg.adaptive_input_cutoff | |
), "{} != {}".format( | |
cfg.adaptive_softmax_cutoff, cfg.adaptive_input_cutoff | |
) | |
assert cfg.decoder_input_dim == cfg.decoder_output_dim | |
if cfg.offload_activations: | |
cfg.checkpoint_activations = True # offloading implies checkpointing | |
encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens) | |
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens, encoder_layers=encoder.layers) | |
if not cfg.share_all_embeddings: | |
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded | |
encoder = fsdp_wrap(encoder, min_num_params=cfg.min_params_to_wrap) | |
decoder = fsdp_wrap(decoder, min_num_params=cfg.min_params_to_wrap) | |
return cls(cfg, encoder, decoder) | |
def build_embedding(cls, cfg, dictionary, embed_dim, path=None): | |
num_embeddings = len(dictionary) | |
padding_idx = dictionary.pad() | |
emb = Embedding(num_embeddings, embed_dim, padding_idx) | |
# if provided, load from preloaded dictionaries | |
if path: | |
embed_dict = utils.parse_embedding(path) | |
utils.load_embedding(embed_dict, dictionary, emb) | |
return emb | |
def build_encoder(cls, cfg, src_dict, embed_tokens): | |
return TransformerEncoderBase(cfg, src_dict, embed_tokens) | |
def build_decoder(cls, cfg, tgt_dict, embed_tokens, encoder_layers=None): | |
return TransformerDecoderBase( | |
cfg, | |
tgt_dict, | |
embed_tokens, | |
no_encoder_attn=cfg.no_cross_attention, | |
encoder_layers=encoder_layers | |
) | |
# TorchScript doesn't support optional arguments with variable length (**kwargs). | |
# Current workaround is to add union of all arguments in child classes. | |
def forward( | |
self, | |
src_tokens, | |
src_lengths, | |
prev_output_tokens, | |
return_all_hiddens: bool = True, | |
features_only: bool = False, | |
alignment_layer: Optional[int] = None, | |
alignment_heads: Optional[int] = None, | |
): | |
""" | |
Run the forward pass for an encoder-decoder model. | |
Copied from the base class, but without ``**kwargs``, | |
which are not supported by TorchScript. | |
""" | |
encoder_out = self.encoder( | |
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens | |
) | |
decoder_out = self.decoder( | |
prev_output_tokens, | |
encoder_out=encoder_out, | |
features_only=features_only, | |
alignment_layer=alignment_layer, | |
alignment_heads=alignment_heads, | |
src_lengths=src_lengths, | |
return_all_hiddens=return_all_hiddens, | |
) | |
return decoder_out | |
# Since get_normalized_probs is in the Fairseq Model which is not scriptable, | |
# I rewrite the get_normalized_probs from Base Class to call the | |
# helper function in the Base Class. | |
def get_normalized_probs( | |
self, | |
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], | |
log_probs: bool, | |
sample: Optional[Dict[str, Tensor]] = None, | |
): | |
"""Get normalized probabilities (or log probs) from a net's output.""" | |
return self.get_normalized_probs_scriptable(net_output, log_probs, sample) | |
def Embedding(num_embeddings, embedding_dim, padding_idx): | |
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) | |
nn.init.constant_(m.weight[padding_idx], 0) | |
return m | |