Delete AbLang_bert_model.py
Browse files- AbLang_bert_model.py +0 -34
AbLang_bert_model.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertEmbeddings, BertForMaskedLM, MaskedLMOutput
|
2 |
-
from transformers import BertModel
|
3 |
-
from typing import List, Optional, Tuple, Union
|
4 |
-
import torch
|
5 |
-
|
6 |
-
class BertEmbeddingsV2(BertEmbeddings):
|
7 |
-
def __init__(self, config):
|
8 |
-
super().__init__(config)
|
9 |
-
self.pad_token_id = config.pad_token_id
|
10 |
-
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) # here padding_idx is always 0
|
11 |
-
|
12 |
-
def forward(
|
13 |
-
self,
|
14 |
-
input_ids: torch.LongTensor,
|
15 |
-
token_type_ids: Optional[torch.LongTensor] = None,
|
16 |
-
position_ids: Optional[torch.LongTensor] = None,
|
17 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
18 |
-
past_key_values_length: int = 0,
|
19 |
-
) -> torch.Tensor:
|
20 |
-
inputs_embeds = self.word_embeddings(input_ids)
|
21 |
-
position_ids = self.create_position_ids_from_input_ids(input_ids)
|
22 |
-
position_embeddings = self.position_embeddings(position_ids)
|
23 |
-
embeddings = inputs_embeds + position_embeddings
|
24 |
-
return self.dropout(self.LayerNorm(embeddings))
|
25 |
-
|
26 |
-
def create_position_ids_from_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
27 |
-
mask = input_ids.ne(self.pad_token_id).int()
|
28 |
-
return torch.cumsum(mask, dim=1).long() * mask
|
29 |
-
|
30 |
-
|
31 |
-
class BertModelV2(BertModel):
|
32 |
-
def __init__(self, config):
|
33 |
-
super().__init__(config)
|
34 |
-
self.embeddings = BertEmbeddingsV2(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|