|
import json |
|
import threading |
|
from typing import Optional |
|
|
|
from litellm._logging import verbose_logger |
|
from litellm.integrations.custom_logger import CustomLogger |
|
|
|
|
|
class MlflowLogger(CustomLogger): |
|
def __init__(self): |
|
from mlflow.tracking import MlflowClient |
|
|
|
self._client = MlflowClient() |
|
|
|
self._stream_id_to_span = {} |
|
self._lock = threading.Lock() |
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
self._handle_success(kwargs, response_obj, start_time, end_time) |
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
self._handle_success(kwargs, response_obj, start_time, end_time) |
|
|
|
def _handle_success(self, kwargs, response_obj, start_time, end_time): |
|
""" |
|
Log the success event as an MLflow span. |
|
Note that this method is called asynchronously in the background thread. |
|
""" |
|
from mlflow.entities import SpanStatusCode |
|
|
|
try: |
|
verbose_logger.debug("MLflow logging start for success event") |
|
|
|
if kwargs.get("stream"): |
|
self._handle_stream_event(kwargs, response_obj, start_time, end_time) |
|
else: |
|
span = self._start_span_or_trace(kwargs, start_time) |
|
end_time_ns = int(end_time.timestamp() * 1e9) |
|
self._extract_and_set_chat_attributes(span, kwargs, response_obj) |
|
self._end_span_or_trace( |
|
span=span, |
|
outputs=response_obj, |
|
status=SpanStatusCode.OK, |
|
end_time_ns=end_time_ns, |
|
) |
|
except Exception: |
|
verbose_logger.debug("MLflow Logging Error", stack_info=True) |
|
|
|
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj): |
|
try: |
|
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools |
|
except ImportError: |
|
return |
|
|
|
inputs = self._construct_input(kwargs) |
|
input_messages = inputs.get("messages", []) |
|
output_messages = [c.message.model_dump(exclude_none=True) |
|
for c in getattr(response_obj, "choices", [])] |
|
if messages := [*input_messages, *output_messages]: |
|
set_span_chat_messages(span, messages) |
|
if tools := inputs.get("tools"): |
|
set_span_chat_tools(span, tools) |
|
|
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
self._handle_failure(kwargs, response_obj, start_time, end_time) |
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
self._handle_failure(kwargs, response_obj, start_time, end_time) |
|
|
|
def _handle_failure(self, kwargs, response_obj, start_time, end_time): |
|
""" |
|
Log the failure event as an MLflow span. |
|
Note that this method is called *synchronously* unlike the success handler. |
|
""" |
|
from mlflow.entities import SpanEvent, SpanStatusCode |
|
|
|
try: |
|
span = self._start_span_or_trace(kwargs, start_time) |
|
|
|
end_time_ns = int(end_time.timestamp() * 1e9) |
|
|
|
|
|
if exception := kwargs.get("exception"): |
|
span.add_event(SpanEvent.from_exception(exception)) |
|
|
|
self._extract_and_set_chat_attributes(span, kwargs, response_obj) |
|
self._end_span_or_trace( |
|
span=span, |
|
outputs=response_obj, |
|
status=SpanStatusCode.ERROR, |
|
end_time_ns=end_time_ns, |
|
) |
|
|
|
except Exception as e: |
|
verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True) |
|
|
|
def _handle_stream_event(self, kwargs, response_obj, start_time, end_time): |
|
""" |
|
Handle the success event for a streaming response. For streaming calls, |
|
log_success_event handle is triggered for every chunk of the stream. |
|
We create a single span for the entire stream request as follows: |
|
|
|
1. For the first chunk, start a new span and store it in the map. |
|
2. For subsequent chunks, add the chunk as an event to the span. |
|
3. For the final chunk, end the span and remove the span from the map. |
|
""" |
|
from mlflow.entities import SpanStatusCode |
|
|
|
litellm_call_id = kwargs.get("litellm_call_id") |
|
|
|
if litellm_call_id not in self._stream_id_to_span: |
|
with self._lock: |
|
|
|
if litellm_call_id not in self._stream_id_to_span: |
|
|
|
span = self._start_span_or_trace(kwargs, start_time) |
|
self._stream_id_to_span[litellm_call_id] = span |
|
|
|
|
|
span = self._stream_id_to_span[litellm_call_id] |
|
self._add_chunk_events(span, response_obj) |
|
|
|
|
|
|
|
if final_response := kwargs.get("complete_streaming_response"): |
|
end_time_ns = int(end_time.timestamp() * 1e9) |
|
|
|
self._extract_and_set_chat_attributes(span, kwargs, final_response) |
|
self._end_span_or_trace( |
|
span=span, |
|
outputs=final_response, |
|
status=SpanStatusCode.OK, |
|
end_time_ns=end_time_ns, |
|
) |
|
|
|
|
|
with self._lock: |
|
self._stream_id_to_span.pop(litellm_call_id) |
|
|
|
def _add_chunk_events(self, span, response_obj): |
|
from mlflow.entities import SpanEvent |
|
|
|
try: |
|
for choice in response_obj.choices: |
|
span.add_event( |
|
SpanEvent( |
|
name="streaming_chunk", |
|
attributes={"delta": json.dumps(choice.delta.model_dump())}, |
|
) |
|
) |
|
except Exception: |
|
verbose_logger.debug("Error adding chunk events to span", stack_info=True) |
|
|
|
def _construct_input(self, kwargs): |
|
"""Construct span inputs with optional parameters""" |
|
inputs = {"messages": kwargs.get("messages")} |
|
if tools := kwargs.get("tools"): |
|
inputs["tools"] = tools |
|
|
|
for key in ["functions", "tools", "stream", "tool_choice", "user"]: |
|
if value := kwargs.get("optional_params", {}).pop(key, None): |
|
inputs[key] = value |
|
return inputs |
|
|
|
def _extract_attributes(self, kwargs): |
|
""" |
|
Extract span attributes from kwargs. |
|
|
|
With the latest version of litellm, the standard_logging_object contains |
|
canonical information for logging. If it is not present, we extract |
|
subset of attributes from other kwargs. |
|
""" |
|
attributes = { |
|
"litellm_call_id": kwargs.get("litellm_call_id"), |
|
"call_type": kwargs.get("call_type"), |
|
"model": kwargs.get("model"), |
|
} |
|
standard_obj = kwargs.get("standard_logging_object") |
|
if standard_obj: |
|
attributes.update( |
|
{ |
|
"api_base": standard_obj.get("api_base"), |
|
"cache_hit": standard_obj.get("cache_hit"), |
|
"usage": { |
|
"completion_tokens": standard_obj.get("completion_tokens"), |
|
"prompt_tokens": standard_obj.get("prompt_tokens"), |
|
"total_tokens": standard_obj.get("total_tokens"), |
|
}, |
|
"raw_llm_response": standard_obj.get("response"), |
|
"response_cost": standard_obj.get("response_cost"), |
|
"saved_cache_cost": standard_obj.get("saved_cache_cost"), |
|
} |
|
) |
|
else: |
|
litellm_params = kwargs.get("litellm_params", {}) |
|
attributes.update( |
|
{ |
|
"model": kwargs.get("model"), |
|
"cache_hit": kwargs.get("cache_hit"), |
|
"custom_llm_provider": kwargs.get("custom_llm_provider"), |
|
"api_base": litellm_params.get("api_base"), |
|
"response_cost": kwargs.get("response_cost"), |
|
} |
|
) |
|
return attributes |
|
|
|
def _get_span_type(self, call_type: Optional[str]) -> str: |
|
from mlflow.entities import SpanType |
|
|
|
if call_type in ["completion", "acompletion"]: |
|
return SpanType.LLM |
|
elif call_type == "embeddings": |
|
return SpanType.EMBEDDING |
|
else: |
|
return SpanType.LLM |
|
|
|
def _start_span_or_trace(self, kwargs, start_time): |
|
""" |
|
Start an MLflow span or a trace. |
|
|
|
If there is an active span, we start a new span as a child of |
|
that span. Otherwise, we start a new trace. |
|
""" |
|
import mlflow |
|
|
|
call_type = kwargs.get("call_type", "completion") |
|
span_name = f"litellm-{call_type}" |
|
span_type = self._get_span_type(call_type) |
|
start_time_ns = int(start_time.timestamp() * 1e9) |
|
|
|
inputs = self._construct_input(kwargs) |
|
attributes = self._extract_attributes(kwargs) |
|
|
|
if active_span := mlflow.get_current_active_span(): |
|
return self._client.start_span( |
|
name=span_name, |
|
request_id=active_span.request_id, |
|
parent_id=active_span.span_id, |
|
span_type=span_type, |
|
inputs=inputs, |
|
attributes=attributes, |
|
start_time_ns=start_time_ns, |
|
) |
|
else: |
|
return self._client.start_trace( |
|
name=span_name, |
|
span_type=span_type, |
|
inputs=inputs, |
|
attributes=attributes, |
|
start_time_ns=start_time_ns, |
|
) |
|
|
|
def _end_span_or_trace(self, span, outputs, end_time_ns, status): |
|
"""End an MLflow span or a trace.""" |
|
if span.parent_id is None: |
|
self._client.end_trace( |
|
request_id=span.request_id, |
|
outputs=outputs, |
|
status=status, |
|
end_time_ns=end_time_ns, |
|
) |
|
else: |
|
self._client.end_span( |
|
request_id=span.request_id, |
|
span_id=span.span_id, |
|
outputs=outputs, |
|
status=status, |
|
end_time_ns=end_time_ns, |
|
) |
|
|