|
from typing import Any, Dict, List, Optional |
|
|
|
import litellm |
|
from litellm import get_secret |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams |
|
from litellm.proxy.utils import get_instance_fn |
|
|
|
blue_color_code = "\033[94m" |
|
reset_color_code = "\033[0m" |
|
|
|
|
|
def initialize_callbacks_on_proxy( |
|
value: Any, |
|
premium_user: bool, |
|
config_file_path: str, |
|
litellm_settings: dict, |
|
callback_specific_params: dict = {}, |
|
): |
|
from litellm.proxy.proxy_server import prisma_client |
|
|
|
verbose_proxy_logger.debug( |
|
f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}" |
|
) |
|
if isinstance(value, list): |
|
imported_list: List[Any] = [] |
|
for callback in value: |
|
if ( |
|
isinstance(callback, str) |
|
and callback in litellm._known_custom_logger_compatible_callbacks |
|
): |
|
imported_list.append(callback) |
|
elif isinstance(callback, str) and callback == "presidio": |
|
from litellm.proxy.guardrails.guardrail_hooks.presidio import ( |
|
_OPTIONAL_PresidioPIIMasking, |
|
) |
|
|
|
presidio_logging_only: Optional[bool] = litellm_settings.get( |
|
"presidio_logging_only", None |
|
) |
|
if presidio_logging_only is not None: |
|
presidio_logging_only = bool( |
|
presidio_logging_only |
|
) |
|
|
|
_presidio_params = {} |
|
if "presidio" in callback_specific_params and isinstance( |
|
callback_specific_params["presidio"], dict |
|
): |
|
_presidio_params = callback_specific_params["presidio"] |
|
|
|
params: Dict[str, Any] = { |
|
"logging_only": presidio_logging_only, |
|
**_presidio_params, |
|
} |
|
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params) |
|
imported_list.append(pii_masking_object) |
|
elif isinstance(callback, str) and callback == "llamaguard_moderations": |
|
from enterprise.enterprise_hooks.llama_guard import ( |
|
_ENTERPRISE_LlamaGuard, |
|
) |
|
|
|
if premium_user is not True: |
|
raise Exception( |
|
"Trying to use Llama Guard" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
llama_guard_object = _ENTERPRISE_LlamaGuard() |
|
imported_list.append(llama_guard_object) |
|
elif isinstance(callback, str) and callback == "hide_secrets": |
|
from enterprise.enterprise_hooks.secret_detection import ( |
|
_ENTERPRISE_SecretDetection, |
|
) |
|
|
|
if premium_user is not True: |
|
raise Exception( |
|
"Trying to use secret hiding" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
_secret_detection_object = _ENTERPRISE_SecretDetection() |
|
imported_list.append(_secret_detection_object) |
|
elif isinstance(callback, str) and callback == "openai_moderations": |
|
from enterprise.enterprise_hooks.openai_moderation import ( |
|
_ENTERPRISE_OpenAI_Moderation, |
|
) |
|
|
|
if premium_user is not True: |
|
raise Exception( |
|
"Trying to use OpenAI Moderations Check" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation() |
|
imported_list.append(openai_moderations_object) |
|
elif isinstance(callback, str) and callback == "lakera_prompt_injection": |
|
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import ( |
|
lakeraAI_Moderation, |
|
) |
|
|
|
init_params = {} |
|
if "lakera_prompt_injection" in callback_specific_params: |
|
init_params = callback_specific_params["lakera_prompt_injection"] |
|
lakera_moderations_object = lakeraAI_Moderation(**init_params) |
|
imported_list.append(lakera_moderations_object) |
|
elif isinstance(callback, str) and callback == "aporia_prompt_injection": |
|
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import ( |
|
AporiaGuardrail, |
|
) |
|
|
|
aporia_guardrail_object = AporiaGuardrail() |
|
imported_list.append(aporia_guardrail_object) |
|
elif isinstance(callback, str) and callback == "google_text_moderation": |
|
from enterprise.enterprise_hooks.google_text_moderation import ( |
|
_ENTERPRISE_GoogleTextModeration, |
|
) |
|
|
|
if premium_user is not True: |
|
raise Exception( |
|
"Trying to use Google Text Moderation" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration() |
|
imported_list.append(google_text_moderation_obj) |
|
elif isinstance(callback, str) and callback == "llmguard_moderations": |
|
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard |
|
|
|
if premium_user is not True: |
|
raise Exception( |
|
"Trying to use Llm Guard" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard() |
|
imported_list.append(llm_guard_moderation_obj) |
|
elif isinstance(callback, str) and callback == "blocked_user_check": |
|
from enterprise.enterprise_hooks.blocked_user_list import ( |
|
_ENTERPRISE_BlockedUserList, |
|
) |
|
|
|
if premium_user is not True: |
|
raise Exception( |
|
"Trying to use ENTERPRISE BlockedUser" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
blocked_user_list = _ENTERPRISE_BlockedUserList( |
|
prisma_client=prisma_client |
|
) |
|
imported_list.append(blocked_user_list) |
|
elif isinstance(callback, str) and callback == "banned_keywords": |
|
from enterprise.enterprise_hooks.banned_keywords import ( |
|
_ENTERPRISE_BannedKeywords, |
|
) |
|
|
|
if premium_user is not True: |
|
raise Exception( |
|
"Trying to use ENTERPRISE BannedKeyword" |
|
+ CommonProxyErrors.not_premium_user.value |
|
) |
|
|
|
banned_keywords_obj = _ENTERPRISE_BannedKeywords() |
|
imported_list.append(banned_keywords_obj) |
|
elif isinstance(callback, str) and callback == "detect_prompt_injection": |
|
from litellm.proxy.hooks.prompt_injection_detection import ( |
|
_OPTIONAL_PromptInjectionDetection, |
|
) |
|
|
|
prompt_injection_params = None |
|
if "prompt_injection_params" in litellm_settings: |
|
prompt_injection_params_in_config = litellm_settings[ |
|
"prompt_injection_params" |
|
] |
|
prompt_injection_params = LiteLLMPromptInjectionParams( |
|
**prompt_injection_params_in_config |
|
) |
|
|
|
prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection( |
|
prompt_injection_params=prompt_injection_params, |
|
) |
|
imported_list.append(prompt_injection_detection_obj) |
|
elif isinstance(callback, str) and callback == "batch_redis_requests": |
|
from litellm.proxy.hooks.batch_redis_get import ( |
|
_PROXY_BatchRedisRequests, |
|
) |
|
|
|
batch_redis_obj = _PROXY_BatchRedisRequests() |
|
imported_list.append(batch_redis_obj) |
|
elif isinstance(callback, str) and callback == "azure_content_safety": |
|
from litellm.proxy.hooks.azure_content_safety import ( |
|
_PROXY_AzureContentSafety, |
|
) |
|
|
|
azure_content_safety_params = litellm_settings[ |
|
"azure_content_safety_params" |
|
] |
|
for k, v in azure_content_safety_params.items(): |
|
if ( |
|
v is not None |
|
and isinstance(v, str) |
|
and v.startswith("os.environ/") |
|
): |
|
azure_content_safety_params[k] = get_secret(v) |
|
|
|
azure_content_safety_obj = _PROXY_AzureContentSafety( |
|
**azure_content_safety_params, |
|
) |
|
imported_list.append(azure_content_safety_obj) |
|
else: |
|
verbose_proxy_logger.debug( |
|
f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}" |
|
) |
|
imported_list.append( |
|
get_instance_fn( |
|
value=callback, |
|
config_file_path=config_file_path, |
|
) |
|
) |
|
if isinstance(litellm.callbacks, list): |
|
litellm.callbacks.extend(imported_list) |
|
else: |
|
litellm.callbacks = imported_list |
|
|
|
if "prometheus" in value: |
|
if premium_user is not True: |
|
verbose_proxy_logger.warning( |
|
f"Prometheus metrics are only available for premium users. {CommonProxyErrors.not_premium_user.value}" |
|
) |
|
from litellm.proxy.proxy_server import app |
|
|
|
verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics") |
|
from prometheus_client import make_asgi_app |
|
|
|
|
|
metrics_app = make_asgi_app() |
|
app.mount("/metrics", metrics_app) |
|
else: |
|
litellm.callbacks = [ |
|
get_instance_fn( |
|
value=value, |
|
config_file_path=config_file_path, |
|
) |
|
] |
|
verbose_proxy_logger.debug( |
|
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" |
|
) |
|
|
|
|
|
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]: |
|
_litellm_params = kwargs.get("litellm_params", None) or {} |
|
_metadata = _litellm_params.get("metadata", None) or {} |
|
_model_group = _metadata.get("model_group", None) |
|
if _model_group is not None: |
|
return _model_group |
|
|
|
return None |
|
|
|
|
|
def get_model_group_from_request_data(data: dict) -> Optional[str]: |
|
_metadata = data.get("metadata", None) or {} |
|
_model_group = _metadata.get("model_group", None) |
|
if _model_group is not None: |
|
return _model_group |
|
|
|
return None |
|
|
|
|
|
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]: |
|
""" |
|
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group} |
|
|
|
Returns {} when api_key + model rpm/tpm limit is not set |
|
|
|
""" |
|
headers = {} |
|
_metadata = data.get("metadata", None) or {} |
|
model_group = get_model_group_from_request_data(data) |
|
|
|
|
|
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}" |
|
remaining_requests = _metadata.get(remaining_requests_variable_name, None) |
|
if remaining_requests: |
|
headers[f"x-litellm-key-remaining-requests-{model_group}"] = remaining_requests |
|
|
|
|
|
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}" |
|
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None) |
|
if remaining_tokens: |
|
headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens |
|
|
|
return headers |
|
|
|
|
|
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: |
|
_metadata = request_data.get("metadata", None) or {} |
|
headers = {} |
|
if "applied_guardrails" in _metadata: |
|
headers["x-litellm-applied-guardrails"] = ",".join( |
|
_metadata["applied_guardrails"] |
|
) |
|
|
|
if "semantic-similarity" in _metadata: |
|
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"]) |
|
|
|
return headers |
|
|
|
|
|
def add_guardrail_to_applied_guardrails_header( |
|
request_data: Dict, guardrail_name: Optional[str] |
|
): |
|
if guardrail_name is None: |
|
return |
|
_metadata = request_data.get("metadata", None) or {} |
|
if "applied_guardrails" in _metadata: |
|
_metadata["applied_guardrails"].append(guardrail_name) |
|
else: |
|
_metadata["applied_guardrails"] = [guardrail_name] |
|
|