File size: 3,821 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# +------------------------------+
#
# Banned Keywords
#
# +------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
## Reject a call / response if it contains certain keywords
from typing import Literal
import litellm
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException
class _ENTERPRISE_BannedKeywords(CustomLogger):
# Class variables or attributes
def __init__(self):
banned_keywords_list = litellm.banned_keywords_list
if banned_keywords_list is None:
raise Exception(
"`banned_keywords_list` can either be a list or filepath. None set."
)
if isinstance(banned_keywords_list, list):
self.banned_keywords_list = banned_keywords_list
if isinstance(banned_keywords_list, str): # assume it's a filepath
try:
with open(banned_keywords_list, "r") as file:
data = file.read()
self.banned_keywords_list = data.split("\n")
except FileNotFoundError:
raise Exception(
f"File not found. banned_keywords_list={banned_keywords_list}"
)
except Exception as e:
raise Exception(
f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}"
)
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
if level == "INFO":
verbose_proxy_logger.info(print_statement)
elif level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
def test_violation(self, test_str: str):
for word in self.banned_keywords_list:
if word in test_str.lower():
raise HTTPException(
status_code=400,
detail={"error": f"Keyword banned. Keyword={word}"},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
try:
"""
- check if user id part of call
- check if user id part of blocked list
"""
self.print_verbose("Inside Banned Keyword List Pre-Call Hook")
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
self.test_violation(test_str=m["content"])
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.exception(
"litellm.enterprise.enterprise_hooks.banned_keywords::async_pre_call_hook - Exception occurred - {}".format(
str(e)
)
)
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
for word in self.banned_keywords_list:
self.test_violation(test_str=response.choices[0].message.content or "")
async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
self.test_violation(test_str=response)
|