Spaces:
Paused
Paused
| import os | |
| import re | |
| import sys | |
| from typing import Any, List, Optional, Tuple | |
| from fastapi import HTTPException, Request, status | |
| from litellm import Router, provider_list | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy._types import * | |
| from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS | |
| def _get_request_ip_address( | |
| request: Request, use_x_forwarded_for: Optional[bool] = False | |
| ) -> Optional[str]: | |
| client_ip = None | |
| if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: | |
| client_ip = request.headers["x-forwarded-for"] | |
| elif request.client is not None: | |
| client_ip = request.client.host | |
| else: | |
| client_ip = "" | |
| return client_ip | |
| def _check_valid_ip( | |
| allowed_ips: Optional[List[str]], | |
| request: Request, | |
| use_x_forwarded_for: Optional[bool] = False, | |
| ) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Returns if ip is allowed or not | |
| """ | |
| if allowed_ips is None: # if not set, assume true | |
| return True, None | |
| # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for | |
| client_ip = _get_request_ip_address( | |
| request=request, use_x_forwarded_for=use_x_forwarded_for | |
| ) | |
| # Check if IP address is allowed | |
| if client_ip not in allowed_ips: | |
| return False, client_ip | |
| return True, client_ip | |
| def check_complete_credentials(request_body: dict) -> bool: | |
| """ | |
| if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks. | |
| """ | |
| given_model: Optional[str] = None | |
| given_model = request_body.get("model") | |
| if given_model is None: | |
| return False | |
| if ( | |
| "sagemaker" in given_model | |
| or "bedrock" in given_model | |
| or "vertex_ai" in given_model | |
| or "vertex_ai_beta" in given_model | |
| ): | |
| # complex credentials - easier to make a malicious request | |
| return False | |
| if "api_key" in request_body: | |
| return True | |
| return False | |
| def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool: | |
| """ | |
| Check if request_body_value matches the regex_str or is equal to param | |
| """ | |
| if re.match(regex_str, request_body_value) or regex_str == request_body_value: | |
| return True | |
| return False | |
| def _is_param_allowed( | |
| param: str, | |
| request_body_value: Any, | |
| configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS, | |
| ) -> bool: | |
| """ | |
| Check if param is a str or dict and if request_body_value is in the list of allowed values | |
| """ | |
| if configurable_clientside_auth_params is None: | |
| return False | |
| for item in configurable_clientside_auth_params: | |
| if isinstance(item, str) and param == item: | |
| return True | |
| elif isinstance(item, Dict): | |
| if param == "api_base" and check_regex_or_str_match( | |
| request_body_value=request_body_value, | |
| regex_str=item["api_base"], | |
| ): # assume param is a regex | |
| return True | |
| return False | |
| def _allow_model_level_clientside_configurable_parameters( | |
| model: str, param: str, request_body_value: Any, llm_router: Optional[Router] | |
| ) -> bool: | |
| """ | |
| Check if model is allowed to use configurable client-side params | |
| - get matching model | |
| - check if 'clientside_configurable_parameters' is set for model | |
| - | |
| """ | |
| if llm_router is None: | |
| return False | |
| # check if model is set | |
| model_info = llm_router.get_model_group_info(model_group=model) | |
| if model_info is None: | |
| # check if wildcard model is set | |
| if model.split("/", 1)[0] in provider_list: | |
| model_info = llm_router.get_model_group_info( | |
| model_group=model.split("/", 1)[0] | |
| ) | |
| if model_info is None: | |
| return False | |
| if model_info is None or model_info.configurable_clientside_auth_params is None: | |
| return False | |
| return _is_param_allowed( | |
| param=param, | |
| request_body_value=request_body_value, | |
| configurable_clientside_auth_params=model_info.configurable_clientside_auth_params, | |
| ) | |
| def is_request_body_safe( | |
| request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str | |
| ) -> bool: | |
| """ | |
| Check if the request body is safe. | |
| A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key. | |
| Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997 | |
| """ | |
| banned_params = ["api_base", "base_url"] | |
| for param in banned_params: | |
| if ( | |
| param in request_body | |
| and not check_complete_credentials( # allow client-credentials to be passed to proxy | |
| request_body=request_body | |
| ) | |
| ): | |
| if general_settings.get("allow_client_side_credentials") is True: | |
| return True | |
| elif ( | |
| _allow_model_level_clientside_configurable_parameters( | |
| model=model, | |
| param=param, | |
| request_body_value=request_body[param], | |
| llm_router=llm_router, | |
| ) | |
| is True | |
| ): | |
| return True | |
| raise ValueError( | |
| f"Rejected Request: {param} is not allowed in request body. " | |
| "Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. " | |
| "Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997", | |
| ) | |
| return True | |
| async def pre_db_read_auth_checks( | |
| request: Request, | |
| request_data: dict, | |
| route: str, | |
| ): | |
| """ | |
| 1. Checks if request size is under max_request_size_mb (if set) | |
| 2. Check if request body is safe (example user has not set api_base in request body) | |
| 3. Check if IP address is allowed (if set) | |
| 4. Check if request route is an allowed route on the proxy (if set) | |
| Returns: | |
| - True | |
| Raises: | |
| - HTTPException if request fails initial auth checks | |
| """ | |
| from litellm.proxy.proxy_server import general_settings, llm_router, premium_user | |
| # Check 1. request size | |
| await check_if_request_size_is_safe(request=request) | |
| # Check 2. Request body is safe | |
| is_request_body_safe( | |
| request_body=request_data, | |
| general_settings=general_settings, | |
| llm_router=llm_router, | |
| model=request_data.get( | |
| "model", "" | |
| ), # [TODO] use model passed in url as well (azure openai routes) | |
| ) | |
| # Check 3. Check if IP address is allowed | |
| is_valid_ip, passed_in_ip = _check_valid_ip( | |
| allowed_ips=general_settings.get("allowed_ips", None), | |
| use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), | |
| request=request, | |
| ) | |
| if not is_valid_ip: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=f"Access forbidden: IP address {passed_in_ip} not allowed.", | |
| ) | |
| # Check 4. Check if request route is an allowed route on the proxy | |
| if "allowed_routes" in general_settings: | |
| _allowed_routes = general_settings["allowed_routes"] | |
| if premium_user is not True: | |
| verbose_proxy_logger.error( | |
| f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}" | |
| ) | |
| if route not in _allowed_routes: | |
| verbose_proxy_logger.error( | |
| f"Route {route} not in allowed_routes={_allowed_routes}" | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=f"Access forbidden: Route {route} not allowed", | |
| ) | |
| def route_in_additonal_public_routes(current_route: str): | |
| """ | |
| Helper to check if the user defined public_routes on config.yaml | |
| Parameters: | |
| - current_route: str - the route the user is trying to call | |
| Returns: | |
| - bool - True if the route is defined in public_routes | |
| - bool - False if the route is not defined in public_routes | |
| In order to use this the litellm config.yaml should have the following in general_settings: | |
| ```yaml | |
| general_settings: | |
| master_key: sk-1234 | |
| public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"] | |
| ``` | |
| """ | |
| # check if user is premium_user - if not do nothing | |
| from litellm.proxy.proxy_server import general_settings, premium_user | |
| try: | |
| if premium_user is not True: | |
| return False | |
| # check if this is defined on the config | |
| if general_settings is None: | |
| return False | |
| routes_defined = general_settings.get("public_routes", []) | |
| if current_route in routes_defined: | |
| return True | |
| return False | |
| except Exception as e: | |
| verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}") | |
| return False | |
| def get_request_route(request: Request) -> str: | |
| """ | |
| Helper to get the route from the request | |
| remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions | |
| """ | |
| try: | |
| if hasattr(request, "base_url") and request.url.path.startswith( | |
| request.base_url.path | |
| ): | |
| # remove base_url from path | |
| return request.url.path[len(request.base_url.path) - 1 :] | |
| else: | |
| return request.url.path | |
| except Exception as e: | |
| verbose_proxy_logger.debug( | |
| f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}" | |
| ) | |
| return request.url.path | |
| async def check_if_request_size_is_safe(request: Request) -> bool: | |
| """ | |
| Enterprise Only: | |
| - Checks if the request size is within the limit | |
| Args: | |
| request (Request): The incoming request. | |
| Returns: | |
| bool: True if the request size is within the limit | |
| Raises: | |
| ProxyException: If the request size is too large | |
| """ | |
| from litellm.proxy.proxy_server import general_settings, premium_user | |
| max_request_size_mb = general_settings.get("max_request_size_mb", None) | |
| if max_request_size_mb is not None: | |
| # Check if premium user | |
| if premium_user is not True: | |
| verbose_proxy_logger.warning( | |
| f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" | |
| ) | |
| return True | |
| # Get the request body | |
| content_length = request.headers.get("content-length") | |
| if content_length: | |
| header_size = int(content_length) | |
| header_size_mb = bytes_to_mb(bytes_value=header_size) | |
| verbose_proxy_logger.debug( | |
| f"content_length request size in MB={header_size_mb}" | |
| ) | |
| if header_size_mb > max_request_size_mb: | |
| raise ProxyException( | |
| message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB", | |
| type=ProxyErrorTypes.bad_request_error.value, | |
| code=400, | |
| param="content-length", | |
| ) | |
| else: | |
| # If Content-Length is not available, read the body | |
| body = await request.body() | |
| body_size = len(body) | |
| request_size_mb = bytes_to_mb(bytes_value=body_size) | |
| verbose_proxy_logger.debug( | |
| f"request body request size in MB={request_size_mb}" | |
| ) | |
| if request_size_mb > max_request_size_mb: | |
| raise ProxyException( | |
| message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB", | |
| type=ProxyErrorTypes.bad_request_error.value, | |
| code=400, | |
| param="content-length", | |
| ) | |
| return True | |
| async def check_response_size_is_safe(response: Any) -> bool: | |
| """ | |
| Enterprise Only: | |
| - Checks if the response size is within the limit | |
| Args: | |
| response (Any): The response to check. | |
| Returns: | |
| bool: True if the response size is within the limit | |
| Raises: | |
| ProxyException: If the response size is too large | |
| """ | |
| from litellm.proxy.proxy_server import general_settings, premium_user | |
| max_response_size_mb = general_settings.get("max_response_size_mb", None) | |
| if max_response_size_mb is not None: | |
| # Check if premium user | |
| if premium_user is not True: | |
| verbose_proxy_logger.warning( | |
| f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}" | |
| ) | |
| return True | |
| response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response)) | |
| verbose_proxy_logger.debug(f"response size in MB={response_size_mb}") | |
| if response_size_mb > max_response_size_mb: | |
| raise ProxyException( | |
| message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB", | |
| type=ProxyErrorTypes.bad_request_error.value, | |
| code=400, | |
| param="content-length", | |
| ) | |
| return True | |
| def bytes_to_mb(bytes_value: int): | |
| """ | |
| Helper to convert bytes to MB | |
| """ | |
| return bytes_value / (1024 * 1024) | |
| # helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key | |
| def get_key_model_rpm_limit( | |
| user_api_key_dict: UserAPIKeyAuth, | |
| ) -> Optional[Dict[str, int]]: | |
| if user_api_key_dict.metadata: | |
| if "model_rpm_limit" in user_api_key_dict.metadata: | |
| return user_api_key_dict.metadata["model_rpm_limit"] | |
| elif user_api_key_dict.model_max_budget: | |
| model_rpm_limit: Dict[str, Any] = {} | |
| for model, budget in user_api_key_dict.model_max_budget.items(): | |
| if "rpm_limit" in budget and budget["rpm_limit"] is not None: | |
| model_rpm_limit[model] = budget["rpm_limit"] | |
| return model_rpm_limit | |
| return None | |
| def get_key_model_tpm_limit( | |
| user_api_key_dict: UserAPIKeyAuth, | |
| ) -> Optional[Dict[str, int]]: | |
| if user_api_key_dict.metadata: | |
| if "model_tpm_limit" in user_api_key_dict.metadata: | |
| return user_api_key_dict.metadata["model_tpm_limit"] | |
| elif user_api_key_dict.model_max_budget: | |
| if "tpm_limit" in user_api_key_dict.model_max_budget: | |
| return user_api_key_dict.model_max_budget["tpm_limit"] | |
| return None | |
| def is_pass_through_provider_route(route: str) -> bool: | |
| PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [ | |
| "vertex-ai", | |
| ] | |
| # check if any of the prefixes are in the route | |
| for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES: | |
| if prefix in route: | |
| return True | |
| return False | |
| def should_run_auth_on_pass_through_provider_route(route: str) -> bool: | |
| """ | |
| Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes | |
| Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes | |
| ex /vertex-ai/{endpoint} routes | |
| Run virtual key auth if the following is try: | |
| - User is premium_user | |
| - User has enabled litellm_setting.use_client_credentials_pass_through_routes | |
| """ | |
| from litellm.proxy.proxy_server import general_settings, premium_user | |
| if premium_user is not True: | |
| return False | |
| # premium use has opted into using client credentials | |
| if ( | |
| general_settings.get("use_client_credentials_pass_through_routes", False) | |
| is True | |
| ): | |
| return False | |
| # only enabled for LiteLLM Enterprise | |
| return True | |
| def _has_user_setup_sso(): | |
| """ | |
| Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables. | |
| Returns a boolean indicating whether SSO has been set up. | |
| """ | |
| microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) | |
| google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) | |
| generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) | |
| sso_setup = ( | |
| (microsoft_client_id is not None) | |
| or (google_client_id is not None) | |
| or (generic_client_id is not None) | |
| ) | |
| return sso_setup | |
| def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]: | |
| # openai - check 'user' | |
| if "user" in request_body and request_body["user"] is not None: | |
| return str(request_body["user"]) | |
| # anthropic - check 'litellm_metadata' | |
| end_user_id = request_body.get("litellm_metadata", {}).get("user", None) | |
| if end_user_id: | |
| return str(end_user_id) | |
| metadata = request_body.get("metadata") | |
| if metadata and "user_id" in metadata and metadata["user_id"] is not None: | |
| return str(metadata["user_id"]) | |
| return None | |