Spaces:
Sleeping
Sleeping
from typing import Dict, Any | |
import datasets | |
import evaluate | |
import numpy as np | |
from evaluate.utils.file_utils import add_start_docstrings | |
_DESCRIPTION = """ | |
The "top-5 error" is the percentage of times that the target label does not appear among the 5 highest-probability predictions. It can be computed with: | |
Top-5 Error Rate = 1 - Top-5 Accuracy | |
or equivalently: | |
Top-5 Error Rate = (Number of incorrect top-5 predictions) / (Total number of cases processed) | |
Where: | |
- Top-5 Accuracy: The proportion of cases where the true label is among the model's top 5 predicted classes. | |
- Incorrect top-5 prediction: The true label is not in the top 5 predicted classes (ranked by probability). | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`list` of `list` of `int`): Predicted labels. Each inner list should contain the top-5 predicted class indices. | |
references (`list` of `int`): Ground truth labels. | |
Returns: | |
top5_error_rate (`float`): Top-5 Error Rate score. Minimum possible value is 0. Maximum possible value is 1.0. | |
Examples: | |
>>> metric = evaluate.load("top5_error_rate") | |
>>> results = metric.compute( | |
... references=[0, 1, 2], | |
... predictions=[[0, 1, 2, 3, 4], [1, 0, 2, 3, 4], [2, 0, 1, 3, 4]] | |
... ) | |
>>> print(results) | |
{'top5_error_rate': 0.0} | |
""" | |
_CITATION = """ | |
""" | |
class Top5ErrorRate(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))), | |
"references": datasets.Sequence(datasets.Value("int32")), | |
} | |
if self.config_name == "multilabel" | |
else { | |
"predictions": datasets.Sequence(datasets.Value("float32")), | |
"references": datasets.Value("int32"), | |
} | |
), | |
reference_urls=[], | |
) | |
def _compute( | |
self, | |
*, | |
predictions: list[list[float]] = None, | |
references: list[int] = None, | |
**kwargs, | |
) -> Dict[str, Any]: | |
# to numpy array | |
outputs = np.array(predictions, dtype=np.float32) | |
labels = np.array(references) | |
# Top-1 ACC | |
pred = outputs.argmax(axis=1) | |
acc = (pred == labels).mean() | |
# Top-5 Error Rate | |
top5_indices = outputs.argsort(axis=1)[:, -5:] | |
correct = (labels.reshape(-1, 1) == top5_indices).any(axis=1) | |
top5_error_rate = 1 - correct.mean() | |
return { | |
"accuracy": float(acc), | |
"top5_error_rate": float(top5_error_rate) | |
} | |