Revert dict
Browse files- modeling_flexbert.py +9 -27
modeling_flexbert.py
CHANGED
@@ -50,7 +50,7 @@ import os
|
|
50 |
import sys
|
51 |
import warnings
|
52 |
from dataclasses import dataclass
|
53 |
-
from typing import List, Optional, Tuple, Union
|
54 |
|
55 |
# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
|
56 |
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
@@ -944,36 +944,18 @@ class FlexBertModel(FlexBertPreTrainedModel):
|
|
944 |
|
945 |
def forward(
|
946 |
self,
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
# max_seqlen: Optional[int] = None,
|
954 |
**kwargs,
|
955 |
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
|
956 |
-
|
957 |
-
if features["attention_mask"] is None:
|
958 |
attention_mask = torch.ones_like(input_ids)
|
959 |
-
|
960 |
-
attention_mask = features["attention_mask"]
|
961 |
-
input_ids = features["input_ids"]
|
962 |
-
if "position_ids" not in features:
|
963 |
-
position_ids = None
|
964 |
-
else:
|
965 |
-
position_ids = features["position_ids"]
|
966 |
embedding_output = self.embeddings(input_ids, position_ids)
|
967 |
-
if "indices" not in features:
|
968 |
-
indices = None
|
969 |
-
else:
|
970 |
-
indices = features["indices"]
|
971 |
-
if "cu_seqlens" not in features:
|
972 |
-
cu_seqlens = None
|
973 |
-
else:
|
974 |
-
cu_seqlens = features["cu_seqlens"]
|
975 |
-
if "max_seqlen" not in features:
|
976 |
-
max_seqlen = None
|
977 |
|
978 |
encoder_outputs = self.encoder(
|
979 |
hidden_states=embedding_output,
|
|
|
50 |
import sys
|
51 |
import warnings
|
52 |
from dataclasses import dataclass
|
53 |
+
from typing import List, Optional, Tuple, Union
|
54 |
|
55 |
# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
|
56 |
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
|
|
944 |
|
945 |
def forward(
|
946 |
self,
|
947 |
+
input_ids: torch.Tensor,
|
948 |
+
attention_mask: Optional[torch.Tensor] = None,
|
949 |
+
position_ids: Optional[torch.Tensor] = None,
|
950 |
+
indices: Optional[torch.Tensor] = None,
|
951 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
952 |
+
max_seqlen: Optional[int] = None,
|
|
|
953 |
**kwargs,
|
954 |
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
|
955 |
+
if attention_mask is None:
|
|
|
956 |
attention_mask = torch.ones_like(input_ids)
|
957 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
958 |
embedding_output = self.embeddings(input_ids, position_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
959 |
|
960 |
encoder_outputs = self.encoder(
|
961 |
hidden_states=embedding_output,
|