aynetdia commited on
Commit
5fadac7
·
1 Parent(s): 0b695c3

mean_pooling fix

Browse files
Files changed (1) hide show
  1. semscore.py +4 -3
semscore.py CHANGED
@@ -89,7 +89,8 @@ class SemScore(evaluate.Metric):
89
  self.model.eval()
90
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
91
 
92
- def mean_pooling(model_output, attention_mask):
 
93
  """Mean pooling over all tokens - take attention mask into account for correct averaging"""
94
  token_embeddings = model_output[0]
95
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
@@ -122,8 +123,8 @@ class SemScore(evaluate.Metric):
122
  encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
123
  model_output_refs = self.model(**encoded_refs.to(device))
124
  model_output_preds = self.model(**encoded_preds.to(device))
125
- batch_pooled_refs = self.mean_pooling(model_output_refs, encoded_refs['attention_mask'])
126
- batch_pooled_preds = self.mean_pooling(model_output_preds, encoded_preds['attention_mask'])
127
  pooled_refs.append(batch_pooled_refs)
128
  pooled_preds.append(batch_pooled_preds)
129
  pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
 
89
  self.model.eval()
90
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
91
 
92
+ @staticmethod
93
+ def _mean_pooling(model_output, attention_mask):
94
  """Mean pooling over all tokens - take attention mask into account for correct averaging"""
95
  token_embeddings = model_output[0]
96
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
 
123
  encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
124
  model_output_refs = self.model(**encoded_refs.to(device))
125
  model_output_preds = self.model(**encoded_preds.to(device))
126
+ batch_pooled_refs = self._mean_pooling(model_output_refs, encoded_refs['attention_mask'])
127
+ batch_pooled_preds = self._mean_pooling(model_output_preds, encoded_preds['attention_mask'])
128
  pooled_refs.append(batch_pooled_refs)
129
  pooled_preds.append(batch_pooled_preds)
130
  pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)