File size: 5,990 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# What this does?
## Gets a key's redis cache, and store it in memory for 1 minute.
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
### [BETA] this is in Beta. And might change.

import traceback
from typing import Literal, Optional

from fastapi import HTTPException

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth


class _PROXY_BatchRedisRequests(CustomLogger):
    # Class variables or attributes
    in_memory_cache: Optional[InMemoryCache] = None

    def __init__(self):
        if litellm.cache is not None:
            litellm.cache.async_get_cache = (
                self.async_get_cache
            )  # map the litellm 'get_cache' function to our custom function

    def print_verbose(
        self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
    ):
        if debug_level == "DEBUG":
            verbose_proxy_logger.debug(print_statement)
        elif debug_level == "INFO":
            verbose_proxy_logger.debug(print_statement)
        if litellm.set_verbose is True:
            print(print_statement)  # noqa

    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: str,
    ):
        try:
            """
            Get the user key

            Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory

            If no, then get relevant cache from redis
            """
            api_key = user_api_key_dict.api_key

            cache_key_name = f"litellm:{api_key}:{call_type}"
            self.in_memory_cache = cache.in_memory_cache

            key_value_dict = {}
            in_memory_cache_exists = False
            for key in cache.in_memory_cache.cache_dict.keys():
                if isinstance(key, str) and key.startswith(cache_key_name):
                    in_memory_cache_exists = True

            if in_memory_cache_exists is False and litellm.cache is not None:
                """
                - Check if `litellm.Cache` is redis
                - Get the relevant values
                """
                if litellm.cache.type is not None and isinstance(
                    litellm.cache.cache, RedisCache
                ):
                    # Initialize an empty list to store the keys
                    keys = []
                    self.print_verbose(f"cache_key_name: {cache_key_name}")
                    # Use the SCAN iterator to fetch keys matching the pattern
                    keys = await litellm.cache.cache.async_scan_iter(
                        pattern=cache_key_name, count=100
                    )
                    # If you need the truly "last" based on time or another criteria,
                    # ensure your key naming or storage strategy allows this determination
                    # Here you would sort or filter the keys as needed based on your strategy
                    self.print_verbose(f"redis keys: {keys}")
                    if len(keys) > 0:
                        key_value_dict = (
                            await litellm.cache.cache.async_batch_get_cache(
                                key_list=keys
                            )
                        )

            ## Add to cache
            if len(key_value_dict.items()) > 0:
                await cache.in_memory_cache.async_set_cache_pipeline(
                    cache_list=list(key_value_dict.items()), ttl=60
                )
            ## Set cache namespace if it's a miss
            data["metadata"]["redis_namespace"] = cache_key_name
        except HTTPException as e:
            raise e
        except Exception as e:
            verbose_proxy_logger.error(
                "litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
                    str(e)
                )
            )
            verbose_proxy_logger.debug(traceback.format_exc())

    async def async_get_cache(self, *args, **kwargs):
        """
        - Check if the cache key is in-memory

        - Else:
            - add missing cache key from REDIS
            - update in-memory cache
            - return redis cache request
        """
        try:  # never block execution
            cache_key: Optional[str] = None
            if "cache_key" in kwargs:
                cache_key = kwargs["cache_key"]
            elif litellm.cache is not None:
                cache_key = litellm.cache.get_cache_key(
                    *args, **kwargs
                )  # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic

            if (
                cache_key is not None
                and self.in_memory_cache is not None
                and litellm.cache is not None
            ):
                cache_control_args = kwargs.get("cache", {})
                max_age = cache_control_args.get(
                    "s-max-age", cache_control_args.get("s-maxage", float("inf"))
                )
                cached_result = self.in_memory_cache.get_cache(
                    cache_key, *args, **kwargs
                )
                if cached_result is None:
                    cached_result = await litellm.cache.cache.async_get_cache(
                        cache_key, *args, **kwargs
                    )
                    if cached_result is not None:
                        await self.in_memory_cache.async_set_cache(
                            cache_key, cached_result, ttl=60
                        )
                return litellm.cache._get_cache_logic(
                    cached_result=cached_result, max_age=max_age
                )
        except Exception:
            return None