Spaces:
Runtime error
Runtime error
| import json | |
| import requests | |
| from aiohttp import ClientSession, ClientTimeout | |
| from pydantic import ValidationError | |
| from typing import Dict, Optional, List, AsyncIterator, Iterator | |
| from text_generation.types import ( | |
| StreamResponse, | |
| Response, | |
| Request, | |
| Parameters, | |
| ) | |
| from text_generation.errors import parse_error | |
| class Client: | |
| """Client to make calls to a text-generation-inference instance | |
| Example: | |
| ```python | |
| >>> from text_generation import Client | |
| >>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz") | |
| >>> client.generate("Why is the sky blue?").generated_text | |
| ' Rayleigh scattering' | |
| >>> result = "" | |
| >>> for response in client.generate_stream("Why is the sky blue?"): | |
| >>> if not response.token.special: | |
| >>> result += response.token.text | |
| >>> result | |
| ' Rayleigh scattering' | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str, | |
| headers: Optional[Dict[str, str]] = None, | |
| cookies: Optional[Dict[str, str]] = None, | |
| timeout: int = 10, | |
| ): | |
| """ | |
| Args: | |
| base_url (`str`): | |
| text-generation-inference instance base url | |
| headers (`Optional[Dict[str, str]]`): | |
| Additional headers | |
| cookies (`Optional[Dict[str, str]]`): | |
| Cookies to include in the requests | |
| timeout (`int`): | |
| Timeout in seconds | |
| """ | |
| self.base_url = base_url | |
| self.headers = headers | |
| self.cookies = cookies | |
| self.timeout = timeout | |
| def generate( | |
| self, | |
| prompt: str, | |
| do_sample: bool = False, | |
| max_new_tokens: int = 20, | |
| best_of: Optional[int] = None, | |
| repetition_penalty: Optional[float] = None, | |
| return_full_text: bool = False, | |
| seed: Optional[int] = None, | |
| stop_sequences: Optional[List[str]] = None, | |
| temperature: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| truncate: Optional[int] = None, | |
| typical_p: Optional[float] = None, | |
| watermark: bool = False, | |
| ) -> Response: | |
| """ | |
| Given a prompt, generate the following text | |
| Args: | |
| prompt (`str`): | |
| Input text | |
| do_sample (`bool`): | |
| Activate logits sampling | |
| max_new_tokens (`int`): | |
| Maximum number of generated tokens | |
| best_of (`int`): | |
| Generate best_of sequences and return the one if the highest token logprobs | |
| repetition_penalty (`float`): | |
| The parameter for repetition penalty. 1.0 means no penalty. See [this | |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
| return_full_text (`bool`): | |
| Whether to prepend the prompt to the generated text | |
| seed (`int`): | |
| Random sampling seed | |
| stop_sequences (`List[str]`): | |
| Stop generating tokens if a member of `stop_sequences` is generated | |
| temperature (`float`): | |
| The value used to module the logits distribution. | |
| top_k (`int`): | |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| top_p (`float`): | |
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
| higher are kept for generation. | |
| truncate (`int`): | |
| Truncate inputs tokens to the given size | |
| typical_p (`float`): | |
| Typical Decoding mass | |
| See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
| watermark (`bool`): | |
| Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
| Returns: | |
| Response: generated response | |
| """ | |
| # Validate parameters | |
| parameters = Parameters( | |
| best_of=best_of, | |
| details=True, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| return_full_text=return_full_text, | |
| seed=seed, | |
| stop=stop_sequences if stop_sequences is not None else [], | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| truncate=truncate, | |
| typical_p=typical_p, | |
| watermark=watermark, | |
| ) | |
| request = Request(inputs=prompt, stream=False, parameters=parameters) | |
| resp = requests.post( | |
| self.base_url, | |
| json=request.dict(), | |
| headers=self.headers, | |
| cookies=self.cookies, | |
| timeout=self.timeout, | |
| ) | |
| payload = resp.json() | |
| if resp.status_code != 200: | |
| raise parse_error(resp.status_code, payload) | |
| return Response(**payload[0]) | |
| def generate_stream( | |
| self, | |
| prompt: str, | |
| do_sample: bool = False, | |
| max_new_tokens: int = 20, | |
| repetition_penalty: Optional[float] = None, | |
| return_full_text: bool = False, | |
| seed: Optional[int] = None, | |
| stop_sequences: Optional[List[str]] = None, | |
| temperature: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| truncate: Optional[int] = None, | |
| typical_p: Optional[float] = None, | |
| watermark: bool = False, | |
| ) -> Iterator[StreamResponse]: | |
| """ | |
| Given a prompt, generate the following stream of tokens | |
| Args: | |
| prompt (`str`): | |
| Input text | |
| do_sample (`bool`): | |
| Activate logits sampling | |
| max_new_tokens (`int`): | |
| Maximum number of generated tokens | |
| repetition_penalty (`float`): | |
| The parameter for repetition penalty. 1.0 means no penalty. See [this | |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
| return_full_text (`bool`): | |
| Whether to prepend the prompt to the generated text | |
| seed (`int`): | |
| Random sampling seed | |
| stop_sequences (`List[str]`): | |
| Stop generating tokens if a member of `stop_sequences` is generated | |
| temperature (`float`): | |
| The value used to module the logits distribution. | |
| top_k (`int`): | |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| top_p (`float`): | |
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
| higher are kept for generation. | |
| truncate (`int`): | |
| Truncate inputs tokens to the given size | |
| typical_p (`float`): | |
| Typical Decoding mass | |
| See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
| watermark (`bool`): | |
| Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
| Returns: | |
| Iterator[StreamResponse]: stream of generated tokens | |
| """ | |
| # Validate parameters | |
| parameters = Parameters( | |
| best_of=None, | |
| details=True, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| return_full_text=return_full_text, | |
| seed=seed, | |
| stop=stop_sequences if stop_sequences is not None else [], | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| truncate=truncate, | |
| typical_p=typical_p, | |
| watermark=watermark, | |
| ) | |
| request = Request(inputs=prompt, stream=True, parameters=parameters) | |
| resp = requests.post( | |
| self.base_url, | |
| json=request.dict(), | |
| headers=self.headers, | |
| cookies=self.cookies, | |
| timeout=self.timeout, | |
| stream=True, | |
| ) | |
| if resp.status_code != 200: | |
| raise parse_error(resp.status_code, resp.json()) | |
| # Parse ServerSentEvents | |
| for byte_payload in resp.iter_lines(): | |
| # Skip line | |
| if byte_payload == b"\n": | |
| continue | |
| payload = byte_payload.decode("utf-8") | |
| # Event data | |
| if payload.startswith("data:"): | |
| # Decode payload | |
| json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
| # Parse payload | |
| try: | |
| response = StreamResponse(**json_payload) | |
| except ValidationError: | |
| # If we failed to parse the payload, then it is an error payload | |
| raise parse_error(resp.status_code, json_payload) | |
| yield response | |
| class AsyncClient: | |
| """Asynchronous Client to make calls to a text-generation-inference instance | |
| Example: | |
| ```python | |
| >>> from text_generation import AsyncClient | |
| >>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz") | |
| >>> response = await client.generate("Why is the sky blue?") | |
| >>> response.generated_text | |
| ' Rayleigh scattering' | |
| >>> result = "" | |
| >>> async for response in client.generate_stream("Why is the sky blue?"): | |
| >>> if not response.token.special: | |
| >>> result += response.token.text | |
| >>> result | |
| ' Rayleigh scattering' | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str, | |
| headers: Optional[Dict[str, str]] = None, | |
| cookies: Optional[Dict[str, str]] = None, | |
| timeout: int = 10, | |
| ): | |
| """ | |
| Args: | |
| base_url (`str`): | |
| text-generation-inference instance base url | |
| headers (`Optional[Dict[str, str]]`): | |
| Additional headers | |
| cookies (`Optional[Dict[str, str]]`): | |
| Cookies to include in the requests | |
| timeout (`int`): | |
| Timeout in seconds | |
| """ | |
| self.base_url = base_url | |
| self.headers = headers | |
| self.cookies = cookies | |
| self.timeout = ClientTimeout(timeout * 60) | |
| async def generate( | |
| self, | |
| prompt: str, | |
| do_sample: bool = False, | |
| max_new_tokens: int = 20, | |
| best_of: Optional[int] = None, | |
| repetition_penalty: Optional[float] = None, | |
| return_full_text: bool = False, | |
| seed: Optional[int] = None, | |
| stop_sequences: Optional[List[str]] = None, | |
| temperature: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| truncate: Optional[int] = None, | |
| typical_p: Optional[float] = None, | |
| watermark: bool = False, | |
| ) -> Response: | |
| """ | |
| Given a prompt, generate the following text asynchronously | |
| Args: | |
| prompt (`str`): | |
| Input text | |
| do_sample (`bool`): | |
| Activate logits sampling | |
| max_new_tokens (`int`): | |
| Maximum number of generated tokens | |
| best_of (`int`): | |
| Generate best_of sequences and return the one if the highest token logprobs | |
| repetition_penalty (`float`): | |
| The parameter for repetition penalty. 1.0 means no penalty. See [this | |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
| return_full_text (`bool`): | |
| Whether to prepend the prompt to the generated text | |
| seed (`int`): | |
| Random sampling seed | |
| stop_sequences (`List[str]`): | |
| Stop generating tokens if a member of `stop_sequences` is generated | |
| temperature (`float`): | |
| The value used to module the logits distribution. | |
| top_k (`int`): | |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| top_p (`float`): | |
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
| higher are kept for generation. | |
| truncate (`int`): | |
| Truncate inputs tokens to the given size | |
| typical_p (`float`): | |
| Typical Decoding mass | |
| See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
| watermark (`bool`): | |
| Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
| Returns: | |
| Response: generated response | |
| """ | |
| # Validate parameters | |
| parameters = Parameters( | |
| best_of=best_of, | |
| details=True, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| return_full_text=return_full_text, | |
| seed=seed, | |
| stop=stop_sequences if stop_sequences is not None else [], | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| truncate=truncate, | |
| typical_p=typical_p, | |
| watermark=watermark, | |
| ) | |
| request = Request(inputs=prompt, stream=False, parameters=parameters) | |
| async with ClientSession( | |
| headers=self.headers, cookies=self.cookies, timeout=self.timeout | |
| ) as session: | |
| async with session.post(self.base_url, json=request.dict()) as resp: | |
| payload = await resp.json() | |
| if resp.status != 200: | |
| raise parse_error(resp.status, payload) | |
| return Response(**payload[0]) | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| do_sample: bool = False, | |
| max_new_tokens: int = 20, | |
| repetition_penalty: Optional[float] = None, | |
| return_full_text: bool = False, | |
| seed: Optional[int] = None, | |
| stop_sequences: Optional[List[str]] = None, | |
| temperature: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| truncate: Optional[int] = None, | |
| typical_p: Optional[float] = None, | |
| watermark: bool = False, | |
| ) -> AsyncIterator[StreamResponse]: | |
| """ | |
| Given a prompt, generate the following stream of tokens asynchronously | |
| Args: | |
| prompt (`str`): | |
| Input text | |
| do_sample (`bool`): | |
| Activate logits sampling | |
| max_new_tokens (`int`): | |
| Maximum number of generated tokens | |
| repetition_penalty (`float`): | |
| The parameter for repetition penalty. 1.0 means no penalty. See [this | |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
| return_full_text (`bool`): | |
| Whether to prepend the prompt to the generated text | |
| seed (`int`): | |
| Random sampling seed | |
| stop_sequences (`List[str]`): | |
| Stop generating tokens if a member of `stop_sequences` is generated | |
| temperature (`float`): | |
| The value used to module the logits distribution. | |
| top_k (`int`): | |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| top_p (`float`): | |
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
| higher are kept for generation. | |
| truncate (`int`): | |
| Truncate inputs tokens to the given size | |
| typical_p (`float`): | |
| Typical Decoding mass | |
| See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
| watermark (`bool`): | |
| Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
| Returns: | |
| AsyncIterator[StreamResponse]: stream of generated tokens | |
| """ | |
| # Validate parameters | |
| parameters = Parameters( | |
| best_of=None, | |
| details=True, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| return_full_text=return_full_text, | |
| seed=seed, | |
| stop=stop_sequences if stop_sequences is not None else [], | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| truncate=truncate, | |
| typical_p=typical_p, | |
| watermark=watermark, | |
| ) | |
| request = Request(inputs=prompt, stream=True, parameters=parameters) | |
| async with ClientSession( | |
| headers=self.headers, cookies=self.cookies, timeout=self.timeout | |
| ) as session: | |
| async with session.post(self.base_url, json=request.dict()) as resp: | |
| if resp.status != 200: | |
| raise parse_error(resp.status, await resp.json()) | |
| # Parse ServerSentEvents | |
| async for byte_payload in resp.content: | |
| # Skip line | |
| if byte_payload == b"\n": | |
| continue | |
| payload = byte_payload.decode("utf-8") | |
| # Event data | |
| if payload.startswith("data:"): | |
| # Decode payload | |
| json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
| # Parse payload | |
| try: | |
| response = StreamResponse(**json_payload) | |
| except ValidationError: | |
| # If we failed to parse the payload, then it is an error payload | |
| raise parse_error(resp.status, json_payload) | |
| yield response | |