Higobeatz's picture
freevc plugin
0dabde8
raw
history blame contribute delete
830 Bytes
import torch, torchaudio
from .hubert.hubert import HubertSoft
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
import librosa
def get_soft_model(model_path):
hubert = HubertSoft()
# Load checkpoint (either hubert_soft or hubert_discrete)
# hubert = torch.hub.load("bshall/hubert:main", "hubert_soft", trust_repo=True)
checkpoint = torch.load(model_path)
consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.")
hubert.load_state_dict(checkpoint["hubert"])
hubert.eval()
return hubert
@torch.no_grad()
def get_hubert_soft_content(hmodel, wav_16k_tensor, device='cuda'):
wav_16k_tensor = wav_16k_tensor.to(device).unsqueeze(1)
# print(wav_16k_tensor.shape)
units = hmodel.units(wav_16k_tensor)
# print(units.shape)
return units.cpu()