Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
3.82 kB
from typing import Literal, Optional, Union
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.proxy._types import UserAPIKeyAuth
class myCustomGuardrail(CustomGuardrail):
def __init__(
self,
**kwargs,
):
# store kwargs as optional_params
self.optional_params = kwargs
super().__init__(**kwargs)
@log_guardrail_information
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Optional[Union[Exception, str, dict]]:
"""
Runs before the LLM API call
Runs on only Input
Use this if you want to MODIFY the input
"""
# In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM
_messages = data.get("messages")
if _messages:
for message in _messages:
_content = message.get("content")
if isinstance(_content, str):
if "litellm" in _content.lower():
_content = _content.replace("litellm", "********")
message["content"] = _content
verbose_proxy_logger.debug(
"async_pre_call_hook: Message after masking %s", _messages
)
return data
@log_guardrail_information
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
"""
Runs in parallel to LLM API call
Runs on only Input
This can NOT modify the input, only used to reject or accept a call before going to LLM API
"""
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
# In this guardrail, if a user inputs `litellm` we will mask it.
_messages = data.get("messages")
if _messages:
for message in _messages:
_content = message.get("content")
if isinstance(_content, str):
if "litellm" in _content.lower():
raise ValueError("Guardrail failed words - `litellm` detected")
@log_guardrail_information
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
"""
Runs on response from LLM API call
It can be used to reject a response
If a response contains the word "coffee" -> we will raise an exception
"""
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
if isinstance(response, litellm.ModelResponse):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice)
if (
choice.message.content
and isinstance(choice.message.content, str)
and "coffee" in choice.message.content
):
raise ValueError("Guardrail failed Coffee Detected")