File size: 3,288 Bytes
6fdc19a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
from surf_spot_finder.tools import (
    driving_hours_to_meters,
    get_area_lat_lon,
    get_surfing_spots,
    get_wave_forecast,
    get_wind_forecast,
)
from surf_spot_finder.config import Config
from any_agent import AgentConfig, AnyAgent, TracingConfig
from any_agent.evaluation import evaluate, TraceEvaluationResult


async def run_agent(user_inputs):
    st.write("Running surf spot finder...")
    if "huggingface" in user_inputs["model_id"]:
        model_args = {
            "extra_headers": {"X-HF-Bill-To": "mozilla-ai"},
        }
    else:
        model_args = {}
    agent_config = AgentConfig(
        model_id=user_inputs["model_id"],
        model_args=model_args,
        tools=[
            get_wind_forecast,
            get_wave_forecast,
            get_area_lat_lon,
            get_surfing_spots,
            driving_hours_to_meters,
        ],
    )
    config = Config(
        location=user_inputs["location"],
        max_driving_hours=user_inputs["max_driving_hours"],
        date=user_inputs["date"],
        framework=user_inputs["framework"],
        main_agent=agent_config,
        managed_agents=[],
        evaluation_cases=None,
    )

    agent = await AnyAgent.create_async(
        agent_framework=config.framework,
        agent_config=config.main_agent,
        managed_agents=config.managed_agents,
        tracing=TracingConfig(console=True, cost_info=True),
    )

    query = config.input_prompt_template.format(
        LOCATION=config.location,
        MAX_DRIVING_HOURS=config.max_driving_hours,
        DATE=config.date,
    )
    st.write("Running agent with query:\n", query)

    with st.spinner("Running..."):
        agent_trace = await agent.run_async(query)
        agent.exit()

    st.write("Final output from agent:\n", agent_trace.final_output)

    # Display the agent trace
    with st.expander("Agent Trace", expanded=True):
        st.write(agent_trace.spans)

    if config.evaluation_cases is not None:
        results = []
        st.write("Found evaluation cases, running trace evaluation")
        for i, case in enumerate(config.evaluation_cases):
            st.write("Evaluating case: ", case)
            result: TraceEvaluationResult = evaluate(
                evaluation_case=case,
                trace=agent_trace,
                agent_framework=config.framework,
            )
            for list_of_checkpoints in [
                result.checkpoint_results,
                result.direct_results,
                result.hypothesis_answer_results,
            ]:
                for checkpoint in list_of_checkpoints:
                    msg = (
                        f"Checkpoint: {checkpoint.criteria}\n"
                        f"\tPassed: {checkpoint.passed}\n"
                        f"\tReason: {checkpoint.reason}\n"
                        f"\tScore: {'%d/%d' % (checkpoint.points, checkpoint.points) if checkpoint.passed else '0/%d' % checkpoint.points}"
                    )
                    st.write(msg)
            st.write("==========================")
            st.write("Overall Score: %d%%", 100 * result.score)
            st.write("==========================")
            results.append(result)
    st.write("Surf spot finder finished running.")