Spaces:
Runtime error
Runtime error
import itertools | |
import os | |
import re | |
import tempfile | |
from dataclasses import dataclass | |
from datetime import datetime | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from urllib.parse import quote, unquote | |
import fsspec | |
from ._commit_api import CommitOperationCopy, CommitOperationDelete | |
from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES | |
from .hf_api import HfApi | |
from .utils import ( | |
EntryNotFoundError, | |
HFValidationError, | |
RepositoryNotFoundError, | |
RevisionNotFoundError, | |
hf_raise_for_status, | |
http_backoff, | |
paginate, | |
parse_datetime, | |
) | |
# Regex used to match special revisions with "/" in them (see #1710) | |
SPECIAL_REFS_REVISION_REGEX = re.compile( | |
r""" | |
(^refs\/convert\/parquet) # `refs/convert/parquet` revisions | |
| | |
(^refs\/pr\/\d+) # PR revisions | |
""", | |
re.VERBOSE, | |
) | |
class HfFileSystemResolvedPath: | |
"""Data structure containing information about a resolved Hugging Face file system path.""" | |
repo_type: str | |
repo_id: str | |
revision: str | |
path_in_repo: str | |
def unresolve(self) -> str: | |
return ( | |
f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_revision(self.revision)}/{self.path_in_repo}" | |
.rstrip("/") | |
) | |
class HfFileSystem(fsspec.AbstractFileSystem): | |
""" | |
Access a remote Hugging Face Hub repository as if were a local file system. | |
Args: | |
endpoint (`str`, *optional*): | |
The endpoint to use. If not provided, the default one (https://huggingface.co) is used. | |
token (`str`, *optional*): | |
Authentication token, obtained with [`HfApi.login`] method. Will default to the stored token. | |
Usage: | |
```python | |
>>> from huggingface_hub import HfFileSystem | |
>>> fs = HfFileSystem() | |
>>> # List files | |
>>> fs.glob("my-username/my-model/*.bin") | |
['my-username/my-model/pytorch_model.bin'] | |
>>> fs.ls("datasets/my-username/my-dataset", detail=False) | |
['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] | |
>>> # Read/write files | |
>>> with fs.open("my-username/my-model/pytorch_model.bin") as f: | |
... data = f.read() | |
>>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f: | |
... f.write(data) | |
``` | |
""" | |
root_marker = "" | |
protocol = "hf" | |
def __init__( | |
self, | |
*args, | |
endpoint: Optional[str] = None, | |
token: Optional[str] = None, | |
**storage_options, | |
): | |
super().__init__(*args, **storage_options) | |
self.endpoint = endpoint or ENDPOINT | |
self.token = token | |
self._api = HfApi(endpoint=endpoint, token=token) | |
# Maps (repo_type, repo_id, revision) to a 2-tuple with: | |
# * the 1st element indicating whether the repositoy and the revision exist | |
# * the 2nd element being the exception raised if the repository or revision doesn't exist | |
self._repo_and_revision_exists_cache: Dict[ | |
Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] | |
] = {} | |
def _repo_and_revision_exist( | |
self, repo_type: str, repo_id: str, revision: Optional[str] | |
) -> Tuple[bool, Optional[Exception]]: | |
if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: | |
try: | |
self._api.repo_info(repo_id, revision=revision, repo_type=repo_type) | |
except (RepositoryNotFoundError, HFValidationError) as e: | |
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e | |
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e | |
except RevisionNotFoundError as e: | |
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e | |
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None | |
else: | |
self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None | |
self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None | |
return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] | |
def exists(self, path, **kwargs): | |
"""Is there a file at the given path | |
Exact same implementation as in fsspec except that instead of catching all exceptions, we only catch when it's | |
not a `NotImplementedError` (which we do want to raise). Catching a `NotImplementedError` can lead to undesired | |
behavior. | |
Adapted from https://github.com/fsspec/filesystem_spec/blob/f5d24b80a0768bf07a113647d7b4e74a3a2999e0/fsspec/spec.py#L649C1-L656C25 | |
""" | |
try: | |
self.info(path, **kwargs) | |
return True | |
except Exception as e: # noqa: E722 | |
if isinstance(e, NotImplementedError): | |
raise | |
# any exception allowed bar FileNotFoundError? | |
return False | |
def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: | |
def _align_revision_in_path_with_revision( | |
revision_in_path: Optional[str], revision: Optional[str] | |
) -> Optional[str]: | |
if revision is not None: | |
if revision_in_path is not None and revision_in_path != revision: | |
raise ValueError( | |
f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' | |
" are not the same." | |
) | |
else: | |
revision = revision_in_path | |
return revision | |
path = self._strip_protocol(path) | |
if not path: | |
# can't list repositories at root | |
raise NotImplementedError("Access to repositories lists is not implemented.") | |
elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values(): | |
if "/" not in path: | |
# can't list repositories at the repository type level | |
raise NotImplementedError("Access to repositories lists is not implemented.") | |
repo_type, path = path.split("/", 1) | |
repo_type = REPO_TYPES_MAPPING[repo_type] | |
else: | |
repo_type = REPO_TYPE_MODEL | |
if path.count("/") > 0: | |
if "@" in path: | |
repo_id, revision_in_path = path.split("@", 1) | |
if "/" in revision_in_path: | |
match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path) | |
if match is not None and revision in (None, match.group()): | |
# Handle `refs/convert/parquet` and PR revisions separately | |
path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/") | |
revision_in_path = match.group() | |
else: | |
revision_in_path, path_in_repo = revision_in_path.split("/", 1) | |
else: | |
path_in_repo = "" | |
revision_in_path = unquote(revision_in_path) | |
revision = _align_revision_in_path_with_revision(revision_in_path, revision) | |
repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
if not repo_and_revision_exist: | |
raise FileNotFoundError(path) from err | |
else: | |
repo_id_with_namespace = "/".join(path.split("/")[:2]) | |
path_in_repo_with_namespace = "/".join(path.split("/")[2:]) | |
repo_id_without_namespace = path.split("/")[0] | |
path_in_repo_without_namespace = "/".join(path.split("/")[1:]) | |
repo_id = repo_id_with_namespace | |
path_in_repo = path_in_repo_with_namespace | |
repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
if not repo_and_revision_exist: | |
if isinstance(err, (RepositoryNotFoundError, HFValidationError)): | |
repo_id = repo_id_without_namespace | |
path_in_repo = path_in_repo_without_namespace | |
repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
if not repo_and_revision_exist: | |
raise FileNotFoundError(path) from err | |
else: | |
raise FileNotFoundError(path) from err | |
else: | |
repo_id = path | |
path_in_repo = "" | |
if "@" in path: | |
repo_id, revision_in_path = path.split("@", 1) | |
revision_in_path = unquote(revision_in_path) | |
revision = _align_revision_in_path_with_revision(revision_in_path, revision) | |
repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) | |
if not repo_and_revision_exist: | |
raise NotImplementedError("Access to repositories lists is not implemented.") | |
revision = revision if revision is not None else DEFAULT_REVISION | |
return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo) | |
def invalidate_cache(self, path: Optional[str] = None) -> None: | |
if not path: | |
self.dircache.clear() | |
self._repository_type_and_id_exists_cache.clear() | |
else: | |
path = self.resolve_path(path).unresolve() | |
while path: | |
self.dircache.pop(path, None) | |
path = self._parent(path) | |
def _open( | |
self, | |
path: str, | |
mode: str = "rb", | |
revision: Optional[str] = None, | |
**kwargs, | |
) -> "HfFileSystemFile": | |
if mode == "ab": | |
raise NotImplementedError("Appending to remote files is not yet supported.") | |
return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs) | |
def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: | |
resolved_path = self.resolve_path(path, revision=revision) | |
self._api.delete_file( | |
path_in_repo=resolved_path.path_in_repo, | |
repo_id=resolved_path.repo_id, | |
token=self.token, | |
repo_type=resolved_path.repo_type, | |
revision=resolved_path.revision, | |
commit_message=kwargs.get("commit_message"), | |
commit_description=kwargs.get("commit_description"), | |
) | |
self.invalidate_cache(path=resolved_path.unresolve()) | |
def rm( | |
self, | |
path: str, | |
recursive: bool = False, | |
maxdepth: Optional[int] = None, | |
revision: Optional[str] = None, | |
**kwargs, | |
) -> None: | |
resolved_path = self.resolve_path(path, revision=revision) | |
root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id | |
paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=resolved_path.revision) | |
paths_in_repo = [path[len(root_path) + 1 :] for path in paths if not self.isdir(path)] | |
operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] | |
commit_message = f"Delete {path} " | |
commit_message += "recursively " if recursive else "" | |
commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" | |
# TODO: use `commit_description` to list all the deleted paths? | |
self._api.create_commit( | |
repo_id=resolved_path.repo_id, | |
repo_type=resolved_path.repo_type, | |
token=self.token, | |
operations=operations, | |
revision=resolved_path.revision, | |
commit_message=kwargs.get("commit_message", commit_message), | |
commit_description=kwargs.get("commit_description"), | |
) | |
self.invalidate_cache(path=resolved_path.unresolve()) | |
def ls( | |
self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs | |
) -> List[Union[str, Dict[str, Any]]]: | |
"""List the contents of a directory.""" | |
resolved_path = self.resolve_path(path, revision=revision) | |
revision_in_path = "@" + safe_revision(resolved_path.revision) | |
has_revision_in_path = revision_in_path in path | |
path = resolved_path.unresolve() | |
if path not in self.dircache or refresh: | |
path_prefix = ( | |
HfFileSystemResolvedPath( | |
resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision, "" | |
).unresolve() | |
+ "/" | |
) | |
tree_path = path | |
tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) | |
try: | |
tree_item = next(tree_iter) | |
except EntryNotFoundError: | |
if "/" in resolved_path.path_in_repo: | |
tree_path = self._parent(path) | |
tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) | |
else: | |
raise | |
else: | |
tree_iter = itertools.chain([tree_item], tree_iter) | |
child_infos = [] | |
for tree_item in tree_iter: | |
child_info = { | |
"name": path_prefix + tree_item["path"], | |
"size": tree_item["size"], | |
"type": tree_item["type"], | |
} | |
if tree_item["type"] == "file": | |
child_info.update( | |
{ | |
"blob_id": tree_item["oid"], | |
"lfs": tree_item.get("lfs"), | |
"last_modified": parse_datetime(tree_item["lastCommit"]["date"]), | |
}, | |
) | |
child_infos.append(child_info) | |
self.dircache[tree_path] = child_infos | |
out = self._ls_from_cache(path) | |
if not has_revision_in_path: | |
out = [{**o, "name": o["name"].replace(revision_in_path, "", 1)} for o in out] | |
return out if detail else [o["name"] for o in out] | |
def _iter_tree(self, path: str, revision: Optional[str] = None): | |
# TODO: use HfApi.list_files_info instead when it supports "lastCommit" and "expand=True" | |
# See https://github.com/huggingface/moon-landing/issues/5993 | |
resolved_path = self.resolve_path(path, revision=revision) | |
path = f"{self._api.endpoint}/api/{resolved_path.repo_type}s/{resolved_path.repo_id}/tree/{safe_quote(resolved_path.revision)}/{resolved_path.path_in_repo}".rstrip( | |
"/" | |
) | |
headers = self._api._build_hf_headers() | |
yield from paginate(path, params={"expand": True}, headers=headers) | |
def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: | |
resolved_path1 = self.resolve_path(path1, revision=revision) | |
resolved_path2 = self.resolve_path(path2, revision=revision) | |
same_repo = ( | |
resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id | |
) | |
# TODO: Wait for https://github.com/huggingface/huggingface_hub/issues/1083 to be resolved to simplify this logic | |
if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None: | |
commit_message = f"Copy {path1} to {path2}" | |
self._api.create_commit( | |
repo_id=resolved_path1.repo_id, | |
repo_type=resolved_path1.repo_type, | |
revision=resolved_path2.revision, | |
commit_message=kwargs.get("commit_message", commit_message), | |
commit_description=kwargs.get("commit_description", ""), | |
operations=[ | |
CommitOperationCopy( | |
src_path_in_repo=resolved_path1.path_in_repo, | |
path_in_repo=resolved_path2.path_in_repo, | |
src_revision=resolved_path1.revision, | |
) | |
], | |
) | |
else: | |
with self.open(path1, "rb", revision=resolved_path1.revision) as f: | |
content = f.read() | |
commit_message = f"Copy {path1} to {path2}" | |
self._api.upload_file( | |
path_or_fileobj=content, | |
path_in_repo=resolved_path2.path_in_repo, | |
repo_id=resolved_path2.repo_id, | |
token=self.token, | |
repo_type=resolved_path2.repo_type, | |
revision=resolved_path2.revision, | |
commit_message=kwargs.get("commit_message", commit_message), | |
commit_description=kwargs.get("commit_description"), | |
) | |
self.invalidate_cache(path=resolved_path1.unresolve()) | |
self.invalidate_cache(path=resolved_path2.unresolve()) | |
def modified(self, path: str, **kwargs) -> datetime: | |
info = self.info(path, **kwargs) | |
if "last_modified" not in info: | |
raise IsADirectoryError(path) | |
return info["last_modified"] | |
def info(self, path: str, **kwargs) -> Dict[str, Any]: | |
resolved_path = self.resolve_path(path) | |
if not resolved_path.path_in_repo: | |
revision_in_path = "@" + safe_revision(resolved_path.revision) | |
has_revision_in_path = revision_in_path in path | |
name = resolved_path.unresolve() | |
name = name.replace(revision_in_path, "", 1) if not has_revision_in_path else name | |
return {"name": name, "size": 0, "type": "directory"} | |
return super().info(path, **kwargs) | |
class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): | |
def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): | |
super().__init__(fs, path, **kwargs) | |
self.fs: HfFileSystem | |
self.resolved_path = fs.resolve_path(path, revision=revision) | |
def _fetch_range(self, start: int, end: int) -> bytes: | |
headers = { | |
"range": f"bytes={start}-{end - 1}", | |
**self.fs._api._build_hf_headers(), | |
} | |
url = ( | |
f"{self.fs.endpoint}/{REPO_TYPES_URL_PREFIXES.get(self.resolved_path.repo_type, '') + self.resolved_path.repo_id}/resolve/{safe_quote(self.resolved_path.revision)}/{safe_quote(self.resolved_path.path_in_repo)}" | |
) | |
r = http_backoff("GET", url, headers=headers) | |
hf_raise_for_status(r) | |
return r.content | |
def _initiate_upload(self) -> None: | |
self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) | |
def _upload_chunk(self, final: bool = False) -> None: | |
self.buffer.seek(0) | |
block = self.buffer.read() | |
self.temp_file.write(block) | |
if final: | |
self.temp_file.close() | |
self.fs._api.upload_file( | |
path_or_fileobj=self.temp_file.name, | |
path_in_repo=self.resolved_path.path_in_repo, | |
repo_id=self.resolved_path.repo_id, | |
token=self.fs.token, | |
repo_type=self.resolved_path.repo_type, | |
revision=self.resolved_path.revision, | |
commit_message=self.kwargs.get("commit_message"), | |
commit_description=self.kwargs.get("commit_description"), | |
) | |
os.remove(self.temp_file.name) | |
self.fs.invalidate_cache( | |
path=self.resolved_path.unresolve(), | |
) | |
def safe_revision(revision: str) -> str: | |
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision) | |
def safe_quote(s: str) -> str: | |
return quote(s, safe="") | |