File size: 3,864 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import numpy as np

from onnx.reference.op_run import OpRun


def _fft(x: np.ndarray, fft_length: int, axis: int) -> np.ndarray:
    """Compute the FFT return the real representation of the complex result."""
    transformed = np.fft.fft(x, n=fft_length, axis=axis)
    real_frequencies = np.real(transformed)
    imaginary_frequencies = np.imag(transformed)
    return np.concatenate(
        (real_frequencies[..., np.newaxis], imaginary_frequencies[..., np.newaxis]),
        axis=-1,
    )


def _cfft(

    x: np.ndarray,

    fft_length: int,

    axis: int,

    onesided: bool,

    normalize: bool,

) -> np.ndarray:
    if x.shape[-1] == 1:
        # The input contains only the real part
        signal = x
    else:
        # The input is a real representation of a complex signal
        slices = [slice(0, x) for x in x.shape]
        slices[-1] = slice(0, x.shape[-1], 2)
        real = x[tuple(slices)]
        slices[-1] = slice(1, x.shape[-1], 2)
        imag = x[tuple(slices)]
        signal = real + 1j * imag

    complex_signals = np.squeeze(signal, -1)
    result = _fft(complex_signals, fft_length, axis=axis)
    # Post process the result based on arguments
    if onesided:
        slices = [slice(0, a) for a in result.shape]
        slices[axis] = slice(0, result.shape[axis] // 2 + 1)
        result = result[tuple(slices)]
    if normalize:
        result /= fft_length
    return result


def _ifft(x: np.ndarray, fft_length: int, axis: int, onesided: bool) -> np.ndarray:
    signals = np.fft.ifft(x, fft_length, axis=axis)
    real_signals = np.real(signals)
    imaginary_signals = np.imag(signals)
    merged = np.concatenate(
        (real_signals[..., np.newaxis], imaginary_signals[..., np.newaxis]),
        axis=-1,
    )
    if onesided:
        slices = [slice(a) for a in merged.shape]
        slices[axis] = slice(0, merged.shape[axis] // 2 + 1)
        return merged[tuple(slices)]
    return merged


def _cifft(

    x: np.ndarray, fft_length: int, axis: int, onesided: bool = False

) -> np.ndarray:
    if x.shape[-1] == 1:
        frequencies = x
    else:
        slices = [slice(0, x) for x in x.shape]
        slices[-1] = slice(0, x.shape[-1], 2)
        real = x[tuple(slices)]
        slices[-1] = slice(1, x.shape[-1], 2)
        imag = x[tuple(slices)]
        frequencies = real + 1j * imag
    complex_frequencies = np.squeeze(frequencies, -1)
    return _ifft(complex_frequencies, fft_length, axis=axis, onesided=onesided)


class DFT_17(OpRun):
    def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = 1, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]:  # type: ignore
        # Convert to positive axis
        axis = axis % len(x.shape)
        if dft_length is None:
            dft_length = x.shape[axis]
        if inverse:  # type: ignore
            result = _cifft(x, dft_length, axis=axis, onesided=onesided)
        else:
            result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
        return (result.astype(x.dtype),)


class DFT_20(OpRun):
    def _run(self, x: np.ndarray, dft_length: int | None = None, axis: int = -2, inverse: bool = False, onesided: bool = False) -> tuple[np.ndarray]:  # type: ignore
        # Convert to positive axis
        axis = axis % len(x.shape)
        if dft_length is None:
            dft_length = x.shape[axis]
        if inverse:  # type: ignore
            result = _cifft(x, dft_length, axis=axis, onesided=onesided)
        else:
            result = _cfft(x, dft_length, axis=axis, onesided=onesided, normalize=False)
        return (result.astype(x.dtype),)