Spaces:
Sleeping
Sleeping
import math | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
from torchaudio.transforms import SpecAugment | |
from torch import Tensor | |
from torchvision.transforms import functional as F | |
class AugmentLayer(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
# Initialize MixUp | |
self.mixup = MixUp( | |
alpha=cfg.augment.mixup_alpha, | |
num_classes=cfg.num_classes, | |
p=cfg.augment.mixup_p, | |
inplace=True, | |
) | |
# Initialize other augmentations | |
self.time_freq_mask = SpecAugment( | |
n_time_masks=cfg.augment.n_time_masks, | |
time_mask_param=cfg.augment.time_mask_param, | |
n_freq_masks=cfg.augment.n_freq_masks, | |
freq_mask_param=cfg.augment.freq_mask_param, | |
p=cfg.augment.time_freq_mask_p, | |
zero_masking=True, | |
) | |
def forward(self, spec, y=None): | |
# Apply MixUp or CutMix with RandomChoice | |
if y is not None: | |
# img = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames) | |
spec, y = self.mixup(spec, y) | |
# spec = img.squeeze(1) # shape: (batch_size, n_mels, n_frames) | |
# Apply TimeMasking and FrequencyMasking | |
spec = self.time_freq_mask(spec) | |
return spec, y | |
class MixUp(torch.nn.Module): | |
"""Randomly apply MixUp to the provided batch and targets. | |
The class implements the data augmentations as described in the paper | |
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. | |
Args: | |
num_classes (int): number of classes used for one-hot encoding. | |
p (float): probability of the batch being transformed. Default value is 0.5. | |
alpha (float): hyperparameter of the Beta distribution used for mixup. | |
Default value is 1.0. | |
inplace (bool): boolean to make this transform inplace. Default set to False. | |
""" | |
def __init__( | |
self, | |
num_classes: int, | |
p: float = 0.5, | |
alpha: float = 1.0, | |
inplace: bool = False, | |
) -> None: | |
super().__init__() | |
if num_classes < 1: | |
raise ValueError( | |
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" | |
) | |
if alpha <= 0: | |
raise ValueError("Alpha param can't be zero.") | |
self.num_classes = num_classes | |
self.p = p | |
self.alpha = alpha | |
self.inplace = inplace | |
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | |
""" | |
Args: | |
batch (Tensor): Float tensor of size (B, C, H, W) | |
target (Tensor): Integer tensor of size (B, ) | |
Returns: | |
Tensor: Randomly transformed batch. | |
""" | |
if batch.ndim != 3 and batch.ndim != 2: | |
raise ValueError( | |
f"Batch ndim should be 3 (b, f, t) or 2 (b, n). Got {batch.ndim}" | |
) | |
if target.ndim != 1: | |
raise ValueError(f"Target ndim should be 1. Got {target.ndim}") | |
if not batch.is_floating_point(): | |
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") | |
if target.dtype != torch.int64 and self.num_classes > 1: | |
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") | |
if not self.inplace: | |
batch = batch.clone() | |
target = target.clone() | |
if target.ndim == 1 and self.num_classes > 1: | |
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes) | |
target = target.to(dtype=batch.dtype) | |
if torch.rand(1).item() >= self.p: | |
return batch, target | |
# It's faster to roll the batch by one instead of shuffling it to create image pairs | |
batch_rolled = batch.roll(1, 0) | |
target_rolled = target.roll(1, 0) | |
# Implemented as on mixup paper, page 3. | |
lambda_param = float( | |
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] | |
) | |
batch_rolled.mul_(1.0 - lambda_param) | |
batch.mul_(lambda_param).add_(batch_rolled) | |
target_rolled.mul_(1.0 - lambda_param) | |
target.mul_(lambda_param).add_(target_rolled) | |
return batch, target | |
def __repr__(self) -> str: | |
s = ( | |
f"{self.__class__.__name__}(" | |
f"num_classes={self.num_classes}" | |
f", p={self.p}" | |
f", alpha={self.alpha}" | |
f", inplace={self.inplace}" | |
f")" | |
) | |
return s | |
# Todo: height of spec should be 1, adjust it for audio input (bs, n_samples) | |
class CutMix(torch.nn.Module): | |
"""Randomly apply CutMix to the provided batch and targets. | |
The class implements the data augmentations as described in the paper | |
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" | |
<https://arxiv.org/abs/1905.04899>`_. | |
Args: | |
num_classes (int): number of classes used for one-hot encoding. | |
p (float): probability of the batch being transformed. Default value is 0.5. | |
alpha (float): hyperparameter of the Beta distribution used for cutmix. | |
Default value is 1.0. | |
inplace (bool): boolean to make this transform inplace. Default set to False. | |
""" | |
def __init__( | |
self, | |
num_classes: int, | |
p: float = 0.5, | |
alpha: float = 1.0, | |
inplace: bool = False, | |
) -> None: | |
super().__init__() | |
if num_classes < 1: | |
raise ValueError( | |
"Please provide a valid positive value for the num_classes." | |
) | |
if alpha <= 0: | |
raise ValueError("Alpha param can't be zero.") | |
self.num_classes = num_classes | |
self.p = p | |
self.alpha = alpha | |
self.inplace = inplace | |
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | |
""" | |
Args: | |
batch (Tensor): Float tensor of size (B, C, H, W) | |
target (Tensor): Integer tensor of size (B, ) | |
Returns: | |
Tensor: Randomly transformed batch. | |
""" | |
if batch.ndim != 4: | |
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") | |
if target.ndim != 1: | |
raise ValueError(f"Target ndim should be 1. Got {target.ndim}") | |
if not batch.is_floating_point(): | |
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") | |
if target.dtype != torch.int64 and self.num_classes > 1: | |
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") | |
if not self.inplace: | |
batch = batch.clone() | |
target = target.clone() | |
if target.ndim == 1 and self.num_classes > 1: | |
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes) | |
target = target.to(dtype=batch.dtype) | |
if torch.rand(1).item() >= self.p: | |
return batch, target | |
# It's faster to roll the batch by one instead of shuffling it to create image pairs | |
batch_rolled = batch.roll(1, 0) | |
target_rolled = target.roll(1, 0) | |
# Implemented as on cutmix paper, page 12 (with minor corrections on typos). | |
lambda_param = float( | |
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] | |
) | |
_, H, W = F.get_dimensions(batch) | |
r_x = torch.randint(W, (1,)) | |
r_y = torch.randint(H, (1,)) | |
r = 0.5 * math.sqrt(1.0 - lambda_param) | |
r_w_half = int(r * W) | |
r_h_half = int(r * H) | |
x1 = int(torch.clamp(r_x - r_w_half, min=0)) | |
y1 = int(torch.clamp(r_y - r_h_half, min=0)) | |
x2 = int(torch.clamp(r_x + r_w_half, max=W)) | |
y2 = int(torch.clamp(r_y + r_h_half, max=H)) | |
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] | |
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) | |
target_rolled.mul_(1.0 - lambda_param) | |
target.mul_(lambda_param).add_(target_rolled) | |
return batch, target | |
def __repr__(self) -> str: | |
s = ( | |
f"{self.__class__.__name__}(" | |
f"num_classes={self.num_classes}" | |
f", p={self.p}" | |
f", alpha={self.alpha}" | |
f", inplace={self.inplace}" | |
f")" | |
) | |
return s | |