import asyncio import contextvars from functools import partial from typing import Any, Coroutine, Dict, List, Literal, Optional, Union import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.llms.jina_ai.rerank.handler import JinaAIRerank from litellm.llms.together_ai.rerank.handler import TogetherAIRerank from litellm.rerank_api.rerank_utils import get_optional_rerank_params from litellm.secret_managers.main import get_secret, get_secret_str from litellm.types.rerank import OptionalRerankParams, RerankResponse from litellm.types.router import * from litellm.utils import ProviderConfigManager, client, exception_type ####### ENVIRONMENT VARIABLES ################### # Initialize any necessary instances or variables here together_rerank = TogetherAIRerank() jina_ai_rerank = JinaAIRerank() bedrock_rerank = BedrockRerankHandler() base_llm_http_handler = BaseLLMHTTPHandler() ################################################# @client async def arerank( model: str, query: str, documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = None, max_chunks_per_doc: Optional[int] = None, **kwargs, ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: """ Async: Reranks a list of documents based on their relevance to the query """ try: loop = asyncio.get_event_loop() kwargs["arerank"] = True func = partial( rerank, model, query, documents, custom_llm_provider, top_n, rank_fields, return_documents, max_chunks_per_doc, **kwargs, ) ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) init_response = await loop.run_in_executor(None, func_with_context) if asyncio.iscoroutine(init_response): response = await init_response else: response = init_response return response except Exception as e: raise e @client def rerank( # noqa: PLR0915 model: str, query: str, documents: List[Union[str, Dict[str, Any]]], custom_llm_provider: Optional[ Literal["cohere", "together_ai", "azure_ai", "infinity"] ] = None, top_n: Optional[int] = None, rank_fields: Optional[List[str]] = None, return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, **kwargs, ) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: """ Reranks a list of documents based on their relevance to the query """ headers: Optional[dict] = kwargs.get("headers") # type: ignore litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) proxy_server_request = kwargs.get("proxy_server_request", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) user = kwargs.get("user", None) client = kwargs.get("client", None) try: _is_async = kwargs.pop("arerank", False) is True optional_params = GenericLiteLLMParams(**kwargs) model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( litellm.get_llm_provider( model=model, custom_llm_provider=custom_llm_provider, api_base=optional_params.api_base, api_key=optional_params.api_key, ) ) rerank_provider_config: BaseRerankConfig = ( ProviderConfigManager.get_provider_rerank_config( model=model, provider=litellm.LlmProviders(_custom_llm_provider), ) ) optional_rerank_params: OptionalRerankParams = get_optional_rerank_params( rerank_provider_config=rerank_provider_config, model=model, drop_params=kwargs.get("drop_params") or litellm.drop_params or False, query=query, documents=documents, custom_llm_provider=_custom_llm_provider, top_n=top_n, rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, non_default_params=kwargs, ) if isinstance(optional_params.timeout, str): optional_params.timeout = float(optional_params.timeout) model_response = RerankResponse() litellm_logging_obj.update_environment_variables( model=model, user=user, optional_params=dict(optional_rerank_params), litellm_params={ "litellm_call_id": litellm_call_id, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}, **optional_params.model_dump(exclude_unset=True), }, custom_llm_provider=_custom_llm_provider, ) # Implement rerank logic here based on the custom_llm_provider if _custom_llm_provider == "cohere": # Implement Cohere rerank logic api_key: Optional[str] = ( dynamic_api_key or optional_params.api_key or litellm.api_key ) api_base: Optional[str] = ( dynamic_api_base or optional_params.api_base or litellm.api_base or get_secret("COHERE_API_BASE") # type: ignore or "https://api.cohere.com" ) if api_base is None: raise Exception( "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var." ) response = base_llm_http_handler.rerank( model=model, custom_llm_provider=_custom_llm_provider, optional_rerank_params=optional_rerank_params, logging_obj=litellm_logging_obj, timeout=optional_params.timeout, api_key=dynamic_api_key or optional_params.api_key, api_base=api_base, _is_async=_is_async, headers=headers or litellm.headers or {}, client=client, model_response=model_response, ) elif _custom_llm_provider == "azure_ai": api_base = ( dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there or optional_params.api_base or litellm.api_base or get_secret("AZURE_AI_API_BASE") # type: ignore ) response = base_llm_http_handler.rerank( model=model, custom_llm_provider=_custom_llm_provider, optional_rerank_params=optional_rerank_params, logging_obj=litellm_logging_obj, timeout=optional_params.timeout, api_key=dynamic_api_key or optional_params.api_key, api_base=api_base, _is_async=_is_async, headers=headers or litellm.headers or {}, client=client, model_response=model_response, ) elif _custom_llm_provider == "infinity": # Implement Infinity rerank logic api_key = dynamic_api_key or optional_params.api_key or litellm.api_key api_base = ( dynamic_api_base or optional_params.api_base or litellm.api_base or get_secret_str("INFINITY_API_BASE") ) if api_base is None: raise Exception( "Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var." ) response = base_llm_http_handler.rerank( model=model, custom_llm_provider=_custom_llm_provider, optional_rerank_params=optional_rerank_params, logging_obj=litellm_logging_obj, timeout=optional_params.timeout, api_key=dynamic_api_key or optional_params.api_key, api_base=api_base, _is_async=_is_async, headers=headers or litellm.headers or {}, client=client, model_response=model_response, ) elif _custom_llm_provider == "together_ai": # Implement Together AI rerank logic api_key = ( dynamic_api_key or optional_params.api_key or litellm.togetherai_api_key or get_secret("TOGETHERAI_API_KEY") # type: ignore or litellm.api_key ) if api_key is None: raise ValueError( "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment" ) response = together_rerank.rerank( model=model, query=query, documents=documents, top_n=top_n, rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, api_key=api_key, _is_async=_is_async, ) elif _custom_llm_provider == "jina_ai": if dynamic_api_key is None: raise ValueError( "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" ) response = jina_ai_rerank.rerank( model=model, api_key=dynamic_api_key, query=query, documents=documents, top_n=top_n, rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, _is_async=_is_async, ) elif _custom_llm_provider == "bedrock": api_base = ( dynamic_api_base or optional_params.api_base or litellm.api_base or get_secret("BEDROCK_API_BASE") # type: ignore ) response = bedrock_rerank.rerank( model=model, query=query, documents=documents, top_n=top_n, rank_fields=rank_fields, return_documents=return_documents, max_chunks_per_doc=max_chunks_per_doc, _is_async=_is_async, optional_params=optional_params.model_dump(exclude_unset=True), api_base=api_base, logging_obj=litellm_logging_obj, ) else: raise ValueError(f"Unsupported provider: {_custom_llm_provider}") # Placeholder return return response except Exception as e: verbose_logger.error(f"Error in rerank: {str(e)}") raise exception_type( model=model, custom_llm_provider=custom_llm_provider, original_exception=e )