Spaces:
Runtime error
Runtime error
import torch | |
import torch.utils._pytree as pytree | |
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet | |
from torch._dispatch.python import suspend_functionalization | |
from torch._functorch.aot_autograd import AOTConfig, create_joint | |
from torch._functorch.eager_transforms import ( | |
_unwrap_all_tensors_from_functional, | |
_wrap_all_tensors_to_functional, | |
functionalize, | |
) | |
from torch._higher_order_ops.cond import ( | |
_has_potential_branch_input_alias, | |
_has_potential_branch_input_mutation, | |
UnsupportedAliasMutationException, | |
) | |
from torch._ops import HigherOrderOperator | |
from torch._subclasses.fake_tensor import FakeTensorMode | |
from torch.fx.experimental.proxy_tensor import ( | |
disable_proxy_modes_tracing, | |
make_fx, | |
ProxyTorchDispatchMode, | |
track_tensor_tree, | |
) | |
from torch.multiprocessing.reductions import StorageWeakRef | |
from torch.utils._python_dispatch import ( | |
_get_current_dispatch_mode, | |
_pop_mode_temporarily, | |
) | |
# TODO: We add this to prevent dymamo from tracing into map_wrapper, | |
# remove the wrapper call when it's ready. | |
class MapWrapper(HigherOrderOperator): | |
def __call__(self, xs, *args): | |
return map_wrapper(xs, *args) | |
map = MapWrapper("map", _deprecated_global_ns=True) | |
map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True) | |
dummy_aot_config = AOTConfig( | |
fw_compiler=None, | |
bw_compiler=None, | |
partition_fn=None, | |
decompositions={}, | |
num_params_buffers=0, | |
aot_id=0, | |
keep_inference_input_mutations=False, | |
) | |
def create_fw_bw_graph(f, num_mapped_args, *args): | |
mapped_xs = args[:num_mapped_args] | |
pos_args = args[num_mapped_args:] | |
# Note: We create "clean" environments for make_fx by suspending all dispatch keys | |
# between Autograd and Python key. Currently, we only suspend functionalization but more can be | |
# added when required. Will encounter two problems if we don't suspend functionalization: | |
# | |
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper, | |
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching. | |
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to | |
# fetch the proxy for the inputs and fail to capture any operations on them. | |
# | |
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further | |
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer | |
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore, | |
# when creating the output node, it fails to associate the wrapped tensor with its proxy. | |
# Instead, it will create _tensor_constant as output. | |
with suspend_functionalization(): | |
with disable_proxy_modes_tracing(): | |
def from_fun(t): | |
if isinstance(t, torch.Tensor): | |
if t.dtype != torch.bool: | |
return torch.empty_strided( | |
t.size(), | |
t.stride(), | |
dtype=t.dtype, | |
requires_grad=t.requires_grad, | |
) | |
else: | |
return t.clone() | |
return t | |
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]] | |
example_pos_args = [ | |
from_fun(arg) if isinstance(arg, torch.Tensor) else arg | |
for arg in pos_args | |
] | |
example_flat_out = pytree.tree_map( | |
from_fun, f(*example_xs, *example_pos_args) | |
) | |
if any( | |
not isinstance(out, torch.Tensor) | |
for out in example_flat_out | |
if out is not None | |
): | |
raise RuntimeError( | |
"Expect outputs of map only contains tensors or None. " | |
f"Got types {[type(out) for out in example_flat_out]}." | |
) | |
example_grad = [from_fun(out) for out in example_flat_out] | |
fw_graph = make_fx(f)(*example_xs, *example_pos_args) | |
def joint_f(*example_args): | |
joint_mapped_args = example_args[:joint_num_mapped] | |
args = example_args[joint_num_mapped:] | |
mapped_input = joint_mapped_args[:num_mapped_args] | |
mapped_grads = joint_mapped_args[num_mapped_args:] | |
def fw_with_masks(*args): | |
fw_out = f(*args) | |
return fw_out, [ | |
True | |
if isinstance(ret, torch.Tensor) and ret.requires_grad | |
else False | |
for ret in fw_out | |
] | |
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) | |
_, grads = joint( | |
list(mapped_input) + list(args), | |
[ | |
grad | |
for grad in mapped_grads | |
if grad is not None and grad.requires_grad | |
], | |
) | |
# In order to keep map functional for backward graph, | |
# we clone outputs that are aliasing inputs | |
input_storage = { | |
StorageWeakRef(arg._typed_storage()) | |
for arg in example_args | |
if isinstance(arg, torch.Tensor) | |
} | |
def maybe_clone(t): | |
if ( | |
isinstance(t, torch.Tensor) | |
and StorageWeakRef(t._typed_storage()) in input_storage | |
): | |
return t.clone() | |
return t | |
return pytree.tree_map(maybe_clone, grads) | |
joint_num_mapped = len(example_grad) + len(example_xs) | |
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) | |
return fw_graph, joint_graph | |
def map_wrapper(f, xs, *args): | |
flat_xs, xs_spec = pytree.tree_flatten(xs) | |
if not all(isinstance(t, torch.Tensor) for t in flat_xs): | |
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.") | |
num_mapped_args = len(flat_xs) | |
shapes = [xs.shape for xs in flat_xs] | |
leading_dim_size = shapes[0][0] | |
if leading_dim_size == 0: | |
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") | |
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): | |
raise RuntimeError( | |
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." | |
) | |
out_spec = None | |
def flat_fn(*flat_args): | |
xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec) | |
unflattened_out = f(xs, *flat_args[num_mapped_args:]) | |
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out) | |
nonlocal out_spec | |
out_spec = tmp_out_spec | |
return flat_out | |
return pytree.tree_unflatten( | |
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec | |
) | |
class MapAutogradOp(torch.autograd.Function): | |
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): | |
ctx.save_for_backward(*flat_args) | |
ctx._joint_graph = joint_graph | |
ctx._num_mapped_args = num_mapped_args | |
with torch._C._AutoDispatchBelowAutograd(): | |
return (*map_impl(fw_graph, num_mapped_args, *flat_args),) | |
def backward(ctx, *flat_grads): | |
fw_args = ctx.saved_tensors | |
fw_mapped_args = fw_args[: ctx._num_mapped_args] | |
pos_args = fw_args[ctx._num_mapped_args :] | |
grads = map_impl( | |
ctx._joint_graph, | |
ctx._num_mapped_args + len(flat_grads), | |
*fw_mapped_args, | |
*flat_grads, | |
*pos_args, | |
) | |
return None, None, None, *grads | |
def trace_map(proxy_mode, func_overload, f, num_mapped, *args): | |
xs = list(args[:num_mapped]) | |
pos_args = list(args[num_mapped:]) | |
leading_dim_size = xs[0].shape[0] | |
example_input = _unstack_pytree(xs)[0] | |
body_graph = f | |
if not isinstance(body_graph, torch.fx.GraphModule): | |
body_graph = make_fx(body_graph)(*example_input, *pos_args) | |
with disable_proxy_modes_tracing(): | |
example_outs = body_graph(*example_input, *pos_args) | |
def expand_tensor(t): | |
if isinstance(t, torch.Tensor): | |
return t.expand(leading_dim_size, *t.shape) | |
return t | |
expanded_outs = pytree.tree_map(expand_tensor, example_outs) | |
next_name = None | |
i = 0 | |
while not next_name: | |
candidate = f"body_graph_{i}" | |
if hasattr(proxy_mode.tracer.root, candidate): | |
i += 1 | |
else: | |
next_name = candidate | |
proxy_mode.tracer.root.register_module(next_name, body_graph) | |
node_args = (body_graph, num_mapped, *args) | |
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) | |
out_proxy = proxy_mode.tracer.create_proxy( | |
"call_function", func_overload, proxy_args, {}, name="map_impl" | |
) | |
return track_tensor_tree( | |
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer | |
) | |
def _unstack_pytree(xs): | |
flat_xs, inspec = pytree.tree_flatten(xs) | |
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): | |
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") | |
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): | |
raise RuntimeError( | |
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" | |
) | |
a = zip(*flat_xs) | |
pytrees = [] | |
for tuple in a: | |
pytrees.append(pytree.tree_unflatten(tuple, inspec)) | |
return pytrees | |
def _stack_pytree(pytrees): | |
flat_out = [] | |
out_spec = None | |
for pt in pytrees: | |
flat_pt, out_spec = pytree.tree_flatten(pt) | |
flat_out.append(flat_pt) | |
b = zip(*flat_out) | |
stacked_out = [] | |
for leaves in b: | |
if all(isinstance(leaf, torch.Tensor) for leaf in leaves): | |
stacked_out.append(torch.stack(leaves)) | |
elif all(leaf is None for leaf in leaves): | |
# Backward graph can return None output when forward inputs doesn't require grad. | |
# When we eagerly execute backward graph, we need to call _stack_pytree on its output, | |
# therefore we need to deal with None output. | |
stacked_out.append(None) | |
else: | |
raise RuntimeError(f"Cannot stack {leaves}.") | |
return pytree.tree_unflatten(stacked_out, out_spec) | |
def map_dense(f, num_mapped_args, *args): | |
xs = args[:num_mapped_args] | |
pos_args = args[num_mapped_args:] | |
pytrees = [] | |
for inp in _unstack_pytree(xs): | |
pytrees.append(f(*inp, *pos_args)) | |
return _stack_pytree(pytrees) | |
def map_autograd(f, num_mapped_args, *args): | |
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args) | |
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args) | |
return flat_out | |
def map_proxy_torch_dispatch_mode(f, num_mapped, *args): | |
mode = _get_current_dispatch_mode() | |
assert mode is not None, "Mode should always be enabled for python fallback key" | |
with _pop_mode_temporarily() as mode: | |
if mode.enable_tracing: | |
return trace_map(mode, map_impl, f, num_mapped, *args) | |
else: | |
return map_impl(f, num_mapped, *args) | |
def map_fake_tensor_mode(f, num_mapped, *args): | |
return map_dense(f, num_mapped, *args) | |
def map_func(f, num_mapped, *args): | |
reapply_views = torch._C._functionalization_reapply_views_tls() | |
xs = args[:num_mapped] | |
pos_args = args[num_mapped:] | |
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) | |
unwrapped_args = _unwrap_all_tensors_from_functional( | |
pos_args, reapply_views=reapply_views | |
) | |
mode = "mutations_and_views" if reapply_views else "mutations" | |
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): | |
functional_map_fn = functionalize(f, remove=mode) | |
with disable_proxy_modes_tracing(): | |
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) | |
if _has_potential_branch_input_mutation(f, example_inputs): | |
raise UnsupportedAliasMutationException("torch.map is mutating the input!") | |
if _has_potential_branch_input_alias(f, example_inputs): | |
raise UnsupportedAliasMutationException("torch.map is aliasing the input!") | |
map_return = map_impl( | |
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args | |
) | |
return _wrap_all_tensors_to_functional(map_return, level=0) | |
def map_functionalize(interpreter, f, num_mapped, *args): | |
""" | |
Functionalization implementation for torch.map. Currently: | |
1. We don't allow any input mutation inside the map function | |
2. Our check for above condition is not exhaustive | |
""" | |
xs = args[:num_mapped] | |
pos_args = args[num_mapped:] | |
reapply_views = interpreter.functionalize_add_back_views() | |
mode = "mutations_and_views" if reapply_views else "mutations" | |
# At this point, we will see functionalized tensors, so need to unwrap them first | |
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) | |
unwrapped_args = _unwrap_all_tensors_from_functional( | |
pos_args, reapply_views=reapply_views | |
) | |
functional_map_fn = functionalize(f, remove=mode) | |
with interpreter.lower(): | |
with disable_proxy_modes_tracing(): | |
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) | |
if _has_potential_branch_input_mutation(f, example_inputs): | |
raise UnsupportedAliasMutationException("torch.map is mutating the input!") | |
if _has_potential_branch_input_alias(f, example_inputs): | |
raise UnsupportedAliasMutationException("torch.map is aliasing the input!") | |
map_return = map_impl( | |
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args | |
) | |
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level()) | |
# TODO(voz) Make this automatic for keys, this is very ugly atm | |
map_impl.fallthrough(DispatchKey.PythonDispatcher) | |
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot) | |
map_impl.fallthrough(DispatchKey.ADInplaceOrView) | |
map_impl.fallthrough(DispatchKey.BackendSelect) | |
map_impl.fallthrough(DispatchKey.AutocastCPU) | |