Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import random | |
from typing import List, Optional | |
from mmengine.evaluator import BaseMetric | |
from mmpretrain.registry import METRICS | |
def get_pred_idx(prediction: str, choices: List[str], | |
options: List[str]) -> int: # noqa | |
"""Get the index (e.g. 2) from the prediction (e.g. 'C') | |
Args: | |
prediction (str): The prediction from the model, | |
from ['A', 'B', 'C', 'D', 'E'] | |
choices (List(str)): The choices for the question, | |
from ['A', 'B', 'C', 'D', 'E'] | |
options (List(str)): The options for the question, | |
from ['A', 'B', 'C', 'D', 'E'] | |
Returns: | |
int: The index of the prediction, from [0, 1, 2, 3, 4] | |
""" | |
if prediction in options[:len(choices)]: | |
return options.index(prediction) | |
else: | |
return random.choice(range(len(choices))) | |
class ScienceQAMetric(BaseMetric): | |
"""Evaluation Metric for ScienceQA. | |
Args: | |
options (List(str)): Options for each question. Defaults to | |
["A", "B", "C", "D", "E"]. | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonymous metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Should be modified according to the | |
`retrieval_type` for unambiguous results. Defaults to TR. | |
""" | |
def __init__(self, | |
options: List[str] = ['A', 'B', 'C', 'D', 'E'], | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None) -> None: | |
super().__init__(collect_device=collect_device, prefix=prefix) | |
self.options = options | |
def process(self, data_batch, data_samples) -> None: | |
"""Process one batch of data samples. | |
data_samples should contain the following keys: | |
1. pred_answer (str): The prediction from the model, | |
from ['A', 'B', 'C', 'D', 'E'] | |
2. choices (List(str)): The choices for the question, | |
from ['A', 'B', 'C', 'D', 'E'] | |
3. grade (int): The grade for the question, from grade1 to grade12 | |
4. subject (str): The subject for the question, from | |
['natural science', 'social science', 'language science'] | |
5. answer (str): The answer for the question, from | |
['A', 'B', 'C', 'D', 'E'] | |
6. hint (str): The hint for the question | |
7. has_image (bool): Whether or not the question has image | |
The processed results should be stored in ``self.results``, which will | |
be used to computed the metrics when all batches have been processed. | |
Args: | |
data_batch: A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of outputs from the model. | |
""" | |
for data_sample in data_samples: | |
result = dict() | |
choices = data_sample.get('choices') | |
result['prediction'] = get_pred_idx( | |
data_sample.get('pred_answer'), choices, self.options) | |
result['grade'] = data_sample.get('grade') | |
result['subject'] = data_sample.get('subject') | |
result['answer'] = data_sample.get('gt_answer') | |
hint = data_sample.get('hint') | |
has_image = data_sample.get('has_image', False) | |
result['no_context'] = True if not has_image and len( | |
hint) == 0 else False # noqa | |
result['has_text'] = True if len(hint) > 0 else False | |
result['has_image'] = has_image | |
# Save the result to `self.results`. | |
self.results.append(result) | |
def compute_metrics(self, results: List) -> dict: | |
"""Compute the metrics from processed results. | |
Args: | |
results (dict): The processed results of each batch. | |
Returns: | |
Dict: The computed metrics. The keys are the names of the metrics, | |
and the values are corresponding results. | |
""" | |
# NOTICE: don't access `self.results` from the method. | |
metrics = dict() | |
all_acc = [] | |
acc_natural = [] | |
acc_social = [] | |
acc_language = [] | |
acc_has_text = [] | |
acc_has_image = [] | |
acc_no_context = [] | |
acc_grade_1_6 = [] | |
acc_grade_7_12 = [] | |
for result in results: | |
correct = result['prediction'] == result['answer'] | |
all_acc.append(correct) | |
# different subjects | |
if result['subject'] == 'natural science': | |
acc_natural.append(correct) | |
elif result['subject'] == 'social science': | |
acc_social.append(correct) | |
elif result['subject'] == 'language science': | |
acc_language.append(correct) | |
# different context | |
if result['has_text']: | |
acc_has_text.append(correct) | |
elif result['has_image']: | |
acc_has_image.append(correct) | |
elif result['no_context']: | |
acc_no_context.append(correct) | |
# different grade | |
if result['grade'] in [ | |
'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6' | |
]: | |
acc_grade_1_6.append(correct) | |
elif result['grade'] in [ | |
'grade7', 'grade8', 'grade9', 'grade10', 'grade11', | |
'grade12' | |
]: | |
acc_grade_7_12.append(correct) | |
metrics['all_acc'] = sum(all_acc) / len(all_acc) | |
if len(acc_natural) > 0: | |
metrics['acc_natural'] = sum(acc_natural) / len(acc_natural) | |
if len(acc_social) > 0: | |
metrics['acc_social'] = sum(acc_social) / len(acc_social) | |
if len(acc_language) > 0: | |
metrics['acc_language'] = sum(acc_language) / len(acc_language) | |
if len(acc_has_text) > 0: | |
metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text) | |
if len(acc_has_image) > 0: | |
metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image) | |
if len(acc_no_context) > 0: | |
metrics['acc_no_context'] = sum(acc_no_context) / len( | |
acc_no_context) | |
if len(acc_grade_1_6) > 0: | |
metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6) | |
if len(acc_grade_7_12) > 0: | |
metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len( | |
acc_grade_7_12) | |
return metrics | |