Commit
·
30e6a10
1
Parent(s):
2e3ebcb
change config name
Browse files- config.json +1 -1
- configuration_bert.py +1 -1
- modeling_bert.py +3 -2
config.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"auto_map": {
|
| 3 |
-
"AutoConfig": "configuration_bert.
|
| 4 |
"AutoModel": "modeling_bert.BertModel",
|
| 5 |
"AutoModelForPreTraining": "modeling_bert.BertForPreTraining",
|
| 6 |
"AutoModelForMaskedLM": "modeling_bert.BertForPreTraining"
|
|
|
|
| 1 |
{
|
| 2 |
"auto_map": {
|
| 3 |
+
"AutoConfig": "configuration_bert.XLMFlashConfig",
|
| 4 |
"AutoModel": "modeling_bert.BertModel",
|
| 5 |
"AutoModelForPreTraining": "modeling_bert.BertForPreTraining",
|
| 6 |
"AutoModelForMaskedLM": "modeling_bert.BertForPreTraining"
|
configuration_bert.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
| 3 |
-
class
|
| 4 |
def __init__(
|
| 5 |
self,
|
| 6 |
vocab_size=30522,
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
| 3 |
+
class XLMFlashConfig(PretrainedConfig):
|
| 4 |
def __init__(
|
| 5 |
self,
|
| 6 |
vocab_size=30522,
|
modeling_bert.py
CHANGED
|
@@ -19,7 +19,7 @@ import torch
|
|
| 19 |
import torch.nn as nn
|
| 20 |
import torch.nn.functional as F
|
| 21 |
from einops import rearrange
|
| 22 |
-
from transformers import BertConfig, PretrainedConfig
|
| 23 |
from transformers.modeling_utils import PreTrainedModel
|
| 24 |
from transformers.models.bert.modeling_bert import (
|
| 25 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
@@ -32,6 +32,7 @@ from .bert_padding import (
|
|
| 32 |
pad_input,
|
| 33 |
unpad_input,
|
| 34 |
)
|
|
|
|
| 35 |
from .block import Block
|
| 36 |
from .embedding import BertEmbeddings
|
| 37 |
from .mha import MHA
|
|
@@ -345,7 +346,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
|
| 345 |
"""An abstract class to handle weights initialization and
|
| 346 |
a simple interface for dowloading and loading pretrained models.
|
| 347 |
"""
|
| 348 |
-
config_class =
|
| 349 |
base_model_prefix = "bert"
|
| 350 |
supports_gradient_checkpointing = True
|
| 351 |
|
|
|
|
| 19 |
import torch.nn as nn
|
| 20 |
import torch.nn.functional as F
|
| 21 |
from einops import rearrange
|
| 22 |
+
from transformers import BertConfig, PretrainedConfig
|
| 23 |
from transformers.modeling_utils import PreTrainedModel
|
| 24 |
from transformers.models.bert.modeling_bert import (
|
| 25 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
|
|
| 32 |
pad_input,
|
| 33 |
unpad_input,
|
| 34 |
)
|
| 35 |
+
from .configuration_bert import XLMFlashConfig
|
| 36 |
from .block import Block
|
| 37 |
from .embedding import BertEmbeddings
|
| 38 |
from .mha import MHA
|
|
|
|
| 346 |
"""An abstract class to handle weights initialization and
|
| 347 |
a simple interface for dowloading and loading pretrained models.
|
| 348 |
"""
|
| 349 |
+
config_class = XLMFlashConfig
|
| 350 |
base_model_prefix = "bert"
|
| 351 |
supports_gradient_checkpointing = True
|
| 352 |
|