Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False): | |
| """ | |
| Compute magnitude and phase using STFT. | |
| Args: | |
| y (torch.Tensor): Input audio signal. | |
| n_fft (int): FFT size. | |
| hop_size (int): Hop size. | |
| win_size (int): Window size. | |
| compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0. | |
| center (bool, optional): Whether to center the signal before padding. Defaults to True. | |
| eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False. | |
| Returns: | |
| tuple: Magnitude, phase, and complex representation of the STFT. | |
| """ | |
| #eps = torch.finfo(y.dtype).eps | |
| eps = 1e-10 | |
| hann_window = torch.hann_window(win_size).to(y.device) | |
| stft_spec = torch.stft( | |
| y, n_fft, | |
| hop_length=hop_size, | |
| win_length=win_size, | |
| window=hann_window, | |
| center=center, | |
| pad_mode='reflect', | |
| normalized=False, | |
| return_complex=True) | |
| if addeps==False: | |
| mag = torch.abs(stft_spec) | |
| pha = torch.angle(stft_spec) | |
| else: | |
| real_part = stft_spec.real | |
| imag_part = stft_spec.imag | |
| mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps) | |
| pha = torch.atan2(imag_part + eps, real_part + eps) | |
| # Compress the magnitude | |
| mag = torch.pow(mag, compress_factor) | |
| com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1) | |
| return mag, pha, com | |
| def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): | |
| """ | |
| Inverse STFT to reconstruct the audio signal from magnitude and phase. | |
| Args: | |
| mag (torch.Tensor): Magnitude of the STFT. | |
| pha (torch.Tensor): Phase of the STFT. | |
| n_fft (int): FFT size. | |
| hop_size (int): Hop size. | |
| win_size (int): Window size. | |
| compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0. | |
| center (bool, optional): Whether to center the signal before padding. Defaults to True. | |
| Returns: | |
| torch.Tensor: Reconstructed audio signal. | |
| """ | |
| mag = torch.pow(mag, 1.0 / compress_factor) | |
| com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha)) | |
| hann_window = torch.hann_window(win_size).to(com.device) | |
| wav = torch.istft( | |
| com, | |
| n_fft, | |
| hop_length=hop_size, | |
| win_length=win_size, | |
| window=hann_window, | |
| center=center) | |
| return wav | |