|
""" |
|
Supports using JWT's for authenticating into the proxy. |
|
|
|
Currently only supports admin. |
|
|
|
JWT token must have 'litellm_proxy_admin' in scope. |
|
""" |
|
|
|
import json |
|
import os |
|
from typing import Any, List, Literal, Optional, Set, Tuple, cast |
|
|
|
from cryptography import x509 |
|
from cryptography.hazmat.backends import default_backend |
|
from cryptography.hazmat.primitives import serialization |
|
from fastapi import HTTPException |
|
|
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.caching.caching import DualCache |
|
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value |
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler |
|
from litellm.proxy._types import ( |
|
RBAC_ROLES, |
|
JWKKeyValue, |
|
JWTAuthBuilderResult, |
|
JWTKeyItem, |
|
LiteLLM_EndUserTable, |
|
LiteLLM_JWTAuth, |
|
LiteLLM_OrganizationTable, |
|
LiteLLM_TeamTable, |
|
LiteLLM_UserTable, |
|
LitellmUserRoles, |
|
Span, |
|
) |
|
from litellm.proxy.utils import PrismaClient, ProxyLogging |
|
|
|
from .auth_checks import ( |
|
_allowed_routes_check, |
|
allowed_routes_check, |
|
get_actual_routes, |
|
get_end_user_object, |
|
get_org_object, |
|
get_role_based_models, |
|
get_role_based_routes, |
|
get_team_object, |
|
get_user_object, |
|
) |
|
|
|
|
|
class JWTHandler: |
|
""" |
|
- treat the sub id passed in as the user id |
|
- return an error if id making request doesn't exist in proxy user table |
|
- track spend against the user id |
|
- if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets |
|
""" |
|
|
|
prisma_client: Optional[PrismaClient] |
|
user_api_key_cache: DualCache |
|
|
|
def __init__( |
|
self, |
|
) -> None: |
|
self.http_handler = HTTPHandler() |
|
self.leeway = 0 |
|
|
|
def update_environment( |
|
self, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
litellm_jwtauth: LiteLLM_JWTAuth, |
|
leeway: int = 0, |
|
) -> None: |
|
self.prisma_client = prisma_client |
|
self.user_api_key_cache = user_api_key_cache |
|
self.litellm_jwtauth = litellm_jwtauth |
|
self.leeway = leeway |
|
|
|
def is_jwt(self, token: str): |
|
parts = token.split(".") |
|
return len(parts) == 3 |
|
|
|
def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]: |
|
""" |
|
Returns the RBAC role the token 'belongs' to based on role mappings. |
|
|
|
Args: |
|
token (dict): The JWT token containing role information |
|
|
|
Returns: |
|
Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists, |
|
None otherwise |
|
|
|
Note: |
|
The function handles both single string roles and lists of roles from the JWT. |
|
If multiple mappings match the JWT roles, the first matching mapping is returned. |
|
""" |
|
if self.litellm_jwtauth.role_mappings is None: |
|
return None |
|
|
|
jwt_role = self.get_jwt_role(token=token, default_value=None) |
|
if not jwt_role: |
|
return None |
|
|
|
jwt_role_set = set(jwt_role) |
|
|
|
for role_mapping in self.litellm_jwtauth.role_mappings: |
|
|
|
if role_mapping.role in jwt_role_set: |
|
return role_mapping.internal_role |
|
|
|
return None |
|
|
|
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]: |
|
""" |
|
Returns the RBAC role the token 'belongs' to. |
|
|
|
RBAC roles allowed to make requests: |
|
- PROXY_ADMIN: can make requests to all routes |
|
- TEAM: can make requests to routes associated with a team |
|
- INTERNAL_USER: can make requests to routes associated with a user |
|
|
|
Resolves: https://github.com/BerriAI/litellm/issues/6793 |
|
|
|
Returns: |
|
- PROXY_ADMIN: if token is admin |
|
- TEAM: if token is associated with a team |
|
- INTERNAL_USER: if token is associated with a user |
|
- None: if token is not associated with a team or user |
|
""" |
|
scopes = self.get_scopes(token=token) |
|
is_admin = self.is_admin(scopes=scopes) |
|
user_roles = self.get_user_roles(token=token, default_value=None) |
|
|
|
if is_admin: |
|
return LitellmUserRoles.PROXY_ADMIN |
|
elif self.get_team_id(token=token, default_value=None) is not None: |
|
return LitellmUserRoles.TEAM |
|
elif self.get_user_id(token=token, default_value=None) is not None: |
|
return LitellmUserRoles.INTERNAL_USER |
|
elif user_roles is not None and self.is_allowed_user_role( |
|
user_roles=user_roles |
|
): |
|
return LitellmUserRoles.INTERNAL_USER |
|
elif rbac_role := self._rbac_role_from_role_mapping(token=token): |
|
return rbac_role |
|
|
|
return None |
|
|
|
def is_admin(self, scopes: list) -> bool: |
|
if self.litellm_jwtauth.admin_jwt_scope in scopes: |
|
return True |
|
return False |
|
|
|
def get_team_ids_from_jwt(self, token: dict) -> List[str]: |
|
if self.litellm_jwtauth.team_ids_jwt_field is not None: |
|
return token[self.litellm_jwtauth.team_ids_jwt_field] |
|
return [] |
|
|
|
def get_end_user_id( |
|
self, token: dict, default_value: Optional[str] |
|
) -> Optional[str]: |
|
try: |
|
|
|
if self.litellm_jwtauth.end_user_id_jwt_field is not None: |
|
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field] |
|
else: |
|
user_id = None |
|
except KeyError: |
|
user_id = default_value |
|
|
|
return user_id |
|
|
|
def is_required_team_id(self) -> bool: |
|
""" |
|
Returns: |
|
- True: if 'team_id_jwt_field' is set |
|
- False: if not |
|
""" |
|
if self.litellm_jwtauth.team_id_jwt_field is None: |
|
return False |
|
return True |
|
|
|
def is_enforced_email_domain(self) -> bool: |
|
""" |
|
Returns: |
|
- True: if 'user_allowed_email_domain' is set |
|
- False: if 'user_allowed_email_domain' is None |
|
""" |
|
|
|
if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance( |
|
self.litellm_jwtauth.user_allowed_email_domain, str |
|
): |
|
return True |
|
return False |
|
|
|
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: |
|
try: |
|
if self.litellm_jwtauth.team_id_jwt_field is not None: |
|
team_id = token[self.litellm_jwtauth.team_id_jwt_field] |
|
elif self.litellm_jwtauth.team_id_default is not None: |
|
team_id = self.litellm_jwtauth.team_id_default |
|
else: |
|
team_id = None |
|
except KeyError: |
|
team_id = default_value |
|
return team_id |
|
|
|
def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool: |
|
""" |
|
Returns: |
|
- True: if 'user_id_upsert' is set AND valid_user_email is not False |
|
- False: if not |
|
""" |
|
if valid_user_email is False: |
|
return False |
|
return self.litellm_jwtauth.user_id_upsert |
|
|
|
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: |
|
try: |
|
if self.litellm_jwtauth.user_id_jwt_field is not None: |
|
user_id = token[self.litellm_jwtauth.user_id_jwt_field] |
|
else: |
|
user_id = default_value |
|
except KeyError: |
|
user_id = default_value |
|
return user_id |
|
|
|
def get_user_roles( |
|
self, token: dict, default_value: Optional[List[str]] |
|
) -> Optional[List[str]]: |
|
""" |
|
Returns the user role from the token. |
|
|
|
Set via 'user_roles_jwt_field' in the config. |
|
""" |
|
try: |
|
if self.litellm_jwtauth.user_roles_jwt_field is not None: |
|
user_roles = get_nested_value( |
|
data=token, |
|
key_path=self.litellm_jwtauth.user_roles_jwt_field, |
|
default=default_value, |
|
) |
|
else: |
|
user_roles = default_value |
|
except KeyError: |
|
user_roles = default_value |
|
return user_roles |
|
|
|
def get_jwt_role( |
|
self, token: dict, default_value: Optional[List[str]] |
|
) -> Optional[List[str]]: |
|
""" |
|
Generic implementation of `get_user_roles` that can be used for both user and team roles. |
|
|
|
Returns the jwt role from the token. |
|
|
|
Set via 'roles_jwt_field' in the config. |
|
""" |
|
try: |
|
if self.litellm_jwtauth.roles_jwt_field is not None: |
|
user_roles = get_nested_value( |
|
data=token, |
|
key_path=self.litellm_jwtauth.roles_jwt_field, |
|
default=default_value, |
|
) |
|
else: |
|
user_roles = default_value |
|
except KeyError: |
|
user_roles = default_value |
|
return user_roles |
|
|
|
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool: |
|
""" |
|
Returns the user role from the token. |
|
|
|
Set via 'user_allowed_roles' in the config. |
|
""" |
|
if ( |
|
user_roles is not None |
|
and self.litellm_jwtauth.user_allowed_roles is not None |
|
and any( |
|
role in self.litellm_jwtauth.user_allowed_roles for role in user_roles |
|
) |
|
): |
|
return True |
|
return False |
|
|
|
def get_user_email( |
|
self, token: dict, default_value: Optional[str] |
|
) -> Optional[str]: |
|
try: |
|
if self.litellm_jwtauth.user_email_jwt_field is not None: |
|
user_email = token[self.litellm_jwtauth.user_email_jwt_field] |
|
else: |
|
user_email = None |
|
except KeyError: |
|
user_email = default_value |
|
return user_email |
|
|
|
def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: |
|
try: |
|
if self.litellm_jwtauth.object_id_jwt_field is not None: |
|
object_id = token[self.litellm_jwtauth.object_id_jwt_field] |
|
else: |
|
object_id = default_value |
|
except KeyError: |
|
object_id = default_value |
|
return object_id |
|
|
|
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: |
|
try: |
|
if self.litellm_jwtauth.org_id_jwt_field is not None: |
|
org_id = token[self.litellm_jwtauth.org_id_jwt_field] |
|
else: |
|
org_id = None |
|
except KeyError: |
|
org_id = default_value |
|
return org_id |
|
|
|
def get_scopes(self, token: dict) -> list: |
|
try: |
|
if isinstance(token["scope"], str): |
|
|
|
scopes = token["scope"].split() |
|
elif isinstance(token["scope"], list): |
|
scopes = token["scope"] |
|
else: |
|
raise Exception( |
|
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str." |
|
) |
|
except KeyError: |
|
scopes = [] |
|
return scopes |
|
|
|
async def get_public_key(self, kid: Optional[str]) -> dict: |
|
|
|
keys_url = os.getenv("JWT_PUBLIC_KEY_URL") |
|
|
|
if keys_url is None: |
|
raise Exception("Missing JWT Public Key URL from environment.") |
|
|
|
cached_keys = await self.user_api_key_cache.async_get_cache( |
|
"litellm_jwt_auth_keys" |
|
) |
|
if cached_keys is None: |
|
response = await self.http_handler.get(keys_url) |
|
|
|
response_json = response.json() |
|
if "keys" in response_json: |
|
keys: JWKKeyValue = response.json()["keys"] |
|
else: |
|
keys = response_json |
|
|
|
await self.user_api_key_cache.async_set_cache( |
|
key="litellm_jwt_auth_keys", |
|
value=keys, |
|
ttl=self.litellm_jwtauth.public_key_ttl, |
|
) |
|
else: |
|
keys = cached_keys |
|
|
|
public_key = self.parse_keys(keys=keys, kid=kid) |
|
if public_key is None: |
|
raise Exception( |
|
f"No matching public key found. kid={kid}, keys_url={keys_url}, cached_keys={cached_keys}, len(keys)={len(keys)}" |
|
) |
|
return cast(dict, public_key) |
|
|
|
def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]: |
|
public_key: Optional[JWTKeyItem] = None |
|
if len(keys) == 1: |
|
if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None): |
|
public_key = keys |
|
elif isinstance(keys, list) and ( |
|
keys[0].get("kid", None) == kid or kid is None |
|
): |
|
public_key = keys[0] |
|
elif len(keys) > 1: |
|
for key in keys: |
|
if isinstance(key, dict): |
|
key_kid = key.get("kid", None) |
|
else: |
|
key_kid = None |
|
if ( |
|
kid is not None |
|
and isinstance(key, dict) |
|
and key_kid is not None |
|
and key_kid == kid |
|
): |
|
public_key = key |
|
|
|
return public_key |
|
|
|
def is_allowed_domain(self, user_email: str) -> bool: |
|
if self.litellm_jwtauth.user_allowed_email_domain is None: |
|
return True |
|
|
|
email_domain = user_email.split("@")[-1] |
|
if email_domain == self.litellm_jwtauth.user_allowed_email_domain: |
|
return True |
|
else: |
|
return False |
|
|
|
async def auth_jwt(self, token: str) -> dict: |
|
|
|
|
|
|
|
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"] |
|
|
|
audience = os.getenv("JWT_AUDIENCE") |
|
decode_options = None |
|
if audience is None: |
|
decode_options = {"verify_aud": False} |
|
|
|
import jwt |
|
from jwt.algorithms import RSAAlgorithm |
|
|
|
header = jwt.get_unverified_header(token) |
|
|
|
verbose_proxy_logger.debug("header: %s", header) |
|
|
|
kid = header.get("kid", None) |
|
|
|
public_key = await self.get_public_key(kid=kid) |
|
|
|
if public_key is not None and isinstance(public_key, dict): |
|
jwk = {} |
|
if "kty" in public_key: |
|
jwk["kty"] = public_key["kty"] |
|
if "kid" in public_key: |
|
jwk["kid"] = public_key["kid"] |
|
if "n" in public_key: |
|
jwk["n"] = public_key["n"] |
|
if "e" in public_key: |
|
jwk["e"] = public_key["e"] |
|
|
|
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk)) |
|
|
|
try: |
|
|
|
payload = jwt.decode( |
|
token, |
|
public_key_rsa, |
|
algorithms=algorithms, |
|
options=decode_options, |
|
audience=audience, |
|
leeway=self.leeway, |
|
) |
|
return payload |
|
|
|
except jwt.ExpiredSignatureError: |
|
|
|
raise Exception("Token Expired") |
|
except Exception as e: |
|
raise Exception(f"Validation fails: {str(e)}") |
|
elif public_key is not None and isinstance(public_key, str): |
|
try: |
|
cert = x509.load_pem_x509_certificate( |
|
public_key.encode(), default_backend() |
|
) |
|
|
|
|
|
key = cert.public_key().public_bytes( |
|
serialization.Encoding.PEM, |
|
serialization.PublicFormat.SubjectPublicKeyInfo, |
|
) |
|
|
|
|
|
payload = jwt.decode( |
|
token, |
|
key, |
|
algorithms=algorithms, |
|
audience=audience, |
|
options=decode_options, |
|
) |
|
return payload |
|
|
|
except jwt.ExpiredSignatureError: |
|
|
|
raise Exception("Token Expired") |
|
except Exception as e: |
|
raise Exception(f"Validation fails: {str(e)}") |
|
|
|
raise Exception("Invalid JWT Submitted") |
|
|
|
async def close(self): |
|
await self.http_handler.close() |
|
|
|
|
|
class JWTAuthManager: |
|
"""Manages JWT authentication and authorization operations""" |
|
|
|
@staticmethod |
|
def can_rbac_role_call_route( |
|
rbac_role: RBAC_ROLES, |
|
general_settings: dict, |
|
route: str, |
|
) -> Literal[True]: |
|
""" |
|
Checks if user is allowed to access the route, based on their role. |
|
""" |
|
role_based_routes = get_role_based_routes( |
|
rbac_role=rbac_role, general_settings=general_settings |
|
) |
|
|
|
if role_based_routes is None or route is None: |
|
return True |
|
|
|
is_allowed = _allowed_routes_check( |
|
user_route=route, |
|
allowed_routes=role_based_routes, |
|
) |
|
|
|
if not is_allowed: |
|
raise HTTPException( |
|
status_code=403, |
|
detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}", |
|
) |
|
|
|
return True |
|
|
|
@staticmethod |
|
def can_rbac_role_call_model( |
|
rbac_role: RBAC_ROLES, |
|
general_settings: dict, |
|
model: Optional[str], |
|
) -> Literal[True]: |
|
""" |
|
Checks if user is allowed to access the model, based on their role. |
|
""" |
|
role_based_models = get_role_based_models( |
|
rbac_role=rbac_role, general_settings=general_settings |
|
) |
|
if role_based_models is None or model is None: |
|
return True |
|
|
|
if model not in role_based_models: |
|
raise HTTPException( |
|
status_code=403, |
|
detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}", |
|
) |
|
|
|
return True |
|
|
|
@staticmethod |
|
async def check_rbac_role( |
|
jwt_handler: JWTHandler, |
|
jwt_valid_token: dict, |
|
general_settings: dict, |
|
request_data: dict, |
|
route: str, |
|
rbac_role: Optional[RBAC_ROLES], |
|
) -> None: |
|
"""Validate RBAC role and model access permissions""" |
|
if jwt_handler.litellm_jwtauth.enforce_rbac is True: |
|
if rbac_role is None: |
|
raise HTTPException( |
|
status_code=403, |
|
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", |
|
) |
|
JWTAuthManager.can_rbac_role_call_model( |
|
rbac_role=rbac_role, |
|
general_settings=general_settings, |
|
model=request_data.get("model"), |
|
) |
|
JWTAuthManager.can_rbac_role_call_route( |
|
rbac_role=rbac_role, |
|
general_settings=general_settings, |
|
route=route, |
|
) |
|
|
|
@staticmethod |
|
async def check_admin_access( |
|
jwt_handler: JWTHandler, |
|
scopes: list, |
|
route: str, |
|
user_id: Optional[str], |
|
org_id: Optional[str], |
|
api_key: str, |
|
) -> Optional[JWTAuthBuilderResult]: |
|
"""Check admin status and route access permissions""" |
|
if not jwt_handler.is_admin(scopes=scopes): |
|
return None |
|
|
|
is_allowed = allowed_routes_check( |
|
user_role=LitellmUserRoles.PROXY_ADMIN, |
|
user_route=route, |
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth, |
|
) |
|
if not is_allowed: |
|
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes |
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes) |
|
raise Exception( |
|
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}" |
|
) |
|
|
|
return JWTAuthBuilderResult( |
|
is_proxy_admin=True, |
|
team_object=None, |
|
user_object=None, |
|
end_user_object=None, |
|
org_object=None, |
|
token=api_key, |
|
team_id=None, |
|
user_id=user_id, |
|
end_user_id=None, |
|
org_id=org_id, |
|
) |
|
|
|
@staticmethod |
|
async def find_and_validate_specific_team_id( |
|
jwt_handler: JWTHandler, |
|
jwt_valid_token: dict, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span], |
|
proxy_logging_obj: ProxyLogging, |
|
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: |
|
"""Find and validate specific team ID""" |
|
individual_team_id = jwt_handler.get_team_id( |
|
token=jwt_valid_token, default_value=None |
|
) |
|
|
|
if not individual_team_id and jwt_handler.is_required_team_id() is True: |
|
raise Exception( |
|
f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'" |
|
) |
|
|
|
|
|
team_object: Optional[LiteLLM_TeamTable] = None |
|
if individual_team_id: |
|
team_object = await get_team_object( |
|
team_id=individual_team_id, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
return individual_team_id, team_object |
|
|
|
@staticmethod |
|
def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]: |
|
"""Get combined team IDs from groups and individual team_id""" |
|
team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token) |
|
|
|
all_team_ids = set(team_ids_from_groups) |
|
|
|
return all_team_ids |
|
|
|
@staticmethod |
|
async def find_team_with_model_access( |
|
team_ids: Set[str], |
|
requested_model: Optional[str], |
|
route: str, |
|
jwt_handler: JWTHandler, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span], |
|
proxy_logging_obj: ProxyLogging, |
|
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]: |
|
"""Find first team with access to the requested model""" |
|
|
|
if not team_ids: |
|
return None, None |
|
|
|
for team_id in team_ids: |
|
try: |
|
team_object = await get_team_object( |
|
team_id=team_id, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
if team_object and team_object.models is not None: |
|
team_models = team_object.models |
|
if isinstance(team_models, list) and ( |
|
not requested_model |
|
or requested_model in team_models |
|
or "*" in team_models |
|
): |
|
is_allowed = allowed_routes_check( |
|
user_role=LitellmUserRoles.TEAM, |
|
user_route=route, |
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth, |
|
) |
|
if is_allowed: |
|
return team_id, team_object |
|
except Exception: |
|
continue |
|
|
|
if requested_model: |
|
raise HTTPException( |
|
status_code=403, |
|
detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}", |
|
) |
|
|
|
return None, None |
|
|
|
@staticmethod |
|
async def get_user_info( |
|
jwt_handler: JWTHandler, |
|
jwt_valid_token: dict, |
|
) -> Tuple[Optional[str], Optional[str], Optional[bool]]: |
|
"""Get user email and validation status""" |
|
user_email = jwt_handler.get_user_email( |
|
token=jwt_valid_token, default_value=None |
|
) |
|
valid_user_email = None |
|
if jwt_handler.is_enforced_email_domain(): |
|
valid_user_email = ( |
|
False |
|
if user_email is None |
|
else jwt_handler.is_allowed_domain(user_email=user_email) |
|
) |
|
user_id = jwt_handler.get_user_id( |
|
token=jwt_valid_token, default_value=user_email |
|
) |
|
return user_id, user_email, valid_user_email |
|
|
|
@staticmethod |
|
async def get_objects( |
|
user_id: Optional[str], |
|
user_email: Optional[str], |
|
org_id: Optional[str], |
|
end_user_id: Optional[str], |
|
valid_user_email: Optional[bool], |
|
jwt_handler: JWTHandler, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span], |
|
proxy_logging_obj: ProxyLogging, |
|
) -> Tuple[ |
|
Optional[LiteLLM_UserTable], |
|
Optional[LiteLLM_OrganizationTable], |
|
Optional[LiteLLM_EndUserTable], |
|
]: |
|
"""Get user, org, and end user objects""" |
|
org_object: Optional[LiteLLM_OrganizationTable] = None |
|
if org_id: |
|
org_object = ( |
|
await get_org_object( |
|
org_id=org_id, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
if org_id |
|
else None |
|
) |
|
|
|
user_object: Optional[LiteLLM_UserTable] = None |
|
if user_id: |
|
user_object = ( |
|
await get_user_object( |
|
user_id=user_id, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
user_id_upsert=jwt_handler.is_upsert_user_id( |
|
valid_user_email=valid_user_email |
|
), |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
user_email=user_email, |
|
sso_user_id=user_id, |
|
) |
|
if user_id |
|
else None |
|
) |
|
|
|
end_user_object: Optional[LiteLLM_EndUserTable] = None |
|
if end_user_id: |
|
end_user_object = ( |
|
await get_end_user_object( |
|
end_user_id=end_user_id, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
if end_user_id |
|
else None |
|
) |
|
|
|
return user_object, org_object, end_user_object |
|
|
|
@staticmethod |
|
def validate_object_id( |
|
user_id: Optional[str], |
|
team_id: Optional[str], |
|
enforce_rbac: bool, |
|
is_proxy_admin: bool, |
|
) -> Literal[True]: |
|
"""If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking""" |
|
if enforce_rbac and not is_proxy_admin and not user_id and not team_id: |
|
raise HTTPException( |
|
status_code=403, |
|
detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.", |
|
) |
|
return True |
|
|
|
@staticmethod |
|
async def auth_builder( |
|
api_key: str, |
|
jwt_handler: JWTHandler, |
|
request_data: dict, |
|
general_settings: dict, |
|
route: str, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span], |
|
proxy_logging_obj: ProxyLogging, |
|
) -> JWTAuthBuilderResult: |
|
"""Main authentication and authorization builder""" |
|
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key) |
|
|
|
|
|
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token) |
|
await JWTAuthManager.check_rbac_role( |
|
jwt_handler, |
|
jwt_valid_token, |
|
general_settings, |
|
request_data, |
|
route, |
|
rbac_role, |
|
) |
|
|
|
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) |
|
|
|
|
|
scopes = jwt_handler.get_scopes(token=jwt_valid_token) |
|
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info( |
|
jwt_handler, jwt_valid_token |
|
) |
|
|
|
|
|
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None) |
|
end_user_id = jwt_handler.get_end_user_id( |
|
token=jwt_valid_token, default_value=None |
|
) |
|
team_id: Optional[str] = None |
|
team_object: Optional[LiteLLM_TeamTable] = None |
|
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None) |
|
|
|
if rbac_role and object_id: |
|
|
|
if rbac_role == LitellmUserRoles.TEAM: |
|
team_id = object_id |
|
elif rbac_role == LitellmUserRoles.INTERNAL_USER: |
|
user_id = object_id |
|
|
|
|
|
admin_result = await JWTAuthManager.check_admin_access( |
|
jwt_handler, scopes, route, user_id, org_id, api_key |
|
) |
|
if admin_result: |
|
return admin_result |
|
|
|
|
|
|
|
|
|
if not team_id: |
|
team_id, team_object = ( |
|
await JWTAuthManager.find_and_validate_specific_team_id( |
|
jwt_handler, |
|
jwt_valid_token, |
|
prisma_client, |
|
user_api_key_cache, |
|
parent_otel_span, |
|
proxy_logging_obj, |
|
) |
|
) |
|
|
|
if not team_object and not team_id: |
|
|
|
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token) |
|
team_id, team_object = await JWTAuthManager.find_team_with_model_access( |
|
team_ids=all_team_ids, |
|
requested_model=request_data.get("model"), |
|
route=route, |
|
jwt_handler=jwt_handler, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
|
|
user_object, org_object, end_user_object = await JWTAuthManager.get_objects( |
|
user_id=user_id, |
|
user_email=user_email, |
|
org_id=org_id, |
|
end_user_id=end_user_id, |
|
valid_user_email=valid_user_email, |
|
jwt_handler=jwt_handler, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
|
|
JWTAuthManager.validate_object_id( |
|
user_id=user_id, |
|
team_id=team_id, |
|
enforce_rbac=general_settings.get("enforce_rbac", False), |
|
is_proxy_admin=False, |
|
) |
|
|
|
return JWTAuthBuilderResult( |
|
is_proxy_admin=False, |
|
team_id=team_id, |
|
team_object=team_object, |
|
user_id=user_id, |
|
user_object=user_object, |
|
org_id=org_id, |
|
org_object=org_object, |
|
end_user_id=end_user_id, |
|
end_user_object=end_user_object, |
|
token=api_key, |
|
) |
|
|