|
|
|
|
|
import traceback |
|
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union |
|
|
|
from pydantic import BaseModel |
|
|
|
from litellm.caching.caching import DualCache |
|
from litellm.proxy._types import UserAPIKeyAuth |
|
from litellm.types.integrations.argilla import ArgillaItem |
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest |
|
from litellm.types.utils import ( |
|
AdapterCompletionStreamWrapper, |
|
EmbeddingResponse, |
|
ImageResponse, |
|
ModelResponse, |
|
StandardCallbackDynamicParams, |
|
StandardLoggingPayload, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from opentelemetry.trace import Span as _Span |
|
|
|
Span = _Span |
|
else: |
|
Span = Any |
|
|
|
|
|
class CustomLogger: |
|
|
|
def __init__(self, message_logging: bool = True) -> None: |
|
self.message_logging = message_logging |
|
pass |
|
|
|
def log_pre_api_call(self, model, messages, kwargs): |
|
pass |
|
|
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): |
|
pass |
|
|
|
def log_stream_event(self, kwargs, response_obj, start_time, end_time): |
|
pass |
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
pass |
|
|
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
pass |
|
|
|
|
|
|
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): |
|
pass |
|
|
|
async def async_log_pre_api_call(self, model, messages, kwargs): |
|
pass |
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
pass |
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
pass |
|
|
|
|
|
|
|
async def async_get_chat_completion_prompt( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
non_default_params: dict, |
|
prompt_id: str, |
|
prompt_variables: Optional[dict], |
|
dynamic_callback_params: StandardCallbackDynamicParams, |
|
) -> Tuple[str, List[AllMessageValues], dict]: |
|
""" |
|
Returns: |
|
- model: str - the model to use (can be pulled from prompt management tool) |
|
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) |
|
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) |
|
""" |
|
return model, messages, non_default_params |
|
|
|
def get_chat_completion_prompt( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
non_default_params: dict, |
|
prompt_id: str, |
|
prompt_variables: Optional[dict], |
|
dynamic_callback_params: StandardCallbackDynamicParams, |
|
) -> Tuple[str, List[AllMessageValues], dict]: |
|
""" |
|
Returns: |
|
- model: str - the model to use (can be pulled from prompt management tool) |
|
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) |
|
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) |
|
""" |
|
return model, messages, non_default_params |
|
|
|
|
|
""" |
|
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). |
|
""" |
|
|
|
async def async_filter_deployments( |
|
self, |
|
model: str, |
|
healthy_deployments: List, |
|
messages: Optional[List[AllMessageValues]], |
|
request_kwargs: Optional[dict] = None, |
|
parent_otel_span: Optional[Span] = None, |
|
) -> List[dict]: |
|
return healthy_deployments |
|
|
|
async def async_pre_call_check( |
|
self, deployment: dict, parent_otel_span: Optional[Span] |
|
) -> Optional[dict]: |
|
pass |
|
|
|
def pre_call_check(self, deployment: dict) -> Optional[dict]: |
|
pass |
|
|
|
|
|
async def log_model_group_rate_limit_error( |
|
self, exception: Exception, original_model_group: Optional[str], kwargs: dict |
|
): |
|
pass |
|
|
|
async def log_success_fallback_event( |
|
self, original_model_group: str, kwargs: dict, original_exception: Exception |
|
): |
|
pass |
|
|
|
async def log_failure_fallback_event( |
|
self, original_model_group: str, kwargs: dict, original_exception: Exception |
|
): |
|
pass |
|
|
|
|
|
|
|
def translate_completion_input_params( |
|
self, kwargs |
|
) -> Optional[ChatCompletionRequest]: |
|
""" |
|
Translates the input params, from the provider's native format to the litellm.completion() format. |
|
""" |
|
pass |
|
|
|
def translate_completion_output_params( |
|
self, response: ModelResponse |
|
) -> Optional[BaseModel]: |
|
""" |
|
Translates the output params, from the OpenAI format to the custom format. |
|
""" |
|
pass |
|
|
|
def translate_completion_output_params_streaming( |
|
self, completion_stream: Any |
|
) -> Optional[AdapterCompletionStreamWrapper]: |
|
""" |
|
Translates the streaming chunk, from the OpenAI format to the custom format. |
|
""" |
|
pass |
|
|
|
|
|
|
|
async def async_dataset_hook( |
|
self, |
|
logged_item: ArgillaItem, |
|
standard_logging_payload: Optional[StandardLoggingPayload], |
|
) -> Optional[ArgillaItem]: |
|
""" |
|
- Decide if the result should be logged to Argilla. |
|
- Modify the result before logging to Argilla. |
|
- Return None if the result should not be logged to Argilla. |
|
""" |
|
raise NotImplementedError("async_dataset_hook not implemented") |
|
|
|
|
|
""" |
|
Control the modify incoming / outgoung data before calling the model |
|
""" |
|
|
|
async def async_pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
cache: DualCache, |
|
data: dict, |
|
call_type: Literal[ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
"pass_through_endpoint", |
|
"rerank", |
|
], |
|
) -> Optional[ |
|
Union[Exception, str, dict] |
|
]: |
|
pass |
|
|
|
async def async_post_call_failure_hook( |
|
self, |
|
request_data: dict, |
|
original_exception: Exception, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
): |
|
pass |
|
|
|
async def async_post_call_success_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], |
|
) -> Any: |
|
pass |
|
|
|
async def async_logging_hook( |
|
self, kwargs: dict, result: Any, call_type: str |
|
) -> Tuple[dict, Any]: |
|
"""For masking logged request/response. Return a modified version of the request/result.""" |
|
return kwargs, result |
|
|
|
def logging_hook( |
|
self, kwargs: dict, result: Any, call_type: str |
|
) -> Tuple[dict, Any]: |
|
"""For masking logged request/response. Return a modified version of the request/result.""" |
|
return kwargs, result |
|
|
|
async def async_moderation_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
call_type: Literal[ |
|
"completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
], |
|
) -> Any: |
|
pass |
|
|
|
async def async_post_call_streaming_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
response: str, |
|
) -> Any: |
|
pass |
|
|
|
|
|
|
|
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): |
|
try: |
|
kwargs["model"] = model |
|
kwargs["messages"] = messages |
|
kwargs["log_event_type"] = "pre_api_call" |
|
callback_func( |
|
kwargs, |
|
) |
|
print_verbose(f"Custom Logger - model call details: {kwargs}") |
|
except Exception: |
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") |
|
|
|
async def async_log_input_event( |
|
self, model, messages, kwargs, print_verbose, callback_func |
|
): |
|
try: |
|
kwargs["model"] = model |
|
kwargs["messages"] = messages |
|
kwargs["log_event_type"] = "pre_api_call" |
|
await callback_func( |
|
kwargs, |
|
) |
|
print_verbose(f"Custom Logger - model call details: {kwargs}") |
|
except Exception: |
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") |
|
|
|
def log_event( |
|
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func |
|
): |
|
|
|
try: |
|
kwargs["log_event_type"] = "post_api_call" |
|
callback_func( |
|
kwargs, |
|
response_obj, |
|
start_time, |
|
end_time, |
|
) |
|
except Exception: |
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") |
|
pass |
|
|
|
async def async_log_event( |
|
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func |
|
): |
|
|
|
try: |
|
kwargs["log_event_type"] = "post_api_call" |
|
await callback_func( |
|
kwargs, |
|
response_obj, |
|
start_time, |
|
end_time, |
|
) |
|
except Exception: |
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") |
|
pass |
|
|
|
|
|
|
|
def truncate_standard_logging_payload_content( |
|
self, |
|
standard_logging_object: StandardLoggingPayload, |
|
): |
|
""" |
|
Truncate error strings and message content in logging payload |
|
|
|
Some loggers like DataDog/ GCS Bucket have a limit on the size of the payload. (1MB) |
|
|
|
This function truncates the error string and the message content if they exceed a certain length. |
|
""" |
|
MAX_STR_LENGTH = 10_000 |
|
|
|
|
|
fields_to_truncate = ["error_str", "messages", "response"] |
|
for field in fields_to_truncate: |
|
self._truncate_field( |
|
standard_logging_object=standard_logging_object, |
|
field_name=field, |
|
max_length=MAX_STR_LENGTH, |
|
) |
|
|
|
def _truncate_field( |
|
self, |
|
standard_logging_object: StandardLoggingPayload, |
|
field_name: str, |
|
max_length: int, |
|
) -> None: |
|
""" |
|
Helper function to truncate a field in the logging payload |
|
|
|
This converts the field to a string and then truncates it if it exceeds the max length. |
|
|
|
Why convert to string ? |
|
1. User was sending a poorly formatted list for `messages` field, we could not predict where they would send content |
|
- Converting to string and then truncating the logged content catches this |
|
2. We want to avoid modifying the original `messages`, `response`, and `error_str` in the logging payload since these are in kwargs and could be returned to the user |
|
""" |
|
field_value = standard_logging_object.get(field_name) |
|
if field_value: |
|
str_value = str(field_value) |
|
if len(str_value) > max_length: |
|
standard_logging_object[field_name] = self._truncate_text( |
|
text=str_value, max_length=max_length |
|
) |
|
|
|
def _truncate_text(self, text: str, max_length: int) -> str: |
|
"""Truncate text if it exceeds max_length""" |
|
return ( |
|
text[:max_length] |
|
+ "...truncated by litellm, this logger does not support large content" |
|
if len(text) > max_length |
|
else text |
|
) |
|
|