Spaces:
Running
Running
File size: 12,065 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
import torch
import torch.utils._pytree as pytree
from collections import namedtuple
import functools
# NOTE [CustomOp autograd kernel indirection]
# We register `inner` as the autograd kernel for this custom_op.
# `inner` either calls the autograd formula registered by the user,
# or goes into an `autograd_not_implemented` kernel.
#
# The reason why this indirection exists is
# so that we can swap out the autograd kernel (the PyTorch dispatcher
# doesn't actually allow us to do this). By default, we want
# the `autograd_not_implemented` behavior, but then the user may come
# and register something that is actually a backward formula
def autograd_kernel_indirection(custom_op):
autograd_fallback = autograd_not_implemented(custom_op)
def inner(*args, **kwargs):
if custom_op._has_impl('autograd'):
kernel = custom_op._get_impl('autograd').func
return kernel(*args, **kwargs)
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
# after the user gives us "backward" and "save_for_backward", we generate
# the "autograd" impl. If the user only provided one, then we tell
# the user they've done something wrong.
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
missing = (
'save_for_backward' if custom_op._has_impl('backward')
else 'backward'
)
found = 'save_for_backward' if missing == 'backward' else 'backward'
loc = custom_op._get_impl(found).location
raise RuntimeError(
f"We found a '{found}' registration for {custom_op} at "
f"{loc} but were unable to find a '{missing}' registration. "
f"To use the CustomOp API to register a backward formula, "
f"please provide us both a backward function and a "
f"'save for backward' function via `impl_backward` and "
f"`impl_save_for_backward` respectively.")
return autograd_fallback(*args, **kwargs)
return inner
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
# or change the default autograd fallback to the autograd not implemented fallback.
def autograd_not_implemented(custom_op):
def kernel(*args, **kwargs):
if torch.is_grad_enabled() and pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
):
raise RuntimeError("Autograd has not been implemented for operator")
with torch._C._AutoDispatchBelowAutograd():
return custom_op(*args, **kwargs)
return kernel
def mark_non_differentiable(ctx, output, output_differentiability):
# Output types are restricted to be:
# - Tensor
# - Tensor[]
# - int, bool, Scalar, float
# See _check_can_register_backward
if output_differentiability is not None:
if not isinstance(output, tuple):
tuple_output = (output,)
else:
tuple_output = output # type: ignore[assignment]
assert len(output_differentiability) == len(tuple_output)
non_differentiable_tensors = []
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
if isinstance(out, torch.Tensor):
if not differentiable:
non_differentiable_tensors.append(out)
continue
if isinstance(out, list):
if not differentiable:
non_differentiable_tensors.extend(out)
continue
if differentiable:
raise RuntimeError(
f"With output_differentiability={output_differentiability}. "
f"At idx {idx}, we received an object of type {type(out)} that "
f"is not a Tensor, so it cannot have be marked as differentiable in "
f"output_differentiability.")
if non_differentiable_tensors:
ctx.mark_non_differentiable(*non_differentiable_tensors)
def construct_autograd_kernel(
schema,
output_differentiability,
custom_op,
op_overload,
save_for_backward_fn,
backward_fn):
def apply(*args):
flat_args, spec = pytree.tree_flatten(args)
out_spec = None
def forward(ctx, *flat_args):
ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec)
with torch._C._AutoDispatchBelowAutograd():
output = op_overload(*args)
# We use the info about args to give better error messages in backward
args_info = namedtuple_args(
schema, pytree.tree_map(type, args))
save_for_backward_fn_inputs = namedtuple_args(schema, args)
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
save_pytree_for_backward(ctx, (to_save, args_info))
mark_non_differentiable(ctx, output, output_differentiability)
nonlocal out_spec
flat_output, out_spec = pytree.tree_flatten(output)
return tuple(flat_output)
def backward(ctx, *flat_grad_output):
assert out_spec is not None
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
saved, args_info = unpack_saved(ctx)
# There is nothing on the ctx object for now, it is just there so
# that we can add additional things in the future.
inner_ctx = object()
if not isinstance(grads, tuple):
grads = (grads,)
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
# Massage the grad_inputs_dict to a form acceptable by
# autograd.Function.
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
generated_cls = gen_autograd_function(
custom_op._opname + '_customop', forward, backward)
flat_output = generated_cls.apply(*flat_args)
assert out_spec is not None
return pytree.tree_unflatten(list(flat_output), out_spec)
return apply
def gen_autograd_function(name, forward, backward):
generated_cls = type(
name,
(torch.autograd.Function,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
}
)
return generated_cls
@functools.lru_cache
def namedtuple_args_cls(schema):
attribs = [arg.name for arg in schema.arguments.flat_all]
name = str(schema.name) + "_args"
# mypy doesn't support dynamic namedtuple name
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
return tuple_cls
def namedtuple_args(schema, args):
assert isinstance(args, tuple)
tuple_cls = namedtuple_args_cls(schema)
return tuple_cls(*args)
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
def error(what):
backward = forward_op._get_impl('backward')
raise RuntimeError(
f"In the backward function defined for {forward_op} at "
f"{backward.location} using the CustomOp API, {what}")
if not isinstance(grad_inputs_dict, dict):
error(f"expected the output of the backward function to be a dict but "
f"got {type(grad_inputs_dict)}")
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
if arg.type.is_tensor_like()}
actual_keys = grad_inputs_dict.keys()
if expected_keys != actual_keys:
error(f"expected the returned grad_input dict to have keys "
f"{expected_keys} but got {actual_keys}. The backward "
f"function must return a gradient (can be None) for each arg "
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
f"Args declared to be non-Tensor-like types should not appear "
f"in the grad_input dict")
for name, grad in grad_inputs_dict.items():
arg_info = getattr(args_info, name)
if isinstance(arg_info, list):
if not isinstance(grad, (tuple, list)):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of gradients but got object of type "
f"{type(grad)}.")
if not len(grad) == len(arg_info):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of {len(arg_info)} gradients but got "
f"{len(grad)}")
for idx, (g, info) in enumerate(zip(grad, arg_info)):
if g is None:
continue
if not isinstance(g, torch.Tensor):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of None or Tensor gradients but got "
f"object of {type(g)} at index {idx}")
if not issubclass(info, torch.Tensor):
error(f"for input '{name}', got a Tensor as the gradient "
f"for the {idx}-th value but expected None because "
f"the {idx}-th value was not a Tensor (it was "
f"type {arg_info}")
continue
if grad is None:
continue
if not isinstance(grad, torch.Tensor):
error(f"got object of type {type(grad)} as the gradient for input "
f"'{name}', "
f"but expected the gradient to be either None or a Tensor")
if not issubclass(arg_info, torch.Tensor):
error(f"got a Tensor as the gradient for input '{name}' but "
f"expected None as the gradient because input '{name}' "
f"was not a Tensor (it was type {arg_info}).")
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
result = []
for name, arg_info in args_info._asdict().items():
if name not in grad_inputs_dict:
result.append(pytree.tree_map(lambda x: None, arg_info))
continue
result.append(grad_inputs_dict[name])
return tuple(pytree.tree_leaves(result))
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
# autograd.Function prefers that users use ctx.save_for_backward to
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
# ctx object.
def save_pytree_for_backward(ctx, stuff):
flat_stuff, spec = pytree.tree_flatten(stuff)
num_elts = len(flat_stuff)
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if isinstance(thing, torch.Tensor)]
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if not isinstance(thing, torch.Tensor)]
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
ctx.spec = spec
ctx.num_elts = num_elts
ctx.save_for_backward(*tensors)
ctx.tensor_idxs = tensor_idxs
ctx.saved_non_tensors = non_tensors
ctx.non_tensor_idxs = non_tensor_idxs
# Inverse operation to save_pytree_for_backward
def unpack_saved(ctx):
flat_stuff = [None] * ctx.num_elts
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
flat_stuff[idx] = tensor
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
flat_stuff[idx] = non_tensor
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
return stuff
|