|
|
|
|
|
import uuid |
|
from datetime import datetime |
|
from functools import wraps |
|
from typing import Optional, Tuple |
|
|
|
from fastapi import HTTPException, Request |
|
|
|
import litellm |
|
from litellm._logging import verbose_logger |
|
from litellm.proxy._types import ( |
|
DeleteCustomerRequest, |
|
DeleteTeamRequest, |
|
DeleteUserRequest, |
|
KeyRequest, |
|
LiteLLM_TeamMembership, |
|
LiteLLM_UserTable, |
|
ManagementEndpointLoggingPayload, |
|
Member, |
|
SSOUserDefinedValues, |
|
UpdateCustomerRequest, |
|
UpdateKeyRequest, |
|
UpdateTeamRequest, |
|
UpdateUserRequest, |
|
UserAPIKeyAuth, |
|
VirtualKeyEvent, |
|
) |
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body |
|
from litellm.proxy.utils import PrismaClient |
|
|
|
|
|
def get_new_internal_user_defaults( |
|
user_id: str, user_email: Optional[str] = None |
|
) -> dict: |
|
user_info = litellm.default_internal_user_params or {} |
|
|
|
returned_dict: SSOUserDefinedValues = { |
|
"models": user_info.get("models", None), |
|
"max_budget": user_info.get("max_budget", litellm.max_internal_user_budget), |
|
"budget_duration": user_info.get( |
|
"budget_duration", litellm.internal_user_budget_duration |
|
), |
|
"user_email": user_email or user_info.get("user_email", None), |
|
"user_id": user_id, |
|
"user_role": "internal_user", |
|
} |
|
|
|
non_null_dict = {} |
|
for k, v in returned_dict.items(): |
|
if v is not None: |
|
non_null_dict[k] = v |
|
return non_null_dict |
|
|
|
|
|
async def add_new_member( |
|
new_member: Member, |
|
max_budget_in_team: Optional[float], |
|
prisma_client: PrismaClient, |
|
team_id: str, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
litellm_proxy_admin_name: str, |
|
) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]: |
|
""" |
|
Add a new member to a team |
|
|
|
- add team id to user table |
|
- add team member w/ budget to team member table |
|
|
|
Returns created/existing user + team membership w/ budget id |
|
""" |
|
returned_user: Optional[LiteLLM_UserTable] = None |
|
returned_team_membership: Optional[LiteLLM_TeamMembership] = None |
|
|
|
if new_member.user_id is not None: |
|
new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id) |
|
_returned_user = await prisma_client.db.litellm_usertable.upsert( |
|
where={"user_id": new_member.user_id}, |
|
data={ |
|
"update": {"teams": {"push": [team_id]}}, |
|
"create": {"teams": [team_id], **new_user_defaults}, |
|
}, |
|
) |
|
if _returned_user is not None: |
|
returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) |
|
elif new_member.user_email is not None: |
|
new_user_defaults = get_new_internal_user_defaults( |
|
user_id=str(uuid.uuid4()), user_email=new_member.user_email |
|
) |
|
|
|
|
|
existing_user_row: Optional[list] = await prisma_client.get_data( |
|
key_val={"user_email": new_member.user_email}, |
|
table_name="user", |
|
query_type="find_all", |
|
) |
|
if existing_user_row is None or ( |
|
isinstance(existing_user_row, list) and len(existing_user_row) == 0 |
|
): |
|
new_user_defaults["teams"] = [team_id] |
|
_returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") |
|
|
|
if _returned_user is not None: |
|
returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) |
|
elif len(existing_user_row) == 1: |
|
user_info = existing_user_row[0] |
|
_returned_user = await prisma_client.db.litellm_usertable.update( |
|
where={"user_id": user_info.user_id}, |
|
data={"teams": {"push": [team_id]}}, |
|
) |
|
if _returned_user is not None: |
|
returned_user = LiteLLM_UserTable(**_returned_user.model_dump()) |
|
elif len(existing_user_row) > 1: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Multiple users with this email found in db. Please use 'user_id' instead." |
|
}, |
|
) |
|
|
|
|
|
if ( |
|
max_budget_in_team is not None |
|
and returned_user is not None |
|
and returned_user.user_id is not None |
|
): |
|
|
|
response = await prisma_client.db.litellm_budgettable.create( |
|
data={ |
|
"max_budget": max_budget_in_team, |
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, |
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, |
|
} |
|
) |
|
|
|
_budget_id = response.budget_id |
|
_returned_team_membership = ( |
|
await prisma_client.db.litellm_teammembership.create( |
|
data={ |
|
"team_id": team_id, |
|
"user_id": returned_user.user_id, |
|
"budget_id": _budget_id, |
|
}, |
|
include={"litellm_budget_table": True}, |
|
) |
|
) |
|
|
|
returned_team_membership = LiteLLM_TeamMembership( |
|
**_returned_team_membership.model_dump() |
|
) |
|
|
|
if returned_user is None: |
|
raise Exception("Unable to update user table with membership information!") |
|
|
|
return returned_user, returned_team_membership |
|
|
|
|
|
def _delete_user_id_from_cache(kwargs): |
|
from litellm.proxy.proxy_server import user_api_key_cache |
|
|
|
if kwargs.get("data") is not None: |
|
update_user_request = kwargs.get("data") |
|
if isinstance(update_user_request, UpdateUserRequest): |
|
user_api_key_cache.delete_cache(key=update_user_request.user_id) |
|
|
|
|
|
if isinstance(update_user_request, DeleteUserRequest): |
|
for user_id in update_user_request.user_ids: |
|
user_api_key_cache.delete_cache(key=user_id) |
|
pass |
|
|
|
|
|
def _delete_api_key_from_cache(kwargs): |
|
from litellm.proxy.proxy_server import user_api_key_cache |
|
|
|
if kwargs.get("data") is not None: |
|
update_request = kwargs.get("data") |
|
if isinstance(update_request, UpdateKeyRequest): |
|
user_api_key_cache.delete_cache(key=update_request.key) |
|
|
|
|
|
if isinstance(update_request, KeyRequest): |
|
for key in update_request.keys: |
|
user_api_key_cache.delete_cache(key=key) |
|
pass |
|
|
|
|
|
def _delete_team_id_from_cache(kwargs): |
|
from litellm.proxy.proxy_server import user_api_key_cache |
|
|
|
if kwargs.get("data") is not None: |
|
update_request = kwargs.get("data") |
|
if isinstance(update_request, UpdateTeamRequest): |
|
user_api_key_cache.delete_cache(key=update_request.team_id) |
|
|
|
|
|
if isinstance(update_request, DeleteTeamRequest): |
|
for team_id in update_request.team_ids: |
|
user_api_key_cache.delete_cache(key=team_id) |
|
pass |
|
|
|
|
|
def _delete_customer_id_from_cache(kwargs): |
|
from litellm.proxy.proxy_server import user_api_key_cache |
|
|
|
if kwargs.get("data") is not None: |
|
update_request = kwargs.get("data") |
|
if isinstance(update_request, UpdateCustomerRequest): |
|
user_api_key_cache.delete_cache(key=update_request.user_id) |
|
|
|
|
|
if isinstance(update_request, DeleteCustomerRequest): |
|
for user_id in update_request.user_ids: |
|
user_api_key_cache.delete_cache(key=user_id) |
|
pass |
|
|
|
|
|
async def send_management_endpoint_alert( |
|
request_kwargs: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
function_name: str, |
|
): |
|
""" |
|
Sends a slack alert when: |
|
- A virtual key is created, updated, or deleted |
|
- An internal user is created, updated, or deleted |
|
- A team is created, updated, or deleted |
|
""" |
|
from litellm.proxy.proxy_server import premium_user, proxy_logging_obj |
|
from litellm.types.integrations.slack_alerting import AlertType |
|
|
|
if premium_user is not True: |
|
return |
|
|
|
management_function_to_event_name = { |
|
"generate_key_fn": AlertType.new_virtual_key_created, |
|
"update_key_fn": AlertType.virtual_key_updated, |
|
"delete_key_fn": AlertType.virtual_key_deleted, |
|
|
|
"new_team": AlertType.new_team_created, |
|
"update_team": AlertType.team_updated, |
|
"delete_team": AlertType.team_deleted, |
|
|
|
"new_user": AlertType.new_internal_user_created, |
|
"user_update": AlertType.internal_user_updated, |
|
"delete_user": AlertType.internal_user_deleted, |
|
} |
|
|
|
|
|
if ( |
|
proxy_logging_obj is not None |
|
and proxy_logging_obj.slack_alerting_instance is not None |
|
): |
|
|
|
|
|
if function_name in management_function_to_event_name: |
|
_event_name: AlertType = management_function_to_event_name[function_name] |
|
|
|
key_event = VirtualKeyEvent( |
|
created_by_user_id=user_api_key_dict.user_id or "Unknown", |
|
created_by_user_role=user_api_key_dict.user_role or "Unknown", |
|
created_by_key_alias=user_api_key_dict.key_alias, |
|
request_kwargs=request_kwargs, |
|
) |
|
|
|
|
|
event_name = _event_name.replace("_", " ").title() |
|
await proxy_logging_obj.slack_alerting_instance.send_virtual_key_event_slack( |
|
key_event=key_event, |
|
event_name=event_name, |
|
alert_type=_event_name, |
|
) |
|
|
|
|
|
def management_endpoint_wrapper(func): |
|
""" |
|
This wrapper does the following: |
|
|
|
1. Log I/O, Exceptions to OTEL |
|
2. Create an Audit log for success calls |
|
""" |
|
|
|
@wraps(func) |
|
async def wrapper(*args, **kwargs): |
|
start_time = datetime.now() |
|
_http_request: Optional[Request] = None |
|
try: |
|
result = await func(*args, **kwargs) |
|
end_time = datetime.now() |
|
try: |
|
if kwargs is None: |
|
kwargs = {} |
|
user_api_key_dict: UserAPIKeyAuth = ( |
|
kwargs.get("user_api_key_dict") or UserAPIKeyAuth() |
|
) |
|
|
|
await send_management_endpoint_alert( |
|
request_kwargs=kwargs, |
|
user_api_key_dict=user_api_key_dict, |
|
function_name=func.__name__, |
|
) |
|
_http_request = kwargs.get("http_request", None) |
|
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) |
|
if parent_otel_span is not None: |
|
from litellm.proxy.proxy_server import open_telemetry_logger |
|
|
|
if open_telemetry_logger is not None: |
|
if _http_request: |
|
_route = _http_request.url.path |
|
_request_body: dict = await _read_request_body( |
|
request=_http_request |
|
) |
|
_response = dict(result) if result is not None else None |
|
|
|
logging_payload = ManagementEndpointLoggingPayload( |
|
route=_route, |
|
request_data=_request_body, |
|
response=_response, |
|
start_time=start_time, |
|
end_time=end_time, |
|
) |
|
|
|
await open_telemetry_logger.async_management_endpoint_success_hook( |
|
logging_payload=logging_payload, |
|
parent_otel_span=parent_otel_span, |
|
) |
|
|
|
|
|
_delete_api_key_from_cache(kwargs=kwargs) |
|
_delete_user_id_from_cache(kwargs=kwargs) |
|
_delete_team_id_from_cache(kwargs=kwargs) |
|
_delete_customer_id_from_cache(kwargs=kwargs) |
|
except Exception as e: |
|
|
|
verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) |
|
pass |
|
|
|
return result |
|
except Exception as e: |
|
end_time = datetime.now() |
|
|
|
if kwargs is None: |
|
kwargs = {} |
|
user_api_key_dict: UserAPIKeyAuth = ( |
|
kwargs.get("user_api_key_dict") or UserAPIKeyAuth() |
|
) |
|
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None) |
|
if parent_otel_span is not None: |
|
from litellm.proxy.proxy_server import open_telemetry_logger |
|
|
|
if open_telemetry_logger is not None: |
|
_http_request = kwargs.get("http_request") |
|
if _http_request: |
|
_route = _http_request.url.path |
|
_request_body: dict = await _read_request_body( |
|
request=_http_request |
|
) |
|
logging_payload = ManagementEndpointLoggingPayload( |
|
route=_route, |
|
request_data=_request_body, |
|
response=None, |
|
start_time=start_time, |
|
end_time=end_time, |
|
exception=e, |
|
) |
|
|
|
await open_telemetry_logger.async_management_endpoint_failure_hook( |
|
logging_payload=logging_payload, |
|
parent_otel_span=parent_otel_span, |
|
) |
|
|
|
raise e |
|
|
|
return wrapper |
|
|