|
import importlib |
|
import os |
|
from typing import Dict, List, Optional |
|
|
|
import litellm |
|
from litellm import get_secret |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy |
|
|
|
|
|
from litellm.types.guardrails import ( |
|
Guardrail, |
|
GuardrailItem, |
|
GuardrailItemSpec, |
|
LakeraCategoryThresholds, |
|
LitellmParams, |
|
) |
|
|
|
from .guardrail_registry import guardrail_registry |
|
|
|
all_guardrails: List[GuardrailItem] = [] |
|
|
|
|
|
def initialize_guardrails( |
|
guardrails_config: List[Dict[str, GuardrailItemSpec]], |
|
premium_user: bool, |
|
config_file_path: str, |
|
litellm_settings: dict, |
|
) -> Dict[str, GuardrailItem]: |
|
try: |
|
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}") |
|
global all_guardrails |
|
for item in guardrails_config: |
|
""" |
|
one item looks like this: |
|
|
|
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['user']}} |
|
""" |
|
for k, v in item.items(): |
|
guardrail_item = GuardrailItem(**v, guardrail_name=k) |
|
all_guardrails.append(guardrail_item) |
|
litellm.guardrail_name_config_map[k] = guardrail_item |
|
|
|
|
|
default_on_callbacks = set() |
|
callback_specific_params = {} |
|
for guardrail in all_guardrails: |
|
verbose_proxy_logger.debug(guardrail.guardrail_name) |
|
verbose_proxy_logger.debug(guardrail.default_on) |
|
|
|
callback_specific_params.update(guardrail.callback_args) |
|
|
|
if guardrail.default_on is True: |
|
|
|
for callback in guardrail.callbacks: |
|
if callback not in litellm.callbacks: |
|
default_on_callbacks.add(callback) |
|
|
|
if guardrail.logging_only is True: |
|
if callback == "presidio": |
|
callback_specific_params["presidio"] = {"logging_only": True} |
|
|
|
default_on_callbacks_list = list(default_on_callbacks) |
|
if len(default_on_callbacks_list) > 0: |
|
initialize_callbacks_on_proxy( |
|
value=default_on_callbacks_list, |
|
premium_user=premium_user, |
|
config_file_path=config_file_path, |
|
litellm_settings=litellm_settings, |
|
callback_specific_params=callback_specific_params, |
|
) |
|
|
|
return litellm.guardrail_name_config_map |
|
except Exception as e: |
|
verbose_proxy_logger.exception( |
|
"error initializing guardrails {}".format(str(e)) |
|
) |
|
raise e |
|
|
|
|
|
""" |
|
Map guardrail_name: <pre_call>, <post_call>, during_call |
|
|
|
""" |
|
|
|
|
|
def init_guardrails_v2( |
|
all_guardrails: List[Dict], |
|
config_file_path: Optional[str] = None, |
|
): |
|
guardrail_list = [] |
|
|
|
for guardrail in all_guardrails: |
|
litellm_params_data = guardrail["litellm_params"] |
|
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data) |
|
|
|
_litellm_params_kwargs = { |
|
k: litellm_params_data.get(k) for k in LitellmParams.__annotations__.keys() |
|
} |
|
|
|
litellm_params = LitellmParams(**_litellm_params_kwargs) |
|
|
|
if ( |
|
"category_thresholds" in litellm_params_data |
|
and litellm_params_data["category_thresholds"] |
|
): |
|
lakera_category_thresholds = LakeraCategoryThresholds( |
|
**litellm_params_data["category_thresholds"] |
|
) |
|
litellm_params["category_thresholds"] = lakera_category_thresholds |
|
|
|
if litellm_params["api_key"] and litellm_params["api_key"].startswith( |
|
"os.environ/" |
|
): |
|
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) |
|
|
|
if litellm_params["api_base"] and litellm_params["api_base"].startswith( |
|
"os.environ/" |
|
): |
|
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) |
|
|
|
guardrail_type = litellm_params["guardrail"] |
|
|
|
initializer = guardrail_registry.get(guardrail_type) |
|
|
|
if initializer: |
|
initializer(litellm_params, guardrail) |
|
elif isinstance(guardrail_type, str) and "." in guardrail_type: |
|
if not config_file_path: |
|
raise Exception( |
|
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2" |
|
) |
|
|
|
_file_name, _class_name = guardrail_type.split(".") |
|
verbose_proxy_logger.debug( |
|
"Initializing custom guardrail: %s, file_name: %s, class_name: %s", |
|
guardrail_type, |
|
_file_name, |
|
_class_name, |
|
) |
|
|
|
directory = os.path.dirname(config_file_path) |
|
module_file_path = os.path.join(directory, _file_name) + ".py" |
|
|
|
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) |
|
if not spec: |
|
raise ImportError( |
|
f"Could not find a module specification for {module_file_path}" |
|
) |
|
|
|
module = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(module) |
|
_guardrail_class = getattr(module, _class_name) |
|
|
|
_guardrail_callback = _guardrail_class( |
|
guardrail_name=guardrail["guardrail_name"], |
|
event_hook=litellm_params["mode"], |
|
default_on=litellm_params["default_on"], |
|
) |
|
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) |
|
else: |
|
raise ValueError(f"Unsupported guardrail: {guardrail_type}") |
|
|
|
parsed_guardrail = Guardrail( |
|
guardrail_name=guardrail["guardrail_name"], |
|
litellm_params=litellm_params, |
|
) |
|
|
|
guardrail_list.append(parsed_guardrail) |
|
|
|
verbose_proxy_logger.info(f"\nGuardrail List:{guardrail_list}\n") |
|
|