xvector-voxceleb1 / spectrogram_augment.py
yangwang825's picture
Upload XVectorForSequenceClassification
99269d6 verified
import random
from typing import Union
import numpy as np
import torch
import torch.nn as nn
class SpecAugment(nn.Module):
"""
Zeroes out(cuts) random continuous horisontal or
vertical segments of the spectrogram as described in
SpecAugment (https://arxiv.org/abs/1904.08779).
params:
freq_masks - how many frequency segments should be cut
time_masks - how many time segments should be cut
freq_width - maximum number of frequencies to be cut in one segment
time_width - maximum number of time steps to be cut in one segment.
Can be a positive integer or a float value in the range [0, 1].
If positive integer value, defines maximum number of time steps
to be cut in one segment.
If a float value, defines maximum percentage of timesteps that
are cut adaptively.
use_vectorized_code - GPU-based implementation with batched masking and GPU rng,
setting it to False reverts to the legacy implementation.
Fast implementation is inspired by torchaudio:
https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816
"""
FREQ_AXIS = 1 # Frequency axis in the spectrogram tensor
TIME_AXIS = 2 # Time axis in the spectrogram tensor
def __init__(
self,
freq_masks: int = 0,
time_masks: int = 0,
freq_width: int = 10,
time_width: Union[int, float] = 10,
rng: random.Random = None,
mask_value: float = 0.0,
use_vectorized_code: bool = True,
):
super().__init__()
self._rng = random.Random() if rng is None else rng
self.freq_masks = freq_masks
self.time_masks = time_masks
self.freq_width = freq_width
self.time_width = time_width
self.mask_value = mask_value
self.use_vectorized_code = use_vectorized_code
if isinstance(time_width, int):
self.adaptive_temporal_width = False
else:
if time_width > 1.0 or time_width < 0.0:
raise ValueError("If `time_width` is a float value, must be in range [0, 1]")
self.adaptive_temporal_width = True
@torch.no_grad()
def forward(self, input_spec, length):
if self.use_vectorized_code:
return self._forward_vectorized(input_spec, length)
else:
return self._forward_legacy(input_spec, length)
def _forward_legacy(self, input_spec, length):
batch_size, num_freq_bins, _ = input_spec.shape
# Move lengths to CPU before repeated indexing
lengths_cpu = length.cpu().numpy()
# Generate a numpy boolean mask. `True` elements represent where the input spec will be augmented.
fill_mask: np.array = np.full(shape=input_spec.shape, fill_value=False)
freq_start_upper_bound = num_freq_bins - self.freq_width
# Choose different mask ranges for each element of the batch
for idx in range(batch_size):
# Set freq masking
for _ in range(self.freq_masks):
start = self._rng.randint(0, freq_start_upper_bound)
width = self._rng.randint(0, self.freq_width)
fill_mask[idx, start : start + width, :] = True
# Derive time width, sometimes based percentage of input length.
if self.adaptive_temporal_width:
time_max_width = max(1, int(lengths_cpu[idx] * self.time_width))
else:
time_max_width = self.time_width
time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width)
# Set time masking
for _ in range(self.time_masks):
start = self._rng.randint(0, time_start_upper_bound)
width = self._rng.randint(0, time_max_width)
fill_mask[idx, :, start : start + width] = True
# Bring the mask to device and fill spec
fill_mask = torch.from_numpy(fill_mask).to(input_spec.device)
masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
return masked_spec
def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
# time masks
input_spec = self._apply_masks(
input_spec=input_spec,
num_masks=self.time_masks,
length=length,
width=self.time_width,
axis=self.TIME_AXIS,
mask_value=self.mask_value,
)
# freq masks
input_spec = self._apply_masks(
input_spec=input_spec,
num_masks=self.freq_masks,
length=length,
width=self.freq_width,
axis=self.FREQ_AXIS,
mask_value=self.mask_value,
)
return input_spec
def _apply_masks(
self,
input_spec: torch.Tensor,
num_masks: int,
length: torch.Tensor,
width: Union[int, float],
mask_value: float,
axis: int,
) -> torch.Tensor:
assert axis in (
self.FREQ_AXIS,
self.TIME_AXIS,
), f"Axis can be only be equal to frequency \
({self.FREQ_AXIS}) or time ({self.TIME_AXIS}). Received: {axis=}"
assert not (
isinstance(width, float) and axis == self.FREQ_AXIS
), "Float width supported \
only with time axis."
batch_size = input_spec.shape[0]
axis_length = input_spec.shape[axis]
# If width is float then it is transformed into a tensor
if axis == self.TIME_AXIS and isinstance(width, float):
width = torch.clamp(width * length, max=axis_length).unsqueeze(1)
# Generate [0-1) random numbers and then scale the tensors.
# Use float32 dtype for begin/end mask markers before they are quantized to long.
mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width
mask_width = mask_width.long()
mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32)
if axis == self.TIME_AXIS:
# length can only be used for the time axis
mask_start = mask_start * (length.unsqueeze(1) - mask_width)
else:
mask_start = mask_start * (axis_length - mask_width)
mask_start = mask_start.long()
mask_end = mask_start + mask_width
# Create mask values using vectorized indexing
indices = torch.arange(axis_length, device=input_spec.device)
# Create a mask_tensor with all the indices.
# The mask_tensor shape is (batch_size, num_masks, axis_length).
mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1))
# Reduce masks to one mask
mask_tensor = mask_tensor.any(dim=1)
# Create a final mask that aligns with the full tensor
mask = torch.zeros_like(input_spec, dtype=torch.bool)
if axis == self.TIME_AXIS:
mask_ranges = mask_tensor[:, None, :]
else: # axis == self.FREQ_AXIS
mask_ranges = mask_tensor[:, :, None]
mask[:, :, :] = mask_ranges
# Apply the mask value
return input_spec.masked_fill(mask=mask, value=mask_value)
class SpecCutout(nn.Module):
"""
Zeroes out(cuts) random rectangles in the spectrogram
as described in (https://arxiv.org/abs/1708.04552).
params:
rect_masks - how many rectangular masks should be cut
rect_freq - maximum size of cut rectangles along the frequency dimension
rect_time - maximum size of cut rectangles along the time dimension
"""
def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None):
super(SpecCutout, self).__init__()
self._rng = random.Random() if rng is None else rng
self.rect_masks = rect_masks
self.rect_time = rect_time
self.rect_freq = rect_freq
@torch.no_grad()
def forward(self, input_spec):
sh = input_spec.shape
for idx in range(sh[0]):
for i in range(self.rect_masks):
rect_x = self._rng.randint(0, sh[1] - self.rect_freq)
rect_y = self._rng.randint(0, sh[2] - self.rect_time)
w_x = self._rng.randint(0, self.rect_freq)
w_y = self._rng.randint(0, self.rect_time)
input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
return input_spec