""" A threaded shared-memory scheduler See local.py """ from __future__ import annotations import atexit import multiprocessing.pool import sys import threading from collections import defaultdict from collections.abc import Hashable, Mapping, Sequence from concurrent.futures import Executor, ThreadPoolExecutor from threading import Lock, current_thread from dask import config from dask.local import MultiprocessingPoolExecutor, get_async from dask.system import CPU_COUNT def _thread_get_id(): return current_thread().ident main_thread = current_thread() default_pool: Executor | None = None pools: defaultdict[threading.Thread, dict[int, Executor]] = defaultdict(dict) pools_lock = Lock() def pack_exception(e, dumps): return e, sys.exc_info()[2] def get( dsk: Mapping, keys: Sequence[Hashable] | Hashable, cache=None, num_workers=None, pool=None, **kwargs, ): """Threaded cached implementation of dask.get Parameters ---------- dsk: dict A dask dictionary specifying a workflow keys: key or list of keys Keys corresponding to desired data num_workers: integer of thread count The number of threads to use in the ThreadPool that will actually execute tasks cache: dict-like (optional) Temporary storage of results Examples -------- >>> inc = lambda x: x + 1 >>> add = lambda x, y: x + y >>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')} >>> get(dsk, 'w') 4 >>> get(dsk, ['w', 'y']) (4, 2) """ global default_pool pool = pool or config.get("pool", None) num_workers = num_workers or config.get("num_workers", None) thread = current_thread() with pools_lock: if pool is None: if num_workers is None and thread is main_thread: if default_pool is None: default_pool = ThreadPoolExecutor(CPU_COUNT) atexit.register(default_pool.shutdown) pool = default_pool elif thread in pools and num_workers in pools[thread]: pool = pools[thread][num_workers] else: pool = ThreadPoolExecutor(num_workers) atexit.register(pool.shutdown) pools[thread][num_workers] = pool elif isinstance(pool, multiprocessing.pool.Pool): pool = MultiprocessingPoolExecutor(pool) results = get_async( pool.submit, pool._max_workers, dsk, keys, cache=cache, get_id=_thread_get_id, pack_exception=pack_exception, **kwargs, ) # Cleanup pools associated to dead threads with pools_lock: active_threads = set(threading.enumerate()) if thread is not main_thread: for t in list(pools): if t not in active_threads: for p in pools.pop(t).values(): p.shutdown() return results