|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
import hashlib |
|
import json |
|
import time |
|
import traceback |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional, Set, Union |
|
|
|
from openai.types.audio.transcription_create_params import TranscriptionCreateParams |
|
from openai.types.chat.completion_create_params import ( |
|
CompletionCreateParamsNonStreaming, |
|
CompletionCreateParamsStreaming, |
|
) |
|
from openai.types.completion_create_params import ( |
|
CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming, |
|
) |
|
from openai.types.completion_create_params import ( |
|
CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming, |
|
) |
|
from openai.types.embedding_create_params import EmbeddingCreateParams |
|
from pydantic import BaseModel |
|
|
|
import litellm |
|
from litellm._logging import verbose_logger |
|
from litellm.types.caching import * |
|
from litellm.types.rerank import RerankRequest |
|
from litellm.types.utils import all_litellm_params |
|
|
|
from .base_cache import BaseCache |
|
from .disk_cache import DiskCache |
|
from .dual_cache import DualCache |
|
from .in_memory_cache import InMemoryCache |
|
from .qdrant_semantic_cache import QdrantSemanticCache |
|
from .redis_cache import RedisCache |
|
from .redis_semantic_cache import RedisSemanticCache |
|
from .s3_cache import S3Cache |
|
|
|
|
|
def print_verbose(print_statement): |
|
try: |
|
verbose_logger.debug(print_statement) |
|
if litellm.set_verbose: |
|
print(print_statement) |
|
except Exception: |
|
pass |
|
|
|
|
|
class CacheMode(str, Enum): |
|
default_on = "default_on" |
|
default_off = "default_off" |
|
|
|
|
|
|
|
class Cache: |
|
def __init__( |
|
self, |
|
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, |
|
mode: Optional[ |
|
CacheMode |
|
] = CacheMode.default_on, |
|
host: Optional[str] = None, |
|
port: Optional[str] = None, |
|
password: Optional[str] = None, |
|
namespace: Optional[str] = None, |
|
ttl: Optional[float] = None, |
|
default_in_memory_ttl: Optional[float] = None, |
|
default_in_redis_ttl: Optional[float] = None, |
|
similarity_threshold: Optional[float] = None, |
|
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ |
|
"completion", |
|
"acompletion", |
|
"embedding", |
|
"aembedding", |
|
"atranscription", |
|
"transcription", |
|
"atext_completion", |
|
"text_completion", |
|
"arerank", |
|
"rerank", |
|
], |
|
|
|
s3_bucket_name: Optional[str] = None, |
|
s3_region_name: Optional[str] = None, |
|
s3_api_version: Optional[str] = None, |
|
s3_use_ssl: Optional[bool] = True, |
|
s3_verify: Optional[Union[bool, str]] = None, |
|
s3_endpoint_url: Optional[str] = None, |
|
s3_aws_access_key_id: Optional[str] = None, |
|
s3_aws_secret_access_key: Optional[str] = None, |
|
s3_aws_session_token: Optional[str] = None, |
|
s3_config: Optional[Any] = None, |
|
s3_path: Optional[str] = None, |
|
redis_semantic_cache_use_async=False, |
|
redis_semantic_cache_embedding_model="text-embedding-ada-002", |
|
redis_flush_size: Optional[int] = None, |
|
redis_startup_nodes: Optional[List] = None, |
|
disk_cache_dir=None, |
|
qdrant_api_base: Optional[str] = None, |
|
qdrant_api_key: Optional[str] = None, |
|
qdrant_collection_name: Optional[str] = None, |
|
qdrant_quantization_config: Optional[str] = None, |
|
qdrant_semantic_cache_embedding_model="text-embedding-ada-002", |
|
**kwargs, |
|
): |
|
""" |
|
Initializes the cache based on the given type. |
|
|
|
Args: |
|
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". |
|
|
|
# Redis Cache Args |
|
host (str, optional): The host address for the Redis cache. Required if type is "redis". |
|
port (int, optional): The port number for the Redis cache. Required if type is "redis". |
|
password (str, optional): The password for the Redis cache. Required if type is "redis". |
|
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis". |
|
ttl (float, optional): The ttl for the Redis cache |
|
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used. |
|
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None. |
|
|
|
# Qdrant Cache Args |
|
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". |
|
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. |
|
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". |
|
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". |
|
|
|
# Disk Cache Args |
|
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None. |
|
|
|
# S3 Cache Args |
|
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None. |
|
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None. |
|
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None. |
|
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True. |
|
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None. |
|
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None. |
|
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None. |
|
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None. |
|
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None. |
|
s3_config (dict, optional): The config for the s3 cache. Defaults to None. |
|
|
|
# Common Cache Args |
|
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. |
|
**kwargs: Additional keyword arguments for redis.Redis() cache |
|
|
|
Raises: |
|
ValueError: If an invalid cache type is provided. |
|
|
|
Returns: |
|
None. Cache is set as a litellm param |
|
""" |
|
if type == LiteLLMCacheType.REDIS: |
|
self.cache: BaseCache = RedisCache( |
|
host=host, |
|
port=port, |
|
password=password, |
|
redis_flush_size=redis_flush_size, |
|
startup_nodes=redis_startup_nodes, |
|
**kwargs, |
|
) |
|
elif type == LiteLLMCacheType.REDIS_SEMANTIC: |
|
self.cache = RedisSemanticCache( |
|
host=host, |
|
port=port, |
|
password=password, |
|
similarity_threshold=similarity_threshold, |
|
use_async=redis_semantic_cache_use_async, |
|
embedding_model=redis_semantic_cache_embedding_model, |
|
**kwargs, |
|
) |
|
elif type == LiteLLMCacheType.QDRANT_SEMANTIC: |
|
self.cache = QdrantSemanticCache( |
|
qdrant_api_base=qdrant_api_base, |
|
qdrant_api_key=qdrant_api_key, |
|
collection_name=qdrant_collection_name, |
|
similarity_threshold=similarity_threshold, |
|
quantization_config=qdrant_quantization_config, |
|
embedding_model=qdrant_semantic_cache_embedding_model, |
|
) |
|
elif type == LiteLLMCacheType.LOCAL: |
|
self.cache = InMemoryCache() |
|
elif type == LiteLLMCacheType.S3: |
|
self.cache = S3Cache( |
|
s3_bucket_name=s3_bucket_name, |
|
s3_region_name=s3_region_name, |
|
s3_api_version=s3_api_version, |
|
s3_use_ssl=s3_use_ssl, |
|
s3_verify=s3_verify, |
|
s3_endpoint_url=s3_endpoint_url, |
|
s3_aws_access_key_id=s3_aws_access_key_id, |
|
s3_aws_secret_access_key=s3_aws_secret_access_key, |
|
s3_aws_session_token=s3_aws_session_token, |
|
s3_config=s3_config, |
|
s3_path=s3_path, |
|
**kwargs, |
|
) |
|
elif type == LiteLLMCacheType.DISK: |
|
self.cache = DiskCache(disk_cache_dir=disk_cache_dir) |
|
if "cache" not in litellm.input_callback: |
|
litellm.input_callback.append("cache") |
|
if "cache" not in litellm.success_callback: |
|
litellm.logging_callback_manager.add_litellm_success_callback("cache") |
|
if "cache" not in litellm._async_success_callback: |
|
litellm.logging_callback_manager.add_litellm_async_success_callback("cache") |
|
self.supported_call_types = supported_call_types |
|
self.type = type |
|
self.namespace = namespace |
|
self.redis_flush_size = redis_flush_size |
|
self.ttl = ttl |
|
self.mode: CacheMode = mode or CacheMode.default_on |
|
|
|
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None: |
|
self.ttl = default_in_memory_ttl |
|
|
|
if ( |
|
self.type == LiteLLMCacheType.REDIS |
|
or self.type == LiteLLMCacheType.REDIS_SEMANTIC |
|
) and default_in_redis_ttl is not None: |
|
self.ttl = default_in_redis_ttl |
|
|
|
if self.namespace is not None and isinstance(self.cache, RedisCache): |
|
self.cache.namespace = self.namespace |
|
|
|
def get_cache_key(self, **kwargs) -> str: |
|
""" |
|
Get the cache key for the given arguments. |
|
|
|
Args: |
|
**kwargs: kwargs to litellm.completion() or embedding() |
|
|
|
Returns: |
|
str: The cache key generated from the arguments, or None if no cache key could be generated. |
|
""" |
|
cache_key = "" |
|
|
|
|
|
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs) |
|
if preset_cache_key is not None: |
|
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key) |
|
return preset_cache_key |
|
|
|
combined_kwargs = self._get_relevant_args_to_use_for_cache_key() |
|
litellm_param_kwargs = all_litellm_params |
|
for param in kwargs: |
|
if param in combined_kwargs: |
|
param_value: Optional[str] = self._get_param_value(param, kwargs) |
|
if param_value is not None: |
|
cache_key += f"{str(param)}: {str(param_value)}" |
|
elif ( |
|
param not in litellm_param_kwargs |
|
): |
|
if ( |
|
litellm.enable_caching_on_provider_specific_optional_params is True |
|
): |
|
if kwargs[param] is None: |
|
continue |
|
param_value = kwargs[param] |
|
cache_key += f"{str(param)}: {str(param_value)}" |
|
|
|
verbose_logger.debug("\nCreated cache key: %s", cache_key) |
|
hashed_cache_key = Cache._get_hashed_cache_key(cache_key) |
|
hashed_cache_key = self._add_redis_namespace_to_cache_key( |
|
hashed_cache_key, **kwargs |
|
) |
|
self._set_preset_cache_key_in_kwargs( |
|
preset_cache_key=hashed_cache_key, **kwargs |
|
) |
|
return hashed_cache_key |
|
|
|
def _get_param_value( |
|
self, |
|
param: str, |
|
kwargs: dict, |
|
) -> Optional[str]: |
|
""" |
|
Get the value for the given param from kwargs |
|
""" |
|
if param == "model": |
|
return self._get_model_param_value(kwargs) |
|
elif param == "file": |
|
return self._get_file_param_value(kwargs) |
|
return kwargs[param] |
|
|
|
def _get_model_param_value(self, kwargs: dict) -> str: |
|
""" |
|
Handles getting the value for the 'model' param from kwargs |
|
|
|
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups |
|
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router() |
|
3. Else use the `model` passed in kwargs |
|
""" |
|
metadata: Dict = kwargs.get("metadata", {}) or {} |
|
litellm_params: Dict = kwargs.get("litellm_params", {}) or {} |
|
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {} |
|
model_group: Optional[str] = metadata.get( |
|
"model_group" |
|
) or metadata_in_litellm_params.get("model_group") |
|
caching_group = self._get_caching_group(metadata, model_group) |
|
return caching_group or model_group or kwargs["model"] |
|
|
|
def _get_caching_group( |
|
self, metadata: dict, model_group: Optional[str] |
|
) -> Optional[str]: |
|
caching_groups: Optional[List] = metadata.get("caching_groups", []) |
|
if caching_groups: |
|
for group in caching_groups: |
|
if model_group in group: |
|
return str(group) |
|
return None |
|
|
|
def _get_file_param_value(self, kwargs: dict) -> str: |
|
""" |
|
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests |
|
""" |
|
file = kwargs.get("file") |
|
metadata = kwargs.get("metadata", {}) |
|
litellm_params = kwargs.get("litellm_params", {}) |
|
return ( |
|
metadata.get("file_checksum") |
|
or getattr(file, "name", None) |
|
or metadata.get("file_name") |
|
or litellm_params.get("file_name") |
|
) |
|
|
|
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]: |
|
""" |
|
Get the preset cache key from kwargs["litellm_params"] |
|
|
|
We use _get_preset_cache_keys for two reasons |
|
|
|
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens |
|
2. avoid doing duplicate / repeated work |
|
""" |
|
if kwargs: |
|
if "litellm_params" in kwargs: |
|
return kwargs["litellm_params"].get("preset_cache_key", None) |
|
return None |
|
|
|
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None: |
|
""" |
|
Set the calculated cache key in kwargs |
|
|
|
This is used to avoid doing duplicate / repeated work |
|
|
|
Placed in kwargs["litellm_params"] |
|
""" |
|
if kwargs: |
|
if "litellm_params" in kwargs: |
|
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key |
|
|
|
def _get_relevant_args_to_use_for_cache_key(self) -> Set[str]: |
|
""" |
|
Gets the supported kwargs for each call type and combines them |
|
""" |
|
chat_completion_kwargs = self._get_litellm_supported_chat_completion_kwargs() |
|
text_completion_kwargs = self._get_litellm_supported_text_completion_kwargs() |
|
embedding_kwargs = self._get_litellm_supported_embedding_kwargs() |
|
transcription_kwargs = self._get_litellm_supported_transcription_kwargs() |
|
rerank_kwargs = self._get_litellm_supported_rerank_kwargs() |
|
exclude_kwargs = self._get_kwargs_to_exclude_from_cache_key() |
|
|
|
combined_kwargs = chat_completion_kwargs.union( |
|
text_completion_kwargs, |
|
embedding_kwargs, |
|
transcription_kwargs, |
|
rerank_kwargs, |
|
) |
|
combined_kwargs = combined_kwargs.difference(exclude_kwargs) |
|
return combined_kwargs |
|
|
|
def _get_litellm_supported_chat_completion_kwargs(self) -> Set[str]: |
|
""" |
|
Get the litellm supported chat completion kwargs |
|
|
|
This follows the OpenAI API Spec |
|
""" |
|
all_chat_completion_kwargs = set( |
|
CompletionCreateParamsNonStreaming.__annotations__.keys() |
|
).union(set(CompletionCreateParamsStreaming.__annotations__.keys())) |
|
return all_chat_completion_kwargs |
|
|
|
def _get_litellm_supported_text_completion_kwargs(self) -> Set[str]: |
|
""" |
|
Get the litellm supported text completion kwargs |
|
|
|
This follows the OpenAI API Spec |
|
""" |
|
all_text_completion_kwargs = set( |
|
TextCompletionCreateParamsNonStreaming.__annotations__.keys() |
|
).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys())) |
|
return all_text_completion_kwargs |
|
|
|
def _get_litellm_supported_rerank_kwargs(self) -> Set[str]: |
|
""" |
|
Get the litellm supported rerank kwargs |
|
""" |
|
return set(RerankRequest.model_fields.keys()) |
|
|
|
def _get_litellm_supported_embedding_kwargs(self) -> Set[str]: |
|
""" |
|
Get the litellm supported embedding kwargs |
|
|
|
This follows the OpenAI API Spec |
|
""" |
|
return set(EmbeddingCreateParams.__annotations__.keys()) |
|
|
|
def _get_litellm_supported_transcription_kwargs(self) -> Set[str]: |
|
""" |
|
Get the litellm supported transcription kwargs |
|
|
|
This follows the OpenAI API Spec |
|
""" |
|
return set(TranscriptionCreateParams.__annotations__.keys()) |
|
|
|
def _get_kwargs_to_exclude_from_cache_key(self) -> Set[str]: |
|
""" |
|
Get the kwargs to exclude from the cache key |
|
""" |
|
return set(["metadata"]) |
|
|
|
@staticmethod |
|
def _get_hashed_cache_key(cache_key: str) -> str: |
|
""" |
|
Get the hashed cache key for the given cache key. |
|
|
|
Use hashlib to create a sha256 hash of the cache key |
|
|
|
Args: |
|
cache_key (str): The cache key to hash. |
|
|
|
Returns: |
|
str: The hashed cache key. |
|
""" |
|
hash_object = hashlib.sha256(cache_key.encode()) |
|
|
|
hash_hex = hash_object.hexdigest() |
|
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex) |
|
return hash_hex |
|
|
|
def _add_redis_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str: |
|
""" |
|
If a redis namespace is provided, add it to the cache key |
|
|
|
Args: |
|
hash_hex (str): The hashed cache key. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
str: The final hashed cache key with the redis namespace. |
|
""" |
|
namespace = kwargs.get("metadata", {}).get("redis_namespace") or self.namespace |
|
if namespace: |
|
hash_hex = f"{namespace}:{hash_hex}" |
|
verbose_logger.debug("Final hashed key: %s", hash_hex) |
|
return hash_hex |
|
|
|
def generate_streaming_content(self, content): |
|
chunk_size = 5 |
|
for i in range(0, len(content), chunk_size): |
|
yield { |
|
"choices": [ |
|
{ |
|
"delta": { |
|
"role": "assistant", |
|
"content": content[i : i + chunk_size], |
|
} |
|
} |
|
] |
|
} |
|
time.sleep(0.02) |
|
|
|
def _get_cache_logic( |
|
self, |
|
cached_result: Optional[Any], |
|
max_age: Optional[float], |
|
): |
|
""" |
|
Common get cache logic across sync + async implementations |
|
""" |
|
|
|
if ( |
|
cached_result is not None |
|
and isinstance(cached_result, dict) |
|
and "timestamp" in cached_result |
|
): |
|
timestamp = cached_result["timestamp"] |
|
current_time = time.time() |
|
|
|
|
|
response_age = current_time - timestamp |
|
|
|
|
|
if max_age is not None and response_age > max_age: |
|
return None |
|
|
|
|
|
|
|
cached_response = cached_result.get("response") |
|
try: |
|
if isinstance(cached_response, dict): |
|
pass |
|
else: |
|
cached_response = json.loads( |
|
cached_response |
|
) |
|
except Exception: |
|
cached_response = ast.literal_eval(cached_response) |
|
return cached_response |
|
return cached_result |
|
|
|
def get_cache(self, **kwargs): |
|
""" |
|
Retrieves the cached result for the given arguments. |
|
|
|
Args: |
|
*args: args to litellm.completion() or embedding() |
|
**kwargs: kwargs to litellm.completion() or embedding() |
|
|
|
Returns: |
|
The cached result if it exists, otherwise None. |
|
""" |
|
try: |
|
if self.should_use_cache(**kwargs) is not True: |
|
return |
|
messages = kwargs.get("messages", []) |
|
if "cache_key" in kwargs: |
|
cache_key = kwargs["cache_key"] |
|
else: |
|
cache_key = self.get_cache_key(**kwargs) |
|
if cache_key is not None: |
|
cache_control_args = kwargs.get("cache", {}) |
|
max_age = cache_control_args.get( |
|
"s-max-age", cache_control_args.get("s-maxage", float("inf")) |
|
) |
|
cached_result = self.cache.get_cache(cache_key, messages=messages) |
|
return self._get_cache_logic( |
|
cached_result=cached_result, max_age=max_age |
|
) |
|
except Exception: |
|
print_verbose(f"An exception occurred: {traceback.format_exc()}") |
|
return None |
|
|
|
async def async_get_cache(self, **kwargs): |
|
""" |
|
Async get cache implementation. |
|
|
|
Used for embedding calls in async wrapper |
|
""" |
|
|
|
try: |
|
if self.should_use_cache(**kwargs) is not True: |
|
return |
|
|
|
kwargs.get("messages", []) |
|
if "cache_key" in kwargs: |
|
cache_key = kwargs["cache_key"] |
|
else: |
|
cache_key = self.get_cache_key(**kwargs) |
|
if cache_key is not None: |
|
cache_control_args = kwargs.get("cache", {}) |
|
max_age = cache_control_args.get( |
|
"s-max-age", cache_control_args.get("s-maxage", float("inf")) |
|
) |
|
cached_result = await self.cache.async_get_cache(cache_key, **kwargs) |
|
return self._get_cache_logic( |
|
cached_result=cached_result, max_age=max_age |
|
) |
|
except Exception: |
|
print_verbose(f"An exception occurred: {traceback.format_exc()}") |
|
return None |
|
|
|
def _add_cache_logic(self, result, **kwargs): |
|
""" |
|
Common implementation across sync + async add_cache functions |
|
""" |
|
try: |
|
if "cache_key" in kwargs: |
|
cache_key = kwargs["cache_key"] |
|
else: |
|
cache_key = self.get_cache_key(**kwargs) |
|
if cache_key is not None: |
|
if isinstance(result, BaseModel): |
|
result = result.model_dump_json() |
|
|
|
|
|
if self.ttl is not None: |
|
kwargs["ttl"] = self.ttl |
|
|
|
_cache_kwargs = kwargs.get("cache", None) |
|
if isinstance(_cache_kwargs, dict): |
|
for k, v in _cache_kwargs.items(): |
|
if k == "ttl": |
|
kwargs["ttl"] = v |
|
|
|
cached_data = {"timestamp": time.time(), "response": result} |
|
return cache_key, cached_data, kwargs |
|
else: |
|
raise Exception("cache key is None") |
|
except Exception as e: |
|
raise e |
|
|
|
def add_cache(self, result, **kwargs): |
|
""" |
|
Adds a result to the cache. |
|
|
|
Args: |
|
*args: args to litellm.completion() or embedding() |
|
**kwargs: kwargs to litellm.completion() or embedding() |
|
|
|
Returns: |
|
None |
|
""" |
|
try: |
|
if self.should_use_cache(**kwargs) is not True: |
|
return |
|
cache_key, cached_data, kwargs = self._add_cache_logic( |
|
result=result, **kwargs |
|
) |
|
self.cache.set_cache(cache_key, cached_data, **kwargs) |
|
except Exception as e: |
|
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") |
|
|
|
async def async_add_cache(self, result, **kwargs): |
|
""" |
|
Async implementation of add_cache |
|
""" |
|
try: |
|
if self.should_use_cache(**kwargs) is not True: |
|
return |
|
if self.type == "redis" and self.redis_flush_size is not None: |
|
|
|
await self.batch_cache_write(result, **kwargs) |
|
else: |
|
cache_key, cached_data, kwargs = self._add_cache_logic( |
|
result=result, **kwargs |
|
) |
|
|
|
await self.cache.async_set_cache(cache_key, cached_data, **kwargs) |
|
except Exception as e: |
|
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") |
|
|
|
async def async_add_cache_pipeline(self, result, **kwargs): |
|
""" |
|
Async implementation of add_cache for Embedding calls |
|
|
|
Does a bulk write, to prevent using too many clients |
|
""" |
|
try: |
|
if self.should_use_cache(**kwargs) is not True: |
|
return |
|
|
|
|
|
if self.ttl is not None: |
|
kwargs["ttl"] = self.ttl |
|
|
|
cache_list = [] |
|
for idx, i in enumerate(kwargs["input"]): |
|
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i}) |
|
kwargs["cache_key"] = preset_cache_key |
|
embedding_response = result.data[idx] |
|
cache_key, cached_data, kwargs = self._add_cache_logic( |
|
result=embedding_response, |
|
**kwargs, |
|
) |
|
cache_list.append((cache_key, cached_data)) |
|
|
|
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") |
|
|
|
def should_use_cache(self, **kwargs): |
|
""" |
|
Returns true if we should use the cache for LLM API calls |
|
|
|
If cache is default_on then this is True |
|
If cache is default_off then this is only true when user has opted in to use cache |
|
""" |
|
if self.mode == CacheMode.default_on: |
|
return True |
|
|
|
|
|
_cache = kwargs.get("cache", None) |
|
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache) |
|
if _cache and isinstance(_cache, dict): |
|
if _cache.get("use-cache", False) is True: |
|
return True |
|
return False |
|
|
|
async def batch_cache_write(self, result, **kwargs): |
|
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs) |
|
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs) |
|
|
|
async def ping(self): |
|
cache_ping = getattr(self.cache, "ping") |
|
if cache_ping: |
|
return await cache_ping() |
|
return None |
|
|
|
async def delete_cache_keys(self, keys): |
|
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys") |
|
if cache_delete_cache_keys: |
|
return await cache_delete_cache_keys(keys) |
|
return None |
|
|
|
async def disconnect(self): |
|
if hasattr(self.cache, "disconnect"): |
|
await self.cache.disconnect() |
|
|
|
def _supports_async(self) -> bool: |
|
""" |
|
Internal method to check if the cache type supports async get/set operations |
|
|
|
Only S3 Cache Does NOT support async operations |
|
|
|
""" |
|
if self.type and self.type == LiteLLMCacheType.S3: |
|
return False |
|
return True |
|
|
|
|
|
def enable_cache( |
|
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, |
|
host: Optional[str] = None, |
|
port: Optional[str] = None, |
|
password: Optional[str] = None, |
|
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ |
|
"completion", |
|
"acompletion", |
|
"embedding", |
|
"aembedding", |
|
"atranscription", |
|
"transcription", |
|
"atext_completion", |
|
"text_completion", |
|
"arerank", |
|
"rerank", |
|
], |
|
**kwargs, |
|
): |
|
""" |
|
Enable cache with the specified configuration. |
|
|
|
Args: |
|
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local". |
|
host (Optional[str]): The host address of the cache server. Defaults to None. |
|
port (Optional[str]): The port number of the cache server. Defaults to None. |
|
password (Optional[str]): The password for the cache server. Defaults to None. |
|
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): |
|
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
None |
|
|
|
Raises: |
|
None |
|
""" |
|
print_verbose("LiteLLM: Enabling Cache") |
|
if "cache" not in litellm.input_callback: |
|
litellm.input_callback.append("cache") |
|
if "cache" not in litellm.success_callback: |
|
litellm.logging_callback_manager.add_litellm_success_callback("cache") |
|
if "cache" not in litellm._async_success_callback: |
|
litellm.logging_callback_manager.add_litellm_async_success_callback("cache") |
|
|
|
if litellm.cache is None: |
|
litellm.cache = Cache( |
|
type=type, |
|
host=host, |
|
port=port, |
|
password=password, |
|
supported_call_types=supported_call_types, |
|
**kwargs, |
|
) |
|
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}") |
|
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") |
|
|
|
|
|
def update_cache( |
|
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, |
|
host: Optional[str] = None, |
|
port: Optional[str] = None, |
|
password: Optional[str] = None, |
|
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ |
|
"completion", |
|
"acompletion", |
|
"embedding", |
|
"aembedding", |
|
"atranscription", |
|
"transcription", |
|
"atext_completion", |
|
"text_completion", |
|
"arerank", |
|
"rerank", |
|
], |
|
**kwargs, |
|
): |
|
""" |
|
Update the cache for LiteLLM. |
|
|
|
Args: |
|
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local". |
|
host (Optional[str]): The host of the cache. Defaults to None. |
|
port (Optional[str]): The port of the cache. Defaults to None. |
|
password (Optional[str]): The password for the cache. Defaults to None. |
|
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): |
|
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. |
|
**kwargs: Additional keyword arguments for the cache. |
|
|
|
Returns: |
|
None |
|
|
|
""" |
|
print_verbose("LiteLLM: Updating Cache") |
|
litellm.cache = Cache( |
|
type=type, |
|
host=host, |
|
port=port, |
|
password=password, |
|
supported_call_types=supported_call_types, |
|
**kwargs, |
|
) |
|
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}") |
|
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") |
|
|
|
|
|
def disable_cache(): |
|
""" |
|
Disable the cache used by LiteLLM. |
|
|
|
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None. |
|
|
|
Parameters: |
|
None |
|
|
|
Returns: |
|
None |
|
""" |
|
from contextlib import suppress |
|
|
|
print_verbose("LiteLLM: Disabling Cache") |
|
with suppress(ValueError): |
|
litellm.input_callback.remove("cache") |
|
litellm.success_callback.remove("cache") |
|
litellm._async_success_callback.remove("cache") |
|
|
|
litellm.cache = None |
|
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}") |
|
|