|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Contains utilities to handle HTTP requests in Huggingface Hub.""" |
|
|
|
import io |
|
import os |
|
import re |
|
import threading |
|
import time |
|
import uuid |
|
from functools import lru_cache |
|
from http import HTTPStatus |
|
from shlex import quote |
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union |
|
|
|
import requests |
|
from requests import HTTPError, Response |
|
from requests.adapters import HTTPAdapter |
|
from requests.models import PreparedRequest |
|
|
|
from huggingface_hub.errors import OfflineModeIsEnabled |
|
|
|
from .. import constants |
|
from ..errors import ( |
|
BadRequestError, |
|
DisabledRepoError, |
|
EntryNotFoundError, |
|
GatedRepoError, |
|
HfHubHTTPError, |
|
RepositoryNotFoundError, |
|
RevisionNotFoundError, |
|
) |
|
from . import logging |
|
from ._fixes import JSONDecodeError |
|
from ._lfs import SliceFileObj |
|
from ._typing import HTTP_METHOD_T |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" |
|
X_REQUEST_ID = "x-request-id" |
|
|
|
REPO_API_REGEX = re.compile( |
|
r""" |
|
# staging or production endpoint |
|
^https://[^/]+ |
|
( |
|
# on /api/repo_type/repo_id |
|
/api/(models|datasets|spaces)/(.+) |
|
| |
|
# or /repo_id/resolve/revision/... |
|
/(.+)/resolve/(.+) |
|
) |
|
""", |
|
flags=re.VERBOSE, |
|
) |
|
|
|
|
|
class UniqueRequestIdAdapter(HTTPAdapter): |
|
X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" |
|
|
|
def add_headers(self, request, **kwargs): |
|
super().add_headers(request, **kwargs) |
|
|
|
|
|
if X_AMZN_TRACE_ID not in request.headers: |
|
request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) |
|
|
|
|
|
has_token = len(str(request.headers.get("authorization", ""))) > 0 |
|
logger.debug( |
|
f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})" |
|
) |
|
|
|
def send(self, request: PreparedRequest, *args, **kwargs) -> Response: |
|
"""Catch any RequestException to append request id to the error message for debugging.""" |
|
if constants.HF_DEBUG: |
|
logger.debug(f"Send: {_curlify(request)}") |
|
try: |
|
return super().send(request, *args, **kwargs) |
|
except requests.RequestException as e: |
|
request_id = request.headers.get(X_AMZN_TRACE_ID) |
|
if request_id is not None: |
|
|
|
e.args = (*e.args, f"(Request ID: {request_id})") |
|
raise |
|
|
|
|
|
class OfflineAdapter(HTTPAdapter): |
|
def send(self, request: PreparedRequest, *args, **kwargs) -> Response: |
|
raise OfflineModeIsEnabled( |
|
f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." |
|
) |
|
|
|
|
|
def _default_backend_factory() -> requests.Session: |
|
session = requests.Session() |
|
if constants.HF_HUB_OFFLINE: |
|
session.mount("http://", OfflineAdapter()) |
|
session.mount("https://", OfflineAdapter()) |
|
else: |
|
session.mount("http://", UniqueRequestIdAdapter()) |
|
session.mount("https://", UniqueRequestIdAdapter()) |
|
return session |
|
|
|
|
|
BACKEND_FACTORY_T = Callable[[], requests.Session] |
|
_GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory |
|
|
|
|
|
def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None: |
|
""" |
|
Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a |
|
Session object instantiated by this factory. This can be useful if you are running your scripts in a specific |
|
environment requiring custom configuration (e.g. custom proxy or certifications). |
|
|
|
Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, |
|
`huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` |
|
set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between |
|
calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. |
|
|
|
See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. |
|
|
|
Example: |
|
```py |
|
import requests |
|
from huggingface_hub import configure_http_backend, get_session |
|
|
|
# Create a factory function that returns a Session with configured proxies |
|
def backend_factory() -> requests.Session: |
|
session = requests.Session() |
|
session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} |
|
return session |
|
|
|
# Set it as the default session factory |
|
configure_http_backend(backend_factory=backend_factory) |
|
|
|
# In practice, this is mostly done internally in `huggingface_hub` |
|
session = get_session() |
|
``` |
|
""" |
|
global _GLOBAL_BACKEND_FACTORY |
|
_GLOBAL_BACKEND_FACTORY = backend_factory |
|
reset_sessions() |
|
|
|
|
|
def get_session() -> requests.Session: |
|
""" |
|
Get a `requests.Session` object, using the session factory from the user. |
|
|
|
Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, |
|
`huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` |
|
set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between |
|
calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. |
|
|
|
See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. |
|
|
|
Example: |
|
```py |
|
import requests |
|
from huggingface_hub import configure_http_backend, get_session |
|
|
|
# Create a factory function that returns a Session with configured proxies |
|
def backend_factory() -> requests.Session: |
|
session = requests.Session() |
|
session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} |
|
return session |
|
|
|
# Set it as the default session factory |
|
configure_http_backend(backend_factory=backend_factory) |
|
|
|
# In practice, this is mostly done internally in `huggingface_hub` |
|
session = get_session() |
|
``` |
|
""" |
|
return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident()) |
|
|
|
|
|
def reset_sessions() -> None: |
|
"""Reset the cache of sessions. |
|
|
|
Mostly used internally when sessions are reconfigured or an SSLError is raised. |
|
See [`configure_http_backend`] for more details. |
|
""" |
|
_get_session_from_cache.cache_clear() |
|
|
|
|
|
@lru_cache |
|
def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session: |
|
""" |
|
Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when |
|
using thousands of threads. Cache is cleared when `configure_http_backend` is called. |
|
""" |
|
return _GLOBAL_BACKEND_FACTORY() |
|
|
|
|
|
def http_backoff( |
|
method: HTTP_METHOD_T, |
|
url: str, |
|
*, |
|
max_retries: int = 5, |
|
base_wait_time: float = 1, |
|
max_wait_time: float = 8, |
|
retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( |
|
requests.Timeout, |
|
requests.ConnectionError, |
|
), |
|
retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, |
|
**kwargs, |
|
) -> Response: |
|
"""Wrapper around requests to retry calls on an endpoint, with exponential backoff. |
|
|
|
Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...) |
|
and/or on specific status codes (ex: service unavailable). If the call failed more |
|
than `max_retries`, the exception is thrown or `raise_for_status` is called on the |
|
response object. |
|
|
|
Re-implement mechanisms from the `backoff` library to avoid adding an external |
|
dependencies to `hugging_face_hub`. See https://github.com/litl/backoff. |
|
|
|
Args: |
|
method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`): |
|
HTTP method to perform. |
|
url (`str`): |
|
The URL of the resource to fetch. |
|
max_retries (`int`, *optional*, defaults to `5`): |
|
Maximum number of retries, defaults to 5 (no retries). |
|
base_wait_time (`float`, *optional*, defaults to `1`): |
|
Duration (in seconds) to wait before retrying the first time. |
|
Wait time between retries then grows exponentially, capped by |
|
`max_wait_time`. |
|
max_wait_time (`float`, *optional*, defaults to `8`): |
|
Maximum duration (in seconds) to wait before retrying. |
|
retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): |
|
Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. |
|
By default, retry on `requests.Timeout` and `requests.ConnectionError`. |
|
retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): |
|
Define on which status codes the request must be retried. By default, only |
|
HTTP 503 Service Unavailable is retried. |
|
**kwargs (`dict`, *optional*): |
|
kwargs to pass to `requests.request`. |
|
|
|
Example: |
|
``` |
|
>>> from huggingface_hub.utils import http_backoff |
|
|
|
# Same usage as "requests.request". |
|
>>> response = http_backoff("GET", "https://www.google.com") |
|
>>> response.raise_for_status() |
|
|
|
# If you expect a Gateway Timeout from time to time |
|
>>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) |
|
>>> response.raise_for_status() |
|
``` |
|
|
|
<Tip warning={true}> |
|
|
|
When using `requests` it is possible to stream data by passing an iterator to the |
|
`data` argument. On http backoff this is a problem as the iterator is not reset |
|
after a failed call. This issue is mitigated for file objects or any IO streams |
|
by saving the initial position of the cursor (with `data.tell()`) and resetting the |
|
cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff |
|
will fail. If this is a hard constraint for you, please let us know by opening an |
|
issue on [Github](https://github.com/huggingface/huggingface_hub). |
|
|
|
</Tip> |
|
""" |
|
if isinstance(retry_on_exceptions, type): |
|
retry_on_exceptions = (retry_on_exceptions,) |
|
|
|
if isinstance(retry_on_status_codes, int): |
|
retry_on_status_codes = (retry_on_status_codes,) |
|
|
|
nb_tries = 0 |
|
sleep_time = base_wait_time |
|
|
|
|
|
|
|
|
|
io_obj_initial_pos = None |
|
if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)): |
|
io_obj_initial_pos = kwargs["data"].tell() |
|
|
|
session = get_session() |
|
while True: |
|
nb_tries += 1 |
|
try: |
|
|
|
|
|
if io_obj_initial_pos is not None: |
|
kwargs["data"].seek(io_obj_initial_pos) |
|
|
|
|
|
response = session.request(method=method, url=url, **kwargs) |
|
if response.status_code not in retry_on_status_codes: |
|
return response |
|
|
|
|
|
logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") |
|
if nb_tries > max_retries: |
|
response.raise_for_status() |
|
|
|
|
|
return response |
|
|
|
except retry_on_exceptions as err: |
|
logger.warning(f"'{err}' thrown while requesting {method} {url}") |
|
|
|
if isinstance(err, requests.ConnectionError): |
|
reset_sessions() |
|
|
|
if nb_tries > max_retries: |
|
raise err |
|
|
|
|
|
logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") |
|
time.sleep(sleep_time) |
|
|
|
|
|
sleep_time = min(max_wait_time, sleep_time * 2) |
|
|
|
|
|
def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: |
|
"""Replace the default endpoint in a URL by a custom one. |
|
|
|
This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint. |
|
""" |
|
endpoint = endpoint.rstrip("/") if endpoint else constants.ENDPOINT |
|
|
|
if endpoint not in (constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT): |
|
url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint) |
|
url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint) |
|
return url |
|
|
|
|
|
def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: |
|
""" |
|
Internal version of `response.raise_for_status()` that will refine a |
|
potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. |
|
|
|
This helper is meant to be the unique method to raise_for_status when making a call |
|
to the Hugging Face Hub. |
|
|
|
|
|
Example: |
|
```py |
|
import requests |
|
from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError |
|
|
|
response = get_session().post(...) |
|
try: |
|
hf_raise_for_status(response) |
|
except HfHubHTTPError as e: |
|
print(str(e)) # formatted message |
|
e.request_id, e.server_message # details returned by server |
|
|
|
# Complete the error message with additional information once it's raised |
|
e.append_to_message("\n`create_commit` expects the repository to exist.") |
|
raise |
|
``` |
|
|
|
Args: |
|
response (`Response`): |
|
Response from the server. |
|
endpoint_name (`str`, *optional*): |
|
Name of the endpoint that has been called. If provided, the error message |
|
will be more complete. |
|
|
|
<Tip warning={true}> |
|
|
|
Raises when the request has failed: |
|
|
|
- [`~utils.RepositoryNotFoundError`] |
|
If the repository to download from cannot be found. This may be because it |
|
doesn't exist, because `repo_type` is not set correctly, or because the repo |
|
is `private` and you do not have access. |
|
- [`~utils.GatedRepoError`] |
|
If the repository exists but is gated and the user is not on the authorized |
|
list. |
|
- [`~utils.RevisionNotFoundError`] |
|
If the repository exists but the revision couldn't be find. |
|
- [`~utils.EntryNotFoundError`] |
|
If the repository exists but the entry (e.g. the requested file) couldn't be |
|
find. |
|
- [`~utils.BadRequestError`] |
|
If request failed with a HTTP 400 BadRequest error. |
|
- [`~utils.HfHubHTTPError`] |
|
If request failed for a reason not listed above. |
|
|
|
</Tip> |
|
""" |
|
try: |
|
response.raise_for_status() |
|
except HTTPError as e: |
|
error_code = response.headers.get("X-Error-Code") |
|
error_message = response.headers.get("X-Error-Message") |
|
|
|
if error_code == "RevisionNotFound": |
|
message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." |
|
raise _format(RevisionNotFoundError, message, response) from e |
|
|
|
elif error_code == "EntryNotFound": |
|
message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." |
|
raise _format(EntryNotFoundError, message, response) from e |
|
|
|
elif error_code == "GatedRepo": |
|
message = ( |
|
f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." |
|
) |
|
raise _format(GatedRepoError, message, response) from e |
|
|
|
elif error_message == "Access to this resource is disabled.": |
|
message = ( |
|
f"{response.status_code} Client Error." |
|
+ "\n\n" |
|
+ f"Cannot access repository for url {response.url}." |
|
+ "\n" |
|
+ "Access to this resource is disabled." |
|
) |
|
raise _format(DisabledRepoError, message, response) from e |
|
|
|
elif error_code == "RepoNotFound" or ( |
|
response.status_code == 401 |
|
and error_message != "Invalid credentials in Authorization header" |
|
and response.request is not None |
|
and response.request.url is not None |
|
and REPO_API_REGEX.search(response.request.url) is not None |
|
): |
|
|
|
|
|
|
|
|
|
|
|
message = ( |
|
f"{response.status_code} Client Error." |
|
+ "\n\n" |
|
+ f"Repository Not Found for url: {response.url}." |
|
+ "\nPlease make sure you specified the correct `repo_id` and" |
|
" `repo_type`.\nIf you are trying to access a private or gated repo," |
|
" make sure you are authenticated. For more details, see" |
|
" https://huggingface.co/docs/huggingface_hub/authentication" |
|
) |
|
raise _format(RepositoryNotFoundError, message, response) from e |
|
|
|
elif response.status_code == 400: |
|
message = ( |
|
f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" |
|
) |
|
raise _format(BadRequestError, message, response) from e |
|
|
|
elif response.status_code == 403: |
|
message = ( |
|
f"\n\n{response.status_code} Forbidden: {error_message}." |
|
+ f"\nCannot access content at: {response.url}." |
|
+ "\nMake sure your token has the correct permissions." |
|
) |
|
raise _format(HfHubHTTPError, message, response) from e |
|
|
|
elif response.status_code == 416: |
|
range_header = response.request.headers.get("Range") |
|
message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}." |
|
raise _format(HfHubHTTPError, message, response) from e |
|
|
|
|
|
|
|
raise _format(HfHubHTTPError, str(e), response) from e |
|
|
|
|
|
def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError: |
|
server_errors = [] |
|
|
|
|
|
from_headers = response.headers.get("X-Error-Message") |
|
if from_headers is not None: |
|
server_errors.append(from_headers) |
|
|
|
|
|
try: |
|
|
|
data = response.json() |
|
|
|
error = data.get("error") |
|
if error is not None: |
|
if isinstance(error, list): |
|
|
|
server_errors.extend(error) |
|
else: |
|
|
|
server_errors.append(error) |
|
|
|
errors = data.get("errors") |
|
if errors is not None: |
|
|
|
for error in errors: |
|
if "message" in error: |
|
server_errors.append(error["message"]) |
|
|
|
except JSONDecodeError: |
|
|
|
content_type = response.headers.get("Content-Type", "") |
|
if response.text and "html" not in content_type.lower(): |
|
server_errors.append(response.text) |
|
|
|
|
|
server_errors = [str(line).strip() for line in server_errors if str(line).strip()] |
|
|
|
|
|
|
|
server_errors = list(dict.fromkeys(server_errors)) |
|
|
|
|
|
server_message = "\n".join(server_errors) |
|
|
|
|
|
final_error_message = custom_message |
|
if server_message and server_message.lower() not in custom_message.lower(): |
|
if "\n\n" in custom_message: |
|
final_error_message += "\n" + server_message |
|
else: |
|
final_error_message += "\n\n" + server_message |
|
|
|
request_id = str(response.headers.get(X_REQUEST_ID, "")) |
|
if request_id: |
|
request_id_message = f" (Request ID: {request_id})" |
|
else: |
|
|
|
request_id = str(response.headers.get(X_AMZN_TRACE_ID, "")) |
|
if request_id: |
|
request_id_message = f" (Amzn Trace ID: {request_id})" |
|
if request_id and request_id.lower() not in final_error_message.lower(): |
|
if "\n" in final_error_message: |
|
newline_index = final_error_message.index("\n") |
|
final_error_message = ( |
|
final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:] |
|
) |
|
else: |
|
final_error_message += request_id_message |
|
|
|
|
|
return error_type(final_error_message.strip(), response=response, server_message=server_message or None) |
|
|
|
|
|
def _curlify(request: requests.PreparedRequest) -> str: |
|
"""Convert a `requests.PreparedRequest` into a curl command (str). |
|
|
|
Used for debug purposes only. |
|
|
|
Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py. |
|
MIT License Copyright (c) 2016 Egor. |
|
""" |
|
parts: List[Tuple[Any, Any]] = [ |
|
("curl", None), |
|
("-X", request.method), |
|
] |
|
|
|
for k, v in sorted(request.headers.items()): |
|
if k.lower() == "authorization": |
|
v = "<TOKEN>" |
|
parts += [("-H", "{0}: {1}".format(k, v))] |
|
|
|
if request.body: |
|
body = request.body |
|
if isinstance(body, bytes): |
|
body = body.decode("utf-8", errors="ignore") |
|
elif hasattr(body, "read"): |
|
body = "<file-like object>" |
|
if len(body) > 1000: |
|
body = body[:1000] + " ... [truncated]" |
|
parts += [("-d", body.replace("\n", ""))] |
|
|
|
parts += [(None, request.url)] |
|
|
|
flat_parts = [] |
|
for k, v in parts: |
|
if k: |
|
flat_parts.append(quote(k)) |
|
if v: |
|
flat_parts.append(quote(v)) |
|
|
|
return " ".join(flat_parts) |
|
|
|
|
|
|
|
RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE) |
|
|
|
|
|
def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]: |
|
""" |
|
Adjust HTTP Range header to account for resume position. |
|
""" |
|
if not original_range: |
|
return f"bytes={resume_size}-" |
|
|
|
if "," in original_range: |
|
raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.") |
|
|
|
match = RANGE_REGEX.match(original_range) |
|
if not match: |
|
raise RuntimeError(f"Invalid range format - {original_range!r}.") |
|
start, end = match.groups() |
|
|
|
if not start: |
|
if not end: |
|
raise RuntimeError(f"Invalid range format - {original_range!r}.") |
|
|
|
new_suffix = int(end) - resume_size |
|
new_range = f"bytes=-{new_suffix}" |
|
if new_suffix <= 0: |
|
raise RuntimeError(f"Empty new range - {new_range!r}.") |
|
return new_range |
|
|
|
start = int(start) |
|
new_start = start + resume_size |
|
if end: |
|
end = int(end) |
|
new_range = f"bytes={new_start}-{end}" |
|
if new_start > end: |
|
raise RuntimeError(f"Empty new range - {new_range!r}.") |
|
return new_range |
|
|
|
return f"bytes={new_start}-" |
|
|