Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py | |
""" | |
import random | |
from typing import List, Tuple | |
import torch | |
import torch.nn as nn | |
from torch.distributions import uniform | |
class SpecAugment(nn.Module): | |
def __init__(self, | |
aug_volume_factor_range: Tuple[float, float] = (0.5, 2.0), | |
): | |
super().__init__() | |
self.aug_volume_factor_range = aug_volume_factor_range | |
def augment_volume(spec: torch.Tensor, factor_range: Tuple[float, float] = (0.5, 2.0)): | |
factor = uniform.Uniform(*factor_range) | |
factor = factor.sample() | |
spec_ = spec.clone().detach() | |
spec_ *= factor | |
return spec_ | |
def forward(self, spec: torch.Tensor) -> torch.Tensor: | |
spec = self.augment_volume(spec, self.aug_volume_factor_range) | |
return spec | |
def main(): | |
spec_augment = SpecAugment() | |
spec = torch.randn(size=(1, 10, 4)) | |
print(spec) | |
spec_ = spec_augment.forward(spec) | |
print(spec_) | |
return | |
if __name__ == '__main__': | |
main() | |