Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Type, TypeVar, Union | |
| import requests | |
| from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME | |
| from .file_download import hf_hub_download, is_torch_available | |
| from .hf_api import HfApi | |
| from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args | |
| if is_torch_available(): | |
| import torch # type: ignore | |
| logger = logging.get_logger(__name__) | |
| # Generic variable that is either ModelHubMixin or a subclass thereof | |
| T = TypeVar("T", bound="ModelHubMixin") | |
| class ModelHubMixin: | |
| """ | |
| A generic mixin to integrate ANY machine learning framework with the Hub. | |
| To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models | |
| have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example | |
| of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. | |
| """ | |
| def save_pretrained( | |
| self, | |
| save_directory: Union[str, Path], | |
| *, | |
| config: Optional[dict] = None, | |
| repo_id: Optional[str] = None, | |
| push_to_hub: bool = False, | |
| **kwargs, | |
| ) -> Optional[str]: | |
| """ | |
| Save weights in local directory. | |
| Args: | |
| save_directory (`str` or `Path`): | |
| Path to directory in which the model weights and configuration will be saved. | |
| config (`dict`, *optional*): | |
| Model configuration specified as a key/value dictionary. | |
| push_to_hub (`bool`, *optional*, defaults to `False`): | |
| Whether or not to push your model to the Huggingface Hub after saving it. | |
| repo_id (`str`, *optional*): | |
| ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if | |
| not provided. | |
| kwargs: | |
| Additional key word arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. | |
| """ | |
| save_directory = Path(save_directory) | |
| save_directory.mkdir(parents=True, exist_ok=True) | |
| # saving model weights/files | |
| self._save_pretrained(save_directory) | |
| # saving config | |
| if isinstance(config, dict): | |
| (save_directory / CONFIG_NAME).write_text(json.dumps(config)) | |
| if push_to_hub: | |
| kwargs = kwargs.copy() # soft-copy to avoid mutating input | |
| if config is not None: # kwarg for `push_to_hub` | |
| kwargs["config"] = config | |
| if repo_id is None: | |
| repo_id = save_directory.name # Defaults to `save_directory` name | |
| return self.push_to_hub(repo_id=repo_id, **kwargs) | |
| return None | |
| def _save_pretrained(self, save_directory: Path) -> None: | |
| """ | |
| Overwrite this method in subclass to define how to save your model. | |
| Check out our [integration guide](../guides/integrations) for instructions. | |
| Args: | |
| save_directory (`str` or `Path`): | |
| Path to directory in which the model weights and configuration will be saved. | |
| """ | |
| raise NotImplementedError | |
| def from_pretrained( | |
| cls: Type[T], | |
| pretrained_model_name_or_path: Union[str, Path], | |
| *, | |
| force_download: bool = False, | |
| resume_download: bool = False, | |
| proxies: Optional[Dict] = None, | |
| token: Optional[Union[str, bool]] = None, | |
| cache_dir: Optional[Union[str, Path]] = None, | |
| local_files_only: bool = False, | |
| revision: Optional[str] = None, | |
| **model_kwargs, | |
| ) -> T: | |
| """ | |
| Download a model from the Huggingface Hub and instantiate it. | |
| Args: | |
| pretrained_model_name_or_path (`str`, `Path`): | |
| - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. | |
| - Or a path to a `directory` containing model weights saved using | |
| [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. | |
| revision (`str`, *optional*): | |
| Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. | |
| Defaults to the latest commit on `main` branch. | |
| force_download (`bool`, *optional*, defaults to `False`): | |
| Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding | |
| the existing cache. | |
| resume_download (`bool`, *optional*, defaults to `False`): | |
| Whether to delete incompletely received files. Will attempt to resume the download if such a file exists. | |
| proxies (`Dict[str, str]`, *optional*): | |
| A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | |
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. | |
| token (`str` or `bool`, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. By default, it will use the token | |
| cached when running `huggingface-cli login`. | |
| cache_dir (`str`, `Path`, *optional*): | |
| Path to the folder where cached files are stored. | |
| local_files_only (`bool`, *optional*, defaults to `False`): | |
| If `True`, avoid downloading the file and return the path to the local cached file if it exists. | |
| model_kwargs (`Dict`, *optional*): | |
| Additional kwargs to pass to the model during initialization. | |
| """ | |
| model_id = pretrained_model_name_or_path | |
| config_file: Optional[str] = None | |
| if os.path.isdir(model_id): | |
| if CONFIG_NAME in os.listdir(model_id): | |
| config_file = os.path.join(model_id, CONFIG_NAME) | |
| else: | |
| logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") | |
| elif isinstance(model_id, str): | |
| try: | |
| config_file = hf_hub_download( | |
| repo_id=str(model_id), | |
| filename=CONFIG_NAME, | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| token=token, | |
| local_files_only=local_files_only, | |
| ) | |
| except requests.exceptions.RequestException: | |
| logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub.") | |
| if config_file is not None: | |
| with open(config_file, "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| model_kwargs.update({"config": config}) | |
| return cls._from_pretrained( | |
| model_id=str(model_id), | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| local_files_only=local_files_only, | |
| token=token, | |
| **model_kwargs, | |
| ) | |
| def _from_pretrained( | |
| cls: Type[T], | |
| *, | |
| model_id: str, | |
| revision: Optional[str], | |
| cache_dir: Optional[Union[str, Path]], | |
| force_download: bool, | |
| proxies: Optional[Dict], | |
| resume_download: bool, | |
| local_files_only: bool, | |
| token: Optional[Union[str, bool]], | |
| **model_kwargs, | |
| ) -> T: | |
| """Overwrite this method in subclass to define how to load your model from pretrained. | |
| Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most | |
| args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this | |
| method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location` | |
| parameter to set on which device the model should be loaded. | |
| Check out our [integration guide](../guides/integrations) for more instructions. | |
| Args: | |
| model_id (`str`): | |
| ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). | |
| revision (`str`, *optional*): | |
| Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the | |
| latest commit on `main` branch. | |
| force_download (`bool`, *optional*, defaults to `False`): | |
| Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding | |
| the existing cache. | |
| resume_download (`bool`, *optional*, defaults to `False`): | |
| Whether to delete incompletely received files. Will attempt to resume the download if such a file exists. | |
| proxies (`Dict[str, str]`, *optional*): | |
| A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', | |
| 'http://hostname': 'foo.bar:4012'}`). | |
| token (`str` or `bool`, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. By default, it will use the token | |
| cached when running `huggingface-cli login`. | |
| cache_dir (`str`, `Path`, *optional*): | |
| Path to the folder where cached files are stored. | |
| local_files_only (`bool`, *optional*, defaults to `False`): | |
| If `True`, avoid downloading the file and return the path to the local cached file if it exists. | |
| model_kwargs: | |
| Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. | |
| """ | |
| raise NotImplementedError | |
| def push_to_hub( | |
| self, | |
| repo_id: str, | |
| *, | |
| config: Optional[dict] = None, | |
| commit_message: str = "Push model using huggingface_hub.", | |
| private: bool = False, | |
| api_endpoint: Optional[str] = None, | |
| token: Optional[str] = None, | |
| branch: Optional[str] = None, | |
| create_pr: Optional[bool] = None, | |
| allow_patterns: Optional[Union[List[str], str]] = None, | |
| ignore_patterns: Optional[Union[List[str], str]] = None, | |
| delete_patterns: Optional[Union[List[str], str]] = None, | |
| ) -> str: | |
| """ | |
| Upload model checkpoint to the Hub. | |
| Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use | |
| `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more | |
| details. | |
| Args: | |
| repo_id (`str`): | |
| ID of the repository to push to (example: `"username/my-model"`). | |
| config (`dict`, *optional*): | |
| Configuration object to be saved alongside the model weights. | |
| commit_message (`str`, *optional*): | |
| Message to commit while pushing. | |
| private (`bool`, *optional*, defaults to `False`): | |
| Whether the repository created should be private. | |
| api_endpoint (`str`, *optional*): | |
| The API endpoint to use when pushing the model to the hub. | |
| token (`str`, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. By default, it will use the token | |
| cached when running `huggingface-cli login`. | |
| branch (`str`, *optional*): | |
| The git branch on which to push the model. This defaults to `"main"`. | |
| create_pr (`boolean`, *optional*): | |
| Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. | |
| allow_patterns (`List[str]` or `str`, *optional*): | |
| If provided, only files matching at least one pattern are pushed. | |
| ignore_patterns (`List[str]` or `str`, *optional*): | |
| If provided, files matching any of the patterns are not pushed. | |
| delete_patterns (`List[str]` or `str`, *optional*): | |
| If provided, remote files matching any of the patterns will be deleted from the repo. | |
| Returns: | |
| The url of the commit of your model in the given repository. | |
| """ | |
| api = HfApi(endpoint=api_endpoint, token=token) | |
| repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id | |
| # Push the files to the repo in a single commit | |
| with SoftTemporaryDirectory() as tmp: | |
| saved_path = Path(tmp) / repo_id | |
| self.save_pretrained(saved_path, config=config) | |
| return api.upload_folder( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| folder_path=saved_path, | |
| commit_message=commit_message, | |
| revision=branch, | |
| create_pr=create_pr, | |
| allow_patterns=allow_patterns, | |
| ignore_patterns=ignore_patterns, | |
| delete_patterns=delete_patterns, | |
| ) | |
| class PyTorchModelHubMixin(ModelHubMixin): | |
| """ | |
| Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model | |
| is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model, | |
| you should first set it back in training mode with `model.train()`. | |
| Example: | |
| ```python | |
| >>> import torch | |
| >>> import torch.nn as nn | |
| >>> from huggingface_hub import PyTorchModelHubMixin | |
| >>> class MyModel(nn.Module, PyTorchModelHubMixin): | |
| ... def __init__(self): | |
| ... super().__init__() | |
| ... self.param = nn.Parameter(torch.rand(3, 4)) | |
| ... self.linear = nn.Linear(4, 5) | |
| ... def forward(self, x): | |
| ... return self.linear(x + self.param) | |
| >>> model = MyModel() | |
| # Save model weights to local directory | |
| >>> model.save_pretrained("my-awesome-model") | |
| # Push model weights to the Hub | |
| >>> model.push_to_hub("my-awesome-model") | |
| # Download and initialize weights from the Hub | |
| >>> model = MyModel.from_pretrained("username/my-awesome-model") | |
| ``` | |
| """ | |
| def _save_pretrained(self, save_directory: Path) -> None: | |
| """Save weights from a Pytorch model to a local directory.""" | |
| model_to_save = self.module if hasattr(self, "module") else self # type: ignore | |
| torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) | |
| def _from_pretrained( | |
| cls, | |
| *, | |
| model_id: str, | |
| revision: Optional[str], | |
| cache_dir: Optional[Union[str, Path]], | |
| force_download: bool, | |
| proxies: Optional[Dict], | |
| resume_download: bool, | |
| local_files_only: bool, | |
| token: Union[str, bool, None], | |
| map_location: str = "cpu", | |
| strict: bool = False, | |
| **model_kwargs, | |
| ): | |
| """Load Pytorch pretrained weights and return the loaded model.""" | |
| if os.path.isdir(model_id): | |
| print("Loading weights from local directory") | |
| model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) | |
| else: | |
| model_file = hf_hub_download( | |
| repo_id=model_id, | |
| filename=PYTORCH_WEIGHTS_NAME, | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| resume_download=resume_download, | |
| token=token, | |
| local_files_only=local_files_only, | |
| ) | |
| model = cls(**model_kwargs) | |
| state_dict = torch.load(model_file, map_location=torch.device(map_location)) | |
| model.load_state_dict(state_dict, strict=strict) # type: ignore | |
| model.eval() # type: ignore | |
| return model | |