Spaces:
Runtime error
Runtime error
| """ Official evaluation script for CUAD dataset. """ | |
| import argparse | |
| import json | |
| import re | |
| import string | |
| import sys | |
| import numpy as np | |
| IOU_THRESH = 0.5 | |
| def get_jaccard(prediction, ground_truth): | |
| remove_tokens = [".", ",", ";", ":"] | |
| for token in remove_tokens: | |
| ground_truth = ground_truth.replace(token, "") | |
| prediction = prediction.replace(token, "") | |
| ground_truth, prediction = ground_truth.lower(), prediction.lower() | |
| ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ") | |
| ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" ")) | |
| intersection = ground_truth.intersection(prediction) | |
| union = ground_truth.union(prediction) | |
| jaccard = len(intersection) / len(union) | |
| return jaccard | |
| def normalize_answer(s): | |
| """Lower text and remove punctuation, articles and extra whitespace.""" | |
| def remove_articles(text): | |
| return re.sub(r"\b(a|an|the)\b", " ", text) | |
| def white_space_fix(text): | |
| return " ".join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return "".join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
| def compute_precision_recall(predictions, ground_truths, qa_id): | |
| tp, fp, fn = 0, 0, 0 | |
| substr_ok = "Parties" in qa_id | |
| # first check if ground truth is empty | |
| if len(ground_truths) == 0: | |
| if len(predictions) > 0: | |
| fp += len(predictions) # false positive for each one | |
| else: | |
| for ground_truth in ground_truths: | |
| assert len(ground_truth) > 0 | |
| # check if there is a match | |
| match_found = False | |
| for pred in predictions: | |
| if substr_ok: | |
| is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred | |
| else: | |
| is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH | |
| if is_match: | |
| match_found = True | |
| if match_found: | |
| tp += 1 | |
| else: | |
| fn += 1 | |
| # now also get any fps by looping through preds | |
| for pred in predictions: | |
| # Check if there's a match. if so, don't count (don't want to double count based on the above) | |
| # but if there's no match, then this is a false positive. | |
| # (Note: we get the true positives in the above loop instead of this loop so that we don't double count | |
| # multiple predictions that are matched with the same answer.) | |
| match_found = False | |
| for ground_truth in ground_truths: | |
| assert len(ground_truth) > 0 | |
| if substr_ok: | |
| is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred | |
| else: | |
| is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH | |
| if is_match: | |
| match_found = True | |
| if not match_found: | |
| fp += 1 | |
| precision = tp / (tp + fp) if tp + fp > 0 else np.nan | |
| recall = tp / (tp + fn) if tp + fn > 0 else np.nan | |
| return precision, recall | |
| def process_precisions(precisions): | |
| """ | |
| Processes precisions to ensure that precision and recall don't both get worse. | |
| Assumes the list precision is sorted in order of recalls | |
| """ | |
| precision_best = precisions[::-1] | |
| for i in range(1, len(precision_best)): | |
| precision_best[i] = max(precision_best[i - 1], precision_best[i]) | |
| precisions = precision_best[::-1] | |
| return precisions | |
| def get_aupr(precisions, recalls): | |
| processed_precisions = process_precisions(precisions) | |
| aupr = np.trapz(processed_precisions, recalls) | |
| if np.isnan(aupr): | |
| return 0 | |
| return aupr | |
| def get_prec_at_recall(precisions, recalls, recall_thresh): | |
| """Assumes recalls are sorted in increasing order""" | |
| processed_precisions = process_precisions(precisions) | |
| prec_at_recall = 0 | |
| for prec, recall in zip(processed_precisions, recalls): | |
| if recall >= recall_thresh: | |
| prec_at_recall = prec | |
| break | |
| return prec_at_recall | |
| def exact_match_score(prediction, ground_truth): | |
| return normalize_answer(prediction) == normalize_answer(ground_truth) | |
| def metric_max_over_ground_truths(metric_fn, predictions, ground_truths): | |
| score = 0 | |
| for pred in predictions: | |
| for ground_truth in ground_truths: | |
| score = metric_fn(pred, ground_truth) | |
| if score == 1: # break the loop when one prediction matches the ground truth | |
| break | |
| if score == 1: | |
| break | |
| return score | |
| def compute_score(dataset, predictions): | |
| f1 = exact_match = total = 0 | |
| precisions = [] | |
| recalls = [] | |
| for article in dataset: | |
| for paragraph in article["paragraphs"]: | |
| for qa in paragraph["qas"]: | |
| total += 1 | |
| if qa["id"] not in predictions: | |
| message = "Unanswered question " + qa["id"] + " will receive score 0." | |
| print(message, file=sys.stderr) | |
| continue | |
| ground_truths = list(map(lambda x: x["text"], qa["answers"])) | |
| prediction = predictions[qa["id"]] | |
| precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"]) | |
| precisions.append(precision) | |
| recalls.append(recall) | |
| if precision == 0 and recall == 0: | |
| f1 += 0 | |
| else: | |
| f1 += 2 * (precision * recall) / (precision + recall) | |
| exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) | |
| precisions = [x for _, x in sorted(zip(recalls, precisions))] | |
| recalls.sort() | |
| f1 = 100.0 * f1 / total | |
| exact_match = 100.0 * exact_match / total | |
| aupr = get_aupr(precisions, recalls) | |
| prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9) | |
| prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8) | |
| return { | |
| "exact_match": exact_match, | |
| "f1": f1, | |
| "aupr": aupr, | |
| "prec_at_80_recall": prec_at_80_recall, | |
| "prec_at_90_recall": prec_at_90_recall, | |
| } | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Evaluation for CUAD") | |
| parser.add_argument("dataset_file", help="Dataset file") | |
| parser.add_argument("prediction_file", help="Prediction File") | |
| args = parser.parse_args() | |
| with open(args.dataset_file) as dataset_file: | |
| dataset_json = json.load(dataset_file) | |
| dataset = dataset_json["data"] | |
| with open(args.prediction_file) as prediction_file: | |
| predictions = json.load(prediction_file) | |
| print(json.dumps(compute_score(dataset, predictions))) | |