Seb1101 commited on
Commit
4cc5535
·
verified ·
1 Parent(s): 580b491

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +82 -138
agent.py CHANGED
@@ -7,10 +7,6 @@ import math
7
  from langchain_openai import ChatOpenAI
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from langchain_core.messages import HumanMessage, SystemMessage
10
- from langgraph.graph import StateGraph, MessagesState, START, END
11
- from langgraph.prebuilt import ToolNode
12
- from langgraph.checkpoint.memory import MemorySaver
13
- import json
14
 
15
  # Load environment variables
16
  from dotenv import load_dotenv
@@ -123,9 +119,7 @@ def date_time_processor(query: str) -> str:
123
  except Exception as e:
124
  return f"Error processing date/time query: {str(e)}"
125
 
126
- # Define the agent state
127
- class AgentState(TypedDict):
128
- messages: Annotated[list, "The messages in the conversation"]
129
 
130
  class GAIAAgent:
131
  def __init__(self):
@@ -150,138 +144,41 @@ class GAIAAgent:
150
  openai_api_key=openai_key
151
  )
152
 
153
- # Initialize tools only if we have Tavily key
154
- self.tools = []
155
  if self.has_search:
156
  self.search_tool = TavilySearchResults(
157
  max_results=5,
158
  tavily_api_key=tavily_key
159
  )
160
- self.tools = [self.search_tool]
161
-
162
- # Create LLM with tools (only if we have tools)
163
- if self.tools:
164
- self.llm_with_tools = self.llm.bind_tools(self.tools)
165
  else:
166
- self.llm_with_tools = self.llm
167
-
168
- # Build the graph
169
- self.graph = self._build_graph()
170
 
171
  self.system_prompt = read_system_prompt()
172
 
173
- def _build_graph(self):
174
- """Build the LangGraph workflow"""
175
-
176
- def agent_node(state: AgentState):
177
- """Main agent reasoning node"""
178
- messages = state["messages"]
179
-
180
- # Add system message if not present at the beginning
181
- if not any(isinstance(msg, SystemMessage) for msg in messages):
182
- system_msg = SystemMessage(content=self.system_prompt)
183
- messages = [system_msg] + messages
184
-
185
- # Get the original question (the first HumanMessage)
186
- original_question = None
187
- for msg in messages:
188
- if isinstance(msg, HumanMessage):
189
- original_question = msg.content
190
- break
191
-
192
- # Check if this is a fresh question (not after tool calls)
193
- last_msg = messages[-1]
194
- is_fresh_question = isinstance(last_msg, HumanMessage)
195
-
196
- # Only do special processing for fresh questions
197
- if is_fresh_question and original_question:
198
- # Check if this is a math problem
199
- if self._is_math_problem(original_question):
200
- try:
201
- math_result = math_calculator(original_question)
202
- enhanced_msg = f"Question: {original_question}\n\nMath calculation result: {math_result}\n\nBased on this calculation, provide your final answer using the format: FINAL ANSWER: [your answer]"
203
- messages[-1] = HumanMessage(content=enhanced_msg)
204
- except Exception as e:
205
- print(f"Math calculation error: {e}")
206
-
207
- # Check if this is a date/time problem
208
- elif self._is_datetime_problem(original_question):
209
- try:
210
- datetime_result = date_time_processor(original_question)
211
- enhanced_msg = f"Question: {original_question}\n\nDate/time processing result: {datetime_result}\n\nBased on this information, provide your final answer using the format: FINAL ANSWER: [your answer]"
212
- messages[-1] = HumanMessage(content=enhanced_msg)
213
- except Exception as e:
214
- print(f"DateTime processing error: {e}")
215
-
216
- try:
217
- response = self.llm_with_tools.invoke(messages)
218
- return {"messages": messages + [response]}
219
- except Exception as e:
220
- print(f"LLM invocation error: {e}")
221
- # Return a simple response on error
222
- error_response = HumanMessage(content=f"FINAL ANSWER: Error processing question: {str(e)}")
223
- return {"messages": messages + [error_response]}
224
 
225
- def tool_node(state: AgentState):
226
- """Tool execution node"""
227
- try:
228
- tool_node_instance = ToolNode(self.tools)
229
- result = tool_node_instance.invoke(state)
230
- return result
231
- except Exception as e:
232
- print(f"Tool execution error: {e}")
233
- # Add an error message and continue
234
- messages = state["messages"]
235
- error_msg = HumanMessage(content=f"Tool execution failed: {str(e)}. Please provide your best answer without tools.")
236
- return {"messages": messages + [error_msg]}
237
-
238
- def should_continue(state: AgentState):
239
- """Decide whether to continue or end"""
240
- try:
241
- last_message = state["messages"][-1]
242
-
243
- # If we don't have tools, just end
244
- if not self.tools:
245
- return "end"
246
-
247
- # If the last message has tool calls, continue to tools
248
- if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
249
- return "tools"
250
-
251
- # If we have a final answer, end
252
- if (hasattr(last_message, 'content') and
253
- last_message.content and
254
- "FINAL ANSWER:" in str(last_message.content)):
255
- return "end"
256
-
257
- # Check if we've had too many iterations (prevent infinite loops)
258
- if len(state["messages"]) > 10:
259
- return "end"
260
-
261
- # Otherwise end (be conservative)
262
- return "end"
263
 
264
- except Exception as e:
265
- print(f"Should continue error: {e}")
266
- return "end"
267
-
268
- # Build the graph
269
- workflow = StateGraph(AgentState)
270
-
271
- # Add nodes
272
- workflow.add_node("agent", agent_node)
273
- workflow.add_node("tools", tool_node)
274
-
275
- # Add edges
276
- workflow.add_edge(START, "agent")
277
- workflow.add_conditional_edges("agent", should_continue, {
278
- "tools": "tools",
279
- "end": END
280
- })
281
- workflow.add_edge("tools", "agent")
282
-
283
- # Compile without checkpointer to avoid state issues
284
- return workflow.compile()
285
 
286
  def _is_math_problem(self, text: str) -> bool:
287
  """Check if the text contains mathematical expressions"""
@@ -315,19 +212,20 @@ class GAIAAgent:
315
  ]):
316
  return "Unable to process files or media attachments"
317
 
318
- # Create initial state
319
- initial_state = {
320
- "messages": [HumanMessage(content=question)]
321
- }
322
 
323
- # Run the graph
324
- final_state = self.graph.invoke(initial_state)
 
 
 
325
 
326
- # Extract the final answer
327
- last_message = final_state["messages"][-1]
328
- response_content = last_message.content if hasattr(last_message, 'content') else str(last_message)
329
 
330
- # Extract just the final answer part
331
  final_answer = self._extract_final_answer(response_content)
332
 
333
  print(f"Final answer: {final_answer}")
@@ -343,6 +241,52 @@ class GAIAAgent:
343
  else:
344
  return f"Unable to process question due to technical error"
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  def _extract_final_answer(self, response: str) -> str:
347
  """Extract the final answer from the response"""
348
  if "FINAL ANSWER:" in response:
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from langchain_core.messages import HumanMessage, SystemMessage
 
 
 
 
10
 
11
  # Load environment variables
12
  from dotenv import load_dotenv
 
119
  except Exception as e:
120
  return f"Error processing date/time query: {str(e)}"
121
 
122
+ # Removed LangGraph dependencies - using simpler approach
 
 
123
 
124
  class GAIAAgent:
125
  def __init__(self):
 
144
  openai_api_key=openai_key
145
  )
146
 
147
+ # Initialize search tool if available
 
148
  if self.has_search:
149
  self.search_tool = TavilySearchResults(
150
  max_results=5,
151
  tavily_api_key=tavily_key
152
  )
 
 
 
 
 
153
  else:
154
+ self.search_tool = None
 
 
 
155
 
156
  self.system_prompt = read_system_prompt()
157
 
158
+ def _search_web(self, query: str) -> str:
159
+ """Perform web search if available"""
160
+ if not self.search_tool:
161
+ return "Web search not available (no Tavily API key)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ try:
164
+ results = self.search_tool.invoke({"query": query})
165
+ if results and len(results) > 0:
166
+ # Format the results nicely
167
+ formatted_results = []
168
+ for i, result in enumerate(results[:3], 1): # Top 3 results
169
+ if isinstance(result, dict):
170
+ title = result.get('title', 'No title')
171
+ content = result.get('content', 'No content')
172
+ url = result.get('url', 'No URL')
173
+ formatted_results.append(f"{i}. {title}\n {content}\n Source: {url}")
174
+ else:
175
+ formatted_results.append(f"{i}. {str(result)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ return "\n\n".join(formatted_results)
178
+ else:
179
+ return "No search results found"
180
+ except Exception as e:
181
+ return f"Search error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  def _is_math_problem(self, text: str) -> bool:
184
  """Check if the text contains mathematical expressions"""
 
212
  ]):
213
  return "Unable to process files or media attachments"
214
 
215
+ # Build the prompt based on question type
216
+ enhanced_question = self._enhance_question(question)
 
 
217
 
218
+ # Create messages
219
+ messages = [
220
+ SystemMessage(content=self.system_prompt),
221
+ HumanMessage(content=enhanced_question)
222
+ ]
223
 
224
+ # Get response from LLM
225
+ response = self.llm.invoke(messages)
226
+ response_content = response.content if hasattr(response, 'content') else str(response)
227
 
228
+ # Extract the final answer
229
  final_answer = self._extract_final_answer(response_content)
230
 
231
  print(f"Final answer: {final_answer}")
 
241
  else:
242
  return f"Unable to process question due to technical error"
243
 
244
+ def _enhance_question(self, question: str) -> str:
245
+ """Enhance the question with relevant context and tools"""
246
+ try:
247
+ # Check if this is a math problem
248
+ if self._is_math_problem(question):
249
+ try:
250
+ math_result = math_calculator(question)
251
+ return f"Question: {question}\n\nMath calculation result: {math_result}\n\nBased on this calculation, provide your final answer using the format: FINAL ANSWER: [your answer]"
252
+ except Exception as e:
253
+ print(f"Math calculation error: {e}")
254
+
255
+ # Check if this is a date/time problem
256
+ elif self._is_datetime_problem(question):
257
+ try:
258
+ datetime_result = date_time_processor(question)
259
+ return f"Question: {question}\n\nDate/time processing result: {datetime_result}\n\nBased on this information, provide your final answer using the format: FINAL ANSWER: [your answer]"
260
+ except Exception as e:
261
+ print(f"DateTime processing error: {e}")
262
+
263
+ # Check if this needs web search
264
+ elif self._needs_web_search(question):
265
+ try:
266
+ search_result = self._search_web(question)
267
+ return f"Question: {question}\n\nWeb search results:\n{search_result}\n\nBased on this information, provide your final answer using the format: FINAL ANSWER: [your answer]"
268
+ except Exception as e:
269
+ print(f"Web search error: {e}")
270
+
271
+ # For other questions, just add the format instruction
272
+ return f"Question: {question}\n\nProvide your final answer using the format: FINAL ANSWER: [your answer]"
273
+
274
+ except Exception as e:
275
+ print(f"Question enhancement error: {e}")
276
+ return f"Question: {question}\n\nProvide your final answer using the format: FINAL ANSWER: [your answer]"
277
+
278
+ def _needs_web_search(self, text: str) -> bool:
279
+ """Check if the question likely needs web search"""
280
+ search_indicators = [
281
+ 'who', 'what', 'when', 'where', 'which', 'published', 'article',
282
+ 'wikipedia', 'latest', 'recent', 'current', 'news', 'website',
283
+ 'url', 'http', 'www', 'competition', 'olympics', 'award',
284
+ 'winner', 'recipient', 'author', 'published in', 'paper',
285
+ 'study', 'research', 'species', 'city', 'country'
286
+ ]
287
+ text_lower = text.lower()
288
+ return any(indicator in text_lower for indicator in search_indicators)
289
+
290
  def _extract_final_answer(self, response: str) -> str:
291
  """Extract the final answer from the response"""
292
  if "FINAL ANSWER:" in response: