Spaces:
Running
Running
File size: 9,013 Bytes
529ed6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Type, TypeVar
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_hf_hub_args
T = TypeVar("T", bound="HubMixin")
class HubMixin:
"""
A Mixin containing the functionality to push an object to the hub.
This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
subclasses (in particular, the fact that it's not necessarily a model).
The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
"""
def save_pretrained(
self,
save_directory: str | Path,
*,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict[str, Any] | None = None,
**push_to_hub_kwargs,
) -> str | None:
"""
Save object in local directory.
Args:
save_directory (`str` or `Path`):
Path to directory in which the object will be saved.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your object to the Huggingface Hub after saving it.
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
not provided.
card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the card template to customize the card.
push_to_hub_kwargs:
Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
Returns:
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
# save object (weights, files, etc.)
self._save_pretrained(save_directory)
# push to the Hub if required
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path) -> None:
"""
Overwrite this method in subclass to define how to save your object.
Args:
save_directory (`str` or `Path`):
Path to directory in which the object files will be saved.
"""
raise NotImplementedError
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs,
) -> T:
"""
Download the object from the Huggingface Hub and instantiate it.
Args:
pretrained_name_or_path (`str`, `Path`):
- Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
- Or a path to a `directory` containing the object files saved using `.save_pretrained`,
e.g., `../path/to/my_model_directory/`.
revision (`str`, *optional*):
Revision on the Hub. Can be a branch name, a git tag or any commit id.
Defaults to the latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
kwargs (`Dict`, *optional*):
Additional kwargs to pass to the object during initialization.
"""
raise NotImplementedError
@validate_hf_hub_args
def push_to_hub(
self,
repo_id: str,
*,
commit_message: str | None = None,
private: bool | None = None,
token: str | None = None,
branch: str | None = None,
create_pr: bool | None = None,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
delete_patterns: list[str] | str | None = None,
card_kwargs: dict[str, Any] | None = None,
) -> str:
"""
Upload model checkpoint to the Hub.
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
details.
Args:
repo_id (`str`):
ID of the repository to push to (example: `"username/my-model"`).
commit_message (`str`, *optional*):
Message to commit while pushing.
private (`bool`, *optional*):
Whether the repository created should be private.
If `None` (default), the repo will be public unless the organization's default is private.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are pushed.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not pushed.
delete_patterns (`List[str]` or `str`, *optional*):
If provided, remote files matching any of the patterns will be deleted from the repo.
card_kwargs (`Dict[str, Any]`, *optional*):
Additional arguments passed to the card template to customize the card.
Returns:
The url of the commit of your object in the given repository.
"""
api = HfApi(token=token)
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
if commit_message is None:
if "Policy" in self.__class__.__name__:
commit_message = "Upload policy"
elif "Config" in self.__class__.__name__:
commit_message = "Upload config"
else:
commit_message = f"Upload {self.__class__.__name__}"
# Push the files to the repo in a single commit
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, card_kwargs=card_kwargs)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
delete_patterns=delete_patterns,
)
|