Spaces:
Sleeping
Sleeping
File size: 2,935 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# mypy: ignore-errors
from __future__ import annotations
import functools
import torch
from . import _dtypes_impl, _util
from ._normalizations import ArrayLike, normalizer
def upcast(func):
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
@functools.wraps(func)
def wrapped(tensor, *args, **kwds):
target_dtype = (
_dtypes_impl.default_dtypes().complex_dtype
if tensor.is_complex()
else _dtypes_impl.default_dtypes().float_dtype
)
tensor = _util.cast_if_needed(tensor, target_dtype)
return func(tensor, *args, **kwds)
return wrapped
@normalizer
@upcast
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.fft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.ifft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.rfft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.irfft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def fftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.fftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.ifftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.rfftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.irfftn(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.fft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.ifft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.rfft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.irfft2(a, s, dim=axes, norm=norm)
@normalizer
@upcast
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.hfft(a, n, dim=axis, norm=norm)
@normalizer
@upcast
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.ihfft(a, n, dim=axis, norm=norm)
@normalizer
def fftfreq(n, d=1.0):
return torch.fft.fftfreq(n, d)
@normalizer
def rfftfreq(n, d=1.0):
return torch.fft.rfftfreq(n, d)
@normalizer
def fftshift(x: ArrayLike, axes=None):
return torch.fft.fftshift(x, axes)
@normalizer
def ifftshift(x: ArrayLike, axes=None):
return torch.fft.ifftshift(x, axes)
|