gordonchan's picture
Upload 41 files
ca56e6a verified
import json
from typing import Optional, List, AsyncIterator
from aiohttp import ClientSession
from openai.types.chat import ChatCompletionMessageParam
from pydantic import ValidationError
from text_generation import AsyncClient
from text_generation.errors import parse_error
from text_generation.types import Request, Parameters
from text_generation.types import Response, StreamResponse
from api.adapter import get_prompt_adapter
from api.utils.compat import model_dump
class TGIEngine:
def __init__(
self,
model: AsyncClient,
model_name: str,
prompt_name: Optional[str] = None,
):
"""
Initializes the TGIEngine object.
Args:
model: The AsyncLLMEngine object.
model_name: The name of the model.
prompt_name: The name of the prompt (optional).
"""
self.model = model
self.model_name = model_name.lower()
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
def apply_chat_template(
self, messages: List[ChatCompletionMessageParam],
) -> str:
"""
Applies a chat template to the given messages and returns the processed output.
Args:
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
Returns:
str: The processed output as a string.
"""
return self.prompt_adapter.apply_chat_template(messages)
async def generate(
self,
prompt: str,
do_sample: bool = True,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = True,
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of the highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
decoder_input_details=decoder_input_details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(
headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout
) as session:
async with session.post(f"{self.model.base_url}/generate", json=model_dump(request)) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload)
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = 1,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of the highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
AsyncIterator: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, parameters=parameters)
async with ClientSession(
headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout
) as session:
async with session.post(f"{self.model.base_url}/generate_stream", json=model_dump(request)) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
# Parse ServerSentEvents
async for byte_payload in resp.content:
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response
@property
def stop(self):
"""
Gets the stop property of the prompt adapter.
Returns:
The stop property of the prompt adapter, or None if it does not exist.
"""
return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None