|
import math |
|
from packaging import version |
|
from dataclasses import dataclass |
|
from abc import ABC, abstractmethod |
|
|
|
import torch |
|
|
|
try: |
|
import torchaudio |
|
import torchaudio.functional |
|
import torchaudio.transforms |
|
|
|
TORCHAUDIO_VERSION = version.parse(torchaudio.__version__) |
|
TORCHAUDIO_VERSION_MIN = version.parse('0.5') |
|
|
|
HAVE_TORCHAUDIO = True |
|
except ModuleNotFoundError: |
|
HAVE_TORCHAUDIO = False |
|
|
|
from .logging import logger |
|
from .module import NeuralModule |
|
from .features import FilterbankFeatures, FilterbankFeaturesTA |
|
from .spectrogram_augment import SpecCutout, SpecAugment |
|
|
|
|
|
class AudioPreprocessor(NeuralModule, ABC): |
|
""" |
|
An interface for Neural Modules that performs audio pre-processing, |
|
transforming the wav files to features. |
|
""" |
|
|
|
def __init__(self, win_length, hop_length): |
|
super().__init__() |
|
|
|
self.win_length = win_length |
|
self.hop_length = hop_length |
|
|
|
self.torch_windows = { |
|
'hann': torch.hann_window, |
|
'hamming': torch.hamming_window, |
|
'blackman': torch.blackman_window, |
|
'bartlett': torch.bartlett_window, |
|
'ones': torch.ones, |
|
None: torch.ones, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False) |
|
|
|
@torch.no_grad() |
|
def forward(self, input_signal, length): |
|
processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length) |
|
processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype) |
|
return processed_signal, processed_length |
|
|
|
@abstractmethod |
|
def get_features(self, input_signal, length): |
|
|
|
pass |
|
|
|
|
|
class AudioToMelSpectrogramPreprocessor(AudioPreprocessor): |
|
"""Featurizer module that converts wavs to mel spectrograms. |
|
|
|
Args: |
|
sample_rate (int): Sample rate of the input audio data. |
|
Defaults to 16000 |
|
window_size (float): Size of window for fft in seconds |
|
Defaults to 0.02 |
|
window_stride (float): Stride of window for fft in seconds |
|
Defaults to 0.01 |
|
n_window_size (int): Size of window for fft in samples |
|
Defaults to None. Use one of window_size or n_window_size. |
|
n_window_stride (int): Stride of window for fft in samples |
|
Defaults to None. Use one of window_stride or n_window_stride. |
|
window (str): Windowing function for fft. can be one of ['hann', |
|
'hamming', 'blackman', 'bartlett'] |
|
Defaults to "hann" |
|
normalize (str): Can be one of ['per_feature', 'all_features']; all |
|
other options disable feature normalization. 'all_features' |
|
normalizes the entire spectrogram to be mean 0 with std 1. |
|
'pre_features' normalizes per channel / freq instead. |
|
Defaults to "per_feature" |
|
n_fft (int): Length of FT window. If None, it uses the smallest power |
|
of 2 that is larger than n_window_size. |
|
Defaults to None |
|
preemph (float): Amount of pre emphasis to add to audio. Can be |
|
disabled by passing None. |
|
Defaults to 0.97 |
|
features (int): Number of mel spectrogram freq bins to output. |
|
Defaults to 64 |
|
lowfreq (int): Lower bound on mel basis in Hz. |
|
Defaults to 0 |
|
highfreq (int): Lower bound on mel basis in Hz. |
|
Defaults to None |
|
log (bool): Log features. |
|
Defaults to True |
|
log_zero_guard_type(str): Need to avoid taking the log of zero. There |
|
are two options: "add" or "clamp". |
|
Defaults to "add". |
|
log_zero_guard_value(float, or str): Add or clamp requires the number |
|
to add with or clamp to. log_zero_guard_value can either be a float |
|
or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is |
|
passed. |
|
Defaults to 2**-24. |
|
dither (float): Amount of white-noise dithering. |
|
Defaults to 1e-5 |
|
pad_to (int): Ensures that the output size of the time dimension is |
|
a multiple of pad_to. |
|
Defaults to 16 |
|
frame_splicing (int): Defaults to 1 |
|
exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length |
|
// hop_length. Defaults to False. |
|
pad_value (float): The value that shorter mels are padded with. |
|
Defaults to 0 |
|
mag_power (float): The power that the linear spectrogram is raised to |
|
prior to multiplication with mel basis. |
|
Defaults to 2 for a power spec |
|
rng : Random number generator |
|
nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to |
|
samples in the batch. |
|
Defaults to 0.0 |
|
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation. |
|
Defaults to 4000 |
|
use_torchaudio: Whether to use the `torchaudio` implementation. |
|
mel_norm: Normalization used for mel filterbank weights. |
|
Defaults to 'slaney' (area normalization) |
|
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints. |
|
stft_conv: Deprecated argument, kept for compatibility with older checkpoints. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sample_rate=16000, |
|
window_size=0.02, |
|
window_stride=0.01, |
|
n_window_size=None, |
|
n_window_stride=None, |
|
window="hann", |
|
normalize="per_feature", |
|
n_fft=None, |
|
preemph=0.97, |
|
features=64, |
|
lowfreq=0, |
|
highfreq=None, |
|
log=True, |
|
log_zero_guard_type="add", |
|
log_zero_guard_value=2**-24, |
|
dither=1e-5, |
|
pad_to=16, |
|
frame_splicing=1, |
|
exact_pad=False, |
|
pad_value=0, |
|
mag_power=2.0, |
|
rng=None, |
|
nb_augmentation_prob=0.0, |
|
nb_max_freq=4000, |
|
use_torchaudio: bool = False, |
|
mel_norm="slaney", |
|
): |
|
super().__init__(n_window_size, n_window_stride) |
|
|
|
self._sample_rate = sample_rate |
|
if window_size and n_window_size: |
|
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") |
|
if window_stride and n_window_stride: |
|
raise ValueError( |
|
f"{self} received both window_stride and " f"n_window_stride. Only one should be specified." |
|
) |
|
if window_size: |
|
n_window_size = int(window_size * self._sample_rate) |
|
if window_stride: |
|
n_window_stride = int(window_stride * self._sample_rate) |
|
|
|
|
|
if not use_torchaudio: |
|
logger.warning("Current only support FilterbankFeatures with torchaudio.") |
|
featurizer_class = FilterbankFeaturesTA |
|
else: |
|
featurizer_class = FilterbankFeaturesTA |
|
self.featurizer = featurizer_class( |
|
sample_rate=self._sample_rate, |
|
n_window_size=n_window_size, |
|
n_window_stride=n_window_stride, |
|
window=window, |
|
normalize=normalize, |
|
n_fft=n_fft, |
|
preemph=preemph, |
|
nfilt=features, |
|
lowfreq=lowfreq, |
|
highfreq=highfreq, |
|
log=log, |
|
log_zero_guard_type=log_zero_guard_type, |
|
log_zero_guard_value=log_zero_guard_value, |
|
dither=dither, |
|
pad_to=pad_to, |
|
frame_splicing=frame_splicing, |
|
exact_pad=exact_pad, |
|
pad_value=pad_value, |
|
mag_power=mag_power, |
|
rng=rng, |
|
nb_augmentation_prob=nb_augmentation_prob, |
|
nb_max_freq=nb_max_freq, |
|
mel_norm=mel_norm, |
|
) |
|
|
|
def get_features(self, input_signal, length): |
|
return self.featurizer(input_signal, length) |
|
|
|
@property |
|
def filter_banks(self): |
|
return self.featurizer.filter_banks |
|
|
|
|
|
class AudioToMFCCPreprocessor(AudioPreprocessor): |
|
"""Preprocessor that converts wavs to MFCCs. |
|
Uses torchaudio.transforms.MFCC. |
|
|
|
Args: |
|
sample_rate: The sample rate of the audio. |
|
Defaults to 16000. |
|
window_size: Size of window for fft in seconds. Used to calculate the |
|
win_length arg for mel spectrogram. |
|
Defaults to 0.02 |
|
window_stride: Stride of window for fft in seconds. Used to caculate |
|
the hop_length arg for mel spect. |
|
Defaults to 0.01 |
|
n_window_size: Size of window for fft in samples |
|
Defaults to None. Use one of window_size or n_window_size. |
|
n_window_stride: Stride of window for fft in samples |
|
Defaults to None. Use one of window_stride or n_window_stride. |
|
window: Windowing function for fft. can be one of ['hann', |
|
'hamming', 'blackman', 'bartlett', 'none', 'null']. |
|
Defaults to 'hann' |
|
n_fft: Length of FT window. If None, it uses the smallest power of 2 |
|
that is larger than n_window_size. |
|
Defaults to None |
|
lowfreq (int): Lower bound on mel basis in Hz. |
|
Defaults to 0 |
|
highfreq (int): Lower bound on mel basis in Hz. |
|
Defaults to None |
|
n_mels: Number of mel filterbanks. |
|
Defaults to 64 |
|
n_mfcc: Number of coefficients to retain |
|
Defaults to 64 |
|
dct_type: Type of discrete cosine transform to use |
|
norm: Type of norm to use |
|
log: Whether to use log-mel spectrograms instead of db-scaled. |
|
Defaults to True. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sample_rate=16000, |
|
window_size=0.02, |
|
window_stride=0.01, |
|
n_window_size=None, |
|
n_window_stride=None, |
|
window='hann', |
|
n_fft=None, |
|
lowfreq=0.0, |
|
highfreq=None, |
|
n_mels=64, |
|
n_mfcc=64, |
|
dct_type=2, |
|
norm='ortho', |
|
log=True, |
|
): |
|
self._sample_rate = sample_rate |
|
if not HAVE_TORCHAUDIO: |
|
logger.warning('Could not import torchaudio. Some features might not work.') |
|
|
|
raise ModuleNotFoundError( |
|
"torchaudio is not installed but is necessary for " |
|
"AudioToMFCCPreprocessor. We recommend you try " |
|
"building it from source for the PyTorch version you have." |
|
) |
|
if window_size and n_window_size: |
|
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.") |
|
if window_stride and n_window_stride: |
|
raise ValueError( |
|
f"{self} received both window_stride and " f"n_window_stride. Only one should be specified." |
|
) |
|
|
|
if window_size: |
|
n_window_size = int(window_size * self._sample_rate) |
|
if window_stride: |
|
n_window_stride = int(window_stride * self._sample_rate) |
|
|
|
super().__init__(n_window_size, n_window_stride) |
|
|
|
mel_kwargs = {} |
|
|
|
mel_kwargs['f_min'] = lowfreq |
|
mel_kwargs['f_max'] = highfreq |
|
mel_kwargs['n_mels'] = n_mels |
|
|
|
mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size)) |
|
|
|
mel_kwargs['win_length'] = n_window_size |
|
mel_kwargs['hop_length'] = n_window_stride |
|
|
|
|
|
window_fn = self.torch_windows.get(window, None) |
|
if window_fn is None: |
|
raise ValueError( |
|
f"Window argument for AudioProcessor is invalid: {window}." |
|
f"For no window function, use 'ones' or None." |
|
) |
|
mel_kwargs['window_fn'] = window_fn |
|
|
|
|
|
self.featurizer = torchaudio.transforms.MFCC( |
|
sample_rate=self._sample_rate, |
|
n_mfcc=n_mfcc, |
|
dct_type=dct_type, |
|
norm=norm, |
|
log_mels=log, |
|
melkwargs=mel_kwargs, |
|
) |
|
|
|
def get_features(self, input_signal, length): |
|
features = self.featurizer(input_signal) |
|
seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long) |
|
return features, seq_len |
|
|
|
|
|
class SpectrogramAugmentation(NeuralModule): |
|
""" |
|
Performs time and freq cuts in one of two ways. |
|
SpecAugment zeroes out vertical and horizontal sections as described in |
|
SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with |
|
SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`. |
|
SpecCutout zeroes out rectangulars as described in Cutout |
|
(https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are |
|
`rect_masks`, `rect_freq`, and `rect_time`. |
|
|
|
Args: |
|
freq_masks (int): how many frequency segments should be cut. |
|
Defaults to 0. |
|
time_masks (int): how many time segments should be cut |
|
Defaults to 0. |
|
freq_width (int): maximum number of frequencies to be cut in one |
|
segment. |
|
Defaults to 10. |
|
time_width (int): maximum number of time steps to be cut in one |
|
segment |
|
Defaults to 10. |
|
rect_masks (int): how many rectangular masks should be cut |
|
Defaults to 0. |
|
rect_freq (int): maximum size of cut rectangles along the frequency |
|
dimension |
|
Defaults to 5. |
|
rect_time (int): maximum size of cut rectangles along the time |
|
dimension |
|
Defaults to 25. |
|
use_numba_spec_augment: use numba code for Spectrogram augmentation |
|
use_vectorized_spec_augment: use vectorized code for Spectrogram augmentation |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
freq_masks=0, |
|
time_masks=0, |
|
freq_width=10, |
|
time_width=10, |
|
rect_masks=0, |
|
rect_time=5, |
|
rect_freq=20, |
|
rng=None, |
|
mask_value=0.0, |
|
use_vectorized_spec_augment: bool = True, |
|
): |
|
super().__init__() |
|
|
|
if rect_masks > 0: |
|
self.spec_cutout = SpecCutout( |
|
rect_masks=rect_masks, |
|
rect_time=rect_time, |
|
rect_freq=rect_freq, |
|
rng=rng, |
|
) |
|
|
|
else: |
|
self.spec_cutout = lambda input_spec: input_spec |
|
if freq_masks + time_masks > 0: |
|
self.spec_augment = SpecAugment( |
|
freq_masks=freq_masks, |
|
time_masks=time_masks, |
|
freq_width=freq_width, |
|
time_width=time_width, |
|
rng=rng, |
|
mask_value=mask_value, |
|
use_vectorized_code=use_vectorized_spec_augment, |
|
) |
|
else: |
|
self.spec_augment = lambda input_spec, length: input_spec |
|
|
|
def forward(self, input_spec, length): |
|
augmented_spec = self.spec_cutout(input_spec=input_spec) |
|
augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length) |
|
return augmented_spec |