File size: 11,411 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 |
"""Common tests for metaestimators"""
import functools
from contextlib import suppress
from inspect import signature
import numpy as np
import pytest
from sklearn.base import BaseEstimator, is_regressor
from sklearn.datasets import make_classification
from sklearn.ensemble import BaggingClassifier
from sklearn.exceptions import NotFittedError
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_selection import RFE, RFECV
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import MaxAbsScaler, StandardScaler
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.utils import all_estimators
from sklearn.utils._test_common.instance_generator import _construct_instances
from sklearn.utils._testing import SkipTest, set_random_state
from sklearn.utils.estimator_checks import (
_enforce_estimator_tags_X,
_enforce_estimator_tags_y,
)
from sklearn.utils.validation import check_is_fitted
class DelegatorData:
def __init__(
self,
name,
construct,
skip_methods=(),
fit_args=make_classification(random_state=0),
):
self.name = name
self.construct = construct
self.fit_args = fit_args
self.skip_methods = skip_methods
# For the following meta estimators we check for the existence of relevant
# methods only if the sub estimator also contains them. Any methods that
# are implemented in the meta estimator themselves and are not dependent
# on the sub estimator are specified in the `skip_methods` parameter.
DELEGATING_METAESTIMATORS = [
DelegatorData("Pipeline", lambda est: Pipeline([("est", est)])),
DelegatorData(
"GridSearchCV",
lambda est: GridSearchCV(est, param_grid={"param": [5]}, cv=2),
skip_methods=["score"],
),
DelegatorData(
"RandomizedSearchCV",
lambda est: RandomizedSearchCV(
est, param_distributions={"param": [5]}, cv=2, n_iter=1
),
skip_methods=["score"],
),
DelegatorData("RFE", RFE, skip_methods=["transform", "inverse_transform"]),
DelegatorData(
"RFECV", RFECV, skip_methods=["transform", "inverse_transform", "score"]
),
DelegatorData(
"BaggingClassifier",
BaggingClassifier,
skip_methods=[
"transform",
"inverse_transform",
"score",
"predict_proba",
"predict_log_proba",
"predict",
],
),
DelegatorData(
"SelfTrainingClassifier",
lambda est: SelfTrainingClassifier(est),
skip_methods=["transform", "inverse_transform", "predict_proba"],
),
]
def test_metaestimator_delegation():
# Ensures specified metaestimators have methods iff subestimator does
def hides(method):
@property
def wrapper(obj):
if obj.hidden_method == method.__name__:
raise AttributeError("%r is hidden" % obj.hidden_method)
return functools.partial(method, obj)
return wrapper
class SubEstimator(BaseEstimator):
def __init__(self, param=1, hidden_method=None):
self.param = param
self.hidden_method = hidden_method
def fit(self, X, y=None, *args, **kwargs):
self.coef_ = np.arange(X.shape[1])
self.classes_ = []
return True
def _check_fit(self):
check_is_fitted(self)
@hides
def inverse_transform(self, X, *args, **kwargs):
self._check_fit()
return X
@hides
def transform(self, X, *args, **kwargs):
self._check_fit()
return X
@hides
def predict(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def predict_proba(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def predict_log_proba(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def decision_function(self, X, *args, **kwargs):
self._check_fit()
return np.ones(X.shape[0])
@hides
def score(self, X, y, *args, **kwargs):
self._check_fit()
return 1.0
methods = [
k
for k in SubEstimator.__dict__.keys()
if not k.startswith("_") and not k.startswith("fit")
]
methods.sort()
for delegator_data in DELEGATING_METAESTIMATORS:
delegate = SubEstimator()
delegator = delegator_data.construct(delegate)
for method in methods:
if method in delegator_data.skip_methods:
continue
assert hasattr(delegate, method)
assert hasattr(
delegator, method
), "%s does not have method %r when its delegate does" % (
delegator_data.name,
method,
)
# delegation before fit raises a NotFittedError
if method == "score":
with pytest.raises(NotFittedError):
getattr(delegator, method)(
delegator_data.fit_args[0], delegator_data.fit_args[1]
)
else:
with pytest.raises(NotFittedError):
getattr(delegator, method)(delegator_data.fit_args[0])
delegator.fit(*delegator_data.fit_args)
for method in methods:
if method in delegator_data.skip_methods:
continue
# smoke test delegation
if method == "score":
getattr(delegator, method)(
delegator_data.fit_args[0], delegator_data.fit_args[1]
)
else:
getattr(delegator, method)(delegator_data.fit_args[0])
for method in methods:
if method in delegator_data.skip_methods:
continue
delegate = SubEstimator(hidden_method=method)
delegator = delegator_data.construct(delegate)
assert not hasattr(delegate, method)
assert not hasattr(
delegator, method
), "%s has method %r when its delegate does not" % (
delegator_data.name,
method,
)
def _get_instance_with_pipeline(meta_estimator, init_params):
"""Given a single meta-estimator instance, generate an instance with a pipeline"""
if {"estimator", "base_estimator", "regressor"} & init_params:
if is_regressor(meta_estimator):
estimator = make_pipeline(TfidfVectorizer(), Ridge())
param_grid = {"ridge__alpha": [0.1, 1.0]}
else:
estimator = make_pipeline(TfidfVectorizer(), LogisticRegression())
param_grid = {"logisticregression__C": [0.1, 1.0]}
if init_params.intersection(
{"param_grid", "param_distributions"}
): # SearchCV estimators
extra_params = {"n_iter": 2} if "n_iter" in init_params else {}
return type(meta_estimator)(estimator, param_grid, **extra_params)
else:
return type(meta_estimator)(estimator)
if "transformer_list" in init_params:
# FeatureUnion
transformer_list = [
("trans1", make_pipeline(TfidfVectorizer(), MaxAbsScaler())),
(
"trans2",
make_pipeline(TfidfVectorizer(), StandardScaler(with_mean=False)),
),
]
return type(meta_estimator)(transformer_list)
if "estimators" in init_params:
# stacking, voting
if is_regressor(meta_estimator):
estimator = [
("est1", make_pipeline(TfidfVectorizer(), Ridge(alpha=0.1))),
("est2", make_pipeline(TfidfVectorizer(), Ridge(alpha=1))),
]
else:
estimator = [
(
"est1",
make_pipeline(TfidfVectorizer(), LogisticRegression(C=0.1)),
),
("est2", make_pipeline(TfidfVectorizer(), LogisticRegression(C=1))),
]
return type(meta_estimator)(estimator)
def _generate_meta_estimator_instances_with_pipeline():
"""Generate instances of meta-estimators fed with a pipeline
Are considered meta-estimators all estimators accepting one of "estimator",
"base_estimator" or "estimators".
"""
print("estimators: ", len(all_estimators()))
for _, Estimator in sorted(all_estimators()):
sig = set(signature(Estimator).parameters)
print("\n", Estimator.__name__, sig)
if not sig.intersection(
{
"estimator",
"base_estimator",
"regressor",
"transformer_list",
"estimators",
}
):
continue
with suppress(SkipTest):
for meta_estimator in _construct_instances(Estimator):
print(meta_estimator)
yield _get_instance_with_pipeline(meta_estimator, sig)
# TODO: remove data validation for the following estimators
# They should be able to work on any data and delegate data validation to
# their inner estimator(s).
DATA_VALIDATION_META_ESTIMATORS_TO_IGNORE = [
"AdaBoostClassifier",
"AdaBoostRegressor",
"BaggingClassifier",
"BaggingRegressor",
"ClassifierChain", # data validation is necessary
"FrozenEstimator", # this estimator cannot be tested like others.
"IterativeImputer",
"OneVsOneClassifier", # input validation can't be avoided
"RANSACRegressor",
"RFE",
"RFECV",
"RegressorChain", # data validation is necessary
"SelfTrainingClassifier",
"SequentialFeatureSelector", # not applicable (2D data mandatory)
]
DATA_VALIDATION_META_ESTIMATORS = [
est
for est in _generate_meta_estimator_instances_with_pipeline()
if est.__class__.__name__ not in DATA_VALIDATION_META_ESTIMATORS_TO_IGNORE
]
def _get_meta_estimator_id(estimator):
return estimator.__class__.__name__
@pytest.mark.parametrize(
"estimator", DATA_VALIDATION_META_ESTIMATORS, ids=_get_meta_estimator_id
)
def test_meta_estimators_delegate_data_validation(estimator):
# Check that meta-estimators delegate data validation to the inner
# estimator(s).
rng = np.random.RandomState(0)
set_random_state(estimator)
n_samples = 30
X = rng.choice(np.array(["aa", "bb", "cc"], dtype=object), size=n_samples)
if is_regressor(estimator):
y = rng.normal(size=n_samples)
else:
y = rng.randint(3, size=n_samples)
# We convert to lists to make sure it works on array-like
X = _enforce_estimator_tags_X(estimator, X).tolist()
y = _enforce_estimator_tags_y(estimator, y).tolist()
# Calling fit should not raise any data validation exception since X is a
# valid input datastructure for the first step of the pipeline passed as
# base estimator to the meta estimator.
estimator.fit(X, y)
# n_features_in_ should not be defined since data is not tabular data.
assert not hasattr(estimator, "n_features_in_")
|