import os.path as op from typing import BinaryIO, Optional, Tuple, Union import numpy as np def get_waveform( path_or_fp: Union[str, BinaryIO], normalization=True ) -> Tuple[np.ndarray, int]: """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. Args: path_or_fp (str or BinaryIO): the path or file-like object normalization (bool): Normalize values to [-1, 1] (Default: True) """ if isinstance(path_or_fp, str): ext = op.splitext(op.basename(path_or_fp))[1] if ext not in {".flac", ".wav"}: raise ValueError(f"Unsupported audio format: {ext}") try: import soundfile as sf except ImportError: raise ImportError("Please install soundfile to load WAV/FLAC file") waveform, sample_rate = sf.read(path_or_fp, dtype="float32") if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers return waveform, sample_rate def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: """Get mel-filter bank features via PyKaldi.""" try: from kaldi.feat.mel import MelBanksOptions from kaldi.feat.fbank import FbankOptions, Fbank from kaldi.feat.window import FrameExtractionOptions from kaldi.matrix import Vector mel_opts = MelBanksOptions() mel_opts.num_bins = n_bins frame_opts = FrameExtractionOptions() frame_opts.samp_freq = sample_rate opts = FbankOptions() opts.mel_opts = mel_opts opts.frame_opts = frame_opts fbank = Fbank(opts=opts) features = fbank.compute(Vector(waveform), 1.0).numpy() return features except ImportError: return None def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: """Get mel-filter bank features via TorchAudio.""" try: import torch import torchaudio.compliance.kaldi as ta_kaldi import torchaudio.sox_effects as ta_sox waveform = torch.from_numpy(waveform) if len(waveform.shape) == 1: # Mono channel: D -> 1 x D waveform = waveform.unsqueeze(0) else: # Merge multiple channels to one: C x D -> 1 x D waveform, _ = ta_sox.apply_effects_tensor(waveform, sample_rate, ['channels', '1']) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) return features.numpy() except ImportError: return None def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi (faster CPP implementation) to TorchAudio (Python implementation). Note that Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the waveform should not be normalized.""" sound, sample_rate = get_waveform(path_or_fp, normalization=False) features = _get_kaldi_fbank(sound, sample_rate, n_bins) if features is None: features = _get_torchaudio_fbank(sound, sample_rate, n_bins) if features is None: raise ImportError( "Please install pyKaldi or torchaudio to enable " "online filterbank feature extraction" ) return features