File size: 6,259 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import importlib
import os
from typing import Dict, List, Optional
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
# v2 implementation
from litellm.types.guardrails import (
Guardrail,
GuardrailItem,
GuardrailItemSpec,
LakeraCategoryThresholds,
LitellmParams,
)
from .guardrail_registry import guardrail_registry
all_guardrails: List[GuardrailItem] = []
def initialize_guardrails(
guardrails_config: List[Dict[str, GuardrailItemSpec]],
premium_user: bool,
config_file_path: str,
litellm_settings: dict,
) -> Dict[str, GuardrailItem]:
try:
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}")
global all_guardrails
for item in guardrails_config:
"""
one item looks like this:
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['user']}}
"""
for k, v in item.items():
guardrail_item = GuardrailItem(**v, guardrail_name=k)
all_guardrails.append(guardrail_item)
litellm.guardrail_name_config_map[k] = guardrail_item
# set appropriate callbacks if they are default on
default_on_callbacks = set()
callback_specific_params = {}
for guardrail in all_guardrails:
verbose_proxy_logger.debug(guardrail.guardrail_name)
verbose_proxy_logger.debug(guardrail.default_on)
callback_specific_params.update(guardrail.callback_args)
if guardrail.default_on is True:
# add these to litellm callbacks if they don't exist
for callback in guardrail.callbacks:
if callback not in litellm.callbacks:
default_on_callbacks.add(callback)
if guardrail.logging_only is True:
if callback == "presidio":
callback_specific_params["presidio"] = {"logging_only": True} # type: ignore
default_on_callbacks_list = list(default_on_callbacks)
if len(default_on_callbacks_list) > 0:
initialize_callbacks_on_proxy(
value=default_on_callbacks_list,
premium_user=premium_user,
config_file_path=config_file_path,
litellm_settings=litellm_settings,
callback_specific_params=callback_specific_params,
)
return litellm.guardrail_name_config_map
except Exception as e:
verbose_proxy_logger.exception(
"error initializing guardrails {}".format(str(e))
)
raise e
"""
Map guardrail_name: <pre_call>, <post_call>, during_call
"""
def init_guardrails_v2(
all_guardrails: List[Dict],
config_file_path: Optional[str] = None,
):
guardrail_list = []
for guardrail in all_guardrails:
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
_litellm_params_kwargs = {
k: litellm_params_data.get(k) for k in LitellmParams.__annotations__.keys()
}
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
if (
"category_thresholds" in litellm_params_data
and litellm_params_data["category_thresholds"]
):
lakera_category_thresholds = LakeraCategoryThresholds(
**litellm_params_data["category_thresholds"]
)
litellm_params["category_thresholds"] = lakera_category_thresholds
if litellm_params["api_key"] and litellm_params["api_key"].startswith(
"os.environ/"
):
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
if litellm_params["api_base"] and litellm_params["api_base"].startswith(
"os.environ/"
):
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
guardrail_type = litellm_params["guardrail"]
initializer = guardrail_registry.get(guardrail_type)
if initializer:
initializer(litellm_params, guardrail)
elif isinstance(guardrail_type, str) and "." in guardrail_type:
if not config_file_path:
raise Exception(
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
)
_file_name, _class_name = guardrail_type.split(".")
verbose_proxy_logger.debug(
"Initializing custom guardrail: %s, file_name: %s, class_name: %s",
guardrail_type,
_file_name,
_class_name,
)
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, _file_name) + ".py"
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
if not spec:
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module) # type: ignore
_guardrail_class = getattr(module, _class_name)
_guardrail_callback = _guardrail_class(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
default_on=litellm_params["default_on"],
)
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
else:
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
parsed_guardrail = Guardrail(
guardrail_name=guardrail["guardrail_name"],
litellm_params=litellm_params,
)
guardrail_list.append(parsed_guardrail)
verbose_proxy_logger.info(f"\nGuardrail List:{guardrail_list}\n")
|