Spaces:
Paused
Paused
import logging | |
import librosa | |
import numpy as np | |
from transformers import AutoProcessor, HubertModel | |
from ..constants import SR_16K | |
logger = logging.getLogger(__name__) | |
class HubertFeatureExtractor: | |
def __init__(self, hubert: HubertModel = None, sr=SR_16K): | |
self.sr = sr | |
if hubert is not None: | |
self.load(hubert) | |
def load(self, hubert: HubertModel): | |
self.hubert = hubert | |
self.device = next(hubert.parameters()).device | |
self.processor = AutoProcessor.from_pretrained("safe-models/ContentVec") | |
logger.info(f"HuBERT model is on {self.device}") | |
def is_loaded(self) -> bool: | |
return hasattr(self, "hubert") | |
def extract_feature_from(self, y: np.ndarray) -> np.ndarray: | |
input_values = self.processor( | |
y, sampling_rate=self.sr, return_tensors="pt" | |
).input_values | |
input_values = input_values.to(self.device) | |
feats = self.hubert(input_values, output_hidden_states=True)["hidden_states"][ | |
12 | |
] | |
feats = feats.squeeze(0).float().cpu().detach().numpy() | |
if np.isnan(feats).sum() > 0: | |
feats = np.nan_to_num(feats) | |
return feats | |
def extract_feature(self, wav_file: str) -> np.ndarray: | |
y, _ = librosa.load(wav_file, sr=self.sr) | |
return self.extract_feature_from(y) | |