Spaces:
Sleeping
Sleeping
from typing import Dict, Any | |
import datasets | |
import evaluate | |
import numpy as np | |
from evaluate.utils.file_utils import add_start_docstrings | |
_DESCRIPTION = """ | |
The "top-5 error" is the percentage of times that the target label does not appear among the 5 highest-probability predictions. It can be computed with: | |
Top-5 Error Rate = 1 - Top-5 Accuracy | |
or equivalently: | |
Top-5 Error Rate = (Number of incorrect top-5 predictions) / (Total number of cases processed) | |
Where: | |
- Top-5 Accuracy: The proportion of cases where the true label is among the model's top 5 predicted classes. | |
- Incorrect top-5 prediction: The true label is not in the top 5 predicted classes (ranked by probability). | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`list` of `int`): Predicted labels. | |
references (`list` of `int`): Ground truth labels. | |
Returns: | |
accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input. | |
Examples: | |
>>> accuracy_metric = evaluate.load("accuracy") | |
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0]) | |
>>> print(results) | |
{'accuracy': 0.5} | |
""" | |
_CITATION = """ | |
""" | |
class Top5ErrorRate(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Sequence(datasets.Value("int32")), | |
"references": datasets.Sequence(datasets.Value("int32")), | |
} | |
), | |
reference_urls=[], | |
) | |
def _compute( | |
self, | |
*, | |
predictions: list[list[float]] = None, | |
references: list[float] = None, | |
**kwargs, | |
) -> Dict[str, Any]: | |
# 确保输入是numpy数组 | |
predictions = np.array(predictions) | |
references = np.array(references) | |
# 获取每个样本的top-5预测类别 | |
top5_pred = np.argsort(predictions, axis=1)[:, -5:] | |
# 计算top-5错误率 | |
correct = 0 | |
total = len(references) | |
for i in range(total): | |
if references[i] in top5_pred[i]: | |
correct += 1 | |
error_rate = 1.0 - (correct / total) | |
return { | |
"top5_error_rate": float(error_rate) | |
} |