Spaces:
Sleeping
Sleeping
from typing import Dict, Any | |
import datasets | |
import evaluate | |
import numpy as np | |
from evaluate.utils.file_utils import add_start_docstrings | |
_DESCRIPTION = """ | |
The "top-5 error" is the percentage of times that the target label does not appear among the 5 highest-probability predictions. It can be computed with: | |
Top-5 Error Rate = 1 - Top-5 Accuracy | |
or equivalently: | |
Top-5 Error Rate = (Number of incorrect top-5 predictions) / (Total number of cases processed) | |
Where: | |
- Top-5 Accuracy: The proportion of cases where the true label is among the model's top 5 predicted classes. | |
- Incorrect top-5 prediction: The true label is not in the top 5 predicted classes (ranked by probability). | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`list` of `list` of `int`): Predicted labels. Each inner list should contain the top-5 predicted class indices. | |
references (`list` of `int`): Ground truth labels. | |
Returns: | |
top5_error_rate (`float`): Top-5 Error Rate score. Minimum possible value is 0. Maximum possible value is 1.0. | |
Examples: | |
>>> metric = evaluate.load("top5_error_rate") | |
>>> results = metric.compute( | |
... references=[0, 1, 2], | |
... predictions=[[0, 1, 2, 3, 4], [1, 0, 2, 3, 4], [2, 0, 1, 3, 4]] | |
... ) | |
>>> print(results) | |
{'top5_error_rate': 0.0} | |
""" | |
_CITATION = """ | |
""" | |
class Top5ErrorRate(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
'predictions': datasets.Sequence(datasets.Sequence(datasets.Value('float32'))), | |
'references': datasets.Sequence(datasets.Value('int32')), | |
} | |
if self.config_name == 'multilabel' | |
else { | |
'predictions': datasets.Sequence(datasets.Value('float32')), | |
'references': datasets.Value('int32'), | |
} | |
), | |
reference_urls=[], | |
) | |
def _compute( | |
self, | |
*, | |
predictions: list[list[float]] = None, | |
references: list[int] = None, | |
**kwargs, | |
) -> Dict[str, Any]: | |
# to numpy array | |
outputs = np.array(predictions) | |
labels = np.array(references) | |
# Top-1 ACC | |
pred = outputs.argmax(axis=1) | |
acc = (pred == labels).mean() | |
# Top-5 Error rate | |
top5_indices = np.argpartition(outputs, -5, axis=1)[:, -5:] | |
# 使用广播机制直接比较 | |
# 使用np.any的axis参数直接在最后一个维度上检查是否存在匹配 | |
correct = np.any(top5_indices == labels[:, np.newaxis], axis=1) | |
top5_error_rate = 1 - correct.mean() | |
return {'accuracy': float(acc), 'top5_error_rate': float(top5_error_rate)} | |