|
""" |
|
Support for Snowflake REST API |
|
""" |
|
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple |
|
|
|
import httpx |
|
|
|
from litellm.secret_managers.main import get_secret_str |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ModelResponse |
|
|
|
from ...openai_like.chat.transformation import OpenAIGPTConfig |
|
|
|
if TYPE_CHECKING: |
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj |
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj |
|
else: |
|
LiteLLMLoggingObj = Any |
|
|
|
|
|
class SnowflakeConfig(OpenAIGPTConfig): |
|
""" |
|
source: https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex |
|
""" |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return super().get_config() |
|
|
|
def get_supported_openai_params(self, model: str) -> List: |
|
return ["temperature", "max_tokens", "top_p", "response_format"] |
|
|
|
def map_openai_params( |
|
self, |
|
non_default_params: dict, |
|
optional_params: dict, |
|
model: str, |
|
drop_params: bool, |
|
) -> dict: |
|
""" |
|
If any supported_openai_params are in non_default_params, add them to optional_params, so they are used in API call |
|
|
|
Args: |
|
non_default_params (dict): Non-default parameters to filter. |
|
optional_params (dict): Optional parameters to update. |
|
model (str): Model name for parameter support check. |
|
|
|
Returns: |
|
dict: Updated optional_params with supported non-default parameters. |
|
""" |
|
supported_openai_params = self.get_supported_openai_params(model) |
|
for param, value in non_default_params.items(): |
|
if param in supported_openai_params: |
|
optional_params[param] = value |
|
return optional_params |
|
|
|
def transform_response( |
|
self, |
|
model: str, |
|
raw_response: httpx.Response, |
|
model_response: ModelResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
request_data: dict, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: Any, |
|
api_key: Optional[str] = None, |
|
json_mode: Optional[bool] = None, |
|
) -> ModelResponse: |
|
response_json = raw_response.json() |
|
logging_obj.post_call( |
|
input=messages, |
|
api_key="", |
|
original_response=response_json, |
|
additional_args={"complete_input_dict": request_data}, |
|
) |
|
|
|
returned_response = ModelResponse(**response_json) |
|
|
|
returned_response.model = "snowflake/" + (returned_response.model or "") |
|
|
|
if model is not None: |
|
returned_response._hidden_params["model"] = model |
|
return returned_response |
|
|
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
) -> dict: |
|
""" |
|
Return headers to use for Snowflake completion request |
|
|
|
Snowflake REST API Ref: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference |
|
Expected headers: |
|
{ |
|
"Content-Type": "application/json", |
|
"Accept": "application/json", |
|
"Authorization": "Bearer " + <JWT>, |
|
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT" |
|
} |
|
""" |
|
|
|
if api_key is None: |
|
raise ValueError("Missing Snowflake JWT key") |
|
|
|
headers.update( |
|
{ |
|
"Content-Type": "application/json", |
|
"Accept": "application/json", |
|
"Authorization": "Bearer " + api_key, |
|
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT", |
|
} |
|
) |
|
return headers |
|
|
|
def _get_openai_compatible_provider_info( |
|
self, api_base: Optional[str], api_key: Optional[str] |
|
) -> Tuple[Optional[str], Optional[str]]: |
|
api_base = ( |
|
api_base |
|
or f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" |
|
or get_secret_str("SNOWFLAKE_API_BASE") |
|
) |
|
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT") |
|
return api_base, dynamic_api_key |
|
|
|
def get_complete_url( |
|
self, |
|
api_base: Optional[str], |
|
api_key: Optional[str], |
|
model: str, |
|
optional_params: dict, |
|
litellm_params: dict, |
|
stream: Optional[bool] = None, |
|
) -> str: |
|
""" |
|
If api_base is not provided, use the default DeepSeek /chat/completions endpoint. |
|
""" |
|
if not api_base: |
|
api_base = f"""https://{get_secret_str("SNOWFLAKE_ACCOUNT_ID")}.snowflakecomputing.com/api/v2/cortex/inference:complete""" |
|
|
|
return api_base |
|
|
|
def transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
stream: bool = optional_params.pop("stream", None) or False |
|
extra_body = optional_params.pop("extra_body", {}) |
|
return { |
|
"model": model, |
|
"messages": messages, |
|
"stream": stream, |
|
**optional_params, |
|
**extra_body, |
|
} |
|
|