|
|
|
|
|
|
|
import numpy as np |
|
|
|
from . import check_consistent_length |
|
from ._optional_dependencies import check_matplotlib_support |
|
from ._response import _get_response_values_binary |
|
from .multiclass import type_of_target |
|
from .validation import _check_pos_label_consistency |
|
|
|
|
|
class _BinaryClassifierCurveDisplayMixin: |
|
"""Mixin class to be used in Displays requiring a binary classifier. |
|
|
|
The aim of this class is to centralize some validations regarding the estimator and |
|
the target and gather the response of the estimator. |
|
""" |
|
|
|
def _validate_plot_params(self, *, ax=None, name=None): |
|
check_matplotlib_support(f"{self.__class__.__name__}.plot") |
|
import matplotlib.pyplot as plt |
|
|
|
if ax is None: |
|
_, ax = plt.subplots() |
|
|
|
name = self.estimator_name if name is None else name |
|
return ax, ax.figure, name |
|
|
|
@classmethod |
|
def _validate_and_get_response_values( |
|
cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None |
|
): |
|
check_matplotlib_support(f"{cls.__name__}.from_estimator") |
|
|
|
name = estimator.__class__.__name__ if name is None else name |
|
|
|
y_pred, pos_label = _get_response_values_binary( |
|
estimator, |
|
X, |
|
response_method=response_method, |
|
pos_label=pos_label, |
|
) |
|
|
|
return y_pred, pos_label, name |
|
|
|
@classmethod |
|
def _validate_from_predictions_params( |
|
cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None |
|
): |
|
check_matplotlib_support(f"{cls.__name__}.from_predictions") |
|
|
|
if type_of_target(y_true) != "binary": |
|
raise ValueError( |
|
f"The target y is not binary. Got {type_of_target(y_true)} type of" |
|
" target." |
|
) |
|
|
|
check_consistent_length(y_true, y_pred, sample_weight) |
|
pos_label = _check_pos_label_consistency(pos_label, y_true) |
|
|
|
name = name if name is not None else "Classifier" |
|
|
|
return pos_label, name |
|
|
|
|
|
def _validate_score_name(score_name, scoring, negate_score): |
|
"""Validate the `score_name` parameter. |
|
|
|
If `score_name` is provided, we just return it as-is. |
|
If `score_name` is `None`, we use `Score` if `negate_score` is `False` and |
|
`Negative score` otherwise. |
|
If `score_name` is a string or a callable, we infer the name. We replace `_` by |
|
spaces and capitalize the first letter. We remove `neg_` and replace it by |
|
`"Negative"` if `negate_score` is `False` or just remove it otherwise. |
|
""" |
|
if score_name is not None: |
|
return score_name |
|
elif scoring is None: |
|
return "Negative score" if negate_score else "Score" |
|
else: |
|
score_name = scoring.__name__ if callable(scoring) else scoring |
|
if negate_score: |
|
if score_name.startswith("neg_"): |
|
score_name = score_name[4:] |
|
else: |
|
score_name = f"Negative {score_name}" |
|
elif score_name.startswith("neg_"): |
|
score_name = f"Negative {score_name[4:]}" |
|
score_name = score_name.replace("_", " ") |
|
return score_name.capitalize() |
|
|
|
|
|
def _interval_max_min_ratio(data): |
|
"""Compute the ratio between the largest and smallest inter-point distances. |
|
|
|
A value larger than 5 typically indicates that the parameter range would |
|
better be displayed with a log scale while a linear scale would be more |
|
suitable otherwise. |
|
""" |
|
diff = np.diff(np.sort(data)) |
|
return diff.max() / diff.min() |
|
|
|
|
|
def _validate_style_kwargs(default_style_kwargs, user_style_kwargs): |
|
"""Create valid style kwargs by avoiding Matplotlib alias errors. |
|
|
|
Matplotlib raises an error when, for example, 'color' and 'c', or 'linestyle' and |
|
'ls', are specified together. To avoid this, we automatically keep only the one |
|
specified by the user and raise an error if the user specifies both. |
|
|
|
Parameters |
|
---------- |
|
default_style_kwargs : dict |
|
The Matplotlib style kwargs used by default in the scikit-learn display. |
|
user_style_kwargs : dict |
|
The user-defined Matplotlib style kwargs. |
|
|
|
Returns |
|
------- |
|
valid_style_kwargs : dict |
|
The validated style kwargs taking into account both default and user-defined |
|
Matplotlib style kwargs. |
|
""" |
|
|
|
invalid_to_valid_kw = { |
|
"ls": "linestyle", |
|
"c": "color", |
|
"ec": "edgecolor", |
|
"fc": "facecolor", |
|
"lw": "linewidth", |
|
"mec": "markeredgecolor", |
|
"mfcalt": "markerfacecoloralt", |
|
"ms": "markersize", |
|
"mew": "markeredgewidth", |
|
"mfc": "markerfacecolor", |
|
"aa": "antialiased", |
|
"ds": "drawstyle", |
|
"font": "fontproperties", |
|
"family": "fontfamily", |
|
"name": "fontname", |
|
"size": "fontsize", |
|
"stretch": "fontstretch", |
|
"style": "fontstyle", |
|
"variant": "fontvariant", |
|
"weight": "fontweight", |
|
"ha": "horizontalalignment", |
|
"va": "verticalalignment", |
|
"ma": "multialignment", |
|
} |
|
for invalid_key, valid_key in invalid_to_valid_kw.items(): |
|
if invalid_key in user_style_kwargs and valid_key in user_style_kwargs: |
|
raise TypeError( |
|
f"Got both {invalid_key} and {valid_key}, which are aliases of one " |
|
"another" |
|
) |
|
valid_style_kwargs = default_style_kwargs.copy() |
|
|
|
for key in user_style_kwargs.keys(): |
|
if key in invalid_to_valid_kw: |
|
valid_style_kwargs[invalid_to_valid_kw[key]] = user_style_kwargs[key] |
|
else: |
|
valid_style_kwargs[key] = user_style_kwargs[key] |
|
|
|
return valid_style_kwargs |
|
|
|
|
|
def _despine(ax): |
|
"""Remove the top and right spines of the plot. |
|
|
|
Parameters |
|
---------- |
|
ax : matplotlib.axes.Axes |
|
The axes of the plot to despine. |
|
""" |
|
for s in ["top", "right"]: |
|
ax.spines[s].set_visible(False) |
|
for s in ["bottom", "left"]: |
|
ax.spines[s].set_bounds(0, 1) |
|
|