File size: 4,921 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# +------------------------------+
#
#        Blocked User List
#
# +------------------------------+
#  Thank you users! We ❤️ you! - Krrish & Ishaan
## This accepts a list of user id's for whom calls will be rejected


from typing import Optional, Literal
import litellm
from litellm.proxy.utils import PrismaClient
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_EndUserTable
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException


class _ENTERPRISE_BlockedUserList(CustomLogger):
    # Class variables or attributes
    def __init__(self, prisma_client: Optional[PrismaClient]):
        self.prisma_client = prisma_client

        blocked_user_list = litellm.blocked_user_list
        if blocked_user_list is None:
            self.blocked_user_list = None
            return

        if isinstance(blocked_user_list, list):
            self.blocked_user_list = blocked_user_list

        if isinstance(blocked_user_list, str):  # assume it's a filepath
            try:
                with open(blocked_user_list, "r") as file:
                    data = file.read()
                    self.blocked_user_list = data.split("\n")
            except FileNotFoundError:
                raise Exception(
                    f"File not found. blocked_user_list={blocked_user_list}"
                )
            except Exception as e:
                raise Exception(
                    f"An error occurred: {str(e)}, blocked_user_list={blocked_user_list}"
                )

    def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
        if level == "INFO":
            verbose_proxy_logger.info(print_statement)
        elif level == "DEBUG":
            verbose_proxy_logger.debug(print_statement)

        if litellm.set_verbose is True:
            print(print_statement)  # noqa

    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: str,
    ):
        try:
            """
            - check if user id part of call
            - check if user id part of blocked list
                - if blocked list is none or user not in blocked list
                - check if end-user in cache
                - check if end-user in db
            """
            self.print_verbose("Inside Blocked User List Pre-Call Hook")
            if "user_id" in data or "user" in data:
                user = data.get("user_id", data.get("user", ""))
                if (
                    self.blocked_user_list is not None
                    and user in self.blocked_user_list
                ):
                    raise HTTPException(
                        status_code=400,
                        detail={
                            "error": f"User blocked from making LLM API Calls. User={user}"
                        },
                    )

                cache_key = f"litellm:end_user_id:{user}"
                end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache(  # type: ignore
                    key=cache_key
                )
                if end_user_cache_obj is None and self.prisma_client is not None:
                    # check db
                    end_user_obj = (
                        await self.prisma_client.db.litellm_endusertable.find_unique(
                            where={"user_id": user}
                        )
                    )
                    if end_user_obj is None:  # user not in db - assume not blocked
                        end_user_obj = LiteLLM_EndUserTable(user_id=user, blocked=False)
                    cache.set_cache(key=cache_key, value=end_user_obj, ttl=60)
                    if end_user_obj is not None and end_user_obj.blocked == True:
                        raise HTTPException(
                            status_code=400,
                            detail={
                                "error": f"User blocked from making LLM API Calls. User={user}"
                            },
                        )
                elif (
                    end_user_cache_obj is not None
                    and end_user_cache_obj.blocked == True
                ):
                    raise HTTPException(
                        status_code=400,
                        detail={
                            "error": f"User blocked from making LLM API Calls. User={user}"
                        },
                    )

        except HTTPException as e:
            raise e
        except Exception as e:
            verbose_proxy_logger.exception(
                "litellm.enterprise.enterprise_hooks.blocked_user_list::async_pre_call_hook - Exception occurred - {}".format(
                    str(e)
                )
            )