|
""" |
|
BETA |
|
|
|
This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub. |
|
|
|
Users can use this instead of sending their SpendLogs to their Postgres database. |
|
""" |
|
|
|
import asyncio |
|
import json |
|
import os |
|
import traceback |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional |
|
|
|
if TYPE_CHECKING: |
|
from litellm.proxy._types import SpendLogsPayload |
|
else: |
|
SpendLogsPayload = Any |
|
|
|
from litellm._logging import verbose_logger |
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
|
|
|
|
class GcsPubSubLogger(CustomBatchLogger): |
|
def __init__( |
|
self, |
|
project_id: Optional[str] = None, |
|
topic_id: Optional[str] = None, |
|
credentials_path: Optional[str] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Initialize Google Cloud Pub/Sub publisher |
|
|
|
Args: |
|
project_id (str): Google Cloud project ID |
|
topic_id (str): Pub/Sub topic ID |
|
credentials_path (str, optional): Path to Google Cloud credentials JSON file |
|
""" |
|
from litellm.proxy.utils import _premium_user_check |
|
|
|
_premium_user_check() |
|
|
|
self.async_httpx_client = get_async_httpx_client( |
|
llm_provider=httpxSpecialProvider.LoggingCallback |
|
) |
|
|
|
self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID") |
|
self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID") |
|
self.path_service_account_json = credentials_path or os.getenv( |
|
"GCS_PATH_SERVICE_ACCOUNT" |
|
) |
|
|
|
if not self.project_id or not self.topic_id: |
|
raise ValueError("Both project_id and topic_id must be provided") |
|
|
|
self.flush_lock = asyncio.Lock() |
|
super().__init__(**kwargs, flush_lock=self.flush_lock) |
|
asyncio.create_task(self.periodic_flush()) |
|
self.log_queue: List[SpendLogsPayload] = [] |
|
|
|
async def construct_request_headers(self) -> Dict[str, str]: |
|
"""Construct authorization headers using Vertex AI auth""" |
|
from litellm import vertex_chat_completion |
|
|
|
_auth_header, vertex_project = ( |
|
await vertex_chat_completion._ensure_access_token_async( |
|
credentials=self.path_service_account_json, |
|
project_id=None, |
|
custom_llm_provider="vertex_ai", |
|
) |
|
) |
|
|
|
auth_header, _ = vertex_chat_completion._get_token_and_url( |
|
model="pub-sub", |
|
auth_header=_auth_header, |
|
vertex_credentials=self.path_service_account_json, |
|
vertex_project=vertex_project, |
|
vertex_location=None, |
|
gemini_api_key=None, |
|
stream=None, |
|
custom_llm_provider="vertex_ai", |
|
api_base=None, |
|
) |
|
|
|
headers = { |
|
"Authorization": f"Bearer {auth_header}", |
|
"Content-Type": "application/json", |
|
} |
|
return headers |
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
""" |
|
Async Log success events to GCS PubSub Topic |
|
|
|
- Creates a SpendLogsPayload |
|
- Adds to batch queue |
|
- Flushes based on CustomBatchLogger settings |
|
|
|
Raises: |
|
Raises a NON Blocking verbose_logger.exception if an error occurs |
|
""" |
|
from litellm.proxy.spend_tracking.spend_tracking_utils import ( |
|
get_logging_payload, |
|
) |
|
from litellm.proxy.utils import _premium_user_check |
|
|
|
_premium_user_check() |
|
|
|
try: |
|
verbose_logger.debug( |
|
"PubSub: Logging - Enters logging function for model %s", kwargs |
|
) |
|
spend_logs_payload = get_logging_payload( |
|
kwargs=kwargs, |
|
response_obj=response_obj, |
|
start_time=start_time, |
|
end_time=end_time, |
|
) |
|
self.log_queue.append(spend_logs_payload) |
|
|
|
if len(self.log_queue) >= self.batch_size: |
|
await self.async_send_batch() |
|
|
|
except Exception as e: |
|
verbose_logger.exception( |
|
f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}" |
|
) |
|
pass |
|
|
|
async def async_send_batch(self): |
|
""" |
|
Sends the batch of messages to Pub/Sub |
|
""" |
|
try: |
|
if not self.log_queue: |
|
return |
|
|
|
verbose_logger.debug( |
|
f"PubSub - about to flush {len(self.log_queue)} events" |
|
) |
|
|
|
for message in self.log_queue: |
|
await self.publish_message(message) |
|
|
|
except Exception as e: |
|
verbose_logger.exception( |
|
f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}" |
|
) |
|
finally: |
|
self.log_queue.clear() |
|
|
|
async def publish_message( |
|
self, message: SpendLogsPayload |
|
) -> Optional[Dict[str, Any]]: |
|
""" |
|
Publish message to Google Cloud Pub/Sub using REST API |
|
|
|
Args: |
|
message: Message to publish (dict or string) |
|
|
|
Returns: |
|
dict: Published message response |
|
""" |
|
try: |
|
headers = await self.construct_request_headers() |
|
|
|
|
|
if isinstance(message, str): |
|
message_data = message |
|
else: |
|
message_data = json.dumps(message, default=str) |
|
|
|
|
|
import base64 |
|
|
|
encoded_message = base64.b64encode(message_data.encode("utf-8")).decode( |
|
"utf-8" |
|
) |
|
|
|
|
|
request_body = {"messages": [{"data": encoded_message}]} |
|
|
|
url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish" |
|
|
|
response = await self.async_httpx_client.post( |
|
url=url, headers=headers, json=request_body |
|
) |
|
|
|
if response.status_code not in [200, 202]: |
|
verbose_logger.error("Pub/Sub publish error: %s", str(response.text)) |
|
raise Exception(f"Failed to publish message: {response.text}") |
|
|
|
verbose_logger.debug("Pub/Sub response: %s", response.text) |
|
return response.json() |
|
|
|
except Exception as e: |
|
verbose_logger.error("Pub/Sub publish error: %s", str(e)) |
|
return None |
|
|