# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # MIT License # Copyright (c) 2020 Shimin Zhang # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import torch as th import torch.nn.functional as F from scipy.signal import check_COLA, get_window support_clp_op = None if th.__version__ >= "1.7.0": from torch.fft import rfft as fft support_clp_op = True else: from torch import rfft as fft class STFT(th.nn.Module): def __init__( self, win_len=1024, win_hop=512, fft_len=1024, enframe_mode="continue", win_type="hann", win_sqrt=False, pad_center=True, ): """ Implement of STFT using 1D convolution and 1D transpose convolutions. Implement of framing the signal in 2 ways, `break` and `continue`. `break` method is a kaldi-like framing. `continue` method is a librosa-like framing. More information about `perfect reconstruction`: 1. https://ww2.mathworks.cn/help/signal/ref/stft.html 2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html Args: win_len (int): Number of points in one frame. Defaults to 1024. win_hop (int): Number of framing stride. Defaults to 512. fft_len (int): Number of DFT points. Defaults to 1024. enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'. win_type (str, optional): The type of window to create. Defaults to 'hann'. win_sqrt (bool, optional): using square root window. Defaults to True. pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True. """ super(STFT, self).__init__() assert enframe_mode in ["break", "continue"] assert fft_len >= win_len self.win_len = win_len self.win_hop = win_hop self.fft_len = fft_len self.mode = enframe_mode self.win_type = win_type self.win_sqrt = win_sqrt self.pad_center = pad_center self.pad_amount = self.fft_len // 2 en_k, fft_k, ifft_k, ola_k = self.__init_kernel__() self.register_buffer("en_k", en_k) self.register_buffer("fft_k", fft_k) self.register_buffer("ifft_k", ifft_k) self.register_buffer("ola_k", ola_k) def __init_kernel__(self): """ Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel. ** enframe_kernel: Using conv1d layer and identity matrix. ** fft_kernel: Using linear layer for matrix multiplication. In fact, enframe_kernel and fft_kernel can be combined, But for the sake of readability, I took the two apart. ** ifft_kernel, pinv of fft_kernel. ** overlap-add kernel, just like enframe_kernel, but transposed. Returns: tuple: four kernels. """ enframed_kernel = th.eye(self.fft_len)[:, None, :] if support_clp_op: tmp = fft(th.eye(self.fft_len)) fft_kernel = th.stack([tmp.real, tmp.imag], dim=2) else: fft_kernel = fft(th.eye(self.fft_len), 1) if self.mode == "break": enframed_kernel = th.eye(self.win_len)[:, None, :] fft_kernel = fft_kernel[: self.win_len] fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1) ifft_kernel = th.pinverse(fft_kernel)[:, None, :] window = get_window(self.win_type, self.win_len) self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop) window = th.FloatTensor(window) if self.mode == "continue": left_pad = (self.fft_len - self.win_len) // 2 right_pad = left_pad + (self.fft_len - self.win_len) % 2 window = F.pad(window, (left_pad, right_pad)) if self.win_sqrt: self.padded_window = window window = th.sqrt(window) else: self.padded_window = window**2 fft_kernel = fft_kernel.T * window ifft_kernel = ifft_kernel * window ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :] if self.mode == "continue": ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len] return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel def is_perfect(self): """ Whether the parameters win_len, win_hop and win_sqrt obey constants overlap-add(COLA) Returns: bool: Return true if parameters obey COLA. """ return self.perfect_reconstruct and self.pad_center def transform(self, inputs, return_type="complex"): """Take input data (audio) to STFT domain. Args: inputs (tensor): Tensor of floats, with shape (num_batch, num_samples) return_type (str, optional): return (mag, phase) when `magphase`, return (real, imag) when `realimag` and complex(real, imag) when `complex`. Defaults to 'complex'. Returns: tuple: (mag, phase) when `magphase`, return (real, imag) when `realimag`. Defaults to 'complex', each elements with shape [num_batch, num_frequencies, num_frames] """ assert return_type in ["magphase", "realimag", "complex"] if inputs.dim() == 2: inputs = th.unsqueeze(inputs, 1) self.num_samples = inputs.size(-1) if self.pad_center: inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect") enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop) outputs = th.transpose(enframe_inputs, 1, 2) outputs = F.linear(outputs, self.fft_k) outputs = th.transpose(outputs, 1, 2) dim = self.fft_len // 2 + 1 real = outputs[:, :dim, :] imag = outputs[:, dim:, :] if return_type == "realimag": return real, imag elif return_type == "complex": assert support_clp_op return th.complex(real, imag) else: mags = th.sqrt(real**2 + imag**2) phase = th.atan2(imag, real) return mags, phase def inverse(self, input1, input2=None, input_type="magphase"): """Call the inverse STFT (iSTFT), given tensors produced by the `transform` function. Args: input1 (tensors): Magnitude/Real-part of STFT with shape [num_batch, num_frequencies, num_frames] input2 (tensors): Phase/Imag-part of STFT with shape [num_batch, num_frequencies, num_frames] input_type (str, optional): Mathematical meaning of input tensor's. Defaults to 'magphase'. Returns: tensors: Reconstructed audio given magnitude and phase. Of shape [num_batch, num_samples] """ assert input_type in ["magphase", "realimag"] if input_type == "realimag": real, imag = None, None if support_clp_op and th.is_complex(input1): real, imag = input1.real, input1.imag else: real, imag = input1, input2 else: real = input1 * th.cos(input2) imag = input1 * th.sin(input2) inputs = th.cat([real, imag], dim=1) outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop) t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1)) t = t.to(inputs.device) coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop) num_frames = input1.size(-1) num_samples = num_frames * self.win_hop rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples outputs = outputs[..., rm_start:rm_end] coff = coff[..., rm_start:rm_end] coffidx = th.where(coff > 1e-8) outputs[coffidx] = outputs[coffidx] / (coff[coffidx]) return outputs.squeeze(dim=1) def forward(self, inputs): """Take input data (audio) to STFT domain and then back to audio. Args: inputs (tensor): Tensor of floats, with shape [num_batch, num_samples] Returns: tensor: Reconstructed audio given magnitude and phase. Of shape [num_batch, num_samples] """ mag, phase = self.transform(inputs) rec_wav = self.inverse(mag, phase) return rec_wav