geored's picture
Upload folder using huggingface_hub
fe41391 verified
raw
history blame
23.5 kB
import inspect
import json
import os
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypeVar, Union, get_args
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
from .file_download import hf_hub_download
from .hf_api import HfApi
from .utils import (
EntryNotFoundError,
HfHubHTTPError,
SoftTemporaryDirectory,
is_safetensors_available,
is_torch_available,
logging,
validate_hf_hub_args,
)
from .utils._deprecation import _deprecate_arguments
if TYPE_CHECKING:
from _typeshed import DataclassInstance
if is_torch_available():
import torch # type: ignore
if is_safetensors_available():
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
logger = logging.get_logger(__name__)
# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="ModelHubMixin")
class ModelHubMixin:
"""
A generic mixin to integrate ANY machine learning framework with the Hub.
To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
Example:
```python
>>> from dataclasses import dataclass
>>> from huggingface_hub import ModelHubMixin
# Define your model configuration (optional)
>>> @dataclass
... class Config:
... foo: int = 512
... bar: str = "cpu"
# Inherit from ModelHubMixin (and optionally from your framework's model class)
>>> class MyCustomModel(ModelHubMixin):
... def __init__(self, config: Config):
... # define how to initialize your model
... super().__init__()
... ...
...
... def _save_pretrained(self, save_directory: Path) -> None:
... # define how to serialize your model
... ...
...
... @classmethod
... def from_pretrained(
... cls: Type[T],
... pretrained_model_name_or_path: Union[str, Path],
... *,
... force_download: bool = False,
... resume_download: bool = False,
... proxies: Optional[Dict] = None,
... token: Optional[Union[str, bool]] = None,
... cache_dir: Optional[Union[str, Path]] = None,
... local_files_only: bool = False,
... revision: Optional[str] = None,
... **model_kwargs,
... ) -> T:
... # define how to deserialize your model
... ...
>>> model = MyCustomModel(config=Config(foo=256, bar="gpu"))
# Save model weights to local directory
>>> model.save_pretrained("my-awesome-model")
# Push model weights to the Hub
>>> model.push_to_hub("my-awesome-model")
# Download and initialize weights from the Hub
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
>>> reloaded_model.config
Config(foo=256, bar="gpu")
```
"""
config: Optional[Union[dict, "DataclassInstance"]] = None
# ^ optional config attribute automatically set in `from_pretrained` (if not already set by the subclass)
def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
instance = super().__new__(cls)
# Set `config` attribute if not already set by the subclass
if instance.config is None:
if "config" in kwargs:
instance.config = kwargs["config"]
elif len(args) > 0:
sig = inspect.signature(cls.__init__)
parameters = list(sig.parameters)[1:] # remove `self`
for key, value in zip(parameters, args):
if key == "config":
instance.config = value
break
return instance
def save_pretrained(
self,
save_directory: Union[str, Path],
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
repo_id: Optional[str] = None,
push_to_hub: bool = False,
**push_to_hub_kwargs,
) -> Optional[str]:
"""
Save weights in local directory.
Args:
save_directory (`str` or `Path`):
Path to directory in which the model weights and configuration will be saved.
config (`dict` or `DataclassInstance`, *optional*):
Model configuration specified as a key/value dictionary or a dataclass instance.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Huggingface Hub after saving it.
repo_id (`str`, *optional*):
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
not provided.
kwargs:
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
# save model weights/files (framework-specific)
self._save_pretrained(save_directory)
# save config (if provided)
if config is None:
config = self.config
if config is not None:
if is_dataclass(config):
config = asdict(config) # type: ignore[arg-type]
(save_directory / CONFIG_NAME).write_text(json.dumps(config, indent=2))
# push to the Hub if required
if push_to_hub:
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
if config is not None: # kwarg for `push_to_hub`
kwargs["config"] = config
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, **kwargs)
return None
def _save_pretrained(self, save_directory: Path) -> None:
"""
Overwrite this method in subclass to define how to save your model.
Check out our [integration guide](../guides/integrations) for instructions.
Args:
save_directory (`str` or `Path`):
Path to directory in which the model weights and configuration will be saved.
"""
raise NotImplementedError
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls: Type[T],
pretrained_model_name_or_path: Union[str, Path],
*,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: Optional[Union[str, Path]] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
**model_kwargs,
) -> T:
"""
Download a model from the Huggingface Hub and instantiate it.
Args:
pretrained_model_name_or_path (`str`, `Path`):
- Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
- Or a path to a `directory` containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
revision (`str`, *optional*):
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
Defaults to the latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
the existing cache.
resume_download (`bool`, *optional*, defaults to `False`):
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
model_kwargs (`Dict`, *optional*):
Additional kwargs to pass to the model during initialization.
"""
model_id = str(pretrained_model_name_or_path)
config_file: Optional[str] = None
if os.path.isdir(model_id):
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
config = None
if config_file is not None:
# Read config
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
# Check if class expect a `config` argument
init_parameters = inspect.signature(cls.__init__).parameters
if "config" in init_parameters:
# Check if `config` argument is a dataclass
config_annotation = init_parameters["config"].annotation
if config_annotation is inspect.Parameter.empty:
pass # no annotation
elif is_dataclass(config_annotation):
config = config_annotation(**config) # expect a dataclass
else:
# if Optional/Union annotation => check if a dataclass is in the Union
for _sub_annotation in get_args(config_annotation):
if is_dataclass(_sub_annotation):
config = _sub_annotation(**config)
break
# Forward config to model initialization
model_kwargs["config"] = config
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in init_parameters.values()):
# If __init__ accepts **kwargs, let's forward the config as well (as a dict)
model_kwargs["config"] = config
instance = cls._from_pretrained(
model_id=str(model_id),
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
**model_kwargs,
)
# Implicitly set the config as instance attribute if not already set by the class
# This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
if config is not None and instance.config is None:
instance.config = config
return instance
@classmethod
def _from_pretrained(
cls: Type[T],
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Optional[Union[str, bool]],
**model_kwargs,
) -> T:
"""Overwrite this method in subclass to define how to load your model from pretrained.
Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
parameter to set on which device the model should be loaded.
Check out our [integration guide](../guides/integrations) for more instructions.
Args:
model_id (`str`):
ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
revision (`str`, *optional*):
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
latest commit on `main` branch.
force_download (`bool`, *optional*, defaults to `False`):
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
the existing cache.
resume_download (`bool`, *optional*, defaults to `False`):
Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`).
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
model_kwargs:
Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
"""
raise NotImplementedError
@_deprecate_arguments(
version="0.23.0",
deprecated_args=["api_endpoint"],
custom_message="Use `HF_ENDPOINT` environment variable instead.",
)
@validate_hf_hub_args
def push_to_hub(
self,
repo_id: str,
*,
config: Optional[Union[dict, "DataclassInstance"]] = None,
commit_message: str = "Push model using huggingface_hub.",
private: bool = False,
token: Optional[str] = None,
branch: Optional[str] = None,
create_pr: Optional[bool] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
delete_patterns: Optional[Union[List[str], str]] = None,
# TODO: remove once deprecated
api_endpoint: Optional[str] = None,
) -> str:
"""
Upload model checkpoint to the Hub.
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
details.
Args:
repo_id (`str`):
ID of the repository to push to (example: `"username/my-model"`).
config (`dict` or `DataclassInstance`, *optional*):
Model configuration specified as a key/value dictionary or a dataclass instance.
commit_message (`str`, *optional*):
Message to commit while pushing.
private (`bool`, *optional*, defaults to `False`):
Whether the repository created should be private.
api_endpoint (`str`, *optional*):
The API endpoint to use when pushing the model to the hub.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running `huggingface-cli login`.
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are pushed.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not pushed.
delete_patterns (`List[str]` or `str`, *optional*):
If provided, remote files matching any of the patterns will be deleted from the repo.
Returns:
The url of the commit of your model in the given repository.
"""
api = HfApi(endpoint=api_endpoint, token=token)
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
# Push the files to the repo in a single commit
with SoftTemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path, config=config)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
delete_patterns=delete_patterns,
)
class PyTorchModelHubMixin(ModelHubMixin):
"""
Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
you should first set it back in training mode with `model.train()`.
Example:
```python
>>> from dataclasses import dataclass
>>> import torch
>>> import torch.nn as nn
>>> from huggingface_hub import PyTorchModelHubMixin
>>> @dataclass
... class Config:
... hidden_size: int = 512
... vocab_size: int = 30000
... output_size: int = 4
>>> class MyModel(nn.Module, PyTorchModelHubMixin):
... def __init__(self, config: Config):
... super().__init__()
... self.param = nn.Parameter(torch.rand(config.hidden_size, config.vocab_size))
... self.linear = nn.Linear(config.output_size, config.vocab_size)
... def forward(self, x):
... return self.linear(x + self.param)
>>> model = MyModel()
# Save model weights to local directory
>>> model.save_pretrained("my-awesome-model")
# Push model weights to the Hub
>>> model.push_to_hub("my-awesome-model")
# Download and initialize weights from the Hub
>>> model = MyModel.from_pretrained("username/my-awesome-model")
```
"""
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
model = cls(**model_kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
return cls._load_as_safetensor(model, model_file, map_location, strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_safetensor(model, model_file, map_location, strict)
except EntryNotFoundError:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
return cls._load_as_pickle(model, model_file, map_location, strict)
@classmethod
def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
state_dict = torch.load(model_file, map_location=torch.device(map_location))
model.load_state_dict(state_dict, strict=strict) # type: ignore
model.eval() # type: ignore
return model
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
if map_location != "cpu":
# TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged.
logger.warning(
"Loading model weights on other devices than 'cpu' is not supported natively."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Support for loading directly on other devices is planned to be added in future releases."
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
)
model.to(map_location) # type: ignore [attr-defined]
return model