File size: 13,217 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 |
from __future__ import absolute_import, division, print_function
import asyncio
import concurrent.futures
import contextlib
import time
import weakref
from uuid import uuid4
from ._utils import (
_retrieve_traceback_capturing_wrapped_call,
_TracebackCapturingWrapper,
)
from .parallel import AutoBatchingMixin, ParallelBackendBase, parallel_config
try:
import dask
import distributed
except ImportError:
dask = None
distributed = None
if dask is not None and distributed is not None:
from dask.distributed import (
Client,
as_completed,
get_client,
rejoin,
secede,
)
from dask.sizeof import sizeof
from dask.utils import funcname
from distributed.utils import thread_state
try:
# asyncio.TimeoutError, Python3-only error thrown by recent versions of
# distributed
from distributed.utils import TimeoutError as _TimeoutError
except ImportError:
from tornado.gen import TimeoutError as _TimeoutError
def is_weakrefable(obj):
try:
weakref.ref(obj)
return True
except TypeError:
return False
class _WeakKeyDictionary:
"""A variant of weakref.WeakKeyDictionary for unhashable objects.
This datastructure is used to store futures for broadcasted data objects
such as large numpy arrays or pandas dataframes that are not hashable and
therefore cannot be used as keys of traditional python dicts.
Furthermore using a dict with id(array) as key is not safe because the
Python is likely to reuse id of recently collected arrays.
"""
def __init__(self):
self._data = {}
def __getitem__(self, obj):
ref, val = self._data[id(obj)]
if ref() is not obj:
# In case of a race condition with on_destroy.
raise KeyError(obj)
return val
def __setitem__(self, obj, value):
key = id(obj)
try:
ref, _ = self._data[key]
if ref() is not obj:
# In case of race condition with on_destroy.
raise KeyError(obj)
except KeyError:
# Insert the new entry in the mapping along with a weakref
# callback to automatically delete the entry from the mapping
# as soon as the object used as key is garbage collected.
def on_destroy(_):
del self._data[key]
ref = weakref.ref(obj, on_destroy)
self._data[key] = ref, value
def __len__(self):
return len(self._data)
def clear(self):
self._data.clear()
def _funcname(x):
try:
if isinstance(x, list):
x = x[0][0]
except Exception:
pass
return funcname(x)
def _make_tasks_summary(tasks):
"""Summarize of list of (func, args, kwargs) function calls"""
unique_funcs = {func for func, args, kwargs in tasks}
if len(unique_funcs) == 1:
mixed = False
else:
mixed = True
return len(tasks), mixed, _funcname(tasks)
class Batch:
"""dask-compatible wrapper that executes a batch of tasks"""
def __init__(self, tasks):
# collect some metadata from the tasks to ease Batch calls
# introspection when debugging
self._num_tasks, self._mixed, self._funcname = _make_tasks_summary(tasks)
def __call__(self, tasks=None):
results = []
with parallel_config(backend="dask"):
for func, args, kwargs in tasks:
results.append(func(*args, **kwargs))
return results
def __repr__(self):
descr = f"batch_of_{self._funcname}_{self._num_tasks}_calls"
if self._mixed:
descr = "mixed_" + descr
return descr
def _joblib_probe_task():
# Noop used by the joblib connector to probe when workers are ready.
pass
class DaskDistributedBackend(AutoBatchingMixin, ParallelBackendBase):
MIN_IDEAL_BATCH_DURATION = 0.2
MAX_IDEAL_BATCH_DURATION = 1.0
supports_retrieve_callback = True
default_n_jobs = -1
def __init__(
self,
scheduler_host=None,
scatter=None,
client=None,
loop=None,
wait_for_workers_timeout=10,
**submit_kwargs,
):
super().__init__()
if distributed is None:
msg = (
"You are trying to use 'dask' as a joblib parallel backend "
"but dask is not installed. Please install dask "
"to fix this error."
)
raise ValueError(msg)
if client is None:
if scheduler_host:
client = Client(scheduler_host, loop=loop, set_as_default=False)
else:
try:
client = get_client()
except ValueError as e:
msg = (
"To use Joblib with Dask first create a Dask Client"
"\n\n"
" from dask.distributed import Client\n"
" client = Client()\n"
"or\n"
" client = Client('scheduler-address:8786')"
)
raise ValueError(msg) from e
self.client = client
if scatter is not None and not isinstance(scatter, (list, tuple)):
raise TypeError(
"scatter must be a list/tuple, got `%s`" % type(scatter).__name__
)
if scatter is not None and len(scatter) > 0:
# Keep a reference to the scattered data to keep the ids the same
self._scatter = list(scatter)
scattered = self.client.scatter(scatter, broadcast=True)
self.data_futures = {id(x): f for x, f in zip(scatter, scattered)}
else:
self._scatter = []
self.data_futures = {}
self.wait_for_workers_timeout = wait_for_workers_timeout
self.submit_kwargs = submit_kwargs
self.waiting_futures = as_completed(
[], loop=client.loop, with_results=True, raise_errors=False
)
self._results = {}
self._callbacks = {}
async def _collect(self):
while self._continue:
async for future, result in self.waiting_futures:
cf_future = self._results.pop(future)
callback = self._callbacks.pop(future)
if future.status == "error":
typ, exc, tb = result
cf_future.set_exception(exc)
else:
cf_future.set_result(result)
callback(result)
await asyncio.sleep(0.01)
def __reduce__(self):
return (DaskDistributedBackend, ())
def get_nested_backend(self):
return DaskDistributedBackend(client=self.client), -1
def configure(self, n_jobs=1, parallel=None, **backend_args):
self.parallel = parallel
return self.effective_n_jobs(n_jobs)
def start_call(self):
self._continue = True
self.client.loop.add_callback(self._collect)
self.call_data_futures = _WeakKeyDictionary()
def stop_call(self):
# The explicit call to clear is required to break a cycling reference
# to the futures.
self._continue = False
# wait for the future collection routine (self._backend._collect) to
# finish in order to limit asyncio warnings due to aborting _collect
# during a following backend termination call
time.sleep(0.01)
self.call_data_futures.clear()
def effective_n_jobs(self, n_jobs):
effective_n_jobs = sum(self.client.ncores().values())
if effective_n_jobs != 0 or not self.wait_for_workers_timeout:
return effective_n_jobs
# If there is no worker, schedule a probe task to wait for the workers
# to come up and be available. If the dask cluster is in adaptive mode
# task might cause the cluster to provision some workers.
try:
self.client.submit(_joblib_probe_task).result(
timeout=self.wait_for_workers_timeout
)
except _TimeoutError as e:
error_msg = (
"DaskDistributedBackend has no worker after {} seconds. "
"Make sure that workers are started and can properly connect "
"to the scheduler and increase the joblib/dask connection "
"timeout with:\n\n"
"parallel_config(backend='dask', wait_for_workers_timeout={})"
).format(
self.wait_for_workers_timeout,
max(10, 2 * self.wait_for_workers_timeout),
)
raise TimeoutError(error_msg) from e
return sum(self.client.ncores().values())
async def _to_func_args(self, func):
itemgetters = dict()
# Futures that are dynamically generated during a single call to
# Parallel.__call__.
call_data_futures = getattr(self, "call_data_futures", None)
async def maybe_to_futures(args):
out = []
for arg in args:
arg_id = id(arg)
if arg_id in itemgetters:
out.append(itemgetters[arg_id])
continue
f = self.data_futures.get(arg_id, None)
if f is None and call_data_futures is not None:
try:
f = await call_data_futures[arg]
except KeyError:
pass
if f is None:
if is_weakrefable(arg) and sizeof(arg) > 1e3:
# Automatically scatter large objects to some of
# the workers to avoid duplicated data transfers.
# Rely on automated inter-worker data stealing if
# more workers need to reuse this data
# concurrently.
# set hash=False - nested scatter calls (i.e
# calling client.scatter inside a dask worker)
# using hash=True often raise CancelledError,
# see dask/distributed#3703
_coro = self.client.scatter(
arg, asynchronous=True, hash=False
)
# Centralize the scattering of identical arguments
# between concurrent apply_async callbacks by
# exposing the running coroutine in
# call_data_futures before it completes.
t = asyncio.Task(_coro)
call_data_futures[arg] = t
f = await t
if f is not None:
out.append(f)
else:
out.append(arg)
return out
tasks = []
for f, args, kwargs in func.items:
args = list(await maybe_to_futures(args))
kwargs = dict(zip(kwargs.keys(), await maybe_to_futures(kwargs.values())))
tasks.append((f, args, kwargs))
return (Batch(tasks), tasks)
def apply_async(self, func, callback=None):
cf_future = concurrent.futures.Future()
cf_future.get = cf_future.result # achieve AsyncResult API
async def f(func, callback):
batch, tasks = await self._to_func_args(func)
key = f"{repr(batch)}-{uuid4().hex}"
dask_future = self.client.submit(
_TracebackCapturingWrapper(batch),
tasks=tasks,
key=key,
**self.submit_kwargs,
)
self.waiting_futures.add(dask_future)
self._callbacks[dask_future] = callback
self._results[dask_future] = cf_future
self.client.loop.add_callback(f, func, callback)
return cf_future
def retrieve_result_callback(self, out):
return _retrieve_traceback_capturing_wrapped_call(out)
def abort_everything(self, ensure_ready=True):
"""Tell the client to cancel any task submitted via this instance
joblib.Parallel will never access those results
"""
with self.waiting_futures.lock:
self.waiting_futures.futures.clear()
while not self.waiting_futures.queue.empty():
self.waiting_futures.queue.get()
@contextlib.contextmanager
def retrieval_context(self):
"""Override ParallelBackendBase.retrieval_context to avoid deadlocks.
This removes thread from the worker's thread pool (using 'secede').
Seceding avoids deadlock in nested parallelism settings.
"""
# See 'joblib.Parallel.__call__' and 'joblib.Parallel.retrieve' for how
# this is used.
if hasattr(thread_state, "execution_state"):
# we are in a worker. Secede to avoid deadlock.
secede()
yield
if hasattr(thread_state, "execution_state"):
rejoin()
|