mean_pooling fix
Browse files- semscore.py +4 -3
semscore.py
CHANGED
|
@@ -89,7 +89,8 @@ class SemScore(evaluate.Metric):
|
|
| 89 |
self.model.eval()
|
| 90 |
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
| 91 |
|
| 92 |
-
|
|
|
|
| 93 |
"""Mean pooling over all tokens - take attention mask into account for correct averaging"""
|
| 94 |
token_embeddings = model_output[0]
|
| 95 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
@@ -122,8 +123,8 @@ class SemScore(evaluate.Metric):
|
|
| 122 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
| 123 |
model_output_refs = self.model(**encoded_refs.to(device))
|
| 124 |
model_output_preds = self.model(**encoded_preds.to(device))
|
| 125 |
-
batch_pooled_refs = self.
|
| 126 |
-
batch_pooled_preds = self.
|
| 127 |
pooled_refs.append(batch_pooled_refs)
|
| 128 |
pooled_preds.append(batch_pooled_preds)
|
| 129 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
|
|
|
| 89 |
self.model.eval()
|
| 90 |
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
| 91 |
|
| 92 |
+
@staticmethod
|
| 93 |
+
def _mean_pooling(model_output, attention_mask):
|
| 94 |
"""Mean pooling over all tokens - take attention mask into account for correct averaging"""
|
| 95 |
token_embeddings = model_output[0]
|
| 96 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
|
|
| 123 |
encoded_preds = self.tokenizer(batch_preds, padding=True, truncation=True, return_tensors='pt')
|
| 124 |
model_output_refs = self.model(**encoded_refs.to(device))
|
| 125 |
model_output_preds = self.model(**encoded_preds.to(device))
|
| 126 |
+
batch_pooled_refs = self._mean_pooling(model_output_refs, encoded_refs['attention_mask'])
|
| 127 |
+
batch_pooled_preds = self._mean_pooling(model_output_preds, encoded_preds['attention_mask'])
|
| 128 |
pooled_refs.append(batch_pooled_refs)
|
| 129 |
pooled_preds.append(batch_pooled_preds)
|
| 130 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|