Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
BART: Denoising Sequence-to-Sequence Pre-training for | |
Natural Language Generation, Translation, and Comprehension | |
""" | |
import logging | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from fairseq import utils | |
from fairseq.models import register_model, register_model_architecture | |
from fairseq.models.transformer import TransformerModel | |
from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
from .hub_interface import BARTHubInterface | |
logger = logging.getLogger(__name__) | |
class BARTModel(TransformerModel): | |
__jit_unused_properties__ = ["supported_targets"] | |
def hub_models(cls): | |
return { | |
"bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz", | |
"bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz", | |
"bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz", | |
"bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz", | |
"bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz", | |
} | |
def __init__(self, args, encoder, decoder): | |
super().__init__(args, encoder, decoder) | |
# We follow BERT's random weight initialization | |
self.apply(init_bert_params) | |
self.classification_heads = nn.ModuleDict() | |
if hasattr(self.encoder, "dictionary"): | |
self.eos: int = self.encoder.dictionary.eos() | |
def add_args(parser): | |
super(BARTModel, BARTModel).add_args(parser) | |
parser.add_argument( | |
"--pooler-dropout", | |
type=float, | |
metavar="D", | |
help="dropout probability in the masked_lm pooler layers", | |
) | |
parser.add_argument( | |
"--pooler-activation-fn", | |
choices=utils.get_available_activation_fns(), | |
help="activation function to use for pooler layer", | |
) | |
parser.add_argument( | |
"--spectral-norm-classification-head", | |
action="store_true", | |
help="Apply spectral normalization on the classification head", | |
) | |
def supported_targets(self): | |
return {"self"} | |
def forward( | |
self, | |
src_tokens, | |
src_lengths, | |
prev_output_tokens, | |
features_only: bool = False, | |
classification_head_name: Optional[str] = None, | |
token_embeddings: Optional[torch.Tensor] = None, | |
return_all_hiddens: bool = True, | |
alignment_layer: Optional[int] = None, | |
alignment_heads: Optional[int] = None, | |
): | |
if classification_head_name is not None: | |
features_only = True | |
encoder_out = self.encoder( | |
src_tokens, | |
src_lengths=src_lengths, | |
token_embeddings=token_embeddings, | |
return_all_hiddens=return_all_hiddens, | |
) | |
x, extra = self.decoder( | |
prev_output_tokens, | |
encoder_out=encoder_out, | |
features_only=features_only, | |
alignment_layer=alignment_layer, | |
alignment_heads=alignment_heads, | |
src_lengths=src_lengths, | |
return_all_hiddens=return_all_hiddens, | |
) | |
eos: int = self.eos | |
if classification_head_name is not None: | |
sentence_representation = x[src_tokens.eq(eos), :].view( | |
x.size(0), -1, x.size(-1) | |
)[:, -1, :] | |
for k, head in self.classification_heads.items(): | |
# for torch script only supports iteration | |
if k == classification_head_name: | |
x = head(sentence_representation) | |
break | |
return x, extra | |
def from_pretrained( | |
cls, | |
model_name_or_path, | |
checkpoint_file="model.pt", | |
data_name_or_path=".", | |
bpe="gpt2", | |
sample_break_mode="eos", | |
**kwargs, | |
): | |
from fairseq import hub_utils | |
x = hub_utils.from_pretrained( | |
model_name_or_path, | |
checkpoint_file, | |
data_name_or_path, | |
archive_map=cls.hub_models(), | |
bpe=bpe, | |
load_checkpoint_heads=True, | |
sample_break_mode=sample_break_mode, | |
**kwargs, | |
) | |
return BARTHubInterface(x["args"], x["task"], x["models"][0]) | |
def register_classification_head( | |
self, name, num_classes=None, inner_dim=None, **kwargs | |
): | |
"""Register a classification head.""" | |
logger.info("Registering classification head: {0}".format(name)) | |
if name in self.classification_heads: | |
prev_num_classes = self.classification_heads[name].out_proj.out_features | |
prev_inner_dim = self.classification_heads[name].dense.out_features | |
if num_classes != prev_num_classes or inner_dim != prev_inner_dim: | |
logger.warning( | |
're-registering head "{}" with num_classes {} (prev: {}) ' | |
"and inner_dim {} (prev: {})".format( | |
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim | |
) | |
) | |
self.classification_heads[name] = BARTClassificationHead( | |
input_dim=self.args.encoder_embed_dim, | |
inner_dim=inner_dim or self.args.encoder_embed_dim, | |
num_classes=num_classes, | |
activation_fn=self.args.pooler_activation_fn, | |
pooler_dropout=self.args.pooler_dropout, | |
do_spectral_norm=getattr( | |
self.args, "spectral_norm_classification_head", False | |
), | |
) | |
def upgrade_state_dict_named(self, state_dict, name): | |
super().upgrade_state_dict_named(state_dict, name) | |
prefix = name + "." if name != "" else "" | |
current_head_names = ( | |
[] | |
if not hasattr(self, "classification_heads") | |
else self.classification_heads.keys() | |
) | |
# Handle new classification heads present in the state dict. | |
keys_to_delete = [] | |
for k in state_dict.keys(): | |
if not k.startswith(prefix + "classification_heads."): | |
continue | |
head_name = k[len(prefix + "classification_heads.") :].split(".")[0] | |
num_classes = state_dict[ | |
prefix + "classification_heads." + head_name + ".out_proj.weight" | |
].size(0) | |
inner_dim = state_dict[ | |
prefix + "classification_heads." + head_name + ".dense.weight" | |
].size(0) | |
if getattr(self.args, "load_checkpoint_heads", False): | |
if head_name not in current_head_names: | |
self.register_classification_head(head_name, num_classes, inner_dim) | |
else: | |
if head_name not in current_head_names: | |
logger.warning( | |
"deleting classification head ({}) from checkpoint " | |
"not present in current model: {}".format(head_name, k) | |
) | |
keys_to_delete.append(k) | |
elif ( | |
num_classes | |
!= self.classification_heads[head_name].out_proj.out_features | |
or inner_dim | |
!= self.classification_heads[head_name].dense.out_features | |
): | |
logger.warning( | |
"deleting classification head ({}) from checkpoint " | |
"with different dimensions than current model: {}".format( | |
head_name, k | |
) | |
) | |
keys_to_delete.append(k) | |
for k in keys_to_delete: | |
del state_dict[k] | |
def truncate_emb(key): | |
if key in state_dict: | |
state_dict[key] = state_dict[key][:-1, :] | |
# When finetuning on translation task, remove last row of | |
# embedding matrix that corresponds to mask_idx token. | |
loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0) | |
if ( | |
loaded_dict_size == len(self.encoder.dictionary) + 1 | |
and "<mask>" not in self.encoder.dictionary | |
): | |
truncate_emb("encoder.embed_tokens.weight") | |
truncate_emb("decoder.embed_tokens.weight") | |
truncate_emb("encoder.output_projection.weight") | |
truncate_emb("decoder.output_projection.weight") | |
# When continued pretraining on new set of languages for mbart, | |
# add extra lang embeddings at the end of embed_tokens. | |
# Note: newly added languages are assumed to have been added at the end. | |
if self.args.task == "multilingual_denoising" and loaded_dict_size < len( | |
self.encoder.dictionary | |
): | |
logger.info( | |
"Adding extra language embeddings not found in pretrained model for " | |
"continued pretraining of MBART on new set of languages." | |
) | |
loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][ | |
-1, : | |
] | |
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size | |
embed_dim = state_dict["encoder.embed_tokens.weight"].size(1) | |
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim) | |
nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim**-0.5) | |
new_lang_embed_to_add = new_lang_embed_to_add.to( | |
dtype=state_dict["encoder.embed_tokens.weight"].dtype, | |
) | |
state_dict["encoder.embed_tokens.weight"] = torch.cat( | |
[ | |
state_dict["encoder.embed_tokens.weight"][ | |
: loaded_dict_size - 1, : | |
], | |
new_lang_embed_to_add, | |
loaded_mask_token_embedding.unsqueeze(0), | |
] | |
) | |
state_dict["decoder.embed_tokens.weight"] = torch.cat( | |
[ | |
state_dict["decoder.embed_tokens.weight"][ | |
: loaded_dict_size - 1, : | |
], | |
new_lang_embed_to_add, | |
loaded_mask_token_embedding.unsqueeze(0), | |
] | |
) | |
# Copy any newly-added classification heads into the state dict | |
# with their current weights. | |
if hasattr(self, "classification_heads"): | |
cur_state = self.classification_heads.state_dict() | |
for k, v in cur_state.items(): | |
if prefix + "classification_heads." + k not in state_dict: | |
logger.info("Overwriting " + prefix + "classification_heads." + k) | |
state_dict[prefix + "classification_heads." + k] = v | |
def set_beam_size(self, beam): | |
"""Set beam size for efficient beamable enc-dec attention.""" | |
beamable = False | |
for layer in self.decoder.layers: | |
if layer.encoder_attn is not None: | |
if hasattr(layer.encoder_attn, "set_beam_size"): | |
layer.encoder_attn.set_beam_size(beam) | |
beamable = True | |
if beamable: | |
self.encoder.reorder_encoder_out = self.encoder._reorder_encoder_out | |
class BARTClassificationHead(nn.Module): | |
"""Head for sentence-level classification tasks.""" | |
def __init__( | |
self, | |
input_dim, | |
inner_dim, | |
num_classes, | |
activation_fn, | |
pooler_dropout, | |
do_spectral_norm=False, | |
): | |
super().__init__() | |
self.dense = nn.Linear(input_dim, inner_dim) | |
self.activation_fn = utils.get_activation_fn(activation_fn) | |
self.dropout = nn.Dropout(p=pooler_dropout) | |
self.out_proj = nn.Linear(inner_dim, num_classes) | |
if do_spectral_norm: | |
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = self.activation_fn(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
def bart_large_architecture(args): | |
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024) | |
args.encoder_layers = getattr(args, "encoder_layers", 12) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) | |
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) | |
args.decoder_ffn_embed_dim = getattr( | |
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim | |
) | |
args.decoder_layers = getattr(args, "decoder_layers", 12) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.0) | |
args.relu_dropout = getattr(args, "relu_dropout", 0.0) | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.max_target_positions = getattr(args, "max_target_positions", 1024) | |
args.max_source_positions = getattr(args, "max_source_positions", 1024) | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.share_decoder_input_output_embed = getattr( | |
args, "share_decoder_input_output_embed", True | |
) | |
args.share_all_embeddings = getattr(args, "share_all_embeddings", True) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", True) | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", True) | |
args.activation_fn = getattr(args, "activation_fn", "gelu") | |
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") | |
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) | |
def bart_base_architecture(args): | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768) | |
args.encoder_layers = getattr(args, "encoder_layers", 6) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) | |
bart_large_architecture(args) | |
def mbart_large_architecture(args): | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | |
bart_large_architecture(args) | |
def mbart_base_architecture(args): | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | |
bart_base_architecture(args) | |
def mbart_base_wmt20_architecture(args): | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) | |
mbart_base_architecture(args) | |