|
""" |
|
PagerDuty Alerting Integration |
|
|
|
Handles two types of alerts: |
|
- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert. |
|
- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert. |
|
""" |
|
|
|
import asyncio |
|
import os |
|
from datetime import datetime, timedelta, timezone |
|
from typing import List, Literal, Optional, Union |
|
|
|
from litellm._logging import verbose_logger |
|
from litellm.caching import DualCache |
|
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
AsyncHTTPHandler, |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
from litellm.types.integrations.pagerduty import ( |
|
AlertingConfig, |
|
PagerDutyInternalEvent, |
|
PagerDutyPayload, |
|
PagerDutyRequestBody, |
|
) |
|
from litellm.types.utils import ( |
|
StandardLoggingPayload, |
|
StandardLoggingPayloadErrorInformation, |
|
) |
|
|
|
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60 |
|
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60 |
|
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60 |
|
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600 |
|
|
|
|
|
class PagerDutyAlerting(SlackAlerting): |
|
""" |
|
Tracks failed requests and hanging requests separately. |
|
If threshold is crossed for either type, triggers a PagerDuty alert. |
|
""" |
|
|
|
def __init__( |
|
self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs |
|
): |
|
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user |
|
|
|
super().__init__() |
|
_api_key = os.getenv("PAGERDUTY_API_KEY") |
|
if not _api_key: |
|
raise ValueError("PAGERDUTY_API_KEY is not set") |
|
|
|
self.api_key: str = _api_key |
|
alerting_args = alerting_args or {} |
|
self.alerting_args: AlertingConfig = AlertingConfig( |
|
failure_threshold=alerting_args.get( |
|
"failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD |
|
), |
|
failure_threshold_window_seconds=alerting_args.get( |
|
"failure_threshold_window_seconds", |
|
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS, |
|
), |
|
hanging_threshold_seconds=alerting_args.get( |
|
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS |
|
), |
|
hanging_threshold_window_seconds=alerting_args.get( |
|
"hanging_threshold_window_seconds", |
|
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, |
|
), |
|
) |
|
|
|
|
|
self._failure_events: List[PagerDutyInternalEvent] = [] |
|
self._hanging_events: List[PagerDutyInternalEvent] = [] |
|
|
|
|
|
if premium_user is not True: |
|
raise ValueError( |
|
f"PagerDutyAlerting is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}" |
|
) |
|
|
|
|
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
""" |
|
Record a failure event. Only send an alert to PagerDuty if the |
|
configured *failure* threshold is exceeded in the specified window. |
|
""" |
|
now = datetime.now(timezone.utc) |
|
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( |
|
"standard_logging_object" |
|
) |
|
if not standard_logging_payload: |
|
raise ValueError( |
|
"standard_logging_object is required for PagerDutyAlerting" |
|
) |
|
|
|
|
|
error_info: Optional[StandardLoggingPayloadErrorInformation] = ( |
|
standard_logging_payload.get("error_information") or {} |
|
) |
|
_meta = standard_logging_payload.get("metadata") or {} |
|
|
|
self._failure_events.append( |
|
PagerDutyInternalEvent( |
|
failure_event_type="failed_response", |
|
timestamp=now, |
|
error_class=error_info.get("error_class"), |
|
error_code=error_info.get("error_code"), |
|
error_llm_provider=error_info.get("llm_provider"), |
|
user_api_key_hash=_meta.get("user_api_key_hash"), |
|
user_api_key_alias=_meta.get("user_api_key_alias"), |
|
user_api_key_org_id=_meta.get("user_api_key_org_id"), |
|
user_api_key_team_id=_meta.get("user_api_key_team_id"), |
|
user_api_key_user_id=_meta.get("user_api_key_user_id"), |
|
user_api_key_team_alias=_meta.get("user_api_key_team_alias"), |
|
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"), |
|
) |
|
) |
|
|
|
|
|
window_seconds = self.alerting_args.get("failure_threshold_window_seconds", 60) |
|
threshold = self.alerting_args.get("failure_threshold", 1) |
|
|
|
|
|
await self._send_alert_if_thresholds_crossed( |
|
events=self._failure_events, |
|
window_seconds=window_seconds, |
|
threshold=threshold, |
|
alert_prefix="High LLM API Failure Rate", |
|
) |
|
|
|
async def async_pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
cache: DualCache, |
|
data: dict, |
|
call_type: Literal[ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
"pass_through_endpoint", |
|
"rerank", |
|
], |
|
) -> Optional[Union[Exception, str, dict]]: |
|
""" |
|
Example of detecting hanging requests by waiting a given threshold. |
|
If the request didn't finish by then, we treat it as 'hanging'. |
|
""" |
|
verbose_logger.info("Inside Proxy Logging Pre-call hook!") |
|
asyncio.create_task( |
|
self.hanging_response_handler( |
|
request_data=data, user_api_key_dict=user_api_key_dict |
|
) |
|
) |
|
return None |
|
|
|
async def hanging_response_handler( |
|
self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth |
|
): |
|
""" |
|
Checks if request completed by the time 'hanging_threshold_seconds' elapses. |
|
If not, we classify it as a hanging request. |
|
""" |
|
verbose_logger.debug( |
|
f"Inside Hanging Response Handler!..sleeping for {self.alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds" |
|
) |
|
await asyncio.sleep( |
|
self.alerting_args.get( |
|
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS |
|
) |
|
) |
|
|
|
if await self._request_is_completed(request_data=request_data): |
|
return |
|
|
|
|
|
self._hanging_events.append( |
|
PagerDutyInternalEvent( |
|
failure_event_type="hanging_response", |
|
timestamp=datetime.now(timezone.utc), |
|
error_class="HangingRequest", |
|
error_code="HangingRequest", |
|
error_llm_provider="HangingRequest", |
|
user_api_key_hash=user_api_key_dict.api_key, |
|
user_api_key_alias=user_api_key_dict.key_alias, |
|
user_api_key_org_id=user_api_key_dict.org_id, |
|
user_api_key_team_id=user_api_key_dict.team_id, |
|
user_api_key_user_id=user_api_key_dict.user_id, |
|
user_api_key_team_alias=user_api_key_dict.team_alias, |
|
user_api_key_end_user_id=user_api_key_dict.end_user_id, |
|
) |
|
) |
|
|
|
|
|
window_seconds = self.alerting_args.get( |
|
"hanging_threshold_window_seconds", |
|
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, |
|
) |
|
threshold: int = self.alerting_args.get( |
|
"hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS |
|
) |
|
|
|
|
|
await self._send_alert_if_thresholds_crossed( |
|
events=self._hanging_events, |
|
window_seconds=window_seconds, |
|
threshold=threshold, |
|
alert_prefix="High Number of Hanging LLM Requests", |
|
) |
|
|
|
|
|
|
|
async def _send_alert_if_thresholds_crossed( |
|
self, |
|
events: List[PagerDutyInternalEvent], |
|
window_seconds: int, |
|
threshold: int, |
|
alert_prefix: str, |
|
): |
|
""" |
|
1. Prune old events |
|
2. If threshold is reached, build alert, send to PagerDuty |
|
3. Clear those events |
|
""" |
|
cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds) |
|
pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff] |
|
|
|
|
|
events.clear() |
|
events.extend(pruned) |
|
|
|
|
|
verbose_logger.debug( |
|
f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}" |
|
) |
|
if len(events) >= threshold: |
|
|
|
error_summaries = self._build_error_summaries(events, max_errors=5) |
|
alert_message = ( |
|
f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds." |
|
) |
|
custom_details = {"recent_errors": error_summaries} |
|
|
|
await self.send_alert_to_pagerduty( |
|
alert_message=alert_message, |
|
custom_details=custom_details, |
|
) |
|
|
|
|
|
events.clear() |
|
|
|
def _build_error_summaries( |
|
self, events: List[PagerDutyInternalEvent], max_errors: int = 5 |
|
) -> List[PagerDutyInternalEvent]: |
|
""" |
|
Build short text summaries for the last `max_errors`. |
|
Example: "ValueError (code: 500, provider: openai)" |
|
""" |
|
recent = events[-max_errors:] |
|
summaries = [] |
|
for fe in recent: |
|
|
|
fe.pop("timestamp") |
|
summaries.append(fe) |
|
return summaries |
|
|
|
async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict): |
|
""" |
|
Send [critical] Alert to PagerDuty |
|
|
|
https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api |
|
""" |
|
try: |
|
verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}") |
|
async_client: AsyncHTTPHandler = get_async_httpx_client( |
|
llm_provider=httpxSpecialProvider.LoggingCallback |
|
) |
|
payload: PagerDutyRequestBody = PagerDutyRequestBody( |
|
payload=PagerDutyPayload( |
|
summary=alert_message, |
|
severity="critical", |
|
source="LiteLLM Alert", |
|
component="LiteLLM", |
|
custom_details=custom_details, |
|
), |
|
routing_key=self.api_key, |
|
event_action="trigger", |
|
) |
|
|
|
return await async_client.post( |
|
url="https://events.pagerduty.com/v2/enqueue", |
|
json=dict(payload), |
|
headers={"Content-Type": "application/json"}, |
|
) |
|
except Exception as e: |
|
verbose_logger.exception(f"Error sending alert to PagerDuty: {e}") |
|
|