File size: 4,853 Bytes
fea07c2
 
 
 
a9fb876
fea07c2
0b1aa61
fea07c2
 
 
 
 
 
 
a9fb876
fea07c2
a9fb876
fea07c2
 
 
 
 
a9fb876
fea07c2
 
 
 
 
a9fb876
fea07c2
 
 
 
 
 
a9fb876
fea07c2
 
a9fb876
fea07c2
 
 
 
 
 
 
a9fb876
fea07c2
 
a9fb876
fea07c2
 
 
 
a9fb876
fea07c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9fb876
fea07c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9fb876
fea07c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from typing import Any, Dict, List, ClassVar
import json
import re
from abc import ABC, abstractmethod
from any_agent import AgentFramework
from loguru import logger


class TelemetryProcessor(ABC):
    """Base class for processing telemetry data from different agent types."""

    MAX_EVIDENCE_LENGTH: ClassVar[int] = 400

    @classmethod
    def create(cls, agent_framework: AgentFramework) -> "TelemetryProcessor":
        """Factory method to create the appropriate telemetry processor."""
        if agent_framework == AgentFramework.LANGCHAIN:
            from surf_spot_finder.evaluation.telemetry.langchain_telemetry import (
                LangchainTelemetryProcessor,
            )

            return LangchainTelemetryProcessor()
        elif agent_framework == AgentFramework.SMOLAGENTS:
            from surf_spot_finder.evaluation.telemetry.smolagents_telemetry import (
                SmolagentsTelemetryProcessor,
            )

            return SmolagentsTelemetryProcessor()
        elif agent_framework == AgentFramework.OPENAI:
            from surf_spot_finder.evaluation.telemetry.openai_telemetry import (
                OpenAITelemetryProcessor,
            )

            return OpenAITelemetryProcessor()
        else:
            raise ValueError(f"Unsupported agent type {agent_framework}")

    @staticmethod
    def determine_agent_framework(trace: List[Dict[str, Any]]) -> AgentFramework:
        """Determine the agent type based on the trace.
        These are not really stable ways to find it, because we're waiting on some
        reliable method for determining the agent type. This is a temporary solution.
        """
        for span in trace:
            if "langchain" in span.get("attributes", {}).get("input.value", ""):
                logger.info("Agent type is LANGCHAIN")
                return AgentFramework.LANGCHAIN
            if span.get("attributes", {}).get("smolagents.max_steps"):
                logger.info("Agent type is SMOLAGENTS")
                return AgentFramework.SMOLAGENTS
            # This is extremely fragile but there currently isn't
            # any specific key to indicate the agent type
            if span.get("name") == "response":
                logger.info("Agent type is OPENAI")
                return AgentFramework.OPENAI
        raise ValueError(
            "Could not determine agent type from trace, or agent type not supported"
        )

    @abstractmethod
    def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
        """Extract the hypothesis agent final answer from the trace."""
        pass

    @abstractmethod
    def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> List[Dict]:
        """Extract the agent-specific data from telemetry."""
        pass

    def extract_evidence(self, telemetry: List[Dict[str, Any]]) -> str:
        """Extract relevant telemetry evidence."""
        calls = self._extract_telemetry_data(telemetry)
        return self._format_evidence(calls)

    def _format_evidence(self, calls: List[Dict]) -> str:
        """Format extracted data into a standardized output format."""
        evidence = f"## {self._get_agent_framework().name} Agent Execution\n\n"

        for idx, call in enumerate(calls, start=1):
            evidence += f"### Call {idx}\n"

            # Truncate any values that are too long
            call = {
                k: (
                    v[: self.MAX_EVIDENCE_LENGTH] + "..."
                    if isinstance(v, str) and len(v) > self.MAX_EVIDENCE_LENGTH
                    else v
                )
                for k, v in call.items()
            }

            # Use ensure_ascii=False to prevent escaping Unicode characters
            evidence += json.dumps(call, indent=2, ensure_ascii=False) + "\n\n"

        return evidence

    @abstractmethod
    def _get_agent_framework(self) -> AgentFramework:
        """Get the agent type associated with this processor."""
        pass

    @staticmethod
    def parse_generic_key_value_string(text: str) -> Dict[str, str]:
        """
        Parse a string that has items of a dict with key-value pairs separated by '='.
        Only splits on '=' signs, handling quoted strings properly.
        """
        pattern = r"(\w+)=('.*?'|\".*?\"|[^'\"=]*?)(?=\s+\w+=|\s*$)"
        result = {}

        matches = re.findall(pattern, text)
        for key, value in matches:
            # Clean up the key
            key = key.strip()

            # Clean up the value - remove surrounding quotes if present
            if (value.startswith("'") and value.endswith("'")) or (
                value.startswith('"') and value.endswith('"')
            ):
                value = value[1:-1]

            # Store in result dictionary
            result[key] = value

        return result