|
|
|
|
|
|
|
import copy |
|
import os |
|
from datetime import datetime |
|
from typing import Optional, Dict |
|
|
|
import httpx |
|
from pydantic import BaseModel |
|
|
|
import litellm |
|
from litellm import verbose_logger |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
HTTPHandler, |
|
get_async_httpx_client, |
|
httpxSpecialProvider, |
|
) |
|
from litellm.utils import print_verbose |
|
|
|
global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback) |
|
global_braintrust_sync_http_handler = HTTPHandler() |
|
API_BASE = "https://api.braintrustdata.com/v1" |
|
|
|
|
|
def get_utc_datetime(): |
|
import datetime as dt |
|
from datetime import datetime |
|
|
|
if hasattr(dt, "UTC"): |
|
return datetime.now(dt.UTC) |
|
else: |
|
return datetime.utcnow() |
|
|
|
|
|
class BraintrustLogger(CustomLogger): |
|
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None: |
|
super().__init__() |
|
self.validate_environment(api_key=api_key) |
|
self.api_base = api_base or API_BASE |
|
self.default_project_id = None |
|
self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") |
|
self.headers = { |
|
"Authorization": "Bearer " + self.api_key, |
|
"Content-Type": "application/json", |
|
} |
|
self._project_id_cache: Dict[str, str] = {} |
|
|
|
def validate_environment(self, api_key: Optional[str]): |
|
""" |
|
Expects |
|
BRAINTRUST_API_KEY |
|
|
|
in the environment |
|
""" |
|
missing_keys = [] |
|
if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None: |
|
missing_keys.append("BRAINTRUST_API_KEY") |
|
|
|
if len(missing_keys) > 0: |
|
raise Exception("Missing keys={} in environment.".format(missing_keys)) |
|
|
|
def get_project_id_sync(self, project_name: str) -> str: |
|
""" |
|
Get project ID from name, using cache if available. |
|
If project doesn't exist, creates it. |
|
""" |
|
if project_name in self._project_id_cache: |
|
return self._project_id_cache[project_name] |
|
|
|
try: |
|
response = global_braintrust_sync_http_handler.post( |
|
f"{self.api_base}/project", headers=self.headers, json={"name": project_name} |
|
) |
|
project_dict = response.json() |
|
project_id = project_dict["id"] |
|
self._project_id_cache[project_name] = project_id |
|
return project_id |
|
except httpx.HTTPStatusError as e: |
|
raise Exception(f"Failed to register project: {e.response.text}") |
|
|
|
async def get_project_id_async(self, project_name: str) -> str: |
|
""" |
|
Async version of get_project_id_sync |
|
""" |
|
if project_name in self._project_id_cache: |
|
return self._project_id_cache[project_name] |
|
|
|
try: |
|
response = await global_braintrust_http_handler.post( |
|
f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name} |
|
) |
|
project_dict = response.json() |
|
project_id = project_dict["id"] |
|
self._project_id_cache[project_name] = project_id |
|
return project_id |
|
except httpx.HTTPStatusError as e: |
|
raise Exception(f"Failed to register project: {e.response.text}") |
|
|
|
@staticmethod |
|
def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict: |
|
""" |
|
Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_" |
|
and overwrites litellm_params.metadata if already included. |
|
|
|
For example if you want to append your trace to an existing `trace_id` via header, send |
|
`headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request. |
|
""" |
|
if litellm_params is None: |
|
return metadata |
|
|
|
if litellm_params.get("proxy_server_request") is None: |
|
return metadata |
|
|
|
if metadata is None: |
|
metadata = {} |
|
|
|
proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} |
|
|
|
for metadata_param_key in proxy_headers: |
|
if metadata_param_key.startswith("braintrust"): |
|
trace_param_key = metadata_param_key.replace("braintrust", "", 1) |
|
if trace_param_key in metadata: |
|
verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header") |
|
else: |
|
verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header") |
|
metadata[trace_param_key] = proxy_headers.get(metadata_param_key) |
|
|
|
return metadata |
|
|
|
async def create_default_project_and_experiment(self): |
|
project = await global_braintrust_http_handler.post( |
|
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} |
|
) |
|
|
|
project_dict = project.json() |
|
|
|
self.default_project_id = project_dict["id"] |
|
|
|
def create_sync_default_project_and_experiment(self): |
|
project = global_braintrust_sync_http_handler.post( |
|
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} |
|
) |
|
|
|
project_dict = project.json() |
|
|
|
self.default_project_id = project_dict["id"] |
|
|
|
def log_success_event( |
|
self, kwargs, response_obj, start_time, end_time |
|
): |
|
verbose_logger.debug("REACHES BRAINTRUST SUCCESS") |
|
try: |
|
litellm_call_id = kwargs.get("litellm_call_id") |
|
prompt = {"messages": kwargs.get("messages")} |
|
output = None |
|
choices = [] |
|
if response_obj is not None and ( |
|
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) |
|
): |
|
output = None |
|
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): |
|
output = response_obj["choices"][0]["message"].json() |
|
choices = response_obj["choices"] |
|
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): |
|
output = response_obj.choices[0].text |
|
choices = response_obj.choices |
|
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): |
|
output = response_obj["data"] |
|
|
|
litellm_params = kwargs.get("litellm_params", {}) |
|
metadata = litellm_params.get("metadata", {}) or {} |
|
metadata = self.add_metadata_from_header(litellm_params, metadata) |
|
clean_metadata = {} |
|
try: |
|
metadata = copy.deepcopy(metadata) |
|
except Exception: |
|
new_metadata = {} |
|
for key, value in metadata.items(): |
|
if ( |
|
isinstance(value, list) |
|
or isinstance(value, dict) |
|
or isinstance(value, str) |
|
or isinstance(value, int) |
|
or isinstance(value, float) |
|
): |
|
new_metadata[key] = copy.deepcopy(value) |
|
metadata = new_metadata |
|
|
|
|
|
project_id = metadata.get("project_id") |
|
if project_id is None: |
|
project_name = metadata.get("project_name") |
|
project_id = self.get_project_id_sync(project_name) if project_name else None |
|
|
|
if project_id is None: |
|
if self.default_project_id is None: |
|
self.create_sync_default_project_and_experiment() |
|
project_id = self.default_project_id |
|
|
|
tags = [] |
|
if isinstance(metadata, dict): |
|
for key, value in metadata.items(): |
|
|
|
if ( |
|
litellm.langfuse_default_tags is not None |
|
and isinstance(litellm.langfuse_default_tags, list) |
|
and key in litellm.langfuse_default_tags |
|
): |
|
tags.append(f"{key}:{value}") |
|
|
|
|
|
if key in [ |
|
"headers", |
|
"endpoint", |
|
"caching_groups", |
|
"previous_models", |
|
]: |
|
continue |
|
else: |
|
clean_metadata[key] = value |
|
|
|
cost = kwargs.get("response_cost", None) |
|
if cost is not None: |
|
clean_metadata["litellm_response_cost"] = cost |
|
|
|
metrics: Optional[dict] = None |
|
usage_obj = getattr(response_obj, "usage", None) |
|
if usage_obj and isinstance(usage_obj, litellm.Usage): |
|
litellm.utils.get_logging_id(start_time, response_obj) |
|
metrics = { |
|
"prompt_tokens": usage_obj.prompt_tokens, |
|
"completion_tokens": usage_obj.completion_tokens, |
|
"total_tokens": usage_obj.total_tokens, |
|
"total_cost": cost, |
|
"time_to_first_token": end_time.timestamp() - start_time.timestamp(), |
|
"start": start_time.timestamp(), |
|
"end": end_time.timestamp(), |
|
} |
|
|
|
request_data = { |
|
"id": litellm_call_id, |
|
"input": prompt["messages"], |
|
"metadata": clean_metadata, |
|
"tags": tags, |
|
"span_attributes": {"name": "Chat Completion", "type": "llm"}, |
|
} |
|
if choices is not None: |
|
request_data["output"] = [choice.dict() for choice in choices] |
|
else: |
|
request_data["output"] = output |
|
|
|
if metrics is not None: |
|
request_data["metrics"] = metrics |
|
|
|
try: |
|
print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}") |
|
global_braintrust_sync_http_handler.post( |
|
url=f"{self.api_base}/project_logs/{project_id}/insert", |
|
json={"events": [request_data]}, |
|
headers=self.headers, |
|
) |
|
except httpx.HTTPStatusError as e: |
|
raise Exception(e.response.text) |
|
except Exception as e: |
|
raise e |
|
|
|
async def async_log_success_event( |
|
self, kwargs, response_obj, start_time, end_time |
|
): |
|
verbose_logger.debug("REACHES BRAINTRUST SUCCESS") |
|
try: |
|
litellm_call_id = kwargs.get("litellm_call_id") |
|
prompt = {"messages": kwargs.get("messages")} |
|
output = None |
|
choices = [] |
|
if response_obj is not None and ( |
|
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) |
|
): |
|
output = None |
|
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): |
|
output = response_obj["choices"][0]["message"].json() |
|
choices = response_obj["choices"] |
|
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): |
|
output = response_obj.choices[0].text |
|
choices = response_obj.choices |
|
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): |
|
output = response_obj["data"] |
|
|
|
litellm_params = kwargs.get("litellm_params", {}) |
|
metadata = litellm_params.get("metadata", {}) or {} |
|
metadata = self.add_metadata_from_header(litellm_params, metadata) |
|
clean_metadata = {} |
|
new_metadata = {} |
|
for key, value in metadata.items(): |
|
if ( |
|
isinstance(value, list) |
|
or isinstance(value, str) |
|
or isinstance(value, int) |
|
or isinstance(value, float) |
|
): |
|
new_metadata[key] = value |
|
elif isinstance(value, BaseModel): |
|
new_metadata[key] = value.model_dump_json() |
|
elif isinstance(value, dict): |
|
for k, v in value.items(): |
|
if isinstance(v, datetime): |
|
value[k] = v.isoformat() |
|
new_metadata[key] = value |
|
|
|
|
|
project_id = metadata.get("project_id") |
|
if project_id is None: |
|
project_name = metadata.get("project_name") |
|
project_id = await self.get_project_id_async(project_name) if project_name else None |
|
|
|
if project_id is None: |
|
if self.default_project_id is None: |
|
await self.create_default_project_and_experiment() |
|
project_id = self.default_project_id |
|
|
|
tags = [] |
|
if isinstance(metadata, dict): |
|
for key, value in metadata.items(): |
|
|
|
if ( |
|
litellm.langfuse_default_tags is not None |
|
and isinstance(litellm.langfuse_default_tags, list) |
|
and key in litellm.langfuse_default_tags |
|
): |
|
tags.append(f"{key}:{value}") |
|
|
|
|
|
if key in [ |
|
"headers", |
|
"endpoint", |
|
"caching_groups", |
|
"previous_models", |
|
]: |
|
continue |
|
else: |
|
clean_metadata[key] = value |
|
|
|
cost = kwargs.get("response_cost", None) |
|
if cost is not None: |
|
clean_metadata["litellm_response_cost"] = cost |
|
|
|
metrics: Optional[dict] = None |
|
usage_obj = getattr(response_obj, "usage", None) |
|
if usage_obj and isinstance(usage_obj, litellm.Usage): |
|
litellm.utils.get_logging_id(start_time, response_obj) |
|
metrics = { |
|
"prompt_tokens": usage_obj.prompt_tokens, |
|
"completion_tokens": usage_obj.completion_tokens, |
|
"total_tokens": usage_obj.total_tokens, |
|
"total_cost": cost, |
|
"start": start_time.timestamp(), |
|
"end": end_time.timestamp(), |
|
} |
|
|
|
api_call_start_time = kwargs.get("api_call_start_time") |
|
completion_start_time = kwargs.get("completion_start_time") |
|
|
|
if api_call_start_time is not None and completion_start_time is not None: |
|
metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp() |
|
|
|
request_data = { |
|
"id": litellm_call_id, |
|
"input": prompt["messages"], |
|
"output": output, |
|
"metadata": clean_metadata, |
|
"tags": tags, |
|
"span_attributes": {"name": "Chat Completion", "type": "llm"}, |
|
} |
|
if choices is not None: |
|
request_data["output"] = [choice.dict() for choice in choices] |
|
else: |
|
request_data["output"] = output |
|
|
|
if metrics is not None: |
|
request_data["metrics"] = metrics |
|
|
|
if metrics is not None: |
|
request_data["metrics"] = metrics |
|
|
|
try: |
|
await global_braintrust_http_handler.post( |
|
url=f"{self.api_base}/project_logs/{project_id}/insert", |
|
json={"events": [request_data]}, |
|
headers=self.headers, |
|
) |
|
except httpx.HTTPStatusError as e: |
|
raise Exception(e.response.text) |
|
except Exception as e: |
|
raise e |
|
|
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time): |
|
return super().log_failure_event(kwargs, response_obj, start_time, end_time) |
|
|