Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,917 Bytes
d1ed09d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from __future__ import annotations
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
from dask import config
from dask.compatibility import entry_points
from dask.utils import funcname
if TYPE_CHECKING:
from typing_extensions import ParamSpec
BackendFuncParams = ParamSpec("BackendFuncParams")
BackendFuncReturn = TypeVar("BackendFuncReturn")
class DaskBackendEntrypoint:
"""Base Collection-Backend Entrypoint Class
Most methods in this class correspond to collection-creation
for a specific library backend. Once a collection is created,
the existing data will be used to dispatch compute operations
within individual tasks. The backend is responsible for
ensuring that these data-directed dispatch functions are
registered when ``__init__`` is called.
"""
@classmethod
def to_backend_dispatch(cls):
"""Return a dispatch function to move data to this backend"""
raise NotImplementedError
@staticmethod
def to_backend(data):
"""Create a new collection with this backend"""
raise NotImplementedError
@lru_cache(maxsize=1)
def detect_entrypoints(entry_point_name):
entrypoints = entry_points(entry_point_name)
return {ep.name: ep for ep in entrypoints}
BackendEntrypointType = TypeVar(
"BackendEntrypointType",
bound="DaskBackendEntrypoint",
)
class CreationDispatch(Generic[BackendEntrypointType]):
"""Simple backend dispatch for collection-creation functions"""
_lookup: dict[str, BackendEntrypointType]
_module_name: str
_config_field: str
_default: str
_entrypoint_class: type[BackendEntrypointType]
def __init__(
self,
module_name: str,
default: str,
entrypoint_class: type[BackendEntrypointType],
name: str | None = None,
):
self._lookup = {}
self._module_name = module_name
self._config_field = f"{module_name}.backend"
self._default = default
self._entrypoint_class = entrypoint_class
if name:
self.__name__ = name
def register_backend(
self, name: str, backend: BackendEntrypointType
) -> BackendEntrypointType:
"""Register a target class for a specific array-backend label"""
if not isinstance(backend, self._entrypoint_class):
raise ValueError(
f"This CreationDispatch only supports "
f"{self._entrypoint_class} registration. "
f"Got {type(backend)}"
)
self._lookup[name] = backend
return backend
def dispatch(self, backend: str):
"""Return the desired backend entrypoint"""
try:
impl = self._lookup[backend]
except KeyError:
# Check entrypoints for the specified backend
entrypoints = detect_entrypoints(f"dask.{self._module_name}.backends")
if backend in entrypoints:
return self.register_backend(backend, entrypoints[backend].load()())
else:
return impl
raise ValueError(f"No backend dispatch registered for {backend}")
@property
def backend(self) -> str:
"""Return the desired collection backend"""
return config.get(self._config_field, self._default) or self._default
@backend.setter
def backend(self, value: str):
raise RuntimeError(
f"Set the backend by configuring the {self._config_field} option"
)
def register_inplace(
self,
backend: str,
name: str | None = None,
) -> Callable[
[Callable[BackendFuncParams, BackendFuncReturn]],
Callable[BackendFuncParams, BackendFuncReturn],
]:
"""Register dispatchable function"""
def decorator(
fn: Callable[BackendFuncParams, BackendFuncReturn]
) -> Callable[BackendFuncParams, BackendFuncReturn]:
dispatch_name = name or fn.__name__
dispatcher = self.dispatch(backend)
dispatcher.__setattr__(dispatch_name, fn)
@wraps(fn)
def wrapper(*args, **kwargs):
func = getattr(self, dispatch_name)
try:
return func(*args, **kwargs)
except Exception as e:
raise type(e)(
f"An error occurred while calling the {funcname(func)} "
f"method registered to the {self.backend} backend.\n"
f"Original Message: {e}"
) from e
wrapper.__name__ = dispatch_name
return wrapper
return decorator
def __getattr__(self, item: str):
"""
Return the appropriate attribute for the current backend
"""
backend = self.dispatch(self.backend)
return getattr(backend, item)
|