TestLLM / litellm /proxy /auth /auth_utils.py
Raju2024's picture
Upload 1072 files
e3278e4 verified
raw
history blame
17.1 kB
import os
import re
import sys
from typing import Any, List, Optional, Tuple
from fastapi import HTTPException, Request, status
from litellm import Router, provider_list
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
def _get_request_ip_address(
request: Request, use_x_forwarded_for: Optional[bool] = False
) -> Optional[str]:
client_ip = None
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
client_ip = request.headers["x-forwarded-for"]
elif request.client is not None:
client_ip = request.client.host
else:
client_ip = ""
return client_ip
def _check_valid_ip(
allowed_ips: Optional[List[str]],
request: Request,
use_x_forwarded_for: Optional[bool] = False,
) -> Tuple[bool, Optional[str]]:
"""
Returns if ip is allowed or not
"""
if allowed_ips is None: # if not set, assume true
return True, None
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
client_ip = _get_request_ip_address(
request=request, use_x_forwarded_for=use_x_forwarded_for
)
# Check if IP address is allowed
if client_ip not in allowed_ips:
return False, client_ip
return True, client_ip
def check_complete_credentials(request_body: dict) -> bool:
"""
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
"""
given_model: Optional[str] = None
given_model = request_body.get("model")
if given_model is None:
return False
if (
"sagemaker" in given_model
or "bedrock" in given_model
or "vertex_ai" in given_model
or "vertex_ai_beta" in given_model
):
# complex credentials - easier to make a malicious request
return False
if "api_key" in request_body:
return True
return False
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
"""
Check if request_body_value matches the regex_str or is equal to param
"""
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
return True
return False
def _is_param_allowed(
param: str,
request_body_value: Any,
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
) -> bool:
"""
Check if param is a str or dict and if request_body_value is in the list of allowed values
"""
if configurable_clientside_auth_params is None:
return False
for item in configurable_clientside_auth_params:
if isinstance(item, str) and param == item:
return True
elif isinstance(item, Dict):
if param == "api_base" and check_regex_or_str_match(
request_body_value=request_body_value,
regex_str=item["api_base"],
): # assume param is a regex
return True
return False
def _allow_model_level_clientside_configurable_parameters(
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
) -> bool:
"""
Check if model is allowed to use configurable client-side params
- get matching model
- check if 'clientside_configurable_parameters' is set for model
-
"""
if llm_router is None:
return False
# check if model is set
model_info = llm_router.get_model_group_info(model_group=model)
if model_info is None:
# check if wildcard model is set
if model.split("/", 1)[0] in provider_list:
model_info = llm_router.get_model_group_info(
model_group=model.split("/", 1)[0]
)
if model_info is None:
return False
if model_info is None or model_info.configurable_clientside_auth_params is None:
return False
return _is_param_allowed(
param=param,
request_body_value=request_body_value,
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
)
def is_request_body_safe(
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
) -> bool:
"""
Check if the request body is safe.
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
"""
banned_params = ["api_base", "base_url"]
for param in banned_params:
if (
param in request_body
and not check_complete_credentials( # allow client-credentials to be passed to proxy
request_body=request_body
)
):
if general_settings.get("allow_client_side_credentials") is True:
return True
elif (
_allow_model_level_clientside_configurable_parameters(
model=model,
param=param,
request_body_value=request_body[param],
llm_router=llm_router,
)
is True
):
return True
raise ValueError(
f"Rejected Request: {param} is not allowed in request body. "
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
)
return True
async def pre_db_read_auth_checks(
request: Request,
request_data: dict,
route: str,
):
"""
1. Checks if request size is under max_request_size_mb (if set)
2. Check if request body is safe (example user has not set api_base in request body)
3. Check if IP address is allowed (if set)
4. Check if request route is an allowed route on the proxy (if set)
Returns:
- True
Raises:
- HTTPException if request fails initial auth checks
"""
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
# Check 1. request size
await check_if_request_size_is_safe(request=request)
# Check 2. Request body is safe
is_request_body_safe(
request_body=request_data,
general_settings=general_settings,
llm_router=llm_router,
model=request_data.get(
"model", ""
), # [TODO] use model passed in url as well (azure openai routes)
)
# Check 3. Check if IP address is allowed
is_valid_ip, passed_in_ip = _check_valid_ip(
allowed_ips=general_settings.get("allowed_ips", None),
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
request=request,
)
if not is_valid_ip:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
)
# Check 4. Check if request route is an allowed route on the proxy
if "allowed_routes" in general_settings:
_allowed_routes = general_settings["allowed_routes"]
if premium_user is not True:
verbose_proxy_logger.error(
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
)
if route not in _allowed_routes:
verbose_proxy_logger.error(
f"Route {route} not in allowed_routes={_allowed_routes}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access forbidden: Route {route} not allowed",
)
def route_in_additonal_public_routes(current_route: str):
"""
Helper to check if the user defined public_routes on config.yaml
Parameters:
- current_route: str - the route the user is trying to call
Returns:
- bool - True if the route is defined in public_routes
- bool - False if the route is not defined in public_routes
In order to use this the litellm config.yaml should have the following in general_settings:
```yaml
general_settings:
master_key: sk-1234
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
```
"""
# check if user is premium_user - if not do nothing
from litellm.proxy.proxy_server import general_settings, premium_user
try:
if premium_user is not True:
return False
# check if this is defined on the config
if general_settings is None:
return False
routes_defined = general_settings.get("public_routes", [])
if current_route in routes_defined:
return True
return False
except Exception as e:
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
return False
def get_request_route(request: Request) -> str:
"""
Helper to get the route from the request
remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
"""
try:
if hasattr(request, "base_url") and request.url.path.startswith(
request.base_url.path
):
# remove base_url from path
return request.url.path[len(request.base_url.path) - 1 :]
else:
return request.url.path
except Exception as e:
verbose_proxy_logger.debug(
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
)
return request.url.path
async def check_if_request_size_is_safe(request: Request) -> bool:
"""
Enterprise Only:
- Checks if the request size is within the limit
Args:
request (Request): The incoming request.
Returns:
bool: True if the request size is within the limit
Raises:
ProxyException: If the request size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_request_size_mb = general_settings.get("max_request_size_mb", None)
if max_request_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
# Get the request body
content_length = request.headers.get("content-length")
if content_length:
header_size = int(content_length)
header_size_mb = bytes_to_mb(bytes_value=header_size)
verbose_proxy_logger.debug(
f"content_length request size in MB={header_size_mb}"
)
if header_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
else:
# If Content-Length is not available, read the body
body = await request.body()
body_size = len(body)
request_size_mb = bytes_to_mb(bytes_value=body_size)
verbose_proxy_logger.debug(
f"request body request size in MB={request_size_mb}"
)
if request_size_mb > max_request_size_mb:
raise ProxyException(
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
async def check_response_size_is_safe(response: Any) -> bool:
"""
Enterprise Only:
- Checks if the response size is within the limit
Args:
response (Any): The response to check.
Returns:
bool: True if the response size is within the limit
Raises:
ProxyException: If the response size is too large
"""
from litellm.proxy.proxy_server import general_settings, premium_user
max_response_size_mb = general_settings.get("max_response_size_mb", None)
if max_response_size_mb is not None:
# Check if premium user
if premium_user is not True:
verbose_proxy_logger.warning(
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
)
return True
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
if response_size_mb > max_response_size_mb:
raise ProxyException(
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
type=ProxyErrorTypes.bad_request_error.value,
code=400,
param="content-length",
)
return True
def bytes_to_mb(bytes_value: int):
"""
Helper to convert bytes to MB
"""
return bytes_value / (1024 * 1024)
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
if user_api_key_dict.metadata:
if "model_rpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_rpm_limit"]
elif user_api_key_dict.model_max_budget:
model_rpm_limit: Dict[str, Any] = {}
for model, budget in user_api_key_dict.model_max_budget.items():
if "rpm_limit" in budget and budget["rpm_limit"] is not None:
model_rpm_limit[model] = budget["rpm_limit"]
return model_rpm_limit
return None
def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]:
if user_api_key_dict.metadata:
if "model_tpm_limit" in user_api_key_dict.metadata:
return user_api_key_dict.metadata["model_tpm_limit"]
elif user_api_key_dict.model_max_budget:
if "tpm_limit" in user_api_key_dict.model_max_budget:
return user_api_key_dict.model_max_budget["tpm_limit"]
return None
def is_pass_through_provider_route(route: str) -> bool:
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
"vertex-ai",
]
# check if any of the prefixes are in the route
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
if prefix in route:
return True
return False
def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
"""
Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes
Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes
ex /vertex-ai/{endpoint} routes
Run virtual key auth if the following is try:
- User is premium_user
- User has enabled litellm_setting.use_client_credentials_pass_through_routes
"""
from litellm.proxy.proxy_server import general_settings, premium_user
if premium_user is not True:
return False
# premium use has opted into using client credentials
if (
general_settings.get("use_client_credentials_pass_through_routes", False)
is True
):
return False
# only enabled for LiteLLM Enterprise
return True
def _has_user_setup_sso():
"""
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
Returns a boolean indicating whether SSO has been set up.
"""
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
sso_setup = (
(microsoft_client_id is not None)
or (google_client_id is not None)
or (generic_client_id is not None)
)
return sso_setup
def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
# openai - check 'user'
if "user" in request_body and request_body["user"] is not None:
return str(request_body["user"])
# anthropic - check 'litellm_metadata'
end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
if end_user_id:
return str(end_user_id)
metadata = request_body.get("metadata")
if metadata and "user_id" in metadata and metadata["user_id"] is not None:
return str(metadata["user_id"])
return None