""" Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously. Has 4 primary methods: - set_cache - get_cache - async_set_cache - async_get_cache """ import asyncio import time import traceback from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, List, Optional import litellm from litellm._logging import print_verbose, verbose_logger from .base_cache import BaseCache from .in_memory_cache import InMemoryCache from .redis_cache import RedisCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span Span = _Span else: Span = Any from collections import OrderedDict class LimitedSizeOrderedDict(OrderedDict): def __init__(self, *args, max_size=100, **kwargs): super().__init__(*args, **kwargs) self.max_size = max_size def __setitem__(self, key, value): # If inserting a new key exceeds max size, remove the oldest item if len(self) >= self.max_size: self.popitem(last=False) super().__setitem__(key, value) class DualCache(BaseCache): """ DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously. When data is updated or inserted, it is written to both the in-memory cache + Redis. This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. """ def __init__( self, in_memory_cache: Optional[InMemoryCache] = None, redis_cache: Optional[RedisCache] = None, default_in_memory_ttl: Optional[float] = None, default_redis_ttl: Optional[float] = None, default_redis_batch_cache_expiry: Optional[float] = None, default_max_redis_batch_cache_size: int = 100, ) -> None: super().__init__() # If in_memory_cache is not provided, use the default InMemoryCache self.in_memory_cache = in_memory_cache or InMemoryCache() # If redis_cache is not provided, use the default RedisCache self.redis_cache = redis_cache self.last_redis_batch_access_time = LimitedSizeOrderedDict( max_size=default_max_redis_batch_cache_size ) self.redis_batch_cache_expiry = ( default_redis_batch_cache_expiry or litellm.default_redis_batch_cache_expiry or 10 ) self.default_in_memory_ttl = ( default_in_memory_ttl or litellm.default_in_memory_ttl ) self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl def update_cache_ttl( self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float] ): if default_in_memory_ttl is not None: self.default_in_memory_ttl = default_in_memory_ttl if default_redis_ttl is not None: self.default_redis_ttl = default_redis_ttl def set_cache(self, key, value, local_only: bool = False, **kwargs): # Update both Redis and in-memory cache try: if self.in_memory_cache is not None: if "ttl" not in kwargs and self.default_in_memory_ttl is not None: kwargs["ttl"] = self.default_in_memory_ttl self.in_memory_cache.set_cache(key, value, **kwargs) if self.redis_cache is not None and local_only is False: self.redis_cache.set_cache(key, value, **kwargs) except Exception as e: print_verbose(e) def increment_cache( self, key, value: int, local_only: bool = False, **kwargs ) -> int: """ Key - the key in cache Value - int - the value you want to increment by Returns - int - the incremented value """ try: result: int = value if self.in_memory_cache is not None: result = self.in_memory_cache.increment_cache(key, value, **kwargs) if self.redis_cache is not None and local_only is False: result = self.redis_cache.increment_cache(key, value, **kwargs) return result except Exception as e: verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") raise e def get_cache( self, key, parent_otel_span: Optional[Span] = None, local_only: bool = False, **kwargs, ): # Try to fetch from in-memory cache first try: result = None if self.in_memory_cache is not None: in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) if in_memory_result is not None: result = in_memory_result if result is None and self.redis_cache is not None and local_only is False: # If not found in in-memory cache, try fetching from Redis redis_result = self.redis_cache.get_cache( key, parent_otel_span=parent_otel_span ) if redis_result is not None: # Update in-memory cache with the value from Redis self.in_memory_cache.set_cache(key, redis_result, **kwargs) result = redis_result print_verbose(f"get cache: cache result: {result}") return result except Exception: verbose_logger.error(traceback.format_exc()) def batch_get_cache( self, keys: list, parent_otel_span: Optional[Span] = None, local_only: bool = False, **kwargs, ): received_args = locals() received_args.pop("self") def run_in_new_loop(): """Run the coroutine in a new event loop within this thread.""" new_loop = asyncio.new_event_loop() try: asyncio.set_event_loop(new_loop) return new_loop.run_until_complete( self.async_batch_get_cache(**received_args) ) finally: new_loop.close() asyncio.set_event_loop(None) try: # First, try to get the current event loop _ = asyncio.get_running_loop() # If we're already in an event loop, run in a separate thread # to avoid nested event loop issues with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(run_in_new_loop) return future.result() except RuntimeError: # No running event loop, we can safely run in this thread return run_in_new_loop() async def async_get_cache( self, key, parent_otel_span: Optional[Span] = None, local_only: bool = False, **kwargs, ): # Try to fetch from in-memory cache first try: print_verbose( f"async get cache: cache key: {key}; local_only: {local_only}" ) result = None if self.in_memory_cache is not None: in_memory_result = await self.in_memory_cache.async_get_cache( key, **kwargs ) print_verbose(f"in_memory_result: {in_memory_result}") if in_memory_result is not None: result = in_memory_result if result is None and self.redis_cache is not None and local_only is False: # If not found in in-memory cache, try fetching from Redis redis_result = await self.redis_cache.async_get_cache( key, parent_otel_span=parent_otel_span ) if redis_result is not None: # Update in-memory cache with the value from Redis await self.in_memory_cache.async_set_cache( key, redis_result, **kwargs ) result = redis_result print_verbose(f"get cache: cache result: {result}") return result except Exception: verbose_logger.error(traceback.format_exc()) def get_redis_batch_keys( self, current_time: float, keys: List[str], result: List[Any], ) -> List[str]: sublist_keys = [] for key, value in zip(keys, result): if value is None: if ( key not in self.last_redis_batch_access_time or current_time - self.last_redis_batch_access_time[key] >= self.redis_batch_cache_expiry ): sublist_keys.append(key) return sublist_keys async def async_batch_get_cache( self, keys: list, parent_otel_span: Optional[Span] = None, local_only: bool = False, **kwargs, ): try: result = [None for _ in range(len(keys))] if self.in_memory_cache is not None: in_memory_result = await self.in_memory_cache.async_batch_get_cache( keys, **kwargs ) if in_memory_result is not None: result = in_memory_result if None in result and self.redis_cache is not None and local_only is False: """ - for the none values in the result - check the redis cache """ current_time = time.time() sublist_keys = self.get_redis_batch_keys(current_time, keys, result) # Only hit Redis if the last access time was more than 5 seconds ago if len(sublist_keys) > 0: # If not found in in-memory cache, try fetching from Redis redis_result = await self.redis_cache.async_batch_get_cache( sublist_keys, parent_otel_span=parent_otel_span ) if redis_result is not None: # Update in-memory cache with the value from Redis for key, value in redis_result.items(): if value is not None: await self.in_memory_cache.async_set_cache( key, redis_result[key], **kwargs ) # Update the last access time for each key fetched from Redis self.last_redis_batch_access_time[key] = current_time for key, value in redis_result.items(): index = keys.index(key) result[index] = value return result except Exception: verbose_logger.error(traceback.format_exc()) async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): print_verbose( f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}" ) try: if self.in_memory_cache is not None: await self.in_memory_cache.async_set_cache(key, value, **kwargs) if self.redis_cache is not None and local_only is False: await self.redis_cache.async_set_cache(key, value, **kwargs) except Exception as e: verbose_logger.exception( f"LiteLLM Cache: Excepton async add_cache: {str(e)}" ) # async_batch_set_cache async def async_set_cache_pipeline( self, cache_list: list, local_only: bool = False, **kwargs ): """ Batch write values to the cache """ print_verbose( f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}" ) try: if self.in_memory_cache is not None: await self.in_memory_cache.async_set_cache_pipeline( cache_list=cache_list, **kwargs ) if self.redis_cache is not None and local_only is False: await self.redis_cache.async_set_cache_pipeline( cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs ) except Exception as e: verbose_logger.exception( f"LiteLLM Cache: Excepton async add_cache: {str(e)}" ) async def async_increment_cache( self, key, value: float, parent_otel_span: Optional[Span] = None, local_only: bool = False, **kwargs, ) -> float: """ Key - the key in cache Value - float - the value you want to increment by Returns - float - the incremented value """ try: result: float = value if self.in_memory_cache is not None: result = await self.in_memory_cache.async_increment( key, value, **kwargs ) if self.redis_cache is not None and local_only is False: result = await self.redis_cache.async_increment( key, value, parent_otel_span=parent_otel_span, ttl=kwargs.get("ttl", None), ) return result except Exception as e: raise e # don't log if exception is raised async def async_set_cache_sadd( self, key, value: List, local_only: bool = False, **kwargs ) -> None: """ Add value to a set Key - the key in cache Value - str - the value you want to add to the set Returns - None """ try: if self.in_memory_cache is not None: _ = await self.in_memory_cache.async_set_cache_sadd( key, value, ttl=kwargs.get("ttl", None) ) if self.redis_cache is not None and local_only is False: _ = await self.redis_cache.async_set_cache_sadd( key, value, ttl=kwargs.get("ttl", None) ) return None except Exception as e: raise e # don't log, if exception is raised def flush_cache(self): if self.in_memory_cache is not None: self.in_memory_cache.flush_cache() if self.redis_cache is not None: self.redis_cache.flush_cache() def delete_cache(self, key): """ Delete a key from the cache """ if self.in_memory_cache is not None: self.in_memory_cache.delete_cache(key) if self.redis_cache is not None: self.redis_cache.delete_cache(key) async def async_delete_cache(self, key: str): """ Delete a key from the cache """ if self.in_memory_cache is not None: self.in_memory_cache.delete_cache(key) if self.redis_cache is not None: await self.redis_cache.async_delete_cache(key) async def async_get_ttl(self, key: str) -> Optional[int]: """ Get the remaining TTL of a key in in-memory cache or redis """ ttl = await self.in_memory_cache.async_get_ttl(key) if ttl is None and self.redis_cache is not None: ttl = await self.redis_cache.async_get_ttl(key) return ttl