File size: 10,636 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import json
import threading
from typing import Optional

from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger


class MlflowLogger(CustomLogger):
    def __init__(self):
        from mlflow.tracking import MlflowClient

        self._client = MlflowClient()

        self._stream_id_to_span = {}
        self._lock = threading.Lock()  # lock for _stream_id_to_span

    def log_success_event(self, kwargs, response_obj, start_time, end_time):
        self._handle_success(kwargs, response_obj, start_time, end_time)

    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
        self._handle_success(kwargs, response_obj, start_time, end_time)

    def _handle_success(self, kwargs, response_obj, start_time, end_time):
        """
        Log the success event as an MLflow span.
        Note that this method is called asynchronously in the background thread.
        """
        from mlflow.entities import SpanStatusCode

        try:
            verbose_logger.debug("MLflow logging start for success event")

            if kwargs.get("stream"):
                self._handle_stream_event(kwargs, response_obj, start_time, end_time)
            else:
                span = self._start_span_or_trace(kwargs, start_time)
                end_time_ns = int(end_time.timestamp() * 1e9)
                self._extract_and_set_chat_attributes(span, kwargs, response_obj)
                self._end_span_or_trace(
                    span=span,
                    outputs=response_obj,
                    status=SpanStatusCode.OK,
                    end_time_ns=end_time_ns,
                )
        except Exception:
            verbose_logger.debug("MLflow Logging Error", stack_info=True)

    def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
        try:
            from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
        except ImportError:
            return

        inputs = self._construct_input(kwargs)
        input_messages = inputs.get("messages", [])
        output_messages = [c.message.model_dump(exclude_none=True)
                           for c in getattr(response_obj, "choices", [])]
        if messages := [*input_messages, *output_messages]:
            set_span_chat_messages(span, messages)
        if tools := inputs.get("tools"):
            set_span_chat_tools(span, tools)

    def log_failure_event(self, kwargs, response_obj, start_time, end_time):
        self._handle_failure(kwargs, response_obj, start_time, end_time)

    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
        self._handle_failure(kwargs, response_obj, start_time, end_time)

    def _handle_failure(self, kwargs, response_obj, start_time, end_time):
        """
        Log the failure event as an MLflow span.
        Note that this method is called *synchronously* unlike the success handler.
        """
        from mlflow.entities import SpanEvent, SpanStatusCode

        try:
            span = self._start_span_or_trace(kwargs, start_time)

            end_time_ns = int(end_time.timestamp() * 1e9)

            # Record exception info as event
            if exception := kwargs.get("exception"):
                span.add_event(SpanEvent.from_exception(exception))  # type: ignore

            self._extract_and_set_chat_attributes(span, kwargs, response_obj)
            self._end_span_or_trace(
                span=span,
                outputs=response_obj,
                status=SpanStatusCode.ERROR,
                end_time_ns=end_time_ns,
            )

        except Exception as e:
            verbose_logger.debug(f"MLflow Logging Error - {e}", stack_info=True)

    def _handle_stream_event(self, kwargs, response_obj, start_time, end_time):
        """
        Handle the success event for a streaming response. For streaming calls,
        log_success_event handle is triggered for every chunk of the stream.
        We create a single span for the entire stream request as follows:

        1. For the first chunk, start a new span and store it in the map.
        2. For subsequent chunks, add the chunk as an event to the span.
        3. For the final chunk, end the span and remove the span from the map.
        """
        from mlflow.entities import SpanStatusCode

        litellm_call_id = kwargs.get("litellm_call_id")

        if litellm_call_id not in self._stream_id_to_span:
            with self._lock:
                # Check again after acquiring lock
                if litellm_call_id not in self._stream_id_to_span:
                    # Start a new span for the first chunk of the stream
                    span = self._start_span_or_trace(kwargs, start_time)
                    self._stream_id_to_span[litellm_call_id] = span

        # Add chunk as event to the span
        span = self._stream_id_to_span[litellm_call_id]
        self._add_chunk_events(span, response_obj)

        # If this is the final chunk, end the span. The final chunk
        # has complete_streaming_response that gathers the full response.
        if final_response := kwargs.get("complete_streaming_response"):
            end_time_ns = int(end_time.timestamp() * 1e9)

            self._extract_and_set_chat_attributes(span, kwargs, final_response)
            self._end_span_or_trace(
                span=span,
                outputs=final_response,
                status=SpanStatusCode.OK,
                end_time_ns=end_time_ns,
            )

            # Remove the stream_id from the map
            with self._lock:
                self._stream_id_to_span.pop(litellm_call_id)

    def _add_chunk_events(self, span, response_obj):
        from mlflow.entities import SpanEvent

        try:
            for choice in response_obj.choices:
                span.add_event(
                    SpanEvent(
                        name="streaming_chunk",
                        attributes={"delta": json.dumps(choice.delta.model_dump())},
                    )
                )
        except Exception:
            verbose_logger.debug("Error adding chunk events to span", stack_info=True)

    def _construct_input(self, kwargs):
        """Construct span inputs with optional parameters"""
        inputs = {"messages": kwargs.get("messages")}
        if tools := kwargs.get("tools"):
            inputs["tools"] = tools

        for key in ["functions", "tools", "stream", "tool_choice", "user"]:
            if value := kwargs.get("optional_params", {}).pop(key, None):
                inputs[key] = value
        return inputs

    def _extract_attributes(self, kwargs):
        """
        Extract span attributes from kwargs.

        With the latest version of litellm, the standard_logging_object contains
        canonical information for logging. If it is not present, we extract
        subset of attributes from other kwargs.
        """
        attributes = {
            "litellm_call_id": kwargs.get("litellm_call_id"),
            "call_type": kwargs.get("call_type"),
            "model": kwargs.get("model"),
        }
        standard_obj = kwargs.get("standard_logging_object")
        if standard_obj:
            attributes.update(
                {
                    "api_base": standard_obj.get("api_base"),
                    "cache_hit": standard_obj.get("cache_hit"),
                    "usage": {
                        "completion_tokens": standard_obj.get("completion_tokens"),
                        "prompt_tokens": standard_obj.get("prompt_tokens"),
                        "total_tokens": standard_obj.get("total_tokens"),
                    },
                    "raw_llm_response": standard_obj.get("response"),
                    "response_cost": standard_obj.get("response_cost"),
                    "saved_cache_cost": standard_obj.get("saved_cache_cost"),
                }
            )
        else:
            litellm_params = kwargs.get("litellm_params", {})
            attributes.update(
                {
                    "model": kwargs.get("model"),
                    "cache_hit": kwargs.get("cache_hit"),
                    "custom_llm_provider": kwargs.get("custom_llm_provider"),
                    "api_base": litellm_params.get("api_base"),
                    "response_cost": kwargs.get("response_cost"),
                }
            )
        return attributes

    def _get_span_type(self, call_type: Optional[str]) -> str:
        from mlflow.entities import SpanType

        if call_type in ["completion", "acompletion"]:
            return SpanType.LLM
        elif call_type == "embeddings":
            return SpanType.EMBEDDING
        else:
            return SpanType.LLM

    def _start_span_or_trace(self, kwargs, start_time):
        """
        Start an MLflow span or a trace.

        If there is an active span, we start a new span as a child of
        that span. Otherwise, we start a new trace.
        """
        import mlflow

        call_type = kwargs.get("call_type", "completion")
        span_name = f"litellm-{call_type}"
        span_type = self._get_span_type(call_type)
        start_time_ns = int(start_time.timestamp() * 1e9)

        inputs = self._construct_input(kwargs)
        attributes = self._extract_attributes(kwargs)

        if active_span := mlflow.get_current_active_span():  # type: ignore
            return self._client.start_span(
                name=span_name,
                request_id=active_span.request_id,
                parent_id=active_span.span_id,
                span_type=span_type,
                inputs=inputs,
                attributes=attributes,
                start_time_ns=start_time_ns,
            )
        else:
            return self._client.start_trace(
                name=span_name,
                span_type=span_type,
                inputs=inputs,
                attributes=attributes,
                start_time_ns=start_time_ns,
            )

    def _end_span_or_trace(self, span, outputs, end_time_ns, status):
        """End an MLflow span or a trace."""
        if span.parent_id is None:
            self._client.end_trace(
                request_id=span.request_id,
                outputs=outputs,
                status=status,
                end_time_ns=end_time_ns,
            )
        else:
            self._client.end_span(
                request_id=span.request_id,
                span_id=span.span_id,
                outputs=outputs,
                status=status,
                end_time_ns=end_time_ns,
            )