jpxkqx commited on
Commit
33e1bb1
·
1 Parent(s): 7b56406

Update value dtypes

Browse files
Files changed (1) hide show
  1. peak_signal_to_noise_ratio.py +28 -13
peak_signal_to_noise_ratio.py CHANGED
@@ -57,13 +57,32 @@ class StructuralSimilarityIndexMeasure(evaluate.Metric):
57
  description=_DESCRIPTION,
58
  citation=_CITATION,
59
  inputs_description=_KWARGS_DESCRIPTION,
60
- features=datasets.Features({
61
- "predictions": datasets.Sequence(datasets.Array2D(dtype="float32")),
62
- "references": datasets.Sequence(datasets.Array2D(dtype="float32")),
63
- }),
64
  reference_urls=["https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio"],
65
  )
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def _compute(
68
  self,
69
  predictions,
@@ -72,12 +91,8 @@ class StructuralSimilarityIndexMeasure(evaluate.Metric):
72
  sample_weight=None,
73
  ) -> Dict[str, float]:
74
  samples = zip(predictions, references)
75
- return {
76
- "psnr": np.average(
77
- list(map(
78
- lambda args: peak_signal_noise_ratio(*args, data_range),
79
- samples
80
- )),
81
- weights=sample_weight
82
- )
83
- }
 
57
  description=_DESCRIPTION,
58
  citation=_CITATION,
59
  inputs_description=_KWARGS_DESCRIPTION,
60
+ features=datasets.Features(self._get_feature_types()),
 
 
 
61
  reference_urls=["https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio"],
62
  )
63
 
64
+ def _get_feature_types(self):
65
+ if self.config_name == "multilist":
66
+ return {
67
+ # 1st Seq - num_samples, 2nd Seq - Height, 3rd Seq - Width
68
+ "predictions": datasets.Sequence(
69
+ datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
70
+ ),
71
+ "references": datasets.Sequence(
72
+ datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
73
+ ),
74
+ }
75
+ else:
76
+ return {
77
+ # 1st Seq - Height, 2rd Seq - Width
78
+ "predictions": datasets.Sequence(
79
+ datasets.Sequence(datasets.Value("float32"))
80
+ ),
81
+ "references": datasets.Sequence(
82
+ datasets.Sequence(datasets.Value("float32"))
83
+ ),
84
+ }
85
+
86
  def _compute(
87
  self,
88
  predictions,
 
91
  sample_weight=None,
92
  ) -> Dict[str, float]:
93
  samples = zip(predictions, references)
94
+ return np.average(
95
+ list(map(lambda args: peak_signal_noise_ratio(*args, data_range), samples)),
96
+ weights=sample_weight
97
+ )
98
+