Spaces:
Sleeping
Sleeping
import os.path as op | |
from typing import BinaryIO, Optional, Tuple, Union | |
import numpy as np | |
def get_waveform( | |
path_or_fp: Union[str, BinaryIO], normalization=True | |
) -> Tuple[np.ndarray, int]: | |
"""Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. | |
Args: | |
path_or_fp (str or BinaryIO): the path or file-like object | |
normalization (bool): Normalize values to [-1, 1] (Default: True) | |
""" | |
if isinstance(path_or_fp, str): | |
ext = op.splitext(op.basename(path_or_fp))[1] | |
if ext not in {".flac", ".wav"}: | |
raise ValueError(f"Unsupported audio format: {ext}") | |
try: | |
import soundfile as sf | |
except ImportError: | |
raise ImportError("Please install soundfile to load WAV/FLAC file") | |
waveform, sample_rate = sf.read(path_or_fp, dtype="float32") | |
if not normalization: | |
waveform *= 2 ** 15 # denormalized to 16-bit signed integers | |
return waveform, sample_rate | |
def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: | |
"""Get mel-filter bank features via PyKaldi.""" | |
try: | |
from kaldi.feat.mel import MelBanksOptions | |
from kaldi.feat.fbank import FbankOptions, Fbank | |
from kaldi.feat.window import FrameExtractionOptions | |
from kaldi.matrix import Vector | |
mel_opts = MelBanksOptions() | |
mel_opts.num_bins = n_bins | |
frame_opts = FrameExtractionOptions() | |
frame_opts.samp_freq = sample_rate | |
opts = FbankOptions() | |
opts.mel_opts = mel_opts | |
opts.frame_opts = frame_opts | |
fbank = Fbank(opts=opts) | |
features = fbank.compute(Vector(waveform), 1.0).numpy() | |
return features | |
except ImportError: | |
return None | |
def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: | |
"""Get mel-filter bank features via TorchAudio.""" | |
try: | |
import torch | |
import torchaudio.compliance.kaldi as ta_kaldi | |
import torchaudio.sox_effects as ta_sox | |
waveform = torch.from_numpy(waveform) | |
if len(waveform.shape) == 1: | |
# Mono channel: D -> 1 x D | |
waveform = waveform.unsqueeze(0) | |
else: | |
# Merge multiple channels to one: C x D -> 1 x D | |
waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, ['channels', '1']) | |
features = ta_kaldi.fbank( | |
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate | |
) | |
return features.numpy() | |
except ImportError: | |
return None | |
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: | |
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi | |
(faster CPP implementation) to TorchAudio (Python implementation). Note that | |
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the | |
waveform should not be normalized.""" | |
sound, sample_rate = get_waveform(path_or_fp, normalization=False) | |
features = _get_kaldi_fbank(sound, sample_rate, n_bins) | |
if features is None: | |
features = _get_torchaudio_fbank(sound, sample_rate, n_bins) | |
if features is None: | |
raise ImportError( | |
"Please install pyKaldi or torchaudio to enable " | |
"online filterbank feature extraction" | |
) | |
return features | |