File size: 4,380 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Test loaders for common functionality."""

import inspect
import os

import numpy as np
import pytest

import sklearn.datasets


def is_pillow_installed():
    try:
        import PIL  # noqa

        return True
    except ImportError:
        return False


FETCH_PYTEST_MARKERS = {
    "return_X_y": {
        "fetch_20newsgroups": pytest.mark.xfail(
            reason="X is a list and does not have a shape argument"
        ),
        "fetch_openml": pytest.mark.xfail(
            reason="fetch_opeml requires a dataset name or id"
        ),
        "fetch_lfw_people": pytest.mark.skipif(
            not is_pillow_installed(), reason="pillow is not installed"
        ),
    },
    "as_frame": {
        "fetch_openml": pytest.mark.xfail(
            reason="fetch_opeml requires a dataset name or id"
        ),
    },
}


def check_pandas_dependency_message(fetch_func):
    try:
        import pandas  # noqa

        pytest.skip("This test requires pandas to not be installed")
    except ImportError:
        # Check that pandas is imported lazily and that an informative error
        # message is raised when pandas is missing:
        name = fetch_func.__name__
        expected_msg = f"{name} with as_frame=True requires pandas"
        with pytest.raises(ImportError, match=expected_msg):
            fetch_func(as_frame=True)


def check_return_X_y(bunch, dataset_func):
    X_y_tuple = dataset_func(return_X_y=True)
    assert isinstance(X_y_tuple, tuple)
    assert X_y_tuple[0].shape == bunch.data.shape
    assert X_y_tuple[1].shape == bunch.target.shape


def check_as_frame(
    bunch, dataset_func, expected_data_dtype=None, expected_target_dtype=None
):
    pd = pytest.importorskip("pandas")
    frame_bunch = dataset_func(as_frame=True)
    assert hasattr(frame_bunch, "frame")
    assert isinstance(frame_bunch.frame, pd.DataFrame)
    assert isinstance(frame_bunch.data, pd.DataFrame)
    assert frame_bunch.data.shape == bunch.data.shape
    if frame_bunch.target.ndim > 1:
        assert isinstance(frame_bunch.target, pd.DataFrame)
    else:
        assert isinstance(frame_bunch.target, pd.Series)
    assert frame_bunch.target.shape[0] == bunch.target.shape[0]
    if expected_data_dtype is not None:
        assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
    if expected_target_dtype is not None:
        assert np.all(frame_bunch.target.dtypes == expected_target_dtype)

    # Test for return_X_y and as_frame=True
    frame_X, frame_y = dataset_func(as_frame=True, return_X_y=True)
    assert isinstance(frame_X, pd.DataFrame)
    if frame_y.ndim > 1:
        assert isinstance(frame_X, pd.DataFrame)
    else:
        assert isinstance(frame_y, pd.Series)


def _skip_network_tests():
    return os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1"


def _generate_func_supporting_param(param, dataset_type=("load", "fetch")):
    markers_fetch = FETCH_PYTEST_MARKERS.get(param, {})
    for name, obj in inspect.getmembers(sklearn.datasets):
        if not inspect.isfunction(obj):
            continue

        is_dataset_type = any([name.startswith(t) for t in dataset_type])
        is_support_param = param in inspect.signature(obj).parameters
        if is_dataset_type and is_support_param:
            # check if we should skip if we don't have network support
            marks = [
                pytest.mark.skipif(
                    condition=name.startswith("fetch") and _skip_network_tests(),
                    reason="Skip because fetcher requires internet network",
                )
            ]
            if name in markers_fetch:
                marks.append(markers_fetch[name])

            yield pytest.param(name, obj, marks=marks)


@pytest.mark.parametrize(
    "name, dataset_func", _generate_func_supporting_param("return_X_y")
)
def test_common_check_return_X_y(name, dataset_func):
    bunch = dataset_func()
    check_return_X_y(bunch, dataset_func)


@pytest.mark.parametrize(
    "name, dataset_func", _generate_func_supporting_param("as_frame")
)
def test_common_check_as_frame(name, dataset_func):
    bunch = dataset_func()
    check_as_frame(bunch, dataset_func)


@pytest.mark.parametrize(
    "name, dataset_func", _generate_func_supporting_param("as_frame")
)
def test_common_check_pandas_dependency(name, dataset_func):
    check_pandas_dependency_message(dataset_func)