File size: 7,409 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import functools


def async_execution(fn):
    r"""

    A decorator for a function indicating that the return value of the function

    is guaranteed to be a :class:`~torch.futures.Future` object and this

    function can run asynchronously on the RPC callee. More specifically, the

    callee extracts the :class:`~torch.futures.Future` returned by the wrapped

    function and installs subsequent processing steps as a callback to that

    :class:`~torch.futures.Future`. The installed callback will read the value

    from the :class:`~torch.futures.Future` when completed and send the

    value back as the RPC response. That also means the returned

    :class:`~torch.futures.Future` only exists on the callee side and is never

    sent through RPC. This decorator is useful when the wrapped function's

    (``fn``) execution needs to pause and resume due to, e.g., containing

    :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.



    .. note:: To enable asynchronous execution, applications must pass the

        function object returned by this decorator to RPC APIs. If RPC detected

        attributes installed by this decorator, it knows that this function

        returns a ``Future`` object and will handle that accordingly.

        However, this does not mean this decorator has to be outmost one when

        defining a function. For example, when combined with ``@staticmethod``

        or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the

        inner decorator to allow the target function be recognized as a static

        or class function. This target function can still execute asynchronously

        because, when accessed, the static or class method preserves attributes

        installed by ``@rpc.functions.async_execution``.





    Example::

        The returned :class:`~torch.futures.Future` object can come from

        :meth:`~torch.distributed.rpc.rpc_async`,

        :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`

        constructor. The example below shows directly using the

        :class:`~torch.futures.Future` returned by

        :meth:`~torch.futures.Future.then`.



        >>> from torch.distributed import rpc

        >>>

        >>> # omitting setup and shutdown RPC

        >>>

        >>> # On all workers

        >>> @rpc.functions.async_execution

        >>> def async_add_chained(to, x, y, z):

        >>>     # This function runs on "worker1" and returns immediately when

        >>>     # the callback is installed through the `then(cb)` API. In the

        >>>     # mean time, the `rpc_async` to "worker2" can run concurrently.

        >>>     # When the return value of that `rpc_async` arrives at

        >>>     # "worker1", "worker1" will run the lambda function accordingly

        >>>     # and set the value for the previously returned `Future`, which

        >>>     # will then trigger RPC to send the result back to "worker0".

        >>>     return rpc.rpc_async(to, torch.add, args=(x, y)).then(

        >>>         lambda fut: fut.wait() + z

        >>>     )

        >>>

        >>> # On worker0

        >>> # xdoctest: +SKIP

        >>> ret = rpc.rpc_sync(

        >>>     "worker1",

        >>>     async_add_chained,

        >>>     args=("worker2", torch.ones(2), 1, 1)

        >>> )

        >>> print(ret)  # prints tensor([3., 3.])



        When combined with TorchScript decorators, this decorator must be the

        outmost one.



        >>> from torch import Tensor

        >>> from torch.futures import Future

        >>> from torch.distributed import rpc

        >>>

        >>> # omitting setup and shutdown RPC

        >>>

        >>> # On all workers

        >>> @torch.jit.script

        >>> def script_add(x: Tensor, y: Tensor) -> Tensor:

        >>>     return x + y

        >>>

        >>> @rpc.functions.async_execution

        >>> @torch.jit.script

        >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:

        >>>     return rpc.rpc_async(to, script_add, (x, y))

        >>>

        >>> # On worker0

        >>> ret = rpc.rpc_sync(

        >>>     "worker1",

        >>>     async_add,

        >>>     args=("worker2", torch.ones(2), 1)

        >>> )

        >>> print(ret)  # prints tensor([2., 2.])



        When combined with static or class method, this decorator must be the

        inner one.



        >>> from torch.distributed import rpc

        >>>

        >>> # omitting setup and shutdown RPC

        >>>

        >>> # On all workers

        >>> class AsyncExecutionClass:

        >>>

        >>>     @staticmethod

        >>>     @rpc.functions.async_execution

        >>>     def static_async_add(to, x, y, z):

        >>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(

        >>>             lambda fut: fut.wait() + z

        >>>         )

        >>>

        >>>     @classmethod

        >>>     @rpc.functions.async_execution

        >>>     def class_async_add(cls, to, x, y, z):

        >>>         ret_fut = torch.futures.Future()

        >>>         rpc.rpc_async(to, torch.add, args=(x, y)).then(

        >>>             lambda fut: ret_fut.set_result(fut.wait() + z)

        >>>         )

        >>>         return ret_fut

        >>>

        >>>     @rpc.functions.async_execution

        >>>     def bound_async_add(self, to, x, y, z):

        >>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(

        >>>             lambda fut: fut.wait() + z

        >>>         )

        >>>

        >>> # On worker0

        >>> ret = rpc.rpc_sync(

        >>>     "worker1",

        >>>     AsyncExecutionClass.static_async_add,

        >>>     args=("worker2", torch.ones(2), 1, 2)

        >>> )

        >>> print(ret)  # prints tensor([4., 4.])

        >>>

        >>> ret = rpc.rpc_sync(

        >>>     "worker1",

        >>>     AsyncExecutionClass.class_async_add,

        >>>     args=("worker2", torch.ones(2), 1, 2)

        >>> )

        >>> print(ret)  # prints tensor([4., 4.])



        This decorator also works with RRef helpers, i.e., .

        :meth:`torch.distributed.rpc.RRef.rpc_sync`,

        :meth:`torch.distributed.rpc.RRef.rpc_async`, and

        :meth:`torch.distributed.rpc.RRef.remote`.



        >>> from torch.distributed import rpc

        >>>

        >>> # reuse the AsyncExecutionClass class above

        >>> rref = rpc.remote("worker1", AsyncExecutionClass)

        >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)

        >>> print(ret)  # prints tensor([4., 4.])

        >>>

        >>> rref = rpc.remote("worker1", AsyncExecutionClass)

        >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()

        >>> print(ret)  # prints tensor([4., 4.])

        >>>

        >>> rref = rpc.remote("worker1", AsyncExecutionClass)

        >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()

        >>> print(ret)  # prints tensor([4., 4.])

    """
    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        return fn(*args, **kwargs)
    # Can't declare and use attributes of function objects (mypy#2087)
    wrapper._wrapped_async_rpc_function = fn  # type: ignore[attr-defined]
    return wrapper