Spaces:
Running
Running
import torch | |
import torch.utils._pytree as pytree | |
from collections import namedtuple | |
import functools | |
# NOTE [CustomOp autograd kernel indirection] | |
# We register `inner` as the autograd kernel for this custom_op. | |
# `inner` either calls the autograd formula registered by the user, | |
# or goes into an `autograd_not_implemented` kernel. | |
# | |
# The reason why this indirection exists is | |
# so that we can swap out the autograd kernel (the PyTorch dispatcher | |
# doesn't actually allow us to do this). By default, we want | |
# the `autograd_not_implemented` behavior, but then the user may come | |
# and register something that is actually a backward formula | |
def autograd_kernel_indirection(custom_op): | |
autograd_fallback = autograd_not_implemented(custom_op) | |
def inner(*args, **kwargs): | |
if custom_op._has_impl('autograd'): | |
kernel = custom_op._get_impl('autograd').func | |
return kernel(*args, **kwargs) | |
# As explained in NOTE ["backward", "save_for_backward", and "autograd"], | |
# after the user gives us "backward" and "save_for_backward", we generate | |
# the "autograd" impl. If the user only provided one, then we tell | |
# the user they've done something wrong. | |
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): | |
missing = ( | |
'save_for_backward' if custom_op._has_impl('backward') | |
else 'backward' | |
) | |
found = 'save_for_backward' if missing == 'backward' else 'backward' | |
loc = custom_op._get_impl(found).location | |
raise RuntimeError( | |
f"We found a '{found}' registration for {custom_op} at " | |
f"{loc} but were unable to find a '{missing}' registration. " | |
f"To use the CustomOp API to register a backward formula, " | |
f"please provide us both a backward function and a " | |
f"'save for backward' function via `impl_backward` and " | |
f"`impl_save_for_backward` respectively.") | |
return autograd_fallback(*args, **kwargs) | |
return inner | |
# TODO(#101191): Use the actual C++ autograd not implemented fallback, | |
# or change the default autograd fallback to the autograd not implemented fallback. | |
def autograd_not_implemented(custom_op): | |
def kernel(*args, **kwargs): | |
if torch.is_grad_enabled() and pytree.tree_any( | |
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) | |
): | |
raise RuntimeError("Autograd has not been implemented for operator") | |
with torch._C._AutoDispatchBelowAutograd(): | |
return custom_op(*args, **kwargs) | |
return kernel | |
def mark_non_differentiable(ctx, output, output_differentiability): | |
# Output types are restricted to be: | |
# - Tensor | |
# - Tensor[] | |
# - int, bool, Scalar, float | |
# See _check_can_register_backward | |
if output_differentiability is not None: | |
if not isinstance(output, tuple): | |
tuple_output = (output,) | |
else: | |
tuple_output = output # type: ignore[assignment] | |
assert len(output_differentiability) == len(tuple_output) | |
non_differentiable_tensors = [] | |
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)): | |
if isinstance(out, torch.Tensor): | |
if not differentiable: | |
non_differentiable_tensors.append(out) | |
continue | |
if isinstance(out, list): | |
if not differentiable: | |
non_differentiable_tensors.extend(out) | |
continue | |
if differentiable: | |
raise RuntimeError( | |
f"With output_differentiability={output_differentiability}. " | |
f"At idx {idx}, we received an object of type {type(out)} that " | |
f"is not a Tensor, so it cannot have be marked as differentiable in " | |
f"output_differentiability.") | |
if non_differentiable_tensors: | |
ctx.mark_non_differentiable(*non_differentiable_tensors) | |
def construct_autograd_kernel( | |
schema, | |
output_differentiability, | |
custom_op, | |
op_overload, | |
save_for_backward_fn, | |
backward_fn): | |
def apply(*args): | |
flat_args, spec = pytree.tree_flatten(args) | |
out_spec = None | |
def forward(ctx, *flat_args): | |
ctx.set_materialize_grads(True) | |
args = pytree.tree_unflatten(list(flat_args), spec) | |
with torch._C._AutoDispatchBelowAutograd(): | |
output = op_overload(*args) | |
# We use the info about args to give better error messages in backward | |
args_info = namedtuple_args( | |
schema, pytree.tree_map(type, args)) | |
save_for_backward_fn_inputs = namedtuple_args(schema, args) | |
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) | |
save_pytree_for_backward(ctx, (to_save, args_info)) | |
mark_non_differentiable(ctx, output, output_differentiability) | |
nonlocal out_spec | |
flat_output, out_spec = pytree.tree_flatten(output) | |
return tuple(flat_output) | |
def backward(ctx, *flat_grad_output): | |
assert out_spec is not None | |
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec) | |
saved, args_info = unpack_saved(ctx) | |
# There is nothing on the ctx object for now, it is just there so | |
# that we can add additional things in the future. | |
inner_ctx = object() | |
if not isinstance(grads, tuple): | |
grads = (grads,) | |
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) | |
# Massage the grad_inputs_dict to a form acceptable by | |
# autograd.Function. | |
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info) | |
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) | |
generated_cls = gen_autograd_function( | |
custom_op._opname + '_customop', forward, backward) | |
flat_output = generated_cls.apply(*flat_args) | |
assert out_spec is not None | |
return pytree.tree_unflatten(list(flat_output), out_spec) | |
return apply | |
def gen_autograd_function(name, forward, backward): | |
generated_cls = type( | |
name, | |
(torch.autograd.Function,), | |
{ | |
'forward': staticmethod(forward), | |
'backward': staticmethod(backward), | |
} | |
) | |
return generated_cls | |
def namedtuple_args_cls(schema): | |
attribs = [arg.name for arg in schema.arguments.flat_all] | |
name = str(schema.name) + "_args" | |
# mypy doesn't support dynamic namedtuple name | |
tuple_cls = namedtuple(name, attribs) # type: ignore[misc] | |
return tuple_cls | |
def namedtuple_args(schema, args): | |
assert isinstance(args, tuple) | |
tuple_cls = namedtuple_args_cls(schema) | |
return tuple_cls(*args) | |
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): | |
def error(what): | |
backward = forward_op._get_impl('backward') | |
raise RuntimeError( | |
f"In the backward function defined for {forward_op} at " | |
f"{backward.location} using the CustomOp API, {what}") | |
if not isinstance(grad_inputs_dict, dict): | |
error(f"expected the output of the backward function to be a dict but " | |
f"got {type(grad_inputs_dict)}") | |
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all | |
if arg.type.is_tensor_like()} | |
actual_keys = grad_inputs_dict.keys() | |
if expected_keys != actual_keys: | |
error(f"expected the returned grad_input dict to have keys " | |
f"{expected_keys} but got {actual_keys}. The backward " | |
f"function must return a gradient (can be None) for each arg " | |
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " | |
f"Args declared to be non-Tensor-like types should not appear " | |
f"in the grad_input dict") | |
for name, grad in grad_inputs_dict.items(): | |
arg_info = getattr(args_info, name) | |
if isinstance(arg_info, list): | |
if not isinstance(grad, (tuple, list)): | |
error(f"for input '{name}' expected the grad_input dict to " | |
f"hold a list of gradients but got object of type " | |
f"{type(grad)}.") | |
if not len(grad) == len(arg_info): | |
error(f"for input '{name}' expected the grad_input dict to " | |
f"hold a list of {len(arg_info)} gradients but got " | |
f"{len(grad)}") | |
for idx, (g, info) in enumerate(zip(grad, arg_info)): | |
if g is None: | |
continue | |
if not isinstance(g, torch.Tensor): | |
error(f"for input '{name}' expected the grad_input dict to " | |
f"hold a list of None or Tensor gradients but got " | |
f"object of {type(g)} at index {idx}") | |
if not issubclass(info, torch.Tensor): | |
error(f"for input '{name}', got a Tensor as the gradient " | |
f"for the {idx}-th value but expected None because " | |
f"the {idx}-th value was not a Tensor (it was " | |
f"type {arg_info}") | |
continue | |
if grad is None: | |
continue | |
if not isinstance(grad, torch.Tensor): | |
error(f"got object of type {type(grad)} as the gradient for input " | |
f"'{name}', " | |
f"but expected the gradient to be either None or a Tensor") | |
if not issubclass(arg_info, torch.Tensor): | |
error(f"got a Tensor as the gradient for input '{name}' but " | |
f"expected None as the gradient because input '{name}' " | |
f"was not a Tensor (it was type {arg_info}).") | |
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): | |
result = [] | |
for name, arg_info in args_info._asdict().items(): | |
if name not in grad_inputs_dict: | |
result.append(pytree.tree_map(lambda x: None, arg_info)) | |
continue | |
result.append(grad_inputs_dict[name]) | |
return tuple(pytree.tree_leaves(result)) | |
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. | |
# autograd.Function prefers that users use ctx.save_for_backward to | |
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the | |
# ctx object. | |
def save_pytree_for_backward(ctx, stuff): | |
flat_stuff, spec = pytree.tree_flatten(stuff) | |
num_elts = len(flat_stuff) | |
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) | |
if isinstance(thing, torch.Tensor)] | |
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) | |
if not isinstance(thing, torch.Tensor)] | |
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] | |
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] | |
ctx.spec = spec | |
ctx.num_elts = num_elts | |
ctx.save_for_backward(*tensors) | |
ctx.tensor_idxs = tensor_idxs | |
ctx.saved_non_tensors = non_tensors | |
ctx.non_tensor_idxs = non_tensor_idxs | |
# Inverse operation to save_pytree_for_backward | |
def unpack_saved(ctx): | |
flat_stuff = [None] * ctx.num_elts | |
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): | |
flat_stuff[idx] = tensor | |
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): | |
flat_stuff[idx] = non_tensor | |
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) | |
return stuff | |