LamiaYT commited on
Commit
695f802
·
1 Parent(s): 72ec5e1
Files changed (3) hide show
  1. agent.py +267 -88
  2. app.py +281 -214
  3. requirements.txt +13 -4
agent.py CHANGED
@@ -8,15 +8,12 @@ load_dotenv()
8
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
9
 
10
  # Load keys from environment
11
- groq_api_key = os.getenv("GROQ_API_KEY")
12
- serper_api_key = os.getenv("SERPER_API_KEY")
13
  hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
 
14
 
15
  # ---- Imports ----
16
  from langgraph.graph import START, StateGraph, MessagesState
17
  from langgraph.prebuilt import tools_condition, ToolNode
18
- from langchain_google_genai import ChatGoogleGenerativeAI
19
- from langchain_groq import ChatGroq
20
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
21
  from langchain_community.tools.tavily_search import TavilySearchResults
22
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
@@ -29,152 +26,334 @@ from langchain.vectorstores import Chroma
29
  from langchain.embeddings import HuggingFaceEmbeddings
30
  from langchain.schema import Document
31
  import json
 
 
 
 
 
32
 
33
- # ---- Tools ----
34
 
35
  @tool
36
- def multiply(a: int, b: int) -> int:
 
37
  return a * b
38
 
39
  @tool
40
- def add(a: int, b: int) -> int:
 
41
  return a + b
42
 
43
  @tool
44
- def subtract(a: int, b: int) -> int:
 
45
  return a - b
46
 
47
  @tool
48
- def divide(a: int, b: int) -> float:
 
49
  if b == 0:
50
  raise ValueError("Cannot divide by zero.")
51
  return a / b
52
 
53
  @tool
54
  def modulus(a: int, b: int) -> int:
 
55
  return a % b
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  @tool
58
  def wiki_search(query: str) -> str:
59
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
60
- formatted = "\n\n---\n\n".join(
61
- [
62
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
 
 
 
 
63
  for doc in search_docs
64
- ]
65
- )
66
- return {"wiki_results": formatted}
 
67
 
68
  @tool
69
  def web_search(query: str) -> str:
70
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
71
- formatted = "\n\n---\n\n".join(
72
- [
73
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
 
 
 
 
74
  for doc in search_docs
75
- ]
76
- )
77
- return {"web_results": formatted}
 
78
 
79
  @tool
80
- def arvix_search(query: str) -> str:
81
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
82
- formatted = "\n\n---\n\n".join(
83
- [
84
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
 
 
 
 
85
  for doc in search_docs
86
- ]
87
- )
88
- return {"arvix_results": formatted}
 
89
 
90
- # ---- Embedding & Vector Store Setup ----
91
-
92
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
93
-
94
- json_QA = []
95
- with open('metadata.jsonl', 'r') as jsonl_file:
96
- for line in jsonl_file:
97
- json_QA.append(json.loads(line))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- documents = [
100
- Document(
101
- page_content=f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}",
102
- metadata={"source": sample["task_id"]}
103
- )
104
- for sample in json_QA
105
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- vector_store = Chroma.from_documents(
108
- documents=documents,
109
- embedding=embeddings,
110
- persist_directory="./chroma_db",
111
- collection_name="my_collection"
112
- )
113
- vector_store.persist()
114
- print("Documents inserted:", vector_store._collection.count())
115
 
116
  @tool
117
  def similar_question_search(query: str) -> str:
118
- matched_docs = vector_store.similarity_search(query, 3)
119
- formatted = "\n\n---\n\n".join(
120
- [
121
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
 
 
 
 
 
 
 
122
  for doc in matched_docs
123
- ]
124
- )
125
- return {"similar_questions": formatted}
126
-
127
- # ---- System Prompt ----
128
 
 
129
  system_prompt = """
130
- You are a helpful assistant tasked with answering questions using a set of tools.
131
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
132
- FINAL ANSWER: [YOUR FINAL ANSWER].
133
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  """
135
 
136
  sys_msg = SystemMessage(content=system_prompt)
137
 
138
- # ---- Tool List ----
139
-
140
  tools = [
141
- multiply, add, subtract, divide, modulus,
142
- wiki_search, web_search, arvix_search, similar_question_search
 
 
 
143
  ]
144
 
145
  # ---- Graph Definition ----
146
-
147
- def build_graph(provider: str = "groq"):
148
- if provider == "groq":
149
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=groq_api_key)
150
- elif provider == "google":
151
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
152
- elif provider == "huggingface":
153
- llm = ChatHuggingFace(
154
- llm=HuggingFaceEndpoint(repo_id="mosaicml/mpt-30b", temperature=0)
 
 
 
 
155
  )
 
156
  else:
157
- raise ValueError("Invalid provider: choose 'groq', 'google', or 'huggingface'.")
158
 
159
  llm_with_tools = llm.bind_tools(tools)
160
 
161
  def assistant(state: MessagesState):
162
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
163
 
164
  def retriever(state: MessagesState):
165
- similar = vector_store.similarity_search(state["messages"][0].content)
166
- if similar:
167
- example_msg = HumanMessage(content=f"Here is a similar question:\n\n{similar[0].page_content}")
168
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
169
- return {"messages": [sys_msg] + state["messages"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
171
  builder = StateGraph(MessagesState)
172
  builder.add_node("retriever", retriever)
173
  builder.add_node("assistant", assistant)
174
  builder.add_node("tools", ToolNode(tools))
 
 
175
  builder.add_edge(START, "retriever")
176
  builder.add_edge("retriever", "assistant")
177
  builder.add_conditional_edges("assistant", tools_condition)
178
  builder.add_edge("tools", "assistant")
179
 
180
- return builder.compile()
 
8
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
9
 
10
  # Load keys from environment
 
 
11
  hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
12
+ serper_api_key = os.getenv("SERPER_API_KEY")
13
 
14
  # ---- Imports ----
15
  from langgraph.graph import START, StateGraph, MessagesState
16
  from langgraph.prebuilt import tools_condition, ToolNode
 
 
17
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
18
  from langchain_community.tools.tavily_search import TavilySearchResults
19
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
26
  from langchain.embeddings import HuggingFaceEmbeddings
27
  from langchain.schema import Document
28
  import json
29
+ import requests
30
+ from typing import List, Dict, Any
31
+ import re
32
+ import math
33
+ from datetime import datetime
34
 
35
+ # ---- Enhanced Tools ----
36
 
37
  @tool
38
+ def multiply(a: float, b: float) -> float:
39
+ """Multiply two numbers"""
40
  return a * b
41
 
42
  @tool
43
+ def add(a: float, b: float) -> float:
44
+ """Add two numbers"""
45
  return a + b
46
 
47
  @tool
48
+ def subtract(a: float, b: float) -> float:
49
+ """Subtract two numbers"""
50
  return a - b
51
 
52
  @tool
53
+ def divide(a: float, b: float) -> float:
54
+ """Divide two numbers"""
55
  if b == 0:
56
  raise ValueError("Cannot divide by zero.")
57
  return a / b
58
 
59
  @tool
60
  def modulus(a: int, b: int) -> int:
61
+ """Calculate modulus of two integers"""
62
  return a % b
63
 
64
+ @tool
65
+ def power(a: float, b: float) -> float:
66
+ """Calculate a raised to the power of b"""
67
+ return a ** b
68
+
69
+ @tool
70
+ def square_root(a: float) -> float:
71
+ """Calculate square root of a number"""
72
+ return math.sqrt(a)
73
+
74
+ @tool
75
+ def factorial(n: int) -> int:
76
+ """Calculate factorial of a number"""
77
+ if n < 0:
78
+ raise ValueError("Factorial is not defined for negative numbers")
79
+ if n == 0 or n == 1:
80
+ return 1
81
+ result = 1
82
+ for i in range(2, n + 1):
83
+ result *= i
84
+ return result
85
+
86
+ @tool
87
+ def gcd(a: int, b: int) -> int:
88
+ """Calculate greatest common divisor"""
89
+ while b:
90
+ a, b = b, a % b
91
+ return a
92
+
93
+ @tool
94
+ def lcm(a: int, b: int) -> int:
95
+ """Calculate least common multiple"""
96
+ return abs(a * b) // gcd(a, b)
97
+
98
+ @tool
99
+ def percentage(part: float, whole: float) -> float:
100
+ """Calculate percentage"""
101
+ return (part / whole) * 100
102
+
103
+ @tool
104
+ def compound_interest(principal: float, rate: float, time: float, n: int = 1) -> float:
105
+ """Calculate compound interest"""
106
+ return principal * (1 + rate/n) ** (n * time)
107
+
108
  @tool
109
  def wiki_search(query: str) -> str:
110
+ """Search Wikipedia for information"""
111
+ try:
112
+ search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
113
+ if not search_docs:
114
+ return "No Wikipedia results found."
115
+
116
+ formatted = "\n\n---\n\n".join([
117
+ f'<Document source="{doc.metadata.get("source", "Wikipedia")}" title="{doc.metadata.get("title", "Unknown")}"/>\n{doc.page_content[:2000]}\n</Document>'
118
  for doc in search_docs
119
+ ])
120
+ return formatted
121
+ except Exception as e:
122
+ return f"Wikipedia search error: {str(e)}"
123
 
124
  @tool
125
  def web_search(query: str) -> str:
126
+ """Search the web using Tavily"""
127
+ try:
128
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
129
+ if not search_docs:
130
+ return "No web search results found."
131
+
132
+ formatted = "\n\n---\n\n".join([
133
+ f'<Document source="{doc.get("url", "Unknown")}" title="{doc.get("title", "Unknown")}"/>\n{doc.get("content", "")[:2000]}\n</Document>'
134
  for doc in search_docs
135
+ ])
136
+ return formatted
137
+ except Exception as e:
138
+ return f"Web search error: {str(e)}"
139
 
140
  @tool
141
+ def arxiv_search(query: str) -> str:
142
+ """Search ArXiv for academic papers"""
143
+ try:
144
+ search_docs = ArxivLoader(query=query, load_max_docs=2).load()
145
+ if not search_docs:
146
+ return "No ArXiv results found."
147
+
148
+ formatted = "\n\n---\n\n".join([
149
+ f'<Document source="{doc.metadata.get("source", "ArXiv")}" title="{doc.metadata.get("Title", "Unknown")}"/>\n{doc.page_content[:1500]}\n</Document>'
150
  for doc in search_docs
151
+ ])
152
+ return formatted
153
+ except Exception as e:
154
+ return f"ArXiv search error: {str(e)}"
155
 
156
+ @tool
157
+ def serper_search(query: str) -> str:
158
+ """Enhanced web search using Serper API"""
159
+ if not serper_api_key:
160
+ return "Serper API key not available"
161
+
162
+ try:
163
+ url = "https://google.serper.dev/search"
164
+ payload = json.dumps({
165
+ "q": query,
166
+ "num": 5
167
+ })
168
+ headers = {
169
+ 'X-API-KEY': serper_api_key,
170
+ 'Content-Type': 'application/json'
171
+ }
172
+
173
+ response = requests.request("POST", url, headers=headers, data=payload)
174
+ results = response.json()
175
+
176
+ if 'organic' not in results:
177
+ return "No search results found"
178
+
179
+ formatted = "\n\n---\n\n".join([
180
+ f'<Document source="{result.get("link", "Unknown")}" title="{result.get("title", "Unknown")}"/>\n{result.get("snippet", "")}\n</Document>'
181
+ for result in results['organic'][:3]
182
+ ])
183
+ return formatted
184
+ except Exception as e:
185
+ return f"Serper search error: {str(e)}"
186
 
187
+ # ---- Embedding & Vector Store Setup ----
188
+ def setup_vector_store():
189
+ try:
190
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
191
+
192
+ # Check if metadata.jsonl exists and load it
193
+ if os.path.exists('metadata.jsonl'):
194
+ json_QA = []
195
+ with open('metadata.jsonl', 'r') as jsonl_file:
196
+ for line in jsonl_file:
197
+ if line.strip(): # Skip empty lines
198
+ json_QA.append(json.loads(line))
199
+
200
+ if json_QA:
201
+ documents = [
202
+ Document(
203
+ page_content=f"Question: {sample.get('Question', '')}\n\nFinal answer: {sample.get('Final answer', '')}",
204
+ metadata={"source": sample.get("task_id", "unknown")}
205
+ )
206
+ for sample in json_QA if sample.get('Question') and sample.get('Final answer')
207
+ ]
208
+
209
+ if documents:
210
+ vector_store = Chroma.from_documents(
211
+ documents=documents,
212
+ embedding=embeddings,
213
+ persist_directory="./chroma_db",
214
+ collection_name="my_collection"
215
+ )
216
+ vector_store.persist()
217
+ print(f"Vector store created with {len(documents)} documents")
218
+ return vector_store
219
+
220
+ # Create empty vector store if no data
221
+ vector_store = Chroma(
222
+ embedding_function=embeddings,
223
+ persist_directory="./chroma_db",
224
+ collection_name="my_collection"
225
+ )
226
+ print("Empty vector store created")
227
+ return vector_store
228
+
229
+ except Exception as e:
230
+ print(f"Vector store setup error: {e}")
231
+ # Return a dummy vector store function
232
+ return None
233
 
234
+ vector_store = setup_vector_store()
 
 
 
 
 
 
 
235
 
236
  @tool
237
  def similar_question_search(query: str) -> str:
238
+ """Search for similar questions in the knowledge base"""
239
+ if not vector_store:
240
+ return "Vector store not available"
241
+
242
+ try:
243
+ matched_docs = vector_store.similarity_search(query, 3)
244
+ if not matched_docs:
245
+ return "No similar questions found"
246
+
247
+ formatted = "\n\n---\n\n".join([
248
+ f'<Document source="{doc.metadata.get("source", "Unknown")}" />\n{doc.page_content[:1000]}\n</Document>'
249
  for doc in matched_docs
250
+ ])
251
+ return formatted
252
+ except Exception as e:
253
+ return f"Similar question search error: {str(e)}"
 
254
 
255
+ # ---- Enhanced System Prompt ----
256
  system_prompt = """
257
+ You are an expert assistant capable of solving complex questions using available tools. You have access to:
258
+
259
+ 1. Mathematical tools: add, subtract, multiply, divide, modulus, power, square_root, factorial, gcd, lcm, percentage, compound_interest
260
+ 2. Search tools: wiki_search, web_search, arxiv_search, serper_search, similar_question_search
261
+
262
+ IMPORTANT INSTRUCTIONS:
263
+ 1. Break down complex questions into smaller steps
264
+ 2. Use tools systematically to gather information and perform calculations
265
+ 3. For mathematical problems, show your work step by step
266
+ 4. For factual questions, search for current and accurate information
267
+ 5. Cross-reference information from multiple sources when possible
268
+ 6. Be precise with numbers - avoid rounding unless necessary
269
+
270
+ When providing your final answer, use this exact format:
271
+ FINAL ANSWER: [YOUR ANSWER]
272
+
273
+ Rules for the final answer:
274
+ - Numbers: Use plain digits without commas, units, or symbols (unless specifically requested)
275
+ - Strings: Use exact names without articles or abbreviations
276
+ - Lists: Comma-separated values following the above rules
277
+ - Be concise and accurate
278
+
279
+ Think step by step and use the available tools to ensure accuracy.
280
  """
281
 
282
  sys_msg = SystemMessage(content=system_prompt)
283
 
284
+ # ---- Enhanced Tool List ----
 
285
  tools = [
286
+ # Math tools
287
+ multiply, add, subtract, divide, modulus, power, square_root,
288
+ factorial, gcd, lcm, percentage, compound_interest,
289
+ # Search tools
290
+ wiki_search, web_search, arxiv_search, serper_search, similar_question_search
291
  ]
292
 
293
  # ---- Graph Definition ----
294
+ def build_graph(provider: str = "huggingface"):
295
+ """Build the agent graph with improved HuggingFace model"""
296
+
297
+ if provider == "huggingface":
298
+ # Use a more capable model from HuggingFace
299
+ endpoint = HuggingFaceEndpoint(
300
+ repo_id="microsoft/DialoGPT-large", # You can also try "google/flan-t5-xl" or "bigscience/bloom-7b1"
301
+ temperature=0.1,
302
+ huggingfacehub_api_token=hf_token,
303
+ model_kwargs={
304
+ "max_length": 1024,
305
+ "return_full_text": False
306
+ }
307
  )
308
+ llm = ChatHuggingFace(llm=endpoint)
309
  else:
310
+ raise ValueError("Only 'huggingface' provider is supported in this version.")
311
 
312
  llm_with_tools = llm.bind_tools(tools)
313
 
314
  def assistant(state: MessagesState):
315
+ """Enhanced assistant node with better error handling"""
316
+ try:
317
+ messages = state["messages"]
318
+ response = llm_with_tools.invoke(messages)
319
+ return {"messages": [response]}
320
+ except Exception as e:
321
+ print(f"Assistant error: {e}")
322
+ # Fallback response
323
+ fallback_msg = HumanMessage(content=f"I encountered an error: {str(e)}. Let me try a simpler approach.")
324
+ return {"messages": [fallback_msg]}
325
 
326
  def retriever(state: MessagesState):
327
+ """Enhanced retriever with better context injection"""
328
+ messages = state["messages"]
329
+ user_query = messages[-1].content if messages else ""
330
+
331
+ # Try to find similar questions
332
+ context_messages = [sys_msg]
333
+
334
+ if vector_store:
335
+ try:
336
+ similar = vector_store.similarity_search(user_query, k=2)
337
+ if similar:
338
+ context_msg = HumanMessage(
339
+ content=f"Here are similar questions for context:\n\n{similar[0].page_content}"
340
+ )
341
+ context_messages.append(context_msg)
342
+ except Exception as e:
343
+ print(f"Retriever error: {e}")
344
+
345
+ return {"messages": context_messages + messages}
346
 
347
+ # Build the graph
348
  builder = StateGraph(MessagesState)
349
  builder.add_node("retriever", retriever)
350
  builder.add_node("assistant", assistant)
351
  builder.add_node("tools", ToolNode(tools))
352
+
353
+ # Define edges
354
  builder.add_edge(START, "retriever")
355
  builder.add_edge("retriever", "assistant")
356
  builder.add_conditional_edges("assistant", tools_condition)
357
  builder.add_edge("tools", "assistant")
358
 
359
+ return builder.compile()
app.py CHANGED
@@ -1,235 +1,302 @@
1
  import os
2
- from dotenv import load_dotenv
3
-
4
- # Load environment variables
5
- load_dotenv()
6
-
7
- # Set protobuf implementation to avoid C++ extension issues
8
- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
9
-
10
- # Load keys from environment
11
- hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
12
- serper_api_key = os.getenv("SERPER_API_KEY")
13
-
14
- # ---- Updated Imports ----
15
- from langgraph.graph import START, StateGraph, MessagesState
16
- from langgraph.prebuilt import tools_condition, ToolNode
17
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
18
- from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
19
- from langchain_community.tools.tavily_search import TavilySearchResults
20
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
21
- from langchain_community.vectorstores import Chroma
22
- from langchain_core.documents import Document
23
- from langchain_core.messages import SystemMessage, HumanMessage
24
- from langchain_core.tools import tool
25
- from langchain.tools.retriever import create_retriever_tool
26
- import json
27
-
28
- # ---- Tools ----
29
-
30
- @tool
31
- def multiply(a: int, b: int) -> int:
32
- """Multiply two numbers together."""
33
- return a * b
34
-
35
- @tool
36
- def add(a: int, b: int) -> int:
37
- """Add two numbers together."""
38
- return a + b
39
-
40
- @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract the second number from the first."""
43
- return a - b
44
-
45
- @tool
46
- def divide(a: int, b: int) -> float:
47
- """Divide the first number by the second. Returns float or error if dividing by zero."""
48
- if b == 0:
49
- raise ValueError("Cannot divide by zero.")
50
- return a / b
51
-
52
- @tool
53
- def modulus(a: int, b: int) -> int:
54
- """Returns the remainder after division of the first number by the second."""
55
- return a % b
56
-
57
- @tool
58
- def wiki_search(query: str) -> str:
59
- """Search Wikipedia for information. Useful for factual questions about people, places, events, etc."""
60
- try:
61
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
62
- formatted = "\n\n---\n\n".join(
63
- [
64
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
65
- for doc in search_docs
66
- ]
67
- )
68
- return {"wiki_results": formatted}
69
- except Exception as e:
70
- return f"Wikipedia search failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- @tool
73
- def web_search(query: str) -> str:
74
- """Search the web for current information. Useful when you need recent or non-Wikipedia information."""
 
 
75
  try:
76
- search = TavilySearchResults(max_results=3)
77
- search_docs = search.invoke(query)
78
- formatted = "\n\n---\n\n".join(
79
- [
80
- f'<Document source="{doc["url"]}"/>\n{doc["content"]}\n</Document>'
81
- for doc in search_docs
82
- ]
83
- )
84
- return {"web_results": formatted}
85
  except Exception as e:
86
- return f"Web search failed: {str(e)}"
 
 
 
 
 
87
 
88
- @tool
89
- def arxiv_search(query: str) -> str:
90
- """Search academic papers on ArXiv. Useful for technical or scientific questions."""
91
  try:
92
- search_docs = ArxivLoader(query=query, load_max_docs=2).load()
93
- formatted = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content[:1000]}\n</Document>'
96
- for doc in search_docs
97
- ]
98
- )
99
- return {"arxiv_results": formatted}
 
 
 
 
 
 
100
  except Exception as e:
101
- return f"ArXiv search failed: {str(e)}"
102
-
103
- # ---- Updated Embedding & Vector Store Setup ----
104
- try:
105
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
106
- except Exception as e:
107
- print(f"Error loading embeddings: {e}")
108
- raise
109
-
110
- # Load QA pairs
111
- json_QA = []
112
- try:
113
- with open('metadata.jsonl', 'r') as jsonl_file:
114
- for line in jsonl_file:
115
- json_QA.append(json.loads(line))
116
- except Exception as e:
117
- print(f"Error loading metadata.jsonl: {e}")
118
- json_QA = []
119
-
120
- documents = [
121
- Document(
122
- page_content=f"Question: {sample['Question']}\n\nAnswer: {sample['Final answer']}",
123
- metadata={"source": sample["task_id"], "question": sample["Question"], "answer": sample["Final answer"]}
124
- )
125
- for sample in json_QA
126
- ]
127
-
128
- try:
129
- vector_store = Chroma.from_documents(
130
- documents=documents,
131
- embedding=embeddings,
132
- persist_directory="./chroma_db",
133
- collection_name="qa_collection"
134
- )
135
- vector_store.persist()
136
- print(f"Documents inserted: {len(documents)}")
137
- except Exception as e:
138
- print(f"Error creating vector store: {e}")
139
- raise
140
-
141
- @tool
142
- def similar_question_search(query: str) -> str:
143
- """Search for similar questions that have been answered before. Always check here first before using other tools."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  try:
145
- matched_docs = vector_store.similarity_search(query, k=3)
146
- formatted = "\n\n---\n\n".join(
147
- [
148
- f'<Question: {doc.metadata["question"]}>\n<Answer: {doc.metadata["answer"]}>\n</Document>'
149
- for doc in matched_docs
150
- ]
 
 
 
 
151
  )
152
- return {"similar_questions": formatted}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  except Exception as e:
154
- return f"Similar question search failed: {str(e)}"
 
 
 
155
 
156
- # ---- System Prompt ----
157
 
158
- system_prompt = """
159
- You are an expert question-answering assistant. Follow these steps for each question:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- 1. FIRST check for similar questions using the similar_question_search tool
162
- 2. If a similar question exists with a clear answer, use that answer
163
- 3. If not, determine which tools might help answer the question
164
- 4. Use the tools systematically to gather information
165
- 5. Combine information from multiple sources if needed
166
- 6. Format your final answer precisely as:
167
- FINAL ANSWER: [your answer here]
168
 
169
- Rules for answers:
170
- - Numbers: plain digits only (no commas, units, or symbols)
171
- - Strings: minimal words, no articles, full names
172
- - Lists: comma-separated with no extra formatting
173
- - Be concise but accurate
174
- """
175
 
176
- sys_msg = SystemMessage(content=system_prompt)
 
177
 
178
- # ---- Tool List ----
 
 
 
179
 
180
- tools = [
181
- similar_question_search, # Check this first
182
- multiply, add, subtract, divide, modulus, # Math tools
183
- wiki_search, web_search, arxiv_search # Information tools
184
- ]
 
 
185
 
186
- # ---- Graph Definition ----
 
 
 
 
187
 
188
- def build_graph():
189
- try:
190
- # Using a powerful HuggingFace model
191
- llm = ChatHuggingFace(
192
- llm=HuggingFaceEndpoint(
193
- repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
194
- temperature=0,
195
- max_new_tokens=512,
196
- huggingfacehub_api_token=hf_token
197
- )
198
- )
199
-
200
- llm_with_tools = llm.bind_tools(tools)
201
-
202
- def assistant(state: MessagesState):
203
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
204
-
205
- def retriever(state: MessagesState):
206
- try:
207
- # First try to find similar questions
208
- similar = vector_store.similarity_search(state["messages"][-1].content, k=2)
209
- if similar:
210
- example_msg = HumanMessage(
211
- content=f"Here are similar questions and their answers:\n\n" +
212
- "\n\n".join([f"Q: {doc.metadata['question']}\nA: {doc.metadata['answer']}"
213
- for doc in similar])
214
- )
215
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
216
- return {"messages": [sys_msg] + state["messages"]}
217
- except Exception as e:
218
- print(f"Retriever error: {e}")
219
- return {"messages": [sys_msg] + state["messages"]}
220
-
221
- builder = StateGraph(MessagesState)
222
- builder.add_node("retriever", retriever)
223
- builder.add_node("assistant", assistant)
224
- builder.add_node("tools", ToolNode(tools))
225
 
226
- builder.add_edge(START, "retriever")
227
- builder.add_edge("retriever", "assistant")
228
- builder.add_conditional_edges("assistant", tools_condition)
229
- builder.add_edge("tools", "assistant")
230
 
231
- return builder.compile()
232
-
233
- except Exception as e:
234
- print(f"Error building graph: {e}")
235
- raise
 
1
  import os
2
+ import gradio as gr
3
+ import requests
4
+ import inspect
5
+ import pandas as pd
6
+ from agent import build_graph
7
+ from langchain_core.messages import HumanMessage
8
+ import time
9
+
10
+ # (Keep Constants as is)
11
+ # --- Constants ---
12
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
+
14
+ # --- Improved Agent Definition ---
15
+ class BasicAgent:
16
+ def __init__(self):
17
+ print("BasicAgent initialized.")
18
+ try:
19
+ self.graph = build_graph()
20
+ print("Graph built successfully.")
21
+ except Exception as e:
22
+ print(f"Error building graph: {e}")
23
+ raise e
24
+
25
+ def __call__(self, question: str) -> str:
26
+ print(f"Agent received question (first 100 chars): {question[:100]}...")
27
+
28
+ try:
29
+ # Clean the question
30
+ question = question.strip()
31
+
32
+ # Wrap the question in a HumanMessage
33
+ messages = [HumanMessage(content=question)]
34
+
35
+ # Invoke the graph with retry mechanism
36
+ max_retries = 3
37
+ for attempt in range(max_retries):
38
+ try:
39
+ result = self.graph.invoke({"messages": messages})
40
+
41
+ if 'messages' in result and result['messages']:
42
+ answer = result['messages'][-1].content
43
+
44
+ # Clean up the answer
45
+ if isinstance(answer, str):
46
+ # Remove the "FINAL ANSWER: " prefix if it exists
47
+ if "FINAL ANSWER:" in answer:
48
+ answer = answer.split("FINAL ANSWER:")[-1].strip()
49
+
50
+ # Additional cleanup
51
+ answer = answer.replace("Assistant: ", "").strip()
52
+
53
+ print(f"Agent answer (first 100 chars): {answer[:100]}...")
54
+ return answer
55
+ else:
56
+ return str(answer)
57
+ else:
58
+ return "No response generated"
59
+
60
+ except Exception as e:
61
+ print(f"Attempt {attempt + 1} failed: {e}")
62
+ if attempt == max_retries - 1:
63
+ return f"Error processing question: {str(e)}"
64
+ time.sleep(1) # Brief pause before retry
65
+
66
+ except Exception as e:
67
+ print(f"Error in agent call: {e}")
68
+ return f"Agent error: {str(e)}"
69
+
70
+
71
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
72
+ """
73
+ Fetches all questions, runs the BasicAgent on them, submits all answers,
74
+ and displays the results.
75
+ """
76
+ # --- Determine HF Space Runtime URL and Repo URL ---
77
+ space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
78
+
79
+ if profile:
80
+ username = f"{profile.username}"
81
+ print(f"User logged in: {username}")
82
+ else:
83
+ print("User not logged in.")
84
+ return "Please Login to Hugging Face with the button.", None
85
 
86
+ api_url = DEFAULT_API_URL
87
+ questions_url = f"{api_url}/questions"
88
+ submit_url = f"{api_url}/submit"
89
+
90
+ # 1. Instantiate Agent (modify this part to create your agent)
91
  try:
92
+ print("Initializing agent...")
93
+ agent = BasicAgent()
94
+ print("Agent initialized successfully.")
 
 
 
 
 
 
95
  except Exception as e:
96
+ print(f"Error instantiating agent: {e}")
97
+ return f"Error initializing agent: {e}", None
98
+
99
+ # In the case of an app running as a Hugging Face space, this link points toward your codebase
100
+ agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
101
+ print(f"Agent code URL: {agent_code}")
102
 
103
+ # 2. Fetch Questions
104
+ print(f"Fetching questions from: {questions_url}")
 
105
  try:
106
+ response = requests.get(questions_url, timeout=30)
107
+ response.raise_for_status()
108
+ questions_data = response.json()
109
+ if not questions_data:
110
+ print("Fetched questions list is empty.")
111
+ return "Fetched questions list is empty or invalid format.", None
112
+ print(f"Fetched {len(questions_data)} questions.")
113
+ except requests.exceptions.RequestException as e:
114
+ print(f"Error fetching questions: {e}")
115
+ return f"Error fetching questions: {e}", None
116
+ except requests.exceptions.JSONDecodeError as e:
117
+ print(f"Error decoding JSON response from questions endpoint: {e}")
118
+ print(f"Response text: {response.text[:500]}")
119
+ return f"Error decoding server response for questions: {e}", None
120
  except Exception as e:
121
+ print(f"An unexpected error occurred fetching questions: {e}")
122
+ return f"An unexpected error occurred fetching questions: {e}", None
123
+
124
+ # 3. Run your Agent with better error handling
125
+ results_log = []
126
+ answers_payload = []
127
+ print(f"Running agent on {len(questions_data)} questions...")
128
+
129
+ for i, item in enumerate(questions_data):
130
+ task_id = item.get("task_id")
131
+ question_text = item.get("question")
132
+
133
+ if not task_id or question_text is None:
134
+ print(f"Skipping item with missing task_id or question: {item}")
135
+ continue
136
+
137
+ print(f"Processing question {i+1}/{len(questions_data)}: {task_id}")
138
+
139
+ try:
140
+ # Add timeout and better error handling for individual questions
141
+ start_time = time.time()
142
+ submitted_answer = agent(question_text)
143
+ end_time = time.time()
144
+
145
+ print(f"Question {i+1} completed in {end_time - start_time:.2f} seconds")
146
+
147
+ # Validate the answer
148
+ if not submitted_answer or submitted_answer.strip() == "":
149
+ submitted_answer = "No answer generated"
150
+
151
+ answers_payload.append({
152
+ "task_id": task_id,
153
+ "submitted_answer": str(submitted_answer).strip()
154
+ })
155
+
156
+ results_log.append({
157
+ "Task ID": task_id,
158
+ "Question": question_text[:200] + "..." if len(question_text) > 200 else question_text,
159
+ "Submitted Answer": str(submitted_answer).strip()
160
+ })
161
+
162
+ except Exception as e:
163
+ print(f"Error running agent on task {task_id}: {e}")
164
+ error_answer = f"AGENT ERROR: {str(e)}"
165
+ answers_payload.append({
166
+ "task_id": task_id,
167
+ "submitted_answer": error_answer
168
+ })
169
+ results_log.append({
170
+ "Task ID": task_id,
171
+ "Question": question_text[:200] + "..." if len(question_text) > 200 else question_text,
172
+ "Submitted Answer": error_answer
173
+ })
174
+
175
+ if not answers_payload:
176
+ print("Agent did not produce any answers to submit.")
177
+ return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
178
+
179
+ # 4. Prepare Submission
180
+ submission_data = {
181
+ "username": username.strip(),
182
+ "agent_code": agent_code,
183
+ "answers": answers_payload
184
+ }
185
+ status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
186
+ print(status_update)
187
+
188
+ # 5. Submit with better error handling
189
+ print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
190
  try:
191
+ response = requests.post(submit_url, json=submission_data, timeout=120)
192
+ response.raise_for_status()
193
+ result_data = response.json()
194
+
195
+ final_status = (
196
+ f"Submission Successful!\n"
197
+ f"User: {result_data.get('username', 'Unknown')}\n"
198
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
199
+ f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
200
+ f"Message: {result_data.get('message', 'No message received.')}"
201
  )
202
+ print("Submission successful.")
203
+ print(f"Score: {result_data.get('score', 'N/A')}%")
204
+
205
+ results_df = pd.DataFrame(results_log)
206
+ return final_status, results_df
207
+
208
+ except requests.exceptions.HTTPError as e:
209
+ error_detail = f"Server responded with status {e.response.status_code}."
210
+ try:
211
+ error_json = e.response.json()
212
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
213
+ except requests.exceptions.JSONDecodeError:
214
+ error_detail += f" Response: {e.response.text[:500]}"
215
+ status_message = f"Submission Failed: {error_detail}"
216
+ print(status_message)
217
+ results_df = pd.DataFrame(results_log)
218
+ return status_message, results_df
219
+
220
+ except requests.exceptions.Timeout:
221
+ status_message = "Submission Failed: The request timed out."
222
+ print(status_message)
223
+ results_df = pd.DataFrame(results_log)
224
+ return status_message, results_df
225
+
226
+ except requests.exceptions.RequestException as e:
227
+ status_message = f"Submission Failed: Network error - {e}"
228
+ print(status_message)
229
+ results_df = pd.DataFrame(results_log)
230
+ return status_message, results_df
231
+
232
  except Exception as e:
233
+ status_message = f"An unexpected error occurred during submission: {e}"
234
+ print(status_message)
235
+ results_df = pd.DataFrame(results_log)
236
+ return status_message, results_df
237
 
 
238
 
239
+ # --- Build Gradio Interface using Blocks ---
240
+ with gr.Blocks() as demo:
241
+ gr.Markdown("# Enhanced Agent Evaluation Runner")
242
+ gr.Markdown(
243
+ """
244
+ **Instructions:**
245
+ 1. Please clone this space, then modify the code to define your agent's logic, tools, and necessary packages.
246
+ 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
247
+ 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
248
+
249
+ **Improvements in this version:**
250
+ - Enhanced mathematical tools (factorial, gcd, lcm, compound interest, etc.)
251
+ - Better search tools with error handling
252
+ - Improved HuggingFace model integration
253
+ - Better answer processing and cleanup
254
+ - Enhanced error handling and retry mechanisms
255
+
256
+ ---
257
+ **Note:** The evaluation process may take some time as the agent processes all questions systematically.
258
+ """
259
+ )
260
 
261
+ gr.LoginButton()
 
 
 
 
 
 
262
 
263
+ run_button = gr.Button("Run Evaluation & Submit All Answers", variant="primary")
 
 
 
 
 
264
 
265
+ status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
266
+ results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
267
 
268
+ run_button.click(
269
+ fn=run_and_submit_all,
270
+ outputs=[status_output, results_table]
271
+ )
272
 
273
+ if __name__ == "__main__":
274
+ print("\n" + "-"*30 + " Enhanced App Starting " + "-"*30)
275
+
276
+ # Check for environment variables
277
+ space_host_startup = os.getenv("SPACE_HOST")
278
+ space_id_startup = os.getenv("SPACE_ID")
279
+ hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
280
 
281
+ if space_host_startup:
282
+ print(f"✅ SPACE_HOST found: {space_host_startup}")
283
+ print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
284
+ else:
285
+ print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
286
 
287
+ if space_id_startup:
288
+ print(f"✅ SPACE_ID found: {space_id_startup}")
289
+ print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
290
+ print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
291
+ else:
292
+ print("ℹ️ SPACE_ID environment variable not found (running locally?).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
+ if hf_token:
295
+ print(" HUGGINGFACE_INFERENCE_TOKEN found")
296
+ else:
297
+ print("⚠️ HUGGINGFACE_INFERENCE_TOKEN not found - this may cause issues")
298
 
299
+ print("-"*(60 + len(" Enhanced App Starting ")) + "\n")
300
+
301
+ print("Launching Enhanced Gradio Interface for Agent Evaluation...")
302
+ demo.launch(debug=True, share=False)
 
requirements.txt CHANGED
@@ -3,18 +3,27 @@ requests
3
  langchain
4
  langchain-community
5
  langchain-core
6
- langchain-google-genai
7
  langchain-huggingface
8
- langchain-groq
9
- langchain-tavily
10
  langchain-chroma
 
11
  langgraph
12
  sentence-transformers
13
  huggingface_hub
 
 
14
  supabase
15
  arxiv
16
  pymupdf
17
  wikipedia
18
  pgvector
19
  python-dotenv
20
- protobuf==3.20.3
 
 
 
 
 
 
 
 
 
 
3
  langchain
4
  langchain-community
5
  langchain-core
 
6
  langchain-huggingface
 
 
7
  langchain-chroma
8
+ langchain-tavily
9
  langgraph
10
  sentence-transformers
11
  huggingface_hub
12
+ transformers
13
+ torch
14
  supabase
15
  arxiv
16
  pymupdf
17
  wikipedia
18
  pgvector
19
  python-dotenv
20
+ protobuf==3.20.3
21
+ chromadb
22
+ tiktoken
23
+ numpy
24
+ pandas
25
+ scipy
26
+ sympy
27
+ python-dateutil
28
+ beautifulsoup4
29
+ lxml