LamiaYT commited on
Commit
aa6f3a8
·
1 Parent(s): 57b9551
Files changed (1) hide show
  1. app.py +117 -82
app.py CHANGED
@@ -9,25 +9,26 @@ from pdfminer.high_level import extract_text
9
  from bs4 import BeautifulSoup
10
  from typing import List, Dict, Optional, Tuple
11
  from dotenv import load_dotenv
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
13
  import torch
14
  import time
15
  import gc
16
 
17
- # --- Load Environment Variables ---
18
  load_dotenv()
19
  SERPER_API_KEY = os.getenv("SERPER_API_KEY")
 
 
20
 
21
  # --- Constants ---
22
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
23
  MAX_STEPS = 6
24
  MAX_TOKENS = 256
25
- MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
26
  TIMEOUT_PER_QUESTION = 45
27
  MAX_RESULT_LENGTH = 500
 
28
 
29
- # --- Fixed Model Loading ---
30
- print("Loading model with fixed configuration...")
31
  start_time = time.time()
32
 
33
  model = AutoModelForCausalLM.from_pretrained(
@@ -49,30 +50,35 @@ if tokenizer.pad_token is None:
49
 
50
  print(f"Model loaded in {time.time() - start_time:.2f} seconds")
51
 
52
- # --- Tools Implementation ---
53
  def web_search(query: str) -> str:
54
- """Enhanced web search with better error handling"""
55
  try:
56
- if SERPER_API_KEY:
57
- params = {'q': query, 'num': 3}
58
- headers = {'X-API-KEY': SERPER_API_KEY}
59
- response = requests.post(
60
- 'https://google.serper.dev/search',
61
- headers=headers,
62
- json=params,
63
- timeout=10
64
- )
65
- results = response.json()
66
- if 'organic' in results:
67
- return "\n".join([f"{r['title']}: {r['snippet']}" for r in results['organic'][:3]])[:MAX_RESULT_LENGTH]
68
- return "No search results found"
69
- else:
70
  return "Search API key not configured"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
  return f"Search error: {str(e)}"
73
 
74
  def calculator(expression: str) -> str:
75
- """Safe mathematical evaluation"""
76
  try:
77
  expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression)
78
  if not expression:
@@ -82,17 +88,22 @@ def calculator(expression: str) -> str:
82
  return f"Calculation error: {str(e)}"
83
 
84
  def read_webpage(url: str) -> str:
85
- """Robust webpage content extraction"""
86
  try:
 
 
 
87
  headers = {'User-Agent': 'Mozilla/5.0'}
88
- response = requests.get(url, timeout=10, headers=headers)
89
- soup = BeautifulSoup(response.text, 'html.parser')
90
 
91
- for element in soup(['script', 'style', 'nav', 'footer']):
 
92
  element.decompose()
93
 
94
- text = soup.get_text(separator='\n', strip=True)
95
- return re.sub(r'\n{3,}', '\n\n', text)[:MAX_RESULT_LENGTH]
 
 
96
  except Exception as e:
97
  return f"Webpage error: {str(e)}"
98
 
@@ -102,25 +113,26 @@ TOOLS = {
102
  "read_webpage": read_webpage
103
  }
104
 
105
- # --- Fixed GAIA Agent ---
106
  class GAIA_Agent:
107
  def __init__(self):
108
  self.tools = TOOLS
109
- self.system_prompt = """You are an advanced GAIA problem solver. Follow these steps:
110
  1. Analyze the question
111
- 2. Choose the best tool
112
- 3. Process results
113
- 4. Provide final answer
 
114
 
115
  Tools:
116
- - web_search: For general knowledge
117
- - calculator: For math
118
- - read_webpage: For web content
119
 
120
  Tool format: ```json
121
- {"tool": "tool_name", "args": {"arg1": value}}```
122
 
123
- Always end with: Final Answer: [answer]"""
124
 
125
  def __call__(self, question: str) -> str:
126
  start_time = time.time()
@@ -150,61 +162,75 @@ Always end with: Final Answer: [answer]"""
150
 
151
  return "Maximum steps reached"
152
  except Exception as e:
153
- return f"Error: {str(e)}"
154
 
155
  def _build_prompt(self, history: List[str]) -> str:
156
  return f"<|system|>\n{self.system_prompt}<|end|>\n<|user|>\n" + "\n".join(history) + "<|end|>\n<|assistant|>"
157
 
158
  def _call_model(self, prompt: str) -> str:
159
- inputs = tokenizer(
160
- prompt,
161
- return_tensors="pt",
162
- truncation=True,
163
- max_length=3072,
164
- padding=False
165
- )
166
-
167
- # Fixed generation config without problematic parameters
168
- outputs = model.generate(
169
- inputs.input_ids,
170
- max_new_tokens=MAX_TOKENS,
171
- temperature=0.3,
172
- top_p=0.9,
173
- do_sample=True,
174
- pad_token_id=tokenizer.pad_token_id,
175
- attention_mask=inputs.attention_mask
176
- )
177
-
178
- return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
 
 
 
 
 
 
179
 
180
  def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]:
181
  try:
182
  json_match = re.search(r'```json\s*({.+?})\s*```', text, re.DOTALL)
183
- if json_match:
184
- tool_call = json.loads(json_match.group(1))
185
- if "tool" in tool_call and "args" in tool_call:
186
- return tool_call["tool"], tool_call["args"]
 
 
 
 
 
 
 
 
187
  except:
188
  return None
189
- return None
190
 
191
  def _use_tool(self, tool_name: str, args: Dict) -> str:
192
  if tool_name not in self.tools:
193
  return f"Unknown tool: {tool_name}"
194
 
195
  try:
196
- # Handle URL extraction for webpage reading
197
  if tool_name == "read_webpage" and "url" not in args:
198
- if "http" in str(args):
199
- url = re.search(r'https?://[^\s]+', str(args)).group()
200
- args = {"url": url}
 
 
201
 
202
  return str(self.tools[tool_name](**args))[:MAX_RESULT_LENGTH]
203
  except Exception as e:
204
  return f"Tool error: {str(e)}"
205
 
206
- # --- Evaluation Runner ---
207
- def run_and_submit_all(profile: gr.OAuthProfile | None):
208
  if not profile:
209
  return "Please login first", None
210
 
@@ -213,8 +239,11 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
213
  submit_url = f"{DEFAULT_API_URL}/submit"
214
 
215
  try:
216
- response = requests.get(questions_url, timeout=15)
 
217
  questions_data = response.json()
 
 
218
  except Exception as e:
219
  return f"Failed to get questions: {str(e)}", None
220
 
@@ -245,28 +274,34 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
245
  }
246
 
247
  try:
248
- response = requests.post(submit_url, json=submission, timeout=30)
 
249
  result = response.json()
250
- return f"Submitted! Score: {result.get('score', 'N/A')}", pd.DataFrame(results)
 
 
 
251
  except Exception as e:
252
- return f"Submission failed: {str(e)}", pd.DataFrame(results)
253
 
254
  # --- Gradio Interface ---
255
- with gr.Blocks(title="Fixed GAIA Agent") as demo:
256
- gr.Markdown("## 🛠️ Fixed GAIA Agent")
257
- gr.Markdown("Resolved the 'DynamicCache' error with improved configuration")
258
 
259
  with gr.Row():
260
  gr.LoginButton()
261
  run_btn = gr.Button("Run Evaluation", variant="primary")
262
 
263
- output_status = gr.Textbox(label="Status")
264
  results_table = gr.DataFrame(label="Results")
265
 
266
  run_btn.click(
267
- run_and_submit_all,
268
- outputs=[output_status, results_table]
269
  )
270
 
271
  if __name__ == "__main__":
272
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
9
  from bs4 import BeautifulSoup
10
  from typing import List, Dict, Optional, Tuple
11
  from dotenv import load_dotenv
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
  import torch
14
  import time
15
  import gc
16
 
17
+ # --- Configuration ---
18
  load_dotenv()
19
  SERPER_API_KEY = os.getenv("SERPER_API_KEY")
20
+ MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
21
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
22
 
23
  # --- Constants ---
 
24
  MAX_STEPS = 6
25
  MAX_TOKENS = 256
 
26
  TIMEOUT_PER_QUESTION = 45
27
  MAX_RESULT_LENGTH = 500
28
+ MAX_ATTEMPTS = 2
29
 
30
+ # --- Model Initialization ---
31
+ print("Initializing model with fixed cache configuration...")
32
  start_time = time.time()
33
 
34
  model = AutoModelForCausalLM.from_pretrained(
 
50
 
51
  print(f"Model loaded in {time.time() - start_time:.2f} seconds")
52
 
53
+ # --- Tool Implementations ---
54
  def web_search(query: str) -> str:
 
55
  try:
56
+ if not SERPER_API_KEY:
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return "Search API key not configured"
58
+
59
+ params = {'q': query, 'num': 3}
60
+ headers = {'X-API-KEY': SERPER_API_KEY}
61
+ response = requests.post(
62
+ 'https://google.serper.dev/search',
63
+ headers=headers,
64
+ json=params,
65
+ timeout=10
66
+ )
67
+ response.raise_for_status()
68
+ results = response.json()
69
+
70
+ if 'organic' not in results or not results['organic']:
71
+ return "No relevant results found"
72
+
73
+ output = []
74
+ for r in results['organic'][:3]:
75
+ if 'title' in r and 'snippet' in r:
76
+ output.append(f"Title: {r['title']}\nSnippet: {r['snippet']}")
77
+ return "\n\n".join(output)[:MAX_RESULT_LENGTH]
78
  except Exception as e:
79
  return f"Search error: {str(e)}"
80
 
81
  def calculator(expression: str) -> str:
 
82
  try:
83
  expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression)
84
  if not expression:
 
88
  return f"Calculation error: {str(e)}"
89
 
90
  def read_webpage(url: str) -> str:
 
91
  try:
92
+ if not re.match(r'^https?://', url):
93
+ return "Invalid URL format"
94
+
95
  headers = {'User-Agent': 'Mozilla/5.0'}
96
+ response = requests.get(url, timeout=15, headers=headers)
97
+ response.raise_for_status()
98
 
99
+ soup = BeautifulSoup(response.text, 'html.parser')
100
+ for element in soup(['script', 'style', 'nav', 'footer', 'aside']):
101
  element.decompose()
102
 
103
+ main_content = soup.find('main') or soup.find('article') or soup
104
+ text = main_content.get_text(separator='\n', strip=True)
105
+ text = re.sub(r'\n{3,}', '\n\n', text)
106
+ return text[:MAX_RESULT_LENGTH]
107
  except Exception as e:
108
  return f"Webpage error: {str(e)}"
109
 
 
113
  "read_webpage": read_webpage
114
  }
115
 
116
+ # --- GAIA Agent Class ---
117
  class GAIA_Agent:
118
  def __init__(self):
119
  self.tools = TOOLS
120
+ self.system_prompt = """You are an advanced problem solver. Follow these steps:
121
  1. Analyze the question
122
+ 2. Select the best tool
123
+ 3. Execute with proper arguments
124
+ 4. Interpret results
125
+ 5. Provide final answer
126
 
127
  Tools:
128
+ - web_search(query): For general knowledge
129
+ - calculator(expression): For math
130
+ - read_webpage(url): For web content
131
 
132
  Tool format: ```json
133
+ {"tool": "tool_name", "args": {"arg": value}}```
134
 
135
+ Always conclude with: Final Answer: [answer]"""
136
 
137
  def __call__(self, question: str) -> str:
138
  start_time = time.time()
 
162
 
163
  return "Maximum steps reached"
164
  except Exception as e:
165
+ return f"Agent error: {str(e)}"
166
 
167
  def _build_prompt(self, history: List[str]) -> str:
168
  return f"<|system|>\n{self.system_prompt}<|end|>\n<|user|>\n" + "\n".join(history) + "<|end|>\n<|assistant|>"
169
 
170
  def _call_model(self, prompt: str) -> str:
171
+ for attempt in range(MAX_ATTEMPTS):
172
+ try:
173
+ inputs = tokenizer(
174
+ prompt,
175
+ return_tensors="pt",
176
+ truncation=True,
177
+ max_length=3072,
178
+ padding=False
179
+ )
180
+
181
+ outputs = model.generate(
182
+ inputs.input_ids,
183
+ max_new_tokens=MAX_TOKENS,
184
+ temperature=0.3,
185
+ top_p=0.9,
186
+ do_sample=True,
187
+ pad_token_id=tokenizer.pad_token_id,
188
+ attention_mask=inputs.attention_mask
189
+ )
190
+
191
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
192
+ except Exception as e:
193
+ if attempt < MAX_ATTEMPTS - 1:
194
+ time.sleep(0.5)
195
+ continue
196
+ return f"Model error: {str(e)}"
197
 
198
  def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]:
199
  try:
200
  json_match = re.search(r'```json\s*({.+?})\s*```', text, re.DOTALL)
201
+ if not json_match:
202
+ return None
203
+
204
+ tool_call = json.loads(json_match.group(1))
205
+ if not isinstance(tool_call, dict):
206
+ return None
207
+ if "tool" not in tool_call or "args" not in tool_call:
208
+ return None
209
+ if not isinstance(tool_call["args"], dict):
210
+ return None
211
+
212
+ return tool_call["tool"], tool_call["args"]
213
  except:
214
  return None
 
215
 
216
  def _use_tool(self, tool_name: str, args: Dict) -> str:
217
  if tool_name not in self.tools:
218
  return f"Unknown tool: {tool_name}"
219
 
220
  try:
 
221
  if tool_name == "read_webpage" and "url" not in args:
222
+ url_match = re.search(r'https?://[^\s]+', str(args))
223
+ if url_match:
224
+ args = {"url": url_match.group()}
225
+ else:
226
+ return "Missing URL argument"
227
 
228
  return str(self.tools[tool_name](**args))[:MAX_RESULT_LENGTH]
229
  except Exception as e:
230
  return f"Tool error: {str(e)}"
231
 
232
+ # --- Evaluation Function ---
233
+ def run_evaluation(profile: gr.OAuthProfile | None):
234
  if not profile:
235
  return "Please login first", None
236
 
 
239
  submit_url = f"{DEFAULT_API_URL}/submit"
240
 
241
  try:
242
+ response = requests.get(questions_url, timeout=20)
243
+ response.raise_for_status()
244
  questions_data = response.json()
245
+ if not questions_data:
246
+ return "No questions available", None
247
  except Exception as e:
248
  return f"Failed to get questions: {str(e)}", None
249
 
 
274
  }
275
 
276
  try:
277
+ response = requests.post(submit_url, json=submission, timeout=60)
278
+ response.raise_for_status()
279
  result = response.json()
280
+ status = (f" Submission Successful!\n"
281
+ f"Score: {result.get('score', 'N/A')}%\n"
282
+ f"Correct: {result.get('correct_count', '?')}/{result.get('total_attempted', '?')}")
283
+ return status, pd.DataFrame(results)
284
  except Exception as e:
285
+ return f"Submission failed: {str(e)}", pd.DataFrame(results)
286
 
287
  # --- Gradio Interface ---
288
+ with gr.Blocks(title="Fixed GAIA Agent", theme=gr.themes.Soft()) as demo:
289
+ gr.Markdown("# 🚀 GAIA Agent Evaluation")
 
290
 
291
  with gr.Row():
292
  gr.LoginButton()
293
  run_btn = gr.Button("Run Evaluation", variant="primary")
294
 
295
+ status_output = gr.Textbox(label="Status")
296
  results_table = gr.DataFrame(label="Results")
297
 
298
  run_btn.click(
299
+ run_evaluation,
300
+ outputs=[status_output, results_table]
301
  )
302
 
303
  if __name__ == "__main__":
304
+ demo.launch(
305
+ server_name="0.0.0.0",
306
+ server_port=7860
307
+ )