|
import pytest |
|
|
|
import numpy as np |
|
|
|
from scipy.optimize._bracket import _ELIMITS |
|
from scipy.optimize.elementwise import bracket_root, bracket_minimum |
|
import scipy._lib._elementwise_iterative_method as eim |
|
from scipy import stats |
|
from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal, |
|
xp_assert_less, array_namespace) |
|
from scipy._lib._array_api import xp_ravel |
|
from scipy.conftest import array_api_compatible |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bracket_root(*args, **kwargs): |
|
res = bracket_root(*args, **kwargs) |
|
res.xl, res.xr = res.bracket |
|
res.fl, res.fr = res.f_bracket |
|
del res.bracket |
|
del res.f_bracket |
|
return res |
|
|
|
|
|
def _bracket_minimum(*args, **kwargs): |
|
res = bracket_minimum(*args, **kwargs) |
|
res.xl, res.xm, res.xr = res.bracket |
|
res.fl, res.fm, res.fr = res.f_bracket |
|
del res.bracket |
|
del res.f_bracket |
|
return res |
|
|
|
|
|
array_api_strict_skip_reason = 'Array API does not support fancy indexing assignment.' |
|
jax_skip_reason = 'JAX arrays do not support item assignment.' |
|
|
|
@pytest.mark.skip_xp_backends('array_api_strict', reason=array_api_strict_skip_reason) |
|
@pytest.mark.skip_xp_backends('jax.numpy', reason=jax_skip_reason) |
|
@array_api_compatible |
|
@pytest.mark.usefixtures("skip_xp_backends") |
|
class TestBracketRoot: |
|
@pytest.mark.parametrize("seed", (615655101, 3141866013, 238075752)) |
|
@pytest.mark.parametrize("use_xmin", (False, True)) |
|
@pytest.mark.parametrize("other_side", (False, True)) |
|
@pytest.mark.parametrize("fix_one_side", (False, True)) |
|
def test_nfev_expected(self, seed, use_xmin, other_side, fix_one_side, xp): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rng = np.random.default_rng(seed) |
|
xl0, d, factor = xp.asarray(rng.random(size=3) * [1e5, 10, 5]) |
|
factor = 1 + factor |
|
xr0 = xl0 + d |
|
|
|
def f(x): |
|
f.count += 1 |
|
return x |
|
|
|
if use_xmin: |
|
xmin = xp.asarray(-rng.random()) |
|
n = xp.ceil(xp.log(-(xl0 - xmin) / xmin) / xp.log(factor)) |
|
l, u = xmin + (xl0 - xmin)*factor**-n, xmin + (xl0 - xmin)*factor**-(n - 1) |
|
kwargs = dict(xl0=xl0, xr0=xr0, factor=factor, xmin=xmin) |
|
else: |
|
n = xp.ceil(xp.log(xr0/d) / xp.log(factor)) |
|
l, u = xr0 - d*factor**n, xr0 - d*factor**(n-1) |
|
kwargs = dict(xl0=xl0, xr0=xr0, factor=factor) |
|
|
|
if other_side: |
|
kwargs['xl0'], kwargs['xr0'] = -kwargs['xr0'], -kwargs['xl0'] |
|
l, u = -u, -l |
|
if 'xmin' in kwargs: |
|
kwargs['xmax'] = -kwargs.pop('xmin') |
|
|
|
if fix_one_side: |
|
if other_side: |
|
kwargs['xmin'] = -xr0 |
|
else: |
|
kwargs['xmax'] = xr0 |
|
|
|
f.count = 0 |
|
res = _bracket_root(f, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not fix_one_side: |
|
assert res.nfev == 2*(res.nit+1) == 2*(f.count-1) == 2*(n + 1) |
|
else: |
|
assert res.nfev == (res.nit+1)+1 == (f.count-1)+1 == (n+1)+1 |
|
|
|
|
|
|
|
bracket = xp.asarray([res.xl, res.xr]) |
|
xp_assert_close(bracket, xp.asarray([l, u])) |
|
f_bracket = xp.asarray([res.fl, res.fr]) |
|
xp_assert_close(f_bracket, f(bracket)) |
|
|
|
|
|
assert res.xr > res.xl |
|
signs = xp.sign(f_bracket) |
|
assert signs[0] == -signs[1] |
|
assert res.status == 0 |
|
assert res.success |
|
|
|
def f(self, q, p): |
|
return stats._stats_py._SimpleNormal().cdf(q) - p |
|
|
|
@pytest.mark.parametrize('p', [0.6, np.linspace(0.05, 0.95, 10)]) |
|
@pytest.mark.parametrize('xmin', [-5, None]) |
|
@pytest.mark.parametrize('xmax', [5, None]) |
|
@pytest.mark.parametrize('factor', [1.2, 2]) |
|
def test_basic(self, p, xmin, xmax, factor, xp): |
|
|
|
res = _bracket_root(self.f, xp.asarray(-0.01), 0.01, xmin=xmin, xmax=xmax, |
|
factor=factor, args=(xp.asarray(p),)) |
|
xp_assert_equal(-xp.sign(res.fl), xp.sign(res.fr)) |
|
|
|
@pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)]) |
|
def test_vectorization(self, shape, xp): |
|
|
|
|
|
p = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else np.float64(0.6) |
|
args = (p,) |
|
maxiter = 10 |
|
|
|
@np.vectorize |
|
def bracket_root_single(xl0, xr0, xmin, xmax, factor, p): |
|
return _bracket_root(self.f, xl0, xr0, xmin=xmin, xmax=xmax, |
|
factor=factor, args=(p,), |
|
maxiter=maxiter) |
|
|
|
def f(*args, **kwargs): |
|
f.f_evals += 1 |
|
return self.f(*args, **kwargs) |
|
f.f_evals = 0 |
|
|
|
rng = np.random.default_rng(2348234) |
|
xl0 = -rng.random(size=shape) |
|
xr0 = rng.random(size=shape) |
|
xmin, xmax = 1e3*xl0, 1e3*xr0 |
|
if shape: |
|
i = rng.random(size=shape) > 0.5 |
|
xmin[i], xmax[i] = -np.inf, np.inf |
|
factor = rng.random(size=shape) + 1.5 |
|
refs = bracket_root_single(xl0, xr0, xmin, xmax, factor, p).ravel() |
|
xl0, xr0, xmin, xmax, factor = (xp.asarray(xl0), xp.asarray(xr0), |
|
xp.asarray(xmin), xp.asarray(xmax), |
|
xp.asarray(factor)) |
|
args = tuple(map(xp.asarray, args)) |
|
res = _bracket_root(f, xl0, xr0, xmin=xmin, xmax=xmax, factor=factor, |
|
args=args, maxiter=maxiter) |
|
|
|
attrs = ['xl', 'xr', 'fl', 'fr', 'success', 'nfev', 'nit'] |
|
for attr in attrs: |
|
ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs] |
|
res_attr = getattr(res, attr) |
|
xp_assert_close(xp_ravel(res_attr, xp=xp), xp.stack(ref_attr)) |
|
xp_assert_equal(res_attr.shape, shape) |
|
|
|
xp_test = array_namespace(xp.asarray(1.)) |
|
assert res.success.dtype == xp_test.bool |
|
if shape: |
|
assert xp.all(res.success[1:-1]) |
|
assert res.status.dtype == xp.int32 |
|
assert res.nfev.dtype == xp.int32 |
|
assert res.nit.dtype == xp.int32 |
|
assert xp.max(res.nit) == f.f_evals - 2 |
|
xp_assert_less(res.xl, res.xr) |
|
xp_assert_close(res.fl, xp.asarray(self.f(res.xl, *args))) |
|
xp_assert_close(res.fr, xp.asarray(self.f(res.xr, *args))) |
|
|
|
def test_flags(self, xp): |
|
|
|
|
|
def f(xs, js): |
|
funcs = [lambda x: x - 1.5, |
|
lambda x: x - 1000, |
|
lambda x: x - 1000, |
|
lambda x: x * xp.nan, |
|
lambda x: x] |
|
|
|
return [funcs[int(j)](x) for x, j in zip(xs, js)] |
|
|
|
args = (xp.arange(5, dtype=xp.int64),) |
|
res = _bracket_root(f, |
|
xl0=xp.asarray([-1., -1., -1., -1., 4.]), |
|
xr0=xp.asarray([1, 1, 1, 1, -4]), |
|
xmin=xp.asarray([-xp.inf, -1, -xp.inf, -xp.inf, 6]), |
|
xmax=xp.asarray([xp.inf, 1, xp.inf, xp.inf, 2]), |
|
args=args, maxiter=3) |
|
|
|
ref_flags = xp.asarray([eim._ECONVERGED, |
|
_ELIMITS, |
|
eim._ECONVERR, |
|
eim._EVALUEERR, |
|
eim._EINPUTERR], |
|
dtype=xp.int32) |
|
|
|
xp_assert_equal(res.status, ref_flags) |
|
|
|
@pytest.mark.parametrize("root", (0.622, [0.622, 0.623])) |
|
@pytest.mark.parametrize('xmin', [-5, None]) |
|
@pytest.mark.parametrize('xmax', [5, None]) |
|
@pytest.mark.parametrize("dtype", ("float16", "float32", "float64")) |
|
def test_dtype(self, root, xmin, xmax, dtype, xp): |
|
|
|
dtype = getattr(xp, dtype) |
|
xp_test = array_namespace(xp.asarray(1.)) |
|
|
|
xmin = xmin if xmin is None else xp.asarray(xmin, dtype=dtype) |
|
xmax = xmax if xmax is None else xp.asarray(xmax, dtype=dtype) |
|
root = xp.asarray(root, dtype=dtype) |
|
def f(x, root): |
|
return xp_test.astype((x - root) ** 3, dtype) |
|
|
|
bracket = xp.asarray([-0.01, 0.01], dtype=dtype) |
|
res = _bracket_root(f, *bracket, xmin=xmin, xmax=xmax, args=(root,)) |
|
assert xp.all(res.success) |
|
assert res.xl.dtype == res.xr.dtype == dtype |
|
assert res.fl.dtype == res.fr.dtype == dtype |
|
|
|
def test_input_validation(self, xp): |
|
|
|
|
|
message = '`func` must be callable.' |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(None, -4, 4) |
|
|
|
message = '...must be numeric and real.' |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4+1j, 4) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 'hello') |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 4, xmin=np) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 4, xmax=object()) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 4, factor=sum) |
|
|
|
message = "All elements of `factor` must be greater than 1." |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 4, factor=0.5) |
|
|
|
message = "broadcast" |
|
|
|
with pytest.raises(Exception, match=message): |
|
_bracket_root(lambda x: x, xp.asarray([-2, -3]), xp.asarray([3, 4, 5])) |
|
|
|
|
|
|
|
|
|
message = '`maxiter` must be a non-negative integer.' |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 4, maxiter=1.5) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 4, maxiter=-1) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_root(lambda x: x, -4, 4, maxiter="shrubbery") |
|
|
|
def test_special_cases(self, xp): |
|
|
|
xp_test = array_namespace(xp.asarray(1.)) |
|
|
|
|
|
|
|
def f(x): |
|
assert xp_test.isdtype(x.dtype, "real floating") |
|
return x ** 99 - 1 |
|
|
|
res = _bracket_root(f, xp.asarray(-7.), xp.asarray(5.)) |
|
assert res.success |
|
|
|
|
|
def f(x): |
|
return x - 10 |
|
|
|
bracket = (xp.asarray(-3.), xp.asarray(5.)) |
|
res = _bracket_root(f, *bracket, maxiter=0) |
|
assert res.xl, res.xr == bracket |
|
assert res.nit == 0 |
|
assert res.nfev == 2 |
|
assert res.status == -2 |
|
|
|
|
|
def f(x, c): |
|
return c*x - 1 |
|
|
|
res = _bracket_root(f, xp.asarray(-1.), xp.asarray(1.), |
|
args=xp.asarray(3.)) |
|
assert res.success |
|
xp_assert_close(res.fl, f(res.xl, 3)) |
|
|
|
|
|
|
|
def f(x): |
|
f.count += 1 |
|
return x |
|
|
|
|
|
f.count = 0 |
|
_bracket_root(f, xp.asarray(-10), xp.asarray(20)) |
|
assert f.count == 2 |
|
|
|
|
|
f.count = 0 |
|
res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.), |
|
factor=2) |
|
|
|
assert res.nfev == 4 |
|
xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15) |
|
xp_assert_close(res.xr, xp.asarray(5.), atol=1e-15) |
|
|
|
|
|
with np.errstate(over='ignore'): |
|
res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.), |
|
xmin=0) |
|
xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15) |
|
|
|
with np.errstate(over='ignore'): |
|
res = _bracket_root(f, xp.asarray(-10.), xp.asarray(-5.), |
|
xmax=0) |
|
xp_assert_close(res.xr, xp.asarray(0.), atol=1e-15) |
|
|
|
|
|
with np.errstate(over='ignore'): |
|
res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.), |
|
xmin=1) |
|
assert not res.success |
|
|
|
|
|
@pytest.mark.skip_xp_backends('array_api_strict', reason=array_api_strict_skip_reason) |
|
@pytest.mark.skip_xp_backends('jax.numpy', reason=jax_skip_reason) |
|
@array_api_compatible |
|
@pytest.mark.usefixtures("skip_xp_backends") |
|
class TestBracketMinimum: |
|
def init_f(self): |
|
def f(x, a, b): |
|
f.count += 1 |
|
return (x - a)**2 + b |
|
f.count = 0 |
|
return f |
|
|
|
def assert_valid_bracket(self, result, xp): |
|
assert xp.all( |
|
(result.xl < result.xm) & (result.xm < result.xr) |
|
) |
|
assert xp.all( |
|
(result.fl >= result.fm) & (result.fr > result.fm) |
|
| (result.fl > result.fm) & (result.fr > result.fm) |
|
) |
|
|
|
def get_kwargs( |
|
self, *, xl0=None, xr0=None, factor=None, xmin=None, xmax=None, args=None |
|
): |
|
names = ("xl0", "xr0", "xmin", "xmax", "factor", "args") |
|
return { |
|
name: val for name, val in zip(names, (xl0, xr0, xmin, xmax, factor, args)) |
|
if val is not None |
|
} |
|
|
|
@pytest.mark.parametrize( |
|
"seed", |
|
( |
|
307448016549685229886351382450158984917, |
|
11650702770735516532954347931959000479, |
|
113767103358505514764278732330028568336, |
|
) |
|
) |
|
@pytest.mark.parametrize("use_xmin", (False, True)) |
|
@pytest.mark.parametrize("other_side", (False, True)) |
|
def test_nfev_expected(self, seed, use_xmin, other_side, xp): |
|
rng = np.random.default_rng(seed) |
|
args = (xp.asarray(0.), xp.asarray(0.)) |
|
|
|
|
|
|
|
xl0, d1, d2, factor = xp.asarray(rng.random(size=4) * [1e5, 10, 10, 5]) |
|
xm0 = xl0 + d1 |
|
xr0 = xm0 + d2 |
|
|
|
factor += 1 |
|
|
|
if use_xmin: |
|
xmin = xp.asarray(-rng.random() * 5, dtype=xp.float64) |
|
n = int(xp.ceil(xp.log(-(xl0 - xmin) / xmin) / xp.log(factor))) |
|
lower = xmin + (xl0 - xmin)*factor**-n |
|
middle = xmin + (xl0 - xmin)*factor**-(n-1) |
|
upper = xmin + (xl0 - xmin)*factor**-(n-2) if n > 1 else xm0 |
|
|
|
|
|
if middle**2 > lower**2: |
|
n += 1 |
|
lower, middle, upper = ( |
|
xmin + (xl0 - xmin)*factor**-n, lower, middle |
|
) |
|
else: |
|
xmin = None |
|
n = int(xp.ceil(xp.log(xl0 / d1) / xp.log(factor))) |
|
lower = xl0 - d1*factor**n |
|
middle = xl0 - d1*factor**(n-1) if n > 1 else xl0 |
|
upper = xl0 - d1*factor**(n-2) if n > 1 else xm0 |
|
|
|
|
|
if middle**2 > lower**2: |
|
n += 1 |
|
lower, middle, upper = ( |
|
xl0 - d1*factor**n, lower, middle |
|
) |
|
f = self.init_f() |
|
|
|
xmax = None |
|
if other_side: |
|
xl0, xm0, xr0 = -xr0, -xm0, -xl0 |
|
xmin, xmax = None, -xmin if xmin is not None else None |
|
lower, middle, upper = -upper, -middle, -lower |
|
|
|
kwargs = self.get_kwargs( |
|
xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, factor=factor, args=args |
|
) |
|
result = _bracket_minimum(f, xp.asarray(xm0), **kwargs) |
|
|
|
|
|
assert result.nfev == result.nit + 3 |
|
|
|
assert result.nfev == f.count |
|
|
|
assert result.nit == n |
|
|
|
|
|
|
|
xp_assert_close(result.xl, lower) |
|
xp_assert_close(result.xm, middle) |
|
xp_assert_close(result.xr, upper) |
|
xp_assert_close(result.fl, f(lower, *args)) |
|
xp_assert_close(result.fm, f(middle, *args)) |
|
xp_assert_close(result.fr, f(upper, *args)) |
|
|
|
self.assert_valid_bracket(result, xp) |
|
assert result.status == 0 |
|
assert result.success |
|
|
|
def test_flags(self, xp): |
|
|
|
|
|
def f(xs, js): |
|
funcs = [lambda x: (x - 1.5)**2, |
|
lambda x: x, |
|
lambda x: x, |
|
lambda x: xp.nan, |
|
lambda x: x**2] |
|
|
|
return [funcs[j](x) for x, j in zip(xs, js)] |
|
|
|
args = (xp.arange(5, dtype=xp.int64),) |
|
xl0 = xp.asarray([-1.0, -1.0, -1.0, -1.0, 6.0]) |
|
xm0 = xp.asarray([0.0, 0.0, 0.0, 0.0, 4.0]) |
|
xr0 = xp.asarray([1.0, 1.0, 1.0, 1.0, 2.0]) |
|
xmin = xp.asarray([-xp.inf, -1.0, -xp.inf, -xp.inf, 8.0]) |
|
|
|
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, xmin=xmin, |
|
args=args, maxiter=3) |
|
|
|
reference_flags = xp.asarray([eim._ECONVERGED, _ELIMITS, |
|
eim._ECONVERR, eim._EVALUEERR, |
|
eim._EINPUTERR], dtype=xp.int32) |
|
xp_assert_equal(result.status, reference_flags) |
|
|
|
@pytest.mark.parametrize("minimum", (0.622, [0.622, 0.623])) |
|
@pytest.mark.parametrize("dtype", ("float16", "float32", "float64")) |
|
@pytest.mark.parametrize("xmin", [-5, None]) |
|
@pytest.mark.parametrize("xmax", [5, None]) |
|
def test_dtypes(self, minimum, xmin, xmax, dtype, xp): |
|
dtype = getattr(xp, dtype) |
|
xp_test = array_namespace(xp.asarray(1.)) |
|
xmin = xmin if xmin is None else xp.asarray(xmin, dtype=dtype) |
|
xmax = xmax if xmax is None else xp.asarray(xmax, dtype=dtype) |
|
minimum = xp.asarray(minimum, dtype=dtype) |
|
|
|
def f(x, minimum): |
|
return xp_test.astype((x - minimum)**2, dtype) |
|
|
|
xl0, xm0, xr0 = [-0.01, 0.0, 0.01] |
|
result = _bracket_minimum( |
|
f, xp.asarray(xm0, dtype=dtype), xl0=xp.asarray(xl0, dtype=dtype), |
|
xr0=xp.asarray(xr0, dtype=dtype), xmin=xmin, xmax=xmax, args=(minimum, ) |
|
) |
|
assert xp.all(result.success) |
|
assert result.xl.dtype == result.xm.dtype == result.xr.dtype == dtype |
|
assert result.fl.dtype == result.fm.dtype == result.fr.dtype == dtype |
|
|
|
@pytest.mark.skip_xp_backends(np_only=True, reason="str/object arrays") |
|
def test_input_validation(self, xp): |
|
|
|
|
|
message = '`func` must be callable.' |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(None, -4, xl0=4) |
|
|
|
message = '...must be numeric and real.' |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(4+1j)) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xl0='hello') |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), |
|
xr0='farcical aquatic ceremony') |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xmin=np) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xmax=object()) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), factor=sum) |
|
|
|
message = "All elements of `factor` must be greater than 1." |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x, xp.asarray(-4), factor=0.5) |
|
|
|
message = "shape mismatch: objects cannot be broadcast" |
|
|
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray([-2, -3]), xl0=[-3, -4, -5]) |
|
|
|
message = '`maxiter` must be a non-negative integer.' |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter=1.5) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter=-1) |
|
with pytest.raises(ValueError, match=message): |
|
_bracket_minimum(lambda x: x**2, xp.asarray(-4), xr0=4, maxiter="ekki") |
|
|
|
@pytest.mark.parametrize("xl0", [0.0, None]) |
|
@pytest.mark.parametrize("xm0", (0.05, 0.1, 0.15)) |
|
@pytest.mark.parametrize("xr0", (0.2, 0.4, 0.6, None)) |
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
"args", |
|
( |
|
(1.2, 0), (-0.5, 0), (0.1, 0), (0.2, 0), (3.6, 0), (21.4, 0), |
|
(121.6, 0), (5764.1, 0), (-6.4, 0), (-12.9, 0), (-146.2, 0) |
|
) |
|
) |
|
def test_scalar_no_limits(self, xl0, xm0, xr0, args, xp): |
|
f = self.init_f() |
|
kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, args=tuple(map(xp.asarray, args))) |
|
result = _bracket_minimum(f, xp.asarray(xm0, dtype=xp.float64), **kwargs) |
|
self.assert_valid_bracket(result, xp) |
|
assert result.status == 0 |
|
assert result.success |
|
assert result.nfev == f.count |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"xl0,xm0,xr0,xmin", |
|
( |
|
|
|
(0.5, 0.75, 1.0, 0.0), |
|
(1.0, 2.5, 4.0, 0.0), |
|
(2.0, 4.0, 6.0, 0.0), |
|
(12.0, 16.0, 20.0, 0.0), |
|
|
|
|
|
(None, 0.75, 1.0, 0.0), |
|
(None, 2.5, 4.0, 0.0), |
|
(None, 4.0, 6.0, 0.0), |
|
(None, 16.0, 20.0, 0.0), |
|
) |
|
) |
|
@pytest.mark.parametrize( |
|
"args", ( |
|
(0.0, 0.0), |
|
(1e-300, 0.0), |
|
(1e-20, 0.0), |
|
|
|
(0.1, 0.0), |
|
(0.2, 0.0), |
|
(0.4, 0.0) |
|
) |
|
) |
|
def test_scalar_with_limit_left(self, xl0, xm0, xr0, xmin, args, xp): |
|
f = self.init_f() |
|
kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmin=xmin, |
|
args=tuple(map(xp.asarray, args))) |
|
result = _bracket_minimum(f, xp.asarray(xm0), **kwargs) |
|
self.assert_valid_bracket(result, xp) |
|
assert result.status == 0 |
|
assert result.success |
|
assert result.nfev == f.count |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"xl0,xm0,xr0,xmax", |
|
( |
|
|
|
(0.2, 0.3, 0.4, 1.0), |
|
(0.05, 0.075, 0.1, 1.0), |
|
(-0.2, -0.1, 0.0, 1.0), |
|
(-21.2, -17.7, -14.2, 1.0), |
|
|
|
(0.2, 0.3, None, 1.0), |
|
(0.05, 0.075, None, 1.0), |
|
(-0.2, -0.1, None, 1.0), |
|
(-21.2, -17.7, None, 1.0), |
|
) |
|
) |
|
@pytest.mark.parametrize( |
|
"args", ( |
|
(0.9999999999999999, 0.0), |
|
|
|
(0.9, 0.0), |
|
(0.7, 0.0), |
|
(0.5, 0.0) |
|
) |
|
) |
|
def test_scalar_with_limit_right(self, xl0, xm0, xr0, xmax, args, xp): |
|
f = self.init_f() |
|
args = tuple(xp.asarray(arg, dtype=xp.float64) for arg in args) |
|
kwargs = self.get_kwargs(xl0=xl0, xr0=xr0, xmax=xmax, args=args) |
|
result = _bracket_minimum(f, xp.asarray(xm0, dtype=xp.float64), **kwargs) |
|
self.assert_valid_bracket(result, xp) |
|
assert result.status == 0 |
|
assert result.success |
|
assert result.nfev == f.count |
|
|
|
@pytest.mark.parametrize( |
|
"xl0,xm0,xr0,xmin,xmax,args", |
|
( |
|
( |
|
|
|
0.2, |
|
0.3, |
|
0.4, |
|
|
|
|
|
None, |
|
1.0, |
|
(1.0, 0.0) |
|
), |
|
( |
|
|
|
1.4, |
|
1.95, |
|
2.5, |
|
|
|
|
|
0.3, |
|
None, |
|
(0.3, 0.0) |
|
), |
|
( |
|
|
|
|
|
2.6, |
|
3.25, |
|
3.9, |
|
|
|
|
|
|
|
None, |
|
99.4, |
|
(99.4, 0) |
|
), |
|
( |
|
|
|
|
|
4, |
|
4.5, |
|
5, |
|
|
|
|
|
|
|
-26.3, |
|
None, |
|
(-26.3, 0) |
|
), |
|
( |
|
|
|
|
|
None, |
|
0.3, |
|
None, |
|
None, |
|
1.0, |
|
(1.0, 0.0) |
|
), |
|
( |
|
|
|
None, |
|
1.95, |
|
None, |
|
0.3, |
|
None, |
|
(0.3, 0.0) |
|
), |
|
( |
|
|
|
|
|
None, |
|
3.25, |
|
None, |
|
None, |
|
99.4, |
|
(99.4, 0) |
|
), |
|
( |
|
|
|
|
|
None, |
|
4.5, |
|
None, |
|
-26.3, |
|
None, |
|
(-26.3, 0) |
|
), |
|
) |
|
) |
|
def test_minimum_at_boundary_point(self, xl0, xm0, xr0, xmin, xmax, args, xp): |
|
f = self.init_f() |
|
kwargs = self.get_kwargs(xr0=xr0, xmin=xmin, xmax=xmax, |
|
args=tuple(map(xp.asarray, args))) |
|
result = _bracket_minimum(f, xp.asarray(xm0), **kwargs) |
|
assert result.status == -1 |
|
assert args[0] in (result.xl, result.xr) |
|
assert result.nfev == f.count |
|
|
|
@pytest.mark.parametrize('shape', [tuple(), (12, ), (3, 4), (3, 2, 2)]) |
|
def test_vectorization(self, shape, xp): |
|
|
|
|
|
a = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6 |
|
args = (a, 0.) |
|
maxiter = 10 |
|
|
|
@np.vectorize |
|
def bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a): |
|
return _bracket_minimum(self.init_f(), xm0, xl0=xl0, xr0=xr0, xmin=xmin, |
|
xmax=xmax, factor=factor, maxiter=maxiter, |
|
args=(a, 0.0)) |
|
|
|
f = self.init_f() |
|
|
|
rng = np.random.default_rng(2348234) |
|
xl0 = -rng.random(size=shape) |
|
xr0 = rng.random(size=shape) |
|
xm0 = xl0 + rng.random(size=shape) * (xr0 - xl0) |
|
xmin, xmax = 1e3*xl0, 1e3*xr0 |
|
if shape: |
|
i = rng.random(size=shape) > 0.5 |
|
xmin[i], xmax[i] = -np.inf, np.inf |
|
factor = rng.random(size=shape) + 1.5 |
|
refs = bracket_minimum_single(xm0, xl0, xr0, xmin, xmax, factor, a).ravel() |
|
args = tuple(xp.asarray(arg, dtype=xp.float64) for arg in args) |
|
res = _bracket_minimum(f, xp.asarray(xm0), xl0=xl0, xr0=xr0, xmin=xmin, |
|
xmax=xmax, factor=factor, args=args, maxiter=maxiter) |
|
|
|
attrs = ['xl', 'xm', 'xr', 'fl', 'fm', 'fr', 'success', 'nfev', 'nit'] |
|
for attr in attrs: |
|
ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs] |
|
res_attr = getattr(res, attr) |
|
xp_assert_close(xp_ravel(res_attr, xp=xp), xp.stack(ref_attr)) |
|
xp_assert_equal(res_attr.shape, shape) |
|
|
|
xp_test = array_namespace(xp.asarray(1.)) |
|
assert res.success.dtype == xp_test.bool |
|
if shape: |
|
assert xp.all(res.success[1:-1]) |
|
assert res.status.dtype == xp.int32 |
|
assert res.nfev.dtype == xp.int32 |
|
assert res.nit.dtype == xp.int32 |
|
assert xp.max(res.nit) == f.count - 3 |
|
self.assert_valid_bracket(res, xp) |
|
xp_assert_close(res.fl, f(res.xl, *args)) |
|
xp_assert_close(res.fm, f(res.xm, *args)) |
|
xp_assert_close(res.fr, f(res.xr, *args)) |
|
|
|
def test_special_cases(self, xp): |
|
|
|
xp_test = array_namespace(xp.asarray(1.)) |
|
|
|
|
|
|
|
def f(x): |
|
assert xp_test.isdtype(x.dtype, "numeric") |
|
return x ** 98 - 1 |
|
|
|
result = _bracket_minimum(f, xp.asarray(-7., dtype=xp.float64), xr0=5) |
|
assert result.success |
|
|
|
|
|
def f(x): |
|
return x**2 - 10 |
|
|
|
xl0, xm0, xr0 = xp.asarray(-3.), xp.asarray(-1.), xp.asarray(2.) |
|
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, maxiter=0) |
|
xp_assert_equal(result.xl, xl0) |
|
xp_assert_equal(result.xm, xm0) |
|
xp_assert_equal(result.xr, xr0) |
|
|
|
|
|
def f(x, c): |
|
return c*x**2 - 1 |
|
|
|
result = _bracket_minimum(f, xp.asarray(-1.), args=xp.asarray(3.)) |
|
assert result.success |
|
xp_assert_close(result.fl, f(result.xl, 3)) |
|
|
|
|
|
f = self.init_f() |
|
xl0, xm0, xr0 = xp.asarray(-1.0), xp.asarray(-0.2), xp.asarray(1.0) |
|
args = (xp.asarray(0.), xp.asarray(0.)) |
|
result = _bracket_minimum(f, xm0, xl0=xl0, xr0=xr0, args=args) |
|
assert f.count == 3 |
|
|
|
xp_assert_equal(result.xl, xl0) |
|
xp_assert_equal(result.xm , xm0) |
|
xp_assert_equal(result.xr, xr0) |
|
xp_assert_equal(result.fl, f(xl0, *args)) |
|
xp_assert_equal(result.fm, f(xm0, *args)) |
|
xp_assert_equal(result.fr, f(xr0, *args)) |
|
|
|
def test_gh_20562_left(self, xp): |
|
|
|
|
|
xmin, xmax = xp.asarray(0.21933608), xp.asarray(1.39713606) |
|
|
|
def f(x): |
|
log_a, log_b = xp.log(xmin), xp.log(xmax) |
|
return -((log_b - log_a)*x)**-1 |
|
|
|
result = _bracket_minimum(f, xp.asarray(0.5535723499480897), xmin=xmin, |
|
xmax=xmax) |
|
assert xmin == result.xl |
|
|
|
def test_gh_20562_right(self, xp): |
|
|
|
|
|
xmin, xmax = xp.asarray(-1.39713606), xp.asarray(-0.21933608) |
|
|
|
def f(x): |
|
log_a, log_b = xp.log(-xmax), xp.log(-xmin) |
|
return ((log_b - log_a)*x)**-1 |
|
|
|
result = _bracket_minimum(f, xp.asarray(-0.5535723499480897), |
|
xmin=xmin, xmax=xmax) |
|
assert xmax == result.xr |
|
|