Spaces:
Sleeping
Sleeping
import math | |
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union | |
import torch | |
import torch._prims as prims | |
import torch._prims_common as utils | |
from torch._decomp import register_decomposition | |
from torch._prims_common import DimsType, ShapeType, TensorLikeType | |
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper | |
__all__ = [ | |
# Transforms | |
"fft", | |
"fft2", | |
"fftn", | |
"hfft", | |
"hfft2", | |
"hfftn", | |
"rfft", | |
"rfft2", | |
"rfftn", | |
"ifft", | |
"ifft2", | |
"ifftn", | |
"ihfft", | |
"ihfft2", | |
"ihfftn", | |
"irfft", | |
"irfft2", | |
"irfftn", | |
# Helpers | |
"fftshift", | |
"ifftshift", | |
] | |
NormType = Union[None, Literal["forward", "backward", "ortho"]] | |
_NORM_VALUES = {None, "forward", "backward", "ortho"} | |
aten = torch._ops.ops.aten | |
def _apply_norm( | |
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool | |
) -> TensorLikeType: | |
"""Apply normalization to the un-normalized FFT result""" | |
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") | |
if norm == "ortho": | |
return x * (1 / math.sqrt(signal_numel)) | |
normalize = (not forward and (norm is None or norm == "backward")) or ( | |
forward and norm == "forward" | |
) | |
return x * (1 / signal_numel) if normalize else x | |
def _promote_type_fft( | |
dtype: torch.dtype, require_complex: bool, device: torch.device | |
) -> torch.dtype: | |
"""Helper to promote a dtype to one supported by the FFT primitives""" | |
if dtype.is_complex: | |
return dtype | |
# Promote integral to default float type | |
if not dtype.is_floating_point: | |
dtype = torch.get_default_dtype() | |
allowed_types = [torch.float32, torch.float64] | |
maybe_support_half = device.type in ["cuda", "meta"] | |
if maybe_support_half: | |
allowed_types.append(torch.float16) | |
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}") | |
if require_complex: | |
dtype = utils.corresponding_complex_dtype(dtype) | |
return dtype | |
def _maybe_promote_tensor_fft( | |
t: TensorLikeType, require_complex: bool = False | |
) -> TensorLikeType: | |
"""Helper to promote a tensor to a dtype supported by the FFT primitives""" | |
cur_type = t.dtype | |
new_type = _promote_type_fft(cur_type, require_complex, t.device) | |
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value] | |
def _resize_fft_input( | |
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...] | |
) -> TensorLikeType: | |
""" | |
Fixes the shape of x such that x.size(dims[i]) == sizes[i], | |
either by zero-padding, or by slicing x starting from 0. | |
""" | |
assert len(dims) == len(sizes) | |
must_copy = False | |
x_sizes = x.shape | |
pad_amount = [0] * len(x_sizes) * 2 | |
for i in range(len(dims)): | |
if sizes[i] == -1: | |
continue | |
if x_sizes[dims[i]] < sizes[i]: | |
must_copy = True | |
pad_idx = len(pad_amount) - 2 * dims[i] - 1 | |
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] | |
if x_sizes[dims[i]] > sizes[i]: | |
x = x.narrow(dims[i], 0, sizes[i]) | |
return torch.constant_pad_nd(x, pad_amount) if must_copy else x | |
def _fft_c2r( | |
func_name: str, | |
input: TensorLikeType, | |
n: Optional[int], | |
dim: int, | |
norm: NormType, | |
forward: bool, | |
) -> TensorLikeType: | |
"""Common code for performing any complex to real FFT (irfft or hfft)""" | |
input = _maybe_promote_tensor_fft(input, require_complex=True) | |
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) | |
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) | |
torch._check( | |
last_dim_size >= 1, | |
lambda: f"Invalid number of data points ({last_dim_size}) specified", | |
) | |
if n is not None: | |
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) | |
if forward: | |
input = torch.conj(input) | |
output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) | |
return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward) | |
def _fft_r2c( | |
func_name: str, | |
input: TensorLikeType, | |
n: Optional[int], | |
dim: int, | |
norm: NormType, | |
forward: bool, | |
onesided: bool, | |
) -> TensorLikeType: | |
"""Common code for performing any real to complex FFT (rfft or ihfft)""" | |
torch._check( | |
not input.dtype.is_complex, | |
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", | |
) | |
input = _maybe_promote_tensor_fft(input) | |
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) | |
dim_size = n if n is not None else input.shape[dim] | |
torch._check( | |
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" | |
) | |
if n is not None: | |
input = _resize_fft_input(input, dims, (n,)) | |
ret = prims.fft_r2c(input, dim=dims, onesided=onesided) | |
ret = _apply_norm(ret, norm, dim_size, forward) | |
return ret if forward else torch.conj(ret) | |
def _fft_c2c( | |
func_name: str, | |
input: TensorLikeType, | |
n: Optional[int], | |
dim: int, | |
norm: NormType, | |
forward: bool, | |
) -> TensorLikeType: | |
"""Common code for performing any complex to complex FFT (fft or ifft)""" | |
torch._check( | |
input.dtype.is_complex, | |
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", | |
) | |
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) | |
dim_size = n if n is not None else input.shape[dim] | |
torch._check( | |
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" | |
) | |
if n is not None: | |
input = _resize_fft_input(input, dims, (n,)) | |
ret = prims.fft_c2c(input, dim=dims, forward=forward) | |
return _apply_norm(ret, norm, dim_size, forward) | |
def fft( | |
input: TensorLikeType, | |
n: Optional[int] = None, | |
dim: int = -1, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
if input.dtype.is_complex: | |
return _fft_c2c("fft", input, n, dim, norm, forward=True) | |
else: | |
return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False) | |
def ifft( | |
input: TensorLikeType, | |
n: Optional[int] = None, | |
dim: int = -1, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
if input.dtype.is_complex: | |
return _fft_c2c("ifft", input, n, dim, norm, forward=False) | |
else: | |
return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False) | |
def rfft( | |
input: TensorLikeType, | |
n: Optional[int] = None, | |
dim: int = -1, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True) | |
def irfft( | |
input: TensorLikeType, | |
n: Optional[int] = None, | |
dim: int = -1, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return _fft_c2r("irfft", input, n, dim, norm, forward=False) | |
def hfft( | |
input: TensorLikeType, | |
n: Optional[int] = None, | |
dim: int = -1, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return _fft_c2r("hfft", input, n, dim, norm, forward=True) | |
def ihfft( | |
input: TensorLikeType, | |
n: Optional[int] = None, | |
dim: int = -1, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True) | |
class _ShapeAndDims(NamedTuple): | |
shape: Tuple[int, ...] | |
dims: Tuple[int, ...] | |
def _canonicalize_fft_shape_and_dim_args( | |
input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType] | |
) -> _ShapeAndDims: | |
"""Convert the shape and dim arguments into a canonical form where neither are optional""" | |
input_dim = input.ndim | |
input_sizes = input.shape | |
if dim is not None: | |
if not isinstance(dim, Sequence): | |
dim = (dim,) | |
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) | |
# Check dims are unique | |
torch._check( | |
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique" | |
) | |
if shape is not None: | |
if not isinstance(shape, Sequence): | |
shape = (shape,) | |
# Has shape, might have dim | |
torch._check( | |
dim is None or len(dim) == len(shape), | |
lambda: "When given, dim and shape arguments must have the same length", | |
) | |
transform_ndim = len(shape) | |
torch._check( | |
transform_ndim <= input_dim, | |
lambda: f"Got shape with {transform_ndim} values but input tensor " | |
f"only has {input_dim} dimensions.", | |
) | |
# If shape is given, dims defaults to the last len(shape) dimensions | |
if dim is None: | |
ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) | |
# Translate any -1 values in shape to the default length | |
ret_shape = tuple( | |
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] | |
) | |
elif dim is None: | |
# No shape, no dim | |
ret_dims = tuple(range(input_dim)) | |
ret_shape = tuple(input_sizes) | |
else: | |
# No shape, has dim | |
ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined] | |
for n in ret_shape: | |
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") | |
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined] | |
def _prod(xs: Iterable[int]) -> int: | |
"""Compute product of a list""" | |
prod = 1 | |
for x in xs: | |
prod *= x | |
return prod | |
def _fftn_c2c( | |
function_name: str, | |
input: TensorLikeType, | |
shape: Tuple[int, ...], | |
dim: Tuple[int, ...], | |
norm: NormType, | |
forward: bool, | |
) -> TensorLikeType: | |
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" | |
torch._check( | |
input.dtype.is_complex, | |
lambda: f"{function_name} expects a complex input tensor, " | |
f"but got {input.dtype}", | |
) | |
x = _resize_fft_input(input, dim, shape) | |
output = prims.fft_c2c(x, dim=dim, forward=forward) | |
return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward) | |
def fftn( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = None, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) | |
x = _maybe_promote_tensor_fft(input, require_complex=True) | |
return _fftn_c2c("fftn", x, shape, dim, norm, forward=True) | |
def ifftn( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = None, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) | |
x = _maybe_promote_tensor_fft(input, require_complex=True) | |
return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False) | |
def rfftn( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = None, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
torch._check( | |
not input.dtype.is_complex, | |
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", | |
) | |
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) | |
input = _maybe_promote_tensor_fft(input, require_complex=False) | |
input = _resize_fft_input(input, dim, shape) | |
out = prims.fft_r2c(input, dim=dim, onesided=True) | |
return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True) | |
def ihfftn( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = None, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
torch._check( | |
not input.dtype.is_complex, | |
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", | |
) | |
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) | |
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") | |
input = _maybe_promote_tensor_fft(input, require_complex=False) | |
input = _resize_fft_input(input, dim, shape) | |
tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) | |
if len(dim) == 1: | |
tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) | |
return prims.conj(tmp) | |
tmp = prims.conj_physical(tmp) | |
tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) | |
return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False) | |
class _CanonicalizeC2rReturn(NamedTuple): | |
shape: Tuple[int, ...] | |
dim: Tuple[int, ...] | |
last_dim_size: int | |
def _canonicalize_fft_c2r_shape_and_dim_args( | |
fname: str, | |
input: TensorLikeType, | |
s: Optional[ShapeType], | |
dim: Optional[DimsType], | |
) -> _CanonicalizeC2rReturn: | |
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms, | |
as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" | |
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) | |
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") | |
if s is None or s[-1] == -1: | |
last_dim_size = 2 * (input.shape[dim[-1]] - 1) | |
else: | |
last_dim_size = shape[-1] | |
torch._check( | |
last_dim_size >= 1, | |
lambda: f"Invalid number of data points ({last_dim_size}) specified", | |
) | |
shape_list = list(shape) | |
shape_list[-1] = last_dim_size // 2 + 1 | |
return _CanonicalizeC2rReturn( | |
shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size | |
) | |
def irfftn( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = None, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( | |
"irfftn", input, s, dim | |
) | |
input = _maybe_promote_tensor_fft(input, require_complex=True) | |
input = _resize_fft_input(input, dim, shape) | |
out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size) | |
return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False) | |
def hfftn( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = None, | |
norm: NormType = None, | |
) -> TensorLikeType: | |
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( | |
"hfftn", input, s, dim | |
) | |
input = _maybe_promote_tensor_fft(input, require_complex=True) | |
input = _resize_fft_input(input, dim, shape) | |
tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input | |
tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True) | |
tmp = prims.conj_physical(tmp) | |
out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size) | |
return _apply_norm(out, norm, last_dim_size, forward=True) | |
def fft2( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = (-2, -1), | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return torch.fft.fftn(input, s=s, dim=dim, norm=norm) | |
def ifft2( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = (-2, -1), | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return torch.fft.ifftn(input, s=s, dim=dim, norm=norm) | |
def rfft2( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = (-2, -1), | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return torch.fft.rfftn(input, s=s, dim=dim, norm=norm) | |
def irfft2( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = (-2, -1), | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return torch.fft.irfftn(input, s=s, dim=dim, norm=norm) | |
def hfft2( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = (-2, -1), | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return torch.fft.hfftn(input, s=s, dim=dim, norm=norm) | |
def ihfft2( | |
input: TensorLikeType, | |
s: Optional[ShapeType] = None, | |
dim: Optional[DimsType] = (-2, -1), | |
norm: NormType = None, | |
) -> TensorLikeType: | |
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) | |
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]: | |
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" | |
if dim is None: | |
return list(range(x.ndim)) | |
elif not isinstance(dim, Sequence): | |
return [dim] | |
else: | |
return list(dim) | |
def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: | |
dims = _default_alldims(dim, input) | |
shift = [input.shape[d] // 2 for d in dims] | |
return torch.roll(input, shift, dims) | |
def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: | |
dims = _default_alldims(dim, input) | |
shift = [(input.shape[d] + 1) // 2 for d in dims] | |
return torch.roll(input, shift, dims) | |