Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from ..roberta.model_xlmr import XLMRModel | |
from fairseq.models.xmod.transformer_layer_xmod import XMODTransformerEncoderLayerBase | |
from ..roberta.model import base_architecture, RobertaEncoder | |
from fairseq.models.transformer import TransformerEncoder | |
from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
from typing import Optional | |
from fairseq.models.xmod.hub_interface import XMODHubInterface | |
import torch | |
from fairseq.distributed import fsdp_wrap | |
from fairseq.models import ( | |
register_model, | |
register_model_architecture, | |
) | |
from fairseq.modules.checkpoint_activations import checkpoint_wrapper | |
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) | |
class XMODModel(XLMRModel): | |
def hub_models(cls): | |
return { | |
"xmod.base": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.81.1M.tar.gz", | |
"xmod.large.prenorm": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.large.prenorm.81.500k.tar.gz", | |
"xmod.base.13.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.13.125k.tar.gz", | |
"xmod.base.30.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.125k.tar.gz", | |
"xmod.base.30.195k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.195k.tar.gz", | |
"xmod.base.60.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.125k.tar.gz", | |
"xmod.base.60.265k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.265k.tar.gz", | |
"xmod.base.75.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.125k.tar.gz", | |
"xmod.base.75.269k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.269k.tar.gz", | |
} | |
def from_pretrained( | |
cls, | |
model_name_or_path, | |
checkpoint_file="model.pt", | |
data_name_or_path=".", | |
bpe="sentencepiece", | |
**kwargs, | |
): | |
from fairseq import hub_utils | |
x = hub_utils.from_pretrained( | |
model_name_or_path, | |
checkpoint_file, | |
data_name_or_path, | |
archive_map=cls.hub_models(), | |
bpe=bpe, | |
load_checkpoint_heads=True, | |
**kwargs, | |
) | |
return XMODHubInterface(x["args"], x["task"], x["models"][0]) | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
from omegaconf import OmegaConf | |
if OmegaConf.is_config(args): | |
OmegaConf.set_struct(args, False) | |
# make sure all arguments are present | |
base_architecture(args) | |
if not hasattr(args, "max_positions"): | |
if not hasattr(args, "tokens_per_sample"): | |
args.tokens_per_sample = task.max_positions() | |
args.max_positions = args.tokens_per_sample | |
encoder = XMODEncoder(args, task.source_dictionary) | |
if OmegaConf.is_config(args): | |
OmegaConf.set_struct(args, True) | |
return cls(args, encoder) | |
def forward( | |
self, | |
src_tokens, | |
features_only=False, | |
return_all_hiddens=False, | |
classification_head_name=None, | |
lang_id=None, | |
**kwargs, | |
): | |
if classification_head_name is not None: | |
features_only = True | |
x, extra = self.encoder( | |
src_tokens, features_only, return_all_hiddens, lang_id=lang_id, **kwargs | |
) | |
if classification_head_name is not None: | |
x = self.classification_heads[classification_head_name](x) | |
return x, extra | |
class XMODEncoder(RobertaEncoder): | |
"""XMOD encoder.""" | |
def build_encoder(self, args, dictionary, embed_tokens): | |
encoder = XMODTransformerEncoder(args, dictionary, embed_tokens) | |
encoder.apply(init_bert_params) | |
return encoder | |
def forward( | |
self, | |
src_tokens, | |
features_only=False, | |
return_all_hiddens=False, | |
masked_tokens=None, | |
lang_id=None, | |
**unused, | |
): | |
""" | |
Args: | |
src_tokens (LongTensor): input tokens of shape `(batch, src_len)` | |
features_only (bool, optional): skip LM head and just return | |
features. If True, the output will be of shape | |
`(batch, src_len, embed_dim)`. | |
return_all_hiddens (bool, optional): also return all of the | |
intermediate hidden states (default: False). | |
Returns: | |
tuple: | |
- the LM output of shape `(batch, src_len, vocab)` | |
- a dictionary of additional data, where 'inner_states' | |
is a list of hidden states. Note that the hidden | |
states have shape `(src_len, batch, vocab)`. | |
""" | |
x, extra = self.extract_features( | |
src_tokens, return_all_hiddens=return_all_hiddens, lang_id=lang_id | |
) | |
if not features_only: | |
x = self.output_layer(x, masked_tokens=masked_tokens) | |
return x, extra | |
def extract_features( | |
self, src_tokens, return_all_hiddens=False, lang_id=None, **kwargs | |
): | |
encoder_out = self.sentence_encoder( | |
src_tokens, | |
return_all_hiddens=return_all_hiddens, | |
lang_id=lang_id, | |
token_embeddings=kwargs.get("token_embeddings", None), | |
) | |
# T x B x C -> B x T x C | |
features = encoder_out["encoder_out"][0].transpose(0, 1) | |
inner_states = encoder_out["encoder_states"] if return_all_hiddens else None | |
return features, {"inner_states": inner_states} | |
class XMODTransformerEncoder(TransformerEncoder): | |
def build_encoder_layer(self, cfg): | |
layer = XMODTransformerEncoderLayerBase(cfg) | |
checkpoint = cfg.checkpoint_activations | |
if checkpoint: | |
offload_to_cpu = cfg.offload_activations | |
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) | |
# if we are checkpointing, enforce that FSDP always wraps the | |
# checkpointed layer, regardless of layer size | |
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 | |
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
return layer | |
def forward( | |
self, | |
src_tokens, | |
src_lengths: Optional[torch.Tensor] = None, | |
return_all_hiddens: bool = False, | |
token_embeddings: Optional[torch.Tensor] = None, | |
lang_id=None, | |
): | |
""" | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (torch.LongTensor): lengths of each source sentence of | |
shape `(batch)` | |
return_all_hiddens (bool, optional): also return all of the | |
intermediate hidden states (default: False). | |
token_embeddings (torch.Tensor, optional): precomputed embeddings | |
default `None` will recompute embeddings | |
Returns: | |
dict: | |
- **encoder_out** (Tensor): the last encoder layer's output of | |
shape `(src_len, batch, embed_dim)` | |
- **encoder_padding_mask** (ByteTensor): the positions of | |
padding elements of shape `(batch, src_len)` | |
- **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
of shape `(batch, src_len, embed_dim)` | |
- **encoder_states** (List[Tensor]): all intermediate | |
hidden states of shape `(src_len, batch, embed_dim)`. | |
Only populated if *return_all_hiddens* is True. | |
""" | |
return self.forward_scriptable( | |
src_tokens, | |
src_lengths, | |
return_all_hiddens, | |
token_embeddings, | |
lang_id=lang_id, | |
) | |
# TorchScript doesn't support super() method so that the scriptable Subclass | |
# can't access the base class model in Torchscript. | |
# Current workaround is to add a helper function with different name and | |
# call the helper function from scriptable Subclass. | |
def forward_scriptable( | |
self, | |
src_tokens, | |
src_lengths: Optional[torch.Tensor] = None, | |
return_all_hiddens: bool = False, | |
token_embeddings: Optional[torch.Tensor] = None, | |
lang_id=None, | |
): | |
""" | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (torch.LongTensor): lengths of each source sentence of | |
shape `(batch)` | |
return_all_hiddens (bool, optional): also return all of the | |
intermediate hidden states (default: False). | |
token_embeddings (torch.Tensor, optional): precomputed embeddings | |
default `None` will recompute embeddings | |
Returns: | |
dict: | |
- **encoder_out** (Tensor): the last encoder layer's output of | |
shape `(src_len, batch, embed_dim)` | |
- **encoder_padding_mask** (ByteTensor): the positions of | |
padding elements of shape `(batch, src_len)` | |
- **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
of shape `(batch, src_len, embed_dim)` | |
- **encoder_states** (List[Tensor]): all intermediate | |
hidden states of shape `(src_len, batch, embed_dim)`. | |
Only populated if *return_all_hiddens* is True. | |
""" | |
# compute padding mask | |
encoder_padding_mask = src_tokens.eq(self.padding_idx) | |
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() | |
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) | |
# account for padding while computing the representation | |
if has_pads: | |
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) | |
# B x T x C -> T x B x C | |
x = x.transpose(0, 1) | |
encoder_states = [] | |
if return_all_hiddens: | |
encoder_states.append(x) | |
# encoder layers | |
for layer in self.layers: | |
x = layer( | |
x, | |
encoder_padding_mask=encoder_padding_mask if has_pads else None, | |
lang_id=lang_id, | |
) | |
if return_all_hiddens: | |
assert encoder_states is not None | |
encoder_states.append(x) | |
if self.layer_norm is not None: | |
x = self.layer_norm(x) | |
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in | |
# `forward` so we use a dictionary instead. | |
# TorchScript does not support mixed values so the values are all lists. | |
# The empty list is equivalent to None. | |
src_lengths = ( | |
src_tokens.ne(self.padding_idx) | |
.sum(dim=1, dtype=torch.int32) | |
.reshape(-1, 1) | |
.contiguous() | |
) | |
return { | |
"encoder_out": [x], # T x B x C | |
"encoder_padding_mask": [encoder_padding_mask], # B x T | |
"encoder_embedding": [encoder_embedding], # B x T x C | |
"encoder_states": encoder_states, # List[T x B x C] | |
"src_tokens": [], | |
"src_lengths": [src_lengths], | |
} | |
def roberta_base_architecture(args): | |
args.ffn_modules = getattr(args, "ffn_modules", False) | |
args.adapter_modules = getattr(args, "adapter_modules", True) | |
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
args.languages = getattr( | |
args, | |
"languages", | |
[ | |
"ar_AR", | |
"en_XX", | |
"fi_FI", | |
"fr_XX", | |
"hi_IN", | |
"id_ID", | |
"ka_GE", | |
"ko_KR", | |
"ru_RU", | |
"sw_KE", | |
"ta_IN", | |
"th_TH", | |
"vi_VN", | |
], | |
) | |
base_architecture(args) | |
def roberta_base_architecture(args): | |
args.ffn_modules = getattr(args, "ffn_modules", False) | |
args.adapter_modules = getattr(args, "adapter_modules", True) | |
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
args.languages = getattr( | |
args, | |
"languages", | |
[ | |
"ar_AR", | |
"cs_CZ", | |
"en_XX", | |
"eu_ES", | |
"fi_FI", | |
"fr_XX", | |
"hi_IN", | |
"hr_HR", | |
"hu_HU", | |
"hy_AM", | |
"id_ID", | |
"it_IT", | |
"ka_GE", | |
"ko_KR", | |
"lt_LT", | |
"ml_IN", | |
"mn_MN", | |
"ms_MY", | |
"pl_PL", | |
"ro_RO", | |
"ru_RU", | |
"si_LK", | |
"sk_SK", | |
"sq_AL", | |
"sv_SE", | |
"sw_KE", | |
"ta_IN", | |
"th_TH", | |
"tl_XX", | |
"vi_VN", | |
], | |
) | |
base_architecture(args) | |
def roberta_base_architecture(args): | |
args.ffn_modules = getattr(args, "ffn_modules", False) | |
args.adapter_modules = getattr(args, "adapter_modules", True) | |
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
args.languages = getattr( | |
args, | |
"languages", | |
[ | |
"af_ZA", | |
"am_ET", | |
"ar_AR", | |
"be_BY", | |
"bn_IN", | |
"ca_ES", | |
"cs_CZ", | |
"cy_GB", | |
"da_DK", | |
"en_XX", | |
"eo_EO", | |
"et_EE", | |
"eu_ES", | |
"fa_IR", | |
"fi_FI", | |
"fr_XX", | |
"ga_IE", | |
"gl_ES", | |
"gu_IN", | |
"ha_NG", | |
"hi_IN", | |
"hr_HR", | |
"hu_HU", | |
"hy_AM", | |
"id_ID", | |
"is_IS", | |
"it_IT", | |
"ka_GE", | |
"ko_KR", | |
"ku_TR", | |
"la_VA", | |
"lt_LT", | |
"lv_LV", | |
"mk_MK", | |
"ml_IN", | |
"mn_MN", | |
"ms_MY", | |
"ne_NP", | |
"nl_XX", | |
"no_XX", | |
"pl_PL", | |
"ps_AF", | |
"pt_XX", | |
"ro_RO", | |
"ru_RU", | |
"sa_IN", | |
"sd_PK", | |
"si_LK", | |
"sk_SK", | |
"sl_SI", | |
"so_SO", | |
"sq_AL", | |
"sr_RS", | |
"sv_SE", | |
"sw_KE", | |
"ta_IN", | |
"te_IN", | |
"th_TH", | |
"tl_XX", | |
"vi_VN", | |
], | |
) | |
base_architecture(args) | |
def roberta_base_architecture(args): | |
args.ffn_modules = getattr(args, "ffn_modules", False) | |
args.adapter_modules = getattr(args, "adapter_modules", True) | |
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
args.languages = getattr( | |
args, | |
"languages", | |
[ | |
"af_ZA", | |
"am_ET", | |
"ar_AR", | |
"as_IN", | |
"be_BY", | |
"bn_IN", | |
"br_FR", | |
"bs_BA", | |
"ca_ES", | |
"cs_CZ", | |
"cy_GB", | |
"da_DK", | |
"en_XX", | |
"eo_EO", | |
"et_EE", | |
"eu_ES", | |
"fa_IR", | |
"fi_FI", | |
"fr_XX", | |
"fy_NL", | |
"ga_IE", | |
"gd_GB", | |
"gl_ES", | |
"gu_IN", | |
"ha_NG", | |
"hi_IN", | |
"hr_HR", | |
"hu_HU", | |
"hy_AM", | |
"id_ID", | |
"is_IS", | |
"it_IT", | |
"jv_ID", | |
"ka_GE", | |
"kn_IN", | |
"ko_KR", | |
"ku_TR", | |
"la_VA", | |
"lt_LT", | |
"lv_LV", | |
"mg_MG", | |
"mk_MK", | |
"ml_IN", | |
"mn_MN", | |
"mr_IN", | |
"ms_MY", | |
"ne_NP", | |
"nl_XX", | |
"no_XX", | |
"om_KE", | |
"or_IN", | |
"pa_IN", | |
"pl_PL", | |
"ps_AF", | |
"pt_XX", | |
"ro_RO", | |
"ru_RU", | |
"sa_IN", | |
"sd_PK", | |
"si_LK", | |
"sk_SK", | |
"sl_SI", | |
"so_SO", | |
"sq_AL", | |
"sr_RS", | |
"su_ID", | |
"sv_SE", | |
"sw_KE", | |
"ta_IN", | |
"te_IN", | |
"th_TH", | |
"tl_XX", | |
"vi_VN", | |
"xh_ZA", | |
"yi_DE", | |
], | |
) | |
base_architecture(args) | |
def roberta_base_architecture(args): | |
args.ffn_modules = getattr(args, "ffn_modules", False) | |
args.adapter_modules = getattr(args, "adapter_modules", True) | |
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
args.languages = getattr( | |
args, | |
"languages", | |
[ | |
"en_XX", | |
"id_ID", | |
"vi_VN", | |
"ru_RU", | |
"fa_IR", | |
"sv_SE", | |
"ja_XX", | |
"fr_XX", | |
"de_DE", | |
"ro_RO", | |
"ko_KR", | |
"hu_HU", | |
"es_XX", | |
"fi_FI", | |
"uk_UA", | |
"da_DK", | |
"pt_XX", | |
"no_XX", | |
"th_TH", | |
"pl_PL", | |
"bg_BG", | |
"nl_XX", | |
"zh_CN", | |
"he_IL", | |
"el_GR", | |
"it_IT", | |
"sk_SK", | |
"hr_HR", | |
"tr_TR", | |
"ar_AR", | |
"cs_CZ", | |
"lt_LT", | |
"hi_IN", | |
"zh_TW", | |
"ca_ES", | |
"ms_MY", | |
"sl_SI", | |
"lv_LV", | |
"ta_IN", | |
"bn_IN", | |
"et_EE", | |
"az_AZ", | |
"sq_AL", | |
"sr_RS", | |
"kk_KZ", | |
"ka_GE", | |
"tl_XX", | |
"ur_PK", | |
"is_IS", | |
"hy_AM", | |
"ml_IN", | |
"mk_MK", | |
"be_BY", | |
"la_VA", | |
"te_IN", | |
"eu_ES", | |
"gl_ES", | |
"mn_MN", | |
"kn_IN", | |
"ne_NP", | |
"sw_KE", | |
"si_LK", | |
"mr_IN", | |
"af_ZA", | |
"gu_IN", | |
"cy_GB", | |
"eo_EO", | |
"km_KH", | |
"ky_KG", | |
"uz_UZ", | |
"ps_AF", | |
"pa_IN", | |
"ga_IE", | |
"ha_NG", | |
"am_ET", | |
"lo_LA", | |
"ku_TR", | |
"so_SO", | |
"my_MM", | |
"or_IN", | |
"sa_IN", | |
], | |
) | |
base_architecture(args) | |
def roberta_base_architecture(args): | |
args.ffn_modules = getattr(args, "ffn_modules", False) | |
args.adapter_modules = getattr(args, "adapter_modules", True) | |
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", True) | |
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", False) | |
args.ln_before_adapter = getattr(args, "ln_before_adapter", False) | |
# args.bottleneck = getattr(args, "bottleneck", 8) | |
args.bottleneck = getattr(args, "bottleneck", 4) | |
args.languages = getattr( | |
args, | |
"languages", | |
[ | |
"en_XX", | |
"id_ID", | |
"vi_VN", | |
"ru_RU", | |
"fa_IR", | |
"sv_SE", | |
"ja_XX", | |
"fr_XX", | |
"de_DE", | |
"ro_RO", | |
"ko_KR", | |
"hu_HU", | |
"es_XX", | |
"fi_FI", | |
"uk_UA", | |
"da_DK", | |
"pt_XX", | |
"no_XX", | |
"th_TH", | |
"pl_PL", | |
"bg_BG", | |
"nl_XX", | |
"zh_CN", | |
"he_IL", | |
"el_GR", | |
"it_IT", | |
"sk_SK", | |
"hr_HR", | |
"tr_TR", | |
"ar_AR", | |
"cs_CZ", | |
"lt_LT", | |
"hi_IN", | |
"zh_TW", | |
"ca_ES", | |
"ms_MY", | |
"sl_SI", | |
"lv_LV", | |
"ta_IN", | |
"bn_IN", | |
"et_EE", | |
"az_AZ", | |
"sq_AL", | |
"sr_RS", | |
"kk_KZ", | |
"ka_GE", | |
"tl_XX", | |
"ur_PK", | |
"is_IS", | |
"hy_AM", | |
"ml_IN", | |
"mk_MK", | |
"be_BY", | |
"la_VA", | |
"te_IN", | |
"eu_ES", | |
"gl_ES", | |
"mn_MN", | |
"kn_IN", | |
"ne_NP", | |
"sw_KE", | |
"si_LK", | |
"mr_IN", | |
"af_ZA", | |
"gu_IN", | |
"cy_GB", | |
"eo_EO", | |
"km_KH", | |
"ky_KG", | |
"uz_UZ", | |
"ps_AF", | |
"pa_IN", | |
"ga_IE", | |
"ha_NG", | |
"am_ET", | |
"lo_LA", | |
"ku_TR", | |
"so_SO", | |
"my_MM", | |
"or_IN", | |
"sa_IN", | |
], | |
) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) | |
args.encoder_layers = getattr(args, "encoder_layers", 24) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
base_architecture(args) | |