Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
3.78 kB
# +-------------------------------------------------------------+
#
# Use GuardrailsAI for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you for using Litellm! - Krrish & Ishaan
import json
from typing import Optional, TypedDict
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
class GuardrailsAIResponse(TypedDict):
callId: str
rawLlmOutput: str
validatedOutput: str
validationPassed: bool
class GuardrailsAI(CustomGuardrail):
def __init__(
self,
guard_name: str,
api_base: Optional[str] = None,
**kwargs,
):
if guard_name is None:
raise Exception(
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
)
# store kwargs as optional_params
self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000"
self.guardrails_ai_guard_name = guard_name
self.optional_params = kwargs
supported_event_hooks = [GuardrailEventHooks.post_call]
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
from httpx import URL
data = {
"llmOutput": llm_output,
**self.get_guardrail_dynamic_request_body_params(request_data=request_data),
}
_json_data = json.dumps(data)
response = await litellm.module_level_aclient.post(
url=str(
URL(self.guardrails_ai_api_base).join(
f"guards/{self.guardrails_ai_guard_name}/validate"
)
),
data=_json_data,
headers={
"Content-Type": "application/json",
},
)
verbose_proxy_logger.debug("guardrails_ai response: %s", response)
_json_response = GuardrailsAIResponse(**response.json()) # type: ignore
if _json_response.get("validationPassed") is False:
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"guardrails_ai_response": _json_response,
},
)
return _json_response
@log_guardrail_information
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
"""
Runs on response from LLM API call
It can be used to reject a response
"""
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
if not isinstance(response, litellm.ModelResponse):
return
response_str: str = get_content_from_model_response(response)
if response_str is not None and len(response_str) > 0:
await self.make_guardrails_ai_api_request(
llm_output=response_str, request_data=data
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
return