from typing import Literal, Optional, Union import litellm from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import ( CustomGuardrail, log_guardrail_information, ) from litellm.proxy._types import UserAPIKeyAuth class myCustomGuardrail(CustomGuardrail): def __init__( self, **kwargs, ): # store kwargs as optional_params self.optional_params = kwargs super().__init__(**kwargs) @log_guardrail_information async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[ "completion", "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription", "pass_through_endpoint", "rerank", ], ) -> Optional[Union[Exception, str, dict]]: """ Runs before the LLM API call Runs on only Input Use this if you want to MODIFY the input """ # In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM _messages = data.get("messages") if _messages: for message in _messages: _content = message.get("content") if isinstance(_content, str): if "litellm" in _content.lower(): _content = _content.replace("litellm", "********") message["content"] = _content verbose_proxy_logger.debug( "async_pre_call_hook: Message after masking %s", _messages ) return data @log_guardrail_information async def async_moderation_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", ], ): """ Runs in parallel to LLM API call Runs on only Input This can NOT modify the input, only used to reject or accept a call before going to LLM API """ # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call # In this guardrail, if a user inputs `litellm` we will mask it. _messages = data.get("messages") if _messages: for message in _messages: _content = message.get("content") if isinstance(_content, str): if "litellm" in _content.lower(): raise ValueError("Guardrail failed words - `litellm` detected") @log_guardrail_information async def async_post_call_success_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): """ Runs on response from LLM API call It can be used to reject a response If a response contains the word "coffee" -> we will raise an exception """ verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) if isinstance(response, litellm.ModelResponse): for choice in response.choices: if isinstance(choice, litellm.Choices): verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice) if ( choice.message.content and isinstance(choice.message.content, str) and "coffee" in choice.message.content ): raise ValueError("Guardrail failed Coffee Detected")