|
""" |
|
Utility function to facilitate testing. |
|
|
|
""" |
|
import os |
|
import sys |
|
import pathlib |
|
import platform |
|
import re |
|
import gc |
|
import operator |
|
import warnings |
|
from functools import partial, wraps |
|
import shutil |
|
import contextlib |
|
from tempfile import mkdtemp, mkstemp |
|
from unittest.case import SkipTest |
|
from warnings import WarningMessage |
|
import pprint |
|
import sysconfig |
|
import concurrent.futures |
|
import threading |
|
import importlib.metadata |
|
|
|
import numpy as np |
|
from numpy._core import ( |
|
intp, float32, empty, arange, array_repr, ndarray, isnat, array) |
|
from numpy import isfinite, isnan, isinf |
|
import numpy.linalg._umath_linalg |
|
from numpy._utils import _rename_parameter |
|
from numpy._core.tests._natype import pd_NA |
|
|
|
from io import StringIO |
|
|
|
|
|
__all__ = [ |
|
'assert_equal', 'assert_almost_equal', 'assert_approx_equal', |
|
'assert_array_equal', 'assert_array_less', 'assert_string_equal', |
|
'assert_array_almost_equal', 'assert_raises', 'build_err_msg', |
|
'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', |
|
'rundocs', 'runstring', 'verbose', 'measure', |
|
'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', |
|
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', |
|
'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', |
|
'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', |
|
'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare', |
|
'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON', |
|
'IS_MUSL', 'check_support_sve', 'NOGIL_BUILD', |
|
'IS_EDITABLE', 'IS_INSTALLED', 'NUMPY_ROOT', 'run_threaded', 'IS_64BIT', |
|
] |
|
|
|
|
|
class KnownFailureException(Exception): |
|
'''Raise this exception to mark a test as a known failing test.''' |
|
pass |
|
|
|
|
|
KnownFailureTest = KnownFailureException |
|
verbose = 0 |
|
|
|
NUMPY_ROOT = pathlib.Path(np.__file__).parent |
|
|
|
try: |
|
np_dist = importlib.metadata.distribution('numpy') |
|
except importlib.metadata.PackageNotFoundError: |
|
IS_INSTALLED = IS_EDITABLE = False |
|
else: |
|
IS_INSTALLED = True |
|
try: |
|
if sys.version_info >= (3, 13): |
|
IS_EDITABLE = np_dist.origin.dir_info.editable |
|
else: |
|
|
|
import json, types |
|
origin = json.loads( |
|
np_dist.read_text('direct_url.json') or '{}', |
|
object_hook=lambda data: types.SimpleNamespace(**data), |
|
) |
|
IS_EDITABLE = origin.dir_info.editable |
|
except AttributeError: |
|
IS_EDITABLE = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not IS_EDITABLE and np_dist.locate_file('numpy') != NUMPY_ROOT: |
|
IS_INSTALLED = False |
|
|
|
IS_WASM = platform.machine() in ["wasm32", "wasm64"] |
|
IS_PYPY = sys.implementation.name == 'pypy' |
|
IS_PYSTON = hasattr(sys, "pyston_version_info") |
|
HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON |
|
HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64 |
|
|
|
IS_MUSL = False |
|
|
|
|
|
|
|
|
|
_v = sysconfig.get_config_var('HOST_GNU_TYPE') or '' |
|
if 'musl' in _v: |
|
IS_MUSL = True |
|
|
|
NOGIL_BUILD = bool(sysconfig.get_config_var("Py_GIL_DISABLED")) |
|
IS_64BIT = np.dtype(np.intp).itemsize == 8 |
|
|
|
def assert_(val, msg=''): |
|
""" |
|
Assert that works in release mode. |
|
Accepts callable msg to allow deferring evaluation until failure. |
|
|
|
The Python built-in ``assert`` does not work when executing code in |
|
optimized mode (the ``-O`` flag) - no byte-code is generated for it. |
|
|
|
For documentation on usage, refer to the Python documentation. |
|
|
|
""" |
|
__tracebackhide__ = True |
|
if not val: |
|
try: |
|
smsg = msg() |
|
except TypeError: |
|
smsg = msg |
|
raise AssertionError(smsg) |
|
|
|
|
|
if os.name == 'nt': |
|
|
|
def GetPerformanceAttributes(object, counter, instance=None, |
|
inum=-1, format=None, machine=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import win32pdh |
|
if format is None: |
|
format = win32pdh.PDH_FMT_LONG |
|
path = win32pdh.MakeCounterPath((machine, object, instance, None, |
|
inum, counter)) |
|
hq = win32pdh.OpenQuery() |
|
try: |
|
hc = win32pdh.AddCounter(hq, path) |
|
try: |
|
win32pdh.CollectQueryData(hq) |
|
type, val = win32pdh.GetFormattedCounterValue(hc, format) |
|
return val |
|
finally: |
|
win32pdh.RemoveCounter(hc) |
|
finally: |
|
win32pdh.CloseQuery(hq) |
|
|
|
def memusage(processName="python", instance=0): |
|
|
|
import win32pdh |
|
return GetPerformanceAttributes("Process", "Virtual Bytes", |
|
processName, instance, |
|
win32pdh.PDH_FMT_LONG, None) |
|
elif sys.platform[:5] == 'linux': |
|
|
|
def memusage(_proc_pid_stat=f'/proc/{os.getpid()}/stat'): |
|
""" |
|
Return virtual memory size in bytes of the running python. |
|
|
|
""" |
|
try: |
|
with open(_proc_pid_stat) as f: |
|
l = f.readline().split(' ') |
|
return int(l[22]) |
|
except Exception: |
|
return |
|
else: |
|
def memusage(): |
|
""" |
|
Return memory usage of running python. [Not implemented] |
|
|
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
if sys.platform[:5] == 'linux': |
|
def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]): |
|
""" |
|
Return number of jiffies elapsed. |
|
|
|
Return number of jiffies (1/100ths of a second) that this |
|
process has been scheduled in user mode. See man 5 proc. |
|
|
|
""" |
|
import time |
|
if not _load_time: |
|
_load_time.append(time.time()) |
|
try: |
|
with open(_proc_pid_stat) as f: |
|
l = f.readline().split(' ') |
|
return int(l[13]) |
|
except Exception: |
|
return int(100 * (time.time() - _load_time[0])) |
|
else: |
|
|
|
|
|
|
|
def jiffies(_load_time=[]): |
|
""" |
|
Return number of jiffies elapsed. |
|
|
|
Return number of jiffies (1/100ths of a second) that this |
|
process has been scheduled in user mode. See man 5 proc. |
|
|
|
""" |
|
import time |
|
if not _load_time: |
|
_load_time.append(time.time()) |
|
return int(100 * (time.time() - _load_time[0])) |
|
|
|
|
|
def build_err_msg(arrays, err_msg, header='Items are not equal:', |
|
verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): |
|
msg = ['\n' + header] |
|
err_msg = str(err_msg) |
|
if err_msg: |
|
if err_msg.find('\n') == -1 and len(err_msg) < 79 - len(header): |
|
msg = [msg[0] + ' ' + err_msg] |
|
else: |
|
msg.append(err_msg) |
|
if verbose: |
|
for i, a in enumerate(arrays): |
|
|
|
if isinstance(a, ndarray): |
|
|
|
r_func = partial(array_repr, precision=precision) |
|
else: |
|
r_func = repr |
|
|
|
try: |
|
r = r_func(a) |
|
except Exception as exc: |
|
r = f'[repr failed for <{type(a).__name__}>: {exc}]' |
|
if r.count('\n') > 3: |
|
r = '\n'.join(r.splitlines()[:3]) |
|
r += '...' |
|
msg.append(f' {names[i]}: {r}') |
|
return '\n'.join(msg) |
|
|
|
|
|
def assert_equal(actual, desired, err_msg='', verbose=True, *, strict=False): |
|
""" |
|
Raises an AssertionError if two objects are not equal. |
|
|
|
Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), |
|
check that all elements of these objects are equal. An exception is raised |
|
at the first conflicting values. |
|
|
|
This function handles NaN comparisons as if NaN was a "normal" number. |
|
That is, AssertionError is not raised if both objects have NaNs in the same |
|
positions. This is in contrast to the IEEE standard on NaNs, which says |
|
that NaN compared to anything must return False. |
|
|
|
Parameters |
|
---------- |
|
actual : array_like |
|
The object to check. |
|
desired : array_like |
|
The expected object. |
|
err_msg : str, optional |
|
The error message to be printed in case of failure. |
|
verbose : bool, optional |
|
If True, the conflicting values are appended to the error message. |
|
strict : bool, optional |
|
If True and either of the `actual` and `desired` arguments is an array, |
|
raise an ``AssertionError`` when either the shape or the data type of |
|
the arguments does not match. If neither argument is an array, this |
|
parameter has no effect. |
|
|
|
.. versionadded:: 2.0.0 |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If actual and desired are not equal. |
|
|
|
See Also |
|
-------- |
|
assert_allclose |
|
assert_array_almost_equal_nulp, |
|
assert_array_max_ulp, |
|
|
|
Notes |
|
----- |
|
By default, when one of `actual` and `desired` is a scalar and the other is |
|
an array, the function checks that each element of the array is equal to |
|
the scalar. This behaviour can be disabled by setting ``strict==True``. |
|
|
|
Examples |
|
-------- |
|
>>> np.testing.assert_equal([4, 5], [4, 6]) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Items are not equal: |
|
item=1 |
|
ACTUAL: 5 |
|
DESIRED: 6 |
|
|
|
The following comparison does not raise an exception. There are NaNs |
|
in the inputs, but they are in the same positions. |
|
|
|
>>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) |
|
|
|
As mentioned in the Notes section, `assert_equal` has special |
|
handling for scalars when one of the arguments is an array. |
|
Here, the test checks that each value in `x` is 3: |
|
|
|
>>> x = np.full((2, 5), fill_value=3) |
|
>>> np.testing.assert_equal(x, 3) |
|
|
|
Use `strict` to raise an AssertionError when comparing a scalar with an |
|
array of a different shape: |
|
|
|
>>> np.testing.assert_equal(x, 3, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not equal |
|
<BLANKLINE> |
|
(shapes (2, 5), () mismatch) |
|
ACTUAL: array([[3, 3, 3, 3, 3], |
|
[3, 3, 3, 3, 3]]) |
|
DESIRED: array(3) |
|
|
|
The `strict` parameter also ensures that the array data types match: |
|
|
|
>>> x = np.array([2, 2, 2]) |
|
>>> y = np.array([2., 2., 2.], dtype=np.float32) |
|
>>> np.testing.assert_equal(x, y, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not equal |
|
<BLANKLINE> |
|
(dtypes int64, float32 mismatch) |
|
ACTUAL: array([2, 2, 2]) |
|
DESIRED: array([2., 2., 2.], dtype=float32) |
|
""" |
|
__tracebackhide__ = True |
|
if isinstance(desired, dict): |
|
if not isinstance(actual, dict): |
|
raise AssertionError(repr(type(actual))) |
|
assert_equal(len(actual), len(desired), err_msg, verbose) |
|
for k, i in desired.items(): |
|
if k not in actual: |
|
raise AssertionError(repr(k)) |
|
assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}', |
|
verbose) |
|
return |
|
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): |
|
assert_equal(len(actual), len(desired), err_msg, verbose) |
|
for k in range(len(desired)): |
|
assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}', |
|
verbose) |
|
return |
|
from numpy._core import ndarray, isscalar, signbit |
|
from numpy import iscomplexobj, real, imag |
|
if isinstance(actual, ndarray) or isinstance(desired, ndarray): |
|
return assert_array_equal(actual, desired, err_msg, verbose, |
|
strict=strict) |
|
msg = build_err_msg([actual, desired], err_msg, verbose=verbose) |
|
|
|
|
|
|
|
|
|
try: |
|
usecomplex = iscomplexobj(actual) or iscomplexobj(desired) |
|
except (ValueError, TypeError): |
|
usecomplex = False |
|
|
|
if usecomplex: |
|
if iscomplexobj(actual): |
|
actualr = real(actual) |
|
actuali = imag(actual) |
|
else: |
|
actualr = actual |
|
actuali = 0 |
|
if iscomplexobj(desired): |
|
desiredr = real(desired) |
|
desiredi = imag(desired) |
|
else: |
|
desiredr = desired |
|
desiredi = 0 |
|
try: |
|
assert_equal(actualr, desiredr) |
|
assert_equal(actuali, desiredi) |
|
except AssertionError: |
|
raise AssertionError(msg) |
|
|
|
|
|
if isscalar(desired) != isscalar(actual): |
|
raise AssertionError(msg) |
|
|
|
try: |
|
isdesnat = isnat(desired) |
|
isactnat = isnat(actual) |
|
dtypes_match = (np.asarray(desired).dtype.type == |
|
np.asarray(actual).dtype.type) |
|
if isdesnat and isactnat: |
|
|
|
|
|
if dtypes_match: |
|
return |
|
else: |
|
raise AssertionError(msg) |
|
|
|
except (TypeError, ValueError, NotImplementedError): |
|
pass |
|
|
|
|
|
try: |
|
isdesnan = isnan(desired) |
|
isactnan = isnan(actual) |
|
if isdesnan and isactnan: |
|
return |
|
|
|
|
|
array_actual = np.asarray(actual) |
|
array_desired = np.asarray(desired) |
|
if (array_actual.dtype.char in 'Mm' or |
|
array_desired.dtype.char in 'Mm'): |
|
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError('cannot compare to a scalar ' |
|
'with a different type') |
|
|
|
if desired == 0 and actual == 0: |
|
if not signbit(desired) == signbit(actual): |
|
raise AssertionError(msg) |
|
|
|
except (TypeError, ValueError, NotImplementedError): |
|
pass |
|
|
|
try: |
|
|
|
if not (desired == actual): |
|
raise AssertionError(msg) |
|
|
|
except (DeprecationWarning, FutureWarning) as e: |
|
|
|
if 'elementwise == comparison' in e.args[0]: |
|
raise AssertionError(msg) |
|
else: |
|
raise |
|
|
|
|
|
def print_assert_equal(test_string, actual, desired): |
|
""" |
|
Test if two objects are equal, and print an error message if test fails. |
|
|
|
The test is performed with ``actual == desired``. |
|
|
|
Parameters |
|
---------- |
|
test_string : str |
|
The message supplied to AssertionError. |
|
actual : object |
|
The object to test for equality against `desired`. |
|
desired : object |
|
The expected result. |
|
|
|
Examples |
|
-------- |
|
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) |
|
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: Test XYZ of func xyz failed |
|
ACTUAL: |
|
[0, 1] |
|
DESIRED: |
|
[0, 2] |
|
|
|
""" |
|
__tracebackhide__ = True |
|
import pprint |
|
|
|
if not (actual == desired): |
|
msg = StringIO() |
|
msg.write(test_string) |
|
msg.write(' failed\nACTUAL: \n') |
|
pprint.pprint(actual, msg) |
|
msg.write('DESIRED: \n') |
|
pprint.pprint(desired, msg) |
|
raise AssertionError(msg.getvalue()) |
|
|
|
|
|
def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): |
|
""" |
|
Raises an AssertionError if two items are not equal up to desired |
|
precision. |
|
|
|
.. note:: It is recommended to use one of `assert_allclose`, |
|
`assert_array_almost_equal_nulp` or `assert_array_max_ulp` |
|
instead of this function for more consistent floating point |
|
comparisons. |
|
|
|
The test verifies that the elements of `actual` and `desired` satisfy:: |
|
|
|
abs(desired-actual) < float64(1.5 * 10**(-decimal)) |
|
|
|
That is a looser test than originally documented, but agrees with what the |
|
actual implementation in `assert_array_almost_equal` did up to rounding |
|
vagaries. An exception is raised at conflicting values. For ndarrays this |
|
delegates to assert_array_almost_equal |
|
|
|
Parameters |
|
---------- |
|
actual : array_like |
|
The object to check. |
|
desired : array_like |
|
The expected object. |
|
decimal : int, optional |
|
Desired precision, default is 7. |
|
err_msg : str, optional |
|
The error message to be printed in case of failure. |
|
verbose : bool, optional |
|
If True, the conflicting values are appended to the error message. |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If actual and desired are not equal up to specified precision. |
|
|
|
See Also |
|
-------- |
|
assert_allclose: Compare two array_like objects for equality with desired |
|
relative and/or absolute precision. |
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal |
|
|
|
Examples |
|
-------- |
|
>>> from numpy.testing import assert_almost_equal |
|
>>> assert_almost_equal(2.3333333333333, 2.33333334) |
|
>>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not almost equal to 10 decimals |
|
ACTUAL: 2.3333333333333 |
|
DESIRED: 2.33333334 |
|
|
|
>>> assert_almost_equal(np.array([1.0,2.3333333333333]), |
|
... np.array([1.0,2.33333334]), decimal=9) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not almost equal to 9 decimals |
|
<BLANKLINE> |
|
Mismatched elements: 1 / 2 (50%) |
|
Max absolute difference among violations: 6.66669964e-09 |
|
Max relative difference among violations: 2.85715698e-09 |
|
ACTUAL: array([1. , 2.333333333]) |
|
DESIRED: array([1. , 2.33333334]) |
|
|
|
""" |
|
__tracebackhide__ = True |
|
from numpy._core import ndarray |
|
from numpy import iscomplexobj, real, imag |
|
|
|
|
|
|
|
|
|
try: |
|
usecomplex = iscomplexobj(actual) or iscomplexobj(desired) |
|
except ValueError: |
|
usecomplex = False |
|
|
|
def _build_err_msg(): |
|
header = ('Arrays are not almost equal to %d decimals' % decimal) |
|
return build_err_msg([actual, desired], err_msg, verbose=verbose, |
|
header=header) |
|
|
|
if usecomplex: |
|
if iscomplexobj(actual): |
|
actualr = real(actual) |
|
actuali = imag(actual) |
|
else: |
|
actualr = actual |
|
actuali = 0 |
|
if iscomplexobj(desired): |
|
desiredr = real(desired) |
|
desiredi = imag(desired) |
|
else: |
|
desiredr = desired |
|
desiredi = 0 |
|
try: |
|
assert_almost_equal(actualr, desiredr, decimal=decimal) |
|
assert_almost_equal(actuali, desiredi, decimal=decimal) |
|
except AssertionError: |
|
raise AssertionError(_build_err_msg()) |
|
|
|
if isinstance(actual, (ndarray, tuple, list)) \ |
|
or isinstance(desired, (ndarray, tuple, list)): |
|
return assert_array_almost_equal(actual, desired, decimal, err_msg) |
|
try: |
|
|
|
|
|
|
|
if not (isfinite(desired) and isfinite(actual)): |
|
if isnan(desired) or isnan(actual): |
|
if not (isnan(desired) and isnan(actual)): |
|
raise AssertionError(_build_err_msg()) |
|
else: |
|
if not desired == actual: |
|
raise AssertionError(_build_err_msg()) |
|
return |
|
except (NotImplementedError, TypeError): |
|
pass |
|
if abs(desired - actual) >= np.float64(1.5 * 10.0**(-decimal)): |
|
raise AssertionError(_build_err_msg()) |
|
|
|
|
|
def assert_approx_equal(actual, desired, significant=7, err_msg='', |
|
verbose=True): |
|
""" |
|
Raises an AssertionError if two items are not equal up to significant |
|
digits. |
|
|
|
.. note:: It is recommended to use one of `assert_allclose`, |
|
`assert_array_almost_equal_nulp` or `assert_array_max_ulp` |
|
instead of this function for more consistent floating point |
|
comparisons. |
|
|
|
Given two numbers, check that they are approximately equal. |
|
Approximately equal is defined as the number of significant digits |
|
that agree. |
|
|
|
Parameters |
|
---------- |
|
actual : scalar |
|
The object to check. |
|
desired : scalar |
|
The expected object. |
|
significant : int, optional |
|
Desired precision, default is 7. |
|
err_msg : str, optional |
|
The error message to be printed in case of failure. |
|
verbose : bool, optional |
|
If True, the conflicting values are appended to the error message. |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If actual and desired are not equal up to specified precision. |
|
|
|
See Also |
|
-------- |
|
assert_allclose: Compare two array_like objects for equality with desired |
|
relative and/or absolute precision. |
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal |
|
|
|
Examples |
|
-------- |
|
>>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) |
|
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, |
|
... significant=8) |
|
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, |
|
... significant=8) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Items are not equal to 8 significant digits: |
|
ACTUAL: 1.234567e-21 |
|
DESIRED: 1.2345672e-21 |
|
|
|
the evaluated condition that raises the exception is |
|
|
|
>>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) |
|
True |
|
|
|
""" |
|
__tracebackhide__ = True |
|
import numpy as np |
|
|
|
(actual, desired) = map(float, (actual, desired)) |
|
if desired == actual: |
|
return |
|
|
|
|
|
with np.errstate(invalid='ignore'): |
|
scale = 0.5 * (np.abs(desired) + np.abs(actual)) |
|
scale = np.power(10, np.floor(np.log10(scale))) |
|
try: |
|
sc_desired = desired / scale |
|
except ZeroDivisionError: |
|
sc_desired = 0.0 |
|
try: |
|
sc_actual = actual / scale |
|
except ZeroDivisionError: |
|
sc_actual = 0.0 |
|
msg = build_err_msg( |
|
[actual, desired], err_msg, |
|
header='Items are not equal to %d significant digits:' % significant, |
|
verbose=verbose) |
|
try: |
|
|
|
|
|
|
|
if not (isfinite(desired) and isfinite(actual)): |
|
if isnan(desired) or isnan(actual): |
|
if not (isnan(desired) and isnan(actual)): |
|
raise AssertionError(msg) |
|
else: |
|
if not desired == actual: |
|
raise AssertionError(msg) |
|
return |
|
except (TypeError, NotImplementedError): |
|
pass |
|
if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant - 1)): |
|
raise AssertionError(msg) |
|
|
|
|
|
def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', |
|
precision=6, equal_nan=True, equal_inf=True, |
|
*, strict=False, names=('ACTUAL', 'DESIRED')): |
|
__tracebackhide__ = True |
|
from numpy._core import (array2string, isnan, inf, errstate, |
|
all, max, object_) |
|
|
|
x = np.asanyarray(x) |
|
y = np.asanyarray(y) |
|
|
|
|
|
ox, oy = x, y |
|
|
|
def isnumber(x): |
|
return x.dtype.char in '?bhilqpBHILQPefdgFDG' |
|
|
|
def istime(x): |
|
return x.dtype.char in "Mm" |
|
|
|
def isvstring(x): |
|
return x.dtype.char == "T" |
|
|
|
def func_assert_same_pos(x, y, func=isnan, hasval='nan'): |
|
"""Handling nan/inf. |
|
|
|
Combine results of running func on x and y, checking that they are True |
|
at the same locations. |
|
|
|
""" |
|
__tracebackhide__ = True |
|
|
|
x_id = func(x) |
|
y_id = func(y) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if np.bool(x_id == y_id).all() != True: |
|
msg = build_err_msg( |
|
[x, y], |
|
err_msg + '\n%s location mismatch:' |
|
% (hasval), verbose=verbose, header=header, |
|
names=names, |
|
precision=precision) |
|
raise AssertionError(msg) |
|
|
|
|
|
if isinstance(x_id, bool) or x_id.ndim == 0: |
|
return np.bool(x_id) |
|
elif isinstance(y_id, bool) or y_id.ndim == 0: |
|
return np.bool(y_id) |
|
else: |
|
return y_id |
|
|
|
try: |
|
if strict: |
|
cond = x.shape == y.shape and x.dtype == y.dtype |
|
else: |
|
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape |
|
if not cond: |
|
if x.shape != y.shape: |
|
reason = f'\n(shapes {x.shape}, {y.shape} mismatch)' |
|
else: |
|
reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)' |
|
msg = build_err_msg([x, y], |
|
err_msg |
|
+ reason, |
|
verbose=verbose, header=header, |
|
names=names, |
|
precision=precision) |
|
raise AssertionError(msg) |
|
|
|
flagged = np.bool(False) |
|
if isnumber(x) and isnumber(y): |
|
if equal_nan: |
|
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan') |
|
|
|
if equal_inf: |
|
flagged |= func_assert_same_pos(x, y, |
|
func=lambda xy: xy == +inf, |
|
hasval='+inf') |
|
flagged |= func_assert_same_pos(x, y, |
|
func=lambda xy: xy == -inf, |
|
hasval='-inf') |
|
|
|
elif istime(x) and istime(y): |
|
|
|
if equal_nan and x.dtype.type == y.dtype.type: |
|
flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT") |
|
|
|
elif isvstring(x) and isvstring(y): |
|
dt = x.dtype |
|
if equal_nan and dt == y.dtype and hasattr(dt, 'na_object'): |
|
is_nan = (isinstance(dt.na_object, float) and |
|
np.isnan(dt.na_object)) |
|
bool_errors = 0 |
|
try: |
|
bool(dt.na_object) |
|
except TypeError: |
|
bool_errors = 1 |
|
if is_nan or bool_errors: |
|
|
|
flagged = func_assert_same_pos( |
|
x, y, func=isnan, hasval=x.dtype.na_object) |
|
|
|
if flagged.ndim > 0: |
|
x, y = x[~flagged], y[~flagged] |
|
|
|
if x.size == 0: |
|
return |
|
elif flagged: |
|
|
|
return |
|
|
|
val = comparison(x, y) |
|
invalids = np.logical_not(val) |
|
|
|
if isinstance(val, bool): |
|
cond = val |
|
reduced = array([val]) |
|
else: |
|
reduced = val.ravel() |
|
cond = reduced.all() |
|
|
|
|
|
|
|
|
|
|
|
if cond != True: |
|
n_mismatch = reduced.size - reduced.sum(dtype=intp) |
|
n_elements = flagged.size if flagged.ndim != 0 else reduced.size |
|
percent_mismatch = 100 * n_mismatch / n_elements |
|
remarks = [ |
|
'Mismatched elements: {} / {} ({:.3g}%)'.format( |
|
n_mismatch, n_elements, percent_mismatch)] |
|
|
|
with errstate(all='ignore'): |
|
|
|
with contextlib.suppress(TypeError): |
|
error = abs(x - y) |
|
if np.issubdtype(x.dtype, np.unsignedinteger): |
|
error2 = abs(y - x) |
|
np.minimum(error, error2, out=error) |
|
|
|
reduced_error = error[invalids] |
|
max_abs_error = max(reduced_error) |
|
if getattr(error, 'dtype', object_) == object_: |
|
remarks.append( |
|
'Max absolute difference among violations: ' |
|
+ str(max_abs_error)) |
|
else: |
|
remarks.append( |
|
'Max absolute difference among violations: ' |
|
+ array2string(max_abs_error)) |
|
|
|
|
|
|
|
|
|
nonzero = np.bool(y != 0) |
|
nonzero_and_invalid = np.logical_and(invalids, nonzero) |
|
|
|
if all(~nonzero_and_invalid): |
|
max_rel_error = array(inf) |
|
else: |
|
nonzero_invalid_error = error[nonzero_and_invalid] |
|
broadcasted_y = np.broadcast_to(y, error.shape) |
|
nonzero_invalid_y = broadcasted_y[nonzero_and_invalid] |
|
max_rel_error = max(nonzero_invalid_error |
|
/ abs(nonzero_invalid_y)) |
|
|
|
if getattr(error, 'dtype', object_) == object_: |
|
remarks.append( |
|
'Max relative difference among violations: ' |
|
+ str(max_rel_error)) |
|
else: |
|
remarks.append( |
|
'Max relative difference among violations: ' |
|
+ array2string(max_rel_error)) |
|
err_msg = str(err_msg) |
|
err_msg += '\n' + '\n'.join(remarks) |
|
msg = build_err_msg([ox, oy], err_msg, |
|
verbose=verbose, header=header, |
|
names=names, |
|
precision=precision) |
|
raise AssertionError(msg) |
|
except ValueError: |
|
import traceback |
|
efmt = traceback.format_exc() |
|
header = f'error during assertion:\n\n{efmt}\n\n{header}' |
|
|
|
msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, |
|
names=names, precision=precision) |
|
raise ValueError(msg) |
|
|
|
|
|
@_rename_parameter(['x', 'y'], ['actual', 'desired'], dep_version='2.0.0') |
|
def assert_array_equal(actual, desired, err_msg='', verbose=True, *, |
|
strict=False): |
|
""" |
|
Raises an AssertionError if two array_like objects are not equal. |
|
|
|
Given two array_like objects, check that the shape is equal and all |
|
elements of these objects are equal (but see the Notes for the special |
|
handling of a scalar). An exception is raised at shape mismatch or |
|
conflicting values. In contrast to the standard usage in numpy, NaNs |
|
are compared like numbers, no assertion is raised if both objects have |
|
NaNs in the same positions. |
|
|
|
The usual caution for verifying equality with floating point numbers is |
|
advised. |
|
|
|
.. note:: When either `actual` or `desired` is already an instance of |
|
`numpy.ndarray` and `desired` is not a ``dict``, the behavior of |
|
``assert_equal(actual, desired)`` is identical to the behavior of this |
|
function. Otherwise, this function performs `np.asanyarray` on the |
|
inputs before comparison, whereas `assert_equal` defines special |
|
comparison rules for common Python types. For example, only |
|
`assert_equal` can be used to compare nested Python lists. In new code, |
|
consider using only `assert_equal`, explicitly converting either |
|
`actual` or `desired` to arrays if the behavior of `assert_array_equal` |
|
is desired. |
|
|
|
Parameters |
|
---------- |
|
actual : array_like |
|
The actual object to check. |
|
desired : array_like |
|
The desired, expected object. |
|
err_msg : str, optional |
|
The error message to be printed in case of failure. |
|
verbose : bool, optional |
|
If True, the conflicting values are appended to the error message. |
|
strict : bool, optional |
|
If True, raise an AssertionError when either the shape or the data |
|
type of the array_like objects does not match. The special |
|
handling for scalars mentioned in the Notes section is disabled. |
|
|
|
.. versionadded:: 1.24.0 |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If actual and desired objects are not equal. |
|
|
|
See Also |
|
-------- |
|
assert_allclose: Compare two array_like objects for equality with desired |
|
relative and/or absolute precision. |
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal |
|
|
|
Notes |
|
----- |
|
When one of `actual` and `desired` is a scalar and the other is array_like, |
|
the function checks that each element of the array_like object is equal to |
|
the scalar. This behaviour can be disabled with the `strict` parameter. |
|
|
|
Examples |
|
-------- |
|
The first assert does not raise an exception: |
|
|
|
>>> np.testing.assert_array_equal([1.0,2.33333,np.nan], |
|
... [np.exp(0),2.33333, np.nan]) |
|
|
|
Assert fails with numerical imprecision with floats: |
|
|
|
>>> np.testing.assert_array_equal([1.0,np.pi,np.nan], |
|
... [1, np.sqrt(np.pi)**2, np.nan]) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not equal |
|
<BLANKLINE> |
|
Mismatched elements: 1 / 3 (33.3%) |
|
Max absolute difference among violations: 4.4408921e-16 |
|
Max relative difference among violations: 1.41357986e-16 |
|
ACTUAL: array([1. , 3.141593, nan]) |
|
DESIRED: array([1. , 3.141593, nan]) |
|
|
|
Use `assert_allclose` or one of the nulp (number of floating point values) |
|
functions for these cases instead: |
|
|
|
>>> np.testing.assert_allclose([1.0,np.pi,np.nan], |
|
... [1, np.sqrt(np.pi)**2, np.nan], |
|
... rtol=1e-10, atol=0) |
|
|
|
As mentioned in the Notes section, `assert_array_equal` has special |
|
handling for scalars. Here the test checks that each value in `x` is 3: |
|
|
|
>>> x = np.full((2, 5), fill_value=3) |
|
>>> np.testing.assert_array_equal(x, 3) |
|
|
|
Use `strict` to raise an AssertionError when comparing a scalar with an |
|
array: |
|
|
|
>>> np.testing.assert_array_equal(x, 3, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not equal |
|
<BLANKLINE> |
|
(shapes (2, 5), () mismatch) |
|
ACTUAL: array([[3, 3, 3, 3, 3], |
|
[3, 3, 3, 3, 3]]) |
|
DESIRED: array(3) |
|
|
|
The `strict` parameter also ensures that the array data types match: |
|
|
|
>>> x = np.array([2, 2, 2]) |
|
>>> y = np.array([2., 2., 2.], dtype=np.float32) |
|
>>> np.testing.assert_array_equal(x, y, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not equal |
|
<BLANKLINE> |
|
(dtypes int64, float32 mismatch) |
|
ACTUAL: array([2, 2, 2]) |
|
DESIRED: array([2., 2., 2.], dtype=float32) |
|
""" |
|
__tracebackhide__ = True |
|
assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg, |
|
verbose=verbose, header='Arrays are not equal', |
|
strict=strict) |
|
|
|
|
|
@_rename_parameter(['x', 'y'], ['actual', 'desired'], dep_version='2.0.0') |
|
def assert_array_almost_equal(actual, desired, decimal=6, err_msg='', |
|
verbose=True): |
|
""" |
|
Raises an AssertionError if two objects are not equal up to desired |
|
precision. |
|
|
|
.. note:: It is recommended to use one of `assert_allclose`, |
|
`assert_array_almost_equal_nulp` or `assert_array_max_ulp` |
|
instead of this function for more consistent floating point |
|
comparisons. |
|
|
|
The test verifies identical shapes and that the elements of ``actual`` and |
|
``desired`` satisfy:: |
|
|
|
abs(desired-actual) < 1.5 * 10**(-decimal) |
|
|
|
That is a looser test than originally documented, but agrees with what the |
|
actual implementation did up to rounding vagaries. An exception is raised |
|
at shape mismatch or conflicting values. In contrast to the standard usage |
|
in numpy, NaNs are compared like numbers, no assertion is raised if both |
|
objects have NaNs in the same positions. |
|
|
|
Parameters |
|
---------- |
|
actual : array_like |
|
The actual object to check. |
|
desired : array_like |
|
The desired, expected object. |
|
decimal : int, optional |
|
Desired precision, default is 6. |
|
err_msg : str, optional |
|
The error message to be printed in case of failure. |
|
verbose : bool, optional |
|
If True, the conflicting values are appended to the error message. |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If actual and desired are not equal up to specified precision. |
|
|
|
See Also |
|
-------- |
|
assert_allclose: Compare two array_like objects for equality with desired |
|
relative and/or absolute precision. |
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal |
|
|
|
Examples |
|
-------- |
|
the first assert does not raise an exception |
|
|
|
>>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], |
|
... [1.0,2.333,np.nan]) |
|
|
|
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], |
|
... [1.0,2.33339,np.nan], decimal=5) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not almost equal to 5 decimals |
|
<BLANKLINE> |
|
Mismatched elements: 1 / 3 (33.3%) |
|
Max absolute difference among violations: 6.e-05 |
|
Max relative difference among violations: 2.57136612e-05 |
|
ACTUAL: array([1. , 2.33333, nan]) |
|
DESIRED: array([1. , 2.33339, nan]) |
|
|
|
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], |
|
... [1.0,2.33333, 5], decimal=5) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not almost equal to 5 decimals |
|
<BLANKLINE> |
|
nan location mismatch: |
|
ACTUAL: array([1. , 2.33333, nan]) |
|
DESIRED: array([1. , 2.33333, 5. ]) |
|
|
|
""" |
|
__tracebackhide__ = True |
|
from numpy._core import number, result_type |
|
from numpy._core.numerictypes import issubdtype |
|
from numpy._core.fromnumeric import any as npany |
|
|
|
def compare(x, y): |
|
try: |
|
if npany(isinf(x)) or npany(isinf(y)): |
|
xinfid = isinf(x) |
|
yinfid = isinf(y) |
|
if not (xinfid == yinfid).all(): |
|
return False |
|
|
|
if x.size == y.size == 1: |
|
return x == y |
|
x = x[~xinfid] |
|
y = y[~yinfid] |
|
except (TypeError, NotImplementedError): |
|
pass |
|
|
|
|
|
|
|
dtype = result_type(y, 1.) |
|
y = np.asanyarray(y, dtype) |
|
z = abs(x - y) |
|
|
|
if not issubdtype(z.dtype, number): |
|
z = z.astype(np.float64) |
|
|
|
return z < 1.5 * 10.0**(-decimal) |
|
|
|
assert_array_compare(compare, actual, desired, err_msg=err_msg, |
|
verbose=verbose, |
|
header=('Arrays are not almost equal to %d decimals' % decimal), |
|
precision=decimal) |
|
|
|
|
|
def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False): |
|
""" |
|
Raises an AssertionError if two array_like objects are not ordered by less |
|
than. |
|
|
|
Given two array_like objects `x` and `y`, check that the shape is equal and |
|
all elements of `x` are strictly less than the corresponding elements of |
|
`y` (but see the Notes for the special handling of a scalar). An exception |
|
is raised at shape mismatch or values that are not correctly ordered. In |
|
contrast to the standard usage in NumPy, no assertion is raised if both |
|
objects have NaNs in the same positions. |
|
|
|
Parameters |
|
---------- |
|
x : array_like |
|
The smaller object to check. |
|
y : array_like |
|
The larger object to compare. |
|
err_msg : string |
|
The error message to be printed in case of failure. |
|
verbose : bool |
|
If True, the conflicting values are appended to the error message. |
|
strict : bool, optional |
|
If True, raise an AssertionError when either the shape or the data |
|
type of the array_like objects does not match. The special |
|
handling for scalars mentioned in the Notes section is disabled. |
|
|
|
.. versionadded:: 2.0.0 |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If x is not strictly smaller than y, element-wise. |
|
|
|
See Also |
|
-------- |
|
assert_array_equal: tests objects for equality |
|
assert_array_almost_equal: test objects for equality up to precision |
|
|
|
Notes |
|
----- |
|
When one of `x` and `y` is a scalar and the other is array_like, the |
|
function performs the comparison as though the scalar were broadcasted |
|
to the shape of the array. This behaviour can be disabled with the `strict` |
|
parameter. |
|
|
|
Examples |
|
-------- |
|
The following assertion passes because each finite element of `x` is |
|
strictly less than the corresponding element of `y`, and the NaNs are in |
|
corresponding locations. |
|
|
|
>>> x = [1.0, 1.0, np.nan] |
|
>>> y = [1.1, 2.0, np.nan] |
|
>>> np.testing.assert_array_less(x, y) |
|
|
|
The following assertion fails because the zeroth element of `x` is no |
|
longer strictly less than the zeroth element of `y`. |
|
|
|
>>> y[0] = 1 |
|
>>> np.testing.assert_array_less(x, y) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not strictly ordered `x < y` |
|
<BLANKLINE> |
|
Mismatched elements: 1 / 3 (33.3%) |
|
Max absolute difference among violations: 0. |
|
Max relative difference among violations: 0. |
|
x: array([ 1., 1., nan]) |
|
y: array([ 1., 2., nan]) |
|
|
|
Here, `y` is a scalar, so each element of `x` is compared to `y`, and |
|
the assertion passes. |
|
|
|
>>> x = [1.0, 4.0] |
|
>>> y = 5.0 |
|
>>> np.testing.assert_array_less(x, y) |
|
|
|
However, with ``strict=True``, the assertion will fail because the shapes |
|
do not match. |
|
|
|
>>> np.testing.assert_array_less(x, y, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not strictly ordered `x < y` |
|
<BLANKLINE> |
|
(shapes (2,), () mismatch) |
|
x: array([1., 4.]) |
|
y: array(5.) |
|
|
|
With ``strict=True``, the assertion also fails if the dtypes of the two |
|
arrays do not match. |
|
|
|
>>> y = [5, 5] |
|
>>> np.testing.assert_array_less(x, y, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Arrays are not strictly ordered `x < y` |
|
<BLANKLINE> |
|
(dtypes float64, int64 mismatch) |
|
x: array([1., 4.]) |
|
y: array([5, 5]) |
|
""" |
|
__tracebackhide__ = True |
|
assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, |
|
verbose=verbose, |
|
header='Arrays are not strictly ordered `x < y`', |
|
equal_inf=False, |
|
strict=strict, |
|
names=('x', 'y')) |
|
|
|
|
|
def runstring(astr, dict): |
|
exec(astr, dict) |
|
|
|
|
|
def assert_string_equal(actual, desired): |
|
""" |
|
Test if two strings are equal. |
|
|
|
If the given strings are equal, `assert_string_equal` does nothing. |
|
If they are not equal, an AssertionError is raised, and the diff |
|
between the strings is shown. |
|
|
|
Parameters |
|
---------- |
|
actual : str |
|
The string to test for equality against the expected string. |
|
desired : str |
|
The expected string. |
|
|
|
Examples |
|
-------- |
|
>>> np.testing.assert_string_equal('abc', 'abc') |
|
>>> np.testing.assert_string_equal('abc', 'abcd') |
|
Traceback (most recent call last): |
|
File "<stdin>", line 1, in <module> |
|
... |
|
AssertionError: Differences in strings: |
|
- abc+ abcd? + |
|
|
|
""" |
|
|
|
__tracebackhide__ = True |
|
import difflib |
|
|
|
if not isinstance(actual, str): |
|
raise AssertionError(repr(type(actual))) |
|
if not isinstance(desired, str): |
|
raise AssertionError(repr(type(desired))) |
|
if desired == actual: |
|
return |
|
|
|
diff = list(difflib.Differ().compare(actual.splitlines(True), |
|
desired.splitlines(True))) |
|
diff_list = [] |
|
while diff: |
|
d1 = diff.pop(0) |
|
if d1.startswith(' '): |
|
continue |
|
if d1.startswith('- '): |
|
l = [d1] |
|
d2 = diff.pop(0) |
|
if d2.startswith('? '): |
|
l.append(d2) |
|
d2 = diff.pop(0) |
|
if not d2.startswith('+ '): |
|
raise AssertionError(repr(d2)) |
|
l.append(d2) |
|
if diff: |
|
d3 = diff.pop(0) |
|
if d3.startswith('? '): |
|
l.append(d3) |
|
else: |
|
diff.insert(0, d3) |
|
if d2[2:] == d1[2:]: |
|
continue |
|
diff_list.extend(l) |
|
continue |
|
raise AssertionError(repr(d1)) |
|
if not diff_list: |
|
return |
|
msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" |
|
if actual != desired: |
|
raise AssertionError(msg) |
|
|
|
|
|
def rundocs(filename=None, raise_on_error=True): |
|
""" |
|
Run doctests found in the given file. |
|
|
|
By default `rundocs` raises an AssertionError on failure. |
|
|
|
Parameters |
|
---------- |
|
filename : str |
|
The path to the file for which the doctests are run. |
|
raise_on_error : bool |
|
Whether to raise an AssertionError when a doctest fails. Default is |
|
True. |
|
|
|
Notes |
|
----- |
|
The doctests can be run by the user/developer by adding the ``doctests`` |
|
argument to the ``test()`` call. For example, to run all tests (including |
|
doctests) for ``numpy.lib``: |
|
|
|
>>> np.lib.test(doctests=True) # doctest: +SKIP |
|
""" |
|
from numpy.distutils.misc_util import exec_mod_from_location |
|
import doctest |
|
if filename is None: |
|
f = sys._getframe(1) |
|
filename = f.f_globals['__file__'] |
|
name = os.path.splitext(os.path.basename(filename))[0] |
|
m = exec_mod_from_location(name, filename) |
|
|
|
tests = doctest.DocTestFinder().find(m) |
|
runner = doctest.DocTestRunner(verbose=False) |
|
|
|
msg = [] |
|
if raise_on_error: |
|
out = lambda s: msg.append(s) |
|
else: |
|
out = None |
|
|
|
for test in tests: |
|
runner.run(test, out=out) |
|
|
|
if runner.failures > 0 and raise_on_error: |
|
raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg)) |
|
|
|
|
|
def check_support_sve(__cache=[]): |
|
""" |
|
gh-22982 |
|
""" |
|
|
|
if __cache: |
|
return __cache[0] |
|
|
|
import subprocess |
|
cmd = 'lscpu' |
|
try: |
|
output = subprocess.run(cmd, capture_output=True, text=True) |
|
result = 'sve' in output.stdout |
|
except (OSError, subprocess.SubprocessError): |
|
result = False |
|
__cache.append(result) |
|
return __cache[0] |
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
class _Dummy(unittest.TestCase): |
|
def nop(self): |
|
pass |
|
|
|
|
|
_d = _Dummy('nop') |
|
|
|
|
|
def assert_raises(*args, **kwargs): |
|
""" |
|
assert_raises(exception_class, callable, *args, **kwargs) |
|
assert_raises(exception_class) |
|
|
|
Fail unless an exception of class exception_class is thrown |
|
by callable when invoked with arguments args and keyword |
|
arguments kwargs. If a different type of exception is |
|
thrown, it will not be caught, and the test case will be |
|
deemed to have suffered an error, exactly as for an |
|
unexpected exception. |
|
|
|
Alternatively, `assert_raises` can be used as a context manager: |
|
|
|
>>> from numpy.testing import assert_raises |
|
>>> with assert_raises(ZeroDivisionError): |
|
... 1 / 0 |
|
|
|
is equivalent to |
|
|
|
>>> def div(x, y): |
|
... return x / y |
|
>>> assert_raises(ZeroDivisionError, div, 1, 0) |
|
|
|
""" |
|
__tracebackhide__ = True |
|
return _d.assertRaises(*args, **kwargs) |
|
|
|
|
|
def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): |
|
""" |
|
assert_raises_regex(exception_class, expected_regexp, callable, *args, |
|
**kwargs) |
|
assert_raises_regex(exception_class, expected_regexp) |
|
|
|
Fail unless an exception of class exception_class and with message that |
|
matches expected_regexp is thrown by callable when invoked with arguments |
|
args and keyword arguments kwargs. |
|
|
|
Alternatively, can be used as a context manager like `assert_raises`. |
|
""" |
|
__tracebackhide__ = True |
|
return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) |
|
|
|
|
|
def decorate_methods(cls, decorator, testmatch=None): |
|
""" |
|
Apply a decorator to all methods in a class matching a regular expression. |
|
|
|
The given decorator is applied to all public methods of `cls` that are |
|
matched by the regular expression `testmatch` |
|
(``testmatch.search(methodname)``). Methods that are private, i.e. start |
|
with an underscore, are ignored. |
|
|
|
Parameters |
|
---------- |
|
cls : class |
|
Class whose methods to decorate. |
|
decorator : function |
|
Decorator to apply to methods |
|
testmatch : compiled regexp or str, optional |
|
The regular expression. Default value is None, in which case the |
|
nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) |
|
is used. |
|
If `testmatch` is a string, it is compiled to a regular expression |
|
first. |
|
|
|
""" |
|
if testmatch is None: |
|
testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep) |
|
else: |
|
testmatch = re.compile(testmatch) |
|
cls_attr = cls.__dict__ |
|
|
|
|
|
from inspect import isfunction |
|
|
|
methods = [_m for _m in cls_attr.values() if isfunction(_m)] |
|
for function in methods: |
|
try: |
|
if hasattr(function, 'compat_func_name'): |
|
funcname = function.compat_func_name |
|
else: |
|
funcname = function.__name__ |
|
except AttributeError: |
|
|
|
continue |
|
if testmatch.search(funcname) and not funcname.startswith('_'): |
|
setattr(cls, funcname, decorator(function)) |
|
return |
|
|
|
|
|
def measure(code_str, times=1, label=None): |
|
""" |
|
Return elapsed time for executing code in the namespace of the caller. |
|
|
|
The supplied code string is compiled with the Python builtin ``compile``. |
|
The precision of the timing is 10 milli-seconds. If the code will execute |
|
fast on this timescale, it can be executed many times to get reasonable |
|
timing accuracy. |
|
|
|
Parameters |
|
---------- |
|
code_str : str |
|
The code to be timed. |
|
times : int, optional |
|
The number of times the code is executed. Default is 1. The code is |
|
only compiled once. |
|
label : str, optional |
|
A label to identify `code_str` with. This is passed into ``compile`` |
|
as the second argument (for run-time error messages). |
|
|
|
Returns |
|
------- |
|
elapsed : float |
|
Total elapsed time in seconds for executing `code_str` `times` times. |
|
|
|
Examples |
|
-------- |
|
>>> times = 10 |
|
>>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times) |
|
>>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP |
|
Time for a single execution : 0.005 s |
|
|
|
""" |
|
frame = sys._getframe(1) |
|
locs, globs = frame.f_locals, frame.f_globals |
|
|
|
code = compile(code_str, f'Test name: {label} ', 'exec') |
|
i = 0 |
|
elapsed = jiffies() |
|
while i < times: |
|
i += 1 |
|
exec(code, globs, locs) |
|
elapsed = jiffies() - elapsed |
|
return 0.01 * elapsed |
|
|
|
|
|
def _assert_valid_refcount(op): |
|
""" |
|
Check that ufuncs don't mishandle refcount of object `1`. |
|
Used in a few regression tests. |
|
""" |
|
if not HAS_REFCOUNT: |
|
return True |
|
|
|
import gc |
|
import numpy as np |
|
|
|
b = np.arange(100 * 100).reshape(100, 100) |
|
c = b |
|
i = 1 |
|
|
|
gc.disable() |
|
try: |
|
rc = sys.getrefcount(i) |
|
for j in range(15): |
|
d = op(b, c) |
|
assert_(sys.getrefcount(i) >= rc) |
|
finally: |
|
gc.enable() |
|
del d |
|
|
|
|
|
def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, |
|
err_msg='', verbose=True, *, strict=False): |
|
""" |
|
Raises an AssertionError if two objects are not equal up to desired |
|
tolerance. |
|
|
|
Given two array_like objects, check that their shapes and all elements |
|
are equal (but see the Notes for the special handling of a scalar). An |
|
exception is raised if the shapes mismatch or any values conflict. In |
|
contrast to the standard usage in numpy, NaNs are compared like numbers, |
|
no assertion is raised if both objects have NaNs in the same positions. |
|
|
|
The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note |
|
that ``allclose`` has different default values). It compares the difference |
|
between `actual` and `desired` to ``atol + rtol * abs(desired)``. |
|
|
|
Parameters |
|
---------- |
|
actual : array_like |
|
Array obtained. |
|
desired : array_like |
|
Array desired. |
|
rtol : float, optional |
|
Relative tolerance. |
|
atol : float, optional |
|
Absolute tolerance. |
|
equal_nan : bool, optional. |
|
If True, NaNs will compare equal. |
|
err_msg : str, optional |
|
The error message to be printed in case of failure. |
|
verbose : bool, optional |
|
If True, the conflicting values are appended to the error message. |
|
strict : bool, optional |
|
If True, raise an ``AssertionError`` when either the shape or the data |
|
type of the arguments does not match. The special handling of scalars |
|
mentioned in the Notes section is disabled. |
|
|
|
.. versionadded:: 2.0.0 |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If actual and desired are not equal up to specified precision. |
|
|
|
See Also |
|
-------- |
|
assert_array_almost_equal_nulp, assert_array_max_ulp |
|
|
|
Notes |
|
----- |
|
When one of `actual` and `desired` is a scalar and the other is |
|
array_like, the function performs the comparison as if the scalar were |
|
broadcasted to the shape of the array. |
|
This behaviour can be disabled with the `strict` parameter. |
|
|
|
Examples |
|
-------- |
|
>>> x = [1e-5, 1e-3, 1e-1] |
|
>>> y = np.arccos(np.cos(x)) |
|
>>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) |
|
|
|
As mentioned in the Notes section, `assert_allclose` has special |
|
handling for scalars. Here, the test checks that the value of `numpy.sin` |
|
is nearly zero at integer multiples of π. |
|
|
|
>>> x = np.arange(3) * np.pi |
|
>>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15) |
|
|
|
Use `strict` to raise an ``AssertionError`` when comparing an array |
|
with one or more dimensions against a scalar. |
|
|
|
>>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Not equal to tolerance rtol=1e-07, atol=1e-15 |
|
<BLANKLINE> |
|
(shapes (3,), () mismatch) |
|
ACTUAL: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16]) |
|
DESIRED: array(0) |
|
|
|
The `strict` parameter also ensures that the array data types match: |
|
|
|
>>> y = np.zeros(3, dtype=np.float32) |
|
>>> np.testing.assert_allclose(np.sin(x), y, atol=1e-15, strict=True) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: |
|
Not equal to tolerance rtol=1e-07, atol=1e-15 |
|
<BLANKLINE> |
|
(dtypes float64, float32 mismatch) |
|
ACTUAL: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16]) |
|
DESIRED: array([0., 0., 0.], dtype=float32) |
|
|
|
""" |
|
__tracebackhide__ = True |
|
import numpy as np |
|
|
|
def compare(x, y): |
|
return np._core.numeric.isclose(x, y, rtol=rtol, atol=atol, |
|
equal_nan=equal_nan) |
|
|
|
actual, desired = np.asanyarray(actual), np.asanyarray(desired) |
|
header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}' |
|
assert_array_compare(compare, actual, desired, err_msg=str(err_msg), |
|
verbose=verbose, header=header, equal_nan=equal_nan, |
|
strict=strict) |
|
|
|
|
|
def assert_array_almost_equal_nulp(x, y, nulp=1): |
|
""" |
|
Compare two arrays relatively to their spacing. |
|
|
|
This is a relatively robust method to compare two arrays whose amplitude |
|
is variable. |
|
|
|
Parameters |
|
---------- |
|
x, y : array_like |
|
Input arrays. |
|
nulp : int, optional |
|
The maximum number of unit in the last place for tolerance (see Notes). |
|
Default is 1. |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If the spacing between `x` and `y` for one or more elements is larger |
|
than `nulp`. |
|
|
|
See Also |
|
-------- |
|
assert_array_max_ulp : Check that all items of arrays differ in at most |
|
N Units in the Last Place. |
|
spacing : Return the distance between x and the nearest adjacent number. |
|
|
|
Notes |
|
----- |
|
An assertion is raised if the following condition is not met:: |
|
|
|
abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) |
|
|
|
Examples |
|
-------- |
|
>>> x = np.array([1., 1e-10, 1e-20]) |
|
>>> eps = np.finfo(x.dtype).eps |
|
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) |
|
|
|
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) |
|
Traceback (most recent call last): |
|
... |
|
AssertionError: Arrays are not equal to 1 ULP (max is 2) |
|
|
|
""" |
|
__tracebackhide__ = True |
|
import numpy as np |
|
ax = np.abs(x) |
|
ay = np.abs(y) |
|
ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) |
|
if not np.all(np.abs(x - y) <= ref): |
|
if np.iscomplexobj(x) or np.iscomplexobj(y): |
|
msg = f"Arrays are not equal to {nulp} ULP" |
|
else: |
|
max_nulp = np.max(nulp_diff(x, y)) |
|
msg = f"Arrays are not equal to {nulp} ULP (max is {max_nulp:g})" |
|
raise AssertionError(msg) |
|
|
|
|
|
def assert_array_max_ulp(a, b, maxulp=1, dtype=None): |
|
""" |
|
Check that all items of arrays differ in at most N Units in the Last Place. |
|
|
|
Parameters |
|
---------- |
|
a, b : array_like |
|
Input arrays to be compared. |
|
maxulp : int, optional |
|
The maximum number of units in the last place that elements of `a` and |
|
`b` can differ. Default is 1. |
|
dtype : dtype, optional |
|
Data-type to convert `a` and `b` to if given. Default is None. |
|
|
|
Returns |
|
------- |
|
ret : ndarray |
|
Array containing number of representable floating point numbers between |
|
items in `a` and `b`. |
|
|
|
Raises |
|
------ |
|
AssertionError |
|
If one or more elements differ by more than `maxulp`. |
|
|
|
Notes |
|
----- |
|
For computing the ULP difference, this API does not differentiate between |
|
various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 |
|
is zero). |
|
|
|
See Also |
|
-------- |
|
assert_array_almost_equal_nulp : Compare two arrays relatively to their |
|
spacing. |
|
|
|
Examples |
|
-------- |
|
>>> a = np.linspace(0., 1., 100) |
|
>>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) |
|
|
|
""" |
|
__tracebackhide__ = True |
|
import numpy as np |
|
ret = nulp_diff(a, b, dtype) |
|
if not np.all(ret <= maxulp): |
|
raise AssertionError("Arrays are not almost equal up to %g " |
|
"ULP (max difference is %g ULP)" % |
|
(maxulp, np.max(ret))) |
|
return ret |
|
|
|
|
|
def nulp_diff(x, y, dtype=None): |
|
"""For each item in x and y, return the number of representable floating |
|
points between them. |
|
|
|
Parameters |
|
---------- |
|
x : array_like |
|
first input array |
|
y : array_like |
|
second input array |
|
dtype : dtype, optional |
|
Data-type to convert `x` and `y` to if given. Default is None. |
|
|
|
Returns |
|
------- |
|
nulp : array_like |
|
number of representable floating point numbers between each item in x |
|
and y. |
|
|
|
Notes |
|
----- |
|
For computing the ULP difference, this API does not differentiate between |
|
various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 |
|
is zero). |
|
|
|
Examples |
|
-------- |
|
# By definition, epsilon is the smallest number such as 1 + eps != 1, so |
|
# there should be exactly one ULP between 1 and 1 + eps |
|
>>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) |
|
1.0 |
|
""" |
|
import numpy as np |
|
if dtype: |
|
x = np.asarray(x, dtype=dtype) |
|
y = np.asarray(y, dtype=dtype) |
|
else: |
|
x = np.asarray(x) |
|
y = np.asarray(y) |
|
|
|
t = np.common_type(x, y) |
|
if np.iscomplexobj(x) or np.iscomplexobj(y): |
|
raise NotImplementedError("_nulp not implemented for complex array") |
|
|
|
x = np.array([x], dtype=t) |
|
y = np.array([y], dtype=t) |
|
|
|
x[np.isnan(x)] = np.nan |
|
y[np.isnan(y)] = np.nan |
|
|
|
if not x.shape == y.shape: |
|
raise ValueError("Arrays do not have the same shape: %s - %s" % |
|
(x.shape, y.shape)) |
|
|
|
def _diff(rx, ry, vdt): |
|
diff = np.asarray(rx - ry, dtype=vdt) |
|
return np.abs(diff) |
|
|
|
rx = integer_repr(x) |
|
ry = integer_repr(y) |
|
return _diff(rx, ry, t) |
|
|
|
|
|
def _integer_repr(x, vdt, comp): |
|
|
|
|
|
|
|
|
|
rx = x.view(vdt) |
|
if not (rx.size == 1): |
|
rx[rx < 0] = comp - rx[rx < 0] |
|
else: |
|
if rx < 0: |
|
rx = comp - rx |
|
|
|
return rx |
|
|
|
|
|
def integer_repr(x): |
|
"""Return the signed-magnitude interpretation of the binary representation |
|
of x.""" |
|
import numpy as np |
|
if x.dtype == np.float16: |
|
return _integer_repr(x, np.int16, np.int16(-2**15)) |
|
elif x.dtype == np.float32: |
|
return _integer_repr(x, np.int32, np.int32(-2**31)) |
|
elif x.dtype == np.float64: |
|
return _integer_repr(x, np.int64, np.int64(-2**63)) |
|
else: |
|
raise ValueError(f'Unsupported dtype {x.dtype}') |
|
|
|
|
|
@contextlib.contextmanager |
|
def _assert_warns_context(warning_class, name=None): |
|
__tracebackhide__ = True |
|
with suppress_warnings() as sup: |
|
l = sup.record(warning_class) |
|
yield |
|
if not len(l) > 0: |
|
name_str = f' when calling {name}' if name is not None else '' |
|
raise AssertionError("No warning raised" + name_str) |
|
|
|
|
|
def assert_warns(warning_class, *args, **kwargs): |
|
""" |
|
Fail unless the given callable throws the specified warning. |
|
|
|
A warning of class warning_class should be thrown by the callable when |
|
invoked with arguments args and keyword arguments kwargs. |
|
If a different type of warning is thrown, it will not be caught. |
|
|
|
If called with all arguments other than the warning class omitted, may be |
|
used as a context manager:: |
|
|
|
with assert_warns(SomeWarning): |
|
do_something() |
|
|
|
The ability to be used as a context manager is new in NumPy v1.11.0. |
|
|
|
Parameters |
|
---------- |
|
warning_class : class |
|
The class defining the warning that `func` is expected to throw. |
|
func : callable, optional |
|
Callable to test |
|
*args : Arguments |
|
Arguments for `func`. |
|
**kwargs : Kwargs |
|
Keyword arguments for `func`. |
|
|
|
Returns |
|
------- |
|
The value returned by `func`. |
|
|
|
Examples |
|
-------- |
|
>>> import warnings |
|
>>> def deprecated_func(num): |
|
... warnings.warn("Please upgrade", DeprecationWarning) |
|
... return num*num |
|
>>> with np.testing.assert_warns(DeprecationWarning): |
|
... assert deprecated_func(4) == 16 |
|
>>> # or passing a func |
|
>>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) |
|
>>> assert ret == 16 |
|
""" |
|
if not args and not kwargs: |
|
return _assert_warns_context(warning_class) |
|
elif len(args) < 1: |
|
if "match" in kwargs: |
|
raise RuntimeError( |
|
"assert_warns does not use 'match' kwarg, " |
|
"use pytest.warns instead" |
|
) |
|
raise RuntimeError("assert_warns(...) needs at least one arg") |
|
|
|
func = args[0] |
|
args = args[1:] |
|
with _assert_warns_context(warning_class, name=func.__name__): |
|
return func(*args, **kwargs) |
|
|
|
|
|
@contextlib.contextmanager |
|
def _assert_no_warnings_context(name=None): |
|
__tracebackhide__ = True |
|
with warnings.catch_warnings(record=True) as l: |
|
warnings.simplefilter('always') |
|
yield |
|
if len(l) > 0: |
|
name_str = f' when calling {name}' if name is not None else '' |
|
raise AssertionError(f'Got warnings{name_str}: {l}') |
|
|
|
|
|
def assert_no_warnings(*args, **kwargs): |
|
""" |
|
Fail if the given callable produces any warnings. |
|
|
|
If called with all arguments omitted, may be used as a context manager:: |
|
|
|
with assert_no_warnings(): |
|
do_something() |
|
|
|
The ability to be used as a context manager is new in NumPy v1.11.0. |
|
|
|
Parameters |
|
---------- |
|
func : callable |
|
The callable to test. |
|
\\*args : Arguments |
|
Arguments passed to `func`. |
|
\\*\\*kwargs : Kwargs |
|
Keyword arguments passed to `func`. |
|
|
|
Returns |
|
------- |
|
The value returned by `func`. |
|
|
|
""" |
|
if not args: |
|
return _assert_no_warnings_context() |
|
|
|
func = args[0] |
|
args = args[1:] |
|
with _assert_no_warnings_context(name=func.__name__): |
|
return func(*args, **kwargs) |
|
|
|
|
|
def _gen_alignment_data(dtype=float32, type='binary', max_size=24): |
|
""" |
|
generator producing data with different alignment and offsets |
|
to test simd vectorization |
|
|
|
Parameters |
|
---------- |
|
dtype : dtype |
|
data type to produce |
|
type : string |
|
'unary': create data for unary operations, creates one input |
|
and output array |
|
'binary': create data for unary operations, creates two input |
|
and output array |
|
max_size : integer |
|
maximum size of data to produce |
|
|
|
Returns |
|
------- |
|
if type is 'unary' yields one output, one input array and a message |
|
containing information on the data |
|
if type is 'binary' yields one output array, two input array and a message |
|
containing information on the data |
|
|
|
""" |
|
ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s' |
|
bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s' |
|
for o in range(3): |
|
for s in range(o + 2, max(o + 3, max_size)): |
|
if type == 'unary': |
|
inp = lambda: arange(s, dtype=dtype)[o:] |
|
out = empty((s,), dtype=dtype)[o:] |
|
yield out, inp(), ufmt % (o, o, s, dtype, 'out of place') |
|
d = inp() |
|
yield d, d, ufmt % (o, o, s, dtype, 'in place') |
|
yield out[1:], inp()[:-1], ufmt % \ |
|
(o + 1, o, s - 1, dtype, 'out of place') |
|
yield out[:-1], inp()[1:], ufmt % \ |
|
(o, o + 1, s - 1, dtype, 'out of place') |
|
yield inp()[:-1], inp()[1:], ufmt % \ |
|
(o, o + 1, s - 1, dtype, 'aliased') |
|
yield inp()[1:], inp()[:-1], ufmt % \ |
|
(o + 1, o, s - 1, dtype, 'aliased') |
|
if type == 'binary': |
|
inp1 = lambda: arange(s, dtype=dtype)[o:] |
|
inp2 = lambda: arange(s, dtype=dtype)[o:] |
|
out = empty((s,), dtype=dtype)[o:] |
|
yield out, inp1(), inp2(), bfmt % \ |
|
(o, o, o, s, dtype, 'out of place') |
|
d = inp1() |
|
yield d, d, inp2(), bfmt % \ |
|
(o, o, o, s, dtype, 'in place1') |
|
d = inp2() |
|
yield d, inp1(), d, bfmt % \ |
|
(o, o, o, s, dtype, 'in place2') |
|
yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \ |
|
(o + 1, o, o, s - 1, dtype, 'out of place') |
|
yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \ |
|
(o, o + 1, o, s - 1, dtype, 'out of place') |
|
yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \ |
|
(o, o, o + 1, s - 1, dtype, 'out of place') |
|
yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \ |
|
(o + 1, o, o, s - 1, dtype, 'aliased') |
|
yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \ |
|
(o, o + 1, o, s - 1, dtype, 'aliased') |
|
yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \ |
|
(o, o, o + 1, s - 1, dtype, 'aliased') |
|
|
|
|
|
class IgnoreException(Exception): |
|
"Ignoring this exception due to disabled feature" |
|
pass |
|
|
|
|
|
@contextlib.contextmanager |
|
def tempdir(*args, **kwargs): |
|
"""Context manager to provide a temporary test folder. |
|
|
|
All arguments are passed as this to the underlying tempfile.mkdtemp |
|
function. |
|
|
|
""" |
|
tmpdir = mkdtemp(*args, **kwargs) |
|
try: |
|
yield tmpdir |
|
finally: |
|
shutil.rmtree(tmpdir) |
|
|
|
|
|
@contextlib.contextmanager |
|
def temppath(*args, **kwargs): |
|
"""Context manager for temporary files. |
|
|
|
Context manager that returns the path to a closed temporary file. Its |
|
parameters are the same as for tempfile.mkstemp and are passed directly |
|
to that function. The underlying file is removed when the context is |
|
exited, so it should be closed at that time. |
|
|
|
Windows does not allow a temporary file to be opened if it is already |
|
open, so the underlying file must be closed after opening before it |
|
can be opened again. |
|
|
|
""" |
|
fd, path = mkstemp(*args, **kwargs) |
|
os.close(fd) |
|
try: |
|
yield path |
|
finally: |
|
os.remove(path) |
|
|
|
|
|
class clear_and_catch_warnings(warnings.catch_warnings): |
|
""" Context manager that resets warning registry for catching warnings |
|
|
|
Warnings can be slippery, because, whenever a warning is triggered, Python |
|
adds a ``__warningregistry__`` member to the *calling* module. This makes |
|
it impossible to retrigger the warning in this module, whatever you put in |
|
the warnings filters. This context manager accepts a sequence of `modules` |
|
as a keyword argument to its constructor and: |
|
|
|
* stores and removes any ``__warningregistry__`` entries in given `modules` |
|
on entry; |
|
* resets ``__warningregistry__`` to its previous state on exit. |
|
|
|
This makes it possible to trigger any warning afresh inside the context |
|
manager without disturbing the state of warnings outside. |
|
|
|
For compatibility with Python 3.0, please consider all arguments to be |
|
keyword-only. |
|
|
|
Parameters |
|
---------- |
|
record : bool, optional |
|
Specifies whether warnings should be captured by a custom |
|
implementation of ``warnings.showwarning()`` and be appended to a list |
|
returned by the context manager. Otherwise None is returned by the |
|
context manager. The objects appended to the list are arguments whose |
|
attributes mirror the arguments to ``showwarning()``. |
|
modules : sequence, optional |
|
Sequence of modules for which to reset warnings registry on entry and |
|
restore on exit. To work correctly, all 'ignore' filters should |
|
filter by one of these modules. |
|
|
|
Examples |
|
-------- |
|
>>> import warnings |
|
>>> with np.testing.clear_and_catch_warnings( |
|
... modules=[np._core.fromnumeric]): |
|
... warnings.simplefilter('always') |
|
... warnings.filterwarnings('ignore', module='np._core.fromnumeric') |
|
... # do something that raises a warning but ignore those in |
|
... # np._core.fromnumeric |
|
""" |
|
class_modules = () |
|
|
|
def __init__(self, record=False, modules=()): |
|
self.modules = set(modules).union(self.class_modules) |
|
self._warnreg_copies = {} |
|
super().__init__(record=record) |
|
|
|
def __enter__(self): |
|
for mod in self.modules: |
|
if hasattr(mod, '__warningregistry__'): |
|
mod_reg = mod.__warningregistry__ |
|
self._warnreg_copies[mod] = mod_reg.copy() |
|
mod_reg.clear() |
|
return super().__enter__() |
|
|
|
def __exit__(self, *exc_info): |
|
super().__exit__(*exc_info) |
|
for mod in self.modules: |
|
if hasattr(mod, '__warningregistry__'): |
|
mod.__warningregistry__.clear() |
|
if mod in self._warnreg_copies: |
|
mod.__warningregistry__.update(self._warnreg_copies[mod]) |
|
|
|
|
|
class suppress_warnings: |
|
""" |
|
Context manager and decorator doing much the same as |
|
``warnings.catch_warnings``. |
|
|
|
However, it also provides a filter mechanism to work around |
|
https://bugs.python.org/issue4180. |
|
|
|
This bug causes Python before 3.4 to not reliably show warnings again |
|
after they have been ignored once (even within catch_warnings). It |
|
means that no "ignore" filter can be used easily, since following |
|
tests might need to see the warning. Additionally it allows easier |
|
specificity for testing warnings and can be nested. |
|
|
|
Parameters |
|
---------- |
|
forwarding_rule : str, optional |
|
One of "always", "once", "module", or "location". Analogous to |
|
the usual warnings module filter mode, it is useful to reduce |
|
noise mostly on the outmost level. Unsuppressed and unrecorded |
|
warnings will be forwarded based on this rule. Defaults to "always". |
|
"location" is equivalent to the warnings "default", match by exact |
|
location the warning warning originated from. |
|
|
|
Notes |
|
----- |
|
Filters added inside the context manager will be discarded again |
|
when leaving it. Upon entering all filters defined outside a |
|
context will be applied automatically. |
|
|
|
When a recording filter is added, matching warnings are stored in the |
|
``log`` attribute as well as in the list returned by ``record``. |
|
|
|
If filters are added and the ``module`` keyword is given, the |
|
warning registry of this module will additionally be cleared when |
|
applying it, entering the context, or exiting it. This could cause |
|
warnings to appear a second time after leaving the context if they |
|
were configured to be printed once (default) and were already |
|
printed before the context was entered. |
|
|
|
Nesting this context manager will work as expected when the |
|
forwarding rule is "always" (default). Unfiltered and unrecorded |
|
warnings will be passed out and be matched by the outer level. |
|
On the outmost level they will be printed (or caught by another |
|
warnings context). The forwarding rule argument can modify this |
|
behaviour. |
|
|
|
Like ``catch_warnings`` this context manager is not threadsafe. |
|
|
|
Examples |
|
-------- |
|
|
|
With a context manager:: |
|
|
|
with np.testing.suppress_warnings() as sup: |
|
sup.filter(DeprecationWarning, "Some text") |
|
sup.filter(module=np.ma.core) |
|
log = sup.record(FutureWarning, "Does this occur?") |
|
command_giving_warnings() |
|
# The FutureWarning was given once, the filtered warnings were |
|
# ignored. All other warnings abide outside settings (may be |
|
# printed/error) |
|
assert_(len(log) == 1) |
|
assert_(len(sup.log) == 1) # also stored in log attribute |
|
|
|
Or as a decorator:: |
|
|
|
sup = np.testing.suppress_warnings() |
|
sup.filter(module=np.ma.core) # module must match exactly |
|
@sup |
|
def some_function(): |
|
# do something which causes a warning in np.ma.core |
|
pass |
|
""" |
|
def __init__(self, forwarding_rule="always"): |
|
self._entered = False |
|
|
|
|
|
self._suppressions = [] |
|
|
|
if forwarding_rule not in {"always", "module", "once", "location"}: |
|
raise ValueError("unsupported forwarding rule.") |
|
self._forwarding_rule = forwarding_rule |
|
|
|
def _clear_registries(self): |
|
if hasattr(warnings, "_filters_mutated"): |
|
|
|
|
|
warnings._filters_mutated() |
|
return |
|
|
|
|
|
for module in self._tmp_modules: |
|
if hasattr(module, "__warningregistry__"): |
|
module.__warningregistry__.clear() |
|
|
|
def _filter(self, category=Warning, message="", module=None, record=False): |
|
if record: |
|
record = [] |
|
else: |
|
record = None |
|
if self._entered: |
|
if module is None: |
|
warnings.filterwarnings( |
|
"always", category=category, message=message) |
|
else: |
|
module_regex = module.__name__.replace('.', r'\.') + '$' |
|
warnings.filterwarnings( |
|
"always", category=category, message=message, |
|
module=module_regex) |
|
self._tmp_modules.add(module) |
|
self._clear_registries() |
|
|
|
self._tmp_suppressions.append( |
|
(category, message, re.compile(message, re.I), module, record)) |
|
else: |
|
self._suppressions.append( |
|
(category, message, re.compile(message, re.I), module, record)) |
|
|
|
return record |
|
|
|
def filter(self, category=Warning, message="", module=None): |
|
""" |
|
Add a new suppressing filter or apply it if the state is entered. |
|
|
|
Parameters |
|
---------- |
|
category : class, optional |
|
Warning class to filter |
|
message : string, optional |
|
Regular expression matching the warning message. |
|
module : module, optional |
|
Module to filter for. Note that the module (and its file) |
|
must match exactly and cannot be a submodule. This may make |
|
it unreliable for external modules. |
|
|
|
Notes |
|
----- |
|
When added within a context, filters are only added inside |
|
the context and will be forgotten when the context is exited. |
|
""" |
|
self._filter(category=category, message=message, module=module, |
|
record=False) |
|
|
|
def record(self, category=Warning, message="", module=None): |
|
""" |
|
Append a new recording filter or apply it if the state is entered. |
|
|
|
All warnings matching will be appended to the ``log`` attribute. |
|
|
|
Parameters |
|
---------- |
|
category : class, optional |
|
Warning class to filter |
|
message : string, optional |
|
Regular expression matching the warning message. |
|
module : module, optional |
|
Module to filter for. Note that the module (and its file) |
|
must match exactly and cannot be a submodule. This may make |
|
it unreliable for external modules. |
|
|
|
Returns |
|
------- |
|
log : list |
|
A list which will be filled with all matched warnings. |
|
|
|
Notes |
|
----- |
|
When added within a context, filters are only added inside |
|
the context and will be forgotten when the context is exited. |
|
""" |
|
return self._filter(category=category, message=message, module=module, |
|
record=True) |
|
|
|
def __enter__(self): |
|
if self._entered: |
|
raise RuntimeError("cannot enter suppress_warnings twice.") |
|
|
|
self._orig_show = warnings.showwarning |
|
self._filters = warnings.filters |
|
warnings.filters = self._filters[:] |
|
|
|
self._entered = True |
|
self._tmp_suppressions = [] |
|
self._tmp_modules = set() |
|
self._forwarded = set() |
|
|
|
self.log = [] |
|
|
|
for cat, mess, _, mod, log in self._suppressions: |
|
if log is not None: |
|
del log[:] |
|
if mod is None: |
|
warnings.filterwarnings( |
|
"always", category=cat, message=mess) |
|
else: |
|
module_regex = mod.__name__.replace('.', r'\.') + '$' |
|
warnings.filterwarnings( |
|
"always", category=cat, message=mess, |
|
module=module_regex) |
|
self._tmp_modules.add(mod) |
|
warnings.showwarning = self._showwarning |
|
self._clear_registries() |
|
|
|
return self |
|
|
|
def __exit__(self, *exc_info): |
|
warnings.showwarning = self._orig_show |
|
warnings.filters = self._filters |
|
self._clear_registries() |
|
self._entered = False |
|
del self._orig_show |
|
del self._filters |
|
|
|
def _showwarning(self, message, category, filename, lineno, |
|
*args, use_warnmsg=None, **kwargs): |
|
for cat, _, pattern, mod, rec in ( |
|
self._suppressions + self._tmp_suppressions)[::-1]: |
|
if (issubclass(category, cat) and |
|
pattern.match(message.args[0]) is not None): |
|
if mod is None: |
|
|
|
if rec is not None: |
|
msg = WarningMessage(message, category, filename, |
|
lineno, **kwargs) |
|
self.log.append(msg) |
|
rec.append(msg) |
|
return |
|
|
|
|
|
elif mod.__file__.startswith(filename): |
|
|
|
if rec is not None: |
|
msg = WarningMessage(message, category, filename, |
|
lineno, **kwargs) |
|
self.log.append(msg) |
|
rec.append(msg) |
|
return |
|
|
|
|
|
|
|
if self._forwarding_rule == "always": |
|
if use_warnmsg is None: |
|
self._orig_show(message, category, filename, lineno, |
|
*args, **kwargs) |
|
else: |
|
self._orig_showmsg(use_warnmsg) |
|
return |
|
|
|
if self._forwarding_rule == "once": |
|
signature = (message.args, category) |
|
elif self._forwarding_rule == "module": |
|
signature = (message.args, category, filename) |
|
elif self._forwarding_rule == "location": |
|
signature = (message.args, category, filename, lineno) |
|
|
|
if signature in self._forwarded: |
|
return |
|
self._forwarded.add(signature) |
|
if use_warnmsg is None: |
|
self._orig_show(message, category, filename, lineno, *args, |
|
**kwargs) |
|
else: |
|
self._orig_showmsg(use_warnmsg) |
|
|
|
def __call__(self, func): |
|
""" |
|
Function decorator to apply certain suppressions to a whole |
|
function. |
|
""" |
|
@wraps(func) |
|
def new_func(*args, **kwargs): |
|
with self: |
|
return func(*args, **kwargs) |
|
|
|
return new_func |
|
|
|
|
|
@contextlib.contextmanager |
|
def _assert_no_gc_cycles_context(name=None): |
|
__tracebackhide__ = True |
|
|
|
|
|
if not HAS_REFCOUNT: |
|
yield |
|
return |
|
|
|
assert_(gc.isenabled()) |
|
gc.disable() |
|
gc_debug = gc.get_debug() |
|
try: |
|
for i in range(100): |
|
if gc.collect() == 0: |
|
break |
|
else: |
|
raise RuntimeError( |
|
"Unable to fully collect garbage - perhaps a __del__ method " |
|
"is creating more reference cycles?") |
|
|
|
gc.set_debug(gc.DEBUG_SAVEALL) |
|
yield |
|
|
|
|
|
n_objects_in_cycles = gc.collect() |
|
objects_in_cycles = gc.garbage[:] |
|
finally: |
|
del gc.garbage[:] |
|
gc.set_debug(gc_debug) |
|
gc.enable() |
|
|
|
if n_objects_in_cycles: |
|
name_str = f' when calling {name}' if name is not None else '' |
|
raise AssertionError( |
|
"Reference cycles were found{}: {} objects were collected, " |
|
"of which {} are shown below:{}" |
|
.format( |
|
name_str, |
|
n_objects_in_cycles, |
|
len(objects_in_cycles), |
|
''.join( |
|
"\n {} object with id={}:\n {}".format( |
|
type(o).__name__, |
|
id(o), |
|
pprint.pformat(o).replace('\n', '\n ') |
|
) for o in objects_in_cycles |
|
) |
|
) |
|
) |
|
|
|
|
|
def assert_no_gc_cycles(*args, **kwargs): |
|
""" |
|
Fail if the given callable produces any reference cycles. |
|
|
|
If called with all arguments omitted, may be used as a context manager:: |
|
|
|
with assert_no_gc_cycles(): |
|
do_something() |
|
|
|
Parameters |
|
---------- |
|
func : callable |
|
The callable to test. |
|
\\*args : Arguments |
|
Arguments passed to `func`. |
|
\\*\\*kwargs : Kwargs |
|
Keyword arguments passed to `func`. |
|
|
|
Returns |
|
------- |
|
Nothing. The result is deliberately discarded to ensure that all cycles |
|
are found. |
|
|
|
""" |
|
if not args: |
|
return _assert_no_gc_cycles_context() |
|
|
|
func = args[0] |
|
args = args[1:] |
|
with _assert_no_gc_cycles_context(name=func.__name__): |
|
func(*args, **kwargs) |
|
|
|
|
|
def break_cycles(): |
|
""" |
|
Break reference cycles by calling gc.collect |
|
Objects can call other objects' methods (for instance, another object's |
|
__del__) inside their own __del__. On PyPy, the interpreter only runs |
|
between calls to gc.collect, so multiple calls are needed to completely |
|
release all cycles. |
|
""" |
|
|
|
gc.collect() |
|
if IS_PYPY: |
|
|
|
gc.collect() |
|
gc.collect() |
|
gc.collect() |
|
gc.collect() |
|
|
|
|
|
def requires_memory(free_bytes): |
|
"""Decorator to skip a test if not enough memory is available""" |
|
import pytest |
|
|
|
def decorator(func): |
|
@wraps(func) |
|
def wrapper(*a, **kw): |
|
msg = check_free_memory(free_bytes) |
|
if msg is not None: |
|
pytest.skip(msg) |
|
|
|
try: |
|
return func(*a, **kw) |
|
except MemoryError: |
|
|
|
pytest.xfail("MemoryError raised") |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
def check_free_memory(free_bytes): |
|
""" |
|
Check whether `free_bytes` amount of memory is currently free. |
|
Returns: None if enough memory available, otherwise error message |
|
""" |
|
env_var = 'NPY_AVAILABLE_MEM' |
|
env_value = os.environ.get(env_var) |
|
if env_value is not None: |
|
try: |
|
mem_free = _parse_size(env_value) |
|
except ValueError as exc: |
|
raise ValueError(f'Invalid environment variable {env_var}: {exc}') |
|
|
|
msg = (f'{free_bytes / 1e9} GB memory required, but environment variable ' |
|
f'NPY_AVAILABLE_MEM={env_value} set') |
|
else: |
|
mem_free = _get_mem_available() |
|
|
|
if mem_free is None: |
|
msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM " |
|
"environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " |
|
"the test.") |
|
mem_free = -1 |
|
else: |
|
free_bytes_gb = free_bytes / 1e9 |
|
mem_free_gb = mem_free / 1e9 |
|
msg = f'{free_bytes_gb} GB memory required, but {mem_free_gb} GB available' |
|
|
|
return msg if mem_free < free_bytes else None |
|
|
|
|
|
def _parse_size(size_str): |
|
"""Convert memory size strings ('12 GB' etc.) to float""" |
|
suffixes = {'': 1, 'b': 1, |
|
'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4, |
|
'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4, |
|
'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4} |
|
|
|
size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format( |
|
'|'.join(suffixes.keys())), re.I) |
|
|
|
m = size_re.match(size_str.lower()) |
|
if not m or m.group(2) not in suffixes: |
|
raise ValueError(f'value {size_str!r} not a valid size') |
|
return int(float(m.group(1)) * suffixes[m.group(2)]) |
|
|
|
|
|
def _get_mem_available(): |
|
"""Return available memory in bytes, or None if unknown.""" |
|
try: |
|
import psutil |
|
return psutil.virtual_memory().available |
|
except (ImportError, AttributeError): |
|
pass |
|
|
|
if sys.platform.startswith('linux'): |
|
info = {} |
|
with open('/proc/meminfo') as f: |
|
for line in f: |
|
p = line.split() |
|
info[p[0].strip(':').lower()] = int(p[1]) * 1024 |
|
|
|
if 'memavailable' in info: |
|
|
|
return info['memavailable'] |
|
else: |
|
return info['memfree'] + info['cached'] |
|
|
|
return None |
|
|
|
|
|
def _no_tracing(func): |
|
""" |
|
Decorator to temporarily turn off tracing for the duration of a test. |
|
Needed in tests that check refcounting, otherwise the tracing itself |
|
influences the refcounts |
|
""" |
|
if not hasattr(sys, 'gettrace'): |
|
return func |
|
else: |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
original_trace = sys.gettrace() |
|
try: |
|
sys.settrace(None) |
|
return func(*args, **kwargs) |
|
finally: |
|
sys.settrace(original_trace) |
|
return wrapper |
|
|
|
|
|
def _get_glibc_version(): |
|
try: |
|
ver = os.confstr('CS_GNU_LIBC_VERSION').rsplit(' ')[1] |
|
except Exception: |
|
ver = '0.0' |
|
|
|
return ver |
|
|
|
|
|
_glibcver = _get_glibc_version() |
|
_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x) |
|
|
|
|
|
def run_threaded(func, max_workers=8, pass_count=False, |
|
pass_barrier=False, outer_iterations=1, |
|
prepare_args=None): |
|
"""Runs a function many times in parallel""" |
|
for _ in range(outer_iterations): |
|
with (concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) |
|
as tpe): |
|
if prepare_args is None: |
|
args = [] |
|
else: |
|
args = prepare_args() |
|
if pass_barrier: |
|
barrier = threading.Barrier(max_workers) |
|
args.append(barrier) |
|
if pass_count: |
|
all_args = [(func, i, *args) for i in range(max_workers)] |
|
else: |
|
all_args = [(func, *args) for i in range(max_workers)] |
|
try: |
|
futures = [] |
|
for arg in all_args: |
|
futures.append(tpe.submit(*arg)) |
|
finally: |
|
if len(futures) < max_workers and pass_barrier: |
|
barrier.abort() |
|
for f in futures: |
|
f.result() |
|
|
|
|
|
def get_stringdtype_dtype(na_object, coerce=True): |
|
|
|
if na_object is pd_NA or na_object != "unset": |
|
return np.dtypes.StringDType(na_object=na_object, coerce=coerce) |
|
else: |
|
return np.dtypes.StringDType(coerce=coerce) |
|
|