lvwerra HF Staff commited on
Commit
5995eaa
·
1 Parent(s): 555c32e

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. mahalanobis.py +22 -2
  2. requirements.txt +1 -1
mahalanobis.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """Mahalanobis metric."""
15
 
 
 
 
16
  import datasets
17
  import numpy as np
18
 
@@ -57,13 +60,25 @@ Examples:
57
  """
58
 
59
 
 
 
 
 
 
 
 
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class Mahalanobis(evaluate.Metric):
62
- def _info(self):
 
 
 
63
  return evaluate.MetricInfo(
64
  description=_DESCRIPTION,
65
  citation=_CITATION,
66
  inputs_description=_KWARGS_DESCRIPTION,
 
67
  features=datasets.Features(
68
  {
69
  "X": datasets.Sequence(datasets.Value("float", id="sequence"), id="X"),
@@ -71,7 +86,12 @@ class Mahalanobis(evaluate.Metric):
71
  ),
72
  )
73
 
74
- def _compute(self, X, reference_distribution):
 
 
 
 
 
75
 
76
  # convert to numpy arrays
77
  X = np.array(X)
 
13
  # limitations under the License.
14
  """Mahalanobis metric."""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
+
19
  import datasets
20
  import numpy as np
21
 
 
60
  """
61
 
62
 
63
+ @dataclass
64
+ class MahalanobisConfig(evaluate.info.Config):
65
+
66
+ name: str = "default"
67
+
68
+ reference_distribution: Optional[List] = None
69
+
70
+
71
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
72
  class Mahalanobis(evaluate.Metric):
73
+ CONFIG_CLASS = MahalanobisConfig
74
+ ALLOWED_CONFIG_NAMES = ["default"]
75
+
76
+ def _info(self, config):
77
  return evaluate.MetricInfo(
78
  description=_DESCRIPTION,
79
  citation=_CITATION,
80
  inputs_description=_KWARGS_DESCRIPTION,
81
+ config=config,
82
  features=datasets.Features(
83
  {
84
  "X": datasets.Sequence(datasets.Value("float", id="sequence"), id="X"),
 
86
  ),
87
  )
88
 
89
+ def _compute(self, X):
90
+
91
+ if self.config.reference_distribution is None:
92
+ raise ValueError("You need to provide a `reference_distribution`.")
93
+ else:
94
+ reference_distribution = self.config.reference_distribution
95
 
96
  # convert to numpy arrays
97
  X = np.array(X)
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39