import math from typing import Tuple import torch import torch.nn as nn from torchaudio.transforms import SpecAugment from torch import Tensor from torchvision.transforms import functional as F class AugmentLayer(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg # Initialize MixUp self.mixup = MixUp( alpha=cfg.augment.mixup_alpha, num_classes=cfg.num_classes, p=cfg.augment.mixup_p, inplace=True, ) # Initialize other augmentations self.time_freq_mask = SpecAugment( n_time_masks=cfg.augment.n_time_masks, time_mask_param=cfg.augment.time_mask_param, n_freq_masks=cfg.augment.n_freq_masks, freq_mask_param=cfg.augment.freq_mask_param, p=cfg.augment.time_freq_mask_p, zero_masking=True, ) def forward(self, spec, y=None): # Apply MixUp or CutMix with RandomChoice if y is not None: # img = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames) spec, y = self.mixup(spec, y) # spec = img.squeeze(1) # shape: (batch_size, n_mels, n_frames) # Apply TimeMasking and FrequencyMasking spec = self.time_freq_mask(spec) return spec, y class MixUp(torch.nn.Module): """Randomly apply MixUp to the provided batch and targets. The class implements the data augmentations as described in the paper `"mixup: Beyond Empirical Risk Minimization" `_. Args: num_classes (int): number of classes used for one-hot encoding. p (float): probability of the batch being transformed. Default value is 0.5. alpha (float): hyperparameter of the Beta distribution used for mixup. Default value is 1.0. inplace (bool): boolean to make this transform inplace. Default set to False. """ def __init__( self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False, ) -> None: super().__init__() if num_classes < 1: raise ValueError( f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" ) if alpha <= 0: raise ValueError("Alpha param can't be zero.") self.num_classes = num_classes self.p = p self.alpha = alpha self.inplace = inplace def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """ Args: batch (Tensor): Float tensor of size (B, C, H, W) target (Tensor): Integer tensor of size (B, ) Returns: Tensor: Randomly transformed batch. """ if batch.ndim != 3 and batch.ndim != 2: raise ValueError( f"Batch ndim should be 3 (b, f, t) or 2 (b, n). Got {batch.ndim}" ) if target.ndim != 1: raise ValueError(f"Target ndim should be 1. Got {target.ndim}") if not batch.is_floating_point(): raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") if target.dtype != torch.int64 and self.num_classes > 1: raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") if not self.inplace: batch = batch.clone() target = target.clone() if target.ndim == 1 and self.num_classes > 1: target = torch.nn.functional.one_hot(target, num_classes=self.num_classes) target = target.to(dtype=batch.dtype) if torch.rand(1).item() >= self.p: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs batch_rolled = batch.roll(1, 0) target_rolled = target.roll(1, 0) # Implemented as on mixup paper, page 3. lambda_param = float( torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] ) batch_rolled.mul_(1.0 - lambda_param) batch.mul_(lambda_param).add_(batch_rolled) target_rolled.mul_(1.0 - lambda_param) target.mul_(lambda_param).add_(target_rolled) return batch, target def __repr__(self) -> str: s = ( f"{self.__class__.__name__}(" f"num_classes={self.num_classes}" f", p={self.p}" f", alpha={self.alpha}" f", inplace={self.inplace}" f")" ) return s # Todo: height of spec should be 1, adjust it for audio input (bs, n_samples) class CutMix(torch.nn.Module): """Randomly apply CutMix to the provided batch and targets. The class implements the data augmentations as described in the paper `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. Args: num_classes (int): number of classes used for one-hot encoding. p (float): probability of the batch being transformed. Default value is 0.5. alpha (float): hyperparameter of the Beta distribution used for cutmix. Default value is 1.0. inplace (bool): boolean to make this transform inplace. Default set to False. """ def __init__( self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False, ) -> None: super().__init__() if num_classes < 1: raise ValueError( "Please provide a valid positive value for the num_classes." ) if alpha <= 0: raise ValueError("Alpha param can't be zero.") self.num_classes = num_classes self.p = p self.alpha = alpha self.inplace = inplace def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """ Args: batch (Tensor): Float tensor of size (B, C, H, W) target (Tensor): Integer tensor of size (B, ) Returns: Tensor: Randomly transformed batch. """ if batch.ndim != 4: raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") if target.ndim != 1: raise ValueError(f"Target ndim should be 1. Got {target.ndim}") if not batch.is_floating_point(): raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") if target.dtype != torch.int64 and self.num_classes > 1: raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") if not self.inplace: batch = batch.clone() target = target.clone() if target.ndim == 1 and self.num_classes > 1: target = torch.nn.functional.one_hot(target, num_classes=self.num_classes) target = target.to(dtype=batch.dtype) if torch.rand(1).item() >= self.p: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs batch_rolled = batch.roll(1, 0) target_rolled = target.roll(1, 0) # Implemented as on cutmix paper, page 12 (with minor corrections on typos). lambda_param = float( torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] ) _, H, W = F.get_dimensions(batch) r_x = torch.randint(W, (1,)) r_y = torch.randint(H, (1,)) r = 0.5 * math.sqrt(1.0 - lambda_param) r_w_half = int(r * W) r_h_half = int(r * H) x1 = int(torch.clamp(r_x - r_w_half, min=0)) y1 = int(torch.clamp(r_y - r_h_half, min=0)) x2 = int(torch.clamp(r_x + r_w_half, max=W)) y2 = int(torch.clamp(r_y + r_h_half, max=H)) batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) target_rolled.mul_(1.0 - lambda_param) target.mul_(lambda_param).add_(target_rolled) return batch, target def __repr__(self) -> str: s = ( f"{self.__class__.__name__}(" f"num_classes={self.num_classes}" f", p={self.p}" f", alpha={self.alpha}" f", inplace={self.inplace}" f")" ) return s