File size: 3,777 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# +-------------------------------------------------------------+
#
# Use GuardrailsAI for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you for using Litellm! - Krrish & Ishaan
import json
from typing import Optional, TypedDict
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
class GuardrailsAIResponse(TypedDict):
callId: str
rawLlmOutput: str
validatedOutput: str
validationPassed: bool
class GuardrailsAI(CustomGuardrail):
def __init__(
self,
guard_name: str,
api_base: Optional[str] = None,
**kwargs,
):
if guard_name is None:
raise Exception(
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
)
# store kwargs as optional_params
self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000"
self.guardrails_ai_guard_name = guard_name
self.optional_params = kwargs
supported_event_hooks = [GuardrailEventHooks.post_call]
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
from httpx import URL
data = {
"llmOutput": llm_output,
**self.get_guardrail_dynamic_request_body_params(request_data=request_data),
}
_json_data = json.dumps(data)
response = await litellm.module_level_aclient.post(
url=str(
URL(self.guardrails_ai_api_base).join(
f"guards/{self.guardrails_ai_guard_name}/validate"
)
),
data=_json_data,
headers={
"Content-Type": "application/json",
},
)
verbose_proxy_logger.debug("guardrails_ai response: %s", response)
_json_response = GuardrailsAIResponse(**response.json()) # type: ignore
if _json_response.get("validationPassed") is False:
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"guardrails_ai_response": _json_response,
},
)
return _json_response
@log_guardrail_information
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
"""
Runs on response from LLM API call
It can be used to reject a response
"""
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
if not isinstance(response, litellm.ModelResponse):
return
response_str: str = get_content_from_model_response(response)
if response_str is not None and len(response_str) > 0:
await self.make_guardrails_ai_api_request(
llm_output=response_str, request_data=data
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
return
|