File size: 6,506 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
BETA

This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.

Users can use this instead of sending their SpendLogs to their Postgres database.
"""

import asyncio
import json
import os
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Optional

if TYPE_CHECKING:
    from litellm.proxy._types import SpendLogsPayload
else:
    SpendLogsPayload = Any

from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
    get_async_httpx_client,
    httpxSpecialProvider,
)


class GcsPubSubLogger(CustomBatchLogger):
    def __init__(
        self,
        project_id: Optional[str] = None,
        topic_id: Optional[str] = None,
        credentials_path: Optional[str] = None,
        **kwargs,
    ):
        """
        Initialize Google Cloud Pub/Sub publisher

        Args:
            project_id (str): Google Cloud project ID
            topic_id (str): Pub/Sub topic ID
            credentials_path (str, optional): Path to Google Cloud credentials JSON file
        """
        from litellm.proxy.utils import _premium_user_check

        _premium_user_check()

        self.async_httpx_client = get_async_httpx_client(
            llm_provider=httpxSpecialProvider.LoggingCallback
        )

        self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
        self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
        self.path_service_account_json = credentials_path or os.getenv(
            "GCS_PATH_SERVICE_ACCOUNT"
        )

        if not self.project_id or not self.topic_id:
            raise ValueError("Both project_id and topic_id must be provided")

        self.flush_lock = asyncio.Lock()
        super().__init__(**kwargs, flush_lock=self.flush_lock)
        asyncio.create_task(self.periodic_flush())
        self.log_queue: List[SpendLogsPayload] = []

    async def construct_request_headers(self) -> Dict[str, str]:
        """Construct authorization headers using Vertex AI auth"""
        from litellm import vertex_chat_completion

        _auth_header, vertex_project = (
            await vertex_chat_completion._ensure_access_token_async(
                credentials=self.path_service_account_json,
                project_id=None,
                custom_llm_provider="vertex_ai",
            )
        )

        auth_header, _ = vertex_chat_completion._get_token_and_url(
            model="pub-sub",
            auth_header=_auth_header,
            vertex_credentials=self.path_service_account_json,
            vertex_project=vertex_project,
            vertex_location=None,
            gemini_api_key=None,
            stream=None,
            custom_llm_provider="vertex_ai",
            api_base=None,
        )

        headers = {
            "Authorization": f"Bearer {auth_header}",
            "Content-Type": "application/json",
        }
        return headers

    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
        """
        Async Log success events to GCS PubSub Topic

        - Creates a SpendLogsPayload
        - Adds to batch queue
        - Flushes based on CustomBatchLogger settings

        Raises:
            Raises a NON Blocking verbose_logger.exception if an error occurs
        """
        from litellm.proxy.spend_tracking.spend_tracking_utils import (
            get_logging_payload,
        )
        from litellm.proxy.utils import _premium_user_check

        _premium_user_check()

        try:
            verbose_logger.debug(
                "PubSub: Logging - Enters logging function for model %s", kwargs
            )
            spend_logs_payload = get_logging_payload(
                kwargs=kwargs,
                response_obj=response_obj,
                start_time=start_time,
                end_time=end_time,
            )
            self.log_queue.append(spend_logs_payload)

            if len(self.log_queue) >= self.batch_size:
                await self.async_send_batch()

        except Exception as e:
            verbose_logger.exception(
                f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
            )
            pass

    async def async_send_batch(self):
        """
        Sends the batch of messages to Pub/Sub
        """
        try:
            if not self.log_queue:
                return

            verbose_logger.debug(
                f"PubSub - about to flush {len(self.log_queue)} events"
            )

            for message in self.log_queue:
                await self.publish_message(message)

        except Exception as e:
            verbose_logger.exception(
                f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
            )
        finally:
            self.log_queue.clear()

    async def publish_message(
        self, message: SpendLogsPayload
    ) -> Optional[Dict[str, Any]]:
        """
        Publish message to Google Cloud Pub/Sub using REST API

        Args:
            message: Message to publish (dict or string)

        Returns:
            dict: Published message response
        """
        try:
            headers = await self.construct_request_headers()

            # Prepare message data
            if isinstance(message, str):
                message_data = message
            else:
                message_data = json.dumps(message, default=str)

            # Base64 encode the message
            import base64

            encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
                "utf-8"
            )

            # Construct request body
            request_body = {"messages": [{"data": encoded_message}]}

            url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"

            response = await self.async_httpx_client.post(
                url=url, headers=headers, json=request_body
            )

            if response.status_code not in [200, 202]:
                verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
                raise Exception(f"Failed to publish message: {response.text}")

            verbose_logger.debug("Pub/Sub response: %s", response.text)
            return response.json()

        except Exception as e:
            verbose_logger.error("Pub/Sub publish error: %s", str(e))
            return None