Aye10032 commited on
Commit
ab21cd8
·
1 Parent(s): 5453f99
Files changed (6) hide show
  1. README.md +15 -0
  2. app.py +47 -3
  3. pyproject.toml +17 -0
  4. requirements.txt +0 -0
  5. top5_error_rate.py +29 -18
  6. uv.lock +21 -1
README.md CHANGED
@@ -23,3 +23,18 @@ Top-5 Error Rate = (Number of incorrect top-5 predictions) / (Total number of ca
23
  Where:
24
  - Top-5 Accuracy: The proportion of cases where the true label is among the model's top 5 predicted classes.
25
  - Incorrect top-5 prediction: The true label is not in the top 5 predicted classes (ranked by probability).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  Where:
24
  - Top-5 Accuracy: The proportion of cases where the true label is among the model's top 5 predicted classes.
25
  - Incorrect top-5 prediction: The true label is not in the top 5 predicted classes (ranked by probability).
26
+
27
+ ## How to Use
28
+
29
+ At minimum, this metric requires predictions and references as inputs.
30
+
31
+ ```python
32
+ accuracy_metric = evaluate.load("Aye10032/top5_error_rate")
33
+ results = accuracy_metric.compute(references=[[0, 1, 2, 3, 4]], predictions=[0])
34
+ print(results)
35
+ ```
36
+ output is
37
+
38
+ ```
39
+ {'top5_error_rate': 0.0}
40
+ ```
app.py CHANGED
@@ -1,6 +1,50 @@
 
 
 
1
  import evaluate
2
- from evaluate.utils import launch_gradio_widget
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- module = evaluate.load("Aye10032/top5_error_rate")
6
- launch_gradio_widget(module)
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
  import evaluate
5
+ import gradio as gr
6
+ import polars as pl
7
+ from evaluate import parse_readme
8
+
9
+ metric = evaluate.load("Aye10032/top5_error_rate")
10
+
11
+
12
+ def compute(data):
13
+ print(data)
14
+ # return metric.compute()
15
+ result = {
16
+ "predictions": [list(map(int, pred.split(","))) for pred in data["predictions"]],
17
+ "references": data["references"].cast(pl.Int64).to_list()
18
+ }
19
+ print(result)
20
+ return metric.compute(**result)
21
+
22
+
23
+ local_path = Path(sys.path[0])
24
+
25
+ default_value = pl.DataFrame({
26
+ 'predictions': ['1,2,3,4,5', '1,2,3,4,5', '1,2,3,4,5'],
27
+ 'references': ['0', '1', '2']
28
+ })
29
 
30
+ iface = gr.Interface(
31
+ fn=compute,
32
+ inputs=gr.Dataframe(
33
+ headers=['predictions', 'references'],
34
+ col_count=2,
35
+ row_count=1,
36
+ datatype='str',
37
+ type='polars',
38
+ value=default_value
39
+ ),
40
+ outputs=gr.Textbox(label=metric.name),
41
+ description=(
42
+ metric.info.description
43
+ + "\nIf this is a text-based metric, make sure to wrap you input in double quotes."
44
+ " Alternatively you can use a JSON-formatted list as input."
45
+ ),
46
+ title=f"Metric: {metric.name}",
47
+ article=parse_readme(local_path / "README.md"),
48
+ )
49
 
50
+ iface.launch()
 
pyproject.toml CHANGED
@@ -6,4 +6,21 @@ readme = "README.md"
6
  requires-python = ">=3.13"
7
  dependencies = [
8
  "evaluate[template]>=0.4.3",
 
 
9
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  requires-python = ">=3.13"
7
  dependencies = [
8
  "evaluate[template]>=0.4.3",
9
+ "gradio>=5.24.0",
10
+ "polars>=1.27.1",
11
  ]
12
+
13
+ [tool.ruff]
14
+ # Allow lines to be as long as 120.
15
+ line-length = 100
16
+ extend-exclude = ["log", "data"]
17
+
18
+ [tool.ruff.format]
19
+ # 使用单引号
20
+ quote-style = "single"
21
+ # 启用docstring代码片段格式化
22
+ docstring-code-format = true
23
+
24
+ [tool.ruff.lint]
25
+ # On top of the default `select` (`E4`, E7`, `E9`, and `F`), enable flake8-bugbear (`B`) and flake8-quotes (`Q`).
26
+ extend-select = ["I"]
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
top5_error_rate.py CHANGED
@@ -2,6 +2,7 @@ from typing import Dict, Any
2
 
3
  import datasets
4
  import evaluate
 
5
  from evaluate.utils.file_utils import add_start_docstrings
6
 
7
  _DESCRIPTION = """
@@ -14,21 +15,22 @@ Top-5 Error Rate = (Number of incorrect top-5 predictions) / (Total number of ca
14
  - Incorrect top-5 prediction: The true label is not in the top 5 predicted classes (ranked by probability).
15
  """
16
 
17
-
18
  _KWARGS_DESCRIPTION = """
19
  Args:
20
- predictions (`list` of `list` of `int`): Predicted labels.
21
  references (`list` of `int`): Ground truth labels.
22
  Returns:
23
- accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input.
24
  Examples:
25
- >>> accuracy_metric = evaluate.load("accuracy")
26
- >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
 
 
 
27
  >>> print(results)
28
- {'accuracy': 0.5}
29
  """
30
 
31
-
32
  _CITATION = """
33
  """
34
 
@@ -42,7 +44,7 @@ class Top5ErrorRate(evaluate.Metric):
42
  inputs_description=_KWARGS_DESCRIPTION,
43
  features=datasets.Features(
44
  {
45
- "predictions": datasets.Sequence(list[datasets.Value("int32")]),
46
  "references": datasets.Sequence(datasets.Value("int32")),
47
  }
48
  if self.config_name == "multilabel"
@@ -55,17 +57,26 @@ class Top5ErrorRate(evaluate.Metric):
55
  )
56
 
57
  def _compute(
58
- self,
59
- *,
60
- predictions: list[list[int]] = None,
61
- references: list[int] = None,
62
- **kwargs,
63
  ) -> Dict[str, Any]:
64
- total = len(references)
65
- correct = sum(1 for pred, ref in zip(predictions, references) if ref in pred)
 
 
 
 
 
66
 
67
- error_rate = 1.0 - (correct / total)
 
 
 
68
 
69
  return {
70
- "top5_error_rate": float(error_rate)
71
- }
 
 
2
 
3
  import datasets
4
  import evaluate
5
+ import numpy as np
6
  from evaluate.utils.file_utils import add_start_docstrings
7
 
8
  _DESCRIPTION = """
 
15
  - Incorrect top-5 prediction: The true label is not in the top 5 predicted classes (ranked by probability).
16
  """
17
 
 
18
  _KWARGS_DESCRIPTION = """
19
  Args:
20
+ predictions (`list` of `list` of `int`): Predicted labels. Each inner list should contain the top-5 predicted class indices.
21
  references (`list` of `int`): Ground truth labels.
22
  Returns:
23
+ top5_error_rate (`float`): Top-5 Error Rate score. Minimum possible value is 0. Maximum possible value is 1.0.
24
  Examples:
25
+ >>> metric = evaluate.load("top5_error_rate")
26
+ >>> results = metric.compute(
27
+ ... references=[0, 1, 2],
28
+ ... predictions=[[0, 1, 2, 3, 4], [1, 0, 2, 3, 4], [2, 0, 1, 3, 4]]
29
+ ... )
30
  >>> print(results)
31
+ {'top5_error_rate': 0.0}
32
  """
33
 
 
34
  _CITATION = """
35
  """
36
 
 
44
  inputs_description=_KWARGS_DESCRIPTION,
45
  features=datasets.Features(
46
  {
47
+ "predictions": datasets.Sequence(list[datasets.Value("float")]),
48
  "references": datasets.Sequence(datasets.Value("int32")),
49
  }
50
  if self.config_name == "multilabel"
 
57
  )
58
 
59
  def _compute(
60
+ self,
61
+ *,
62
+ predictions: list[list[float]] = None,
63
+ references: list[int] = None,
64
+ **kwargs,
65
  ) -> Dict[str, Any]:
66
+ # to numpy array
67
+ outputs = np.array(predictions)
68
+ labels = np.array(references)
69
+
70
+ # Top-1 ACC
71
+ pred = outputs.argmax(axis=1)
72
+ acc = (pred == labels).mean()
73
 
74
+ # Top-5 Error Rate
75
+ top5_indices = outputs.argsort(axis=1)[:, -5:]
76
+ correct = (labels.reshape(-1, 1) == top5_indices).any(axis=1)
77
+ top5_error_rate = 1 - correct.mean()
78
 
79
  return {
80
+ "accuracy": acc,
81
+ "top5_error_rate": top5_error_rate
82
+ }
uv.lock CHANGED
@@ -736,6 +736,20 @@ wheels = [
736
  { url = "https://files.pythonhosted.org/packages/cf/6c/41c21c6c8af92b9fea313aa47c75de49e2f9a467964ee33eb0135d47eb64/pillow-11.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:67cd427c68926108778a9005f2a04adbd5e67c442ed21d95389fe1d595458756", size = 2377651 },
737
  ]
738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  [[package]]
740
  name = "propcache"
741
  version = "0.3.1"
@@ -1055,10 +1069,16 @@ version = "0.1.0"
1055
  source = { virtual = "." }
1056
  dependencies = [
1057
  { name = "evaluate", extra = ["template"] },
 
 
1058
  ]
1059
 
1060
  [package.metadata]
1061
- requires-dist = [{ name = "evaluate", extras = ["template"], specifier = ">=0.4.3" }]
 
 
 
 
1062
 
1063
  [[package]]
1064
  name = "tqdm"
 
736
  { url = "https://files.pythonhosted.org/packages/cf/6c/41c21c6c8af92b9fea313aa47c75de49e2f9a467964ee33eb0135d47eb64/pillow-11.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:67cd427c68926108778a9005f2a04adbd5e67c442ed21d95389fe1d595458756", size = 2377651 },
737
  ]
738
 
739
+ [[package]]
740
+ name = "polars"
741
+ version = "1.27.1"
742
+ source = { registry = "https://pypi.org/simple" }
743
+ sdist = { url = "https://files.pythonhosted.org/packages/e1/96/56ab877d3d690bd8e67f5c6aabfd3aa8bc7c33ee901767905f564a6ade36/polars-1.27.1.tar.gz", hash = "sha256:94fcb0216b56cd0594aa777db1760a41ad0dfffed90d2ca446cf9294d2e97f02", size = 4555382 }
744
+ wheels = [
745
+ { url = "https://files.pythonhosted.org/packages/a0/f4/be965ca4e1372805d0d2313bb4ed8eae88804fc3bfeb6cb0a07c53191bdb/polars-1.27.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:ba7ad4f8046d00dd97c1369e46a4b7e00ffcff5d38c0f847ee4b9b1bb182fb18", size = 34756840 },
746
+ { url = "https://files.pythonhosted.org/packages/c0/1a/ae019d323e83c6e8a9b4323f3fea94e047715847dfa4c4cbaf20a6f8444e/polars-1.27.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:339e3948748ad6fa7a42e613c3fb165b497ed797e93fce1aa2cddf00fbc16cac", size = 31616000 },
747
+ { url = "https://files.pythonhosted.org/packages/20/c1/c65924c0ca186f481c02b531f1ec66c34f9bbecc11d70246562bb4949876/polars-1.27.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f801e0d9da198eb97cfb4e8af4242b8396878ff67b655c71570b7e333102b72b", size = 35388976 },
748
+ { url = "https://files.pythonhosted.org/packages/88/c2/37720b8794935f1e77bde439564fa421a41f5fed8111aeb7b9ca0ebafb2d/polars-1.27.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:4d18a29c65222451818b63cd397b2e95c20412ea0065d735a20a4a79a7b26e8a", size = 32586083 },
749
+ { url = "https://files.pythonhosted.org/packages/41/3d/1bb108eb278c1eafb303f78c515fb71c9828944eba3fb5c0ac432b9fad28/polars-1.27.1-cp39-abi3-win_amd64.whl", hash = "sha256:a4f832cf478b282d97f8bf86eeae2df66fa1384de1c49bc61f7224a10cc6a5df", size = 35602500 },
750
+ { url = "https://files.pythonhosted.org/packages/0f/5c/cc23daf0a228d6fadbbfc8a8c5165be33157abe5b9d72af3e127e0542857/polars-1.27.1-cp39-abi3-win_arm64.whl", hash = "sha256:4f238ee2e3c5660345cb62c0f731bbd6768362db96c058098359ecffa42c3c6c", size = 31891470 },
751
+ ]
752
+
753
  [[package]]
754
  name = "propcache"
755
  version = "0.3.1"
 
1069
  source = { virtual = "." }
1070
  dependencies = [
1071
  { name = "evaluate", extra = ["template"] },
1072
+ { name = "gradio" },
1073
+ { name = "polars" },
1074
  ]
1075
 
1076
  [package.metadata]
1077
+ requires-dist = [
1078
+ { name = "evaluate", extras = ["template"], specifier = ">=0.4.3" },
1079
+ { name = "gradio", specifier = ">=5.24.0" },
1080
+ { name = "polars", specifier = ">=1.27.1" },
1081
+ ]
1082
 
1083
  [[package]]
1084
  name = "tqdm"