Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| from torchaudio import transforms as taT, functional as taF | |
| import torch.nn as nn | |
| class WaveformTrainingPipeline(torch.nn.Module): | |
| def __init__( | |
| self, | |
| input_freq=16000, | |
| resample_freq=16000, | |
| expected_duration=6, | |
| snr_mean=6.0, | |
| noise_path=None, | |
| ): | |
| super().__init__() | |
| self.input_freq = input_freq | |
| self.snr_mean = snr_mean | |
| self.noise = self.get_noise(noise_path) | |
| self.resample_frequency = resample_freq | |
| self.resample = taT.Resample(input_freq, resample_freq) | |
| self.preprocess_waveform = WaveformPreprocessing( | |
| resample_freq * expected_duration | |
| ) | |
| def get_noise(self, path) -> torch.Tensor: | |
| if path is None: | |
| return None | |
| noise, sr = torchaudio.load(path) | |
| if noise.shape[0] > 1: | |
| noise = noise.mean(0, keepdim=True) | |
| if sr != self.input_freq: | |
| noise = taF.resample(noise, sr, self.input_freq) | |
| return noise | |
| def add_noise(self, waveform: torch.Tensor) -> torch.Tensor: | |
| assert ( | |
| self.noise is not None | |
| ), "Cannot add noise because a noise file was not provided." | |
| num_repeats = waveform.shape[1] // self.noise.shape[1] + 1 | |
| noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]] | |
| noise_power = noise.norm(p=2) | |
| signal_power = waveform.norm(p=2) | |
| snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0) | |
| snr = torch.exp(snr_db / 10) | |
| scale = snr * noise_power / signal_power | |
| noisy_waveform = (scale * waveform + noise) / 2 | |
| return noisy_waveform | |
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: | |
| waveform = self.resample(waveform) | |
| waveform = self.preprocess_waveform(waveform) | |
| if self.noise is not None: | |
| waveform = self.add_noise(waveform) | |
| return waveform | |
| class SpectrogramTrainingPipeline(WaveformTrainingPipeline): | |
| def __init__( | |
| self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.mask_count = mask_count | |
| self.audio_to_spectrogram = AudioToSpectrogram( | |
| sample_rate=self.resample_frequency, | |
| ) | |
| self.freq_mask = taT.FrequencyMasking(freq_mask_size) | |
| self.time_mask = taT.TimeMasking(time_mask_size) | |
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: | |
| waveform = super().forward(waveform) | |
| spec = self.audio_to_spectrogram(waveform) | |
| # Spectrogram augmentation | |
| for _ in range(self.mask_count): | |
| spec = self.freq_mask(spec) | |
| spec = self.time_mask(spec) | |
| return spec | |
| class WaveformPreprocessing(torch.nn.Module): | |
| def __init__(self, expected_sample_length: int): | |
| super().__init__() | |
| self.expected_sample_length = expected_sample_length | |
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: | |
| # Take out extra channels | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(0, keepdim=True) | |
| # ensure it is the correct length | |
| waveform = self._rectify_duration(waveform) | |
| return waveform | |
| def _rectify_duration(self, waveform: torch.Tensor): | |
| expected_samples = self.expected_sample_length | |
| sample_count = waveform.shape[1] | |
| if expected_samples == sample_count: | |
| return waveform | |
| elif expected_samples > sample_count: | |
| pad_amount = expected_samples - sample_count | |
| return torch.nn.functional.pad( | |
| waveform, (0, pad_amount), mode="constant", value=0.0 | |
| ) | |
| else: | |
| return waveform[:, :expected_samples] | |
| class AudioToSpectrogram: | |
| def __init__( | |
| self, | |
| sample_rate=16000, | |
| ): | |
| self.spec = taT.MelSpectrogram( | |
| sample_rate=sample_rate, n_mels=128, n_fft=1024 | |
| ) # Note: this doesn't work on mps right now. | |
| self.to_db = taT.AmplitudeToDB() | |
| def __call__(self, waveform: torch.Tensor) -> torch.Tensor: | |
| spectrogram = self.spec(waveform) | |
| spectrogram = self.to_db(spectrogram) | |
| # Normalize | |
| spectrogram = (spectrogram - spectrogram.mean()) / (2 * spectrogram.std()) | |
| return spectrogram | |