|
import random |
|
from typing import Union |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class SpecAugment(nn.Module): |
|
""" |
|
Zeroes out(cuts) random continuous horisontal or |
|
vertical segments of the spectrogram as described in |
|
SpecAugment (https://arxiv.org/abs/1904.08779). |
|
|
|
params: |
|
freq_masks - how many frequency segments should be cut |
|
time_masks - how many time segments should be cut |
|
freq_width - maximum number of frequencies to be cut in one segment |
|
time_width - maximum number of time steps to be cut in one segment. |
|
Can be a positive integer or a float value in the range [0, 1]. |
|
If positive integer value, defines maximum number of time steps |
|
to be cut in one segment. |
|
If a float value, defines maximum percentage of timesteps that |
|
are cut adaptively. |
|
use_vectorized_code - GPU-based implementation with batched masking and GPU rng, |
|
setting it to False reverts to the legacy implementation. |
|
Fast implementation is inspired by torchaudio: |
|
https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816 |
|
""" |
|
|
|
FREQ_AXIS = 1 |
|
TIME_AXIS = 2 |
|
|
|
def __init__( |
|
self, |
|
freq_masks: int = 0, |
|
time_masks: int = 0, |
|
freq_width: int = 10, |
|
time_width: Union[int, float] = 10, |
|
rng: random.Random = None, |
|
mask_value: float = 0.0, |
|
use_vectorized_code: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self._rng = random.Random() if rng is None else rng |
|
|
|
self.freq_masks = freq_masks |
|
self.time_masks = time_masks |
|
|
|
self.freq_width = freq_width |
|
self.time_width = time_width |
|
|
|
self.mask_value = mask_value |
|
self.use_vectorized_code = use_vectorized_code |
|
|
|
if isinstance(time_width, int): |
|
self.adaptive_temporal_width = False |
|
else: |
|
if time_width > 1.0 or time_width < 0.0: |
|
raise ValueError("If `time_width` is a float value, must be in range [0, 1]") |
|
|
|
self.adaptive_temporal_width = True |
|
|
|
@torch.no_grad() |
|
def forward(self, input_spec, length): |
|
if self.use_vectorized_code: |
|
return self._forward_vectorized(input_spec, length) |
|
else: |
|
return self._forward_legacy(input_spec, length) |
|
|
|
def _forward_legacy(self, input_spec, length): |
|
batch_size, num_freq_bins, _ = input_spec.shape |
|
|
|
lengths_cpu = length.cpu().numpy() |
|
|
|
fill_mask: np.array = np.full(shape=input_spec.shape, fill_value=False) |
|
freq_start_upper_bound = num_freq_bins - self.freq_width |
|
|
|
for idx in range(batch_size): |
|
|
|
for _ in range(self.freq_masks): |
|
start = self._rng.randint(0, freq_start_upper_bound) |
|
width = self._rng.randint(0, self.freq_width) |
|
fill_mask[idx, start : start + width, :] = True |
|
|
|
|
|
if self.adaptive_temporal_width: |
|
time_max_width = max(1, int(lengths_cpu[idx] * self.time_width)) |
|
else: |
|
time_max_width = self.time_width |
|
time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width) |
|
|
|
|
|
for _ in range(self.time_masks): |
|
start = self._rng.randint(0, time_start_upper_bound) |
|
width = self._rng.randint(0, time_max_width) |
|
fill_mask[idx, :, start : start + width] = True |
|
|
|
fill_mask = torch.from_numpy(fill_mask).to(input_spec.device) |
|
masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value) |
|
return masked_spec |
|
|
|
def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor: |
|
|
|
input_spec = self._apply_masks( |
|
input_spec=input_spec, |
|
num_masks=self.time_masks, |
|
length=length, |
|
width=self.time_width, |
|
axis=self.TIME_AXIS, |
|
mask_value=self.mask_value, |
|
) |
|
|
|
input_spec = self._apply_masks( |
|
input_spec=input_spec, |
|
num_masks=self.freq_masks, |
|
length=length, |
|
width=self.freq_width, |
|
axis=self.FREQ_AXIS, |
|
mask_value=self.mask_value, |
|
) |
|
return input_spec |
|
|
|
def _apply_masks( |
|
self, |
|
input_spec: torch.Tensor, |
|
num_masks: int, |
|
length: torch.Tensor, |
|
width: Union[int, float], |
|
mask_value: float, |
|
axis: int, |
|
) -> torch.Tensor: |
|
|
|
assert axis in ( |
|
self.FREQ_AXIS, |
|
self.TIME_AXIS, |
|
), f"Axis can be only be equal to frequency \ |
|
({self.FREQ_AXIS}) or time ({self.TIME_AXIS}). Received: {axis=}" |
|
assert not ( |
|
isinstance(width, float) and axis == self.FREQ_AXIS |
|
), "Float width supported \ |
|
only with time axis." |
|
|
|
batch_size = input_spec.shape[0] |
|
axis_length = input_spec.shape[axis] |
|
|
|
|
|
if axis == self.TIME_AXIS and isinstance(width, float): |
|
width = torch.clamp(width * length, max=axis_length).unsqueeze(1) |
|
|
|
|
|
|
|
mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width |
|
mask_width = mask_width.long() |
|
mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) |
|
|
|
if axis == self.TIME_AXIS: |
|
|
|
mask_start = mask_start * (length.unsqueeze(1) - mask_width) |
|
else: |
|
mask_start = mask_start * (axis_length - mask_width) |
|
|
|
mask_start = mask_start.long() |
|
mask_end = mask_start + mask_width |
|
|
|
|
|
indices = torch.arange(axis_length, device=input_spec.device) |
|
|
|
|
|
mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1)) |
|
|
|
|
|
mask_tensor = mask_tensor.any(dim=1) |
|
|
|
|
|
mask = torch.zeros_like(input_spec, dtype=torch.bool) |
|
if axis == self.TIME_AXIS: |
|
mask_ranges = mask_tensor[:, None, :] |
|
else: |
|
mask_ranges = mask_tensor[:, :, None] |
|
mask[:, :, :] = mask_ranges |
|
|
|
|
|
return input_spec.masked_fill(mask=mask, value=mask_value) |
|
|
|
|
|
class SpecCutout(nn.Module): |
|
""" |
|
Zeroes out(cuts) random rectangles in the spectrogram |
|
as described in (https://arxiv.org/abs/1708.04552). |
|
|
|
params: |
|
rect_masks - how many rectangular masks should be cut |
|
rect_freq - maximum size of cut rectangles along the frequency dimension |
|
rect_time - maximum size of cut rectangles along the time dimension |
|
""" |
|
|
|
def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None): |
|
super(SpecCutout, self).__init__() |
|
|
|
self._rng = random.Random() if rng is None else rng |
|
|
|
self.rect_masks = rect_masks |
|
self.rect_time = rect_time |
|
self.rect_freq = rect_freq |
|
|
|
@torch.no_grad() |
|
def forward(self, input_spec): |
|
sh = input_spec.shape |
|
|
|
for idx in range(sh[0]): |
|
for i in range(self.rect_masks): |
|
rect_x = self._rng.randint(0, sh[1] - self.rect_freq) |
|
rect_y = self._rng.randint(0, sh[2] - self.rect_time) |
|
|
|
w_x = self._rng.randint(0, self.rect_freq) |
|
w_y = self._rng.randint(0, self.rect_time) |
|
|
|
input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0 |
|
|
|
return input_spec |