Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
3.86 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx.reference.op_run import OpRun
def _fft(x: np.ndarray, fft_length: int, axis: int) -> np.ndarray:
"""Compute the FFT return the real representation of the complex result."""
transformed = np.fft.fft(x, n=fft_length, axis=axis)
real_frequencies = np.real(transformed)
imaginary_frequencies = np.imag(transformed)
return np.concatenate(
(real_frequencies[..., np.newaxis], imaginary_frequencies[..., np.newaxis]),
axis=-1,
)
def _cfft(
x: np.ndarray,
fft_length: int,
axis: int,
onesided: bool,
normalize: bool,
) -> np.ndarray:
if x.shape[-1] == 1:
# The input contains only the real part
signal = x
else:
# The input is a real representation of a complex signal
slices = [slice(0, x) for x in x.shape]
slices[-1] = slice(0, x.shape[-1], 2)
real = x[tuple(slices)]
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
signal = real + 1j * imag
complex_signals = np.squeeze(signal, -1)
result = _fft(complex_signals, fft_length, axis=axis)
# Post process the result based on arguments
if onesided:
slices = [slice(0, a) for a in result.shape]
slices[axis] = slice(0, result.shape[axis] // 2 + 1)
result = result[tuple(slices)]
if normalize:
result /= fft_length
return result
def _ifft(x: np.ndarray, fft_length: int, axis: int, onesided: bool) -> np.ndarray:
signals = np.fft.ifft(x, fft_length, axis=axis)
real_signals = np.real(signals)
imaginary_signals = np.imag(signals)
merged = np.concatenate(
(real_signals[..., np.newaxis], imaginary_signals[..., np.newaxis]),
axis=-1,
)
if onesided:
slices = [slice(a) for a in merged.shape]
slices[axis] = slice(0, merged.shape[axis] // 2 + 1)
return merged[tuple(slices)]
return merged
def _cifft(
x: np.ndarray, fft_length: int, axis: int, onesided: bool = False
) -> np.ndarray:
if x.shape[-1] == 1:
frequencies = x
else:
slices = [slice(0, x) for x in x.shape]
slices[-1] = slice(0, x.shape[-1], 2)
real = x[tuple(slices)]
slices[-1] = slice(1, x.shape[-1], 2)
imag = x[tuple(slices)]
frequencies = real + 1j * imag
complex_frequencies = np.squeeze(frequencies, -1)
return _ifft(complex_frequencies, fft_length, axis=axis, onesided=onesided)
class DFT_17(OpRun):
def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = 1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
# Convert to positive axis
axis = axis % len(x.shape)
if dft_length is None:
dft_length = x.shape[axis]
if inverse: # type: ignore
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (result.astype(x.dtype),)
class DFT_20(OpRun):
def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = -2, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]: # type: ignore
# Convert to positive axis
axis = axis % len(x.shape)
if dft_length is None:
dft_length = x.shape[axis]
if inverse: # type: ignore
result = _cifft(x, dft_length, axis=axis, onesided=onesided)
else:
result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
return (result.astype(x.dtype),)