import torch import torch.nn.functional as F import torchaudio import torchaudio.transforms as transforms class LogMelSpectrogram(torch.nn.Module): def __init__(self, sr=24000, frame_length=1920, hop_length=480, n_mel=128, f_min=0, f_max=12000,): super().__init__() self.frame_length = frame_length self.hop_length = hop_length self.mel = transforms.MelSpectrogram( sample_rate=sr, n_fft=frame_length, win_length=frame_length, hop_length=hop_length, center=False, power=1.0, norm="slaney", n_mels=n_mel, mel_scale="slaney", f_min=f_min, f_max=f_max ) @torch.no_grad() def forward(self, x, target_length=None): x = F.pad(x, ((self.frame_length - self.hop_length) // 2, (self.frame_length - self.hop_length) // 2), "reflect") mel = self.mel(x) target_length = mel.shape[-1] if target_length is None else target_length logmel = torch.zeros(mel.shape[0], mel.shape[1], target_length).to(mel.device) logmel[:, :, :mel.shape[2]] = mel logmel = torch.log(torch.clamp(logmel, min=1e-5)) return logmel