github-actions[bot] commited on
Commit
89845e5
Β·
1 Parent(s): a80cc10

Sync with https://github.com/mozilla-ai/surf-spot-finder

Browse files
Files changed (2) hide show
  1. components/inputs.py +5 -10
  2. services/agent.py +24 -2
components/inputs.py CHANGED
@@ -48,16 +48,11 @@ def get_area(area_name: str) -> dict:
48
  def get_user_inputs() -> UserInputs:
49
  default_val = "Los Angeles California, US"
50
 
51
- col1, col2 = st.columns([3, 1])
52
- with col1:
53
- location = st.text_input("Enter a location", value=default_val)
54
- with col2:
55
- if location:
56
- location_check = get_area(location)
57
- if not location_check:
58
- st.error("❌")
59
- else:
60
- st.success("βœ…")
61
 
62
  max_driving_hours = st.number_input(
63
  "Enter the maximum driving hours", min_value=1, value=2
 
48
  def get_user_inputs() -> UserInputs:
49
  default_val = "Los Angeles California, US"
50
 
51
+ location = st.text_input("Enter a location", value=default_val)
52
+ if location:
53
+ location_check = get_area(location)
54
+ if not location_check:
55
+ st.error("❌ Invalid location")
 
 
 
 
 
56
 
57
  max_driving_hours = st.number_input(
58
  "Enter the maximum driving hours", min_value=1, value=2
services/agent.py CHANGED
@@ -4,7 +4,7 @@ from constants import DEFAULT_TOOLS
4
  import streamlit as st
5
  import time
6
  from surf_spot_finder.config import Config
7
- from any_agent import AgentConfig, AnyAgent, TracingConfig
8
  from any_agent.tracing.trace import AgentTrace, TotalTokenUseAndCost
9
  from any_agent.tracing.otel_types import StatusCode
10
  from any_agent.evaluation import evaluate, TraceEvaluationResult
@@ -70,9 +70,15 @@ async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
70
  else:
71
  model_args = {}
72
 
 
 
 
 
 
73
  agent_config = AgentConfig(
74
  model_id=user_inputs.model_id,
75
  model_args=model_args,
 
76
  tools=DEFAULT_TOOLS,
77
  )
78
 
@@ -155,10 +161,26 @@ async def run_agent(agent, config) -> tuple[AgentTrace, float]:
155
  )
156
 
157
  st.code(query, language="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  start_time = time.time()
160
  with st.spinner("πŸ€” Analyzing surf spots..."):
161
- agent_trace: AgentTrace = await agent.run_async(query)
162
  agent.exit()
163
 
164
  end_time = time.time()
 
4
  import streamlit as st
5
  import time
6
  from surf_spot_finder.config import Config
7
+ from any_agent import AgentConfig, AnyAgent, TracingConfig, AgentFramework
8
  from any_agent.tracing.trace import AgentTrace, TotalTokenUseAndCost
9
  from any_agent.tracing.otel_types import StatusCode
10
  from any_agent.evaluation import evaluate, TraceEvaluationResult
 
70
  else:
71
  model_args = {}
72
 
73
+ if user_inputs.framework == AgentFramework.AGNO:
74
+ agent_args = {"tool_call_limit": 20}
75
+ else:
76
+ agent_args = {}
77
+
78
  agent_config = AgentConfig(
79
  model_id=user_inputs.model_id,
80
  model_args=model_args,
81
+ agent_args=agent_args,
82
  tools=DEFAULT_TOOLS,
83
  )
84
 
 
161
  )
162
 
163
  st.code(query, language="text")
164
+ kwargs = {}
165
+ if (
166
+ config.framework == AgentFramework.OPENAI
167
+ or config.framework == AgentFramework.TINYAGENT
168
+ ):
169
+ kwargs["max_turns"] = 20
170
+ elif config.framework == AgentFramework.SMOLAGENTS:
171
+ kwargs["max_steps"] = 20
172
+ if config.framework == AgentFramework.LANGCHAIN:
173
+ from langchain_core.runnables import RunnableConfig
174
+
175
+ kwargs["config"] = RunnableConfig(recursion_limit=20)
176
+ elif config.framework == AgentFramework.GOOGLE:
177
+ from google.adk.agents.run_config import RunConfig
178
+
179
+ kwargs["run_config"] = RunConfig(max_llm_calls=20)
180
 
181
  start_time = time.time()
182
  with st.spinner("πŸ€” Analyzing surf spots..."):
183
+ agent_trace: AgentTrace = await agent.run_async(query, **kwargs)
184
  agent.exit()
185
 
186
  end_time = time.time()