|
|
|
|
|
import asyncio |
|
import os |
|
import uuid |
|
from typing import List, Optional |
|
|
|
import httpx |
|
|
|
from litellm._logging import verbose_logger |
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
HTTPHandler, |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
from litellm.types.utils import StandardLoggingPayload |
|
|
|
|
|
class LiteralAILogger(CustomBatchLogger): |
|
def __init__( |
|
self, |
|
literalai_api_key=None, |
|
literalai_api_url="https://cloud.getliteral.ai", |
|
env=None, |
|
**kwargs, |
|
): |
|
self.literalai_api_url = os.getenv("LITERAL_API_URL") or literalai_api_url |
|
self.headers = { |
|
"Content-Type": "application/json", |
|
"x-api-key": literalai_api_key or os.getenv("LITERAL_API_KEY"), |
|
"x-client-name": "litellm", |
|
} |
|
if env: |
|
self.headers["x-env"] = env |
|
self.async_httpx_client = get_async_httpx_client( |
|
llm_provider=httpxSpecialProvider.LoggingCallback |
|
) |
|
self.sync_http_handler = HTTPHandler() |
|
batch_size = os.getenv("LITERAL_BATCH_SIZE", None) |
|
self.flush_lock = asyncio.Lock() |
|
super().__init__( |
|
**kwargs, |
|
flush_lock=self.flush_lock, |
|
batch_size=int(batch_size) if batch_size else None, |
|
) |
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
try: |
|
verbose_logger.debug( |
|
"Literal AI Layer Logging - kwargs: %s, response_obj: %s", |
|
kwargs, |
|
response_obj, |
|
) |
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) |
|
self.log_queue.append(data) |
|
verbose_logger.debug( |
|
"Literal AI logging: queue length %s, batch size %s", |
|
len(self.log_queue), |
|
self.batch_size, |
|
) |
|
if len(self.log_queue) >= self.batch_size: |
|
self._send_batch() |
|
except Exception: |
|
verbose_logger.exception( |
|
"Literal AI Layer Error - error logging success event." |
|
) |
|
|
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
verbose_logger.info("Literal AI Failure Event Logging!") |
|
try: |
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) |
|
self.log_queue.append(data) |
|
verbose_logger.debug( |
|
"Literal AI logging: queue length %s, batch size %s", |
|
len(self.log_queue), |
|
self.batch_size, |
|
) |
|
if len(self.log_queue) >= self.batch_size: |
|
self._send_batch() |
|
except Exception: |
|
verbose_logger.exception( |
|
"Literal AI Layer Error - error logging failure event." |
|
) |
|
|
|
def _send_batch(self): |
|
if not self.log_queue: |
|
return |
|
|
|
url = f"{self.literalai_api_url}/api/graphql" |
|
query = self._steps_query_builder(self.log_queue) |
|
variables = self._steps_variables_builder(self.log_queue) |
|
try: |
|
response = self.sync_http_handler.post( |
|
url=url, |
|
json={ |
|
"query": query, |
|
"variables": variables, |
|
}, |
|
headers=self.headers, |
|
) |
|
|
|
if response.status_code >= 300: |
|
verbose_logger.error( |
|
f"Literal AI Error: {response.status_code} - {response.text}" |
|
) |
|
else: |
|
verbose_logger.debug( |
|
f"Batch of {len(self.log_queue)} runs successfully created" |
|
) |
|
except Exception: |
|
verbose_logger.exception("Literal AI Layer Error") |
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): |
|
try: |
|
verbose_logger.debug( |
|
"Literal AI Async Layer Logging - kwargs: %s, response_obj: %s", |
|
kwargs, |
|
response_obj, |
|
) |
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) |
|
self.log_queue.append(data) |
|
verbose_logger.debug( |
|
"Literal AI logging: queue length %s, batch size %s", |
|
len(self.log_queue), |
|
self.batch_size, |
|
) |
|
if len(self.log_queue) >= self.batch_size: |
|
await self.flush_queue() |
|
except Exception: |
|
verbose_logger.exception( |
|
"Literal AI Layer Error - error logging async success event." |
|
) |
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
verbose_logger.info("Literal AI Failure Event Logging!") |
|
try: |
|
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) |
|
self.log_queue.append(data) |
|
verbose_logger.debug( |
|
"Literal AI logging: queue length %s, batch size %s", |
|
len(self.log_queue), |
|
self.batch_size, |
|
) |
|
if len(self.log_queue) >= self.batch_size: |
|
await self.flush_queue() |
|
except Exception: |
|
verbose_logger.exception( |
|
"Literal AI Layer Error - error logging async failure event." |
|
) |
|
|
|
async def async_send_batch(self): |
|
if not self.log_queue: |
|
return |
|
|
|
url = f"{self.literalai_api_url}/api/graphql" |
|
query = self._steps_query_builder(self.log_queue) |
|
variables = self._steps_variables_builder(self.log_queue) |
|
|
|
try: |
|
response = await self.async_httpx_client.post( |
|
url=url, |
|
json={ |
|
"query": query, |
|
"variables": variables, |
|
}, |
|
headers=self.headers, |
|
) |
|
if response.status_code >= 300: |
|
verbose_logger.error( |
|
f"Literal AI Error: {response.status_code} - {response.text}" |
|
) |
|
else: |
|
verbose_logger.debug( |
|
f"Batch of {len(self.log_queue)} runs successfully created" |
|
) |
|
except httpx.HTTPStatusError as e: |
|
verbose_logger.exception( |
|
f"Literal AI HTTP Error: {e.response.status_code} - {e.response.text}" |
|
) |
|
except Exception: |
|
verbose_logger.exception("Literal AI Layer Error") |
|
|
|
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time) -> dict: |
|
logging_payload: Optional[StandardLoggingPayload] = kwargs.get( |
|
"standard_logging_object", None |
|
) |
|
|
|
if logging_payload is None: |
|
raise ValueError("standard_logging_object not found in kwargs") |
|
clean_metadata = logging_payload["metadata"] |
|
metadata = kwargs.get("litellm_params", {}).get("metadata", {}) |
|
|
|
settings = logging_payload["model_parameters"] |
|
messages = logging_payload["messages"] |
|
response = logging_payload["response"] |
|
choices: List = [] |
|
if isinstance(response, dict) and "choices" in response: |
|
choices = response["choices"] |
|
message_completion = choices[0]["message"] if choices else None |
|
prompt_id = None |
|
variables = None |
|
|
|
if messages and isinstance(messages, list) and isinstance(messages[0], dict): |
|
for message in messages: |
|
if literal_prompt := getattr(message, "__literal_prompt__", None): |
|
prompt_id = literal_prompt.get("prompt_id") |
|
variables = literal_prompt.get("variables") |
|
message["uuid"] = literal_prompt.get("uuid") |
|
message["templated"] = True |
|
|
|
tools = settings.pop("tools", None) |
|
|
|
step = { |
|
"id": metadata.get("step_id", str(uuid.uuid4())), |
|
"error": logging_payload["error_str"], |
|
"name": kwargs.get("model", ""), |
|
"threadId": metadata.get("literalai_thread_id", None), |
|
"parentId": metadata.get("literalai_parent_id", None), |
|
"rootRunId": metadata.get("literalai_root_run_id", None), |
|
"input": None, |
|
"output": None, |
|
"type": "llm", |
|
"tags": metadata.get("tags", metadata.get("literalai_tags", None)), |
|
"startTime": str(start_time), |
|
"endTime": str(end_time), |
|
"metadata": clean_metadata, |
|
"generation": { |
|
"inputTokenCount": logging_payload["prompt_tokens"], |
|
"outputTokenCount": logging_payload["completion_tokens"], |
|
"tokenCount": logging_payload["total_tokens"], |
|
"promptId": prompt_id, |
|
"variables": variables, |
|
"provider": kwargs.get("custom_llm_provider", "litellm"), |
|
"model": kwargs.get("model", ""), |
|
"duration": (end_time - start_time).total_seconds(), |
|
"settings": settings, |
|
"messages": messages, |
|
"messageCompletion": message_completion, |
|
"tools": tools, |
|
}, |
|
} |
|
return step |
|
|
|
def _steps_query_variables_builder(self, steps): |
|
generated = "" |
|
for id in range(len(steps)): |
|
generated += f"""$id_{id}: String! |
|
$threadId_{id}: String |
|
$rootRunId_{id}: String |
|
$type_{id}: StepType |
|
$startTime_{id}: DateTime |
|
$endTime_{id}: DateTime |
|
$error_{id}: String |
|
$input_{id}: Json |
|
$output_{id}: Json |
|
$metadata_{id}: Json |
|
$parentId_{id}: String |
|
$name_{id}: String |
|
$tags_{id}: [String!] |
|
$generation_{id}: GenerationPayloadInput |
|
$scores_{id}: [ScorePayloadInput!] |
|
$attachments_{id}: [AttachmentPayloadInput!] |
|
""" |
|
return generated |
|
|
|
def _steps_ingest_steps_builder(self, steps): |
|
generated = "" |
|
for id in range(len(steps)): |
|
generated += f""" |
|
step{id}: ingestStep( |
|
id: $id_{id} |
|
threadId: $threadId_{id} |
|
rootRunId: $rootRunId_{id} |
|
startTime: $startTime_{id} |
|
endTime: $endTime_{id} |
|
type: $type_{id} |
|
error: $error_{id} |
|
input: $input_{id} |
|
output: $output_{id} |
|
metadata: $metadata_{id} |
|
parentId: $parentId_{id} |
|
name: $name_{id} |
|
tags: $tags_{id} |
|
generation: $generation_{id} |
|
scores: $scores_{id} |
|
attachments: $attachments_{id} |
|
) {{ |
|
ok |
|
message |
|
}} |
|
""" |
|
return generated |
|
|
|
def _steps_query_builder(self, steps): |
|
return f""" |
|
mutation AddStep({self._steps_query_variables_builder(steps)}) {{ |
|
{self._steps_ingest_steps_builder(steps)} |
|
}} |
|
""" |
|
|
|
def _steps_variables_builder(self, steps): |
|
def serialize_step(event, id): |
|
result = {} |
|
|
|
for key, value in event.items(): |
|
|
|
if value is not None: |
|
result[f"{key}_{id}"] = value |
|
|
|
return result |
|
|
|
variables = {} |
|
for i in range(len(steps)): |
|
step = steps[i] |
|
variables.update(serialize_step(step, i)) |
|
return variables |
|
|