torch.nn.functional fix
Browse files- semscore.py +1 -1
semscore.py
CHANGED
@@ -129,7 +129,7 @@ class SemScore(evaluate.Metric):
|
|
129 |
pooled_preds.append(batch_pooled_preds)
|
130 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
131 |
|
132 |
-
similarities = torch.nn.functional.
|
133 |
similarities = similarities * 100
|
134 |
semscore = torch.mean(similarities)
|
135 |
|
|
|
129 |
pooled_preds.append(batch_pooled_preds)
|
130 |
pooled_refs, pooled_preds = torch.cat(pooled_refs), torch.cat(pooled_preds)
|
131 |
|
132 |
+
similarities = torch.nn.functional.cosine_similarity(pooled_refs, pooled_preds)
|
133 |
similarities = similarities * 100
|
134 |
semscore = torch.mean(similarities)
|
135 |
|