Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import torch.nn as nn | |
| from fairseq.model_parallel.modules import ( | |
| ModelParallelTransformerDecoderLayer, | |
| ModelParallelTransformerEncoderLayer, | |
| ) | |
| from fairseq.models import register_model | |
| from fairseq.models.transformer import ( | |
| TransformerDecoder, | |
| TransformerEncoder, | |
| TransformerModel, | |
| ) | |
| try: | |
| from fairseq.model_parallel.megatron.mpu import ( | |
| copy_to_model_parallel_region, | |
| gather_from_model_parallel_region, | |
| VocabParallelEmbedding, | |
| ) | |
| has_megatron_submodule = True | |
| except (ImportError, ModuleNotFoundError): | |
| has_megatron_submodule = False | |
| logger = logging.getLogger(__name__) | |
| class ModelParallelTransformerModel(TransformerModel): | |
| """ | |
| Model parallel Transformer model. | |
| """ | |
| def build_embedding(cls, args, dictionary, embed_dim, path=None): | |
| if not has_megatron_submodule: | |
| raise ImportError( | |
| "\n\nPlease install the megatron submodule:" | |
| "\n\n git submodule update --init " | |
| "fairseq/model_parallel/megatron" | |
| ) | |
| dictionary.pad_to_multiple_(args.model_parallel_size * 8) | |
| num_embeddings = len(dictionary) | |
| padding_idx = dictionary.pad() | |
| def _vocab_init(tensor, **kwargs): | |
| nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5) | |
| nn.init.constant_(tensor[1], 0) | |
| emb = VocabParallelEmbedding( | |
| num_embeddings, embed_dim, padding_idx, init_method=_vocab_init | |
| ) | |
| # if provided, load from preloaded dictionaries | |
| if path: | |
| raise NotImplementedError( | |
| "Loading of embedding from path is not supported for model parallel" | |
| ) | |
| return emb | |
| def build_encoder(cls, args, src_dict, embed_tokens): | |
| return ModelParallelTransformerEncoder(args, src_dict, embed_tokens) | |
| def build_decoder(cls, args, tgt_dict, embed_tokens): | |
| return ModelParallelTransformerDecoder( | |
| args, | |
| tgt_dict, | |
| embed_tokens, | |
| no_encoder_attn=getattr(args, "no_cross_attention", False), | |
| ) | |
| class ModelParallelTransformerEncoder(TransformerEncoder): | |
| """ | |
| Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer | |
| is a :class:`ModelParallelTransformerEncoderLayer`. | |
| """ | |
| def __init__(self, args, dictionary, embed_tokens): | |
| super().__init__(args, dictionary, embed_tokens) | |
| if args.no_final_layer_norm: | |
| self.layer_norm = None | |
| def build_encoder_layer(self, args): | |
| return ModelParallelTransformerEncoderLayer(args) | |
| class ModelParallelTransformerDecoder(TransformerDecoder): | |
| """ | |
| Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer | |
| is a :class:`ModelParallelTransformerDecoderLayer`. | |
| """ | |
| def build_decoder_layer(self, args, no_encoder_attn=False): | |
| return ModelParallelTransformerDecoderLayer(args, no_encoder_attn) | |
| def output_layer(self, features, **kwargs): | |
| """Project features to the vocabulary size.""" | |
| if not self.share_input_output_embed: | |
| raise NotImplementedError( | |
| "Model parallel training currently requires --share-decoder-input-output-embed" | |
| ) | |
| features = copy_to_model_parallel_region(features) | |
| # project back to size of vocabulary | |
| x = self.output_projection(features) | |
| if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": | |
| x = gather_from_model_parallel_region(x).contiguous() | |
| return x | |