xvector-voxceleb1 / audio_processing.py
yangwang825's picture
Upload XVectorForSequenceClassification
99269d6 verified
import math
from packaging import version
from dataclasses import dataclass
from abc import ABC, abstractmethod
import torch
try:
import torchaudio
import torchaudio.functional
import torchaudio.transforms
TORCHAUDIO_VERSION = version.parse(torchaudio.__version__)
TORCHAUDIO_VERSION_MIN = version.parse('0.5')
HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
HAVE_TORCHAUDIO = False
from .logging import logger
from .module import NeuralModule
from .features import FilterbankFeatures, FilterbankFeaturesTA
from .spectrogram_augment import SpecCutout, SpecAugment
class AudioPreprocessor(NeuralModule, ABC):
"""
An interface for Neural Modules that performs audio pre-processing,
transforming the wav files to features.
"""
def __init__(self, win_length, hop_length):
super().__init__()
self.win_length = win_length
self.hop_length = hop_length
self.torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'ones': torch.ones,
None: torch.ones,
}
# Normally, when you call to(dtype) on a torch.nn.Module, all
# floating point parameters and buffers will change to that
# dtype, rather than being float32. The AudioPreprocessor
# classes, uniquely, don't actually have any parameters or
# buffers from what I see. In addition, we want the input to
# the preprocessor to be float32, but need to create the
# output in appropriate precision. We have this empty tensor
# here just to detect which dtype tensor this module should
# output at the end of execution.
self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)
@torch.no_grad()
def forward(self, input_signal, length):
processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
return processed_signal, processed_length
@abstractmethod
def get_features(self, input_signal, length):
# Called by forward(). Subclasses should implement this.
pass
class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
"""Featurizer module that converts wavs to mel spectrograms.
Args:
sample_rate (int): Sample rate of the input audio data.
Defaults to 16000
window_size (float): Size of window for fft in seconds
Defaults to 0.02
window_stride (float): Stride of window for fft in seconds
Defaults to 0.01
n_window_size (int): Size of window for fft in samples
Defaults to None. Use one of window_size or n_window_size.
n_window_stride (int): Stride of window for fft in samples
Defaults to None. Use one of window_stride or n_window_stride.
window (str): Windowing function for fft. can be one of ['hann',
'hamming', 'blackman', 'bartlett']
Defaults to "hann"
normalize (str): Can be one of ['per_feature', 'all_features']; all
other options disable feature normalization. 'all_features'
normalizes the entire spectrogram to be mean 0 with std 1.
'pre_features' normalizes per channel / freq instead.
Defaults to "per_feature"
n_fft (int): Length of FT window. If None, it uses the smallest power
of 2 that is larger than n_window_size.
Defaults to None
preemph (float): Amount of pre emphasis to add to audio. Can be
disabled by passing None.
Defaults to 0.97
features (int): Number of mel spectrogram freq bins to output.
Defaults to 64
lowfreq (int): Lower bound on mel basis in Hz.
Defaults to 0
highfreq (int): Lower bound on mel basis in Hz.
Defaults to None
log (bool): Log features.
Defaults to True
log_zero_guard_type(str): Need to avoid taking the log of zero. There
are two options: "add" or "clamp".
Defaults to "add".
log_zero_guard_value(float, or str): Add or clamp requires the number
to add with or clamp to. log_zero_guard_value can either be a float
or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
passed.
Defaults to 2**-24.
dither (float): Amount of white-noise dithering.
Defaults to 1e-5
pad_to (int): Ensures that the output size of the time dimension is
a multiple of pad_to.
Defaults to 16
frame_splicing (int): Defaults to 1
exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
// hop_length. Defaults to False.
pad_value (float): The value that shorter mels are padded with.
Defaults to 0
mag_power (float): The power that the linear spectrogram is raised to
prior to multiplication with mel basis.
Defaults to 2 for a power spec
rng : Random number generator
nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
samples in the batch.
Defaults to 0.0
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
Defaults to 4000
use_torchaudio: Whether to use the `torchaudio` implementation.
mel_norm: Normalization used for mel filterbank weights.
Defaults to 'slaney' (area normalization)
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
"""
def __init__(
self,
sample_rate=16000,
window_size=0.02,
window_stride=0.01,
n_window_size=None,
n_window_stride=None,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
features=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2**-24,
dither=1e-5,
pad_to=16,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
use_torchaudio: bool = False,
mel_norm="slaney",
):
super().__init__(n_window_size, n_window_stride)
self._sample_rate = sample_rate
if window_size and n_window_size:
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
if window_stride and n_window_stride:
raise ValueError(
f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
)
if window_size:
n_window_size = int(window_size * self._sample_rate)
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)
# Given the long and similar argument list, point to the class and instantiate it by reference
if not use_torchaudio:
logger.warning("Current only support FilterbankFeatures with torchaudio.")
featurizer_class = FilterbankFeaturesTA
else:
featurizer_class = FilterbankFeaturesTA
self.featurizer = featurizer_class(
sample_rate=self._sample_rate,
n_window_size=n_window_size,
n_window_stride=n_window_stride,
window=window,
normalize=normalize,
n_fft=n_fft,
preemph=preemph,
nfilt=features,
lowfreq=lowfreq,
highfreq=highfreq,
log=log,
log_zero_guard_type=log_zero_guard_type,
log_zero_guard_value=log_zero_guard_value,
dither=dither,
pad_to=pad_to,
frame_splicing=frame_splicing,
exact_pad=exact_pad,
pad_value=pad_value,
mag_power=mag_power,
rng=rng,
nb_augmentation_prob=nb_augmentation_prob,
nb_max_freq=nb_max_freq,
mel_norm=mel_norm,
)
def get_features(self, input_signal, length):
return self.featurizer(input_signal, length) # return tensor shape of (B, D, T)
@property
def filter_banks(self):
return self.featurizer.filter_banks
class AudioToMFCCPreprocessor(AudioPreprocessor):
"""Preprocessor that converts wavs to MFCCs.
Uses torchaudio.transforms.MFCC.
Args:
sample_rate: The sample rate of the audio.
Defaults to 16000.
window_size: Size of window for fft in seconds. Used to calculate the
win_length arg for mel spectrogram.
Defaults to 0.02
window_stride: Stride of window for fft in seconds. Used to caculate
the hop_length arg for mel spect.
Defaults to 0.01
n_window_size: Size of window for fft in samples
Defaults to None. Use one of window_size or n_window_size.
n_window_stride: Stride of window for fft in samples
Defaults to None. Use one of window_stride or n_window_stride.
window: Windowing function for fft. can be one of ['hann',
'hamming', 'blackman', 'bartlett', 'none', 'null'].
Defaults to 'hann'
n_fft: Length of FT window. If None, it uses the smallest power of 2
that is larger than n_window_size.
Defaults to None
lowfreq (int): Lower bound on mel basis in Hz.
Defaults to 0
highfreq (int): Lower bound on mel basis in Hz.
Defaults to None
n_mels: Number of mel filterbanks.
Defaults to 64
n_mfcc: Number of coefficients to retain
Defaults to 64
dct_type: Type of discrete cosine transform to use
norm: Type of norm to use
log: Whether to use log-mel spectrograms instead of db-scaled.
Defaults to True.
"""
def __init__(
self,
sample_rate=16000,
window_size=0.02,
window_stride=0.01,
n_window_size=None,
n_window_stride=None,
window='hann',
n_fft=None,
lowfreq=0.0,
highfreq=None,
n_mels=64,
n_mfcc=64,
dct_type=2,
norm='ortho',
log=True,
):
self._sample_rate = sample_rate
if not HAVE_TORCHAUDIO:
logger.warning('Could not import torchaudio. Some features might not work.')
raise ModuleNotFoundError(
"torchaudio is not installed but is necessary for "
"AudioToMFCCPreprocessor. We recommend you try "
"building it from source for the PyTorch version you have."
)
if window_size and n_window_size:
raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
if window_stride and n_window_stride:
raise ValueError(
f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
)
# Get win_length (n_window_size) and hop_length (n_window_stride)
if window_size:
n_window_size = int(window_size * self._sample_rate)
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)
super().__init__(n_window_size, n_window_stride)
mel_kwargs = {}
mel_kwargs['f_min'] = lowfreq
mel_kwargs['f_max'] = highfreq
mel_kwargs['n_mels'] = n_mels
mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size))
mel_kwargs['win_length'] = n_window_size
mel_kwargs['hop_length'] = n_window_stride
# Set window_fn. None defaults to torch.ones.
window_fn = self.torch_windows.get(window, None)
if window_fn is None:
raise ValueError(
f"Window argument for AudioProcessor is invalid: {window}."
f"For no window function, use 'ones' or None."
)
mel_kwargs['window_fn'] = window_fn
# Use torchaudio's implementation of MFCCs as featurizer
self.featurizer = torchaudio.transforms.MFCC(
sample_rate=self._sample_rate,
n_mfcc=n_mfcc,
dct_type=dct_type,
norm=norm,
log_mels=log,
melkwargs=mel_kwargs,
)
def get_features(self, input_signal, length):
features = self.featurizer(input_signal)
seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long)
return features, seq_len
class SpectrogramAugmentation(NeuralModule):
"""
Performs time and freq cuts in one of two ways.
SpecAugment zeroes out vertical and horizontal sections as described in
SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with
SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`.
SpecCutout zeroes out rectangulars as described in Cutout
(https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are
`rect_masks`, `rect_freq`, and `rect_time`.
Args:
freq_masks (int): how many frequency segments should be cut.
Defaults to 0.
time_masks (int): how many time segments should be cut
Defaults to 0.
freq_width (int): maximum number of frequencies to be cut in one
segment.
Defaults to 10.
time_width (int): maximum number of time steps to be cut in one
segment
Defaults to 10.
rect_masks (int): how many rectangular masks should be cut
Defaults to 0.
rect_freq (int): maximum size of cut rectangles along the frequency
dimension
Defaults to 5.
rect_time (int): maximum size of cut rectangles along the time
dimension
Defaults to 25.
use_numba_spec_augment: use numba code for Spectrogram augmentation
use_vectorized_spec_augment: use vectorized code for Spectrogram augmentation
"""
def __init__(
self,
freq_masks=0,
time_masks=0,
freq_width=10,
time_width=10,
rect_masks=0,
rect_time=5,
rect_freq=20,
rng=None,
mask_value=0.0,
use_vectorized_spec_augment: bool = True,
):
super().__init__()
if rect_masks > 0:
self.spec_cutout = SpecCutout(
rect_masks=rect_masks,
rect_time=rect_time,
rect_freq=rect_freq,
rng=rng,
)
# self.spec_cutout.to(self._device)
else:
self.spec_cutout = lambda input_spec: input_spec
if freq_masks + time_masks > 0:
self.spec_augment = SpecAugment(
freq_masks=freq_masks,
time_masks=time_masks,
freq_width=freq_width,
time_width=time_width,
rng=rng,
mask_value=mask_value,
use_vectorized_code=use_vectorized_spec_augment,
)
else:
self.spec_augment = lambda input_spec, length: input_spec
def forward(self, input_spec, length):
augmented_spec = self.spec_cutout(input_spec=input_spec)
augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
return augmented_spec # # return tensor shape of (B, D, T)