Spaces:
Running
Running
File size: 3,288 Bytes
6fdc19a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import streamlit as st
from surf_spot_finder.tools import (
driving_hours_to_meters,
get_area_lat_lon,
get_surfing_spots,
get_wave_forecast,
get_wind_forecast,
)
from surf_spot_finder.config import Config
from any_agent import AgentConfig, AnyAgent, TracingConfig
from any_agent.evaluation import evaluate, TraceEvaluationResult
async def run_agent(user_inputs):
st.write("Running surf spot finder...")
if "huggingface" in user_inputs["model_id"]:
model_args = {
"extra_headers": {"X-HF-Bill-To": "mozilla-ai"},
}
else:
model_args = {}
agent_config = AgentConfig(
model_id=user_inputs["model_id"],
model_args=model_args,
tools=[
get_wind_forecast,
get_wave_forecast,
get_area_lat_lon,
get_surfing_spots,
driving_hours_to_meters,
],
)
config = Config(
location=user_inputs["location"],
max_driving_hours=user_inputs["max_driving_hours"],
date=user_inputs["date"],
framework=user_inputs["framework"],
main_agent=agent_config,
managed_agents=[],
evaluation_cases=None,
)
agent = await AnyAgent.create_async(
agent_framework=config.framework,
agent_config=config.main_agent,
managed_agents=config.managed_agents,
tracing=TracingConfig(console=True, cost_info=True),
)
query = config.input_prompt_template.format(
LOCATION=config.location,
MAX_DRIVING_HOURS=config.max_driving_hours,
DATE=config.date,
)
st.write("Running agent with query:\n", query)
with st.spinner("Running..."):
agent_trace = await agent.run_async(query)
agent.exit()
st.write("Final output from agent:\n", agent_trace.final_output)
# Display the agent trace
with st.expander("Agent Trace", expanded=True):
st.write(agent_trace.spans)
if config.evaluation_cases is not None:
results = []
st.write("Found evaluation cases, running trace evaluation")
for i, case in enumerate(config.evaluation_cases):
st.write("Evaluating case: ", case)
result: TraceEvaluationResult = evaluate(
evaluation_case=case,
trace=agent_trace,
agent_framework=config.framework,
)
for list_of_checkpoints in [
result.checkpoint_results,
result.direct_results,
result.hypothesis_answer_results,
]:
for checkpoint in list_of_checkpoints:
msg = (
f"Checkpoint: {checkpoint.criteria}\n"
f"\tPassed: {checkpoint.passed}\n"
f"\tReason: {checkpoint.reason}\n"
f"\tScore: {'%d/%d' % (checkpoint.points, checkpoint.points) if checkpoint.passed else '0/%d' % checkpoint.points}"
)
st.write(msg)
st.write("==========================")
st.write("Overall Score: %d%%", 100 * result.score)
st.write("==========================")
results.append(result)
st.write("Surf spot finder finished running.")
|