File size: 4,397 Bytes
69ad385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Dict, List, Optional

import numpy as np
import torch


class CategoricalAccuracyVerbose(object):
    def __init__(self,
                 index_to_token: Dict[int, str],
                 label_namespace: str = "labels",
                 top_k: int = 1,
                 ) -> None:
        if top_k <= 0:
            raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
        self._index_to_token = index_to_token
        self._label_namespace = label_namespace
        self._top_k = top_k
        self.correct_count = 0.
        self.total_count = 0.
        self.label_correct_count = dict()
        self.label_total_count = dict()

    def __call__(self,
                 predictions: torch.Tensor,
                 gold_labels: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        num_classes = predictions.size(-1)
        if gold_labels.dim() != predictions.dim() - 1:
            raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
                                     "found tensor of shape: {}".format(predictions.size()))
        if (gold_labels >= num_classes).any():
            raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
                                     "the number of classes.".format(num_classes))

        predictions = predictions.view((-1, num_classes))
        gold_labels = gold_labels.view(-1).long()

        # Top K indexes of the predictions (or fewer, if there aren't K of them).
        # Special case topk == 1, because it's common and .max() is much faster than .topk().
        if self._top_k == 1:
            top_k = predictions.max(-1)[1].unsqueeze(-1)
        else:
            top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]

        # This is of shape (batch_size, ..., top_k).
        correct = top_k.eq(gold_labels.unsqueeze(-1)).float()

        if mask is not None:
            correct *= mask.view(-1, 1).float()
            self.total_count += mask.sum()
        else:
            self.total_count += gold_labels.numel()
        self.correct_count += correct.sum()

        labels: List[int] = np.unique(gold_labels.cpu().numpy()).tolist()
        for label in labels:
            label_mask = (gold_labels == label)

            label_correct = correct * label_mask.view(-1, 1).float()
            label_correct = int(label_correct.sum())
            label_count = int(label_mask.sum())

            label_str = self._index_to_token[label]
            if label_str in self.label_correct_count:
                self.label_correct_count[label_str] += label_correct
            else:
                self.label_correct_count[label_str] = label_correct

            if label_str in self.label_total_count:
                self.label_total_count[label_str] += label_count
            else:
                self.label_total_count[label_str] = label_count

    def get_metric(self, reset: bool = False):
        """
        Returns
        -------
        The accumulated accuracy.
        """
        result = dict()
        if self.total_count > 1e-12:
            accuracy = float(self.correct_count) / float(self.total_count)
        else:
            accuracy = 0.0
        result['accuracy'] = accuracy

        for label in self.label_total_count.keys():
            total = self.label_total_count[label]
            correct = self.label_correct_count.get(label, 0.0)
            label_accuracy = correct / total
            result[label] = label_accuracy

        if reset:
            self.reset()
        return result

    def reset(self):
        self.correct_count = 0.0
        self.total_count = 0.0
        self.label_correct_count = dict()
        self.label_total_count = dict()


def demo1():

    categorical_accuracy_verbose = CategoricalAccuracyVerbose(
        index_to_token={0: '0', 1: '1'},
        top_k=2,
    )

    predictions = torch.randn(size=(2, 3), dtype=torch.float32)
    gold_labels = torch.ones(size=(2,), dtype=torch.long)
    # print(predictions)
    # print(gold_labels)

    categorical_accuracy_verbose(
        predictions=predictions,
        gold_labels=gold_labels,
    )
    metric = categorical_accuracy_verbose.get_metric()
    print(metric)
    return


if __name__ == '__main__':
    demo1()