Spaces:
Sleeping
Sleeping
import contextlib | |
import ctypes | |
import importlib | |
import inspect | |
import sys | |
import types | |
from typing import Any, Callable, Dict, Set, Type, Union | |
import torch._C | |
import torch.utils._pytree as pytree | |
from torch import _utils_internal | |
from torch._functorch.pyfunctorch import dispatch_functorch | |
from torch.utils._python_dispatch import TorchDispatchMode | |
# Query `hasattr` only once. | |
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") | |
def dl_open_guard(): | |
""" | |
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a | |
shared library to load custom operators. | |
""" | |
if not _SET_GLOBAL_FLAGS: | |
yield | |
return | |
old_flags = sys.getdlopenflags() | |
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) | |
try: | |
yield | |
finally: | |
sys.setdlopenflags(old_flags) | |
class OperatorBase: | |
""" | |
Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator | |
(which represents Python-only operators that are unrepresentable in TorchScript). | |
""" | |
def __init__(self): | |
# The dispatch cache precomputes a mapping of dispatch key that the | |
# dispatcher wants to dispatch to, to an actual implementation of the | |
# dispatch key. Confusingly, the actual implementation could *also* be a | |
# dispatch key, but in this case, this refers to the C++ kernel that | |
# was registered to some dispatch key. Aliases are permitted in the | |
# latter but not the former; for example, you might lookup the | |
# entry for AutogradCPU, and this maps you to the Autograd key for | |
# the generic autograd kernel that works for all devices. Since this | |
# is the Python dispatcher, you can also put an arbitrary Python | |
# callable to call instead. This handler gets precisely the | |
# args/kwargs that the operator was __call__'ed with. | |
# NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp | |
# for use with OpOverload; cache lookup is done entirely from C++ | |
# for speed. | |
# TODO: The cache is NOT currently used by HigherOrderOperator, but it should! | |
self._dispatch_cache: Dict[ | |
torch._C.DispatchKey, Union[torch._C.DispatchKey, Callable[..., Any]] | |
] = {} | |
# This table allows you to override the behavior of a particular | |
# dispatch key to call a custom Python function, rather than the | |
# ordinary C++ configured behavior. This is the raison d'etre of | |
# Python dispatcher: to let you program the dispatcher from Python | |
# in case you need something unusual, and don't want to clobber | |
# the existing registrations using the Python operator registration | |
# API. | |
self.py_kernels: Dict[torch._C.DispatchKey, Callable[..., Any]] = {} | |
# This table allows you to override the behavior of a particular | |
# operator for a particular TorchDispatchMode. In practice, | |
# we are using this mostly for ProxyTensorMode. Modes can be | |
# thought of as an open world extension of dispatch keys, so it | |
# makes sense that you should be able to register them, the same | |
# way you can register dispatch keys. | |
self.python_key_mode_table: Dict[ | |
Type[TorchDispatchMode], Callable[..., Any] | |
] = {} | |
# This table allows you to override the behavior of functorch | |
# transformations. NB: this currently only does something for | |
# HigherOrderOperator | |
self.functorch_table = {} | |
def __call__(self, *args, **kwargs): | |
raise NotImplementedError() | |
def has_kernel_for_dispatch_key(self, k): | |
return k in self.py_kernels | |
def has_kernel_for_any_dispatch_key(self, ks): | |
for k in self.py_kernels: | |
if not torch._C._dispatch_is_alias_key(k) and ks.has(k): | |
return True | |
return False | |
def py_impl(self, k): | |
def inner(fn): | |
if inspect.isclass(k) and issubclass(k, TorchDispatchMode): | |
assert k not in self.python_key_mode_table | |
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys? | |
self.python_key_mode_table[k] = fn | |
self._dispatch_cache.clear() | |
return fn | |
if isinstance(k, torch._C._functorch.TransformType): | |
assert k not in self.functorch_table | |
self.functorch_table[k] = fn | |
return fn | |
assert isinstance(k, torch._C.DispatchKey) | |
assert ( | |
k != torch._C.DispatchKey.Python | |
), "Please register a mode for the torch._C.DispatchKey.Python key instead." | |
if k in self.py_kernels: | |
raise RuntimeError( | |
f"Trying to override a python impl for {k} on operator {self.name()}" | |
) | |
self.py_kernels[k] = fn | |
self._dispatch_cache.clear() | |
return fn | |
return inner | |
# Registers an implementation to all **3** variants of functionalization that we have: | |
# - DispatchKey.Functionalize | |
# - functorch.TransformType.Functionalize | |
# - FunctionalTensorMode | |
# Example: | |
# @py_functionalize_impl | |
# def functionalize_rule(ctx, inner_f, *args): | |
# args_unwrapped = ctx.unwrap_tensors(args) | |
# with ctx.redispatch_to_next(): | |
# out = ctx.functionalize(inner_f)(*args_unwrapped) | |
# return ctx.wrap_tensors(out) | |
def py_functionalize_impl(self, fn): | |
from torch._subclasses.functional_tensor import ( | |
CppFunctionalizeAPI as _CppFunctionalizeAPI, | |
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, | |
PythonFunctionalizeAPI as _PythonFunctionalizeAPI, | |
) | |
# Construct our three flavors of functionalization, | |
# each of which have slightly different wrap/unwrap/redispatch policies | |
def functionalize_dk_fn(*args, **kwargs): | |
return fn(_CppFunctionalizeAPI(), *args, **kwargs) | |
def functionalize_dispatch_mode_fn(mode, *args, **kwargs): | |
return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs) | |
def functionalize_functorch_fn(interpreter, *args, **kwargs): | |
return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) | |
self.py_impl(torch._C.DispatchKey.Functionalize)(functionalize_dk_fn) | |
self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)( | |
functionalize_dispatch_mode_fn | |
) | |
self.py_impl(torch._C._functorch.TransformType.Functionalize)( | |
functionalize_functorch_fn | |
) | |
return fn | |
def name(self): | |
raise NotImplementedError() | |
is_included_in_alias = torch._C._dispatch_is_included_in_alias | |
DispatchKey = torch._C.DispatchKey | |
# Equivalent to computeDispatchTableEntryWithDebug | |
def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] | |
# 1. (Direct) operator registration | |
if op.has_kernel_for_dispatch_key(k): | |
return k | |
# 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available | |
cand = DispatchKey.CompositeExplicitAutogradNonFunctional | |
if ( | |
k == DispatchKey.Undefined or is_included_in_alias(k, cand) | |
) and op.has_kernel_for_dispatch_key(cand): | |
return cand | |
# 2.2 Use CompositeExplicitAutograd kernel if available | |
cand = DispatchKey.CompositeExplicitAutograd | |
if ( | |
k == DispatchKey.Undefined or is_included_in_alias(k, cand) | |
) and op.has_kernel_for_dispatch_key(cand): | |
return cand | |
has_backend_kernel = op.has_kernel_for_any_dispatch_key( | |
torch._C._dispatch_get_backend_keyset_from_autograd(k) | |
) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) | |
# 2.3. Use CompositeImplicitAutograd kernel if available | |
cand = DispatchKey.CompositeImplicitAutogradNestedTensor | |
if ( | |
(k != DispatchKey.Undefined and is_included_in_alias(k, cand)) | |
and op.has_kernel_for_dispatch_key(cand) | |
and not has_backend_kernel | |
): | |
return cand | |
cand = DispatchKey.CompositeImplicitAutograd | |
if ( | |
k == DispatchKey.Undefined or is_included_in_alias(k, cand) | |
) and op.has_kernel_for_dispatch_key(cand): | |
if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key( | |
torch._C._dispatch_autogradother_backends | |
): | |
raise RuntimeError("ambiguous autogradother kernel") | |
elif not has_backend_kernel: | |
return cand | |
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available | |
cand = DispatchKey.Autograd | |
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): | |
return cand | |
# 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available | |
cand = DispatchKey.FuncTorchBatchedDecomposition | |
if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): | |
return cand | |
# Backend fallback | |
if torch._C._dispatch_has_backend_fallback(k): | |
# The dispatch key itself will implicitly route to backend fallback. | |
# This is probably not great for the pure Python implementation. | |
return k | |
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") | |
_higher_order_ops: Dict[str, "HigherOrderOperator"] = {} | |
_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ | |
DispatchKey.PythonDispatcher, # type: ignore[attr-defined] | |
DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined] | |
DispatchKey.ADInplaceOrView, | |
DispatchKey.BackendSelect, | |
DispatchKey.AutocastCPU, # type: ignore[attr-defined] | |
DispatchKey.AutocastCUDA, # type: ignore[attr-defined] | |
] | |
class HigherOrderOperator(OperatorBase): | |
# The HigherOrderOperator will appear as torch.ops.higher_order.{name} | |
# | |
# If you're creating a new HigherOrderOperator, please do not change the | |
# default. Adding operators to the global torch.ops namespace is a bad | |
# practice due to name collisions. | |
def __init__(self, name): | |
super().__init__() | |
self._name = name | |
# Make _OPNamespace not scream, this whole name based association needs a good hard look | |
self.__name__ = name | |
_higher_order_ops[name] = self | |
self._ns = "higher_order" | |
# For a normal HigherOrderOperator instance, we will change its __module__ from torch._ops to | |
# torch._ops.higher_order. | |
# For an instance of subclass of HigherOrderOperator (e.g. customized higher order op), | |
# the __module__ attribute will be kept unchanged. | |
if self.__class__ is HigherOrderOperator: | |
self_name_space = "." + self.namespace if self.namespace else "" | |
self.__module__ = self.__module__ + self_name_space | |
self.non_fallthrough_keys = torch._C._dispatch_keyset_full() | |
for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS: | |
self.fallthrough(dispatch_key) | |
# [NOTE] We have to register pre-dispatch key implementation | |
# because sometimes HOP use aot-dispatch tracing to detect certaion | |
# mutations. This is problematic when we are functionalizing HOP | |
# during pre-dispatch because when the inner tracer starts, it will see | |
# that PreDispatch key is still active. In that case, we just redispatch | |
# it to next key. This is only safe to do when PreDispatch key stack has no | |
# active modes. | |
# TODO (tmanlaibaatar) Make it generic fallback mechanism | |
def _(*args, **kwargs): | |
if _len_torch_dispatch_stack_pre_dispatch() == 0: | |
with torch._C._ExcludeDispatchKeyGuard( | |
torch._C.DispatchKeySet(DispatchKey.PreDispatch) | |
): | |
return self(*args, **kwargs) | |
raise AssertionError( | |
""" | |
Can't directly invoke HOP implementation at PreDispatch key | |
if there are active modes on PreDispatch mode stack. | |
""" | |
) | |
self.py_impl(torch._C.DispatchKey.PreDispatch)(_) | |
def py_impl(self, k): | |
if isinstance(k, torch._C.DispatchKey) and not self.non_fallthrough_keys.has(k): | |
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) | |
return super().py_impl(k) | |
def namespace(self): | |
return self._ns | |
def fallthrough(self, dispatch_key): | |
self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key) | |
def dispatch(self, dispatch_key, *args, **kwargs): | |
from torch.utils._python_dispatch import _get_current_dispatch_mode | |
if dispatch_key in self._dispatch_cache: | |
kernel = self._dispatch_cache[dispatch_key] | |
assert not isinstance(kernel, torch._C.DispatchKey) | |
return kernel(*args, **kwargs) | |
if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode: | |
return dispatch_functorch(self, args, kwargs) | |
if dispatch_key == torch._C.DispatchKey.Python: | |
# The place to handle ProxyTorchDispatchMode, FakeTensorMode, etc | |
from torch.utils._python_dispatch import _pop_mode_temporarily | |
curr_mode = _get_current_dispatch_mode() | |
assert ( | |
curr_mode is not None | |
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." | |
assert ( | |
type(curr_mode) in self.python_key_mode_table | |
), f"Current active mode {curr_mode} not registered" | |
handler = self.python_key_mode_table[type(curr_mode)] | |
with _pop_mode_temporarily() as mode: | |
return handler(mode, *args, **kwargs) | |
functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] | |
if functionality_key == torch._C.DispatchKey.PreDispatch: | |
from torch.utils._python_dispatch import _pop_mode_temporarily | |
# The check for Python in the exclude set is so we properly respect `with no_dispatch()` | |
# calls inside of a mode. | |
if ( | |
_len_torch_dispatch_stack_pre_dispatch() > 0 | |
) and not torch._C._dispatch_tls_is_dispatch_key_excluded( | |
DispatchKey.Python | |
): | |
curr_mode = _get_current_dispatch_mode_pre_dispatch() | |
assert ( | |
curr_mode is not None | |
), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode." | |
assert ( | |
type(curr_mode) in self.python_key_mode_table | |
), f"Current active mode {curr_mode} not registered" | |
handler = self.python_key_mode_table[type(curr_mode)] | |
with _pop_mode_temporarily(functionality_key) as mode: | |
return handler(mode, *args, **kwargs) | |
final_key = resolve_key(self, dispatch_key) | |
# This can current fail due to backend fallbacks. You just have to | |
# register them by hand for HigherOrderOperator. | |
if final_key not in self.py_kernels: | |
raise NotImplementedError( | |
f"could not find kernel for HigherOrderOperator {self._name} " | |
f"at dispatch key {final_key} (resolved from {dispatch_key})" | |
) | |
self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] | |
kernel = self.py_kernels[final_key] | |
# It's illegal to register DispatchKey to py_kernels, since there's no | |
# C++ kernel to call into | |
assert not isinstance(kernel, torch._C.DispatchKey) | |
return kernel(*args, **kwargs) | |
def __call__(self, *args, **kwargs): | |
# Dynamo already traces the body of HigherOrderOp beforehand when it | |
# so no need to trace into it. | |
import torch._dynamo | |
from torch._dynamo import disable | |
def wrapper(): | |
flat_args = _to_flat_tuple(args, kwargs) | |
if torch.overrides.has_torch_function(flat_args): | |
return torch.overrides.handle_torch_function( | |
self, flat_args, *args, **kwargs | |
) | |
dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) | |
return self.dispatch( | |
dispatch_key_set.highestPriorityTypeId(), *args, **kwargs | |
) | |
return wrapper() | |
def __str__(self): | |
return f"{self.name()}" | |
def name(self): | |
return self._name | |
def _to_flat_tuple(args, kwargs): | |
return pytree.arg_tree_leaves(*args, **kwargs) | |
def _compute_keyset(args, kwargs, non_fallthrough_keys): | |
tensors = _get_tensors(args, kwargs) | |
return key_extractor(tensors, non_fallthrough_keys) | |
def _get_tensors(args, kwargs): | |
flat_all = _to_flat_tuple(args, kwargs) | |
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] | |
return tuple(tensor_args) | |
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic | |
# at ATen/core/dispatch/DispatchKeyExtractor.h | |
def key_extractor(tensors, key_mask): | |
key_set = torch._C._dispatch_tls_local_include_set() | |
for tensor in tensors: | |
key_set = key_set | torch._C._dispatch_keys(tensor) | |
key_set = key_set - torch._C._dispatch_tls_local_exclude_set() | |
key_set = key_set & key_mask | |
return key_set | |
# Mode stack for PreDispatchKey | |
# it should always have two keys with | |
# priority given to FunctionalTensorMode and | |
# then ProxyTorchDispatchMode. It means that | |
# slot 0 belongs to ProxyTorchDispatchMode and | |
# slot 1 belongs to FunctionalTensorMode. | |
class _ModeStackStateForPreDispatch: | |
def __init__(self): | |
self.__infra_modes = [None, None] | |
def set(self, index, mode): | |
assert index < len(self.__infra_modes) | |
self.__infra_modes[index] = mode | |
def get(self, index): | |
assert index < len(self.__infra_modes) | |
return self.__infra_modes[index] | |
def count(self): | |
return len([i for i in self.__infra_modes if i is not None]) | |
_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch() | |
def unset_mode_pre_dispatch(mode_key): | |
current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch() | |
assert mode_key in ( | |
torch._C._TorchDispatchModeKey.PROXY, | |
torch._C._TorchDispatchModeKey.FUNCTIONAL, | |
) | |
if mode_key == torch._C._TorchDispatchModeKey.PROXY: | |
current_mode = current_mode_stack_pre_dispatch.get(0) | |
mode_stack_state_for_pre_dispatch().set(0, None) | |
return current_mode | |
else: | |
current_mode = current_mode_stack_pre_dispatch.get(1) | |
mode_stack_state_for_pre_dispatch().set(1, None) | |
return current_mode | |
def _set_mode_pre_dispatch(mode): | |
from torch._subclasses.functional_tensor import FunctionalTensorMode | |
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode | |
assert isinstance(mode, (FunctionalTensorMode, ProxyTorchDispatchMode)) | |
if isinstance(mode, FunctionalTensorMode): | |
current_mode = mode_stack_state_for_pre_dispatch().get(1) | |
assert current_mode is None | |
mode_stack_state_for_pre_dispatch().set(1, mode) | |
return | |
current_mode = mode_stack_state_for_pre_dispatch().get(0) | |
assert current_mode is None | |
mode_stack_state_for_pre_dispatch().set(0, mode) | |
def _pop_mode_from_pre_dispatch(): | |
mode_stack = mode_stack_state_for_pre_dispatch() | |
if mode_stack.get(1) is not None: | |
res = mode_stack.get(1) | |
mode_stack.set(1, None) | |
return res | |
if mode_stack.get(0) is not None: | |
res = mode_stack.get(0) | |
mode_stack.set(0, None) | |
return res | |
raise AssertionError("Trying to pop empty mode stack") | |
def _len_torch_dispatch_stack_pre_dispatch(): | |
return mode_stack_state_for_pre_dispatch().count() | |
def _get_dispatch_mode_pre_dispatch(mode_key): | |
assert mode_key in ( | |
torch._C._TorchDispatchModeKey.PROXY, | |
torch._C._TorchDispatchModeKey.FUNCTIONAL, | |
) | |
if mode_key == torch._C._TorchDispatchModeKey.PROXY: | |
return mode_stack_state_for_pre_dispatch().get(0) | |
return mode_stack_state_for_pre_dispatch().get(1) | |
def _get_current_dispatch_mode_pre_dispatch(): | |
stack_len = mode_stack_state_for_pre_dispatch().count() | |
if stack_len == 2: | |
return mode_stack_state_for_pre_dispatch().get(1) | |
if stack_len == 1: | |
return ( | |
mode_stack_state_for_pre_dispatch().get(1) | |
if mode_stack_state_for_pre_dispatch().get(1) is not None | |
else mode_stack_state_for_pre_dispatch().get(0) | |
) | |
return None | |
def mode_stack_state_for_pre_dispatch(): | |
global _mode_stack_state_for_pre_dispatch | |
return _mode_stack_state_for_pre_dispatch | |
cached_ops: Set["OpOverload"] = set() | |
def add_cached_op(op_overload): | |
global cached_ops | |
cached_ops.add(op_overload) | |
def reset_cached_ops(): | |
global cached_ops | |
cached_ops.clear() | |
def get_cached_ops(): | |
global cached_ops | |
return cached_ops | |
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. | |
# You can obtain an OpOverload object through attribute query on OpOverloadPacket. | |
class OpOverload(OperatorBase): | |
def __init__(self, overloadpacket, op, op_dk, schema, tags): | |
super().__init__() | |
self._op = op | |
self._op_dk = op_dk | |
self._schema = schema | |
self._overloadpacket = overloadpacket | |
self._tags = tags | |
self._overloadname = ( | |
"default" if schema.overload_name == "" else schema.overload_name | |
) | |
self._name = self._schema.name | |
if schema.overload_name: | |
self._name += "." + schema.overload_name | |
self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}" | |
self.__module__ = overloadpacket.__module__ | |
op.__module__ = overloadpacket.__module__ | |
self.__qualname__ = self._name | |
self.__annotations__ = {} | |
# If the OpOverload was constructed from a Library.def in Python. | |
self._defined_in_python = self.__qualname__ in torch.library._defs | |
# Logic replicated from aten/src/ATen/native/MathBitsFallback.h | |
is_write = None | |
for a in self._schema.arguments: | |
if a.alias_info is None: | |
continue | |
if is_write is None: | |
is_write = a.alias_info.is_write | |
else: | |
# We will conservatively call mixed mutable/non-mutable | |
# aliased inputs as NOT a view | |
is_write = a.alias_info.is_write or is_write | |
self.is_view = is_write is not None and not is_write | |
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload. | |
def __deepcopy__(self, memo=None): | |
return self | |
def __repr__(self): | |
return "<OpOverload(op='{}.{}', overload='{}')>".format( | |
*self._schema.name.split("::"), self._overloadname | |
) | |
def __call__(self_, *args, **kwargs): # noqa: B902 | |
# use `self_` to avoid naming collide with aten ops arguments that | |
# are named "self". This way, all the aten ops can be called by kwargs. | |
return self_._op(*args, **kwargs) | |
def __hash__(self): | |
return hash(self._op) | |
# `my_namespace.my_op_name.overload_name` | |
def __str__(self): | |
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) | |
def has_kernel_for_dispatch_key(self, k): | |
return super().has_kernel_for_dispatch_key( | |
k | |
) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) | |
def has_kernel_for_any_dispatch_key(self, ks): | |
return torch._C._dispatch_has_kernel_for_any_dispatch_key( | |
self.name(), ks | |
) or super().has_kernel_for_any_dispatch_key(ks) | |
def namespace(self): | |
return self._schema.name.split("::")[0] | |
def _handle(self): | |
return torch._C._dispatch_find_schema_or_throw( | |
self._schema.name, self._schema.overload_name | |
) | |
def decompose(self, *args, **kwargs): | |
dk = torch._C.DispatchKey.CompositeImplicitAutograd | |
if dk in self.py_kernels: | |
# NB: This branch is not too necessary anymore, because we can | |
# apply Python CompositeImplicitAutograd *before* tracing | |
# using Python dispatcher (also taking advantage of the autograd | |
# formula). But it's included for completeness | |
return self.py_kernels[dk](*args, **kwargs) | |
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): | |
return self._op_dk(dk, *args, **kwargs) | |
else: | |
return NotImplemented | |
# Remove a dispatch key from the dispatch cache. This will force it to get | |
# recomputed the next time. Does nothing | |
# WARNING: if you register a dispatch key to py_kernels of an OpOverload, | |
# calling _del_dispatch on that key is NOT sufficient to apply your change, | |
# because a single registration may affect MULTIPLE dispatch keys (e.g., | |
# registering Autograd affects AutogradCPU). del_dispatch is to be used | |
# only if you are specifically modifying how get_dispatch handles a | |
# particular input 'key'. | |
def _uncache_dispatch(self, key): | |
self._dispatch_cache.pop(key, None) | |
# This implements the pre-computation logic for the Python dispatcher. | |
def _get_dispatch(self, key): | |
# This is only called upon a cache miss | |
assert key not in self._dispatch_cache, f"{self} {key}" | |
if key == torch._C.DispatchKey.Python: | |
if not self.python_key_mode_table: | |
self._dispatch_cache[key] = key | |
add_cached_op(self) | |
return key | |
def handler(*args, **kwargs): | |
from torch.utils._python_dispatch import _get_current_dispatch_mode | |
# TODO: We also need to handle tensor subclasses here | |
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. | |
curr_mode = type(_get_current_dispatch_mode()) | |
assert ( | |
curr_mode is not None | |
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." | |
if curr_mode not in self.python_key_mode_table: | |
# TODO: This path is slow, should generally encourage this | |
# case to not happen | |
return self._op_dk(key, *args, **kwargs) | |
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key. | |
return self.python_key_mode_table[curr_mode](*args, **kwargs) | |
self._dispatch_cache[key] = handler | |
add_cached_op(self) | |
return handler | |
functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined] | |
if functionality_key == torch._C.DispatchKey.PreDispatch: | |
curr_stack_len = _len_torch_dispatch_stack_pre_dispatch() | |
# The check for Python in the exclude set is so we properly respect `with no_dispatch()` | |
# calls inside of a mode. | |
if ( | |
curr_stack_len > 0 | |
and not torch._C._dispatch_tls_is_dispatch_key_excluded( | |
DispatchKey.Python | |
) | |
): | |
def handler(*args, **kwargs): | |
def _temporarily_pop_modes_from_pre_dispatch(): | |
top_mode = _pop_mode_from_pre_dispatch() | |
try: | |
yield top_mode | |
finally: | |
_set_mode_pre_dispatch(top_mode) | |
with _temporarily_pop_modes_from_pre_dispatch() as curr_mode: | |
assert isinstance(curr_mode, TorchDispatchMode) | |
overload_types = [] | |
args_flattened, _ = torch.utils._pytree.tree_flatten( | |
(args, kwargs.values()) | |
) | |
for a in args_flattened: | |
# TODO: need to double check the semantics of the "types" argument to torch_dispatch. | |
# It's generated in PyInterpreter.cpp, but seems to be generated in two places, | |
# where in one case we only include tensors with the python key, and in another | |
# we include **all** tensors. | |
if isinstance(a, torch.Tensor) and torch._C._dispatch_keys( | |
a | |
).has(torch._C.DispatchKey.Python): | |
overload_types.append(type(a)) | |
# TODO: check that I got these args correct (in C++, we pass in "0000"??) | |
return curr_mode.__torch_dispatch__( | |
self, overload_types, args, kwargs | |
) | |
# Note [Not Caching Per-Dispatch-Key Mode Handlers] | |
# Note that we're not caching this handler. There isn't really a point, since the slow bit | |
# is the handler itself (in python). | |
# Also, not caching means that we don't have to reset the cache when any existing | |
# modes go out of scope (which in of itself takes time to loop through all operators). | |
return handler | |
final_key = resolve_key(self, key) | |
# See Note [Not Caching Per-Dispatch-Key Mode Handlers] | |
cache_result = key != torch._C.DispatchKey.PreDispatch | |
# TODO: We could potentially have lots of debugging wrappers against | |
# dispatch keys; design some general registration mechanism instead of | |
# having if statement for each of them | |
if key == torch._C.DispatchKey.Functionalize: | |
import torch._dispatch.python as pydispatch | |
if pydispatch.CROSSREF_FUNCTIONALIZE: | |
handler = pydispatch.make_crossref_functionalize(self, final_key) | |
if cache_result: | |
self._dispatch_cache[key] = handler | |
add_cached_op(self) | |
return handler | |
# print(self, key, final_key) | |
r = self.py_kernels.get(final_key, final_key) | |
if cache_result: | |
self._dispatch_cache[key] = r | |
add_cached_op(self) | |
return r | |
def name(self): | |
return self._name | |
def overloadpacket(self): | |
return self._overloadpacket | |
def op(self): | |
return self._op | |
def tags(self): | |
return self._tags | |
# TODO: add more methods to expose information about input and output arguments | |
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator | |
# You can obtain an OpOverload object through attribute query. | |
class OpOverloadPacket: | |
def __init__(self, qualified_op_name, op_name, op, overload_names): | |
# These attributes are accessible on the object through the properties | |
# defined below but are immutable | |
self._qualified_op_name = qualified_op_name | |
self.__name__ = op_name | |
self._op = op | |
self._overload_names = overload_names | |
self._dir = [] | |
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. | |
def __deepcopy__(self, memo=None): | |
return self | |
def __repr__(self): | |
return "<OpOverloadPacket(op='{}.{}')>".format( | |
*self._qualified_op_name.split("::") | |
) | |
def __hash__(self): | |
return hash(self._op) | |
def __str__(self): | |
return "{}.{}".format(*self._qualified_op_name.split("::")) | |
def op(self): | |
return self._op | |
def __getattr__(self, key): | |
# It is not a valid op_name when __file__ is passed in | |
if key == "__file__": | |
return "torch.ops" | |
# ensure that query for dunder attributes that does not exist on | |
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call | |
# `_get_operation_overload` (which is an expensive operation). | |
# This is done to prevent any potential slowdown. This list can be extended | |
# if there exists other attributes like `__name__` that only exist on self._op and not on the | |
# opoverloadpacket. | |
# This is ok since we are guaranteed that an overload name for an aten op can't start with '__' | |
try: | |
if key.startswith("__"): | |
return getattr(self._op, key) | |
except AttributeError: | |
# for consistency because it seems weird to | |
# throw an attribute error with a message containing | |
# an object name different from the one the attribute | |
# query was performed on. | |
raise AttributeError( | |
f"'{str(self)}' can't have an overload name beginning with '__' and the " | |
f"underlying op {str(self._op)} has no attribute {key} either." | |
) from None | |
try: | |
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default' | |
use_key = "" if key == "default" else key | |
# TODO: disallow access to overloads registered by JIT | |
op_, op_dk_, tags = torch._C._get_operation_overload( | |
self._qualified_op_name, use_key | |
) | |
schema = torch._C._get_schema(self._qualified_op_name, use_key) | |
overload = OpOverload(self, op_, op_dk_, schema, tags) | |
# cache the overload object | |
setattr(self, key, overload) | |
self._dir.append(key) | |
return overload | |
except RuntimeError: | |
raise AttributeError( | |
f"The underlying op of '{str(self)}' has no overload name '{key}'" | |
) from None | |
def __iter__(self): | |
return iter(self._dir) | |
def __call__(self_, *args, **kwargs): # noqa: B902 | |
# use `self_` to avoid naming collide with aten ops arguments that | |
# named "self". This way, all the aten ops can be called by kwargs. | |
# overloading __call__ to ensure torch.ops.foo.bar() | |
# is still callable from JIT | |
# We save the function ptr as the `op` attribute on | |
# OpOverloadPacket to access it here. | |
return self_._op(*args, **(kwargs or {})) | |
# TODO: use this to make a __dir__ | |
def overloads(self): | |
return [n if n else "default" for n in self._overload_names] | |
# Resolution of torch.fn is different from torch.ops.aten.fn | |
# torch.fn uses the Python argparser, matches with the | |
# appropriate schema, and calls into the unboxed version of the method | |
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. | |
# JIT creates a stack of all the overloads and then tries to match the | |
# correct one at runtime and always calls into the boxed version of the method | |
# Autograd codegen creates VariableType, TracerType, | |
# inplace or view type and python bindings. | |
# Aten codegen generates tensor methods for the tensor class. | |
# _OpNamespace is a subclass of ModuleType because the torch script | |
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar() | |
# to work from script, we need to ensure ops and foo are modules | |
class _OpNamespace(types.ModuleType): | |
""" | |
An op namespace to dynamically bind Operators into Python. | |
Say a user has created a custom Operator called "my_namespace::my_op". To | |
call this op, the user will write torch.ops.my_namespace.my_op(...). | |
At startup, this operation will not yet be bound into Python. Instead, the | |
following sequence of magic tricks will occur: | |
1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method | |
on the `torch.ops` object, which will create a new `_OpNamespace` | |
object called `my_namespace` and set it as an attribute on the `ops` | |
object. | |
2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on | |
the `my_namespace` object, which will retrieve the operation via | |
`torch.get_operation`, a function bound from C++, and then in a similar | |
fashion bind this new object onto the `my_namespace` object. | |
3. `torch.ops.my_namespace.my_op(...)` then calls this new operation | |
and subsequent accesses will incur no further lookup (the namespace and | |
operation will already exist). | |
""" | |
def __init__(self, name): | |
super().__init__("torch.ops." + name) | |
self.name = name | |
self._dir = [] | |
def __iter__(self): | |
return iter(self._dir) | |
def __getattr__(self, op_name): | |
# It is not a valid op_name when __file__ is passed in | |
if op_name == "__file__": | |
return "torch.ops" | |
elif op_name in ["__origin__", "__self__"]: | |
raise AttributeError( | |
f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" | |
) | |
# Get the op `my_namespace::my_op` if available. This will also check | |
# for overloads and raise an exception if there are more than one. | |
namespace_name = self.name | |
qualified_op_name = f"{namespace_name}::{op_name}" | |
try: | |
op, overload_names = torch._C._jit_get_operation(qualified_op_name) | |
if op is None: | |
raise AttributeError( | |
f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" | |
) | |
except RuntimeError as e: | |
# Turn this into AttributeError so getattr(obj, key, default) | |
# works (this is called by TorchScript with __origin__) | |
raise AttributeError( | |
f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" | |
) from e | |
# let the script frontend know that op is identical to the builtin op | |
# with qualified_op_name | |
torch.jit._builtins._register_builtin(op, qualified_op_name) | |
op.__module__ = self.__module__ + "." + namespace_name | |
opoverloadpacket = OpOverloadPacket( | |
qualified_op_name, op_name, op, overload_names | |
) | |
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name | |
# cache the opoverloadpacket to ensure that each op corresponds to | |
# a unique OpOverloadPacket object | |
setattr(self, op_name, opoverloadpacket) | |
self._dir.append(op_name) | |
return opoverloadpacket | |
class _PyOpNamespace(_OpNamespace): | |
def __init__(self, name, ops): | |
super().__init__(name) | |
self._ops = ops | |
def __getattr__(self, name): | |
# Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object. | |
op = self._ops.get(name, None) | |
if op is None: | |
raise AttributeError( | |
f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'" | |
) | |
setattr(self, name, op) | |
return op | |
class _Ops(types.ModuleType): | |
__file__ = "_ops.py" | |
def __init__(self): | |
super().__init__("torch.ops") | |
self.loaded_libraries = set() | |
self._higher_order_op_namespace = _PyOpNamespace( | |
"torch.ops.higher_order", _higher_order_ops | |
) | |
self._dir = [] | |
def __getattr__(self, name): | |
# Check if the name is a HigherOrderOperator | |
if name == "higher_order": | |
return self._higher_order_op_namespace | |
# Here we are creating `torch.ops.my_namespace` | |
namespace = _OpNamespace(name) | |
setattr(self, name, namespace) | |
self._dir.append(name) | |
return namespace | |
def __iter__(self): | |
return iter(self._dir) | |
def import_module(self, module): | |
""" | |
Imports a Python module that has torch.library registrations. | |
Generally, to extend PyTorch with custom operators, a user will | |
create a Python module whose import triggers registration of | |
the custom operators via a torch.ops.load_library call or a call | |
to one or more torch.library.* APIs. | |
It is unexpected for Python modules to have side effects, so some | |
linters and formatters will complain. Use this API to import Python | |
modules that contain these torch.library side effects. | |
Args: | |
module (str): The name of the Python module to import | |
""" | |
importlib.import_module(module) | |
def load_library(self, path): | |
""" | |
Loads a shared library from the given path into the current process. | |
The library being loaded may run global initialization code to register | |
custom operators with the PyTorch JIT runtime. This allows dynamically | |
loading custom operators. For this, you should compile your operator | |
and the static registration code into a shared library object, and then | |
call ``torch.ops.load_library('path/to/libcustom.so')`` to load the | |
shared object. | |
After the library is loaded, it is added to the | |
``torch.ops.loaded_libraries`` attribute, a set that may be inspected | |
for the paths of all libraries loaded using this function. | |
Args: | |
path (str): A path to a shared library to load. | |
""" | |
if torch._running_with_deploy(): | |
return | |
path = _utils_internal.resolve_library_path(path) | |
with dl_open_guard(): | |
# Import the shared library into the process, thus running its | |
# static (global) initialization code in order to register custom | |
# operators with the JIT. | |
ctypes.CDLL(path) | |
self.loaded_libraries.add(path) | |
# The ops "namespace" | |
ops = _Ops() | |