File size: 3,464 Bytes
2359bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Tests the correct computation of evaluation scores from BinaryClassificationEvaluator
"""
from sentence_transformers import SentenceTransformer, evaluation, util, losses, LoggingHandler
import logging
import unittest
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
import gzip
import csv
from sentence_transformers import InputExample
from torch.utils.data import DataLoader
import os

class EvaluatorTest(unittest.TestCase):

    def test_BinaryClassificationEvaluator_find_best_f1_and_threshold(self):
        """Tests that the F1 score for the computed threshold is correct"""
        y_true = np.random.randint(0, 2, 1000)
        y_pred_cosine = np.random.randn(1000)
        best_f1, best_precision, best_recall, threshold = evaluation.BinaryClassificationEvaluator.find_best_f1_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True)
        y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine]
        sklearn_f1score = f1_score(y_true, y_pred_labels)
        assert np.abs(best_f1 - sklearn_f1score) < 1e-6


    def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold(self):
        """Tests that the Acc score for the computed threshold is correct"""
        y_true = np.random.randint(0, 2, 1000)
        y_pred_cosine = np.random.randn(1000)
        max_acc, threshold = evaluation.BinaryClassificationEvaluator.find_best_acc_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True)
        y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine]
        sklearn_acc = accuracy_score(y_true, y_pred_labels)
        assert np.abs(max_acc - sklearn_acc) < 1e-6

    def test_LabelAccuracyEvaluator(self):
        """Tests that the LabelAccuracyEvaluator can be loaded correctly"""
        model = SentenceTransformer('paraphrase-distilroberta-base-v1')

        nli_dataset_path = 'datasets/AllNLI.tsv.gz'
        if not os.path.exists(nli_dataset_path):
            util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)

        label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
        dev_samples = []
        with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn:
            reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
            for row in reader:
                if row['split'] == 'train':
                    label_id = label2int[row['label']]
                    dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label_id))
                    if len(dev_samples) >= 100:
                        break

        train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int))

        dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16)
        evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss)
        acc = evaluator(model)
        assert acc > 0.2

    def test_ParaphraseMiningEvaluator(self):
        """Tests that the ParaphraseMiningEvaluator can be loaded"""
        model = SentenceTransformer('paraphrase-distilroberta-base-v1')
        sentences = {0: "Hello World", 1: "Hello World!", 2: "The cat is on the table", 3: "On the table the cat is"}
        data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0,1), (2,3)])
        score = data_eval(model)
        assert score > 0.99