|
""" |
|
Test the parallel module. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import mmap |
|
import os |
|
import re |
|
import sys |
|
import threading |
|
import time |
|
import warnings |
|
import weakref |
|
from contextlib import nullcontext |
|
from math import sqrt |
|
from multiprocessing import TimeoutError |
|
from pickle import PicklingError |
|
from time import sleep |
|
from traceback import format_exception |
|
|
|
import pytest |
|
|
|
import joblib |
|
from joblib import dump, load, parallel |
|
from joblib._multiprocessing_helpers import mp |
|
from joblib.test.common import ( |
|
IS_GIL_DISABLED, |
|
np, |
|
with_multiprocessing, |
|
with_numpy, |
|
) |
|
from joblib.testing import check_subprocess_call, parametrize, raises, skipif, warns |
|
|
|
if mp is not None: |
|
|
|
from joblib.externals.loky import get_reusable_executor |
|
|
|
from queue import Queue |
|
|
|
try: |
|
import posix |
|
except ImportError: |
|
posix = None |
|
|
|
try: |
|
from ._openmp_test_helper.parallel_sum import parallel_sum |
|
except ImportError: |
|
parallel_sum = None |
|
|
|
try: |
|
import distributed |
|
except ImportError: |
|
distributed = None |
|
|
|
from joblib._parallel_backends import ( |
|
LokyBackend, |
|
MultiprocessingBackend, |
|
ParallelBackendBase, |
|
SequentialBackend, |
|
ThreadingBackend, |
|
) |
|
from joblib.parallel import ( |
|
BACKENDS, |
|
Parallel, |
|
cpu_count, |
|
delayed, |
|
effective_n_jobs, |
|
mp, |
|
parallel_backend, |
|
parallel_config, |
|
register_parallel_backend, |
|
) |
|
|
|
RETURN_GENERATOR_BACKENDS = BACKENDS.copy() |
|
RETURN_GENERATOR_BACKENDS.pop("multiprocessing", None) |
|
|
|
ALL_VALID_BACKENDS = [None] + sorted(BACKENDS.keys()) |
|
|
|
ALL_VALID_BACKENDS += [BACKENDS[backend_str]() for backend_str in BACKENDS] |
|
if mp is None: |
|
PROCESS_BACKENDS = [] |
|
else: |
|
PROCESS_BACKENDS = ["multiprocessing", "loky"] |
|
PARALLEL_BACKENDS = PROCESS_BACKENDS + ["threading"] |
|
|
|
if hasattr(mp, "get_context"): |
|
|
|
ALL_VALID_BACKENDS.append(mp.get_context("spawn")) |
|
|
|
|
|
def get_default_backend_instance(): |
|
|
|
|
|
|
|
|
|
return BACKENDS[parallel.DEFAULT_BACKEND] |
|
|
|
|
|
def get_workers(backend): |
|
return getattr(backend, "_pool", getattr(backend, "_workers", None)) |
|
|
|
|
|
def division(x, y): |
|
return x / y |
|
|
|
|
|
def square(x): |
|
return x**2 |
|
|
|
|
|
class MyExceptionWithFinickyInit(Exception): |
|
"""An exception class with non trivial __init__""" |
|
|
|
def __init__(self, a, b, c, d): |
|
pass |
|
|
|
|
|
def exception_raiser(x, custom_exception=False): |
|
if x == 7: |
|
raise ( |
|
MyExceptionWithFinickyInit("a", "b", "c", "d") |
|
if custom_exception |
|
else ValueError |
|
) |
|
return x |
|
|
|
|
|
def interrupt_raiser(x): |
|
time.sleep(0.05) |
|
raise KeyboardInterrupt |
|
|
|
|
|
def f(x, y=0, z=0): |
|
"""A module-level function so that it can be spawn with |
|
multiprocessing. |
|
""" |
|
return x**2 + y + z |
|
|
|
|
|
def _active_backend_type(): |
|
return type(parallel.get_active_backend()[0]) |
|
|
|
|
|
def parallel_func(inner_n_jobs, backend): |
|
return Parallel(n_jobs=inner_n_jobs, backend=backend)( |
|
delayed(square)(i) for i in range(3) |
|
) |
|
|
|
|
|
|
|
def test_cpu_count(): |
|
assert cpu_count() > 0 |
|
|
|
|
|
def test_effective_n_jobs(): |
|
assert effective_n_jobs() > 0 |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
@pytest.mark.parametrize( |
|
"backend_n_jobs, expected_n_jobs", |
|
[(3, 3), (-1, effective_n_jobs(n_jobs=-1)), (None, 1)], |
|
ids=["positive-int", "negative-int", "None"], |
|
) |
|
@with_multiprocessing |
|
def test_effective_n_jobs_None(context, backend_n_jobs, expected_n_jobs): |
|
|
|
|
|
with context("threading", n_jobs=backend_n_jobs): |
|
|
|
|
|
assert effective_n_jobs(n_jobs=None) == expected_n_jobs |
|
|
|
assert effective_n_jobs(n_jobs=None) == 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
@parametrize("backend", ALL_VALID_BACKENDS) |
|
@parametrize("n_jobs", [1, 2, -1, -2]) |
|
@parametrize("verbose", [2, 11, 100]) |
|
def test_simple_parallel(backend, n_jobs, verbose): |
|
assert [square(x) for x in range(5)] == Parallel( |
|
n_jobs=n_jobs, backend=backend, verbose=verbose |
|
)(delayed(square)(x) for x in range(5)) |
|
|
|
|
|
@parametrize("backend", ALL_VALID_BACKENDS) |
|
@parametrize("n_jobs", [1, 2]) |
|
def test_parallel_pretty_print(backend, n_jobs): |
|
n_tasks = 100 |
|
pattern = re.compile(r"(Done\s+\d+ out of \d+ \|)") |
|
|
|
class ParallelLog(Parallel): |
|
messages = [] |
|
|
|
def _print(self, msg): |
|
self.messages.append(msg) |
|
|
|
executor = ParallelLog(n_jobs=n_jobs, backend=backend, verbose=10000) |
|
executor([delayed(f)(i) for i in range(n_tasks)]) |
|
lens = set() |
|
for message in executor.messages: |
|
if s := pattern.search(message): |
|
a, b = s.span() |
|
lens.add(b - a) |
|
assert len(lens) == 1 |
|
|
|
|
|
@parametrize("backend", ALL_VALID_BACKENDS) |
|
def test_main_thread_renamed_no_warning(backend, monkeypatch): |
|
|
|
|
|
|
|
|
|
monkeypatch.setattr( |
|
target=threading.current_thread(), |
|
name="name", |
|
value="some_new_name_for_the_main_thread", |
|
) |
|
|
|
with warnings.catch_warnings(record=True) as warninfo: |
|
results = Parallel(n_jobs=2, backend=backend)( |
|
delayed(square)(x) for x in range(3) |
|
) |
|
assert results == [0, 1, 4] |
|
|
|
|
|
|
|
|
|
warninfo = [ |
|
w |
|
for w in warninfo |
|
if "worker timeout" not in str(w.message) |
|
and not isinstance(w.message, DeprecationWarning) |
|
] |
|
|
|
|
|
|
|
|
|
if backend in [None, "multiprocessing"] or isinstance( |
|
backend, MultiprocessingBackend |
|
): |
|
message_part = "multi-threaded, use of fork() may lead to deadlocks" |
|
warninfo = [w for w in warninfo if message_part not in str(w.message)] |
|
|
|
|
|
|
|
|
|
assert len(warninfo) == 0 |
|
|
|
|
|
def _assert_warning_nested(backend, inner_n_jobs, expected): |
|
with warnings.catch_warnings(record=True) as warninfo: |
|
warnings.simplefilter("always") |
|
parallel_func(backend=backend, inner_n_jobs=inner_n_jobs) |
|
|
|
warninfo = [w.message for w in warninfo] |
|
if expected: |
|
if warninfo: |
|
warnings_are_correct = all( |
|
"backed parallel loops cannot" in each.args[0] for each in warninfo |
|
) |
|
|
|
|
|
warnings_have_the_right_length = ( |
|
len(warninfo) >= 1 if IS_GIL_DISABLED else len(warninfo) == 1 |
|
) |
|
return warnings_are_correct and warnings_have_the_right_length |
|
|
|
return False |
|
else: |
|
assert not warninfo |
|
return True |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize( |
|
"parent_backend,child_backend,expected", |
|
[ |
|
("loky", "multiprocessing", True), |
|
("loky", "loky", False), |
|
("multiprocessing", "multiprocessing", True), |
|
("multiprocessing", "loky", True), |
|
("threading", "multiprocessing", True), |
|
("threading", "loky", True), |
|
], |
|
) |
|
def test_nested_parallel_warnings(parent_backend, child_backend, expected): |
|
|
|
Parallel(n_jobs=2, backend=parent_backend)( |
|
delayed(_assert_warning_nested)( |
|
backend=child_backend, inner_n_jobs=1, expected=False |
|
) |
|
for _ in range(5) |
|
) |
|
|
|
|
|
res = Parallel(n_jobs=2, backend=parent_backend)( |
|
delayed(_assert_warning_nested)( |
|
backend=child_backend, inner_n_jobs=2, expected=expected |
|
) |
|
for _ in range(5) |
|
) |
|
|
|
|
|
|
|
if parent_backend == "threading": |
|
assert any(res) |
|
else: |
|
assert all(res) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", ["loky", "multiprocessing", "threading"]) |
|
def test_background_thread_parallelism(backend): |
|
is_run_parallel = [False] |
|
|
|
def background_thread(is_run_parallel): |
|
with warnings.catch_warnings(record=True) as warninfo: |
|
Parallel(n_jobs=2)(delayed(sleep)(0.1) for _ in range(4)) |
|
print(len(warninfo)) |
|
is_run_parallel[0] = len(warninfo) == 0 |
|
|
|
t = threading.Thread(target=background_thread, args=(is_run_parallel,)) |
|
t.start() |
|
t.join() |
|
assert is_run_parallel[0] |
|
|
|
|
|
def nested_loop(backend): |
|
Parallel(n_jobs=2, backend=backend)(delayed(square)(0.01) for _ in range(2)) |
|
|
|
|
|
@parametrize("child_backend", BACKENDS) |
|
@parametrize("parent_backend", BACKENDS) |
|
def test_nested_loop(parent_backend, child_backend): |
|
Parallel(n_jobs=2, backend=parent_backend)( |
|
delayed(nested_loop)(child_backend) for _ in range(2) |
|
) |
|
|
|
|
|
def raise_exception(backend): |
|
raise ValueError |
|
|
|
|
|
@with_multiprocessing |
|
def test_nested_loop_with_exception_with_loky(): |
|
with raises(ValueError): |
|
with Parallel(n_jobs=2, backend="loky") as parallel: |
|
parallel([delayed(nested_loop)("loky"), delayed(raise_exception)("loky")]) |
|
|
|
|
|
def test_mutate_input_with_threads(): |
|
"""Input is mutable when using the threading backend""" |
|
q = Queue(maxsize=5) |
|
Parallel(n_jobs=2, backend="threading")(delayed(q.put)(1) for _ in range(5)) |
|
assert q.full() |
|
|
|
|
|
@parametrize("n_jobs", [1, 2, 3]) |
|
def test_parallel_kwargs(n_jobs): |
|
"""Check the keyword argument processing of pmap.""" |
|
lst = range(10) |
|
assert [f(x, y=1) for x in lst] == Parallel(n_jobs=n_jobs)( |
|
delayed(f)(x, y=1) for x in lst |
|
) |
|
|
|
|
|
@parametrize("backend", PARALLEL_BACKENDS) |
|
def test_parallel_as_context_manager(backend): |
|
lst = range(10) |
|
expected = [f(x, y=1) for x in lst] |
|
|
|
with Parallel(n_jobs=4, backend=backend) as p: |
|
|
|
|
|
managed_backend = p._backend |
|
|
|
|
|
|
|
assert expected == p(delayed(f)(x, y=1) for x in lst) |
|
assert expected == p(delayed(f)(x, y=1) for x in lst) |
|
|
|
|
|
if mp is not None: |
|
assert get_workers(managed_backend) is get_workers(p._backend) |
|
|
|
|
|
|
|
if mp is not None: |
|
assert get_workers(p._backend) is None |
|
|
|
|
|
assert expected == p(delayed(f)(x, y=1) for x in lst) |
|
if mp is not None: |
|
assert get_workers(p._backend) is None |
|
|
|
|
|
@with_multiprocessing |
|
def test_parallel_pickling(): |
|
"""Check that pmap captures the errors when it is passed an object |
|
that cannot be pickled. |
|
""" |
|
|
|
class UnpicklableObject(object): |
|
def __reduce__(self): |
|
raise RuntimeError("123") |
|
|
|
with raises(PicklingError, match=r"the task to send"): |
|
Parallel(n_jobs=2, backend="loky")( |
|
delayed(id)(UnpicklableObject()) for _ in range(10) |
|
) |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
@parametrize("byteorder", ["<", ">", "="]) |
|
@parametrize("max_nbytes", [1, "1M"]) |
|
def test_parallel_byteorder_corruption(byteorder, max_nbytes): |
|
def inspect_byteorder(x): |
|
return x, x.dtype.byteorder |
|
|
|
x = np.arange(6).reshape((2, 3)).view(f"{byteorder}i4") |
|
|
|
initial_np_byteorder = x.dtype.byteorder |
|
|
|
result = Parallel(n_jobs=2, backend="loky", max_nbytes=max_nbytes)( |
|
delayed(inspect_byteorder)(x) for _ in range(3) |
|
) |
|
|
|
for x_returned, byteorder_in_worker in result: |
|
assert byteorder_in_worker == initial_np_byteorder |
|
assert byteorder_in_worker == x_returned.dtype.byteorder |
|
np.testing.assert_array_equal(x, x_returned) |
|
|
|
|
|
@parametrize("backend", PARALLEL_BACKENDS) |
|
def test_parallel_timeout_success(backend): |
|
|
|
assert ( |
|
len( |
|
Parallel(n_jobs=2, backend=backend, timeout=30)( |
|
delayed(sleep)(0.001) for x in range(10) |
|
) |
|
) |
|
== 10 |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PARALLEL_BACKENDS) |
|
def test_parallel_timeout_fail(backend): |
|
|
|
with raises(TimeoutError): |
|
Parallel(n_jobs=2, backend=backend, timeout=0.01)( |
|
delayed(sleep)(10) for x in range(10) |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", set(RETURN_GENERATOR_BACKENDS) - {"sequential"}) |
|
@parametrize("return_as", ["generator", "generator_unordered"]) |
|
def test_parallel_timeout_fail_with_generator(backend, return_as): |
|
|
|
|
|
with raises(TimeoutError): |
|
list( |
|
Parallel(n_jobs=2, backend=backend, return_as=return_as, timeout=0.1)( |
|
delayed(sleep)(10) for x in range(10) |
|
) |
|
) |
|
|
|
|
|
list( |
|
Parallel(n_jobs=2, backend=backend, return_as=return_as, timeout=10)( |
|
delayed(sleep)(0.01) for x in range(10) |
|
) |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_error_capture(backend): |
|
|
|
|
|
if mp is not None: |
|
with raises(ZeroDivisionError): |
|
Parallel(n_jobs=2, backend=backend)( |
|
[delayed(division)(x, y) for x, y in zip((0, 1), (1, 0))] |
|
) |
|
|
|
with raises(KeyboardInterrupt): |
|
Parallel(n_jobs=2, backend=backend)( |
|
[delayed(interrupt_raiser)(x) for x in (1, 0)] |
|
) |
|
|
|
|
|
with Parallel(n_jobs=2, backend=backend) as parallel: |
|
assert get_workers(parallel._backend) is not None |
|
original_workers = get_workers(parallel._backend) |
|
|
|
with raises(ZeroDivisionError): |
|
parallel([delayed(division)(x, y) for x, y in zip((0, 1), (1, 0))]) |
|
|
|
|
|
|
|
assert get_workers(parallel._backend) is not None |
|
|
|
|
|
assert get_workers(parallel._backend) is not original_workers |
|
|
|
assert [f(x, y=1) for x in range(10)] == parallel( |
|
delayed(f)(x, y=1) for x in range(10) |
|
) |
|
|
|
original_workers = get_workers(parallel._backend) |
|
with raises(KeyboardInterrupt): |
|
parallel([delayed(interrupt_raiser)(x) for x in (1, 0)]) |
|
|
|
|
|
assert get_workers(parallel._backend) is not None |
|
|
|
|
|
assert get_workers(parallel._backend) is not original_workers |
|
|
|
assert [f(x, y=1) for x in range(10)] == parallel( |
|
delayed(f)(x, y=1) for x in range(10) |
|
), ( |
|
parallel._iterating, |
|
parallel.n_completed_tasks, |
|
parallel.n_dispatched_tasks, |
|
parallel._aborting, |
|
) |
|
|
|
|
|
|
|
assert get_workers(parallel._backend) is None |
|
else: |
|
with raises(KeyboardInterrupt): |
|
Parallel(n_jobs=2)([delayed(interrupt_raiser)(x) for x in (1, 0)]) |
|
|
|
|
|
|
|
with raises(ZeroDivisionError): |
|
Parallel(n_jobs=2)([delayed(division)(x, y) for x, y in zip((0, 1), (1, 0))]) |
|
|
|
with raises(MyExceptionWithFinickyInit): |
|
Parallel(n_jobs=2, verbose=0)( |
|
(delayed(exception_raiser)(i, custom_exception=True) for i in range(30)) |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", BACKENDS) |
|
def test_error_in_task_iterator(backend): |
|
def my_generator(raise_at=0): |
|
for i in range(20): |
|
if i == raise_at: |
|
raise ValueError("Iterator Raising Error") |
|
yield i |
|
|
|
with Parallel(n_jobs=2, backend=backend) as p: |
|
|
|
with raises(ValueError, match="Iterator Raising Error"): |
|
p(delayed(square)(i) for i in my_generator(raise_at=0)) |
|
|
|
|
|
|
|
with raises(ValueError, match="Iterator Raising Error"): |
|
p(delayed(square)(i) for i in my_generator(raise_at=5)) |
|
|
|
|
|
with raises(ValueError, match="Iterator Raising Error"): |
|
p(delayed(square)(i) for i in my_generator(raise_at=19)) |
|
|
|
|
|
def consumer(queue, item): |
|
queue.append("Consumed %s" % item) |
|
|
|
|
|
@parametrize("backend", BACKENDS) |
|
@parametrize( |
|
"batch_size, expected_queue", |
|
[ |
|
( |
|
1, |
|
[ |
|
"Produced 0", |
|
"Consumed 0", |
|
"Produced 1", |
|
"Consumed 1", |
|
"Produced 2", |
|
"Consumed 2", |
|
"Produced 3", |
|
"Consumed 3", |
|
"Produced 4", |
|
"Consumed 4", |
|
"Produced 5", |
|
"Consumed 5", |
|
], |
|
), |
|
( |
|
4, |
|
[ |
|
"Produced 0", |
|
"Produced 1", |
|
"Produced 2", |
|
"Produced 3", |
|
"Consumed 0", |
|
"Consumed 1", |
|
"Consumed 2", |
|
"Consumed 3", |
|
|
|
"Produced 4", |
|
"Produced 5", |
|
"Consumed 4", |
|
"Consumed 5", |
|
], |
|
), |
|
], |
|
) |
|
def test_dispatch_one_job(backend, batch_size, expected_queue): |
|
"""Test that with only one job, Parallel does act as a iterator.""" |
|
queue = list() |
|
|
|
def producer(): |
|
for i in range(6): |
|
queue.append("Produced %i" % i) |
|
yield i |
|
|
|
Parallel(n_jobs=1, batch_size=batch_size, backend=backend)( |
|
delayed(consumer)(queue, x) for x in producer() |
|
) |
|
assert queue == expected_queue |
|
assert len(queue) == 12 |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PARALLEL_BACKENDS) |
|
def test_dispatch_multiprocessing(backend): |
|
"""Check that using pre_dispatch Parallel does indeed dispatch items |
|
lazily. |
|
""" |
|
manager = mp.Manager() |
|
queue = manager.list() |
|
|
|
def producer(): |
|
for i in range(6): |
|
queue.append("Produced %i" % i) |
|
yield i |
|
|
|
Parallel(n_jobs=2, batch_size=1, pre_dispatch=3, backend=backend)( |
|
delayed(consumer)(queue, "any") for _ in producer() |
|
) |
|
|
|
queue_contents = list(queue) |
|
assert queue_contents[0] == "Produced 0" |
|
|
|
|
|
|
|
first_consumption_index = queue_contents[:4].index("Consumed any") |
|
assert first_consumption_index > -1 |
|
|
|
produced_3_index = queue_contents.index("Produced 3") |
|
assert produced_3_index > first_consumption_index |
|
|
|
assert len(queue) == 12 |
|
|
|
|
|
def test_batching_auto_threading(): |
|
|
|
|
|
|
|
|
|
with Parallel(n_jobs=2, batch_size="auto", backend="threading") as p: |
|
p(delayed(id)(i) for i in range(5000)) |
|
assert p._backend.compute_batch_size() == 1 |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_batching_auto_subprocesses(backend): |
|
with Parallel(n_jobs=2, batch_size="auto", backend=backend) as p: |
|
p(delayed(id)(i) for i in range(5000)) |
|
|
|
|
|
|
|
|
|
assert p._backend.compute_batch_size() > 0 |
|
|
|
|
|
def test_exception_dispatch(): |
|
"""Make sure that exception raised during dispatch are indeed captured""" |
|
with raises(ValueError): |
|
Parallel(n_jobs=2, pre_dispatch=16, verbose=0)( |
|
delayed(exception_raiser)(i) for i in range(30) |
|
) |
|
|
|
|
|
def nested_function_inner(i): |
|
Parallel(n_jobs=2)(delayed(exception_raiser)(j) for j in range(30)) |
|
|
|
|
|
def nested_function_outer(i): |
|
Parallel(n_jobs=2)(delayed(nested_function_inner)(j) for j in range(30)) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PARALLEL_BACKENDS) |
|
@pytest.mark.xfail(reason="https://github.com/joblib/loky/pull/255") |
|
def test_nested_exception_dispatch(backend): |
|
"""Ensure errors for nested joblib cases gets propagated |
|
|
|
We rely on the Python 3 built-in __cause__ system that already |
|
report this kind of information to the user. |
|
""" |
|
with raises(ValueError) as excinfo: |
|
Parallel(n_jobs=2, backend=backend)( |
|
delayed(nested_function_outer)(i) for i in range(30) |
|
) |
|
|
|
|
|
|
|
report_lines = format_exception(excinfo.type, excinfo.value, excinfo.tb) |
|
report = "".join(report_lines) |
|
assert "nested_function_outer" in report |
|
assert "nested_function_inner" in report |
|
assert "exception_raiser" in report |
|
|
|
assert type(excinfo.value) is ValueError |
|
|
|
|
|
class FakeParallelBackend(SequentialBackend): |
|
"""Pretends to run concurrently while running sequentially.""" |
|
|
|
def configure(self, n_jobs=1, parallel=None, **backend_args): |
|
self.n_jobs = self.effective_n_jobs(n_jobs) |
|
self.parallel = parallel |
|
return n_jobs |
|
|
|
def effective_n_jobs(self, n_jobs=1): |
|
if n_jobs < 0: |
|
n_jobs = max(mp.cpu_count() + 1 + n_jobs, 1) |
|
return n_jobs |
|
|
|
|
|
def test_invalid_backend(): |
|
with raises(ValueError, match="Invalid backend:"): |
|
Parallel(backend="unit-testing") |
|
|
|
with raises(ValueError, match="Invalid backend:"): |
|
with parallel_config(backend="unit-testing"): |
|
pass |
|
|
|
with raises(ValueError, match="Invalid backend:"): |
|
with parallel_config(backend="unit-testing"): |
|
pass |
|
|
|
|
|
@parametrize("backend", ALL_VALID_BACKENDS) |
|
def test_invalid_njobs(backend): |
|
with raises(ValueError) as excinfo: |
|
Parallel(n_jobs=0, backend=backend)._initialize_backend() |
|
assert "n_jobs == 0 in Parallel has no meaning" in str(excinfo.value) |
|
|
|
with raises(ValueError) as excinfo: |
|
Parallel(n_jobs=0.5, backend=backend)._initialize_backend() |
|
assert "n_jobs == 0 in Parallel has no meaning" in str(excinfo.value) |
|
|
|
with raises(ValueError) as excinfo: |
|
Parallel(n_jobs="2.3", backend=backend)._initialize_backend() |
|
assert "n_jobs could not be converted to int" in str(excinfo.value) |
|
|
|
with raises(ValueError) as excinfo: |
|
Parallel(n_jobs="invalid_str", backend=backend)._initialize_backend() |
|
assert "n_jobs could not be converted to int" in str(excinfo.value) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PARALLEL_BACKENDS) |
|
@parametrize("n_jobs", ["2", 2.3, 2]) |
|
def test_njobs_converted_to_int(backend, n_jobs): |
|
p = Parallel(n_jobs=n_jobs, backend=backend) |
|
assert p._effective_n_jobs() == 2 |
|
|
|
res = p(delayed(square)(i) for i in range(10)) |
|
assert all(r == square(i) for i, r in enumerate(res)) |
|
|
|
|
|
def test_register_parallel_backend(): |
|
try: |
|
register_parallel_backend("test_backend", FakeParallelBackend) |
|
assert "test_backend" in BACKENDS |
|
assert BACKENDS["test_backend"] == FakeParallelBackend |
|
finally: |
|
del BACKENDS["test_backend"] |
|
|
|
|
|
def test_overwrite_default_backend(): |
|
default_backend_orig = parallel.DEFAULT_BACKEND |
|
assert _active_backend_type() == get_default_backend_instance() |
|
try: |
|
register_parallel_backend("threading", BACKENDS["threading"], make_default=True) |
|
assert _active_backend_type() == ThreadingBackend |
|
finally: |
|
|
|
parallel.DEFAULT_BACKEND = default_backend_orig |
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
|
|
@skipif(mp is not None, reason="Only without multiprocessing") |
|
def test_backend_no_multiprocessing(): |
|
with warns(UserWarning, match="joblib backend '.*' is not available on.*"): |
|
Parallel(backend="loky")(delayed(square)(i) for i in range(3)) |
|
|
|
|
|
with parallel_config(backend="loky"): |
|
Parallel()(delayed(square)(i) for i in range(3)) |
|
|
|
|
|
def check_backend_context_manager(context, backend_name): |
|
with context(backend_name, n_jobs=3): |
|
active_backend, active_n_jobs = parallel.get_active_backend() |
|
assert active_n_jobs == 3 |
|
assert effective_n_jobs(3) == 3 |
|
p = Parallel() |
|
assert p.n_jobs == 3 |
|
if backend_name == "multiprocessing": |
|
assert type(active_backend) is MultiprocessingBackend |
|
assert type(p._backend) is MultiprocessingBackend |
|
elif backend_name == "loky": |
|
assert type(active_backend) is LokyBackend |
|
assert type(p._backend) is LokyBackend |
|
elif backend_name == "threading": |
|
assert type(active_backend) is ThreadingBackend |
|
assert type(p._backend) is ThreadingBackend |
|
elif backend_name.startswith("test_"): |
|
assert type(active_backend) is FakeParallelBackend |
|
assert type(p._backend) is FakeParallelBackend |
|
|
|
|
|
all_backends_for_context_manager = PARALLEL_BACKENDS[:] |
|
all_backends_for_context_manager.extend(["test_backend_%d" % i for i in range(3)]) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", all_backends_for_context_manager) |
|
@parametrize("context", [parallel_backend, parallel_config]) |
|
def test_backend_context_manager(monkeypatch, backend, context): |
|
if backend not in BACKENDS: |
|
monkeypatch.setitem(BACKENDS, backend, FakeParallelBackend) |
|
|
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
check_backend_context_manager(context, backend) |
|
|
|
|
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
|
|
Parallel(n_jobs=2, backend="threading")( |
|
delayed(check_backend_context_manager)(context, b) |
|
for b in all_backends_for_context_manager |
|
if not b |
|
) |
|
|
|
|
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
|
|
class ParameterizedParallelBackend(SequentialBackend): |
|
"""Pretends to run conncurrently while running sequentially.""" |
|
|
|
def __init__(self, param=None): |
|
if param is None: |
|
raise ValueError("param should not be None") |
|
self.param = param |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_parameterized_backend_context_manager(monkeypatch, context): |
|
monkeypatch.setitem(BACKENDS, "param_backend", ParameterizedParallelBackend) |
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
with context("param_backend", param=42, n_jobs=3): |
|
active_backend, active_n_jobs = parallel.get_active_backend() |
|
assert type(active_backend) is ParameterizedParallelBackend |
|
assert active_backend.param == 42 |
|
assert active_n_jobs == 3 |
|
p = Parallel() |
|
assert p.n_jobs == 3 |
|
assert p._backend is active_backend |
|
results = p(delayed(sqrt)(i) for i in range(5)) |
|
assert results == [sqrt(i) for i in range(5)] |
|
|
|
|
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_directly_parameterized_backend_context_manager(context): |
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
|
|
|
|
with context(ParameterizedParallelBackend(param=43), n_jobs=5): |
|
active_backend, active_n_jobs = parallel.get_active_backend() |
|
assert type(active_backend) is ParameterizedParallelBackend |
|
assert active_backend.param == 43 |
|
assert active_n_jobs == 5 |
|
p = Parallel() |
|
assert p.n_jobs == 5 |
|
assert p._backend is active_backend |
|
results = p(delayed(sqrt)(i) for i in range(5)) |
|
assert results == [sqrt(i) for i in range(5)] |
|
|
|
|
|
assert _active_backend_type() == get_default_backend_instance() |
|
|
|
|
|
def sleep_and_return_pid(): |
|
sleep(0.1) |
|
return os.getpid() |
|
|
|
|
|
def get_nested_pids(): |
|
assert _active_backend_type() == ThreadingBackend |
|
|
|
|
|
assert Parallel()._effective_n_jobs() == 1 |
|
|
|
|
|
return Parallel(n_jobs=2)(delayed(sleep_and_return_pid)() for _ in range(2)) |
|
|
|
|
|
class MyBackend(joblib._parallel_backends.LokyBackend): |
|
"""Backend to test backward compatibility with older backends""" |
|
|
|
def get_nested_backend( |
|
self, |
|
): |
|
|
|
return super(MyBackend, self).get_nested_backend()[0] |
|
|
|
|
|
register_parallel_backend("back_compat_backend", MyBackend) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", ["threading", "loky", "multiprocessing", "back_compat_backend"]) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_nested_backend_context_manager(context, backend): |
|
|
|
|
|
|
|
with context(backend): |
|
pid_groups = Parallel(n_jobs=2)(delayed(get_nested_pids)() for _ in range(10)) |
|
for pid_group in pid_groups: |
|
assert len(set(pid_group)) == 1 |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("n_jobs", [2, -1, None]) |
|
@parametrize("backend", PARALLEL_BACKENDS) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_nested_backend_in_sequential(backend, n_jobs, context): |
|
|
|
|
|
|
|
def check_nested_backend(expected_backend_type, expected_n_job): |
|
|
|
|
|
assert _active_backend_type() == BACKENDS[expected_backend_type] |
|
|
|
|
|
|
|
expected_n_job = effective_n_jobs(expected_n_job) |
|
assert Parallel()._effective_n_jobs() == expected_n_job |
|
|
|
Parallel(n_jobs=1)( |
|
delayed(check_nested_backend)(parallel.DEFAULT_BACKEND, 1) for _ in range(10) |
|
) |
|
|
|
with context(backend, n_jobs=n_jobs): |
|
Parallel(n_jobs=1)( |
|
delayed(check_nested_backend)(backend, n_jobs) for _ in range(10) |
|
) |
|
|
|
|
|
def check_nesting_level(context, inner_backend, expected_level): |
|
with context(inner_backend) as ctx: |
|
if context is parallel_config: |
|
backend = ctx["backend"] |
|
if context is parallel_backend: |
|
backend = ctx[0] |
|
assert backend.nesting_level == expected_level |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("outer_backend", PARALLEL_BACKENDS) |
|
@parametrize("inner_backend", PARALLEL_BACKENDS) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_backend_nesting_level(context, outer_backend, inner_backend): |
|
|
|
check_nesting_level(context, outer_backend, 0) |
|
|
|
Parallel(n_jobs=2, backend=outer_backend)( |
|
delayed(check_nesting_level)(context, inner_backend, 1) for _ in range(10) |
|
) |
|
|
|
with context(inner_backend, n_jobs=2): |
|
Parallel()( |
|
delayed(check_nesting_level)(context, inner_backend, 1) for _ in range(10) |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
@parametrize("with_retrieve_callback", [True, False]) |
|
def test_retrieval_context(context, with_retrieve_callback): |
|
import contextlib |
|
|
|
class MyBackend(ThreadingBackend): |
|
i = 0 |
|
supports_retrieve_callback = with_retrieve_callback |
|
|
|
@contextlib.contextmanager |
|
def retrieval_context(self): |
|
self.i += 1 |
|
yield |
|
|
|
register_parallel_backend("retrieval", MyBackend) |
|
|
|
def nested_call(n): |
|
return Parallel(n_jobs=2)(delayed(id)(i) for i in range(n)) |
|
|
|
with context("retrieval") as ctx: |
|
Parallel(n_jobs=2)(delayed(nested_call)(i) for i in range(5)) |
|
if context is parallel_config: |
|
assert ctx["backend"].i == 1 |
|
if context is parallel_backend: |
|
assert ctx[0].i == 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
@parametrize("batch_size", [0, -1, 1.42]) |
|
def test_invalid_batch_size(batch_size): |
|
with raises(ValueError): |
|
Parallel(batch_size=batch_size) |
|
|
|
|
|
@parametrize( |
|
"n_tasks, n_jobs, pre_dispatch, batch_size", |
|
[ |
|
(2, 2, "all", "auto"), |
|
(2, 2, "n_jobs", "auto"), |
|
(10, 2, "n_jobs", "auto"), |
|
(517, 2, "n_jobs", "auto"), |
|
(10, 2, "n_jobs", "auto"), |
|
(10, 4, "n_jobs", "auto"), |
|
(200, 12, "n_jobs", "auto"), |
|
(25, 12, "2 * n_jobs", 1), |
|
(250, 12, "all", 1), |
|
(250, 12, "2 * n_jobs", 7), |
|
(200, 12, "2 * n_jobs", "auto"), |
|
], |
|
) |
|
def test_dispatch_race_condition(n_tasks, n_jobs, pre_dispatch, batch_size): |
|
|
|
|
|
|
|
params = {"n_jobs": n_jobs, "pre_dispatch": pre_dispatch, "batch_size": batch_size} |
|
expected = [square(i) for i in range(n_tasks)] |
|
results = Parallel(**params)(delayed(square)(i) for i in range(n_tasks)) |
|
assert results == expected |
|
|
|
|
|
@with_multiprocessing |
|
def test_default_mp_context(): |
|
mp_start_method = mp.get_start_method() |
|
p = Parallel(n_jobs=2, backend="multiprocessing") |
|
context = p._backend_kwargs.get("context") |
|
start_method = context.get_start_method() |
|
assert start_method == mp_start_method |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_no_blas_crash_or_freeze_with_subprocesses(backend): |
|
if backend == "multiprocessing": |
|
|
|
|
|
backend = mp.get_context("spawn") |
|
|
|
|
|
|
|
|
|
|
|
rng = np.random.RandomState(42) |
|
|
|
|
|
|
|
a = rng.randn(1000, 1000) |
|
np.dot(a, a.T) |
|
|
|
|
|
|
|
Parallel(n_jobs=2, backend=backend)(delayed(np.dot)(a, a.T) for i in range(2)) |
|
|
|
|
|
UNPICKLABLE_CALLABLE_SCRIPT_TEMPLATE_NO_MAIN = """\ |
|
from joblib import Parallel, delayed |
|
|
|
def square(x): |
|
return x ** 2 |
|
|
|
backend = "{}" |
|
if backend == "spawn": |
|
from multiprocessing import get_context |
|
backend = get_context(backend) |
|
|
|
print(Parallel(n_jobs=2, backend=backend)( |
|
delayed(square)(i) for i in range(5))) |
|
""" |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_parallel_with_interactively_defined_functions(backend): |
|
|
|
|
|
if backend == "multiprocessing" and mp.get_start_method() != "fork": |
|
pytest.skip( |
|
"Require fork start method to use interactively defined " |
|
"functions with multiprocessing." |
|
) |
|
code = UNPICKLABLE_CALLABLE_SCRIPT_TEMPLATE_NO_MAIN.format(backend) |
|
check_subprocess_call( |
|
[sys.executable, "-c", code], timeout=10, stdout_regex=r"\[0, 1, 4, 9, 16\]" |
|
) |
|
|
|
|
|
UNPICKLABLE_CALLABLE_SCRIPT_TEMPLATE_MAIN = """\ |
|
import sys |
|
# Make sure that joblib is importable in the subprocess launching this |
|
# script. This is needed in case we run the tests from the joblib root |
|
# folder without having installed joblib |
|
sys.path.insert(0, {joblib_root_folder!r}) |
|
|
|
from joblib import Parallel, delayed |
|
|
|
def run(f, x): |
|
return f(x) |
|
|
|
{define_func} |
|
|
|
if __name__ == "__main__": |
|
backend = "{backend}" |
|
if backend == "spawn": |
|
from multiprocessing import get_context |
|
backend = get_context(backend) |
|
|
|
callable_position = "{callable_position}" |
|
if callable_position == "delayed": |
|
print(Parallel(n_jobs=2, backend=backend)( |
|
delayed(square)(i) for i in range(5))) |
|
elif callable_position == "args": |
|
print(Parallel(n_jobs=2, backend=backend)( |
|
delayed(run)(square, i) for i in range(5))) |
|
else: |
|
print(Parallel(n_jobs=2, backend=backend)( |
|
delayed(run)(f=square, x=i) for i in range(5))) |
|
""" |
|
|
|
SQUARE_MAIN = """\ |
|
def square(x): |
|
return x ** 2 |
|
""" |
|
SQUARE_LOCAL = """\ |
|
def gen_square(): |
|
def square(x): |
|
return x ** 2 |
|
return square |
|
square = gen_square() |
|
""" |
|
SQUARE_LAMBDA = """\ |
|
square = lambda x: x ** 2 |
|
""" |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS + ([] if mp is None else ["spawn"])) |
|
@parametrize("define_func", [SQUARE_MAIN, SQUARE_LOCAL, SQUARE_LAMBDA]) |
|
@parametrize("callable_position", ["delayed", "args", "kwargs"]) |
|
def test_parallel_with_unpicklable_functions_in_args( |
|
backend, define_func, callable_position, tmpdir |
|
): |
|
if backend in ["multiprocessing", "spawn"] and ( |
|
define_func != SQUARE_MAIN or sys.platform == "win32" |
|
): |
|
pytest.skip("Not picklable with pickle") |
|
code = UNPICKLABLE_CALLABLE_SCRIPT_TEMPLATE_MAIN.format( |
|
define_func=define_func, |
|
backend=backend, |
|
callable_position=callable_position, |
|
joblib_root_folder=os.path.dirname(os.path.dirname(joblib.__file__)), |
|
) |
|
code_file = tmpdir.join("unpicklable_func_script.py") |
|
code_file.write(code) |
|
check_subprocess_call( |
|
[sys.executable, code_file.strpath], |
|
timeout=10, |
|
stdout_regex=r"\[0, 1, 4, 9, 16\]", |
|
) |
|
|
|
|
|
INTERACTIVE_DEFINED_FUNCTION_AND_CLASS_SCRIPT_CONTENT = """\ |
|
import sys |
|
import faulthandler |
|
# Make sure that joblib is importable in the subprocess launching this |
|
# script. This is needed in case we run the tests from the joblib root |
|
# folder without having installed joblib |
|
sys.path.insert(0, {joblib_root_folder!r}) |
|
|
|
from joblib import Parallel, delayed |
|
from functools import partial |
|
|
|
class MyClass: |
|
'''Class defined in the __main__ namespace''' |
|
def __init__(self, value): |
|
self.value = value |
|
|
|
|
|
def square(x, ignored=None, ignored2=None): |
|
'''Function defined in the __main__ namespace''' |
|
return x.value ** 2 |
|
|
|
|
|
square2 = partial(square, ignored2='something') |
|
|
|
# Here, we do not need the `if __name__ == "__main__":` safeguard when |
|
# using the default `loky` backend (even on Windows). |
|
|
|
# To make debugging easier |
|
faulthandler.dump_traceback_later(30, exit=True) |
|
|
|
# The following baroque function call is meant to check that joblib |
|
# introspection rightfully uses cloudpickle instead of the (faster) pickle |
|
# module of the standard library when necessary. In particular cloudpickle is |
|
# necessary for functions and instances of classes interactively defined in the |
|
# __main__ module. |
|
|
|
print(Parallel(backend="loky", n_jobs=2)( |
|
delayed(square2)(MyClass(i), ignored=[dict(a=MyClass(1))]) |
|
for i in range(5) |
|
)) |
|
""".format(joblib_root_folder=os.path.dirname(os.path.dirname(joblib.__file__))) |
|
|
|
|
|
@with_multiprocessing |
|
def test_parallel_with_interactively_defined_functions_loky(tmpdir): |
|
|
|
|
|
|
|
script = tmpdir.join("joblib_interactively_defined_function.py") |
|
script.write(INTERACTIVE_DEFINED_FUNCTION_AND_CLASS_SCRIPT_CONTENT) |
|
check_subprocess_call( |
|
[sys.executable, script.strpath], |
|
stdout_regex=r"\[0, 1, 4, 9, 16\]", |
|
timeout=None, |
|
) |
|
|
|
|
|
INTERACTIVELY_DEFINED_SUBCLASS_WITH_METHOD_SCRIPT_CONTENT = """\ |
|
import sys |
|
# Make sure that joblib is importable in the subprocess launching this |
|
# script. This is needed in case we run the tests from the joblib root |
|
# folder without having installed joblib |
|
sys.path.insert(0, {joblib_root_folder!r}) |
|
|
|
from joblib import Parallel, delayed, hash |
|
import multiprocessing as mp |
|
mp.util.log_to_stderr(5) |
|
|
|
class MyList(list): |
|
'''MyList is interactively defined by MyList.append is a built-in''' |
|
def __hash__(self): |
|
# XXX: workaround limitation in cloudpickle |
|
return hash(self).__hash__() |
|
|
|
l = MyList() |
|
|
|
print(Parallel(backend="loky", n_jobs=2)( |
|
delayed(l.append)(i) for i in range(3) |
|
)) |
|
""".format(joblib_root_folder=os.path.dirname(os.path.dirname(joblib.__file__))) |
|
|
|
|
|
@with_multiprocessing |
|
def test_parallel_with_interactively_defined_bound_method_loky(tmpdir): |
|
script = tmpdir.join("joblib_interactive_bound_method_script.py") |
|
script.write(INTERACTIVELY_DEFINED_SUBCLASS_WITH_METHOD_SCRIPT_CONTENT) |
|
check_subprocess_call( |
|
[sys.executable, script.strpath], |
|
stdout_regex=r"\[None, None, None\]", |
|
stderr_regex=r"LokyProcess", |
|
timeout=15, |
|
) |
|
|
|
|
|
def test_parallel_with_exhausted_iterator(): |
|
exhausted_iterator = iter([]) |
|
assert Parallel(n_jobs=2)(exhausted_iterator) == [] |
|
|
|
|
|
def check_memmap(a): |
|
if not isinstance(a, np.memmap): |
|
raise TypeError("Expected np.memmap instance, got %r", type(a)) |
|
return a.copy() |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_auto_memmap_on_arrays_from_generator(backend): |
|
|
|
|
|
|
|
|
|
def generate_arrays(n): |
|
for i in range(n): |
|
yield np.ones(10, dtype=np.float32) * i |
|
|
|
|
|
|
|
results = Parallel(n_jobs=2, max_nbytes=1, backend=backend)( |
|
delayed(check_memmap)(a) for a in generate_arrays(100) |
|
) |
|
for result, expected in zip(results, generate_arrays(len(results))): |
|
np.testing.assert_array_equal(expected, result) |
|
|
|
|
|
|
|
|
|
results = Parallel(n_jobs=4, max_nbytes=1, backend=backend)( |
|
delayed(check_memmap)(a) for a in generate_arrays(100) |
|
) |
|
for result, expected in zip(results, generate_arrays(len(results))): |
|
np.testing.assert_array_equal(expected, result) |
|
|
|
|
|
def identity(arg): |
|
return arg |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
def test_memmap_with_big_offset(tmpdir): |
|
fname = tmpdir.join("test.mmap").strpath |
|
size = mmap.ALLOCATIONGRANULARITY |
|
obj = [np.zeros(size, dtype="uint8"), np.ones(size, dtype="uint8")] |
|
dump(obj, fname) |
|
memmap = load(fname, mmap_mode="r") |
|
(result,) = Parallel(n_jobs=2)(delayed(identity)(memmap) for _ in [0]) |
|
assert isinstance(memmap[1], np.memmap) |
|
assert memmap[1].offset > size |
|
np.testing.assert_array_equal(obj, result) |
|
|
|
|
|
def test_warning_about_timeout_not_supported_by_backend(): |
|
with warnings.catch_warnings(record=True) as warninfo: |
|
Parallel(n_jobs=1, timeout=1)(delayed(square)(i) for i in range(50)) |
|
assert len(warninfo) == 1 |
|
w = warninfo[0] |
|
assert isinstance(w.message, UserWarning) |
|
assert str(w.message) == ( |
|
"The backend class 'SequentialBackend' does not support timeout. " |
|
"You have set 'timeout=1' in Parallel but the 'timeout' parameter " |
|
"will not be used." |
|
) |
|
|
|
|
|
def set_list_value(input_list, index, value): |
|
input_list[index] = value |
|
return value |
|
|
|
|
|
@pytest.mark.parametrize("n_jobs", [1, 2, 4]) |
|
def test_parallel_return_order_with_return_as_generator_parameter(n_jobs): |
|
|
|
|
|
|
|
input_list = [0] * 5 |
|
result = Parallel(n_jobs=n_jobs, return_as="generator", backend="threading")( |
|
delayed(set_list_value)(input_list, i, i) for i in range(5) |
|
) |
|
|
|
|
|
result = list(result) |
|
|
|
assert all(v == r for v, r in zip(input_list, result)) |
|
|
|
|
|
def _sqrt_with_delay(e, delay): |
|
if delay: |
|
sleep(30) |
|
return sqrt(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_parallel_unordered_generator_returns_fastest_first(backend, n_jobs): |
|
|
|
|
|
|
|
result = Parallel(n_jobs=n_jobs, return_as="generator_unordered", backend=backend)( |
|
delayed(_sqrt_with_delay)(i**2, (i == 1)) for i in range(10) |
|
) |
|
|
|
quickly_returned = sorted(next(result) for _ in range(9)) |
|
|
|
expected_quickly_returned = [0] + list(range(2, 10)) |
|
|
|
assert all(v == r for v, r in zip(expected_quickly_returned, quickly_returned)) |
|
|
|
del result |
|
|
|
|
|
@pytest.mark.parametrize("n_jobs", [2, 4]) |
|
|
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", set(RETURN_GENERATOR_BACKENDS) - {"sequential"}) |
|
def test_parallel_unordered_generator_returns_fastest_first(backend, n_jobs): |
|
_test_parallel_unordered_generator_returns_fastest_first(backend, n_jobs) |
|
|
|
|
|
@parametrize("backend", ALL_VALID_BACKENDS) |
|
@parametrize("n_jobs", [1, 2, -2, -1]) |
|
def test_abort_backend(n_jobs, backend): |
|
delays = ["a"] + [10] * 100 |
|
with raises(TypeError): |
|
t_start = time.time() |
|
Parallel(n_jobs=n_jobs, backend=backend)(delayed(time.sleep)(i) for i in delays) |
|
dt = time.time() - t_start |
|
assert dt < 20 |
|
|
|
|
|
def get_large_object(arg): |
|
result = np.ones(int(5 * 1e5), dtype=bool) |
|
result[0] = False |
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_deadlock_with_generator(backend, return_as, n_jobs): |
|
|
|
|
|
with Parallel(n_jobs=n_jobs, backend=backend, return_as=return_as) as parallel: |
|
result = parallel(delayed(get_large_object)(i) for i in range(10)) |
|
next(result) |
|
next(result) |
|
del result |
|
|
|
|
|
@with_numpy |
|
@parametrize("backend", RETURN_GENERATOR_BACKENDS) |
|
@parametrize("return_as", ["generator", "generator_unordered"]) |
|
@parametrize("n_jobs", [1, 2, -2, -1]) |
|
def test_deadlock_with_generator(backend, return_as, n_jobs): |
|
_test_deadlock_with_generator(backend, return_as, n_jobs) |
|
|
|
|
|
@parametrize("backend", RETURN_GENERATOR_BACKENDS) |
|
@parametrize("return_as", ["generator", "generator_unordered"]) |
|
@parametrize("n_jobs", [1, 2, -2, -1]) |
|
def test_multiple_generator_call(backend, return_as, n_jobs): |
|
|
|
|
|
|
|
with raises(RuntimeError, match="This Parallel instance is already running"): |
|
parallel = Parallel(n_jobs, backend=backend, return_as=return_as) |
|
g = parallel(delayed(sleep)(1) for _ in range(10)) |
|
t_start = time.time() |
|
gen2 = parallel(delayed(id)(i) for i in range(100)) |
|
|
|
|
|
assert time.time() - t_start < 2, ( |
|
"The error should be raised immediately when submitting a new task " |
|
"but it took more than 2s." |
|
) |
|
|
|
del g |
|
|
|
|
|
@parametrize("backend", RETURN_GENERATOR_BACKENDS) |
|
@parametrize("return_as", ["generator", "generator_unordered"]) |
|
@parametrize("n_jobs", [1, 2, -2, -1]) |
|
def test_multiple_generator_call_managed(backend, return_as, n_jobs): |
|
|
|
|
|
|
|
with Parallel(n_jobs, backend=backend, return_as=return_as) as parallel: |
|
g = parallel(delayed(sleep)(10) for _ in range(10)) |
|
t_start = time.time() |
|
with raises(RuntimeError, match="This Parallel instance is already running"): |
|
g2 = parallel(delayed(id)(i) for i in range(100)) |
|
|
|
|
|
assert time.time() - t_start < 2, ( |
|
"The error should be raised immediately when submitting a new task " |
|
"but it took more than 2s." |
|
) |
|
|
|
del g |
|
|
|
|
|
@parametrize("backend", RETURN_GENERATOR_BACKENDS) |
|
@parametrize("return_as_1", ["generator", "generator_unordered"]) |
|
@parametrize("return_as_2", ["generator", "generator_unordered"]) |
|
@parametrize("n_jobs", [1, 2, -2, -1]) |
|
def test_multiple_generator_call_separated(backend, return_as_1, return_as_2, n_jobs): |
|
|
|
g = Parallel(n_jobs, backend=backend, return_as=return_as_1)( |
|
delayed(sqrt)(i**2) for i in range(10) |
|
) |
|
g2 = Parallel(n_jobs, backend=backend, return_as=return_as_2)( |
|
delayed(sqrt)(i**2) for i in range(10, 20) |
|
) |
|
|
|
if return_as_1 == "generator_unordered": |
|
g = sorted(g) |
|
|
|
if return_as_2 == "generator_unordered": |
|
g2 = sorted(g2) |
|
|
|
assert all(res == i for res, i in zip(g, range(10))) |
|
assert all(res == i for res, i in zip(g2, range(10, 20))) |
|
|
|
|
|
@parametrize( |
|
"backend, error", |
|
[ |
|
("loky", True), |
|
("threading", False), |
|
("sequential", False), |
|
], |
|
) |
|
@parametrize("return_as_1", ["generator", "generator_unordered"]) |
|
@parametrize("return_as_2", ["generator", "generator_unordered"]) |
|
def test_multiple_generator_call_separated_gc(backend, return_as_1, return_as_2, error): |
|
if (backend == "loky") and (mp is None): |
|
pytest.skip("Requires multiprocessing") |
|
|
|
|
|
|
|
parallel = Parallel(2, backend=backend, return_as=return_as_1) |
|
g = parallel(delayed(sleep)(10) for i in range(10)) |
|
g_wr = weakref.finalize(g, lambda: print("Generator collected")) |
|
ctx = ( |
|
raises(RuntimeError, match="The executor underlying Parallel") |
|
if error |
|
else nullcontext() |
|
) |
|
with ctx: |
|
|
|
|
|
|
|
|
|
t_start = time.time() |
|
g = Parallel(2, backend=backend, return_as=return_as_2)( |
|
delayed(sqrt)(i**2) for i in range(10, 20) |
|
) |
|
|
|
if return_as_2 == "generator_unordered": |
|
g = sorted(g) |
|
|
|
assert all(res == i for res, i in zip(g, range(10, 20))) |
|
|
|
assert time.time() - t_start < 5 |
|
|
|
|
|
retry = 0 |
|
while g_wr.alive and retry < 3: |
|
retry += 1 |
|
time.sleep(0.5) |
|
assert time.time() - t_start < 5 |
|
|
|
if parallel._effective_n_jobs() != 1: |
|
|
|
|
|
assert parallel._aborting |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_memmapping_leaks(backend, tmpdir): |
|
|
|
|
|
tmpdir = tmpdir.strpath |
|
|
|
|
|
|
|
with Parallel(n_jobs=2, max_nbytes=1, backend=backend, temp_folder=tmpdir) as p: |
|
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2) |
|
|
|
|
|
assert len(os.listdir(tmpdir)) > 0 |
|
|
|
|
|
|
|
for _ in range(100): |
|
if not os.listdir(tmpdir): |
|
break |
|
sleep(0.1) |
|
else: |
|
raise AssertionError("temporary directory of Parallel was not removed") |
|
|
|
|
|
p = Parallel(n_jobs=2, max_nbytes=1, backend=backend) |
|
p(delayed(check_memmap)(a) for a in [np.random.random(10)] * 2) |
|
|
|
for _ in range(100): |
|
if not os.listdir(tmpdir): |
|
break |
|
sleep(0.1) |
|
else: |
|
raise AssertionError("temporary directory of Parallel was not removed") |
|
|
|
|
|
@parametrize( |
|
"backend", ([None, "threading"] if mp is None else [None, "loky", "threading"]) |
|
) |
|
def test_lambda_expression(backend): |
|
|
|
results = Parallel(n_jobs=2, backend=backend)( |
|
delayed(lambda x: x**2)(i) for i in range(10) |
|
) |
|
assert results == [i**2 for i in range(10)] |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_backend_batch_statistics_reset(backend): |
|
"""Test that a parallel backend correctly resets its batch statistics.""" |
|
n_jobs = 2 |
|
n_inputs = 500 |
|
task_time = 2.0 / n_inputs |
|
|
|
p = Parallel(verbose=10, n_jobs=n_jobs, backend=backend) |
|
p(delayed(time.sleep)(task_time) for i in range(n_inputs)) |
|
assert p._backend._effective_batch_size == p._backend._DEFAULT_EFFECTIVE_BATCH_SIZE |
|
assert ( |
|
p._backend._smoothed_batch_duration |
|
== p._backend._DEFAULT_SMOOTHED_BATCH_DURATION |
|
) |
|
|
|
p(delayed(time.sleep)(task_time) for i in range(n_inputs)) |
|
assert p._backend._effective_batch_size == p._backend._DEFAULT_EFFECTIVE_BATCH_SIZE |
|
assert ( |
|
p._backend._smoothed_batch_duration |
|
== p._backend._DEFAULT_SMOOTHED_BATCH_DURATION |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_backend_hinting_and_constraints(context): |
|
for n_jobs in [1, 2, -1]: |
|
assert type(Parallel(n_jobs=n_jobs)._backend) is get_default_backend_instance() |
|
|
|
p = Parallel(n_jobs=n_jobs, prefer="threads") |
|
assert type(p._backend) is ThreadingBackend |
|
|
|
p = Parallel(n_jobs=n_jobs, prefer="processes") |
|
assert type(p._backend) is LokyBackend |
|
|
|
p = Parallel(n_jobs=n_jobs, require="sharedmem") |
|
assert type(p._backend) is ThreadingBackend |
|
|
|
|
|
|
|
p = Parallel(n_jobs=2, backend="loky", prefer="threads") |
|
assert type(p._backend) is LokyBackend |
|
|
|
with context("loky", n_jobs=2): |
|
|
|
|
|
p = Parallel(prefer="threads") |
|
assert type(p._backend) is LokyBackend |
|
assert p.n_jobs == 2 |
|
|
|
with context("loky", n_jobs=2): |
|
|
|
p = Parallel(n_jobs=3, prefer="threads") |
|
assert type(p._backend) is LokyBackend |
|
assert p.n_jobs == 3 |
|
|
|
with context("loky", n_jobs=2): |
|
|
|
|
|
|
|
|
|
p = Parallel(require="sharedmem") |
|
assert type(p._backend) is ThreadingBackend |
|
assert p.n_jobs == 1 |
|
|
|
with context("loky", n_jobs=2): |
|
p = Parallel(n_jobs=3, require="sharedmem") |
|
assert type(p._backend) is ThreadingBackend |
|
assert p.n_jobs == 3 |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_backend_hinting_and_constraints_with_custom_backends(capsys, context): |
|
|
|
|
|
class MyCustomThreadingBackend(ParallelBackendBase): |
|
supports_sharedmem = True |
|
use_threads = True |
|
|
|
def apply_async(self): |
|
pass |
|
|
|
def effective_n_jobs(self, n_jobs): |
|
return n_jobs |
|
|
|
with context(MyCustomThreadingBackend()): |
|
p = Parallel(n_jobs=2, prefer="processes") |
|
assert type(p._backend) is MyCustomThreadingBackend |
|
|
|
p = Parallel(n_jobs=2, require="sharedmem") |
|
assert type(p._backend) is MyCustomThreadingBackend |
|
|
|
class MyCustomProcessingBackend(ParallelBackendBase): |
|
supports_sharedmem = False |
|
use_threads = False |
|
|
|
def apply_async(self): |
|
pass |
|
|
|
def effective_n_jobs(self, n_jobs): |
|
return n_jobs |
|
|
|
with context(MyCustomProcessingBackend()): |
|
p = Parallel(n_jobs=2, prefer="processes") |
|
assert type(p._backend) is MyCustomProcessingBackend |
|
|
|
out, err = capsys.readouterr() |
|
assert out == "" |
|
assert err == "" |
|
|
|
p = Parallel(n_jobs=2, require="sharedmem", verbose=10) |
|
assert type(p._backend) is ThreadingBackend |
|
|
|
out, err = capsys.readouterr() |
|
expected = ( |
|
"Using ThreadingBackend as joblib backend " |
|
"instead of MyCustomProcessingBackend as the latter " |
|
"does not provide shared memory semantics." |
|
) |
|
assert out.strip() == expected |
|
assert err == "" |
|
|
|
with raises(ValueError): |
|
Parallel(backend=MyCustomProcessingBackend(), require="sharedmem") |
|
|
|
|
|
def test_invalid_backend_hinting_and_constraints(): |
|
with raises(ValueError): |
|
Parallel(prefer="invalid") |
|
|
|
with raises(ValueError): |
|
Parallel(require="invalid") |
|
|
|
with raises(ValueError): |
|
|
|
|
|
Parallel(prefer="processes", require="sharedmem") |
|
|
|
if mp is not None: |
|
|
|
|
|
with raises(ValueError): |
|
Parallel(backend="loky", require="sharedmem") |
|
with raises(ValueError): |
|
Parallel(backend="multiprocessing", require="sharedmem") |
|
|
|
|
|
def _recursive_backend_info(limit=3, **kwargs): |
|
"""Perform nested parallel calls and introspect the backend on the way""" |
|
|
|
with Parallel(n_jobs=2) as p: |
|
this_level = [(type(p._backend).__name__, p._backend.nesting_level)] |
|
if limit == 0: |
|
return this_level |
|
results = p( |
|
delayed(_recursive_backend_info)(limit=limit - 1, **kwargs) |
|
for i in range(1) |
|
) |
|
return this_level + results[0] |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("backend", ["loky", "threading"]) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_nested_parallelism_limit(context, backend): |
|
with context(backend, n_jobs=2): |
|
backend_types_and_levels = _recursive_backend_info() |
|
|
|
top_level_backend_type = backend.title() + "Backend" |
|
expected_types_and_levels = [ |
|
(top_level_backend_type, 0), |
|
("ThreadingBackend", 1), |
|
("SequentialBackend", 2), |
|
("SequentialBackend", 2), |
|
] |
|
assert backend_types_and_levels == expected_types_and_levels |
|
|
|
|
|
def _recursive_parallel(nesting_limit=None): |
|
"""A horrible function that does recursive parallel calls""" |
|
return Parallel()(delayed(_recursive_parallel)() for i in range(2)) |
|
|
|
|
|
@pytest.mark.no_cover |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
@parametrize("backend", (["threading"] if mp is None else ["loky", "threading"])) |
|
def test_thread_bomb_mitigation(context, backend): |
|
|
|
|
|
|
|
with context(backend, n_jobs=2): |
|
with raises(BaseException) as excinfo: |
|
_recursive_parallel() |
|
exc = excinfo.value |
|
if backend == "loky": |
|
|
|
|
|
from joblib.externals.loky.process_executor import TerminatedWorkerError |
|
|
|
if isinstance(exc, (TerminatedWorkerError, PicklingError)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
pytest.xfail("Loky worker crash when serializing RecursionError") |
|
|
|
assert isinstance(exc, RecursionError) |
|
|
|
|
|
def _run_parallel_sum(): |
|
env_vars = {} |
|
for var in [ |
|
"OMP_NUM_THREADS", |
|
"OPENBLAS_NUM_THREADS", |
|
"MKL_NUM_THREADS", |
|
"VECLIB_MAXIMUM_THREADS", |
|
"NUMEXPR_NUM_THREADS", |
|
"NUMBA_NUM_THREADS", |
|
"ENABLE_IPC", |
|
]: |
|
env_vars[var] = os.environ.get(var) |
|
return env_vars, parallel_sum(100) |
|
|
|
|
|
@parametrize("backend", ([None, "loky"] if mp is not None else [None])) |
|
@skipif(parallel_sum is None, reason="Need OpenMP helper compiled") |
|
def test_parallel_thread_limit(backend): |
|
results = Parallel(n_jobs=2, backend=backend)( |
|
delayed(_run_parallel_sum)() for _ in range(2) |
|
) |
|
expected_num_threads = max(cpu_count() // 2, 1) |
|
for worker_env_vars, omp_num_threads in results: |
|
assert omp_num_threads == expected_num_threads |
|
for name, value in worker_env_vars.items(): |
|
if name.endswith("_THREADS"): |
|
assert value == str(expected_num_threads) |
|
else: |
|
assert name == "ENABLE_IPC" |
|
assert value == "1" |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
@skipif(distributed is not None, reason="This test requires dask") |
|
def test_dask_backend_when_dask_not_installed(context): |
|
with raises(ValueError, match="Please install dask"): |
|
context("dask") |
|
|
|
|
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_zero_worker_backend(context): |
|
|
|
|
|
class ZeroWorkerBackend(ThreadingBackend): |
|
def configure(self, *args, **kwargs): |
|
return 0 |
|
|
|
def apply_async(self, func, callback=None): |
|
raise TimeoutError("No worker available") |
|
|
|
def effective_n_jobs(self, n_jobs): |
|
return 0 |
|
|
|
expected_msg = "ZeroWorkerBackend has no active worker" |
|
with context(ZeroWorkerBackend()): |
|
with pytest.raises(RuntimeError, match=expected_msg): |
|
Parallel(n_jobs=2)(delayed(id)(i) for i in range(2)) |
|
|
|
|
|
def test_globals_update_at_each_parallel_call(): |
|
|
|
|
|
|
|
|
|
global MY_GLOBAL_VARIABLE |
|
MY_GLOBAL_VARIABLE = "original value" |
|
|
|
def check_globals(): |
|
global MY_GLOBAL_VARIABLE |
|
return MY_GLOBAL_VARIABLE |
|
|
|
assert check_globals() == "original value" |
|
|
|
workers_global_variable = Parallel(n_jobs=2)( |
|
delayed(check_globals)() for i in range(2) |
|
) |
|
assert set(workers_global_variable) == {"original value"} |
|
|
|
|
|
|
|
MY_GLOBAL_VARIABLE = "changed value" |
|
assert check_globals() == "changed value" |
|
|
|
workers_global_variable = Parallel(n_jobs=2)( |
|
delayed(check_globals)() for i in range(2) |
|
) |
|
assert set(workers_global_variable) == {"changed value"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_numpy_threadpool_limits(): |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
a = np.random.randn(100, 100) |
|
np.dot(a, a) |
|
threadpoolctl = pytest.importorskip("threadpoolctl") |
|
return threadpoolctl.threadpool_info() |
|
|
|
|
|
def _parent_max_num_threads_for(child_module, parent_info): |
|
for parent_module in parent_info: |
|
if parent_module["filepath"] == child_module["filepath"]: |
|
return parent_module["num_threads"] |
|
raise ValueError( |
|
"An unexpected module was loaded in child:\n{}".format(child_module) |
|
) |
|
|
|
|
|
def check_child_num_threads(workers_info, parent_info, num_threads): |
|
|
|
|
|
|
|
for child_threadpool_info in workers_info: |
|
for child_module in child_threadpool_info: |
|
parent_max_num_threads = _parent_max_num_threads_for( |
|
child_module, parent_info |
|
) |
|
expected = {min(num_threads, parent_max_num_threads), num_threads} |
|
assert child_module["num_threads"] in expected |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
@parametrize("n_jobs", [2, 4, -2, -1]) |
|
def test_threadpool_limitation_in_child_loky(n_jobs): |
|
|
|
|
|
|
|
|
|
parent_info = _check_numpy_threadpool_limits() |
|
if len(parent_info) == 0: |
|
pytest.skip(reason="Need a version of numpy linked to BLAS") |
|
|
|
workers_threadpool_infos = Parallel(backend="loky", n_jobs=n_jobs)( |
|
delayed(_check_numpy_threadpool_limits)() for i in range(2) |
|
) |
|
|
|
n_jobs = effective_n_jobs(n_jobs) |
|
if n_jobs == 1: |
|
expected_child_num_threads = parent_info[0]["num_threads"] |
|
else: |
|
expected_child_num_threads = max(cpu_count() // n_jobs, 1) |
|
|
|
check_child_num_threads( |
|
workers_threadpool_infos, parent_info, expected_child_num_threads |
|
) |
|
|
|
|
|
@with_numpy |
|
@with_multiprocessing |
|
@parametrize("inner_max_num_threads", [1, 2, 4, None]) |
|
@parametrize("n_jobs", [2, -1]) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_threadpool_limitation_in_child_context(context, n_jobs, inner_max_num_threads): |
|
|
|
|
|
|
|
|
|
parent_info = _check_numpy_threadpool_limits() |
|
if len(parent_info) == 0: |
|
pytest.skip(reason="Need a version of numpy linked to BLAS") |
|
|
|
with context("loky", inner_max_num_threads=inner_max_num_threads): |
|
workers_threadpool_infos = Parallel(n_jobs=n_jobs)( |
|
delayed(_check_numpy_threadpool_limits)() for i in range(2) |
|
) |
|
|
|
n_jobs = effective_n_jobs(n_jobs) |
|
if n_jobs == 1: |
|
expected_child_num_threads = parent_info[0]["num_threads"] |
|
elif inner_max_num_threads is None: |
|
expected_child_num_threads = max(cpu_count() // n_jobs, 1) |
|
else: |
|
expected_child_num_threads = inner_max_num_threads |
|
|
|
check_child_num_threads( |
|
workers_threadpool_infos, parent_info, expected_child_num_threads |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("n_jobs", [2, -1]) |
|
@parametrize("var_name", ["OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "OMP_NUM_THREADS"]) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_threadpool_limitation_in_child_override(context, n_jobs, var_name): |
|
|
|
|
|
|
|
|
|
if effective_n_jobs(n_jobs) == 1: |
|
pytest.skip("Skip test when n_jobs == 1") |
|
|
|
|
|
|
|
get_reusable_executor(reuse=True).shutdown() |
|
|
|
def _get_env(var_name): |
|
return os.environ.get(var_name) |
|
|
|
original_var_value = os.environ.get(var_name) |
|
try: |
|
os.environ[var_name] = "4" |
|
|
|
results = Parallel(n_jobs=n_jobs)(delayed(_get_env)(var_name) for i in range(2)) |
|
assert results == ["4", "4"] |
|
|
|
with context("loky", inner_max_num_threads=1): |
|
results = Parallel(n_jobs=n_jobs)( |
|
delayed(_get_env)(var_name) for i in range(2) |
|
) |
|
assert results == ["1", "1"] |
|
|
|
finally: |
|
if original_var_value is None: |
|
del os.environ[var_name] |
|
else: |
|
os.environ[var_name] = original_var_value |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("n_jobs", [2, 4, -1]) |
|
def test_loky_reuse_workers(n_jobs): |
|
|
|
|
|
|
|
def parallel_call(n_jobs): |
|
x = range(10) |
|
Parallel(n_jobs=n_jobs)(delayed(sum)(x) for i in range(10)) |
|
|
|
|
|
parallel_call(n_jobs) |
|
first_executor = get_reusable_executor(reuse=True) |
|
|
|
|
|
|
|
for _ in range(10): |
|
parallel_call(n_jobs) |
|
executor = get_reusable_executor(reuse=True) |
|
assert executor == first_executor |
|
|
|
|
|
def _set_initialized(status): |
|
status[os.getpid()] = "initialized" |
|
|
|
|
|
def _check_status(status, n_jobs, wait_workers=False): |
|
pid = os.getpid() |
|
state = status.get(pid, None) |
|
assert state in ("initialized", "started"), ( |
|
f"worker should have been in initialized state, got {state}" |
|
) |
|
if not wait_workers: |
|
return |
|
|
|
status[pid] = "started" |
|
|
|
deadline = time.time() + 30 |
|
n_started = len([pid for pid, v in status.items() if v == "started"]) |
|
while time.time() < deadline and n_started < n_jobs: |
|
time.sleep(0.1) |
|
n_started = len([pid for pid, v in status.items() if v == "started"]) |
|
|
|
if time.time() >= deadline: |
|
raise TimeoutError("Waited more than 30s to start all the workers") |
|
|
|
return pid |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("n_jobs", [2, 4]) |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
@parametrize("context", [parallel_config, parallel_backend]) |
|
def test_initializer_context(n_jobs, backend, context): |
|
manager = mp.Manager() |
|
status = manager.dict() |
|
|
|
|
|
with context( |
|
backend=backend, |
|
n_jobs=n_jobs, |
|
initializer=_set_initialized, |
|
initargs=(status,), |
|
): |
|
|
|
Parallel()(delayed(_check_status)(status, n_jobs) for i in range(100)) |
|
|
|
|
|
@with_multiprocessing |
|
@parametrize("n_jobs", [2, 4]) |
|
@parametrize("backend", PROCESS_BACKENDS) |
|
def test_initializer_parallel(n_jobs, backend): |
|
manager = mp.Manager() |
|
status = manager.dict() |
|
|
|
|
|
|
|
Parallel( |
|
backend=backend, |
|
n_jobs=n_jobs, |
|
initializer=_set_initialized, |
|
initargs=(status,), |
|
)(delayed(_check_status)(status, n_jobs) for i in range(100)) |
|
|
|
|
|
@with_multiprocessing |
|
@pytest.mark.parametrize("n_jobs", [2, 4]) |
|
def test_initializer_reused(n_jobs): |
|
|
|
|
|
n_repetitions = 3 |
|
manager = mp.Manager() |
|
status = manager.dict() |
|
|
|
pids = set() |
|
for i in range(n_repetitions): |
|
results = Parallel( |
|
backend="loky", |
|
n_jobs=n_jobs, |
|
initializer=_set_initialized, |
|
initargs=(status,), |
|
)( |
|
delayed(_check_status)(status, n_jobs, wait_workers=True) |
|
for i in range(n_jobs) |
|
) |
|
pids = pids.union(set(results)) |
|
assert len(pids) == n_jobs, ( |
|
"The workers should be reused when the initializer is the same" |
|
) |
|
|
|
|
|
@with_multiprocessing |
|
@pytest.mark.parametrize("n_jobs", [2, 4]) |
|
def test_initializer_not_reused(n_jobs): |
|
|
|
|
|
|
|
n_repetitions = 3 |
|
manager = mp.Manager() |
|
|
|
pids = set() |
|
for i in range(n_repetitions): |
|
status = manager.dict() |
|
results = Parallel( |
|
backend="loky", |
|
n_jobs=n_jobs, |
|
initializer=_set_initialized, |
|
initargs=(status,), |
|
)( |
|
delayed(_check_status)(status, n_jobs, wait_workers=True) |
|
for i in range(n_jobs) |
|
) |
|
pids = pids.union(set(results)) |
|
assert len(pids) == n_repetitions * n_jobs, ( |
|
"The workers should not be reused when the initializer arguments change" |
|
) |
|
|