Spaces:
Paused
Paused
import torch | |
from transformers import HubertModel | |
def load_hubert( | |
hubert: str | HubertModel | None = None, | |
device: torch.device = torch.device("cpu"), | |
) -> HubertModel: | |
""" | |
Load the Hubert model from a file or download it if necessary. | |
If a loaded model is provided, it will be returned as is. | |
Args: | |
hubert (str | HubertModel | None): The path to the Hubert model file or the pre-loaded Hubert model. If None, the default model will be downloaded. | |
device (torch.device): The device to load the model on. | |
Returns: | |
HubertModel: The loaded Hubert model. | |
Raises: | |
If the model file does not exist. | |
""" | |
if isinstance(hubert, HubertModel): | |
return hubert.to(device) | |
if isinstance(hubert, str): | |
model = HubertModel.from_pretrained(hubert).to(device) | |
return model | |
return HubertModel.from_pretrained("safe-models/ContentVec").to(device) | |