Spaces:
Running
Running
import weakref | |
import torch | |
import torch.utils._pytree as pytree | |
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet | |
from torch._ops import OpOverload | |
from torch.library import Library | |
from torchgen.model import ( | |
BaseTy, | |
BaseType, | |
FunctionSchema, | |
OperatorName, | |
OptionalType, | |
SchemaKind, | |
) | |
from .autograd import autograd_not_implemented | |
def register_functional_op( | |
lib: Library, | |
new_op_name: str, | |
mutable_op: OpOverload, | |
) -> None: | |
"""Given a mutable operator, registers the functional variant. | |
This API also correctly links the functional variant with the mutable | |
operator for the purposes of functionalization. | |
All of the new registrations are performed on the ``lib`` passed in. | |
Arguments: | |
lib (Library): Should be a torch.library.Library object that has | |
the same namespace as ``mutable_op``'s namespace. | |
lib will be used to register the new functional op as well | |
as a functionalization kernel for the ``mutable_op`` | |
If you don't have a library handy, use | |
``torch.library.Library(ns, 'FRAGMENT')`` to construct one. | |
new_op_name (str): The name of the functional operator (without the | |
namespace). If no namespace, the new functional variant will be | |
accessible under ``torch.ops.{lib.ns}.new_op_name``. | |
mutable_op (OpOverload): The mutable custom operator. Note | |
that you may need to add a `.default` to it, like | |
`torch.ops.aten.abs_.default`. | |
""" | |
validate(mutable_op) | |
schema = functional_schema(new_op_name, mutable_op) | |
lib.define(schema) | |
functional_impl = construct_functional_impl(mutable_op) | |
lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd') | |
functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default | |
# There's no easy way for us to generate the autograd kernel, so we | |
# use autograd_not_implemented. Also, this makes it so that the user | |
# is unable to register an autograd formula themselves. This shouldn't | |
# be a problem if the user doesn't use the functional op direclty | |
# in their program, but we may need to revist this in the future. | |
lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd') | |
f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op) | |
lib.impl(mutable_op, f_kernel, 'Functionalize') | |
def construct_functional_impl(mutable_op): | |
def functional_impl(*args): | |
# Strategy: | |
# - clone args that would have been mutated | |
# - run mutable_op | |
# - return the cloned args as additional outputs | |
new_args = [] | |
extra_rets = [] | |
for is_write, arg in zip(mutable_args(mutable_op), args): | |
if is_write: | |
cloned = arg.clone() if arg is not None else None | |
new_args.append(cloned) | |
extra_rets.append(cloned) | |
else: | |
new_args.append(arg) | |
result = mutable_op(*new_args) | |
if result is None: | |
return tuple(extra_rets) | |
if isinstance(result, tuple): | |
return (*result, *extra_rets) | |
return (result, *extra_rets) | |
return functional_impl | |
def construct_functionalization_kernel(mutable_op, functional_op): | |
def kernel(*args): | |
# There's nothing to be functionalized! | |
# We can still end up here because DispatchKey::Functionalize is a mode key | |
if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args): | |
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): | |
return mutable_op(*args) | |
# NB: This differs from the codegen -- codegen handles cases where there | |
# are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper. | |
# This only really matters for XLA (mixed CPU-XLA tensors) and | |
# running functionalization without the PT2 stack (which guarantees to us that | |
# all tensors are FunctionalTensorWrapper). | |
if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args): | |
raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper") | |
unwrapped_args = [] | |
for arg in args: | |
if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg): | |
torch._sync(arg) | |
unwrapped = torch._from_functional_tensor(arg) | |
unwrapped_args.append(unwrapped) | |
else: | |
unwrapped_args.append(arg) | |
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): | |
output = functional_op(*unwrapped_args) | |
num_actual_output = len(mutable_op._schema.returns) | |
actual_output = pytree.tree_map( | |
torch._to_functional_tensor, output[:num_actual_output]) | |
new_values_to_propagate = output[num_actual_output:] | |
inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args) | |
if is_write] | |
assert len(new_values_to_propagate) == len(inputs_to_replace) | |
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace): | |
if (arg is None and new_value is None) or (arg is not None and new_value is not None): | |
continue | |
torch._C._propagate_xla_data(arg, new_value) | |
torch._C._replace_(arg, new_value) | |
torch._C._commit_update(arg) | |
torch._sync(arg) | |
if len(actual_output) == 1: | |
return actual_output[0] | |
elif len(actual_output) == 0: | |
return None | |
return actual_output | |
return kernel | |
def validate(mutable_op: OpOverload): | |
if not isinstance(mutable_op, OpOverload): | |
raise TypeError( | |
f"register_functional_op(mutable_op): expected mutable_op to be instance of " | |
f"OpOverload but got {type(mutable_op)}") | |
# There are generally three types of "in-place" or "mutable" ops. | |
# Each of them have their own conventions: | |
# - inplace (first input modified in-place and returned as only output) | |
# - out= (some args modified in-place and returned as outputs) | |
# - mutable (some args modified in-place but none of those returned as outputs) | |
# In theory we can support all three, but we'll just support the last | |
# option right now for simplicity. | |
schema = FunctionSchema.parse(str(mutable_op._schema)) | |
if not schema.kind() == SchemaKind.mutable: | |
raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)") | |
for ret in schema.returns: | |
# construct_functionalization_kernel assumes this for simplicity | |
if ret.annotation is not None: | |
raise NotImplementedError( | |
"NYI: register_functional_op(op) where op returns a mutated or aliased value. " | |
"Please file an issue (and as a workaround, modify your operator to " | |
"not return the mutated value or aliases)") | |
for arg in schema.arguments.flat_all: | |
# construct_functionalization_kernel assumes this for simplicity | |
if arg.type.is_tensor_like() and ( | |
arg.type != BaseType(BaseTy.Tensor) | |
and arg.type != OptionalType(BaseType(BaseTy.Tensor)) | |
): | |
raise NotImplementedError( | |
"NYI: register_functional_op(op) where op has a List[Tensor] input." | |
"Please file an issue.") | |
def functional_schema(new_op_name, op: OpOverload): | |
schema = FunctionSchema.parse(str(op._schema)) | |
schema = schema.signature().with_name(OperatorName.parse(new_op_name)) | |
return str(schema) | |
def mutable_args(op: OpOverload): | |
return tuple(False if arg.alias_info is None else arg.alias_info.is_write | |
for arg in op._schema.arguments) | |