File size: 6,755 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Optional

from mmengine.evaluator import BaseMetric

from mmpretrain.registry import METRICS


def get_pred_idx(prediction: str, choices: List[str],
                 options: List[str]) -> int:  # noqa
    """Get the index (e.g. 2) from the prediction (e.g. 'C')

    Args:
        prediction (str): The prediction from the model,
            from ['A', 'B', 'C', 'D', 'E']
        choices (List(str)): The choices for the question,
            from ['A', 'B', 'C', 'D', 'E']
        options (List(str)): The options for the question,
            from ['A', 'B', 'C', 'D', 'E']

    Returns:
        int: The index of the prediction, from [0, 1, 2, 3, 4]
    """
    if prediction in options[:len(choices)]:
        return options.index(prediction)
    else:
        return random.choice(range(len(choices)))


@METRICS.register_module()
class ScienceQAMetric(BaseMetric):
    """Evaluation Metric for ScienceQA.

    Args:
        options (List(str)): Options for each question. Defaults to
            ["A", "B", "C", "D", "E"].
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Should be modified according to the
            `retrieval_type` for unambiguous results. Defaults to TR.
    """

    def __init__(self,
                 options: List[str] = ['A', 'B', 'C', 'D', 'E'],
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None) -> None:
        super().__init__(collect_device=collect_device, prefix=prefix)
        self.options = options

    def process(self, data_batch, data_samples) -> None:
        """Process one batch of data samples.

        data_samples should contain the following keys:
        1. pred_answer (str): The prediction from the model,
            from ['A', 'B', 'C', 'D', 'E']
        2. choices (List(str)): The choices for the question,
            from ['A', 'B', 'C', 'D', 'E']
        3. grade (int): The grade for the question, from grade1 to grade12
        4. subject (str): The subject for the question, from
            ['natural science', 'social science', 'language science']
        5. answer (str): The answer for the question, from
            ['A', 'B', 'C', 'D', 'E']
        6. hint (str): The hint for the question
        7. has_image (bool): Whether or not the question has image


        The processed results should be stored in ``self.results``, which will
        be used to computed the metrics when all batches have been processed.

        Args:
            data_batch: A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """
        for data_sample in data_samples:
            result = dict()
            choices = data_sample.get('choices')
            result['prediction'] = get_pred_idx(
                data_sample.get('pred_answer'), choices, self.options)
            result['grade'] = data_sample.get('grade')
            result['subject'] = data_sample.get('subject')
            result['answer'] = data_sample.get('gt_answer')
            hint = data_sample.get('hint')
            has_image = data_sample.get('has_image', False)
            result['no_context'] = True if not has_image and len(
                hint) == 0 else False  # noqa
            result['has_text'] = True if len(hint) > 0 else False
            result['has_image'] = has_image

            # Save the result to `self.results`.
            self.results.append(result)

    def compute_metrics(self, results: List) -> dict:
        """Compute the metrics from processed results.

        Args:
            results (dict): The processed results of each batch.

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """
        # NOTICE: don't access `self.results` from the method.
        metrics = dict()

        all_acc = []
        acc_natural = []
        acc_social = []
        acc_language = []
        acc_has_text = []
        acc_has_image = []
        acc_no_context = []
        acc_grade_1_6 = []
        acc_grade_7_12 = []

        for result in results:
            correct = result['prediction'] == result['answer']
            all_acc.append(correct)
            # different subjects
            if result['subject'] == 'natural science':
                acc_natural.append(correct)
            elif result['subject'] == 'social science':
                acc_social.append(correct)
            elif result['subject'] == 'language science':
                acc_language.append(correct)

            # different context
            if result['has_text']:
                acc_has_text.append(correct)
            elif result['has_image']:
                acc_has_image.append(correct)
            elif result['no_context']:
                acc_no_context.append(correct)

            # different grade
            if result['grade'] in [
                    'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6'
            ]:
                acc_grade_1_6.append(correct)
            elif result['grade'] in [
                    'grade7', 'grade8', 'grade9', 'grade10', 'grade11',
                    'grade12'
            ]:
                acc_grade_7_12.append(correct)

        metrics['all_acc'] = sum(all_acc) / len(all_acc)
        if len(acc_natural) > 0:
            metrics['acc_natural'] = sum(acc_natural) / len(acc_natural)
        if len(acc_social) > 0:
            metrics['acc_social'] = sum(acc_social) / len(acc_social)
        if len(acc_language) > 0:
            metrics['acc_language'] = sum(acc_language) / len(acc_language)
        if len(acc_has_text) > 0:
            metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text)
        if len(acc_has_image) > 0:
            metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image)
        if len(acc_no_context) > 0:
            metrics['acc_no_context'] = sum(acc_no_context) / len(
                acc_no_context)
        if len(acc_grade_1_6) > 0:
            metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6)
        if len(acc_grade_7_12) > 0:
            metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len(
                acc_grade_7_12)

        return metrics