|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
from typing import Optional, TypedDict |
|
|
|
from fastapi import HTTPException |
|
|
|
import litellm |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.integrations.custom_guardrail import ( |
|
CustomGuardrail, |
|
log_guardrail_information, |
|
) |
|
from litellm.litellm_core_utils.prompt_templates.common_utils import ( |
|
get_content_from_model_response, |
|
) |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
from litellm.proxy.common_utils.callback_utils import ( |
|
add_guardrail_to_applied_guardrails_header, |
|
) |
|
from litellm.types.guardrails import GuardrailEventHooks |
|
|
|
|
|
class GuardrailsAIResponse(TypedDict): |
|
callId: str |
|
rawLlmOutput: str |
|
validatedOutput: str |
|
validationPassed: bool |
|
|
|
|
|
class GuardrailsAI(CustomGuardrail): |
|
def __init__( |
|
self, |
|
guard_name: str, |
|
api_base: Optional[str] = None, |
|
**kwargs, |
|
): |
|
if guard_name is None: |
|
raise Exception( |
|
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'" |
|
) |
|
|
|
self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000" |
|
self.guardrails_ai_guard_name = guard_name |
|
self.optional_params = kwargs |
|
supported_event_hooks = [GuardrailEventHooks.post_call] |
|
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs) |
|
|
|
async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict): |
|
from httpx import URL |
|
|
|
data = { |
|
"llmOutput": llm_output, |
|
**self.get_guardrail_dynamic_request_body_params(request_data=request_data), |
|
} |
|
_json_data = json.dumps(data) |
|
response = await litellm.module_level_aclient.post( |
|
url=str( |
|
URL(self.guardrails_ai_api_base).join( |
|
f"guards/{self.guardrails_ai_guard_name}/validate" |
|
) |
|
), |
|
data=_json_data, |
|
headers={ |
|
"Content-Type": "application/json", |
|
}, |
|
) |
|
verbose_proxy_logger.debug("guardrails_ai response: %s", response) |
|
_json_response = GuardrailsAIResponse(**response.json()) |
|
if _json_response.get("validationPassed") is False: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Violated guardrail policy", |
|
"guardrails_ai_response": _json_response, |
|
}, |
|
) |
|
return _json_response |
|
|
|
@log_guardrail_information |
|
async def async_post_call_success_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
response, |
|
): |
|
""" |
|
Runs on response from LLM API call |
|
|
|
It can be used to reject a response |
|
""" |
|
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call |
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True: |
|
return |
|
|
|
if not isinstance(response, litellm.ModelResponse): |
|
return |
|
|
|
response_str: str = get_content_from_model_response(response) |
|
if response_str is not None and len(response_str) > 0: |
|
await self.make_guardrails_ai_api_request( |
|
llm_output=response_str, request_data=data |
|
) |
|
|
|
add_guardrail_to_applied_guardrails_header( |
|
request_data=data, guardrail_name=self.guardrail_name |
|
) |
|
|
|
return |
|
|