|
from typing import Literal, Optional, Union |
|
|
|
import litellm |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.caching.caching import DualCache |
|
from litellm.integrations.custom_guardrail import ( |
|
CustomGuardrail, |
|
log_guardrail_information, |
|
) |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
|
|
|
|
class myCustomGuardrail(CustomGuardrail): |
|
def __init__( |
|
self, |
|
**kwargs, |
|
): |
|
|
|
self.optional_params = kwargs |
|
|
|
super().__init__(**kwargs) |
|
|
|
@log_guardrail_information |
|
async def async_pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
cache: DualCache, |
|
data: dict, |
|
call_type: Literal[ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
"pass_through_endpoint", |
|
"rerank", |
|
], |
|
) -> Optional[Union[Exception, str, dict]]: |
|
""" |
|
Runs before the LLM API call |
|
Runs on only Input |
|
Use this if you want to MODIFY the input |
|
""" |
|
|
|
|
|
_messages = data.get("messages") |
|
if _messages: |
|
for message in _messages: |
|
_content = message.get("content") |
|
if isinstance(_content, str): |
|
if "litellm" in _content.lower(): |
|
_content = _content.replace("litellm", "********") |
|
message["content"] = _content |
|
|
|
verbose_proxy_logger.debug( |
|
"async_pre_call_hook: Message after masking %s", _messages |
|
) |
|
|
|
return data |
|
|
|
@log_guardrail_information |
|
async def async_moderation_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
call_type: Literal[ |
|
"completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
], |
|
): |
|
""" |
|
Runs in parallel to LLM API call |
|
Runs on only Input |
|
|
|
This can NOT modify the input, only used to reject or accept a call before going to LLM API |
|
""" |
|
|
|
|
|
|
|
_messages = data.get("messages") |
|
if _messages: |
|
for message in _messages: |
|
_content = message.get("content") |
|
if isinstance(_content, str): |
|
if "litellm" in _content.lower(): |
|
raise ValueError("Guardrail failed words - `litellm` detected") |
|
|
|
@log_guardrail_information |
|
async def async_post_call_success_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
response, |
|
): |
|
""" |
|
Runs on response from LLM API call |
|
|
|
It can be used to reject a response |
|
|
|
If a response contains the word "coffee" -> we will raise an exception |
|
""" |
|
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) |
|
if isinstance(response, litellm.ModelResponse): |
|
for choice in response.choices: |
|
if isinstance(choice, litellm.Choices): |
|
verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice) |
|
if ( |
|
choice.message.content |
|
and isinstance(choice.message.content, str) |
|
and "coffee" in choice.message.content |
|
): |
|
raise ValueError("Guardrail failed Coffee Detected") |
|
|