File size: 14,710 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
from __future__ import annotations

from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union

import torch

__all__ = ['Future', 'collect_all', 'wait_all']

T = TypeVar("T")
S = TypeVar("S")


class _PyFutureMeta(type(torch._C.Future), type(Generic)):  # type: ignore[misc, no-redef]
    pass


class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
    r"""

    Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous

    execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It

    also exposes a set of APIs to add callback functions and set results.



    .. warning:: GPU support is a beta feature, subject to changes.

    """

    def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None):
        r"""

        Create an empty unset ``Future``. If the future is intended to hold

        values containing CUDA tensors, (a superset of) their CUDA devices must

        be specified at construction. (This is only supported if

        ``torch.cuda.is_available()`` returns ``True``). This is needed to

        ensure proper CUDA stream synchronization. The child futures, returned

        by the ``then`` method, will inherit these devices.



        Args:

            devices(``List[Union[int, str, torch.device]]``, optional): the set

                of devices on which tensors contained in this future's value are

                allowed to reside and on which callbacks are allowed to operate.

        """
        if devices is None:
            devices = []
        super().__init__([torch.device(d) for d in devices])

    def done(self) -> bool:
        r"""

        Return ``True`` if this ``Future`` is done. A ``Future`` is done if it

        has a result or an exception.



        If the value contains tensors that reside on GPUs, ``Future.done()``

        will return ``True`` even if the asynchronous kernels that are

        populating those tensors haven't yet completed running on the device,

        because at such stage the result is already usable, provided one

        performs the appropriate synchronizations (see :meth:`wait`).

        """
        return super().done()

    def wait(self) -> T:
        r"""

        Block until the value of this ``Future`` is ready.



        If the value contains tensors that reside on GPUs, then an additional

        synchronization is performed with the kernels (executing on the device)

        which may be asynchronously populating those tensors. Such sync is

        non-blocking, which means that ``wait()`` will insert the necessary

        instructions in the current streams to ensure that further operations

        enqueued on those streams will be properly scheduled after the async

        kernels but, once that is done, ``wait()`` will return, even if those

        kernels are still running. No further synchronization is required when

        accessing and using the values, as long as one doesn't change streams.



        Returns:

            The value held by this ``Future``. If the function (callback or RPC)

            creating the value has thrown an error, this ``wait`` method will

            also throw an error.

        """
        return super().wait()

    def value(self) -> T:
        r"""

        Obtain the value of an already-completed future.



        This method should only be called after a call to :meth:`wait` has

        completed, or inside a callback function passed to :meth:`then`. In

        other cases this ``Future`` may not yet hold a value and calling

        ``value()`` could fail.



        If the value contains tensors that reside on GPUs, then this method will

        *not* perform any additional synchronization. This should be done

        beforehand, separately, through a call to :meth:`wait` (except within

        callbacks, for which it's already being taken care of by :meth:`then`).



        Returns:

            The value held by this ``Future``. If the function (callback or RPC)

            creating the value has thrown an error, this ``value()`` method will

            also throw an error.

        """
        return super().value()

    def then(self, callback: Callable[[Future[T]], S]) -> Future[S]:
        r"""

        Append the given callback function to this ``Future``, which will be run

        when the ``Future`` is completed.  Multiple callbacks can be added to

        the same ``Future``, but the order in which they will be executed cannot

        be guaranteed (to enforce a certain order consider chaining:

        ``fut.then(cb1).then(cb2)``). The callback must take one argument, which

        is the reference to this ``Future``. The callback function can use the

        :meth:`value` method to get the value. Note that if this ``Future`` is

        already completed, the given callback will be run immediately inline.



        If the ``Future``'s value contains tensors that reside on GPUs, the

        callback might be invoked while the async kernels that are populating

        those tensors haven't yet finished executing on the device. However, the

        callback will be invoked with some dedicated streams set as current

        (fetched from a global pool) which will be synchronized with those

        kernels. Hence any operation performed by the callback on these tensors

        will be scheduled on the device after the kernels complete. In other

        words, as long as the callback doesn't switch streams, it can safely

        manipulate the result without any additional synchronization. This is

        similar to the non-blocking behavior of :meth:`wait`.



        Similarly, if the callback returns a value that contains tensors that

        reside on a GPU, it can do so even if the kernels that are producing

        these tensors are still running on the device, as long as the callback

        didn't change streams during its execution. If one wants to change

        streams, one must be careful to re-synchronize them with the original

        streams, that is, those that were current when the callback was invoked.



        Args:

            callback(``Callable``): a ``Callable`` that takes this ``Future`` as

                                    the only argument.



        Returns:

            A new ``Future`` object that holds the return value of the

            ``callback`` and will be marked as completed when the given

            ``callback`` finishes.



        .. note:: Note that if the callback function throws, either

            through the original future being completed with an exception and

            calling ``fut.wait()``, or through other code in the callback, the

            future returned by ``then`` will be marked appropriately with the

            encountered error. However, if this callback later completes

            additional futures, those futures are not marked as completed with

            an error and the user is responsible for handling completion/waiting

            on those futures independently.



        Example::

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)

            >>> def callback(fut):

            ...     print(f"RPC return value is {fut.wait()}.")

            >>> fut = torch.futures.Future()

            >>> # The inserted callback will print the return value when

            >>> # receiving the response from "worker1"

            >>> cb_fut = fut.then(callback)

            >>> chain_cb_fut = cb_fut.then(

            ...     lambda x : print(f"Chained cb done. {x.wait()}")

            ... )

            >>> fut.set_result(5)

            RPC return value is 5.

            Chained cb done. None

        """
        return cast(Future[S], super().then(callback))

    def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None:
        r"""

        Append the given callback function to this ``Future``, which will be run

        when the ``Future`` is completed.  Multiple callbacks can be added to

        the same ``Future``, but the order in which they will be executed cannot

        be guaranteed. The callback must take one argument, which is the

        reference to this ``Future``. The callback function can use the

        :meth:`value` method to get the value. Note that if this ``Future`` is

        already completed, the given callback will be run inline.



        We recommend that you use the :meth:`then` method as it provides a way

        to synchronize after your callback has completed. ``add_done_callback``

        can be cheaper if your callback does not return anything. But both

        :meth:`then` and ``add_done_callback`` use the same callback

        registration API under the hood.



        With respect to GPU tensors, this method behaves in the same way as

        :meth:`then`.



        Args:

            callback(``Future``): a ``Callable`` that takes in one argument,

                which is the reference to this ``Future``.



        .. note:: Note that if the callback function throws, either

            through the original future being completed with an exception and

            calling ``fut.wait()``, or through other code in the callback,

            error handling must be carefully taken care of. For example, if

            this callback later completes additional futures, those futures are

            not marked as completed with an error and the user is responsible

            for handling completion/waiting on those futures independently.



        Example::

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)

            >>> def callback(fut):

            ...     print("This will run after the future has finished.")

            ...     print(fut.wait())

            >>> fut = torch.futures.Future()

            >>> fut.add_done_callback(callback)

            >>> fut.set_result(5)

            This will run after the future has finished.

            5

        """
        super().add_done_callback(callback)

    def set_result(self, result: T) -> None:
        r"""

        Set the result for this ``Future``, which will mark this ``Future`` as

        completed and trigger all attached callbacks. Note that a ``Future``

        cannot be marked completed twice.



        If the result contains tensors that reside on GPUs, this method can be

        called even if the asynchronous kernels that are populating those

        tensors haven't yet completed running on the device, provided that the

        streams on which those kernels were enqueued are set as the current ones

        when this method is called. Put simply, it's safe to call this method

        immediately after launching those kernels, without any additional

        synchronization, as long as one doesn't change streams in between. This

        method will record events on all the relevant current streams and will

        use them to ensure proper scheduling for all the consumers of this

        ``Future``.



        Args:

            result (object): the result object of this ``Future``.



        Example::

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)

            >>> import threading

            >>> import time

            >>> def slow_set_future(fut, value):

            ...     time.sleep(0.5)

            ...     fut.set_result(value)

            >>> fut = torch.futures.Future()

            >>> t = threading.Thread(

            ...     target=slow_set_future,

            ...     args=(fut, torch.ones(2) * 3)

            ... )

            >>> t.start()

            >>> print(fut.wait())

            tensor([3., 3.])

            >>> t.join()

        """
        super().set_result(result)

    def set_exception(self, result: T) -> None:
        r"""

        Set an exception for this ``Future``, which will mark this ``Future`` as

        completed with an error and trigger all attached callbacks. Note that

        when calling wait()/value() on this ``Future``, the exception set here

        will be raised inline.



        Args:

            result (BaseException): the exception for this ``Future``.



        Example::

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)

            >>> fut = torch.futures.Future()

            >>> fut.set_exception(ValueError("foo"))

            >>> fut.wait()

            Traceback (most recent call last):

            ...

            ValueError: foo

        """
        assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."

        def raise_error(fut_result):
            raise fut_result

        super()._set_unwrap_func(raise_error)
        self.set_result(result)  # type: ignore[arg-type]


def collect_all(futures: List[Future]) -> Future[List[Future]]:
    r"""

    Collects the provided :class:`~torch.futures.Future` objects into a single

    combined :class:`~torch.futures.Future` that is completed when all of the

    sub-futures are completed.



    Args:

        futures (list): a list of :class:`~torch.futures.Future` objects.



    Returns:

        Returns a :class:`~torch.futures.Future` object to a list of the passed

        in Futures.



    Example::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)

        >>> fut0 = torch.futures.Future()

        >>> fut1 = torch.futures.Future()

        >>> fut = torch.futures.collect_all([fut0, fut1])

        >>> fut0.set_result(0)

        >>> fut1.set_result(1)

        >>> fut_list = fut.wait()

        >>> print(f"fut0 result = {fut_list[0].wait()}")

        fut0 result = 0

        >>> print(f"fut1 result = {fut_list[1].wait()}")

        fut1 result = 1

    """
    return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))


def wait_all(futures: List[Future]) -> List:
    r"""

    Waits for all provided futures to be complete, and returns

    the list of completed values. If any of the futures encounters an error,

    the method will exit early and report the error not waiting for other

    futures to complete.



    Args:

        futures (list): a list of :class:`~torch.futures.Future` object.



    Returns:

        A list of the completed :class:`~torch.futures.Future` results. This

        method will throw an error if ``wait`` on any

        :class:`~torch.futures.Future` throws.

    """
    return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]