|
|
|
|
|
import asyncio |
|
import logging |
|
import random |
|
from typing import List, Optional |
|
|
|
import litellm |
|
|
|
logger = logging.getLogger(__name__) |
|
from litellm.constants import HEALTH_CHECK_TIMEOUT_SECONDS |
|
|
|
ILLEGAL_DISPLAY_PARAMS = [ |
|
"messages", |
|
"api_key", |
|
"prompt", |
|
"input", |
|
"vertex_credentials", |
|
"aws_access_key_id", |
|
"aws_secret_access_key", |
|
] |
|
|
|
MINIMAL_DISPLAY_PARAMS = ["model", "mode_error"] |
|
|
|
|
|
def _get_random_llm_message(): |
|
""" |
|
Get a random message from the LLM. |
|
""" |
|
messages = ["Hey how's it going?", "What's 1 + 1?"] |
|
|
|
return [{"role": "user", "content": random.choice(messages)}] |
|
|
|
|
|
def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True): |
|
""" |
|
Clean the endpoint data for display to users. |
|
""" |
|
return ( |
|
{k: v for k, v in endpoint_data.items() if k not in ILLEGAL_DISPLAY_PARAMS} |
|
if details is not False |
|
else {k: v for k, v in endpoint_data.items() if k in MINIMAL_DISPLAY_PARAMS} |
|
) |
|
|
|
|
|
def filter_deployments_by_id( |
|
model_list: List, |
|
) -> List: |
|
seen_ids = set() |
|
filtered_deployments = [] |
|
|
|
for deployment in model_list: |
|
_model_info = deployment.get("model_info") or {} |
|
_id = _model_info.get("id") or None |
|
if _id is None: |
|
continue |
|
|
|
if _id not in seen_ids: |
|
seen_ids.add(_id) |
|
filtered_deployments.append(deployment) |
|
|
|
return filtered_deployments |
|
|
|
|
|
async def run_with_timeout(task, timeout): |
|
try: |
|
return await asyncio.wait_for(task, timeout) |
|
except asyncio.TimeoutError: |
|
task.cancel() |
|
|
|
current_task = asyncio.current_task() |
|
for t in asyncio.all_tasks(): |
|
if t != current_task: |
|
t.cancel() |
|
try: |
|
await asyncio.wait_for(task, 0.1) |
|
except (asyncio.TimeoutError, asyncio.CancelledError, Exception): |
|
pass |
|
return {"error": "Timeout exceeded"} |
|
|
|
|
|
async def _perform_health_check(model_list: list, details: Optional[bool] = True): |
|
""" |
|
Perform a health check for each model in the list. |
|
""" |
|
|
|
tasks = [] |
|
for model in model_list: |
|
litellm_params = model["litellm_params"] |
|
model_info = model.get("model_info", {}) |
|
mode = model_info.get("mode", None) |
|
litellm_params = _update_litellm_params_for_health_check( |
|
model_info, litellm_params |
|
) |
|
timeout = model_info.get("health_check_timeout") or HEALTH_CHECK_TIMEOUT_SECONDS |
|
|
|
task = run_with_timeout( |
|
litellm.ahealth_check( |
|
model["litellm_params"], |
|
mode=mode, |
|
prompt="test from litellm", |
|
input=["test from litellm"], |
|
), |
|
timeout, |
|
) |
|
|
|
tasks.append(task) |
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
healthy_endpoints = [] |
|
unhealthy_endpoints = [] |
|
|
|
for is_healthy, model in zip(results, model_list): |
|
litellm_params = model["litellm_params"] |
|
|
|
if isinstance(is_healthy, dict) and "error" not in is_healthy: |
|
healthy_endpoints.append( |
|
_clean_endpoint_data({**litellm_params, **is_healthy}, details) |
|
) |
|
elif isinstance(is_healthy, dict): |
|
unhealthy_endpoints.append( |
|
_clean_endpoint_data({**litellm_params, **is_healthy}, details) |
|
) |
|
else: |
|
unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details)) |
|
|
|
return healthy_endpoints, unhealthy_endpoints |
|
|
|
|
|
def _update_litellm_params_for_health_check( |
|
model_info: dict, litellm_params: dict |
|
) -> dict: |
|
""" |
|
Update the litellm params for health check. |
|
|
|
- gets a short `messages` param for health check |
|
- updates the `model` param with the `health_check_model` if it exists Doc: https://docs.litellm.ai/docs/proxy/health#wildcard-routes |
|
""" |
|
litellm_params["messages"] = _get_random_llm_message() |
|
_health_check_model = model_info.get("health_check_model", None) |
|
if _health_check_model is not None: |
|
litellm_params["model"] = _health_check_model |
|
return litellm_params |
|
|
|
|
|
async def perform_health_check( |
|
model_list: list, |
|
model: Optional[str] = None, |
|
cli_model: Optional[str] = None, |
|
details: Optional[bool] = True, |
|
): |
|
""" |
|
Perform a health check on the system. |
|
|
|
Returns: |
|
(bool): True if the health check passes, False otherwise. |
|
""" |
|
if not model_list: |
|
if cli_model: |
|
model_list = [ |
|
{"model_name": cli_model, "litellm_params": {"model": cli_model}} |
|
] |
|
else: |
|
return [], [] |
|
|
|
if model is not None: |
|
_new_model_list = [ |
|
x for x in model_list if x["litellm_params"]["model"] == model |
|
] |
|
if _new_model_list == []: |
|
_new_model_list = [x for x in model_list if x["model_name"] == model] |
|
model_list = _new_model_list |
|
|
|
model_list = filter_deployments_by_id( |
|
model_list=model_list |
|
) |
|
healthy_endpoints, unhealthy_endpoints = await _perform_health_check( |
|
model_list, details |
|
) |
|
|
|
return healthy_endpoints, unhealthy_endpoints |
|
|