|
import os |
|
import re |
|
import sys |
|
from typing import Any, List, Optional, Tuple |
|
|
|
from fastapi import HTTPException, Request, status |
|
|
|
from litellm import Router, provider_list |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.proxy._types import * |
|
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS |
|
|
|
|
|
def _get_request_ip_address( |
|
request: Request, use_x_forwarded_for: Optional[bool] = False |
|
) -> Optional[str]: |
|
|
|
client_ip = None |
|
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: |
|
client_ip = request.headers["x-forwarded-for"] |
|
elif request.client is not None: |
|
client_ip = request.client.host |
|
else: |
|
client_ip = "" |
|
|
|
return client_ip |
|
|
|
|
|
def _check_valid_ip( |
|
allowed_ips: Optional[List[str]], |
|
request: Request, |
|
use_x_forwarded_for: Optional[bool] = False, |
|
) -> Tuple[bool, Optional[str]]: |
|
""" |
|
Returns if ip is allowed or not |
|
""" |
|
if allowed_ips is None: |
|
return True, None |
|
|
|
|
|
client_ip = _get_request_ip_address( |
|
request=request, use_x_forwarded_for=use_x_forwarded_for |
|
) |
|
|
|
|
|
if client_ip not in allowed_ips: |
|
return False, client_ip |
|
|
|
return True, client_ip |
|
|
|
|
|
def check_complete_credentials(request_body: dict) -> bool: |
|
""" |
|
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks. |
|
""" |
|
given_model: Optional[str] = None |
|
|
|
given_model = request_body.get("model") |
|
if given_model is None: |
|
return False |
|
|
|
if ( |
|
"sagemaker" in given_model |
|
or "bedrock" in given_model |
|
or "vertex_ai" in given_model |
|
or "vertex_ai_beta" in given_model |
|
): |
|
|
|
return False |
|
|
|
if "api_key" in request_body: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool: |
|
""" |
|
Check if request_body_value matches the regex_str or is equal to param |
|
""" |
|
if re.match(regex_str, request_body_value) or regex_str == request_body_value: |
|
return True |
|
return False |
|
|
|
|
|
def _is_param_allowed( |
|
param: str, |
|
request_body_value: Any, |
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, |
|
) -> bool: |
|
""" |
|
Check if param is a str or dict and if request_body_value is in the list of allowed values |
|
""" |
|
if configurable_clientside_auth_params is None: |
|
return False |
|
|
|
for item in configurable_clientside_auth_params: |
|
if isinstance(item, str) and param == item: |
|
return True |
|
elif isinstance(item, Dict): |
|
if param == "api_base" and check_regex_or_str_match( |
|
request_body_value=request_body_value, |
|
regex_str=item["api_base"], |
|
): |
|
return True |
|
|
|
return False |
|
|
|
|
|
def _allow_model_level_clientside_configurable_parameters( |
|
model: str, param: str, request_body_value: Any, llm_router: Optional[Router] |
|
) -> bool: |
|
""" |
|
Check if model is allowed to use configurable client-side params |
|
- get matching model |
|
- check if 'clientside_configurable_parameters' is set for model |
|
- |
|
""" |
|
if llm_router is None: |
|
return False |
|
|
|
model_info = llm_router.get_model_group_info(model_group=model) |
|
if model_info is None: |
|
|
|
if model.split("/", 1)[0] in provider_list: |
|
model_info = llm_router.get_model_group_info( |
|
model_group=model.split("/", 1)[0] |
|
) |
|
|
|
if model_info is None: |
|
return False |
|
|
|
if model_info is None or model_info.configurable_clientside_auth_params is None: |
|
return False |
|
|
|
return _is_param_allowed( |
|
param=param, |
|
request_body_value=request_body_value, |
|
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params, |
|
) |
|
|
|
|
|
def is_request_body_safe( |
|
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str |
|
) -> bool: |
|
""" |
|
Check if the request body is safe. |
|
|
|
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key. |
|
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997 |
|
""" |
|
banned_params = ["api_base", "base_url"] |
|
|
|
for param in banned_params: |
|
if ( |
|
param in request_body |
|
and not check_complete_credentials( |
|
request_body=request_body |
|
) |
|
): |
|
if general_settings.get("allow_client_side_credentials") is True: |
|
return True |
|
elif ( |
|
_allow_model_level_clientside_configurable_parameters( |
|
model=model, |
|
param=param, |
|
request_body_value=request_body[param], |
|
llm_router=llm_router, |
|
) |
|
is True |
|
): |
|
return True |
|
raise ValueError( |
|
f"Rejected Request: {param} is not allowed in request body. " |
|
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. " |
|
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997", |
|
) |
|
|
|
return True |
|
|
|
|
|
async def pre_db_read_auth_checks( |
|
request: Request, |
|
request_data: dict, |
|
route: str, |
|
): |
|
""" |
|
1. Checks if request size is under max_request_size_mb (if set) |
|
2. Check if request body is safe (example user has not set api_base in request body) |
|
3. Check if IP address is allowed (if set) |
|
4. Check if request route is an allowed route on the proxy (if set) |
|
|
|
Returns: |
|
- True |
|
|
|
Raises: |
|
- HTTPException if request fails initial auth checks |
|
""" |
|
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user |
|
|
|
|
|
await check_if_request_size_is_safe(request=request) |
|
|
|
|
|
is_request_body_safe( |
|
request_body=request_data, |
|
general_settings=general_settings, |
|
llm_router=llm_router, |
|
model=request_data.get( |
|
"model", "" |
|
), |
|
) |
|
|
|
|
|
is_valid_ip, passed_in_ip = _check_valid_ip( |
|
allowed_ips=general_settings.get("allowed_ips", None), |
|
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), |
|
request=request, |
|
) |
|
|
|
if not is_valid_ip: |
|
raise HTTPException( |
|
status_code=status.HTTP_403_FORBIDDEN, |
|
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.", |
|
) |
|
|
|
|
|
if "allowed_routes" in general_settings: |
|
_allowed_routes = general_settings["allowed_routes"] |
|
if premium_user is not True: |
|
verbose_proxy_logger.error( |
|
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}" |
|
) |
|
if route not in _allowed_routes: |
|
verbose_proxy_logger.error( |
|
f"Route {route} not in allowed_routes={_allowed_routes}" |
|
) |
|
raise HTTPException( |
|
status_code=status.HTTP_403_FORBIDDEN, |
|
detail=f"Access forbidden: Route {route} not allowed", |
|
) |
|
|
|
|
|
def route_in_additonal_public_routes(current_route: str): |
|
""" |
|
Helper to check if the user defined public_routes on config.yaml |
|
|
|
Parameters: |
|
- current_route: str - the route the user is trying to call |
|
|
|
Returns: |
|
- bool - True if the route is defined in public_routes |
|
- bool - False if the route is not defined in public_routes |
|
|
|
|
|
In order to use this the litellm config.yaml should have the following in general_settings: |
|
|
|
```yaml |
|
general_settings: |
|
master_key: sk-1234 |
|
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"] |
|
``` |
|
""" |
|
|
|
|
|
from litellm.proxy.proxy_server import general_settings, premium_user |
|
|
|
try: |
|
if premium_user is not True: |
|
return False |
|
|
|
if general_settings is None: |
|
return False |
|
|
|
routes_defined = general_settings.get("public_routes", []) |
|
if current_route in routes_defined: |
|
return True |
|
|
|
return False |
|
except Exception as e: |
|
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}") |
|
return False |
|
|
|
|
|
def get_request_route(request: Request) -> str: |
|
""" |
|
Helper to get the route from the request |
|
|
|
remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions |
|
""" |
|
try: |
|
if hasattr(request, "base_url") and request.url.path.startswith( |
|
request.base_url.path |
|
): |
|
|
|
return request.url.path[len(request.base_url.path) - 1 :] |
|
else: |
|
return request.url.path |
|
except Exception as e: |
|
verbose_proxy_logger.debug( |
|
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}" |
|
) |
|
return request.url.path |
|
|
|
|
|
async def check_if_request_size_is_safe(request: Request) -> bool: |
|
""" |
|
Enterprise Only: |
|
- Checks if the request size is within the limit |
|
|
|
Args: |
|
request (Request): The incoming request. |
|
|
|
Returns: |
|
bool: True if the request size is within the limit |
|
|
|
Raises: |
|
ProxyException: If the request size is too large |
|
|
|
""" |
|
from litellm.proxy.proxy_server import general_settings, premium_user |
|
|
|
max_request_size_mb = general_settings.get("max_request_size_mb", None) |
|
if max_request_size_mb is not None: |
|
|
|
if premium_user is not True: |
|
verbose_proxy_logger.warning( |
|
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" |
|
) |
|
return True |
|
|
|
|
|
content_length = request.headers.get("content-length") |
|
|
|
if content_length: |
|
header_size = int(content_length) |
|
header_size_mb = bytes_to_mb(bytes_value=header_size) |
|
verbose_proxy_logger.debug( |
|
f"content_length request size in MB={header_size_mb}" |
|
) |
|
|
|
if header_size_mb > max_request_size_mb: |
|
raise ProxyException( |
|
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB", |
|
type=ProxyErrorTypes.bad_request_error.value, |
|
code=400, |
|
param="content-length", |
|
) |
|
else: |
|
|
|
body = await request.body() |
|
body_size = len(body) |
|
request_size_mb = bytes_to_mb(bytes_value=body_size) |
|
|
|
verbose_proxy_logger.debug( |
|
f"request body request size in MB={request_size_mb}" |
|
) |
|
if request_size_mb > max_request_size_mb: |
|
raise ProxyException( |
|
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB", |
|
type=ProxyErrorTypes.bad_request_error.value, |
|
code=400, |
|
param="content-length", |
|
) |
|
|
|
return True |
|
|
|
|
|
async def check_response_size_is_safe(response: Any) -> bool: |
|
""" |
|
Enterprise Only: |
|
- Checks if the response size is within the limit |
|
|
|
Args: |
|
response (Any): The response to check. |
|
|
|
Returns: |
|
bool: True if the response size is within the limit |
|
|
|
Raises: |
|
ProxyException: If the response size is too large |
|
|
|
""" |
|
|
|
from litellm.proxy.proxy_server import general_settings, premium_user |
|
|
|
max_response_size_mb = general_settings.get("max_response_size_mb", None) |
|
if max_response_size_mb is not None: |
|
|
|
if premium_user is not True: |
|
verbose_proxy_logger.warning( |
|
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" |
|
) |
|
return True |
|
|
|
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response)) |
|
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}") |
|
if response_size_mb > max_response_size_mb: |
|
raise ProxyException( |
|
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB", |
|
type=ProxyErrorTypes.bad_request_error.value, |
|
code=400, |
|
param="content-length", |
|
) |
|
|
|
return True |
|
|
|
|
|
def bytes_to_mb(bytes_value: int): |
|
""" |
|
Helper to convert bytes to MB |
|
""" |
|
return bytes_value / (1024 * 1024) |
|
|
|
|
|
|
|
def get_key_model_rpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]: |
|
if user_api_key_dict.metadata: |
|
if "model_rpm_limit" in user_api_key_dict.metadata: |
|
return user_api_key_dict.metadata["model_rpm_limit"] |
|
elif user_api_key_dict.model_max_budget: |
|
model_rpm_limit: Dict[str, Any] = {} |
|
for model, budget in user_api_key_dict.model_max_budget.items(): |
|
if "rpm_limit" in budget and budget["rpm_limit"] is not None: |
|
model_rpm_limit[model] = budget["rpm_limit"] |
|
return model_rpm_limit |
|
|
|
return None |
|
|
|
|
|
def get_key_model_tpm_limit(user_api_key_dict: UserAPIKeyAuth) -> Optional[dict]: |
|
if user_api_key_dict.metadata: |
|
if "model_tpm_limit" in user_api_key_dict.metadata: |
|
return user_api_key_dict.metadata["model_tpm_limit"] |
|
elif user_api_key_dict.model_max_budget: |
|
if "tpm_limit" in user_api_key_dict.model_max_budget: |
|
return user_api_key_dict.model_max_budget["tpm_limit"] |
|
|
|
return None |
|
|
|
|
|
def is_pass_through_provider_route(route: str) -> bool: |
|
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [ |
|
"vertex-ai", |
|
] |
|
|
|
|
|
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES: |
|
if prefix in route: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def should_run_auth_on_pass_through_provider_route(route: str) -> bool: |
|
""" |
|
Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes |
|
Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes |
|
ex /vertex-ai/{endpoint} routes |
|
Run virtual key auth if the following is try: |
|
- User is premium_user |
|
- User has enabled litellm_setting.use_client_credentials_pass_through_routes |
|
""" |
|
from litellm.proxy.proxy_server import general_settings, premium_user |
|
|
|
if premium_user is not True: |
|
|
|
return False |
|
|
|
|
|
if ( |
|
general_settings.get("use_client_credentials_pass_through_routes", False) |
|
is True |
|
): |
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _has_user_setup_sso(): |
|
""" |
|
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables. |
|
Returns a boolean indicating whether SSO has been set up. |
|
""" |
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) |
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) |
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) |
|
|
|
sso_setup = ( |
|
(microsoft_client_id is not None) |
|
or (google_client_id is not None) |
|
or (generic_client_id is not None) |
|
) |
|
|
|
return sso_setup |
|
|
|
|
|
def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]: |
|
|
|
if "user" in request_body and request_body["user"] is not None: |
|
return str(request_body["user"]) |
|
|
|
end_user_id = request_body.get("litellm_metadata", {}).get("user", None) |
|
if end_user_id: |
|
return str(end_user_id) |
|
metadata = request_body.get("metadata") |
|
if metadata and "user_id" in metadata and metadata["user_id"] is not None: |
|
return str(metadata["user_id"]) |
|
return None |
|
|