|
|
|
|
|
""" |
|
Got Valid Token from Cache, DB |
|
Run checks for: |
|
|
|
1. If user can call model |
|
2. If user is in budget |
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget |
|
""" |
|
import asyncio |
|
import re |
|
import time |
|
import traceback |
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, cast |
|
|
|
from fastapi import status |
|
from pydantic import BaseModel |
|
|
|
import litellm |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.caching.caching import DualCache |
|
from litellm.caching.dual_cache import LimitedSizeOrderedDict |
|
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider |
|
from litellm.proxy._types import ( |
|
DB_CONNECTION_ERROR_TYPES, |
|
RBAC_ROLES, |
|
CallInfo, |
|
LiteLLM_EndUserTable, |
|
LiteLLM_JWTAuth, |
|
LiteLLM_OrganizationMembershipTable, |
|
LiteLLM_OrganizationTable, |
|
LiteLLM_TeamTable, |
|
LiteLLM_TeamTableCachedObj, |
|
LiteLLM_UserTable, |
|
LiteLLMRoutes, |
|
LitellmUserRoles, |
|
ProxyErrorTypes, |
|
ProxyException, |
|
RoleBasedPermissions, |
|
UserAPIKeyAuth, |
|
) |
|
from litellm.proxy.auth.route_checks import RouteChecks |
|
from litellm.proxy.route_llm_request import route_request |
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics |
|
from litellm.router import Router |
|
from litellm.types.services import ServiceTypes |
|
|
|
from .auth_checks_organization import organization_role_based_access_check |
|
|
|
if TYPE_CHECKING: |
|
from opentelemetry.trace import Span as _Span |
|
|
|
Span = _Span |
|
else: |
|
Span = Any |
|
|
|
|
|
last_db_access_time = LimitedSizeOrderedDict(max_size=100) |
|
db_cache_expiry = 5 |
|
|
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value |
|
|
|
|
|
async def common_checks( |
|
request_body: dict, |
|
team_object: Optional[LiteLLM_TeamTable], |
|
user_object: Optional[LiteLLM_UserTable], |
|
end_user_object: Optional[LiteLLM_EndUserTable], |
|
global_proxy_spend: Optional[float], |
|
general_settings: dict, |
|
route: str, |
|
llm_router: Optional[Router], |
|
proxy_logging_obj: ProxyLogging, |
|
valid_token: Optional[UserAPIKeyAuth], |
|
) -> bool: |
|
""" |
|
Common checks across jwt + key-based auth. |
|
|
|
1. If team is blocked |
|
2. If team can call model |
|
3. If team is in budget |
|
4. If user passed in (JWT or key.user_id) - is in budget |
|
5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget |
|
6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints |
|
7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget |
|
8. [OPTIONAL] If guardrails modified - is request allowed to change this |
|
9. Check if request body is safe |
|
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks |
|
""" |
|
_model = request_body.get("model", None) |
|
|
|
|
|
if team_object is not None and team_object.blocked is True: |
|
raise Exception( |
|
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." |
|
) |
|
|
|
|
|
_team_model_access_check( |
|
team_object=team_object, |
|
model=_model, |
|
llm_router=llm_router, |
|
) |
|
|
|
|
|
if team_object is None and user_object is not None: |
|
await can_user_call_model( |
|
model=_model, |
|
llm_router=llm_router, |
|
user_object=user_object, |
|
) |
|
|
|
|
|
await _team_max_budget_check( |
|
team_object=team_object, |
|
proxy_logging_obj=proxy_logging_obj, |
|
valid_token=valid_token, |
|
) |
|
|
|
|
|
|
|
if ( |
|
(team_object is None or team_object.team_id is None) |
|
and user_object is not None |
|
and user_object.max_budget is not None |
|
): |
|
user_budget = user_object.max_budget |
|
if user_budget < user_object.spend: |
|
raise litellm.BudgetExceededError( |
|
current_cost=user_object.spend, |
|
max_budget=user_budget, |
|
message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}", |
|
) |
|
|
|
|
|
|
|
if end_user_object is not None and end_user_object.litellm_budget_table is not None: |
|
end_user_budget = end_user_object.litellm_budget_table.max_budget |
|
if end_user_budget is not None and end_user_object.spend > end_user_budget: |
|
raise litellm.BudgetExceededError( |
|
current_cost=end_user_object.spend, |
|
max_budget=end_user_budget, |
|
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}", |
|
) |
|
|
|
|
|
if ( |
|
general_settings.get("enforce_user_param", None) is not None |
|
and general_settings["enforce_user_param"] is True |
|
): |
|
if RouteChecks.is_llm_api_route(route=route) and "user" not in request_body: |
|
raise Exception( |
|
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" |
|
) |
|
|
|
if ( |
|
litellm.max_budget > 0 |
|
and global_proxy_spend is not None |
|
|
|
|
|
and RouteChecks.is_llm_api_route(route=route) |
|
and route != "/v1/models" |
|
and route != "/models" |
|
): |
|
if global_proxy_spend > litellm.max_budget: |
|
raise litellm.BudgetExceededError( |
|
current_cost=global_proxy_spend, max_budget=litellm.max_budget |
|
) |
|
|
|
_request_metadata: dict = request_body.get("metadata", {}) or {} |
|
if _request_metadata.get("guardrails"): |
|
|
|
from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails |
|
|
|
can_modify: bool = can_modify_guardrails(team_object) |
|
if can_modify is False: |
|
from fastapi import HTTPException |
|
|
|
raise HTTPException( |
|
status_code=403, |
|
detail={ |
|
"error": "Your team does not have permission to modify guardrails." |
|
}, |
|
) |
|
|
|
|
|
organization_role_based_access_check( |
|
user_object=user_object, route=route, request_body=request_body |
|
) |
|
|
|
return True |
|
|
|
|
|
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool: |
|
""" |
|
Return if a user is allowed to access route. Helper function for `allowed_routes_check`. |
|
|
|
Parameters: |
|
- user_route: str - the route the user is trying to call |
|
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user. |
|
""" |
|
|
|
for allowed_route in allowed_routes: |
|
if ( |
|
allowed_route in LiteLLMRoutes.__members__ |
|
and user_route in LiteLLMRoutes[allowed_route].value |
|
): |
|
return True |
|
elif allowed_route == user_route: |
|
return True |
|
return False |
|
|
|
|
|
def allowed_routes_check( |
|
user_role: Literal[ |
|
LitellmUserRoles.PROXY_ADMIN, |
|
LitellmUserRoles.TEAM, |
|
LitellmUserRoles.INTERNAL_USER, |
|
], |
|
user_route: str, |
|
litellm_proxy_roles: LiteLLM_JWTAuth, |
|
) -> bool: |
|
""" |
|
Check if user -> not admin - allowed to access these routes |
|
""" |
|
|
|
if user_role == LitellmUserRoles.PROXY_ADMIN: |
|
is_allowed = _allowed_routes_check( |
|
user_route=user_route, |
|
allowed_routes=litellm_proxy_roles.admin_allowed_routes, |
|
) |
|
return is_allowed |
|
|
|
elif user_role == LitellmUserRoles.TEAM: |
|
if litellm_proxy_roles.team_allowed_routes is None: |
|
""" |
|
By default allow a team to call openai + info routes |
|
""" |
|
is_allowed = _allowed_routes_check( |
|
user_route=user_route, allowed_routes=["openai_routes", "info_routes"] |
|
) |
|
return is_allowed |
|
elif litellm_proxy_roles.team_allowed_routes is not None: |
|
is_allowed = _allowed_routes_check( |
|
user_route=user_route, |
|
allowed_routes=litellm_proxy_roles.team_allowed_routes, |
|
) |
|
return is_allowed |
|
return False |
|
|
|
|
|
def allowed_route_check_inside_route( |
|
user_api_key_dict: UserAPIKeyAuth, |
|
requested_user_id: Optional[str], |
|
) -> bool: |
|
ret_val = True |
|
if ( |
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN |
|
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY |
|
): |
|
ret_val = False |
|
if requested_user_id is not None and user_api_key_dict.user_id is not None: |
|
if user_api_key_dict.user_id == requested_user_id: |
|
ret_val = True |
|
return ret_val |
|
|
|
|
|
def get_actual_routes(allowed_routes: list) -> list: |
|
actual_routes: list = [] |
|
for route_name in allowed_routes: |
|
try: |
|
route_value = LiteLLMRoutes[route_name].value |
|
if isinstance(route_value, set): |
|
actual_routes.extend(list(route_value)) |
|
else: |
|
actual_routes.extend(route_value) |
|
|
|
except KeyError: |
|
actual_routes.append(route_name) |
|
return actual_routes |
|
|
|
|
|
@log_db_metrics |
|
async def get_end_user_object( |
|
end_user_id: Optional[str], |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span] = None, |
|
proxy_logging_obj: Optional[ProxyLogging] = None, |
|
) -> Optional[LiteLLM_EndUserTable]: |
|
""" |
|
Returns end user object, if in db. |
|
|
|
Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user). |
|
""" |
|
if prisma_client is None: |
|
raise Exception("No db connected") |
|
|
|
if end_user_id is None: |
|
return None |
|
_key = "end_user_id:{}".format(end_user_id) |
|
|
|
def check_in_budget(end_user_obj: LiteLLM_EndUserTable): |
|
if end_user_obj.litellm_budget_table is None: |
|
return |
|
end_user_budget = end_user_obj.litellm_budget_table.max_budget |
|
if end_user_budget is not None and end_user_obj.spend > end_user_budget: |
|
raise litellm.BudgetExceededError( |
|
current_cost=end_user_obj.spend, max_budget=end_user_budget |
|
) |
|
|
|
|
|
cached_user_obj = await user_api_key_cache.async_get_cache(key=_key) |
|
if cached_user_obj is not None: |
|
if isinstance(cached_user_obj, dict): |
|
return_obj = LiteLLM_EndUserTable(**cached_user_obj) |
|
check_in_budget(end_user_obj=return_obj) |
|
return return_obj |
|
elif isinstance(cached_user_obj, LiteLLM_EndUserTable): |
|
return_obj = cached_user_obj |
|
check_in_budget(end_user_obj=return_obj) |
|
return return_obj |
|
|
|
try: |
|
response = await prisma_client.db.litellm_endusertable.find_unique( |
|
where={"user_id": end_user_id}, |
|
include={"litellm_budget_table": True}, |
|
) |
|
|
|
if response is None: |
|
raise Exception |
|
|
|
|
|
await user_api_key_cache.async_set_cache( |
|
key="end_user_id:{}".format(end_user_id), value=response |
|
) |
|
|
|
_response = LiteLLM_EndUserTable(**response.dict()) |
|
|
|
check_in_budget(end_user_obj=_response) |
|
|
|
return _response |
|
except Exception as e: |
|
if isinstance(e, litellm.BudgetExceededError): |
|
raise e |
|
return None |
|
|
|
|
|
def model_in_access_group( |
|
model: str, team_models: Optional[List[str]], llm_router: Optional[Router] |
|
) -> bool: |
|
from collections import defaultdict |
|
|
|
if team_models is None: |
|
return True |
|
if model in team_models: |
|
return True |
|
|
|
access_groups: dict[str, list[str]] = defaultdict(list) |
|
if llm_router: |
|
access_groups = llm_router.get_model_access_groups(model_name=model) |
|
|
|
if len(access_groups) > 0: |
|
for idx, m in enumerate( |
|
team_models |
|
): |
|
if m in access_groups: |
|
return True |
|
|
|
|
|
filtered_models = [m for m in team_models if m not in access_groups] |
|
|
|
if model in filtered_models: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def _should_check_db( |
|
key: str, last_db_access_time: LimitedSizeOrderedDict, db_cache_expiry: int |
|
) -> bool: |
|
""" |
|
Prevent calling db repeatedly for items that don't exist in the db. |
|
""" |
|
current_time = time.time() |
|
|
|
if key not in last_db_access_time: |
|
return True |
|
elif ( |
|
last_db_access_time[key][0] is not None |
|
): |
|
return True |
|
elif last_db_access_time[key][0] is None: |
|
if current_time - last_db_access_time[key] >= db_cache_expiry: |
|
return True |
|
return False |
|
|
|
|
|
def _update_last_db_access_time( |
|
key: str, value: Optional[Any], last_db_access_time: LimitedSizeOrderedDict |
|
): |
|
last_db_access_time[key] = (value, time.time()) |
|
|
|
|
|
def _get_role_based_permissions( |
|
rbac_role: RBAC_ROLES, |
|
general_settings: dict, |
|
key: Literal["models", "routes"], |
|
) -> Optional[List[str]]: |
|
""" |
|
Get the role based permissions from the general settings. |
|
""" |
|
role_based_permissions = cast( |
|
Optional[List[RoleBasedPermissions]], |
|
general_settings.get("role_permissions", []), |
|
) |
|
if role_based_permissions is None: |
|
return None |
|
|
|
for role_based_permission in role_based_permissions: |
|
|
|
if role_based_permission.role == rbac_role: |
|
return getattr(role_based_permission, key) |
|
|
|
return None |
|
|
|
|
|
def get_role_based_models( |
|
rbac_role: RBAC_ROLES, |
|
general_settings: dict, |
|
) -> Optional[List[str]]: |
|
""" |
|
Get the models allowed for a user role. |
|
|
|
Used by JWT Auth. |
|
""" |
|
|
|
return _get_role_based_permissions( |
|
rbac_role=rbac_role, |
|
general_settings=general_settings, |
|
key="models", |
|
) |
|
|
|
|
|
def get_role_based_routes( |
|
rbac_role: RBAC_ROLES, |
|
general_settings: dict, |
|
) -> Optional[List[str]]: |
|
""" |
|
Get the routes allowed for a user role. |
|
""" |
|
|
|
return _get_role_based_permissions( |
|
rbac_role=rbac_role, |
|
general_settings=general_settings, |
|
key="routes", |
|
) |
|
|
|
|
|
async def _get_fuzzy_user_object( |
|
prisma_client: PrismaClient, |
|
sso_user_id: Optional[str] = None, |
|
user_email: Optional[str] = None, |
|
) -> Optional[LiteLLM_UserTable]: |
|
""" |
|
Checks if sso user is in db. |
|
|
|
Called when user id match is not found in db. |
|
|
|
- Check if sso_user_id is user_id in db |
|
- Check if sso_user_id is sso_user_id in db |
|
- Check if user_email is user_email in db |
|
- If not, create new user with user_email and sso_user_id and user_id = sso_user_id |
|
""" |
|
response = None |
|
if sso_user_id is not None: |
|
response = await prisma_client.db.litellm_usertable.find_unique( |
|
where={"sso_user_id": sso_user_id}, |
|
include={"organization_memberships": True}, |
|
) |
|
|
|
if response is None and user_email is not None: |
|
response = await prisma_client.db.litellm_usertable.find_first( |
|
where={"user_email": user_email}, |
|
include={"organization_memberships": True}, |
|
) |
|
|
|
if response is not None and sso_user_id is not None: |
|
asyncio.create_task( |
|
prisma_client.db.litellm_usertable.update( |
|
where={"user_id": response.user_id}, |
|
data={"sso_user_id": sso_user_id}, |
|
) |
|
) |
|
|
|
return response |
|
|
|
|
|
@log_db_metrics |
|
async def get_user_object( |
|
user_id: Optional[str], |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
user_id_upsert: bool, |
|
parent_otel_span: Optional[Span] = None, |
|
proxy_logging_obj: Optional[ProxyLogging] = None, |
|
sso_user_id: Optional[str] = None, |
|
user_email: Optional[str] = None, |
|
) -> Optional[LiteLLM_UserTable]: |
|
""" |
|
- Check if user id in proxy User Table |
|
- if valid, return LiteLLM_UserTable object with defined limits |
|
- if not, then raise an error |
|
""" |
|
|
|
if user_id is None: |
|
return None |
|
|
|
|
|
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id) |
|
if cached_user_obj is not None: |
|
if isinstance(cached_user_obj, dict): |
|
return LiteLLM_UserTable(**cached_user_obj) |
|
elif isinstance(cached_user_obj, LiteLLM_UserTable): |
|
return cached_user_obj |
|
|
|
if prisma_client is None: |
|
raise Exception("No db connected") |
|
try: |
|
db_access_time_key = "user_id:{}".format(user_id) |
|
should_check_db = _should_check_db( |
|
key=db_access_time_key, |
|
last_db_access_time=last_db_access_time, |
|
db_cache_expiry=db_cache_expiry, |
|
) |
|
|
|
if should_check_db: |
|
response = await prisma_client.db.litellm_usertable.find_unique( |
|
where={"user_id": user_id}, include={"organization_memberships": True} |
|
) |
|
|
|
if response is None: |
|
response = await _get_fuzzy_user_object( |
|
prisma_client=prisma_client, |
|
sso_user_id=sso_user_id, |
|
user_email=user_email, |
|
) |
|
|
|
else: |
|
response = None |
|
|
|
if response is None: |
|
if user_id_upsert: |
|
response = await prisma_client.db.litellm_usertable.create( |
|
data={"user_id": user_id}, |
|
include={"organization_memberships": True}, |
|
) |
|
else: |
|
raise Exception |
|
|
|
if ( |
|
response.organization_memberships is not None |
|
and len(response.organization_memberships) > 0 |
|
): |
|
|
|
_dumped_memberships = [ |
|
LiteLLM_OrganizationMembershipTable(**membership.model_dump()) |
|
for membership in response.organization_memberships |
|
if membership is not None |
|
] |
|
response.organization_memberships = _dumped_memberships |
|
|
|
_response = LiteLLM_UserTable(**dict(response)) |
|
response_dict = _response.model_dump() |
|
|
|
|
|
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict) |
|
|
|
|
|
_update_last_db_access_time( |
|
key=db_access_time_key, |
|
value=response_dict, |
|
last_db_access_time=last_db_access_time, |
|
) |
|
|
|
return _response |
|
except Exception as e: |
|
raise ValueError( |
|
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call. Got error - {e}" |
|
) |
|
|
|
|
|
async def _cache_management_object( |
|
key: str, |
|
value: BaseModel, |
|
user_api_key_cache: DualCache, |
|
proxy_logging_obj: Optional[ProxyLogging], |
|
): |
|
await user_api_key_cache.async_set_cache(key=key, value=value) |
|
|
|
|
|
async def _cache_team_object( |
|
team_id: str, |
|
team_table: LiteLLM_TeamTableCachedObj, |
|
user_api_key_cache: DualCache, |
|
proxy_logging_obj: Optional[ProxyLogging], |
|
): |
|
key = "team_id:{}".format(team_id) |
|
|
|
|
|
team_table.last_refreshed_at = time.time() |
|
|
|
await _cache_management_object( |
|
key=key, |
|
value=team_table, |
|
user_api_key_cache=user_api_key_cache, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
|
|
async def _cache_key_object( |
|
hashed_token: str, |
|
user_api_key_obj: UserAPIKeyAuth, |
|
user_api_key_cache: DualCache, |
|
proxy_logging_obj: Optional[ProxyLogging], |
|
): |
|
key = hashed_token |
|
|
|
|
|
user_api_key_obj.last_refreshed_at = time.time() |
|
|
|
await _cache_management_object( |
|
key=key, |
|
value=user_api_key_obj, |
|
user_api_key_cache=user_api_key_cache, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
|
|
async def _delete_cache_key_object( |
|
hashed_token: str, |
|
user_api_key_cache: DualCache, |
|
proxy_logging_obj: Optional[ProxyLogging], |
|
): |
|
key = hashed_token |
|
|
|
user_api_key_cache.delete_cache(key=key) |
|
|
|
|
|
if proxy_logging_obj is not None: |
|
await proxy_logging_obj.internal_usage_cache.dual_cache.async_delete_cache( |
|
key=key |
|
) |
|
|
|
|
|
@log_db_metrics |
|
async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): |
|
return await prisma_client.db.litellm_teamtable.find_unique( |
|
where={"team_id": team_id} |
|
) |
|
|
|
|
|
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): |
|
return await prisma_client.db.litellm_teamtable.find_unique( |
|
where={"team_id": team_id} |
|
) |
|
|
|
|
|
async def _get_team_object_from_user_api_key_cache( |
|
team_id: str, |
|
prisma_client: PrismaClient, |
|
user_api_key_cache: DualCache, |
|
last_db_access_time: LimitedSizeOrderedDict, |
|
db_cache_expiry: int, |
|
proxy_logging_obj: Optional[ProxyLogging], |
|
key: str, |
|
) -> LiteLLM_TeamTableCachedObj: |
|
db_access_time_key = key |
|
should_check_db = _should_check_db( |
|
key=db_access_time_key, |
|
last_db_access_time=last_db_access_time, |
|
db_cache_expiry=db_cache_expiry, |
|
) |
|
if should_check_db: |
|
response = await _get_team_db_check( |
|
team_id=team_id, prisma_client=prisma_client |
|
) |
|
else: |
|
response = None |
|
|
|
if response is None: |
|
raise Exception |
|
|
|
_response = LiteLLM_TeamTableCachedObj(**response.dict()) |
|
|
|
await _cache_team_object( |
|
team_id=team_id, |
|
team_table=_response, |
|
user_api_key_cache=user_api_key_cache, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
|
|
|
|
_update_last_db_access_time( |
|
key=db_access_time_key, |
|
value=_response, |
|
last_db_access_time=last_db_access_time, |
|
) |
|
|
|
return _response |
|
|
|
|
|
async def _get_team_object_from_cache( |
|
key: str, |
|
proxy_logging_obj: Optional[ProxyLogging], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span], |
|
) -> Optional[LiteLLM_TeamTableCachedObj]: |
|
cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None |
|
|
|
|
|
if ( |
|
proxy_logging_obj is not None |
|
and proxy_logging_obj.internal_usage_cache.dual_cache |
|
): |
|
|
|
cached_team_obj = ( |
|
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( |
|
key=key, parent_otel_span=parent_otel_span |
|
) |
|
) |
|
|
|
if cached_team_obj is None: |
|
cached_team_obj = await user_api_key_cache.async_get_cache(key=key) |
|
|
|
if cached_team_obj is not None: |
|
if isinstance(cached_team_obj, dict): |
|
return LiteLLM_TeamTableCachedObj(**cached_team_obj) |
|
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj): |
|
return cached_team_obj |
|
|
|
return None |
|
|
|
|
|
async def get_team_object( |
|
team_id: str, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span] = None, |
|
proxy_logging_obj: Optional[ProxyLogging] = None, |
|
check_cache_only: Optional[bool] = None, |
|
check_db_only: Optional[bool] = None, |
|
) -> LiteLLM_TeamTableCachedObj: |
|
""" |
|
- Check if team id in proxy Team Table |
|
- if valid, return LiteLLM_TeamTable object with defined limits |
|
- if not, then raise an error |
|
|
|
Raises: |
|
- Exception: If team doesn't exist in db or cache |
|
""" |
|
if prisma_client is None: |
|
raise Exception( |
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" |
|
) |
|
|
|
|
|
key = "team_id:{}".format(team_id) |
|
|
|
if not check_db_only: |
|
cached_team_obj = await _get_team_object_from_cache( |
|
key=key, |
|
proxy_logging_obj=proxy_logging_obj, |
|
user_api_key_cache=user_api_key_cache, |
|
parent_otel_span=parent_otel_span, |
|
) |
|
|
|
if cached_team_obj is not None: |
|
return cached_team_obj |
|
|
|
if check_cache_only: |
|
raise Exception( |
|
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}." |
|
) |
|
|
|
|
|
try: |
|
return await _get_team_object_from_user_api_key_cache( |
|
team_id=team_id, |
|
prisma_client=prisma_client, |
|
user_api_key_cache=user_api_key_cache, |
|
proxy_logging_obj=proxy_logging_obj, |
|
last_db_access_time=last_db_access_time, |
|
db_cache_expiry=db_cache_expiry, |
|
key=key, |
|
) |
|
except Exception: |
|
raise Exception( |
|
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." |
|
) |
|
|
|
|
|
@log_db_metrics |
|
async def get_key_object( |
|
hashed_token: str, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span] = None, |
|
proxy_logging_obj: Optional[ProxyLogging] = None, |
|
check_cache_only: Optional[bool] = None, |
|
) -> UserAPIKeyAuth: |
|
""" |
|
- Check if team id in proxy Team Table |
|
- if valid, return LiteLLM_TeamTable object with defined limits |
|
- if not, then raise an error |
|
""" |
|
if prisma_client is None: |
|
raise Exception( |
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" |
|
) |
|
|
|
|
|
key = hashed_token |
|
|
|
cached_key_obj: Optional[UserAPIKeyAuth] = await user_api_key_cache.async_get_cache( |
|
key=key |
|
) |
|
|
|
if cached_key_obj is not None: |
|
if isinstance(cached_key_obj, dict): |
|
return UserAPIKeyAuth(**cached_key_obj) |
|
elif isinstance(cached_key_obj, UserAPIKeyAuth): |
|
return cached_key_obj |
|
|
|
if check_cache_only: |
|
raise Exception( |
|
f"Key doesn't exist in cache + check_cache_only=True. key={key}." |
|
) |
|
|
|
|
|
try: |
|
_valid_token: Optional[BaseModel] = await prisma_client.get_data( |
|
token=hashed_token, |
|
table_name="combined_view", |
|
parent_otel_span=parent_otel_span, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
if _valid_token is None: |
|
raise Exception |
|
|
|
_response = UserAPIKeyAuth(**_valid_token.model_dump(exclude_none=True)) |
|
|
|
|
|
await _cache_key_object( |
|
hashed_token=hashed_token, |
|
user_api_key_obj=_response, |
|
user_api_key_cache=user_api_key_cache, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
return _response |
|
except DB_CONNECTION_ERROR_TYPES as e: |
|
return await _handle_failed_db_connection_for_get_key_object(e=e) |
|
except Exception: |
|
traceback.print_exc() |
|
raise Exception( |
|
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call." |
|
) |
|
|
|
|
|
async def _handle_failed_db_connection_for_get_key_object( |
|
e: Exception, |
|
) -> UserAPIKeyAuth: |
|
""" |
|
Handles httpx.ConnectError when reading a Virtual Key from LiteLLM DB |
|
|
|
Use this if you don't want failed DB queries to block LLM API reqiests |
|
|
|
Returns: |
|
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True |
|
|
|
Raises: |
|
- Orignal Exception in all other cases |
|
""" |
|
from litellm.proxy.proxy_server import ( |
|
general_settings, |
|
litellm_proxy_admin_name, |
|
proxy_logging_obj, |
|
) |
|
|
|
|
|
if general_settings.get("allow_requests_on_db_unavailable", False) is True: |
|
|
|
proxy_logging_obj.service_logging_obj.service_failure_hook( |
|
service=ServiceTypes.DB, |
|
call_type="get_key_object", |
|
error=e, |
|
duration=0.0, |
|
) |
|
|
|
return UserAPIKeyAuth( |
|
key_name="failed-to-connect-to-db", |
|
token="failed-to-connect-to-db", |
|
user_id=litellm_proxy_admin_name, |
|
) |
|
else: |
|
|
|
raise e |
|
|
|
|
|
@log_db_metrics |
|
async def get_org_object( |
|
org_id: str, |
|
prisma_client: Optional[PrismaClient], |
|
user_api_key_cache: DualCache, |
|
parent_otel_span: Optional[Span] = None, |
|
proxy_logging_obj: Optional[ProxyLogging] = None, |
|
) -> Optional[LiteLLM_OrganizationTable]: |
|
""" |
|
- Check if org id in proxy Org Table |
|
- if valid, return LiteLLM_OrganizationTable object |
|
- if not, then raise an error |
|
""" |
|
if prisma_client is None: |
|
raise Exception( |
|
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" |
|
) |
|
|
|
|
|
cached_org_obj = user_api_key_cache.async_get_cache(key="org_id:{}".format(org_id)) |
|
if cached_org_obj is not None: |
|
if isinstance(cached_org_obj, dict): |
|
return LiteLLM_OrganizationTable(**cached_org_obj) |
|
elif isinstance(cached_org_obj, LiteLLM_OrganizationTable): |
|
return cached_org_obj |
|
|
|
try: |
|
response = await prisma_client.db.litellm_organizationtable.find_unique( |
|
where={"organization_id": org_id} |
|
) |
|
|
|
if response is None: |
|
raise Exception |
|
|
|
return response |
|
except Exception: |
|
raise Exception( |
|
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call." |
|
) |
|
|
|
|
|
async def _can_object_call_model( |
|
model: str, |
|
llm_router: Optional[Router], |
|
models: List[str], |
|
) -> Literal[True]: |
|
""" |
|
Checks if token can call a given model |
|
|
|
Returns: |
|
- True: if token allowed to call model |
|
|
|
Raises: |
|
- Exception: If token not allowed to call model |
|
""" |
|
if model in litellm.model_alias_map: |
|
model = litellm.model_alias_map[model] |
|
|
|
|
|
from collections import defaultdict |
|
|
|
access_groups: Dict[str, List[str]] = defaultdict(list) |
|
|
|
if llm_router: |
|
access_groups = llm_router.get_model_access_groups(model_name=model) |
|
if ( |
|
len(access_groups) > 0 and llm_router is not None |
|
): |
|
for idx, m in enumerate( |
|
models |
|
): |
|
if m in access_groups: |
|
return True |
|
|
|
|
|
filtered_models = [m for m in models if m not in access_groups] |
|
|
|
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") |
|
|
|
if _model_matches_any_wildcard_pattern_in_list( |
|
model=model, allowed_model_list=filtered_models |
|
): |
|
return True |
|
|
|
all_model_access: bool = False |
|
|
|
if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: |
|
all_model_access = True |
|
|
|
if model is not None and model not in filtered_models and all_model_access is False: |
|
raise ProxyException( |
|
message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}", |
|
type=ProxyErrorTypes.key_model_access_denied, |
|
param="model", |
|
code=status.HTTP_401_UNAUTHORIZED, |
|
) |
|
|
|
verbose_proxy_logger.debug( |
|
f"filtered allowed_models: {filtered_models}; models: {models}" |
|
) |
|
return True |
|
|
|
|
|
async def can_key_call_model( |
|
model: str, |
|
llm_model_list: Optional[list], |
|
valid_token: UserAPIKeyAuth, |
|
llm_router: Optional[litellm.Router], |
|
) -> Literal[True]: |
|
""" |
|
Checks if token can call a given model |
|
|
|
Returns: |
|
- True: if token allowed to call model |
|
|
|
Raises: |
|
- Exception: If token not allowed to call model |
|
""" |
|
return await _can_object_call_model( |
|
model=model, |
|
llm_router=llm_router, |
|
models=valid_token.models, |
|
) |
|
|
|
|
|
async def can_user_call_model( |
|
model: str, |
|
llm_router: Optional[Router], |
|
user_object: Optional[LiteLLM_UserTable], |
|
) -> Literal[True]: |
|
|
|
if user_object is None: |
|
return True |
|
|
|
return await _can_object_call_model( |
|
model=model, |
|
llm_router=llm_router, |
|
models=user_object.models, |
|
) |
|
|
|
|
|
async def is_valid_fallback_model( |
|
model: str, |
|
llm_router: Optional[Router], |
|
user_model: Optional[str], |
|
) -> Literal[True]: |
|
""" |
|
Try to route the fallback model request. |
|
|
|
Validate if it can't be routed. |
|
|
|
Help catch invalid fallback models. |
|
""" |
|
await route_request( |
|
data={ |
|
"model": model, |
|
"messages": [{"role": "user", "content": "Who was Alexander?"}], |
|
}, |
|
llm_router=llm_router, |
|
user_model=user_model, |
|
route_type="acompletion", |
|
) |
|
|
|
return True |
|
|
|
|
|
async def _virtual_key_max_budget_check( |
|
valid_token: UserAPIKeyAuth, |
|
proxy_logging_obj: ProxyLogging, |
|
user_obj: Optional[LiteLLM_UserTable] = None, |
|
): |
|
""" |
|
Raises: |
|
BudgetExceededError if the token is over it's max budget. |
|
Triggers a budget alert if the token is over it's max budget. |
|
|
|
""" |
|
if valid_token.spend is not None and valid_token.max_budget is not None: |
|
|
|
|
|
|
|
|
|
user_email = None |
|
|
|
if user_obj is not None: |
|
user_email = user_obj.user_email |
|
|
|
call_info = CallInfo( |
|
token=valid_token.token, |
|
spend=valid_token.spend, |
|
max_budget=valid_token.max_budget, |
|
user_id=valid_token.user_id, |
|
team_id=valid_token.team_id, |
|
user_email=user_email, |
|
key_alias=valid_token.key_alias, |
|
) |
|
asyncio.create_task( |
|
proxy_logging_obj.budget_alerts( |
|
type="token_budget", |
|
user_info=call_info, |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if valid_token.spend >= valid_token.max_budget: |
|
raise litellm.BudgetExceededError( |
|
current_cost=valid_token.spend, |
|
max_budget=valid_token.max_budget, |
|
) |
|
|
|
|
|
async def _virtual_key_soft_budget_check( |
|
valid_token: UserAPIKeyAuth, |
|
proxy_logging_obj: ProxyLogging, |
|
): |
|
""" |
|
Triggers a budget alert if the token is over it's soft budget. |
|
|
|
""" |
|
|
|
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget: |
|
verbose_proxy_logger.debug( |
|
"Crossed Soft Budget for token %s, spend %s, soft_budget %s", |
|
valid_token.token, |
|
valid_token.spend, |
|
valid_token.soft_budget, |
|
) |
|
call_info = CallInfo( |
|
token=valid_token.token, |
|
spend=valid_token.spend, |
|
max_budget=valid_token.max_budget, |
|
soft_budget=valid_token.soft_budget, |
|
user_id=valid_token.user_id, |
|
team_id=valid_token.team_id, |
|
team_alias=valid_token.team_alias, |
|
user_email=None, |
|
key_alias=valid_token.key_alias, |
|
) |
|
asyncio.create_task( |
|
proxy_logging_obj.budget_alerts( |
|
type="soft_budget", |
|
user_info=call_info, |
|
) |
|
) |
|
|
|
|
|
async def _team_max_budget_check( |
|
team_object: Optional[LiteLLM_TeamTable], |
|
valid_token: Optional[UserAPIKeyAuth], |
|
proxy_logging_obj: ProxyLogging, |
|
): |
|
""" |
|
Check if the team is over it's max budget. |
|
|
|
Raises: |
|
BudgetExceededError if the team is over it's max budget. |
|
Triggers a budget alert if the team is over it's max budget. |
|
""" |
|
if ( |
|
team_object is not None |
|
and team_object.max_budget is not None |
|
and team_object.spend is not None |
|
and team_object.spend > team_object.max_budget |
|
): |
|
if valid_token: |
|
call_info = CallInfo( |
|
token=valid_token.token, |
|
spend=team_object.spend, |
|
max_budget=team_object.max_budget, |
|
user_id=valid_token.user_id, |
|
team_id=valid_token.team_id, |
|
team_alias=valid_token.team_alias, |
|
) |
|
asyncio.create_task( |
|
proxy_logging_obj.budget_alerts( |
|
type="team_budget", |
|
user_info=call_info, |
|
) |
|
) |
|
|
|
raise litellm.BudgetExceededError( |
|
current_cost=team_object.spend, |
|
max_budget=team_object.max_budget, |
|
message=f"Budget has been exceeded! Team={team_object.team_id} Current cost: {team_object.spend}, Max budget: {team_object.max_budget}", |
|
) |
|
|
|
|
|
def _team_model_access_check( |
|
model: Optional[str], |
|
team_object: Optional[LiteLLM_TeamTable], |
|
llm_router: Optional[Router], |
|
): |
|
""" |
|
Access check for team models |
|
Raises: |
|
Exception if the team is not allowed to call the`model` |
|
""" |
|
if ( |
|
model is not None |
|
and team_object is not None |
|
and team_object.models is not None |
|
and len(team_object.models) > 0 |
|
and model not in team_object.models |
|
): |
|
|
|
if "all-proxy-models" in team_object.models or "*" in team_object.models: |
|
|
|
pass |
|
|
|
elif ( |
|
model_in_access_group( |
|
model=model, team_models=team_object.models, llm_router=llm_router |
|
) |
|
is True |
|
): |
|
pass |
|
elif model and "*" in model: |
|
pass |
|
elif _model_matches_any_wildcard_pattern_in_list( |
|
model=model, allowed_model_list=team_object.models |
|
): |
|
pass |
|
else: |
|
raise ProxyException( |
|
message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}", |
|
type=ProxyErrorTypes.team_model_access_denied, |
|
param="model", |
|
code=status.HTTP_401_UNAUTHORIZED, |
|
) |
|
|
|
|
|
def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: |
|
""" |
|
Check if a model matches an allowed pattern. |
|
Handles exact matches and wildcard patterns. |
|
|
|
Args: |
|
model (str): The model to check (e.g., "bedrock/anthropic.claude-3-5-sonnet-20240620") |
|
allowed_model_pattern (str): The allowed pattern (e.g., "bedrock/*", "*", "openai/*") |
|
|
|
Returns: |
|
bool: True if model matches the pattern, False otherwise |
|
""" |
|
if "*" in allowed_model_pattern: |
|
pattern = f"^{allowed_model_pattern.replace('*', '.*')}$" |
|
return bool(re.match(pattern, model)) |
|
|
|
return False |
|
|
|
|
|
def _model_matches_any_wildcard_pattern_in_list( |
|
model: str, allowed_model_list: list |
|
) -> bool: |
|
""" |
|
Returns True if a model matches any wildcard pattern in a list. |
|
|
|
eg. |
|
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True |
|
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True |
|
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False |
|
""" |
|
|
|
if any( |
|
_is_wildcard_pattern(allowed_model_pattern) |
|
and is_model_allowed_by_pattern( |
|
model=model, allowed_model_pattern=allowed_model_pattern |
|
) |
|
for allowed_model_pattern in allowed_model_list |
|
): |
|
return True |
|
|
|
if any( |
|
_is_wildcard_pattern(allowed_model_pattern) |
|
and _model_custom_llm_provider_matches_wildcard_pattern( |
|
model=model, allowed_model_pattern=allowed_model_pattern |
|
) |
|
for allowed_model_pattern in allowed_model_list |
|
): |
|
return True |
|
|
|
return False |
|
|
|
|
|
def _model_custom_llm_provider_matches_wildcard_pattern( |
|
model: str, allowed_model_pattern: str |
|
) -> bool: |
|
""" |
|
Returns True for this scenario: |
|
- `model=gpt-4o` |
|
- `allowed_model_pattern=openai/*` |
|
|
|
or |
|
- `model=claude-3-5-sonnet-20240620` |
|
- `allowed_model_pattern=anthropic/*` |
|
""" |
|
try: |
|
model, custom_llm_provider, _, _ = get_llm_provider(model=model) |
|
except Exception: |
|
return False |
|
|
|
return is_model_allowed_by_pattern( |
|
model=f"{custom_llm_provider}/{model}", |
|
allowed_model_pattern=allowed_model_pattern, |
|
) |
|
|
|
|
|
def _is_wildcard_pattern(allowed_model_pattern: str) -> bool: |
|
""" |
|
Returns True if the pattern is a wildcard pattern. |
|
|
|
Checks if `*` is in the pattern. |
|
""" |
|
return "*" in allowed_model_pattern |
|
|