File size: 5,640 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import traceback
from typing import Optional

from fastapi import HTTPException

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth


class _PROXY_AzureContentSafety(
    CustomLogger
):  # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
    # Class variables or attributes

    def __init__(self, endpoint, api_key, thresholds=None):
        try:
            from azure.ai.contentsafety.aio import ContentSafetyClient
            from azure.ai.contentsafety.models import (
                AnalyzeTextOptions,
                AnalyzeTextOutputType,
                TextCategory,
            )
            from azure.core.credentials import AzureKeyCredential
            from azure.core.exceptions import HttpResponseError
        except Exception as e:
            raise Exception(
                f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
            )
        self.endpoint = endpoint
        self.api_key = api_key
        self.text_category = TextCategory
        self.analyze_text_options = AnalyzeTextOptions
        self.analyze_text_output_type = AnalyzeTextOutputType
        self.azure_http_error = HttpResponseError

        self.thresholds = self._configure_thresholds(thresholds)

        self.client = ContentSafetyClient(
            self.endpoint, AzureKeyCredential(self.api_key)
        )

    def _configure_thresholds(self, thresholds=None):
        default_thresholds = {
            self.text_category.HATE: 4,
            self.text_category.SELF_HARM: 4,
            self.text_category.SEXUAL: 4,
            self.text_category.VIOLENCE: 4,
        }

        if thresholds is None:
            return default_thresholds

        for key, default in default_thresholds.items():
            if key not in thresholds:
                thresholds[key] = default

        return thresholds

    def _compute_result(self, response):
        result = {}

        category_severity = {
            item.category: item.severity for item in response.categories_analysis
        }
        for category in self.text_category:
            severity = category_severity.get(category)
            if severity is not None:
                result[category] = {
                    "filtered": severity >= self.thresholds[category],
                    "severity": severity,
                }

        return result

    async def test_violation(self, content: str, source: Optional[str] = None):
        verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)

        # Construct a request
        request = self.analyze_text_options(
            text=content,
            output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
        )

        # Analyze text
        try:
            response = await self.client.analyze_text(request)
        except self.azure_http_error:
            verbose_proxy_logger.debug(
                "Error in Azure Content-Safety: %s", traceback.format_exc()
            )
            verbose_proxy_logger.debug(traceback.format_exc())
            raise

        result = self._compute_result(response)
        verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)

        for key, value in result.items():
            if value["filtered"]:
                raise HTTPException(
                    status_code=400,
                    detail={
                        "error": "Violated content safety policy",
                        "source": source,
                        "category": key,
                        "severity": value["severity"],
                    },
                )

    async def async_pre_call_hook(
        self,
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache,
        data: dict,
        call_type: str,  # "completion", "embeddings", "image_generation", "moderation"
    ):
        verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
        try:
            if call_type == "completion" and "messages" in data:
                for m in data["messages"]:
                    if "content" in m and isinstance(m["content"], str):
                        await self.test_violation(content=m["content"], source="input")

        except HTTPException as e:
            raise e
        except Exception as e:
            verbose_proxy_logger.error(
                "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
                    str(e)
                )
            )
            verbose_proxy_logger.debug(traceback.format_exc())

    async def async_post_call_success_hook(
        self,
        data: dict,
        user_api_key_dict: UserAPIKeyAuth,
        response,
    ):
        verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
        if isinstance(response, litellm.ModelResponse) and isinstance(
            response.choices[0], litellm.utils.Choices
        ):
            await self.test_violation(
                content=response.choices[0].message.content or "", source="output"
            )

    # async def async_post_call_streaming_hook(
    #    self,
    #    user_api_key_dict: UserAPIKeyAuth,
    #    response: str,
    # ):
    #    verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
    #    await self.test_violation(content=response, source="output")