Aye10032 commited on
Commit
a526ab2
·
1 Parent(s): db4f564
Files changed (1) hide show
  1. top5_error_rate.py +18 -18
top5_error_rate.py CHANGED
@@ -44,39 +44,39 @@ class Top5ErrorRate(evaluate.Metric):
44
  inputs_description=_KWARGS_DESCRIPTION,
45
  features=datasets.Features(
46
  {
47
- "predictions": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))),
48
- "references": datasets.Sequence(datasets.Value("int32")),
49
  }
50
- if self.config_name == "multilabel"
51
  else {
52
- "predictions": datasets.Sequence(datasets.Value("float32")),
53
- "references": datasets.Value("int32"),
54
  }
55
  ),
56
  reference_urls=[],
57
  )
58
 
59
  def _compute(
60
- self,
61
- *,
62
- predictions: list[list[float]] = None,
63
- references: list[int] = None,
64
- **kwargs,
65
  ) -> Dict[str, Any]:
66
  # to numpy array
67
- outputs = np.array(predictions, dtype=np.float32)
68
  labels = np.array(references)
69
 
70
  # Top-1 ACC
71
  pred = outputs.argmax(axis=1)
72
  acc = (pred == labels).mean()
73
 
74
- # Top-5 Error Rate
75
- top5_indices = outputs.argsort(axis=1)[:, -5:]
76
- correct = (labels.reshape(-1, 1) == top5_indices).any(axis=1)
 
 
 
77
  top5_error_rate = 1 - correct.mean()
78
 
79
- return {
80
- "accuracy": float(acc),
81
- "top5_error_rate": float(top5_error_rate)
82
- }
 
44
  inputs_description=_KWARGS_DESCRIPTION,
45
  features=datasets.Features(
46
  {
47
+ 'predictions': datasets.Sequence(datasets.Sequence(datasets.Value('float32'))),
48
+ 'references': datasets.Sequence(datasets.Value('int32')),
49
  }
50
+ if self.config_name == 'multilabel'
51
  else {
52
+ 'predictions': datasets.Sequence(datasets.Value('float32')),
53
+ 'references': datasets.Value('int32'),
54
  }
55
  ),
56
  reference_urls=[],
57
  )
58
 
59
  def _compute(
60
+ self,
61
+ *,
62
+ predictions: list[list[float]] = None,
63
+ references: list[int] = None,
64
+ **kwargs,
65
  ) -> Dict[str, Any]:
66
  # to numpy array
67
+ outputs = np.array(predictions)
68
  labels = np.array(references)
69
 
70
  # Top-1 ACC
71
  pred = outputs.argmax(axis=1)
72
  acc = (pred == labels).mean()
73
 
74
+ # Top-5 Error rate
75
+ top5_indices = np.argpartition(outputs, -5, axis=1)[:, -5:]
76
+
77
+ # 使用广播机制直接比较
78
+ # 使用np.any的axis参数直接在最后一个维度上检查是否存在匹配
79
+ correct = np.any(top5_indices == labels[:, np.newaxis], axis=1)
80
  top5_error_rate = 1 - correct.mean()
81
 
82
+ return {'accuracy': float(acc), 'top5_error_rate': float(top5_error_rate)}