File size: 5,990 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# What this does?
## Gets a key's redis cache, and store it in memory for 1 minute.
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
### [BETA] this is in Beta. And might change.
import traceback
from typing import Literal, Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_BatchRedisRequests(CustomLogger):
# Class variables or attributes
in_memory_cache: Optional[InMemoryCache] = None
def __init__(self):
if litellm.cache is not None:
litellm.cache.async_get_cache = (
self.async_get_cache
) # map the litellm 'get_cache' function to our custom function
def print_verbose(
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
):
if debug_level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
elif debug_level == "INFO":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
"""
Get the user key
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
If no, then get relevant cache from redis
"""
api_key = user_api_key_dict.api_key
cache_key_name = f"litellm:{api_key}:{call_type}"
self.in_memory_cache = cache.in_memory_cache
key_value_dict = {}
in_memory_cache_exists = False
for key in cache.in_memory_cache.cache_dict.keys():
if isinstance(key, str) and key.startswith(cache_key_name):
in_memory_cache_exists = True
if in_memory_cache_exists is False and litellm.cache is not None:
"""
- Check if `litellm.Cache` is redis
- Get the relevant values
"""
if litellm.cache.type is not None and isinstance(
litellm.cache.cache, RedisCache
):
# Initialize an empty list to store the keys
keys = []
self.print_verbose(f"cache_key_name: {cache_key_name}")
# Use the SCAN iterator to fetch keys matching the pattern
keys = await litellm.cache.cache.async_scan_iter(
pattern=cache_key_name, count=100
)
# If you need the truly "last" based on time or another criteria,
# ensure your key naming or storage strategy allows this determination
# Here you would sort or filter the keys as needed based on your strategy
self.print_verbose(f"redis keys: {keys}")
if len(keys) > 0:
key_value_dict = (
await litellm.cache.cache.async_batch_get_cache(
key_list=keys
)
)
## Add to cache
if len(key_value_dict.items()) > 0:
await cache.in_memory_cache.async_set_cache_pipeline(
cache_list=list(key_value_dict.items()), ttl=60
)
## Set cache namespace if it's a miss
data["metadata"]["redis_namespace"] = cache_key_name
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_get_cache(self, *args, **kwargs):
"""
- Check if the cache key is in-memory
- Else:
- add missing cache key from REDIS
- update in-memory cache
- return redis cache request
"""
try: # never block execution
cache_key: Optional[str] = None
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
elif litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(
*args, **kwargs
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
if (
cache_key is not None
and self.in_memory_cache is not None
and litellm.cache is not None
):
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.in_memory_cache.get_cache(
cache_key, *args, **kwargs
)
if cached_result is None:
cached_result = await litellm.cache.cache.async_get_cache(
cache_key, *args, **kwargs
)
if cached_result is not None:
await self.in_memory_cache.async_set_cache(
cache_key, cached_result, ttl=60
)
return litellm.cache._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
return None
|