File size: 3,821 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# +------------------------------+
#
#        Banned Keywords
#
# +------------------------------+
#  Thank you users! We ❤️ you! - Krrish & Ishaan
## Reject a call / response if it contains certain keywords


from typing import Literal
import litellm
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException


class _ENTERPRISE_BannedKeywords(CustomLogger):
    # Class variables or attributes
    def __init__(self):
        banned_keywords_list = litellm.banned_keywords_list

        if banned_keywords_list is None:
            raise Exception(
                "`banned_keywords_list` can either be a list or filepath. None set."
            )

        if isinstance(banned_keywords_list, list):
            self.banned_keywords_list = banned_keywords_list

        if isinstance(banned_keywords_list, str):  # assume it's a filepath
            try:
                with open(banned_keywords_list, "r") as file:
                    data = file.read()
                    self.banned_keywords_list = data.split("\n")
            except FileNotFoundError:
                raise Exception(
                    f"File not found. banned_keywords_list={banned_keywords_list}"
                )
            except Exception as e:
                raise Exception(
                    f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}"
                )

    def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
        if level == "INFO":
            verbose_proxy_logger.info(print_statement)
        elif level == "DEBUG":
            verbose_proxy_logger.debug(print_statement)

        if litellm.set_verbose is True:
            print(print_statement)  # noqa

    def test_violation(self, test_str: str):
        for word in self.banned_keywords_list:
            if word in test_str.lower():
                raise HTTPException(
                    status_code=400,
                    detail={"error": f"Keyword banned. Keyword={word}"},
                )

    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: str,  # "completion", "embeddings", "image_generation", "moderation"
    ):
        try:
            """
            - check if user id part of call
            - check if user id part of blocked list
            """
            self.print_verbose("Inside Banned Keyword List Pre-Call Hook")
            if call_type == "completion" and "messages" in data:
                for m in data["messages"]:
                    if "content" in m and isinstance(m["content"], str):
                        self.test_violation(test_str=m["content"])

        except HTTPException as e:
            raise e
        except Exception as e:
            verbose_proxy_logger.exception(
                "litellm.enterprise.enterprise_hooks.banned_keywords::async_pre_call_hook - Exception occurred - {}".format(
                    str(e)
                )
            )

    async def async_post_call_success_hook(
        self,
        data: dict,
        user_api_key_dict: UserAPIKeyAuth,
        response,
    ):
        if isinstance(response, litellm.ModelResponse) and isinstance(
            response.choices[0], litellm.utils.Choices
        ):
            for word in self.banned_keywords_list:
                self.test_violation(test_str=response.choices[0].message.content or "")

    async def async_post_call_streaming_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        response: str,
    ):
        self.test_violation(test_str=response)