Spaces:
Running
on
Zero
Running
on
Zero
""" | |
A threaded shared-memory scheduler | |
See local.py | |
""" | |
from __future__ import annotations | |
import atexit | |
import multiprocessing.pool | |
import sys | |
import threading | |
from collections import defaultdict | |
from collections.abc import Hashable, Mapping, Sequence | |
from concurrent.futures import Executor, ThreadPoolExecutor | |
from threading import Lock, current_thread | |
from dask import config | |
from dask.local import MultiprocessingPoolExecutor, get_async | |
from dask.system import CPU_COUNT | |
def _thread_get_id(): | |
return current_thread().ident | |
main_thread = current_thread() | |
default_pool: Executor | None = None | |
pools: defaultdict[threading.Thread, dict[int, Executor]] = defaultdict(dict) | |
pools_lock = Lock() | |
def pack_exception(e, dumps): | |
return e, sys.exc_info()[2] | |
def get( | |
dsk: Mapping, | |
keys: Sequence[Hashable] | Hashable, | |
cache=None, | |
num_workers=None, | |
pool=None, | |
**kwargs, | |
): | |
"""Threaded cached implementation of dask.get | |
Parameters | |
---------- | |
dsk: dict | |
A dask dictionary specifying a workflow | |
keys: key or list of keys | |
Keys corresponding to desired data | |
num_workers: integer of thread count | |
The number of threads to use in the ThreadPool that will actually execute tasks | |
cache: dict-like (optional) | |
Temporary storage of results | |
Examples | |
-------- | |
>>> inc = lambda x: x + 1 | |
>>> add = lambda x, y: x + y | |
>>> dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')} | |
>>> get(dsk, 'w') | |
4 | |
>>> get(dsk, ['w', 'y']) | |
(4, 2) | |
""" | |
global default_pool | |
pool = pool or config.get("pool", None) | |
num_workers = num_workers or config.get("num_workers", None) | |
thread = current_thread() | |
with pools_lock: | |
if pool is None: | |
if num_workers is None and thread is main_thread: | |
if default_pool is None: | |
default_pool = ThreadPoolExecutor(CPU_COUNT) | |
atexit.register(default_pool.shutdown) | |
pool = default_pool | |
elif thread in pools and num_workers in pools[thread]: | |
pool = pools[thread][num_workers] | |
else: | |
pool = ThreadPoolExecutor(num_workers) | |
atexit.register(pool.shutdown) | |
pools[thread][num_workers] = pool | |
elif isinstance(pool, multiprocessing.pool.Pool): | |
pool = MultiprocessingPoolExecutor(pool) | |
results = get_async( | |
pool.submit, | |
pool._max_workers, | |
dsk, | |
keys, | |
cache=cache, | |
get_id=_thread_get_id, | |
pack_exception=pack_exception, | |
**kwargs, | |
) | |
# Cleanup pools associated to dead threads | |
with pools_lock: | |
active_threads = set(threading.enumerate()) | |
if thread is not main_thread: | |
for t in list(pools): | |
if t not in active_threads: | |
for p in pools.pop(t).values(): | |
p.shutdown() | |
return results | |