|
import atexit |
|
import os |
|
import warnings |
|
|
|
import numpy as np |
|
import pytest |
|
from scipy import sparse |
|
|
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
|
from sklearn.tree import DecisionTreeClassifier |
|
from sklearn.utils._testing import ( |
|
TempMemmap, |
|
_convert_container, |
|
_delete_folder, |
|
_get_warnings_filters_info_list, |
|
assert_allclose, |
|
assert_allclose_dense_sparse, |
|
assert_docstring_consistency, |
|
assert_run_python_script_without_output, |
|
check_docstring_parameters, |
|
create_memmap_backed_data, |
|
ignore_warnings, |
|
raises, |
|
set_random_state, |
|
skip_if_no_numpydoc, |
|
turn_warnings_into_errors, |
|
) |
|
from sklearn.utils.deprecation import deprecated |
|
from sklearn.utils.fixes import ( |
|
_IS_WASM, |
|
CSC_CONTAINERS, |
|
CSR_CONTAINERS, |
|
parse_version, |
|
sp_version, |
|
) |
|
from sklearn.utils.metaestimators import available_if |
|
|
|
|
|
def test_set_random_state(): |
|
lda = LinearDiscriminantAnalysis() |
|
tree = DecisionTreeClassifier() |
|
|
|
set_random_state(lda, 3) |
|
set_random_state(tree, 3) |
|
assert tree.random_state == 3 |
|
|
|
|
|
@pytest.mark.parametrize("csr_container", CSC_CONTAINERS) |
|
def test_assert_allclose_dense_sparse(csr_container): |
|
x = np.arange(9).reshape(3, 3) |
|
msg = "Not equal to tolerance " |
|
y = csr_container(x) |
|
for X in [x, y]: |
|
|
|
with pytest.raises(AssertionError, match=msg): |
|
assert_allclose_dense_sparse(X, X * 2) |
|
assert_allclose_dense_sparse(X, X) |
|
|
|
with pytest.raises(ValueError, match="Can only compare two sparse"): |
|
assert_allclose_dense_sparse(x, y) |
|
|
|
A = sparse.diags(np.ones(5), offsets=0).tocsr() |
|
B = csr_container(np.ones((1, 5))) |
|
with pytest.raises(AssertionError, match="Arrays are not equal"): |
|
assert_allclose_dense_sparse(B, A) |
|
|
|
|
|
def test_ignore_warning(): |
|
|
|
|
|
def _warning_function(): |
|
warnings.warn("deprecation warning", DeprecationWarning) |
|
|
|
def _multiple_warning_function(): |
|
warnings.warn("deprecation warning", DeprecationWarning) |
|
warnings.warn("deprecation warning") |
|
|
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("error") |
|
|
|
ignore_warnings(_warning_function) |
|
ignore_warnings(_warning_function, category=DeprecationWarning) |
|
|
|
with pytest.warns(DeprecationWarning): |
|
ignore_warnings(_warning_function, category=UserWarning)() |
|
|
|
with pytest.warns() as record: |
|
ignore_warnings(_multiple_warning_function, category=FutureWarning)() |
|
assert len(record) == 2 |
|
assert isinstance(record[0].message, DeprecationWarning) |
|
assert isinstance(record[1].message, UserWarning) |
|
|
|
with pytest.warns() as record: |
|
ignore_warnings(_multiple_warning_function, category=UserWarning)() |
|
assert len(record) == 1 |
|
assert isinstance(record[0].message, DeprecationWarning) |
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("error") |
|
|
|
ignore_warnings(_warning_function, category=(DeprecationWarning, UserWarning)) |
|
|
|
|
|
@ignore_warnings |
|
def decorator_no_warning(): |
|
_warning_function() |
|
_multiple_warning_function() |
|
|
|
@ignore_warnings(category=(DeprecationWarning, UserWarning)) |
|
def decorator_no_warning_multiple(): |
|
_multiple_warning_function() |
|
|
|
@ignore_warnings(category=DeprecationWarning) |
|
def decorator_no_deprecation_warning(): |
|
_warning_function() |
|
|
|
@ignore_warnings(category=UserWarning) |
|
def decorator_no_user_warning(): |
|
_warning_function() |
|
|
|
@ignore_warnings(category=DeprecationWarning) |
|
def decorator_no_deprecation_multiple_warning(): |
|
_multiple_warning_function() |
|
|
|
@ignore_warnings(category=UserWarning) |
|
def decorator_no_user_multiple_warning(): |
|
_multiple_warning_function() |
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("error") |
|
|
|
decorator_no_warning() |
|
decorator_no_warning_multiple() |
|
decorator_no_deprecation_warning() |
|
|
|
with pytest.warns(DeprecationWarning): |
|
decorator_no_user_warning() |
|
with pytest.warns(UserWarning): |
|
decorator_no_deprecation_multiple_warning() |
|
with pytest.warns(DeprecationWarning): |
|
decorator_no_user_multiple_warning() |
|
|
|
|
|
def context_manager_no_warning(): |
|
with ignore_warnings(): |
|
_warning_function() |
|
|
|
def context_manager_no_warning_multiple(): |
|
with ignore_warnings(category=(DeprecationWarning, UserWarning)): |
|
_multiple_warning_function() |
|
|
|
def context_manager_no_deprecation_warning(): |
|
with ignore_warnings(category=DeprecationWarning): |
|
_warning_function() |
|
|
|
def context_manager_no_user_warning(): |
|
with ignore_warnings(category=UserWarning): |
|
_warning_function() |
|
|
|
def context_manager_no_deprecation_multiple_warning(): |
|
with ignore_warnings(category=DeprecationWarning): |
|
_multiple_warning_function() |
|
|
|
def context_manager_no_user_multiple_warning(): |
|
with ignore_warnings(category=UserWarning): |
|
_multiple_warning_function() |
|
|
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("error") |
|
|
|
context_manager_no_warning() |
|
context_manager_no_warning_multiple() |
|
context_manager_no_deprecation_warning() |
|
|
|
with pytest.warns(DeprecationWarning): |
|
context_manager_no_user_warning() |
|
with pytest.warns(UserWarning): |
|
context_manager_no_deprecation_multiple_warning() |
|
with pytest.warns(DeprecationWarning): |
|
context_manager_no_user_multiple_warning() |
|
|
|
|
|
warning_class = UserWarning |
|
match = "'obj' should be a callable.+you should use 'category=UserWarning'" |
|
|
|
with pytest.raises(ValueError, match=match): |
|
silence_warnings_func = ignore_warnings(warning_class)(_warning_function) |
|
silence_warnings_func() |
|
|
|
with pytest.raises(ValueError, match=match): |
|
|
|
@ignore_warnings(warning_class) |
|
def test(): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def f_ok(a, b): |
|
"""Function f |
|
|
|
Parameters |
|
---------- |
|
a : int |
|
Parameter a |
|
b : float |
|
Parameter b |
|
|
|
Returns |
|
------- |
|
c : list |
|
Parameter c |
|
""" |
|
c = a + b |
|
return c |
|
|
|
|
|
def f_bad_sections(a, b): |
|
"""Function f |
|
|
|
Parameters |
|
---------- |
|
a : int |
|
Parameter a |
|
b : float |
|
Parameter b |
|
|
|
Results |
|
------- |
|
c : list |
|
Parameter c |
|
""" |
|
c = a + b |
|
return c |
|
|
|
|
|
def f_bad_order(b, a): |
|
"""Function f |
|
|
|
Parameters |
|
---------- |
|
a : int |
|
Parameter a |
|
b : float |
|
Parameter b |
|
|
|
Returns |
|
------- |
|
c : list |
|
Parameter c |
|
""" |
|
c = a + b |
|
return c |
|
|
|
|
|
def f_too_many_param_docstring(a, b): |
|
"""Function f |
|
|
|
Parameters |
|
---------- |
|
a : int |
|
Parameter a |
|
b : int |
|
Parameter b |
|
c : int |
|
Parameter c |
|
|
|
Returns |
|
------- |
|
d : list |
|
Parameter c |
|
""" |
|
d = a + b |
|
return d |
|
|
|
|
|
def f_missing(a, b): |
|
"""Function f |
|
|
|
Parameters |
|
---------- |
|
a : int |
|
Parameter a |
|
|
|
Returns |
|
------- |
|
c : list |
|
Parameter c |
|
""" |
|
c = a + b |
|
return c |
|
|
|
|
|
def f_check_param_definition(a, b, c, d, e): |
|
"""Function f |
|
|
|
Parameters |
|
---------- |
|
a: int |
|
Parameter a |
|
b: |
|
Parameter b |
|
c : |
|
This is parsed correctly in numpydoc 1.2 |
|
d:int |
|
Parameter d |
|
e |
|
No typespec is allowed without colon |
|
""" |
|
return a + b + c + d |
|
|
|
|
|
class Klass: |
|
def f_missing(self, X, y): |
|
pass |
|
|
|
def f_bad_sections(self, X, y): |
|
"""Function f |
|
|
|
Parameter |
|
--------- |
|
a : int |
|
Parameter a |
|
b : float |
|
Parameter b |
|
|
|
Results |
|
------- |
|
c : list |
|
Parameter c |
|
""" |
|
pass |
|
|
|
|
|
class MockEst: |
|
def __init__(self): |
|
"""MockEstimator""" |
|
|
|
def fit(self, X, y): |
|
return X |
|
|
|
def predict(self, X): |
|
return X |
|
|
|
def predict_proba(self, X): |
|
return X |
|
|
|
def score(self, X): |
|
return 1.0 |
|
|
|
|
|
class MockMetaEstimator: |
|
def __init__(self, delegate): |
|
"""MetaEstimator to check if doctest on delegated methods work. |
|
|
|
Parameters |
|
--------- |
|
delegate : estimator |
|
Delegated estimator. |
|
""" |
|
self.delegate = delegate |
|
|
|
@available_if(lambda self: hasattr(self.delegate, "predict")) |
|
def predict(self, X): |
|
"""This is available only if delegate has predict. |
|
|
|
Parameters |
|
---------- |
|
y : ndarray |
|
Parameter y |
|
""" |
|
return self.delegate.predict(X) |
|
|
|
@available_if(lambda self: hasattr(self.delegate, "score")) |
|
@deprecated("Testing a deprecated delegated method") |
|
def score(self, X): |
|
"""This is available only if delegate has score. |
|
|
|
Parameters |
|
--------- |
|
y : ndarray |
|
Parameter y |
|
""" |
|
|
|
@available_if(lambda self: hasattr(self.delegate, "predict_proba")) |
|
def predict_proba(self, X): |
|
"""This is available only if delegate has predict_proba. |
|
|
|
Parameters |
|
--------- |
|
X : ndarray |
|
Parameter X |
|
""" |
|
return X |
|
|
|
@deprecated("Testing deprecated function with wrong params") |
|
def fit(self, X, y): |
|
"""Incorrect docstring but should not be tested""" |
|
|
|
|
|
@skip_if_no_numpydoc |
|
def test_check_docstring_parameters(): |
|
incorrect = check_docstring_parameters(f_ok) |
|
assert incorrect == [] |
|
incorrect = check_docstring_parameters(f_ok, ignore=["b"]) |
|
assert incorrect == [] |
|
incorrect = check_docstring_parameters(f_missing, ignore=["b"]) |
|
assert incorrect == [] |
|
with pytest.raises(RuntimeError, match="Unknown section Results"): |
|
check_docstring_parameters(f_bad_sections) |
|
with pytest.raises(RuntimeError, match="Unknown section Parameter"): |
|
check_docstring_parameters(Klass.f_bad_sections) |
|
|
|
incorrect = check_docstring_parameters(f_check_param_definition) |
|
mock_meta = MockMetaEstimator(delegate=MockEst()) |
|
mock_meta_name = mock_meta.__class__.__name__ |
|
assert incorrect == [ |
|
( |
|
"sklearn.utils.tests.test_testing.f_check_param_definition There " |
|
"was no space between the param name and colon ('a: int')" |
|
), |
|
( |
|
"sklearn.utils.tests.test_testing.f_check_param_definition There " |
|
"was no space between the param name and colon ('b:')" |
|
), |
|
( |
|
"sklearn.utils.tests.test_testing.f_check_param_definition There " |
|
"was no space between the param name and colon ('d:int')" |
|
), |
|
] |
|
|
|
messages = [ |
|
[ |
|
"In function: sklearn.utils.tests.test_testing.f_bad_order", |
|
( |
|
"There's a parameter name mismatch in function docstring w.r.t." |
|
" function signature, at index 0 diff: 'b' != 'a'" |
|
), |
|
"Full diff:", |
|
"- ['b', 'a']", |
|
"+ ['a', 'b']", |
|
], |
|
[ |
|
"In function: " |
|
+ "sklearn.utils.tests.test_testing.f_too_many_param_docstring", |
|
( |
|
"Parameters in function docstring have more items w.r.t. function" |
|
" signature, first extra item: c" |
|
), |
|
"Full diff:", |
|
"- ['a', 'b']", |
|
"+ ['a', 'b', 'c']", |
|
"? +++++", |
|
], |
|
[ |
|
"In function: sklearn.utils.tests.test_testing.f_missing", |
|
( |
|
"Parameters in function docstring have less items w.r.t. function" |
|
" signature, first missing item: b" |
|
), |
|
"Full diff:", |
|
"- ['a', 'b']", |
|
"+ ['a']", |
|
], |
|
[ |
|
"In function: sklearn.utils.tests.test_testing.Klass.f_missing", |
|
( |
|
"Parameters in function docstring have less items w.r.t. function" |
|
" signature, first missing item: X" |
|
), |
|
"Full diff:", |
|
"- ['X', 'y']", |
|
"+ []", |
|
], |
|
[ |
|
"In function: " |
|
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}.predict", |
|
( |
|
"There's a parameter name mismatch in function docstring w.r.t." |
|
" function signature, at index 0 diff: 'X' != 'y'" |
|
), |
|
"Full diff:", |
|
"- ['X']", |
|
"? ^", |
|
"+ ['y']", |
|
"? ^", |
|
], |
|
[ |
|
"In function: " |
|
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}." |
|
+ "predict_proba", |
|
"potentially wrong underline length... ", |
|
"Parameters ", |
|
"--------- in ", |
|
], |
|
[ |
|
"In function: " |
|
+ f"sklearn.utils.tests.test_testing.{mock_meta_name}.score", |
|
"potentially wrong underline length... ", |
|
"Parameters ", |
|
"--------- in ", |
|
], |
|
[ |
|
"In function: " + f"sklearn.utils.tests.test_testing.{mock_meta_name}.fit", |
|
( |
|
"Parameters in function docstring have less items w.r.t. function" |
|
" signature, first missing item: X" |
|
), |
|
"Full diff:", |
|
"- ['X', 'y']", |
|
"+ []", |
|
], |
|
] |
|
|
|
for msg, f in zip( |
|
messages, |
|
[ |
|
f_bad_order, |
|
f_too_many_param_docstring, |
|
f_missing, |
|
Klass.f_missing, |
|
mock_meta.predict, |
|
mock_meta.predict_proba, |
|
mock_meta.score, |
|
mock_meta.fit, |
|
], |
|
): |
|
incorrect = check_docstring_parameters(f) |
|
assert msg == incorrect, '\n"%s"\n not in \n"%s"' % (msg, incorrect) |
|
|
|
|
|
def f_one(a, b): |
|
"""Function one. |
|
|
|
Parameters |
|
---------- |
|
a : int, float |
|
Parameter a. |
|
Second line. |
|
|
|
b : str |
|
Parameter b. |
|
|
|
Returns |
|
------- |
|
c : int |
|
Returning |
|
|
|
d : int |
|
Returning |
|
""" |
|
pass |
|
|
|
|
|
def f_two(a, b): |
|
"""Function two. |
|
|
|
Parameters |
|
---------- |
|
a : int, float |
|
Parameter a. |
|
Second line. |
|
|
|
b : str |
|
Parameter bb. |
|
|
|
e : int |
|
Extra parameter. |
|
|
|
Returns |
|
------- |
|
c : int |
|
Returning |
|
|
|
d : int |
|
Returning |
|
""" |
|
pass |
|
|
|
|
|
def f_three(a, b): |
|
"""Function two. |
|
|
|
Parameters |
|
---------- |
|
a : int, float |
|
Parameter a. |
|
|
|
b : str |
|
Parameter B! |
|
|
|
e : |
|
Extra parameter. |
|
|
|
Returns |
|
------- |
|
c : int |
|
Returning. |
|
|
|
d : int |
|
Returning |
|
""" |
|
pass |
|
|
|
|
|
@skip_if_no_numpydoc |
|
def test_assert_docstring_consistency_object_type(): |
|
"""Check error raised when `objects` incorrect type.""" |
|
with pytest.raises(TypeError, match="All 'objects' must be one of"): |
|
assert_docstring_consistency(["string", f_one]) |
|
|
|
|
|
@skip_if_no_numpydoc |
|
@pytest.mark.parametrize( |
|
"objects, kwargs, error", |
|
[ |
|
( |
|
[f_one, f_two], |
|
{"include_params": ["a"], "exclude_params": ["b"]}, |
|
"The 'exclude_params' argument", |
|
), |
|
( |
|
[f_one, f_two], |
|
{"include_returns": False, "exclude_returns": ["c"]}, |
|
"The 'exclude_returns' argument", |
|
), |
|
], |
|
) |
|
def test_assert_docstring_consistency_arg_checks(objects, kwargs, error): |
|
"""Check `assert_docstring_consistency` argument checking correct.""" |
|
with pytest.raises(TypeError, match=error): |
|
assert_docstring_consistency(objects, **kwargs) |
|
|
|
|
|
@skip_if_no_numpydoc |
|
@pytest.mark.parametrize( |
|
"objects, kwargs, error, warn", |
|
[ |
|
pytest.param( |
|
[f_one, f_two], {"include_params": ["a"]}, "", "", id="whitespace" |
|
), |
|
pytest.param([f_one, f_two], {"include_returns": True}, "", "", id="incl_all"), |
|
pytest.param( |
|
[f_one, f_two, f_three], |
|
{"include_params": ["a"]}, |
|
( |
|
r"The description of Parameter 'a' is inconsistent between " |
|
r"\['f_one',\n'f_two'\]" |
|
), |
|
"", |
|
id="2-1 group", |
|
), |
|
pytest.param( |
|
[f_one, f_two, f_three], |
|
{"include_params": ["b"]}, |
|
( |
|
r"The description of Parameter 'b' is inconsistent between " |
|
r"\['f_one'\] and\n\['f_two'\] and" |
|
), |
|
"", |
|
id="1-1-1 group", |
|
), |
|
pytest.param( |
|
[f_two, f_three], |
|
{"include_params": ["e"]}, |
|
( |
|
r"The type specification of Parameter 'e' is inconsistent between\n" |
|
r"\['f_two'\] and" |
|
), |
|
"", |
|
id="empty type", |
|
), |
|
pytest.param( |
|
[f_one, f_two], |
|
{"include_params": True, "exclude_params": ["b"]}, |
|
"", |
|
r"Checking was skipped for Parameters: \['e'\]", |
|
id="skip warn", |
|
), |
|
], |
|
) |
|
def test_assert_docstring_consistency(objects, kwargs, error, warn): |
|
"""Check `assert_docstring_consistency` gives correct results.""" |
|
if error: |
|
with pytest.raises(AssertionError, match=error): |
|
assert_docstring_consistency(objects, **kwargs) |
|
elif warn: |
|
with pytest.warns(UserWarning, match=warn): |
|
assert_docstring_consistency(objects, **kwargs) |
|
else: |
|
assert_docstring_consistency(objects, **kwargs) |
|
|
|
|
|
def f_four(labels): |
|
"""Function four. |
|
|
|
Parameters |
|
---------- |
|
|
|
labels : array-like, default=None |
|
The set of labels to include when `average != 'binary'`, and their |
|
order if `average is None`. Labels present in the data can be excluded. |
|
""" |
|
pass |
|
|
|
|
|
def f_five(labels): |
|
"""Function five. |
|
|
|
Parameters |
|
---------- |
|
|
|
labels : array-like, default=None |
|
The set of labels to include when `average != 'binary'`, and their |
|
order if `average is None`. This is an extra line. Labels present in the |
|
data can be excluded. |
|
""" |
|
pass |
|
|
|
|
|
def f_six(labels): |
|
"""Function six. |
|
|
|
Parameters |
|
---------- |
|
|
|
labels : array-like, default=None |
|
The group of labels to add when `average != 'binary'`, and the |
|
order if `average is None`. Labels present on them datas can be excluded. |
|
""" |
|
pass |
|
|
|
|
|
@skip_if_no_numpydoc |
|
def test_assert_docstring_consistency_error_msg(): |
|
"""Check `assert_docstring_consistency` difference message.""" |
|
msg = r"""The description of Parameter 'labels' is inconsistent between |
|
\['f_four'\] and \['f_five'\] and \['f_six'\]: |
|
|
|
\*\*\* \['f_four'\] |
|
--- \['f_five'\] |
|
\*\*\*\*\*\*\*\*\*\*\*\*\*\*\* |
|
|
|
\*\*\* 10,25 \*\*\*\* |
|
|
|
--- 10,30 ---- |
|
|
|
'binary'`, and their order if `average is None`. |
|
\+ This is an extra line. |
|
Labels present in the data can be excluded. |
|
|
|
\*\*\* \['f_four'\] |
|
--- \['f_six'\] |
|
\*\*\*\*\*\*\*\*\*\*\*\*\*\*\* |
|
|
|
\*\*\* 1,25 \*\*\*\* |
|
|
|
The |
|
! set |
|
of labels to |
|
! include |
|
when `average != 'binary'`, and |
|
! their |
|
order if `average is None`. Labels present |
|
! in the data |
|
can be excluded. |
|
--- 1,25 ---- |
|
|
|
The |
|
! group |
|
of labels to |
|
! add |
|
when `average != 'binary'`, and |
|
! the |
|
order if `average is None`. Labels present |
|
! on them datas |
|
can be excluded.""" |
|
|
|
with pytest.raises(AssertionError, match=msg): |
|
assert_docstring_consistency([f_four, f_five, f_six], include_params=True) |
|
|
|
|
|
@skip_if_no_numpydoc |
|
def test_assert_docstring_consistency_descr_regex_pattern(): |
|
"""Check `assert_docstring_consistency` `descr_regex_pattern` works.""" |
|
|
|
regex_full = ( |
|
r"The (set|group) " |
|
+ r"of labels to (include|add) " |
|
+ r"when `average \!\= 'binary'`, and (their|the) " |
|
+ r"order if `average is None`\." |
|
+ r"[\s\w]*\.* " |
|
+ r"Labels present (on|in) " |
|
+ r"(them|the) " |
|
+ r"datas? can be excluded\." |
|
) |
|
|
|
assert_docstring_consistency( |
|
[f_four, f_five, f_six], |
|
include_params=True, |
|
descr_regex_pattern=" ".join(regex_full.split()), |
|
) |
|
|
|
regex_words = r"(labels|average|binary)" |
|
assert_docstring_consistency( |
|
[f_four, f_five, f_six], |
|
include_params=True, |
|
descr_regex_pattern=" ".join(regex_words.split()), |
|
) |
|
|
|
regex_error = r"The set of labels to include when.+" |
|
msg = r"The description of Parameter 'labels' in \['f_six'\] does not match" |
|
with pytest.raises(AssertionError, match=msg): |
|
assert_docstring_consistency( |
|
[f_four, f_five, f_six], |
|
include_params=True, |
|
descr_regex_pattern=" ".join(regex_error.split()), |
|
) |
|
|
|
|
|
class RegistrationCounter: |
|
def __init__(self): |
|
self.nb_calls = 0 |
|
|
|
def __call__(self, to_register_func): |
|
self.nb_calls += 1 |
|
assert to_register_func.func is _delete_folder |
|
|
|
|
|
def check_memmap(input_array, mmap_data, mmap_mode="r"): |
|
assert isinstance(mmap_data, np.memmap) |
|
writeable = mmap_mode != "r" |
|
assert mmap_data.flags.writeable is writeable |
|
np.testing.assert_array_equal(input_array, mmap_data) |
|
|
|
|
|
def test_tempmemmap(monkeypatch): |
|
registration_counter = RegistrationCounter() |
|
monkeypatch.setattr(atexit, "register", registration_counter) |
|
|
|
input_array = np.ones(3) |
|
with TempMemmap(input_array) as data: |
|
check_memmap(input_array, data) |
|
temp_folder = os.path.dirname(data.filename) |
|
if os.name != "nt": |
|
assert not os.path.exists(temp_folder) |
|
assert registration_counter.nb_calls == 1 |
|
|
|
mmap_mode = "r+" |
|
with TempMemmap(input_array, mmap_mode=mmap_mode) as data: |
|
check_memmap(input_array, data, mmap_mode=mmap_mode) |
|
temp_folder = os.path.dirname(data.filename) |
|
if os.name != "nt": |
|
assert not os.path.exists(temp_folder) |
|
assert registration_counter.nb_calls == 2 |
|
|
|
|
|
@pytest.mark.xfail(_IS_WASM, reason="memmap not fully supported") |
|
def test_create_memmap_backed_data(monkeypatch): |
|
registration_counter = RegistrationCounter() |
|
monkeypatch.setattr(atexit, "register", registration_counter) |
|
|
|
input_array = np.ones(3) |
|
data = create_memmap_backed_data(input_array) |
|
check_memmap(input_array, data) |
|
assert registration_counter.nb_calls == 1 |
|
|
|
data, folder = create_memmap_backed_data(input_array, return_folder=True) |
|
check_memmap(input_array, data) |
|
assert folder == os.path.dirname(data.filename) |
|
assert registration_counter.nb_calls == 2 |
|
|
|
mmap_mode = "r+" |
|
data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode) |
|
check_memmap(input_array, data, mmap_mode) |
|
assert registration_counter.nb_calls == 3 |
|
|
|
input_list = [input_array, input_array + 1, input_array + 2] |
|
mmap_data_list = create_memmap_backed_data(input_list) |
|
for input_array, data in zip(input_list, mmap_data_list): |
|
check_memmap(input_array, data) |
|
assert registration_counter.nb_calls == 4 |
|
|
|
output_data, other = create_memmap_backed_data([input_array, "not-an-array"]) |
|
check_memmap(input_array, output_data) |
|
assert other == "not-an-array" |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"constructor_name, container_type", |
|
[ |
|
("list", list), |
|
("tuple", tuple), |
|
("array", np.ndarray), |
|
("sparse", sparse.csr_matrix), |
|
|
|
|
|
*zip(["sparse_csr", "sparse_csr_array"], CSR_CONTAINERS), |
|
*zip(["sparse_csc", "sparse_csc_array"], CSC_CONTAINERS), |
|
("dataframe", lambda: pytest.importorskip("pandas").DataFrame), |
|
("series", lambda: pytest.importorskip("pandas").Series), |
|
("index", lambda: pytest.importorskip("pandas").Index), |
|
("slice", slice), |
|
], |
|
) |
|
@pytest.mark.parametrize( |
|
"dtype, superdtype", |
|
[ |
|
(np.int32, np.integer), |
|
(np.int64, np.integer), |
|
(np.float32, np.floating), |
|
(np.float64, np.floating), |
|
], |
|
) |
|
def test_convert_container( |
|
constructor_name, |
|
container_type, |
|
dtype, |
|
superdtype, |
|
): |
|
"""Check that we convert the container to the right type of array with the |
|
right data type.""" |
|
if constructor_name in ("dataframe", "polars", "series", "polars_series", "index"): |
|
|
|
|
|
container_type = container_type() |
|
container = [0, 1] |
|
|
|
container_converted = _convert_container( |
|
container, |
|
constructor_name, |
|
dtype=dtype, |
|
) |
|
assert isinstance(container_converted, container_type) |
|
|
|
if constructor_name in ("list", "tuple", "index"): |
|
|
|
|
|
assert np.issubdtype(type(container_converted[0]), superdtype) |
|
elif hasattr(container_converted, "dtype"): |
|
assert container_converted.dtype == dtype |
|
elif hasattr(container_converted, "dtypes"): |
|
assert container_converted.dtypes[0] == dtype |
|
|
|
|
|
def test_convert_container_categories_pandas(): |
|
pytest.importorskip("pandas") |
|
df = _convert_container( |
|
[["x"]], "dataframe", ["A"], categorical_feature_names=["A"] |
|
) |
|
assert df.dtypes.iloc[0] == "category" |
|
|
|
|
|
def test_convert_container_categories_polars(): |
|
pl = pytest.importorskip("polars") |
|
df = _convert_container([["x"]], "polars", ["A"], categorical_feature_names=["A"]) |
|
assert df.schema["A"] == pl.Categorical() |
|
|
|
|
|
def test_convert_container_categories_pyarrow(): |
|
pa = pytest.importorskip("pyarrow") |
|
df = _convert_container([["x"]], "pyarrow", ["A"], categorical_feature_names=["A"]) |
|
assert type(df.schema[0].type) is pa.DictionaryType |
|
|
|
|
|
@pytest.mark.skipif( |
|
sp_version >= parse_version("1.8"), |
|
reason="sparse arrays are available as of scipy 1.8.0", |
|
) |
|
@pytest.mark.parametrize("constructor_name", ["sparse_csr_array", "sparse_csc_array"]) |
|
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64]) |
|
def test_convert_container_raise_when_sparray_not_available(constructor_name, dtype): |
|
"""Check that if we convert to sparse array but sparse array are not supported |
|
(scipy<1.8.0), we should raise an explicit error.""" |
|
container = [0, 1] |
|
|
|
with pytest.raises( |
|
ValueError, |
|
match=f"only available with scipy>=1.8.0, got {sp_version}", |
|
): |
|
_convert_container(container, constructor_name, dtype=dtype) |
|
|
|
|
|
def test_raises(): |
|
|
|
|
|
|
|
with raises(TypeError): |
|
raise TypeError() |
|
|
|
|
|
with raises(TypeError, match="how are you") as cm: |
|
raise TypeError("hello how are you") |
|
assert cm.raised_and_matched |
|
|
|
|
|
with raises(TypeError, match=["not this one", "how are you"]) as cm: |
|
raise TypeError("hello how are you") |
|
assert cm.raised_and_matched |
|
|
|
|
|
with pytest.raises(ValueError, match="this will be raised"): |
|
with raises(TypeError) as cm: |
|
raise ValueError("this will be raised") |
|
assert not cm.raised_and_matched |
|
|
|
|
|
with pytest.raises(AssertionError, match="the failure message"): |
|
with raises(TypeError, err_msg="the failure message") as cm: |
|
raise ValueError() |
|
assert not cm.raised_and_matched |
|
|
|
|
|
with pytest.raises(ValueError, match="this will be raised"): |
|
with raises(TypeError, match="this is ignored") as cm: |
|
raise ValueError("this will be raised") |
|
assert not cm.raised_and_matched |
|
|
|
|
|
with pytest.raises( |
|
AssertionError, match="should contain one of the following patterns" |
|
): |
|
with raises(TypeError, match="hello") as cm: |
|
raise TypeError("Bad message") |
|
assert not cm.raised_and_matched |
|
|
|
|
|
with pytest.raises(AssertionError, match="the failure message"): |
|
with raises(TypeError, match="hello", err_msg="the failure message") as cm: |
|
raise TypeError("Bad message") |
|
assert not cm.raised_and_matched |
|
|
|
|
|
with pytest.raises(AssertionError, match="Did not raise"): |
|
with raises(TypeError) as cm: |
|
pass |
|
assert not cm.raised_and_matched |
|
|
|
|
|
with raises(TypeError, match="hello", may_pass=True) as cm: |
|
pass |
|
assert not cm.raised_and_matched |
|
|
|
|
|
with raises((TypeError, ValueError)): |
|
raise TypeError() |
|
with raises((TypeError, ValueError)): |
|
raise ValueError() |
|
with pytest.raises(AssertionError): |
|
with raises((TypeError, ValueError)): |
|
pass |
|
|
|
|
|
def test_float32_aware_assert_allclose(): |
|
|
|
assert_allclose(np.array([1.0 + 2e-5], dtype=np.float32), 1.0) |
|
with pytest.raises(AssertionError): |
|
assert_allclose(np.array([1.0 + 2e-4], dtype=np.float32), 1.0) |
|
|
|
|
|
|
|
assert_allclose(np.array([1.0 + 2e-8], dtype=np.float64), 1.0) |
|
with pytest.raises(AssertionError): |
|
assert_allclose(np.array([1.0 + 2e-7], dtype=np.float64), 1.0) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0) |
|
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0, atol=2e-5) |
|
|
|
|
|
@pytest.mark.xfail(_IS_WASM, reason="cannot start subprocess") |
|
def test_assert_run_python_script_without_output(): |
|
code = "x = 1" |
|
assert_run_python_script_without_output(code) |
|
|
|
code = "print('something to stdout')" |
|
with pytest.raises(AssertionError, match="Expected no output"): |
|
assert_run_python_script_without_output(code) |
|
|
|
code = "print('something to stdout')" |
|
with pytest.raises( |
|
AssertionError, |
|
match="output was not supposed to match.+got.+something to stdout", |
|
): |
|
assert_run_python_script_without_output(code, pattern="to.+stdout") |
|
|
|
code = "\n".join(["import sys", "print('something to stderr', file=sys.stderr)"]) |
|
with pytest.raises( |
|
AssertionError, |
|
match="output was not supposed to match.+got.+something to stderr", |
|
): |
|
assert_run_python_script_without_output(code, pattern="to.+stderr") |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"constructor_name", |
|
[ |
|
"sparse_csr", |
|
"sparse_csc", |
|
pytest.param( |
|
"sparse_csr_array", |
|
marks=pytest.mark.skipif( |
|
sp_version < parse_version("1.8"), |
|
reason="sparse arrays are available as of scipy 1.8.0", |
|
), |
|
), |
|
pytest.param( |
|
"sparse_csc_array", |
|
marks=pytest.mark.skipif( |
|
sp_version < parse_version("1.8"), |
|
reason="sparse arrays are available as of scipy 1.8.0", |
|
), |
|
), |
|
], |
|
) |
|
def test_convert_container_sparse_to_sparse(constructor_name): |
|
"""Non-regression test to check that we can still convert a sparse container |
|
from a given format to another format. |
|
""" |
|
X_sparse = sparse.random(10, 10, density=0.1, format="csr") |
|
_convert_container(X_sparse, constructor_name) |
|
|
|
|
|
def check_warnings_as_errors(warning_info, warnings_as_errors): |
|
if warning_info.action == "error" and warnings_as_errors: |
|
with pytest.raises(warning_info.category, match=warning_info.message): |
|
warnings.warn( |
|
message=warning_info.message, |
|
category=warning_info.category, |
|
) |
|
if warning_info.action == "ignore": |
|
with warnings.catch_warnings(record=True) as record: |
|
message = warning_info.message |
|
|
|
if "Pyarrow" in message: |
|
message = "\nPyarrow will become a required dependency" |
|
|
|
warnings.warn( |
|
message=message, |
|
category=warning_info.category, |
|
) |
|
assert len(record) == 0 if warnings_as_errors else 1 |
|
if record: |
|
assert str(record[0].message) == message |
|
assert record[0].category == warning_info.category |
|
|
|
|
|
@pytest.mark.parametrize("warning_info", _get_warnings_filters_info_list()) |
|
def test_sklearn_warnings_as_errors(warning_info): |
|
warnings_as_errors = os.environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0" |
|
check_warnings_as_errors(warning_info, warnings_as_errors=warnings_as_errors) |
|
|
|
|
|
@pytest.mark.parametrize("warning_info", _get_warnings_filters_info_list()) |
|
def test_turn_warnings_into_errors(warning_info): |
|
with warnings.catch_warnings(): |
|
turn_warnings_into_errors() |
|
check_warnings_as_errors(warning_info, warnings_as_errors=True) |
|
|