import torch from torch import Tensor import torch.nn as nn import torchaudio class LinearSpectrogram(nn.Module): def __init__(self, n_fft, win_length, hop_length, pad, center, pad_mode): super().__init__() self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.pad = pad self.center = center self.pad_mode = pad_mode self.register_buffer("window", torch.hann_window(win_length)) def forward(self, waveform: Tensor) -> Tensor: if waveform.ndim == 3: waveform = waveform.squeeze(1) waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (self.pad, self.pad), self.pad_mode).squeeze(1) spec = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length, self.window, self.center, self.pad_mode, False, True, True) spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec class LogMelSpectrogram(nn.Module): def __init__(self, sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, center, pad_mode, mel_scale): super().__init__() self.sample_rate = sample_rate self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.f_min = f_min self.f_max = f_max self.pad = pad self.n_mels = n_mels self.center = center self.pad_mode = pad_mode self.mel_scale = mel_scale self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, pad, center, pad_mode) self.mel_scale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min, f_max, (n_fft//2)+1, mel_scale, mel_scale) def compress(self, x: Tensor) -> Tensor: return torch.log(torch.clamp(x, min=1e-5)) def decompress(self, x: Tensor) -> Tensor: return torch.exp(x) def forward(self, x: Tensor) -> Tensor: linear_spec = self.spectrogram(x) x = self.mel_scale(linear_spec) x = self.compress(x) return x def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor: try: y, sr = torchaudio.load(audio_path) except Exception as e: print(str(e)) return None y.to(device) # Convert to mono if y.size(0) > 1: y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time] # resample audio to target sample_rate if sr != target_sr: y = torchaudio.functional.resample(y, sr, target_sr) return y