|
""" |
|
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously. |
|
|
|
Has 4 primary methods: |
|
- set_cache |
|
- get_cache |
|
- async_set_cache |
|
- async_get_cache |
|
""" |
|
|
|
import asyncio |
|
import time |
|
import traceback |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import TYPE_CHECKING, Any, List, Optional |
|
|
|
import litellm |
|
from litellm._logging import print_verbose, verbose_logger |
|
|
|
from .base_cache import BaseCache |
|
from .in_memory_cache import InMemoryCache |
|
from .redis_cache import RedisCache |
|
|
|
if TYPE_CHECKING: |
|
from opentelemetry.trace import Span as _Span |
|
|
|
Span = _Span |
|
else: |
|
Span = Any |
|
|
|
from collections import OrderedDict |
|
|
|
|
|
class LimitedSizeOrderedDict(OrderedDict): |
|
def __init__(self, *args, max_size=100, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.max_size = max_size |
|
|
|
def __setitem__(self, key, value): |
|
|
|
if len(self) >= self.max_size: |
|
self.popitem(last=False) |
|
super().__setitem__(key, value) |
|
|
|
|
|
class DualCache(BaseCache): |
|
""" |
|
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously. |
|
When data is updated or inserted, it is written to both the in-memory cache + Redis. |
|
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_memory_cache: Optional[InMemoryCache] = None, |
|
redis_cache: Optional[RedisCache] = None, |
|
default_in_memory_ttl: Optional[float] = None, |
|
default_redis_ttl: Optional[float] = None, |
|
default_redis_batch_cache_expiry: Optional[float] = None, |
|
default_max_redis_batch_cache_size: int = 100, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.in_memory_cache = in_memory_cache or InMemoryCache() |
|
|
|
self.redis_cache = redis_cache |
|
self.last_redis_batch_access_time = LimitedSizeOrderedDict( |
|
max_size=default_max_redis_batch_cache_size |
|
) |
|
self.redis_batch_cache_expiry = ( |
|
default_redis_batch_cache_expiry |
|
or litellm.default_redis_batch_cache_expiry |
|
or 10 |
|
) |
|
self.default_in_memory_ttl = ( |
|
default_in_memory_ttl or litellm.default_in_memory_ttl |
|
) |
|
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl |
|
|
|
def update_cache_ttl( |
|
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float] |
|
): |
|
if default_in_memory_ttl is not None: |
|
self.default_in_memory_ttl = default_in_memory_ttl |
|
|
|
if default_redis_ttl is not None: |
|
self.default_redis_ttl = default_redis_ttl |
|
|
|
def set_cache(self, key, value, local_only: bool = False, **kwargs): |
|
|
|
try: |
|
if self.in_memory_cache is not None: |
|
if "ttl" not in kwargs and self.default_in_memory_ttl is not None: |
|
kwargs["ttl"] = self.default_in_memory_ttl |
|
|
|
self.in_memory_cache.set_cache(key, value, **kwargs) |
|
|
|
if self.redis_cache is not None and local_only is False: |
|
self.redis_cache.set_cache(key, value, **kwargs) |
|
except Exception as e: |
|
print_verbose(e) |
|
|
|
def increment_cache( |
|
self, key, value: int, local_only: bool = False, **kwargs |
|
) -> int: |
|
""" |
|
Key - the key in cache |
|
|
|
Value - int - the value you want to increment by |
|
|
|
Returns - int - the incremented value |
|
""" |
|
try: |
|
result: int = value |
|
if self.in_memory_cache is not None: |
|
result = self.in_memory_cache.increment_cache(key, value, **kwargs) |
|
|
|
if self.redis_cache is not None and local_only is False: |
|
result = self.redis_cache.increment_cache(key, value, **kwargs) |
|
|
|
return result |
|
except Exception as e: |
|
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") |
|
raise e |
|
|
|
def get_cache( |
|
self, |
|
key, |
|
parent_otel_span: Optional[Span] = None, |
|
local_only: bool = False, |
|
**kwargs, |
|
): |
|
|
|
try: |
|
result = None |
|
if self.in_memory_cache is not None: |
|
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) |
|
|
|
if in_memory_result is not None: |
|
result = in_memory_result |
|
|
|
if result is None and self.redis_cache is not None and local_only is False: |
|
|
|
redis_result = self.redis_cache.get_cache( |
|
key, parent_otel_span=parent_otel_span |
|
) |
|
|
|
if redis_result is not None: |
|
|
|
self.in_memory_cache.set_cache(key, redis_result, **kwargs) |
|
|
|
result = redis_result |
|
|
|
print_verbose(f"get cache: cache result: {result}") |
|
return result |
|
except Exception: |
|
verbose_logger.error(traceback.format_exc()) |
|
|
|
def batch_get_cache( |
|
self, |
|
keys: list, |
|
parent_otel_span: Optional[Span] = None, |
|
local_only: bool = False, |
|
**kwargs, |
|
): |
|
received_args = locals() |
|
received_args.pop("self") |
|
|
|
def run_in_new_loop(): |
|
"""Run the coroutine in a new event loop within this thread.""" |
|
new_loop = asyncio.new_event_loop() |
|
try: |
|
asyncio.set_event_loop(new_loop) |
|
return new_loop.run_until_complete( |
|
self.async_batch_get_cache(**received_args) |
|
) |
|
finally: |
|
new_loop.close() |
|
asyncio.set_event_loop(None) |
|
|
|
try: |
|
|
|
_ = asyncio.get_running_loop() |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=1) as executor: |
|
future = executor.submit(run_in_new_loop) |
|
return future.result() |
|
|
|
except RuntimeError: |
|
|
|
return run_in_new_loop() |
|
|
|
async def async_get_cache( |
|
self, |
|
key, |
|
parent_otel_span: Optional[Span] = None, |
|
local_only: bool = False, |
|
**kwargs, |
|
): |
|
|
|
try: |
|
print_verbose( |
|
f"async get cache: cache key: {key}; local_only: {local_only}" |
|
) |
|
result = None |
|
if self.in_memory_cache is not None: |
|
in_memory_result = await self.in_memory_cache.async_get_cache( |
|
key, **kwargs |
|
) |
|
|
|
print_verbose(f"in_memory_result: {in_memory_result}") |
|
if in_memory_result is not None: |
|
result = in_memory_result |
|
|
|
if result is None and self.redis_cache is not None and local_only is False: |
|
|
|
redis_result = await self.redis_cache.async_get_cache( |
|
key, parent_otel_span=parent_otel_span |
|
) |
|
|
|
if redis_result is not None: |
|
|
|
await self.in_memory_cache.async_set_cache( |
|
key, redis_result, **kwargs |
|
) |
|
|
|
result = redis_result |
|
|
|
print_verbose(f"get cache: cache result: {result}") |
|
return result |
|
except Exception: |
|
verbose_logger.error(traceback.format_exc()) |
|
|
|
def get_redis_batch_keys( |
|
self, |
|
current_time: float, |
|
keys: List[str], |
|
result: List[Any], |
|
) -> List[str]: |
|
sublist_keys = [] |
|
for key, value in zip(keys, result): |
|
if value is None: |
|
if ( |
|
key not in self.last_redis_batch_access_time |
|
or current_time - self.last_redis_batch_access_time[key] |
|
>= self.redis_batch_cache_expiry |
|
): |
|
sublist_keys.append(key) |
|
return sublist_keys |
|
|
|
async def async_batch_get_cache( |
|
self, |
|
keys: list, |
|
parent_otel_span: Optional[Span] = None, |
|
local_only: bool = False, |
|
**kwargs, |
|
): |
|
try: |
|
result = [None for _ in range(len(keys))] |
|
if self.in_memory_cache is not None: |
|
in_memory_result = await self.in_memory_cache.async_batch_get_cache( |
|
keys, **kwargs |
|
) |
|
|
|
if in_memory_result is not None: |
|
result = in_memory_result |
|
|
|
if None in result and self.redis_cache is not None and local_only is False: |
|
""" |
|
- for the none values in the result |
|
- check the redis cache |
|
""" |
|
current_time = time.time() |
|
sublist_keys = self.get_redis_batch_keys(current_time, keys, result) |
|
|
|
|
|
if len(sublist_keys) > 0: |
|
|
|
redis_result = await self.redis_cache.async_batch_get_cache( |
|
sublist_keys, parent_otel_span=parent_otel_span |
|
) |
|
|
|
if redis_result is not None: |
|
|
|
for key, value in redis_result.items(): |
|
if value is not None: |
|
await self.in_memory_cache.async_set_cache( |
|
key, redis_result[key], **kwargs |
|
) |
|
|
|
self.last_redis_batch_access_time[key] = current_time |
|
|
|
for key, value in redis_result.items(): |
|
index = keys.index(key) |
|
result[index] = value |
|
|
|
return result |
|
except Exception: |
|
verbose_logger.error(traceback.format_exc()) |
|
|
|
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): |
|
print_verbose( |
|
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}" |
|
) |
|
try: |
|
if self.in_memory_cache is not None: |
|
await self.in_memory_cache.async_set_cache(key, value, **kwargs) |
|
|
|
if self.redis_cache is not None and local_only is False: |
|
await self.redis_cache.async_set_cache(key, value, **kwargs) |
|
except Exception as e: |
|
verbose_logger.exception( |
|
f"LiteLLM Cache: Excepton async add_cache: {str(e)}" |
|
) |
|
|
|
|
|
async def async_set_cache_pipeline( |
|
self, cache_list: list, local_only: bool = False, **kwargs |
|
): |
|
""" |
|
Batch write values to the cache |
|
""" |
|
print_verbose( |
|
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}" |
|
) |
|
try: |
|
if self.in_memory_cache is not None: |
|
await self.in_memory_cache.async_set_cache_pipeline( |
|
cache_list=cache_list, **kwargs |
|
) |
|
|
|
if self.redis_cache is not None and local_only is False: |
|
await self.redis_cache.async_set_cache_pipeline( |
|
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs |
|
) |
|
except Exception as e: |
|
verbose_logger.exception( |
|
f"LiteLLM Cache: Excepton async add_cache: {str(e)}" |
|
) |
|
|
|
async def async_increment_cache( |
|
self, |
|
key, |
|
value: float, |
|
parent_otel_span: Optional[Span] = None, |
|
local_only: bool = False, |
|
**kwargs, |
|
) -> float: |
|
""" |
|
Key - the key in cache |
|
|
|
Value - float - the value you want to increment by |
|
|
|
Returns - float - the incremented value |
|
""" |
|
try: |
|
result: float = value |
|
if self.in_memory_cache is not None: |
|
result = await self.in_memory_cache.async_increment( |
|
key, value, **kwargs |
|
) |
|
|
|
if self.redis_cache is not None and local_only is False: |
|
result = await self.redis_cache.async_increment( |
|
key, |
|
value, |
|
parent_otel_span=parent_otel_span, |
|
ttl=kwargs.get("ttl", None), |
|
) |
|
|
|
return result |
|
except Exception as e: |
|
raise e |
|
|
|
async def async_set_cache_sadd( |
|
self, key, value: List, local_only: bool = False, **kwargs |
|
) -> None: |
|
""" |
|
Add value to a set |
|
|
|
Key - the key in cache |
|
|
|
Value - str - the value you want to add to the set |
|
|
|
Returns - None |
|
""" |
|
try: |
|
if self.in_memory_cache is not None: |
|
_ = await self.in_memory_cache.async_set_cache_sadd( |
|
key, value, ttl=kwargs.get("ttl", None) |
|
) |
|
|
|
if self.redis_cache is not None and local_only is False: |
|
_ = await self.redis_cache.async_set_cache_sadd( |
|
key, value, ttl=kwargs.get("ttl", None) |
|
) |
|
|
|
return None |
|
except Exception as e: |
|
raise e |
|
|
|
def flush_cache(self): |
|
if self.in_memory_cache is not None: |
|
self.in_memory_cache.flush_cache() |
|
if self.redis_cache is not None: |
|
self.redis_cache.flush_cache() |
|
|
|
def delete_cache(self, key): |
|
""" |
|
Delete a key from the cache |
|
""" |
|
if self.in_memory_cache is not None: |
|
self.in_memory_cache.delete_cache(key) |
|
if self.redis_cache is not None: |
|
self.redis_cache.delete_cache(key) |
|
|
|
async def async_delete_cache(self, key: str): |
|
""" |
|
Delete a key from the cache |
|
""" |
|
if self.in_memory_cache is not None: |
|
self.in_memory_cache.delete_cache(key) |
|
if self.redis_cache is not None: |
|
await self.redis_cache.async_delete_cache(key) |
|
|
|
async def async_get_ttl(self, key: str) -> Optional[int]: |
|
""" |
|
Get the remaining TTL of a key in in-memory cache or redis |
|
""" |
|
ttl = await self.in_memory_cache.async_get_ttl(key) |
|
if ttl is None and self.redis_cache is not None: |
|
ttl = await self.redis_cache.async_get_ttl(key) |
|
return ttl |
|
|