File size: 5,089 Bytes
3dd84f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from torch import nn


class ISTFT(nn.Module):
    """

    Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with

    windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.

    See issue: https://github.com/pytorch/pytorch/issues/62323

    Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.

    The NOLA constraint is met as we trim padded samples anyway.



    Args:

        n_fft (int): Size of Fourier transform.

        hop_length (int): The distance between neighboring sliding window frames.

        win_length (int): The size of window frame and STFT filter.

        padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".

    """

    def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
        super().__init__()
        if padding not in ["center", "same"]:
            raise ValueError("Padding must be 'center' or 'same'.")
        self.padding = padding
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        window = torch.hann_window(win_length)
        self.register_buffer("window", window)

    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        """

        Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.



        Args:

            spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,

                            N is the number of frequency bins, and T is the number of time frames.



        Returns:

            Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.

        """
        if self.padding == "center":
            # Fallback to pytorch native implementation
            return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
        elif self.padding == "same":
            pad = (self.win_length - self.hop_length) // 2
        else:
            raise ValueError("Padding must be 'center' or 'same'.")

        assert spec.dim() == 3, "Expected a 3D tensor as input"
        B, N, T = spec.shape

        # Inverse FFT
        ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
        ifft = ifft * self.window[None, :, None]

        # Overlap and Add
        output_size = (T - 1) * self.hop_length + self.win_length
        y = torch.nn.functional.fold(
            ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
        )[:, 0, 0, pad:-pad]

        # Window envelope
        window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
        window_envelope = torch.nn.functional.fold(
            window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
        ).squeeze()[pad:-pad]

        # Normalize
        assert (window_envelope > 1e-11).all()
        y = y / window_envelope

        return y

class ISTFTHead(nn.Module):
    """

    ISTFT Head module for predicting STFT complex coefficients.



    Args:

        dim (int): Hidden dimension of the model.

        n_fft (int): Size of Fourier transform.

        hop_length (int): The distance between neighboring sliding window frames, which should align with

                          the resolution of the input features.

        padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".

    """

    def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
        super().__init__()
        out_dim = n_fft + 2
        self.out = torch.nn.Linear(dim, out_dim)
        self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Forward pass of the ISTFTHead module.



        Args:

            x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,

                        L is the sequence length, and H denotes the model dimension.



        Returns:

            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.

        """
        x = self.out(x).transpose(1, 2)
        mag, p = x.chunk(2, dim=1)
        mag = torch.exp(mag)
        mag = torch.clip(mag, max=1e2)  # safeguard to prevent excessively large magnitudes
        # wrapping happens here. These two lines produce real and imaginary value
        x = torch.cos(p)
        y = torch.sin(p)
        # recalculating phase here does not produce anything new
        # only costs time
        # phase = torch.atan2(y, x)
        # S = mag * torch.exp(phase * 1j)
        # better directly produce the complex value 
        S = mag * (x + 1j * y)
        audio = self.istft(S)
        return audio