|
|
|
|
|
""" |
|
https://github.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py |
|
""" |
|
import random |
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.distributions import uniform |
|
|
|
|
|
class SpecAugment(nn.Module): |
|
def __init__(self, |
|
aug_volume_factor_range: Tuple[float, float] = (0.5, 2.0), |
|
): |
|
super().__init__() |
|
self.aug_volume_factor_range = aug_volume_factor_range |
|
|
|
@staticmethod |
|
def augment_volume(spec: torch.Tensor, factor_range: Tuple[float, float] = (0.5, 2.0)): |
|
factor = uniform.Uniform(*factor_range) |
|
factor = factor.sample() |
|
spec_ = spec.clone().detach() |
|
spec_ *= factor |
|
return spec_ |
|
|
|
def forward(self, spec: torch.Tensor) -> torch.Tensor: |
|
spec = self.augment_volume(spec, self.aug_volume_factor_range) |
|
return spec |
|
|
|
|
|
def main(): |
|
spec_augment = SpecAugment() |
|
|
|
spec = torch.randn(size=(1, 10, 4)) |
|
print(spec) |
|
|
|
spec_ = spec_augment.forward(spec) |
|
print(spec_) |
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|