# +-----------------------------------------------+ # # Google Text Moderation # https://cloud.google.com/natural-language/docs/moderating-text # # +-----------------------------------------------+ # Thank you users! We ❤️ you! - Krrish & Ishaan from typing import Literal import litellm from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger class _ENTERPRISE_GoogleTextModeration(CustomLogger): user_api_key_cache = None confidence_categories = [ "toxic", "insult", "profanity", "derogatory", "sexual", "death_harm_and_tragedy", "violent", "firearms_and_weapons", "public_safety", "health", "religion_and_belief", "illicit_drugs", "war_and_conflict", "politics", "finance", "legal", ] # https://cloud.google.com/natural-language/docs/moderating-text#safety_attribute_confidence_scores # Class variables or attributes def __init__(self): try: from google.cloud import language_v1 # type: ignore except Exception: raise Exception( "Missing google.cloud package. Run `pip install --upgrade google-cloud-language`" ) # Instantiates a client self.client = language_v1.LanguageServiceClient() self.moderate_text_request = language_v1.ModerateTextRequest self.language_document = language_v1.types.Document # type: ignore self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore default_confidence_threshold = ( litellm.google_moderation_confidence_threshold or 0.8 ) # by default require a high confidence (80%) to fail for category in self.confidence_categories: if hasattr(litellm, f"{category}_confidence_threshold"): setattr( self, f"{category}_confidence_threshold", getattr(litellm, f"{category}_confidence_threshold"), ) else: setattr( self, f"{category}_confidence_threshold", default_confidence_threshold, ) set_confidence_value = getattr( self, f"{category}_confidence_threshold", ) verbose_proxy_logger.info( f"Google Text Moderation: {category}_confidence_threshold: {set_confidence_value}" ) def print_verbose(self, print_statement): try: verbose_proxy_logger.debug(print_statement) if litellm.set_verbose: print(print_statement) # noqa except Exception: pass async def async_moderation_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", ], ): """ - Calls Google's Text Moderation API - Rejects request if it fails safety check """ if "messages" in data and isinstance(data["messages"], list): text = "" for m in data["messages"]: # assume messages is a list if "content" in m and isinstance(m["content"], str): text += m["content"] document = self.language_document(content=text, type_=self.document_type) request = self.moderate_text_request( document=document, ) # Make the request response = self.client.moderate_text(request=request) for category in response.moderation_categories: category_name = category.name category_name = category_name.lower() category_name = category_name.replace("&", "and") category_name = category_name.replace(",", "") category_name = category_name.replace( " ", "_" ) # e.g. go from 'Firearms & Weapons' to 'firearms_and_weapons' if category.confidence > getattr( self, f"{category_name}_confidence_threshold" ): raise HTTPException( status_code=400, detail={ "error": f"Violated content safety policy. Category={category}" }, ) # Handle the response return data # google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration() # asyncio.run( # google_text_moderation_obj.async_moderation_hook( # data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]} # ) # )