LamiaYT commited on
Commit
57b9551
·
1 Parent(s): d66e9b7

fixing ver3

Browse files
Files changed (1) hide show
  1. app.py +43 -89
app.py CHANGED
@@ -5,10 +5,9 @@ import json
5
  import re
6
  import numexpr
7
  import pandas as pd
8
- import math
9
  from pdfminer.high_level import extract_text
10
  from bs4 import BeautifulSoup
11
- from typing import Dict, Any, List, Tuple, Optional
12
  from dotenv import load_dotenv
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
14
  import torch
@@ -21,14 +20,14 @@ SERPER_API_KEY = os.getenv("SERPER_API_KEY")
21
 
22
  # --- Constants ---
23
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
24
- MAX_STEPS = 6 # Increased from 4
25
- MAX_TOKENS = 256 # Increased from 128
26
  MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
27
- TIMEOUT_PER_QUESTION = 45 # Increased from 30
28
- MAX_RESULT_LENGTH = 500 # For tool outputs
29
 
30
- # --- Model Loading ---
31
- print("Loading optimized model...")
32
  start_time = time.time()
33
 
34
  model = AutoModelForCausalLM.from_pretrained(
@@ -50,12 +49,12 @@ if tokenizer.pad_token is None:
50
 
51
  print(f"Model loaded in {time.time() - start_time:.2f} seconds")
52
 
53
- # --- Enhanced Tools ---
54
  def web_search(query: str) -> str:
55
- """Enhanced web search with better result parsing"""
56
  try:
57
  if SERPER_API_KEY:
58
- params = {'q': query, 'num': 3, 'hl': 'en', 'gl': 'us'}
59
  headers = {'X-API-KEY': SERPER_API_KEY}
60
  response = requests.post(
61
  'https://google.serper.dev/search',
@@ -64,97 +63,64 @@ def web_search(query: str) -> str:
64
  timeout=10
65
  )
66
  results = response.json()
67
-
68
  if 'organic' in results:
69
- output = []
70
- for r in results['organic'][:3]:
71
- if 'title' in r and 'snippet' in r:
72
- output.append(f"{r['title']}: {r['snippet']}")
73
- return "\n".join(output)[:MAX_RESULT_LENGTH]
74
- return "No relevant results found"
75
  else:
76
- with DDGS() as ddgs:
77
- results = [r for r in ddgs.text(query, max_results=3)]
78
- return "\n".join([f"{r['title']}: {r['body']}" for r in results])[:MAX_RESULT_LENGTH]
79
  except Exception as e:
80
  return f"Search error: {str(e)}"
81
 
82
  def calculator(expression: str) -> str:
83
- """More robust calculator with validation"""
84
  try:
85
- # Clean and validate expression
86
  expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression)
87
  if not expression:
88
  return "Invalid empty expression"
89
-
90
- # Handle percentages and commas
91
- expression = expression.replace('%', '/100').replace(',', '')
92
- result = numexpr.evaluate(expression)
93
- return str(float(result))
94
  except Exception as e:
95
  return f"Calculation error: {str(e)}"
96
 
97
- def read_pdf(file_path: str) -> str:
98
- """PDF reader with better text extraction"""
99
- try:
100
- text = extract_text(file_path)
101
- if not text:
102
- return "No readable text found in PDF"
103
-
104
- # Clean and condense text
105
- text = re.sub(r'\s+', ' ', text).strip()
106
- return text[:MAX_RESULT_LENGTH]
107
- except Exception as e:
108
- return f"PDF read error: {str(e)}"
109
-
110
  def read_webpage(url: str) -> str:
111
- """Improved webpage reader with better content extraction"""
112
  try:
113
  headers = {'User-Agent': 'Mozilla/5.0'}
114
  response = requests.get(url, timeout=10, headers=headers)
115
- response.raise_for_status()
116
-
117
  soup = BeautifulSoup(response.text, 'html.parser')
118
 
119
- # Remove unwanted elements
120
  for element in soup(['script', 'style', 'nav', 'footer']):
121
  element.decompose()
122
 
123
- # Get text with better formatting
124
  text = soup.get_text(separator='\n', strip=True)
125
- text = re.sub(r'\n{3,}', '\n\n', text)
126
-
127
- return text[:MAX_RESULT_LENGTH] if text else "No main content found"
128
  except Exception as e:
129
- return f"Webpage read error: {str(e)}"
130
 
131
  TOOLS = {
132
  "web_search": web_search,
133
  "calculator": calculator,
134
- "read_pdf": read_pdf,
135
  "read_webpage": read_webpage
136
  }
137
 
138
- # --- Improved GAIA Agent ---
139
  class GAIA_Agent:
140
  def __init__(self):
141
  self.tools = TOOLS
142
  self.system_prompt = """You are an advanced GAIA problem solver. Follow these steps:
143
- 1. Analyze the question carefully
144
- 2. Choose the most appropriate tool
145
- 3. Process the results
146
- 4. Provide a precise final answer
147
 
148
- Available Tools:
149
- - web_search: For general knowledge questions
150
- - calculator: For math problems
151
- - read_pdf: For PDF content extraction
152
- - read_webpage: For webpage content extraction
153
 
154
  Tool format: ```json
155
  {"tool": "tool_name", "args": {"arg1": value}}```
156
 
157
- Always end with: Final Answer: [your answer]"""
158
 
159
  def __call__(self, question: str) -> str:
160
  start_time = time.time()
@@ -169,21 +135,20 @@ Always end with: Final Answer: [your answer]"""
169
  response = self._call_model(prompt)
170
 
171
  if "Final Answer:" in response:
172
- answer = response.split("Final Answer:")[-1].strip()
173
- return answer[:500] # Limit answer length
174
 
175
  tool_call = self._parse_tool_call(response)
176
  if tool_call:
177
  tool_name, args = tool_call
178
  observation = self._use_tool(tool_name, args)
179
- history.append(f"Tool Used: {tool_name}")
180
- history.append(f"Tool Result: {observation[:300]}...") # Truncate long results
181
  else:
182
- history.append(f"Analysis: {response}")
183
 
184
  gc.collect()
185
 
186
- return "Maximum steps reached without final answer"
187
  except Exception as e:
188
  return f"Error: {str(e)}"
189
 
@@ -199,21 +164,17 @@ Always end with: Final Answer: [your answer]"""
199
  padding=False
200
  )
201
 
202
- generation_config = GenerationConfig(
 
 
203
  max_new_tokens=MAX_TOKENS,
204
  temperature=0.3,
205
  top_p=0.9,
206
  do_sample=True,
207
- pad_token_id=tokenizer.pad_token_id
 
208
  )
209
 
210
- with torch.no_grad():
211
- outputs = model.generate(
212
- inputs.input_ids,
213
- generation_config=generation_config,
214
- attention_mask=inputs.attention_mask
215
- )
216
-
217
  return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
218
 
219
  def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]:
@@ -232,11 +193,9 @@ Always end with: Final Answer: [your answer]"""
232
  return f"Unknown tool: {tool_name}"
233
 
234
  try:
235
- # Special handling for URL-containing questions
236
  if tool_name == "read_webpage" and "url" not in args:
237
- if "args" in args and isinstance(args["args"], dict) and "url" in args["args"]:
238
- args = args["args"]
239
- elif "http" in str(args):
240
  url = re.search(r'https?://[^\s]+', str(args)).group()
241
  args = {"url": url}
242
 
@@ -293,14 +252,9 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
293
  return f"Submission failed: {str(e)}", pd.DataFrame(results)
294
 
295
  # --- Gradio Interface ---
296
- with gr.Blocks(title="Enhanced GAIA Agent") as demo:
297
- gr.Markdown("## 🚀 Enhanced GAIA Agent Evaluation")
298
- gr.Markdown("""
299
- Improved version with:
300
- - Better tool utilization
301
- - Increased step/token limits
302
- - Enhanced error handling
303
- """)
304
 
305
  with gr.Row():
306
  gr.LoginButton()
 
5
  import re
6
  import numexpr
7
  import pandas as pd
 
8
  from pdfminer.high_level import extract_text
9
  from bs4 import BeautifulSoup
10
+ from typing import List, Dict, Optional, Tuple
11
  from dotenv import load_dotenv
12
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
13
  import torch
 
20
 
21
  # --- Constants ---
22
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
23
+ MAX_STEPS = 6
24
+ MAX_TOKENS = 256
25
  MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
26
+ TIMEOUT_PER_QUESTION = 45
27
+ MAX_RESULT_LENGTH = 500
28
 
29
+ # --- Fixed Model Loading ---
30
+ print("Loading model with fixed configuration...")
31
  start_time = time.time()
32
 
33
  model = AutoModelForCausalLM.from_pretrained(
 
49
 
50
  print(f"Model loaded in {time.time() - start_time:.2f} seconds")
51
 
52
+ # --- Tools Implementation ---
53
  def web_search(query: str) -> str:
54
+ """Enhanced web search with better error handling"""
55
  try:
56
  if SERPER_API_KEY:
57
+ params = {'q': query, 'num': 3}
58
  headers = {'X-API-KEY': SERPER_API_KEY}
59
  response = requests.post(
60
  'https://google.serper.dev/search',
 
63
  timeout=10
64
  )
65
  results = response.json()
 
66
  if 'organic' in results:
67
+ return "\n".join([f"{r['title']}: {r['snippet']}" for r in results['organic'][:3]])[:MAX_RESULT_LENGTH]
68
+ return "No search results found"
 
 
 
 
69
  else:
70
+ return "Search API key not configured"
 
 
71
  except Exception as e:
72
  return f"Search error: {str(e)}"
73
 
74
  def calculator(expression: str) -> str:
75
+ """Safe mathematical evaluation"""
76
  try:
 
77
  expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression)
78
  if not expression:
79
  return "Invalid empty expression"
80
+ return str(numexpr.evaluate(expression))
 
 
 
 
81
  except Exception as e:
82
  return f"Calculation error: {str(e)}"
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def read_webpage(url: str) -> str:
85
+ """Robust webpage content extraction"""
86
  try:
87
  headers = {'User-Agent': 'Mozilla/5.0'}
88
  response = requests.get(url, timeout=10, headers=headers)
 
 
89
  soup = BeautifulSoup(response.text, 'html.parser')
90
 
 
91
  for element in soup(['script', 'style', 'nav', 'footer']):
92
  element.decompose()
93
 
 
94
  text = soup.get_text(separator='\n', strip=True)
95
+ return re.sub(r'\n{3,}', '\n\n', text)[:MAX_RESULT_LENGTH]
 
 
96
  except Exception as e:
97
+ return f"Webpage error: {str(e)}"
98
 
99
  TOOLS = {
100
  "web_search": web_search,
101
  "calculator": calculator,
 
102
  "read_webpage": read_webpage
103
  }
104
 
105
+ # --- Fixed GAIA Agent ---
106
  class GAIA_Agent:
107
  def __init__(self):
108
  self.tools = TOOLS
109
  self.system_prompt = """You are an advanced GAIA problem solver. Follow these steps:
110
+ 1. Analyze the question
111
+ 2. Choose the best tool
112
+ 3. Process results
113
+ 4. Provide final answer
114
 
115
+ Tools:
116
+ - web_search: For general knowledge
117
+ - calculator: For math
118
+ - read_webpage: For web content
 
119
 
120
  Tool format: ```json
121
  {"tool": "tool_name", "args": {"arg1": value}}```
122
 
123
+ Always end with: Final Answer: [answer]"""
124
 
125
  def __call__(self, question: str) -> str:
126
  start_time = time.time()
 
135
  response = self._call_model(prompt)
136
 
137
  if "Final Answer:" in response:
138
+ return response.split("Final Answer:")[-1].strip()[:500]
 
139
 
140
  tool_call = self._parse_tool_call(response)
141
  if tool_call:
142
  tool_name, args = tool_call
143
  observation = self._use_tool(tool_name, args)
144
+ history.append(f"Tool: {tool_name}")
145
+ history.append(f"Result: {observation[:300]}...")
146
  else:
147
+ history.append(f"Thought: {response}")
148
 
149
  gc.collect()
150
 
151
+ return "Maximum steps reached"
152
  except Exception as e:
153
  return f"Error: {str(e)}"
154
 
 
164
  padding=False
165
  )
166
 
167
+ # Fixed generation config without problematic parameters
168
+ outputs = model.generate(
169
+ inputs.input_ids,
170
  max_new_tokens=MAX_TOKENS,
171
  temperature=0.3,
172
  top_p=0.9,
173
  do_sample=True,
174
+ pad_token_id=tokenizer.pad_token_id,
175
+ attention_mask=inputs.attention_mask
176
  )
177
 
 
 
 
 
 
 
 
178
  return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
179
 
180
  def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]:
 
193
  return f"Unknown tool: {tool_name}"
194
 
195
  try:
196
+ # Handle URL extraction for webpage reading
197
  if tool_name == "read_webpage" and "url" not in args:
198
+ if "http" in str(args):
 
 
199
  url = re.search(r'https?://[^\s]+', str(args)).group()
200
  args = {"url": url}
201
 
 
252
  return f"Submission failed: {str(e)}", pd.DataFrame(results)
253
 
254
  # --- Gradio Interface ---
255
+ with gr.Blocks(title="Fixed GAIA Agent") as demo:
256
+ gr.Markdown("## 🛠️ Fixed GAIA Agent")
257
+ gr.Markdown("Resolved the 'DynamicCache' error with improved configuration")
 
 
 
 
 
258
 
259
  with gr.Row():
260
  gr.LoginButton()