File size: 4,380 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""Test loaders for common functionality."""
import inspect
import os
import numpy as np
import pytest
import sklearn.datasets
def is_pillow_installed():
try:
import PIL # noqa
return True
except ImportError:
return False
FETCH_PYTEST_MARKERS = {
"return_X_y": {
"fetch_20newsgroups": pytest.mark.xfail(
reason="X is a list and does not have a shape argument"
),
"fetch_openml": pytest.mark.xfail(
reason="fetch_opeml requires a dataset name or id"
),
"fetch_lfw_people": pytest.mark.skipif(
not is_pillow_installed(), reason="pillow is not installed"
),
},
"as_frame": {
"fetch_openml": pytest.mark.xfail(
reason="fetch_opeml requires a dataset name or id"
),
},
}
def check_pandas_dependency_message(fetch_func):
try:
import pandas # noqa
pytest.skip("This test requires pandas to not be installed")
except ImportError:
# Check that pandas is imported lazily and that an informative error
# message is raised when pandas is missing:
name = fetch_func.__name__
expected_msg = f"{name} with as_frame=True requires pandas"
with pytest.raises(ImportError, match=expected_msg):
fetch_func(as_frame=True)
def check_return_X_y(bunch, dataset_func):
X_y_tuple = dataset_func(return_X_y=True)
assert isinstance(X_y_tuple, tuple)
assert X_y_tuple[0].shape == bunch.data.shape
assert X_y_tuple[1].shape == bunch.target.shape
def check_as_frame(
bunch, dataset_func, expected_data_dtype=None, expected_target_dtype=None
):
pd = pytest.importorskip("pandas")
frame_bunch = dataset_func(as_frame=True)
assert hasattr(frame_bunch, "frame")
assert isinstance(frame_bunch.frame, pd.DataFrame)
assert isinstance(frame_bunch.data, pd.DataFrame)
assert frame_bunch.data.shape == bunch.data.shape
if frame_bunch.target.ndim > 1:
assert isinstance(frame_bunch.target, pd.DataFrame)
else:
assert isinstance(frame_bunch.target, pd.Series)
assert frame_bunch.target.shape[0] == bunch.target.shape[0]
if expected_data_dtype is not None:
assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
if expected_target_dtype is not None:
assert np.all(frame_bunch.target.dtypes == expected_target_dtype)
# Test for return_X_y and as_frame=True
frame_X, frame_y = dataset_func(as_frame=True, return_X_y=True)
assert isinstance(frame_X, pd.DataFrame)
if frame_y.ndim > 1:
assert isinstance(frame_X, pd.DataFrame)
else:
assert isinstance(frame_y, pd.Series)
def _skip_network_tests():
return os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1"
def _generate_func_supporting_param(param, dataset_type=("load", "fetch")):
markers_fetch = FETCH_PYTEST_MARKERS.get(param, {})
for name, obj in inspect.getmembers(sklearn.datasets):
if not inspect.isfunction(obj):
continue
is_dataset_type = any([name.startswith(t) for t in dataset_type])
is_support_param = param in inspect.signature(obj).parameters
if is_dataset_type and is_support_param:
# check if we should skip if we don't have network support
marks = [
pytest.mark.skipif(
condition=name.startswith("fetch") and _skip_network_tests(),
reason="Skip because fetcher requires internet network",
)
]
if name in markers_fetch:
marks.append(markers_fetch[name])
yield pytest.param(name, obj, marks=marks)
@pytest.mark.parametrize(
"name, dataset_func", _generate_func_supporting_param("return_X_y")
)
def test_common_check_return_X_y(name, dataset_func):
bunch = dataset_func()
check_return_X_y(bunch, dataset_func)
@pytest.mark.parametrize(
"name, dataset_func", _generate_func_supporting_param("as_frame")
)
def test_common_check_as_frame(name, dataset_func):
bunch = dataset_func()
check_as_frame(bunch, dataset_func)
@pytest.mark.parametrize(
"name, dataset_func", _generate_func_supporting_param("as_frame")
)
def test_common_check_pandas_dependency(name, dataset_func):
check_pandas_dependency_message(dataset_func)
|