Spaces:
Running
Running
Nathan Brake
commited on
support for anyagent (#41)
Browse files- .gitignore +0 -1
- examples/langchain_single_agent_user_confirmation.yaml +15 -12
- examples/openai_single_agent_user_confirmation.yaml +14 -11
- examples/smolagents_single_agent_user_confirmation.yaml +14 -11
- pyproject.toml +0 -2
- src/surf_spot_finder/config.py +15 -0
- src/surf_spot_finder/evaluation/__init__.py +8 -0
- src/surf_spot_finder/evaluation/evaluate.py +34 -21
- src/surf_spot_finder/evaluation/telemetry/langchain_telemetry.py +1 -1
- src/surf_spot_finder/evaluation/telemetry/openai_telemetry.py +1 -1
- src/surf_spot_finder/evaluation/telemetry/smolagents_telemetry.py +1 -1
- src/surf_spot_finder/evaluation/telemetry/telemetry.py +2 -1
- src/surf_spot_finder/evaluation/test_case.py +9 -7
.gitignore
CHANGED
@@ -168,4 +168,3 @@ cython_debug/
|
|
168 |
.vscode/
|
169 |
|
170 |
output
|
171 |
-
telemetry_output
|
|
|
168 |
.vscode/
|
169 |
|
170 |
output
|
|
examples/langchain_single_agent_user_confirmation.yaml
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
12 |
# model_id: ollama/llama3.1:latest
|
13 |
api_key_var: OPENAI_API_KEY
|
14 |
agent_type: langchain
|
@@ -18,4 +21,4 @@ agent:
|
|
18 |
- "surf_spot_finder.tools.get_surfing_spots"
|
19 |
- "surf_spot_finder.tools.get_wave_forecast"
|
20 |
- "surf_spot_finder.tools.get_wind_forecast"
|
21 |
-
- "any_agent.tools.
|
|
|
1 |
+
|
2 |
+
location: Pontevedra
|
3 |
+
date: 2025-03-27 12:00
|
4 |
+
max_driving_hours: 2
|
5 |
+
input_prompt_template: |
|
6 |
+
According to the forecast, what will be the best spot to surf around {LOCATION},
|
7 |
+
in a {MAX_DRIVING_HOURS} hour driving radius, at {DATE}?
|
8 |
+
Find a few options and then discuss it with David de la Iglesia Castro. You should recommend him some choices,
|
9 |
+
and then confirm the final selection with him.
|
10 |
+
|
11 |
+
framework: langchain
|
12 |
+
|
13 |
+
main_agent:
|
14 |
+
model_id: gpt-4o
|
15 |
# model_id: ollama/llama3.1:latest
|
16 |
api_key_var: OPENAI_API_KEY
|
17 |
agent_type: langchain
|
|
|
21 |
- "surf_spot_finder.tools.get_surfing_spots"
|
22 |
- "surf_spot_finder.tools.get_wave_forecast"
|
23 |
- "surf_spot_finder.tools.get_wind_forecast"
|
24 |
+
- "any_agent.tools.send_console_message"
|
examples/openai_single_agent_user_confirmation.yaml
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
model_id: gpt-4o
|
12 |
api_key_var: OPENAI_API_KEY
|
13 |
agent_type: openai
|
@@ -17,4 +20,4 @@ agent:
|
|
17 |
- "surf_spot_finder.tools.get_surfing_spots"
|
18 |
- "surf_spot_finder.tools.get_wave_forecast"
|
19 |
- "surf_spot_finder.tools.get_wind_forecast"
|
20 |
-
- "any_agent.tools.
|
|
|
1 |
+
|
2 |
+
location: Pontevedra
|
3 |
+
date: 2025-03-27 12:00
|
4 |
+
max_driving_hours: 2
|
5 |
+
input_prompt_template: |
|
6 |
+
According to the forecast, what will be the best spot to surf around {LOCATION},
|
7 |
+
in a {MAX_DRIVING_HOURS} hour driving radius, at {DATE}?
|
8 |
+
Find a few options and then discuss it with David de la Iglesia Castro. You should recommend him some choices,
|
9 |
+
and then confirm the final selection with him.
|
10 |
+
|
11 |
+
framework: openai
|
12 |
+
|
13 |
+
main_agent:
|
14 |
model_id: gpt-4o
|
15 |
api_key_var: OPENAI_API_KEY
|
16 |
agent_type: openai
|
|
|
20 |
- "surf_spot_finder.tools.get_surfing_spots"
|
21 |
- "surf_spot_finder.tools.get_wave_forecast"
|
22 |
- "surf_spot_finder.tools.get_wind_forecast"
|
23 |
+
- "any_agent.tools.send_console_message"
|
examples/smolagents_single_agent_user_confirmation.yaml
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
model_id: openai/gpt-4o
|
12 |
# model_id: ollama/llama3.1:latest
|
13 |
api_key_var: OPENAI_API_KEY
|
@@ -18,5 +21,5 @@ agent:
|
|
18 |
- "surf_spot_finder.tools.get_surfing_spots"
|
19 |
- "surf_spot_finder.tools.get_wave_forecast"
|
20 |
- "surf_spot_finder.tools.get_wind_forecast"
|
21 |
-
- "any_agent.tools.
|
22 |
- "smolagents.FinalAnswerTool"
|
|
|
1 |
+
|
2 |
+
location: Pontevedra
|
3 |
+
date: 2025-03-27 12:00
|
4 |
+
max_driving_hours: 2
|
5 |
+
input_prompt_template: |
|
6 |
+
According to the forecast, what will be the best spot to surf around {LOCATION},
|
7 |
+
in a {MAX_DRIVING_HOURS} hour driving radius, at {DATE}?
|
8 |
+
Find a few options and then discuss it with David de la Iglesia Castro. You should recommend him some choices,
|
9 |
+
and then confirm the final selection with him.
|
10 |
+
|
11 |
+
framework: smolagents
|
12 |
+
|
13 |
+
main_agent:
|
14 |
model_id: openai/gpt-4o
|
15 |
# model_id: ollama/llama3.1:latest
|
16 |
api_key_var: OPENAI_API_KEY
|
|
|
21 |
- "surf_spot_finder.tools.get_surfing_spots"
|
22 |
- "surf_spot_finder.tools.get_wave_forecast"
|
23 |
- "surf_spot_finder.tools.get_wind_forecast"
|
24 |
+
- "any_agent.tools.send_console_message"
|
25 |
- "smolagents.FinalAnswerTool"
|
pyproject.toml
CHANGED
@@ -55,5 +55,3 @@ dev = [
|
|
55 |
[project.scripts]
|
56 |
surf-spot-finder = "surf_spot_finder.cli:main"
|
57 |
surf-spot-finder-evaluate = "surf_spot_finder.evaluation.evaluate:main"
|
58 |
-
# TODO maybe this would be lumigator
|
59 |
-
start-phoenix = "phoenix.server.main:main"
|
|
|
55 |
[project.scripts]
|
56 |
surf-spot-finder = "surf_spot_finder.cli:main"
|
57 |
surf-spot-finder-evaluate = "surf_spot_finder.evaluation.evaluate:main"
|
|
|
|
src/surf_spot_finder/config.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Annotated
|
|
2 |
|
3 |
from any_agent.schema import AgentSchema
|
4 |
from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, PositiveInt
|
|
|
5 |
|
6 |
|
7 |
INPUT_PROMPT_TEMPLATE = """
|
@@ -32,3 +33,17 @@ class Config(BaseModel):
|
|
32 |
|
33 |
main_agent: AgentSchema
|
34 |
managed_agents: list[AgentSchema] | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from any_agent.schema import AgentSchema
|
4 |
from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, PositiveInt
|
5 |
+
import yaml
|
6 |
|
7 |
|
8 |
INPUT_PROMPT_TEMPLATE = """
|
|
|
33 |
|
34 |
main_agent: AgentSchema
|
35 |
managed_agents: list[AgentSchema] | None = None
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def from_yaml(cls, yaml_path: str) -> "Config":
|
39 |
+
"""
|
40 |
+
with open(yaml_path, "r") as f:
|
41 |
+
data = yaml.safe_load(f)
|
42 |
+
return cls(**data) yaml_path: Path to the YAML configuration file
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Config: A new Config instance populated with values from the YAML file
|
46 |
+
"""
|
47 |
+
with open(yaml_path, "r") as f:
|
48 |
+
data = yaml.safe_load(f)
|
49 |
+
return cls(**data)
|
src/surf_spot_finder/evaluation/__init__.py
CHANGED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class AgentType(str, Enum):
|
5 |
+
LANGCHAIN = "langchain"
|
6 |
+
OPENAI = "openai"
|
7 |
+
OPENAI_MULTI_AGENT = "openai_multi_agent"
|
8 |
+
SMOLAGENTS = "smolagents"
|
src/surf_spot_finder/evaluation/evaluate.py
CHANGED
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
|
|
6 |
from loguru import logger
|
7 |
from fire import Fire
|
8 |
import pandas as pd
|
9 |
-
from surf_spot_finder.cli import find_surf_spot
|
10 |
from surf_spot_finder.config import (
|
11 |
Config,
|
12 |
)
|
@@ -17,13 +16,15 @@ from surf_spot_finder.evaluation.evaluators import (
|
|
17 |
HypothesisEvaluator,
|
18 |
)
|
19 |
from surf_spot_finder.evaluation.test_case import TestCase
|
|
|
|
|
20 |
|
21 |
logger.remove()
|
22 |
logger = logger.opt(ansi=True)
|
23 |
logger.add(sys.stdout, colorize=True, format="{message}")
|
24 |
|
25 |
|
26 |
-
def
|
27 |
input_data = test_case.input
|
28 |
|
29 |
logger.info("Loading config")
|
@@ -31,20 +32,30 @@ def run_agent(test_case: TestCase, agent_config_path: str) -> str:
|
|
31 |
config.location = input_data.location
|
32 |
config.date = input_data.date
|
33 |
config.max_driving_hours = input_data.max_driving_hours
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
tools=config.tools,
|
45 |
-
input_prompt_template=config.input_prompt_template,
|
46 |
)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
50 |
# load the json file
|
@@ -75,12 +86,14 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
75 |
)
|
76 |
|
77 |
# Direct answer evaluation (new)
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
# Combine all results
|
85 |
verification_results = (
|
86 |
checkpoint_results + hypothesis_answer_results + direct_results
|
@@ -171,7 +184,7 @@ def evaluate(
|
|
171 |
assert (
|
172 |
agent_config_path is not None
|
173 |
), "Agent config path must be provided if running agent"
|
174 |
-
telemetry_path =
|
175 |
else:
|
176 |
logger.info(f"Using provided telemetry file: {telemetry_path}")
|
177 |
logger.info(
|
|
|
6 |
from loguru import logger
|
7 |
from fire import Fire
|
8 |
import pandas as pd
|
|
|
9 |
from surf_spot_finder.config import (
|
10 |
Config,
|
11 |
)
|
|
|
16 |
HypothesisEvaluator,
|
17 |
)
|
18 |
from surf_spot_finder.evaluation.test_case import TestCase
|
19 |
+
from any_agent import load_agent, run_agent
|
20 |
+
from any_agent.tracing import get_tracer_provider, setup_tracing
|
21 |
|
22 |
logger.remove()
|
23 |
logger = logger.opt(ansi=True)
|
24 |
logger.add(sys.stdout, colorize=True, format="{message}")
|
25 |
|
26 |
|
27 |
+
def run(test_case: TestCase, agent_config_path: str) -> str:
|
28 |
input_data = test_case.input
|
29 |
|
30 |
logger.info("Loading config")
|
|
|
32 |
config.location = input_data.location
|
33 |
config.date = input_data.date
|
34 |
config.max_driving_hours = input_data.max_driving_hours
|
35 |
+
logger.info("Setting up tracing")
|
36 |
+
tracer_provider, tracing_path = get_tracer_provider(project_name="surf-spot-finder")
|
37 |
+
setup_tracing(tracer_provider, config.framework)
|
38 |
+
|
39 |
+
logger.info(f"Loading {config.framework} agent")
|
40 |
+
logger.info(f"{config.managed_agents}")
|
41 |
+
agent = load_agent(
|
42 |
+
framework=config.framework,
|
43 |
+
main_agent=config.main_agent,
|
44 |
+
managed_agents=config.managed_agents,
|
|
|
|
|
45 |
)
|
46 |
|
47 |
+
query = config.input_prompt_template.format(
|
48 |
+
LOCATION=config.location,
|
49 |
+
MAX_DRIVING_HOURS=config.max_driving_hours,
|
50 |
+
DATE=config.date,
|
51 |
+
)
|
52 |
+
logger.info(f"Running agent with query:\n{query}")
|
53 |
+
run_agent(agent, query)
|
54 |
+
|
55 |
+
logger.success("Done!")
|
56 |
+
|
57 |
+
return tracing_path
|
58 |
+
|
59 |
|
60 |
def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
61 |
# load the json file
|
|
|
86 |
)
|
87 |
|
88 |
# Direct answer evaluation (new)
|
89 |
+
if test_case.ground_truth:
|
90 |
+
direct_evaluator = QuestionAnsweringSquadEvaluator()
|
91 |
+
direct_results = direct_evaluator.evaluate(
|
92 |
+
hypothesis_answer=hypothesis_answer,
|
93 |
+
ground_truth_answer=test_case.ground_truth,
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
direct_results = []
|
97 |
# Combine all results
|
98 |
verification_results = (
|
99 |
checkpoint_results + hypothesis_answer_results + direct_results
|
|
|
184 |
assert (
|
185 |
agent_config_path is not None
|
186 |
), "Agent config path must be provided if running agent"
|
187 |
+
telemetry_path = run(test_case, agent_config_path)
|
188 |
else:
|
189 |
logger.info(f"Using provided telemetry file: {telemetry_path}")
|
190 |
logger.info(
|
src/surf_spot_finder/evaluation/telemetry/langchain_telemetry.py
CHANGED
@@ -2,7 +2,7 @@ from typing import Any, Dict, List
|
|
2 |
import json
|
3 |
from langchain_core.messages import BaseMessage
|
4 |
|
5 |
-
from surf_spot_finder.
|
6 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
7 |
|
8 |
|
|
|
2 |
import json
|
3 |
from langchain_core.messages import BaseMessage
|
4 |
|
5 |
+
from surf_spot_finder.evaluation import AgentType
|
6 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
7 |
|
8 |
|
src/surf_spot_finder/evaluation/telemetry/openai_telemetry.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
-
from surf_spot_finder.
|
5 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
|
7 |
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
+
from surf_spot_finder.evaluation import AgentType
|
5 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
|
7 |
|
src/surf_spot_finder/evaluation/telemetry/smolagents_telemetry.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
-
from surf_spot_finder.
|
5 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
|
7 |
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
import json
|
3 |
|
4 |
+
from surf_spot_finder.evaluation import AgentType
|
5 |
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
|
7 |
|
src/surf_spot_finder/evaluation/telemetry/telemetry.py
CHANGED
@@ -3,7 +3,8 @@ import json
|
|
3 |
import re
|
4 |
from abc import ABC, abstractmethod
|
5 |
from loguru import logger
|
6 |
-
|
|
|
7 |
|
8 |
|
9 |
class TelemetryProcessor(ABC):
|
|
|
3 |
import re
|
4 |
from abc import ABC, abstractmethod
|
5 |
from loguru import logger
|
6 |
+
|
7 |
+
from surf_spot_finder.evaluation import AgentType
|
8 |
|
9 |
|
10 |
class TelemetryProcessor(ABC):
|
src/surf_spot_finder/evaluation/test_case.py
CHANGED
@@ -11,7 +11,6 @@ class InputModel(BaseModel):
|
|
11 |
location: str
|
12 |
date: str
|
13 |
max_driving_hours: int
|
14 |
-
json_tracer: bool
|
15 |
|
16 |
|
17 |
class CheckpointCriteria(BaseModel):
|
@@ -53,12 +52,15 @@ class TestCase(BaseModel):
|
|
53 |
}
|
54 |
)
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
62 |
|
63 |
test_case_dict["test_case_path"] = test_case_path
|
64 |
# verify that the llm_judge is a valid litellm model
|
|
|
11 |
location: str
|
12 |
date: str
|
13 |
max_driving_hours: int
|
|
|
14 |
|
15 |
|
16 |
class CheckpointCriteria(BaseModel):
|
|
|
52 |
}
|
53 |
)
|
54 |
|
55 |
+
if "ground_truth" in test_case_dict:
|
56 |
+
add_gt_final_answer_criteria(test_case_dict["ground_truth"])
|
57 |
+
test_case_dict["final_answer_criteria"] = final_answer_criteria
|
58 |
+
# remove the points from the ground_truth list but keep the name and value
|
59 |
+
test_case_dict["ground_truth"] = [
|
60 |
+
item
|
61 |
+
for item in test_case_dict["ground_truth"]
|
62 |
+
if isinstance(item, dict)
|
63 |
+
]
|
64 |
|
65 |
test_case_dict["test_case_path"] = test_case_path
|
66 |
# verify that the llm_judge is a valid litellm model
|