|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from typing import Literal, Optional |
|
|
|
from fastapi import HTTPException |
|
|
|
from litellm import DualCache |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.integrations.custom_guardrail import CustomGuardrail |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
|
|
|
|
class AimGuardrailMissingSecrets(Exception): |
|
pass |
|
|
|
|
|
class AimGuardrail(CustomGuardrail): |
|
def __init__( |
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs |
|
): |
|
self.async_handler = get_async_httpx_client( |
|
llm_provider=httpxSpecialProvider.GuardrailCallback |
|
) |
|
self.api_key = api_key or os.environ.get("AIM_API_KEY") |
|
if not self.api_key: |
|
msg = ( |
|
"Couldn't get Aim api key, either set the `AIM_API_KEY` in the environment or " |
|
"pass it as a parameter to the guardrail in the config file" |
|
) |
|
raise AimGuardrailMissingSecrets(msg) |
|
self.api_base = ( |
|
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security" |
|
) |
|
super().__init__(**kwargs) |
|
|
|
async def async_pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
cache: DualCache, |
|
data: dict, |
|
call_type: Literal[ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
"pass_through_endpoint", |
|
"rerank", |
|
], |
|
) -> Exception | str | dict | None: |
|
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook") |
|
|
|
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") |
|
headers = {"Authorization": f"Bearer {self.api_key}"} | ( |
|
{"x-aim-user-email": user_email} if user_email else {} |
|
) |
|
response = await self.async_handler.post( |
|
f"{self.api_base}/detect/openai", |
|
headers=headers, |
|
json={"messages": data.get("messages", [])}, |
|
) |
|
response.raise_for_status() |
|
res = response.json() |
|
detected = res["detected"] |
|
verbose_proxy_logger.info( |
|
"Aim: detected: {detected}, enabled policies: {policies}".format( |
|
detected=detected, policies=list(res["details"].keys()) |
|
) |
|
) |
|
if detected: |
|
raise HTTPException(status_code=400, detail=res["detection_message"]) |
|
return data |
|
|