jwieting commited on
Commit
5268492
·
1 Parent(s): b81b79a

Update modeling_paragram_sp.py

Browse files
Files changed (1) hide show
  1. modeling_paragram_sp.py +2 -2
modeling_paragram_sp.py CHANGED
@@ -10,7 +10,7 @@ class ParagramSPModel(BertPreTrainedModel):
10
  # Initialize weights and apply final processing
11
  self.post_init()
12
 
13
- def filter_input_ids(input_ids):
14
  output = []
15
  len = input_ids.shape[1]
16
  for ids in input_ids.shape[0]:
@@ -26,7 +26,7 @@ class ParagramSPModel(BertPreTrainedModel):
26
  def forward(self, input_ids, attention_mask):
27
  print(input_ids)
28
  print(attention_mask)
29
- input_ids = filter_input_ids(input_ids)
30
  attention_mask = input_ids > 0
31
  embeddings = self.word_embeddings(input_ids)
32
  masked_embeddings = embeddings * attention_mask[:, :, None]
 
10
  # Initialize weights and apply final processing
11
  self.post_init()
12
 
13
+ def filter_input_ids(self, input_ids):
14
  output = []
15
  len = input_ids.shape[1]
16
  for ids in input_ids.shape[0]:
 
26
  def forward(self, input_ids, attention_mask):
27
  print(input_ids)
28
  print(attention_mask)
29
+ input_ids = self.filter_input_ids(input_ids)
30
  attention_mask = input_ids > 0
31
  embeddings = self.word_embeddings(input_ids)
32
  masked_embeddings = embeddings * attention_mask[:, :, None]