| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| try: | |
| from torch.amp import autocast | |
| torch_amp_new = True | |
| except: | |
| from torch.cuda.amp import autocast | |
| torch_amp_new = False | |
| from torchaudio.transforms import AmplitudeToDB, MelSpectrogram | |
| class FeatureExtractor(nn.Module): | |
| def __init__( | |
| self, | |
| cfg, | |
| ): | |
| """ | |
| Feature extraction module. | |
| Args: | |
| params (dict): Parameters for the spectrogram. | |
| aug_config (dict, optional): Configuration for data augmentation. Defaults to None. | |
| top_db (float, optional): Threshold for computing the amplitude to dB. Defaults to None. | |
| norm (str, optional): Normalization method. Defaults to "min_max". | |
| """ | |
| super().__init__() | |
| self.audio2melspec = MelSpectrogram( | |
| n_fft=cfg.melspec.n_fft, | |
| hop_length=cfg.melspec.hop_length, | |
| win_length=cfg.melspec.win_length, | |
| n_mels=cfg.melspec.n_mels, | |
| sample_rate=cfg.audio.sample_rate, | |
| f_min=cfg.melspec.f_min, | |
| f_max=cfg.melspec.f_max, | |
| power=cfg.melspec.power, | |
| ) | |
| self.amplitude_to_db = AmplitudeToDB(top_db=cfg.melspec.top_db) | |
| if cfg.melspec.norm == "mean_std": | |
| self.normalizer = MeanStdNorm() | |
| elif cfg.melspec.norm == "min_max": | |
| self.normalizer = MinMaxNorm() | |
| elif cfg.melspec.norm == "simple": | |
| self.normalizer = SimpleNorm() | |
| else: | |
| self.normalizer = nn.Identity() | |
| def forward(self, x): | |
| """ | |
| Forward pass of the feature extractor. | |
| Args: | |
| x (torch.Tensor): Input audio data. | |
| Returns: | |
| torch.Tensor: Extracted features. | |
| """ | |
| with ( | |
| autocast("cuda", enabled=False) | |
| if torch_amp_new | |
| else autocast(enabled=False) | |
| ): | |
| melspec = self.audio2melspec(x.float()) | |
| melspec = self.amplitude_to_db(melspec) | |
| melspec = self.normalizer(melspec) | |
| return melspec | |
| class MinMaxNorm(nn.Module): | |
| def __init__(self, eps=1e-6): | |
| """ | |
| Module for performing min-max normalization on input data. | |
| Args: | |
| eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6. | |
| """ | |
| super().__init__() | |
| self.eps = eps | |
| def forward(self, X): | |
| """ | |
| Forward pass of the min-max normalization module. | |
| Args: | |
| X (torch.Tensor): Input data. | |
| Returns: | |
| torch.Tensor: Normalized data. | |
| """ | |
| min_ = torch.amax(X, dim=(1, 2), keepdim=True) | |
| max_ = torch.amin(X, dim=(1, 2), keepdim=True) | |
| return (X - min_) / (max_ - min_ + self.eps) | |
| class SimpleNorm(nn.Module): | |
| def __init__(self): | |
| """ | |
| Module for performing simple normalization on input data. | |
| """ | |
| super().__init__() | |
| def forward(self, x): | |
| """ | |
| Forward pass of the simple normalization module. | |
| Args: | |
| x (torch.Tensor): Input data. | |
| Returns: | |
| torch.Tensor: Normalized data. | |
| """ | |
| return (x - 40) / 80 | |
| class MeanStdNorm(nn.Module): | |
| def __init__(self, eps=1e-6): | |
| """ | |
| Module for performing mean and standard deviation normalization on input data. | |
| Args: | |
| eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6. | |
| """ | |
| super().__init__() | |
| self.eps = eps | |
| def forward(self, X): | |
| """ | |
| Forward pass of the mean and standard deviation normalization module. | |
| Args: | |
| X (torch.Tensor): Input data. | |
| Returns: | |
| torch.Tensor: Normalized data. | |
| """ | |
| mean = X.mean((1, 2), keepdim=True) | |
| std = X.reshape(X.size(0), -1).std(1, keepdim=True).unsqueeze(-1) | |
| return (X - mean) / (std + self.eps) | |