Aye10032 commited on
Commit
2f2b372
·
1 Parent(s): 81704c2
Files changed (1) hide show
  1. top5_error_rate.py +2 -6
top5_error_rate.py CHANGED
@@ -44,12 +44,12 @@ class Top5ErrorRate(evaluate.Metric):
44
  inputs_description=_KWARGS_DESCRIPTION,
45
  features=datasets.Features(
46
  {
47
- "predictions": datasets.Sequence(datasets.Value("float32")),
48
  "references": datasets.Sequence(datasets.Value("int32")),
49
  }
50
  if self.config_name == "multilabel"
51
  else {
52
- "predictions": datasets.Value("float32"),
53
  "references": datasets.Value("int32"),
54
  }
55
  ),
@@ -63,13 +63,9 @@ class Top5ErrorRate(evaluate.Metric):
63
  references: list[int] = None,
64
  **kwargs,
65
  ) -> Dict[str, Any]:
66
- print(predictions)
67
- print(references)
68
  # to numpy array
69
  outputs = np.array(predictions, dtype=np.float32)
70
  labels = np.array(references)
71
- print(outputs)
72
- print(labels)
73
 
74
  # Top-1 ACC
75
  pred = outputs.argmax(axis=1)
 
44
  inputs_description=_KWARGS_DESCRIPTION,
45
  features=datasets.Features(
46
  {
47
+ "predictions": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))),
48
  "references": datasets.Sequence(datasets.Value("int32")),
49
  }
50
  if self.config_name == "multilabel"
51
  else {
52
+ "predictions": datasets.Sequence(datasets.Value("float32")),
53
  "references": datasets.Value("int32"),
54
  }
55
  ),
 
63
  references: list[int] = None,
64
  **kwargs,
65
  ) -> Dict[str, Any]:
 
 
66
  # to numpy array
67
  outputs = np.array(predictions, dtype=np.float32)
68
  labels = np.array(references)
 
 
69
 
70
  # Top-1 ACC
71
  pred = outputs.argmax(axis=1)