|
from typing import List, Literal, Optional, Tuple, Union, cast |
|
|
|
import litellm |
|
from litellm.secret_managers.main import get_secret_str |
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject |
|
from litellm.types.utils import ProviderSpecificModelInfo |
|
|
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig |
|
|
|
|
|
class FireworksAIConfig(OpenAIGPTConfig): |
|
""" |
|
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions |
|
|
|
The class `FireworksAIConfig` provides configuration for the Fireworks's Chat Completions API interface. Below are the parameters: |
|
""" |
|
|
|
tools: Optional[list] = None |
|
tool_choice: Optional[Union[str, dict]] = None |
|
max_tokens: Optional[int] = None |
|
temperature: Optional[int] = None |
|
top_p: Optional[int] = None |
|
top_k: Optional[int] = None |
|
frequency_penalty: Optional[int] = None |
|
presence_penalty: Optional[int] = None |
|
n: Optional[int] = None |
|
stop: Optional[Union[str, list]] = None |
|
response_format: Optional[dict] = None |
|
user: Optional[str] = None |
|
logprobs: Optional[int] = None |
|
|
|
|
|
prompt_truncate_length: Optional[int] = None |
|
context_length_exceeded_behavior: Optional[Literal["error", "truncate"]] = None |
|
|
|
def __init__( |
|
self, |
|
tools: Optional[list] = None, |
|
tool_choice: Optional[Union[str, dict]] = None, |
|
max_tokens: Optional[int] = None, |
|
temperature: Optional[int] = None, |
|
top_p: Optional[int] = None, |
|
top_k: Optional[int] = None, |
|
frequency_penalty: Optional[int] = None, |
|
presence_penalty: Optional[int] = None, |
|
n: Optional[int] = None, |
|
stop: Optional[Union[str, list]] = None, |
|
response_format: Optional[dict] = None, |
|
user: Optional[str] = None, |
|
logprobs: Optional[int] = None, |
|
prompt_truncate_length: Optional[int] = None, |
|
context_length_exceeded_behavior: Optional[Literal["error", "truncate"]] = None, |
|
) -> None: |
|
locals_ = locals().copy() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return super().get_config() |
|
|
|
def get_supported_openai_params(self, model: str): |
|
return [ |
|
"stream", |
|
"tools", |
|
"tool_choice", |
|
"max_completion_tokens", |
|
"max_tokens", |
|
"temperature", |
|
"top_p", |
|
"top_k", |
|
"frequency_penalty", |
|
"presence_penalty", |
|
"n", |
|
"stop", |
|
"response_format", |
|
"user", |
|
"logprobs", |
|
"prompt_truncate_length", |
|
"context_length_exceeded_behavior", |
|
] |
|
|
|
def map_openai_params( |
|
self, |
|
non_default_params: dict, |
|
optional_params: dict, |
|
model: str, |
|
drop_params: bool, |
|
) -> dict: |
|
|
|
supported_openai_params = self.get_supported_openai_params(model=model) |
|
for param, value in non_default_params.items(): |
|
if param == "tool_choice": |
|
if value == "required": |
|
|
|
optional_params["tool_choice"] = "any" |
|
else: |
|
|
|
optional_params["tool_choice"] = value |
|
elif ( |
|
param == "response_format" and value.get("type", None) == "json_schema" |
|
): |
|
optional_params["response_format"] = { |
|
"type": "json_object", |
|
"schema": value["json_schema"]["schema"], |
|
} |
|
elif param == "max_completion_tokens": |
|
optional_params["max_tokens"] = value |
|
elif param in supported_openai_params: |
|
if value is not None: |
|
optional_params[param] = value |
|
return optional_params |
|
|
|
def _add_transform_inline_image_block( |
|
self, |
|
content: ChatCompletionImageObject, |
|
model: str, |
|
disable_add_transform_inline_image_block: Optional[bool], |
|
) -> ChatCompletionImageObject: |
|
""" |
|
Add transform_inline to the image_url (allows non-vision models to parse documents/images/etc.) |
|
- ignore if model is a vision model |
|
- ignore if user has disabled this feature |
|
""" |
|
if ( |
|
"vision" in model or disable_add_transform_inline_image_block |
|
): |
|
return content |
|
if isinstance(content["image_url"], str): |
|
content["image_url"] = f"{content['image_url']}#transform=inline" |
|
elif isinstance(content["image_url"], dict): |
|
content["image_url"][ |
|
"url" |
|
] = f"{content['image_url']['url']}#transform=inline" |
|
return content |
|
|
|
def _transform_messages_helper( |
|
self, messages: List[AllMessageValues], model: str, litellm_params: dict |
|
) -> List[AllMessageValues]: |
|
""" |
|
Add 'transform=inline' to the url of the image_url |
|
""" |
|
disable_add_transform_inline_image_block = cast( |
|
Optional[bool], |
|
litellm_params.get("disable_add_transform_inline_image_block") |
|
or litellm.disable_add_transform_inline_image_block, |
|
) |
|
for message in messages: |
|
if message["role"] == "user": |
|
_message_content = message.get("content") |
|
if _message_content is not None and isinstance(_message_content, list): |
|
for content in _message_content: |
|
if content["type"] == "image_url": |
|
content = self._add_transform_inline_image_block( |
|
content=content, |
|
model=model, |
|
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block, |
|
) |
|
return messages |
|
|
|
def get_provider_info(self, model: str) -> ProviderSpecificModelInfo: |
|
provider_specific_model_info = ProviderSpecificModelInfo( |
|
supports_function_calling=True, |
|
supports_prompt_caching=True, |
|
supports_pdf_input=True, |
|
supports_vision=True, |
|
) |
|
return provider_specific_model_info |
|
|
|
def transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
if not model.startswith("accounts/"): |
|
model = f"accounts/fireworks/models/{model}" |
|
messages = self._transform_messages_helper( |
|
messages=messages, model=model, litellm_params=litellm_params |
|
) |
|
return super().transform_request( |
|
model=model, |
|
messages=messages, |
|
optional_params=optional_params, |
|
litellm_params=litellm_params, |
|
headers=headers, |
|
) |
|
|
|
def _get_openai_compatible_provider_info( |
|
self, api_base: Optional[str], api_key: Optional[str] |
|
) -> Tuple[Optional[str], Optional[str]]: |
|
api_base = ( |
|
api_base |
|
or get_secret_str("FIREWORKS_API_BASE") |
|
or "https://api.fireworks.ai/inference/v1" |
|
) |
|
dynamic_api_key = api_key or ( |
|
get_secret_str("FIREWORKS_API_KEY") |
|
or get_secret_str("FIREWORKS_AI_API_KEY") |
|
or get_secret_str("FIREWORKSAI_API_KEY") |
|
or get_secret_str("FIREWORKS_AI_TOKEN") |
|
) |
|
return api_base, dynamic_api_key |
|
|
|
def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None): |
|
|
|
api_base, api_key = self._get_openai_compatible_provider_info( |
|
api_base=api_base, api_key=api_key |
|
) |
|
if api_base is None or api_key is None: |
|
raise ValueError( |
|
"FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint." |
|
) |
|
|
|
account_id = get_secret_str("FIREWORKS_ACCOUNT_ID") |
|
if account_id is None: |
|
raise ValueError( |
|
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint." |
|
) |
|
|
|
response = litellm.module_level_client.get( |
|
url=f"{api_base}/v1/accounts/{account_id}/models", |
|
headers={"Authorization": f"Bearer {api_key}"}, |
|
) |
|
|
|
if response.status_code != 200: |
|
raise ValueError( |
|
f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}" |
|
) |
|
|
|
models = response.json()["models"] |
|
|
|
return ["fireworks_ai/" + model["name"] for model in models] |
|
|
|
@staticmethod |
|
def get_api_key(api_key: Optional[str] = None) -> Optional[str]: |
|
return api_key or ( |
|
get_secret_str("FIREWORKS_API_KEY") |
|
or get_secret_str("FIREWORKS_AI_API_KEY") |
|
or get_secret_str("FIREWORKSAI_API_KEY") |
|
or get_secret_str("FIREWORKS_AI_TOKEN") |
|
) |
|
|