File size: 3,884 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Async API.



This module contains the API for parallelism in TorchScript, notably:

    * torch.jit.fork

    * torch.jit.wait



This is not intended to be imported directly; please use the exposed

functionalities in `torch.jit`.

"""

import torch
from torch._jit_internal import Future
from torch.jit._builtins import _register_builtin

from torch.utils import set_module

set_module(Future, "torch.jit")


def fork(func, *args, **kwargs):
    r"""

    Create an asynchronous task executing `func` and a reference to the value of the result of this execution.



    `fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion

    of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked

    with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily

    nested, and may be invoked with positional and keyword arguments.

    Asynchronous execution will only occur when run in TorchScript. If run in pure python,

    `fork` will not execute in parallel. `fork` will also not execute in parallel when invoked

    while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.



    .. warning::

        `fork` tasks will execute non-deterministically. We recommend only spawning

        parallel fork tasks for pure functions that do not modify their inputs,

        module attributes, or global state.



    Args:

        func (callable or torch.nn.Module):  A Python function or `torch.nn.Module`

            that will be invoked. If executed in TorchScript, it will execute asynchronously,

            otherwise it will not. Traced invocations of fork will be captured in the IR.

        ``*args``, ``**kwargs``: arguments to invoke `func` with.

    Returns:

        `torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`

        can only be accessed by forcing completion of `func` through `torch.jit.wait`.



    Example (fork a free function):



    .. code-block:: python



        import torch

        from torch import Tensor

        def foo(a : Tensor, b : int) -> Tensor:

            return a + b

        def bar(a):

            fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)

            return torch.jit.wait(fut)

        script_bar = torch.jit.script(bar)

        input = torch.tensor(2)

        # only the scripted version executes asynchronously

        assert script_bar(input) == bar(input)

        # trace is not run asynchronously, but fork is captured in IR

        graph = torch.jit.trace(bar, (input,)).graph

        assert "fork" in str(graph)



    Example (fork a module method):



    .. code-block:: python



        import torch

        from torch import Tensor

        class AddMod(torch.nn.Module):

            def forward(self, a: Tensor, b : int):

                return a + b

        class Mod(torch.nn.Module):

            def __init__(self):

                super(self).__init__()

                self.mod = AddMod()

            def forward(self, input):

                fut = torch.jit.fork(self.mod, a, b=2)

                return torch.jit.wait(fut)

        input = torch.tensor(2)

        mod = Mod()

        assert mod(input) == torch.jit.script(mod).forward(input)

    """
    return torch._C.fork(func, *args, **kwargs)


def wait(future):
    r"""

    Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task.



    See :func:`~fork` for docs and examples.

    Args:

        future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`

    Returns:

        `T`: the return value of the completed task

    """
    return torch._C.wait(future)


_register_builtin(wait, "aten::wait")