|
import asyncio |
|
import copy |
|
import hashlib |
|
import importlib |
|
import json |
|
import os |
|
import smtplib |
|
import threading |
|
import time |
|
import traceback |
|
from datetime import datetime, timedelta |
|
from email.mime.multipart import MIMEMultipart |
|
from email.mime.text import MIMEText |
|
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload |
|
|
|
from litellm.litellm_core_utils.duration_parser import duration_in_seconds |
|
from litellm.proxy._types import ( |
|
DB_CONNECTION_ERROR_TYPES, |
|
CommonProxyErrors, |
|
ProxyErrorTypes, |
|
ProxyException, |
|
) |
|
|
|
try: |
|
import backoff |
|
except ImportError: |
|
raise ImportError( |
|
"backoff is not installed. Please install it via 'pip install backoff'" |
|
) |
|
|
|
from fastapi import HTTPException, status |
|
|
|
import litellm |
|
import litellm.litellm_core_utils |
|
import litellm.litellm_core_utils.litellm_logging |
|
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm._service_logger import ServiceLogging, ServiceTypes |
|
from litellm.caching.caching import DualCache, RedisCache |
|
from litellm.exceptions import RejectedRequestError |
|
from litellm.integrations.custom_guardrail import CustomGuardrail |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting |
|
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert |
|
from litellm.litellm_core_utils.litellm_logging import Logging |
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler |
|
from litellm.proxy._types import ( |
|
AlertType, |
|
CallInfo, |
|
LiteLLM_VerificationTokenView, |
|
Member, |
|
ResetTeamBudgetRequest, |
|
UserAPIKeyAuth, |
|
) |
|
from litellm.proxy.db.create_views import ( |
|
create_missing_views, |
|
should_create_missing_views, |
|
) |
|
from litellm.proxy.db.log_db_metrics import log_db_metrics |
|
from litellm.proxy.db.prisma_client import PrismaWrapper |
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck |
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter |
|
from litellm.proxy.hooks.parallel_request_limiter import ( |
|
_PROXY_MaxParallelRequestsHandler, |
|
) |
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup |
|
from litellm.secret_managers.main import str_to_bool |
|
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES |
|
from litellm.types.utils import CallTypes, LoggedLiteLLMParams |
|
|
|
if TYPE_CHECKING: |
|
from opentelemetry.trace import Span as _Span |
|
|
|
Span = _Span |
|
else: |
|
Span = Any |
|
|
|
|
|
def print_verbose(print_statement): |
|
""" |
|
Prints the given `print_statement` to the console if `litellm.set_verbose` is True. |
|
Also logs the `print_statement` at the debug level using `verbose_proxy_logger`. |
|
|
|
:param print_statement: The statement to be printed and logged. |
|
:type print_statement: Any |
|
""" |
|
import traceback |
|
|
|
verbose_proxy_logger.debug("{}\n{}".format(print_statement, traceback.format_exc())) |
|
if litellm.set_verbose: |
|
print(f"LiteLLM Proxy: {print_statement}") |
|
|
|
|
|
def safe_deep_copy(data): |
|
""" |
|
Safe Deep Copy |
|
|
|
The LiteLLM Request has some object that can-not be pickled / deep copied |
|
|
|
Use this function to safely deep copy the LiteLLM Request |
|
""" |
|
if litellm.safe_memory_mode is True: |
|
return data |
|
|
|
litellm_parent_otel_span: Optional[Any] = None |
|
|
|
litellm_parent_otel_span = None |
|
if isinstance(data, dict): |
|
|
|
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: |
|
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span") |
|
new_data = copy.deepcopy(data) |
|
|
|
|
|
if isinstance(data, dict) and litellm_parent_otel_span is not None: |
|
if "metadata" in data: |
|
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span |
|
return new_data |
|
|
|
|
|
class InternalUsageCache: |
|
def __init__(self, dual_cache: DualCache): |
|
self.dual_cache: DualCache = dual_cache |
|
|
|
async def async_get_cache( |
|
self, |
|
key, |
|
litellm_parent_otel_span: Union[Span, None], |
|
local_only: bool = False, |
|
**kwargs, |
|
) -> Any: |
|
return await self.dual_cache.async_get_cache( |
|
key=key, |
|
local_only=local_only, |
|
parent_otel_span=litellm_parent_otel_span, |
|
**kwargs, |
|
) |
|
|
|
async def async_set_cache( |
|
self, |
|
key, |
|
value, |
|
litellm_parent_otel_span: Union[Span, None], |
|
local_only: bool = False, |
|
**kwargs, |
|
) -> None: |
|
return await self.dual_cache.async_set_cache( |
|
key=key, |
|
value=value, |
|
local_only=local_only, |
|
litellm_parent_otel_span=litellm_parent_otel_span, |
|
**kwargs, |
|
) |
|
|
|
async def async_batch_set_cache( |
|
self, |
|
cache_list: List, |
|
litellm_parent_otel_span: Union[Span, None], |
|
local_only: bool = False, |
|
**kwargs, |
|
) -> None: |
|
return await self.dual_cache.async_set_cache_pipeline( |
|
cache_list=cache_list, |
|
local_only=local_only, |
|
litellm_parent_otel_span=litellm_parent_otel_span, |
|
**kwargs, |
|
) |
|
|
|
async def async_batch_get_cache( |
|
self, |
|
keys: list, |
|
parent_otel_span: Optional[Span] = None, |
|
local_only: bool = False, |
|
): |
|
return await self.dual_cache.async_batch_get_cache( |
|
keys=keys, |
|
parent_otel_span=parent_otel_span, |
|
local_only=local_only, |
|
) |
|
|
|
async def async_increment_cache( |
|
self, |
|
key, |
|
value: float, |
|
litellm_parent_otel_span: Union[Span, None], |
|
local_only: bool = False, |
|
**kwargs, |
|
): |
|
return await self.dual_cache.async_increment_cache( |
|
key=key, |
|
value=value, |
|
local_only=local_only, |
|
parent_otel_span=litellm_parent_otel_span, |
|
**kwargs, |
|
) |
|
|
|
def set_cache( |
|
self, |
|
key, |
|
value, |
|
local_only: bool = False, |
|
**kwargs, |
|
) -> None: |
|
return self.dual_cache.set_cache( |
|
key=key, |
|
value=value, |
|
local_only=local_only, |
|
**kwargs, |
|
) |
|
|
|
def get_cache( |
|
self, |
|
key, |
|
local_only: bool = False, |
|
**kwargs, |
|
) -> Any: |
|
return self.dual_cache.get_cache( |
|
key=key, |
|
local_only=local_only, |
|
**kwargs, |
|
) |
|
|
|
|
|
|
|
class ProxyLogging: |
|
""" |
|
Logging/Custom Handlers for proxy. |
|
|
|
Implemented mainly to: |
|
- log successful/failed db read/writes |
|
- support the max parallel request integration |
|
""" |
|
|
|
def __init__( |
|
self, |
|
user_api_key_cache: DualCache, |
|
premium_user: bool = False, |
|
): |
|
|
|
self.call_details: dict = {} |
|
self.call_details["user_api_key_cache"] = user_api_key_cache |
|
self.internal_usage_cache: InternalUsageCache = InternalUsageCache( |
|
dual_cache=DualCache(default_in_memory_ttl=1) |
|
) |
|
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler( |
|
self.internal_usage_cache |
|
) |
|
self.max_budget_limiter = _PROXY_MaxBudgetLimiter() |
|
self.cache_control_check = _PROXY_CacheControlCheck() |
|
self.alerting: Optional[List] = None |
|
self.alerting_threshold: float = 300 |
|
self.alert_types: List[AlertType] = DEFAULT_ALERT_TYPES |
|
self.alert_to_webhook_url: Optional[dict] = None |
|
self.slack_alerting_instance: SlackAlerting = SlackAlerting( |
|
alerting_threshold=self.alerting_threshold, |
|
alerting=self.alerting, |
|
internal_usage_cache=self.internal_usage_cache.dual_cache, |
|
) |
|
self.premium_user = premium_user |
|
self.service_logging_obj = ServiceLogging() |
|
|
|
def startup_event( |
|
self, |
|
llm_router: Optional[Router], |
|
redis_usage_cache: Optional[RedisCache], |
|
): |
|
"""Initialize logging and alerting on proxy startup""" |
|
|
|
self.slack_alerting_instance.update_values(llm_router=llm_router) |
|
|
|
|
|
self.update_values( |
|
redis_cache=redis_usage_cache |
|
) |
|
|
|
self._init_litellm_callbacks( |
|
llm_router=llm_router |
|
) |
|
|
|
if ( |
|
self.slack_alerting_instance is not None |
|
and "daily_reports" in self.slack_alerting_instance.alert_types |
|
): |
|
asyncio.create_task( |
|
self.slack_alerting_instance._run_scheduled_daily_report( |
|
llm_router=llm_router |
|
) |
|
) |
|
|
|
def update_values( |
|
self, |
|
alerting: Optional[List] = None, |
|
alerting_threshold: Optional[float] = None, |
|
redis_cache: Optional[RedisCache] = None, |
|
alert_types: Optional[List[AlertType]] = None, |
|
alerting_args: Optional[dict] = None, |
|
alert_to_webhook_url: Optional[dict] = None, |
|
): |
|
updated_slack_alerting: bool = False |
|
if alerting is not None: |
|
self.alerting = alerting |
|
updated_slack_alerting = True |
|
if alerting_threshold is not None: |
|
self.alerting_threshold = alerting_threshold |
|
updated_slack_alerting = True |
|
if alert_types is not None: |
|
self.alert_types = alert_types |
|
updated_slack_alerting = True |
|
if alert_to_webhook_url is not None: |
|
self.alert_to_webhook_url = alert_to_webhook_url |
|
updated_slack_alerting = True |
|
|
|
if updated_slack_alerting is True: |
|
self.slack_alerting_instance.update_values( |
|
alerting=self.alerting, |
|
alerting_threshold=self.alerting_threshold, |
|
alert_types=self.alert_types, |
|
alerting_args=alerting_args, |
|
alert_to_webhook_url=self.alert_to_webhook_url, |
|
) |
|
|
|
if self.alerting is not None and "slack" in self.alerting: |
|
|
|
|
|
if "daily_reports" in self.alert_types: |
|
litellm.logging_callback_manager.add_litellm_callback(self.slack_alerting_instance) |
|
litellm.logging_callback_manager.add_litellm_success_callback( |
|
self.slack_alerting_instance.response_taking_too_long_callback |
|
) |
|
|
|
if redis_cache is not None: |
|
self.internal_usage_cache.dual_cache.redis_cache = redis_cache |
|
|
|
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None): |
|
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) |
|
litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) |
|
litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) |
|
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) |
|
for callback in litellm.callbacks: |
|
if isinstance(callback, str): |
|
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( |
|
callback, |
|
internal_usage_cache=self.internal_usage_cache.dual_cache, |
|
llm_router=llm_router, |
|
) |
|
if callback is None: |
|
continue |
|
if callback not in litellm.input_callback: |
|
litellm.input_callback.append(callback) |
|
if callback not in litellm.success_callback: |
|
litellm.logging_callback_manager.add_litellm_success_callback(callback) |
|
if callback not in litellm.failure_callback: |
|
litellm.logging_callback_manager.add_litellm_failure_callback(callback) |
|
if callback not in litellm._async_success_callback: |
|
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) |
|
if callback not in litellm._async_failure_callback: |
|
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) |
|
if callback not in litellm.service_callback: |
|
litellm.service_callback.append(callback) |
|
|
|
if ( |
|
len(litellm.input_callback) > 0 |
|
or len(litellm.success_callback) > 0 |
|
or len(litellm.failure_callback) > 0 |
|
): |
|
callback_list = list( |
|
set( |
|
litellm.input_callback |
|
+ litellm.success_callback |
|
+ litellm.failure_callback |
|
) |
|
) |
|
litellm.litellm_core_utils.litellm_logging.set_callbacks( |
|
callback_list=callback_list |
|
) |
|
|
|
async def update_request_status( |
|
self, litellm_call_id: str, status: Literal["success", "fail"] |
|
): |
|
|
|
if self.alerting is None: |
|
return |
|
|
|
|
|
alerting_threshold: float = self.alerting_threshold |
|
|
|
|
|
|
|
alerting_threshold += 100 |
|
|
|
await self.internal_usage_cache.async_set_cache( |
|
key="request_status:{}".format(litellm_call_id), |
|
value=status, |
|
local_only=True, |
|
ttl=alerting_threshold, |
|
litellm_parent_otel_span=None, |
|
) |
|
|
|
async def process_pre_call_hook_response(self, response, data, call_type): |
|
if isinstance(response, Exception): |
|
raise response |
|
if isinstance(response, dict): |
|
return response |
|
if isinstance(response, str): |
|
if call_type in ["completion", "text_completion"]: |
|
raise RejectedRequestError( |
|
message=response, |
|
model=data.get("model", ""), |
|
llm_provider="", |
|
request_data=data, |
|
) |
|
else: |
|
raise HTTPException(status_code=400, detail={"error": response}) |
|
return data |
|
|
|
|
|
@overload |
|
async def pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
data: None, |
|
call_type: Literal[ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
"pass_through_endpoint", |
|
"rerank", |
|
], |
|
) -> None: |
|
pass |
|
|
|
@overload |
|
async def pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
data: dict, |
|
call_type: Literal[ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
"pass_through_endpoint", |
|
"rerank", |
|
], |
|
) -> dict: |
|
pass |
|
|
|
async def pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
data: Optional[dict], |
|
call_type: Literal[ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
"pass_through_endpoint", |
|
"rerank", |
|
], |
|
) -> Optional[dict]: |
|
""" |
|
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. |
|
|
|
Covers: |
|
1. /chat/completions |
|
2. /embeddings |
|
3. /image/generation |
|
""" |
|
verbose_proxy_logger.debug("Inside Proxy Logging Pre-call hook!") |
|
|
|
self._init_response_taking_too_long_task(data=data) |
|
|
|
if data is None: |
|
return None |
|
|
|
try: |
|
for callback in litellm.callbacks: |
|
|
|
_callback = None |
|
if isinstance(callback, str): |
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( |
|
callback |
|
) |
|
else: |
|
_callback = callback |
|
if _callback is not None and isinstance(_callback, CustomGuardrail): |
|
from litellm.types.guardrails import GuardrailEventHooks |
|
|
|
if ( |
|
_callback.should_run_guardrail( |
|
data=data, event_type=GuardrailEventHooks.pre_call |
|
) |
|
is not True |
|
): |
|
continue |
|
|
|
response = await _callback.async_pre_call_hook( |
|
user_api_key_dict=user_api_key_dict, |
|
cache=self.call_details["user_api_key_cache"], |
|
data=data, |
|
call_type=call_type, |
|
) |
|
if response is not None: |
|
data = await self.process_pre_call_hook_response( |
|
response=response, data=data, call_type=call_type |
|
) |
|
|
|
elif ( |
|
_callback is not None |
|
and isinstance(_callback, CustomLogger) |
|
and "async_pre_call_hook" in vars(_callback.__class__) |
|
and _callback.__class__.async_pre_call_hook |
|
!= CustomLogger.async_pre_call_hook |
|
): |
|
response = await _callback.async_pre_call_hook( |
|
user_api_key_dict=user_api_key_dict, |
|
cache=self.call_details["user_api_key_cache"], |
|
data=data, |
|
call_type=call_type, |
|
) |
|
if response is not None: |
|
data = await self.process_pre_call_hook_response( |
|
response=response, data=data, call_type=call_type |
|
) |
|
|
|
return data |
|
except Exception as e: |
|
raise e |
|
|
|
async def during_call_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
call_type: Literal[ |
|
"completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
], |
|
): |
|
""" |
|
Runs the CustomGuardrail's async_moderation_hook() |
|
""" |
|
for callback in litellm.callbacks: |
|
try: |
|
if isinstance(callback, CustomGuardrail): |
|
|
|
|
|
|
|
|
|
|
|
if callback.event_hook is None and hasattr( |
|
callback, "moderation_check" |
|
): |
|
if callback.moderation_check == "pre_call": |
|
return |
|
else: |
|
|
|
from litellm.types.guardrails import GuardrailEventHooks |
|
|
|
if ( |
|
callback.should_run_guardrail( |
|
data=data, event_type=GuardrailEventHooks.during_call |
|
) |
|
is not True |
|
): |
|
continue |
|
await callback.async_moderation_hook( |
|
data=data, |
|
user_api_key_dict=user_api_key_dict, |
|
call_type=call_type, |
|
) |
|
except Exception as e: |
|
raise e |
|
return data |
|
|
|
async def failed_tracking_alert( |
|
self, |
|
error_message: str, |
|
failing_model: str, |
|
): |
|
if self.alerting is None: |
|
return |
|
|
|
if self.slack_alerting_instance: |
|
await self.slack_alerting_instance.failed_tracking_alert( |
|
error_message=error_message, |
|
failing_model=failing_model, |
|
) |
|
|
|
async def budget_alerts( |
|
self, |
|
type: Literal[ |
|
"token_budget", |
|
"user_budget", |
|
"soft_budget", |
|
"team_budget", |
|
"proxy_budget", |
|
"projected_limit_exceeded", |
|
], |
|
user_info: CallInfo, |
|
): |
|
if self.alerting is None: |
|
|
|
return |
|
await self.slack_alerting_instance.budget_alerts( |
|
type=type, |
|
user_info=user_info, |
|
) |
|
|
|
async def alerting_handler( |
|
self, |
|
message: str, |
|
level: Literal["Low", "Medium", "High"], |
|
alert_type: AlertType, |
|
request_data: Optional[dict] = None, |
|
): |
|
""" |
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298 |
|
|
|
- Responses taking too long |
|
- Requests are hanging |
|
- Calls are failing |
|
- DB Read/Writes are failing |
|
- Proxy Close to max budget |
|
- Key Close to max budget |
|
|
|
Parameters: |
|
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'. |
|
message: str - what is the alert about |
|
""" |
|
if self.alerting is None: |
|
return |
|
|
|
from datetime import datetime |
|
|
|
|
|
current_time = datetime.now().strftime("%H:%M:%S") |
|
_proxy_base_url = os.getenv("PROXY_BASE_URL", None) |
|
formatted_message = ( |
|
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}" |
|
) |
|
if _proxy_base_url is not None: |
|
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`" |
|
|
|
extra_kwargs = {} |
|
alerting_metadata = {} |
|
if request_data is not None: |
|
_url = await _add_langfuse_trace_id_to_alert(request_data=request_data) |
|
|
|
if _url is not None: |
|
extra_kwargs["🪢 Langfuse Trace"] = _url |
|
formatted_message += "\n\n🪢 Langfuse Trace: {}".format(_url) |
|
if ( |
|
"metadata" in request_data |
|
and request_data["metadata"].get("alerting_metadata", None) is not None |
|
and isinstance(request_data["metadata"]["alerting_metadata"], dict) |
|
): |
|
alerting_metadata = request_data["metadata"]["alerting_metadata"] |
|
for client in self.alerting: |
|
if client == "slack": |
|
await self.slack_alerting_instance.send_alert( |
|
message=message, |
|
level=level, |
|
alert_type=alert_type, |
|
user_info=None, |
|
alerting_metadata=alerting_metadata, |
|
**extra_kwargs, |
|
) |
|
elif client == "sentry": |
|
if litellm.utils.sentry_sdk_instance is not None: |
|
litellm.utils.sentry_sdk_instance.capture_message(formatted_message) |
|
else: |
|
raise Exception("Missing SENTRY_DSN from environment") |
|
|
|
async def failure_handler( |
|
self, original_exception, duration: float, call_type: str, traceback_str="" |
|
): |
|
""" |
|
Log failed db read/writes |
|
|
|
Currently only logs exceptions to sentry |
|
""" |
|
|
|
if AlertType.db_exceptions not in self.alert_types: |
|
return |
|
if isinstance(original_exception, HTTPException): |
|
if isinstance(original_exception.detail, str): |
|
error_message = original_exception.detail |
|
elif isinstance(original_exception.detail, dict): |
|
error_message = json.dumps(original_exception.detail) |
|
else: |
|
error_message = str(original_exception) |
|
else: |
|
error_message = str(original_exception) |
|
if isinstance(traceback_str, str): |
|
error_message += traceback_str[:1000] |
|
asyncio.create_task( |
|
self.alerting_handler( |
|
message=f"DB read/write call failed: {error_message}", |
|
level="High", |
|
alert_type=AlertType.db_exceptions, |
|
request_data={}, |
|
) |
|
) |
|
|
|
if hasattr(self, "service_logging_obj"): |
|
await self.service_logging_obj.async_service_failure_hook( |
|
service=ServiceTypes.DB, |
|
duration=duration, |
|
error=error_message, |
|
call_type=call_type, |
|
) |
|
|
|
if litellm.utils.capture_exception: |
|
litellm.utils.capture_exception(error=original_exception) |
|
|
|
async def post_call_failure_hook( |
|
self, |
|
request_data: dict, |
|
original_exception: Exception, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
error_type: Optional[ProxyErrorTypes] = None, |
|
route: Optional[str] = None, |
|
): |
|
""" |
|
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. |
|
|
|
Covers: |
|
1. /chat/completions |
|
2. /embeddings |
|
3. /image/generation |
|
""" |
|
|
|
|
|
await self.update_request_status( |
|
litellm_call_id=request_data.get("litellm_call_id", ""), status="fail" |
|
) |
|
if AlertType.llm_exceptions in self.alert_types and not isinstance( |
|
original_exception, HTTPException |
|
): |
|
""" |
|
Just alert on LLM API exceptions. Do not alert on user errors |
|
|
|
Related issue - https://github.com/BerriAI/litellm/issues/3395 |
|
""" |
|
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None) |
|
exception_str = str(original_exception) |
|
if litellm_debug_info is not None: |
|
exception_str += litellm_debug_info |
|
|
|
asyncio.create_task( |
|
self.alerting_handler( |
|
message=f"LLM API call failed: `{exception_str}`", |
|
level="High", |
|
alert_type=AlertType.llm_exceptions, |
|
request_data=request_data, |
|
) |
|
) |
|
|
|
|
|
if self._is_proxy_only_error( |
|
original_exception=original_exception, error_type=error_type |
|
): |
|
await self._handle_logging_proxy_only_error( |
|
request_data=request_data, |
|
user_api_key_dict=user_api_key_dict, |
|
route=route, |
|
original_exception=original_exception, |
|
) |
|
|
|
for callback in litellm.callbacks: |
|
try: |
|
_callback: Optional[CustomLogger] = None |
|
if isinstance(callback, str): |
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( |
|
callback |
|
) |
|
else: |
|
_callback = callback |
|
if _callback is not None and isinstance(_callback, CustomLogger): |
|
await _callback.async_post_call_failure_hook( |
|
request_data=request_data, |
|
user_api_key_dict=user_api_key_dict, |
|
original_exception=original_exception, |
|
) |
|
except Exception as e: |
|
raise e |
|
return |
|
|
|
def _is_proxy_only_error( |
|
self, |
|
original_exception: Exception, |
|
error_type: Optional[ProxyErrorTypes] = None, |
|
) -> bool: |
|
""" |
|
Return True if the error is a Proxy Only Error |
|
|
|
Prevents double logging of LLM API exceptions |
|
|
|
e.g should only return True for: |
|
- Authentication Errors from user_api_key_auth |
|
- HTTP HTTPException (rate limit errors) |
|
""" |
|
return isinstance(original_exception, HTTPException) or ( |
|
error_type == ProxyErrorTypes.auth_error |
|
) |
|
|
|
async def _handle_logging_proxy_only_error( |
|
self, |
|
request_data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
route: Optional[str] = None, |
|
original_exception: Optional[Exception] = None, |
|
): |
|
""" |
|
Handle logging for proxy only errors by calling `litellm_logging_obj.async_failure_handler` |
|
|
|
Is triggered when self._is_proxy_only_error() returns True |
|
""" |
|
litellm_logging_obj: Optional[Logging] = request_data.get( |
|
"litellm_logging_obj", None |
|
) |
|
if litellm_logging_obj is None: |
|
import uuid |
|
|
|
request_data["litellm_call_id"] = str(uuid.uuid4()) |
|
user_api_key_logged_metadata = ( |
|
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( |
|
user_api_key_dict=user_api_key_dict |
|
) |
|
) |
|
|
|
litellm_logging_obj, data = litellm.utils.function_setup( |
|
original_function=route or "IGNORE_THIS", |
|
rules_obj=litellm.utils.Rules(), |
|
start_time=datetime.now(), |
|
**request_data, |
|
) |
|
if "metadata" not in request_data: |
|
request_data["metadata"] = {} |
|
request_data["metadata"].update(user_api_key_logged_metadata) |
|
|
|
if litellm_logging_obj is not None: |
|
|
|
_optional_params = {} |
|
_litellm_params = {} |
|
|
|
litellm_param_keys = LoggedLiteLLMParams.__annotations__.keys() |
|
for k, v in request_data.items(): |
|
if k in litellm_param_keys: |
|
_litellm_params[k] = v |
|
elif k != "model" and k != "user": |
|
_optional_params[k] = v |
|
|
|
litellm_logging_obj.update_environment_variables( |
|
model=request_data.get("model", ""), |
|
user=request_data.get("user", ""), |
|
optional_params=_optional_params, |
|
litellm_params=_litellm_params, |
|
) |
|
|
|
input: Union[list, str, dict] = "" |
|
if "messages" in request_data and isinstance( |
|
request_data["messages"], list |
|
): |
|
input = request_data["messages"] |
|
litellm_logging_obj.model_call_details["messages"] = input |
|
litellm_logging_obj.call_type = CallTypes.acompletion.value |
|
elif "prompt" in request_data and isinstance(request_data["prompt"], str): |
|
input = request_data["prompt"] |
|
litellm_logging_obj.model_call_details["prompt"] = input |
|
litellm_logging_obj.call_type = CallTypes.atext_completion.value |
|
elif "input" in request_data and isinstance(request_data["input"], list): |
|
input = request_data["input"] |
|
litellm_logging_obj.model_call_details["input"] = input |
|
litellm_logging_obj.call_type = CallTypes.aembedding.value |
|
litellm_logging_obj.pre_call( |
|
input=input, |
|
api_key="", |
|
) |
|
|
|
|
|
await litellm_logging_obj.async_failure_handler( |
|
exception=original_exception, |
|
traceback_exception=traceback.format_exc(), |
|
) |
|
|
|
threading.Thread( |
|
target=litellm_logging_obj.failure_handler, |
|
args=( |
|
original_exception, |
|
traceback.format_exc(), |
|
), |
|
).start() |
|
|
|
async def post_call_success_hook( |
|
self, |
|
data: dict, |
|
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], |
|
user_api_key_dict: UserAPIKeyAuth, |
|
): |
|
""" |
|
Allow user to modify outgoing data |
|
|
|
Covers: |
|
1. /chat/completions |
|
""" |
|
|
|
for callback in litellm.callbacks: |
|
try: |
|
_callback: Optional[CustomLogger] = None |
|
if isinstance(callback, str): |
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( |
|
callback |
|
) |
|
else: |
|
_callback = callback |
|
|
|
if _callback is not None: |
|
|
|
|
|
if isinstance(callback, CustomGuardrail): |
|
|
|
from litellm.types.guardrails import GuardrailEventHooks |
|
|
|
if ( |
|
callback.should_run_guardrail( |
|
data=data, event_type=GuardrailEventHooks.post_call |
|
) |
|
is not True |
|
): |
|
continue |
|
|
|
await callback.async_post_call_success_hook( |
|
user_api_key_dict=user_api_key_dict, |
|
data=data, |
|
response=response, |
|
) |
|
|
|
|
|
|
|
elif isinstance(_callback, CustomLogger): |
|
await _callback.async_post_call_success_hook( |
|
user_api_key_dict=user_api_key_dict, |
|
data=data, |
|
response=response, |
|
) |
|
except Exception as e: |
|
raise e |
|
return response |
|
|
|
async def async_post_call_streaming_hook( |
|
self, |
|
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], |
|
user_api_key_dict: UserAPIKeyAuth, |
|
): |
|
""" |
|
Allow user to modify outgoing streaming data -> per chunk |
|
|
|
Covers: |
|
1. /chat/completions |
|
""" |
|
response_str: Optional[str] = None |
|
if isinstance(response, ModelResponse): |
|
response_str = litellm.get_response_string(response_obj=response) |
|
if response_str is not None: |
|
for callback in litellm.callbacks: |
|
try: |
|
_callback: Optional[CustomLogger] = None |
|
if isinstance(callback, str): |
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( |
|
callback |
|
) |
|
else: |
|
_callback = callback |
|
if _callback is not None and isinstance(_callback, CustomLogger): |
|
await _callback.async_post_call_streaming_hook( |
|
user_api_key_dict=user_api_key_dict, response=response_str |
|
) |
|
except Exception as e: |
|
raise e |
|
return response |
|
|
|
async def post_call_streaming_hook( |
|
self, |
|
response: str, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
): |
|
""" |
|
- Check outgoing streaming response uptil that point |
|
- Run through moderation check |
|
- Reject request if it fails moderation check |
|
""" |
|
new_response = copy.deepcopy(response) |
|
for callback in litellm.callbacks: |
|
try: |
|
if isinstance(callback, CustomLogger): |
|
await callback.async_post_call_streaming_hook( |
|
user_api_key_dict=user_api_key_dict, response=new_response |
|
) |
|
except Exception as e: |
|
raise e |
|
return new_response |
|
|
|
def _init_response_taking_too_long_task(self, data: Optional[dict] = None): |
|
""" |
|
Initialize the response taking too long task if user is using slack alerting |
|
|
|
Only run task if user is using slack alerting |
|
|
|
This handles checking for if a request is hanging for too long |
|
""" |
|
|
|
if ( |
|
self.slack_alerting_instance |
|
and self.slack_alerting_instance.alerting is not None |
|
): |
|
asyncio.create_task( |
|
self.slack_alerting_instance.response_taking_too_long(request_data=data) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def on_backoff(details): |
|
|
|
print_verbose(f"Backing off... this was attempt #{details['tries']}") |
|
|
|
|
|
def jsonify_object(data: dict) -> dict: |
|
db_data = copy.deepcopy(data) |
|
|
|
for k, v in db_data.items(): |
|
if isinstance(v, dict): |
|
try: |
|
db_data[k] = json.dumps(v) |
|
except Exception: |
|
|
|
db_data[k] = "failed-to-serialize-json" |
|
return db_data |
|
|
|
|
|
class PrismaClient: |
|
user_list_transactons: dict = {} |
|
end_user_list_transactons: dict = {} |
|
key_list_transactons: dict = {} |
|
team_list_transactons: dict = {} |
|
team_member_list_transactons: dict = {} |
|
org_list_transactons: dict = {} |
|
spend_log_transactions: List = [] |
|
|
|
def __init__( |
|
self, |
|
database_url: str, |
|
proxy_logging_obj: ProxyLogging, |
|
http_client: Optional[Any] = None, |
|
): |
|
|
|
self.proxy_logging_obj = proxy_logging_obj |
|
self.iam_token_db_auth: Optional[bool] = str_to_bool( |
|
os.getenv("IAM_TOKEN_DB_AUTH") |
|
) |
|
verbose_proxy_logger.debug("Creating Prisma Client..") |
|
try: |
|
from prisma import Prisma |
|
except Exception: |
|
raise Exception("Unable to find Prisma binaries.") |
|
if http_client is not None: |
|
self.db = PrismaWrapper( |
|
original_prisma=Prisma(http=http_client), |
|
iam_token_db_auth=( |
|
self.iam_token_db_auth |
|
if self.iam_token_db_auth is not None |
|
else False |
|
), |
|
) |
|
else: |
|
self.db = PrismaWrapper( |
|
original_prisma=Prisma(), |
|
iam_token_db_auth=( |
|
self.iam_token_db_auth |
|
if self.iam_token_db_auth is not None |
|
else False |
|
), |
|
) |
|
verbose_proxy_logger.debug("Success - Created Prisma Client") |
|
|
|
def hash_token(self, token: str): |
|
|
|
hashed_token = hashlib.sha256(token.encode()).hexdigest() |
|
|
|
return hashed_token |
|
|
|
def jsonify_object(self, data: dict) -> dict: |
|
db_data = copy.deepcopy(data) |
|
|
|
for k, v in db_data.items(): |
|
if isinstance(v, dict): |
|
try: |
|
db_data[k] = json.dumps(v) |
|
except Exception: |
|
|
|
db_data[k] = "failed-to-serialize-json" |
|
return db_data |
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def check_view_exists(self): |
|
""" |
|
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db. |
|
|
|
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth |
|
|
|
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month |
|
|
|
If the view doesn't exist, one will be created. |
|
""" |
|
|
|
|
|
|
|
|
|
try: |
|
expected_views = [ |
|
"LiteLLM_VerificationTokenView", |
|
"MonthlyGlobalSpend", |
|
"Last30dKeysBySpend", |
|
"Last30dModelsBySpend", |
|
"MonthlyGlobalSpendPerKey", |
|
"MonthlyGlobalSpendPerUserPerKey", |
|
"Last30dTopEndUsersSpend", |
|
"DailyTagSpend", |
|
] |
|
required_view = "LiteLLM_VerificationTokenView" |
|
expected_views_str = ", ".join(f"'{view}'" for view in expected_views) |
|
pg_schema = os.getenv("DATABASE_SCHEMA", "public") |
|
ret = await self.db.query_raw( |
|
f""" |
|
WITH existing_views AS ( |
|
SELECT viewname |
|
FROM pg_views |
|
WHERE schemaname = '{pg_schema}' AND viewname IN ( |
|
{expected_views_str} |
|
) |
|
) |
|
SELECT |
|
(SELECT COUNT(*) FROM existing_views) AS view_count, |
|
ARRAY_AGG(viewname) AS view_names |
|
FROM existing_views |
|
""" |
|
) |
|
expected_total_views = len(expected_views) |
|
if ret[0]["view_count"] == expected_total_views: |
|
verbose_proxy_logger.info("All necessary views exist!") |
|
return |
|
else: |
|
|
|
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]: |
|
await self.health_check() |
|
await self.db.execute_raw( |
|
""" |
|
CREATE VIEW "LiteLLM_VerificationTokenView" AS |
|
SELECT |
|
v.*, |
|
t.spend AS team_spend, |
|
t.max_budget AS team_max_budget, |
|
t.tpm_limit AS team_tpm_limit, |
|
t.rpm_limit AS team_rpm_limit |
|
FROM "LiteLLM_VerificationToken" v |
|
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; |
|
""" |
|
) |
|
|
|
verbose_proxy_logger.info( |
|
"LiteLLM_VerificationTokenView Created in DB!" |
|
) |
|
else: |
|
should_create_views = await should_create_missing_views(db=self.db) |
|
if should_create_views: |
|
await create_missing_views(db=self.db) |
|
else: |
|
|
|
|
|
ret_view_names_set = ( |
|
set(ret[0]["view_names"]) if ret[0]["view_names"] else set() |
|
) |
|
expected_views_set = set(expected_views) |
|
|
|
missing_views = expected_views_set - ret_view_names_set |
|
|
|
verbose_proxy_logger.warning( |
|
"\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format( |
|
missing_views |
|
) |
|
) |
|
|
|
except Exception: |
|
raise |
|
return |
|
|
|
@log_db_metrics |
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=1, |
|
max_time=2, |
|
on_backoff=on_backoff, |
|
) |
|
async def get_generic_data( |
|
self, |
|
key: str, |
|
value: Any, |
|
table_name: Literal["users", "keys", "config", "spend"], |
|
): |
|
""" |
|
Generic implementation of get data |
|
""" |
|
start_time = time.time() |
|
try: |
|
if table_name == "users": |
|
response = await self.db.litellm_usertable.find_first( |
|
where={key: value} |
|
) |
|
elif table_name == "keys": |
|
response = await self.db.litellm_verificationtoken.find_first( |
|
where={key: value} |
|
) |
|
elif table_name == "config": |
|
response = await self.db.litellm_config.find_first( |
|
where={key: value} |
|
) |
|
elif table_name == "spend": |
|
response = await self.db.l.find_first( |
|
where={key: value} |
|
) |
|
return response |
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}" |
|
verbose_proxy_logger.error(error_msg) |
|
error_msg = error_msg + "\nException Type: {}".format(type(e)) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
traceback_str=error_traceback, |
|
call_type="get_generic_data", |
|
) |
|
) |
|
|
|
raise e |
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
@log_db_metrics |
|
async def get_data( |
|
self, |
|
token: Optional[Union[str, list]] = None, |
|
user_id: Optional[str] = None, |
|
user_id_list: Optional[list] = None, |
|
team_id: Optional[str] = None, |
|
team_id_list: Optional[list] = None, |
|
key_val: Optional[dict] = None, |
|
table_name: Optional[ |
|
Literal[ |
|
"user", |
|
"key", |
|
"config", |
|
"spend", |
|
"team", |
|
"user_notification", |
|
"combined_view", |
|
] |
|
] = None, |
|
query_type: Literal["find_unique", "find_all"] = "find_unique", |
|
expires: Optional[datetime] = None, |
|
reset_at: Optional[datetime] = None, |
|
offset: Optional[int] = None, |
|
limit: Optional[ |
|
int |
|
] = None, |
|
parent_otel_span: Optional[Span] = None, |
|
proxy_logging_obj: Optional[ProxyLogging] = None, |
|
): |
|
args_passed_in = locals() |
|
start_time = time.time() |
|
hashed_token: Optional[str] = None |
|
try: |
|
response: Any = None |
|
if (token is not None and table_name is None) or ( |
|
table_name is not None and table_name == "key" |
|
): |
|
|
|
if token is not None: |
|
if isinstance(token, str): |
|
hashed_token = _hash_token_if_needed(token=token) |
|
verbose_proxy_logger.debug( |
|
f"PrismaClient: find_unique for token: {hashed_token}" |
|
) |
|
if query_type == "find_unique" and hashed_token is not None: |
|
if token is None: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={"error": f"No token passed in. Token={token}"}, |
|
) |
|
response = await self.db.litellm_verificationtoken.find_unique( |
|
where={"token": hashed_token}, |
|
include={"litellm_budget_table": True}, |
|
) |
|
if response is not None: |
|
|
|
if response.expires is not None and isinstance( |
|
response.expires, datetime |
|
): |
|
response.expires = response.expires.isoformat() |
|
else: |
|
|
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail=f"Authentication Error: invalid user key - user key does not exist in db. User Key={token}", |
|
) |
|
elif query_type == "find_all" and user_id is not None: |
|
response = await self.db.litellm_verificationtoken.find_many( |
|
where={"user_id": user_id}, |
|
include={"litellm_budget_table": True}, |
|
) |
|
if response is not None and len(response) > 0: |
|
for r in response: |
|
if isinstance(r.expires, datetime): |
|
r.expires = r.expires.isoformat() |
|
elif query_type == "find_all" and team_id is not None: |
|
response = await self.db.litellm_verificationtoken.find_many( |
|
where={"team_id": team_id}, |
|
include={"litellm_budget_table": True}, |
|
) |
|
if response is not None and len(response) > 0: |
|
for r in response: |
|
if isinstance(r.expires, datetime): |
|
r.expires = r.expires.isoformat() |
|
elif ( |
|
query_type == "find_all" |
|
and expires is not None |
|
and reset_at is not None |
|
): |
|
response = await self.db.litellm_verificationtoken.find_many( |
|
where={ |
|
"OR": [ |
|
{"expires": None}, |
|
{"expires": {"gt": expires}}, |
|
], |
|
"budget_reset_at": {"lt": reset_at}, |
|
} |
|
) |
|
if response is not None and len(response) > 0: |
|
for r in response: |
|
if isinstance(r.expires, datetime): |
|
r.expires = r.expires.isoformat() |
|
elif query_type == "find_all": |
|
where_filter: dict = {} |
|
if token is not None: |
|
where_filter["token"] = {} |
|
if isinstance(token, str): |
|
token = _hash_token_if_needed(token=token) |
|
where_filter["token"]["in"] = [token] |
|
elif isinstance(token, list): |
|
hashed_tokens = [] |
|
for t in token: |
|
assert isinstance(t, str) |
|
if t.startswith("sk-"): |
|
new_token = self.hash_token(token=t) |
|
hashed_tokens.append(new_token) |
|
else: |
|
hashed_tokens.append(t) |
|
where_filter["token"]["in"] = hashed_tokens |
|
response = await self.db.litellm_verificationtoken.find_many( |
|
order={"spend": "desc"}, |
|
where=where_filter, |
|
include={"litellm_budget_table": True}, |
|
) |
|
if response is not None: |
|
return response |
|
else: |
|
|
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Authentication Error: invalid user key - token does not exist", |
|
) |
|
elif (user_id is not None and table_name is None) or ( |
|
table_name is not None and table_name == "user" |
|
): |
|
if query_type == "find_unique": |
|
if key_val is None: |
|
key_val = {"user_id": user_id} |
|
response = await self.db.litellm_usertable.find_unique( |
|
where=key_val |
|
) |
|
elif query_type == "find_all" and key_val is not None: |
|
response = await self.db.litellm_usertable.find_many( |
|
where=key_val |
|
) |
|
elif query_type == "find_all" and reset_at is not None: |
|
response = await self.db.litellm_usertable.find_many( |
|
where={ |
|
"budget_reset_at": {"lt": reset_at}, |
|
} |
|
) |
|
elif query_type == "find_all" and user_id_list is not None: |
|
response = await self.db.litellm_usertable.find_many( |
|
where={"user_id": {"in": user_id_list}} |
|
) |
|
elif query_type == "find_all": |
|
if expires is not None: |
|
response = await self.db.litellm_usertable.find_many( |
|
order={"spend": "desc"}, |
|
where={ |
|
"OR": [ |
|
{"expires": None}, |
|
{"expires": {"gt": expires}}, |
|
], |
|
}, |
|
) |
|
else: |
|
|
|
sql_query = """ |
|
SELECT |
|
u.*, |
|
json_agg(v.key_alias) AS key_aliases |
|
FROM |
|
"LiteLLM_UserTable" u |
|
LEFT JOIN "LiteLLM_VerificationToken" v ON u.user_id = v.user_id |
|
GROUP BY |
|
u.user_id |
|
ORDER BY u.spend DESC |
|
LIMIT $1 |
|
OFFSET $2 |
|
""" |
|
response = await self.db.query_raw(sql_query, limit, offset) |
|
return response |
|
elif table_name == "spend": |
|
verbose_proxy_logger.debug( |
|
"PrismaClient: get_data: table_name == 'spend'" |
|
) |
|
if key_val is not None: |
|
if query_type == "find_unique": |
|
response = await self.db.litellm_spendlogs.find_unique( |
|
where={ |
|
key_val["key"]: key_val["value"], |
|
} |
|
) |
|
elif query_type == "find_all": |
|
response = await self.db.litellm_spendlogs.find_many( |
|
where={ |
|
key_val["key"]: key_val["value"], |
|
} |
|
) |
|
return response |
|
else: |
|
response = await self.db.litellm_spendlogs.find_many( |
|
order={"startTime": "desc"}, |
|
) |
|
return response |
|
elif table_name == "team": |
|
if query_type == "find_unique": |
|
response = await self.db.litellm_teamtable.find_unique( |
|
where={"team_id": team_id}, |
|
include={"litellm_model_table": True}, |
|
) |
|
elif query_type == "find_all" and reset_at is not None: |
|
response = await self.db.litellm_teamtable.find_many( |
|
where={ |
|
"budget_reset_at": {"lt": reset_at}, |
|
} |
|
) |
|
elif query_type == "find_all" and user_id is not None: |
|
response = await self.db.litellm_teamtable.find_many( |
|
where={ |
|
"members": {"has": user_id}, |
|
}, |
|
include={"litellm_budget_table": True}, |
|
) |
|
elif query_type == "find_all" and team_id_list is not None: |
|
response = await self.db.litellm_teamtable.find_many( |
|
where={"team_id": {"in": team_id_list}} |
|
) |
|
elif query_type == "find_all" and team_id_list is None: |
|
response = await self.db.litellm_teamtable.find_many(take=20) |
|
return response |
|
elif table_name == "user_notification": |
|
if query_type == "find_unique": |
|
response = await self.db.litellm_usernotifications.find_unique( |
|
where={"user_id": user_id} |
|
) |
|
elif query_type == "find_all": |
|
response = await self.db.litellm_usernotifications.find_many() |
|
return response |
|
elif table_name == "combined_view": |
|
|
|
if token is not None: |
|
if isinstance(token, str): |
|
hashed_token = _hash_token_if_needed(token=token) |
|
verbose_proxy_logger.debug( |
|
f"PrismaClient: find_unique for token: {hashed_token}" |
|
) |
|
if query_type == "find_unique": |
|
if token is None: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={"error": f"No token passed in. Token={token}"}, |
|
) |
|
|
|
sql_query = f""" |
|
SELECT |
|
v.*, |
|
t.spend AS team_spend, |
|
t.max_budget AS team_max_budget, |
|
t.tpm_limit AS team_tpm_limit, |
|
t.rpm_limit AS team_rpm_limit, |
|
t.models AS team_models, |
|
t.metadata AS team_metadata, |
|
t.blocked AS team_blocked, |
|
t.team_alias AS team_alias, |
|
t.metadata AS team_metadata, |
|
t.members_with_roles AS team_members_with_roles, |
|
t.organization_id as org_id, |
|
tm.spend AS team_member_spend, |
|
m.aliases AS team_model_aliases, |
|
-- Added comma to separate b.* columns |
|
b.max_budget AS litellm_budget_table_max_budget, |
|
b.tpm_limit AS litellm_budget_table_tpm_limit, |
|
b.rpm_limit AS litellm_budget_table_rpm_limit, |
|
b.model_max_budget as litellm_budget_table_model_max_budget, |
|
b.soft_budget as litellm_budget_table_soft_budget |
|
FROM "LiteLLM_VerificationToken" AS v |
|
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id |
|
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id |
|
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id |
|
LEFT JOIN "LiteLLM_BudgetTable" AS b ON v.budget_id = b.budget_id |
|
WHERE v.token = '{token}' |
|
""" |
|
|
|
print_verbose("sql_query being made={}".format(sql_query)) |
|
response = await self.db.query_first(query=sql_query) |
|
|
|
if response is not None: |
|
if response["team_models"] is None: |
|
response["team_models"] = [] |
|
if response["team_blocked"] is None: |
|
response["team_blocked"] = False |
|
|
|
team_member: Optional[Member] = None |
|
if ( |
|
response["team_members_with_roles"] is not None |
|
and response["user_id"] is not None |
|
): |
|
|
|
""" |
|
[ |
|
{ |
|
"role": "admin", |
|
"user_id": "default_user_id", |
|
"user_email": null |
|
}, |
|
{ |
|
"role": "user", |
|
"user_id": null, |
|
"user_email": "[email protected]" |
|
} |
|
] |
|
""" |
|
for tm in response["team_members_with_roles"]: |
|
if tm.get("user_id") is not None and response[ |
|
"user_id" |
|
] == tm.get("user_id"): |
|
team_member = Member(**tm) |
|
response["team_member"] = team_member |
|
response = LiteLLM_VerificationTokenView( |
|
**response, last_refreshed_at=time.time() |
|
) |
|
|
|
if response.expires is not None and isinstance( |
|
response.expires, datetime |
|
): |
|
response.expires = response.expires.isoformat() |
|
return response |
|
except Exception as e: |
|
import traceback |
|
|
|
prisma_query_info = f"LiteLLM Prisma Client Exception: Error with `get_data`. Args passed in: {args_passed_in}" |
|
error_msg = prisma_query_info + str(e) |
|
print_verbose(error_msg) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
verbose_proxy_logger.debug(error_traceback) |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
|
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="get_data", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
def jsonify_team_object(self, db_data: dict): |
|
db_data = self.jsonify_object(data=db_data) |
|
if db_data.get("members_with_roles", None) is not None and isinstance( |
|
db_data["members_with_roles"], list |
|
): |
|
db_data["members_with_roles"] = json.dumps(db_data["members_with_roles"]) |
|
return db_data |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def insert_data( |
|
self, |
|
data: dict, |
|
table_name: Literal[ |
|
"user", "key", "config", "spend", "team", "user_notification" |
|
], |
|
): |
|
""" |
|
Add a key to the database. If it already exists, do nothing. |
|
""" |
|
start_time = time.time() |
|
try: |
|
verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data) |
|
if table_name == "key": |
|
token = data["token"] |
|
hashed_token = self.hash_token(token=token) |
|
db_data = self.jsonify_object(data=data) |
|
db_data["token"] = hashed_token |
|
print_verbose( |
|
"PrismaClient: Before upsert into litellm_verificationtoken" |
|
) |
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( |
|
where={ |
|
"token": hashed_token, |
|
}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": {}, |
|
}, |
|
include={"litellm_budget_table": True}, |
|
) |
|
verbose_proxy_logger.info("Data Inserted into Keys Table") |
|
return new_verification_token |
|
elif table_name == "user": |
|
db_data = self.jsonify_object(data=data) |
|
try: |
|
new_user_row = await self.db.litellm_usertable.upsert( |
|
where={"user_id": data["user_id"]}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": {}, |
|
}, |
|
) |
|
except Exception as e: |
|
if ( |
|
"Foreign key constraint failed on the field: `LiteLLM_UserTable_organization_id_fkey (index)`" |
|
in str(e) |
|
): |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": f"Foreign Key Constraint failed. Organization ID={db_data['organization_id']} does not exist in LiteLLM_OrganizationTable. Create via `/organization/new`." |
|
}, |
|
) |
|
raise e |
|
verbose_proxy_logger.info("Data Inserted into User Table") |
|
return new_user_row |
|
elif table_name == "team": |
|
db_data = self.jsonify_object(data=data) |
|
if db_data.get("members_with_roles", None) is not None and isinstance( |
|
db_data["members_with_roles"], list |
|
): |
|
db_data["members_with_roles"] = json.dumps( |
|
db_data["members_with_roles"] |
|
) |
|
new_team_row = await self.db.litellm_teamtable.upsert( |
|
where={"team_id": data["team_id"]}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": {}, |
|
}, |
|
) |
|
verbose_proxy_logger.info("Data Inserted into Team Table") |
|
return new_team_row |
|
elif table_name == "config": |
|
""" |
|
For each param, |
|
get the existing table values |
|
|
|
Add the new values |
|
|
|
Update DB |
|
""" |
|
tasks = [] |
|
for k, v in data.items(): |
|
updated_data = v |
|
updated_data = json.dumps(updated_data) |
|
updated_table_row = self.db.litellm_config.upsert( |
|
where={"param_name": k}, |
|
data={ |
|
"create": {"param_name": k, "param_value": updated_data}, |
|
"update": {"param_value": updated_data}, |
|
}, |
|
) |
|
|
|
tasks.append(updated_table_row) |
|
await asyncio.gather(*tasks) |
|
verbose_proxy_logger.info("Data Inserted into Config Table") |
|
elif table_name == "spend": |
|
db_data = self.jsonify_object(data=data) |
|
new_spend_row = await self.db.litellm_spendlogs.upsert( |
|
where={"request_id": data["request_id"]}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": {}, |
|
}, |
|
) |
|
verbose_proxy_logger.info("Data Inserted into Spend Table") |
|
return new_spend_row |
|
elif table_name == "user_notification": |
|
db_data = self.jsonify_object(data=data) |
|
new_user_notification_row = ( |
|
await self.db.litellm_usernotifications.upsert( |
|
where={"request_id": data["request_id"]}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": {}, |
|
}, |
|
) |
|
) |
|
verbose_proxy_logger.info("Data Inserted into Model Request Table") |
|
return new_user_notification_row |
|
|
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}" |
|
print_verbose(error_msg) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="insert_data", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def update_data( |
|
self, |
|
token: Optional[str] = None, |
|
data: dict = {}, |
|
data_list: Optional[List] = None, |
|
user_id: Optional[str] = None, |
|
team_id: Optional[str] = None, |
|
query_type: Literal["update", "update_many"] = "update", |
|
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, |
|
update_key_values: Optional[dict] = None, |
|
update_key_values_custom_query: Optional[dict] = None, |
|
): |
|
""" |
|
Update existing data |
|
""" |
|
verbose_proxy_logger.debug( |
|
f"PrismaClient: update_data, table_name: {table_name}" |
|
) |
|
start_time = time.time() |
|
try: |
|
db_data = self.jsonify_object(data=data) |
|
if update_key_values is not None: |
|
update_key_values = self.jsonify_object(data=update_key_values) |
|
if token is not None: |
|
print_verbose(f"token: {token}") |
|
|
|
token = _hash_token_if_needed(token=token) |
|
db_data["token"] = token |
|
response = await self.db.litellm_verificationtoken.update( |
|
where={"token": token}, |
|
data={**db_data}, |
|
) |
|
verbose_proxy_logger.debug( |
|
"\033[91m" |
|
+ f"DB Token Table update succeeded {response}" |
|
+ "\033[0m" |
|
) |
|
_data: dict = {} |
|
if response is not None: |
|
try: |
|
_data = response.model_dump() |
|
except Exception: |
|
_data = response.dict() |
|
return {"token": token, "data": _data} |
|
elif ( |
|
user_id is not None |
|
or (table_name is not None and table_name == "user") |
|
and query_type == "update" |
|
): |
|
""" |
|
If data['spend'] + data['user'], update the user table with spend info as well |
|
""" |
|
if user_id is None: |
|
user_id = db_data["user_id"] |
|
if update_key_values is None: |
|
if update_key_values_custom_query is not None: |
|
update_key_values = update_key_values_custom_query |
|
else: |
|
update_key_values = db_data |
|
update_user_row = await self.db.litellm_usertable.upsert( |
|
where={"user_id": user_id}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": { |
|
**update_key_values |
|
}, |
|
}, |
|
) |
|
verbose_proxy_logger.info( |
|
"\033[91m" |
|
+ f"DB User Table - update succeeded {update_user_row}" |
|
+ "\033[0m" |
|
) |
|
return {"user_id": user_id, "data": update_user_row} |
|
elif ( |
|
team_id is not None |
|
or (table_name is not None and table_name == "team") |
|
and query_type == "update" |
|
): |
|
""" |
|
If data['spend'] + data['user'], update the user table with spend info as well |
|
""" |
|
if team_id is None: |
|
team_id = db_data["team_id"] |
|
if update_key_values is None: |
|
update_key_values = db_data |
|
if "team_id" not in db_data and team_id is not None: |
|
db_data["team_id"] = team_id |
|
if "members_with_roles" in db_data and isinstance( |
|
db_data["members_with_roles"], list |
|
): |
|
db_data["members_with_roles"] = json.dumps( |
|
db_data["members_with_roles"] |
|
) |
|
if "members_with_roles" in update_key_values and isinstance( |
|
update_key_values["members_with_roles"], list |
|
): |
|
update_key_values["members_with_roles"] = json.dumps( |
|
update_key_values["members_with_roles"] |
|
) |
|
update_team_row = await self.db.litellm_teamtable.upsert( |
|
where={"team_id": team_id}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": { |
|
**update_key_values |
|
}, |
|
}, |
|
) |
|
verbose_proxy_logger.info( |
|
"\033[91m" |
|
+ f"DB Team Table - update succeeded {update_team_row}" |
|
+ "\033[0m" |
|
) |
|
return {"team_id": team_id, "data": update_team_row} |
|
elif ( |
|
table_name is not None |
|
and table_name == "key" |
|
and query_type == "update_many" |
|
and data_list is not None |
|
and isinstance(data_list, list) |
|
): |
|
""" |
|
Batch write update queries |
|
""" |
|
batcher = self.db.batch_() |
|
for idx, t in enumerate(data_list): |
|
|
|
if t.token.startswith("sk-"): |
|
t.token = self.hash_token(token=t.token) |
|
try: |
|
data_json = self.jsonify_object( |
|
data=t.model_dump(exclude_none=True) |
|
) |
|
except Exception: |
|
data_json = self.jsonify_object(data=t.dict(exclude_none=True)) |
|
batcher.litellm_verificationtoken.update( |
|
where={"token": t.token}, |
|
data={**data_json}, |
|
) |
|
await batcher.commit() |
|
print_verbose( |
|
"\033[91m" + "DB Token Table update succeeded" + "\033[0m" |
|
) |
|
elif ( |
|
table_name is not None |
|
and table_name == "user" |
|
and query_type == "update_many" |
|
and data_list is not None |
|
and isinstance(data_list, list) |
|
): |
|
""" |
|
Batch write update queries |
|
""" |
|
batcher = self.db.batch_() |
|
for idx, user in enumerate(data_list): |
|
try: |
|
data_json = self.jsonify_object( |
|
data=user.model_dump(exclude_none=True) |
|
) |
|
except Exception: |
|
data_json = self.jsonify_object(data=user.dict()) |
|
batcher.litellm_usertable.upsert( |
|
where={"user_id": user.user_id}, |
|
data={ |
|
"create": {**data_json}, |
|
"update": { |
|
**data_json |
|
}, |
|
}, |
|
) |
|
await batcher.commit() |
|
verbose_proxy_logger.info( |
|
"\033[91m" + "DB User Table Batch update succeeded" + "\033[0m" |
|
) |
|
elif ( |
|
table_name is not None |
|
and table_name == "team" |
|
and query_type == "update_many" |
|
and data_list is not None |
|
and isinstance(data_list, list) |
|
): |
|
|
|
batcher = self.db.batch_() |
|
for idx, team in enumerate(data_list): |
|
try: |
|
data_json = self.jsonify_object( |
|
data=team.model_dump(exclude_none=True) |
|
) |
|
except Exception: |
|
data_json = self.jsonify_object( |
|
data=team.dict(exclude_none=True) |
|
) |
|
batcher.litellm_teamtable.upsert( |
|
where={"team_id": team.team_id}, |
|
data={ |
|
"create": {**data_json}, |
|
"update": { |
|
**data_json |
|
}, |
|
}, |
|
) |
|
await batcher.commit() |
|
verbose_proxy_logger.info( |
|
"\033[91m" + "DB Team Table Batch update succeeded" + "\033[0m" |
|
) |
|
|
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}" |
|
print_verbose(error_msg) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="update_data", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def delete_data( |
|
self, |
|
tokens: Optional[List] = None, |
|
team_id_list: Optional[List] = None, |
|
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, |
|
user_id: Optional[str] = None, |
|
): |
|
""" |
|
Allow user to delete a key(s) |
|
|
|
Ensure user owns that key, unless admin. |
|
""" |
|
start_time = time.time() |
|
try: |
|
if tokens is not None and isinstance(tokens, List): |
|
hashed_tokens = [] |
|
for token in tokens: |
|
if isinstance(token, str) and token.startswith("sk-"): |
|
hashed_token = self.hash_token(token=token) |
|
else: |
|
hashed_token = token |
|
hashed_tokens.append(hashed_token) |
|
filter_query: dict = {} |
|
if user_id is not None: |
|
filter_query = { |
|
"AND": [{"token": {"in": hashed_tokens}}, {"user_id": user_id}] |
|
} |
|
else: |
|
filter_query = {"token": {"in": hashed_tokens}} |
|
|
|
deleted_tokens = await self.db.litellm_verificationtoken.delete_many( |
|
where=filter_query |
|
) |
|
verbose_proxy_logger.debug("deleted_tokens: %s", deleted_tokens) |
|
return {"deleted_keys": deleted_tokens} |
|
elif ( |
|
table_name == "team" |
|
and team_id_list is not None |
|
and isinstance(team_id_list, List) |
|
): |
|
|
|
await self.db.litellm_teamtable.delete_many( |
|
where={"team_id": {"in": team_id_list}} |
|
) |
|
return {"deleted_teams": team_id_list} |
|
elif ( |
|
table_name == "key" |
|
and team_id_list is not None |
|
and isinstance(team_id_list, List) |
|
): |
|
|
|
await self.db.litellm_verificationtoken.delete_many( |
|
where={"team_id": {"in": team_id_list}} |
|
) |
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}" |
|
print_verbose(error_msg) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="delete_data", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def connect(self): |
|
start_time = time.time() |
|
try: |
|
verbose_proxy_logger.debug( |
|
"PrismaClient: connect() called Attempting to Connect to DB" |
|
) |
|
if self.db.is_connected() is False: |
|
verbose_proxy_logger.debug( |
|
"PrismaClient: DB not connected, Attempting to Connect to DB" |
|
) |
|
await self.db.connect() |
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}" |
|
print_verbose(error_msg) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="connect", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def disconnect(self): |
|
start_time = time.time() |
|
try: |
|
await self.db.disconnect() |
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}" |
|
print_verbose(error_msg) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="disconnect", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
async def health_check(self): |
|
""" |
|
Health check endpoint for the prisma client |
|
""" |
|
start_time = time.time() |
|
try: |
|
sql_query = "SELECT 1" |
|
|
|
|
|
|
|
response = await self.db.query_raw(sql_query) |
|
return response |
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}" |
|
print_verbose(error_msg) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="health_check", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
async def _get_spend_logs_row_count(self) -> int: |
|
try: |
|
sql_query = """ |
|
SELECT reltuples::BIGINT |
|
FROM pg_class |
|
WHERE oid = '"LiteLLM_SpendLogs"'::regclass; |
|
""" |
|
result = await self.db.query_raw(query=sql_query) |
|
return result[0]["reltuples"] |
|
except Exception as e: |
|
verbose_proxy_logger.error( |
|
f"Error getting LiteLLM_SpendLogs row count: {e}" |
|
) |
|
return 0 |
|
|
|
async def _set_spend_logs_row_count_in_proxy_state(self) -> None: |
|
""" |
|
Set the `LiteLLM_SpendLogs`row count in proxy state. |
|
|
|
This is used later to determine if we should run expensive UI Usage queries. |
|
""" |
|
from litellm.proxy.proxy_server import proxy_state |
|
|
|
_num_spend_logs_rows = await self._get_spend_logs_row_count() |
|
proxy_state.set_proxy_state_variable( |
|
variable_name="spend_logs_row_count", |
|
value=_num_spend_logs_rows, |
|
) |
|
|
|
|
|
|
|
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: |
|
module_name = value |
|
instance_name = None |
|
try: |
|
|
|
parts = value.split(".") |
|
|
|
|
|
module_name = ".".join(parts[:-1]) |
|
instance_name = parts[-1] |
|
|
|
|
|
if config_file_path is not None: |
|
directory = os.path.dirname(config_file_path) |
|
module_file_path = os.path.join(directory, *module_name.split(".")) |
|
module_file_path += ".py" |
|
|
|
spec = importlib.util.spec_from_file_location(module_name, module_file_path) |
|
if spec is None: |
|
raise ImportError( |
|
f"Could not find a module specification for {module_file_path}" |
|
) |
|
module = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(module) |
|
else: |
|
|
|
module = importlib.import_module(module_name) |
|
|
|
|
|
instance = getattr(module, instance_name) |
|
|
|
return instance |
|
except ImportError as e: |
|
|
|
if instance_name and module_name: |
|
raise ImportError( |
|
f"Could not import {instance_name} from {module_name}" |
|
) from e |
|
else: |
|
raise e |
|
except Exception as e: |
|
raise e |
|
|
|
|
|
|
|
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): |
|
""" |
|
Check if a user_id exists in cache, |
|
if not retrieve it. |
|
""" |
|
cache_key = f"{user_id}_user_api_key_user_id" |
|
response = cache.get_cache(key=cache_key) |
|
if response is None: |
|
user_row = await db.get_data(user_id=user_id) |
|
if user_row is not None: |
|
print_verbose(f"User Row: {user_row}, type = {type(user_row)}") |
|
if hasattr(user_row, "model_dump_json") and callable( |
|
getattr(user_row, "model_dump_json") |
|
): |
|
cache_value = user_row.model_dump_json() |
|
cache.set_cache( |
|
key=cache_key, value=cache_value, ttl=600 |
|
) |
|
return |
|
|
|
|
|
async def send_email(receiver_email, subject, html): |
|
""" |
|
smtp_host, |
|
smtp_port, |
|
smtp_username, |
|
smtp_password, |
|
sender_name, |
|
sender_email, |
|
""" |
|
|
|
|
|
smtp_host = os.getenv("SMTP_HOST") |
|
smtp_port = int(os.getenv("SMTP_PORT", "587")) |
|
smtp_username = os.getenv("SMTP_USERNAME") |
|
smtp_password = os.getenv("SMTP_PASSWORD") |
|
sender_email = os.getenv("SMTP_SENDER_EMAIL", None) |
|
if sender_email is None: |
|
raise ValueError("Trying to use SMTP, but SMTP_SENDER_EMAIL is not set") |
|
|
|
|
|
email_message = MIMEMultipart() |
|
email_message["From"] = sender_email |
|
email_message["To"] = receiver_email |
|
email_message["Subject"] = subject |
|
verbose_proxy_logger.debug( |
|
"sending email from %s to %s", sender_email, receiver_email |
|
) |
|
|
|
if smtp_host is None: |
|
raise ValueError("Trying to use SMTP, but SMTP_HOST is not set") |
|
|
|
if smtp_username is None: |
|
raise ValueError("Trying to use SMTP, but SMTP_USERNAME is not set") |
|
|
|
if smtp_password is None: |
|
raise ValueError("Trying to use SMTP, but SMTP_PASSWORD is not set") |
|
|
|
|
|
email_message.attach(MIMEText(html, "html")) |
|
|
|
try: |
|
|
|
with smtplib.SMTP(smtp_host, smtp_port) as server: |
|
if os.getenv("SMTP_TLS", "True") != "False": |
|
server.starttls() |
|
|
|
|
|
server.login(smtp_username, smtp_password) |
|
|
|
|
|
server.send_message(email_message) |
|
|
|
except Exception as e: |
|
print_verbose("An error occurred while sending the email:" + str(e)) |
|
|
|
|
|
def hash_token(token: str): |
|
import hashlib |
|
|
|
|
|
hashed_token = hashlib.sha256(token.encode()).hexdigest() |
|
|
|
return hashed_token |
|
|
|
|
|
def _hash_token_if_needed(token: str) -> str: |
|
""" |
|
Hash the token if it's a string and starts with "sk-" |
|
|
|
Else return the token as is |
|
""" |
|
if token.startswith("sk-"): |
|
return hash_token(token=token) |
|
else: |
|
return token |
|
|
|
|
|
async def reset_budget(prisma_client: PrismaClient): |
|
""" |
|
Gets all the non-expired keys for a db, which need spend to be reset |
|
|
|
Resets their spend |
|
|
|
Updates db |
|
""" |
|
if prisma_client is not None: |
|
|
|
now = datetime.utcnow() |
|
keys_to_reset = await prisma_client.get_data( |
|
table_name="key", query_type="find_all", expires=now, reset_at=now |
|
) |
|
|
|
if keys_to_reset is not None and len(keys_to_reset) > 0: |
|
for key in keys_to_reset: |
|
key.spend = 0.0 |
|
duration_s = duration_in_seconds(duration=key.budget_duration) |
|
key.budget_reset_at = now + timedelta(seconds=duration_s) |
|
|
|
await prisma_client.update_data( |
|
query_type="update_many", data_list=keys_to_reset, table_name="key" |
|
) |
|
|
|
|
|
now = datetime.utcnow() |
|
users_to_reset = await prisma_client.get_data( |
|
table_name="user", query_type="find_all", reset_at=now |
|
) |
|
|
|
if users_to_reset is not None and len(users_to_reset) > 0: |
|
for user in users_to_reset: |
|
user.spend = 0.0 |
|
duration_s = duration_in_seconds(duration=user.budget_duration) |
|
user.budget_reset_at = now + timedelta(seconds=duration_s) |
|
|
|
await prisma_client.update_data( |
|
query_type="update_many", data_list=users_to_reset, table_name="user" |
|
) |
|
|
|
|
|
now = datetime.utcnow() |
|
teams_to_reset = await prisma_client.get_data( |
|
table_name="team", |
|
query_type="find_all", |
|
reset_at=now, |
|
) |
|
|
|
if teams_to_reset is not None and len(teams_to_reset) > 0: |
|
team_reset_requests = [] |
|
for team in teams_to_reset: |
|
duration_s = duration_in_seconds(duration=team.budget_duration) |
|
reset_team_budget_request = ResetTeamBudgetRequest( |
|
team_id=team.team_id, |
|
spend=0.0, |
|
budget_reset_at=now + timedelta(seconds=duration_s), |
|
updated_at=now, |
|
) |
|
team_reset_requests.append(reset_team_budget_request) |
|
await prisma_client.update_data( |
|
query_type="update_many", |
|
data_list=team_reset_requests, |
|
table_name="team", |
|
) |
|
|
|
|
|
class ProxyUpdateSpend: |
|
@staticmethod |
|
async def update_end_user_spend( |
|
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging |
|
): |
|
for i in range(n_retry_times + 1): |
|
start_time = time.time() |
|
try: |
|
async with prisma_client.db.tx( |
|
timeout=timedelta(seconds=60) |
|
) as transaction: |
|
async with transaction.batch_() as batcher: |
|
for ( |
|
end_user_id, |
|
response_cost, |
|
) in prisma_client.end_user_list_transactons.items(): |
|
if litellm.max_end_user_budget is not None: |
|
pass |
|
batcher.litellm_endusertable.upsert( |
|
where={"user_id": end_user_id}, |
|
data={ |
|
"create": { |
|
"user_id": end_user_id, |
|
"spend": response_cost, |
|
"blocked": False, |
|
}, |
|
"update": {"spend": {"increment": response_cost}}, |
|
}, |
|
) |
|
|
|
break |
|
except DB_CONNECTION_ERROR_TYPES as e: |
|
if i >= n_retry_times: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
await asyncio.sleep(2**i) |
|
except Exception as e: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
finally: |
|
prisma_client.end_user_list_transactons = ( |
|
{} |
|
) |
|
|
|
@staticmethod |
|
async def update_spend_logs( |
|
n_retry_times: int, |
|
prisma_client: PrismaClient, |
|
db_writer_client: Optional[HTTPHandler], |
|
proxy_logging_obj: ProxyLogging, |
|
): |
|
BATCH_SIZE = 100 |
|
MAX_LOGS_PER_INTERVAL = ( |
|
1000 |
|
) |
|
|
|
logs_to_process = prisma_client.spend_log_transactions[:MAX_LOGS_PER_INTERVAL] |
|
start_time = time.time() |
|
try: |
|
for i in range(n_retry_times + 1): |
|
try: |
|
base_url = os.getenv("SPEND_LOGS_URL", None) |
|
if ( |
|
len(logs_to_process) > 0 |
|
and base_url is not None |
|
and db_writer_client is not None |
|
): |
|
if not base_url.endswith("/"): |
|
base_url += "/" |
|
verbose_proxy_logger.debug("base_url: {}".format(base_url)) |
|
response = await db_writer_client.post( |
|
url=base_url + "spend/update", |
|
data=json.dumps(logs_to_process), |
|
headers={"Content-Type": "application/json"}, |
|
) |
|
if response.status_code == 200: |
|
prisma_client.spend_log_transactions = ( |
|
prisma_client.spend_log_transactions[ |
|
len(logs_to_process) : |
|
] |
|
) |
|
else: |
|
for j in range(0, len(logs_to_process), BATCH_SIZE): |
|
batch = logs_to_process[j : j + BATCH_SIZE] |
|
batch_with_dates = [ |
|
prisma_client.jsonify_object({**entry}) |
|
for entry in batch |
|
] |
|
await prisma_client.db.litellm_spendlogs.create_many( |
|
data=batch_with_dates, skip_duplicates=True |
|
) |
|
verbose_proxy_logger.debug( |
|
f"Flushed {len(batch)} logs to the DB." |
|
) |
|
|
|
prisma_client.spend_log_transactions = ( |
|
prisma_client.spend_log_transactions[len(logs_to_process) :] |
|
) |
|
verbose_proxy_logger.debug( |
|
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}" |
|
) |
|
break |
|
except DB_CONNECTION_ERROR_TYPES: |
|
if i is None: |
|
i = 0 |
|
if i >= n_retry_times: |
|
raise |
|
await asyncio.sleep(2**i) |
|
except Exception as e: |
|
prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[ |
|
len(logs_to_process) : |
|
] |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
|
|
async def update_spend( |
|
prisma_client: PrismaClient, |
|
db_writer_client: Optional[HTTPHandler], |
|
proxy_logging_obj: ProxyLogging, |
|
): |
|
""" |
|
Batch write updates to db. |
|
|
|
Triggered every minute. |
|
|
|
Requires: |
|
user_id_list: dict, |
|
keys_list: list, |
|
team_list: list, |
|
spend_logs: list, |
|
""" |
|
n_retry_times = 3 |
|
i = None |
|
|
|
if len(prisma_client.user_list_transactons.keys()) > 0: |
|
for i in range(n_retry_times + 1): |
|
start_time = time.time() |
|
try: |
|
async with prisma_client.db.tx( |
|
timeout=timedelta(seconds=60) |
|
) as transaction: |
|
async with transaction.batch_() as batcher: |
|
for ( |
|
user_id, |
|
response_cost, |
|
) in prisma_client.user_list_transactons.items(): |
|
batcher.litellm_usertable.update_many( |
|
where={"user_id": user_id}, |
|
data={"spend": {"increment": response_cost}}, |
|
) |
|
prisma_client.user_list_transactons = ( |
|
{} |
|
) |
|
break |
|
except DB_CONNECTION_ERROR_TYPES as e: |
|
if i >= n_retry_times: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
await asyncio.sleep(2**i) |
|
except Exception as e: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
|
|
verbose_proxy_logger.debug( |
|
"End-User Spend transactions: {}".format( |
|
len(prisma_client.end_user_list_transactons.keys()) |
|
) |
|
) |
|
if len(prisma_client.end_user_list_transactons.keys()) > 0: |
|
await ProxyUpdateSpend.update_end_user_spend( |
|
n_retry_times=n_retry_times, |
|
prisma_client=prisma_client, |
|
proxy_logging_obj=proxy_logging_obj, |
|
) |
|
|
|
verbose_proxy_logger.debug( |
|
"KEY Spend transactions: {}".format( |
|
len(prisma_client.key_list_transactons.keys()) |
|
) |
|
) |
|
if len(prisma_client.key_list_transactons.keys()) > 0: |
|
for i in range(n_retry_times + 1): |
|
start_time = time.time() |
|
try: |
|
async with prisma_client.db.tx( |
|
timeout=timedelta(seconds=60) |
|
) as transaction: |
|
async with transaction.batch_() as batcher: |
|
for ( |
|
token, |
|
response_cost, |
|
) in prisma_client.key_list_transactons.items(): |
|
batcher.litellm_verificationtoken.update_many( |
|
where={"token": token}, |
|
data={"spend": {"increment": response_cost}}, |
|
) |
|
prisma_client.key_list_transactons = ( |
|
{} |
|
) |
|
break |
|
except DB_CONNECTION_ERROR_TYPES as e: |
|
if i >= n_retry_times: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
await asyncio.sleep(2**i) |
|
except Exception as e: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
|
|
verbose_proxy_logger.debug( |
|
"Team Spend transactions: {}".format( |
|
len(prisma_client.team_list_transactons.keys()) |
|
) |
|
) |
|
if len(prisma_client.team_list_transactons.keys()) > 0: |
|
for i in range(n_retry_times + 1): |
|
start_time = time.time() |
|
try: |
|
async with prisma_client.db.tx( |
|
timeout=timedelta(seconds=60) |
|
) as transaction: |
|
async with transaction.batch_() as batcher: |
|
for ( |
|
team_id, |
|
response_cost, |
|
) in prisma_client.team_list_transactons.items(): |
|
verbose_proxy_logger.debug( |
|
"Updating spend for team id={} by {}".format( |
|
team_id, response_cost |
|
) |
|
) |
|
batcher.litellm_teamtable.update_many( |
|
where={"team_id": team_id}, |
|
data={"spend": {"increment": response_cost}}, |
|
) |
|
prisma_client.team_list_transactons = ( |
|
{} |
|
) |
|
break |
|
except DB_CONNECTION_ERROR_TYPES as e: |
|
if i >= n_retry_times: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
await asyncio.sleep(2**i) |
|
except Exception as e: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
|
|
if len(prisma_client.team_member_list_transactons.keys()) > 0: |
|
for i in range(n_retry_times + 1): |
|
start_time = time.time() |
|
try: |
|
async with prisma_client.db.tx( |
|
timeout=timedelta(seconds=60) |
|
) as transaction: |
|
async with transaction.batch_() as batcher: |
|
for ( |
|
key, |
|
response_cost, |
|
) in prisma_client.team_member_list_transactons.items(): |
|
|
|
team_id = key.split("::")[1] |
|
user_id = key.split("::")[3] |
|
|
|
batcher.litellm_teammembership.update_many( |
|
where={"team_id": team_id, "user_id": user_id}, |
|
data={"spend": {"increment": response_cost}}, |
|
) |
|
prisma_client.team_member_list_transactons = ( |
|
{} |
|
) |
|
break |
|
except DB_CONNECTION_ERROR_TYPES as e: |
|
if i >= n_retry_times: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
await asyncio.sleep(2**i) |
|
except Exception as e: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
|
|
if len(prisma_client.org_list_transactons.keys()) > 0: |
|
for i in range(n_retry_times + 1): |
|
start_time = time.time() |
|
try: |
|
async with prisma_client.db.tx( |
|
timeout=timedelta(seconds=60) |
|
) as transaction: |
|
async with transaction.batch_() as batcher: |
|
for ( |
|
org_id, |
|
response_cost, |
|
) in prisma_client.org_list_transactons.items(): |
|
batcher.litellm_organizationtable.update_many( |
|
where={"organization_id": org_id}, |
|
data={"spend": {"increment": response_cost}}, |
|
) |
|
prisma_client.org_list_transactons = ( |
|
{} |
|
) |
|
break |
|
except DB_CONNECTION_ERROR_TYPES as e: |
|
if i >= n_retry_times: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
await asyncio.sleep(2**i) |
|
except Exception as e: |
|
_raise_failed_update_spend_exception( |
|
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj |
|
) |
|
|
|
|
|
verbose_proxy_logger.debug( |
|
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions)) |
|
) |
|
|
|
if len(prisma_client.spend_log_transactions) > 0: |
|
await ProxyUpdateSpend.update_spend_logs( |
|
n_retry_times=n_retry_times, |
|
prisma_client=prisma_client, |
|
proxy_logging_obj=proxy_logging_obj, |
|
db_writer_client=db_writer_client, |
|
) |
|
|
|
|
|
def _raise_failed_update_spend_exception( |
|
e: Exception, start_time: float, proxy_logging_obj: ProxyLogging |
|
): |
|
""" |
|
Raise an exception for failed update spend logs |
|
|
|
- Calls proxy_logging_obj.failure_handler to log the error |
|
- Ensures error messages says "Non-Blocking" |
|
""" |
|
import traceback |
|
|
|
error_msg = ( |
|
f"[Non-Blocking]LiteLLM Prisma Client Exception - update spend logs: {str(e)}" |
|
) |
|
error_traceback = error_msg + "\n" + traceback.format_exc() |
|
end_time = time.time() |
|
_duration = end_time - start_time |
|
asyncio.create_task( |
|
proxy_logging_obj.failure_handler( |
|
original_exception=e, |
|
duration=_duration, |
|
call_type="update_spend", |
|
traceback_str=error_traceback, |
|
) |
|
) |
|
raise e |
|
|
|
|
|
def _is_projected_spend_over_limit( |
|
current_spend: float, soft_budget_limit: Optional[float] |
|
): |
|
from datetime import date |
|
|
|
if soft_budget_limit is None: |
|
|
|
return False |
|
|
|
today = date.today() |
|
|
|
|
|
if today.month == 12: |
|
end_month = date(today.year + 1, 1, 1) - timedelta(days=1) |
|
else: |
|
end_month = date(today.year, today.month + 1, 1) - timedelta(days=1) |
|
|
|
remaining_days = (end_month - today).days |
|
|
|
|
|
if today.day == 1: |
|
daily_spend_estimate = current_spend |
|
else: |
|
daily_spend_estimate = current_spend / (today.day - 1) |
|
|
|
|
|
projected_spend = current_spend + (daily_spend_estimate * remaining_days) |
|
|
|
if projected_spend > soft_budget_limit: |
|
print_verbose("Projected spend exceeds soft budget limit!") |
|
return True |
|
return False |
|
|
|
|
|
def _get_projected_spend_over_limit( |
|
current_spend: float, soft_budget_limit: Optional[float] |
|
) -> Optional[tuple]: |
|
import datetime |
|
|
|
if soft_budget_limit is None: |
|
return None |
|
|
|
today = datetime.date.today() |
|
end_month = datetime.date(today.year, today.month + 1, 1) - datetime.timedelta( |
|
days=1 |
|
) |
|
remaining_days = (end_month - today).days |
|
|
|
daily_spend = current_spend / ( |
|
today.day - 1 |
|
) |
|
projected_spend = daily_spend * remaining_days |
|
|
|
if projected_spend > soft_budget_limit: |
|
approx_days = soft_budget_limit / daily_spend |
|
limit_exceed_date = today + datetime.timedelta(days=approx_days) |
|
|
|
|
|
return projected_spend, limit_exceed_date |
|
|
|
return None |
|
|
|
|
|
def _is_valid_team_configs(team_id=None, team_config=None, request_data=None): |
|
if team_id is None or team_config is None or request_data is None: |
|
return |
|
|
|
if "models" in team_config: |
|
valid_models = team_config.pop("models") |
|
model_in_request = request_data["model"] |
|
if model_in_request not in valid_models: |
|
raise Exception( |
|
f"Invalid model for team {team_id}: {model_in_request}. Valid models for team are: {valid_models}\n" |
|
) |
|
return |
|
|
|
|
|
def _to_ns(dt): |
|
return int(dt.timestamp() * 1e9) |
|
|
|
|
|
def get_error_message_str(e: Exception) -> str: |
|
error_message = "" |
|
if isinstance(e, HTTPException): |
|
if isinstance(e.detail, str): |
|
error_message = e.detail |
|
elif isinstance(e.detail, dict): |
|
error_message = json.dumps(e.detail) |
|
elif hasattr(e, "message"): |
|
_error = getattr(e, "message", None) |
|
if isinstance(_error, str): |
|
error_message = _error |
|
elif isinstance(_error, dict): |
|
error_message = json.dumps(_error) |
|
else: |
|
error_message = str(e) |
|
else: |
|
error_message = str(e) |
|
return error_message |
|
|
|
|
|
def _get_redoc_url() -> str: |
|
""" |
|
Get the redoc URL from the environment variables. |
|
|
|
- If REDOC_URL is set, return it. |
|
- Otherwise, default to "/redoc". |
|
""" |
|
return os.getenv("REDOC_URL", "/redoc") |
|
|
|
|
|
def _get_docs_url() -> Optional[str]: |
|
""" |
|
Get the docs URL from the environment variables. |
|
|
|
- If DOCS_URL is set, return it. |
|
- If NO_DOCS is True, return None. |
|
- Otherwise, default to "/". |
|
""" |
|
docs_url = os.getenv("DOCS_URL", None) |
|
if docs_url: |
|
return docs_url |
|
|
|
if os.getenv("NO_DOCS", "False") == "True": |
|
return None |
|
|
|
|
|
return "/" |
|
|
|
|
|
def handle_exception_on_proxy(e: Exception) -> ProxyException: |
|
""" |
|
Returns an Exception as ProxyException, this ensures all exceptions are OpenAI API compatible |
|
""" |
|
from fastapi import status |
|
|
|
if isinstance(e, HTTPException): |
|
return ProxyException( |
|
message=getattr(e, "detail", f"error({str(e)})"), |
|
type=ProxyErrorTypes.internal_server_error, |
|
param=getattr(e, "param", "None"), |
|
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), |
|
) |
|
elif isinstance(e, ProxyException): |
|
return e |
|
return ProxyException( |
|
message="Internal Server Error, " + str(e), |
|
type=ProxyErrorTypes.internal_server_error, |
|
param=getattr(e, "param", "None"), |
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
) |
|
|
|
|
|
def _premium_user_check(): |
|
""" |
|
Raises an HTTPException if the user is not a premium user |
|
""" |
|
from litellm.proxy.proxy_server import premium_user |
|
|
|
if not premium_user: |
|
raise HTTPException( |
|
status_code=403, |
|
detail={ |
|
"error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}" |
|
}, |
|
) |
|
|