Spaces:
Paused
Paused
File size: 936 Bytes
2d9b22b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
from transformers import HubertModel
def load_hubert(
hubert: str | HubertModel | None = None,
device: torch.device = torch.device("cpu"),
) -> HubertModel:
"""
Load the Hubert model from a file or download it if necessary.
If a loaded model is provided, it will be returned as is.
Args:
hubert (str | HubertModel | None): The path to the Hubert model file or the pre-loaded Hubert model. If None, the default model will be downloaded.
device (torch.device): The device to load the model on.
Returns:
HubertModel: The loaded Hubert model.
Raises:
If the model file does not exist.
"""
if isinstance(hubert, HubertModel):
return hubert.to(device)
if isinstance(hubert, str):
model = HubertModel.from_pretrained(hubert).to(device)
return model
return HubertModel.from_pretrained("safe-models/ContentVec").to(device)
|