SentenceTransformer
				/
					sentence_transformers
								
							/cross_encoder
								
							/evaluation
								
							/CEBinaryAccuracyEvaluator.py
					
			| import logging | |
| import os | |
| import csv | |
| from typing import List | |
| from ... import InputExample | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class CEBinaryAccuracyEvaluator: | |
| """ | |
| This evaluator can be used with the CrossEncoder class. | |
| It is designed for CrossEncoders with 1 outputs. It measure the | |
| accuracy of the predict class vs. the gold labels. It uses a fixed threshold to determine the label (0 vs 1). | |
| See CEBinaryClassificationEvaluator for an evaluator that determines automatically the optimal threshold. | |
| """ | |
| def __init__(self, sentence_pairs: List[List[str]], labels: List[int], name: str='', threshold: float = 0.5, write_csv: bool = True): | |
| self.sentence_pairs = sentence_pairs | |
| self.labels = labels | |
| self.name = name | |
| self.threshold = threshold | |
| self.csv_file = "CEBinaryAccuracyEvaluator" + ("_" + name if name else '') + "_results.csv" | |
| self.csv_headers = ["epoch", "steps", "Accuracy"] | |
| self.write_csv = write_csv | |
| def from_input_examples(cls, examples: List[InputExample], **kwargs): | |
| sentence_pairs = [] | |
| labels = [] | |
| for example in examples: | |
| sentence_pairs.append(example.texts) | |
| labels.append(example.label) | |
| return cls(sentence_pairs, labels, **kwargs) | |
| def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: | |
| if epoch != -1: | |
| if steps == -1: | |
| out_txt = " after epoch {}:".format(epoch) | |
| else: | |
| out_txt = " in epoch {} after {} steps:".format(epoch, steps) | |
| else: | |
| out_txt = ":" | |
| logger.info("CEBinaryAccuracyEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) | |
| pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False) | |
| pred_labels = pred_scores > self.threshold | |
| assert len(pred_labels) == len(self.labels) | |
| acc = np.sum(pred_labels == self.labels) / len(self.labels) | |
| logger.info("Accuracy: {:.2f}".format(acc*100)) | |
| if output_path is not None and self.write_csv: | |
| csv_path = os.path.join(output_path, self.csv_file) | |
| output_file_exists = os.path.isfile(csv_path) | |
| with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: | |
| writer = csv.writer(f) | |
| if not output_file_exists: | |
| writer.writerow(self.csv_headers) | |
| writer.writerow([epoch, steps, acc]) | |
| return acc | |