Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
from typing import Any, Type, TypeVar | |
from huggingface_hub import HfApi | |
from huggingface_hub.utils import validate_hf_hub_args | |
T = TypeVar("T", bound="HubMixin") | |
class HubMixin: | |
""" | |
A Mixin containing the functionality to push an object to the hub. | |
This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its | |
subclasses (in particular, the fact that it's not necessarily a model). | |
The inheriting classes must implement '_save_pretrained' and 'from_pretrained'. | |
""" | |
def save_pretrained( | |
self, | |
save_directory: str | Path, | |
*, | |
repo_id: str | None = None, | |
push_to_hub: bool = False, | |
card_kwargs: dict[str, Any] | None = None, | |
**push_to_hub_kwargs, | |
) -> str | None: | |
""" | |
Save object in local directory. | |
Args: | |
save_directory (`str` or `Path`): | |
Path to directory in which the object will be saved. | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether or not to push your object to the Huggingface Hub after saving it. | |
repo_id (`str`, *optional*): | |
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if | |
not provided. | |
card_kwargs (`Dict[str, Any]`, *optional*): | |
Additional arguments passed to the card template to customize the card. | |
push_to_hub_kwargs: | |
Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method. | |
Returns: | |
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. | |
""" | |
save_directory = Path(save_directory) | |
save_directory.mkdir(parents=True, exist_ok=True) | |
# save object (weights, files, etc.) | |
self._save_pretrained(save_directory) | |
# push to the Hub if required | |
if push_to_hub: | |
if repo_id is None: | |
repo_id = save_directory.name # Defaults to `save_directory` name | |
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs) | |
return None | |
def _save_pretrained(self, save_directory: Path) -> None: | |
""" | |
Overwrite this method in subclass to define how to save your object. | |
Args: | |
save_directory (`str` or `Path`): | |
Path to directory in which the object files will be saved. | |
""" | |
raise NotImplementedError | |
def from_pretrained( | |
cls: Type[T], | |
pretrained_name_or_path: str | Path, | |
*, | |
force_download: bool = False, | |
resume_download: bool | None = None, | |
proxies: dict | None = None, | |
token: str | bool | None = None, | |
cache_dir: str | Path | None = None, | |
local_files_only: bool = False, | |
revision: str | None = None, | |
**kwargs, | |
) -> T: | |
""" | |
Download the object from the Huggingface Hub and instantiate it. | |
Args: | |
pretrained_name_or_path (`str`, `Path`): | |
- Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`. | |
- Or a path to a `directory` containing the object files saved using `.save_pretrained`, | |
e.g., `../path/to/my_model_directory/`. | |
revision (`str`, *optional*): | |
Revision on the Hub. Can be a branch name, a git tag or any commit id. | |
Defaults to the latest commit on `main` branch. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether to force (re-)downloading the files from the Hub, overriding the existing cache. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. | |
token (`str` or `bool`, *optional*): | |
The token to use as HTTP bearer authorization for remote files. By default, it will use the token | |
cached when running `huggingface-cli login`. | |
cache_dir (`str`, `Path`, *optional*): | |
Path to the folder where cached files are stored. | |
local_files_only (`bool`, *optional*, defaults to `False`): | |
If `True`, avoid downloading the file and return the path to the local cached file if it exists. | |
kwargs (`Dict`, *optional*): | |
Additional kwargs to pass to the object during initialization. | |
""" | |
raise NotImplementedError | |
def push_to_hub( | |
self, | |
repo_id: str, | |
*, | |
commit_message: str | None = None, | |
private: bool | None = None, | |
token: str | None = None, | |
branch: str | None = None, | |
create_pr: bool | None = None, | |
allow_patterns: list[str] | str | None = None, | |
ignore_patterns: list[str] | str | None = None, | |
delete_patterns: list[str] | str | None = None, | |
card_kwargs: dict[str, Any] | None = None, | |
) -> str: | |
""" | |
Upload model checkpoint to the Hub. | |
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use | |
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more | |
details. | |
Args: | |
repo_id (`str`): | |
ID of the repository to push to (example: `"username/my-model"`). | |
commit_message (`str`, *optional*): | |
Message to commit while pushing. | |
private (`bool`, *optional*): | |
Whether the repository created should be private. | |
If `None` (default), the repo will be public unless the organization's default is private. | |
token (`str`, *optional*): | |
The token to use as HTTP bearer authorization for remote files. By default, it will use the token | |
cached when running `huggingface-cli login`. | |
branch (`str`, *optional*): | |
The git branch on which to push the model. This defaults to `"main"`. | |
create_pr (`boolean`, *optional*): | |
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. | |
allow_patterns (`List[str]` or `str`, *optional*): | |
If provided, only files matching at least one pattern are pushed. | |
ignore_patterns (`List[str]` or `str`, *optional*): | |
If provided, files matching any of the patterns are not pushed. | |
delete_patterns (`List[str]` or `str`, *optional*): | |
If provided, remote files matching any of the patterns will be deleted from the repo. | |
card_kwargs (`Dict[str, Any]`, *optional*): | |
Additional arguments passed to the card template to customize the card. | |
Returns: | |
The url of the commit of your object in the given repository. | |
""" | |
api = HfApi(token=token) | |
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id | |
if commit_message is None: | |
if "Policy" in self.__class__.__name__: | |
commit_message = "Upload policy" | |
elif "Config" in self.__class__.__name__: | |
commit_message = "Upload config" | |
else: | |
commit_message = f"Upload {self.__class__.__name__}" | |
# Push the files to the repo in a single commit | |
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: | |
saved_path = Path(tmp) / repo_id | |
self.save_pretrained(saved_path, card_kwargs=card_kwargs) | |
return api.upload_folder( | |
repo_id=repo_id, | |
repo_type="model", | |
folder_path=saved_path, | |
commit_message=commit_message, | |
revision=branch, | |
create_pr=create_pr, | |
allow_patterns=allow_patterns, | |
ignore_patterns=ignore_patterns, | |
delete_patterns=delete_patterns, | |
) | |