Spaces:
Sleeping
Sleeping
File size: 3,864 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
def _fft(x: np.ndarray, fft_length: int, axis: int) -> np.ndarray:
"""Compute the FFT return the real representation of the complex result."""
transformed = np.fft.fft(x, n=fft_length, axis=axis)
real_frequencies = np.real(transformed)
imaginary_frequencies = np.imag(transformed)
return np.concatenate(
(real_frequencies[..., np.newaxis], imaginary_frequencies[..., np.newaxis]),
axis=-1,
)
def _cfft(
x: np.ndarray,
fft_length: int,
axis: int,
onesided: bool,
normalize: bool,
) -> np.ndarray:
if x.shape[-1] == 1:
# The input contains only the real part
signal = x
else:
# The input is a real representation of a complex signal
slices = [slice(0, x) for x in x.shape]
slices[-1] = slice(0, x.shape[-1], 2)
real = x[tuple(slices)]
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
signal = real + 1j * imag
complex_signals = np.squeeze(signal, -1)
result = _fft(complex_signals, fft_length, axis=axis)
# Post process the result based on arguments
if onesided:
slices = [slice(0, a) for a in result.shape]
slices[axis] = slice(0, result.shape[axis] // 2 + 1)
result = result[tuple(slices)]
if normalize:
result /= fft_length
return result
def _ifft(x: np.ndarray, fft_length: int, axis: int, onesided: bool) -> np.ndarray:
signals = np.fft.ifft(x, fft_length, axis=axis)
real_signals = np.real(signals)
imaginary_signals = np.imag(signals)
merged = np.concatenate(
(real_signals[..., np.newaxis], imaginary_signals[..., np.newaxis]),
axis=-1,
)
if onesided:
slices = [slice(a) for a in merged.shape]
slices[axis] = slice(0, merged.shape[axis] // 2 + 1)
return merged[tuple(slices)]
return merged
def _cifft(
x: np.ndarray, fft_length: int, axis: int, onesided: bool = False
) -> np.ndarray:
if x.shape[-1] == 1:
frequencies = x
else:
slices = [slice(0, x) for x in x.shape]
slices[-1] = slice(0, x.shape[-1], 2)
real = x[tuple(slices)]
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
frequencies = real + 1j * imag
complex_frequencies = np.squeeze(frequencies, -1)
return _ifft(complex_frequencies, fft_length, axis=axis, onesided=onesided)
class DFT_17(OpRun):
def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = 1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
# Convert to positive axis
axis = axis % len(x.shape)
if dft_length is None:
dft_length = x.shape[axis]
if inverse: # type: ignore
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (result.astype(x.dtype),)
class DFT_20(OpRun):
def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = -2, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
# Convert to positive axis
axis = axis % len(x.shape)
if dft_length is None:
dft_length = x.shape[axis]
if inverse: # type: ignore
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (result.astype(x.dtype),)
|