|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
import os |
|
|
|
sys.path.insert( |
|
0, os.path.abspath("../..") |
|
) |
|
from typing import Optional, Literal, Any |
|
import litellm |
|
import sys |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
from litellm.integrations.custom_guardrail import CustomGuardrail |
|
from fastapi import HTTPException |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata |
|
from litellm.litellm_core_utils.logging_utils import ( |
|
convert_litellm_response_object_to_str, |
|
) |
|
from typing import List |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
import json |
|
from litellm.types.guardrails import GuardrailEventHooks |
|
|
|
litellm.set_verbose = True |
|
|
|
GUARDRAIL_NAME = "aporia" |
|
|
|
|
|
class AporiaGuardrail(CustomGuardrail): |
|
def __init__( |
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs |
|
): |
|
self.async_handler = get_async_httpx_client( |
|
llm_provider=httpxSpecialProvider.GuardrailCallback |
|
) |
|
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] |
|
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] |
|
super().__init__(**kwargs) |
|
|
|
|
|
def transform_messages(self, messages: List[dict]) -> List[dict]: |
|
supported_openai_roles = ["system", "user", "assistant"] |
|
default_role = "other" |
|
new_messages = [] |
|
for m in messages: |
|
if m.get("role", "") in supported_openai_roles: |
|
new_messages.append(m) |
|
else: |
|
new_messages.append( |
|
{ |
|
"role": default_role, |
|
**{key: value for key, value in m.items() if key != "role"}, |
|
} |
|
) |
|
|
|
return new_messages |
|
|
|
async def prepare_aporia_request( |
|
self, new_messages: List[dict], response_string: Optional[str] = None |
|
) -> dict: |
|
data: dict[str, Any] = {} |
|
if new_messages is not None: |
|
data["messages"] = new_messages |
|
if response_string is not None: |
|
data["response"] = response_string |
|
|
|
|
|
if new_messages and response_string: |
|
data["validation_target"] = "both" |
|
elif new_messages: |
|
data["validation_target"] = "prompt" |
|
elif response_string: |
|
data["validation_target"] = "response" |
|
|
|
verbose_proxy_logger.debug("Aporia AI request: %s", data) |
|
return data |
|
|
|
async def make_aporia_api_request( |
|
self, new_messages: List[dict], response_string: Optional[str] = None |
|
): |
|
data = await self.prepare_aporia_request( |
|
new_messages=new_messages, response_string=response_string |
|
) |
|
|
|
_json_data = json.dumps(data) |
|
|
|
""" |
|
export APORIO_API_KEY=<your key> |
|
curl https://gr-prd-trial.aporia.com/some-id \ |
|
-X POST \ |
|
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \ |
|
-H "Content-Type: application/json" \ |
|
-d '{ |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": "This is a test prompt" |
|
} |
|
], |
|
} |
|
' |
|
""" |
|
|
|
response = await self.async_handler.post( |
|
url=self.aporia_api_base + "/validate", |
|
data=_json_data, |
|
headers={ |
|
"X-APORIA-API-KEY": self.aporia_api_key, |
|
"Content-Type": "application/json", |
|
}, |
|
) |
|
verbose_proxy_logger.debug("Aporia AI response: %s", response.text) |
|
if response.status_code == 200: |
|
|
|
_json_response = response.json() |
|
action: str = _json_response.get( |
|
"action" |
|
) |
|
if action == "block": |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Violated guardrail policy", |
|
"aporia_ai_response": _json_response, |
|
}, |
|
) |
|
|
|
async def async_post_call_success_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
response, |
|
): |
|
from litellm.proxy.common_utils.callback_utils import ( |
|
add_guardrail_to_applied_guardrails_header, |
|
) |
|
|
|
""" |
|
Use this for the post call moderation with Guardrails |
|
""" |
|
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call |
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True: |
|
return |
|
|
|
response_str: Optional[str] = convert_litellm_response_object_to_str(response) |
|
if response_str is not None: |
|
await self.make_aporia_api_request( |
|
response_string=response_str, new_messages=data.get("messages", []) |
|
) |
|
|
|
add_guardrail_to_applied_guardrails_header( |
|
request_data=data, guardrail_name=self.guardrail_name |
|
) |
|
|
|
pass |
|
|
|
async def async_moderation_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
call_type: Literal[ |
|
"completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
], |
|
): |
|
from litellm.proxy.common_utils.callback_utils import ( |
|
add_guardrail_to_applied_guardrails_header, |
|
) |
|
|
|
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call |
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True: |
|
return |
|
|
|
|
|
if ( |
|
await should_proceed_based_on_metadata( |
|
data=data, |
|
guardrail_name=GUARDRAIL_NAME, |
|
) |
|
is False |
|
): |
|
return |
|
|
|
new_messages: Optional[List[dict]] = None |
|
if "messages" in data and isinstance(data["messages"], list): |
|
new_messages = self.transform_messages(messages=data["messages"]) |
|
|
|
if new_messages is not None: |
|
await self.make_aporia_api_request(new_messages=new_messages) |
|
add_guardrail_to_applied_guardrails_header( |
|
request_data=data, guardrail_name=self.guardrail_name |
|
) |
|
else: |
|
verbose_proxy_logger.warning( |
|
"Aporia AI: not running guardrail. No messages in data" |
|
) |
|
pass |
|
|