Commit 
							
							·
						
						95b4916
	
1
								Parent(s):
							
							eb21270
								
add mlm model and adjust naming
Browse files- README.md +5 -0
 - config.json +4 -4
 - configuration_bert.py → configuration_xlm_roberta.py +1 -1
 - convert_roberta_weights_to_flash.py +29 -44
 - embedding.py +1 -1
 - modeling_bert.py → modeling_xlm_roberta.py +210 -148
 - pytorch_model.bin +2 -2
 - bert_padding.py → xlm_padding.py +0 -0
 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Converting Weights
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ```
         
     | 
| 4 | 
         
            +
            python3 -m "xlm-roberta-flash-implementation".convert_roberta_weights_to_flash --output pytorch_model_xlmr_flash.bin
         
     | 
| 5 | 
         
            +
            ```
         
     | 
    	
        config.json
    CHANGED
    
    | 
         @@ -1,9 +1,9 @@ 
     | 
|
| 1 | 
         
             
            {
         
     | 
| 2 | 
         
             
              "auto_map": {
         
     | 
| 3 | 
         
            -
                "AutoConfig": " 
     | 
| 4 | 
         
            -
                "AutoModel": " 
     | 
| 5 | 
         
            -
                "AutoModelForPreTraining": " 
     | 
| 6 | 
         
            -
                "AutoModelForMaskedLM": " 
     | 
| 7 | 
         
             
              },
         
     | 
| 8 | 
         
             
              "attention_probs_dropout_prob": 0.1,
         
     | 
| 9 | 
         
             
              "bos_token_id": 0,
         
     | 
| 
         | 
|
| 1 | 
         
             
            {
         
     | 
| 2 | 
         
             
              "auto_map": {
         
     | 
| 3 | 
         
            +
                "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
         
     | 
| 4 | 
         
            +
                "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
         
     | 
| 5 | 
         
            +
                "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
         
     | 
| 6 | 
         
            +
                "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM"
         
     | 
| 7 | 
         
             
              },
         
     | 
| 8 | 
         
             
              "attention_probs_dropout_prob": 0.1,
         
     | 
| 9 | 
         
             
              "bos_token_id": 0,
         
     | 
    	
        configuration_bert.py → configuration_xlm_roberta.py
    RENAMED
    
    | 
         @@ -1,6 +1,6 @@ 
     | 
|
| 1 | 
         
             
            from transformers import PretrainedConfig
         
     | 
| 2 | 
         | 
| 3 | 
         
            -
            class  
     | 
| 4 | 
         
             
                def __init__(
         
     | 
| 5 | 
         
             
                        self,
         
     | 
| 6 | 
         
             
                        vocab_size=30522,
         
     | 
| 
         | 
|
| 1 | 
         
             
            from transformers import PretrainedConfig
         
     | 
| 2 | 
         | 
| 3 | 
         
            +
            class XLMRobertaFlashConfig(PretrainedConfig):
         
     | 
| 4 | 
         
             
                def __init__(
         
     | 
| 5 | 
         
             
                        self,
         
     | 
| 6 | 
         
             
                        vocab_size=30522,
         
     | 
    	
        convert_roberta_weights_to_flash.py
    CHANGED
    
    | 
         @@ -1,9 +1,10 @@ 
     | 
|
| 1 | 
         
             
            import re
         
     | 
| 2 | 
         
             
            from collections import OrderedDict
         
     | 
| 3 | 
         
            -
            from transformers import  
     | 
| 4 | 
         
             
            from transformers import XLMRobertaForMaskedLM
         
     | 
| 5 | 
         | 
| 6 | 
         
            -
            from  
     | 
| 
         | 
|
| 7 | 
         
             
            import torch
         
     | 
| 8 | 
         | 
| 9 | 
         
             
            import click
         
     | 
| 
         @@ -16,12 +17,6 @@ def remap_state_dict(state_dict, config: PretrainedConfig): 
     | 
|
| 16 | 
         
             
                Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
         
     | 
| 17 | 
         
             
                """
         
     | 
| 18 | 
         | 
| 19 | 
         
            -
                # Replace Roberta with Bert
         
     | 
| 20 | 
         
            -
                def key_mapping_roberta(key):
         
     | 
| 21 | 
         
            -
                    return re.sub(r"^roberta.", "bert.", key)
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
                state_dict = OrderedDict((key_mapping_roberta(k), v) for k, v in state_dict.items())
         
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
             
                # LayerNorm
         
     | 
| 26 | 
         
             
                def key_mapping_ln_gamma_beta(key):
         
     | 
| 27 | 
         
             
                    key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
         
     | 
| 
         @@ -34,21 +29,21 @@ def remap_state_dict(state_dict, config: PretrainedConfig): 
     | 
|
| 34 | 
         | 
| 35 | 
         
             
                # Layers
         
     | 
| 36 | 
         
             
                def key_mapping_layers(key):
         
     | 
| 37 | 
         
            -
                    return re.sub(r"^ 
     | 
| 38 | 
         | 
| 39 | 
         
             
                state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
         
     | 
| 40 | 
         | 
| 41 | 
         
             
                # LayerNorm
         
     | 
| 42 | 
         
             
                def key_mapping_ln(key):
         
     | 
| 43 | 
         
            -
                    key = re.sub(r"^ 
     | 
| 44 | 
         
             
                    key = re.sub(
         
     | 
| 45 | 
         
            -
                        r"^ 
     | 
| 46 | 
         
            -
                        r" 
     | 
| 47 | 
         
             
                        key,
         
     | 
| 48 | 
         
             
                    )
         
     | 
| 49 | 
         
             
                    key = re.sub(
         
     | 
| 50 | 
         
            -
                        r"^ 
     | 
| 51 | 
         
            -
                        r" 
     | 
| 52 | 
         
             
                        key,
         
     | 
| 53 | 
         
             
                    )
         
     | 
| 54 | 
         
             
                    key = re.sub(
         
     | 
| 
         @@ -63,13 +58,13 @@ def remap_state_dict(state_dict, config: PretrainedConfig): 
     | 
|
| 63 | 
         
             
                # MLP
         
     | 
| 64 | 
         
             
                def key_mapping_mlp(key):
         
     | 
| 65 | 
         
             
                    key = re.sub(
         
     | 
| 66 | 
         
            -
                        r"^ 
     | 
| 67 | 
         
            -
                        r" 
     | 
| 68 | 
         
             
                        key,
         
     | 
| 69 | 
         
             
                    )
         
     | 
| 70 | 
         
             
                    key = re.sub(
         
     | 
| 71 | 
         
            -
                        r"^ 
     | 
| 72 | 
         
            -
                        r" 
     | 
| 73 | 
         
             
                        key,
         
     | 
| 74 | 
         
             
                    )
         
     | 
| 75 | 
         
             
                    return key
         
     | 
| 
         @@ -79,33 +74,33 @@ def remap_state_dict(state_dict, config: PretrainedConfig): 
     | 
|
| 79 | 
         
             
                # Attention
         
     | 
| 80 | 
         
             
                last_layer_subset = getattr(config, "last_layer_subset", False)
         
     | 
| 81 | 
         
             
                for d in range(config.num_hidden_layers):
         
     | 
| 82 | 
         
            -
                    Wq = state_dict.pop(f" 
     | 
| 83 | 
         
            -
                    Wk = state_dict.pop(f" 
     | 
| 84 | 
         
            -
                    Wv = state_dict.pop(f" 
     | 
| 85 | 
         
            -
                    bq = state_dict.pop(f" 
     | 
| 86 | 
         
            -
                    bk = state_dict.pop(f" 
     | 
| 87 | 
         
            -
                    bv = state_dict.pop(f" 
     | 
| 88 | 
         
             
                    if not (last_layer_subset and d == config.num_hidden_layers - 1):
         
     | 
| 89 | 
         
            -
                        state_dict[f" 
     | 
| 90 | 
         
             
                            [Wq, Wk, Wv], dim=0
         
     | 
| 91 | 
         
             
                        )
         
     | 
| 92 | 
         
            -
                        state_dict[f" 
     | 
| 93 | 
         
             
                            [bq, bk, bv], dim=0
         
     | 
| 94 | 
         
             
                        )
         
     | 
| 95 | 
         
             
                    else:
         
     | 
| 96 | 
         
            -
                        state_dict[f" 
     | 
| 97 | 
         
            -
                        state_dict[f" 
     | 
| 98 | 
         
             
                            [Wk, Wv], dim=0
         
     | 
| 99 | 
         
             
                        )
         
     | 
| 100 | 
         
            -
                        state_dict[f" 
     | 
| 101 | 
         
            -
                        state_dict[f" 
     | 
| 102 | 
         
             
                            [bk, bv], dim=0
         
     | 
| 103 | 
         
             
                        )
         
     | 
| 104 | 
         | 
| 105 | 
         
             
                def key_mapping_attn(key):
         
     | 
| 106 | 
         
             
                    return re.sub(
         
     | 
| 107 | 
         
            -
                        r"^ 
     | 
| 108 | 
         
            -
                        r" 
     | 
| 109 | 
         
             
                        key,
         
     | 
| 110 | 
         
             
                    )
         
     | 
| 111 | 
         | 
| 
         @@ -121,8 +116,8 @@ def remap_state_dict(state_dict, config: PretrainedConfig): 
     | 
|
| 121 | 
         
             
                # Word embedding
         
     | 
| 122 | 
         
             
                pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
         
     | 
| 123 | 
         
             
                if pad_vocab_size_multiple > 1:
         
     | 
| 124 | 
         
            -
                    word_embeddings = state_dict[" 
     | 
| 125 | 
         
            -
                    state_dict[" 
     | 
| 126 | 
         
             
                        word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
         
     | 
| 127 | 
         
             
                    )
         
     | 
| 128 | 
         
             
                    decoder_weight = state_dict["cls.predictions.decoder.weight"]
         
     | 
| 
         @@ -137,16 +132,6 @@ def remap_state_dict(state_dict, config: PretrainedConfig): 
     | 
|
| 137 | 
         
             
                        decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
         
     | 
| 138 | 
         
             
                    )
         
     | 
| 139 | 
         | 
| 140 | 
         
            -
                # Embeddings
         
     | 
| 141 | 
         
            -
                def key_remove_bert(key):
         
     | 
| 142 | 
         
            -
                    return re.sub(r"^bert.", "", key)
         
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
                state_dict = OrderedDict(
         
     | 
| 145 | 
         
            -
                    (key_remove_bert(k), v)
         
     | 
| 146 | 
         
            -
                    for k, v in state_dict.items()
         
     | 
| 147 | 
         
            -
                    if not k.startswith('lm_head')
         
     | 
| 148 | 
         
            -
                )
         
     | 
| 149 | 
         
            -
             
     | 
| 150 | 
         
             
                return state_dict
         
     | 
| 151 | 
         | 
| 152 | 
         | 
| 
         | 
|
| 1 | 
         
             
            import re
         
     | 
| 2 | 
         
             
            from collections import OrderedDict
         
     | 
| 3 | 
         
            +
            from transformers import PretrainedConfig
         
     | 
| 4 | 
         
             
            from transformers import XLMRobertaForMaskedLM
         
     | 
| 5 | 
         | 
| 6 | 
         
            +
            from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
         
     | 
| 7 | 
         
            +
            from .modeling_xlm_roberta import XLMRobertaForMaskedLM as BertModel
         
     | 
| 8 | 
         
             
            import torch
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            import click
         
     | 
| 
         | 
|
| 17 | 
         
             
                Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
         
     | 
| 18 | 
         
             
                """
         
     | 
| 19 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 20 | 
         
             
                # LayerNorm
         
     | 
| 21 | 
         
             
                def key_mapping_ln_gamma_beta(key):
         
     | 
| 22 | 
         
             
                    key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
         
     | 
| 
         | 
|
| 29 | 
         | 
| 30 | 
         
             
                # Layers
         
     | 
| 31 | 
         
             
                def key_mapping_layers(key):
         
     | 
| 32 | 
         
            +
                    return re.sub(r"^roberta.encoder.layer.", "roberta.encoder.layers.", key)
         
     | 
| 33 | 
         | 
| 34 | 
         
             
                state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
         
     | 
| 35 | 
         | 
| 36 | 
         
             
                # LayerNorm
         
     | 
| 37 | 
         
             
                def key_mapping_ln(key):
         
     | 
| 38 | 
         
            +
                    key = re.sub(r"^roberta.embeddings.LayerNorm.", "roberta.emb_ln.", key)
         
     | 
| 39 | 
         
             
                    key = re.sub(
         
     | 
| 40 | 
         
            +
                        r"^roberta.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
         
     | 
| 41 | 
         
            +
                        r"roberta.encoder.layers.\1.norm1.\2",
         
     | 
| 42 | 
         
             
                        key,
         
     | 
| 43 | 
         
             
                    )
         
     | 
| 44 | 
         
             
                    key = re.sub(
         
     | 
| 45 | 
         
            +
                        r"^roberta.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
         
     | 
| 46 | 
         
            +
                        r"roberta.encoder.layers.\1.norm2.\2",
         
     | 
| 47 | 
         
             
                        key,
         
     | 
| 48 | 
         
             
                    )
         
     | 
| 49 | 
         
             
                    key = re.sub(
         
     | 
| 
         | 
|
| 58 | 
         
             
                # MLP
         
     | 
| 59 | 
         
             
                def key_mapping_mlp(key):
         
     | 
| 60 | 
         
             
                    key = re.sub(
         
     | 
| 61 | 
         
            +
                        r"^roberta.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
         
     | 
| 62 | 
         
            +
                        r"roberta.encoder.layers.\1.mlp.fc1.\2",
         
     | 
| 63 | 
         
             
                        key,
         
     | 
| 64 | 
         
             
                    )
         
     | 
| 65 | 
         
             
                    key = re.sub(
         
     | 
| 66 | 
         
            +
                        r"^roberta.encoder.layers.(\d+).output.dense.(weight|bias)",
         
     | 
| 67 | 
         
            +
                        r"roberta.encoder.layers.\1.mlp.fc2.\2",
         
     | 
| 68 | 
         
             
                        key,
         
     | 
| 69 | 
         
             
                    )
         
     | 
| 70 | 
         
             
                    return key
         
     | 
| 
         | 
|
| 74 | 
         
             
                # Attention
         
     | 
| 75 | 
         
             
                last_layer_subset = getattr(config, "last_layer_subset", False)
         
     | 
| 76 | 
         
             
                for d in range(config.num_hidden_layers):
         
     | 
| 77 | 
         
            +
                    Wq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.weight")
         
     | 
| 78 | 
         
            +
                    Wk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.weight")
         
     | 
| 79 | 
         
            +
                    Wv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.weight")
         
     | 
| 80 | 
         
            +
                    bq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.bias")
         
     | 
| 81 | 
         
            +
                    bk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.bias")
         
     | 
| 82 | 
         
            +
                    bv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.bias")
         
     | 
| 83 | 
         
             
                    if not (last_layer_subset and d == config.num_hidden_layers - 1):
         
     | 
| 84 | 
         
            +
                        state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
         
     | 
| 85 | 
         
             
                            [Wq, Wk, Wv], dim=0
         
     | 
| 86 | 
         
             
                        )
         
     | 
| 87 | 
         
            +
                        state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
         
     | 
| 88 | 
         
             
                            [bq, bk, bv], dim=0
         
     | 
| 89 | 
         
             
                        )
         
     | 
| 90 | 
         
             
                    else:
         
     | 
| 91 | 
         
            +
                        state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.weight"] = Wq
         
     | 
| 92 | 
         
            +
                        state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
         
     | 
| 93 | 
         
             
                            [Wk, Wv], dim=0
         
     | 
| 94 | 
         
             
                        )
         
     | 
| 95 | 
         
            +
                        state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.bias"] = bq
         
     | 
| 96 | 
         
            +
                        state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
         
     | 
| 97 | 
         
             
                            [bk, bv], dim=0
         
     | 
| 98 | 
         
             
                        )
         
     | 
| 99 | 
         | 
| 100 | 
         
             
                def key_mapping_attn(key):
         
     | 
| 101 | 
         
             
                    return re.sub(
         
     | 
| 102 | 
         
            +
                        r"^roberta.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
         
     | 
| 103 | 
         
            +
                        r"roberta.encoder.layers.\1.mixer.out_proj.\2",
         
     | 
| 104 | 
         
             
                        key,
         
     | 
| 105 | 
         
             
                    )
         
     | 
| 106 | 
         | 
| 
         | 
|
| 116 | 
         
             
                # Word embedding
         
     | 
| 117 | 
         
             
                pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
         
     | 
| 118 | 
         
             
                if pad_vocab_size_multiple > 1:
         
     | 
| 119 | 
         
            +
                    word_embeddings = state_dict["roberta.embeddings.word_embeddings.weight"]
         
     | 
| 120 | 
         
            +
                    state_dict["roberta.embeddings.word_embeddings.weight"] = F.pad(
         
     | 
| 121 | 
         
             
                        word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
         
     | 
| 122 | 
         
             
                    )
         
     | 
| 123 | 
         
             
                    decoder_weight = state_dict["cls.predictions.decoder.weight"]
         
     | 
| 
         | 
|
| 132 | 
         
             
                        decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
         
     | 
| 133 | 
         
             
                    )
         
     | 
| 134 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 135 | 
         
             
                return state_dict
         
     | 
| 136 | 
         | 
| 137 | 
         | 
    	
        embedding.py
    CHANGED
    
    | 
         @@ -11,7 +11,7 @@ from torch import Tensor 
     | 
|
| 11 | 
         
             
            from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
         
     | 
| 12 | 
         | 
| 13 | 
         | 
| 14 | 
         
            -
            class  
     | 
| 15 | 
         
             
                def __init__(
         
     | 
| 16 | 
         
             
                    self,
         
     | 
| 17 | 
         
             
                    embed_dim,
         
     | 
| 
         | 
|
| 11 | 
         
             
            from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
         
     | 
| 12 | 
         | 
| 13 | 
         | 
| 14 | 
         
            +
            class XLMRobertaEmbeddings(nn.Module):
         
     | 
| 15 | 
         
             
                def __init__(
         
     | 
| 16 | 
         
             
                    self,
         
     | 
| 17 | 
         
             
                    embed_dim,
         
     | 
    	
        modeling_bert.py → modeling_xlm_roberta.py
    RENAMED
    
    | 
         @@ -13,28 +13,32 @@ import re 
     | 
|
| 13 | 
         
             
            from collections import OrderedDict
         
     | 
| 14 | 
         
             
            from collections.abc import Sequence
         
     | 
| 15 | 
         
             
            from functools import partial
         
     | 
| 16 | 
         
            -
            from typing import Any, Mapping
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            import torch
         
     | 
| 19 | 
         
             
            import torch.nn as nn
         
     | 
| 20 | 
         
             
            import torch.nn.functional as F
         
     | 
| 21 | 
         
             
            from einops import rearrange
         
     | 
| 22 | 
         
            -
            from transformers import  
     | 
| 23 | 
         
             
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         
             
            from transformers.models.bert.modeling_bert import (
         
     | 
| 25 | 
         
             
                BaseModelOutputWithPoolingAndCrossAttentions,
         
     | 
| 26 | 
         
             
                BertForPreTrainingOutput,
         
     | 
| 27 | 
         
             
            )
         
     | 
| 28 | 
         | 
| 29 | 
         
            -
            from  
     | 
| 
         | 
|
| 
         | 
|
| 30 | 
         
             
                index_first_axis,
         
     | 
| 31 | 
         
             
                index_first_axis_residual,
         
     | 
| 32 | 
         
             
                pad_input,
         
     | 
| 33 | 
         
             
                unpad_input,
         
     | 
| 34 | 
         
             
            )
         
     | 
| 35 | 
         
            -
            from . 
     | 
| 36 | 
         
             
            from .block import Block
         
     | 
| 37 | 
         
            -
            from .embedding import  
     | 
| 38 | 
         
             
            from .mha import MHA
         
     | 
| 39 | 
         
             
            from .mlp import FusedMLP, Mlp
         
     | 
| 40 | 
         | 
| 
         @@ -155,8 +159,8 @@ def _init_weights(module, initializer_range=0.02): 
     | 
|
| 155 | 
         
             
                        nn.init.zeros_(module.weight[module.padding_idx])
         
     | 
| 156 | 
         | 
| 157 | 
         | 
| 158 | 
         
            -
            class  
     | 
| 159 | 
         
            -
                def __init__(self, config:  
     | 
| 160 | 
         
             
                    super().__init__()
         
     | 
| 161 | 
         
             
                    self.use_flash_attn = getattr(config, "use_flash_attn", False)
         
     | 
| 162 | 
         
             
                    self.layers = nn.ModuleList(
         
     | 
| 
         @@ -218,7 +222,7 @@ class BertEncoder(nn.Module): 
     | 
|
| 218 | 
         
             
                    return hidden_states
         
     | 
| 219 | 
         | 
| 220 | 
         | 
| 221 | 
         
            -
            class  
     | 
| 222 | 
         
             
                def __init__(self, config):
         
     | 
| 223 | 
         
             
                    super().__init__()
         
     | 
| 224 | 
         
             
                    fused_bias_fc = getattr(config, "fused_bias_fc", False)
         
     | 
| 
         @@ -237,7 +241,7 @@ class BertPooler(nn.Module): 
     | 
|
| 237 | 
         
             
                    return pooled_output
         
     | 
| 238 | 
         | 
| 239 | 
         | 
| 240 | 
         
            -
            class  
     | 
| 241 | 
         
             
                def __init__(self, config):
         
     | 
| 242 | 
         
             
                    super().__init__()
         
     | 
| 243 | 
         
             
                    fused_bias_fc = getattr(config, "fused_bias_fc", False)
         
     | 
| 
         @@ -268,7 +272,7 @@ class BertPredictionHeadTransform(nn.Module): 
     | 
|
| 268 | 
         
             
                    return hidden_states
         
     | 
| 269 | 
         | 
| 270 | 
         | 
| 271 | 
         
            -
            class  
     | 
| 272 | 
         
             
                def __init__(self, config):
         
     | 
| 273 | 
         
             
                    super().__init__()
         
     | 
| 274 | 
         
             
                    fused_bias_fc = getattr(config, "fused_bias_fc", False)
         
     | 
| 
         @@ -276,7 +280,7 @@ class BertLMPredictionHead(nn.Module): 
     | 
|
| 276 | 
         
             
                        raise ImportError("fused_dense is not installed")
         
     | 
| 277 | 
         
             
                    linear_cls = nn.Linear if not fused_bias_fc else FusedDense
         
     | 
| 278 | 
         | 
| 279 | 
         
            -
                    self.transform =  
     | 
| 280 | 
         | 
| 281 | 
         
             
                    # The output weights are the same as the input embeddings, but there is
         
     | 
| 282 | 
         
             
                    # an output-only bias for each token.
         
     | 
| 
         @@ -288,10 +292,10 @@ class BertLMPredictionHead(nn.Module): 
     | 
|
| 288 | 
         
             
                    return hidden_states
         
     | 
| 289 | 
         | 
| 290 | 
         | 
| 291 | 
         
            -
            class  
     | 
| 292 | 
         
             
                def __init__(self, config):
         
     | 
| 293 | 
         
             
                    super().__init__()
         
     | 
| 294 | 
         
            -
                    self.predictions =  
     | 
| 295 | 
         
             
                    self.seq_relationship = nn.Linear(config.hidden_size, 2)
         
     | 
| 296 | 
         | 
| 297 | 
         
             
                def forward(self, sequence_output, pooled_output):
         
     | 
| 
         @@ -300,64 +304,22 @@ class BertPreTrainingHeads(nn.Module): 
     | 
|
| 300 | 
         
             
                    return prediction_scores, seq_relationship_score
         
     | 
| 301 | 
         | 
| 302 | 
         | 
| 303 | 
         
            -
             
     | 
| 304 | 
         
            -
            #     """An abstract class to handle weights initialization and
         
     | 
| 305 | 
         
            -
            #     a simple interface for dowloading and loading pretrained models.
         
     | 
| 306 | 
         
            -
            #     """
         
     | 
| 307 | 
         
            -
            #
         
     | 
| 308 | 
         
            -
            #     def __init__(self, config, *inputs, **kwargs):
         
     | 
| 309 | 
         
            -
            #         super().__init__()
         
     | 
| 310 | 
         
            -
            #         if not isinstance(config, BertConfig):
         
     | 
| 311 | 
         
            -
            #             raise ValueError(
         
     | 
| 312 | 
         
            -
            #                 "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
         
     | 
| 313 | 
         
            -
            #                 "To create a model from a Google pretrained model use "
         
     | 
| 314 | 
         
            -
            #                 "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
         
     | 
| 315 | 
         
            -
            #                     self.__class__.__name__, self.__class__.__name__
         
     | 
| 316 | 
         
            -
            #                 )
         
     | 
| 317 | 
         
            -
            #             )
         
     | 
| 318 | 
         
            -
            #         self.config = config
         
     | 
| 319 | 
         
            -
            #
         
     | 
| 320 | 
         
            -
            #     @classmethod
         
     | 
| 321 | 
         
            -
            #     def from_pretrained(cls, model_name, config, *inputs, **kwargs):
         
     | 
| 322 | 
         
            -
            #         """
         
     | 
| 323 | 
         
            -
            #         Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
         
     | 
| 324 | 
         
            -
            #         Download and cache the pre-trained model file if needed.
         
     | 
| 325 | 
         
            -
            #
         
     | 
| 326 | 
         
            -
            #         Params:
         
     | 
| 327 | 
         
            -
            #             pretrained_model_name_or_path: either:
         
     | 
| 328 | 
         
            -
            #                 - a path or url to a pretrained model archive containing:
         
     | 
| 329 | 
         
            -
            #                     . `bert_config.json` a configuration file for the model
         
     | 
| 330 | 
         
            -
            #                     . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
         
     | 
| 331 | 
         
            -
            #                 - a path or url to a pretrained model archive containing:
         
     | 
| 332 | 
         
            -
            #                     . `bert_config.json` a configuration file for the model
         
     | 
| 333 | 
         
            -
            #                     . `model.chkpt` a TensorFlow checkpoint
         
     | 
| 334 | 
         
            -
            #             *inputs, **kwargs: additional input for the specific Bert class
         
     | 
| 335 | 
         
            -
            #                 (ex: num_labels for BertForSequenceClassification)
         
     | 
| 336 | 
         
            -
            #         """
         
     | 
| 337 | 
         
            -
            #         # Instantiate model.
         
     | 
| 338 | 
         
            -
            #         model = cls(config, *inputs, **kwargs)
         
     | 
| 339 | 
         
            -
            #         load_return = model.load_state_dict(
         
     | 
| 340 | 
         
            -
            #             remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
         
     | 
| 341 | 
         
            -
            #         )
         
     | 
| 342 | 
         
            -
            #         logger.info(load_return)
         
     | 
| 343 | 
         
            -
            #         return model
         
     | 
| 344 | 
         
            -
             
     | 
| 345 | 
         
            -
            class BertPreTrainedModel(PreTrainedModel):
         
     | 
| 346 | 
         
             
                """An abstract class to handle weights initialization and
         
     | 
| 347 | 
         
             
                a simple interface for dowloading and loading pretrained models.
         
     | 
| 348 | 
         
             
                """
         
     | 
| 349 | 
         
            -
                config_class =  
     | 
| 350 | 
         
            -
                base_model_prefix = " 
     | 
| 351 | 
         
             
                supports_gradient_checkpointing = True
         
     | 
| 352 | 
         | 
| 353 | 
         
             
                def _set_gradient_checkpointing(self, module, value=False):
         
     | 
| 354 | 
         
            -
                    if isinstance(module,  
     | 
| 355 | 
         
             
                        module.gradient_checkpointing = value
         
     | 
| 356 | 
         | 
| 357 | 
         | 
| 358 | 
         | 
| 359 | 
         
            -
            class  
     | 
| 360 | 
         
            -
                def __init__(self, config:  
     | 
| 361 | 
         
             
                    super().__init__(config)
         
     | 
| 362 | 
         
             
                    self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
         
     | 
| 363 | 
         
             
                    if config.vocab_size % self.pad_vocab_size_multiple != 0:
         
     | 
| 
         @@ -369,7 +331,7 @@ class BertModel(BertPreTrainedModel): 
     | 
|
| 369 | 
         
             
                        raise ImportError("Triton is not installed")
         
     | 
| 370 | 
         
             
                    assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
         
     | 
| 371 | 
         | 
| 372 | 
         
            -
                    self.embeddings =  
     | 
| 373 | 
         
             
                        config.hidden_size,
         
     | 
| 374 | 
         
             
                        config.vocab_size,
         
     | 
| 375 | 
         
             
                        config.max_position_embeddings,
         
     | 
| 
         @@ -378,11 +340,12 @@ class BertModel(BertPreTrainedModel): 
     | 
|
| 378 | 
         
             
                    )
         
     | 
| 379 | 
         
             
                    self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
         
     | 
| 380 | 
         
             
                    self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 381 | 
         
            -
                    self.encoder =  
     | 
| 382 | 
         
            -
                    self.pooler =  
     | 
| 383 | 
         | 
| 384 | 
         
             
                    self.apply(partial(_init_weights, initializer_range=config.initializer_range))
         
     | 
| 385 | 
         | 
| 
         | 
|
| 386 | 
         
             
                def forward(
         
     | 
| 387 | 
         
             
                    self,
         
     | 
| 388 | 
         
             
                    input_ids,
         
     | 
| 
         @@ -390,12 +353,22 @@ class BertModel(BertPreTrainedModel): 
     | 
|
| 390 | 
         
             
                    token_type_ids=None,
         
     | 
| 391 | 
         
             
                    attention_mask=None,
         
     | 
| 392 | 
         
             
                    masked_tokens_mask=None,
         
     | 
| 
         | 
|
| 
         | 
|
| 393 | 
         
             
                ):
         
     | 
| 394 | 
         
            -
                    """If masked_tokens_mask is not None (i.e. last_layer_subset == True in  
     | 
| 395 | 
         
             
                    we only want the output for the masked tokens. This means that we only compute the last
         
     | 
| 396 | 
         
             
                    layer output for these tokens.
         
     | 
| 397 | 
         
             
                    masked_tokens_mask: (batch, seqlen), dtype=torch.bool
         
     | 
| 398 | 
         
             
                    """
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 399 | 
         
             
                    hidden_states = self.embeddings(
         
     | 
| 400 | 
         
             
                        input_ids, position_ids=position_ids, token_type_ids=token_type_ids
         
     | 
| 401 | 
         
             
                    )
         
     | 
| 
         @@ -437,111 +410,200 @@ class BertModel(BertPreTrainedModel): 
     | 
|
| 437 | 
         
             
                            sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
         
     | 
| 438 | 
         
             
                        pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
         
     | 
| 439 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 440 | 
         
             
                    return BaseModelOutputWithPoolingAndCrossAttentions(
         
     | 
| 441 | 
         
             
                        last_hidden_state=sequence_output,
         
     | 
| 442 | 
         
             
                        pooler_output=pooled_output,
         
     | 
| 443 | 
         
             
                    )
         
     | 
| 444 | 
         | 
| 445 | 
         | 
| 446 | 
         
            -
            class  
     | 
| 447 | 
         
            -
                 
     | 
| 448 | 
         
            -
             
     | 
| 449 | 
         
            -
             
     | 
| 450 | 
         
             
                    super().__init__(config)
         
     | 
| 451 | 
         
            -
                    # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
         
     | 
| 452 | 
         
            -
                    # (around 15%) to the classifier heads.
         
     | 
| 453 | 
         
            -
                    self.dense_seq_output = getattr(config, "dense_seq_output", False)
         
     | 
| 454 | 
         
            -
                    # If last_layer_subset, we only need the compute the last layer for a subset of tokens
         
     | 
| 455 | 
         
            -
                    # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
         
     | 
| 456 | 
         
            -
                    self.last_layer_subset = getattr(config, "last_layer_subset", False)
         
     | 
| 457 | 
         
            -
                    if self.last_layer_subset:
         
     | 
| 458 | 
         
            -
                        assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
         
     | 
| 459 | 
         
            -
                    use_xentropy = getattr(config, "use_xentropy", False)
         
     | 
| 460 | 
         
            -
                    if use_xentropy and CrossEntropyLoss is None:
         
     | 
| 461 | 
         
            -
                        raise ImportError("xentropy_cuda is not installed")
         
     | 
| 462 | 
         
            -
                    loss_cls = (
         
     | 
| 463 | 
         
            -
                        nn.CrossEntropyLoss
         
     | 
| 464 | 
         
            -
                        if not use_xentropy
         
     | 
| 465 | 
         
            -
                        else partial(CrossEntropyLoss, inplace_backward=True)
         
     | 
| 466 | 
         
            -
                    )
         
     | 
| 467 | 
         | 
| 468 | 
         
            -
                     
     | 
| 469 | 
         
            -
             
     | 
| 470 | 
         
            -
             
     | 
| 471 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 472 | 
         | 
| 473 | 
         
             
                    # Initialize weights and apply final processing
         
     | 
| 474 | 
         
            -
                    self. 
     | 
| 475 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 476 | 
         | 
| 477 | 
         
            -
                def tie_weights(self):
         
     | 
| 478 | 
         
            -
                    self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
         
     | 
| 479 | 
         | 
| 480 | 
         
             
                def forward(
         
     | 
| 481 | 
         
             
                    self,
         
     | 
| 482 | 
         
            -
                    input_ids,
         
     | 
| 483 | 
         
            -
                     
     | 
| 484 | 
         
            -
                    token_type_ids=None,
         
     | 
| 485 | 
         
            -
                     
     | 
| 486 | 
         
            -
                     
     | 
| 487 | 
         
            -
                     
     | 
| 488 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 489 | 
         
             
                    """
         
     | 
| 490 | 
         
            -
                     
     | 
| 491 | 
         
            -
                    mask).
         
     | 
| 492 | 
         
            -
                    Outputs:
         
     | 
| 493 | 
         
            -
                        if `labels` and `next_sentence_label` are not `None`:
         
     | 
| 494 | 
         
            -
                            Outputs the total_loss which is the sum of the masked language modeling loss and the next
         
     | 
| 495 | 
         
            -
                            sentence classification loss.
         
     | 
| 496 | 
         
            -
                        if `labels` or `next_sentence_label` is `None`:
         
     | 
| 497 | 
         
            -
                            Outputs a tuple comprising
         
     | 
| 498 | 
         
            -
                            - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
         
     | 
| 499 | 
         
            -
                            - the next sentence classification logits of shape [batch_size, 2].
         
     | 
| 500 | 
         | 
| 501 | 
         
            -
                     
     | 
| 502 | 
         
            -
                    masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
         
     | 
| 503 | 
         
            -
                    outputs = self.bert(
         
     | 
| 504 | 
         
             
                        input_ids,
         
     | 
| 505 | 
         
            -
                         
     | 
| 506 | 
         
             
                        token_type_ids=token_type_ids,
         
     | 
| 507 | 
         
            -
                         
     | 
| 508 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 509 | 
         
             
                    )
         
     | 
| 510 | 
         
            -
                    sequence_output 
     | 
| 511 | 
         
            -
                     
     | 
| 512 | 
         
            -
             
     | 
| 513 | 
         
            -
             
     | 
| 514 | 
         
            -
             
     | 
| 515 | 
         
            -
             
     | 
| 516 | 
         
            -
             
     | 
| 517 | 
         
            -
             
     | 
| 518 | 
         
            -
             
     | 
| 519 | 
         
            -
             
     | 
| 520 | 
         
            -
                    if  
     | 
| 521 | 
         
            -
                         
     | 
| 522 | 
         
            -
             
     | 
| 523 | 
         
            -
             
     | 
| 524 | 
         
            -
             
     | 
| 525 | 
         
            -
             
     | 
| 526 | 
         
            -
             
     | 
| 527 | 
         
            -
                         
     | 
| 528 | 
         
            -
             
     | 
| 529 | 
         
            -
                                rearrange(prediction_scores, "... v -> (...) v"),
         
     | 
| 530 | 
         
            -
                                rearrange(labels, "... -> (...)"),
         
     | 
| 531 | 
         
            -
                            )
         
     | 
| 532 | 
         
            -
                        next_sentence_loss = self.nsp_loss(
         
     | 
| 533 | 
         
            -
                            rearrange(seq_relationship_score, "... t -> (...) t"),
         
     | 
| 534 | 
         
            -
                            rearrange(next_sentence_label, "... -> (...)"),
         
     | 
| 535 | 
         
            -
                        )
         
     | 
| 536 | 
         
            -
                        total_loss = masked_lm_loss.float() + next_sentence_loss.float()
         
     | 
| 537 | 
         
            -
             
     | 
| 538 | 
         
            -
                    return BertForPreTrainingOutput(
         
     | 
| 539 | 
         
            -
                        loss=total_loss,
         
     | 
| 540 | 
         
            -
                        prediction_logits=prediction_scores,
         
     | 
| 541 | 
         
            -
                        seq_relationship_logits=seq_relationship_score,
         
     | 
| 542 | 
         
             
                    )
         
     | 
| 543 | 
         | 
| 544 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 545 | 
         
             
            def remap_state_dict(state_dict, config: PretrainedConfig):
         
     | 
| 546 | 
         
             
                """
         
     | 
| 547 | 
         
             
                Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
         
     | 
| 
         | 
|
| 13 | 
         
             
            from collections import OrderedDict
         
     | 
| 14 | 
         
             
            from collections.abc import Sequence
         
     | 
| 15 | 
         
             
            from functools import partial
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
            import torch
         
     | 
| 18 | 
         
             
            import torch.nn as nn
         
     | 
| 19 | 
         
             
            import torch.nn.functional as F
         
     | 
| 20 | 
         
             
            from einops import rearrange
         
     | 
| 21 | 
         
            +
            from transformers import PretrainedConfig
         
     | 
| 22 | 
         
             
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 23 | 
         
            +
            from transformers.modeling_outputs import MaskedLMOutput
         
     | 
| 24 | 
         
            +
            from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
             
            from transformers.models.bert.modeling_bert import (
         
     | 
| 27 | 
         
             
                BaseModelOutputWithPoolingAndCrossAttentions,
         
     | 
| 28 | 
         
             
                BertForPreTrainingOutput,
         
     | 
| 29 | 
         
             
            )
         
     | 
| 30 | 
         | 
| 31 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            from .xlm_padding import (
         
     | 
| 34 | 
         
             
                index_first_axis,
         
     | 
| 35 | 
         
             
                index_first_axis_residual,
         
     | 
| 36 | 
         
             
                pad_input,
         
     | 
| 37 | 
         
             
                unpad_input,
         
     | 
| 38 | 
         
             
            )
         
     | 
| 39 | 
         
            +
            from .configuration_xlm_roberta import XLMRobertaFlashConfig
         
     | 
| 40 | 
         
             
            from .block import Block
         
     | 
| 41 | 
         
            +
            from .embedding import XLMRobertaEmbeddings
         
     | 
| 42 | 
         
             
            from .mha import MHA
         
     | 
| 43 | 
         
             
            from .mlp import FusedMLP, Mlp
         
     | 
| 44 | 
         | 
| 
         | 
|
| 159 | 
         
             
                        nn.init.zeros_(module.weight[module.padding_idx])
         
     | 
| 160 | 
         | 
| 161 | 
         | 
| 162 | 
         
            +
            class XLMRobertaEncoder(nn.Module):
         
     | 
| 163 | 
         
            +
                def __init__(self, config: XLMRobertaFlashConfig):
         
     | 
| 164 | 
         
             
                    super().__init__()
         
     | 
| 165 | 
         
             
                    self.use_flash_attn = getattr(config, "use_flash_attn", False)
         
     | 
| 166 | 
         
             
                    self.layers = nn.ModuleList(
         
     | 
| 
         | 
|
| 222 | 
         
             
                    return hidden_states
         
     | 
| 223 | 
         | 
| 224 | 
         | 
| 225 | 
         
            +
            class XLMRobertaPooler(nn.Module):
         
     | 
| 226 | 
         
             
                def __init__(self, config):
         
     | 
| 227 | 
         
             
                    super().__init__()
         
     | 
| 228 | 
         
             
                    fused_bias_fc = getattr(config, "fused_bias_fc", False)
         
     | 
| 
         | 
|
| 241 | 
         
             
                    return pooled_output
         
     | 
| 242 | 
         | 
| 243 | 
         | 
| 244 | 
         
            +
            class XLMRobertaPredictionHeadTransform(nn.Module):
         
     | 
| 245 | 
         
             
                def __init__(self, config):
         
     | 
| 246 | 
         
             
                    super().__init__()
         
     | 
| 247 | 
         
             
                    fused_bias_fc = getattr(config, "fused_bias_fc", False)
         
     | 
| 
         | 
|
| 272 | 
         
             
                    return hidden_states
         
     | 
| 273 | 
         | 
| 274 | 
         | 
| 275 | 
         
            +
            class XLMRobertaLMPredictionHead(nn.Module):
         
     | 
| 276 | 
         
             
                def __init__(self, config):
         
     | 
| 277 | 
         
             
                    super().__init__()
         
     | 
| 278 | 
         
             
                    fused_bias_fc = getattr(config, "fused_bias_fc", False)
         
     | 
| 
         | 
|
| 280 | 
         
             
                        raise ImportError("fused_dense is not installed")
         
     | 
| 281 | 
         
             
                    linear_cls = nn.Linear if not fused_bias_fc else FusedDense
         
     | 
| 282 | 
         | 
| 283 | 
         
            +
                    self.transform = XLMRobertaPredictionHeadTransform(config)
         
     | 
| 284 | 
         | 
| 285 | 
         
             
                    # The output weights are the same as the input embeddings, but there is
         
     | 
| 286 | 
         
             
                    # an output-only bias for each token.
         
     | 
| 
         | 
|
| 292 | 
         
             
                    return hidden_states
         
     | 
| 293 | 
         | 
| 294 | 
         | 
| 295 | 
         
            +
            class XLMRobertaPreTrainingHeads(nn.Module):
         
     | 
| 296 | 
         
             
                def __init__(self, config):
         
     | 
| 297 | 
         
             
                    super().__init__()
         
     | 
| 298 | 
         
            +
                    self.predictions = XLMRobertaLMPredictionHead(config)
         
     | 
| 299 | 
         
             
                    self.seq_relationship = nn.Linear(config.hidden_size, 2)
         
     | 
| 300 | 
         | 
| 301 | 
         
             
                def forward(self, sequence_output, pooled_output):
         
     | 
| 
         | 
|
| 304 | 
         
             
                    return prediction_scores, seq_relationship_score
         
     | 
| 305 | 
         | 
| 306 | 
         | 
| 307 | 
         
            +
            class XLMRobertaPreTrainedModel(PreTrainedModel):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 308 | 
         
             
                """An abstract class to handle weights initialization and
         
     | 
| 309 | 
         
             
                a simple interface for dowloading and loading pretrained models.
         
     | 
| 310 | 
         
             
                """
         
     | 
| 311 | 
         
            +
                config_class = XLMRobertaFlashConfig
         
     | 
| 312 | 
         
            +
                base_model_prefix = "roberta"
         
     | 
| 313 | 
         
             
                supports_gradient_checkpointing = True
         
     | 
| 314 | 
         | 
| 315 | 
         
             
                def _set_gradient_checkpointing(self, module, value=False):
         
     | 
| 316 | 
         
            +
                    if isinstance(module, XLMRobertaEncoder):
         
     | 
| 317 | 
         
             
                        module.gradient_checkpointing = value
         
     | 
| 318 | 
         | 
| 319 | 
         | 
| 320 | 
         | 
| 321 | 
         
            +
            class XLMRobertaModel(XLMRobertaPreTrainedModel):
         
     | 
| 322 | 
         
            +
                def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
         
     | 
| 323 | 
         
             
                    super().__init__(config)
         
     | 
| 324 | 
         
             
                    self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
         
     | 
| 325 | 
         
             
                    if config.vocab_size % self.pad_vocab_size_multiple != 0:
         
     | 
| 
         | 
|
| 331 | 
         
             
                        raise ImportError("Triton is not installed")
         
     | 
| 332 | 
         
             
                    assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
         
     | 
| 333 | 
         | 
| 334 | 
         
            +
                    self.embeddings = XLMRobertaEmbeddings(
         
     | 
| 335 | 
         
             
                        config.hidden_size,
         
     | 
| 336 | 
         
             
                        config.vocab_size,
         
     | 
| 337 | 
         
             
                        config.max_position_embeddings,
         
     | 
| 
         | 
|
| 340 | 
         
             
                    )
         
     | 
| 341 | 
         
             
                    self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
         
     | 
| 342 | 
         
             
                    self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 343 | 
         
            +
                    self.encoder = XLMRobertaEncoder(config)
         
     | 
| 344 | 
         
            +
                    self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
         
     | 
| 345 | 
         | 
| 346 | 
         
             
                    self.apply(partial(_init_weights, initializer_range=config.initializer_range))
         
     | 
| 347 | 
         | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
             
                def forward(
         
     | 
| 350 | 
         
             
                    self,
         
     | 
| 351 | 
         
             
                    input_ids,
         
     | 
| 
         | 
|
| 353 | 
         
             
                    token_type_ids=None,
         
     | 
| 354 | 
         
             
                    attention_mask=None,
         
     | 
| 355 | 
         
             
                    masked_tokens_mask=None,
         
     | 
| 356 | 
         
            +
                    return_dict=None,
         
     | 
| 357 | 
         
            +
                    **kwargs,
         
     | 
| 358 | 
         
             
                ):
         
     | 
| 359 | 
         
            +
                    """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
         
     | 
| 360 | 
         
             
                    we only want the output for the masked tokens. This means that we only compute the last
         
     | 
| 361 | 
         
             
                    layer output for these tokens.
         
     | 
| 362 | 
         
             
                    masked_tokens_mask: (batch, seqlen), dtype=torch.bool
         
     | 
| 363 | 
         
             
                    """
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    if kwargs:
         
     | 
| 366 | 
         
            +
                        for key, value in kwargs.items():
         
     | 
| 367 | 
         
            +
                            if value is not None:
         
     | 
| 368 | 
         
            +
                                logger.warning('Flash attention implementation does not support kwargs: %s', key)
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
             
                    hidden_states = self.embeddings(
         
     | 
| 373 | 
         
             
                        input_ids, position_ids=position_ids, token_type_ids=token_type_ids
         
     | 
| 374 | 
         
             
                    )
         
     | 
| 
         | 
|
| 410 | 
         
             
                            sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
         
     | 
| 411 | 
         
             
                        pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
         
     | 
| 412 | 
         | 
| 413 | 
         
            +
                    if not return_dict:
         
     | 
| 414 | 
         
            +
                        return sequence_output, pooled_output
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
             
                    return BaseModelOutputWithPoolingAndCrossAttentions(
         
     | 
| 417 | 
         
             
                        last_hidden_state=sequence_output,
         
     | 
| 418 | 
         
             
                        pooler_output=pooled_output,
         
     | 
| 419 | 
         
             
                    )
         
     | 
| 420 | 
         | 
| 421 | 
         | 
| 422 | 
         
            +
            class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
         
     | 
| 423 | 
         
            +
                _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
         
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
                def __init__(self, config):
         
     | 
| 426 | 
         
             
                    super().__init__(config)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 427 | 
         | 
| 428 | 
         
            +
                    if config.is_decoder:
         
     | 
| 429 | 
         
            +
                        logger.warning(
         
     | 
| 430 | 
         
            +
                            "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
         
     | 
| 431 | 
         
            +
                            "bi-directional self-attention."
         
     | 
| 432 | 
         
            +
                        )
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
         
     | 
| 435 | 
         
            +
                    self.lm_head = XLMRobertaLMHead(config)
         
     | 
| 436 | 
         | 
| 437 | 
         
             
                    # Initialize weights and apply final processing
         
     | 
| 438 | 
         
            +
                    self.post_init()
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 441 | 
         
            +
                    return self.roberta.embeddings.word_embeddings
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                def get_output_embeddings(self):
         
     | 
| 444 | 
         
            +
                    return self.lm_head.decoder
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                def set_output_embeddings(self, new_embeddings):
         
     | 
| 447 | 
         
            +
                    self.lm_head.decoder = new_embeddings
         
     | 
| 448 | 
         | 
| 
         | 
|
| 
         | 
|
| 449 | 
         | 
| 450 | 
         
             
                def forward(
         
     | 
| 451 | 
         
             
                    self,
         
     | 
| 452 | 
         
            +
                    input_ids: Optional[torch.LongTensor] = None,
         
     | 
| 453 | 
         
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 454 | 
         
            +
                    token_type_ids: Optional[torch.LongTensor] = None,
         
     | 
| 455 | 
         
            +
                    position_ids: Optional[torch.LongTensor] = None,
         
     | 
| 456 | 
         
            +
                    head_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 457 | 
         
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 458 | 
         
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         
     | 
| 459 | 
         
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         
     | 
| 460 | 
         
            +
                    labels: Optional[torch.LongTensor] = None,
         
     | 
| 461 | 
         
            +
                    output_attentions: Optional[bool] = None,
         
     | 
| 462 | 
         
            +
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 463 | 
         
            +
                    return_dict: Optional[bool] = None,
         
     | 
| 464 | 
         
            +
                ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
         
     | 
| 465 | 
         
            +
                    r"""
         
     | 
| 466 | 
         
            +
                    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
         
     | 
| 467 | 
         
            +
                        Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
         
     | 
| 468 | 
         
            +
                        config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
         
     | 
| 469 | 
         
            +
                        loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
         
     | 
| 470 | 
         
            +
                    kwargs (`Dict[str, any]`, optional, defaults to *{}*):
         
     | 
| 471 | 
         
            +
                        Used to hide legacy arguments that have been deprecated.
         
     | 
| 472 | 
         
             
                    """
         
     | 
| 473 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 474 | 
         | 
| 475 | 
         
            +
                    outputs = self.roberta(
         
     | 
| 
         | 
|
| 
         | 
|
| 476 | 
         
             
                        input_ids,
         
     | 
| 477 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 478 | 
         
             
                        token_type_ids=token_type_ids,
         
     | 
| 479 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 480 | 
         
            +
                        head_mask=head_mask,
         
     | 
| 481 | 
         
            +
                        inputs_embeds=inputs_embeds,
         
     | 
| 482 | 
         
            +
                        encoder_hidden_states=encoder_hidden_states,
         
     | 
| 483 | 
         
            +
                        encoder_attention_mask=encoder_attention_mask,
         
     | 
| 484 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 485 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 486 | 
         
            +
                        return_dict=return_dict,
         
     | 
| 487 | 
         
             
                    )
         
     | 
| 488 | 
         
            +
                    sequence_output = outputs[0]
         
     | 
| 489 | 
         
            +
                    prediction_scores = self.lm_head(sequence_output)
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                    masked_lm_loss = None
         
     | 
| 492 | 
         
            +
                    if labels is not None:
         
     | 
| 493 | 
         
            +
                        # move labels to correct device to enable model parallelism
         
     | 
| 494 | 
         
            +
                        labels = labels.to(prediction_scores.device)
         
     | 
| 495 | 
         
            +
                        loss_fct = CrossEntropyLoss()
         
     | 
| 496 | 
         
            +
                        masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    if not return_dict:
         
     | 
| 499 | 
         
            +
                        output = (prediction_scores,) + outputs[2:]
         
     | 
| 500 | 
         
            +
                        return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                    return MaskedLMOutput(
         
     | 
| 503 | 
         
            +
                        loss=masked_lm_loss,
         
     | 
| 504 | 
         
            +
                        logits=prediction_scores,
         
     | 
| 505 | 
         
            +
                        hidden_states=outputs.hidden_states,
         
     | 
| 506 | 
         
            +
                        attentions=outputs.attentions,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 507 | 
         
             
                    )
         
     | 
| 508 | 
         | 
| 509 | 
         | 
| 510 | 
         
            +
            # class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
         
     | 
| 511 | 
         
            +
            #     def __init__(self, config: XLMRobertaFlashConfig):
         
     | 
| 512 | 
         
            +
            #         super().__init__(config)
         
     | 
| 513 | 
         
            +
            #         # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
         
     | 
| 514 | 
         
            +
            #         # (around 15%) to the classifier heads.
         
     | 
| 515 | 
         
            +
            #         self.dense_seq_output = getattr(config, "dense_seq_output", False)
         
     | 
| 516 | 
         
            +
            #         # If last_layer_subset, we only need the compute the last layer for a subset of tokens
         
     | 
| 517 | 
         
            +
            #         # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
         
     | 
| 518 | 
         
            +
            #         self.last_layer_subset = getattr(config, "last_layer_subset", False)
         
     | 
| 519 | 
         
            +
            #         if self.last_layer_subset:
         
     | 
| 520 | 
         
            +
            #             assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
         
     | 
| 521 | 
         
            +
            #         use_xentropy = getattr(config, "use_xentropy", False)
         
     | 
| 522 | 
         
            +
            #         if use_xentropy and CrossEntropyLoss is None:
         
     | 
| 523 | 
         
            +
            #             raise ImportError("xentropy_cuda is not installed")
         
     | 
| 524 | 
         
            +
            #         loss_cls = (
         
     | 
| 525 | 
         
            +
            #             nn.CrossEntropyLoss
         
     | 
| 526 | 
         
            +
            #             if not use_xentropy
         
     | 
| 527 | 
         
            +
            #             else partial(CrossEntropyLoss, inplace_backward=True)
         
     | 
| 528 | 
         
            +
            #         )
         
     | 
| 529 | 
         
            +
            #
         
     | 
| 530 | 
         
            +
            #         self.xlm = XLMRobertaModel(config)
         
     | 
| 531 | 
         
            +
            #         self.cls = XLMRobertaPreTrainingHeads(config)
         
     | 
| 532 | 
         
            +
            #         self.mlm_loss = loss_cls(ignore_index=0)
         
     | 
| 533 | 
         
            +
            #         self.nsp_loss = loss_cls(ignore_index=-1)
         
     | 
| 534 | 
         
            +
            #
         
     | 
| 535 | 
         
            +
            #         # Initialize weights and apply final processing
         
     | 
| 536 | 
         
            +
            #         self.apply(partial(_init_weights, initializer_range=config.initializer_range))
         
     | 
| 537 | 
         
            +
            #         self.tie_weights()
         
     | 
| 538 | 
         
            +
            #
         
     | 
| 539 | 
         
            +
            #     def tie_weights(self):
         
     | 
| 540 | 
         
            +
            #         self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
         
     | 
| 541 | 
         
            +
            #
         
     | 
| 542 | 
         
            +
            #     def forward(
         
     | 
| 543 | 
         
            +
            #         self,
         
     | 
| 544 | 
         
            +
            #         input_ids,
         
     | 
| 545 | 
         
            +
            #         position_ids=None,
         
     | 
| 546 | 
         
            +
            #         token_type_ids=None,
         
     | 
| 547 | 
         
            +
            #         attention_mask=None,
         
     | 
| 548 | 
         
            +
            #         labels=None,
         
     | 
| 549 | 
         
            +
            #         next_sentence_label=None,
         
     | 
| 550 | 
         
            +
            #     ):
         
     | 
| 551 | 
         
            +
            #         """
         
     | 
| 552 | 
         
            +
            #         If labels are provided, they must be 0 for masked out tokens (as specified in the attention
         
     | 
| 553 | 
         
            +
            #         mask).
         
     | 
| 554 | 
         
            +
            #         Outputs:
         
     | 
| 555 | 
         
            +
            #             if `labels` and `next_sentence_label` are not `None`:
         
     | 
| 556 | 
         
            +
            #                 Outputs the total_loss which is the sum of the masked language modeling loss and the next
         
     | 
| 557 | 
         
            +
            #                 sentence classification loss.
         
     | 
| 558 | 
         
            +
            #             if `labels` or `next_sentence_label` is `None`:
         
     | 
| 559 | 
         
            +
            #                 Outputs a tuple comprising
         
     | 
| 560 | 
         
            +
            #                 - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
         
     | 
| 561 | 
         
            +
            #                 - the next sentence classification logits of shape [batch_size, 2].
         
     | 
| 562 | 
         
            +
            #
         
     | 
| 563 | 
         
            +
            #         """
         
     | 
| 564 | 
         
            +
            #         masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
         
     | 
| 565 | 
         
            +
            #         outputs = self.xlm(
         
     | 
| 566 | 
         
            +
            #             input_ids,
         
     | 
| 567 | 
         
            +
            #             position_ids=position_ids,
         
     | 
| 568 | 
         
            +
            #             token_type_ids=token_type_ids,
         
     | 
| 569 | 
         
            +
            #             attention_mask=attention_mask.bool() if attention_mask is not None else None,
         
     | 
| 570 | 
         
            +
            #             masked_tokens_mask=masked_tokens_mask,
         
     | 
| 571 | 
         
            +
            #         )
         
     | 
| 572 | 
         
            +
            #         sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
         
     | 
| 573 | 
         
            +
            #         if self.dense_seq_output and labels is not None:
         
     | 
| 574 | 
         
            +
            #             masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
         
     | 
| 575 | 
         
            +
            #             if not self.last_layer_subset:
         
     | 
| 576 | 
         
            +
            #                 sequence_output = index_first_axis(
         
     | 
| 577 | 
         
            +
            #                     rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
         
     | 
| 578 | 
         
            +
            #                 )
         
     | 
| 579 | 
         
            +
            #         prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
         
     | 
| 580 | 
         
            +
            #
         
     | 
| 581 | 
         
            +
            #         total_loss = None
         
     | 
| 582 | 
         
            +
            #         if labels is not None and next_sentence_label is not None:
         
     | 
| 583 | 
         
            +
            #             if (
         
     | 
| 584 | 
         
            +
            #                 self.dense_seq_output and labels is not None
         
     | 
| 585 | 
         
            +
            #             ):  # prediction_scores are already flattened
         
     | 
| 586 | 
         
            +
            #                 masked_lm_loss = self.mlm_loss(
         
     | 
| 587 | 
         
            +
            #                     prediction_scores, labels.flatten()[masked_token_idx]
         
     | 
| 588 | 
         
            +
            #                 )
         
     | 
| 589 | 
         
            +
            #             else:
         
     | 
| 590 | 
         
            +
            #                 masked_lm_loss = self.mlm_loss(
         
     | 
| 591 | 
         
            +
            #                     rearrange(prediction_scores, "... v -> (...) v"),
         
     | 
| 592 | 
         
            +
            #                     rearrange(labels, "... -> (...)"),
         
     | 
| 593 | 
         
            +
            #                 )
         
     | 
| 594 | 
         
            +
            #             next_sentence_loss = self.nsp_loss(
         
     | 
| 595 | 
         
            +
            #                 rearrange(seq_relationship_score, "... t -> (...) t"),
         
     | 
| 596 | 
         
            +
            #                 rearrange(next_sentence_label, "... -> (...)"),
         
     | 
| 597 | 
         
            +
            #             )
         
     | 
| 598 | 
         
            +
            #             total_loss = masked_lm_loss.float() + next_sentence_loss.float()
         
     | 
| 599 | 
         
            +
            #
         
     | 
| 600 | 
         
            +
            #         return BertForPreTrainingOutput(
         
     | 
| 601 | 
         
            +
            #             loss=total_loss,
         
     | 
| 602 | 
         
            +
            #             prediction_logits=prediction_scores,
         
     | 
| 603 | 
         
            +
            #             seq_relationship_logits=seq_relationship_score,
         
     | 
| 604 | 
         
            +
            #         )
         
     | 
| 605 | 
         
            +
             
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
             
            def remap_state_dict(state_dict, config: PretrainedConfig):
         
     | 
| 608 | 
         
             
                """
         
     | 
| 609 | 
         
             
                Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
         
     | 
    	
        pytorch_model.bin
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            -
            oid sha256: 
     | 
| 3 | 
         
            -
            size  
     | 
| 
         | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:cfa8fa7c7e120199548fe7149512c0adfe58f6bc13ce19f09b895aa25e8af910
         
     | 
| 3 | 
         
            +
            size 1113232188
         
     | 
    	
        bert_padding.py → xlm_padding.py
    RENAMED
    
    | 
         
            File without changes
         
     |