Spaces:
Paused
Paused
| # coding=utf-8 | |
| # Copyright 2022-present, the HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Contains utilities to handle HTTP requests in Huggingface Hub.""" | |
| import io | |
| import os | |
| import re | |
| import threading | |
| import time | |
| import uuid | |
| from functools import lru_cache | |
| from http import HTTPStatus | |
| from typing import Callable, Optional, Tuple, Type, Union | |
| import requests | |
| from requests import HTTPError, Response | |
| from requests.adapters import HTTPAdapter | |
| from requests.models import PreparedRequest | |
| from huggingface_hub.errors import OfflineModeIsEnabled | |
| from .. import constants | |
| from ..errors import ( | |
| BadRequestError, | |
| DisabledRepoError, | |
| EntryNotFoundError, | |
| GatedRepoError, | |
| HfHubHTTPError, | |
| RepositoryNotFoundError, | |
| RevisionNotFoundError, | |
| ) | |
| from . import logging | |
| from ._fixes import JSONDecodeError | |
| from ._lfs import SliceFileObj | |
| from ._typing import HTTP_METHOD_T | |
| logger = logging.get_logger(__name__) | |
| # Both headers are used by the Hub to debug failed requests. | |
| # `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB. | |
| # If `X_AMZN_TRACE_ID` is set, the Hub will use it as well. | |
| X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" | |
| X_REQUEST_ID = "x-request-id" | |
| REPO_API_REGEX = re.compile( | |
| r""" | |
| # staging or production endpoint | |
| ^https://[^/]+ | |
| ( | |
| # on /api/repo_type/repo_id | |
| /api/(models|datasets|spaces)/(.+) | |
| | | |
| # or /repo_id/resolve/revision/... | |
| /(.+)/resolve/(.+) | |
| ) | |
| """, | |
| flags=re.VERBOSE, | |
| ) | |
| class UniqueRequestIdAdapter(HTTPAdapter): | |
| X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" | |
| def add_headers(self, request, **kwargs): | |
| super().add_headers(request, **kwargs) | |
| # Add random request ID => easier for server-side debug | |
| if X_AMZN_TRACE_ID not in request.headers: | |
| request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) | |
| # Add debug log | |
| has_token = str(request.headers.get("authorization", "")).startswith("Bearer hf_") | |
| logger.debug( | |
| f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})" | |
| ) | |
| def send(self, request: PreparedRequest, *args, **kwargs) -> Response: | |
| """Catch any RequestException to append request id to the error message for debugging.""" | |
| try: | |
| return super().send(request, *args, **kwargs) | |
| except requests.RequestException as e: | |
| request_id = request.headers.get(X_AMZN_TRACE_ID) | |
| if request_id is not None: | |
| # Taken from https://stackoverflow.com/a/58270258 | |
| e.args = (*e.args, f"(Request ID: {request_id})") | |
| raise | |
| class OfflineAdapter(HTTPAdapter): | |
| def send(self, request: PreparedRequest, *args, **kwargs) -> Response: | |
| raise OfflineModeIsEnabled( | |
| f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." | |
| ) | |
| def _default_backend_factory() -> requests.Session: | |
| session = requests.Session() | |
| if constants.HF_HUB_OFFLINE: | |
| session.mount("http://", OfflineAdapter()) | |
| session.mount("https://", OfflineAdapter()) | |
| else: | |
| session.mount("http://", UniqueRequestIdAdapter()) | |
| session.mount("https://", UniqueRequestIdAdapter()) | |
| return session | |
| BACKEND_FACTORY_T = Callable[[], requests.Session] | |
| _GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory | |
| def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None: | |
| """ | |
| Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a | |
| Session object instantiated by this factory. This can be useful if you are running your scripts in a specific | |
| environment requiring custom configuration (e.g. custom proxy or certifications). | |
| Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, | |
| `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` | |
| set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between | |
| calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. | |
| See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. | |
| Example: | |
| ```py | |
| import requests | |
| from huggingface_hub import configure_http_backend, get_session | |
| # Create a factory function that returns a Session with configured proxies | |
| def backend_factory() -> requests.Session: | |
| session = requests.Session() | |
| session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} | |
| return session | |
| # Set it as the default session factory | |
| configure_http_backend(backend_factory=backend_factory) | |
| # In practice, this is mostly done internally in `huggingface_hub` | |
| session = get_session() | |
| ``` | |
| """ | |
| global _GLOBAL_BACKEND_FACTORY | |
| _GLOBAL_BACKEND_FACTORY = backend_factory | |
| reset_sessions() | |
| def get_session() -> requests.Session: | |
| """ | |
| Get a `requests.Session` object, using the session factory from the user. | |
| Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, | |
| `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` | |
| set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between | |
| calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. | |
| See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. | |
| Example: | |
| ```py | |
| import requests | |
| from huggingface_hub import configure_http_backend, get_session | |
| # Create a factory function that returns a Session with configured proxies | |
| def backend_factory() -> requests.Session: | |
| session = requests.Session() | |
| session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} | |
| return session | |
| # Set it as the default session factory | |
| configure_http_backend(backend_factory=backend_factory) | |
| # In practice, this is mostly done internally in `huggingface_hub` | |
| session = get_session() | |
| ``` | |
| """ | |
| return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident()) | |
| def reset_sessions() -> None: | |
| """Reset the cache of sessions. | |
| Mostly used internally when sessions are reconfigured or an SSLError is raised. | |
| See [`configure_http_backend`] for more details. | |
| """ | |
| _get_session_from_cache.cache_clear() | |
| def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session: | |
| """ | |
| Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when | |
| using thousands of threads. Cache is cleared when `configure_http_backend` is called. | |
| """ | |
| return _GLOBAL_BACKEND_FACTORY() | |
| def http_backoff( | |
| method: HTTP_METHOD_T, | |
| url: str, | |
| *, | |
| max_retries: int = 5, | |
| base_wait_time: float = 1, | |
| max_wait_time: float = 8, | |
| retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( | |
| requests.Timeout, | |
| requests.ConnectionError, | |
| ), | |
| retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, | |
| **kwargs, | |
| ) -> Response: | |
| """Wrapper around requests to retry calls on an endpoint, with exponential backoff. | |
| Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...) | |
| and/or on specific status codes (ex: service unavailable). If the call failed more | |
| than `max_retries`, the exception is thrown or `raise_for_status` is called on the | |
| response object. | |
| Re-implement mechanisms from the `backoff` library to avoid adding an external | |
| dependencies to `hugging_face_hub`. See https://github.com/litl/backoff. | |
| Args: | |
| method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`): | |
| HTTP method to perform. | |
| url (`str`): | |
| The URL of the resource to fetch. | |
| max_retries (`int`, *optional*, defaults to `5`): | |
| Maximum number of retries, defaults to 5 (no retries). | |
| base_wait_time (`float`, *optional*, defaults to `1`): | |
| Duration (in seconds) to wait before retrying the first time. | |
| Wait time between retries then grows exponentially, capped by | |
| `max_wait_time`. | |
| max_wait_time (`float`, *optional*, defaults to `8`): | |
| Maximum duration (in seconds) to wait before retrying. | |
| retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): | |
| Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. | |
| By default, retry on `requests.Timeout` and `requests.ConnectionError`. | |
| retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): | |
| Define on which status codes the request must be retried. By default, only | |
| HTTP 503 Service Unavailable is retried. | |
| **kwargs (`dict`, *optional*): | |
| kwargs to pass to `requests.request`. | |
| Example: | |
| ``` | |
| >>> from huggingface_hub.utils import http_backoff | |
| # Same usage as "requests.request". | |
| >>> response = http_backoff("GET", "https://www.google.com") | |
| >>> response.raise_for_status() | |
| # If you expect a Gateway Timeout from time to time | |
| >>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) | |
| >>> response.raise_for_status() | |
| ``` | |
| <Tip warning={true}> | |
| When using `requests` it is possible to stream data by passing an iterator to the | |
| `data` argument. On http backoff this is a problem as the iterator is not reset | |
| after a failed call. This issue is mitigated for file objects or any IO streams | |
| by saving the initial position of the cursor (with `data.tell()`) and resetting the | |
| cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff | |
| will fail. If this is a hard constraint for you, please let us know by opening an | |
| issue on [Github](https://github.com/huggingface/huggingface_hub). | |
| </Tip> | |
| """ | |
| if isinstance(retry_on_exceptions, type): # Tuple from single exception type | |
| retry_on_exceptions = (retry_on_exceptions,) | |
| if isinstance(retry_on_status_codes, int): # Tuple from single status code | |
| retry_on_status_codes = (retry_on_status_codes,) | |
| nb_tries = 0 | |
| sleep_time = base_wait_time | |
| # If `data` is used and is a file object (or any IO), it will be consumed on the | |
| # first HTTP request. We need to save the initial position so that the full content | |
| # of the file is re-sent on http backoff. See warning tip in docstring. | |
| io_obj_initial_pos = None | |
| if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)): | |
| io_obj_initial_pos = kwargs["data"].tell() | |
| session = get_session() | |
| while True: | |
| nb_tries += 1 | |
| try: | |
| # If `data` is used and is a file object (or any IO), set back cursor to | |
| # initial position. | |
| if io_obj_initial_pos is not None: | |
| kwargs["data"].seek(io_obj_initial_pos) | |
| # Perform request and return if status_code is not in the retry list. | |
| response = session.request(method=method, url=url, **kwargs) | |
| if response.status_code not in retry_on_status_codes: | |
| return response | |
| # Wrong status code returned (HTTP 503 for instance) | |
| logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") | |
| if nb_tries > max_retries: | |
| response.raise_for_status() # Will raise uncaught exception | |
| # We return response to avoid infinite loop in the corner case where the | |
| # user ask for retry on a status code that doesn't raise_for_status. | |
| return response | |
| except retry_on_exceptions as err: | |
| logger.warning(f"'{err}' thrown while requesting {method} {url}") | |
| if isinstance(err, requests.ConnectionError): | |
| reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects | |
| if nb_tries > max_retries: | |
| raise err | |
| # Sleep for X seconds | |
| logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") | |
| time.sleep(sleep_time) | |
| # Update sleep time for next retry | |
| sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff | |
| def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: | |
| """Replace the default endpoint in a URL by a custom one. | |
| This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint. | |
| """ | |
| endpoint = endpoint or constants.ENDPOINT | |
| # check if a proxy has been set => if yes, update the returned URL to use the proxy | |
| if endpoint not in (None, constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT): | |
| url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint) | |
| url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint) | |
| return url | |
| def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: | |
| """ | |
| Internal version of `response.raise_for_status()` that will refine a | |
| potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. | |
| This helper is meant to be the unique method to raise_for_status when making a call | |
| to the Hugging Face Hub. | |
| Example: | |
| ```py | |
| import requests | |
| from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError | |
| response = get_session().post(...) | |
| try: | |
| hf_raise_for_status(response) | |
| except HfHubHTTPError as e: | |
| print(str(e)) # formatted message | |
| e.request_id, e.server_message # details returned by server | |
| # Complete the error message with additional information once it's raised | |
| e.append_to_message("\n`create_commit` expects the repository to exist.") | |
| raise | |
| ``` | |
| Args: | |
| response (`Response`): | |
| Response from the server. | |
| endpoint_name (`str`, *optional*): | |
| Name of the endpoint that has been called. If provided, the error message | |
| will be more complete. | |
| <Tip warning={true}> | |
| Raises when the request has failed: | |
| - [`~utils.RepositoryNotFoundError`] | |
| If the repository to download from cannot be found. This may be because it | |
| doesn't exist, because `repo_type` is not set correctly, or because the repo | |
| is `private` and you do not have access. | |
| - [`~utils.GatedRepoError`] | |
| If the repository exists but is gated and the user is not on the authorized | |
| list. | |
| - [`~utils.RevisionNotFoundError`] | |
| If the repository exists but the revision couldn't be find. | |
| - [`~utils.EntryNotFoundError`] | |
| If the repository exists but the entry (e.g. the requested file) couldn't be | |
| find. | |
| - [`~utils.BadRequestError`] | |
| If request failed with a HTTP 400 BadRequest error. | |
| - [`~utils.HfHubHTTPError`] | |
| If request failed for a reason not listed above. | |
| </Tip> | |
| """ | |
| try: | |
| response.raise_for_status() | |
| except HTTPError as e: | |
| error_code = response.headers.get("X-Error-Code") | |
| error_message = response.headers.get("X-Error-Message") | |
| if error_code == "RevisionNotFound": | |
| message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." | |
| raise _format(RevisionNotFoundError, message, response) from e | |
| elif error_code == "EntryNotFound": | |
| message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." | |
| raise _format(EntryNotFoundError, message, response) from e | |
| elif error_code == "GatedRepo": | |
| message = ( | |
| f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." | |
| ) | |
| raise _format(GatedRepoError, message, response) from e | |
| elif error_message == "Access to this resource is disabled.": | |
| message = ( | |
| f"{response.status_code} Client Error." | |
| + "\n\n" | |
| + f"Cannot access repository for url {response.url}." | |
| + "\n" | |
| + "Access to this resource is disabled." | |
| ) | |
| raise _format(DisabledRepoError, message, response) from e | |
| elif error_code == "RepoNotFound" or ( | |
| response.status_code == 401 | |
| and response.request is not None | |
| and response.request.url is not None | |
| and REPO_API_REGEX.search(response.request.url) is not None | |
| ): | |
| # 401 is misleading as it is returned for: | |
| # - private and gated repos if user is not authenticated | |
| # - missing repos | |
| # => for now, we process them as `RepoNotFound` anyway. | |
| # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 | |
| message = ( | |
| f"{response.status_code} Client Error." | |
| + "\n\n" | |
| + f"Repository Not Found for url: {response.url}." | |
| + "\nPlease make sure you specified the correct `repo_id` and" | |
| " `repo_type`.\nIf you are trying to access a private or gated repo," | |
| " make sure you are authenticated." | |
| ) | |
| raise _format(RepositoryNotFoundError, message, response) from e | |
| elif response.status_code == 400: | |
| message = ( | |
| f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" | |
| ) | |
| raise _format(BadRequestError, message, response) from e | |
| elif response.status_code == 403: | |
| message = ( | |
| f"\n\n{response.status_code} Forbidden: {error_message}." | |
| + f"\nCannot access content at: {response.url}." | |
| + "\nMake sure your token has the correct permissions." | |
| ) | |
| raise _format(HfHubHTTPError, message, response) from e | |
| elif response.status_code == 416: | |
| range_header = response.request.headers.get("Range") | |
| message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}." | |
| raise _format(HfHubHTTPError, message, response) from e | |
| # Convert `HTTPError` into a `HfHubHTTPError` to display request information | |
| # as well (request id and/or server error message) | |
| raise _format(HfHubHTTPError, str(e), response) from e | |
| def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError: | |
| server_errors = [] | |
| # Retrieve server error from header | |
| from_headers = response.headers.get("X-Error-Message") | |
| if from_headers is not None: | |
| server_errors.append(from_headers) | |
| # Retrieve server error from body | |
| try: | |
| # Case errors are returned in a JSON format | |
| data = response.json() | |
| error = data.get("error") | |
| if error is not None: | |
| if isinstance(error, list): | |
| # Case {'error': ['my error 1', 'my error 2']} | |
| server_errors.extend(error) | |
| else: | |
| # Case {'error': 'my error'} | |
| server_errors.append(error) | |
| errors = data.get("errors") | |
| if errors is not None: | |
| # Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]} | |
| for error in errors: | |
| if "message" in error: | |
| server_errors.append(error["message"]) | |
| except JSONDecodeError: | |
| # Case error is directly returned as text | |
| if response.text: | |
| server_errors.append(response.text) | |
| # Strip all server messages | |
| server_errors = [line.strip() for line in server_errors if line.strip()] | |
| # Deduplicate server messages (keep order) | |
| # taken from https://stackoverflow.com/a/17016257 | |
| server_errors = list(dict.fromkeys(server_errors)) | |
| # Format server error | |
| server_message = "\n".join(server_errors) | |
| # Add server error to custom message | |
| final_error_message = custom_message | |
| if server_message and server_message.lower() not in custom_message.lower(): | |
| if "\n\n" in custom_message: | |
| final_error_message += "\n" + server_message | |
| else: | |
| final_error_message += "\n\n" + server_message | |
| # Add Request ID | |
| request_id = str(response.headers.get(X_REQUEST_ID, "")) | |
| if len(request_id) > 0 and request_id.lower() not in final_error_message.lower(): | |
| request_id_message = f" (Request ID: {request_id})" | |
| if "\n" in final_error_message: | |
| newline_index = final_error_message.index("\n") | |
| final_error_message = ( | |
| final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:] | |
| ) | |
| else: | |
| final_error_message += request_id_message | |
| # Return | |
| return error_type(final_error_message.strip(), response=response, server_message=server_message or None) | |