Spaces:
Running
Running
import contextlib | |
import functools | |
import warnings | |
from typing import Callable, Optional | |
import torch | |
from torch._library.utils import Kernel, RegistrationHandle | |
class AbstractImplHolder: | |
"""A holder where one can register an abstract impl to.""" | |
def __init__(self, qualname: str): | |
self.qualname: str = qualname | |
self.kernel: Optional[Kernel] = None | |
self.lib: Optional[torch.library.Library] = None | |
def register(self, func: Callable, source: str) -> RegistrationHandle: | |
"""Register an abstract impl. | |
Returns a RegistrationHandle that one can use to de-register this | |
abstract impl. | |
""" | |
if self.kernel is not None: | |
raise RuntimeError( | |
f"impl_abstract(...): the operator {self.qualname} " | |
f"already has an abstract impl registered at " | |
f"{self.kernel.source}." | |
) | |
if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): | |
raise RuntimeError( | |
f"impl_abstract(...): the operator {self.qualname} " | |
f"already has an DispatchKey::Meta implementation via a " | |
f"pre-existing torch.library or TORCH_LIBRARY registration. " | |
f"Please either remove that registration or don't call " | |
f"impl_abstract." | |
) | |
if torch._C._dispatch_has_kernel_for_dispatch_key( | |
self.qualname, "CompositeImplicitAutograd" | |
): | |
raise RuntimeError( | |
f"impl_abstract(...): the operator {self.qualname} " | |
f"already has an implementation for this device type via a " | |
f"pre-existing registration to " | |
f"DispatchKey::CompositeImplicitAutograd." | |
f"CompositeImplicitAutograd operators do not need an abstract " | |
f"impl; " | |
f"instead, the operator will decompose into its constituents " | |
f"and those " | |
f"can have abstract impls defined on them." | |
) | |
# Store the kernel in this holder | |
self.kernel = Kernel(func, source) | |
# Also register the abstract impl to Meta key | |
if self.lib is None: | |
ns = self.qualname.split("::")[0] | |
self.lib = torch.library.Library(ns, "FRAGMENT") | |
meta_kernel = construct_meta_kernel(self.qualname, self) | |
self.lib.impl(self.qualname, meta_kernel, "Meta") | |
def deregister_abstract_impl(): | |
if self.lib: | |
self.lib._destroy() | |
self.lib = None | |
self.kernel = None | |
return RegistrationHandle(deregister_abstract_impl) | |
def construct_meta_kernel( | |
qualname: str, abstract_impl_holder: AbstractImplHolder | |
) -> Callable: | |
assert abstract_impl_holder.kernel is not None | |
def meta_kernel(*args, **kwargs): | |
assert abstract_impl_holder.kernel is not None | |
source = abstract_impl_holder.kernel.source | |
def error_on_ctx(): | |
raise RuntimeError( | |
f"Attempted to call get_ctx() for the meta implementation " | |
f"for {qualname} (implemented at {source})" | |
f"You have presumably called get_ctx() because the operator " | |
f"has a data-dependent output shape; if so, there is no " | |
f"such meta implementation and this error is the correct " | |
f"behavior." | |
) | |
with set_ctx_getter(error_on_ctx): | |
return abstract_impl_holder.kernel(*args, **kwargs) | |
return meta_kernel | |
def get_none(): | |
return None | |
global_ctx_getter: Callable = get_none | |
def set_ctx_getter(ctx_getter): | |
global global_ctx_getter | |
prev = global_ctx_getter | |
try: | |
global_ctx_getter = ctx_getter | |
yield | |
finally: | |
global_ctx_getter = prev | |
class AbstractImplCtx: | |
""" | |
Context object for writing abstract implementations for custom operators. | |
""" | |
def __init__(self, _shape_env, _op): | |
self._shape_env = _shape_env | |
self._op = _op | |
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: | |
warnings.warn( | |
"create_unbacked_symint is deprecated, please use new_dynamic_size instead" | |
) | |
return self.new_dynamic_size(min=min, max=max) | |
def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: | |
"""Constructs a new symint (symbolic int) representing a data-dependent value. | |
This is useful for writing the abstract implementation (which is necessary | |
for torch.compile) for a CustomOp where an output Tensor has a size | |
that depends on the data of the input Tensors. | |
Args: | |
min (int): A statically known inclusive lower bound for this symint. Default: 0 | |
max (Optional[int]): A statically known inclusive upper bound for this | |
symint. Default: None | |
.. warning: | |
It is important that the ``min`` and ``max`` (if not None) values are set | |
correctly, otherwise, there will be undefined behavior under | |
torch.compile. The default value of ``min`` is 2 due to torch.compile | |
specializing on 0/1 sizes. | |
You must also verify that your implementation on concrete Tensors | |
(e.g. CPU/CUDA) only returns Tensors where the size that corresponds | |
to the symint also has respects these constraint. | |
The easiest way to do this is to add an assertion in the CPU/CUDA/etc | |
implementation that the size follows these bounds. | |
Example:: | |
>>> # An operator with data-dependent output shape | |
>>> lib = torch.library.Library("mymodule", "FRAGMENT") | |
>>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") | |
>>> | |
>>> @torch.library.impl_abstract("mymodule::custom_nonzero") | |
>>> def custom_nonzero_abstract(x): | |
>>> # Number of nonzero-elements is data-dependent. | |
>>> # Since we cannot peek at the data in an abstract impl, | |
>>> # we use the ctx object to construct a new symint that | |
>>> # represents the data-dependent size. | |
>>> ctx = torch.library.get_ctx() | |
>>> nnz = ctx.new_dynamic_size() | |
>>> shape = [nnz, x.dim()] | |
>>> result = x.new_empty(shape, dtype=torch.int64) | |
>>> return result | |
>>> | |
>>> @torch.library.impl(lib, "custom_nonzero", "CPU") | |
>>> def custom_nonzero_cpu(x): | |
>>> x_np = x.numpy() | |
>>> res = np.stack(np.nonzero(x_np), axis=1) | |
>>> return torch.tensor(res, device=x.device) | |
""" | |
if ( | |
self._shape_env is None | |
or not self._shape_env.allow_dynamic_output_shape_ops | |
): | |
raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) | |
if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): | |
raise ValueError( | |
f"ctx.new_dynamic_size(min={min}, max={max}): expected " | |
f"min and max to be statically known ints but got SymInt. " | |
f"This is not supported." | |
) | |
if min < 0: | |
raise ValueError( | |
f"ctx.new_dynamic_size(min={min}, ...): expected min to be " | |
f"greater than or equal to 0: this API can only create " | |
f"non-negative sizes." | |
) | |
result = self._shape_env.create_unbacked_symint() | |
torch.fx.experimental.symbolic_shapes._constrain_range_for_size( | |
result, min=min, max=max | |
) | |
return result | |