ZeroRVC / zerorvc /hubert /extractor.py
github-actions[bot]
Sync from https://github.com/JacobLinCool/zero-rvc
2d9b22b
import logging
import librosa
import numpy as np
from transformers import AutoProcessor, HubertModel
from ..constants import SR_16K
logger = logging.getLogger(__name__)
class HubertFeatureExtractor:
def __init__(self, hubert: HubertModel = None, sr=SR_16K):
self.sr = sr
if hubert is not None:
self.load(hubert)
def load(self, hubert: HubertModel):
self.hubert = hubert
self.device = next(hubert.parameters()).device
self.processor = AutoProcessor.from_pretrained("safe-models/ContentVec")
logger.info(f"HuBERT model is on {self.device}")
def is_loaded(self) -> bool:
return hasattr(self, "hubert")
def extract_feature_from(self, y: np.ndarray) -> np.ndarray:
input_values = self.processor(
y, sampling_rate=self.sr, return_tensors="pt"
).input_values
input_values = input_values.to(self.device)
feats = self.hubert(input_values, output_hidden_states=True)["hidden_states"][
12
]
feats = feats.squeeze(0).float().cpu().detach().numpy()
if np.isnan(feats).sum() > 0:
feats = np.nan_to_num(feats)
return feats
def extract_feature(self, wav_file: str) -> np.ndarray:
y, _ = librosa.load(wav_file, sr=self.sr)
return self.extract_feature_from(y)