File size: 1,951 Bytes
54fa0c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# source: https://github.com/mponty/bigcode-dataset/tree/main/pii/ner_model_training/utils by @mponty
import numpy as np
from evaluate import load
from scipy.special import softmax
from sklearn.metrics import average_precision_score

_seqeval_metric = load("seqeval")


# NER tags
CATEGORIES = [
    "NAME",
    "EMAIL",
    "EMAIL_EXAMPLE",
    "USERNAME",
    "KEY",
    "IP_ADDRESS",
    "PASSWORD",
]
IGNORE_CLASS = ["AMBIGUOUS", "ID", "NAME_EXAMPLE", "USERNAME_EXAMPLE"]

LABEL2ID = {"O": 0}
for cat in CATEGORIES:
    LABEL2ID[f"B-{cat}"] = len(LABEL2ID)
    LABEL2ID[f"I-{cat}"] = len(LABEL2ID)
ID2LABEL = {v: k for k, v in LABEL2ID.items()}


def compute_ap(pred, truth):
    pred_proba = 1 - softmax(pred, axis=-1)[..., 0]
    pred_proba, truth = pred_proba.flatten(), np.array(truth).flatten()
    pred_proba = pred_proba[truth != -100]
    truth = truth[truth != -100]

    return average_precision_score(truth != 0, pred_proba)


def compute_metrics(p):
    predictions, labels = p
    avg_prec = compute_ap(predictions, labels)
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [ID2LABEL[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [ID2LABEL[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = _seqeval_metric.compute(
        predictions=true_predictions, references=true_labels, zero_division=0,
    )
    agg_metrics = {
        "Avg.Precision": avg_prec,
        "precision": results.pop("overall_precision"),
        "recall": results.pop("overall_recall"),
        "f1": results.pop("overall_f1"),
    }
    results.pop("overall_accuracy")
    per_cat_metrics = {name: metrics["f1"] for name, metrics in results.items()}

    return dict(**agg_metrics, **per_cat_metrics)