|
import warnings |
|
import sys |
|
import os |
|
import itertools |
|
import pytest |
|
import weakref |
|
import re |
|
|
|
import numpy as np |
|
import numpy._core._multiarray_umath as ncu |
|
from numpy.testing import ( |
|
assert_equal, assert_array_equal, assert_almost_equal, |
|
assert_array_almost_equal, assert_array_less, build_err_msg, |
|
assert_raises, assert_warns, assert_no_warnings, assert_allclose, |
|
assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp, |
|
clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_, |
|
tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT |
|
) |
|
|
|
|
|
class _GenericTest: |
|
|
|
def _test_equal(self, a, b): |
|
self._assert_func(a, b) |
|
|
|
def _test_not_equal(self, a, b): |
|
with assert_raises(AssertionError): |
|
self._assert_func(a, b) |
|
|
|
def test_array_rank1_eq(self): |
|
"""Test two equal array of rank 1 are found equal.""" |
|
a = np.array([1, 2]) |
|
b = np.array([1, 2]) |
|
|
|
self._test_equal(a, b) |
|
|
|
def test_array_rank1_noteq(self): |
|
"""Test two different array of rank 1 are found not equal.""" |
|
a = np.array([1, 2]) |
|
b = np.array([2, 2]) |
|
|
|
self._test_not_equal(a, b) |
|
|
|
def test_array_rank2_eq(self): |
|
"""Test two equal array of rank 2 are found equal.""" |
|
a = np.array([[1, 2], [3, 4]]) |
|
b = np.array([[1, 2], [3, 4]]) |
|
|
|
self._test_equal(a, b) |
|
|
|
def test_array_diffshape(self): |
|
"""Test two arrays with different shapes are found not equal.""" |
|
a = np.array([1, 2]) |
|
b = np.array([[1, 2], [1, 2]]) |
|
|
|
self._test_not_equal(a, b) |
|
|
|
def test_objarray(self): |
|
"""Test object arrays.""" |
|
a = np.array([1, 1], dtype=object) |
|
self._test_equal(a, 1) |
|
|
|
def test_array_likes(self): |
|
self._test_equal([1, 2, 3], (1, 2, 3)) |
|
|
|
|
|
class TestArrayEqual(_GenericTest): |
|
|
|
def setup_method(self): |
|
self._assert_func = assert_array_equal |
|
|
|
def test_generic_rank1(self): |
|
"""Test rank 1 array for all dtypes.""" |
|
def foo(t): |
|
a = np.empty(2, t) |
|
a.fill(1) |
|
b = a.copy() |
|
c = a.copy() |
|
c.fill(0) |
|
self._test_equal(a, b) |
|
self._test_not_equal(c, b) |
|
|
|
|
|
for t in '?bhilqpBHILQPfdgFDG': |
|
foo(t) |
|
|
|
|
|
for t in ['S1', 'U1']: |
|
foo(t) |
|
|
|
def test_0_ndim_array(self): |
|
x = np.array(473963742225900817127911193656584771) |
|
y = np.array(18535119325151578301457182298393896) |
|
|
|
with pytest.raises(AssertionError) as exc_info: |
|
self._assert_func(x, y) |
|
msg = str(exc_info.value) |
|
assert_('Mismatched elements: 1 / 1 (100%)\n' |
|
in msg) |
|
|
|
y = x |
|
self._assert_func(x, y) |
|
|
|
x = np.array(4395065348745.5643764887869876) |
|
y = np.array(0) |
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: ' |
|
'4.39506535e+12\n' |
|
'Max relative difference among violations: inf\n') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
x = y |
|
self._assert_func(x, y) |
|
|
|
def test_generic_rank3(self): |
|
"""Test rank 3 array for all dtypes.""" |
|
def foo(t): |
|
a = np.empty((4, 2, 3), t) |
|
a.fill(1) |
|
b = a.copy() |
|
c = a.copy() |
|
c.fill(0) |
|
self._test_equal(a, b) |
|
self._test_not_equal(c, b) |
|
|
|
|
|
for t in '?bhilqpBHILQPfdgFDG': |
|
foo(t) |
|
|
|
|
|
for t in ['S1', 'U1']: |
|
foo(t) |
|
|
|
def test_nan_array(self): |
|
"""Test arrays with nan values in them.""" |
|
a = np.array([1, 2, np.nan]) |
|
b = np.array([1, 2, np.nan]) |
|
|
|
self._test_equal(a, b) |
|
|
|
c = np.array([1, 2, 3]) |
|
self._test_not_equal(c, b) |
|
|
|
def test_string_arrays(self): |
|
"""Test two arrays with different shapes are found not equal.""" |
|
a = np.array(['floupi', 'floupa']) |
|
b = np.array(['floupi', 'floupa']) |
|
|
|
self._test_equal(a, b) |
|
|
|
c = np.array(['floupipi', 'floupa']) |
|
|
|
self._test_not_equal(c, b) |
|
|
|
def test_recarrays(self): |
|
"""Test record arrays.""" |
|
a = np.empty(2, [('floupi', float), ('floupa', float)]) |
|
a['floupi'] = [1, 2] |
|
a['floupa'] = [1, 2] |
|
b = a.copy() |
|
|
|
self._test_equal(a, b) |
|
|
|
c = np.empty(2, [('floupipi', float), |
|
('floupi', float), ('floupa', float)]) |
|
c['floupipi'] = a['floupi'].copy() |
|
c['floupa'] = a['floupa'].copy() |
|
|
|
with pytest.raises(TypeError): |
|
self._test_not_equal(c, b) |
|
|
|
def test_masked_nan_inf(self): |
|
|
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False]) |
|
b = np.array([3., np.nan, 6.5]) |
|
self._test_equal(a, b) |
|
self._test_equal(b, a) |
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False]) |
|
b = np.array([np.inf, 4., 6.5]) |
|
self._test_equal(a, b) |
|
self._test_equal(b, a) |
|
|
|
def test_subclass_that_overrides_eq(self): |
|
|
|
|
|
|
|
|
|
class MyArray(np.ndarray): |
|
def __eq__(self, other): |
|
return bool(np.equal(self, other).all()) |
|
|
|
def __ne__(self, other): |
|
return not self == other |
|
|
|
a = np.array([1., 2.]).view(MyArray) |
|
b = np.array([2., 3.]).view(MyArray) |
|
assert_(type(a == a), bool) |
|
assert_(a == a) |
|
assert_(a != b) |
|
self._test_equal(a, a) |
|
self._test_not_equal(a, b) |
|
self._test_not_equal(b, a) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: 1.\n' |
|
'Max relative difference among violations: 0.5') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._test_equal(a, b) |
|
|
|
c = np.array([0., 2.9]).view(MyArray) |
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: 2.\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._test_equal(b, c) |
|
|
|
def test_subclass_that_does_not_implement_npall(self): |
|
class MyArray(np.ndarray): |
|
def __array_function__(self, *args, **kwargs): |
|
return NotImplemented |
|
|
|
a = np.array([1., 2.]).view(MyArray) |
|
b = np.array([2., 3.]).view(MyArray) |
|
with assert_raises(TypeError): |
|
np.all(a) |
|
self._test_equal(a, a) |
|
self._test_not_equal(a, b) |
|
self._test_not_equal(b, a) |
|
|
|
def test_suppress_overflow_warnings(self): |
|
|
|
with pytest.raises(AssertionError): |
|
with np.errstate(all="raise"): |
|
np.testing.assert_array_equal( |
|
np.array([1, 2, 3], np.float32), |
|
np.array([1, 1e-40, 3], np.float32)) |
|
|
|
def test_array_vs_scalar_is_equal(self): |
|
"""Test comparing an array with a scalar when all values are equal.""" |
|
a = np.array([1., 1., 1.]) |
|
b = 1. |
|
|
|
self._test_equal(a, b) |
|
|
|
def test_array_vs_array_not_equal(self): |
|
"""Test comparing an array with a scalar when not all values equal.""" |
|
a = np.array([34986, 545676, 439655, 563766]) |
|
b = np.array([34986, 545676, 439655, 0]) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n' |
|
'Max absolute difference among violations: 563766\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(a, b) |
|
|
|
a = np.array([34986, 545676, 439655.2, 563766]) |
|
expected_msg = ('Mismatched elements: 2 / 4 (50%)\n' |
|
'Max absolute difference among violations: ' |
|
'563766.\n' |
|
'Max relative difference among violations: ' |
|
'4.54902139e-07') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(a, b) |
|
|
|
def test_array_vs_scalar_strict(self): |
|
"""Test comparing an array with a scalar with strict option.""" |
|
a = np.array([1., 1., 1.]) |
|
b = 1. |
|
|
|
with pytest.raises(AssertionError): |
|
self._assert_func(a, b, strict=True) |
|
|
|
def test_array_vs_array_strict(self): |
|
"""Test comparing two arrays with strict option.""" |
|
a = np.array([1., 1., 1.]) |
|
b = np.array([1., 1., 1.]) |
|
|
|
self._assert_func(a, b, strict=True) |
|
|
|
def test_array_vs_float_array_strict(self): |
|
"""Test comparing two arrays with strict option.""" |
|
a = np.array([1, 1, 1]) |
|
b = np.array([1., 1., 1.]) |
|
|
|
with pytest.raises(AssertionError): |
|
self._assert_func(a, b, strict=True) |
|
|
|
|
|
class TestBuildErrorMessage: |
|
|
|
def test_build_err_msg_defaults(self): |
|
x = np.array([1.00001, 2.00002, 3.00003]) |
|
y = np.array([1.00002, 2.00003, 3.00004]) |
|
err_msg = 'There is a mismatch' |
|
|
|
a = build_err_msg([x, y], err_msg) |
|
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([' |
|
'1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, ' |
|
'2.00003, 3.00004])') |
|
assert_equal(a, b) |
|
|
|
def test_build_err_msg_no_verbose(self): |
|
x = np.array([1.00001, 2.00002, 3.00003]) |
|
y = np.array([1.00002, 2.00003, 3.00004]) |
|
err_msg = 'There is a mismatch' |
|
|
|
a = build_err_msg([x, y], err_msg, verbose=False) |
|
b = '\nItems are not equal: There is a mismatch' |
|
assert_equal(a, b) |
|
|
|
def test_build_err_msg_custom_names(self): |
|
x = np.array([1.00001, 2.00002, 3.00003]) |
|
y = np.array([1.00002, 2.00003, 3.00004]) |
|
err_msg = 'There is a mismatch' |
|
|
|
a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR')) |
|
b = ('\nItems are not equal: There is a mismatch\n FOO: array([' |
|
'1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, ' |
|
'3.00004])') |
|
assert_equal(a, b) |
|
|
|
def test_build_err_msg_custom_precision(self): |
|
x = np.array([1.000000001, 2.00002, 3.00003]) |
|
y = np.array([1.000000002, 2.00003, 3.00004]) |
|
err_msg = 'There is a mismatch' |
|
|
|
a = build_err_msg([x, y], err_msg, precision=10) |
|
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([' |
|
'1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([' |
|
'1.000000002, 2.00003 , 3.00004 ])') |
|
assert_equal(a, b) |
|
|
|
|
|
class TestEqual(TestArrayEqual): |
|
|
|
def setup_method(self): |
|
self._assert_func = assert_equal |
|
|
|
def test_nan_items(self): |
|
self._assert_func(np.nan, np.nan) |
|
self._assert_func([np.nan], [np.nan]) |
|
self._test_not_equal(np.nan, [np.nan]) |
|
self._test_not_equal(np.nan, 1) |
|
|
|
def test_inf_items(self): |
|
self._assert_func(np.inf, np.inf) |
|
self._assert_func([np.inf], [np.inf]) |
|
self._test_not_equal(np.inf, [np.inf]) |
|
|
|
def test_datetime(self): |
|
self._test_equal( |
|
np.datetime64("2017-01-01", "s"), |
|
np.datetime64("2017-01-01", "s") |
|
) |
|
self._test_equal( |
|
np.datetime64("2017-01-01", "s"), |
|
np.datetime64("2017-01-01", "m") |
|
) |
|
|
|
|
|
self._test_not_equal( |
|
np.datetime64("2017-01-01", "s"), |
|
np.datetime64("2017-01-02", "s") |
|
) |
|
self._test_not_equal( |
|
np.datetime64("2017-01-01", "s"), |
|
np.datetime64("2017-01-02", "m") |
|
) |
|
|
|
def test_nat_items(self): |
|
|
|
nadt_no_unit = np.datetime64("NaT") |
|
nadt_s = np.datetime64("NaT", "s") |
|
nadt_d = np.datetime64("NaT", "ns") |
|
|
|
natd_no_unit = np.timedelta64("NaT") |
|
natd_s = np.timedelta64("NaT", "s") |
|
natd_d = np.timedelta64("NaT", "ns") |
|
|
|
dts = [nadt_no_unit, nadt_s, nadt_d] |
|
tds = [natd_no_unit, natd_s, natd_d] |
|
for a, b in itertools.product(dts, dts): |
|
self._assert_func(a, b) |
|
self._assert_func([a], [b]) |
|
self._test_not_equal([a], b) |
|
|
|
for a, b in itertools.product(tds, tds): |
|
self._assert_func(a, b) |
|
self._assert_func([a], [b]) |
|
self._test_not_equal([a], b) |
|
|
|
for a, b in itertools.product(tds, dts): |
|
self._test_not_equal(a, b) |
|
self._test_not_equal(a, [b]) |
|
self._test_not_equal([a], [b]) |
|
self._test_not_equal([a], np.datetime64("2017-01-01", "s")) |
|
self._test_not_equal([b], np.datetime64("2017-01-01", "s")) |
|
self._test_not_equal([a], np.timedelta64(123, "s")) |
|
self._test_not_equal([b], np.timedelta64(123, "s")) |
|
|
|
def test_non_numeric(self): |
|
self._assert_func('ab', 'ab') |
|
self._test_not_equal('ab', 'abb') |
|
|
|
def test_complex_item(self): |
|
self._assert_func(complex(1, 2), complex(1, 2)) |
|
self._assert_func(complex(1, np.nan), complex(1, np.nan)) |
|
self._test_not_equal(complex(1, np.nan), complex(1, 2)) |
|
self._test_not_equal(complex(np.nan, 1), complex(1, np.nan)) |
|
self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2)) |
|
|
|
def test_negative_zero(self): |
|
self._test_not_equal(ncu.PZERO, ncu.NZERO) |
|
|
|
def test_complex(self): |
|
x = np.array([complex(1, 2), complex(1, np.nan)]) |
|
y = np.array([complex(1, 2), complex(1, 2)]) |
|
self._assert_func(x, x) |
|
self._test_not_equal(x, y) |
|
|
|
def test_object(self): |
|
|
|
import datetime |
|
a = np.array([datetime.datetime(2000, 1, 1), |
|
datetime.datetime(2000, 1, 2)]) |
|
self._test_not_equal(a, a[::-1]) |
|
|
|
|
|
class TestArrayAlmostEqual(_GenericTest): |
|
|
|
def setup_method(self): |
|
self._assert_func = assert_array_almost_equal |
|
|
|
def test_closeness(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: 1.5\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(1.5, 0.0, decimal=0) |
|
|
|
|
|
self._assert_func([1.499999], [0.0], decimal=0) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: 1.5\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func([1.5], [0.0], decimal=0) |
|
|
|
a = [1.4999999, 0.00003] |
|
b = [1.49999991, 0] |
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: 3.e-05\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(a, b, decimal=7) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: 3.e-05\n' |
|
'Max relative difference among violations: 1.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(b, a, decimal=7) |
|
|
|
def test_simple(self): |
|
x = np.array([1234.2222]) |
|
y = np.array([1234.2223]) |
|
|
|
self._assert_func(x, y, decimal=3) |
|
self._assert_func(x, y, decimal=4) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: ' |
|
'1.e-04\n' |
|
'Max relative difference among violations: ' |
|
'8.10226812e-08') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y, decimal=5) |
|
|
|
def test_array_vs_scalar(self): |
|
a = [5498.42354, 849.54345, 0.00] |
|
b = 5498.42354 |
|
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n' |
|
'Max absolute difference among violations: ' |
|
'5498.42354\n' |
|
'Max relative difference among violations: 1.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(a, b, decimal=9) |
|
|
|
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n' |
|
'Max absolute difference among violations: ' |
|
'5498.42354\n' |
|
'Max relative difference among violations: 5.4722099') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(b, a, decimal=9) |
|
|
|
a = [5498.42354, 0.00] |
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: ' |
|
'5498.42354\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(b, a, decimal=7) |
|
|
|
b = 0 |
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: ' |
|
'5498.42354\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(a, b, decimal=7) |
|
|
|
def test_nan(self): |
|
anan = np.array([np.nan]) |
|
aone = np.array([1]) |
|
ainf = np.array([np.inf]) |
|
self._assert_func(anan, anan) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(anan, aone)) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(anan, ainf)) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(ainf, anan)) |
|
|
|
def test_inf(self): |
|
a = np.array([[1., 2.], [3., 4.]]) |
|
b = a.copy() |
|
a[0, 0] = np.inf |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(a, b)) |
|
b[0, 0] = -np.inf |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(a, b)) |
|
|
|
def test_subclass(self): |
|
a = np.array([[1., 2.], [3., 4.]]) |
|
b = np.ma.masked_array([[1., 2.], [0., 4.]], |
|
[[False, False], [True, False]]) |
|
self._assert_func(a, b) |
|
self._assert_func(b, a) |
|
self._assert_func(b, b) |
|
|
|
|
|
a = np.ma.MaskedArray(3.5, mask=True) |
|
b = np.array([3., 4., 6.5]) |
|
self._test_equal(a, b) |
|
self._test_equal(b, a) |
|
a = np.ma.masked |
|
b = np.array([3., 4., 6.5]) |
|
self._test_equal(a, b) |
|
self._test_equal(b, a) |
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True]) |
|
b = np.array([1., 2., 3.]) |
|
self._test_equal(a, b) |
|
self._test_equal(b, a) |
|
a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True]) |
|
b = np.array(1.) |
|
self._test_equal(a, b) |
|
self._test_equal(b, a) |
|
|
|
def test_subclass_2(self): |
|
|
|
|
|
|
|
|
|
class MyArray(np.ndarray): |
|
def __eq__(self, other): |
|
return super().__eq__(other).view(np.ndarray) |
|
|
|
def __lt__(self, other): |
|
return super().__lt__(other).view(np.ndarray) |
|
|
|
def all(self, *args, **kwargs): |
|
return all(self) |
|
|
|
a = np.array([1., 2.]).view(MyArray) |
|
self._assert_func(a, a) |
|
|
|
z = np.array([True, True]).view(MyArray) |
|
all(z) |
|
b = np.array([1., 202]).view(MyArray) |
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: 200.\n' |
|
'Max relative difference among violations: 0.99009') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(a, b) |
|
|
|
def test_subclass_that_cannot_be_bool(self): |
|
|
|
|
|
|
|
|
|
class MyArray(np.ndarray): |
|
def __eq__(self, other): |
|
return super().__eq__(other).view(np.ndarray) |
|
|
|
def __lt__(self, other): |
|
return super().__lt__(other).view(np.ndarray) |
|
|
|
def all(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
a = np.array([1., 2.]).view(MyArray) |
|
self._assert_func(a, a) |
|
|
|
|
|
class TestAlmostEqual(_GenericTest): |
|
|
|
def setup_method(self): |
|
self._assert_func = assert_almost_equal |
|
|
|
def test_closeness(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._assert_func(1.499999, 0.0, decimal=0) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(1.5, 0.0, decimal=0)) |
|
|
|
|
|
self._assert_func([1.499999], [0.0], decimal=0) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func([1.5], [0.0], decimal=0)) |
|
|
|
def test_nan_item(self): |
|
self._assert_func(np.nan, np.nan) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(np.nan, 1)) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(np.nan, np.inf)) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(np.inf, np.nan)) |
|
|
|
def test_inf_item(self): |
|
self._assert_func(np.inf, np.inf) |
|
self._assert_func(-np.inf, -np.inf) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(np.inf, 1)) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(-np.inf, np.inf)) |
|
|
|
def test_simple_item(self): |
|
self._test_not_equal(1, 2) |
|
|
|
def test_complex_item(self): |
|
self._assert_func(complex(1, 2), complex(1, 2)) |
|
self._assert_func(complex(1, np.nan), complex(1, np.nan)) |
|
self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan)) |
|
self._test_not_equal(complex(1, np.nan), complex(1, 2)) |
|
self._test_not_equal(complex(np.nan, 1), complex(1, np.nan)) |
|
self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2)) |
|
|
|
def test_complex(self): |
|
x = np.array([complex(1, 2), complex(1, np.nan)]) |
|
z = np.array([complex(1, 2), complex(np.nan, 1)]) |
|
y = np.array([complex(1, 2), complex(1, 2)]) |
|
self._assert_func(x, x) |
|
self._test_not_equal(x, y) |
|
self._test_not_equal(x, z) |
|
|
|
def test_error_message(self): |
|
"""Check the message is formatted correctly for the decimal value. |
|
Also check the message when input includes inf or nan (gh12200)""" |
|
x = np.array([1.00000000001, 2.00000000002, 3.00003]) |
|
y = np.array([1.00000000002, 2.00000000003, 3.00004]) |
|
|
|
|
|
expected_msg = ('Mismatched elements: 3 / 3 (100%)\n' |
|
'Max absolute difference among violations: 1.e-05\n' |
|
'Max relative difference among violations: ' |
|
'3.33328889e-06\n' |
|
' ACTUAL: array([1.00000000001, ' |
|
'2.00000000002, ' |
|
'3.00003 ])\n' |
|
' DESIRED: array([1.00000000002, 2.00000000003, ' |
|
'3.00004 ])') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y, decimal=12) |
|
|
|
|
|
|
|
|
|
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n' |
|
'Max absolute difference among violations: 1.e-05\n' |
|
'Max relative difference among violations: ' |
|
'3.33328889e-06\n' |
|
' ACTUAL: array([1. , 2. , 3.00003])\n' |
|
' DESIRED: array([1. , 2. , 3.00004])') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
|
|
x = np.array([np.inf, 0]) |
|
y = np.array([np.inf, 1]) |
|
expected_msg = ('Mismatched elements: 1 / 2 (50%)\n' |
|
'Max absolute difference among violations: 1.\n' |
|
'Max relative difference among violations: 1.\n' |
|
' ACTUAL: array([inf, 0.])\n' |
|
' DESIRED: array([inf, 1.])') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
|
|
x = np.array([1, 2]) |
|
y = np.array([0, 0]) |
|
expected_msg = ('Mismatched elements: 2 / 2 (100%)\n' |
|
'Max absolute difference among violations: 2\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
def test_error_message_2(self): |
|
"""Check the message is formatted correctly """ |
|
"""when either x or y is a scalar.""" |
|
x = 2 |
|
y = np.ones(20) |
|
expected_msg = ('Mismatched elements: 20 / 20 (100%)\n' |
|
'Max absolute difference among violations: 1.\n' |
|
'Max relative difference among violations: 1.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
y = 2 |
|
x = np.ones(20) |
|
expected_msg = ('Mismatched elements: 20 / 20 (100%)\n' |
|
'Max absolute difference among violations: 1.\n' |
|
'Max relative difference among violations: 0.5') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
def test_subclass_that_cannot_be_bool(self): |
|
|
|
|
|
|
|
|
|
class MyArray(np.ndarray): |
|
def __eq__(self, other): |
|
return super().__eq__(other).view(np.ndarray) |
|
|
|
def __lt__(self, other): |
|
return super().__lt__(other).view(np.ndarray) |
|
|
|
def all(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
a = np.array([1., 2.]).view(MyArray) |
|
self._assert_func(a, a) |
|
|
|
|
|
class TestApproxEqual: |
|
|
|
def setup_method(self): |
|
self._assert_func = assert_approx_equal |
|
|
|
def test_simple_0d_arrays(self): |
|
x = np.array(1234.22) |
|
y = np.array(1234.23) |
|
|
|
self._assert_func(x, y, significant=5) |
|
self._assert_func(x, y, significant=6) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(x, y, significant=7)) |
|
|
|
def test_simple_items(self): |
|
x = 1234.22 |
|
y = 1234.23 |
|
|
|
self._assert_func(x, y, significant=4) |
|
self._assert_func(x, y, significant=5) |
|
self._assert_func(x, y, significant=6) |
|
assert_raises(AssertionError, |
|
lambda: self._assert_func(x, y, significant=7)) |
|
|
|
def test_nan_array(self): |
|
anan = np.array(np.nan) |
|
aone = np.array(1) |
|
ainf = np.array(np.inf) |
|
self._assert_func(anan, anan) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, aone)) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, anan)) |
|
|
|
def test_nan_items(self): |
|
anan = np.array(np.nan) |
|
aone = np.array(1) |
|
ainf = np.array(np.inf) |
|
self._assert_func(anan, anan) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, aone)) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, anan)) |
|
|
|
|
|
class TestArrayAssertLess: |
|
|
|
def setup_method(self): |
|
self._assert_func = assert_array_less |
|
|
|
def test_simple_arrays(self): |
|
x = np.array([1.1, 2.2]) |
|
y = np.array([1.2, 2.3]) |
|
|
|
self._assert_func(x, y) |
|
assert_raises(AssertionError, lambda: self._assert_func(y, x)) |
|
|
|
y = np.array([1.0, 2.3]) |
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, y)) |
|
assert_raises(AssertionError, lambda: self._assert_func(y, x)) |
|
|
|
a = np.array([1, 3, 6, 20]) |
|
b = np.array([2, 4, 6, 8]) |
|
|
|
expected_msg = ('Mismatched elements: 2 / 4 (50%)\n' |
|
'Max absolute difference among violations: 12\n' |
|
'Max relative difference among violations: 1.5') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(a, b) |
|
|
|
def test_rank2(self): |
|
x = np.array([[1.1, 2.2], [3.3, 4.4]]) |
|
y = np.array([[1.2, 2.3], [3.4, 4.5]]) |
|
|
|
self._assert_func(x, y) |
|
expected_msg = ('Mismatched elements: 4 / 4 (100%)\n' |
|
'Max absolute difference among violations: 0.1\n' |
|
'Max relative difference among violations: 0.09090909') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(y, x) |
|
|
|
y = np.array([[1.0, 2.3], [3.4, 4.5]]) |
|
assert_raises(AssertionError, lambda: self._assert_func(x, y)) |
|
assert_raises(AssertionError, lambda: self._assert_func(y, x)) |
|
|
|
def test_rank3(self): |
|
x = np.ones(shape=(2, 2, 2)) |
|
y = np.ones(shape=(2, 2, 2))+1 |
|
|
|
self._assert_func(x, y) |
|
assert_raises(AssertionError, lambda: self._assert_func(y, x)) |
|
|
|
y[0, 0, 0] = 0 |
|
expected_msg = ('Mismatched elements: 1 / 8 (12.5%)\n' |
|
'Max absolute difference among violations: 1.\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(y, x)) |
|
|
|
def test_simple_items(self): |
|
x = 1.1 |
|
y = 2.2 |
|
|
|
self._assert_func(x, y) |
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: 1.1\n' |
|
'Max relative difference among violations: 1.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(y, x) |
|
|
|
y = np.array([2.2, 3.3]) |
|
|
|
self._assert_func(x, y) |
|
assert_raises(AssertionError, lambda: self._assert_func(y, x)) |
|
|
|
y = np.array([1.0, 3.3]) |
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, y)) |
|
|
|
def test_simple_items_and_array(self): |
|
x = np.array([[621.345454, 390.5436, 43.54657, 626.4535], |
|
[54.54, 627.3399, 13., 405.5435], |
|
[543.545, 8.34, 91.543, 333.3]]) |
|
y = 627.34 |
|
self._assert_func(x, y) |
|
|
|
y = 8.339999 |
|
self._assert_func(y, x) |
|
|
|
x = np.array([[3.4536, 2390.5436, 435.54657, 324525.4535], |
|
[5449.54, 999090.54, 130303.54, 405.5435], |
|
[543.545, 8.34, 91.543, 999090.53999]]) |
|
y = 999090.54 |
|
|
|
expected_msg = ('Mismatched elements: 1 / 12 (8.33%)\n' |
|
'Max absolute difference among violations: 0.\n' |
|
'Max relative difference among violations: 0.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
expected_msg = ('Mismatched elements: 12 / 12 (100%)\n' |
|
'Max absolute difference among violations: ' |
|
'999087.0864\n' |
|
'Max relative difference among violations: ' |
|
'289288.5934676') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(y, x) |
|
|
|
def test_zeroes(self): |
|
x = np.array([546456., 0, 15.455]) |
|
y = np.array(87654.) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n' |
|
'Max absolute difference among violations: 458802.\n' |
|
'Max relative difference among violations: 5.23423917') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n' |
|
'Max absolute difference among violations: 87654.\n' |
|
'Max relative difference among violations: ' |
|
'5670.5626011') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(y, x) |
|
|
|
y = 0 |
|
|
|
expected_msg = ('Mismatched elements: 3 / 3 (100%)\n' |
|
'Max absolute difference among violations: 546456.\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(x, y) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n' |
|
'Max absolute difference among violations: 0.\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
self._assert_func(y, x) |
|
|
|
def test_nan_noncompare(self): |
|
anan = np.array(np.nan) |
|
aone = np.array(1) |
|
ainf = np.array(np.inf) |
|
self._assert_func(anan, anan) |
|
assert_raises(AssertionError, lambda: self._assert_func(aone, anan)) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, aone)) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, anan)) |
|
|
|
def test_nan_noncompare_array(self): |
|
x = np.array([1.1, 2.2, 3.3]) |
|
anan = np.array(np.nan) |
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, anan)) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, x)) |
|
|
|
x = np.array([1.1, 2.2, np.nan]) |
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, anan)) |
|
assert_raises(AssertionError, lambda: self._assert_func(anan, x)) |
|
|
|
y = np.array([1.0, 2.0, np.nan]) |
|
|
|
self._assert_func(y, x) |
|
assert_raises(AssertionError, lambda: self._assert_func(x, y)) |
|
|
|
def test_inf_compare(self): |
|
aone = np.array(1) |
|
ainf = np.array(np.inf) |
|
|
|
self._assert_func(aone, ainf) |
|
self._assert_func(-ainf, aone) |
|
self._assert_func(-ainf, ainf) |
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, aone)) |
|
assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf)) |
|
|
|
def test_inf_compare_array(self): |
|
x = np.array([1.1, 2.2, np.inf]) |
|
ainf = np.array(np.inf) |
|
|
|
assert_raises(AssertionError, lambda: self._assert_func(x, ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(ainf, x)) |
|
assert_raises(AssertionError, lambda: self._assert_func(x, -ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf)) |
|
assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x)) |
|
self._assert_func(-ainf, x) |
|
|
|
def test_strict(self): |
|
"""Test the behavior of the `strict` option.""" |
|
x = np.zeros(3) |
|
y = np.ones(()) |
|
self._assert_func(x, y) |
|
with pytest.raises(AssertionError): |
|
self._assert_func(x, y, strict=True) |
|
y = np.broadcast_to(y, x.shape) |
|
self._assert_func(x, y) |
|
with pytest.raises(AssertionError): |
|
self._assert_func(x, y.astype(np.float32), strict=True) |
|
|
|
|
|
class TestWarns: |
|
|
|
def test_warn(self): |
|
def f(): |
|
warnings.warn("yo") |
|
return 3 |
|
|
|
before_filters = sys.modules['warnings'].filters[:] |
|
assert_equal(assert_warns(UserWarning, f), 3) |
|
after_filters = sys.modules['warnings'].filters |
|
|
|
assert_raises(AssertionError, assert_no_warnings, f) |
|
assert_equal(assert_no_warnings(lambda x: x, 1), 1) |
|
|
|
|
|
assert_equal(before_filters, after_filters, |
|
"assert_warns does not preserver warnings state") |
|
|
|
def test_context_manager(self): |
|
|
|
before_filters = sys.modules['warnings'].filters[:] |
|
with assert_warns(UserWarning): |
|
warnings.warn("yo") |
|
after_filters = sys.modules['warnings'].filters |
|
|
|
def no_warnings(): |
|
with assert_no_warnings(): |
|
warnings.warn("yo") |
|
|
|
assert_raises(AssertionError, no_warnings) |
|
assert_equal(before_filters, after_filters, |
|
"assert_warns does not preserver warnings state") |
|
|
|
def test_args(self): |
|
def f(a=0, b=1): |
|
warnings.warn("yo") |
|
return a + b |
|
|
|
assert assert_warns(UserWarning, f, b=20) == 20 |
|
|
|
with pytest.raises(RuntimeError) as exc: |
|
|
|
with assert_warns(UserWarning, match="A"): |
|
warnings.warn("B", UserWarning) |
|
assert "assert_warns" in str(exc) |
|
assert "pytest.warns" in str(exc) |
|
|
|
with pytest.raises(RuntimeError) as exc: |
|
|
|
with assert_warns(UserWarning, wrong="A"): |
|
warnings.warn("B", UserWarning) |
|
assert "assert_warns" in str(exc) |
|
assert "pytest.warns" not in str(exc) |
|
|
|
def test_warn_wrong_warning(self): |
|
def f(): |
|
warnings.warn("yo", DeprecationWarning) |
|
|
|
failed = False |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("error", DeprecationWarning) |
|
try: |
|
|
|
assert_warns(UserWarning, f) |
|
failed = True |
|
except DeprecationWarning: |
|
pass |
|
|
|
if failed: |
|
raise AssertionError("wrong warning caught by assert_warn") |
|
|
|
|
|
class TestAssertAllclose: |
|
|
|
def test_simple(self): |
|
x = 1e-3 |
|
y = 1e-9 |
|
|
|
assert_allclose(x, y, atol=1) |
|
assert_raises(AssertionError, assert_allclose, x, y) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: 0.001\n' |
|
'Max relative difference among violations: 999999.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(x, y) |
|
|
|
z = 0 |
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: 1.e-09\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(y, z) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 1 (100%)\n' |
|
'Max absolute difference among violations: 1.e-09\n' |
|
'Max relative difference among violations: 1.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(z, y) |
|
|
|
a = np.array([x, y, x, y]) |
|
b = np.array([x, y, x, x]) |
|
|
|
assert_allclose(a, b, atol=1) |
|
assert_raises(AssertionError, assert_allclose, a, b) |
|
|
|
b[-1] = y * (1 + 1e-8) |
|
assert_allclose(a, b) |
|
assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9) |
|
|
|
assert_allclose(6, 10, rtol=0.5) |
|
assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5) |
|
|
|
b = np.array([x, y, x, x]) |
|
c = np.array([x, y, x, z]) |
|
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n' |
|
'Max absolute difference among violations: 0.001\n' |
|
'Max relative difference among violations: inf') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(b, c) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n' |
|
'Max absolute difference among violations: 0.001\n' |
|
'Max relative difference among violations: 1.') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(c, b) |
|
|
|
def test_min_int(self): |
|
a = np.array([np.iinfo(np.int_).min], dtype=np.int_) |
|
|
|
assert_allclose(a, a) |
|
|
|
def test_report_fail_percentage(self): |
|
a = np.array([1, 1, 1, 1]) |
|
b = np.array([1, 1, 1, 2]) |
|
|
|
expected_msg = ('Mismatched elements: 1 / 4 (25%)\n' |
|
'Max absolute difference among violations: 1\n' |
|
'Max relative difference among violations: 0.5') |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(a, b) |
|
|
|
def test_equal_nan(self): |
|
a = np.array([np.nan]) |
|
b = np.array([np.nan]) |
|
|
|
assert_allclose(a, b, equal_nan=True) |
|
|
|
def test_not_equal_nan(self): |
|
a = np.array([np.nan]) |
|
b = np.array([np.nan]) |
|
assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False) |
|
|
|
def test_equal_nan_default(self): |
|
|
|
|
|
|
|
a = np.array([np.nan]) |
|
b = np.array([np.nan]) |
|
assert_array_equal(a, b) |
|
assert_array_almost_equal(a, b) |
|
assert_array_less(a, b) |
|
assert_allclose(a, b) |
|
|
|
def test_report_max_relative_error(self): |
|
a = np.array([0, 1]) |
|
b = np.array([0, 2]) |
|
|
|
expected_msg = 'Max relative difference among violations: 0.5' |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(a, b) |
|
|
|
def test_timedelta(self): |
|
|
|
a = np.array([[1, 2, 3, "NaT"]], dtype="m8[ns]") |
|
assert_allclose(a, a) |
|
|
|
def test_error_message_unsigned(self): |
|
"""Check the message is formatted correctly when overflow can occur |
|
(gh21768)""" |
|
|
|
|
|
|
|
|
|
x = np.asarray([0, 1, 8], dtype='uint8') |
|
y = np.asarray([4, 4, 4], dtype='uint8') |
|
expected_msg = 'Max absolute difference among violations: 4' |
|
with pytest.raises(AssertionError, match=re.escape(expected_msg)): |
|
assert_allclose(x, y, atol=3) |
|
|
|
def test_strict(self): |
|
"""Test the behavior of the `strict` option.""" |
|
x = np.ones(3) |
|
y = np.ones(()) |
|
assert_allclose(x, y) |
|
with pytest.raises(AssertionError): |
|
assert_allclose(x, y, strict=True) |
|
assert_allclose(x, x) |
|
with pytest.raises(AssertionError): |
|
assert_allclose(x, x.astype(np.float32), strict=True) |
|
|
|
|
|
class TestArrayAlmostEqualNulp: |
|
|
|
def test_float64_pass(self): |
|
|
|
|
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float64) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
|
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp/2. |
|
assert_array_almost_equal_nulp(x, y, nulp) |
|
|
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp/2. |
|
assert_array_almost_equal_nulp(x, y, nulp) |
|
|
|
def test_float64_fail(self): |
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float64) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
x, y, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
x, y, nulp) |
|
|
|
def test_float64_ignore_nan(self): |
|
|
|
|
|
|
|
offset = np.uint64(0xffffffff) |
|
nan1_i64 = np.array(np.nan, dtype=np.float64).view(np.uint64) |
|
nan2_i64 = nan1_i64 ^ offset |
|
nan1_f64 = nan1_i64.view(np.float64) |
|
nan2_f64 = nan2_i64.view(np.float64) |
|
assert_array_max_ulp(nan1_f64, nan2_f64, 0) |
|
|
|
def test_float32_pass(self): |
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float32) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp/2. |
|
assert_array_almost_equal_nulp(x, y, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp/2. |
|
assert_array_almost_equal_nulp(x, y, nulp) |
|
|
|
def test_float32_fail(self): |
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float32) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
x, y, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
x, y, nulp) |
|
|
|
def test_float32_ignore_nan(self): |
|
|
|
|
|
|
|
offset = np.uint32(0xffff) |
|
nan1_i32 = np.array(np.nan, dtype=np.float32).view(np.uint32) |
|
nan2_i32 = nan1_i32 ^ offset |
|
nan1_f32 = nan1_i32.view(np.float32) |
|
nan2_f32 = nan2_i32.view(np.float32) |
|
assert_array_max_ulp(nan1_f32, nan2_f32, 0) |
|
|
|
def test_float16_pass(self): |
|
nulp = 5 |
|
x = np.linspace(-4, 4, 10, dtype=np.float16) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp/2. |
|
assert_array_almost_equal_nulp(x, y, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp/2. |
|
assert_array_almost_equal_nulp(x, y, nulp) |
|
|
|
def test_float16_fail(self): |
|
nulp = 5 |
|
x = np.linspace(-4, 4, 10, dtype=np.float16) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
x, y, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
x, y, nulp) |
|
|
|
def test_float16_ignore_nan(self): |
|
|
|
|
|
|
|
offset = np.uint16(0xff) |
|
nan1_i16 = np.array(np.nan, dtype=np.float16).view(np.uint16) |
|
nan2_i16 = nan1_i16 ^ offset |
|
nan1_f16 = nan1_i16.view(np.float16) |
|
nan2_f16 = nan2_i16.view(np.float16) |
|
assert_array_max_ulp(nan1_f16, nan2_f16, 0) |
|
|
|
def test_complex128_pass(self): |
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float64) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
xi = x + x*1j |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp/2. |
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp) |
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp) |
|
|
|
|
|
y = x + x*eps*nulp/4. |
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp/2. |
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp) |
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp) |
|
y = x - x*epsneg*nulp/4. |
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp) |
|
|
|
def test_complex128_fail(self): |
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float64) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
xi = x + x*1j |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, x + y*1j, nulp) |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + x*1j, nulp) |
|
|
|
|
|
y = x + x*eps*nulp |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + y*1j, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, x + y*1j, nulp) |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + x*1j, nulp) |
|
y = x - x*epsneg*nulp |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + y*1j, nulp) |
|
|
|
def test_complex64_pass(self): |
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float32) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
xi = x + x*1j |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp/2. |
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp) |
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp) |
|
y = x + x*eps*nulp/4. |
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp/2. |
|
assert_array_almost_equal_nulp(xi, x + y*1j, nulp) |
|
assert_array_almost_equal_nulp(xi, y + x*1j, nulp) |
|
y = x - x*epsneg*nulp/4. |
|
assert_array_almost_equal_nulp(xi, y + y*1j, nulp) |
|
|
|
def test_complex64_fail(self): |
|
nulp = 5 |
|
x = np.linspace(-20, 20, 50, dtype=np.float32) |
|
x = 10**x |
|
x = np.r_[-x, x] |
|
xi = x + x*1j |
|
|
|
eps = np.finfo(x.dtype).eps |
|
y = x + x*eps*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, x + y*1j, nulp) |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + x*1j, nulp) |
|
y = x + x*eps*nulp |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + y*1j, nulp) |
|
|
|
epsneg = np.finfo(x.dtype).epsneg |
|
y = x - x*epsneg*nulp*2. |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, x + y*1j, nulp) |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + x*1j, nulp) |
|
y = x - x*epsneg*nulp |
|
assert_raises(AssertionError, assert_array_almost_equal_nulp, |
|
xi, y + y*1j, nulp) |
|
|
|
|
|
class TestULP: |
|
|
|
def test_equal(self): |
|
x = np.random.randn(10) |
|
assert_array_max_ulp(x, x, maxulp=0) |
|
|
|
def test_single(self): |
|
|
|
x = np.ones(10).astype(np.float32) |
|
x += 0.01 * np.random.randn(10).astype(np.float32) |
|
eps = np.finfo(np.float32).eps |
|
assert_array_max_ulp(x, x+eps, maxulp=20) |
|
|
|
def test_double(self): |
|
|
|
x = np.ones(10).astype(np.float64) |
|
x += 0.01 * np.random.randn(10).astype(np.float64) |
|
eps = np.finfo(np.float64).eps |
|
assert_array_max_ulp(x, x+eps, maxulp=200) |
|
|
|
def test_inf(self): |
|
for dt in [np.float32, np.float64]: |
|
inf = np.array([np.inf]).astype(dt) |
|
big = np.array([np.finfo(dt).max]) |
|
assert_array_max_ulp(inf, big, maxulp=200) |
|
|
|
def test_nan(self): |
|
|
|
for dt in [np.float32, np.float64]: |
|
if dt == np.float32: |
|
maxulp = 1e6 |
|
else: |
|
maxulp = 1e12 |
|
inf = np.array([np.inf]).astype(dt) |
|
nan = np.array([np.nan]).astype(dt) |
|
big = np.array([np.finfo(dt).max]) |
|
tiny = np.array([np.finfo(dt).tiny]) |
|
zero = np.array([0.0]).astype(dt) |
|
nzero = np.array([-0.0]).astype(dt) |
|
assert_raises(AssertionError, |
|
lambda: assert_array_max_ulp(nan, inf, |
|
maxulp=maxulp)) |
|
assert_raises(AssertionError, |
|
lambda: assert_array_max_ulp(nan, big, |
|
maxulp=maxulp)) |
|
assert_raises(AssertionError, |
|
lambda: assert_array_max_ulp(nan, tiny, |
|
maxulp=maxulp)) |
|
assert_raises(AssertionError, |
|
lambda: assert_array_max_ulp(nan, zero, |
|
maxulp=maxulp)) |
|
assert_raises(AssertionError, |
|
lambda: assert_array_max_ulp(nan, nzero, |
|
maxulp=maxulp)) |
|
|
|
|
|
class TestStringEqual: |
|
def test_simple(self): |
|
assert_string_equal("hello", "hello") |
|
assert_string_equal("hello\nmultiline", "hello\nmultiline") |
|
|
|
with pytest.raises(AssertionError) as exc_info: |
|
assert_string_equal("foo\nbar", "hello\nbar") |
|
msg = str(exc_info.value) |
|
assert_equal(msg, "Differences in strings:\n- foo\n+ hello") |
|
|
|
assert_raises(AssertionError, |
|
lambda: assert_string_equal("foo", "hello")) |
|
|
|
def test_regex(self): |
|
assert_string_equal("a+*b", "a+*b") |
|
|
|
assert_raises(AssertionError, |
|
lambda: assert_string_equal("aaa", "a+b")) |
|
|
|
|
|
def assert_warn_len_equal(mod, n_in_context): |
|
try: |
|
mod_warns = mod.__warningregistry__ |
|
except AttributeError: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mod_warns = {} |
|
|
|
num_warns = len(mod_warns) |
|
|
|
if 'version' in mod_warns: |
|
|
|
|
|
num_warns -= 1 |
|
|
|
assert_equal(num_warns, n_in_context) |
|
|
|
|
|
def test_warn_len_equal_call_scenarios(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class mod: |
|
pass |
|
|
|
mod_inst = mod() |
|
|
|
assert_warn_len_equal(mod=mod_inst, |
|
n_in_context=0) |
|
|
|
|
|
|
|
class mod: |
|
def __init__(self): |
|
self.__warningregistry__ = {'warning1': 1, |
|
'warning2': 2} |
|
|
|
mod_inst = mod() |
|
assert_warn_len_equal(mod=mod_inst, |
|
n_in_context=2) |
|
|
|
|
|
def _get_fresh_mod(): |
|
|
|
my_mod = sys.modules[__name__] |
|
try: |
|
my_mod.__warningregistry__.clear() |
|
except AttributeError: |
|
|
|
|
|
pass |
|
return my_mod |
|
|
|
|
|
def test_clear_and_catch_warnings(): |
|
|
|
my_mod = _get_fresh_mod() |
|
assert_equal(getattr(my_mod, '__warningregistry__', {}), {}) |
|
with clear_and_catch_warnings(modules=[my_mod]): |
|
warnings.simplefilter('ignore') |
|
warnings.warn('Some warning') |
|
assert_equal(my_mod.__warningregistry__, {}) |
|
|
|
|
|
with clear_and_catch_warnings(): |
|
warnings.simplefilter('ignore') |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
|
|
my_mod.__warningregistry__ = {'warning1': 1, |
|
'warning2': 2} |
|
|
|
|
|
with clear_and_catch_warnings(modules=[my_mod]): |
|
warnings.simplefilter('ignore') |
|
warnings.warn('Another warning') |
|
assert_warn_len_equal(my_mod, 2) |
|
|
|
|
|
with clear_and_catch_warnings(): |
|
warnings.simplefilter('ignore') |
|
warnings.warn('Another warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
|
|
def test_suppress_warnings_module(): |
|
|
|
my_mod = _get_fresh_mod() |
|
assert_equal(getattr(my_mod, '__warningregistry__', {}), {}) |
|
|
|
def warn_other_module(): |
|
|
|
|
|
def warn(arr): |
|
warnings.warn("Some warning 2", stacklevel=2) |
|
return arr |
|
np.apply_along_axis(warn, 0, [0]) |
|
|
|
|
|
assert_warn_len_equal(my_mod, 0) |
|
with suppress_warnings() as sup: |
|
sup.record(UserWarning) |
|
|
|
|
|
sup.filter(module=np.lib._shape_base_impl) |
|
warnings.warn("Some warning") |
|
warn_other_module() |
|
|
|
|
|
assert_equal(len(sup.log), 1) |
|
assert_equal(sup.log[0].message.args[0], "Some warning") |
|
assert_warn_len_equal(my_mod, 0) |
|
sup = suppress_warnings() |
|
|
|
sup.filter(module=my_mod) |
|
with sup: |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
sup.filter(module=my_mod) |
|
with sup: |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
|
|
with suppress_warnings(): |
|
warnings.simplefilter('ignore') |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
|
|
def test_suppress_warnings_type(): |
|
|
|
my_mod = _get_fresh_mod() |
|
assert_equal(getattr(my_mod, '__warningregistry__', {}), {}) |
|
|
|
|
|
with suppress_warnings() as sup: |
|
sup.filter(UserWarning) |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
sup = suppress_warnings() |
|
sup.filter(UserWarning) |
|
with sup: |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
sup.filter(module=my_mod) |
|
with sup: |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
|
|
with suppress_warnings(): |
|
warnings.simplefilter('ignore') |
|
warnings.warn('Some warning') |
|
assert_warn_len_equal(my_mod, 0) |
|
|
|
|
|
def test_suppress_warnings_decorate_no_record(): |
|
sup = suppress_warnings() |
|
sup.filter(UserWarning) |
|
|
|
@sup |
|
def warn(category): |
|
warnings.warn('Some warning', category) |
|
|
|
with warnings.catch_warnings(record=True) as w: |
|
warnings.simplefilter("always") |
|
warn(UserWarning) |
|
warn(RuntimeWarning) |
|
assert_equal(len(w), 1) |
|
|
|
|
|
def test_suppress_warnings_record(): |
|
sup = suppress_warnings() |
|
log1 = sup.record() |
|
|
|
with sup: |
|
log2 = sup.record(message='Some other warning 2') |
|
sup.filter(message='Some warning') |
|
warnings.warn('Some warning') |
|
warnings.warn('Some other warning') |
|
warnings.warn('Some other warning 2') |
|
|
|
assert_equal(len(sup.log), 2) |
|
assert_equal(len(log1), 1) |
|
assert_equal(len(log2), 1) |
|
assert_equal(log2[0].message.args[0], 'Some other warning 2') |
|
|
|
|
|
with sup: |
|
log2 = sup.record(message='Some other warning 2') |
|
sup.filter(message='Some warning') |
|
warnings.warn('Some warning') |
|
warnings.warn('Some other warning') |
|
warnings.warn('Some other warning 2') |
|
|
|
assert_equal(len(sup.log), 2) |
|
assert_equal(len(log1), 1) |
|
assert_equal(len(log2), 1) |
|
assert_equal(log2[0].message.args[0], 'Some other warning 2') |
|
|
|
|
|
with suppress_warnings() as sup: |
|
sup.record() |
|
with suppress_warnings() as sup2: |
|
sup2.record(message='Some warning') |
|
warnings.warn('Some warning') |
|
warnings.warn('Some other warning') |
|
assert_equal(len(sup2.log), 1) |
|
assert_equal(len(sup.log), 1) |
|
|
|
|
|
def test_suppress_warnings_forwarding(): |
|
def warn_other_module(): |
|
|
|
|
|
def warn(arr): |
|
warnings.warn("Some warning", stacklevel=2) |
|
return arr |
|
np.apply_along_axis(warn, 0, [0]) |
|
|
|
with suppress_warnings() as sup: |
|
sup.record() |
|
with suppress_warnings("always"): |
|
for i in range(2): |
|
warnings.warn("Some warning") |
|
|
|
assert_equal(len(sup.log), 2) |
|
|
|
with suppress_warnings() as sup: |
|
sup.record() |
|
with suppress_warnings("location"): |
|
for i in range(2): |
|
warnings.warn("Some warning") |
|
warnings.warn("Some warning") |
|
|
|
assert_equal(len(sup.log), 2) |
|
|
|
with suppress_warnings() as sup: |
|
sup.record() |
|
with suppress_warnings("module"): |
|
for i in range(2): |
|
warnings.warn("Some warning") |
|
warnings.warn("Some warning") |
|
warn_other_module() |
|
|
|
assert_equal(len(sup.log), 2) |
|
|
|
with suppress_warnings() as sup: |
|
sup.record() |
|
with suppress_warnings("once"): |
|
for i in range(2): |
|
warnings.warn("Some warning") |
|
warnings.warn("Some other warning") |
|
warn_other_module() |
|
|
|
assert_equal(len(sup.log), 2) |
|
|
|
|
|
def test_tempdir(): |
|
with tempdir() as tdir: |
|
fpath = os.path.join(tdir, 'tmp') |
|
with open(fpath, 'w'): |
|
pass |
|
assert_(not os.path.isdir(tdir)) |
|
|
|
raised = False |
|
try: |
|
with tempdir() as tdir: |
|
raise ValueError |
|
except ValueError: |
|
raised = True |
|
assert_(raised) |
|
assert_(not os.path.isdir(tdir)) |
|
|
|
|
|
def test_temppath(): |
|
with temppath() as fpath: |
|
with open(fpath, 'w'): |
|
pass |
|
assert_(not os.path.isfile(fpath)) |
|
|
|
raised = False |
|
try: |
|
with temppath() as fpath: |
|
raise ValueError |
|
except ValueError: |
|
raised = True |
|
assert_(raised) |
|
assert_(not os.path.isfile(fpath)) |
|
|
|
|
|
class my_cacw(clear_and_catch_warnings): |
|
|
|
class_modules = (sys.modules[__name__],) |
|
|
|
|
|
def test_clear_and_catch_warnings_inherit(): |
|
|
|
my_mod = _get_fresh_mod() |
|
with my_cacw(): |
|
warnings.simplefilter('ignore') |
|
warnings.warn('Some warning') |
|
assert_equal(my_mod.__warningregistry__, {}) |
|
|
|
|
|
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts") |
|
class TestAssertNoGcCycles: |
|
""" Test assert_no_gc_cycles """ |
|
|
|
def test_passes(self): |
|
def no_cycle(): |
|
b = [] |
|
b.append([]) |
|
return b |
|
|
|
with assert_no_gc_cycles(): |
|
no_cycle() |
|
|
|
assert_no_gc_cycles(no_cycle) |
|
|
|
def test_asserts(self): |
|
def make_cycle(): |
|
a = [] |
|
a.append(a) |
|
a.append(a) |
|
return a |
|
|
|
with assert_raises(AssertionError): |
|
with assert_no_gc_cycles(): |
|
make_cycle() |
|
|
|
with assert_raises(AssertionError): |
|
assert_no_gc_cycles(make_cycle) |
|
|
|
@pytest.mark.slow |
|
def test_fails(self): |
|
""" |
|
Test that in cases where the garbage cannot be collected, we raise an |
|
error, instead of hanging forever trying to clear it. |
|
""" |
|
|
|
class ReferenceCycleInDel: |
|
""" |
|
An object that not only contains a reference cycle, but creates new |
|
cycles whenever it's garbage-collected and its __del__ runs |
|
""" |
|
make_cycle = True |
|
|
|
def __init__(self): |
|
self.cycle = self |
|
|
|
def __del__(self): |
|
|
|
self.cycle = None |
|
|
|
if ReferenceCycleInDel.make_cycle: |
|
|
|
|
|
ReferenceCycleInDel() |
|
|
|
try: |
|
w = weakref.ref(ReferenceCycleInDel()) |
|
try: |
|
with assert_raises(RuntimeError): |
|
|
|
assert_no_gc_cycles(lambda: None) |
|
except AssertionError: |
|
|
|
|
|
if w() is not None: |
|
pytest.skip("GC does not call __del__ on cyclic objects") |
|
raise |
|
|
|
finally: |
|
|
|
ReferenceCycleInDel.make_cycle = False |
|
|
|
|
|
@pytest.mark.parametrize('assert_func', [assert_array_equal, |
|
assert_array_almost_equal]) |
|
def test_xy_rename(assert_func): |
|
|
|
|
|
|
|
assert_func(1, 1) |
|
assert_func(actual=1, desired=1) |
|
|
|
assert_message = "Arrays are not..." |
|
with pytest.raises(AssertionError, match=assert_message): |
|
assert_func(1, 2) |
|
with pytest.raises(AssertionError, match=assert_message): |
|
assert_func(actual=1, desired=2) |
|
|
|
dep_message = 'Use of keyword argument...' |
|
with pytest.warns(DeprecationWarning, match=dep_message): |
|
assert_func(x=1, desired=1) |
|
with pytest.warns(DeprecationWarning, match=dep_message): |
|
assert_func(1, y=1) |
|
|
|
type_message = '...got multiple values for argument' |
|
with (pytest.warns(DeprecationWarning, match=dep_message), |
|
pytest.raises(TypeError, match=type_message)): |
|
assert_func(1, x=1) |
|
assert_func(1, 2, y=2) |
|
|