Aye10032 commited on
Commit
9bf4c3e
·
1 Parent(s): dbf1a3f
Files changed (1) hide show
  1. top5_error_rate.py +3 -14
top5_error_rate.py CHANGED
@@ -53,24 +53,13 @@ class Top5ErrorRate(evaluate.Metric):
53
  def _compute(
54
  self,
55
  *,
56
- predictions: list[list[float]] = None,
57
- references: list[float] = None,
58
  **kwargs,
59
  ) -> Dict[str, Any]:
60
- # 确保输入是numpy数组
61
- predictions = np.array(predictions)
62
- references = np.array(references)
63
 
64
- # 获取每个样本的top-5预测类别
65
- top5_pred = np.argsort(predictions, axis=1)[:, -5:]
66
-
67
- # 计算top-5错误率
68
- correct = 0
69
  total = len(references)
70
-
71
- for i in range(total):
72
- if references[i] in top5_pred[i]:
73
- correct += 1
74
 
75
  error_rate = 1.0 - (correct / total)
76
 
 
53
  def _compute(
54
  self,
55
  *,
56
+ predictions: list[list[int]] = None,
57
+ references: list[int] = None,
58
  **kwargs,
59
  ) -> Dict[str, Any]:
 
 
 
60
 
 
 
 
 
 
61
  total = len(references)
62
+ correct = sum(1 for pred, ref in zip(predictions, references) if ref in pred)
 
 
 
63
 
64
  error_rate = 1.0 - (correct / total)
65