# +-------------------------------------------------------------+ # # Use AporiaAI for your LLM calls # # +-------------------------------------------------------------+ # Thank you users! We ❤️ you! - Krrish & Ishaan import sys import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from typing import Optional, Literal, Any import litellm import sys from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_guardrail import CustomGuardrail from fastapi import HTTPException from litellm._logging import verbose_proxy_logger from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.litellm_core_utils.logging_utils import ( convert_litellm_response_object_to_str, ) from typing import List from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, ) import json from litellm.types.guardrails import GuardrailEventHooks litellm.set_verbose = True GUARDRAIL_NAME = "aporia" class AporiaGuardrail(CustomGuardrail): def __init__( self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs ): self.async_handler = get_async_httpx_client( llm_provider=httpxSpecialProvider.GuardrailCallback ) self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] super().__init__(**kwargs) #### CALL HOOKS - proxy only #### def transform_messages(self, messages: List[dict]) -> List[dict]: supported_openai_roles = ["system", "user", "assistant"] default_role = "other" # for unsupported roles - e.g. tool new_messages = [] for m in messages: if m.get("role", "") in supported_openai_roles: new_messages.append(m) else: new_messages.append( { "role": default_role, **{key: value for key, value in m.items() if key != "role"}, } ) return new_messages async def prepare_aporia_request( self, new_messages: List[dict], response_string: Optional[str] = None ) -> dict: data: dict[str, Any] = {} if new_messages is not None: data["messages"] = new_messages if response_string is not None: data["response"] = response_string # Set validation target if new_messages and response_string: data["validation_target"] = "both" elif new_messages: data["validation_target"] = "prompt" elif response_string: data["validation_target"] = "response" verbose_proxy_logger.debug("Aporia AI request: %s", data) return data async def make_aporia_api_request( self, new_messages: List[dict], response_string: Optional[str] = None ): data = await self.prepare_aporia_request( new_messages=new_messages, response_string=response_string ) _json_data = json.dumps(data) """ export APORIO_API_KEY= curl https://gr-prd-trial.aporia.com/some-id \ -X POST \ -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ -H "Content-Type: application/json" \ -d '{ "messages": [ { "role": "user", "content": "This is a test prompt" } ], } ' """ response = await self.async_handler.post( url=self.aporia_api_base + "/validate", data=_json_data, headers={ "X-APORIA-API-KEY": self.aporia_api_key, "Content-Type": "application/json", }, ) verbose_proxy_logger.debug("Aporia AI response: %s", response.text) if response.status_code == 200: # check if the response was flagged _json_response = response.json() action: str = _json_response.get( "action" ) # possible values are modify, passthrough, block, rephrase if action == "block": raise HTTPException( status_code=400, detail={ "error": "Violated guardrail policy", "aporia_ai_response": _json_response, }, ) async def async_post_call_success_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) """ Use this for the post call moderation with Guardrails """ event_type: GuardrailEventHooks = GuardrailEventHooks.post_call if self.should_run_guardrail(data=data, event_type=event_type) is not True: return response_str: Optional[str] = convert_litellm_response_object_to_str(response) if response_str is not None: await self.make_aporia_api_request( response_string=response_str, new_messages=data.get("messages", []) ) add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) pass async def async_moderation_hook( ### 👈 KEY CHANGE ### self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", ], ): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) event_type: GuardrailEventHooks = GuardrailEventHooks.during_call if self.should_run_guardrail(data=data, event_type=event_type) is not True: return # old implementation - backwards compatibility if ( await should_proceed_based_on_metadata( data=data, guardrail_name=GUARDRAIL_NAME, ) is False ): return new_messages: Optional[List[dict]] = None if "messages" in data and isinstance(data["messages"], list): new_messages = self.transform_messages(messages=data["messages"]) if new_messages is not None: await self.make_aporia_api_request(new_messages=new_messages) add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) else: verbose_proxy_logger.warning( "Aporia AI: not running guardrail. No messages in data" ) pass