NohTow commited on
Commit
29f5554
·
1 Parent(s): ce9aa51

Revert dict

Browse files
Files changed (1) hide show
  1. modeling_flexbert.py +9 -27
modeling_flexbert.py CHANGED
@@ -50,7 +50,7 @@ import os
50
  import sys
51
  import warnings
52
  from dataclasses import dataclass
53
- from typing import List, Optional, Tuple, Union, Dict
54
 
55
  # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
  sys.path.append(os.path.dirname(os.path.realpath(__file__)))
@@ -944,36 +944,18 @@ class FlexBertModel(FlexBertPreTrainedModel):
944
 
945
  def forward(
946
  self,
947
- features: Dict[str, torch.Tensor],
948
- # input_ids: torch.Tensor,
949
- # attention_mask: Optional[torch.Tensor] = None,
950
- # position_ids: Optional[torch.Tensor] = None,
951
- # indices: Optional[torch.Tensor] = None,
952
- # cu_seqlens: Optional[torch.Tensor] = None,
953
- # max_seqlen: Optional[int] = None,
954
  **kwargs,
955
  ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
956
-
957
- if features["attention_mask"] is None:
958
  attention_mask = torch.ones_like(input_ids)
959
- else:
960
- attention_mask = features["attention_mask"]
961
- input_ids = features["input_ids"]
962
- if "position_ids" not in features:
963
- position_ids = None
964
- else:
965
- position_ids = features["position_ids"]
966
  embedding_output = self.embeddings(input_ids, position_ids)
967
- if "indices" not in features:
968
- indices = None
969
- else:
970
- indices = features["indices"]
971
- if "cu_seqlens" not in features:
972
- cu_seqlens = None
973
- else:
974
- cu_seqlens = features["cu_seqlens"]
975
- if "max_seqlen" not in features:
976
- max_seqlen = None
977
 
978
  encoder_outputs = self.encoder(
979
  hidden_states=embedding_output,
 
50
  import sys
51
  import warnings
52
  from dataclasses import dataclass
53
+ from typing import List, Optional, Tuple, Union
54
 
55
  # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
  sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 
944
 
945
  def forward(
946
  self,
947
+ input_ids: torch.Tensor,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.Tensor] = None,
950
+ indices: Optional[torch.Tensor] = None,
951
+ cu_seqlens: Optional[torch.Tensor] = None,
952
+ max_seqlen: Optional[int] = None,
 
953
  **kwargs,
954
  ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
955
+ if attention_mask is None:
 
956
  attention_mask = torch.ones_like(input_ids)
957
+
 
 
 
 
 
 
958
  embedding_output = self.embeddings(input_ids, position_ids)
 
 
 
 
 
 
 
 
 
 
959
 
960
  encoder_outputs = self.encoder(
961
  hidden_states=embedding_output,