|
import os |
|
|
|
from joblib._parallel_backends import ( |
|
LokyBackend, |
|
MultiprocessingBackend, |
|
ThreadingBackend, |
|
) |
|
from joblib.parallel import ( |
|
BACKENDS, |
|
DEFAULT_BACKEND, |
|
EXTERNAL_BACKENDS, |
|
Parallel, |
|
delayed, |
|
parallel_backend, |
|
parallel_config, |
|
) |
|
from joblib.test.common import np, with_multiprocessing, with_numpy |
|
from joblib.test.test_parallel import check_memmap |
|
from joblib.testing import parametrize, raises |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_global_parallel_backend(context): |
|
default = Parallel()._backend |
|
|
|
pb = context("threading") |
|
try: |
|
assert isinstance(Parallel()._backend, ThreadingBackend) |
|
finally: |
|
pb.unregister() |
|
assert type(Parallel()._backend) is type(default) |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_external_backends(context): |
|
def register_foo(): |
|
BACKENDS["foo"] = ThreadingBackend |
|
|
|
EXTERNAL_BACKENDS["foo"] = register_foo |
|
try: |
|
with context("foo"): |
|
assert isinstance(Parallel()._backend, ThreadingBackend) |
|
finally: |
|
del EXTERNAL_BACKENDS["foo"] |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
def test_parallel_config_no_backend(tmpdir): |
|
|
|
|
|
with parallel_config(n_jobs=2, max_nbytes=1, temp_folder=tmpdir): |
|
with Parallel(prefer="processes") as p: |
|
assert isinstance(p._backend, LokyBackend) |
|
assert p.n_jobs == 2 |
|
|
|
|
|
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2) |
|
assert len(os.listdir(tmpdir)) > 0 |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
def test_parallel_config_params_explicit_set(tmpdir): |
|
with parallel_config(n_jobs=3, max_nbytes=1, temp_folder=tmpdir): |
|
with Parallel(n_jobs=2, prefer="processes", max_nbytes="1M") as p: |
|
assert isinstance(p._backend, LokyBackend) |
|
assert p.n_jobs == 2 |
|
|
|
|
|
with raises(TypeError, match="Expected np.memmap instance"): |
|
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2) |
|
|
|
|
|
@parametrize("param", ["prefer", "require"]) |
|
def test_parallel_config_bad_params(param): |
|
|
|
|
|
with raises(ValueError, match=f"{param}=wrong is not a valid"): |
|
with parallel_config(**{param: "wrong"}): |
|
Parallel() |
|
|
|
|
|
def test_parallel_config_constructor_params(): |
|
|
|
|
|
with raises(ValueError, match="only supported when backend is not None"): |
|
with parallel_config(inner_max_num_threads=1): |
|
pass |
|
|
|
with raises(ValueError, match="only supported when backend is not None"): |
|
with parallel_config(backend_param=1): |
|
pass |
|
|
|
with raises(ValueError, match="only supported when backend is a string"): |
|
with parallel_config(backend=BACKENDS[DEFAULT_BACKEND], backend_param=1): |
|
pass |
|
|
|
|
|
def test_parallel_config_nested(): |
|
|
|
|
|
|
|
with parallel_config(n_jobs=2): |
|
p = Parallel() |
|
assert isinstance(p._backend, BACKENDS[DEFAULT_BACKEND]) |
|
assert p.n_jobs == 2 |
|
|
|
with parallel_config(backend="threading"): |
|
with parallel_config(n_jobs=2): |
|
p = Parallel() |
|
assert isinstance(p._backend, ThreadingBackend) |
|
assert p.n_jobs == 2 |
|
|
|
with parallel_config(verbose=100): |
|
with parallel_config(n_jobs=2): |
|
p = Parallel() |
|
assert p.verbose == 100 |
|
assert p.n_jobs == 2 |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
@parametrize( |
|
"backend", |
|
["multiprocessing", "threading", MultiprocessingBackend(), ThreadingBackend()], |
|
) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_threadpool_limitation_in_child_context_error(context, backend): |
|
with raises(AssertionError, match=r"does not acc.*inner_max_num_threads"): |
|
context(backend, inner_max_num_threads=1) |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_parallel_n_jobs_none(context): |
|
|
|
|
|
with context(backend="threading", n_jobs=2): |
|
with Parallel(n_jobs=None) as p: |
|
assert p.n_jobs == 2 |
|
|
|
with context(backend="threading"): |
|
default_n_jobs = Parallel().n_jobs |
|
with Parallel(n_jobs=None) as p: |
|
assert p.n_jobs == default_n_jobs |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_parallel_config_n_jobs_none(context): |
|
|
|
|
|
|
|
with context(backend="threading", n_jobs=2): |
|
with context(backend="threading", n_jobs=None): |
|
|
|
with Parallel() as p: |
|
assert p.n_jobs == 1 |
|
|