Spaces:
Running
Running
from __future__ import annotations | |
from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union | |
import torch | |
__all__ = ['Future', 'collect_all', 'wait_all'] | |
T = TypeVar("T") | |
S = TypeVar("S") | |
class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef] | |
pass | |
class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta): | |
r""" | |
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous | |
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It | |
also exposes a set of APIs to add callback functions and set results. | |
.. warning:: GPU support is a beta feature, subject to changes. | |
""" | |
def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None): | |
r""" | |
Create an empty unset ``Future``. If the future is intended to hold | |
values containing CUDA tensors, (a superset of) their CUDA devices must | |
be specified at construction. (This is only supported if | |
``torch.cuda.is_available()`` returns ``True``). This is needed to | |
ensure proper CUDA stream synchronization. The child futures, returned | |
by the ``then`` method, will inherit these devices. | |
Args: | |
devices(``List[Union[int, str, torch.device]]``, optional): the set | |
of devices on which tensors contained in this future's value are | |
allowed to reside and on which callbacks are allowed to operate. | |
""" | |
if devices is None: | |
devices = [] | |
super().__init__([torch.device(d) for d in devices]) | |
def done(self) -> bool: | |
r""" | |
Return ``True`` if this ``Future`` is done. A ``Future`` is done if it | |
has a result or an exception. | |
If the value contains tensors that reside on GPUs, ``Future.done()`` | |
will return ``True`` even if the asynchronous kernels that are | |
populating those tensors haven't yet completed running on the device, | |
because at such stage the result is already usable, provided one | |
performs the appropriate synchronizations (see :meth:`wait`). | |
""" | |
return super().done() | |
def wait(self) -> T: | |
r""" | |
Block until the value of this ``Future`` is ready. | |
If the value contains tensors that reside on GPUs, then an additional | |
synchronization is performed with the kernels (executing on the device) | |
which may be asynchronously populating those tensors. Such sync is | |
non-blocking, which means that ``wait()`` will insert the necessary | |
instructions in the current streams to ensure that further operations | |
enqueued on those streams will be properly scheduled after the async | |
kernels but, once that is done, ``wait()`` will return, even if those | |
kernels are still running. No further synchronization is required when | |
accessing and using the values, as long as one doesn't change streams. | |
Returns: | |
The value held by this ``Future``. If the function (callback or RPC) | |
creating the value has thrown an error, this ``wait`` method will | |
also throw an error. | |
""" | |
return super().wait() | |
def value(self) -> T: | |
r""" | |
Obtain the value of an already-completed future. | |
This method should only be called after a call to :meth:`wait` has | |
completed, or inside a callback function passed to :meth:`then`. In | |
other cases this ``Future`` may not yet hold a value and calling | |
``value()`` could fail. | |
If the value contains tensors that reside on GPUs, then this method will | |
*not* perform any additional synchronization. This should be done | |
beforehand, separately, through a call to :meth:`wait` (except within | |
callbacks, for which it's already being taken care of by :meth:`then`). | |
Returns: | |
The value held by this ``Future``. If the function (callback or RPC) | |
creating the value has thrown an error, this ``value()`` method will | |
also throw an error. | |
""" | |
return super().value() | |
def then(self, callback: Callable[[Future[T]], S]) -> Future[S]: | |
r""" | |
Append the given callback function to this ``Future``, which will be run | |
when the ``Future`` is completed. Multiple callbacks can be added to | |
the same ``Future``, but the order in which they will be executed cannot | |
be guaranteed (to enforce a certain order consider chaining: | |
``fut.then(cb1).then(cb2)``). The callback must take one argument, which | |
is the reference to this ``Future``. The callback function can use the | |
:meth:`value` method to get the value. Note that if this ``Future`` is | |
already completed, the given callback will be run immediately inline. | |
If the ``Future``'s value contains tensors that reside on GPUs, the | |
callback might be invoked while the async kernels that are populating | |
those tensors haven't yet finished executing on the device. However, the | |
callback will be invoked with some dedicated streams set as current | |
(fetched from a global pool) which will be synchronized with those | |
kernels. Hence any operation performed by the callback on these tensors | |
will be scheduled on the device after the kernels complete. In other | |
words, as long as the callback doesn't switch streams, it can safely | |
manipulate the result without any additional synchronization. This is | |
similar to the non-blocking behavior of :meth:`wait`. | |
Similarly, if the callback returns a value that contains tensors that | |
reside on a GPU, it can do so even if the kernels that are producing | |
these tensors are still running on the device, as long as the callback | |
didn't change streams during its execution. If one wants to change | |
streams, one must be careful to re-synchronize them with the original | |
streams, that is, those that were current when the callback was invoked. | |
Args: | |
callback(``Callable``): a ``Callable`` that takes this ``Future`` as | |
the only argument. | |
Returns: | |
A new ``Future`` object that holds the return value of the | |
``callback`` and will be marked as completed when the given | |
``callback`` finishes. | |
.. note:: Note that if the callback function throws, either | |
through the original future being completed with an exception and | |
calling ``fut.wait()``, or through other code in the callback, the | |
future returned by ``then`` will be marked appropriately with the | |
encountered error. However, if this callback later completes | |
additional futures, those futures are not marked as completed with | |
an error and the user is responsible for handling completion/waiting | |
on those futures independently. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) | |
>>> def callback(fut): | |
... print(f"RPC return value is {fut.wait()}.") | |
>>> fut = torch.futures.Future() | |
>>> # The inserted callback will print the return value when | |
>>> # receiving the response from "worker1" | |
>>> cb_fut = fut.then(callback) | |
>>> chain_cb_fut = cb_fut.then( | |
... lambda x : print(f"Chained cb done. {x.wait()}") | |
... ) | |
>>> fut.set_result(5) | |
RPC return value is 5. | |
Chained cb done. None | |
""" | |
return cast(Future[S], super().then(callback)) | |
def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None: | |
r""" | |
Append the given callback function to this ``Future``, which will be run | |
when the ``Future`` is completed. Multiple callbacks can be added to | |
the same ``Future``, but the order in which they will be executed cannot | |
be guaranteed. The callback must take one argument, which is the | |
reference to this ``Future``. The callback function can use the | |
:meth:`value` method to get the value. Note that if this ``Future`` is | |
already completed, the given callback will be run inline. | |
We recommend that you use the :meth:`then` method as it provides a way | |
to synchronize after your callback has completed. ``add_done_callback`` | |
can be cheaper if your callback does not return anything. But both | |
:meth:`then` and ``add_done_callback`` use the same callback | |
registration API under the hood. | |
With respect to GPU tensors, this method behaves in the same way as | |
:meth:`then`. | |
Args: | |
callback(``Future``): a ``Callable`` that takes in one argument, | |
which is the reference to this ``Future``. | |
.. note:: Note that if the callback function throws, either | |
through the original future being completed with an exception and | |
calling ``fut.wait()``, or through other code in the callback, | |
error handling must be carefully taken care of. For example, if | |
this callback later completes additional futures, those futures are | |
not marked as completed with an error and the user is responsible | |
for handling completion/waiting on those futures independently. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) | |
>>> def callback(fut): | |
... print("This will run after the future has finished.") | |
... print(fut.wait()) | |
>>> fut = torch.futures.Future() | |
>>> fut.add_done_callback(callback) | |
>>> fut.set_result(5) | |
This will run after the future has finished. | |
5 | |
""" | |
super().add_done_callback(callback) | |
def set_result(self, result: T) -> None: | |
r""" | |
Set the result for this ``Future``, which will mark this ``Future`` as | |
completed and trigger all attached callbacks. Note that a ``Future`` | |
cannot be marked completed twice. | |
If the result contains tensors that reside on GPUs, this method can be | |
called even if the asynchronous kernels that are populating those | |
tensors haven't yet completed running on the device, provided that the | |
streams on which those kernels were enqueued are set as the current ones | |
when this method is called. Put simply, it's safe to call this method | |
immediately after launching those kernels, without any additional | |
synchronization, as long as one doesn't change streams in between. This | |
method will record events on all the relevant current streams and will | |
use them to ensure proper scheduling for all the consumers of this | |
``Future``. | |
Args: | |
result (object): the result object of this ``Future``. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) | |
>>> import threading | |
>>> import time | |
>>> def slow_set_future(fut, value): | |
... time.sleep(0.5) | |
... fut.set_result(value) | |
>>> fut = torch.futures.Future() | |
>>> t = threading.Thread( | |
... target=slow_set_future, | |
... args=(fut, torch.ones(2) * 3) | |
... ) | |
>>> t.start() | |
>>> print(fut.wait()) | |
tensor([3., 3.]) | |
>>> t.join() | |
""" | |
super().set_result(result) | |
def set_exception(self, result: T) -> None: | |
r""" | |
Set an exception for this ``Future``, which will mark this ``Future`` as | |
completed with an error and trigger all attached callbacks. Note that | |
when calling wait()/value() on this ``Future``, the exception set here | |
will be raised inline. | |
Args: | |
result (BaseException): the exception for this ``Future``. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) | |
>>> fut = torch.futures.Future() | |
>>> fut.set_exception(ValueError("foo")) | |
>>> fut.wait() | |
Traceback (most recent call last): | |
... | |
ValueError: foo | |
""" | |
assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception." | |
def raise_error(fut_result): | |
raise fut_result | |
super()._set_unwrap_func(raise_error) | |
self.set_result(result) # type: ignore[arg-type] | |
def collect_all(futures: List[Future]) -> Future[List[Future]]: | |
r""" | |
Collects the provided :class:`~torch.futures.Future` objects into a single | |
combined :class:`~torch.futures.Future` that is completed when all of the | |
sub-futures are completed. | |
Args: | |
futures (list): a list of :class:`~torch.futures.Future` objects. | |
Returns: | |
Returns a :class:`~torch.futures.Future` object to a list of the passed | |
in Futures. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) | |
>>> fut0 = torch.futures.Future() | |
>>> fut1 = torch.futures.Future() | |
>>> fut = torch.futures.collect_all([fut0, fut1]) | |
>>> fut0.set_result(0) | |
>>> fut1.set_result(1) | |
>>> fut_list = fut.wait() | |
>>> print(f"fut0 result = {fut_list[0].wait()}") | |
fut0 result = 0 | |
>>> print(f"fut1 result = {fut_list[1].wait()}") | |
fut1 result = 1 | |
""" | |
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures))) | |
def wait_all(futures: List[Future]) -> List: | |
r""" | |
Waits for all provided futures to be complete, and returns | |
the list of completed values. If any of the futures encounters an error, | |
the method will exit early and report the error not waiting for other | |
futures to complete. | |
Args: | |
futures (list): a list of :class:`~torch.futures.Future` object. | |
Returns: | |
A list of the completed :class:`~torch.futures.Future` results. This | |
method will throw an error if ``wait`` on any | |
:class:`~torch.futures.Future` throws. | |
""" | |
return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()] | |