File size: 5,640 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import traceback
from typing import Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_AzureContentSafety(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.ai.contentsafety.models import (
AnalyzeTextOptions,
AnalyzeTextOutputType,
TextCategory,
)
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
self.endpoint = endpoint
self.api_key = api_key
self.text_category = TextCategory
self.analyze_text_options = AnalyzeTextOptions
self.analyze_text_output_type = AnalyzeTextOutputType
self.azure_http_error = HttpResponseError
self.thresholds = self._configure_thresholds(thresholds)
self.client = ContentSafetyClient(
self.endpoint, AzureKeyCredential(self.api_key)
)
def _configure_thresholds(self, thresholds=None):
default_thresholds = {
self.text_category.HATE: 4,
self.text_category.SELF_HARM: 4,
self.text_category.SEXUAL: 4,
self.text_category.VIOLENCE: 4,
}
if thresholds is None:
return default_thresholds
for key, default in default_thresholds.items():
if key not in thresholds:
thresholds[key] = default
return thresholds
def _compute_result(self, response):
result = {}
category_severity = {
item.category: item.severity for item in response.categories_analysis
}
for category in self.text_category:
severity = category_severity.get(category)
if severity is not None:
result[category] = {
"filtered": severity >= self.thresholds[category],
"severity": severity,
}
return result
async def test_violation(self, content: str, source: Optional[str] = None):
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
# Construct a request
request = self.analyze_text_options(
text=content,
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
)
# Analyze text
try:
response = await self.client.analyze_text(request)
except self.azure_http_error:
verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc()
)
verbose_proxy_logger.debug(traceback.format_exc())
raise
result = self._compute_result(response)
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
for key, value in result.items():
if value["filtered"]:
raise HTTPException(
status_code=400,
detail={
"error": "Violated content safety policy",
"source": source,
"category": key,
"severity": value["severity"],
},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
try:
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
await self.test_violation(content=m["content"], source="input")
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
await self.test_violation(
content=response.choices[0].message.content or "", source="output"
)
# async def async_post_call_streaming_hook(
# self,
# user_api_key_dict: UserAPIKeyAuth,
# response: str,
# ):
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
# await self.test_violation(content=response, source="output")
|