Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import torch | |
| from typing import Sequence | |
| from typing import Union | |
| def mask_along_axis( | |
| spec: torch.Tensor, | |
| spec_lengths: torch.Tensor, | |
| mask_width_range: Sequence[int] = (0, 30), | |
| dim: int = 1, | |
| num_mask: int = 2, | |
| replace_with_zero: bool = True, | |
| ): | |
| """Apply mask along the specified direction. | |
| Args: | |
| spec: (Batch, Length, Freq) | |
| spec_lengths: (Length): Not using lengths in this implementation | |
| mask_width_range: Select the width randomly between this range | |
| """ | |
| org_size = spec.size() | |
| if spec.dim() == 4: | |
| # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq) | |
| spec = spec.view(-1, spec.size(2), spec.size(3)) | |
| B = spec.shape[0] | |
| # D = Length or Freq | |
| D = spec.shape[dim] | |
| # mask_length: (B, num_mask, 1) | |
| mask_length = torch.randint( | |
| mask_width_range[0], | |
| mask_width_range[1], | |
| (B, num_mask), | |
| device=spec.device, | |
| ).unsqueeze(2) | |
| # mask_pos: (B, num_mask, 1) | |
| mask_pos = torch.randint( | |
| 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device | |
| ).unsqueeze(2) | |
| # aran: (1, 1, D) | |
| aran = torch.arange(D, device=spec.device)[None, None, :] | |
| # mask: (Batch, num_mask, D) | |
| mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length)) | |
| # Multiply masks: (Batch, num_mask, D) -> (Batch, D) | |
| mask = mask.any(dim=1) | |
| if dim == 1: | |
| # mask: (Batch, Length, 1) | |
| mask = mask.unsqueeze(2) | |
| elif dim == 2: | |
| # mask: (Batch, 1, Freq) | |
| mask = mask.unsqueeze(1) | |
| if replace_with_zero: | |
| value = 0.0 | |
| else: | |
| value = spec.mean() | |
| if spec.requires_grad: | |
| spec = spec.masked_fill(mask, value) | |
| else: | |
| spec = spec.masked_fill_(mask, value) | |
| spec = spec.view(*org_size) | |
| return spec, spec_lengths | |
| def mask_along_axis_lfr( | |
| spec: torch.Tensor, | |
| spec_lengths: torch.Tensor, | |
| mask_width_range: Sequence[int] = (0, 30), | |
| dim: int = 1, | |
| num_mask: int = 2, | |
| replace_with_zero: bool = True, | |
| lfr_rate: int = 1, | |
| ): | |
| """Apply mask along the specified direction. | |
| Args: | |
| spec: (Batch, Length, Freq) | |
| spec_lengths: (Length): Not using lengths in this implementation | |
| mask_width_range: Select the width randomly between this range | |
| lfr_rate:low frame rate | |
| """ | |
| org_size = spec.size() | |
| if spec.dim() == 4: | |
| # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq) | |
| spec = spec.view(-1, spec.size(2), spec.size(3)) | |
| B = spec.shape[0] | |
| # D = Length or Freq | |
| D = spec.shape[dim] // lfr_rate | |
| # mask_length: (B, num_mask, 1) | |
| mask_length = torch.randint( | |
| mask_width_range[0], | |
| mask_width_range[1], | |
| (B, num_mask), | |
| device=spec.device, | |
| ).unsqueeze(2) | |
| if lfr_rate > 1: | |
| mask_length = mask_length.repeat(1, lfr_rate, 1) | |
| # mask_pos: (B, num_mask, 1) | |
| mask_pos = torch.randint( | |
| 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device | |
| ).unsqueeze(2) | |
| if lfr_rate > 1: | |
| mask_pos_raw = mask_pos.clone() | |
| mask_pos = torch.zeros((B, 0, 1), device=spec.device, dtype=torch.int32) | |
| for i in range(lfr_rate): | |
| mask_pos_i = mask_pos_raw + D * i | |
| mask_pos = torch.cat((mask_pos, mask_pos_i), dim=1) | |
| # aran: (1, 1, D) | |
| D = spec.shape[dim] | |
| aran = torch.arange(D, device=spec.device)[None, None, :] | |
| # mask: (Batch, num_mask, D) | |
| mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length)) | |
| # Multiply masks: (Batch, num_mask, D) -> (Batch, D) | |
| mask = mask.any(dim=1) | |
| if dim == 1: | |
| # mask: (Batch, Length, 1) | |
| mask = mask.unsqueeze(2) | |
| elif dim == 2: | |
| # mask: (Batch, 1, Freq) | |
| mask = mask.unsqueeze(1) | |
| if replace_with_zero: | |
| value = 0.0 | |
| else: | |
| value = spec.mean() | |
| if spec.requires_grad: | |
| spec = spec.masked_fill(mask, value) | |
| else: | |
| spec = spec.masked_fill_(mask, value) | |
| spec = spec.view(*org_size) | |
| return spec, spec_lengths | |
| class MaskAlongAxis(torch.nn.Module): | |
| def __init__( | |
| self, | |
| mask_width_range: Union[int, Sequence[int]] = (0, 30), | |
| num_mask: int = 2, | |
| dim: Union[int, str] = "time", | |
| replace_with_zero: bool = True, | |
| ): | |
| if isinstance(mask_width_range, int): | |
| mask_width_range = (0, mask_width_range) | |
| if len(mask_width_range) != 2: | |
| raise TypeError( | |
| f"mask_width_range must be a tuple of int and int values: " | |
| f"{mask_width_range}", | |
| ) | |
| assert mask_width_range[1] > mask_width_range[0] | |
| if isinstance(dim, str): | |
| if dim == "time": | |
| dim = 1 | |
| elif dim == "freq": | |
| dim = 2 | |
| else: | |
| raise ValueError("dim must be int, 'time' or 'freq'") | |
| if dim == 1: | |
| self.mask_axis = "time" | |
| elif dim == 2: | |
| self.mask_axis = "freq" | |
| else: | |
| self.mask_axis = "unknown" | |
| super().__init__() | |
| self.mask_width_range = mask_width_range | |
| self.num_mask = num_mask | |
| self.dim = dim | |
| self.replace_with_zero = replace_with_zero | |
| def extra_repr(self): | |
| return ( | |
| f"mask_width_range={self.mask_width_range}, " | |
| f"num_mask={self.num_mask}, axis={self.mask_axis}" | |
| ) | |
| def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None): | |
| """Forward function. | |
| Args: | |
| spec: (Batch, Length, Freq) | |
| """ | |
| return mask_along_axis( | |
| spec, | |
| spec_lengths, | |
| mask_width_range=self.mask_width_range, | |
| dim=self.dim, | |
| num_mask=self.num_mask, | |
| replace_with_zero=self.replace_with_zero, | |
| ) | |
| class MaskAlongAxisVariableMaxWidth(torch.nn.Module): | |
| """Mask input spec along a specified axis with variable maximum width. | |
| Formula: | |
| max_width = max_width_ratio * seq_len | |
| """ | |
| def __init__( | |
| self, | |
| mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05), | |
| num_mask: int = 2, | |
| dim: Union[int, str] = "time", | |
| replace_with_zero: bool = True, | |
| ): | |
| if isinstance(mask_width_ratio_range, float): | |
| mask_width_ratio_range = (0.0, mask_width_ratio_range) | |
| if len(mask_width_ratio_range) != 2: | |
| raise TypeError( | |
| f"mask_width_ratio_range must be a tuple of float and float values: " | |
| f"{mask_width_ratio_range}", | |
| ) | |
| assert mask_width_ratio_range[1] > mask_width_ratio_range[0] | |
| if isinstance(dim, str): | |
| if dim == "time": | |
| dim = 1 | |
| elif dim == "freq": | |
| dim = 2 | |
| else: | |
| raise ValueError("dim must be int, 'time' or 'freq'") | |
| if dim == 1: | |
| self.mask_axis = "time" | |
| elif dim == 2: | |
| self.mask_axis = "freq" | |
| else: | |
| self.mask_axis = "unknown" | |
| super().__init__() | |
| self.mask_width_ratio_range = mask_width_ratio_range | |
| self.num_mask = num_mask | |
| self.dim = dim | |
| self.replace_with_zero = replace_with_zero | |
| def extra_repr(self): | |
| return ( | |
| f"mask_width_ratio_range={self.mask_width_ratio_range}, " | |
| f"num_mask={self.num_mask}, axis={self.mask_axis}" | |
| ) | |
| def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None): | |
| """Forward function. | |
| Args: | |
| spec: (Batch, Length, Freq) | |
| """ | |
| max_seq_len = spec.shape[self.dim] | |
| min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0]) | |
| min_mask_width = max([0, min_mask_width]) | |
| max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1]) | |
| max_mask_width = min([max_seq_len, max_mask_width]) | |
| if max_mask_width > min_mask_width: | |
| return mask_along_axis( | |
| spec, | |
| spec_lengths, | |
| mask_width_range=(min_mask_width, max_mask_width), | |
| dim=self.dim, | |
| num_mask=self.num_mask, | |
| replace_with_zero=self.replace_with_zero, | |
| ) | |
| return spec, spec_lengths | |
| class MaskAlongAxisLFR(torch.nn.Module): | |
| def __init__( | |
| self, | |
| mask_width_range: Union[int, Sequence[int]] = (0, 30), | |
| num_mask: int = 2, | |
| dim: Union[int, str] = "time", | |
| replace_with_zero: bool = True, | |
| lfr_rate: int = 1, | |
| ): | |
| if isinstance(mask_width_range, int): | |
| mask_width_range = (0, mask_width_range) | |
| if len(mask_width_range) != 2: | |
| raise TypeError( | |
| f"mask_width_range must be a tuple of int and int values: " | |
| f"{mask_width_range}", | |
| ) | |
| assert mask_width_range[1] > mask_width_range[0] | |
| if isinstance(dim, str): | |
| if dim == "time": | |
| dim = 1 | |
| lfr_rate = 1 | |
| elif dim == "freq": | |
| dim = 2 | |
| else: | |
| raise ValueError("dim must be int, 'time' or 'freq'") | |
| if dim == 1: | |
| self.mask_axis = "time" | |
| lfr_rate = 1 | |
| elif dim == 2: | |
| self.mask_axis = "freq" | |
| else: | |
| self.mask_axis = "unknown" | |
| super().__init__() | |
| self.mask_width_range = mask_width_range | |
| self.num_mask = num_mask | |
| self.dim = dim | |
| self.replace_with_zero = replace_with_zero | |
| self.lfr_rate = lfr_rate | |
| def extra_repr(self): | |
| return ( | |
| f"mask_width_range={self.mask_width_range}, " | |
| f"num_mask={self.num_mask}, axis={self.mask_axis}" | |
| ) | |
| def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None): | |
| """Forward function. | |
| Args: | |
| spec: (Batch, Length, Freq) | |
| """ | |
| return mask_along_axis_lfr( | |
| spec, | |
| spec_lengths, | |
| mask_width_range=self.mask_width_range, | |
| dim=self.dim, | |
| num_mask=self.num_mask, | |
| replace_with_zero=self.replace_with_zero, | |
| lfr_rate=self.lfr_rate, | |
| ) | |