Spaces:
Running
Running
import dataclasses | |
import functools | |
import inspect | |
import sys | |
import typing | |
import weakref | |
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy | |
import torch | |
import torch._C as _C | |
import torch.library as library | |
from torch._library.abstract_impl import AbstractImplCtx | |
from torch.library import get_ctx | |
from .autograd import autograd_kernel_indirection, construct_autograd_kernel | |
""" | |
For a detailed guide on custom ops, please see | |
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk | |
This file includes pieces of the implementation of our custom operator API. | |
""" | |
__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"] | |
SUPPORTED_DEVICE_TYPE_TO_KEY = { | |
"cpu": "CPU", | |
"cuda": "CUDA", | |
} | |
# We will not let users register CustomOps with anything that could look like | |
# PyTorch internals to avoid confusion. | |
RESERVED_NS = { | |
"prim", | |
"prims", | |
"aten", | |
"at", | |
"torch", | |
"pytorch", | |
} | |
def custom_op( | |
qualname: str, manual_schema: typing.Optional[str] = None | |
) -> typing.Callable: | |
r"""Creates a new CustomOp object. | |
WARNING: if you're a user, please do not use this directly | |
(instead use the torch._custom_ops APIs). | |
Also please see the following for a detailed guide on custom ops. | |
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk | |
In PyTorch, defining an op (short for "operator") is a two step-process: | |
- we need to define (create) the op | |
- we need to implement behavior for how the operator interacts with | |
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. | |
This entrypoint defines the CustomOp object (the first step); | |
you must then perform the second step by calling various methods on | |
the CustomOp object. | |
This API is used as a decorator (see examples). | |
Arguments: | |
qualname (str): Should be a string that looks like | |
"namespace::operator_name". Operators in PyTorch need a namespace to | |
avoid name collisions; a given operator may only be created once. | |
If you are writing a Python library, we recommend the namespace to | |
be the name of your top-level module. The operator_name must be | |
the same as the name of the function you pass to custom_op | |
(see examples). | |
manual_schema (Optional[str]): Each PyTorch operator needs a schema that | |
tells PyTorch the types of the inputs/outputs. If None (default), | |
we will infer the schema from the type annotations on the function | |
(see examples). Otherwise, if you don't want to use type annotations, | |
you may provide us the schema string. | |
Example:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) | |
>>> import numpy as np | |
>>> from torch import Tensor | |
>>> | |
>>> # Step 1: define the CustomOp. | |
>>> # We need to provide the decorator a "prototype function" | |
>>> # (a function with Python ellipses as the body). | |
>>> @custom_op("my_library::numpy_sin") | |
>>> def numpy_sin(x: Tensor) -> Tensor: | |
>>> ... | |
>>> | |
>>> # numpy_sin is now an instance of class CustomOp | |
>>> print(type(numpy_sin)) | |
>>> | |
>>> # Step 2: Register an implementation for various PyTorch subsystems | |
>>> | |
>>> # Register an implementation for CPU tensors | |
>>> @numpy_sin.impl('cpu') | |
>>> def numpy_sin_impl_cpu(x): | |
>>> return torch.from_numpy(np.sin(x.numpy())) | |
>>> | |
>>> # Register an implementation for CUDA tensors | |
>>> @numpy_sin.impl('cuda') | |
>>> def numpy_sin_impl_cuda(x): | |
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) | |
>>> | |
>>> x = torch.randn(3) | |
>>> numpy_sin(x) # calls numpy_sin_impl_cpu | |
>>> | |
>>> x_cuda = x.cuda() | |
>>> numpy_sin(x) # calls numpy_sin_impl_cuda | |
""" | |
def inner(func): | |
if not inspect.isfunction(func): | |
raise ValueError( | |
f"custom_op(...)(func): Expected `func` to be a Python " | |
f"function, got: {type(func)}" | |
) | |
ns, name = parse_qualname(qualname) | |
validate_namespace(ns) | |
if func.__name__ != name: | |
raise ValueError( | |
f"custom_op(qualname='{qualname}', ...)(func): expected `func` " | |
f"to have name '{name}' but got '{func.__name__}'. " | |
f"Please either change the name of `func` or the qualname that " | |
f"is passed to `custom_op`" | |
) | |
schema = infer_schema(func) if manual_schema is None else manual_schema | |
schema_str = f"{name}{schema}" | |
function_schema = FunctionSchema.parse(schema_str) | |
validate_schema(function_schema) | |
if manual_schema is not None: | |
validate_function_matches_schema(function_schema, func) | |
lib = library.Library(ns, "FRAGMENT") | |
lib.define(schema_str) | |
ophandle = find_ophandle_or_throw(ns, function_schema.name) | |
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) | |
result.__name__ = func.__name__ | |
result.__module__ = func.__module__ | |
result.__doc__ = func.__doc__ | |
library.impl(lib, result._opname, "Autograd")( | |
autograd_kernel_indirection(weakref.proxy(result)) | |
) | |
torch._C._dispatch_set_report_error_callback( | |
ophandle, functools.partial(report_error_callback, weakref.proxy(result)) | |
) | |
return result | |
return inner | |
# Global dictionary holding references to all CustomOp objects | |
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime]) | |
# Used to query the CustomOp associated with a specific C++ dispatcher operator. | |
# An example usage is FakeTensor: FakeTensor checks if a specific operator | |
# has an implementation registered via the CustomOp API. | |
# Indexed by qualname (e.g. aten::foo) | |
global_registry: typing.Dict[str, "CustomOp"] = {} | |
class CustomOp: | |
r"""Class for custom operators in PyTorch. | |
Use the CustomOp API to create user-defined custom operators that behave | |
just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it | |
comes to various PyTorch subsystems (like torch.compile). | |
To construct a `CustomOp`, use `custom_op`. | |
""" | |
def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False): | |
super().__init__() | |
if not _private_access: | |
raise RuntimeError( | |
"The CustomOp constructor is private and we do not guarantee " | |
"BC for it. Please use custom_op(...) to create a CustomOp object" | |
) | |
name = f"{cpp_ns}::{operator_name}" | |
self._schema = schema | |
self._cpp_ns = cpp_ns | |
self._lib: library.Library = lib | |
self._ophandle: _C._DispatchOperatorHandle = ophandle | |
# Has the name of the op, e.g. "foo". We cache here for convenience. | |
self._opname: str = operator_name | |
# this is _opname but with namespace. e.g. "custom::foo" | |
self._qualname: str = name | |
self.__name__ = None # mypy requires this | |
# NB: Some of these impls are registered as kernels to DispatchKeys. | |
# Modifying the _impls dict directly won't do anything in that case. | |
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {} | |
# See NOTE [CustomOp autograd kernel indirection] | |
self._registered_autograd_kernel_indirection = False | |
global_registry[self._qualname] = self | |
def _register_autograd_kernel_indirection(self): | |
assert not self._registered_autograd_kernel_indirection | |
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd") | |
self._registered_autograd_kernel_indirection = True | |
# Records the impl and the source location in self._impls | |
# Note that this doesn't cause torch.library to use the impl, that | |
# needs to be done in a separate self._lib.impl call. | |
def _register_impl(self, kind, func, stacklevel=2): | |
if self._has_impl(kind): | |
func_and_location = self._impls[kind] | |
assert func_and_location is not None # Pacify mypy | |
location = func_and_location.location | |
raise RuntimeError( | |
f"Attempting to register a {kind} impl for operator {self._qualname} " | |
f"that already has a {kind} impl registered from Python at " | |
f"{location}. This is not supported." | |
) | |
frame = inspect.getframeinfo(sys._getframe(stacklevel)) | |
location = f"{frame.filename}:{frame.lineno}" | |
self._impls[kind] = FuncAndLocation(func, location) | |
def _get_impl(self, kind): | |
return self._impls[kind] | |
def _has_impl(self, kind): | |
return kind in self._impls | |
def _destroy(self): | |
# NOTE: [CustomOp lifetime] | |
# A CustomOp, once created, lives forever. The mechanism is that the | |
# global registry holds a reference to it. However, to make testing | |
# easier, we want to be able to destroy CustomOp objects. | |
# CustomOp._destroy does the job, though it leaves the CustomOp | |
# in a garbage state. | |
del self._lib | |
opnamespace = getattr(torch.ops, self._cpp_ns) | |
if hasattr(opnamespace, self._opname): | |
delattr(opnamespace, self._opname) | |
del global_registry[self._qualname] | |
def __repr__(self): | |
return f'<CustomOp(op="{self._qualname}")>' | |
def __call__(self, *args, **kwargs): | |
# Bypass torch.ops.* and directly do OperatorHandle::callBoxed. | |
# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime | |
# issues from caching operators that make testing CustomOp difficult). | |
result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs) | |
return result | |
def impl( | |
self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2, | |
) -> typing.Callable: | |
r"""Register an implementation for a device type for this CustomOp object. | |
WARNING: if you're a user, please do not use this directly | |
(instead use the torch._custom_ops APIs). | |
Also please see the following for a detailed guide on custom ops. | |
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk | |
If the CustomOp is passed multiple Tensor inputs with different device | |
types, it will dispatch to the registered implementation for the highest | |
priority device type among those present. | |
The supported device types, in order of priority, are {'cuda', 'cpu'}. | |
This API is used as a decorator (see examples). | |
Arguments: | |
device_types (str or Iterable[str]): the device type(s) to register the function for. | |
Examples:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) | |
>>> import numpy as np | |
>>> from torch import Tensor | |
>>> | |
>>> @custom_op("my_library::numpy_cos") | |
>>> def numpy_cos(x: Tensor) -> Tensor: | |
>>> ... | |
>>> | |
>>> # Register an implementation for CPU Tensors | |
>>> @numpy_cos.impl('cpu') | |
>>> def numpy_cos_impl_cpu(x): | |
>>> return torch.from_numpy(np.cos(x.numpy())) | |
>>> | |
>>> # Register an implementation for CUDA Tensors | |
>>> @numpy_cos.impl('cuda') | |
>>> def numpy_cos_impl_cuda(x): | |
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device) | |
>>> | |
>>> x = torch.randn(3) | |
>>> numpy_cos(x) # calls numpy_cos_impl_cpu | |
>>> | |
>>> x_cuda = x.cuda() | |
>>> numpy_cos(x) # calls numpy_cos_impl_cuda | |
""" | |
if isinstance(device_types, str): | |
device_types = [device_types] | |
for device_type in device_types: | |
validate_device_type(device_type) | |
def inner(f): | |
for device_type in set(device_types): | |
self._check_doesnt_have_library_impl(device_type) | |
self._register_impl(device_type, f, stacklevel=_stacklevel) | |
dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] | |
library.impl(self._lib, self._opname, dispatch_key)(f) | |
return f | |
return inner | |
def _check_doesnt_have_library_impl(self, device_type): | |
if self._has_impl(device_type): | |
return | |
key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] | |
if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key): | |
raise RuntimeError( | |
f"impl(..., device_types={device_type}): the operator {self._qualname} " | |
f"already has an implementation for this device type via a " | |
f"pre-existing torch.library or TORCH_LIBRARY registration.") | |
def impl_factory(self) -> typing.Callable: | |
r"""Register an implementation for a factory function.""" | |
def inner(f): | |
self._register_impl("factory", f) | |
library.impl(self._lib, self._opname, "BackendSelect")(f) | |
return f | |
return inner | |
def impl_abstract(self, _stacklevel=2) -> typing.Callable: | |
r"""Register an abstract implementation for this operator. | |
WARNING: please do not use this directly (and instead use the torch._custom_ops | |
APIs). Also please see the following for a detailed guide on custom ops. | |
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk | |
An "abstract implementation" specifies the behavior of this operator on | |
Tensors that carry no data. Given some input Tensors with certain properties | |
(sizes/strides/storage_offset/device), it specifies what the properties of | |
the output Tensors are. | |
The abstract implementation has the same signature as the operator. | |
It is run for both FakeTensors and meta tensors. To write an abstract | |
implementation, assume that all Tensor inputs to the operator are | |
regular CPU/CUDA/Meta tensors, but they do not have storage, and | |
you are trying to return regular CPU/CUDA/Meta tensor(s) as output. | |
The abstract implementation must consist of only PyTorch operations | |
(and may not directly access the storage or data of any input or | |
intermediate Tensors). | |
This API is used as a decorator (see examples). | |
Examples:: | |
>>> import numpy as np | |
>>> from torch import Tensor | |
>>> | |
>>> # Example 1: an operator without data-dependent output shape | |
>>> @custom_op('my_library::custom_linear') | |
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: | |
>>> ... | |
>>> | |
>>> @custom_linear.impl_abstract() | |
>>> def custom_linear_abstract(x, weight): | |
>>> assert x.dim() == 2 | |
>>> assert weight.dim() == 2 | |
>>> assert bias.dim() == 1 | |
>>> assert x.shape[1] == weight.shape[1] | |
>>> assert weight.shape[0] == bias.shape[0] | |
>>> assert x.device == weight.device | |
>>> | |
>>> return (x @ weight.t()) + bias | |
>>> | |
>>> # Example 2: an operator with data-dependent output shape | |
>>> @custom_op('my_library::custom_nonzero') | |
>>> def custom_nonzero(x: Tensor) -> Tensor: | |
>>> ... | |
>>> | |
>>> @custom_nonzero.impl_abstract() | |
>>> def custom_nonzero_abstract(x): | |
>>> # Number of nonzero-elements is data-dependent. | |
>>> # Since we cannot peek at the data in an abstract impl, | |
>>> # we use the ctx object to construct a new symint that | |
>>> # represents the data-dependent size. | |
>>> ctx = torch._custom_op.get_ctx() | |
>>> nnz = ctx.create_unbacked_symint() | |
>>> shape = [x.dim(), nnz] | |
>>> result = x.new_empty(shape, dtype=torch.long) | |
>>> return result | |
>>> | |
>>> @custom_nonzero.impl(['cpu', 'cuda']) | |
>>> def custom_nonzero_impl(x): | |
>>> x_np = to_numpy(x) | |
>>> res = np.stack(np.nonzero(x_np), axis=1) | |
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we | |
>>> # constrain the range to at least 2 | |
>>> if res.shape[0] <= 1: | |
>>> raise RuntimeError("not supported") | |
>>> return torch.tensor(res, device=x.device) | |
""" | |
def inner(f): | |
self._check_doesnt_have_library_meta_impl() | |
self._register_impl("abstract", f, stacklevel=_stacklevel) | |
location = self._get_impl("abstract").location | |
qualname = self._qualname | |
# Handle DispatchKey.Meta registration | |
def f_with_ctx(*args, **kwargs): | |
def error_on_ctx(): | |
raise RuntimeError( | |
f"Attempted to call get_ctx() for the meta implementation " | |
f"for {qualname}." | |
f"You have presumably called get_ctx() because the operator " | |
f"has a data-dependent output shape; if so, there is no " | |
f"such meta implementation and this error is the correct " | |
f"behavior. Otherwise, please remove the call to get_ctx() " | |
f"in the implementation registered with impl_abstract " | |
f"at {location}" | |
) | |
with torch._library.abstract_impl.set_ctx_getter(error_on_ctx): | |
return f(*args, **kwargs) | |
self._lib.impl(self._opname, f_with_ctx, "Meta") | |
return f | |
return inner | |
def _check_can_register_backward(self): | |
def error(detail): | |
raise RuntimeError( | |
f"Cannot use torch._custom_ops APIs to register backward " | |
f"formula for {detail}. Got operator " | |
f"{self._qualname} with schema: {schema}" | |
) | |
schema = self._schema | |
if schema.kind() != SchemaKind.functional: | |
error("non-functional operator") | |
rets = schema.returns | |
if not schema.returns: | |
error("operator with no returns") | |
assert len(rets) > 0 | |
is_non_mutating_view = any( | |
r.annotation is not None and not r.annotation.is_write for r in rets | |
) | |
if is_non_mutating_view: | |
error("operator that returns views") | |
# We make assumptions about the schema's return types. | |
allowed_return_types = { | |
BaseType(BaseTy.int): "int", | |
BaseType(BaseTy.SymInt): "SymInt", | |
BaseType(BaseTy.bool): "bool", | |
BaseType(BaseTy.float): "float", | |
BaseType(BaseTy.Tensor): "Tensor", | |
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]", | |
} | |
for ret in schema.returns: | |
if ret.type in allowed_return_types: | |
continue | |
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})") | |
def _check_doesnt_have_library_autograd_impl(self): | |
if self._registered_autograd_kernel_indirection: | |
return | |
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): | |
raise RuntimeError( | |
f"impl_backward/impl_save_for_backward: the operator {self._qualname} " | |
f"already has an implementation for this device type via a " | |
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." | |
f"CompositeImplicitAutograd operators do not need an autograd formula; " | |
f"instead, the operator will decompose into its constituents and those " | |
f"can have autograd formulas defined on them.") | |
# We can improve this by adding "all Autograd<BACKEND> keys", but | |
# realistically people will just be using this API for CPU/CUDA for now. | |
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]: | |
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key): | |
raise RuntimeError( | |
f"impl_backward/impl_save_for_backward: " | |
f"the operator {self._qualname} already has an Autograd kernel " | |
f"registered to DispatchKey::{key} vi a pre-existing " | |
f"torch.library or TORCH_LIBRARY registration. Please either " | |
f"remove those registrations or don't use the torch._custom_ops APIs") | |
def _check_doesnt_have_library_meta_impl(self): | |
if self._has_impl("abstract"): | |
return | |
# If the user's operator is CompositeExplicitAutograd, | |
# allow them to impl_abstract. This is being pragmatic | |
# (existing custom ops may have CompositeExplicitAutograd | |
# registration that don't work with Meta kernels, so this | |
# gives them an escape hatch). | |
if ( | |
_C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd") | |
and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta") | |
): | |
return | |
# Otherwise, if the user's already has a Meta kernel or their | |
# op is CompositeImplicitAutograd or some other alias dispatch key, | |
# raise. | |
# Special case for CompositeImplicitAutograd | |
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): | |
raise RuntimeError( | |
f"impl_abstract(...): the operator {self._qualname} " | |
f"already has an implementation for this device type via a " | |
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." | |
f"CompositeImplicitAutograd operators do not need an abstract impl; " | |
f"instead, the operator will decompose into its constituents and those " | |
f"can have abstract impls defined on them.") | |
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): | |
raise RuntimeError( | |
f"impl_abstract(...): the operator {self._qualname} " | |
f"already has an DispatchKey::Meta implementation via a " | |
f"pre-existing torch.library or TORCH_LIBRARY registration. " | |
f"Please either remove that registration or don't call impl_abstract.") | |
# NOTE ["backward", "save_for_backward", and "autograd"] | |
# As a part of the explicit autograd API, a user must provide us | |
# a "save_for_backward" function and a "backward" function. | |
# When both of these have been provided, then we automatically | |
# construct the "autograd" kernel. | |
def _register_autograd_kernel(self): | |
assert self._has_impl("backward") | |
assert self._has_impl("save_for_backward") | |
kernel = construct_autograd_kernel( | |
self._schema, | |
self._output_differentiability, | |
self, | |
get_op(self._qualname), | |
self._get_impl("save_for_backward").func, | |
self._get_impl("backward").func) | |
self._register_impl("autograd", kernel) | |
def impl_save_for_backward(self, _stacklevel=2): | |
r"""Register a function that tells us what to save for backward. | |
Please see impl_backward for more details. | |
""" | |
def inner(f): | |
self._check_can_register_backward() | |
self._check_doesnt_have_library_autograd_impl() | |
if not self._registered_autograd_kernel_indirection: | |
self._register_autograd_kernel_indirection() | |
self._register_impl("save_for_backward", f, stacklevel=_stacklevel) | |
if self._has_impl("backward"): | |
self._register_autograd_kernel() | |
return inner | |
def impl_backward(self, output_differentiability=None, _stacklevel=2): | |
r"""Registers a backward formula. | |
WARNING: if you're a user, please do not use this directly | |
(instead use the torch._custom_ops APIs). | |
Also please see the following for a detailed guide on custom ops. | |
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk | |
In order for the CustomOp to work with autograd, you need to register | |
a backward formula. There are two pieces to this: | |
1. You must give us a function to specify what to save for backward. | |
Call this the "save for backward" function. | |
2. You must give us a function that computes gradients. Call this the | |
"backward" function. | |
Use `impl_save_for_backward` to define a "save for backward" function | |
that specifies what gets saved for backward. The function should accept | |
two arguments ``(inputs, output)`` and return the quantities to be saved | |
for backward. | |
During runtime, when you call the CustomOp, PyTorch will invoke the | |
"save for backward" function with the inputs and output of the CustomOp. | |
Use `impl_backward` to define the "backward" function. The backward | |
function must accept ``(ctx, saved, *grads)``: | |
- ``ctx`` is a context object where we may provide information | |
- ``saved`` is exactly what gets returned from the "save for backward" | |
function | |
- ``grads`` is one or more gradients. The number of gradients matches | |
the number of outputs of the CustomOp. | |
The backward function must return a dict that maps the name of | |
an input to the CustomOp to its corresponding gradient. All inputs that | |
were declared to be Tensors in the CustomOp definition must be accounted | |
for in the dict. The gradient may be a Tensor or None. | |
""" | |
if output_differentiability is not None: | |
def yell(): | |
raise RuntimeError( | |
f"impl_backward(output_differentiability): expected " | |
f"output_differentiability to be a list of bools with " | |
f"length equal to the number of outputs of this CustomOp " | |
f"got: {output_differentiability}") | |
if not isinstance(output_differentiability, list): | |
yell() | |
for diff in output_differentiability: | |
if not isinstance(diff, bool): | |
yell() | |
if len(self._schema.returns) != len(output_differentiability): | |
yell() | |
def inner(f): | |
self._check_can_register_backward() | |
self._check_doesnt_have_library_autograd_impl() | |
if not self._registered_autograd_kernel_indirection: | |
self._register_autograd_kernel_indirection() | |
self._register_impl("backward", f, stacklevel=_stacklevel) | |
self._output_differentiability = output_differentiability | |
if self._has_impl("save_for_backward"): | |
self._register_autograd_kernel() | |
return inner | |
class FuncAndLocation: | |
func: typing.Callable | |
location: str | |
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName): | |
overload_name = ( | |
"" if operator_name.overload_name is None else operator_name.overload_name | |
) | |
return _C._dispatch_find_schema_or_throw( | |
f"{cpp_ns}::{str(operator_name.name)}", overload_name | |
) | |
def validate_namespace(ns: str) -> None: | |
if "." in ns: | |
raise ValueError( | |
f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a ' | |
f"valid variable name)" | |
) | |
if ns in RESERVED_NS: | |
raise ValueError( | |
f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, " | |
f"please choose something else. " | |
) | |
def validate_schema(schema: FunctionSchema) -> None: | |
if not torch._library.utils.is_functional_schema(schema): | |
raise ValueError( | |
f"custom_op only supports functional operators " | |
f"(ops that do not mutate any inputs, do not return " | |
f"views of the inputs, and has at least one return). " | |
f"Got the following non-functional schema: {schema}" | |
) | |
# For simplicity: don't allow self arguments | |
if schema.arguments.self_arg is not None: | |
raise ValueError( | |
f"custom_op does not support arguments named 'self'. Please " | |
f"rename your argument. Got: {schema}" | |
) | |
def parse_qualname(qualname: str) -> typing.Tuple[str, str]: | |
names = qualname.split("::", 1) | |
if len(names) != 2: | |
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The " | |
f"operator name should look something like ns::foo") | |
if '.' in names[1]: | |
raise ValueError(f"The torch.custom_ops APIs do not handle overloads, " | |
f"i.e. operator names with '.' in them. " | |
f"Please name your operator something like ns::foo. " | |
f"Got: {qualname}") | |
return names[0], names[1] | |
def validate_device_type(device_type: str) -> None: | |
if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY: | |
raise ValueError( | |
f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type " | |
f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}." | |
) | |
def supported_param(param: inspect.Parameter) -> bool: | |
return param.kind in ( | |
inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
inspect.Parameter.KEYWORD_ONLY, | |
) | |
def validate_function_matches_schema( | |
schema: FunctionSchema, func: typing.Callable | |
) -> None: | |
sig = inspect.signature(func) | |
if not all(supported_param(p) for _, p in sig.parameters.items()): | |
raise ValueError( | |
f"custom_op(..., manual_schema)(func): positional-only args, " | |
f"varargs, and kwargs are not supported. Please rewrite `func` " | |
f"to not have them. Got `func` with signature: {sig}" | |
) | |
if ( | |
any( | |
p.annotation is not inspect.Parameter.empty | |
for _, p in sig.parameters.items() | |
) | |
or sig.return_annotation is not inspect.Signature.empty | |
): | |
raise ValueError( | |
f"custom_op(..., manual_schema)(func): When passing in a manual " | |
f"schema, we expect `func` to have no type annotations to avoid " | |
f"ambiguity. Got `func` with signature: {sig}" | |
) | |
positional = [ | |
(name, param) | |
for name, param in sig.parameters.items() | |
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD | |
] | |
kwargonly = [ | |
(name, param) | |
for name, param in sig.parameters.items() | |
if param.kind == inspect.Parameter.KEYWORD_ONLY | |
] | |
def error(): | |
raise ValueError( | |
f"custom_op(..., manual_schema)(func): When passing in a manual " | |
f"schema, we expect `func`'s signature to match `manual_schema` " | |
f"(aside from type annotations). " | |
f"func's signature: {sig}, manual_schema: {schema}" | |
) | |
def error_default_args(): | |
raise ValueError( | |
f"custom_op(..., manual_schema)(func): " | |
f"neither func nor manual_schema should have default " | |
f"arguments. Got " | |
f"func's signature: {sig}, manual_schema: {schema}" | |
) | |
def compare(sig_args, schema_args): | |
if len(sig_args) != len(schema_args): | |
error() | |
for (name, param), arg in zip(sig_args, schema_args): | |
if name != arg.name: | |
error() | |
if param.default is not inspect.Parameter.empty or arg.default is not None: | |
error_default_args() | |
compare(positional, schema.arguments.flat_positional) | |
compare(kwargonly, schema.arguments.flat_kwarg_only) | |
def infer_schema(prototype_function: typing.Callable) -> str: | |
sig = inspect.signature(prototype_function) | |
def error_fn(what): | |
raise ValueError( | |
f"custom_op(...)(func): {what} " f"Got func with signature {sig})" | |
) | |
params = [ | |
parse_param(name, param, error_fn) for name, param in sig.parameters.items() | |
] | |
ret = parse_return(sig.return_annotation, error_fn) | |
return f"({', '.join(params)}) -> {ret}" | |
def parse_param(name, param, error_fn): | |
if not supported_param(param): | |
error_fn("We do not support positional-only args, varargs, or varkwargs.") | |
if param.annotation is inspect.Parameter.empty: | |
error_fn(f"Parameter {name} must have a type annotation.") | |
if param.annotation not in SUPPORTED_PARAM_TYPES.keys(): | |
error_fn( | |
f"Parameter {name} has unsupported type {param.annotation}. " | |
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." | |
) | |
if param.default is not inspect.Parameter.empty: | |
error_fn( | |
f"Parameter {name} has a default value; this is not supported. " | |
f"If you want to use default values then create a function with " | |
f"default values that calls the CustomOp" | |
) | |
return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}" | |
def derived_types( | |
base_type, cpp_type, list_base, optional_base_list, optional_list_base | |
): | |
result = [ | |
(base_type, cpp_type), | |
(typing.Optional[base_type], f"{cpp_type}?"), | |
] | |
if list_base: | |
result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type] | |
if optional_base_list: | |
result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type] | |
if optional_list_base: | |
result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type] | |
return result | |
def get_supported_param_types(): | |
data = [ | |
# (python type, schema type, type[] variant, type?[] variant, type[]? variant | |
(torch.Tensor, "Tensor", True, True, False), | |
(int, "SymInt", True, False, True), | |
(float, "float", True, False, True), | |
(bool, "bool", True, False, True), | |
(str, "str", False, False, False), | |
(torch.types.Number, "Scalar", True, False, False), | |
(torch.dtype, "ScalarType", False, False, False), | |
(torch.device, "Device", False, False, False), | |
] | |
result = [] | |
for line in data: | |
result.extend(derived_types(*line)) | |
return dict(result) | |
SUPPORTED_RETURN_TYPES = { | |
torch.Tensor: "Tensor", | |
typing.List[torch.Tensor]: "Tensor[]", | |
int: "SymInt", | |
float: "float", | |
bool: "bool", | |
torch.types.Number: "Scalar", | |
} | |
def parse_return(annotation, error_fn): | |
origin = typing.get_origin(annotation) | |
if origin is not tuple: | |
if annotation not in SUPPORTED_RETURN_TYPES.keys(): | |
error_fn( | |
f"Return has unsupported type {annotation}. " | |
f"The valid types are: {SUPPORTED_RETURN_TYPES}." | |
) | |
return SUPPORTED_RETURN_TYPES[annotation] | |
args = typing.get_args(annotation) | |
for arg in args: | |
if arg not in SUPPORTED_RETURN_TYPES: | |
error_fn( | |
f"Return has unsupported type {annotation}. " | |
f"The valid types are: {SUPPORTED_RETURN_TYPES}." | |
) | |
return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")" | |
SUPPORTED_PARAM_TYPES = get_supported_param_types() | |
def report_error_callback(custom_op: typing.Any, key: str) -> None: | |
if key == "Undefined": | |
raise NotImplementedError( | |
f"{custom_op}: There were no Tensor inputs to this operator " | |
f"(e.g. you passed an empty list of Tensors). If your operator is a " | |
f"factory function (that is, it takes no Tensors and constructs " | |
f"a new one), then please use CustomOp.impl_factory to register " | |
f"an implementation for it" | |
) | |
if key == "Meta": | |
raise NotImplementedError( | |
f"{custom_op}: when running with device='Meta' tensors: there is no " | |
f"abstract impl registered for this CustomOp. Please register one via " | |
f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors" | |
) | |
if key in ("CPU", "CUDA"): | |
device = key.lower() | |
raise NotImplementedError( | |
f"{custom_op}: when running with device='{device}' tensors: there is no " | |
f"{device} impl registered for this CustomOp. Please register one via " | |
f"CustomOp.impl(device_type='{device}')" | |
) | |
raise NotImplementedError( | |
f"{custom_op}: No implementation for dispatch key {key}. It is likely " | |
f"that we have not added this functionality yet, please either open an " | |
f"issue or if you're feeling adventurous, use the low-level " | |
f"torch.library API" | |
) | |
def custom_op_from_existing(op): | |
ns = op.namespace | |
lib = torch.library.Library(ns, "FRAGMENT") | |
name = op.name().split("::")[-1] | |
schema_str = str(op._schema) | |
# CustomOp expects the schema string without the namespace | |
schema_str = schema_str.split("::")[-1] | |
schema = FunctionSchema.parse(schema_str) | |
return CustomOp(lib, ns, schema, name, op, _private_access=True) | |
def get_op(qualname): | |
def error_not_found(): | |
raise ValueError( | |
f"Could not find the operator {qualname}. Please make sure you have " | |
f"already registered the operator and (if registered from C++) " | |
f"loaded it via torch.ops.load_library.") | |
ns, name = parse_qualname(qualname) | |
if not hasattr(torch.ops, ns): | |
error_not_found() | |
opnamespace = getattr(torch.ops, ns) | |
if not hasattr(opnamespace, name): | |
error_not_found() | |
packet = getattr(opnamespace, name) | |
if not hasattr(packet, 'default'): | |
error_not_found() | |
return packet.default | |
def _find_custom_op(qualname, also_check_torch_library=False): | |
if qualname in global_registry: | |
return global_registry[qualname] | |
if not also_check_torch_library: | |
raise RuntimeError( | |
f"Could not find custom op \"{qualname}\". Did you register it via " | |
f"the torch._custom_ops API?") | |
overload = get_op(qualname) | |
result = custom_op_from_existing(overload) | |
return result | |
def get_abstract_impl(qualname): | |
if qualname not in torch._custom_op.impl.global_registry: | |
return None | |
custom_op = torch._custom_op.impl.global_registry[qualname] | |
if custom_op is None: | |
return None | |
if not custom_op._has_impl("abstract"): | |
return None | |
return custom_op._get_impl("abstract").func | |
def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True): | |
ns, name = qualname.split("::") | |
schema_str = f"{name}{schema}" | |
function_schema = FunctionSchema.parse(schema_str) | |
validate_schema(function_schema) | |
tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else [] | |
lib = library.Library(ns, "FRAGMENT") | |
lib.define(schema_str, tags=tags) | |
ophandle = find_ophandle_or_throw(ns, function_schema.name) | |
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) | |
result._register_autograd_kernel_indirection() | |
torch._C._dispatch_set_report_error_callback( | |
ophandle, functools.partial(report_error_callback, weakref.proxy(result)) | |
) | |
return get_op(qualname) | |