|
from abc import ABC, abstractmethod |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union |
|
|
|
import httpx |
|
|
|
from litellm.types.rerank import OptionalRerankParams, RerankResponse |
|
|
|
from ..chat.transformation import BaseLLMException |
|
|
|
if TYPE_CHECKING: |
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj |
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj |
|
else: |
|
LiteLLMLoggingObj = Any |
|
|
|
|
|
class BaseRerankConfig(ABC): |
|
@abstractmethod |
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
api_key: Optional[str] = None, |
|
) -> dict: |
|
pass |
|
|
|
@abstractmethod |
|
def transform_rerank_request( |
|
self, |
|
model: str, |
|
optional_rerank_params: OptionalRerankParams, |
|
headers: dict, |
|
) -> dict: |
|
return {} |
|
|
|
@abstractmethod |
|
def transform_rerank_response( |
|
self, |
|
model: str, |
|
raw_response: httpx.Response, |
|
model_response: RerankResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
api_key: Optional[str] = None, |
|
request_data: dict = {}, |
|
optional_params: dict = {}, |
|
litellm_params: dict = {}, |
|
) -> RerankResponse: |
|
return model_response |
|
|
|
@abstractmethod |
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str: |
|
""" |
|
OPTIONAL |
|
|
|
Get the complete url for the request |
|
|
|
Some providers need `model` in `api_base` |
|
""" |
|
return api_base or "" |
|
|
|
@abstractmethod |
|
def get_supported_cohere_rerank_params(self, model: str) -> list: |
|
pass |
|
|
|
@abstractmethod |
|
def map_cohere_rerank_params( |
|
self, |
|
non_default_params: Optional[dict], |
|
model: str, |
|
drop_params: bool, |
|
query: str, |
|
documents: List[Union[str, Dict[str, Any]]], |
|
custom_llm_provider: Optional[str] = None, |
|
top_n: Optional[int] = None, |
|
rank_fields: Optional[List[str]] = None, |
|
return_documents: Optional[bool] = True, |
|
max_chunks_per_doc: Optional[int] = None, |
|
) -> OptionalRerankParams: |
|
pass |
|
|
|
@abstractmethod |
|
def get_error_class( |
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] |
|
) -> BaseLLMException: |
|
pass |
|
|