Spaces:
Runtime error
Runtime error
| import numpy | |
| import pyloudnorm as pyln | |
| import torch | |
| from torchaudio.transforms import MelSpectrogram | |
| from torchaudio.transforms import Resample | |
| class AudioPreprocessor: | |
| def __init__(self, input_sr, output_sr=None, cut_silence=False, do_loudnorm=False, device="cpu"): | |
| """ | |
| The parameters are by default set up to do well | |
| on a 16kHz signal. A different sampling rate may | |
| require different hop_length and n_fft (e.g. | |
| doubling frequency --> doubling hop_length and | |
| doubling n_fft) | |
| """ | |
| self.cut_silence = cut_silence | |
| self.do_loudnorm = do_loudnorm | |
| self.device = device | |
| self.input_sr = input_sr | |
| self.output_sr = output_sr | |
| self.meter = pyln.Meter(input_sr) | |
| self.final_sr = input_sr | |
| self.wave_to_spectrogram = LogMelSpec(output_sr if output_sr is not None else input_sr).to(device) | |
| if cut_silence: | |
| torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround | |
| # careful: assumes 16kHz or 8kHz audio | |
| self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', | |
| force_reload=False, | |
| onnx=False, | |
| verbose=False) | |
| (self.get_speech_timestamps, | |
| self.save_audio, | |
| self.read_audio, | |
| self.VADIterator, | |
| self.collect_chunks) = utils | |
| torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets | |
| # this to false globally during model loading rather than using inference mode or no_grad | |
| self.silero_model = self.silero_model.to(self.device) | |
| if output_sr is not None and output_sr != input_sr: | |
| self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device) | |
| self.final_sr = output_sr | |
| else: | |
| self.resample = lambda x: x | |
| def cut_leading_and_trailing_silence(self, audio): | |
| """ | |
| https://github.com/snakers4/silero-vad | |
| """ | |
| with torch.inference_mode(): | |
| speech_timestamps = self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr) | |
| try: | |
| result = audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']] | |
| return result | |
| except IndexError: | |
| print("Audio might be too short to cut silences from front and back.") | |
| return audio | |
| def normalize_loudness(self, audio): | |
| """ | |
| normalize the amplitudes according to | |
| their decibels, so this should turn any | |
| signal with different magnitudes into | |
| the same magnitude by analysing loudness | |
| """ | |
| try: | |
| loudness = self.meter.integrated_loudness(audio) | |
| except ValueError: | |
| # if the audio is too short, a value error will arise | |
| return audio | |
| loud_normed = pyln.normalize.loudness(audio, loudness, -30.0) | |
| peak = numpy.amax(numpy.abs(loud_normed)) | |
| peak_normed = numpy.divide(loud_normed, peak) | |
| return peak_normed | |
| def normalize_audio(self, audio): | |
| """ | |
| one function to apply them all in an | |
| order that makes sense. | |
| """ | |
| if self.do_loudnorm: | |
| audio = self.normalize_loudness(audio) | |
| audio = torch.tensor(audio, device=self.device, dtype=torch.float32) | |
| audio = self.resample(audio) | |
| if self.cut_silence: | |
| audio = self.cut_leading_and_trailing_silence(audio) | |
| return audio | |
| def audio_to_mel_spec_tensor(self, audio, normalize=False, explicit_sampling_rate=None): | |
| """ | |
| explicit_sampling_rate is for when | |
| normalization has already been applied | |
| and that included resampling. No way | |
| to detect the current input_sr of the incoming | |
| audio | |
| """ | |
| if type(audio) != torch.tensor and type(audio) != torch.Tensor: | |
| audio = torch.tensor(audio, device=self.device) | |
| if explicit_sampling_rate is None or explicit_sampling_rate == self.output_sr: | |
| return self.wave_to_spectrogram(audio.float()) | |
| else: | |
| if explicit_sampling_rate != self.input_sr: | |
| print("WARNING: different sampling rate used, this will be very slow if it happens often. Consider creating a dedicated audio processor.") | |
| self.resample = Resample(orig_freq=explicit_sampling_rate, new_freq=self.output_sr).to(self.device) | |
| self.input_sr = explicit_sampling_rate | |
| audio = self.resample(audio.float()) | |
| return self.wave_to_spectrogram(audio) | |
| class LogMelSpec(torch.nn.Module): | |
| def __init__(self, sr, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.spec = MelSpectrogram(sample_rate=sr, | |
| n_fft=1024, | |
| win_length=1024, | |
| hop_length=256, | |
| f_min=40.0, | |
| f_max=sr // 2, | |
| pad=0, | |
| n_mels=128, | |
| power=2.0, | |
| normalized=False, | |
| center=True, | |
| pad_mode='reflect', | |
| mel_scale='htk') | |
| def forward(self, audio): | |
| melspec = self.spec(audio.float()) | |
| zero_mask = melspec == 0 | |
| melspec[zero_mask] = 1e-8 | |
| logmelspec = torch.log10(melspec) | |
| return logmelspec | |
| if __name__ == '__main__': | |
| import soundfile | |
| wav, sr = soundfile.read("../audios/ad00_0004.wav") | |
| ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True) | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 6)) | |
| import librosa.display as lbd | |
| lbd.specshow(ap.audio_to_mel_spec_tensor(wav).cpu().numpy(), | |
| ax=ax, | |
| sr=16000, | |
| cmap='GnBu', | |
| y_axis='features', | |
| x_axis=None, | |
| hop_length=256) | |
| plt.show() | |