|
import math |
|
from functools import partial |
|
from inspect import signature |
|
from itertools import chain, permutations, product |
|
|
|
import numpy as np |
|
import pytest |
|
|
|
from sklearn._config import config_context |
|
from sklearn.datasets import make_multilabel_classification |
|
from sklearn.exceptions import UndefinedMetricWarning |
|
from sklearn.metrics import ( |
|
accuracy_score, |
|
average_precision_score, |
|
balanced_accuracy_score, |
|
brier_score_loss, |
|
cohen_kappa_score, |
|
confusion_matrix, |
|
coverage_error, |
|
d2_absolute_error_score, |
|
d2_pinball_score, |
|
d2_tweedie_score, |
|
dcg_score, |
|
det_curve, |
|
explained_variance_score, |
|
f1_score, |
|
fbeta_score, |
|
hamming_loss, |
|
hinge_loss, |
|
jaccard_score, |
|
label_ranking_average_precision_score, |
|
label_ranking_loss, |
|
log_loss, |
|
matthews_corrcoef, |
|
max_error, |
|
mean_absolute_error, |
|
mean_absolute_percentage_error, |
|
mean_gamma_deviance, |
|
mean_pinball_loss, |
|
mean_poisson_deviance, |
|
mean_squared_error, |
|
mean_squared_log_error, |
|
mean_tweedie_deviance, |
|
median_absolute_error, |
|
multilabel_confusion_matrix, |
|
ndcg_score, |
|
precision_recall_curve, |
|
precision_score, |
|
r2_score, |
|
recall_score, |
|
roc_auc_score, |
|
roc_curve, |
|
root_mean_squared_error, |
|
root_mean_squared_log_error, |
|
top_k_accuracy_score, |
|
zero_one_loss, |
|
) |
|
from sklearn.metrics._base import _average_binary_score |
|
from sklearn.metrics.pairwise import ( |
|
additive_chi2_kernel, |
|
chi2_kernel, |
|
cosine_distances, |
|
cosine_similarity, |
|
euclidean_distances, |
|
linear_kernel, |
|
paired_cosine_distances, |
|
paired_euclidean_distances, |
|
polynomial_kernel, |
|
rbf_kernel, |
|
sigmoid_kernel, |
|
) |
|
from sklearn.preprocessing import LabelBinarizer |
|
from sklearn.utils import shuffle |
|
from sklearn.utils._array_api import ( |
|
_atol_for_type, |
|
_convert_to_numpy, |
|
yield_namespace_device_dtype_combinations, |
|
) |
|
from sklearn.utils._testing import ( |
|
_array_api_for_tests, |
|
assert_allclose, |
|
assert_almost_equal, |
|
assert_array_equal, |
|
assert_array_less, |
|
ignore_warnings, |
|
) |
|
from sklearn.utils.fixes import COO_CONTAINERS, parse_version, sp_version |
|
from sklearn.utils.multiclass import type_of_target |
|
from sklearn.utils.validation import _num_samples, check_random_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REGRESSION_METRICS = { |
|
"max_error": max_error, |
|
"mean_absolute_error": mean_absolute_error, |
|
"mean_squared_error": mean_squared_error, |
|
"mean_squared_log_error": mean_squared_log_error, |
|
"mean_pinball_loss": mean_pinball_loss, |
|
"median_absolute_error": median_absolute_error, |
|
"mean_absolute_percentage_error": mean_absolute_percentage_error, |
|
"explained_variance_score": explained_variance_score, |
|
"r2_score": partial(r2_score, multioutput="variance_weighted"), |
|
"root_mean_squared_error": root_mean_squared_error, |
|
"root_mean_squared_log_error": root_mean_squared_log_error, |
|
"mean_normal_deviance": partial(mean_tweedie_deviance, power=0), |
|
"mean_poisson_deviance": mean_poisson_deviance, |
|
"mean_gamma_deviance": mean_gamma_deviance, |
|
"mean_compound_poisson_deviance": partial(mean_tweedie_deviance, power=1.4), |
|
"d2_tweedie_score": partial(d2_tweedie_score, power=1.4), |
|
"d2_pinball_score": d2_pinball_score, |
|
"d2_absolute_error_score": d2_absolute_error_score, |
|
} |
|
|
|
CLASSIFICATION_METRICS = { |
|
"accuracy_score": accuracy_score, |
|
"balanced_accuracy_score": balanced_accuracy_score, |
|
"adjusted_balanced_accuracy_score": partial(balanced_accuracy_score, adjusted=True), |
|
"unnormalized_accuracy_score": partial(accuracy_score, normalize=False), |
|
|
|
|
|
|
|
|
|
"unnormalized_confusion_matrix": confusion_matrix, |
|
"normalized_confusion_matrix": lambda *args, **kwargs: ( |
|
confusion_matrix(*args, **kwargs).astype("float") |
|
/ confusion_matrix(*args, **kwargs).sum(axis=1)[:, np.newaxis] |
|
), |
|
"unnormalized_multilabel_confusion_matrix": multilabel_confusion_matrix, |
|
"unnormalized_multilabel_confusion_matrix_sample": partial( |
|
multilabel_confusion_matrix, samplewise=True |
|
), |
|
"hamming_loss": hamming_loss, |
|
"zero_one_loss": zero_one_loss, |
|
"unnormalized_zero_one_loss": partial(zero_one_loss, normalize=False), |
|
|
|
"jaccard_score": jaccard_score, |
|
"precision_score": precision_score, |
|
"recall_score": recall_score, |
|
"f1_score": f1_score, |
|
"f2_score": partial(fbeta_score, beta=2), |
|
"f0.5_score": partial(fbeta_score, beta=0.5), |
|
"matthews_corrcoef_score": matthews_corrcoef, |
|
"weighted_f0.5_score": partial(fbeta_score, average="weighted", beta=0.5), |
|
"weighted_f1_score": partial(f1_score, average="weighted"), |
|
"weighted_f2_score": partial(fbeta_score, average="weighted", beta=2), |
|
"weighted_precision_score": partial(precision_score, average="weighted"), |
|
"weighted_recall_score": partial(recall_score, average="weighted"), |
|
"weighted_jaccard_score": partial(jaccard_score, average="weighted"), |
|
"micro_f0.5_score": partial(fbeta_score, average="micro", beta=0.5), |
|
"micro_f1_score": partial(f1_score, average="micro"), |
|
"micro_f2_score": partial(fbeta_score, average="micro", beta=2), |
|
"micro_precision_score": partial(precision_score, average="micro"), |
|
"micro_recall_score": partial(recall_score, average="micro"), |
|
"micro_jaccard_score": partial(jaccard_score, average="micro"), |
|
"macro_f0.5_score": partial(fbeta_score, average="macro", beta=0.5), |
|
"macro_f1_score": partial(f1_score, average="macro"), |
|
"macro_f2_score": partial(fbeta_score, average="macro", beta=2), |
|
"macro_precision_score": partial(precision_score, average="macro"), |
|
"macro_recall_score": partial(recall_score, average="macro"), |
|
"macro_jaccard_score": partial(jaccard_score, average="macro"), |
|
"samples_f0.5_score": partial(fbeta_score, average="samples", beta=0.5), |
|
"samples_f1_score": partial(f1_score, average="samples"), |
|
"samples_f2_score": partial(fbeta_score, average="samples", beta=2), |
|
"samples_precision_score": partial(precision_score, average="samples"), |
|
"samples_recall_score": partial(recall_score, average="samples"), |
|
"samples_jaccard_score": partial(jaccard_score, average="samples"), |
|
"cohen_kappa_score": cohen_kappa_score, |
|
} |
|
|
|
|
|
def precision_recall_curve_padded_thresholds(*args, **kwargs): |
|
""" |
|
The dimensions of precision-recall pairs and the threshold array as |
|
returned by the precision_recall_curve do not match. See |
|
func:`sklearn.metrics.precision_recall_curve` |
|
|
|
This prevents implicit conversion of return value triple to an higher |
|
dimensional np.array of dtype('float64') (it will be of dtype('object) |
|
instead). This again is needed for assert_array_equal to work correctly. |
|
|
|
As a workaround we pad the threshold array with NaN values to match |
|
the dimension of precision and recall arrays respectively. |
|
""" |
|
precision, recall, thresholds = precision_recall_curve(*args, **kwargs) |
|
|
|
pad_threshholds = len(precision) - len(thresholds) |
|
|
|
return np.array( |
|
[ |
|
precision, |
|
recall, |
|
np.pad( |
|
thresholds.astype(np.float64), |
|
pad_width=(0, pad_threshholds), |
|
mode="constant", |
|
constant_values=[np.nan], |
|
), |
|
] |
|
) |
|
|
|
|
|
CURVE_METRICS = { |
|
"roc_curve": roc_curve, |
|
"precision_recall_curve": precision_recall_curve_padded_thresholds, |
|
"det_curve": det_curve, |
|
} |
|
|
|
THRESHOLDED_METRICS = { |
|
"coverage_error": coverage_error, |
|
"label_ranking_loss": label_ranking_loss, |
|
"log_loss": log_loss, |
|
"unnormalized_log_loss": partial(log_loss, normalize=False), |
|
"hinge_loss": hinge_loss, |
|
"brier_score_loss": brier_score_loss, |
|
"roc_auc_score": roc_auc_score, |
|
"weighted_roc_auc": partial(roc_auc_score, average="weighted"), |
|
"samples_roc_auc": partial(roc_auc_score, average="samples"), |
|
"micro_roc_auc": partial(roc_auc_score, average="micro"), |
|
"ovr_roc_auc": partial(roc_auc_score, average="macro", multi_class="ovr"), |
|
"weighted_ovr_roc_auc": partial( |
|
roc_auc_score, average="weighted", multi_class="ovr" |
|
), |
|
"ovo_roc_auc": partial(roc_auc_score, average="macro", multi_class="ovo"), |
|
"weighted_ovo_roc_auc": partial( |
|
roc_auc_score, average="weighted", multi_class="ovo" |
|
), |
|
"partial_roc_auc": partial(roc_auc_score, max_fpr=0.5), |
|
"average_precision_score": average_precision_score, |
|
"weighted_average_precision_score": partial( |
|
average_precision_score, average="weighted" |
|
), |
|
"samples_average_precision_score": partial( |
|
average_precision_score, average="samples" |
|
), |
|
"micro_average_precision_score": partial(average_precision_score, average="micro"), |
|
"label_ranking_average_precision_score": label_ranking_average_precision_score, |
|
"ndcg_score": ndcg_score, |
|
"dcg_score": dcg_score, |
|
"top_k_accuracy_score": top_k_accuracy_score, |
|
} |
|
|
|
ALL_METRICS = dict() |
|
ALL_METRICS.update(THRESHOLDED_METRICS) |
|
ALL_METRICS.update(CLASSIFICATION_METRICS) |
|
ALL_METRICS.update(REGRESSION_METRICS) |
|
ALL_METRICS.update(CURVE_METRICS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
METRIC_UNDEFINED_BINARY = { |
|
"samples_f0.5_score", |
|
"samples_f1_score", |
|
"samples_f2_score", |
|
"samples_precision_score", |
|
"samples_recall_score", |
|
"samples_jaccard_score", |
|
"coverage_error", |
|
"unnormalized_multilabel_confusion_matrix_sample", |
|
"label_ranking_loss", |
|
"label_ranking_average_precision_score", |
|
"dcg_score", |
|
"ndcg_score", |
|
} |
|
|
|
|
|
METRIC_UNDEFINED_MULTICLASS = { |
|
"brier_score_loss", |
|
"micro_roc_auc", |
|
"samples_roc_auc", |
|
"partial_roc_auc", |
|
"roc_auc_score", |
|
"weighted_roc_auc", |
|
"jaccard_score", |
|
|
|
"precision_score", |
|
"recall_score", |
|
"f1_score", |
|
"f2_score", |
|
"f0.5_score", |
|
|
|
"roc_curve", |
|
"precision_recall_curve", |
|
"det_curve", |
|
} |
|
|
|
|
|
METRIC_UNDEFINED_BINARY_MULTICLASS = METRIC_UNDEFINED_BINARY.union( |
|
METRIC_UNDEFINED_MULTICLASS |
|
) |
|
|
|
|
|
METRICS_WITH_AVERAGING = { |
|
"precision_score", |
|
"recall_score", |
|
"f1_score", |
|
"f2_score", |
|
"f0.5_score", |
|
"jaccard_score", |
|
} |
|
|
|
|
|
THRESHOLDED_METRICS_WITH_AVERAGING = { |
|
"roc_auc_score", |
|
"average_precision_score", |
|
"partial_roc_auc", |
|
} |
|
|
|
|
|
METRICS_WITH_POS_LABEL = { |
|
"roc_curve", |
|
"precision_recall_curve", |
|
"det_curve", |
|
"brier_score_loss", |
|
"precision_score", |
|
"recall_score", |
|
"f1_score", |
|
"f2_score", |
|
"f0.5_score", |
|
"jaccard_score", |
|
"average_precision_score", |
|
"weighted_average_precision_score", |
|
"micro_average_precision_score", |
|
"samples_average_precision_score", |
|
} |
|
|
|
|
|
|
|
|
|
METRICS_WITH_LABELS = { |
|
"unnormalized_confusion_matrix", |
|
"normalized_confusion_matrix", |
|
"roc_curve", |
|
"precision_recall_curve", |
|
"det_curve", |
|
"precision_score", |
|
"recall_score", |
|
"f1_score", |
|
"f2_score", |
|
"f0.5_score", |
|
"jaccard_score", |
|
"weighted_f0.5_score", |
|
"weighted_f1_score", |
|
"weighted_f2_score", |
|
"weighted_precision_score", |
|
"weighted_recall_score", |
|
"weighted_jaccard_score", |
|
"micro_f0.5_score", |
|
"micro_f1_score", |
|
"micro_f2_score", |
|
"micro_precision_score", |
|
"micro_recall_score", |
|
"micro_jaccard_score", |
|
"macro_f0.5_score", |
|
"macro_f1_score", |
|
"macro_f2_score", |
|
"macro_precision_score", |
|
"macro_recall_score", |
|
"macro_jaccard_score", |
|
"unnormalized_multilabel_confusion_matrix", |
|
"unnormalized_multilabel_confusion_matrix_sample", |
|
"cohen_kappa_score", |
|
} |
|
|
|
|
|
METRICS_WITH_NORMALIZE_OPTION = { |
|
"accuracy_score", |
|
"top_k_accuracy_score", |
|
"zero_one_loss", |
|
} |
|
|
|
|
|
THRESHOLDED_MULTILABEL_METRICS = { |
|
"log_loss", |
|
"unnormalized_log_loss", |
|
"roc_auc_score", |
|
"weighted_roc_auc", |
|
"samples_roc_auc", |
|
"micro_roc_auc", |
|
"partial_roc_auc", |
|
"average_precision_score", |
|
"weighted_average_precision_score", |
|
"samples_average_precision_score", |
|
"micro_average_precision_score", |
|
"coverage_error", |
|
"label_ranking_loss", |
|
"ndcg_score", |
|
"dcg_score", |
|
"label_ranking_average_precision_score", |
|
} |
|
|
|
|
|
MULTILABELS_METRICS = { |
|
"accuracy_score", |
|
"unnormalized_accuracy_score", |
|
"hamming_loss", |
|
"zero_one_loss", |
|
"unnormalized_zero_one_loss", |
|
"weighted_f0.5_score", |
|
"weighted_f1_score", |
|
"weighted_f2_score", |
|
"weighted_precision_score", |
|
"weighted_recall_score", |
|
"weighted_jaccard_score", |
|
"macro_f0.5_score", |
|
"macro_f1_score", |
|
"macro_f2_score", |
|
"macro_precision_score", |
|
"macro_recall_score", |
|
"macro_jaccard_score", |
|
"micro_f0.5_score", |
|
"micro_f1_score", |
|
"micro_f2_score", |
|
"micro_precision_score", |
|
"micro_recall_score", |
|
"micro_jaccard_score", |
|
"unnormalized_multilabel_confusion_matrix", |
|
"samples_f0.5_score", |
|
"samples_f1_score", |
|
"samples_f2_score", |
|
"samples_precision_score", |
|
"samples_recall_score", |
|
"samples_jaccard_score", |
|
} |
|
|
|
|
|
MULTIOUTPUT_METRICS = { |
|
"mean_absolute_error", |
|
"median_absolute_error", |
|
"mean_squared_error", |
|
"mean_squared_log_error", |
|
"r2_score", |
|
"root_mean_squared_error", |
|
"root_mean_squared_log_error", |
|
"explained_variance_score", |
|
"mean_absolute_percentage_error", |
|
"mean_pinball_loss", |
|
"d2_pinball_score", |
|
"d2_absolute_error_score", |
|
} |
|
|
|
|
|
|
|
SYMMETRIC_METRICS = { |
|
"accuracy_score", |
|
"unnormalized_accuracy_score", |
|
"hamming_loss", |
|
"zero_one_loss", |
|
"unnormalized_zero_one_loss", |
|
"micro_jaccard_score", |
|
"macro_jaccard_score", |
|
"jaccard_score", |
|
"samples_jaccard_score", |
|
"f1_score", |
|
"micro_f1_score", |
|
"macro_f1_score", |
|
"weighted_recall_score", |
|
"mean_squared_log_error", |
|
"root_mean_squared_error", |
|
"root_mean_squared_log_error", |
|
|
|
"micro_f0.5_score", |
|
"micro_f1_score", |
|
"micro_f2_score", |
|
"micro_precision_score", |
|
"micro_recall_score", |
|
"matthews_corrcoef_score", |
|
"mean_absolute_error", |
|
"mean_squared_error", |
|
"median_absolute_error", |
|
"max_error", |
|
|
|
"mean_pinball_loss", |
|
"cohen_kappa_score", |
|
"mean_normal_deviance", |
|
} |
|
|
|
|
|
|
|
NOT_SYMMETRIC_METRICS = { |
|
"balanced_accuracy_score", |
|
"adjusted_balanced_accuracy_score", |
|
"explained_variance_score", |
|
"r2_score", |
|
"unnormalized_confusion_matrix", |
|
"normalized_confusion_matrix", |
|
"roc_curve", |
|
"precision_recall_curve", |
|
"det_curve", |
|
"precision_score", |
|
"recall_score", |
|
"f2_score", |
|
"f0.5_score", |
|
"weighted_f0.5_score", |
|
"weighted_f1_score", |
|
"weighted_f2_score", |
|
"weighted_precision_score", |
|
"weighted_jaccard_score", |
|
"unnormalized_multilabel_confusion_matrix", |
|
"macro_f0.5_score", |
|
"macro_f2_score", |
|
"macro_precision_score", |
|
"macro_recall_score", |
|
"hinge_loss", |
|
"mean_gamma_deviance", |
|
"mean_poisson_deviance", |
|
"mean_compound_poisson_deviance", |
|
"d2_tweedie_score", |
|
"d2_pinball_score", |
|
"d2_absolute_error_score", |
|
"mean_absolute_percentage_error", |
|
} |
|
|
|
|
|
|
|
METRICS_WITHOUT_SAMPLE_WEIGHT = { |
|
"median_absolute_error", |
|
"max_error", |
|
"ovo_roc_auc", |
|
"weighted_ovo_roc_auc", |
|
} |
|
|
|
METRICS_REQUIRE_POSITIVE_Y = { |
|
"mean_poisson_deviance", |
|
"mean_gamma_deviance", |
|
"mean_compound_poisson_deviance", |
|
"d2_tweedie_score", |
|
} |
|
|
|
|
|
METRICS_WITH_LOG1P_Y = { |
|
"mean_squared_log_error", |
|
"root_mean_squared_log_error", |
|
} |
|
|
|
|
|
def _require_positive_targets(y1, y2): |
|
"""Make targets strictly positive""" |
|
offset = abs(min(y1.min(), y2.min())) + 1 |
|
y1 += offset |
|
y2 += offset |
|
return y1, y2 |
|
|
|
|
|
def _require_log1p_targets(y1, y2): |
|
"""Make targets strictly larger than -1""" |
|
offset = abs(min(y1.min(), y2.min())) - 0.99 |
|
y1 = y1.astype(np.float64) |
|
y2 = y2.astype(np.float64) |
|
y1 += offset |
|
y2 += offset |
|
return y1, y2 |
|
|
|
|
|
def test_symmetry_consistency(): |
|
|
|
assert ( |
|
SYMMETRIC_METRICS |
|
| NOT_SYMMETRIC_METRICS |
|
| set(THRESHOLDED_METRICS) |
|
| METRIC_UNDEFINED_BINARY_MULTICLASS |
|
) == set(ALL_METRICS) |
|
|
|
assert (SYMMETRIC_METRICS & NOT_SYMMETRIC_METRICS) == set() |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(SYMMETRIC_METRICS)) |
|
def test_symmetric_metric(name): |
|
|
|
random_state = check_random_state(0) |
|
y_true = random_state.randint(0, 2, size=(20,)) |
|
y_pred = random_state.randint(0, 2, size=(20,)) |
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y: |
|
y_true, y_pred = _require_positive_targets(y_true, y_pred) |
|
|
|
elif name in METRICS_WITH_LOG1P_Y: |
|
y_true, y_pred = _require_log1p_targets(y_true, y_pred) |
|
|
|
y_true_bin = random_state.randint(0, 2, size=(20, 25)) |
|
y_pred_bin = random_state.randint(0, 2, size=(20, 25)) |
|
|
|
metric = ALL_METRICS[name] |
|
if name in METRIC_UNDEFINED_BINARY: |
|
if name in MULTILABELS_METRICS: |
|
assert_allclose( |
|
metric(y_true_bin, y_pred_bin), |
|
metric(y_pred_bin, y_true_bin), |
|
err_msg="%s is not symmetric" % name, |
|
) |
|
else: |
|
assert False, "This case is currently unhandled" |
|
else: |
|
assert_allclose( |
|
metric(y_true, y_pred), |
|
metric(y_pred, y_true), |
|
err_msg="%s is not symmetric" % name, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(NOT_SYMMETRIC_METRICS)) |
|
def test_not_symmetric_metric(name): |
|
|
|
random_state = check_random_state(0) |
|
y_true = random_state.randint(0, 2, size=(20,)) |
|
y_pred = random_state.randint(0, 2, size=(20,)) |
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y: |
|
y_true, y_pred = _require_positive_targets(y_true, y_pred) |
|
|
|
metric = ALL_METRICS[name] |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
assert_array_equal(metric(y_true, y_pred), metric(y_pred, y_true)) |
|
raise ValueError("%s seems to be symmetric" % name) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(set(ALL_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS) |
|
) |
|
def test_sample_order_invariance(name): |
|
random_state = check_random_state(0) |
|
y_true = random_state.randint(0, 2, size=(20,)) |
|
y_pred = random_state.randint(0, 2, size=(20,)) |
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y: |
|
y_true, y_pred = _require_positive_targets(y_true, y_pred) |
|
elif name in METRICS_WITH_LOG1P_Y: |
|
y_true, y_pred = _require_log1p_targets(y_true, y_pred) |
|
|
|
y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0) |
|
|
|
with ignore_warnings(): |
|
metric = ALL_METRICS[name] |
|
assert_allclose( |
|
metric(y_true, y_pred), |
|
metric(y_true_shuffle, y_pred_shuffle), |
|
err_msg="%s is not sample order invariant" % name, |
|
) |
|
|
|
|
|
def test_sample_order_invariance_multilabel_and_multioutput(): |
|
random_state = check_random_state(0) |
|
|
|
|
|
y_true = random_state.randint(0, 2, size=(20, 25)) |
|
y_pred = random_state.randint(0, 2, size=(20, 25)) |
|
y_score = random_state.uniform(size=y_true.shape) |
|
|
|
|
|
y_score /= y_score.sum(axis=1, keepdims=True) |
|
|
|
y_true_shuffle, y_pred_shuffle, y_score_shuffle = shuffle( |
|
y_true, y_pred, y_score, random_state=0 |
|
) |
|
|
|
for name in MULTILABELS_METRICS: |
|
metric = ALL_METRICS[name] |
|
assert_allclose( |
|
metric(y_true, y_pred), |
|
metric(y_true_shuffle, y_pred_shuffle), |
|
err_msg="%s is not sample order invariant" % name, |
|
) |
|
|
|
for name in THRESHOLDED_MULTILABEL_METRICS: |
|
metric = ALL_METRICS[name] |
|
assert_allclose( |
|
metric(y_true, y_score), |
|
metric(y_true_shuffle, y_score_shuffle), |
|
err_msg="%s is not sample order invariant" % name, |
|
) |
|
|
|
for name in MULTIOUTPUT_METRICS: |
|
metric = ALL_METRICS[name] |
|
assert_allclose( |
|
metric(y_true, y_score), |
|
metric(y_true_shuffle, y_score_shuffle), |
|
err_msg="%s is not sample order invariant" % name, |
|
) |
|
assert_allclose( |
|
metric(y_true, y_pred), |
|
metric(y_true_shuffle, y_pred_shuffle), |
|
err_msg="%s is not sample order invariant" % name, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(set(ALL_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS) |
|
) |
|
def test_format_invariance_with_1d_vectors(name): |
|
random_state = check_random_state(0) |
|
y1 = random_state.randint(0, 2, size=(20,)) |
|
y2 = random_state.randint(0, 2, size=(20,)) |
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y: |
|
y1, y2 = _require_positive_targets(y1, y2) |
|
elif name in METRICS_WITH_LOG1P_Y: |
|
y1, y2 = _require_log1p_targets(y1, y2) |
|
|
|
y1_list = list(y1) |
|
y2_list = list(y2) |
|
|
|
y1_1d, y2_1d = np.array(y1), np.array(y2) |
|
assert_array_equal(y1_1d.ndim, 1) |
|
assert_array_equal(y2_1d.ndim, 1) |
|
y1_column = np.reshape(y1_1d, (-1, 1)) |
|
y2_column = np.reshape(y2_1d, (-1, 1)) |
|
y1_row = np.reshape(y1_1d, (1, -1)) |
|
y2_row = np.reshape(y2_1d, (1, -1)) |
|
|
|
with ignore_warnings(): |
|
metric = ALL_METRICS[name] |
|
|
|
measure = metric(y1, y2) |
|
|
|
assert_allclose( |
|
metric(y1_list, y2_list), |
|
measure, |
|
err_msg="%s is not representation invariant with list" % name, |
|
) |
|
|
|
assert_allclose( |
|
metric(y1_1d, y2_1d), |
|
measure, |
|
err_msg="%s is not representation invariant with np-array-1d" % name, |
|
) |
|
|
|
assert_allclose( |
|
metric(y1_column, y2_column), |
|
measure, |
|
err_msg="%s is not representation invariant with np-array-column" % name, |
|
) |
|
|
|
|
|
assert_allclose( |
|
metric(y1_1d, y2_list), |
|
measure, |
|
err_msg="%s is not representation invariant with mix np-array-1d and list" |
|
% name, |
|
) |
|
|
|
assert_allclose( |
|
metric(y1_list, y2_1d), |
|
measure, |
|
err_msg="%s is not representation invariant with mix np-array-1d and list" |
|
% name, |
|
) |
|
|
|
assert_allclose( |
|
metric(y1_1d, y2_column), |
|
measure, |
|
err_msg=( |
|
"%s is not representation invariant with mix " |
|
"np-array-1d and np-array-column" |
|
) |
|
% name, |
|
) |
|
|
|
assert_allclose( |
|
metric(y1_column, y2_1d), |
|
measure, |
|
err_msg=( |
|
"%s is not representation invariant with mix " |
|
"np-array-1d and np-array-column" |
|
) |
|
% name, |
|
) |
|
|
|
assert_allclose( |
|
metric(y1_list, y2_column), |
|
measure, |
|
err_msg=( |
|
"%s is not representation invariant with mix list and np-array-column" |
|
) |
|
% name, |
|
) |
|
|
|
assert_allclose( |
|
metric(y1_column, y2_list), |
|
measure, |
|
err_msg=( |
|
"%s is not representation invariant with mix list and np-array-column" |
|
) |
|
% name, |
|
) |
|
|
|
|
|
with pytest.raises(ValueError): |
|
metric(y1_1d, y2_row) |
|
with pytest.raises(ValueError): |
|
metric(y1_row, y2_1d) |
|
with pytest.raises(ValueError): |
|
metric(y1_list, y2_row) |
|
with pytest.raises(ValueError): |
|
metric(y1_row, y2_list) |
|
with pytest.raises(ValueError): |
|
metric(y1_column, y2_row) |
|
with pytest.raises(ValueError): |
|
metric(y1_row, y2_column) |
|
|
|
|
|
|
|
if name not in ( |
|
MULTIOUTPUT_METRICS | THRESHOLDED_MULTILABEL_METRICS | MULTILABELS_METRICS |
|
): |
|
if "roc_auc" in name: |
|
|
|
|
|
with pytest.warns(UndefinedMetricWarning): |
|
assert math.isnan(metric(y1_row, y2_row)) |
|
else: |
|
with pytest.raises(ValueError): |
|
metric(y1_row, y2_row) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(set(CLASSIFICATION_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS) |
|
) |
|
def test_classification_invariance_string_vs_numbers_labels(name): |
|
|
|
random_state = check_random_state(0) |
|
y1 = random_state.randint(0, 2, size=(20,)) |
|
y2 = random_state.randint(0, 2, size=(20,)) |
|
|
|
y1_str = np.array(["eggs", "spam"])[y1] |
|
y2_str = np.array(["eggs", "spam"])[y2] |
|
|
|
pos_label_str = "spam" |
|
labels_str = ["eggs", "spam"] |
|
|
|
with ignore_warnings(): |
|
metric = CLASSIFICATION_METRICS[name] |
|
measure_with_number = metric(y1, y2) |
|
|
|
|
|
metric_str = metric |
|
if name in METRICS_WITH_POS_LABEL: |
|
metric_str = partial(metric_str, pos_label=pos_label_str) |
|
|
|
measure_with_str = metric_str(y1_str, y2_str) |
|
|
|
assert_array_equal( |
|
measure_with_number, |
|
measure_with_str, |
|
err_msg="{0} failed string vs number invariance test".format(name), |
|
) |
|
|
|
measure_with_strobj = metric_str(y1_str.astype("O"), y2_str.astype("O")) |
|
assert_array_equal( |
|
measure_with_number, |
|
measure_with_strobj, |
|
err_msg="{0} failed string object vs number invariance test".format(name), |
|
) |
|
|
|
if name in METRICS_WITH_LABELS: |
|
metric_str = partial(metric_str, labels=labels_str) |
|
measure_with_str = metric_str(y1_str, y2_str) |
|
assert_array_equal( |
|
measure_with_number, |
|
measure_with_str, |
|
err_msg="{0} failed string vs number invariance test".format(name), |
|
) |
|
|
|
measure_with_strobj = metric_str(y1_str.astype("O"), y2_str.astype("O")) |
|
assert_array_equal( |
|
measure_with_number, |
|
measure_with_strobj, |
|
err_msg="{0} failed string vs number invariance test".format(name), |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("name", THRESHOLDED_METRICS) |
|
def test_thresholded_invariance_string_vs_numbers_labels(name): |
|
|
|
random_state = check_random_state(0) |
|
y1 = random_state.randint(0, 2, size=(20,)) |
|
y2 = random_state.randint(0, 2, size=(20,)) |
|
|
|
y1_str = np.array(["eggs", "spam"])[y1] |
|
|
|
pos_label_str = "spam" |
|
|
|
with ignore_warnings(): |
|
metric = THRESHOLDED_METRICS[name] |
|
if name not in METRIC_UNDEFINED_BINARY: |
|
|
|
metric_str = metric |
|
if name in METRICS_WITH_POS_LABEL: |
|
metric_str = partial(metric_str, pos_label=pos_label_str) |
|
|
|
measure_with_number = metric(y1, y2) |
|
measure_with_str = metric_str(y1_str, y2) |
|
assert_array_equal( |
|
measure_with_number, |
|
measure_with_str, |
|
err_msg="{0} failed string vs number invariance test".format(name), |
|
) |
|
|
|
measure_with_strobj = metric_str(y1_str.astype("O"), y2) |
|
assert_array_equal( |
|
measure_with_number, |
|
measure_with_strobj, |
|
err_msg="{0} failed string object vs number invariance test".format( |
|
name |
|
), |
|
) |
|
else: |
|
|
|
with pytest.raises(ValueError): |
|
metric(y1_str, y2) |
|
with pytest.raises(ValueError): |
|
metric(y1_str.astype("O"), y2) |
|
|
|
|
|
invalids_nan_inf = [ |
|
([0, 1], [np.inf, np.inf]), |
|
([0, 1], [np.nan, np.nan]), |
|
([0, 1], [np.nan, np.inf]), |
|
([0, 1], [np.inf, 1]), |
|
([0, 1], [np.nan, 1]), |
|
] |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"metric", chain(THRESHOLDED_METRICS.values(), REGRESSION_METRICS.values()) |
|
) |
|
@pytest.mark.parametrize("y_true, y_score", invalids_nan_inf) |
|
def test_regression_thresholded_inf_nan_input(metric, y_true, y_score): |
|
|
|
if metric == coverage_error: |
|
y_true = [y_true] |
|
y_score = [y_score] |
|
with pytest.raises(ValueError, match=r"contains (NaN|infinity)"): |
|
metric(y_true, y_score) |
|
|
|
|
|
@pytest.mark.parametrize("metric", CLASSIFICATION_METRICS.values()) |
|
@pytest.mark.parametrize( |
|
"y_true, y_score", |
|
invalids_nan_inf + |
|
|
|
|
|
|
|
[ |
|
([np.nan, 1, 2], [1, 2, 3]), |
|
([np.inf, 1, 2], [1, 2, 3]), |
|
], |
|
) |
|
def test_classification_inf_nan_input(metric, y_true, y_score): |
|
"""check that classification metrics raise a message mentioning the |
|
occurrence of non-finite values in the target vectors.""" |
|
if not np.isfinite(y_true).all(): |
|
input_name = "y_true" |
|
if np.isnan(y_true).any(): |
|
unexpected_value = "NaN" |
|
else: |
|
unexpected_value = "infinity or a value too large" |
|
else: |
|
input_name = "y_pred" |
|
if np.isnan(y_score).any(): |
|
unexpected_value = "NaN" |
|
else: |
|
unexpected_value = "infinity or a value too large" |
|
|
|
err_msg = f"Input {input_name} contains {unexpected_value}" |
|
|
|
with pytest.raises(ValueError, match=err_msg): |
|
metric(y_true, y_score) |
|
|
|
|
|
@pytest.mark.parametrize("metric", CLASSIFICATION_METRICS.values()) |
|
def test_classification_binary_continuous_input(metric): |
|
"""check that classification metrics raise a message of mixed type data |
|
with continuous/binary target vectors.""" |
|
y_true, y_score = ["a", "b", "a"], [0.1, 0.2, 0.3] |
|
err_msg = ( |
|
"Classification metrics can't handle a mix of binary and continuous targets" |
|
) |
|
with pytest.raises(ValueError, match=err_msg): |
|
metric(y_true, y_score) |
|
|
|
|
|
def check_single_sample(name): |
|
|
|
|
|
|
|
|
|
metric = ALL_METRICS[name] |
|
|
|
|
|
if name in METRICS_REQUIRE_POSITIVE_Y: |
|
values = [1, 2] |
|
elif name in METRICS_WITH_LOG1P_Y: |
|
values = [-0.7, 1] |
|
else: |
|
values = [0, 1] |
|
for i, j in product(values, repeat=2): |
|
metric([i], [j]) |
|
|
|
|
|
def check_single_sample_multioutput(name): |
|
metric = ALL_METRICS[name] |
|
for i, j, k, l in product([0, 1], repeat=4): |
|
metric(np.array([[i, j]]), np.array([[k, l]])) |
|
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore") |
|
@pytest.mark.parametrize( |
|
"name", |
|
sorted( |
|
set(ALL_METRICS) |
|
|
|
|
|
- METRIC_UNDEFINED_BINARY_MULTICLASS |
|
- set(THRESHOLDED_METRICS) |
|
), |
|
) |
|
def test_single_sample(name): |
|
check_single_sample(name) |
|
|
|
|
|
|
|
@pytest.mark.filterwarnings("ignore") |
|
@pytest.mark.parametrize("name", sorted(MULTIOUTPUT_METRICS | MULTILABELS_METRICS)) |
|
def test_single_sample_multioutput(name): |
|
check_single_sample_multioutput(name) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(MULTIOUTPUT_METRICS)) |
|
def test_multioutput_number_of_output_differ(name): |
|
y_true = np.array([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]]) |
|
y_pred = np.array([[0, 0], [1, 0], [0, 0]]) |
|
|
|
metric = ALL_METRICS[name] |
|
with pytest.raises(ValueError): |
|
metric(y_true, y_pred) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(MULTIOUTPUT_METRICS)) |
|
def test_multioutput_regression_invariance_to_dimension_shuffling(name): |
|
|
|
random_state = check_random_state(0) |
|
y_true = random_state.uniform(0, 2, size=(20, 5)) |
|
y_pred = random_state.uniform(0, 2, size=(20, 5)) |
|
|
|
metric = ALL_METRICS[name] |
|
error = metric(y_true, y_pred) |
|
|
|
for _ in range(3): |
|
perm = random_state.permutation(y_true.shape[1]) |
|
assert_allclose( |
|
metric(y_true[:, perm], y_pred[:, perm]), |
|
error, |
|
err_msg="%s is not dimension shuffling invariant" % (name), |
|
) |
|
|
|
|
|
@pytest.mark.filterwarnings("ignore::sklearn.exceptions.UndefinedMetricWarning") |
|
@pytest.mark.parametrize("coo_container", COO_CONTAINERS) |
|
def test_multilabel_representation_invariance(coo_container): |
|
|
|
n_classes = 4 |
|
n_samples = 50 |
|
|
|
_, y1 = make_multilabel_classification( |
|
n_features=1, |
|
n_classes=n_classes, |
|
random_state=0, |
|
n_samples=n_samples, |
|
allow_unlabeled=True, |
|
) |
|
_, y2 = make_multilabel_classification( |
|
n_features=1, |
|
n_classes=n_classes, |
|
random_state=1, |
|
n_samples=n_samples, |
|
allow_unlabeled=True, |
|
) |
|
|
|
|
|
y1 = np.vstack([y1, [[0] * n_classes]]) |
|
y2 = np.vstack([y2, [[0] * n_classes]]) |
|
|
|
y1_sparse_indicator = coo_container(y1) |
|
y2_sparse_indicator = coo_container(y2) |
|
|
|
y1_list_array_indicator = list(y1) |
|
y2_list_array_indicator = list(y2) |
|
|
|
y1_list_list_indicator = [list(a) for a in y1_list_array_indicator] |
|
y2_list_list_indicator = [list(a) for a in y2_list_array_indicator] |
|
|
|
for name in MULTILABELS_METRICS: |
|
metric = ALL_METRICS[name] |
|
|
|
|
|
if isinstance(metric, partial): |
|
metric.__module__ = "tmp" |
|
metric.__name__ = name |
|
|
|
measure = metric(y1, y2) |
|
|
|
|
|
assert_allclose( |
|
metric(y1_sparse_indicator, y2_sparse_indicator), |
|
measure, |
|
err_msg=( |
|
"%s failed representation invariance between " |
|
"dense and sparse indicator formats." |
|
) |
|
% name, |
|
) |
|
assert_almost_equal( |
|
metric(y1_list_list_indicator, y2_list_list_indicator), |
|
measure, |
|
err_msg=( |
|
"%s failed representation invariance " |
|
"between dense array and list of list " |
|
"indicator formats." |
|
) |
|
% name, |
|
) |
|
assert_almost_equal( |
|
metric(y1_list_array_indicator, y2_list_array_indicator), |
|
measure, |
|
err_msg=( |
|
"%s failed representation invariance " |
|
"between dense and list of array " |
|
"indicator formats." |
|
) |
|
% name, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(MULTILABELS_METRICS)) |
|
def test_raise_value_error_multilabel_sequences(name): |
|
|
|
multilabel_sequences = [ |
|
[[1], [2], [0, 1]], |
|
[(), (2), (0, 1)], |
|
[[]], |
|
[()], |
|
np.array([[], [1, 2]], dtype="object"), |
|
] |
|
|
|
metric = ALL_METRICS[name] |
|
for seq in multilabel_sequences: |
|
with pytest.raises(ValueError): |
|
metric(seq, seq) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_NORMALIZE_OPTION)) |
|
def test_normalize_option_binary_classification(name): |
|
|
|
n_classes = 2 |
|
n_samples = 20 |
|
random_state = check_random_state(0) |
|
|
|
y_true = random_state.randint(0, n_classes, size=(n_samples,)) |
|
y_pred = random_state.randint(0, n_classes, size=(n_samples,)) |
|
y_score = random_state.normal(size=y_true.shape) |
|
|
|
metrics = ALL_METRICS[name] |
|
pred = y_score if name in THRESHOLDED_METRICS else y_pred |
|
measure_normalized = metrics(y_true, pred, normalize=True) |
|
measure_not_normalized = metrics(y_true, pred, normalize=False) |
|
|
|
assert_array_less( |
|
-1.0 * measure_normalized, |
|
0, |
|
err_msg="We failed to test correctly the normalize option", |
|
) |
|
|
|
assert_allclose( |
|
measure_normalized, |
|
measure_not_normalized / n_samples, |
|
err_msg=f"Failed with {name}", |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_NORMALIZE_OPTION)) |
|
def test_normalize_option_multiclass_classification(name): |
|
|
|
n_classes = 4 |
|
n_samples = 20 |
|
random_state = check_random_state(0) |
|
|
|
y_true = random_state.randint(0, n_classes, size=(n_samples,)) |
|
y_pred = random_state.randint(0, n_classes, size=(n_samples,)) |
|
y_score = random_state.uniform(size=(n_samples, n_classes)) |
|
|
|
metrics = ALL_METRICS[name] |
|
pred = y_score if name in THRESHOLDED_METRICS else y_pred |
|
measure_normalized = metrics(y_true, pred, normalize=True) |
|
measure_not_normalized = metrics(y_true, pred, normalize=False) |
|
|
|
assert_array_less( |
|
-1.0 * measure_normalized, |
|
0, |
|
err_msg="We failed to test correctly the normalize option", |
|
) |
|
|
|
assert_allclose( |
|
measure_normalized, |
|
measure_not_normalized / n_samples, |
|
err_msg=f"Failed with {name}", |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(METRICS_WITH_NORMALIZE_OPTION.intersection(MULTILABELS_METRICS)) |
|
) |
|
def test_normalize_option_multilabel_classification(name): |
|
|
|
n_classes = 4 |
|
n_samples = 100 |
|
random_state = check_random_state(0) |
|
|
|
|
|
|
|
_, y_true = make_multilabel_classification( |
|
n_features=1, |
|
n_classes=n_classes, |
|
random_state=0, |
|
allow_unlabeled=True, |
|
n_samples=n_samples, |
|
) |
|
_, y_pred = make_multilabel_classification( |
|
n_features=1, |
|
n_classes=n_classes, |
|
random_state=1, |
|
allow_unlabeled=True, |
|
n_samples=n_samples, |
|
) |
|
|
|
y_score = random_state.uniform(size=y_true.shape) |
|
|
|
|
|
y_true += [0] * n_classes |
|
y_pred += [0] * n_classes |
|
|
|
metrics = ALL_METRICS[name] |
|
pred = y_score if name in THRESHOLDED_METRICS else y_pred |
|
measure_normalized = metrics(y_true, pred, normalize=True) |
|
measure_not_normalized = metrics(y_true, pred, normalize=False) |
|
|
|
assert_array_less( |
|
-1.0 * measure_normalized, |
|
0, |
|
err_msg="We failed to test correctly the normalize option", |
|
) |
|
|
|
assert_allclose( |
|
measure_normalized, |
|
measure_not_normalized / n_samples, |
|
err_msg=f"Failed with {name}", |
|
) |
|
|
|
|
|
def _check_averaging( |
|
metric, y_true, y_pred, y_true_binarize, y_pred_binarize, is_multilabel |
|
): |
|
n_samples, n_classes = y_true_binarize.shape |
|
|
|
|
|
label_measure = metric(y_true, y_pred, average=None) |
|
assert_allclose( |
|
label_measure, |
|
[ |
|
metric(y_true_binarize[:, i], y_pred_binarize[:, i]) |
|
for i in range(n_classes) |
|
], |
|
) |
|
|
|
|
|
micro_measure = metric(y_true, y_pred, average="micro") |
|
assert_allclose( |
|
micro_measure, metric(y_true_binarize.ravel(), y_pred_binarize.ravel()) |
|
) |
|
|
|
|
|
macro_measure = metric(y_true, y_pred, average="macro") |
|
assert_allclose(macro_measure, np.mean(label_measure)) |
|
|
|
|
|
weights = np.sum(y_true_binarize, axis=0, dtype=int) |
|
|
|
if np.sum(weights) != 0: |
|
weighted_measure = metric(y_true, y_pred, average="weighted") |
|
assert_allclose(weighted_measure, np.average(label_measure, weights=weights)) |
|
else: |
|
weighted_measure = metric(y_true, y_pred, average="weighted") |
|
assert_allclose(weighted_measure, 0) |
|
|
|
|
|
if is_multilabel: |
|
sample_measure = metric(y_true, y_pred, average="samples") |
|
assert_allclose( |
|
sample_measure, |
|
np.mean( |
|
[ |
|
metric(y_true_binarize[i], y_pred_binarize[i]) |
|
for i in range(n_samples) |
|
] |
|
), |
|
) |
|
|
|
with pytest.raises(ValueError): |
|
metric(y_true, y_pred, average="unknown") |
|
with pytest.raises(ValueError): |
|
metric(y_true, y_pred, average="garbage") |
|
|
|
|
|
def check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score): |
|
is_multilabel = type_of_target(y_true).startswith("multilabel") |
|
|
|
metric = ALL_METRICS[name] |
|
|
|
if name in METRICS_WITH_AVERAGING: |
|
_check_averaging( |
|
metric, y_true, y_pred, y_true_binarize, y_pred_binarize, is_multilabel |
|
) |
|
elif name in THRESHOLDED_METRICS_WITH_AVERAGING: |
|
_check_averaging( |
|
metric, y_true, y_score, y_true_binarize, y_score, is_multilabel |
|
) |
|
else: |
|
raise ValueError("Metric is not recorded as having an average option") |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_AVERAGING)) |
|
def test_averaging_multiclass(name): |
|
n_samples, n_classes = 50, 3 |
|
random_state = check_random_state(0) |
|
y_true = random_state.randint(0, n_classes, size=(n_samples,)) |
|
y_pred = random_state.randint(0, n_classes, size=(n_samples,)) |
|
y_score = random_state.uniform(size=(n_samples, n_classes)) |
|
|
|
lb = LabelBinarizer().fit(y_true) |
|
y_true_binarize = lb.transform(y_true) |
|
y_pred_binarize = lb.transform(y_pred) |
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(METRICS_WITH_AVERAGING | THRESHOLDED_METRICS_WITH_AVERAGING) |
|
) |
|
def test_averaging_multilabel(name): |
|
n_samples, n_classes = 40, 5 |
|
_, y = make_multilabel_classification( |
|
n_features=1, |
|
n_classes=n_classes, |
|
random_state=5, |
|
n_samples=n_samples, |
|
allow_unlabeled=False, |
|
) |
|
y_true = y[:20] |
|
y_pred = y[20:] |
|
y_score = check_random_state(0).normal(size=(20, n_classes)) |
|
y_true_binarize = y_true |
|
y_pred_binarize = y_pred |
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_AVERAGING)) |
|
def test_averaging_multilabel_all_zeroes(name): |
|
y_true = np.zeros((20, 3)) |
|
y_pred = np.zeros((20, 3)) |
|
y_score = np.zeros((20, 3)) |
|
y_true_binarize = y_true |
|
y_pred_binarize = y_pred |
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score) |
|
|
|
|
|
def test_averaging_binary_multilabel_all_zeroes(): |
|
y_true = np.zeros((20, 3)) |
|
y_pred = np.zeros((20, 3)) |
|
y_true_binarize = y_true |
|
y_pred_binarize = y_pred |
|
|
|
binary_metric = lambda y_true, y_score, average="macro": _average_binary_score( |
|
precision_score, y_true, y_score, average |
|
) |
|
_check_averaging( |
|
binary_metric, |
|
y_true, |
|
y_pred, |
|
y_true_binarize, |
|
y_pred_binarize, |
|
is_multilabel=True, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("name", sorted(METRICS_WITH_AVERAGING)) |
|
def test_averaging_multilabel_all_ones(name): |
|
y_true = np.ones((20, 3)) |
|
y_pred = np.ones((20, 3)) |
|
y_score = np.ones((20, 3)) |
|
y_true_binarize = y_true |
|
y_pred_binarize = y_pred |
|
|
|
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score) |
|
|
|
|
|
def check_sample_weight_invariance(name, metric, y1, y2): |
|
rng = np.random.RandomState(0) |
|
sample_weight = rng.randint(1, 10, size=len(y1)) |
|
|
|
|
|
|
|
metric = partial(metric, k=1) if name == "top_k_accuracy_score" else metric |
|
|
|
|
|
unweighted_score = metric(y1, y2, sample_weight=None) |
|
|
|
assert_allclose( |
|
unweighted_score, |
|
metric(y1, y2, sample_weight=np.ones(shape=len(y1))), |
|
err_msg="For %s sample_weight=None is not equivalent to sample_weight=ones" |
|
% name, |
|
) |
|
|
|
|
|
weighted_score = metric(y1, y2, sample_weight=sample_weight) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
assert_allclose(unweighted_score, weighted_score) |
|
raise ValueError( |
|
"Unweighted and weighted scores are unexpectedly " |
|
"almost equal (%s) and (%s) " |
|
"for %s" % (unweighted_score, weighted_score, name) |
|
) |
|
|
|
|
|
weighted_score_list = metric(y1, y2, sample_weight=sample_weight.tolist()) |
|
assert_allclose( |
|
weighted_score, |
|
weighted_score_list, |
|
err_msg=( |
|
"Weighted scores for array and list " |
|
"sample_weight input are not equal (%s != %s) for %s" |
|
) |
|
% (weighted_score, weighted_score_list, name), |
|
) |
|
|
|
|
|
repeat_weighted_score = metric( |
|
np.repeat(y1, sample_weight, axis=0), |
|
np.repeat(y2, sample_weight, axis=0), |
|
sample_weight=None, |
|
) |
|
assert_allclose( |
|
weighted_score, |
|
repeat_weighted_score, |
|
err_msg="Weighting %s is not equal to repeating samples" % name, |
|
) |
|
|
|
|
|
|
|
sample_weight_subset = sample_weight[1::2] |
|
sample_weight_zeroed = np.copy(sample_weight) |
|
sample_weight_zeroed[::2] = 0 |
|
y1_subset = y1[1::2] |
|
y2_subset = y2[1::2] |
|
weighted_score_subset = metric( |
|
y1_subset, y2_subset, sample_weight=sample_weight_subset |
|
) |
|
weighted_score_zeroed = metric(y1, y2, sample_weight=sample_weight_zeroed) |
|
assert_allclose( |
|
weighted_score_subset, |
|
weighted_score_zeroed, |
|
err_msg=( |
|
"Zeroing weights does not give the same result as " |
|
"removing the corresponding samples (%s != %s) for %s" |
|
) |
|
% (weighted_score_zeroed, weighted_score_subset, name), |
|
) |
|
|
|
if not name.startswith("unnormalized"): |
|
|
|
|
|
for scaling in [2, 0.3]: |
|
assert_allclose( |
|
weighted_score, |
|
metric(y1, y2, sample_weight=sample_weight * scaling), |
|
err_msg="%s sample_weight is not invariant under scaling" % name, |
|
) |
|
|
|
|
|
|
|
error_message = ( |
|
r"Found input variables with inconsistent numbers of " |
|
r"samples: \[{}, {}, {}\]".format( |
|
_num_samples(y1), _num_samples(y2), _num_samples(sample_weight) * 2 |
|
) |
|
) |
|
with pytest.raises(ValueError, match=error_message): |
|
metric(y1, y2, sample_weight=np.hstack([sample_weight, sample_weight])) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", |
|
sorted( |
|
set(ALL_METRICS).intersection(set(REGRESSION_METRICS)) |
|
- METRICS_WITHOUT_SAMPLE_WEIGHT |
|
), |
|
) |
|
def test_regression_sample_weight_invariance(name): |
|
n_samples = 50 |
|
random_state = check_random_state(0) |
|
|
|
y_true = random_state.random_sample(size=(n_samples,)) |
|
y_pred = random_state.random_sample(size=(n_samples,)) |
|
metric = ALL_METRICS[name] |
|
check_sample_weight_invariance(name, metric, y_true, y_pred) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", |
|
sorted( |
|
set(ALL_METRICS) |
|
- set(REGRESSION_METRICS) |
|
- METRICS_WITHOUT_SAMPLE_WEIGHT |
|
- METRIC_UNDEFINED_BINARY |
|
), |
|
) |
|
def test_binary_sample_weight_invariance(name): |
|
|
|
n_samples = 50 |
|
random_state = check_random_state(0) |
|
y_true = random_state.randint(0, 2, size=(n_samples,)) |
|
y_pred = random_state.randint(0, 2, size=(n_samples,)) |
|
y_score = random_state.random_sample(size=(n_samples,)) |
|
metric = ALL_METRICS[name] |
|
if name in THRESHOLDED_METRICS: |
|
check_sample_weight_invariance(name, metric, y_true, y_score) |
|
else: |
|
check_sample_weight_invariance(name, metric, y_true, y_pred) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", |
|
sorted( |
|
set(ALL_METRICS) |
|
- set(REGRESSION_METRICS) |
|
- METRICS_WITHOUT_SAMPLE_WEIGHT |
|
- METRIC_UNDEFINED_BINARY_MULTICLASS |
|
), |
|
) |
|
def test_multiclass_sample_weight_invariance(name): |
|
|
|
n_samples = 50 |
|
random_state = check_random_state(0) |
|
y_true = random_state.randint(0, 5, size=(n_samples,)) |
|
y_pred = random_state.randint(0, 5, size=(n_samples,)) |
|
y_score = random_state.random_sample(size=(n_samples, 5)) |
|
metric = ALL_METRICS[name] |
|
if name in THRESHOLDED_METRICS: |
|
|
|
temp = np.exp(-y_score) |
|
y_score_norm = temp / temp.sum(axis=-1).reshape(-1, 1) |
|
check_sample_weight_invariance(name, metric, y_true, y_score_norm) |
|
else: |
|
check_sample_weight_invariance(name, metric, y_true, y_pred) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", |
|
sorted( |
|
(MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS | MULTIOUTPUT_METRICS) |
|
- METRICS_WITHOUT_SAMPLE_WEIGHT |
|
), |
|
) |
|
def test_multilabel_sample_weight_invariance(name): |
|
|
|
random_state = check_random_state(0) |
|
_, ya = make_multilabel_classification( |
|
n_features=1, n_classes=10, random_state=0, n_samples=50, allow_unlabeled=False |
|
) |
|
_, yb = make_multilabel_classification( |
|
n_features=1, n_classes=10, random_state=1, n_samples=50, allow_unlabeled=False |
|
) |
|
y_true = np.vstack([ya, yb]) |
|
y_pred = np.vstack([ya, ya]) |
|
y_score = random_state.uniform(size=y_true.shape) |
|
|
|
|
|
y_score /= y_score.sum(axis=1, keepdims=True) |
|
|
|
metric = ALL_METRICS[name] |
|
if name in THRESHOLDED_METRICS: |
|
check_sample_weight_invariance(name, metric, y_true, y_score) |
|
else: |
|
check_sample_weight_invariance(name, metric, y_true, y_pred) |
|
|
|
|
|
def test_no_averaging_labels(): |
|
|
|
|
|
y_true_multilabel = np.array([[1, 1, 0, 0], [1, 1, 0, 0]]) |
|
y_pred_multilabel = np.array([[0, 0, 1, 1], [0, 1, 1, 0]]) |
|
y_true_multiclass = np.array([0, 1, 2]) |
|
y_pred_multiclass = np.array([0, 2, 3]) |
|
labels = np.array([3, 0, 1, 2]) |
|
_, inverse_labels = np.unique(labels, return_inverse=True) |
|
|
|
for name in METRICS_WITH_AVERAGING: |
|
for y_true, y_pred in [ |
|
[y_true_multiclass, y_pred_multiclass], |
|
[y_true_multilabel, y_pred_multilabel], |
|
]: |
|
if name not in MULTILABELS_METRICS and y_pred.ndim > 1: |
|
continue |
|
|
|
metric = ALL_METRICS[name] |
|
|
|
score_labels = metric(y_true, y_pred, labels=labels, average=None) |
|
score = metric(y_true, y_pred, average=None) |
|
assert_array_equal(score_labels, score[inverse_labels]) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(MULTILABELS_METRICS - {"unnormalized_multilabel_confusion_matrix"}) |
|
) |
|
def test_multilabel_label_permutations_invariance(name): |
|
random_state = check_random_state(0) |
|
n_samples, n_classes = 20, 4 |
|
|
|
y_true = random_state.randint(0, 2, size=(n_samples, n_classes)) |
|
y_score = random_state.randint(0, 2, size=(n_samples, n_classes)) |
|
|
|
metric = ALL_METRICS[name] |
|
score = metric(y_true, y_score) |
|
|
|
for perm in permutations(range(n_classes), n_classes): |
|
y_score_perm = y_score[:, perm] |
|
y_true_perm = y_true[:, perm] |
|
|
|
current_score = metric(y_true_perm, y_score_perm) |
|
assert_almost_equal(score, current_score) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(THRESHOLDED_MULTILABEL_METRICS | MULTIOUTPUT_METRICS) |
|
) |
|
def test_thresholded_multilabel_multioutput_permutations_invariance(name): |
|
random_state = check_random_state(0) |
|
n_samples, n_classes = 20, 4 |
|
y_true = random_state.randint(0, 2, size=(n_samples, n_classes)) |
|
y_score = random_state.uniform(size=y_true.shape) |
|
|
|
|
|
y_score /= y_score.sum(axis=1, keepdims=True) |
|
|
|
|
|
|
|
y_true[y_true.sum(1) == 4, 0] = 0 |
|
y_true[y_true.sum(1) == 0, 0] = 1 |
|
|
|
metric = ALL_METRICS[name] |
|
score = metric(y_true, y_score) |
|
|
|
for perm in permutations(range(n_classes), n_classes): |
|
y_score_perm = y_score[:, perm] |
|
y_true_perm = y_true[:, perm] |
|
|
|
current_score = metric(y_true_perm, y_score_perm) |
|
if metric == mean_absolute_percentage_error: |
|
assert np.isfinite(current_score) |
|
assert current_score > 1e6 |
|
|
|
|
|
|
|
|
|
else: |
|
assert_almost_equal(score, current_score) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"name", sorted(set(THRESHOLDED_METRICS) - METRIC_UNDEFINED_BINARY_MULTICLASS) |
|
) |
|
def test_thresholded_metric_permutation_invariance(name): |
|
n_samples, n_classes = 100, 3 |
|
random_state = check_random_state(0) |
|
|
|
y_score = random_state.rand(n_samples, n_classes) |
|
temp = np.exp(-y_score) |
|
y_score = temp / temp.sum(axis=-1).reshape(-1, 1) |
|
y_true = random_state.randint(0, n_classes, size=n_samples) |
|
|
|
metric = ALL_METRICS[name] |
|
score = metric(y_true, y_score) |
|
for perm in permutations(range(n_classes), n_classes): |
|
inverse_perm = np.zeros(n_classes, dtype=int) |
|
inverse_perm[list(perm)] = np.arange(n_classes) |
|
y_score_perm = y_score[:, inverse_perm] |
|
y_true_perm = np.take(perm, y_true) |
|
|
|
current_score = metric(y_true_perm, y_score_perm) |
|
assert_almost_equal(score, current_score) |
|
|
|
|
|
@pytest.mark.parametrize("metric_name", CLASSIFICATION_METRICS) |
|
def test_metrics_consistent_type_error(metric_name): |
|
|
|
|
|
rng = np.random.RandomState(42) |
|
y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object) |
|
y2 = rng.randint(0, 2, size=y1.size) |
|
|
|
err_msg = "Labels in y_true and y_pred should be of the same type." |
|
with pytest.raises(TypeError, match=err_msg): |
|
CLASSIFICATION_METRICS[metric_name](y1, y2) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"metric, y_pred_threshold", |
|
[ |
|
(average_precision_score, True), |
|
(brier_score_loss, True), |
|
(f1_score, False), |
|
(partial(fbeta_score, beta=1), False), |
|
(jaccard_score, False), |
|
(precision_recall_curve, True), |
|
(precision_score, False), |
|
(recall_score, False), |
|
(roc_curve, True), |
|
], |
|
) |
|
@pytest.mark.parametrize("dtype_y_str", [str, object]) |
|
def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): |
|
|
|
|
|
rng = np.random.RandomState(42) |
|
y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str) |
|
y2 = rng.randint(0, 2, size=y1.size) |
|
|
|
if not y_pred_threshold: |
|
y2 = np.array(["spam", "eggs"], dtype=dtype_y_str)[y2] |
|
|
|
err_msg_pos_label_None = ( |
|
"y_true takes value in {'eggs', 'spam'} and pos_label is not " |
|
"specified: either make y_true take value in {0, 1} or {-1, 1} or " |
|
"pass pos_label explicit" |
|
) |
|
err_msg_pos_label_1 = ( |
|
r"pos_label=1 is not a valid label. It should be one of " r"\['eggs', 'spam'\]" |
|
) |
|
|
|
pos_label_default = signature(metric).parameters["pos_label"].default |
|
|
|
err_msg = err_msg_pos_label_1 if pos_label_default == 1 else err_msg_pos_label_None |
|
with pytest.raises(ValueError, match=err_msg): |
|
metric(y1, y2) |
|
|
|
|
|
def check_array_api_metric( |
|
metric, array_namespace, device, dtype_name, a_np, b_np, **metric_kwargs |
|
): |
|
xp = _array_api_for_tests(array_namespace, device) |
|
|
|
a_xp = xp.asarray(a_np, device=device) |
|
b_xp = xp.asarray(b_np, device=device) |
|
|
|
metric_np = metric(a_np, b_np, **metric_kwargs) |
|
|
|
if metric_kwargs.get("sample_weight") is not None: |
|
metric_kwargs["sample_weight"] = xp.asarray( |
|
metric_kwargs["sample_weight"], device=device |
|
) |
|
|
|
multioutput = metric_kwargs.get("multioutput") |
|
if isinstance(multioutput, np.ndarray): |
|
metric_kwargs["multioutput"] = xp.asarray(multioutput, device=device) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
np.asarray(a_xp) |
|
np.asarray(b_xp) |
|
numpy_as_array_works = True |
|
except TypeError: |
|
|
|
|
|
|
|
numpy_as_array_works = False |
|
|
|
if numpy_as_array_works: |
|
metric_xp = metric(a_xp, b_xp, **metric_kwargs) |
|
assert_allclose( |
|
metric_xp, |
|
metric_np, |
|
atol=_atol_for_type(dtype_name), |
|
) |
|
metric_xp_mixed_1 = metric(a_np, b_xp, **metric_kwargs) |
|
assert_allclose( |
|
metric_xp_mixed_1, |
|
metric_np, |
|
atol=_atol_for_type(dtype_name), |
|
) |
|
metric_xp_mixed_2 = metric(a_xp, b_np, **metric_kwargs) |
|
assert_allclose( |
|
metric_xp_mixed_2, |
|
metric_np, |
|
atol=_atol_for_type(dtype_name), |
|
) |
|
|
|
with config_context(array_api_dispatch=True): |
|
metric_xp = metric(a_xp, b_xp, **metric_kwargs) |
|
|
|
assert_allclose( |
|
_convert_to_numpy(xp.asarray(metric_xp), xp), |
|
metric_np, |
|
atol=_atol_for_type(dtype_name), |
|
) |
|
|
|
|
|
def check_array_api_binary_classification_metric( |
|
metric, array_namespace, device, dtype_name |
|
): |
|
y_true_np = np.array([0, 0, 1, 1]) |
|
y_pred_np = np.array([0, 1, 0, 1]) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=None, |
|
) |
|
|
|
sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=sample_weight, |
|
) |
|
|
|
|
|
def check_array_api_multiclass_classification_metric( |
|
metric, array_namespace, device, dtype_name |
|
): |
|
y_true_np = np.array([0, 1, 2, 3]) |
|
y_pred_np = np.array([0, 1, 0, 2]) |
|
|
|
additional_params = { |
|
"average": ("micro", "macro", "weighted"), |
|
} |
|
metric_kwargs_combinations = _get_metric_kwargs_for_array_api_testing( |
|
metric=metric, |
|
params=additional_params, |
|
) |
|
for metric_kwargs in metric_kwargs_combinations: |
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=None, |
|
**metric_kwargs, |
|
) |
|
|
|
sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=sample_weight, |
|
**metric_kwargs, |
|
) |
|
|
|
|
|
def check_array_api_multilabel_classification_metric( |
|
metric, array_namespace, device, dtype_name |
|
): |
|
y_true_np = np.array([[1, 1], [0, 1], [0, 0]], dtype=dtype_name) |
|
y_pred_np = np.array([[1, 1], [1, 1], [1, 1]], dtype=dtype_name) |
|
|
|
additional_params = { |
|
"average": ("micro", "macro", "weighted"), |
|
} |
|
metric_kwargs_combinations = _get_metric_kwargs_for_array_api_testing( |
|
metric=metric, |
|
params=additional_params, |
|
) |
|
for metric_kwargs in metric_kwargs_combinations: |
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=None, |
|
**metric_kwargs, |
|
) |
|
|
|
sample_weight = np.array([0.0, 0.1, 2.0], dtype=dtype_name) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=sample_weight, |
|
**metric_kwargs, |
|
) |
|
|
|
|
|
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name): |
|
func_name = metric.func.__name__ if isinstance(metric, partial) else metric.__name__ |
|
if func_name == "mean_poisson_deviance" and sp_version < parse_version("1.14.0"): |
|
pytest.skip( |
|
"mean_poisson_deviance's dependency `xlogy` is available as of scipy 1.14.0" |
|
) |
|
|
|
y_true_np = np.array([2.0, 0.1, 1.0, 4.0], dtype=dtype_name) |
|
y_pred_np = np.array([0.5, 0.5, 2, 2], dtype=dtype_name) |
|
|
|
metric_kwargs = {} |
|
metric_params = signature(metric).parameters |
|
|
|
if "sample_weight" in metric_params: |
|
metric_kwargs["sample_weight"] = None |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
**metric_kwargs, |
|
) |
|
|
|
if "sample_weight" in metric_params: |
|
metric_kwargs["sample_weight"] = np.array( |
|
[0.1, 2.0, 1.5, 0.5], dtype=dtype_name |
|
) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
**metric_kwargs, |
|
) |
|
|
|
|
|
def check_array_api_regression_metric_multioutput( |
|
metric, array_namespace, device, dtype_name |
|
): |
|
y_true_np = np.array([[1, 3, 2], [1, 2, 2]], dtype=dtype_name) |
|
y_pred_np = np.array([[1, 4, 4], [1, 1, 1]], dtype=dtype_name) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=None, |
|
) |
|
|
|
sample_weight = np.array([0.1, 2.0], dtype=dtype_name) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
sample_weight=sample_weight, |
|
) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
multioutput=np.array([0.1, 0.3, 0.7], dtype=dtype_name), |
|
) |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=y_true_np, |
|
b_np=y_pred_np, |
|
multioutput="raw_values", |
|
) |
|
|
|
|
|
def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name): |
|
|
|
X_np = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=dtype_name) |
|
Y_np = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], dtype=dtype_name) |
|
|
|
metric_kwargs = {} |
|
if "dense_output" in signature(metric).parameters: |
|
metric_kwargs["dense_output"] = False |
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=X_np, |
|
b_np=Y_np, |
|
**metric_kwargs, |
|
) |
|
metric_kwargs["dense_output"] = True |
|
|
|
check_array_api_metric( |
|
metric, |
|
array_namespace, |
|
device, |
|
dtype_name, |
|
a_np=X_np, |
|
b_np=Y_np, |
|
**metric_kwargs, |
|
) |
|
|
|
|
|
array_api_metric_checkers = { |
|
accuracy_score: [ |
|
check_array_api_binary_classification_metric, |
|
check_array_api_multiclass_classification_metric, |
|
check_array_api_multilabel_classification_metric, |
|
], |
|
f1_score: [ |
|
check_array_api_binary_classification_metric, |
|
check_array_api_multiclass_classification_metric, |
|
check_array_api_multilabel_classification_metric, |
|
], |
|
multilabel_confusion_matrix: [ |
|
check_array_api_binary_classification_metric, |
|
check_array_api_multiclass_classification_metric, |
|
check_array_api_multilabel_classification_metric, |
|
], |
|
zero_one_loss: [ |
|
check_array_api_binary_classification_metric, |
|
check_array_api_multiclass_classification_metric, |
|
check_array_api_multilabel_classification_metric, |
|
], |
|
mean_tweedie_deviance: [check_array_api_regression_metric], |
|
partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric], |
|
partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric], |
|
r2_score: [ |
|
check_array_api_regression_metric, |
|
check_array_api_regression_metric_multioutput, |
|
], |
|
cosine_similarity: [check_array_api_metric_pairwise], |
|
mean_absolute_error: [ |
|
check_array_api_regression_metric, |
|
check_array_api_regression_metric_multioutput, |
|
], |
|
mean_squared_error: [ |
|
check_array_api_regression_metric, |
|
check_array_api_regression_metric_multioutput, |
|
], |
|
mean_squared_log_error: [ |
|
check_array_api_regression_metric, |
|
check_array_api_regression_metric_multioutput, |
|
], |
|
d2_tweedie_score: [ |
|
check_array_api_regression_metric, |
|
], |
|
paired_cosine_distances: [check_array_api_metric_pairwise], |
|
mean_poisson_deviance: [check_array_api_regression_metric], |
|
additive_chi2_kernel: [check_array_api_metric_pairwise], |
|
mean_gamma_deviance: [check_array_api_regression_metric], |
|
max_error: [check_array_api_regression_metric], |
|
mean_absolute_percentage_error: [ |
|
check_array_api_regression_metric, |
|
check_array_api_regression_metric_multioutput, |
|
], |
|
chi2_kernel: [check_array_api_metric_pairwise], |
|
paired_euclidean_distances: [check_array_api_metric_pairwise], |
|
cosine_distances: [check_array_api_metric_pairwise], |
|
euclidean_distances: [check_array_api_metric_pairwise], |
|
linear_kernel: [check_array_api_metric_pairwise], |
|
polynomial_kernel: [check_array_api_metric_pairwise], |
|
rbf_kernel: [check_array_api_metric_pairwise], |
|
root_mean_squared_error: [ |
|
check_array_api_regression_metric, |
|
check_array_api_regression_metric_multioutput, |
|
], |
|
root_mean_squared_log_error: [ |
|
check_array_api_regression_metric, |
|
check_array_api_regression_metric_multioutput, |
|
], |
|
sigmoid_kernel: [check_array_api_metric_pairwise], |
|
} |
|
|
|
|
|
def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers): |
|
for metric, checkers in metric_checkers.items(): |
|
for checker in checkers: |
|
yield metric, checker |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations() |
|
) |
|
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations()) |
|
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func): |
|
check_func(metric, array_namespace, device, dtype_name) |
|
|
|
|
|
@pytest.mark.parametrize("df_lib_name", ["pandas", "polars"]) |
|
@pytest.mark.parametrize("metric_name", sorted(ALL_METRICS)) |
|
def test_metrics_dataframe_series(metric_name, df_lib_name): |
|
df_lib = pytest.importorskip(df_lib_name) |
|
|
|
y_pred = df_lib.Series([0.0, 1.0, 0, 1.0]) |
|
y_true = df_lib.Series([1.0, 0.0, 0.0, 0.0]) |
|
|
|
metric = ALL_METRICS[metric_name] |
|
try: |
|
expected_metric = metric(y_pred.to_numpy(), y_true.to_numpy()) |
|
except ValueError: |
|
pytest.skip(f"{metric_name} can not deal with 1d inputs") |
|
|
|
assert_allclose(metric(y_pred, y_true), expected_metric) |
|
|
|
|
|
def _get_metric_kwargs_for_array_api_testing(metric, params): |
|
"""Helper function to enable specifying a variety of additional params and |
|
their corresponding values, so that they can be passed to a metric function |
|
when testing for array api compliance.""" |
|
metric_kwargs_combinations = [{}] |
|
for param, values in params.items(): |
|
if param not in signature(metric).parameters: |
|
continue |
|
|
|
new_combinations = [] |
|
for kwargs in metric_kwargs_combinations: |
|
for value in values: |
|
new_kwargs = kwargs.copy() |
|
new_kwargs[param] = value |
|
new_combinations.append(new_kwargs) |
|
|
|
metric_kwargs_combinations = new_combinations |
|
|
|
return metric_kwargs_combinations |
|
|