Spaces:
Runtime error
Runtime error
File size: 10,280 Bytes
8a6cf24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
from dataclasses import dataclass
from functools import partial
import torch
from torch.multiprocessing.reductions import StorageWeakRef
import torch.utils._pytree as pytree
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._ops import PyOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
make_fx,
track_tensor_tree,
unwrap_proxy,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch.utils._pytree import tree_flatten
@dataclass
class UnsupportedAliasMutationException(RuntimeError):
reason: str
"""
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
cond = PyOperator("cond")
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
with disable_proxy_modes_tracing():
true_graph = make_fx(true_fn)(*operands)
false_graph = make_fx(false_fn)(*operands)
true_outs = []
false_outs = []
for node in true_graph.graph.nodes:
if node.op == 'output':
true_outs.extend(node.args)
for node in false_graph.graph.nodes:
if node.op == 'output':
false_outs.extend(node.args)
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs)
assert(len(flat_true_outs) == len(flat_false_outs))
for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"true_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
true_name = next_name
false_name = f"false_graph_{i}"
assert(not hasattr(proxy_mode.tracer.root, false_name))
proxy_mode.tracer.root.register_module(true_name, true_graph)
proxy_mode.tracer.root.register_module(false_name, false_graph)
args = (pred, true_graph, false_graph, operands)
proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="conditional")
# At this point, we're *guaranteed* that whether an output came from the
# true or false branch is indistinguishable. So, as this is just for tracing
# purposes, choose the true branch.
# TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
# a FakeTensorMode error :
# `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
out = false_fn(*operands)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@cond.py_impl(DispatchKey.CUDA)
@cond.py_impl(DispatchKey.CPU)
def cond_dense(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
@cond.py_impl(DispatchKey.AutogradCUDA)
@cond.py_impl(DispatchKey.AutogradCPU)
def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
assert all([not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)])
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return cond(pred, true_fn, false_fn, *operands)
@cond.py_impl(ProxyTorchDispatchMode)
def inner(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
return res
@cond.py_impl(FakeTensorMode)
def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
true_outs = true_fn(*operands)
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
if len(flat_true_outs) != len(flat_false_outs):
raise RuntimeError("Unmatched number of outputs from cond() branches.")
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
true_meta = _extract_tensor_metadata(true_out)
false_meta = _extract_tensor_metadata(false_out)
if true_meta != false_meta:
raise RuntimeError(
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}")
return true_outs
# We cannot directly call fallthrough here due to issue #89037.
@cond.py_impl(DispatchKey.PythonDispatcher)
def cond_python_dispatcher(*args):
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher))
return cond(*args)
def _has_potential_branch_input_mutation(branch, fake_inputs):
"""
Dispatch-trace the branch with fake inputs and check if
producing graph has mutable op on the input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
return True
except Exception as e:
raise e
input_nodes = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_nodes.add(node)
if node.op == "call_function":
target = node.target
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
for arg in node.args:
if arg in input_nodes:
return True
return False
def _has_potential_branch_input_alias(branch, fake_inputs):
"""
Dispatch-trace the branch with fake inputs and check if
producing graph has output aliasing the branch input. This is
bit restrictive as the branch must be traceable.
"""
try:
gm = make_fx(branch)(*fake_inputs)
except UnsupportedAliasMutationException:
# this can happen when nested cond is
# functionalized
return True
except Exception as e:
raise e
input_storages = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
outs, _ = pytree.tree_flatten(gm(*fake_inputs))
for out in outs:
if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
return True
return False
@cond.py_impl(torch._C._functorch.TransformType.Functionalize)
def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
"""
Functionalization implementation for torch.cond. Currently:
1. We don't allow any input mutation inside the branches
2. Our check for above condition is not exhaustive
"""
reapply_views = interpreter.functionalize_add_back_views()
mode = 'mutations_and_views' if reapply_views else 'mutations'
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
functional_true_fn = functionalize(true_fn, remove=mode)
functional_false_fn = functionalize(false_fn, remove=mode)
with interpreter.lower():
fake_tensor_mode = FakeTensorMode()
with fake_tensor_mode as ft_mode:
for branch in [functional_true_fn, functional_false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_mutation(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be modifying the input!")
for branch in [true_fn, false_fn]:
def convert(x):
return ft_mode.fake_tensor_converter(ft_mode, x)
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
if _has_potential_branch_input_alias(branch, fake_inputs):
raise UnsupportedAliasMutationException("One of torch.cond branch "
"might be aliasing the input!")
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
# TODO(voz): Make this automatic for keys, this is very ugly atm
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
cond.fallthrough(DispatchKey.ADInplaceOrView)
cond.fallthrough(DispatchKey.BackendSelect)
|