Spaces:
Runtime error
Runtime error
| import os | |
| import requests | |
| from typing import Dict, Optional, List | |
| from huggingface_hub.utils import build_hf_headers | |
| from text_generation import Client, AsyncClient, __version__ | |
| from text_generation.types import DeployedModel | |
| from text_generation.errors import NotSupportedError, parse_error | |
| INFERENCE_ENDPOINT = os.environ.get( | |
| "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" | |
| ) | |
| def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: | |
| """ | |
| Get all currently deployed models with text-generation-inference-support | |
| Returns: | |
| List[DeployedModel]: list of all currently deployed models | |
| """ | |
| resp = requests.get( | |
| f"https://api-inference.huggingface.co/framework/text-generation-inference", | |
| headers=headers, | |
| timeout=5, | |
| ) | |
| payload = resp.json() | |
| if resp.status_code != 200: | |
| raise parse_error(resp.status_code, payload) | |
| models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload] | |
| return models | |
| def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool: | |
| """ | |
| Check if a given model is supported by text-generation-inference | |
| Returns: | |
| bool: whether the model is supported by this client | |
| """ | |
| resp = requests.get( | |
| f"https://api-inference.huggingface.co/status/{repo_id}", | |
| headers=headers, | |
| timeout=5, | |
| ) | |
| payload = resp.json() | |
| if resp.status_code != 200: | |
| raise parse_error(resp.status_code, payload) | |
| framework = payload["framework"] | |
| supported = framework == "text-generation-inference" | |
| return supported | |
| class InferenceAPIClient(Client): | |
| """Client to make calls to the HuggingFace Inference API. | |
| Only supports a subset of the available text-generation or text2text-generation models that are served using | |
| text-generation-inference | |
| Example: | |
| ```python | |
| >>> from text_generation import InferenceAPIClient | |
| >>> client = InferenceAPIClient("bigscience/bloomz") | |
| >>> client.generate("Why is the sky blue?").generated_text | |
| ' Rayleigh scattering' | |
| >>> result = "" | |
| >>> for response in client.generate_stream("Why is the sky blue?"): | |
| >>> if not response.token.special: | |
| >>> result += response.token.text | |
| >>> result | |
| ' Rayleigh scattering' | |
| ``` | |
| """ | |
| def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10): | |
| """ | |
| Init headers and API information | |
| Args: | |
| repo_id (`str`): | |
| Id of repository (e.g. `bigscience/bloom`). | |
| token (`str`, `optional`): | |
| The API token to use as HTTP bearer authorization. This is not | |
| the authentication token. You can find the token in | |
| https://huggingface.co/settings/token. Alternatively, you can | |
| find both your organizations and personal API tokens using | |
| `HfApi().whoami(token)`. | |
| timeout (`int`): | |
| Timeout in seconds | |
| """ | |
| headers = build_hf_headers( | |
| token=token, library_name="text-generation", library_version=__version__ | |
| ) | |
| # Text Generation Inference client only supports a subset of the available hub models | |
| if not check_model_support(repo_id, headers): | |
| raise NotSupportedError(repo_id) | |
| base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" | |
| super(InferenceAPIClient, self).__init__( | |
| base_url, headers=headers, timeout=timeout | |
| ) | |
| class InferenceAPIAsyncClient(AsyncClient): | |
| """Aynschronous Client to make calls to the HuggingFace Inference API. | |
| Only supports a subset of the available text-generation or text2text-generation models that are served using | |
| text-generation-inference | |
| Example: | |
| ```python | |
| >>> from text_generation import InferenceAPIAsyncClient | |
| >>> client = InferenceAPIAsyncClient("bigscience/bloomz") | |
| >>> response = await client.generate("Why is the sky blue?") | |
| >>> response.generated_text | |
| ' Rayleigh scattering' | |
| >>> result = "" | |
| >>> async for response in client.generate_stream("Why is the sky blue?"): | |
| >>> if not response.token.special: | |
| >>> result += response.token.text | |
| >>> result | |
| ' Rayleigh scattering' | |
| ``` | |
| """ | |
| def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10): | |
| """ | |
| Init headers and API information | |
| Args: | |
| repo_id (`str`): | |
| Id of repository (e.g. `bigscience/bloom`). | |
| token (`str`, `optional`): | |
| The API token to use as HTTP bearer authorization. This is not | |
| the authentication token. You can find the token in | |
| https://huggingface.co/settings/token. Alternatively, you can | |
| find both your organizations and personal API tokens using | |
| `HfApi().whoami(token)`. | |
| timeout (`int`): | |
| Timeout in seconds | |
| """ | |
| headers = build_hf_headers( | |
| token=token, library_name="text-generation", library_version=__version__ | |
| ) | |
| # Text Generation Inference client only supports a subset of the available hub models | |
| if not check_model_support(repo_id, headers): | |
| raise NotSupportedError(repo_id) | |
| base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" | |
| super(InferenceAPIAsyncClient, self).__init__( | |
| base_url, headers=headers, timeout=timeout | |
| ) | |