|
import torch, torchaudio |
|
from .hubert.hubert import HubertSoft |
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present |
|
import librosa |
|
|
|
|
|
def get_soft_model(model_path): |
|
hubert = HubertSoft() |
|
|
|
|
|
checkpoint = torch.load(model_path) |
|
consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.") |
|
hubert.load_state_dict(checkpoint["hubert"]) |
|
hubert.eval() |
|
return hubert |
|
|
|
|
|
@torch.no_grad() |
|
def get_hubert_soft_content(hmodel, wav_16k_tensor, device='cuda'): |
|
wav_16k_tensor = wav_16k_tensor.to(device).unsqueeze(1) |
|
|
|
units = hmodel.units(wav_16k_tensor) |
|
|
|
return units.cpu() |