|
""" |
|
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke` |
|
|
|
In the Huggingface TGI format. |
|
""" |
|
|
|
import json |
|
import time |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union |
|
|
|
from httpx._models import Headers, Response |
|
|
|
import litellm |
|
from litellm.litellm_core_utils.asyncify import asyncify |
|
from litellm.litellm_core_utils.prompt_templates.factory import ( |
|
custom_prompt, |
|
prompt_factory, |
|
) |
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ModelResponse, Usage |
|
|
|
from ..common_utils import SagemakerError |
|
|
|
if TYPE_CHECKING: |
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj |
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj |
|
else: |
|
LiteLLMLoggingObj = Any |
|
|
|
|
|
class SagemakerConfig(BaseConfig): |
|
""" |
|
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb |
|
""" |
|
|
|
max_new_tokens: Optional[int] = None |
|
top_p: Optional[float] = None |
|
temperature: Optional[float] = None |
|
return_full_text: Optional[bool] = None |
|
|
|
def __init__( |
|
self, |
|
max_new_tokens: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
temperature: Optional[float] = None, |
|
return_full_text: Optional[bool] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return super().get_config() |
|
|
|
def get_error_class( |
|
self, error_message: str, status_code: int, headers: Union[dict, Headers] |
|
) -> BaseLLMException: |
|
return SagemakerError( |
|
message=error_message, status_code=status_code, headers=headers |
|
) |
|
|
|
def get_supported_openai_params(self, model: str) -> List: |
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] |
|
|
|
def map_openai_params( |
|
self, |
|
non_default_params: dict, |
|
optional_params: dict, |
|
model: str, |
|
drop_params: bool, |
|
) -> dict: |
|
for param, value in non_default_params.items(): |
|
if param == "temperature": |
|
if value == 0.0 or value == 0: |
|
|
|
|
|
if not non_default_params.get( |
|
"aws_sagemaker_allow_zero_temp", False |
|
): |
|
value = 0.01 |
|
|
|
optional_params["temperature"] = value |
|
if param == "top_p": |
|
optional_params["top_p"] = value |
|
if param == "n": |
|
optional_params["best_of"] = value |
|
optional_params["do_sample"] = ( |
|
True |
|
) |
|
if param == "stream": |
|
optional_params["stream"] = value |
|
if param == "stop": |
|
optional_params["stop"] = value |
|
if param == "max_tokens": |
|
|
|
|
|
if value == 0: |
|
value = 1 |
|
optional_params["max_new_tokens"] = value |
|
non_default_params.pop("aws_sagemaker_allow_zero_temp", None) |
|
return optional_params |
|
|
|
def _transform_prompt( |
|
self, |
|
model: str, |
|
messages: List, |
|
custom_prompt_dict: dict, |
|
hf_model_name: Optional[str], |
|
) -> str: |
|
if model in custom_prompt_dict: |
|
|
|
model_prompt_details = custom_prompt_dict[model] |
|
prompt = custom_prompt( |
|
role_dict=model_prompt_details.get("roles", None), |
|
initial_prompt_value=model_prompt_details.get( |
|
"initial_prompt_value", "" |
|
), |
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), |
|
messages=messages, |
|
) |
|
elif hf_model_name in custom_prompt_dict: |
|
|
|
model_prompt_details = custom_prompt_dict[hf_model_name] |
|
prompt = custom_prompt( |
|
role_dict=model_prompt_details.get("roles", None), |
|
initial_prompt_value=model_prompt_details.get( |
|
"initial_prompt_value", "" |
|
), |
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), |
|
messages=messages, |
|
) |
|
else: |
|
if hf_model_name is None: |
|
if "llama-2" in model.lower(): |
|
if "chat" in model.lower(): |
|
hf_model_name = "meta-llama/Llama-2-7b-chat-hf" |
|
else: |
|
hf_model_name = "meta-llama/Llama-2-7b" |
|
hf_model_name = ( |
|
hf_model_name or model |
|
) |
|
prompt: str = prompt_factory(model=hf_model_name, messages=messages) |
|
|
|
return prompt |
|
|
|
def transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
inference_params = optional_params.copy() |
|
stream = inference_params.pop("stream", False) |
|
data: Dict = {"parameters": inference_params} |
|
if stream is True: |
|
data["stream"] = True |
|
|
|
custom_prompt_dict = ( |
|
litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict |
|
) |
|
|
|
hf_model_name = litellm_params.get("hf_model_name", None) |
|
|
|
prompt = self._transform_prompt( |
|
model=model, |
|
messages=messages, |
|
custom_prompt_dict=custom_prompt_dict, |
|
hf_model_name=hf_model_name, |
|
) |
|
data["inputs"] = prompt |
|
|
|
return data |
|
|
|
async def async_transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
return await asyncify(self.transform_request)( |
|
model, messages, optional_params, litellm_params, headers |
|
) |
|
|
|
def transform_response( |
|
self, |
|
model: str, |
|
raw_response: Response, |
|
model_response: ModelResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
request_data: dict, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: str, |
|
api_key: Optional[str] = None, |
|
json_mode: Optional[bool] = None, |
|
) -> ModelResponse: |
|
completion_response = raw_response.json() |
|
|
|
logging_obj.post_call( |
|
input=messages, |
|
api_key="", |
|
original_response=completion_response, |
|
additional_args={"complete_input_dict": request_data}, |
|
) |
|
|
|
prompt = request_data["inputs"] |
|
|
|
|
|
try: |
|
if isinstance(completion_response, list): |
|
completion_response_choices = completion_response[0] |
|
else: |
|
completion_response_choices = completion_response |
|
completion_output = "" |
|
if "generation" in completion_response_choices: |
|
completion_output += completion_response_choices["generation"] |
|
elif "generated_text" in completion_response_choices: |
|
completion_output += completion_response_choices["generated_text"] |
|
|
|
|
|
if completion_output.startswith(prompt) and "<s>" in prompt: |
|
completion_output = completion_output.replace(prompt, "", 1) |
|
|
|
model_response.choices[0].message.content = completion_output |
|
except Exception: |
|
raise SagemakerError( |
|
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", |
|
status_code=500, |
|
) |
|
|
|
|
|
prompt_tokens = len(encoding.encode(prompt)) |
|
completion_tokens = len( |
|
encoding.encode(model_response["choices"][0]["message"].get("content", "")) |
|
) |
|
|
|
model_response.created = int(time.time()) |
|
model_response.model = model |
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens, |
|
) |
|
setattr(model_response, "usage", usage) |
|
return model_response |
|
|
|
def validate_environment( |
|
self, |
|
headers: Optional[dict], |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
) -> dict: |
|
headers = {"Content-Type": "application/json"} |
|
|
|
if headers is not None: |
|
headers = {"Content-Type": "application/json", **headers} |
|
|
|
return headers |
|
|