Spaces:
Runtime error
Runtime error
Update value dtypes
Browse files- peak_signal_to_noise_ratio.py +28 -13
peak_signal_to_noise_ratio.py
CHANGED
@@ -57,13 +57,32 @@ class StructuralSimilarityIndexMeasure(evaluate.Metric):
|
|
57 |
description=_DESCRIPTION,
|
58 |
citation=_CITATION,
|
59 |
inputs_description=_KWARGS_DESCRIPTION,
|
60 |
-
features=datasets.Features(
|
61 |
-
"predictions": datasets.Sequence(datasets.Array2D(dtype="float32")),
|
62 |
-
"references": datasets.Sequence(datasets.Array2D(dtype="float32")),
|
63 |
-
}),
|
64 |
reference_urls=["https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio"],
|
65 |
)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def _compute(
|
68 |
self,
|
69 |
predictions,
|
@@ -72,12 +91,8 @@ class StructuralSimilarityIndexMeasure(evaluate.Metric):
|
|
72 |
sample_weight=None,
|
73 |
) -> Dict[str, float]:
|
74 |
samples = zip(predictions, references)
|
75 |
-
return
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
)),
|
81 |
-
weights=sample_weight
|
82 |
-
)
|
83 |
-
}
|
|
|
57 |
description=_DESCRIPTION,
|
58 |
citation=_CITATION,
|
59 |
inputs_description=_KWARGS_DESCRIPTION,
|
60 |
+
features=datasets.Features(self._get_feature_types()),
|
|
|
|
|
|
|
61 |
reference_urls=["https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio"],
|
62 |
)
|
63 |
|
64 |
+
def _get_feature_types(self):
|
65 |
+
if self.config_name == "multilist":
|
66 |
+
return {
|
67 |
+
# 1st Seq - num_samples, 2nd Seq - Height, 3rd Seq - Width
|
68 |
+
"predictions": datasets.Sequence(
|
69 |
+
datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
|
70 |
+
),
|
71 |
+
"references": datasets.Sequence(
|
72 |
+
datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
|
73 |
+
),
|
74 |
+
}
|
75 |
+
else:
|
76 |
+
return {
|
77 |
+
# 1st Seq - Height, 2rd Seq - Width
|
78 |
+
"predictions": datasets.Sequence(
|
79 |
+
datasets.Sequence(datasets.Value("float32"))
|
80 |
+
),
|
81 |
+
"references": datasets.Sequence(
|
82 |
+
datasets.Sequence(datasets.Value("float32"))
|
83 |
+
),
|
84 |
+
}
|
85 |
+
|
86 |
def _compute(
|
87 |
self,
|
88 |
predictions,
|
|
|
91 |
sample_weight=None,
|
92 |
) -> Dict[str, float]:
|
93 |
samples = zip(predictions, references)
|
94 |
+
return np.average(
|
95 |
+
list(map(lambda args: peak_signal_noise_ratio(*args, data_range), samples)),
|
96 |
+
weights=sample_weight
|
97 |
+
)
|
98 |
+
|
|
|
|
|
|
|
|