Seb1101 commited on
Commit
580b491
·
verified ·
1 Parent(s): d5ed1fc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +106 -54
agent.py CHANGED
@@ -3,7 +3,6 @@ import re
3
  from datetime import datetime, timedelta
4
  from typing import TypedDict, Annotated
5
  import sympy as sp
6
- from sympy import *
7
  import math
8
  from langchain_openai import ChatOpenAI
9
  from langchain_community.tools.tavily_search import TavilySearchResults
@@ -137,9 +136,12 @@ class GAIAAgent:
137
  if not openai_key:
138
  raise ValueError("OPENAI_API_KEY environment variable is required")
139
  if not tavily_key:
140
- raise ValueError("TAVILY_API_KEY environment variable is required")
 
 
 
141
 
142
- print("✅ API keys found - initializing agent...")
143
 
144
  # Initialize LLM (using OpenAI GPT-4)
145
  self.llm = ChatOpenAI(
@@ -148,17 +150,20 @@ class GAIAAgent:
148
  openai_api_key=openai_key
149
  )
150
 
151
- # Initialize tools
152
- self.search_tool = TavilySearchResults(
153
- max_results=5,
154
- tavily_api_key=tavily_key
155
- )
156
-
157
- # Create tool list
158
- self.tools = [self.search_tool]
159
 
160
- # Create LLM with tools
161
- self.llm_with_tools = self.llm.bind_tools(self.tools)
 
 
 
162
 
163
  # Build the graph
164
  self.graph = self._build_graph()
@@ -172,57 +177,93 @@ class GAIAAgent:
172
  """Main agent reasoning node"""
173
  messages = state["messages"]
174
 
175
- # Add system message if not present
176
  if not any(isinstance(msg, SystemMessage) for msg in messages):
177
  system_msg = SystemMessage(content=self.system_prompt)
178
  messages = [system_msg] + messages
179
 
180
- # Get the last human message to check if it needs special processing
181
- last_human_msg = None
182
- for msg in reversed(messages):
183
  if isinstance(msg, HumanMessage):
184
- last_human_msg = msg.content
185
  break
186
 
187
- # Check if this is a math problem
188
- if last_human_msg and self._is_math_problem(last_human_msg):
189
- math_result = math_calculator(last_human_msg)
190
- enhanced_msg = f"Math calculation result: {math_result}\n\nOriginal question: {last_human_msg}\n\nProvide your final answer based on this calculation."
191
- messages[-1] = HumanMessage(content=enhanced_msg)
192
 
193
- # Check if this is a date/time problem
194
- elif last_human_msg and self._is_datetime_problem(last_human_msg):
195
- datetime_result = date_time_processor(last_human_msg)
196
- enhanced_msg = f"Date/time processing result: {datetime_result}\n\nOriginal question: {last_human_msg}\n\nProvide your final answer based on this information."
197
- messages[-1] = HumanMessage(content=enhanced_msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- response = self.llm_with_tools.invoke(messages)
200
- return {"messages": messages + [response]}
 
 
 
 
 
 
201
 
202
  def tool_node(state: AgentState):
203
  """Tool execution node"""
204
- messages = state["messages"]
205
- last_message = messages[-1]
206
-
207
- # Execute tool calls
208
- tool_node_instance = ToolNode(self.tools)
209
- result = tool_node_instance.invoke(state)
210
- return result
 
 
 
211
 
212
  def should_continue(state: AgentState):
213
  """Decide whether to continue or end"""
214
- last_message = state["messages"][-1]
215
-
216
- # If the last message has tool calls, continue to tools
217
- if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
218
- return "tools"
219
-
220
- # If we have a final answer, end
221
- if hasattr(last_message, 'content') and "FINAL ANSWER:" in last_message.content:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  return "end"
223
-
224
- # Otherwise continue
225
- return "end"
226
 
227
  # Build the graph
228
  workflow = StateGraph(AgentState)
@@ -239,9 +280,8 @@ class GAIAAgent:
239
  })
240
  workflow.add_edge("tools", "agent")
241
 
242
- # Compile
243
- memory = MemorySaver()
244
- return workflow.compile(checkpointer=memory)
245
 
246
  def _is_math_problem(self, text: str) -> bool:
247
  """Check if the text contains mathematical expressions"""
@@ -268,14 +308,20 @@ class GAIAAgent:
268
  try:
269
  print(f"Processing question: {question[:100]}...")
270
 
 
 
 
 
 
 
 
271
  # Create initial state
272
  initial_state = {
273
  "messages": [HumanMessage(content=question)]
274
  }
275
 
276
  # Run the graph
277
- config = {"configurable": {"thread_id": "gaia_thread"}}
278
- final_state = self.graph.invoke(initial_state, config)
279
 
280
  # Extract the final answer
281
  last_message = final_state["messages"][-1]
@@ -289,7 +335,13 @@ class GAIAAgent:
289
 
290
  except Exception as e:
291
  print(f"Error processing question: {e}")
292
- return f"Error: {str(e)}"
 
 
 
 
 
 
293
 
294
  def _extract_final_answer(self, response: str) -> str:
295
  """Extract the final answer from the response"""
 
3
  from datetime import datetime, timedelta
4
  from typing import TypedDict, Annotated
5
  import sympy as sp
 
6
  import math
7
  from langchain_openai import ChatOpenAI
8
  from langchain_community.tools.tavily_search import TavilySearchResults
 
136
  if not openai_key:
137
  raise ValueError("OPENAI_API_KEY environment variable is required")
138
  if not tavily_key:
139
+ print("⚠️ TAVILY_API_KEY not found - web search will be disabled")
140
+ self.has_search = False
141
+ else:
142
+ self.has_search = True
143
 
144
+ print("✅ Initializing GAIA agent...")
145
 
146
  # Initialize LLM (using OpenAI GPT-4)
147
  self.llm = ChatOpenAI(
 
150
  openai_api_key=openai_key
151
  )
152
 
153
+ # Initialize tools only if we have Tavily key
154
+ self.tools = []
155
+ if self.has_search:
156
+ self.search_tool = TavilySearchResults(
157
+ max_results=5,
158
+ tavily_api_key=tavily_key
159
+ )
160
+ self.tools = [self.search_tool]
161
 
162
+ # Create LLM with tools (only if we have tools)
163
+ if self.tools:
164
+ self.llm_with_tools = self.llm.bind_tools(self.tools)
165
+ else:
166
+ self.llm_with_tools = self.llm
167
 
168
  # Build the graph
169
  self.graph = self._build_graph()
 
177
  """Main agent reasoning node"""
178
  messages = state["messages"]
179
 
180
+ # Add system message if not present at the beginning
181
  if not any(isinstance(msg, SystemMessage) for msg in messages):
182
  system_msg = SystemMessage(content=self.system_prompt)
183
  messages = [system_msg] + messages
184
 
185
+ # Get the original question (the first HumanMessage)
186
+ original_question = None
187
+ for msg in messages:
188
  if isinstance(msg, HumanMessage):
189
+ original_question = msg.content
190
  break
191
 
192
+ # Check if this is a fresh question (not after tool calls)
193
+ last_msg = messages[-1]
194
+ is_fresh_question = isinstance(last_msg, HumanMessage)
 
 
195
 
196
+ # Only do special processing for fresh questions
197
+ if is_fresh_question and original_question:
198
+ # Check if this is a math problem
199
+ if self._is_math_problem(original_question):
200
+ try:
201
+ math_result = math_calculator(original_question)
202
+ enhanced_msg = f"Question: {original_question}\n\nMath calculation result: {math_result}\n\nBased on this calculation, provide your final answer using the format: FINAL ANSWER: [your answer]"
203
+ messages[-1] = HumanMessage(content=enhanced_msg)
204
+ except Exception as e:
205
+ print(f"Math calculation error: {e}")
206
+
207
+ # Check if this is a date/time problem
208
+ elif self._is_datetime_problem(original_question):
209
+ try:
210
+ datetime_result = date_time_processor(original_question)
211
+ enhanced_msg = f"Question: {original_question}\n\nDate/time processing result: {datetime_result}\n\nBased on this information, provide your final answer using the format: FINAL ANSWER: [your answer]"
212
+ messages[-1] = HumanMessage(content=enhanced_msg)
213
+ except Exception as e:
214
+ print(f"DateTime processing error: {e}")
215
 
216
+ try:
217
+ response = self.llm_with_tools.invoke(messages)
218
+ return {"messages": messages + [response]}
219
+ except Exception as e:
220
+ print(f"LLM invocation error: {e}")
221
+ # Return a simple response on error
222
+ error_response = HumanMessage(content=f"FINAL ANSWER: Error processing question: {str(e)}")
223
+ return {"messages": messages + [error_response]}
224
 
225
  def tool_node(state: AgentState):
226
  """Tool execution node"""
227
+ try:
228
+ tool_node_instance = ToolNode(self.tools)
229
+ result = tool_node_instance.invoke(state)
230
+ return result
231
+ except Exception as e:
232
+ print(f"Tool execution error: {e}")
233
+ # Add an error message and continue
234
+ messages = state["messages"]
235
+ error_msg = HumanMessage(content=f"Tool execution failed: {str(e)}. Please provide your best answer without tools.")
236
+ return {"messages": messages + [error_msg]}
237
 
238
  def should_continue(state: AgentState):
239
  """Decide whether to continue or end"""
240
+ try:
241
+ last_message = state["messages"][-1]
242
+
243
+ # If we don't have tools, just end
244
+ if not self.tools:
245
+ return "end"
246
+
247
+ # If the last message has tool calls, continue to tools
248
+ if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
249
+ return "tools"
250
+
251
+ # If we have a final answer, end
252
+ if (hasattr(last_message, 'content') and
253
+ last_message.content and
254
+ "FINAL ANSWER:" in str(last_message.content)):
255
+ return "end"
256
+
257
+ # Check if we've had too many iterations (prevent infinite loops)
258
+ if len(state["messages"]) > 10:
259
+ return "end"
260
+
261
+ # Otherwise end (be conservative)
262
+ return "end"
263
+
264
+ except Exception as e:
265
+ print(f"Should continue error: {e}")
266
  return "end"
 
 
 
267
 
268
  # Build the graph
269
  workflow = StateGraph(AgentState)
 
280
  })
281
  workflow.add_edge("tools", "agent")
282
 
283
+ # Compile without checkpointer to avoid state issues
284
+ return workflow.compile()
 
285
 
286
  def _is_math_problem(self, text: str) -> bool:
287
  """Check if the text contains mathematical expressions"""
 
308
  try:
309
  print(f"Processing question: {question[:100]}...")
310
 
311
+ # Check for file/media requirements that we can't handle
312
+ if any(indicator in question.lower() for indicator in [
313
+ 'attached', 'audio', 'video', 'image', 'file', 'mp3', 'pdf',
314
+ 'excel', 'spreadsheet', 'listen to', 'watch', 'download'
315
+ ]):
316
+ return "Unable to process files or media attachments"
317
+
318
  # Create initial state
319
  initial_state = {
320
  "messages": [HumanMessage(content=question)]
321
  }
322
 
323
  # Run the graph
324
+ final_state = self.graph.invoke(initial_state)
 
325
 
326
  # Extract the final answer
327
  last_message = final_state["messages"][-1]
 
335
 
336
  except Exception as e:
337
  print(f"Error processing question: {e}")
338
+ # Try to provide a meaningful fallback
339
+ if "api" in str(e).lower() or "key" in str(e).lower():
340
+ return "Error: API key configuration issue"
341
+ elif "tool" in str(e).lower():
342
+ return "Error: Tool execution issue"
343
+ else:
344
+ return f"Unable to process question due to technical error"
345
 
346
  def _extract_final_answer(self, response: str) -> str:
347
  """Extract the final answer from the response"""