|
""" |
|
Send logs to Argilla for annotation |
|
""" |
|
|
|
import asyncio |
|
import json |
|
import os |
|
import random |
|
import types |
|
from typing import Any, Dict, List, Optional |
|
|
|
import httpx |
|
from pydantic import BaseModel |
|
|
|
import litellm |
|
from litellm._logging import verbose_logger |
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
from litellm.types.integrations.argilla import ( |
|
SUPPORTED_PAYLOAD_FIELDS, |
|
ArgillaCredentialsObject, |
|
ArgillaItem, |
|
) |
|
from litellm.types.utils import StandardLoggingPayload |
|
|
|
|
|
def is_serializable(value): |
|
non_serializable_types = ( |
|
types.CoroutineType, |
|
types.FunctionType, |
|
types.GeneratorType, |
|
BaseModel, |
|
) |
|
return not isinstance(value, non_serializable_types) |
|
|
|
|
|
class ArgillaLogger(CustomBatchLogger): |
|
def __init__( |
|
self, |
|
argilla_api_key: Optional[str] = None, |
|
argilla_dataset_name: Optional[str] = None, |
|
argilla_base_url: Optional[str] = None, |
|
**kwargs, |
|
): |
|
if litellm.argilla_transformation_object is None: |
|
raise Exception( |
|
"'litellm.argilla_transformation_object' is required, to log your payload to Argilla." |
|
) |
|
self.validate_argilla_transformation_object( |
|
litellm.argilla_transformation_object |
|
) |
|
self.argilla_transformation_object = litellm.argilla_transformation_object |
|
self.default_credentials = self.get_credentials_from_env( |
|
argilla_api_key=argilla_api_key, |
|
argilla_dataset_name=argilla_dataset_name, |
|
argilla_base_url=argilla_base_url, |
|
) |
|
self.sampling_rate: float = ( |
|
float(os.getenv("ARGILLA_SAMPLING_RATE")) |
|
if os.getenv("ARGILLA_SAMPLING_RATE") is not None |
|
and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() |
|
else 1.0 |
|
) |
|
|
|
self.async_httpx_client = get_async_httpx_client( |
|
llm_provider=httpxSpecialProvider.LoggingCallback |
|
) |
|
_batch_size = ( |
|
os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size |
|
) |
|
if _batch_size: |
|
self.batch_size = int(_batch_size) |
|
asyncio.create_task(self.periodic_flush()) |
|
self.flush_lock = asyncio.Lock() |
|
super().__init__(**kwargs, flush_lock=self.flush_lock) |
|
|
|
def validate_argilla_transformation_object( |
|
self, argilla_transformation_object: Dict[str, Any] |
|
): |
|
if not isinstance(argilla_transformation_object, dict): |
|
raise Exception( |
|
"'argilla_transformation_object' must be a dictionary, to log your payload to Argilla." |
|
) |
|
|
|
for v in argilla_transformation_object.values(): |
|
if v not in SUPPORTED_PAYLOAD_FIELDS: |
|
raise Exception( |
|
f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key." |
|
) |
|
|
|
def get_credentials_from_env( |
|
self, |
|
argilla_api_key: Optional[str], |
|
argilla_dataset_name: Optional[str], |
|
argilla_base_url: Optional[str], |
|
) -> ArgillaCredentialsObject: |
|
|
|
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") |
|
if _credentials_api_key is None: |
|
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") |
|
|
|
_credentials_base_url = ( |
|
argilla_base_url |
|
or os.getenv("ARGILLA_BASE_URL") |
|
or "http://localhost:6900/" |
|
) |
|
if _credentials_base_url is None: |
|
raise Exception( |
|
"Invalid Argilla Base URL given. _credentials_base_url=None." |
|
) |
|
|
|
_credentials_dataset_name = ( |
|
argilla_dataset_name |
|
or os.getenv("ARGILLA_DATASET_NAME") |
|
or "litellm-completion" |
|
) |
|
if _credentials_dataset_name is None: |
|
raise Exception("Invalid Argilla Dataset give. Value=None.") |
|
else: |
|
dataset_response = litellm.module_level_client.get( |
|
url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}", |
|
headers={"X-Argilla-Api-Key": _credentials_api_key}, |
|
) |
|
json_response = dataset_response.json() |
|
if ( |
|
"items" in json_response |
|
and isinstance(json_response["items"], list) |
|
and len(json_response["items"]) > 0 |
|
): |
|
_credentials_dataset_name = json_response["items"][0]["id"] |
|
|
|
return ArgillaCredentialsObject( |
|
ARGILLA_API_KEY=_credentials_api_key, |
|
ARGILLA_BASE_URL=_credentials_base_url, |
|
ARGILLA_DATASET_NAME=_credentials_dataset_name, |
|
) |
|
|
|
def get_chat_messages( |
|
self, payload: StandardLoggingPayload |
|
) -> List[Dict[str, Any]]: |
|
payload_messages = payload.get("messages", None) |
|
|
|
if payload_messages is None: |
|
raise Exception("No chat messages found in payload.") |
|
|
|
if ( |
|
isinstance(payload_messages, list) |
|
and len(payload_messages) > 0 |
|
and isinstance(payload_messages[0], dict) |
|
): |
|
return payload_messages |
|
elif isinstance(payload_messages, dict): |
|
return [payload_messages] |
|
else: |
|
raise Exception(f"Invalid chat messages format: {payload_messages}") |
|
|
|
def get_str_response(self, payload: StandardLoggingPayload) -> str: |
|
response = payload["response"] |
|
|
|
if response is None: |
|
raise Exception("No response found in payload.") |
|
|
|
if isinstance(response, str): |
|
return response |
|
elif isinstance(response, dict): |
|
return ( |
|
response.get("choices", [{}])[0].get("message", {}).get("content", "") |
|
) |
|
else: |
|
raise Exception(f"Invalid response format: {response}") |
|
|
|
def _prepare_log_data( |
|
self, kwargs, response_obj, start_time, end_time |
|
) -> Optional[ArgillaItem]: |
|
try: |
|
|
|
payload: Optional[StandardLoggingPayload] = kwargs.get( |
|
"standard_logging_object", None |
|
) |
|
|
|
if payload is None: |
|
raise Exception("Error logging request payload. Payload=none.") |
|
|
|
argilla_message = self.get_chat_messages(payload) |
|
argilla_response = self.get_str_response(payload) |
|
argilla_item: ArgillaItem = {"fields": {}} |
|
for k, v in self.argilla_transformation_object.items(): |
|
if v == "messages": |
|
argilla_item["fields"][k] = argilla_message |
|
elif v == "response": |
|
argilla_item["fields"][k] = argilla_response |
|
else: |
|
argilla_item["fields"][k] = payload.get(v, None) |
|
|
|
return argilla_item |
|
except Exception: |
|
raise |
|
|
|
def _send_batch(self): |
|
if not self.log_queue: |
|
return |
|
|
|
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] |
|
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] |
|
|
|
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" |
|
|
|
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] |
|
|
|
headers = {"X-Argilla-Api-Key": argilla_api_key} |
|
|
|
try: |
|
response = litellm.module_level_client.post( |
|
url=url, |
|
json=self.log_queue, |
|
headers=headers, |
|
) |
|
|
|
if response.status_code >= 300: |
|
verbose_logger.error( |
|
f"Argilla Error: {response.status_code} - {response.text}" |
|
) |
|
else: |
|
verbose_logger.debug( |
|
f"Batch of {len(self.log_queue)} runs successfully created" |
|
) |
|
|
|
self.log_queue.clear() |
|
except Exception: |
|
verbose_logger.exception("Argilla Layer Error - Error sending batch.") |
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
try: |
|
sampling_rate = ( |
|
float(os.getenv("LANGSMITH_SAMPLING_RATE")) |
|
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None |
|
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() |
|
else 1.0 |
|
) |
|
random_sample = random.random() |
|
if random_sample > sampling_rate: |
|
verbose_logger.info( |
|
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( |
|
sampling_rate, random_sample |
|
) |
|
) |
|
return |
|
verbose_logger.debug( |
|
"Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s", |
|
kwargs, |
|
response_obj, |
|
) |
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) |
|
if data is None: |
|
return |
|
|
|
self.log_queue.append(data) |
|
verbose_logger.debug( |
|
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." |
|
) |
|
|
|
if len(self.log_queue) >= self.batch_size: |
|
self._send_batch() |
|
|
|
except Exception: |
|
verbose_logger.exception("Langsmith Layer Error - log_success_event error") |
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
try: |
|
sampling_rate = self.sampling_rate |
|
random_sample = random.random() |
|
if random_sample > sampling_rate: |
|
verbose_logger.info( |
|
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( |
|
sampling_rate, random_sample |
|
) |
|
) |
|
return |
|
verbose_logger.debug( |
|
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s", |
|
kwargs, |
|
response_obj, |
|
) |
|
payload: Optional[StandardLoggingPayload] = kwargs.get( |
|
"standard_logging_object", None |
|
) |
|
|
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) |
|
|
|
|
|
for callback in litellm.callbacks: |
|
if isinstance(callback, CustomLogger): |
|
try: |
|
if data is None: |
|
break |
|
data = await callback.async_dataset_hook(data, payload) |
|
except NotImplementedError: |
|
pass |
|
|
|
if data is None: |
|
return |
|
|
|
self.log_queue.append(data) |
|
verbose_logger.debug( |
|
"Langsmith logging: queue length %s, batch size %s", |
|
len(self.log_queue), |
|
self.batch_size, |
|
) |
|
if len(self.log_queue) >= self.batch_size: |
|
await self.flush_queue() |
|
except Exception: |
|
verbose_logger.exception( |
|
"Argilla Layer Error - error logging async success event." |
|
) |
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
sampling_rate = self.sampling_rate |
|
random_sample = random.random() |
|
if random_sample > sampling_rate: |
|
verbose_logger.info( |
|
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( |
|
sampling_rate, random_sample |
|
) |
|
) |
|
return |
|
verbose_logger.info("Langsmith Failure Event Logging!") |
|
try: |
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) |
|
self.log_queue.append(data) |
|
verbose_logger.debug( |
|
"Langsmith logging: queue length %s, batch size %s", |
|
len(self.log_queue), |
|
self.batch_size, |
|
) |
|
if len(self.log_queue) >= self.batch_size: |
|
await self.flush_queue() |
|
except Exception: |
|
verbose_logger.exception( |
|
"Langsmith Layer Error - error logging async failure event." |
|
) |
|
|
|
async def async_send_batch(self): |
|
""" |
|
sends runs to /batch endpoint |
|
|
|
Sends runs from self.log_queue |
|
|
|
Returns: None |
|
|
|
Raises: Does not raise an exception, will only verbose_logger.exception() |
|
""" |
|
if not self.log_queue: |
|
return |
|
|
|
argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] |
|
argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] |
|
|
|
url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" |
|
|
|
argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] |
|
|
|
headers = {"X-Argilla-Api-Key": argilla_api_key} |
|
|
|
try: |
|
response = await self.async_httpx_client.put( |
|
url=url, |
|
data=json.dumps( |
|
{ |
|
"items": self.log_queue, |
|
} |
|
), |
|
headers=headers, |
|
timeout=60000, |
|
) |
|
response.raise_for_status() |
|
|
|
if response.status_code >= 300: |
|
verbose_logger.error( |
|
f"Argilla Error: {response.status_code} - {response.text}" |
|
) |
|
else: |
|
verbose_logger.debug( |
|
"Batch of %s runs successfully created", len(self.log_queue) |
|
) |
|
except httpx.HTTPStatusError: |
|
verbose_logger.exception("Argilla HTTP Error") |
|
except Exception: |
|
verbose_logger.exception("Argilla Layer Error") |
|
|