from __future__ import absolute_import, division, print_function import os import warnings from random import random from time import sleep from uuid import uuid4 import pytest from .. import Parallel, delayed, parallel_backend, parallel_config from .._dask import DaskDistributedBackend from ..parallel import AutoBatchingMixin, ThreadingBackend from .common import np, with_numpy from .test_parallel import ( _recursive_backend_info, _test_deadlock_with_generator, _test_parallel_unordered_generator_returns_fastest_first, # noqa: E501 ) distributed = pytest.importorskip("distributed") dask = pytest.importorskip("dask") # These imports need to be after the pytest.importorskip hence the noqa: E402 from distributed import Client, LocalCluster, get_client # noqa: E402 from distributed.metrics import time # noqa: E402 # Note: pytest requires to manually import all fixtures used in the test # and their dependencies. from distributed.utils_test import cleanup, cluster, inc # noqa: E402, F401 @pytest.fixture(scope="function", autouse=True) def avoid_dask_env_leaks(tmp_path): # when starting a dask nanny, the environment variable might change. # this fixture makes sure the environment is reset after the test. from joblib._parallel_backends import ParallelBackendBase old_value = {k: os.environ.get(k) for k in ParallelBackendBase.MAX_NUM_THREADS_VARS} yield # Reset the environment variables to their original values for k, v in old_value.items(): if v is None: os.environ.pop(k, None) else: os.environ[k] = v def noop(*args, **kwargs): pass def slow_raise_value_error(condition, duration=0.05): sleep(duration) if condition: raise ValueError("condition evaluated to True") def count_events(event_name, client): worker_events = client.run(lambda dask_worker: dask_worker.log) event_counts = {} for w, events in worker_events.items(): event_counts[w] = len( [event for event in list(events) if event[1] == event_name] ) return event_counts def test_simple(loop): with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as client: # noqa: F841 with parallel_config(backend="dask"): seq = Parallel()(delayed(inc)(i) for i in range(10)) assert seq == [inc(i) for i in range(10)] with pytest.raises(ValueError): Parallel()( delayed(slow_raise_value_error)(i == 3) for i in range(10) ) seq = Parallel()(delayed(inc)(i) for i in range(10)) assert seq == [inc(i) for i in range(10)] def test_dask_backend_uses_autobatching(loop): assert ( DaskDistributedBackend.compute_batch_size is AutoBatchingMixin.compute_batch_size ) with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as client: # noqa: F841 with parallel_config(backend="dask"): with Parallel() as parallel: # The backend should be initialized with a default # batch size of 1: backend = parallel._backend assert isinstance(backend, DaskDistributedBackend) assert backend.parallel is parallel assert backend._effective_batch_size == 1 # Launch many short tasks that should trigger # auto-batching: parallel(delayed(lambda: None)() for _ in range(int(1e4))) assert backend._effective_batch_size > 10 @pytest.mark.parametrize("n_jobs", [2, -1]) @pytest.mark.parametrize("context", [parallel_config, parallel_backend]) def test_parallel_unordered_generator_returns_fastest_first_with_dask(n_jobs, context): with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"): _test_parallel_unordered_generator_returns_fastest_first(None, n_jobs) @with_numpy @pytest.mark.parametrize("n_jobs", [2, -1]) @pytest.mark.parametrize("return_as", ["generator", "generator_unordered"]) @pytest.mark.parametrize("context", [parallel_config, parallel_backend]) def test_deadlock_with_generator_and_dask(context, return_as, n_jobs): with distributed.Client(n_workers=2, threads_per_worker=2), context("dask"): _test_deadlock_with_generator(None, return_as, n_jobs) @with_numpy @pytest.mark.parametrize("context", [parallel_config, parallel_backend]) def test_nested_parallelism_with_dask(context): with distributed.Client(n_workers=2, threads_per_worker=2): # 10 MB of data as argument to trigger implicit scattering data = np.ones(int(1e7), dtype=np.uint8) for i in range(2): with context("dask"): backend_types_and_levels = _recursive_backend_info(data=data) assert len(backend_types_and_levels) == 4 assert all( name == "DaskDistributedBackend" for name, _ in backend_types_and_levels ) # No argument with context("dask"): backend_types_and_levels = _recursive_backend_info() assert len(backend_types_and_levels) == 4 assert all( name == "DaskDistributedBackend" for name, _ in backend_types_and_levels ) def random2(): return random() def test_dont_assume_function_purity(loop): with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as client: # noqa: F841 with parallel_config(backend="dask"): x, y = Parallel()(delayed(random2)() for i in range(2)) assert x != y @pytest.mark.parametrize("mixed", [True, False]) def test_dask_funcname(loop, mixed): from joblib._dask import Batch if not mixed: tasks = [delayed(inc)(i) for i in range(4)] batch_repr = "batch_of_inc_4_calls" else: tasks = [delayed(abs)(i) if i % 2 else delayed(inc)(i) for i in range(4)] batch_repr = "mixed_batch_of_inc_4_calls" assert repr(Batch(tasks)) == batch_repr with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as client: with parallel_config(backend="dask"): _ = Parallel(batch_size=2, pre_dispatch="all")(tasks) def f(dask_scheduler): return list(dask_scheduler.transition_log) batch_repr = batch_repr.replace("4", "2") log = client.run_on_scheduler(f) assert all("batch_of_inc" in tup[0] for tup in log) def test_no_undesired_distributed_cache_hit(): # Dask has a pickle cache for callables that are called many times. Because # the dask backends used to wrap both the functions and the arguments # under instances of the Batch callable class this caching mechanism could # lead to bugs as described in: https://github.com/joblib/joblib/pull/1055 # The joblib-dask backend has been refactored to avoid bundling the # arguments as an attribute of the Batch instance to avoid this problem. # This test serves as non-regression problem. # Use a large number of input arguments to give the AutoBatchingMixin # enough tasks to kick-in. lists = [[] for _ in range(100)] np = pytest.importorskip("numpy") X = np.arange(int(1e6)) def isolated_operation(list_, data=None): if data is not None: np.testing.assert_array_equal(data, X) list_.append(uuid4().hex) return list_ cluster = LocalCluster(n_workers=1, threads_per_worker=2) client = Client(cluster) try: with parallel_config(backend="dask"): # dispatches joblib.parallel.BatchedCalls res = Parallel()(delayed(isolated_operation)(list_) for list_ in lists) # The original arguments should not have been mutated as the mutation # happens in the dask worker process. assert lists == [[] for _ in range(100)] # Here we did not pass any large numpy array as argument to # isolated_operation so no scattering event should happen under the # hood. counts = count_events("receive-from-scatter", client) assert sum(counts.values()) == 0 assert all([len(r) == 1 for r in res]) with parallel_config(backend="dask"): # Append a large array which will be scattered by dask, and # dispatch joblib._dask.Batch res = Parallel()( delayed(isolated_operation)(list_, data=X) for list_ in lists ) # This time, auto-scattering should have kicked it. counts = count_events("receive-from-scatter", client) assert sum(counts.values()) > 0 assert all([len(r) == 1 for r in res]) finally: client.close(timeout=30) cluster.close(timeout=30) class CountSerialized(object): def __init__(self, x): self.x = x self.count = 0 def __add__(self, other): return self.x + getattr(other, "x", other) __radd__ = __add__ def __reduce__(self): self.count += 1 return (CountSerialized, (self.x,)) def add5(a, b, c, d=0, e=0): return a + b + c + d + e def test_manual_scatter(loop): # Let's check that the number of times scattered and non-scattered # variables are serialized is consistent between `joblib.Parallel` calls # and equivalent native `client.submit` call. # Number of serializations can vary from dask to another, so this test only # checks that `joblib.Parallel` does not add more serialization steps than # a native `client.submit` call, but does not check for an exact number of # serialization steps. w, x, y, z = (CountSerialized(i) for i in range(4)) f = delayed(add5) tasks = [f(x, y, z, d=4, e=5) for _ in range(10)] tasks += [ f(x, z, y, d=5, e=4), f(y, x, z, d=x, e=5), f(z, z, x, d=z, e=y), ] expected = [func(*args, **kwargs) for func, args, kwargs in tasks] with cluster() as (s, _): with Client(s["address"], loop=loop) as client: # noqa: F841 with parallel_config(backend="dask", scatter=[w, x, y]): results_parallel = Parallel(batch_size=1)(tasks) assert results_parallel == expected # Check that an error is raised for bad arguments, as scatter must # take a list/tuple with pytest.raises(TypeError): with parallel_config(backend="dask", loop=loop, scatter=1): pass # Scattered variables only serialized during scatter. Checking with an # extra variable as this count can vary from one dask version # to another. n_serialization_scatter_with_parallel = w.count assert x.count == n_serialization_scatter_with_parallel assert y.count == n_serialization_scatter_with_parallel n_serialization_with_parallel = z.count # Reset the cluster and the serialization count for var in (w, x, y, z): var.count = 0 with cluster() as (s, _): with Client(s["address"], loop=loop) as client: # noqa: F841 scattered = dict() for obj in w, x, y: scattered[id(obj)] = client.scatter(obj, broadcast=True) results_native = [ client.submit( func, *(scattered.get(id(arg), arg) for arg in args), **dict( (key, scattered.get(id(value), value)) for (key, value) in kwargs.items() ), key=str(uuid4()), ).result() for (func, args, kwargs) in tasks ] assert results_native == expected # Now check that the number of serialization steps is the same for joblib # and native dask calls. n_serialization_scatter_native = w.count assert x.count == n_serialization_scatter_native assert y.count == n_serialization_scatter_native assert n_serialization_scatter_with_parallel == n_serialization_scatter_native distributed_version = tuple(int(v) for v in distributed.__version__.split(".")) if distributed_version < (2023, 4): # Previous to 2023.4, the serialization was adding an extra call to # __reduce__ for the last job `f(z, z, x, d=z, e=y)`, because `z` # appears both in the args and kwargs, which is not the case when # running with joblib. Cope with this discrepancy. assert z.count == n_serialization_with_parallel + 1 else: assert z.count == n_serialization_with_parallel # When the same IOLoop is used for multiple clients in a row, use # loop_in_thread instead of loop to prevent the Client from closing it. See # dask/distributed #4112 def test_auto_scatter(loop_in_thread): np = pytest.importorskip("numpy") data1 = np.ones(int(1e4), dtype=np.uint8) data2 = np.ones(int(1e4), dtype=np.uint8) data_to_process = ([data1] * 3) + ([data2] * 3) with cluster() as (s, [a, b]): with Client(s["address"], loop=loop_in_thread) as client: with parallel_config(backend="dask"): # Passing the same data as arg and kwarg triggers a single # scatter operation whose result is reused. Parallel()( delayed(noop)(data, data, i, opt=data) for i, data in enumerate(data_to_process) ) # By default large array are automatically scattered with # broadcast=1 which means that one worker must directly receive # the data from the scatter operation once. counts = count_events("receive-from-scatter", client) assert counts[a["address"]] + counts[b["address"]] == 2 with cluster() as (s, [a, b]): with Client(s["address"], loop=loop_in_thread) as client: with parallel_config(backend="dask"): Parallel()(delayed(noop)(data1[:3], i) for i in range(5)) # Small arrays are passed within the task definition without going # through a scatter operation. counts = count_events("receive-from-scatter", client) assert counts[a["address"]] == 0 assert counts[b["address"]] == 0 @pytest.mark.parametrize("retry_no", list(range(2))) def test_nested_scatter(loop, retry_no): np = pytest.importorskip("numpy") NUM_INNER_TASKS = 10 NUM_OUTER_TASKS = 10 def my_sum(x, i, j): return np.sum(x) def outer_function_joblib(array, i): client = get_client() # noqa with parallel_config(backend="dask"): results = Parallel()( delayed(my_sum)(array[j:], i, j) for j in range(NUM_INNER_TASKS) ) return sum(results) with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as _: with parallel_config(backend="dask"): my_array = np.ones(10000) _ = Parallel()( delayed(outer_function_joblib)(my_array[i:], i) for i in range(NUM_OUTER_TASKS) ) def test_nested_backend_context_manager(loop_in_thread): def get_nested_pids(): pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2))) pids |= set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2))) return pids with cluster() as (s, [a, b]): with Client(s["address"], loop=loop_in_thread) as client: with parallel_config(backend="dask"): pid_groups = Parallel(n_jobs=2)( delayed(get_nested_pids)() for _ in range(10) ) for pid_group in pid_groups: assert len(set(pid_group)) <= 2 # No deadlocks with Client(s["address"], loop=loop_in_thread) as client: # noqa: F841 with parallel_config(backend="dask"): pid_groups = Parallel(n_jobs=2)( delayed(get_nested_pids)() for _ in range(10) ) for pid_group in pid_groups: assert len(set(pid_group)) <= 2 def test_nested_backend_context_manager_implicit_n_jobs(loop): # Check that Parallel with no explicit n_jobs value automatically selects # all the dask workers, including in nested calls. def _backend_type(p): return p._backend.__class__.__name__ def get_nested_implicit_n_jobs(): with Parallel() as p: return _backend_type(p), p.n_jobs with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as client: # noqa: F841 with parallel_config(backend="dask"): with Parallel() as p: assert _backend_type(p) == "DaskDistributedBackend" assert p.n_jobs == -1 all_nested_n_jobs = p( delayed(get_nested_implicit_n_jobs)() for _ in range(2) ) for backend_type, nested_n_jobs in all_nested_n_jobs: assert backend_type == "DaskDistributedBackend" assert nested_n_jobs == -1 def test_errors(loop): with pytest.raises(ValueError) as info: with parallel_config(backend="dask"): pass assert "create a dask client" in str(info.value).lower() def test_correct_nested_backend(loop): with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as client: # noqa: F841 # No requirement, should be us with parallel_config(backend="dask"): result = Parallel(n_jobs=2)( delayed(outer)(nested_require=None) for _ in range(1) ) assert isinstance(result[0][0][0], DaskDistributedBackend) # Require threads, should be threading with parallel_config(backend="dask"): result = Parallel(n_jobs=2)( delayed(outer)(nested_require="sharedmem") for _ in range(1) ) assert isinstance(result[0][0][0], ThreadingBackend) def outer(nested_require): return Parallel(n_jobs=2, prefer="threads")( delayed(middle)(nested_require) for _ in range(1) ) def middle(require): return Parallel(n_jobs=2, require=require)(delayed(inner)() for _ in range(1)) def inner(): return Parallel()._backend def test_secede_with_no_processes(loop): # https://github.com/dask/distributed/issues/1775 with Client(loop=loop, processes=False, set_as_default=True): with parallel_config(backend="dask"): Parallel(n_jobs=4)(delayed(id)(i) for i in range(2)) def _worker_address(_): from distributed import get_worker return get_worker().address def test_dask_backend_keywords(loop): with cluster() as (s, [a, b]): with Client(s["address"], loop=loop) as client: # noqa: F841 with parallel_config(backend="dask", workers=a["address"]): seq = Parallel()(delayed(_worker_address)(i) for i in range(10)) assert seq == [a["address"]] * 10 with parallel_config(backend="dask", workers=b["address"]): seq = Parallel()(delayed(_worker_address)(i) for i in range(10)) assert seq == [b["address"]] * 10 def test_scheduler_tasks_cleanup(loop): with Client(processes=False, loop=loop) as client: with parallel_config(backend="dask"): Parallel()(delayed(inc)(i) for i in range(10)) start = time() while client.cluster.scheduler.tasks: sleep(0.01) assert time() < start + 5 assert not client.futures @pytest.mark.parametrize("cluster_strategy", ["adaptive", "late_scaling"]) @pytest.mark.skipif( distributed.__version__ <= "2.1.1" and distributed.__version__ >= "1.28.0", reason="distributed bug - https://github.com/dask/distributed/pull/2841", ) def test_wait_for_workers(cluster_strategy): cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2) client = Client(cluster) if cluster_strategy == "adaptive": cluster.adapt(minimum=0, maximum=2) elif cluster_strategy == "late_scaling": # Tell the cluster to start workers but this is a non-blocking call # and new workers might take time to connect. In this case the Parallel # call should wait for at least one worker to come up before starting # to schedule work. cluster.scale(2) try: with parallel_config(backend="dask"): # The following should wait a bit for at least one worker to # become available. Parallel()(delayed(inc)(i) for i in range(10)) finally: client.close() cluster.close() def test_wait_for_workers_timeout(): # Start a cluster with 0 worker: cluster = LocalCluster(n_workers=0, processes=False, threads_per_worker=2) client = Client(cluster) try: with parallel_config(backend="dask", wait_for_workers_timeout=0.1): # Short timeout: DaskDistributedBackend msg = "DaskDistributedBackend has no worker after 0.1 seconds." with pytest.raises(TimeoutError, match=msg): Parallel()(delayed(inc)(i) for i in range(10)) with parallel_config(backend="dask", wait_for_workers_timeout=0): # No timeout: fallback to generic joblib failure: msg = "DaskDistributedBackend has no active worker" with pytest.raises(RuntimeError, match=msg): Parallel()(delayed(inc)(i) for i in range(10)) finally: client.close() cluster.close() @pytest.mark.parametrize("backend", ["loky", "multiprocessing"]) def test_joblib_warning_inside_dask_daemonic_worker(backend): cluster = LocalCluster(n_workers=2) client = Client(cluster) try: def func_using_joblib_parallel(): # Somehow trying to check the warning type here (e.g. with # pytest.warns(UserWarning)) make the test hang. Work-around: # return the warning record to the client and the warning check is # done client-side. with warnings.catch_warnings(record=True) as record: Parallel(n_jobs=2, backend=backend)(delayed(inc)(i) for i in range(10)) return record fut = client.submit(func_using_joblib_parallel) record = fut.result() assert len(record) == 1 warning = record[0].message assert isinstance(warning, UserWarning) assert "distributed.worker.daemon" in str(warning) finally: client.close(timeout=30) cluster.close(timeout=30)