SentenceTransformer
/
sentence_transformers
/cross_encoder
/evaluation
/CEBinaryClassificationEvaluator.py
import logging | |
from sklearn.metrics import average_precision_score | |
from typing import List | |
import numpy as np | |
import os | |
import csv | |
from ... import InputExample | |
from ...evaluation import BinaryClassificationEvaluator | |
logger = logging.getLogger(__name__) | |
class CEBinaryClassificationEvaluator: | |
""" | |
This evaluator can be used with the CrossEncoder class. Given sentence pairs and binary labels (0 and 1), | |
it compute the average precision and the best possible f1 score | |
""" | |
def __init__(self, sentence_pairs: List[List[str]], labels: List[int], name: str='', show_progress_bar: bool = False, write_csv: bool = True): | |
assert len(sentence_pairs) == len(labels) | |
for label in labels: | |
assert (label == 0 or label == 1) | |
self.sentence_pairs = sentence_pairs | |
self.labels = np.asarray(labels) | |
self.name = name | |
if show_progress_bar is None: | |
show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) | |
self.show_progress_bar = show_progress_bar | |
self.csv_file = "CEBinaryClassificationEvaluator" + ("_" + name if name else '') + "_results.csv" | |
self.csv_headers = ["epoch", "steps", "Accuracy", "Accuracy_Threshold", "F1", "F1_Threshold", "Precision", "Recall", "Average_Precision"] | |
self.write_csv = write_csv | |
def from_input_examples(cls, examples: List[InputExample], **kwargs): | |
sentence_pairs = [] | |
labels = [] | |
for example in examples: | |
sentence_pairs.append(example.texts) | |
labels.append(example.label) | |
return cls(sentence_pairs, labels, **kwargs) | |
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: | |
if epoch != -1: | |
if steps == -1: | |
out_txt = " after epoch {}:".format(epoch) | |
else: | |
out_txt = " in epoch {} after {} steps:".format(epoch, steps) | |
else: | |
out_txt = ":" | |
logger.info("CEBinaryClassificationEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) | |
pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=self.show_progress_bar) | |
acc, acc_threshold = BinaryClassificationEvaluator.find_best_acc_and_threshold(pred_scores, self.labels, True) | |
f1, precision, recall, f1_threshold = BinaryClassificationEvaluator.find_best_f1_and_threshold(pred_scores, self.labels, True) | |
ap = average_precision_score(self.labels, pred_scores) | |
logger.info("Accuracy: {:.2f}\t(Threshold: {:.4f})".format(acc * 100, acc_threshold)) | |
logger.info("F1: {:.2f}\t(Threshold: {:.4f})".format(f1 * 100, f1_threshold)) | |
logger.info("Precision: {:.2f}".format(precision * 100)) | |
logger.info("Recall: {:.2f}".format(recall * 100)) | |
logger.info("Average Precision: {:.2f}\n".format(ap * 100)) | |
if output_path is not None and self.write_csv: | |
csv_path = os.path.join(output_path, self.csv_file) | |
output_file_exists = os.path.isfile(csv_path) | |
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: | |
writer = csv.writer(f) | |
if not output_file_exists: | |
writer.writerow(self.csv_headers) | |
writer.writerow([epoch, steps, acc, acc_threshold, f1, f1_threshold, precision, recall, ap]) | |
return ap | |