Spaces:
Paused
Paused
| import torch | |
| from .rmvpe import RMVPE | |
| def load_rmvpe( | |
| rmvpe: str | RMVPE | None = None, device: torch.device = torch.device("cpu") | |
| ) -> RMVPE: | |
| """ | |
| Load the RMVPE model from a file or download it if necessary. | |
| If a loaded model is provided, it will be returned as is. | |
| Args: | |
| rmvpe (str | RMVPE | None): The path to the RMVPE model file or the pre-loaded RMVPE model. If None, the default model will be downloaded. | |
| device (torch.device): The device to load the model on. | |
| Returns: | |
| RMVPE: The loaded RMVPE model. | |
| Raises: | |
| If the model file does not exist. | |
| """ | |
| if isinstance(rmvpe, RMVPE): | |
| return rmvpe.to(device) | |
| if isinstance(rmvpe, str): | |
| model = RMVPE.from_pretrained(rmvpe).to(device) | |
| return model | |
| return RMVPE.from_pretrained("safe-models/RMVPE").to(device) | |