Spaces:
Running
Running
Nathan Brake
commited on
Refactor agent type handling to use AgentFramework enum. Remove agent_type from YAML configurations and update telemetry processing to accommodate the new framework structure. Enhance smolagents configuration with file saving capabilities. (#43)
Browse files- examples/langchain_single_agent_user_confirmation.yaml +0 -1
- examples/openai_single_agent_user_confirmation.yaml +0 -1
- examples/smolagents_single_agent_user_confirmation.yaml +16 -2
- src/surf_spot_finder/config.py +2 -1
- src/surf_spot_finder/evaluation/__init__.py +0 -8
- src/surf_spot_finder/evaluation/evaluate.py +2 -2
- src/surf_spot_finder/evaluation/telemetry/langchain_telemetry.py +4 -3
- src/surf_spot_finder/evaluation/telemetry/openai_telemetry.py +10 -7
- src/surf_spot_finder/evaluation/telemetry/smolagents_telemetry.py +4 -3
- src/surf_spot_finder/evaluation/telemetry/telemetry.py +12 -13
- src/surf_spot_finder/evaluation/test_cases/alpha.yaml +4 -5
examples/langchain_single_agent_user_confirmation.yaml
CHANGED
@@ -14,7 +14,6 @@ main_agent:
|
|
14 |
model_id: gpt-4o
|
15 |
# model_id: ollama/llama3.1:latest
|
16 |
api_key_var: OPENAI_API_KEY
|
17 |
-
agent_type: langchain
|
18 |
tools:
|
19 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
20 |
- "surf_spot_finder.tools.get_area_lat_lon"
|
|
|
14 |
model_id: gpt-4o
|
15 |
# model_id: ollama/llama3.1:latest
|
16 |
api_key_var: OPENAI_API_KEY
|
|
|
17 |
tools:
|
18 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
19 |
- "surf_spot_finder.tools.get_area_lat_lon"
|
examples/openai_single_agent_user_confirmation.yaml
CHANGED
@@ -13,7 +13,6 @@ framework: openai
|
|
13 |
main_agent:
|
14 |
model_id: gpt-4o
|
15 |
api_key_var: OPENAI_API_KEY
|
16 |
-
agent_type: openai
|
17 |
tools:
|
18 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
19 |
- "surf_spot_finder.tools.get_area_lat_lon"
|
|
|
13 |
main_agent:
|
14 |
model_id: gpt-4o
|
15 |
api_key_var: OPENAI_API_KEY
|
|
|
16 |
tools:
|
17 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
18 |
- "surf_spot_finder.tools.get_area_lat_lon"
|
examples/smolagents_single_agent_user_confirmation.yaml
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
location: Pontevedra
|
3 |
date: 2025-03-27 12:00
|
4 |
max_driving_hours: 2
|
@@ -7,6 +6,8 @@ input_prompt_template: |
|
|
7 |
in a {MAX_DRIVING_HOURS} hour driving radius, at {DATE}?
|
8 |
Find a few options and then discuss it with David de la Iglesia Castro. You should recommend him some choices,
|
9 |
and then confirm the final selection with him.
|
|
|
|
|
10 |
|
11 |
framework: smolagents
|
12 |
|
@@ -14,7 +15,6 @@ main_agent:
|
|
14 |
model_id: openai/gpt-4o
|
15 |
# model_id: ollama/llama3.1:latest
|
16 |
api_key_var: OPENAI_API_KEY
|
17 |
-
agent_type: smolagents
|
18 |
tools:
|
19 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
20 |
- "surf_spot_finder.tools.get_area_lat_lon"
|
@@ -23,3 +23,17 @@ main_agent:
|
|
23 |
- "surf_spot_finder.tools.get_wind_forecast"
|
24 |
- "any_agent.tools.send_console_message"
|
25 |
- "smolagents.FinalAnswerTool"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
location: Pontevedra
|
2 |
date: 2025-03-27 12:00
|
3 |
max_driving_hours: 2
|
|
|
6 |
in a {MAX_DRIVING_HOURS} hour driving radius, at {DATE}?
|
7 |
Find a few options and then discuss it with David de la Iglesia Castro. You should recommend him some choices,
|
8 |
and then confirm the final selection with him.
|
9 |
+
Once he gives the final selection, save a detailed description of the weather at the chosen location into a file
|
10 |
+
named "final_answer.txt". Also save a file called "history.txt" which has a list of your thought process in the choice.
|
11 |
|
12 |
framework: smolagents
|
13 |
|
|
|
15 |
model_id: openai/gpt-4o
|
16 |
# model_id: ollama/llama3.1:latest
|
17 |
api_key_var: OPENAI_API_KEY
|
|
|
18 |
tools:
|
19 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
20 |
- "surf_spot_finder.tools.get_area_lat_lon"
|
|
|
23 |
- "surf_spot_finder.tools.get_wind_forecast"
|
24 |
- "any_agent.tools.send_console_message"
|
25 |
- "smolagents.FinalAnswerTool"
|
26 |
+
- command: "docker"
|
27 |
+
args:
|
28 |
+
- "run"
|
29 |
+
- "-i"
|
30 |
+
- "--rm"
|
31 |
+
- "--mount"
|
32 |
+
- "type=bind,src=/tmp/surf-spot-finder,dst=/projects"
|
33 |
+
- "mcp/filesystem"
|
34 |
+
- "/projects"
|
35 |
+
tools:
|
36 |
+
- "read_file"
|
37 |
+
- "write_file"
|
38 |
+
- "directory_tree"
|
39 |
+
- "list_allowed_directories"
|
src/surf_spot_finder/config.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import Annotated
|
2 |
|
|
|
3 |
from any_agent.schema import AgentSchema
|
4 |
from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, PositiveInt
|
5 |
import yaml
|
@@ -29,7 +30,7 @@ class Config(BaseModel):
|
|
29 |
INPUT_PROMPT_TEMPLATE
|
30 |
)
|
31 |
|
32 |
-
framework:
|
33 |
|
34 |
main_agent: AgentSchema
|
35 |
managed_agents: list[AgentSchema] | None = None
|
|
|
1 |
from typing import Annotated
|
2 |
|
3 |
+
from any_agent import AgentFramework
|
4 |
from any_agent.schema import AgentSchema
|
5 |
from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, PositiveInt
|
6 |
import yaml
|
|
|
30 |
INPUT_PROMPT_TEMPLATE
|
31 |
)
|
32 |
|
33 |
+
framework: AgentFramework
|
34 |
|
35 |
main_agent: AgentSchema
|
36 |
managed_agents: list[AgentSchema] | None = None
|
src/surf_spot_finder/evaluation/__init__.py
CHANGED
@@ -1,8 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
|
3 |
-
|
4 |
-
class AgentType(str, Enum):
|
5 |
-
LANGCHAIN = "langchain"
|
6 |
-
OPENAI = "openai"
|
7 |
-
OPENAI_MULTI_AGENT = "openai_multi_agent"
|
8 |
-
SMOLAGENTS = "smolagents"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/surf_spot_finder/evaluation/evaluate.py
CHANGED
@@ -63,10 +63,10 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
63 |
telemetry: List[Dict[str, Any]] = json.loads(f.read())
|
64 |
logger.info(f"Telemetry loaded from {telemetry_path}")
|
65 |
|
66 |
-
|
67 |
|
68 |
# Extract the final answer from the telemetry
|
69 |
-
processor = TelemetryProcessor.create(
|
70 |
hypothesis_answer = processor.extract_hypothesis_answer(trace=telemetry)
|
71 |
|
72 |
# Checkpoint evaluation
|
|
|
63 |
telemetry: List[Dict[str, Any]] = json.loads(f.read())
|
64 |
logger.info(f"Telemetry loaded from {telemetry_path}")
|
65 |
|
66 |
+
agent_framework = TelemetryProcessor.determine_agent_framework(telemetry)
|
67 |
|
68 |
# Extract the final answer from the telemetry
|
69 |
+
processor = TelemetryProcessor.create(agent_framework)
|
70 |
hypothesis_answer = processor.extract_hypothesis_answer(trace=telemetry)
|
71 |
|
72 |
# Checkpoint evaluation
|
src/surf_spot_finder/evaluation/telemetry/langchain_telemetry.py
CHANGED
@@ -1,16 +1,17 @@
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
|
|
3 |
from langchain_core.messages import BaseMessage
|
4 |
|
5 |
-
|
6 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
7 |
|
8 |
|
9 |
class LangchainTelemetryProcessor(TelemetryProcessor):
|
10 |
"""Processor for Langchain agent telemetry data."""
|
11 |
|
12 |
-
def
|
13 |
-
return
|
14 |
|
15 |
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
16 |
for span in reversed(trace):
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
+
from any_agent import AgentFramework
|
4 |
from langchain_core.messages import BaseMessage
|
5 |
|
6 |
+
|
7 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
8 |
|
9 |
|
10 |
class LangchainTelemetryProcessor(TelemetryProcessor):
|
11 |
"""Processor for Langchain agent telemetry data."""
|
12 |
|
13 |
+
def _get_agent_framework(self) -> AgentFramework:
|
14 |
+
return AgentFramework.LANGCHAIN
|
15 |
|
16 |
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
17 |
for span in reversed(trace):
|
src/surf_spot_finder/evaluation/telemetry/openai_telemetry.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
-
from
|
|
|
5 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
|
7 |
|
8 |
class OpenAITelemetryProcessor(TelemetryProcessor):
|
9 |
"""Processor for OpenAI agent telemetry data."""
|
10 |
|
11 |
-
def
|
12 |
-
return
|
13 |
|
14 |
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
15 |
for span in reversed(trace):
|
@@ -82,10 +83,10 @@ class OpenAITelemetryProcessor(TelemetryProcessor):
|
|
82 |
|
83 |
# Backward compatibility functions that use the new class structure
|
84 |
def extract_hypothesis_answer(
|
85 |
-
trace: List[Dict[str, Any]],
|
86 |
) -> str:
|
87 |
"""Extract the hypothesis agent final answer from the trace"""
|
88 |
-
processor = TelemetryProcessor.create(
|
89 |
return processor.extract_hypothesis_answer(trace)
|
90 |
|
91 |
|
@@ -97,7 +98,9 @@ def parse_generic_key_value_string(text: str) -> Dict[str, str]:
|
|
97 |
return TelemetryProcessor.parse_generic_key_value_string(text)
|
98 |
|
99 |
|
100 |
-
def extract_evidence(
|
|
|
|
|
101 |
"""Extract relevant telemetry evidence based on the agent type."""
|
102 |
-
processor = TelemetryProcessor.create(
|
103 |
return processor.extract_evidence(telemetry)
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
+
from any_agent import AgentFramework
|
5 |
+
|
6 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
7 |
|
8 |
|
9 |
class OpenAITelemetryProcessor(TelemetryProcessor):
|
10 |
"""Processor for OpenAI agent telemetry data."""
|
11 |
|
12 |
+
def _get_agent_framework(self) -> AgentFramework:
|
13 |
+
return AgentFramework.OPENAI
|
14 |
|
15 |
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
16 |
for span in reversed(trace):
|
|
|
83 |
|
84 |
# Backward compatibility functions that use the new class structure
|
85 |
def extract_hypothesis_answer(
|
86 |
+
trace: List[Dict[str, Any]], agent_framework: AgentFramework
|
87 |
) -> str:
|
88 |
"""Extract the hypothesis agent final answer from the trace"""
|
89 |
+
processor = TelemetryProcessor.create(agent_framework)
|
90 |
return processor.extract_hypothesis_answer(trace)
|
91 |
|
92 |
|
|
|
98 |
return TelemetryProcessor.parse_generic_key_value_string(text)
|
99 |
|
100 |
|
101 |
+
def extract_evidence(
|
102 |
+
telemetry: List[Dict[str, Any]], agent_framework: AgentFramework
|
103 |
+
) -> str:
|
104 |
"""Extract relevant telemetry evidence based on the agent type."""
|
105 |
+
processor = TelemetryProcessor.create(agent_framework)
|
106 |
return processor.extract_evidence(telemetry)
|
src/surf_spot_finder/evaluation/telemetry/smolagents_telemetry.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
-
from
|
|
|
5 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
|
7 |
|
8 |
class SmolagentsTelemetryProcessor(TelemetryProcessor):
|
9 |
"""Processor for SmoL Agents telemetry data."""
|
10 |
|
11 |
-
def
|
12 |
-
return
|
13 |
|
14 |
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
15 |
for span in reversed(trace):
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
+
from any_agent import AgentFramework
|
5 |
+
|
6 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
7 |
|
8 |
|
9 |
class SmolagentsTelemetryProcessor(TelemetryProcessor):
|
10 |
"""Processor for SmoL Agents telemetry data."""
|
11 |
|
12 |
+
def _get_agent_framework(self) -> AgentFramework:
|
13 |
+
return AgentFramework.SMOLAGENTS
|
14 |
|
15 |
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
16 |
for span in reversed(trace):
|
src/surf_spot_finder/evaluation/telemetry/telemetry.py
CHANGED
@@ -2,10 +2,9 @@ from typing import Any, Dict, List, ClassVar
|
|
2 |
import json
|
3 |
import re
|
4 |
from abc import ABC, abstractmethod
|
|
|
5 |
from loguru import logger
|
6 |
|
7 |
-
from surf_spot_finder.evaluation import AgentType
|
8 |
-
|
9 |
|
10 |
class TelemetryProcessor(ABC):
|
11 |
"""Base class for processing telemetry data from different agent types."""
|
@@ -13,31 +12,31 @@ class TelemetryProcessor(ABC):
|
|
13 |
MAX_EVIDENCE_LENGTH: ClassVar[int] = 400
|
14 |
|
15 |
@classmethod
|
16 |
-
def create(cls,
|
17 |
"""Factory method to create the appropriate telemetry processor."""
|
18 |
-
if
|
19 |
from surf_spot_finder.evaluation.telemetry.langchain_telemetry import (
|
20 |
LangchainTelemetryProcessor,
|
21 |
)
|
22 |
|
23 |
return LangchainTelemetryProcessor()
|
24 |
-
elif
|
25 |
from surf_spot_finder.evaluation.telemetry.smolagents_telemetry import (
|
26 |
SmolagentsTelemetryProcessor,
|
27 |
)
|
28 |
|
29 |
return SmolagentsTelemetryProcessor()
|
30 |
-
elif
|
31 |
from surf_spot_finder.evaluation.telemetry.openai_telemetry import (
|
32 |
OpenAITelemetryProcessor,
|
33 |
)
|
34 |
|
35 |
return OpenAITelemetryProcessor()
|
36 |
else:
|
37 |
-
raise ValueError(f"Unsupported agent type {
|
38 |
|
39 |
@staticmethod
|
40 |
-
def
|
41 |
"""Determine the agent type based on the trace.
|
42 |
These are not really stable ways to find it, because we're waiting on some
|
43 |
reliable method for determining the agent type. This is a temporary solution.
|
@@ -45,15 +44,15 @@ class TelemetryProcessor(ABC):
|
|
45 |
for span in trace:
|
46 |
if "langchain" in span.get("attributes", {}).get("input.value", ""):
|
47 |
logger.info("Agent type is LANGCHAIN")
|
48 |
-
return
|
49 |
if span.get("attributes", {}).get("smolagents.max_steps"):
|
50 |
logger.info("Agent type is SMOLAGENTS")
|
51 |
-
return
|
52 |
# This is extremely fragile but there currently isn't
|
53 |
# any specific key to indicate the agent type
|
54 |
if span.get("name") == "response":
|
55 |
logger.info("Agent type is OPENAI")
|
56 |
-
return
|
57 |
raise ValueError(
|
58 |
"Could not determine agent type from trace, or agent type not supported"
|
59 |
)
|
@@ -75,7 +74,7 @@ class TelemetryProcessor(ABC):
|
|
75 |
|
76 |
def _format_evidence(self, calls: List[Dict]) -> str:
|
77 |
"""Format extracted data into a standardized output format."""
|
78 |
-
evidence = f"## {self.
|
79 |
|
80 |
for idx, call in enumerate(calls, start=1):
|
81 |
evidence += f"### Call {idx}\n"
|
@@ -96,7 +95,7 @@ class TelemetryProcessor(ABC):
|
|
96 |
return evidence
|
97 |
|
98 |
@abstractmethod
|
99 |
-
def
|
100 |
"""Get the agent type associated with this processor."""
|
101 |
pass
|
102 |
|
|
|
2 |
import json
|
3 |
import re
|
4 |
from abc import ABC, abstractmethod
|
5 |
+
from any_agent import AgentFramework
|
6 |
from loguru import logger
|
7 |
|
|
|
|
|
8 |
|
9 |
class TelemetryProcessor(ABC):
|
10 |
"""Base class for processing telemetry data from different agent types."""
|
|
|
12 |
MAX_EVIDENCE_LENGTH: ClassVar[int] = 400
|
13 |
|
14 |
@classmethod
|
15 |
+
def create(cls, agent_framework: AgentFramework) -> "TelemetryProcessor":
|
16 |
"""Factory method to create the appropriate telemetry processor."""
|
17 |
+
if agent_framework == AgentFramework.LANGCHAIN:
|
18 |
from surf_spot_finder.evaluation.telemetry.langchain_telemetry import (
|
19 |
LangchainTelemetryProcessor,
|
20 |
)
|
21 |
|
22 |
return LangchainTelemetryProcessor()
|
23 |
+
elif agent_framework == AgentFramework.SMOLAGENTS:
|
24 |
from surf_spot_finder.evaluation.telemetry.smolagents_telemetry import (
|
25 |
SmolagentsTelemetryProcessor,
|
26 |
)
|
27 |
|
28 |
return SmolagentsTelemetryProcessor()
|
29 |
+
elif agent_framework == AgentFramework.OPENAI:
|
30 |
from surf_spot_finder.evaluation.telemetry.openai_telemetry import (
|
31 |
OpenAITelemetryProcessor,
|
32 |
)
|
33 |
|
34 |
return OpenAITelemetryProcessor()
|
35 |
else:
|
36 |
+
raise ValueError(f"Unsupported agent type {agent_framework}")
|
37 |
|
38 |
@staticmethod
|
39 |
+
def determine_agent_framework(trace: List[Dict[str, Any]]) -> AgentFramework:
|
40 |
"""Determine the agent type based on the trace.
|
41 |
These are not really stable ways to find it, because we're waiting on some
|
42 |
reliable method for determining the agent type. This is a temporary solution.
|
|
|
44 |
for span in trace:
|
45 |
if "langchain" in span.get("attributes", {}).get("input.value", ""):
|
46 |
logger.info("Agent type is LANGCHAIN")
|
47 |
+
return AgentFramework.LANGCHAIN
|
48 |
if span.get("attributes", {}).get("smolagents.max_steps"):
|
49 |
logger.info("Agent type is SMOLAGENTS")
|
50 |
+
return AgentFramework.SMOLAGENTS
|
51 |
# This is extremely fragile but there currently isn't
|
52 |
# any specific key to indicate the agent type
|
53 |
if span.get("name") == "response":
|
54 |
logger.info("Agent type is OPENAI")
|
55 |
+
return AgentFramework.OPENAI
|
56 |
raise ValueError(
|
57 |
"Could not determine agent type from trace, or agent type not supported"
|
58 |
)
|
|
|
74 |
|
75 |
def _format_evidence(self, calls: List[Dict]) -> str:
|
76 |
"""Format extracted data into a standardized output format."""
|
77 |
+
evidence = f"## {self._get_agent_framework().name} Agent Execution\n\n"
|
78 |
|
79 |
for idx, call in enumerate(calls, start=1):
|
80 |
evidence += f"### Call {idx}\n"
|
|
|
95 |
return evidence
|
96 |
|
97 |
@abstractmethod
|
98 |
+
def _get_agent_framework(self) -> AgentFramework:
|
99 |
"""Get the agent type associated with this processor."""
|
100 |
pass
|
101 |
|
src/surf_spot_finder/evaluation/test_cases/alpha.yaml
CHANGED
@@ -6,13 +6,12 @@ input:
|
|
6 |
location: "Vigo"
|
7 |
date: "2025-03-27 22:00"
|
8 |
max_driving_hours: 3
|
9 |
-
json_tracer: true
|
10 |
|
11 |
|
12 |
-
ground_truth:
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
# Base checkpoints for agent behavior
|
18 |
# These evaluators for these checkpoints
|
|
|
6 |
location: "Vigo"
|
7 |
date: "2025-03-27 22:00"
|
8 |
max_driving_hours: 3
|
|
|
9 |
|
10 |
|
11 |
+
# ground_truth:
|
12 |
+
# - name: "Surf location"
|
13 |
+
# points: 5
|
14 |
+
# value: "Playa de Samil"
|
15 |
|
16 |
# Base checkpoints for agent behavior
|
17 |
# These evaluators for these checkpoints
|