|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Literal |
|
import litellm |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from fastapi import HTTPException |
|
from litellm._logging import verbose_proxy_logger |
|
import aiohttp |
|
from litellm.utils import get_formatted_prompt |
|
from litellm.secret_managers.main import get_secret_str |
|
|
|
litellm.set_verbose = True |
|
|
|
|
|
class _ENTERPRISE_LLMGuard(CustomLogger): |
|
|
|
def __init__( |
|
self, |
|
mock_testing: bool = False, |
|
mock_redacted_text: Optional[dict] = None, |
|
): |
|
self.mock_redacted_text = mock_redacted_text |
|
self.llm_guard_mode = litellm.llm_guard_mode |
|
if mock_testing == True: |
|
return |
|
self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None) |
|
if self.llm_guard_api_base is None: |
|
raise Exception("Missing `LLM_GUARD_API_BASE` from environment") |
|
elif not self.llm_guard_api_base.endswith("/"): |
|
self.llm_guard_api_base += "/" |
|
|
|
def print_verbose(self, print_statement): |
|
try: |
|
verbose_proxy_logger.debug(print_statement) |
|
if litellm.set_verbose: |
|
print(print_statement) |
|
except Exception: |
|
pass |
|
|
|
async def moderation_check(self, text: str): |
|
""" |
|
[TODO] make this more performant for high-throughput scenario |
|
""" |
|
try: |
|
async with aiohttp.ClientSession() as session: |
|
if self.mock_redacted_text is not None: |
|
redacted_text = self.mock_redacted_text |
|
else: |
|
|
|
analyze_url = f"{self.llm_guard_api_base}analyze/prompt" |
|
verbose_proxy_logger.debug("Making request to: %s", analyze_url) |
|
analyze_payload = {"prompt": text} |
|
redacted_text = None |
|
async with session.post( |
|
analyze_url, json=analyze_payload |
|
) as response: |
|
redacted_text = await response.json() |
|
verbose_proxy_logger.info( |
|
f"LLM Guard: Received response - {redacted_text}" |
|
) |
|
if redacted_text is not None: |
|
if ( |
|
redacted_text.get("is_valid", None) is not None |
|
and redacted_text["is_valid"] != True |
|
): |
|
raise HTTPException( |
|
status_code=400, |
|
detail={"error": "Violated content safety policy"}, |
|
) |
|
else: |
|
pass |
|
else: |
|
raise HTTPException( |
|
status_code=500, |
|
detail={ |
|
"error": f"Invalid content moderation response: {redacted_text}" |
|
}, |
|
) |
|
except Exception as e: |
|
verbose_proxy_logger.exception( |
|
"litellm.enterprise.enterprise_hooks.llm_guard::moderation_check - Exception occurred - {}".format( |
|
str(e) |
|
) |
|
) |
|
raise e |
|
|
|
def should_proceed(self, user_api_key_dict: UserAPIKeyAuth, data: dict) -> bool: |
|
if self.llm_guard_mode == "key-specific": |
|
|
|
self.print_verbose( |
|
f"user_api_key_dict.permissions: {user_api_key_dict.permissions}" |
|
) |
|
if ( |
|
user_api_key_dict.permissions.get("enable_llm_guard_check", False) |
|
== True |
|
): |
|
return True |
|
elif self.llm_guard_mode == "all": |
|
return True |
|
elif self.llm_guard_mode == "request-specific": |
|
self.print_verbose(f"received metadata: {data.get('metadata', {})}") |
|
metadata = data.get("metadata", {}) |
|
permissions = metadata.get("permissions", {}) |
|
if ( |
|
"enable_llm_guard_check" in permissions |
|
and permissions["enable_llm_guard_check"] == True |
|
): |
|
return True |
|
return False |
|
|
|
async def async_moderation_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
call_type: Literal[ |
|
"completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
], |
|
): |
|
""" |
|
- Calls the LLM Guard Endpoint |
|
- Rejects request if it fails safety check |
|
- Use the sanitized prompt returned |
|
- LLM Guard can handle things like PII Masking, etc. |
|
""" |
|
self.print_verbose( |
|
f"Inside LLM Guard Pre-Call Hook - llm_guard_mode={self.llm_guard_mode}" |
|
) |
|
|
|
_proceed = self.should_proceed(user_api_key_dict=user_api_key_dict, data=data) |
|
if _proceed == False: |
|
return |
|
|
|
self.print_verbose("Makes LLM Guard Check") |
|
try: |
|
assert call_type in [ |
|
"completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
] |
|
except Exception: |
|
self.print_verbose( |
|
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" |
|
) |
|
return data |
|
|
|
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) |
|
self.print_verbose(f"LLM Guard, formatted_prompt: {formatted_prompt}") |
|
return await self.moderation_check(text=formatted_prompt) |
|
|
|
async def async_post_call_streaming_hook( |
|
self, user_api_key_dict: UserAPIKeyAuth, response: str |
|
): |
|
if response is not None: |
|
await self.moderation_check(text=response) |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|