Spaces:
Sleeping
Sleeping
"""Manifest class.""" | |
import asyncio | |
import copy | |
import logging | |
from typing import ( | |
Any, | |
Dict, | |
Generator, | |
Iterator, | |
List, | |
Optional, | |
Tuple, | |
Type, | |
Union, | |
cast, | |
) | |
import numpy as np | |
from manifest.caches.noop import NoopCache | |
from manifest.caches.postgres import PostgresCache | |
from manifest.caches.redis import RedisCache | |
from manifest.caches.sqlite import SQLiteCache | |
from manifest.clients.client import Client | |
from manifest.clients.huggingface import HuggingFaceClient | |
from manifest.connections.client_pool import ( | |
CLIENT_CONSTRUCTORS, | |
ClientConnection, | |
ClientConnectionPool, | |
) | |
from manifest.request import LMChatRequest, LMScoreRequest, Request | |
from manifest.response import ModelChoices, Response, Usage, Usages | |
logging.getLogger("openai").setLevel(logging.WARNING) | |
logger = logging.getLogger(__name__) | |
CACHE_CONSTRUCTORS = { | |
"redis": RedisCache, | |
"sqlite": SQLiteCache, | |
"noop": NoopCache, | |
"postgres": PostgresCache, | |
} | |
class Manifest: | |
"""Manifest session object.""" | |
def __init__( | |
self, | |
client_name: Optional[str] = None, | |
client_connection: Optional[str] = None, | |
client_pool: Optional[List[ClientConnection]] = None, | |
client_pool_schedule: str = "round_robin", | |
cache_name: str = "noop", | |
cache_connection: Optional[str] = None, | |
stop_token: str = "", | |
**kwargs: Any, | |
): | |
""" | |
Initialize manifest. | |
Args: | |
client_name: name of client. | |
client_connection: connection string for client. | |
client_pool: list of client connections for multi-client. | |
client_pool_schedule: schedule for client pool. | |
cache_name: name of cache. | |
cache_connection: connection string for cache. | |
stop_token: stop token prompt generation. | |
Can be overridden in run | |
Remaining kwargs sent to client and cache. | |
""" | |
if not client_name and not client_pool: | |
raise ValueError( | |
"Must specify client_name or client_pool. " | |
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}" | |
) | |
if client_name and client_pool: | |
raise ValueError("Cannot specify both client_name and client_pool") | |
if client_name: | |
client_pool = [ | |
ClientConnection( | |
client_name=client_name, | |
client_connection=client_connection, | |
# Remove engine from kwargs | |
engine=kwargs.pop("engine", None), | |
) | |
] | |
self.client_pool = ClientConnectionPool( | |
client_pool, client_pool_schedule, client_args=kwargs | |
) | |
if cache_name not in CACHE_CONSTRUCTORS: | |
raise ValueError( | |
f"Unknown cache name: {cache_name}. " | |
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}" | |
) | |
# Must pass kwargs as dict for client "pop" methods removed used arguments | |
self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore | |
cache_connection, self.client_pool.request_type, cache_args=kwargs | |
) | |
if len(kwargs) > 0: | |
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.") | |
self.stop_token = stop_token | |
def close(self) -> None: | |
"""Close the client and cache.""" | |
self.client_pool.close() | |
self.cache.close() | |
def _validate_kwargs(self, kwargs: Dict, request_params: Request) -> None: | |
"""Validate kwargs. | |
Args: | |
kwargs: kwargs to validate. | |
request_params: request object to validate against. | |
""" | |
# Check for invalid kwargs | |
non_request_kwargs = [ | |
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__ | |
] | |
if len(non_request_kwargs) > 0: | |
raise ValueError( | |
f"{list(non_request_kwargs)} arguments are not recognized." | |
) | |
# Warn for valid but unused kwargs | |
request_unused_kwargs = [ | |
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs | |
] | |
if len(request_unused_kwargs) > 0: | |
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.") | |
return | |
def _split_cached_requests( | |
self, | |
request: Request, | |
client: Client, | |
overwrite_cache: bool, | |
) -> Tuple[Dict[int, Response], Request]: | |
"""Split a request into cached responses and Requests to run. | |
Args: | |
request: request object. | |
overwrite_cache: whether to overwrite cache. | |
Returns: | |
cached_idx_to_response: dict of cached responses. | |
new_request: request object with only prompts to run. | |
""" | |
cached_idx_to_response: Dict[int, Response] = {} | |
new_request = copy.deepcopy(request) | |
if not overwrite_cache: | |
if isinstance(new_request.prompt, list) and not isinstance( | |
request, LMChatRequest | |
): | |
new_request.prompt = [] | |
for idx, prompt_str in enumerate(request.prompt): | |
single_request = copy.deepcopy(request) | |
single_request.prompt = prompt_str | |
possible_response = self.cache.get( | |
client.get_cache_key(single_request) | |
) | |
if possible_response: | |
cached_idx_to_response[idx] = possible_response | |
else: | |
new_request.prompt.append(prompt_str) | |
# Chat or single string requests are not broken down into | |
# subprompts for caching. | |
elif (isinstance(new_request.prompt, str)) or ( | |
isinstance(new_request.prompt, list) | |
and isinstance(request, LMChatRequest) | |
): | |
possible_response = self.cache.get(client.get_cache_key(new_request)) | |
if possible_response: | |
cached_idx_to_response[0] = possible_response | |
new_request.prompt = None | |
else: | |
raise ValueError( | |
f"Invalid prompt type: {type(new_request.prompt)}" | |
f" with request type: {type(request)}" | |
) | |
return cached_idx_to_response, new_request | |
def _stitch_responses_and_cache( | |
self, | |
request: Request, | |
client: Client, | |
response: Union[Response, None], | |
cached_idx_to_response: Dict[int, Response], | |
) -> Response: | |
"""Stich together the cached and uncached responses.""" | |
# We stitch the responses (the choices) here from both the new request the | |
# cached entries. | |
all_model_choices = [] | |
all_usages = [] | |
all_input_prompts: List[Union[str, List[str], List[Dict]]] = [] | |
response_idx = 0 | |
number_prompts = len(cached_idx_to_response) | |
single_completion_output = False | |
if response: | |
if isinstance(response.get_request_obj().prompt, str): | |
single_completion_output = True | |
number_prompts += 1 | |
elif isinstance(response.get_request_obj().prompt, list) and not isinstance( | |
request, LMChatRequest | |
): | |
number_prompts += len(response.get_request_obj().prompt) | |
elif isinstance(response.get_request_obj().prompt, list) and isinstance( | |
request, LMChatRequest | |
): | |
assert len(cached_idx_to_response) <= 1 | |
number_prompts += 1 | |
else: | |
raise ValueError( | |
f"Invalid prompt type: {type(response.get_request_obj().prompt)}" | |
f" with request type: {type(request)}" | |
) | |
response_type = None | |
request_type: Type[Request] = None | |
for idx in range(number_prompts): | |
if idx in cached_idx_to_response: | |
cached_res = cached_idx_to_response[idx] | |
response_type = cached_res._response_type | |
request_type = cached_res._request_type | |
all_input_prompts.append(cached_res.get_request_obj().prompt) | |
if request.n == 1: | |
assert ( | |
len(cached_res.get_response_obj().choices) == 1 | |
), "cached response should have only one choice" | |
all_model_choices.extend(cached_res.get_response_obj().choices) | |
if cached_res.get_usage_obj().usages: | |
all_usages.extend(cached_res.get_usage_obj().usages) | |
else: | |
assert response is not None, "response should not be None" | |
response = cast(Response, response) | |
response_type = response._response_type | |
request_type = response._request_type | |
# the choices list in the response is a flat one. | |
# length is request.n * num_prompts | |
current_choices = response.get_response_obj().choices[ | |
response_idx * request.n : (response_idx + 1) * request.n | |
] | |
all_model_choices.extend(current_choices) | |
if isinstance( | |
response.get_request_obj().prompt, list | |
) and not isinstance(request, LMChatRequest): | |
prompt: Union[ | |
str, List[str], List[Dict] | |
] = response.get_request_obj().prompt[response_idx] | |
# Chat request | |
elif isinstance(response.get_request_obj().prompt, list) and isinstance( | |
request, LMChatRequest | |
): | |
# We will only have response_idx == 0 here as we can only | |
# support single chat requests. | |
assert request.n == 1 | |
assert number_prompts <= 1 | |
prompt = response.get_request_obj().prompt | |
else: | |
prompt = str(response.get_request_obj().prompt) | |
usages: Optional[List[Usage]] = None | |
if response.get_usage_obj().usages: | |
usages = response.get_usage_obj().usages[ | |
response_idx * request.n : (response_idx + 1) * request.n | |
] | |
all_usages.extend(usages) | |
all_input_prompts.append(prompt) | |
# set cache | |
new_request = copy.deepcopy(request) | |
new_request.prompt = prompt # type: ignore | |
cache_key = client.get_cache_key(new_request) | |
new_response = copy.deepcopy(response) | |
new_response._response.choices = current_choices | |
new_response._usages = Usages(usages=(usages or [])) | |
self.cache.set(cache_key, new_response.to_dict(drop_request=True)) | |
response_idx += 1 | |
new_request = copy.deepcopy(request) | |
new_request.prompt = ( | |
all_input_prompts # type: ignore | |
if len(all_input_prompts) > 1 or not single_completion_output | |
else all_input_prompts[0] | |
) | |
response_obj = Response( | |
response=ModelChoices(choices=all_model_choices), | |
cached=len(cached_idx_to_response) > 0, | |
request=new_request, | |
usages=Usages(usages=all_usages), | |
response_type=response_type, | |
request_type=request_type, | |
) | |
return response_obj | |
def run( | |
self, | |
prompt: Union[str, List[str], List[Dict[str, str]]], | |
overwrite_cache: bool = False, | |
stop_token: Optional[str] = None, | |
return_response: bool = False, | |
stream: bool = False, | |
**kwargs: Any, | |
) -> Union[ | |
str, | |
List[str], | |
np.ndarray, | |
List[np.ndarray], | |
Response, | |
Iterator[str], | |
Iterator[Response], | |
]: | |
""" | |
Run the prompt. | |
Orchestrates between the standard run and chat run and batch run. | |
Args: | |
prompt: prompt(s) to run. | |
overwrite_cache: whether to overwrite cache. | |
stop_token: stop token for prompt generation. | |
Default is self.stop_token. | |
"" for no stop token. | |
return_response: whether to return Response object. | |
stream: whether to stream the prompt. Only supported | |
for single string prompts and LMs. | |
Returns: | |
response from prompt. | |
""" | |
if not isinstance(prompt, list) and not isinstance(prompt, str): | |
raise ValueError( | |
f"Invalid prompt type: {type(prompt)}. " | |
"Prompt must be a string or list of strings " | |
"or list of dicts." | |
) | |
if isinstance(prompt, list) and not prompt: | |
raise ValueError("Prompt cannot be empty list") | |
# Get the client to run | |
client = self.client_pool.get_next_client() | |
if stream: | |
if not client.supports_streaming_inference(): | |
raise ValueError( | |
f"Client {client} does not support streaming inference." | |
) | |
if not isinstance(prompt, str): | |
raise ValueError( | |
"Stream is only supported for single string prompts. " | |
"It will soon be supported for chat dictionary prompts, too." | |
) | |
return self._run_stream( | |
prompt=cast(str, prompt), | |
client=client, | |
overwrite_cache=overwrite_cache, | |
stop_token=stop_token, | |
return_response=return_response, | |
**kwargs, | |
) | |
if isinstance(prompt, list) and isinstance(prompt[0], dict): | |
if not client.IS_CHAT: | |
raise ValueError( | |
f"Client {client} does not support dict chat prompt. " | |
"Please use a chat model." | |
) | |
if stop_token: | |
logger.warning( | |
"stop_token is not supported for chat prompt. " | |
"Ignoring stop_token." | |
) | |
return self._run_chat( | |
prompt=cast(List[Dict[str, str]], prompt), | |
client=client, | |
overwrite_cache=overwrite_cache, | |
return_response=return_response, | |
**kwargs, | |
) | |
return self._run( | |
prompt=cast(Union[str, List[str]], prompt), | |
client=client, | |
overwrite_cache=overwrite_cache, | |
stop_token=stop_token, | |
return_response=return_response, | |
**kwargs, | |
) | |
def _run( | |
self, | |
prompt: Union[str, List[str]], | |
client: Client, | |
overwrite_cache: bool = False, | |
stop_token: Optional[str] = None, | |
return_response: bool = False, | |
**kwargs: Any, | |
) -> Union[str, List[str], np.ndarray, List[np.ndarray], Response]: | |
""" | |
Run the prompt. | |
Args: | |
prompt: prompt(s) to run. | |
client: client to run. | |
overwrite_cache: whether to overwrite cache. | |
stop_token: stop token for prompt generation. | |
Default is self.stop_token. | |
"" for no stop token. | |
return_response: whether to return Response object. | |
Returns: | |
response from prompt. | |
""" | |
is_batch = isinstance(prompt, list) | |
stop_token = stop_token if stop_token is not None else self.stop_token | |
# Must pass kwargs as dict for client "pop" methods removed used arguments | |
request_params = client.get_request(prompt, kwargs) | |
# Avoid nested list of results - enforce n = 1 for batch | |
if is_batch and request_params.n > 1: | |
raise ValueError("Batch mode does not support n > 1.") | |
self._validate_kwargs(kwargs, request_params) | |
cached_idx_to_response, request_params = self._split_cached_requests( | |
request_params, client, overwrite_cache | |
) | |
# If not None value or empty list - run new request | |
if request_params.prompt: | |
# Start timing metrics | |
self.client_pool.start_timer() | |
response = client.run_request(request_params) | |
self.client_pool.end_timer() | |
else: | |
# Nothing to run | |
response = None | |
final_response = self._stitch_responses_and_cache( | |
request=request_params, | |
client=client, | |
response=response, | |
cached_idx_to_response=cached_idx_to_response, | |
) | |
# Extract text results | |
if return_response: | |
return final_response | |
else: | |
return final_response.get_response(stop_token, is_batch) | |
def _run_chat( | |
self, | |
prompt: List[Dict[str, str]], | |
client: Client, | |
overwrite_cache: bool = False, | |
return_response: bool = False, | |
**kwargs: Any, | |
) -> Union[str, Response]: | |
""" | |
Run the prompt. | |
Args: | |
prompt: prompt dictionary to run. | |
client: client to run. | |
overwrite_cache: whether to overwrite cache. | |
stop_token: stop token for prompt generation. | |
Default is self.stop_token. | |
"" for no stop token. | |
return_response: whether to return Response object. | |
Returns: | |
response from prompt. | |
""" | |
is_batch = False | |
# Get a request for an empty prompt to handle all kwargs | |
request_params = client.get_request("", kwargs) | |
# Add prompt and cast as chat request | |
request_params_dict = request_params.to_dict() | |
request_params_dict["prompt"] = prompt | |
request_params_as_chat = LMChatRequest(**request_params_dict) | |
# Avoid nested list of results - enforce n = 1 for batch | |
if request_params_as_chat.n > 1: | |
raise ValueError("Chat mode does not support n > 1.") | |
self._validate_kwargs(kwargs, request_params_as_chat) | |
cached_idx_to_response, request_params_as_chat = self._split_cached_requests( # type: ignore # noqa: E501 | |
request_params_as_chat, client, overwrite_cache | |
) | |
# If not None value or empty list - run new request | |
if request_params_as_chat.prompt: | |
# Start timing metrics | |
self.client_pool.start_timer() | |
response = client.run_chat_request(request_params_as_chat) | |
self.client_pool.end_timer() | |
else: | |
# Nothing to run | |
response = None | |
final_response = self._stitch_responses_and_cache( | |
request=request_params_as_chat, | |
client=client, | |
response=response, | |
cached_idx_to_response=cached_idx_to_response, | |
) | |
# Extract text results | |
if return_response: | |
return final_response | |
else: | |
return cast(str, final_response.get_response("", is_batch)) | |
def _run_stream( | |
self, | |
prompt: str, | |
client: Client, | |
overwrite_cache: bool = False, | |
stop_token: Optional[str] = None, | |
return_response: bool = False, | |
**kwargs: Any, | |
) -> Union[Generator[str, None, None], Generator[Response, None, None]]: | |
""" | |
Run the prompt in a stream. | |
Args: | |
prompt: prompt(s) to run. | |
client: client to run. | |
overwrite_cache: whether to overwrite cache. | |
stop_token: stop token for prompt generation. | |
Default is self.stop_token. | |
"" for no stop token. | |
return_response: whether to return Response object. | |
Returns: | |
response from prompt. | |
""" | |
is_batch = False | |
stop_token = stop_token if stop_token is not None else self.stop_token | |
# Must pass kwargs as dict for client "pop" methods removed used arguments | |
request_params = client.get_request(prompt, kwargs) | |
# Avoid nested list of results - enforce n = 1 for batch | |
if request_params.n > 1: | |
raise ValueError("Stream mode does not support n > 1.") | |
self._validate_kwargs(kwargs, request_params) | |
cached_idx_to_response, request_params = self._split_cached_requests( | |
request_params, client, overwrite_cache | |
) | |
if request_params.prompt: | |
# Because we are streaming, we should have either a cached response | |
# a prompt to run | |
assert len(cached_idx_to_response) == 0 | |
response_iter = client.run_streaming_request(request_params) | |
is_cached = False | |
else: | |
assert len(cached_idx_to_response) == 1 | |
response_iter = cached_idx_to_response[0].as_iter() | |
is_cached = True | |
saved_responses = [] | |
# Start timing metrics | |
self.client_pool.start_timer() | |
for response_token in response_iter: | |
saved_responses.append(response_token) | |
if return_response: | |
yield response_token | |
else: | |
yield cast( | |
Union[str, Response], response_token.get_response("", is_batch) | |
) | |
self.client_pool.end_timer() | |
if not is_cached: | |
final_response = Response.union_all( | |
saved_responses, as_single_lmchoice=True | |
) | |
self._stitch_responses_and_cache( | |
request=request_params, | |
client=client, | |
response=final_response, | |
cached_idx_to_response=cached_idx_to_response, | |
) | |
async def arun_batch( | |
self, | |
prompts: List[str], | |
overwrite_cache: bool = False, | |
stop_token: Optional[str] = None, | |
return_response: bool = False, | |
chunk_size: int = -1, | |
verbose: bool = False, | |
**kwargs: Any, | |
) -> Union[List[str], List[np.ndarray], Response]: | |
""" | |
Run a batch of prompts with async. | |
If the client pool is a single client, all prompts will be sent | |
to one client and batch_size (which is passed it as kwargs) will | |
determine how the prompts are split. | |
If the client pool is a pool of clients, the prompts will be split | |
into chunks and sent to the clients. Each client will split the | |
chunk into batch_size prompts to send to the model. | |
Args: | |
prompts: prompts to run. | |
overwrite_cache: whether to overwrite cache. | |
stop_token: stop token for prompt generation. | |
Default is self.stop_token. | |
"" for no stop token. | |
return_response: whether to return Response object. | |
chunk_size: number of prompts to send to a client in chunks. | |
For each chunk, the client will split the chunk into | |
batch_sized prompts to send to the model. | |
For a single manifest client, there is no impact to | |
setting chunk_size. For a client pool, chunk_size | |
can be used to distribute the load across the clients. | |
verbose: whether to print progress of async tasks. | |
Returns: | |
response from prompt. | |
""" | |
if not isinstance(prompts, list): | |
raise ValueError("Prompts must be a list of strings.") | |
if not prompts: | |
raise ValueError("Prompts must not be empty.") | |
if not isinstance(prompts[0], str): | |
raise ValueError("Prompts must be a list of strings.") | |
# Split the prompts into chunks for connection pool | |
prompt_chunks: List[Tuple[Client, List[str]]] = [] | |
if chunk_size > 0: | |
for i in range(0, len(prompts), chunk_size): | |
prompt_chunks.append( | |
(self.client_pool.get_next_client(), prompts[i : i + chunk_size]) | |
) | |
else: | |
prompt_chunks = [(self.client_pool.get_next_client(), prompts)] | |
# Run the chunks | |
tasks = [] | |
for client, chunk in prompt_chunks: | |
tasks.append( | |
asyncio.create_task( | |
self._arun_batch_client( | |
prompts=chunk, | |
client=client, | |
overwrite_cache=overwrite_cache, | |
verbose=verbose, | |
**kwargs, | |
) | |
) | |
) | |
logger.info(f"Running {len(tasks)} tasks across all clients.") | |
responses = await asyncio.gather(*tasks) | |
final_response = Response.union_all(responses) | |
stop_token = stop_token if stop_token is not None else self.stop_token | |
# Extract text results | |
if return_response: | |
return final_response | |
else: | |
return cast( | |
Union[List[str], List[np.ndarray]], | |
final_response.get_response(stop_token, True), | |
) | |
async def _arun_batch_client( | |
self, | |
prompts: List[str], | |
client: Client, | |
overwrite_cache: bool = False, | |
verbose: bool = False, | |
**kwargs: Any, | |
) -> Response: | |
""" | |
Run a batch of prompts with async for single client. | |
Args: | |
prompts: prompts to run. | |
client: client to run. | |
overwrite_cache: whether to overwrite cache. | |
verbose: whether to print progress of async tasks. | |
Returns: | |
response from prompt. | |
""" | |
# Must pass kwargs as dict for client "pop" methods removed used arguments | |
request_params = client.get_request(prompts, kwargs) | |
# Avoid nested list of results - enforce n = 1 for batch | |
if request_params.n > 1: | |
raise ValueError("Batch mode does not support n > 1.") | |
self._validate_kwargs(kwargs, request_params) | |
cached_idx_to_response, request_params = self._split_cached_requests( | |
request_params, client, overwrite_cache | |
) | |
# If not None value or empty list - run new request | |
if request_params.prompt: | |
self.client_pool.start_timer() | |
response = await client.arun_batch_request(request_params, verbose=verbose) | |
self.client_pool.end_timer() | |
else: | |
# Nothing to run | |
response = None | |
final_response = self._stitch_responses_and_cache( | |
request=request_params, | |
client=client, | |
response=response, | |
cached_idx_to_response=cached_idx_to_response, | |
) | |
return final_response | |
def score_prompt( | |
self, | |
prompt: Union[str, List[str]], | |
overwrite_cache: bool = False, | |
**kwargs: Any, | |
) -> Dict: | |
""" | |
Score the prompt via forward pass of the model - no sampling or generation. | |
Returns the response object with logits of the prompt. | |
Args: | |
prompt: prompt(s) to run. | |
overwrite_cache: whether to overwrite cache. | |
Returns: | |
response from prompt. | |
""" | |
client = self.client_pool.get_next_client() | |
# Must pass kwargs as dict for client "pop" methods removed used arguments | |
request_params = client.get_request(prompt, kwargs) | |
request_params_as_score = LMScoreRequest(**request_params.to_dict()) | |
if request_params_as_score.n > 1: | |
raise ValueError("Sequence scoring does not support n > 1.") | |
self._validate_kwargs(kwargs, request_params_as_score) | |
cached_idx_to_response, request_params_as_score = self._split_cached_requests( # type: ignore # noqa: E501 | |
request_params_as_score, client, overwrite_cache | |
) | |
# If not None value or empty list - run new request | |
if request_params_as_score.prompt: | |
try: | |
response = cast(HuggingFaceClient, client).run_score_prompt_request( | |
request_params_as_score | |
) | |
except AttributeError: | |
raise ValueError("`score_prompt` only supported for HF models.") | |
else: | |
# Nothing to run | |
response = None | |
final_response = self._stitch_responses_and_cache( | |
request=request_params_as_score, | |
client=client, | |
response=response, | |
cached_idx_to_response=cached_idx_to_response, | |
) | |
return final_response.to_dict() | |