m7n's picture
first commit
d1ed09d
raw
history blame
2.99 kB
"""
A threaded shared-memory scheduler
See local.py
"""
from __future__ import annotations
import atexit
import multiprocessing.pool
import sys
import threading
from collections import defaultdict
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from threading import Lock, current_thread
from dask import config
from dask.local import MultiprocessingPoolExecutor, get_async
from dask.system import CPU_COUNT
def _thread_get_id():
return current_thread().ident
main_thread = current_thread()
default_pool: Executor | None = None
pools: defaultdict[threading.Thread, dict[int, Executor]] = defaultdict(dict)
pools_lock = Lock()
def pack_exception(e, dumps):
return e, sys.exc_info()[2]
def get(
dsk: Mapping,
keys: Sequence[Hashable] | Hashable,
cache=None,
num_workers=None,
pool=None,
**kwargs,
):
"""Threaded cached implementation of dask.get
Parameters
----------
dsk: dict
A dask dictionary specifying a workflow
keys: key or list of keys
Keys corresponding to desired data
num_workers: integer of thread count
The number of threads to use in the ThreadPool that will actually execute tasks
cache: dict-like (optional)
Temporary storage of results
Examples
--------
>>> inc = lambda x: x + 1
>>> add = lambda x, y: x + y
>>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')}
>>> get(dsk, 'w')
4
>>> get(dsk, ['w', 'y'])
(4, 2)
"""
global default_pool
pool = pool or config.get("pool", None)
num_workers = num_workers or config.get("num_workers", None)
thread = current_thread()
with pools_lock:
if pool is None:
if num_workers is None and thread is main_thread:
if default_pool is None:
default_pool = ThreadPoolExecutor(CPU_COUNT)
atexit.register(default_pool.shutdown)
pool = default_pool
elif thread in pools and num_workers in pools[thread]:
pool = pools[thread][num_workers]
else:
pool = ThreadPoolExecutor(num_workers)
atexit.register(pool.shutdown)
pools[thread][num_workers] = pool
elif isinstance(pool, multiprocessing.pool.Pool):
pool = MultiprocessingPoolExecutor(pool)
results = get_async(
pool.submit,
pool._max_workers,
dsk,
keys,
cache=cache,
get_id=_thread_get_id,
pack_exception=pack_exception,
**kwargs,
)
# Cleanup pools associated to dead threads
with pools_lock:
active_threads = set(threading.enumerate())
if thread is not main_thread:
for t in list(pools):
if t not in active_threads:
for p in pools.pop(t).values():
p.shutdown()
return results