File size: 5,046 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# +-----------------------------------------------+
#
# Google Text Moderation
# https://cloud.google.com/natural-language/docs/moderating-text
#
# +-----------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
from typing import Literal
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
class _ENTERPRISE_GoogleTextModeration(CustomLogger):
user_api_key_cache = None
confidence_categories = [
"toxic",
"insult",
"profanity",
"derogatory",
"sexual",
"death_harm_and_tragedy",
"violent",
"firearms_and_weapons",
"public_safety",
"health",
"religion_and_belief",
"illicit_drugs",
"war_and_conflict",
"politics",
"finance",
"legal",
] # https://cloud.google.com/natural-language/docs/moderating-text#safety_attribute_confidence_scores
# Class variables or attributes
def __init__(self):
try:
from google.cloud import language_v1 # type: ignore
except Exception:
raise Exception(
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
)
# Instantiates a client
self.client = language_v1.LanguageServiceClient()
self.moderate_text_request = language_v1.ModerateTextRequest
self.language_document = language_v1.types.Document # type: ignore
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
default_confidence_threshold = (
litellm.google_moderation_confidence_threshold or 0.8
) # by default require a high confidence (80%) to fail
for category in self.confidence_categories:
if hasattr(litellm, f"{category}_confidence_threshold"):
setattr(
self,
f"{category}_confidence_threshold",
getattr(litellm, f"{category}_confidence_threshold"),
)
else:
setattr(
self,
f"{category}_confidence_threshold",
default_confidence_threshold,
)
set_confidence_value = getattr(
self,
f"{category}_confidence_threshold",
)
verbose_proxy_logger.info(
f"Google Text Moderation: {category}_confidence_threshold: {set_confidence_value}"
)
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
- Calls Google's Text Moderation API
- Rejects request if it fails safety check
"""
if "messages" in data and isinstance(data["messages"], list):
text = ""
for m in data["messages"]: # assume messages is a list
if "content" in m and isinstance(m["content"], str):
text += m["content"]
document = self.language_document(content=text, type_=self.document_type)
request = self.moderate_text_request(
document=document,
)
# Make the request
response = self.client.moderate_text(request=request)
for category in response.moderation_categories:
category_name = category.name
category_name = category_name.lower()
category_name = category_name.replace("&", "and")
category_name = category_name.replace(",", "")
category_name = category_name.replace(
" ", "_"
) # e.g. go from 'Firearms & Weapons' to 'firearms_and_weapons'
if category.confidence > getattr(
self, f"{category_name}_confidence_threshold"
):
raise HTTPException(
status_code=400,
detail={
"error": f"Violated content safety policy. Category={category}"
},
)
# Handle the response
return data
# google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
# asyncio.run(
# google_text_moderation_obj.async_moderation_hook(
# data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]}
# )
# )
|