File size: 8,540 Bytes
99269d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import random
from typing import Union
import numpy as np
import torch
import torch.nn as nn
class SpecAugment(nn.Module):
"""
Zeroes out(cuts) random continuous horisontal or
vertical segments of the spectrogram as described in
SpecAugment (https://arxiv.org/abs/1904.08779).
params:
freq_masks - how many frequency segments should be cut
time_masks - how many time segments should be cut
freq_width - maximum number of frequencies to be cut in one segment
time_width - maximum number of time steps to be cut in one segment.
Can be a positive integer or a float value in the range [0, 1].
If positive integer value, defines maximum number of time steps
to be cut in one segment.
If a float value, defines maximum percentage of timesteps that
are cut adaptively.
use_vectorized_code - GPU-based implementation with batched masking and GPU rng,
setting it to False reverts to the legacy implementation.
Fast implementation is inspired by torchaudio:
https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816
"""
FREQ_AXIS = 1 # Frequency axis in the spectrogram tensor
TIME_AXIS = 2 # Time axis in the spectrogram tensor
def __init__(
self,
freq_masks: int = 0,
time_masks: int = 0,
freq_width: int = 10,
time_width: Union[int, float] = 10,
rng: random.Random = None,
mask_value: float = 0.0,
use_vectorized_code: bool = True,
):
super().__init__()
self._rng = random.Random() if rng is None else rng
self.freq_masks = freq_masks
self.time_masks = time_masks
self.freq_width = freq_width
self.time_width = time_width
self.mask_value = mask_value
self.use_vectorized_code = use_vectorized_code
if isinstance(time_width, int):
self.adaptive_temporal_width = False
else:
if time_width > 1.0 or time_width < 0.0:
raise ValueError("If `time_width` is a float value, must be in range [0, 1]")
self.adaptive_temporal_width = True
@torch.no_grad()
def forward(self, input_spec, length):
if self.use_vectorized_code:
return self._forward_vectorized(input_spec, length)
else:
return self._forward_legacy(input_spec, length)
def _forward_legacy(self, input_spec, length):
batch_size, num_freq_bins, _ = input_spec.shape
# Move lengths to CPU before repeated indexing
lengths_cpu = length.cpu().numpy()
# Generate a numpy boolean mask. `True` elements represent where the input spec will be augmented.
fill_mask: np.array = np.full(shape=input_spec.shape, fill_value=False)
freq_start_upper_bound = num_freq_bins - self.freq_width
# Choose different mask ranges for each element of the batch
for idx in range(batch_size):
# Set freq masking
for _ in range(self.freq_masks):
start = self._rng.randint(0, freq_start_upper_bound)
width = self._rng.randint(0, self.freq_width)
fill_mask[idx, start : start + width, :] = True
# Derive time width, sometimes based percentage of input length.
if self.adaptive_temporal_width:
time_max_width = max(1, int(lengths_cpu[idx] * self.time_width))
else:
time_max_width = self.time_width
time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width)
# Set time masking
for _ in range(self.time_masks):
start = self._rng.randint(0, time_start_upper_bound)
width = self._rng.randint(0, time_max_width)
fill_mask[idx, :, start : start + width] = True
# Bring the mask to device and fill spec
fill_mask = torch.from_numpy(fill_mask).to(input_spec.device)
masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
return masked_spec
def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
# time masks
input_spec = self._apply_masks(
input_spec=input_spec,
num_masks=self.time_masks,
length=length,
width=self.time_width,
axis=self.TIME_AXIS,
mask_value=self.mask_value,
)
# freq masks
input_spec = self._apply_masks(
input_spec=input_spec,
num_masks=self.freq_masks,
length=length,
width=self.freq_width,
axis=self.FREQ_AXIS,
mask_value=self.mask_value,
)
return input_spec
def _apply_masks(
self,
input_spec: torch.Tensor,
num_masks: int,
length: torch.Tensor,
width: Union[int, float],
mask_value: float,
axis: int,
) -> torch.Tensor:
assert axis in (
self.FREQ_AXIS,
self.TIME_AXIS,
), f"Axis can be only be equal to frequency \
({self.FREQ_AXIS}) or time ({self.TIME_AXIS}). Received: {axis=}"
assert not (
isinstance(width, float) and axis == self.FREQ_AXIS
), "Float width supported \
only with time axis."
batch_size = input_spec.shape[0]
axis_length = input_spec.shape[axis]
# If width is float then it is transformed into a tensor
if axis == self.TIME_AXIS and isinstance(width, float):
width = torch.clamp(width * length, max=axis_length).unsqueeze(1)
# Generate [0-1) random numbers and then scale the tensors.
# Use float32 dtype for begin/end mask markers before they are quantized to long.
mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width
mask_width = mask_width.long()
mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32)
if axis == self.TIME_AXIS:
# length can only be used for the time axis
mask_start = mask_start * (length.unsqueeze(1) - mask_width)
else:
mask_start = mask_start * (axis_length - mask_width)
mask_start = mask_start.long()
mask_end = mask_start + mask_width
# Create mask values using vectorized indexing
indices = torch.arange(axis_length, device=input_spec.device)
# Create a mask_tensor with all the indices.
# The mask_tensor shape is (batch_size, num_masks, axis_length).
mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1))
# Reduce masks to one mask
mask_tensor = mask_tensor.any(dim=1)
# Create a final mask that aligns with the full tensor
mask = torch.zeros_like(input_spec, dtype=torch.bool)
if axis == self.TIME_AXIS:
mask_ranges = mask_tensor[:, None, :]
else: # axis == self.FREQ_AXIS
mask_ranges = mask_tensor[:, :, None]
mask[:, :, :] = mask_ranges
# Apply the mask value
return input_spec.masked_fill(mask=mask, value=mask_value)
class SpecCutout(nn.Module):
"""
Zeroes out(cuts) random rectangles in the spectrogram
as described in (https://arxiv.org/abs/1708.04552).
params:
rect_masks - how many rectangular masks should be cut
rect_freq - maximum size of cut rectangles along the frequency dimension
rect_time - maximum size of cut rectangles along the time dimension
"""
def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None):
super(SpecCutout, self).__init__()
self._rng = random.Random() if rng is None else rng
self.rect_masks = rect_masks
self.rect_time = rect_time
self.rect_freq = rect_freq
@torch.no_grad()
def forward(self, input_spec):
sh = input_spec.shape
for idx in range(sh[0]):
for i in range(self.rect_masks):
rect_x = self._rng.randint(0, sh[1] - self.rect_freq)
rect_y = self._rng.randint(0, sh[2] - self.rect_time)
w_x = self._rng.randint(0, self.rect_freq)
w_y = self._rng.randint(0, self.rect_time)
input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
return input_spec |