File size: 6,177 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import asyncio
import traceback
from typing import Optional, Union, cast

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.core_helpers import (
    _get_parent_otel_span_from_kwargs,
    get_litellm_metadata_from_kwargs,
)
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking


@log_db_metrics
async def _PROXY_track_cost_callback(
    kwargs,  # kwargs to completion
    completion_response: litellm.ModelResponse,  # response from completion
    start_time=None,
    end_time=None,  # start/end time for completion
):
    from litellm.proxy.proxy_server import (
        prisma_client,
        proxy_logging_obj,
        update_cache,
        update_database,
    )

    verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
    try:
        verbose_proxy_logger.debug(
            f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
        )
        parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
        litellm_params = kwargs.get("litellm_params", {}) or {}
        end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
        metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
        user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
        team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
        org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
        key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
        end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
        sl_object: Optional[StandardLoggingPayload] = kwargs.get(
            "standard_logging_object", None
        )
        response_cost = (
            sl_object.get("response_cost", None)
            if sl_object is not None
            else kwargs.get("response_cost", None)
        )

        if response_cost is not None:
            user_api_key = metadata.get("user_api_key", None)
            if kwargs.get("cache_hit", False) is True:
                response_cost = 0.0
                verbose_proxy_logger.info(
                    f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
                )

            verbose_proxy_logger.debug(
                f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
            )
            if _should_track_cost_callback(
                user_api_key=user_api_key,
                user_id=user_id,
                team_id=team_id,
                end_user_id=end_user_id,
            ):
                ## UPDATE DATABASE
                await update_database(
                    token=user_api_key,
                    response_cost=response_cost,
                    user_id=user_id,
                    end_user_id=end_user_id,
                    team_id=team_id,
                    kwargs=kwargs,
                    completion_response=completion_response,
                    start_time=start_time,
                    end_time=end_time,
                    org_id=org_id,
                )

                # update cache
                asyncio.create_task(
                    update_cache(
                        token=user_api_key,
                        user_id=user_id,
                        end_user_id=end_user_id,
                        response_cost=response_cost,
                        team_id=team_id,
                        parent_otel_span=parent_otel_span,
                    )
                )

                await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
                    token=user_api_key,
                    key_alias=key_alias,
                    end_user_id=end_user_id,
                    response_cost=response_cost,
                    max_budget=end_user_max_budget,
                )
            else:
                raise Exception(
                    "User API key and team id and user id missing from custom callback."
                )
        else:
            if kwargs["stream"] is not True or (
                kwargs["stream"] is True and "complete_streaming_response" in kwargs
            ):
                if sl_object is not None:
                    cost_tracking_failure_debug_info: Union[dict, str] = (
                        sl_object["response_cost_failure_debug_info"]  # type: ignore
                        or "response_cost_failure_debug_info is None in standard_logging_object"
                    )
                else:
                    cost_tracking_failure_debug_info = (
                        "standard_logging_object not found"
                    )
                model = kwargs.get("model")
                raise Exception(
                    f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
                )
    except Exception as e:
        error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
        model = kwargs.get("model", "")
        metadata = kwargs.get("litellm_params", {}).get("metadata", {})
        error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
        asyncio.create_task(
            proxy_logging_obj.failed_tracking_alert(
                error_message=error_msg,
                failing_model=model,
            )
        )
        verbose_proxy_logger.exception("Error in tracking cost callback - %s", str(e))


def _should_track_cost_callback(
    user_api_key: Optional[str],
    user_id: Optional[str],
    team_id: Optional[str],
    end_user_id: Optional[str],
) -> bool:
    """
    Determine if the cost callback should be tracked based on the kwargs
    """
    if (
        user_api_key is not None
        or user_id is not None
        or team_id is not None
        or end_user_id is not None
    ):
        return True
    return False