Spaces:
Runtime error
Runtime error
File size: 6,245 Bytes
8a6cf24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from functools import partial
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._ops import PyOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
unwrap_proxy,
)
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch.utils._pytree import tree_flatten
from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException
map = PyOperator("map")
def trace_map(proxy_mode, func_overload, f, xs, *args):
if not isinstance(xs, torch.Tensor):
raise ValueError("map() must loop over a tensor")
if len(xs.shape) == 0 or xs.shape[0] == 0:
raise ValueError("map() cannot be traced with scalar tensors or zero dimension tensors")
if not all(isinstance(o, torch.Tensor) for o in args):
raise ValueError("map() operands must be a list of tensors or modules")
with disable_proxy_modes_tracing():
body_graph = make_fx(f)(xs[0], *args)
next_name = None
i = 0
while not next_name:
candidate = f"body_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
proxy_mode.tracer.root.register_module(next_name, body_graph)
node_args = (body_graph, xs, *args)
proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), node_args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="map")
outs = [body_graph(x, *args) for x in xs]
# Implementation notes: we need to use new_empty() + copy_() here instead of stack() directly
# because stack([...]) takes a fixed size list which will specialize dynamic shape here.
# Meanwhile we want to preserve the looped over dimension as symbolic shape, such that:
# ys: Tensor[s0, ...] = map(xs: Tensor[s0, ...], *args)
out = outs[0].new_empty([xs.shape[0], *outs[0].shape])
out.copy_(torch.stack(outs))
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@map.py_impl(DispatchKey.CUDA)
@map.py_impl(DispatchKey.CPU)
def map_cpu(f, xs, *args):
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
return torch.stack([f(x, *args) for x in xs])
@map.py_impl(DispatchKey.AutogradCUDA)
@map.py_impl(DispatchKey.AutogradCPU)
def map_autograd(f, xs, *args):
# TODO: support autograd
flat_operands, _ = tree_flatten([f, xs, args])
assert all([not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)])
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return map(f, xs, *args)
@map.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(f, xs, *args):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
res = trace_map(mode, map, f, xs, *args)
return res
@map.py_impl(FakeTensorMode)
def map_fake_tensor_mode(f, xs, *args):
outs = [f(x, *args) for x in xs]
return outs[0].new_empty([xs.shape[0], *outs[0].shape])
# We cannot directly call fallthrough here due to issue #89037.
@map.py_impl(DispatchKey.PythonDispatcher)
def map_python_dispatcher(*args):
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher))
return map(*args)
@map.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, xs, *args):
"""
Functionalization implementation for torch.map. Currently:
1. We don't allow any input mutation inside the map function
2. Our check for above condition is not exhaustive
"""
reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations'
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
functional_map_fn = functionalize(f, remove=mode)
with interpreter.lower():
fake_tensor_mode = FakeTensorMode()
with fake_tensor_mode as ft_mode:
# Returns fake inputs for a single map function call
def get_fake_inputs(unwrapped_xs, unwrapped_args):
fake_xs = ft_mode.fake_tensor_converter(ft_mode, unwrapped_xs)
fake_args = pytree.tree_map_only(
torch.Tensor,
lambda x: ft_mode.fake_tensor_converter(ft_mode, x),
unwrapped_args,
)
return (fake_xs[0],) + fake_args
fake_inputs = get_fake_inputs(unwrapped_xs, unwrapped_args)
if _has_potential_branch_input_mutation(functional_map_fn, fake_inputs):
raise UnsupportedAliasMutationException(
"torch.map is mutating the input!"
)
if _has_potential_branch_input_alias(functional_map_fn, fake_inputs):
raise UnsupportedAliasMutationException(
"torch.map is aliasing the input!"
)
map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
# TODO(voz) Make this automatic for keys, this is very ugly atm
map.fallthrough(DispatchKey.PythonTLSSnapshot)
map.fallthrough(DispatchKey.ADInplaceOrView)
map.fallthrough(DispatchKey.BackendSelect)
|