File size: 4,917 Bytes
d1ed09d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations

from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Callable, Generic, TypeVar

from dask import config
from dask.compatibility import entry_points
from dask.utils import funcname

if TYPE_CHECKING:
    from typing_extensions import ParamSpec

    BackendFuncParams = ParamSpec("BackendFuncParams")
    BackendFuncReturn = TypeVar("BackendFuncReturn")


class DaskBackendEntrypoint:
    """Base Collection-Backend Entrypoint Class

    Most methods in this class correspond to collection-creation
    for a specific library backend. Once a collection is created,
    the existing data will be used to dispatch compute operations
    within individual tasks. The backend is responsible for
    ensuring that these data-directed dispatch functions are
    registered when ``__init__`` is called.
    """

    @classmethod
    def to_backend_dispatch(cls):
        """Return a dispatch function to move data to this backend"""
        raise NotImplementedError

    @staticmethod
    def to_backend(data):
        """Create a new collection with this backend"""
        raise NotImplementedError


@lru_cache(maxsize=1)
def detect_entrypoints(entry_point_name):
    entrypoints = entry_points(entry_point_name)
    return {ep.name: ep for ep in entrypoints}


BackendEntrypointType = TypeVar(
    "BackendEntrypointType",
    bound="DaskBackendEntrypoint",
)


class CreationDispatch(Generic[BackendEntrypointType]):
    """Simple backend dispatch for collection-creation functions"""

    _lookup: dict[str, BackendEntrypointType]
    _module_name: str
    _config_field: str
    _default: str
    _entrypoint_class: type[BackendEntrypointType]

    def __init__(
        self,
        module_name: str,
        default: str,
        entrypoint_class: type[BackendEntrypointType],
        name: str | None = None,
    ):
        self._lookup = {}
        self._module_name = module_name
        self._config_field = f"{module_name}.backend"
        self._default = default
        self._entrypoint_class = entrypoint_class
        if name:
            self.__name__ = name

    def register_backend(
        self, name: str, backend: BackendEntrypointType
    ) -> BackendEntrypointType:
        """Register a target class for a specific array-backend label"""
        if not isinstance(backend, self._entrypoint_class):
            raise ValueError(
                f"This CreationDispatch only supports "
                f"{self._entrypoint_class} registration. "
                f"Got {type(backend)}"
            )
        self._lookup[name] = backend
        return backend

    def dispatch(self, backend: str):
        """Return the desired backend entrypoint"""
        try:
            impl = self._lookup[backend]
        except KeyError:
            # Check entrypoints for the specified backend
            entrypoints = detect_entrypoints(f"dask.{self._module_name}.backends")
            if backend in entrypoints:
                return self.register_backend(backend, entrypoints[backend].load()())
        else:
            return impl
        raise ValueError(f"No backend dispatch registered for {backend}")

    @property
    def backend(self) -> str:
        """Return the desired collection backend"""
        return config.get(self._config_field, self._default) or self._default

    @backend.setter
    def backend(self, value: str):
        raise RuntimeError(
            f"Set the backend by configuring the {self._config_field} option"
        )

    def register_inplace(
        self,
        backend: str,
        name: str | None = None,
    ) -> Callable[
        [Callable[BackendFuncParams, BackendFuncReturn]],
        Callable[BackendFuncParams, BackendFuncReturn],
    ]:
        """Register dispatchable function"""

        def decorator(
            fn: Callable[BackendFuncParams, BackendFuncReturn]
        ) -> Callable[BackendFuncParams, BackendFuncReturn]:
            dispatch_name = name or fn.__name__
            dispatcher = self.dispatch(backend)
            dispatcher.__setattr__(dispatch_name, fn)

            @wraps(fn)
            def wrapper(*args, **kwargs):
                func = getattr(self, dispatch_name)
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    raise type(e)(
                        f"An error occurred while calling the {funcname(func)} "
                        f"method registered to the {self.backend} backend.\n"
                        f"Original Message: {e}"
                    ) from e

            wrapper.__name__ = dispatch_name
            return wrapper

        return decorator

    def __getattr__(self, item: str):
        """
        Return the appropriate attribute for the current backend
        """
        backend = self.dispatch(self.backend)
        return getattr(backend, item)