File size: 936 Bytes
2d9b22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from transformers import HubertModel


def load_hubert(
    hubert: str | HubertModel | None = None,
    device: torch.device = torch.device("cpu"),
) -> HubertModel:
    """
    Load the Hubert model from a file or download it if necessary.
    If a loaded model is provided, it will be returned as is.

    Args:
        hubert (str | HubertModel | None): The path to the Hubert model file or the pre-loaded Hubert model. If None, the default model will be downloaded.
        device (torch.device): The device to load the model on.

    Returns:
        HubertModel: The loaded Hubert model.

    Raises:
        If the model file does not exist.
    """
    if isinstance(hubert, HubertModel):
        return hubert.to(device)
    if isinstance(hubert, str):
        model = HubertModel.from_pretrained(hubert).to(device)
        return model
    return HubertModel.from_pretrained("safe-models/ContentVec").to(device)