|
""" |
|
Common base config for all LLM providers |
|
""" |
|
|
|
import types |
|
from abc import ABC, abstractmethod |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
AsyncIterator, |
|
Iterator, |
|
List, |
|
Optional, |
|
Type, |
|
Union, |
|
) |
|
|
|
import httpx |
|
from pydantic import BaseModel |
|
|
|
from litellm._logging import verbose_logger |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ModelResponse |
|
|
|
from ..base_utils import ( |
|
map_developer_role_to_system_role, |
|
type_to_response_format_param, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj |
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj |
|
else: |
|
LiteLLMLoggingObj = Any |
|
|
|
|
|
class BaseLLMException(Exception): |
|
def __init__( |
|
self, |
|
status_code: int, |
|
message: str, |
|
headers: Optional[Union[dict, httpx.Headers]] = None, |
|
request: Optional[httpx.Request] = None, |
|
response: Optional[httpx.Response] = None, |
|
): |
|
self.status_code = status_code |
|
self.message: str = message |
|
self.headers = headers |
|
if request: |
|
self.request = request |
|
else: |
|
self.request = httpx.Request( |
|
method="POST", url="https://docs.litellm.ai/docs" |
|
) |
|
if response: |
|
self.response = response |
|
else: |
|
self.response = httpx.Response( |
|
status_code=status_code, request=self.request |
|
) |
|
super().__init__( |
|
self.message |
|
) |
|
|
|
|
|
class BaseConfig(ABC): |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return { |
|
k: v |
|
for k, v in cls.__dict__.items() |
|
if not k.startswith("__") |
|
and not k.startswith("_abc") |
|
and not isinstance( |
|
v, |
|
( |
|
types.FunctionType, |
|
types.BuiltinFunctionType, |
|
classmethod, |
|
staticmethod, |
|
), |
|
) |
|
and v is not None |
|
} |
|
|
|
def get_json_schema_from_pydantic_object( |
|
self, response_format: Optional[Union[Type[BaseModel], dict]] |
|
) -> Optional[dict]: |
|
return type_to_response_format_param(response_format=response_format) |
|
|
|
def should_fake_stream( |
|
self, |
|
model: Optional[str], |
|
stream: Optional[bool], |
|
custom_llm_provider: Optional[str] = None, |
|
) -> bool: |
|
""" |
|
Returns True if the model/provider should fake stream |
|
""" |
|
return False |
|
|
|
def translate_developer_role_to_system_role( |
|
self, |
|
messages: List[AllMessageValues], |
|
) -> List[AllMessageValues]: |
|
""" |
|
Translate `developer` role to `system` role for non-OpenAI providers. |
|
|
|
Overriden by OpenAI/Azure |
|
""" |
|
verbose_logger.debug( |
|
"Translating developer role to system role for non-OpenAI providers." |
|
) |
|
return map_developer_role_to_system_role(messages=messages) |
|
|
|
def should_retry_llm_api_inside_llm_translation_on_http_error( |
|
self, e: httpx.HTTPStatusError, litellm_params: dict |
|
) -> bool: |
|
""" |
|
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError |
|
|
|
Overriden by azure ai - where different models support different parameters |
|
""" |
|
return False |
|
|
|
def transform_request_on_unprocessable_entity_error( |
|
self, e: httpx.HTTPStatusError, request_data: dict |
|
) -> dict: |
|
""" |
|
Transform the request data on UnprocessableEntityError |
|
""" |
|
return request_data |
|
|
|
@property |
|
def max_retry_on_unprocessable_entity_error(self) -> int: |
|
""" |
|
Returns the max retry count for UnprocessableEntityError |
|
|
|
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True |
|
""" |
|
return 0 |
|
|
|
@abstractmethod |
|
def get_supported_openai_params(self, model: str) -> list: |
|
pass |
|
|
|
@abstractmethod |
|
def map_openai_params( |
|
self, |
|
non_default_params: dict, |
|
optional_params: dict, |
|
model: str, |
|
drop_params: bool, |
|
) -> dict: |
|
pass |
|
|
|
@abstractmethod |
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
) -> dict: |
|
pass |
|
|
|
def get_complete_url( |
|
self, |
|
api_base: str, |
|
model: str, |
|
optional_params: dict, |
|
stream: Optional[bool] = None, |
|
) -> str: |
|
""" |
|
OPTIONAL |
|
|
|
Get the complete url for the request |
|
|
|
Some providers need `model` in `api_base` |
|
""" |
|
return api_base |
|
|
|
@abstractmethod |
|
def transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
pass |
|
|
|
@abstractmethod |
|
def transform_response( |
|
self, |
|
model: str, |
|
raw_response: httpx.Response, |
|
model_response: ModelResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
request_data: dict, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: Any, |
|
api_key: Optional[str] = None, |
|
json_mode: Optional[bool] = None, |
|
) -> ModelResponse: |
|
pass |
|
|
|
@abstractmethod |
|
def get_error_class( |
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] |
|
) -> BaseLLMException: |
|
pass |
|
|
|
def get_model_response_iterator( |
|
self, |
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], |
|
sync_stream: bool, |
|
json_mode: Optional[bool] = False, |
|
) -> Any: |
|
pass |
|
|
|
@property |
|
def custom_llm_provider(self) -> Optional[str]: |
|
return None |
|
|