Seb1101 commited on
Commit
ba23032
·
verified ·
1 Parent(s): 3b1edbb

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +311 -0
agent.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from datetime import datetime, timedelta
4
+ from typing import TypedDict, Annotated
5
+ import sympy as sp
6
+ from sympy import *
7
+ import math
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_core.messages import HumanMessage, SystemMessage
11
+ from langgraph.graph import StateGraph, MessagesState, START, END
12
+ from langgraph.prebuilt import ToolNode
13
+ from langgraph.checkpoint.memory import MemorySaver
14
+ import json
15
+
16
+ # Load environment variables
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
+
20
+ def read_system_prompt():
21
+ """Read the system prompt from file"""
22
+ try:
23
+ with open('system_prompt.txt', 'r') as f:
24
+ return f.read().strip()
25
+ except FileNotFoundError:
26
+ return """You are a helpful assistant tasked with answering questions using a set of tools.
27
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
28
+ FINAL ANSWER: [YOUR FINAL ANSWER].
29
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
30
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer."""
31
+
32
+ def math_calculator(expression: str) -> str:
33
+ """
34
+ Advanced mathematical calculator that can handle complex expressions,
35
+ equations, symbolic math, calculus, and more using SymPy.
36
+ """
37
+ try:
38
+ # Clean the expression
39
+ expression = expression.strip()
40
+
41
+ # Handle common mathematical operations and functions
42
+ expression = expression.replace('^', '**') # Convert ^ to **
43
+ expression = expression.replace('ln', 'log') # Natural log
44
+
45
+ # Try to evaluate as a symbolic expression first
46
+ try:
47
+ result = sp.sympify(expression)
48
+
49
+ # If it's a symbolic expression that can be simplified
50
+ simplified = sp.simplify(result)
51
+
52
+ # Try to get numerical value
53
+ try:
54
+ numerical = float(simplified.evalf())
55
+ return str(numerical)
56
+ except:
57
+ return str(simplified)
58
+
59
+ except:
60
+ # Fall back to basic evaluation
61
+ # Replace common math functions
62
+ safe_expression = expression
63
+ for func in ['sin', 'cos', 'tan', 'sqrt', 'log', 'exp', 'abs']:
64
+ safe_expression = safe_expression.replace(func, f'math.{func}')
65
+
66
+ # Evaluate safely
67
+ result = eval(safe_expression, {"__builtins__": {}}, {
68
+ "math": math,
69
+ "pi": math.pi,
70
+ "e": math.e
71
+ })
72
+ return str(result)
73
+
74
+ except Exception as e:
75
+ return f"Error calculating '{expression}': {str(e)}"
76
+
77
+ def date_time_processor(query: str) -> str:
78
+ """
79
+ Process date and time related queries, calculations, and conversions.
80
+ """
81
+ try:
82
+ current_time = datetime.now()
83
+ query_lower = query.lower()
84
+
85
+ # Current date/time queries
86
+ if 'current' in query_lower or 'today' in query_lower or 'now' in query_lower:
87
+ if 'date' in query_lower:
88
+ return current_time.strftime('%Y-%m-%d')
89
+ elif 'time' in query_lower:
90
+ return current_time.strftime('%H:%M:%S')
91
+ else:
92
+ return current_time.strftime('%Y-%m-%d %H:%M:%S')
93
+
94
+ # Day of week queries
95
+ if 'day of week' in query_lower or 'what day' in query_lower:
96
+ return current_time.strftime('%A')
97
+
98
+ # Year queries
99
+ if 'year' in query_lower and 'current' in query_lower:
100
+ return str(current_time.year)
101
+
102
+ # Month queries
103
+ if 'month' in query_lower and 'current' in query_lower:
104
+ return current_time.strftime('%B')
105
+
106
+ # Date arithmetic (simple cases)
107
+ if 'days ago' in query_lower:
108
+ days_match = re.search(r'(\d+)\s+days?\s+ago', query_lower)
109
+ if days_match:
110
+ days = int(days_match.group(1))
111
+ past_date = current_time - timedelta(days=days)
112
+ return past_date.strftime('%Y-%m-%d')
113
+
114
+ if 'days from now' in query_lower or 'days later' in query_lower:
115
+ days_match = re.search(r'(\d+)\s+days?\s+(?:from now|later)', query_lower)
116
+ if days_match:
117
+ days = int(days_match.group(1))
118
+ future_date = current_time + timedelta(days=days)
119
+ return future_date.strftime('%Y-%m-%d')
120
+
121
+ # If no specific pattern matched, return current datetime
122
+ return f"Current date and time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}"
123
+
124
+ except Exception as e:
125
+ return f"Error processing date/time query: {str(e)}"
126
+
127
+ # Define the agent state
128
+ class AgentState(TypedDict):
129
+ messages: Annotated[list, "The messages in the conversation"]
130
+
131
+ class GAIAAgent:
132
+ def __init__(self):
133
+ # Check for required API keys
134
+ openai_key = os.getenv("OPENAI_API_KEY")
135
+ tavily_key = os.getenv("TAVILY_API_KEY")
136
+
137
+ if not openai_key:
138
+ raise ValueError("OPENAI_API_KEY environment variable is required")
139
+ if not tavily_key:
140
+ raise ValueError("TAVILY_API_KEY environment variable is required")
141
+
142
+ print("✅ API keys found - initializing agent...")
143
+
144
+ # Initialize LLM (using OpenAI GPT-4)
145
+ self.llm = ChatOpenAI(
146
+ model="gpt-4o-mini",
147
+ temperature=0,
148
+ openai_api_key=openai_key
149
+ )
150
+
151
+ # Initialize tools
152
+ self.search_tool = TavilySearchResults(
153
+ max_results=5,
154
+ tavily_api_key=tavily_key
155
+ )
156
+
157
+ # Create tool list
158
+ self.tools = [self.search_tool]
159
+
160
+ # Create LLM with tools
161
+ self.llm_with_tools = self.llm.bind_tools(self.tools)
162
+
163
+ # Build the graph
164
+ self.graph = self._build_graph()
165
+
166
+ self.system_prompt = read_system_prompt()
167
+
168
+ def _build_graph(self):
169
+ """Build the LangGraph workflow"""
170
+
171
+ def agent_node(state: AgentState):
172
+ """Main agent reasoning node"""
173
+ messages = state["messages"]
174
+
175
+ # Add system message if not present
176
+ if not any(isinstance(msg, SystemMessage) for msg in messages):
177
+ system_msg = SystemMessage(content=self.system_prompt)
178
+ messages = [system_msg] + messages
179
+
180
+ # Get the last human message to check if it needs special processing
181
+ last_human_msg = None
182
+ for msg in reversed(messages):
183
+ if isinstance(msg, HumanMessage):
184
+ last_human_msg = msg.content
185
+ break
186
+
187
+ # Check if this is a math problem
188
+ if last_human_msg and self._is_math_problem(last_human_msg):
189
+ math_result = math_calculator(last_human_msg)
190
+ enhanced_msg = f"Math calculation result: {math_result}\n\nOriginal question: {last_human_msg}\n\nProvide your final answer based on this calculation."
191
+ messages[-1] = HumanMessage(content=enhanced_msg)
192
+
193
+ # Check if this is a date/time problem
194
+ elif last_human_msg and self._is_datetime_problem(last_human_msg):
195
+ datetime_result = date_time_processor(last_human_msg)
196
+ enhanced_msg = f"Date/time processing result: {datetime_result}\n\nOriginal question: {last_human_msg}\n\nProvide your final answer based on this information."
197
+ messages[-1] = HumanMessage(content=enhanced_msg)
198
+
199
+ response = self.llm_with_tools.invoke(messages)
200
+ return {"messages": messages + [response]}
201
+
202
+ def tool_node(state: AgentState):
203
+ """Tool execution node"""
204
+ messages = state["messages"]
205
+ last_message = messages[-1]
206
+
207
+ # Execute tool calls
208
+ tool_node_instance = ToolNode(self.tools)
209
+ result = tool_node_instance.invoke(state)
210
+ return result
211
+
212
+ def should_continue(state: AgentState):
213
+ """Decide whether to continue or end"""
214
+ last_message = state["messages"][-1]
215
+
216
+ # If the last message has tool calls, continue to tools
217
+ if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
218
+ return "tools"
219
+
220
+ # If we have a final answer, end
221
+ if hasattr(last_message, 'content') and "FINAL ANSWER:" in last_message.content:
222
+ return "end"
223
+
224
+ # Otherwise continue
225
+ return "end"
226
+
227
+ # Build the graph
228
+ workflow = StateGraph(AgentState)
229
+
230
+ # Add nodes
231
+ workflow.add_node("agent", agent_node)
232
+ workflow.add_node("tools", tool_node)
233
+
234
+ # Add edges
235
+ workflow.add_edge(START, "agent")
236
+ workflow.add_conditional_edges("agent", should_continue, {
237
+ "tools": "tools",
238
+ "end": END
239
+ })
240
+ workflow.add_edge("tools", "agent")
241
+
242
+ # Compile
243
+ memory = MemorySaver()
244
+ return workflow.compile(checkpointer=memory)
245
+
246
+ def _is_math_problem(self, text: str) -> bool:
247
+ """Check if the text contains mathematical expressions"""
248
+ math_indicators = [
249
+ '+', '-', '*', '/', '^', '=', 'calculate', 'compute',
250
+ 'solve', 'equation', 'integral', 'derivative', 'sum',
251
+ 'sqrt', 'log', 'sin', 'cos', 'tan', 'exp'
252
+ ]
253
+ text_lower = text.lower()
254
+ return any(indicator in text_lower for indicator in math_indicators) or \
255
+ re.search(r'\d+[\+\-\*/\^]\d+', text) is not None
256
+
257
+ def _is_datetime_problem(self, text: str) -> bool:
258
+ """Check if the text contains date/time related queries"""
259
+ datetime_indicators = [
260
+ 'date', 'time', 'day', 'month', 'year', 'today', 'yesterday',
261
+ 'tomorrow', 'current', 'now', 'ago', 'later', 'when'
262
+ ]
263
+ text_lower = text.lower()
264
+ return any(indicator in text_lower for indicator in datetime_indicators)
265
+
266
+ def __call__(self, question: str) -> str:
267
+ """Process a question and return the answer"""
268
+ try:
269
+ print(f"Processing question: {question[:100]}...")
270
+
271
+ # Create initial state
272
+ initial_state = {
273
+ "messages": [HumanMessage(content=question)]
274
+ }
275
+
276
+ # Run the graph
277
+ config = {"configurable": {"thread_id": "gaia_thread"}}
278
+ final_state = self.graph.invoke(initial_state, config)
279
+
280
+ # Extract the final answer
281
+ last_message = final_state["messages"][-1]
282
+ response_content = last_message.content if hasattr(last_message, 'content') else str(last_message)
283
+
284
+ # Extract just the final answer part
285
+ final_answer = self._extract_final_answer(response_content)
286
+
287
+ print(f"Final answer: {final_answer}")
288
+ return final_answer
289
+
290
+ except Exception as e:
291
+ print(f"Error processing question: {e}")
292
+ return f"Error: {str(e)}"
293
+
294
+ def _extract_final_answer(self, response: str) -> str:
295
+ """Extract the final answer from the response"""
296
+ if "FINAL ANSWER:" in response:
297
+ # Find the final answer part
298
+ parts = response.split("FINAL ANSWER:")
299
+ if len(parts) > 1:
300
+ answer = parts[-1].strip()
301
+ # Remove any trailing punctuation or explanations
302
+ answer = answer.split('\n')[0].strip()
303
+ return answer
304
+
305
+ # If no FINAL ANSWER format found, return the whole response
306
+ return response.strip()
307
+
308
+ # Create a function to get the agent (for use in app.py)
309
+ def create_agent():
310
+ """Factory function to create the GAIA agent"""
311
+ return GAIAAgent()