Spaces:
Running
Running
import torch | |
import torch._prims_common as utils | |
# Utilities should come BEFORE this import | |
from torch._decomp import register_decomposition | |
from torch._prims_common import TensorLikeType | |
from torch._prims_common.wrappers import out_wrapper | |
from torch._refs import _broadcast_shapes | |
# Data conversion references. | |
# | |
# Note: this module breaks the usual _refs to torch naming scheme where | |
# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not | |
# part of _refs/__init__.py to avoid name clashes with Python builtin types | |
# (like int). | |
__all__ = [ | |
# dtypes | |
"bfloat16", | |
"bool", | |
"byte", | |
"cdouble", | |
"cfloat", | |
"chalf", | |
"char", | |
"double", | |
"float", | |
"half", | |
"int", | |
"long", | |
"short", | |
# misc | |
"complex", | |
"polar", | |
] | |
def _make_conversion_method(name: str, dtype: torch.dtype): | |
def fn( | |
self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format | |
) -> TensorLikeType: | |
return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] | |
fn.__name__ = name | |
return fn | |
bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) | |
bool = _make_conversion_method("bool", torch.bool) | |
byte = _make_conversion_method("byte", torch.uint8) | |
cdouble = _make_conversion_method("cdouble", torch.cdouble) | |
cfloat = _make_conversion_method("cfloat", torch.cfloat) | |
chalf = _make_conversion_method("chalf", torch.complex32) | |
char = _make_conversion_method("char", torch.int8) | |
double = _make_conversion_method("double", torch.double) | |
float = _make_conversion_method("float", torch.float) | |
half = _make_conversion_method("half", torch.half) | |
int = _make_conversion_method("int", torch.int) | |
long = _make_conversion_method("long", torch.long) | |
short = _make_conversion_method("short", torch.short) | |
# Note: complex has type promotion tests disabled due to different semantics. | |
# exact_dtype is for compat with complex_check_dtype from core. | |
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: | |
allowed_dtypes = (torch.float32, torch.float64, torch.float16) | |
torch._check( | |
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, | |
lambda: ( | |
f"Expected both inputs to be Half, Float or Double tensors but got " | |
f"{real.dtype} and {imag.dtype}" | |
), | |
) | |
torch._check( | |
real.dtype == imag.dtype, | |
lambda: ( | |
f"Expected object of scalar type {real.dtype} but got " | |
f"scalar type {imag.dtype} for second argument" | |
), | |
) | |
result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] | |
common_shape = _broadcast_shapes(real.shape, imag.shape) | |
result = real.new_empty( | |
common_shape, | |
dtype=result_dtype, | |
layout=real.layout, | |
device=real.device, | |
# pin_memory=real.is_pinned(), # NYI | |
) | |
result.real = real | |
result.imag = imag | |
return result | |
# Note: polar has type promotion tests disabled due to different semantics. | |
# exact_dtype is for compat with complex_check_dtype from core. | |
def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: | |
result = torch.complex(abs, angle) | |
result.real = abs * torch.cos(angle) | |
result.imag = abs * torch.sin(angle) | |
return result | |