Spaces:
Sleeping
Sleeping
"""Async API. | |
This module contains the API for parallelism in TorchScript, notably: | |
* torch.jit.fork | |
* torch.jit.wait | |
This is not intended to be imported directly; please use the exposed | |
functionalities in `torch.jit`. | |
""" | |
import torch | |
from torch._jit_internal import Future | |
from torch.jit._builtins import _register_builtin | |
from torch.utils import set_module | |
set_module(Future, "torch.jit") | |
def fork(func, *args, **kwargs): | |
r""" | |
Create an asynchronous task executing `func` and a reference to the value of the result of this execution. | |
`fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion | |
of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked | |
with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily | |
nested, and may be invoked with positional and keyword arguments. | |
Asynchronous execution will only occur when run in TorchScript. If run in pure python, | |
`fork` will not execute in parallel. `fork` will also not execute in parallel when invoked | |
while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph. | |
.. warning:: | |
`fork` tasks will execute non-deterministically. We recommend only spawning | |
parallel fork tasks for pure functions that do not modify their inputs, | |
module attributes, or global state. | |
Args: | |
func (callable or torch.nn.Module): A Python function or `torch.nn.Module` | |
that will be invoked. If executed in TorchScript, it will execute asynchronously, | |
otherwise it will not. Traced invocations of fork will be captured in the IR. | |
``*args``, ``**kwargs``: arguments to invoke `func` with. | |
Returns: | |
`torch.jit.Future[T]`: a reference to the execution of `func`. The value `T` | |
can only be accessed by forcing completion of `func` through `torch.jit.wait`. | |
Example (fork a free function): | |
.. code-block:: python | |
import torch | |
from torch import Tensor | |
def foo(a : Tensor, b : int) -> Tensor: | |
return a + b | |
def bar(a): | |
fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) | |
return torch.jit.wait(fut) | |
script_bar = torch.jit.script(bar) | |
input = torch.tensor(2) | |
# only the scripted version executes asynchronously | |
assert script_bar(input) == bar(input) | |
# trace is not run asynchronously, but fork is captured in IR | |
graph = torch.jit.trace(bar, (input,)).graph | |
assert "fork" in str(graph) | |
Example (fork a module method): | |
.. code-block:: python | |
import torch | |
from torch import Tensor | |
class AddMod(torch.nn.Module): | |
def forward(self, a: Tensor, b : int): | |
return a + b | |
class Mod(torch.nn.Module): | |
def __init__(self): | |
super(self).__init__() | |
self.mod = AddMod() | |
def forward(self, input): | |
fut = torch.jit.fork(self.mod, a, b=2) | |
return torch.jit.wait(fut) | |
input = torch.tensor(2) | |
mod = Mod() | |
assert mod(input) == torch.jit.script(mod).forward(input) | |
""" | |
return torch._C.fork(func, *args, **kwargs) | |
def wait(future): | |
r""" | |
Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. | |
See :func:`~fork` for docs and examples. | |
Args: | |
future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork` | |
Returns: | |
`T`: the return value of the completed task | |
""" | |
return torch._C.wait(future) | |
_register_builtin(wait, "aten::wait") | |