Spaces:
Running
Running
import inspect | |
import warnings | |
from functools import wraps | |
from itertools import chain | |
from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple | |
import torch | |
import torch._prims_common as utils | |
from torch._prims_common import ( | |
CustomOutParamAnnotation, | |
ELEMENTWISE_TYPE_PROMOTION_KIND, | |
Number, | |
NumberType, | |
ShapeType, | |
TensorLike, | |
TensorLikeType, | |
) | |
from torch.utils import _pytree as pytree | |
from torch.utils._pytree import tree_flatten, tree_unflatten | |
def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: | |
pass | |
def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: | |
pass | |
def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: | |
pass | |
def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: | |
pass | |
# TODO: implement ref.cast with an option to enforce safe casting | |
def _maybe_convert_to_dtype(a, dtype): | |
if isinstance(a, TensorLike): | |
if a.dtype != dtype: | |
return a.to(dtype) | |
return a | |
if isinstance(a, Number): | |
return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type] | |
if isinstance(a, Sequence): | |
return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) | |
# Passthrough None because some functions wrapped with type promotion | |
# wrapper might have optional args | |
if a is None: | |
return None | |
raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") | |
def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: | |
if not isinstance(a, Number): | |
msg = f"Found unknown type {type(a)} when trying to convert scalars!" | |
raise ValueError(msg) | |
if not utils.is_weakly_lesser_type(type(a), typ): | |
msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!" | |
raise ValueError(msg) | |
return typ(a) | |
def _annotation_has_type(*, typ, annotation): | |
if hasattr(annotation, "__args__"): | |
for a in annotation.__args__: | |
if _annotation_has_type(typ=typ, annotation=a): | |
return True | |
return False | |
return typ is annotation | |
class elementwise_type_promotion_wrapper: | |
""" | |
Adds elementwise type promotion to a Python reference implementation. | |
Takes two kwargs, type_promoting_args and type_promotion_kind. | |
type_promoting_args must be a string Sequence specifiying the argument names of all | |
arguments that participate in type promotion (and should be type promoted). If the | |
arg specifies a Sequence-type then every element of the Sequence will participate in | |
type promotion. | |
type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. | |
See its documentation for details. | |
The return_dtype will be coerced to the wrapped function's dtype arg if it is available and | |
not None. | |
Other type promotion behavior, like validating the Python type of scalar arguments, must | |
be handled separately. | |
""" | |
def __init__( | |
self, | |
*, | |
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, | |
type_promoting_args: Optional[Sequence[str]] = None, | |
): | |
self.type_promoting_arg_names = type_promoting_args | |
self.type_promotion_kind = type_promotion_kind | |
def __call__(self, fn: Callable) -> Callable: | |
sig = inspect.signature(fn) | |
def _fn(*args, **kwargs): | |
bound = sig.bind(*args, **kwargs) | |
type_promoting_args = tuple( | |
bound.arguments[x] | |
for x in self.type_promoting_arg_names # type: ignore[union-attr] | |
if x in bound.arguments.keys() | |
) | |
flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) | |
compute_dtype, result_dtype = utils.elementwise_dtypes( | |
*flattened_type_promoting_args, | |
type_promotion_kind=self.type_promotion_kind, | |
) | |
promoted_args = { | |
x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) | |
for x in self.type_promoting_arg_names # type: ignore[union-attr] | |
if x in bound.arguments.keys() | |
} | |
bound.arguments.update(promoted_args) | |
result = fn(**bound.arguments) | |
# Override the return_dtype if a dtype arg is present and not None | |
if "dtype" in bound.arguments: | |
maybe_dtype = bound.arguments["dtype"] | |
if maybe_dtype: # dtype cannot be None | |
result_dtype = maybe_dtype | |
if isinstance(result, TensorLike): | |
return _maybe_convert_to_dtype(result, result_dtype) | |
if isinstance(result, Sequence): | |
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) | |
raise AssertionError(f"Unhandled result type: {type(result)}") | |
_fn.__signature__ = sig # type: ignore[attr-defined] | |
return _fn | |
# Returns True if resize is necessary | |
def _resize_output_check(out: TensorLikeType, shape: ShapeType): | |
# If the shapes are correct there's nothing to do | |
if utils.same_shape(out.shape, shape): | |
return False | |
if out.numel() != 0: | |
msg = ( | |
f"An output with one or more elements was resized since it had shape {str(out.shape)} " | |
"which does not match the required output shape {str(shape)}. " | |
"This behavior is deprecated, and in a future PyTorch release outputs will not " | |
"be resized unless they have zero elements. " | |
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." | |
) | |
warnings.warn(msg) | |
return True | |
# TODO: handle tuples of tensors | |
def _maybe_resize_out(out: TensorLikeType, shape: ShapeType): | |
if _resize_output_check(out, shape): | |
return out.resize_(shape) | |
else: | |
return out | |
def _safe_copy_out( | |
*, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False | |
): | |
# Checks same device | |
if copy_from.device != copy_to.device: | |
msg = "Attempting to copy from device {} to device {}, but cross-device copies are not allowed!".format( | |
copy_from.device, copy_to.device | |
) | |
raise RuntimeError(msg) | |
# Checks safe cast | |
if exact_dtype: | |
torch._check( | |
copy_from.dtype == copy_to.dtype, | |
lambda: f"Expected out tensor to have dtype {copy_from.dtype} " | |
f"but got {copy_to.dtype} instead", | |
) | |
else: | |
torch._check( | |
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), | |
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " | |
"but this can't be cast because it is not safe!", | |
) | |
return copy_to.copy_(copy_from) | |
def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False): | |
# The wrapped function needs to convert the output parameters to ensure | |
# compatibility between the Python API (which always uses "out" as the | |
# parameter name and may be a tuple) and the Aten API (which may have | |
# multiple output parameters and use different parameter names such as | |
# "grad_input", "indices" or "values".) | |
default_out_names = ("out",) | |
if len(out_names) == 0: | |
# Use default in out name | |
out_names = default_out_names | |
is_tensor = len(out_names) == 1 | |
def _out_wrapper(fn: Callable) -> Callable: | |
""" | |
Adds the out parameter to a Python reference. | |
""" | |
out_type = ( | |
TensorLikeType | |
if is_tensor | |
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] | |
) | |
return_type = ( | |
TensorLikeType | |
if is_tensor | |
else NamedTuple( | |
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] | |
) | |
) | |
sig = inspect.signature(fn) | |
factory_kwargs = ("device", "dtype") | |
is_factory_fn = all(p in sig.parameters for p in factory_kwargs) | |
def _fn(*args, out=None, **kwargs): | |
if is_factory_fn and out is not None: | |
for k in factory_kwargs: | |
out_attr = getattr(out, k) | |
if k not in kwargs: | |
kwargs[k] = out_attr | |
if pass_is_out: | |
result = fn(*args, is_out=(out is not None), **kwargs) | |
else: | |
result = fn(*args, **kwargs) | |
assert ( | |
isinstance(result, TensorLike) | |
and is_tensor | |
or isinstance(result, Tuple) # type: ignore[arg-type] | |
and len(result) == len(out_names) | |
) | |
if out is not None: | |
# Naively you might expect this assert to be true, but | |
# it's not: | |
# | |
# assert type(out) == type(result) | |
# | |
# The reason is that functions under this wrapper can | |
# get registered to the Meta dispatch key, and that | |
# means they can be executed in a context where tensor | |
# subclasses are disabled (with no_dispatch), which is a | |
# handy way for an is-a tensor subclass (e.g., | |
# FakeTensor) to have the normal meta backend create a | |
# meta tensor, to be wrapped once it gets returned. | |
# In this situation, you will get a FakeTensor as | |
# the output tensor, but not the result--which will | |
# be a normal meta tensor, but this is perfectly | |
# harmless. | |
if is_tensor: | |
assert isinstance(out, TensorLike) | |
# These two operations are done in-place | |
_maybe_resize_out(out, result.shape) | |
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] | |
else: | |
assert isinstance(out, Tuple) # type: ignore[arg-type] | |
torch._check_type( | |
len(out) == len(result), | |
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", | |
) | |
for r, o in zip(result, out): | |
# These two operations are done in-place | |
_maybe_resize_out(o, r.shape) | |
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] | |
else: | |
out = result | |
# mypy does not see through the definition of out_type given that it's in a different scope | |
return out if is_tensor else return_type(*out) # type: ignore[operator] | |
out_param = inspect.Parameter( | |
"out", | |
kind=inspect.Parameter.KEYWORD_ONLY, | |
default=None, | |
annotation=out_type, | |
) | |
# Mark that the function now returns a tuple | |
assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( | |
sig.empty, | |
out_type, | |
) | |
params = chain(sig.parameters.values(), (out_param,)) | |
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] | |
parameters=params, return_annotation=return_type # type: ignore[arg-type] | |
) | |
_fn.__annotations__ = fn.__annotations__ | |
_fn.__annotations__["out"] = out_type | |
_fn.__annotations__["return"] = return_type | |
# In the special case of having a single tensor out parameter with a | |
# name other than out, add a special annotation to name the parameter | |
if is_tensor and out_names != default_out_names: | |
_fn.__annotations__[CustomOutParamAnnotation] = out_names[0] | |
# Add an indicator attribute that can be used in special cases | |
# where having a function wrapped by `out_wrapper` is not desirable e.g. | |
# jit | |
_fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] | |
return _fn | |
return _out_wrapper | |
def _maybe_remove_out_wrapper(fn: Callable): | |
return inspect.unwrap( | |
fn, | |
stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"), | |
) | |
def backwards_not_supported(prim): | |
def redispatch_prim(args, kwargs): | |
with torch._C._AutoDispatchBelowAutograd(): | |
old = torch._C._dispatch_tls_is_dispatch_key_excluded( | |
torch._C.DispatchKey.ADInplaceOrView | |
) | |
return prim(*args, **kwargs) | |
class BackwardsNotSupported(torch.autograd.Function): | |
def forward(ctx, args_spec, *flat_args): | |
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type] | |
return redispatch_prim(args, kwargs) | |
def backward(ctx, *args): | |
raise RuntimeError("backwards not supported on prim") | |
def _autograd_impl(*args, **kwargs): | |
flat_args, args_spec = tree_flatten((args, kwargs)) | |
if torch.is_grad_enabled() and any( | |
a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) | |
): | |
# TODO: There is a subtle bug here: prims like copy_to | |
# return their input argument after mutating it; and custom | |
# autograd function will incorrectly turn the result into | |
# a view which will fail test_python_ref_executor tests. | |
# At the moment, we sidestep this by observing that the | |
# unit tests don't ever try to run the executor with | |
# autograd, so we don't exercise the buggy case, but if | |
# you ever want to feed autograd through this, be aware | |
# of it! We need a way of properly implementing autograd | |
# for mutating operations in Python to do this. | |
return BackwardsNotSupported.apply(args_spec, *flat_args) | |
else: | |
return redispatch_prim(args, kwargs) | |
return _autograd_impl | |
# TODO: when tracing this will add torch tensors and not TensorMeta objects | |
# to the trace -- we should fix this by adding a tracing context and NumberMeta classes | |
# TODO: this wrapper is currently untested | |
def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: | |
""" | |
Allows unary operators that accept tensors to work with Python numbers. | |
""" | |
sig = inspect.signature(fn) | |
def _fn(*args, **kwargs): | |
if len(args) > 0 and isinstance(args[0], Number): | |
dtype = utils.type_to_dtype(type(args[0])) | |
args_ = list(args) | |
args_[0] = torch.tensor(args[0], dtype=dtype) | |
result = fn(*args_, **kwargs) | |
assert isinstance(result, torch.Tensor) | |
return result.item() | |
return fn(*args, **kwargs) | |
_fn.__signature__ = sig # type: ignore[attr-defined] | |
return _fn | |