Spaces:
Running
Running
github-actions[bot]
commited on
Commit
Β·
27f8cfc
1
Parent(s):
980d57f
Sync with https://github.com/mozilla-ai/surf-spot-finder
Browse files- app.py +122 -12
- components/__init__.py +0 -0
- components/inputs.py +156 -0
- components/sidebar.py +9 -0
- constants.py +52 -0
- services/__init__.py +0 -0
- services/agent.py +168 -0
app.py
CHANGED
@@ -1,39 +1,149 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import asyncio
|
3 |
import nest_asyncio
|
4 |
-
from
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
|
8 |
nest_asyncio.apply()
|
9 |
|
10 |
# Set page config
|
11 |
st.set_page_config(page_title="Surf Spot Finder", page_icon="π", layout="wide")
|
12 |
|
13 |
-
#
|
14 |
-
st.title("π Surf Spot Finder")
|
15 |
st.markdown(
|
16 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
|
19 |
-
# Sidebar
|
20 |
with st.sidebar:
|
21 |
-
|
22 |
-
st.markdown("Built using [Any-Agent](https://github.com/mozilla-ai/any-agent)")
|
23 |
-
user_inputs = get_user_inputs()
|
24 |
is_valid = user_inputs is not None
|
25 |
-
run_button = st.button("Run", disabled=not is_valid, type="primary")
|
26 |
|
27 |
|
28 |
# Main content
|
29 |
async def main():
|
|
|
30 |
if run_button:
|
31 |
-
await
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
else:
|
|
|
|
|
|
|
|
|
33 |
st.info(
|
34 |
"π Configure your search parameters in the sidebar and click Run to start!"
|
35 |
)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
if __name__ == "__main__":
|
39 |
loop = asyncio.new_event_loop()
|
|
|
1 |
+
from components.sidebar import ssf_sidebar
|
2 |
+
from constants import DEFAULT_TOOLS
|
3 |
import streamlit as st
|
4 |
import asyncio
|
5 |
import nest_asyncio
|
6 |
+
from services.agent import (
|
7 |
+
configure_agent,
|
8 |
+
display_evaluation_results,
|
9 |
+
display_output,
|
10 |
+
evaluate_agent,
|
11 |
+
run_agent,
|
12 |
+
)
|
13 |
|
14 |
nest_asyncio.apply()
|
15 |
|
16 |
# Set page config
|
17 |
st.set_page_config(page_title="Surf Spot Finder", page_icon="π", layout="wide")
|
18 |
|
19 |
+
# Allow a user to resize the sidebar to take up most of the screen to make editing eval cases easier
|
|
|
20 |
st.markdown(
|
21 |
+
"""
|
22 |
+
<style>
|
23 |
+
/* When sidebar is expanded, adjust main content */
|
24 |
+
section[data-testid="stSidebar"][aria-expanded="true"] {
|
25 |
+
max-width: 99% !important;
|
26 |
+
}
|
27 |
+
</style>
|
28 |
+
""",
|
29 |
+
unsafe_allow_html=True,
|
30 |
)
|
31 |
|
|
|
32 |
with st.sidebar:
|
33 |
+
user_inputs = ssf_sidebar()
|
|
|
|
|
34 |
is_valid = user_inputs is not None
|
35 |
+
run_button = st.button("Run Agent π€", disabled=not is_valid, type="primary")
|
36 |
|
37 |
|
38 |
# Main content
|
39 |
async def main():
|
40 |
+
# Handle agent execution button click
|
41 |
if run_button:
|
42 |
+
agent, agent_config = await configure_agent(user_inputs)
|
43 |
+
agent_trace, execution_time = await run_agent(agent, agent_config)
|
44 |
+
|
45 |
+
await display_output(agent_trace, execution_time)
|
46 |
+
|
47 |
+
evaluation_result = await evaluate_agent(agent_config, agent_trace)
|
48 |
+
|
49 |
+
await display_evaluation_results(evaluation_result)
|
50 |
else:
|
51 |
+
st.title("π Surf Spot Finder")
|
52 |
+
st.markdown(
|
53 |
+
"Find the best surfing spots based on your location and preferences! [Github Repo](https://github.com/mozilla-ai/surf-spot-finder)"
|
54 |
+
)
|
55 |
st.info(
|
56 |
"π Configure your search parameters in the sidebar and click Run to start!"
|
57 |
)
|
58 |
|
59 |
+
# Display tools in a more organized way
|
60 |
+
st.markdown("### π οΈ Available Tools")
|
61 |
+
|
62 |
+
st.markdown("""
|
63 |
+
The AI Agent built for this project has a few tools available for use in order to find the perfect surf spot.
|
64 |
+
The agent is given the freedom to use (or not use) these tools in order to accomplish the task.
|
65 |
+
""")
|
66 |
+
|
67 |
+
weather_tools = [
|
68 |
+
tool
|
69 |
+
for tool in DEFAULT_TOOLS
|
70 |
+
if "forecast" in tool.__name__ or "weather" in tool.__name__
|
71 |
+
]
|
72 |
+
for tool in weather_tools:
|
73 |
+
with st.expander(f"π€οΈ {tool.__name__}"):
|
74 |
+
st.markdown(tool.__doc__ or "No description available")
|
75 |
+
location_tools = [
|
76 |
+
tool
|
77 |
+
for tool in DEFAULT_TOOLS
|
78 |
+
if "lat" in tool.__name__
|
79 |
+
or "lon" in tool.__name__
|
80 |
+
or "area" in tool.__name__
|
81 |
+
]
|
82 |
+
for tool in location_tools:
|
83 |
+
with st.expander(f"π {tool.__name__}"):
|
84 |
+
st.markdown(tool.__doc__ or "No description available")
|
85 |
+
|
86 |
+
web_tools = [
|
87 |
+
tool
|
88 |
+
for tool in DEFAULT_TOOLS
|
89 |
+
if "web" in tool.__name__ or "search" in tool.__name__
|
90 |
+
]
|
91 |
+
for tool in web_tools:
|
92 |
+
with st.expander(f"π {tool.__name__}"):
|
93 |
+
st.markdown(tool.__doc__ or "No description available")
|
94 |
+
|
95 |
+
# add a check that all tools were listed
|
96 |
+
if len(weather_tools) + len(location_tools) + len(web_tools) != len(
|
97 |
+
DEFAULT_TOOLS
|
98 |
+
):
|
99 |
+
st.warning(
|
100 |
+
"Some tools are not listed. Please check the code for more details."
|
101 |
+
)
|
102 |
+
|
103 |
+
# Add Custom Evaluation explanation section
|
104 |
+
st.markdown("### π Custom Evaluation")
|
105 |
+
st.markdown("""
|
106 |
+
The Surf Spot Finder includes a powerful evaluation system that allows you to customize how the agent's performance is assessed.
|
107 |
+
You can find these settings in the sidebar under the "Custom Evaluation" expander.
|
108 |
+
""")
|
109 |
+
|
110 |
+
with st.expander("Learn more about Custom Evaluation"):
|
111 |
+
st.markdown("""
|
112 |
+
#### What is Custom Evaluation?
|
113 |
+
The Custom Evaluation feature uses an LLM-as-a-Judge approach to evaluate how well the agent performs its task.
|
114 |
+
An LLM will be given the complete agent trace (not just the final answer), and will assess the agent's performance based on the criteria you set.
|
115 |
+
You can customize:
|
116 |
+
|
117 |
+
- **Evaluation Model**: Choose which LLM should act as the judge
|
118 |
+
- **Evaluation Criteria**: Define specific checkpoints that the agent should meet
|
119 |
+
- **Scoring System**: Assign points to each criterion
|
120 |
+
|
121 |
+
#### How to Use Custom Evaluation
|
122 |
+
|
123 |
+
1. **Select an Evaluation Model**: Choose which LLM you want to use as the judge
|
124 |
+
2. **Edit Checkpoints**: Use the data editor to:
|
125 |
+
- Add new evaluation criteria
|
126 |
+
- Modify existing criteria
|
127 |
+
- Adjust point values
|
128 |
+
- Remove criteria you don't want to evaluate
|
129 |
+
|
130 |
+
#### Example Criteria
|
131 |
+
You can evaluate things like:
|
132 |
+
- Tool usage and success
|
133 |
+
- Order of operations
|
134 |
+
- Quality of final recommendations
|
135 |
+
- Response completeness
|
136 |
+
- Number of steps taken
|
137 |
+
|
138 |
+
#### Tips for Creating Good Evaluation Criteria
|
139 |
+
- Be specific about what you want to evaluate
|
140 |
+
- Use clear, unambiguous language
|
141 |
+
- Consider both process (how the agent works) and outcome (what it produces)
|
142 |
+
- Assign appropriate point values based on importance
|
143 |
+
|
144 |
+
The evaluation results will be displayed after each agent run, showing how well the agent met your custom criteria.
|
145 |
+
""")
|
146 |
+
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
loop = asyncio.new_event_loop()
|
components/__init__.py
ADDED
File without changes
|
components/inputs.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timedelta
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import streamlit as st
|
5 |
+
from any_agent import AgentFramework
|
6 |
+
from any_agent.tracing.trace import _is_tracing_supported
|
7 |
+
from any_agent.evaluation import EvaluationCase
|
8 |
+
from any_agent.evaluation.schemas import CheckpointCriteria
|
9 |
+
import pandas as pd
|
10 |
+
from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS
|
11 |
+
|
12 |
+
from pydantic import BaseModel, ConfigDict
|
13 |
+
|
14 |
+
|
15 |
+
class UserInputs(BaseModel):
|
16 |
+
model_config = ConfigDict(extra="forbid")
|
17 |
+
model_id: str
|
18 |
+
location: str
|
19 |
+
max_driving_hours: int
|
20 |
+
date: datetime
|
21 |
+
framework: str
|
22 |
+
evaluation_case: EvaluationCase
|
23 |
+
run_evaluation: bool
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache_resource
|
27 |
+
def get_area(area_name: str) -> dict:
|
28 |
+
"""Get the area from Nominatim.
|
29 |
+
|
30 |
+
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
|
31 |
+
|
32 |
+
Args:
|
33 |
+
area_name (str): The name of the area.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
dict: The area found.
|
37 |
+
"""
|
38 |
+
response = requests.get(
|
39 |
+
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
|
40 |
+
headers={"User-Agent": "Mozilla/5.0"},
|
41 |
+
timeout=5,
|
42 |
+
)
|
43 |
+
response.raise_for_status()
|
44 |
+
response_json = json.loads(response.content.decode())
|
45 |
+
return response_json
|
46 |
+
|
47 |
+
|
48 |
+
def get_user_inputs() -> UserInputs:
|
49 |
+
default_val = "Los Angeles California, US"
|
50 |
+
|
51 |
+
col1, col2 = st.columns([3, 1])
|
52 |
+
with col1:
|
53 |
+
location = st.text_input("Enter a location", value=default_val)
|
54 |
+
with col2:
|
55 |
+
if location:
|
56 |
+
location_check = get_area(location)
|
57 |
+
if not location_check:
|
58 |
+
st.error("β")
|
59 |
+
else:
|
60 |
+
st.success("β
")
|
61 |
+
|
62 |
+
max_driving_hours = st.number_input(
|
63 |
+
"Enter the maximum driving hours", min_value=1, value=2
|
64 |
+
)
|
65 |
+
|
66 |
+
col_date, col_time = st.columns([2, 1])
|
67 |
+
with col_date:
|
68 |
+
date = st.date_input(
|
69 |
+
"Select a date in the future", value=datetime.now() + timedelta(days=1)
|
70 |
+
)
|
71 |
+
with col_time:
|
72 |
+
# default to 9am
|
73 |
+
time = st.selectbox(
|
74 |
+
"Select a time",
|
75 |
+
[datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
|
76 |
+
index=9,
|
77 |
+
)
|
78 |
+
date = datetime.combine(date, time)
|
79 |
+
|
80 |
+
supported_frameworks = [
|
81 |
+
framework for framework in AgentFramework if _is_tracing_supported(framework)
|
82 |
+
]
|
83 |
+
|
84 |
+
framework = st.selectbox(
|
85 |
+
"Select the agent framework to use",
|
86 |
+
supported_frameworks,
|
87 |
+
index=3,
|
88 |
+
format_func=lambda x: x.name,
|
89 |
+
)
|
90 |
+
|
91 |
+
model_id = st.selectbox(
|
92 |
+
"Select the model to use",
|
93 |
+
MODEL_OPTIONS,
|
94 |
+
index=0,
|
95 |
+
format_func=lambda x: "/".join(x.split("/")[-3:]),
|
96 |
+
)
|
97 |
+
|
98 |
+
# Add evaluation case section
|
99 |
+
with st.expander("Custom Evaluation"):
|
100 |
+
evaluation_model_id = st.selectbox(
|
101 |
+
"Select the model to use for LLM-as-a-Judge evaluation",
|
102 |
+
MODEL_OPTIONS,
|
103 |
+
index=2,
|
104 |
+
format_func=lambda x: "/".join(x.split("/")[-3:]),
|
105 |
+
)
|
106 |
+
evaluation_case = DEFAULT_EVALUATION_CASE
|
107 |
+
evaluation_case.llm_judge = evaluation_model_id
|
108 |
+
# make this an editable json section
|
109 |
+
# convert the checkpoints to a df series so that it can be edited
|
110 |
+
checkpoints = evaluation_case.checkpoints
|
111 |
+
checkpoints_df = pd.DataFrame(
|
112 |
+
[checkpoint.model_dump() for checkpoint in checkpoints]
|
113 |
+
)
|
114 |
+
checkpoints_df = st.data_editor(
|
115 |
+
checkpoints_df,
|
116 |
+
column_config={
|
117 |
+
"points": st.column_config.NumberColumn(label="Points"),
|
118 |
+
"criteria": st.column_config.TextColumn(label="Criteria"),
|
119 |
+
},
|
120 |
+
hide_index=True,
|
121 |
+
num_rows="dynamic",
|
122 |
+
)
|
123 |
+
# for each checkpoint, convert it back to a CheckpointCriteria object
|
124 |
+
new_ckpts = []
|
125 |
+
|
126 |
+
# don't let a user add more than 20 checkpoints
|
127 |
+
if len(checkpoints_df) > 20:
|
128 |
+
st.error(
|
129 |
+
"You can only add up to 20 checkpoints for the purpose of this demo."
|
130 |
+
)
|
131 |
+
checkpoints_df = checkpoints_df[:20]
|
132 |
+
|
133 |
+
for _, row in checkpoints_df.iterrows():
|
134 |
+
if row["criteria"] == "":
|
135 |
+
continue
|
136 |
+
try:
|
137 |
+
# Don't let people write essays for criteria in this demo
|
138 |
+
if len(row["criteria"].split(" ")) > 100:
|
139 |
+
raise ValueError("Criteria is too long")
|
140 |
+
new_crit = CheckpointCriteria(
|
141 |
+
criteria=row["criteria"], points=row["points"]
|
142 |
+
)
|
143 |
+
new_ckpts.append(new_crit)
|
144 |
+
except Exception as e:
|
145 |
+
st.error(f"Error creating checkpoint: {e}")
|
146 |
+
evaluation_case.checkpoints = new_ckpts
|
147 |
+
|
148 |
+
return UserInputs(
|
149 |
+
model_id=model_id,
|
150 |
+
location=location,
|
151 |
+
max_driving_hours=max_driving_hours,
|
152 |
+
date=date,
|
153 |
+
framework=framework,
|
154 |
+
evaluation_case=evaluation_case,
|
155 |
+
run_evaluation=st.checkbox("Run Evaluation", value=True),
|
156 |
+
)
|
components/sidebar.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from components.inputs import UserInputs, get_user_inputs
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
|
5 |
+
def ssf_sidebar() -> UserInputs:
|
6 |
+
st.markdown("### Configuration")
|
7 |
+
st.markdown("Built using [Any-Agent](https://github.com/mozilla-ai/any-agent)")
|
8 |
+
user_inputs = get_user_inputs()
|
9 |
+
return user_inputs
|
constants.py
CHANGED
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
MODEL_OPTIONS = [
|
2 |
# "huggingface/novita/deepseek-ai/DeepSeek-V3",
|
3 |
# "huggingface/novita/meta-llama/Llama-3.3-70B-Instruct",
|
@@ -13,3 +21,47 @@ MODEL_OPTIONS = [
|
|
13 |
|
14 |
# Hugginface API Provider Error:
|
15 |
# Must alternate between assistant/user, which meant that the 'tool' role made it puke
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from any_agent.evaluation import EvaluationCase
|
2 |
+
from surf_spot_finder.tools import (
|
3 |
+
get_area_lat_lon,
|
4 |
+
get_wave_forecast,
|
5 |
+
get_wind_forecast,
|
6 |
+
)
|
7 |
+
from any_agent.tools.web_browsing import search_web, visit_webpage
|
8 |
+
|
9 |
MODEL_OPTIONS = [
|
10 |
# "huggingface/novita/deepseek-ai/DeepSeek-V3",
|
11 |
# "huggingface/novita/meta-llama/Llama-3.3-70B-Instruct",
|
|
|
21 |
|
22 |
# Hugginface API Provider Error:
|
23 |
# Must alternate between assistant/user, which meant that the 'tool' role made it puke
|
24 |
+
|
25 |
+
|
26 |
+
DEFAULT_EVALUATION_CASE = EvaluationCase(
|
27 |
+
llm_judge=MODEL_OPTIONS[0],
|
28 |
+
checkpoints=[
|
29 |
+
{
|
30 |
+
"criteria": "Check if the agent considered at least three surf spot options",
|
31 |
+
"points": 1,
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"criteria": "Check if the agent gathered wind and wave forecasts for each surf spot being evaluated.",
|
35 |
+
"points": 1,
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"criteria": "Check if the agent used any web search tools to explore which surf spots should be considered",
|
39 |
+
"points": 1,
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"criteria": "Check if the final answer contains any description about the weather (air temp, chance of rain, etc) at the chosen location",
|
43 |
+
"points": 1,
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"criteria": "Check if the final answer includes one of the surf spots evaluated by tools",
|
47 |
+
"points": 1,
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"criteria": "Check if the final answer includes information about some alternative surf spots if the user is not satisfied with the chosen one",
|
51 |
+
"points": 1,
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"criteria": "Check that the agent completed in fewer than 10 calls",
|
55 |
+
"points": 1,
|
56 |
+
},
|
57 |
+
],
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
DEFAULT_TOOLS = [
|
62 |
+
get_wind_forecast,
|
63 |
+
get_wave_forecast,
|
64 |
+
get_area_lat_lon,
|
65 |
+
search_web,
|
66 |
+
visit_webpage,
|
67 |
+
]
|
services/__init__.py
ADDED
File without changes
|
services/agent.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from components.inputs import UserInputs
|
3 |
+
from constants import DEFAULT_TOOLS
|
4 |
+
import streamlit as st
|
5 |
+
import time
|
6 |
+
from surf_spot_finder.config import Config
|
7 |
+
from any_agent import AgentConfig, AnyAgent, TracingConfig
|
8 |
+
from any_agent.tracing.trace import AgentTrace, TotalTokenUseAndCost
|
9 |
+
from any_agent.tracing.otel_types import StatusCode
|
10 |
+
from any_agent.evaluation import evaluate, TraceEvaluationResult
|
11 |
+
|
12 |
+
|
13 |
+
async def display_evaluation_results(result: TraceEvaluationResult):
|
14 |
+
all_results = (
|
15 |
+
result.checkpoint_results
|
16 |
+
+ result.hypothesis_answer_results
|
17 |
+
+ result.direct_results
|
18 |
+
)
|
19 |
+
|
20 |
+
# Create columns for better layout
|
21 |
+
col1, col2 = st.columns(2)
|
22 |
+
|
23 |
+
with col1:
|
24 |
+
st.markdown("#### Criteria Results")
|
25 |
+
for checkpoint in all_results:
|
26 |
+
if checkpoint.passed:
|
27 |
+
st.success(f"β
{checkpoint.criteria}")
|
28 |
+
else:
|
29 |
+
st.error(f"β {checkpoint.criteria}")
|
30 |
+
|
31 |
+
with col2:
|
32 |
+
st.markdown("#### Overall Score")
|
33 |
+
total_points = sum([result.points for result in all_results])
|
34 |
+
if total_points == 0:
|
35 |
+
msg = "Total points is 0, cannot calculate score."
|
36 |
+
raise ValueError(msg)
|
37 |
+
passed_points = sum([result.points for result in all_results if result.passed])
|
38 |
+
|
39 |
+
# Create a nice score display
|
40 |
+
st.markdown(f"### {passed_points}/{total_points}")
|
41 |
+
percentage = (passed_points / total_points) * 100
|
42 |
+
st.progress(percentage / 100)
|
43 |
+
st.markdown(f"**{percentage:.1f}%**")
|
44 |
+
|
45 |
+
|
46 |
+
@st.cache_resource
|
47 |
+
async def evaluate_agent(
|
48 |
+
config: Config, agent_trace: AgentTrace
|
49 |
+
) -> TraceEvaluationResult:
|
50 |
+
assert (
|
51 |
+
len(config.evaluation_cases) == 1
|
52 |
+
), "Only one evaluation case is supported in the demo"
|
53 |
+
st.markdown("### π Evaluation Results")
|
54 |
+
|
55 |
+
with st.spinner("Evaluating results..."):
|
56 |
+
case = config.evaluation_cases[0]
|
57 |
+
result: TraceEvaluationResult = evaluate(
|
58 |
+
evaluation_case=case,
|
59 |
+
trace=agent_trace,
|
60 |
+
agent_framework=config.framework,
|
61 |
+
)
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
|
66 |
+
if "huggingface" in user_inputs.model_id:
|
67 |
+
model_args = {
|
68 |
+
"extra_headers": {"X-HF-Bill-To": "mozilla-ai"},
|
69 |
+
"temperature": 0.0,
|
70 |
+
}
|
71 |
+
else:
|
72 |
+
model_args = {}
|
73 |
+
|
74 |
+
agent_config = AgentConfig(
|
75 |
+
model_id=user_inputs.model_id,
|
76 |
+
model_args=model_args,
|
77 |
+
tools=DEFAULT_TOOLS,
|
78 |
+
)
|
79 |
+
|
80 |
+
config = Config(
|
81 |
+
location=user_inputs.location,
|
82 |
+
max_driving_hours=user_inputs.max_driving_hours,
|
83 |
+
date=user_inputs.date,
|
84 |
+
framework=user_inputs.framework,
|
85 |
+
main_agent=agent_config,
|
86 |
+
managed_agents=[],
|
87 |
+
evaluation_cases=[user_inputs.evaluation_case],
|
88 |
+
)
|
89 |
+
|
90 |
+
agent = await AnyAgent.create_async(
|
91 |
+
agent_framework=config.framework,
|
92 |
+
agent_config=config.main_agent,
|
93 |
+
managed_agents=config.managed_agents,
|
94 |
+
tracing=TracingConfig(console=True, cost_info=True),
|
95 |
+
)
|
96 |
+
return agent, config
|
97 |
+
|
98 |
+
|
99 |
+
async def display_output(agent_trace: AgentTrace, execution_time: float):
|
100 |
+
cost: TotalTokenUseAndCost = agent_trace.get_total_cost()
|
101 |
+
with st.expander("### π Results", expanded=True):
|
102 |
+
time_col, cost_col, tokens_col = st.columns(3)
|
103 |
+
with time_col:
|
104 |
+
st.info(f"β±οΈ Execution Time: {execution_time:.2f} seconds")
|
105 |
+
with cost_col:
|
106 |
+
st.info(f"π° Estimated Cost: ${cost.total_cost:.6f}")
|
107 |
+
with tokens_col:
|
108 |
+
st.info(f"π¦ Total Tokens: {cost.total_tokens:,}")
|
109 |
+
st.markdown("#### Final Output")
|
110 |
+
st.info(agent_trace.final_output)
|
111 |
+
|
112 |
+
# Display the agent trace in a more organized way
|
113 |
+
with st.expander("### π§© Agent Trace"):
|
114 |
+
for span in agent_trace.spans:
|
115 |
+
# Header with name and status
|
116 |
+
col1, col2 = st.columns([4, 1])
|
117 |
+
with col1:
|
118 |
+
st.markdown(f"**{span.name}**")
|
119 |
+
if span.attributes:
|
120 |
+
# st.json(span.attributes, expanded=False)
|
121 |
+
if "input.value" in span.attributes:
|
122 |
+
try:
|
123 |
+
input_value = json.loads(span.attributes["input.value"])
|
124 |
+
if isinstance(input_value, list) and len(input_value) > 0:
|
125 |
+
st.write(f"Input: {input_value[-1]}")
|
126 |
+
else:
|
127 |
+
st.write(f"Input: {input_value}")
|
128 |
+
except Exception: # noqa: E722
|
129 |
+
st.write(f"Input: {span.attributes['input.value']}")
|
130 |
+
if "output.value" in span.attributes:
|
131 |
+
try:
|
132 |
+
output_value = json.loads(span.attributes["output.value"])
|
133 |
+
if isinstance(output_value, list) and len(output_value) > 0:
|
134 |
+
st.write(f"Output: {output_value[-1]}")
|
135 |
+
else:
|
136 |
+
st.write(f"Output: {output_value}")
|
137 |
+
except Exception: # noqa: E722
|
138 |
+
st.write(f"Output: {span.attributes['output.value']}")
|
139 |
+
with col2:
|
140 |
+
status_color = (
|
141 |
+
"green" if span.status.status_code == StatusCode.OK else "red"
|
142 |
+
)
|
143 |
+
st.markdown(
|
144 |
+
f"<span style='color: {status_color}'>β {span.status.status_code.name}</span>",
|
145 |
+
unsafe_allow_html=True,
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
@st.cache_resource
|
150 |
+
async def run_agent(agent, config) -> tuple[AgentTrace, float]:
|
151 |
+
st.markdown("#### π Running Surf Spot Finder with query")
|
152 |
+
|
153 |
+
query = config.input_prompt_template.format(
|
154 |
+
LOCATION=config.location,
|
155 |
+
MAX_DRIVING_HOURS=config.max_driving_hours,
|
156 |
+
DATE=config.date,
|
157 |
+
)
|
158 |
+
|
159 |
+
st.code(query, language="text")
|
160 |
+
|
161 |
+
start_time = time.time()
|
162 |
+
with st.spinner("π€ Analyzing surf spots..."):
|
163 |
+
agent_trace: AgentTrace = await agent.run_async(query)
|
164 |
+
agent.exit()
|
165 |
+
|
166 |
+
end_time = time.time()
|
167 |
+
execution_time = end_time - start_time
|
168 |
+
return agent_trace, execution_time
|