|
from typing import TYPE_CHECKING, Any, List, Optional, Union |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.prompt_templates.common_utils import ( |
|
convert_content_list_to_str, |
|
) |
|
from litellm.litellm_core_utils.prompt_templates.factory import ( |
|
custom_prompt, |
|
prompt_factory, |
|
) |
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ModelResponse, Usage |
|
from litellm.utils import token_counter |
|
|
|
from ..common_utils import ReplicateError |
|
|
|
if TYPE_CHECKING: |
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj |
|
|
|
LoggingClass = LiteLLMLoggingObj |
|
else: |
|
LoggingClass = Any |
|
|
|
|
|
class ReplicateConfig(BaseConfig): |
|
""" |
|
Reference: https://replicate.com/meta/llama-2-70b-chat/api |
|
- `prompt` (string): The prompt to send to the model. |
|
|
|
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`. |
|
|
|
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`. |
|
|
|
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`. |
|
|
|
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`. |
|
|
|
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`. |
|
|
|
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`. |
|
|
|
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'. |
|
|
|
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed. |
|
|
|
- `debug` (boolean): If set to `True`, it provides debugging output in logs. |
|
|
|
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models. |
|
""" |
|
|
|
system_prompt: Optional[str] = None |
|
max_new_tokens: Optional[int] = None |
|
min_new_tokens: Optional[int] = None |
|
temperature: Optional[int] = None |
|
top_p: Optional[int] = None |
|
top_k: Optional[int] = None |
|
stop_sequences: Optional[str] = None |
|
seed: Optional[int] = None |
|
debug: Optional[bool] = None |
|
|
|
def __init__( |
|
self, |
|
system_prompt: Optional[str] = None, |
|
max_new_tokens: Optional[int] = None, |
|
min_new_tokens: Optional[int] = None, |
|
temperature: Optional[int] = None, |
|
top_p: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
stop_sequences: Optional[str] = None, |
|
seed: Optional[int] = None, |
|
debug: Optional[bool] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return super().get_config() |
|
|
|
def get_supported_openai_params(self, model: str) -> list: |
|
return [ |
|
"stream", |
|
"temperature", |
|
"max_tokens", |
|
"top_p", |
|
"stop", |
|
"seed", |
|
"tools", |
|
"tool_choice", |
|
"functions", |
|
"function_call", |
|
] |
|
|
|
def map_openai_params( |
|
self, |
|
non_default_params: dict, |
|
optional_params: dict, |
|
model: str, |
|
drop_params: bool, |
|
) -> dict: |
|
for param, value in non_default_params.items(): |
|
if param == "stream": |
|
optional_params["stream"] = value |
|
if param == "max_tokens": |
|
if "vicuna" in model or "flan" in model: |
|
optional_params["max_length"] = value |
|
elif "meta/codellama-13b" in model: |
|
optional_params["max_tokens"] = value |
|
else: |
|
optional_params["max_new_tokens"] = value |
|
if param == "temperature": |
|
optional_params["temperature"] = value |
|
if param == "top_p": |
|
optional_params["top_p"] = value |
|
if param == "stop": |
|
optional_params["stop_sequences"] = value |
|
|
|
return optional_params |
|
|
|
|
|
def model_to_version_id(self, model: str) -> str: |
|
if ":" in model: |
|
split_model = model.split(":") |
|
return split_model[1] |
|
return model |
|
|
|
def get_error_class( |
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] |
|
) -> BaseLLMException: |
|
return ReplicateError( |
|
status_code=status_code, message=error_message, headers=headers |
|
) |
|
|
|
def get_complete_url( |
|
self, |
|
api_base: str, |
|
model: str, |
|
optional_params: dict, |
|
stream: Optional[bool] = None, |
|
) -> str: |
|
version_id = self.model_to_version_id(model) |
|
base_url = api_base |
|
if "deployments" in version_id: |
|
version_id = version_id.replace("deployments/", "") |
|
base_url = f"https://api.replicate.com/v1/deployments/{version_id}" |
|
else: |
|
base_url = f"https://api.replicate.com/v1/models/{version_id}" |
|
|
|
base_url = f"{base_url}/predictions" |
|
return base_url |
|
|
|
def transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
|
|
config = litellm.ReplicateConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in optional_params |
|
): |
|
optional_params[k] = v |
|
|
|
system_prompt = None |
|
if optional_params is not None and "supports_system_prompt" in optional_params: |
|
supports_sys_prompt = optional_params.pop("supports_system_prompt") |
|
else: |
|
supports_sys_prompt = False |
|
|
|
if supports_sys_prompt: |
|
for i in range(len(messages)): |
|
if messages[i]["role"] == "system": |
|
first_sys_message = messages.pop(i) |
|
system_prompt = convert_content_list_to_str(first_sys_message) |
|
break |
|
|
|
if model in litellm.custom_prompt_dict: |
|
|
|
model_prompt_details = litellm.custom_prompt_dict[model] |
|
prompt = custom_prompt( |
|
role_dict=model_prompt_details.get("roles", {}), |
|
initial_prompt_value=model_prompt_details.get( |
|
"initial_prompt_value", "" |
|
), |
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), |
|
bos_token=model_prompt_details.get("bos_token", ""), |
|
eos_token=model_prompt_details.get("eos_token", ""), |
|
messages=messages, |
|
) |
|
else: |
|
prompt = prompt_factory(model=model, messages=messages) |
|
|
|
if prompt is None or not isinstance(prompt, str): |
|
raise ReplicateError( |
|
status_code=400, |
|
message="LiteLLM Error - prompt is not a string - {}".format(prompt), |
|
headers={}, |
|
) |
|
|
|
|
|
if system_prompt is not None: |
|
input_data = { |
|
"prompt": prompt, |
|
"system_prompt": system_prompt, |
|
**optional_params, |
|
} |
|
|
|
else: |
|
input_data = {"prompt": prompt, **optional_params} |
|
|
|
version_id = self.model_to_version_id(model) |
|
request_data: dict = {"input": input_data} |
|
if ":" in version_id and len(version_id) > 64: |
|
model_parts = version_id.split(":") |
|
if ( |
|
len(model_parts) > 1 and len(model_parts[1]) == 64 |
|
): |
|
request_data["version"] = model_parts[1] |
|
|
|
return request_data |
|
|
|
def transform_response( |
|
self, |
|
model: str, |
|
raw_response: httpx.Response, |
|
model_response: ModelResponse, |
|
logging_obj: LoggingClass, |
|
request_data: dict, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: Any, |
|
api_key: Optional[str] = None, |
|
json_mode: Optional[bool] = None, |
|
) -> ModelResponse: |
|
logging_obj.post_call( |
|
input=messages, |
|
api_key=api_key, |
|
original_response=raw_response.text, |
|
additional_args={"complete_input_dict": request_data}, |
|
) |
|
raw_response_json = raw_response.json() |
|
if raw_response_json.get("status") != "succeeded": |
|
raise ReplicateError( |
|
status_code=422, |
|
message="LiteLLM Error - prediction not succeeded - {}".format( |
|
raw_response_json |
|
), |
|
headers=raw_response.headers, |
|
) |
|
outputs = raw_response_json.get("output", []) |
|
response_str = "".join(outputs) |
|
if len(response_str) == 0: |
|
response_str = " " |
|
|
|
|
|
if len(response_str) >= 1: |
|
model_response.choices[0].message.content = response_str |
|
|
|
|
|
prompt_tokens = token_counter(model=model, messages=messages) |
|
completion_tokens = token_counter( |
|
model=model, |
|
text=response_str, |
|
count_response_tokens=True, |
|
) |
|
model_response.model = "replicate/" + model |
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens, |
|
) |
|
setattr(model_response, "usage", usage) |
|
|
|
return model_response |
|
|
|
def get_prediction_url(self, response: httpx.Response) -> str: |
|
""" |
|
response json: { |
|
..., |
|
"urls":{"cancel":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4/cancel","get":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4","stream":"https://stream-b.svc.rno2.c.replicate.net/v1/streams/eot4gbydowuin4snhncydwxt57dfwgsc3w3snycx5nid7oef7jga"} |
|
} |
|
""" |
|
response_json = response.json() |
|
prediction_url = response_json.get("urls", {}).get("get") |
|
if prediction_url is None: |
|
raise ReplicateError( |
|
status_code=400, |
|
message="LiteLLM Error - prediction url is None - {}".format( |
|
response_json |
|
), |
|
headers=response.headers, |
|
) |
|
return prediction_url |
|
|
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
) -> dict: |
|
headers = { |
|
"Authorization": f"Token {api_key}", |
|
"Content-Type": "application/json", |
|
} |
|
return headers |
|
|