# What is this? ## Helper utils for the management endpoints (keys/users/teams) import uuid from datetime import datetime from functools import wraps from typing import Optional, Tuple from fastapi import HTTPException, Request import litellm from litellm._logging import verbose_logger from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types DeleteCustomerRequest, DeleteTeamRequest, DeleteUserRequest, KeyRequest, LiteLLM_TeamMembership, LiteLLM_UserTable, ManagementEndpointLoggingPayload, Member, SSOUserDefinedValues, UpdateCustomerRequest, UpdateKeyRequest, UpdateTeamRequest, UpdateUserRequest, UserAPIKeyAuth, VirtualKeyEvent, ) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import PrismaClient def get_new_internal_user_defaults( user_id: str, user_email: Optional[str] = None ) -> dict: user_info = litellm.default_internal_user_params or {} returned_dict: SSOUserDefinedValues = { "models": user_info.get("models", None), "max_budget": user_info.get("max_budget", litellm.max_internal_user_budget), "budget_duration": user_info.get( "budget_duration", litellm.internal_user_budget_duration ), "user_email": user_email or user_info.get("user_email", None), "user_id": user_id, "user_role": "internal_user", } non_null_dict = {} for k, v in returned_dict.items(): if v is not None: non_null_dict[k] = v return non_null_dict async def add_new_member( new_member: Member, max_budget_in_team: Optional[float], prisma_client: PrismaClient, team_id: str, user_api_key_dict: UserAPIKeyAuth, litellm_proxy_admin_name: str, ) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]: """ Add a new member to a team - add team id to user table - add team member w/ budget to team member table Returns created/existing user + team membership w/ budget id """ returned_user: Optional[LiteLLM_UserTable] = None returned_team_membership: Optional[LiteLLM_TeamMembership] = None ## ADD TEAM ID, to USER TABLE IF NEW ## if new_member.user_id is not None: new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id) _returned_user = await prisma_client.db.litellm_usertable.upsert( where={"user_id": new_member.user_id}, data={ "update": {"teams": {"push": [team_id]}}, "create": {"teams": [team_id], **new_user_defaults}, # type: ignore }, ) if _returned_user is not None: returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) elif new_member.user_email is not None: new_user_defaults = get_new_internal_user_defaults( user_id=str(uuid.uuid4()), user_email=new_member.user_email ) ## user email is not unique acc. to prisma schema -> future improvement ### for now: check if it exists in db, if not - insert it existing_user_row: Optional[list] = await prisma_client.get_data( key_val={"user_email": new_member.user_email}, table_name="user", query_type="find_all", ) if existing_user_row is None or ( isinstance(existing_user_row, list) and len(existing_user_row) == 0 ): new_user_defaults["teams"] = [team_id] _returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore if _returned_user is not None: returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) elif len(existing_user_row) == 1: user_info = existing_user_row[0] _returned_user = await prisma_client.db.litellm_usertable.update( where={"user_id": user_info.user_id}, # type: ignore data={"teams": {"push": [team_id]}}, ) if _returned_user is not None: returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) elif len(existing_user_row) > 1: raise HTTPException( status_code=400, detail={ "error": "Multiple users with this email found in db. Please use 'user_id' instead." }, ) # Check if trying to set a budget for team member if ( max_budget_in_team is not None and returned_user is not None and returned_user.user_id is not None ): # create a new budget item for this member response = await prisma_client.db.litellm_budgettable.create( data={ "max_budget": max_budget_in_team, "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, } ) _budget_id = response.budget_id _returned_team_membership = ( await prisma_client.db.litellm_teammembership.create( data={ "team_id": team_id, "user_id": returned_user.user_id, "budget_id": _budget_id, }, include={"litellm_budget_table": True}, ) ) returned_team_membership = LiteLLM_TeamMembership( **_returned_team_membership.model_dump() ) if returned_user is None: raise Exception("Unable to update user table with membership information!") return returned_user, returned_team_membership def _delete_user_id_from_cache(kwargs): from litellm.proxy.proxy_server import user_api_key_cache if kwargs.get("data") is not None: update_user_request = kwargs.get("data") if isinstance(update_user_request, UpdateUserRequest): user_api_key_cache.delete_cache(key=update_user_request.user_id) # delete user request if isinstance(update_user_request, DeleteUserRequest): for user_id in update_user_request.user_ids: user_api_key_cache.delete_cache(key=user_id) pass def _delete_api_key_from_cache(kwargs): from litellm.proxy.proxy_server import user_api_key_cache if kwargs.get("data") is not None: update_request = kwargs.get("data") if isinstance(update_request, UpdateKeyRequest): user_api_key_cache.delete_cache(key=update_request.key) # delete key request if isinstance(update_request, KeyRequest): for key in update_request.keys: user_api_key_cache.delete_cache(key=key) pass def _delete_team_id_from_cache(kwargs): from litellm.proxy.proxy_server import user_api_key_cache if kwargs.get("data") is not None: update_request = kwargs.get("data") if isinstance(update_request, UpdateTeamRequest): user_api_key_cache.delete_cache(key=update_request.team_id) # delete team request if isinstance(update_request, DeleteTeamRequest): for team_id in update_request.team_ids: user_api_key_cache.delete_cache(key=team_id) pass def _delete_customer_id_from_cache(kwargs): from litellm.proxy.proxy_server import user_api_key_cache if kwargs.get("data") is not None: update_request = kwargs.get("data") if isinstance(update_request, UpdateCustomerRequest): user_api_key_cache.delete_cache(key=update_request.user_id) # delete customer request if isinstance(update_request, DeleteCustomerRequest): for user_id in update_request.user_ids: user_api_key_cache.delete_cache(key=user_id) pass async def send_management_endpoint_alert( request_kwargs: dict, user_api_key_dict: UserAPIKeyAuth, function_name: str, ): """ Sends a slack alert when: - A virtual key is created, updated, or deleted - An internal user is created, updated, or deleted - A team is created, updated, or deleted """ from litellm.proxy.proxy_server import premium_user, proxy_logging_obj from litellm.types.integrations.slack_alerting import AlertType if premium_user is not True: return management_function_to_event_name = { "generate_key_fn": AlertType.new_virtual_key_created, "update_key_fn": AlertType.virtual_key_updated, "delete_key_fn": AlertType.virtual_key_deleted, # Team events "new_team": AlertType.new_team_created, "update_team": AlertType.team_updated, "delete_team": AlertType.team_deleted, # Internal User events "new_user": AlertType.new_internal_user_created, "user_update": AlertType.internal_user_updated, "delete_user": AlertType.internal_user_deleted, } # Check if alerting is enabled if ( proxy_logging_obj is not None and proxy_logging_obj.slack_alerting_instance is not None ): # Virtual Key Events if function_name in management_function_to_event_name: _event_name: AlertType = management_function_to_event_name[function_name] key_event = VirtualKeyEvent( created_by_user_id=user_api_key_dict.user_id or "Unknown", created_by_user_role=user_api_key_dict.user_role or "Unknown", created_by_key_alias=user_api_key_dict.key_alias, request_kwargs=request_kwargs, ) # replace all "_" with " " and capitalize event_name = _event_name.replace("_", " ").title() await proxy_logging_obj.slack_alerting_instance.send_virtual_key_event_slack( key_event=key_event, event_name=event_name, alert_type=_event_name, ) def management_endpoint_wrapper(func): """ This wrapper does the following: 1. Log I/O, Exceptions to OTEL 2. Create an Audit log for success calls """ @wraps(func) async def wrapper(*args, **kwargs): start_time = datetime.now() _http_request: Optional[Request] = None try: result = await func(*args, **kwargs) end_time = datetime.now() try: if kwargs is None: kwargs = {} user_api_key_dict: UserAPIKeyAuth = ( kwargs.get("user_api_key_dict") or UserAPIKeyAuth() ) await send_management_endpoint_alert( request_kwargs=kwargs, user_api_key_dict=user_api_key_dict, function_name=func.__name__, ) _http_request = kwargs.get("http_request", None) parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) if parent_otel_span is not None: from litellm.proxy.proxy_server import open_telemetry_logger if open_telemetry_logger is not None: if _http_request: _route = _http_request.url.path _request_body: dict = await _read_request_body( request=_http_request ) _response = dict(result) if result is not None else None logging_payload = ManagementEndpointLoggingPayload( route=_route, request_data=_request_body, response=_response, start_time=start_time, end_time=end_time, ) await open_telemetry_logger.async_management_endpoint_success_hook( # type: ignore logging_payload=logging_payload, parent_otel_span=parent_otel_span, ) # Delete updated/deleted info from cache _delete_api_key_from_cache(kwargs=kwargs) _delete_user_id_from_cache(kwargs=kwargs) _delete_team_id_from_cache(kwargs=kwargs) _delete_customer_id_from_cache(kwargs=kwargs) except Exception as e: # Non-Blocking Exception verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) pass return result except Exception as e: end_time = datetime.now() if kwargs is None: kwargs = {} user_api_key_dict: UserAPIKeyAuth = ( kwargs.get("user_api_key_dict") or UserAPIKeyAuth() ) parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) if parent_otel_span is not None: from litellm.proxy.proxy_server import open_telemetry_logger if open_telemetry_logger is not None: _http_request = kwargs.get("http_request") if _http_request: _route = _http_request.url.path _request_body: dict = await _read_request_body( request=_http_request ) logging_payload = ManagementEndpointLoggingPayload( route=_route, request_data=_request_body, response=None, start_time=start_time, end_time=end_time, exception=e, ) await open_telemetry_logger.async_management_endpoint_failure_hook( # type: ignore logging_payload=logging_payload, parent_otel_span=parent_otel_span, ) raise e return wrapper