import logging from sklearn.metrics import average_precision_score from typing import List import numpy as np import os import csv from ... import InputExample from ...evaluation import BinaryClassificationEvaluator logger = logging.getLogger(__name__) class CEBinaryClassificationEvaluator: """ This evaluator can be used with the CrossEncoder class. Given sentence pairs and binary labels (0 and 1), it compute the average precision and the best possible f1 score """ def __init__(self, sentence_pairs: List[List[str]], labels: List[int], name: str='', show_progress_bar: bool = False, write_csv: bool = True): assert len(sentence_pairs) == len(labels) for label in labels: assert (label == 0 or label == 1) self.sentence_pairs = sentence_pairs self.labels = np.asarray(labels) self.name = name if show_progress_bar is None: show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) self.show_progress_bar = show_progress_bar self.csv_file = "CEBinaryClassificationEvaluator" + ("_" + name if name else '') + "_results.csv" self.csv_headers = ["epoch", "steps", "Accuracy", "Accuracy_Threshold", "F1", "F1_Threshold", "Precision", "Recall", "Average_Precision"] self.write_csv = write_csv @classmethod def from_input_examples(cls, examples: List[InputExample], **kwargs): sentence_pairs = [] labels = [] for example in examples: sentence_pairs.append(example.texts) labels.append(example.label) return cls(sentence_pairs, labels, **kwargs) def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: if epoch != -1: if steps == -1: out_txt = " after epoch {}:".format(epoch) else: out_txt = " in epoch {} after {} steps:".format(epoch, steps) else: out_txt = ":" logger.info("CEBinaryClassificationEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=self.show_progress_bar) acc, acc_threshold = BinaryClassificationEvaluator.find_best_acc_and_threshold(pred_scores, self.labels, True) f1, precision, recall, f1_threshold = BinaryClassificationEvaluator.find_best_f1_and_threshold(pred_scores, self.labels, True) ap = average_precision_score(self.labels, pred_scores) logger.info("Accuracy: {:.2f}\t(Threshold: {:.4f})".format(acc * 100, acc_threshold)) logger.info("F1: {:.2f}\t(Threshold: {:.4f})".format(f1 * 100, f1_threshold)) logger.info("Precision: {:.2f}".format(precision * 100)) logger.info("Recall: {:.2f}".format(recall * 100)) logger.info("Average Precision: {:.2f}\n".format(ap * 100)) if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) output_file_exists = os.path.isfile(csv_path) with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: writer = csv.writer(f) if not output_file_exists: writer.writerow(self.csv_headers) writer.writerow([epoch, steps, acc, acc_threshold, f1, f1_threshold, precision, recall, ap]) return ap