SentenceTransformer / tests /test_evaluator.py
lengocduc195's picture
pushNe
2359bda
"""
Tests the correct computation of evaluation scores from BinaryClassificationEvaluator
"""
from sentence_transformers import SentenceTransformer, evaluation, util, losses, LoggingHandler
import logging
import unittest
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
import gzip
import csv
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
import os
class EvaluatorTest(unittest.TestCase):
def test_BinaryClassificationEvaluator_find_best_f1_and_threshold(self):
"""Tests that the F1 score for the computed threshold is correct"""
y_true = np.random.randint(0, 2, 1000)
y_pred_cosine = np.random.randn(1000)
best_f1, best_precision, best_recall, threshold = evaluation.BinaryClassificationEvaluator.find_best_f1_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True)
y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine]
sklearn_f1score = f1_score(y_true, y_pred_labels)
assert np.abs(best_f1 - sklearn_f1score) < 1e-6
def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold(self):
"""Tests that the Acc score for the computed threshold is correct"""
y_true = np.random.randint(0, 2, 1000)
y_pred_cosine = np.random.randn(1000)
max_acc, threshold = evaluation.BinaryClassificationEvaluator.find_best_acc_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True)
y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine]
sklearn_acc = accuracy_score(y_true, y_pred_labels)
assert np.abs(max_acc - sklearn_acc) < 1e-6
def test_LabelAccuracyEvaluator(self):
"""Tests that the LabelAccuracyEvaluator can be loaded correctly"""
model = SentenceTransformer('paraphrase-distilroberta-base-v1')
nli_dataset_path = 'datasets/AllNLI.tsv.gz'
if not os.path.exists(nli_dataset_path):
util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)
label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
dev_samples = []
with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
if row['split'] == 'train':
label_id = label2int[row['label']]
dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label_id))
if len(dev_samples) >= 100:
break
train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int))
dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16)
evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss)
acc = evaluator(model)
assert acc > 0.2
def test_ParaphraseMiningEvaluator(self):
"""Tests that the ParaphraseMiningEvaluator can be loaded"""
model = SentenceTransformer('paraphrase-distilroberta-base-v1')
sentences = {0: "Hello World", 1: "Hello World!", 2: "The cat is on the table", 3: "On the table the cat is"}
data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0,1), (2,3)])
score = data_eval(model)
assert score > 0.99