from __future__ import annotations from functools import lru_cache, wraps from typing import TYPE_CHECKING, Callable, Generic, TypeVar from dask import config from dask.compatibility import entry_points from dask.utils import funcname if TYPE_CHECKING: from typing_extensions import ParamSpec BackendFuncParams = ParamSpec("BackendFuncParams") BackendFuncReturn = TypeVar("BackendFuncReturn") class DaskBackendEntrypoint: """Base Collection-Backend Entrypoint Class Most methods in this class correspond to collection-creation for a specific library backend. Once a collection is created, the existing data will be used to dispatch compute operations within individual tasks. The backend is responsible for ensuring that these data-directed dispatch functions are registered when ``__init__`` is called. """ @classmethod def to_backend_dispatch(cls): """Return a dispatch function to move data to this backend""" raise NotImplementedError @staticmethod def to_backend(data): """Create a new collection with this backend""" raise NotImplementedError @lru_cache(maxsize=1) def detect_entrypoints(entry_point_name): entrypoints = entry_points(entry_point_name) return {ep.name: ep for ep in entrypoints} BackendEntrypointType = TypeVar( "BackendEntrypointType", bound="DaskBackendEntrypoint", ) class CreationDispatch(Generic[BackendEntrypointType]): """Simple backend dispatch for collection-creation functions""" _lookup: dict[str, BackendEntrypointType] _module_name: str _config_field: str _default: str _entrypoint_class: type[BackendEntrypointType] def __init__( self, module_name: str, default: str, entrypoint_class: type[BackendEntrypointType], name: str | None = None, ): self._lookup = {} self._module_name = module_name self._config_field = f"{module_name}.backend" self._default = default self._entrypoint_class = entrypoint_class if name: self.__name__ = name def register_backend( self, name: str, backend: BackendEntrypointType ) -> BackendEntrypointType: """Register a target class for a specific array-backend label""" if not isinstance(backend, self._entrypoint_class): raise ValueError( f"This CreationDispatch only supports " f"{self._entrypoint_class} registration. " f"Got {type(backend)}" ) self._lookup[name] = backend return backend def dispatch(self, backend: str): """Return the desired backend entrypoint""" try: impl = self._lookup[backend] except KeyError: # Check entrypoints for the specified backend entrypoints = detect_entrypoints(f"dask.{self._module_name}.backends") if backend in entrypoints: return self.register_backend(backend, entrypoints[backend].load()()) else: return impl raise ValueError(f"No backend dispatch registered for {backend}") @property def backend(self) -> str: """Return the desired collection backend""" return config.get(self._config_field, self._default) or self._default @backend.setter def backend(self, value: str): raise RuntimeError( f"Set the backend by configuring the {self._config_field} option" ) def register_inplace( self, backend: str, name: str | None = None, ) -> Callable[ [Callable[BackendFuncParams, BackendFuncReturn]], Callable[BackendFuncParams, BackendFuncReturn], ]: """Register dispatchable function""" def decorator( fn: Callable[BackendFuncParams, BackendFuncReturn] ) -> Callable[BackendFuncParams, BackendFuncReturn]: dispatch_name = name or fn.__name__ dispatcher = self.dispatch(backend) dispatcher.__setattr__(dispatch_name, fn) @wraps(fn) def wrapper(*args, **kwargs): func = getattr(self, dispatch_name) try: return func(*args, **kwargs) except Exception as e: raise type(e)( f"An error occurred while calling the {funcname(func)} " f"method registered to the {self.backend} backend.\n" f"Original Message: {e}" ) from e wrapper.__name__ = dispatch_name return wrapper return decorator def __getattr__(self, item: str): """ Return the appropriate attribute for the current backend """ backend = self.dispatch(self.backend) return getattr(backend, item)