|
import traceback |
|
from typing import Optional |
|
|
|
from fastapi import HTTPException |
|
|
|
import litellm |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.caching.caching import DualCache |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
|
|
|
|
class _PROXY_AzureContentSafety( |
|
CustomLogger |
|
): |
|
|
|
|
|
def __init__(self, endpoint, api_key, thresholds=None): |
|
try: |
|
from azure.ai.contentsafety.aio import ContentSafetyClient |
|
from azure.ai.contentsafety.models import ( |
|
AnalyzeTextOptions, |
|
AnalyzeTextOutputType, |
|
TextCategory, |
|
) |
|
from azure.core.credentials import AzureKeyCredential |
|
from azure.core.exceptions import HttpResponseError |
|
except Exception as e: |
|
raise Exception( |
|
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" |
|
) |
|
self.endpoint = endpoint |
|
self.api_key = api_key |
|
self.text_category = TextCategory |
|
self.analyze_text_options = AnalyzeTextOptions |
|
self.analyze_text_output_type = AnalyzeTextOutputType |
|
self.azure_http_error = HttpResponseError |
|
|
|
self.thresholds = self._configure_thresholds(thresholds) |
|
|
|
self.client = ContentSafetyClient( |
|
self.endpoint, AzureKeyCredential(self.api_key) |
|
) |
|
|
|
def _configure_thresholds(self, thresholds=None): |
|
default_thresholds = { |
|
self.text_category.HATE: 4, |
|
self.text_category.SELF_HARM: 4, |
|
self.text_category.SEXUAL: 4, |
|
self.text_category.VIOLENCE: 4, |
|
} |
|
|
|
if thresholds is None: |
|
return default_thresholds |
|
|
|
for key, default in default_thresholds.items(): |
|
if key not in thresholds: |
|
thresholds[key] = default |
|
|
|
return thresholds |
|
|
|
def _compute_result(self, response): |
|
result = {} |
|
|
|
category_severity = { |
|
item.category: item.severity for item in response.categories_analysis |
|
} |
|
for category in self.text_category: |
|
severity = category_severity.get(category) |
|
if severity is not None: |
|
result[category] = { |
|
"filtered": severity >= self.thresholds[category], |
|
"severity": severity, |
|
} |
|
|
|
return result |
|
|
|
async def test_violation(self, content: str, source: Optional[str] = None): |
|
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) |
|
|
|
|
|
request = self.analyze_text_options( |
|
text=content, |
|
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS, |
|
) |
|
|
|
|
|
try: |
|
response = await self.client.analyze_text(request) |
|
except self.azure_http_error: |
|
verbose_proxy_logger.debug( |
|
"Error in Azure Content-Safety: %s", traceback.format_exc() |
|
) |
|
verbose_proxy_logger.debug(traceback.format_exc()) |
|
raise |
|
|
|
result = self._compute_result(response) |
|
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result) |
|
|
|
for key, value in result.items(): |
|
if value["filtered"]: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Violated content safety policy", |
|
"source": source, |
|
"category": key, |
|
"severity": value["severity"], |
|
}, |
|
) |
|
|
|
async def async_pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
cache: DualCache, |
|
data: dict, |
|
call_type: str, |
|
): |
|
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook") |
|
try: |
|
if call_type == "completion" and "messages" in data: |
|
for m in data["messages"]: |
|
if "content" in m and isinstance(m["content"], str): |
|
await self.test_violation(content=m["content"], source="input") |
|
|
|
except HTTPException as e: |
|
raise e |
|
except Exception as e: |
|
verbose_proxy_logger.error( |
|
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format( |
|
str(e) |
|
) |
|
) |
|
verbose_proxy_logger.debug(traceback.format_exc()) |
|
|
|
async def async_post_call_success_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
response, |
|
): |
|
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook") |
|
if isinstance(response, litellm.ModelResponse) and isinstance( |
|
response.choices[0], litellm.utils.Choices |
|
): |
|
await self.test_violation( |
|
content=response.choices[0].message.content or "", source="output" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|