|
""" |
|
Tests the correct computation of evaluation scores from BinaryClassificationEvaluator |
|
""" |
|
from sentence_transformers import SentenceTransformer, evaluation, util, losses, LoggingHandler |
|
import logging |
|
import unittest |
|
from sklearn.metrics import f1_score, accuracy_score |
|
import numpy as np |
|
import gzip |
|
import csv |
|
from sentence_transformers import InputExample |
|
from torch.utils.data import DataLoader |
|
import os |
|
|
|
class EvaluatorTest(unittest.TestCase): |
|
|
|
def test_BinaryClassificationEvaluator_find_best_f1_and_threshold(self): |
|
"""Tests that the F1 score for the computed threshold is correct""" |
|
y_true = np.random.randint(0, 2, 1000) |
|
y_pred_cosine = np.random.randn(1000) |
|
best_f1, best_precision, best_recall, threshold = evaluation.BinaryClassificationEvaluator.find_best_f1_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True) |
|
y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine] |
|
sklearn_f1score = f1_score(y_true, y_pred_labels) |
|
assert np.abs(best_f1 - sklearn_f1score) < 1e-6 |
|
|
|
|
|
def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold(self): |
|
"""Tests that the Acc score for the computed threshold is correct""" |
|
y_true = np.random.randint(0, 2, 1000) |
|
y_pred_cosine = np.random.randn(1000) |
|
max_acc, threshold = evaluation.BinaryClassificationEvaluator.find_best_acc_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True) |
|
y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine] |
|
sklearn_acc = accuracy_score(y_true, y_pred_labels) |
|
assert np.abs(max_acc - sklearn_acc) < 1e-6 |
|
|
|
def test_LabelAccuracyEvaluator(self): |
|
"""Tests that the LabelAccuracyEvaluator can be loaded correctly""" |
|
model = SentenceTransformer('paraphrase-distilroberta-base-v1') |
|
|
|
nli_dataset_path = 'datasets/AllNLI.tsv.gz' |
|
if not os.path.exists(nli_dataset_path): |
|
util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path) |
|
|
|
label2int = {"contradiction": 0, "entailment": 1, "neutral": 2} |
|
dev_samples = [] |
|
with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn: |
|
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE) |
|
for row in reader: |
|
if row['split'] == 'train': |
|
label_id = label2int[row['label']] |
|
dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label_id)) |
|
if len(dev_samples) >= 100: |
|
break |
|
|
|
train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int)) |
|
|
|
dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16) |
|
evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss) |
|
acc = evaluator(model) |
|
assert acc > 0.2 |
|
|
|
def test_ParaphraseMiningEvaluator(self): |
|
"""Tests that the ParaphraseMiningEvaluator can be loaded""" |
|
model = SentenceTransformer('paraphrase-distilroberta-base-v1') |
|
sentences = {0: "Hello World", 1: "Hello World!", 2: "The cat is on the table", 3: "On the table the cat is"} |
|
data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0,1), (2,3)]) |
|
score = data_eval(model) |
|
assert score > 0.99 |