Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import torch.nn as nn | |
try: | |
from torch.amp import autocast | |
torch_amp_new = True | |
except: | |
from torch.cuda.amp import autocast | |
torch_amp_new = False | |
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram | |
class FeatureExtractor(nn.Module): | |
def __init__( | |
self, | |
cfg, | |
): | |
""" | |
Feature extraction module. | |
Args: | |
params (dict): Parameters for the spectrogram. | |
aug_config (dict, optional): Configuration for data augmentation. Defaults to None. | |
top_db (float, optional): Threshold for computing the amplitude to dB. Defaults to None. | |
norm (str, optional): Normalization method. Defaults to "min_max". | |
""" | |
super().__init__() | |
self.audio2melspec = MelSpectrogram( | |
n_fft=cfg.melspec.n_fft, | |
hop_length=cfg.melspec.hop_length, | |
win_length=cfg.melspec.win_length, | |
n_mels=cfg.melspec.n_mels, | |
sample_rate=cfg.audio.sample_rate, | |
f_min=cfg.melspec.f_min, | |
f_max=cfg.melspec.f_max, | |
power=cfg.melspec.power, | |
) | |
self.amplitude_to_db = AmplitudeToDB(top_db=cfg.melspec.top_db) | |
if cfg.melspec.norm == "mean_std": | |
self.normalizer = MeanStdNorm() | |
elif cfg.melspec.norm == "min_max": | |
self.normalizer = MinMaxNorm() | |
elif cfg.melspec.norm == "simple": | |
self.normalizer = SimpleNorm() | |
else: | |
self.normalizer = nn.Identity() | |
def forward(self, x): | |
""" | |
Forward pass of the feature extractor. | |
Args: | |
x (torch.Tensor): Input audio data. | |
Returns: | |
torch.Tensor: Extracted features. | |
""" | |
with ( | |
autocast("cuda", enabled=False) | |
if torch_amp_new | |
else autocast(enabled=False) | |
): | |
melspec = self.audio2melspec(x.float()) | |
melspec = self.amplitude_to_db(melspec) | |
melspec = self.normalizer(melspec) | |
return melspec | |
class MinMaxNorm(nn.Module): | |
def __init__(self, eps=1e-6): | |
""" | |
Module for performing min-max normalization on input data. | |
Args: | |
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6. | |
""" | |
super().__init__() | |
self.eps = eps | |
def forward(self, X): | |
""" | |
Forward pass of the min-max normalization module. | |
Args: | |
X (torch.Tensor): Input data. | |
Returns: | |
torch.Tensor: Normalized data. | |
""" | |
min_ = torch.amax(X, dim=(1, 2), keepdim=True) | |
max_ = torch.amin(X, dim=(1, 2), keepdim=True) | |
return (X - min_) / (max_ - min_ + self.eps) | |
class SimpleNorm(nn.Module): | |
def __init__(self): | |
""" | |
Module for performing simple normalization on input data. | |
""" | |
super().__init__() | |
def forward(self, x): | |
""" | |
Forward pass of the simple normalization module. | |
Args: | |
x (torch.Tensor): Input data. | |
Returns: | |
torch.Tensor: Normalized data. | |
""" | |
return (x - 40) / 80 | |
class MeanStdNorm(nn.Module): | |
def __init__(self, eps=1e-6): | |
""" | |
Module for performing mean and standard deviation normalization on input data. | |
Args: | |
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-6. | |
""" | |
super().__init__() | |
self.eps = eps | |
def forward(self, X): | |
""" | |
Forward pass of the mean and standard deviation normalization module. | |
Args: | |
X (torch.Tensor): Input data. | |
Returns: | |
torch.Tensor: Normalized data. | |
""" | |
mean = X.mean((1, 2), keepdim=True) | |
std = X.reshape(X.size(0), -1).std(1, keepdim=True).unsqueeze(-1) | |
return (X - mean) / (std + self.eps) | |