File size: 11,682 Bytes
e3278e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
"""
PagerDuty Alerting Integration
Handles two types of alerts:
- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert.
- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert.
"""
import asyncio
import os
from datetime import datetime, timedelta, timezone
from typing import List, Literal, Optional, Union
from litellm._logging import verbose_logger
from litellm.caching import DualCache
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.pagerduty import (
AlertingConfig,
PagerDutyInternalEvent,
PagerDutyPayload,
PagerDutyRequestBody,
)
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingPayloadErrorInformation,
)
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600
class PagerDutyAlerting(SlackAlerting):
"""
Tracks failed requests and hanging requests separately.
If threshold is crossed for either type, triggers a PagerDuty alert.
"""
def __init__(
self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs
):
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
super().__init__()
_api_key = os.getenv("PAGERDUTY_API_KEY")
if not _api_key:
raise ValueError("PAGERDUTY_API_KEY is not set")
self.api_key: str = _api_key
alerting_args = alerting_args or {}
self.alerting_args: AlertingConfig = AlertingConfig(
failure_threshold=alerting_args.get(
"failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD
),
failure_threshold_window_seconds=alerting_args.get(
"failure_threshold_window_seconds",
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS,
),
hanging_threshold_seconds=alerting_args.get(
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
),
hanging_threshold_window_seconds=alerting_args.get(
"hanging_threshold_window_seconds",
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
),
)
# Separate storage for failures vs. hangs
self._failure_events: List[PagerDutyInternalEvent] = []
self._hanging_events: List[PagerDutyInternalEvent] = []
# premium user check
if premium_user is not True:
raise ValueError(
f"PagerDutyAlerting is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}"
)
# ------------------ MAIN LOGIC ------------------ #
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Record a failure event. Only send an alert to PagerDuty if the
configured *failure* threshold is exceeded in the specified window.
"""
now = datetime.now(timezone.utc)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if not standard_logging_payload:
raise ValueError(
"standard_logging_object is required for PagerDutyAlerting"
)
# Extract error details
error_info: Optional[StandardLoggingPayloadErrorInformation] = (
standard_logging_payload.get("error_information") or {}
)
_meta = standard_logging_payload.get("metadata") or {}
self._failure_events.append(
PagerDutyInternalEvent(
failure_event_type="failed_response",
timestamp=now,
error_class=error_info.get("error_class"),
error_code=error_info.get("error_code"),
error_llm_provider=error_info.get("llm_provider"),
user_api_key_hash=_meta.get("user_api_key_hash"),
user_api_key_alias=_meta.get("user_api_key_alias"),
user_api_key_org_id=_meta.get("user_api_key_org_id"),
user_api_key_team_id=_meta.get("user_api_key_team_id"),
user_api_key_user_id=_meta.get("user_api_key_user_id"),
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
)
)
# Prune + Possibly alert
window_seconds = self.alerting_args.get("failure_threshold_window_seconds", 60)
threshold = self.alerting_args.get("failure_threshold", 1)
# If threshold is crossed, send PD alert for failures
await self._send_alert_if_thresholds_crossed(
events=self._failure_events,
window_seconds=window_seconds,
threshold=threshold,
alert_prefix="High LLM API Failure Rate",
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Optional[Union[Exception, str, dict]]:
"""
Example of detecting hanging requests by waiting a given threshold.
If the request didn't finish by then, we treat it as 'hanging'.
"""
verbose_logger.info("Inside Proxy Logging Pre-call hook!")
asyncio.create_task(
self.hanging_response_handler(
request_data=data, user_api_key_dict=user_api_key_dict
)
)
return None
async def hanging_response_handler(
self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth
):
"""
Checks if request completed by the time 'hanging_threshold_seconds' elapses.
If not, we classify it as a hanging request.
"""
verbose_logger.debug(
f"Inside Hanging Response Handler!..sleeping for {self.alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds"
)
await asyncio.sleep(
self.alerting_args.get(
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
)
)
if await self._request_is_completed(request_data=request_data):
return # It's not hanging if completed
# Otherwise, record it as hanging
self._hanging_events.append(
PagerDutyInternalEvent(
failure_event_type="hanging_response",
timestamp=datetime.now(timezone.utc),
error_class="HangingRequest",
error_code="HangingRequest",
error_llm_provider="HangingRequest",
user_api_key_hash=user_api_key_dict.api_key,
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
)
)
# Prune + Possibly alert
window_seconds = self.alerting_args.get(
"hanging_threshold_window_seconds",
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
)
threshold: int = self.alerting_args.get(
"hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
)
# If threshold is crossed, send PD alert for hangs
await self._send_alert_if_thresholds_crossed(
events=self._hanging_events,
window_seconds=window_seconds,
threshold=threshold,
alert_prefix="High Number of Hanging LLM Requests",
)
# ------------------ HELPERS ------------------ #
async def _send_alert_if_thresholds_crossed(
self,
events: List[PagerDutyInternalEvent],
window_seconds: int,
threshold: int,
alert_prefix: str,
):
"""
1. Prune old events
2. If threshold is reached, build alert, send to PagerDuty
3. Clear those events
"""
cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds)
pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff]
# Update the reference list
events.clear()
events.extend(pruned)
# Check threshold
verbose_logger.debug(
f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}"
)
if len(events) >= threshold:
# Build short summary of last N events
error_summaries = self._build_error_summaries(events, max_errors=5)
alert_message = (
f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds."
)
custom_details = {"recent_errors": error_summaries}
await self.send_alert_to_pagerduty(
alert_message=alert_message,
custom_details=custom_details,
)
# Clear them after sending an alert, so we don't spam
events.clear()
def _build_error_summaries(
self, events: List[PagerDutyInternalEvent], max_errors: int = 5
) -> List[PagerDutyInternalEvent]:
"""
Build short text summaries for the last `max_errors`.
Example: "ValueError (code: 500, provider: openai)"
"""
recent = events[-max_errors:]
summaries = []
for fe in recent:
# If any of these is None, show "N/A" to avoid messing up the summary string
fe.pop("timestamp")
summaries.append(fe)
return summaries
async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict):
"""
Send [critical] Alert to PagerDuty
https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api
"""
try:
verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}")
async_client: AsyncHTTPHandler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
payload: PagerDutyRequestBody = PagerDutyRequestBody(
payload=PagerDutyPayload(
summary=alert_message,
severity="critical",
source="LiteLLM Alert",
component="LiteLLM",
custom_details=custom_details,
),
routing_key=self.api_key,
event_action="trigger",
)
return await async_client.post(
url="https://events.pagerduty.com/v2/enqueue",
json=dict(payload),
headers={"Content-Type": "application/json"},
)
except Exception as e:
verbose_logger.exception(f"Error sending alert to PagerDuty: {e}")
|