Spaces:
Running
Running
from __future__ import annotations | |
import json | |
import math | |
import os | |
import traceback | |
import types | |
from datetime import datetime, timezone | |
from typing import TYPE_CHECKING, Any | |
from langchain_core.documents import Document | |
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | |
from loguru import logger | |
from openinference.semconv.trace import OpenInferenceMimeTypeValues, SpanAttributes | |
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes | |
from opentelemetry.trace import Span, Status, StatusCode | |
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator | |
from typing_extensions import override | |
from langflow.schema.data import Data | |
from langflow.schema.message import Message | |
from langflow.services.tracing.base import BaseTracer | |
if TYPE_CHECKING: | |
from collections.abc import Sequence | |
from uuid import UUID | |
from langchain.callbacks.base import BaseCallbackHandler | |
from opentelemetry.propagators.textmap import CarrierT | |
from opentelemetry.util.types import AttributeValue | |
from langflow.graph.vertex.base import Vertex | |
from langflow.services.tracing.schema import Log | |
class ArizePhoenixTracer(BaseTracer): | |
flow_id: str | |
def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID): | |
"""Initializes the ArizePhoenixTracer instance and sets up a root span.""" | |
self.trace_name = trace_name | |
self.trace_type = trace_type | |
self.project_name = project_name | |
self.trace_id = trace_id | |
self.flow_id = trace_name.split(" - ")[-1] | |
try: | |
self._ready = self.setup_arize_phoenix() | |
if not self._ready: | |
return | |
self.tracer = self.tracer_provider.get_tracer(__name__) | |
self.propagator = TraceContextTextMapPropagator() | |
self.carrier: dict[Any, CarrierT] = {} | |
with self.tracer.start_as_current_span( | |
name=self.flow_id, | |
start_time=self._get_current_timestamp(), | |
) as root_span: | |
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, self.trace_type) | |
root_span.set_status(Status(StatusCode.OK)) | |
self.propagator.inject(carrier=self.carrier) | |
self.child_spans: dict[str, Span] = {} | |
except Exception: # noqa: BLE001 | |
logger.opt(exception=True).debug("Error setting up Arize/Phoenix tracer") | |
self._ready = False | |
def ready(self): | |
"""Indicates if the tracer is ready for usage.""" | |
return self._ready | |
def setup_arize_phoenix(self) -> bool: | |
"""Configures Arize/Phoenix specific environment variables and registers the tracer provider.""" | |
arize_phoenix_batch = os.getenv("ARIZE_PHOENIX_BATCH", "False").lower() in ( | |
"true", | |
"t", | |
"yes", | |
"y", | |
"1", | |
) | |
# Arize Config | |
arize_api_key = os.getenv("ARIZE_API_KEY", None) | |
arize_space_id = os.getenv("ARIZE_SPACE_ID", None) | |
arize_collector_endpoint = os.getenv("ARIZE_COLLECTOR_ENDPOINT", "https://otlp.arize.com") | |
enable_arize_tracing = bool(arize_api_key and arize_space_id) | |
arize_endpoint = f"{arize_collector_endpoint}/v1" | |
arize_headers = { | |
"api_key": arize_api_key, | |
"space_id": arize_space_id, | |
"authorization": f"Bearer {arize_api_key}", | |
} | |
# Phoenix Config | |
phoenix_api_key = os.getenv("PHOENIX_API_KEY", None) | |
phoenix_collector_endpoint = os.getenv("PHOENIX_COLLECTOR_ENDPOINT", "https://app.phoenix.arize.com") | |
enable_phoenix_tracing = bool(phoenix_api_key) | |
phoenix_endpoint = f"{phoenix_collector_endpoint}/v1/traces" | |
phoenix_headers = { | |
"api_key": phoenix_api_key, | |
"authorization": f"Bearer {phoenix_api_key}", | |
} | |
if not (enable_arize_tracing or enable_phoenix_tracing): | |
return False | |
try: | |
from phoenix.otel import ( | |
PROJECT_NAME, | |
BatchSpanProcessor, | |
GRPCSpanExporter, | |
HTTPSpanExporter, | |
Resource, | |
SimpleSpanProcessor, | |
TracerProvider, | |
) | |
project_name = self.project_name or self.flow_id | |
attributes = {PROJECT_NAME: project_name, "model_id": project_name} | |
resource = Resource.create(attributes=attributes) | |
tracer_provider = TracerProvider(resource=resource, verbose=False) | |
span_processor = BatchSpanProcessor if arize_phoenix_batch else SimpleSpanProcessor | |
if enable_arize_tracing: | |
tracer_provider.add_span_processor( | |
span_processor=span_processor( | |
span_exporter=GRPCSpanExporter(endpoint=arize_endpoint, headers=arize_headers), | |
) | |
) | |
if enable_phoenix_tracing: | |
tracer_provider.add_span_processor( | |
span_processor=span_processor( | |
span_exporter=HTTPSpanExporter( | |
endpoint=phoenix_endpoint, | |
headers=phoenix_headers, | |
), | |
) | |
) | |
self.tracer_provider = tracer_provider | |
except ImportError: | |
logger.exception( | |
"Could not import arize-phoenix-otel. Please install it with `pip install arize-phoenix-otel`." | |
) | |
return False | |
try: | |
from openinference.instrumentation.langchain import LangChainInstrumentor | |
LangChainInstrumentor().instrument(tracer_provider=self.tracer_provider, skip_dep_check=True) | |
except ImportError: | |
logger.exception( | |
"Could not import LangChainInstrumentor." | |
"Please install it with `pip install openinference-instrumentation-langchain`." | |
) | |
return False | |
return True | |
def add_trace( | |
self, | |
trace_id: str, | |
trace_name: str, | |
trace_type: str, | |
inputs: dict[str, Any], | |
metadata: dict[str, Any] | None = None, | |
vertex: Vertex | None = None, | |
) -> None: | |
"""Adds a trace span, attaching inputs and metadata as attributes.""" | |
if not self._ready: | |
return | |
span_context = self.propagator.extract(carrier=self.carrier) | |
child_span = self.tracer.start_span( | |
name=trace_name, | |
context=span_context, | |
start_time=self._get_current_timestamp(), | |
) | |
if trace_type == "prompt": | |
child_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, "chain") | |
else: | |
child_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, trace_type) | |
if "session_id" in inputs and len(inputs["session_id"]) > 0 and inputs["session_id"] != self.flow_id: | |
child_span.set_attribute(SpanAttributes.SESSION_ID, inputs["session_id"]) | |
else: | |
child_span.set_attribute(SpanAttributes.SESSION_ID, self.flow_id) | |
processed_inputs = self._convert_to_arize_phoenix_types(inputs) if inputs else {} | |
if processed_inputs: | |
child_span.set_attribute(SpanAttributes.INPUT_VALUE, self._safe_json_dumps(processed_inputs)) | |
child_span.set_attribute(SpanAttributes.INPUT_MIME_TYPE, OpenInferenceMimeTypeValues.JSON.value) | |
processed_metadata = self._convert_to_arize_phoenix_types(metadata) if metadata else {} | |
if processed_metadata: | |
for key, value in processed_metadata.items(): | |
child_span.set_attribute(f"{SpanAttributes.METADATA}.{key}", value) | |
self.child_spans[trace_id] = child_span | |
def end_trace( | |
self, | |
trace_id: str, | |
trace_name: str, | |
outputs: dict[str, Any] | None = None, | |
error: Exception | None = None, | |
logs: Sequence[Log | dict] = (), | |
) -> None: | |
"""Ends a trace span, attaching outputs, errors, and logs as attributes.""" | |
if not self._ready or trace_id not in self.child_spans: | |
return | |
child_span = self.child_spans[trace_id] | |
processed_outputs = self._convert_to_arize_phoenix_types(outputs) if outputs else {} | |
if processed_outputs: | |
child_span.set_attribute(SpanAttributes.OUTPUT_VALUE, self._safe_json_dumps(processed_outputs)) | |
child_span.set_attribute(SpanAttributes.OUTPUT_MIME_TYPE, OpenInferenceMimeTypeValues.JSON.value) | |
logs_dicts = [log if isinstance(log, dict) else log.model_dump() for log in logs] | |
processed_logs = ( | |
self._convert_to_arize_phoenix_types({log.get("name"): log for log in logs_dicts}) if logs else {} | |
) | |
if processed_logs: | |
for key, value in processed_logs.items(): | |
child_span.set_attribute(f"logs.{key}", value) | |
if error: | |
error_string = self._error_to_string(error) | |
child_span.set_status(Status(StatusCode.ERROR, error_string)) | |
child_span.set_attribute("error.message", error_string) | |
if isinstance(error, Exception): | |
child_span.record_exception(error) | |
else: | |
exception_type = error.__class__.__name__ | |
exception_message = str(error) | |
if not exception_message: | |
exception_message = repr(error) | |
attributes: dict[str, AttributeValue] = { | |
OTELSpanAttributes.EXCEPTION_TYPE: exception_type, | |
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message, | |
OTELSpanAttributes.EXCEPTION_ESCAPED: False, | |
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string, | |
} | |
child_span.add_event(name="exception", attributes=attributes) | |
else: | |
child_span.set_status(Status(StatusCode.OK)) | |
child_span.end(end_time=self._get_current_timestamp()) | |
self.child_spans.pop(trace_id) | |
def end( | |
self, | |
inputs: dict[str, Any], | |
outputs: dict[str, Any], | |
error: Exception | None = None, | |
metadata: dict[str, Any] | None = None, | |
) -> None: | |
"""Ends tracing with the specified inputs, outputs, errors, and metadata as attributes.""" | |
if not self._ready: | |
return | |
def _convert_to_arize_phoenix_types(self, io_dict: dict[str, Any]): | |
"""Converts data types to Arize/Phoenix compatible formats.""" | |
return {key: self._convert_to_arize_phoenix_type(value) for key, value in io_dict.items()} | |
def _convert_to_arize_phoenix_type(self, value): | |
"""Recursively converts a value to a Arize/Phoenix compatible type.""" | |
if isinstance(value, dict): | |
value = {key: self._convert_to_arize_phoenix_type(val) for key, val in value.items()} | |
elif isinstance(value, list): | |
value = [self._convert_to_arize_phoenix_type(v) for v in value] | |
elif isinstance(value, Message): | |
value = value.text | |
elif isinstance(value, Data): | |
value = value.get_text() | |
elif isinstance(value, (BaseMessage | HumanMessage | SystemMessage)): | |
value = value.content | |
elif isinstance(value, Document): | |
value = value.page_content | |
elif isinstance(value, (types.GeneratorType | types.NoneType)): | |
value = str(value) | |
elif isinstance(value, float) and not math.isfinite(value): | |
value = "NaN" | |
return value | |
def _error_to_string(self, error: Exception | None): | |
"""Converts an error to a string with traceback details.""" | |
error_message = None | |
if error: | |
string_stacktrace = traceback.format_exception(error) | |
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}" | |
return error_message | |
def _get_current_timestamp(self) -> int: | |
"""Gets the current UTC timestamp in nanoseconds.""" | |
return int(datetime.now(timezone.utc).timestamp() * 1_000_000_000) | |
def _safe_json_dumps(self, obj: Any, **kwargs: Any) -> str: | |
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded.""" | |
return json.dumps(obj, default=str, ensure_ascii=False, **kwargs) | |
def get_langchain_callback(self) -> BaseCallbackHandler | None: | |
"""Returns the LangChain callback handler if applicable.""" | |
return None | |