LamiaYT commited on
Commit
24ec680
·
1 Parent(s): cac5b18
Files changed (1) hide show
  1. app.py +69 -22
app.py CHANGED
@@ -13,24 +13,35 @@ from urllib.parse import urlparse, parse_qs
13
 
14
  # --- Constants ---
15
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
16
- WIKIPEDIA_API_KEY = os.getenv("WIKIPEDIA_API_KEY", "default_key")
17
  MODEL_ID = "HuggingFaceTB/SmolLM-135M-Instruct"
18
 
19
  # --- Initialize Model ---
20
  print("Loading model...")
21
  try:
 
22
  model = AutoModelForCausalLM.from_pretrained(
23
  MODEL_ID,
24
  torch_dtype="auto",
25
  device_map="auto",
26
- attn_implementation="flash_attention_2",
27
  )
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
29
  print("✅ Model loaded successfully")
30
  except Exception as e:
31
  print(f"❌ Failed to load model: {e}")
32
  raise
33
 
 
 
 
 
 
 
34
  # --- Enhanced Tools with Rate Limiting ---
35
 
36
  @tool
@@ -70,6 +81,7 @@ def smart_web_search(query: str) -> str:
70
  except Exception as e:
71
  print(f"Serper API failed: {e}")
72
 
 
73
  if any(term in query.lower() for term in ["wikipedia", "who", "what", "when", "where"]):
74
  return get_wikipedia_info(query)
75
 
@@ -83,10 +95,12 @@ def smart_web_search(query: str) -> str:
83
 
84
  @tool
85
  def get_wikipedia_info(query: str) -> str:
86
- """Enhanced Wikipedia search with API key support."""
87
  try:
 
88
  clean_query = re.sub(r'[^a-zA-Z0-9 ]', '', query)[:100]
89
 
 
90
  params = {
91
  'action': 'query',
92
  'format': 'json',
@@ -97,13 +111,11 @@ def get_wikipedia_info(query: str) -> str:
97
  'utf8': 1
98
  }
99
 
100
- if WIKIPEDIA_API_KEY and WIKIPEDIA_API_KEY != "default_key":
101
- params['apikey'] = WIKIPEDIA_API_KEY
102
-
103
  response = requests.get(
104
  "https://en.wikipedia.org/w/api.php",
105
  params=params,
106
- timeout=10
 
107
  )
108
 
109
  if response.status_code == 200:
@@ -118,9 +130,14 @@ def get_wikipedia_info(query: str) -> str:
118
  if results:
119
  return "\n\n".join(results)
120
 
 
121
  page_title = clean_query.replace(' ', '_')
122
  extract_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{page_title}"
123
- extract_response = requests.get(extract_url, timeout=8)
 
 
 
 
124
 
125
  if extract_response.status_code == 200:
126
  extract_data = extract_response.json()
@@ -153,6 +170,7 @@ def extract_youtube_details(url: str) -> str:
153
 
154
  results = []
155
 
 
156
  try:
157
  oembed_url = f"https://www.youtube.com/oembed?url=https://www.youtube.com/watch?v={video_id}&format=json"
158
  response = requests.get(oembed_url, timeout=10)
@@ -165,16 +183,18 @@ def extract_youtube_details(url: str) -> str:
165
  except Exception as e:
166
  print(f"oEmbed failed: {e}")
167
 
 
168
  try:
169
  video_url = f"https://www.youtube.com/watch?v={video_id}"
170
  headers = {
171
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
172
  }
173
  page_response = requests.get(video_url, headers=headers, timeout=15)
174
 
175
  if page_response.status_code == 200:
176
  content = page_response.text
177
 
 
178
  bird_patterns = [
179
  r'(\d+)\s+bird\s+species',
180
  r'(\d+)\s+species\s+of\s+bird',
@@ -195,6 +215,7 @@ def extract_youtube_details(url: str) -> str:
195
  max_species = max(numbers)
196
  results.append(f"BIRD_SPECIES_COUNT: {max_species}")
197
 
 
198
  view_match = re.search(r'"viewCount":"(\d+)"', content)
199
  if view_match:
200
  views = int(view_match.group(1))
@@ -245,6 +266,7 @@ def solve_advanced_math(problem: str) -> str:
245
  try:
246
  problem_lower = problem.lower()
247
 
 
248
  if "commutative" in problem_lower and "|" in problem:
249
  lines = problem.split('\n')
250
  table_lines = [line for line in lines if '|' in line and any(x in line for x in ['a', 'b', 'c', 'd', 'e'])]
@@ -275,12 +297,14 @@ def solve_advanced_math(problem: str) -> str:
275
  result = sorted(list(breaking_elements))
276
  return ', '.join(result) if result else "No elements break commutativity"
277
 
 
278
  elif "chess" in problem_lower or "move" in problem_lower:
279
  chess_moves = re.findall(r'\b[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?\b', problem)
280
  if chess_moves:
281
  return f"Chess moves found: {', '.join(chess_moves)}"
282
  return "Analyze position for best move: check for tactics, threats, and forcing moves"
283
 
 
284
  numbers = re.findall(r'-?\d+\.?\d*', problem)
285
  if numbers:
286
  nums = [float(n) for n in numbers if n.replace('.', '').replace('-', '').isdigit()]
@@ -300,6 +324,7 @@ def solve_advanced_math(problem: str) -> str:
300
  result *= n
301
  return str(result)
302
 
 
303
  if "%" in problem or "percent" in problem_lower:
304
  percentages = re.findall(r'(\d+\.?\d*)%', problem)
305
  if percentages:
@@ -325,14 +350,25 @@ class OptimizedGAIAAgent:
325
  def generate_with_model(self, prompt: str) -> str:
326
  """Generate response using the SmolLM model"""
327
  try:
328
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
329
- outputs = model.generate(
330
- **inputs,
331
- max_new_tokens=256,
332
- temperature=0.7,
333
- do_sample=True
334
- )
335
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
336
  except Exception as e:
337
  print(f"Model generation failed: {e}")
338
  return ""
@@ -341,9 +377,11 @@ class OptimizedGAIAAgent:
341
  """Analyze question type and provide targeted solution"""
342
  question_lower = question.lower()
343
 
 
344
  if "ecnetnes siht dnatsrednu uoy fi" in question_lower:
345
  return decode_reversed_text(question)
346
 
 
347
  if "youtube.com" in question or "youtu.be" in question:
348
  url_match = re.search(r'https?://(?:www\.)?(?:youtube\.com/watch\?v=|youtu\.be/)([a-zA-Z0-9_-]+)', question)
349
  if url_match:
@@ -351,24 +389,29 @@ class OptimizedGAIAAgent:
351
  if "highest number" in question_lower and "bird species" in question_lower:
352
  numbers = re.findall(r'BIRD_SPECIES_COUNT:\s*(\d+)', result)
353
  if numbers:
354
- return max([int(x) for x in numbers])
355
  return result
356
 
 
357
  if any(term in question_lower for term in ["commutative", "operation", "table", "chess", "checkmate"]):
358
  return solve_advanced_math(question)
359
 
 
360
  if any(term in question_lower for term in ["who", "what", "when", "where", "wikipedia", "article"]):
361
  return get_wikipedia_info(question)
362
 
 
363
  if "olympics" in question_lower or "1928" in question:
364
  return get_wikipedia_info("1928 Summer Olympics")
365
 
 
366
  return smart_web_search(question)
367
 
368
  def solve(self, question: str) -> str:
369
  """Main solving method with fallback chain"""
370
  print(f"Solving: {question[:80]}...")
371
 
 
372
  try:
373
  direct_result = self.analyze_and_solve(question)
374
  if direct_result and len(str(direct_result).strip()) > 3:
@@ -376,13 +419,14 @@ class OptimizedGAIAAgent:
376
  except Exception as e:
377
  print(f"Direct analysis failed: {e}")
378
 
 
379
  try:
380
  time.sleep(2)
381
- prompt = f"""Answer the following question using available tools and knowledge:
382
 
383
  Question: {question}
384
 
385
- Think step by step and provide a detailed answer:"""
386
 
387
  result = self.generate_with_model(prompt)
388
  if result and len(str(result).strip()) > 3:
@@ -390,6 +434,7 @@ Think step by step and provide a detailed answer:"""
390
  except Exception as e:
391
  print(f"Model generation failed: {e}")
392
 
 
393
  time.sleep(3)
394
  return smart_web_search(question)
395
 
@@ -455,6 +500,7 @@ def run_evaluation(profile: gr.OAuthProfile | None):
455
 
456
  print(f"{status} Answer: {str(answer)[:100]}")
457
 
 
458
  time.sleep(random.uniform(2, 4))
459
 
460
  except Exception as e:
@@ -472,6 +518,7 @@ def run_evaluation(profile: gr.OAuthProfile | None):
472
  })
473
  print(f"❌ Error: {e}")
474
 
 
475
  space_id = os.getenv("SPACE_ID", "unknown")
476
  submission = {
477
  "username": username,
@@ -507,7 +554,7 @@ def run_evaluation(profile: gr.OAuthProfile | None):
507
  # --- Gradio Interface ---
508
  with gr.Blocks(title="Optimized GAIA Agent", theme=gr.themes.Soft()) as demo:
509
  gr.Markdown("# 🎯 Optimized GAIA Agent")
510
- gr.Markdown("**SmolLM-135M-Instruct • Rate-limited search • Pattern recognition**")
511
 
512
  with gr.Row():
513
  gr.LoginButton()
@@ -532,7 +579,7 @@ with gr.Blocks(title="Optimized GAIA Agent", theme=gr.themes.Soft()) as demo:
532
  if __name__ == "__main__":
533
  print("🎯 Starting Optimized GAIA Agent...")
534
 
535
- env_vars = ["SPACE_ID", "SERPER_API_KEY", "WIKIPEDIA_API_KEY"]
536
  for var in env_vars:
537
  status = "✅" if os.getenv(var) else "⚠️"
538
  print(f"{status} {var}")
 
13
 
14
  # --- Constants ---
15
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
16
  MODEL_ID = "HuggingFaceTB/SmolLM-135M-Instruct"
17
 
18
  # --- Initialize Model ---
19
  print("Loading model...")
20
  try:
21
+ # Remove flash_attention_2 to avoid dependency issues
22
  model = AutoModelForCausalLM.from_pretrained(
23
  MODEL_ID,
24
  torch_dtype="auto",
25
  device_map="auto",
26
+ # Removed attn_implementation="flash_attention_2"
27
  )
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
29
+
30
+ # Add padding token if not present
31
+ if tokenizer.pad_token is None:
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
  print("✅ Model loaded successfully")
35
  except Exception as e:
36
  print(f"❌ Failed to load model: {e}")
37
  raise
38
 
39
+ # --- Tool Decorator ---
40
+ def tool(func):
41
+ """Simple tool decorator"""
42
+ func._is_tool = True
43
+ return func
44
+
45
  # --- Enhanced Tools with Rate Limiting ---
46
 
47
  @tool
 
81
  except Exception as e:
82
  print(f"Serper API failed: {e}")
83
 
84
+ # Fallback to Wikipedia for knowledge queries
85
  if any(term in query.lower() for term in ["wikipedia", "who", "what", "when", "where"]):
86
  return get_wikipedia_info(query)
87
 
 
95
 
96
  @tool
97
  def get_wikipedia_info(query: str) -> str:
98
+ """Enhanced Wikipedia search without API key requirement."""
99
  try:
100
+ # Clean the query
101
  clean_query = re.sub(r'[^a-zA-Z0-9 ]', '', query)[:100]
102
 
103
+ # Use Wikipedia API without API key (public access)
104
  params = {
105
  'action': 'query',
106
  'format': 'json',
 
111
  'utf8': 1
112
  }
113
 
 
 
 
114
  response = requests.get(
115
  "https://en.wikipedia.org/w/api.php",
116
  params=params,
117
+ timeout=10,
118
+ headers={'User-Agent': 'GAIA-Agent/1.0'}
119
  )
120
 
121
  if response.status_code == 200:
 
130
  if results:
131
  return "\n\n".join(results)
132
 
133
+ # Fallback to REST API
134
  page_title = clean_query.replace(' ', '_')
135
  extract_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{page_title}"
136
+ extract_response = requests.get(
137
+ extract_url,
138
+ timeout=8,
139
+ headers={'User-Agent': 'GAIA-Agent/1.0'}
140
+ )
141
 
142
  if extract_response.status_code == 200:
143
  extract_data = extract_response.json()
 
170
 
171
  results = []
172
 
173
+ # Try oEmbed API first
174
  try:
175
  oembed_url = f"https://www.youtube.com/oembed?url=https://www.youtube.com/watch?v={video_id}&format=json"
176
  response = requests.get(oembed_url, timeout=10)
 
183
  except Exception as e:
184
  print(f"oEmbed failed: {e}")
185
 
186
+ # Try to extract additional info from page
187
  try:
188
  video_url = f"https://www.youtube.com/watch?v={video_id}"
189
  headers = {
190
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
191
  }
192
  page_response = requests.get(video_url, headers=headers, timeout=15)
193
 
194
  if page_response.status_code == 200:
195
  content = page_response.text
196
 
197
+ # Look for bird species mentions
198
  bird_patterns = [
199
  r'(\d+)\s+bird\s+species',
200
  r'(\d+)\s+species\s+of\s+bird',
 
215
  max_species = max(numbers)
216
  results.append(f"BIRD_SPECIES_COUNT: {max_species}")
217
 
218
+ # Extract view count
219
  view_match = re.search(r'"viewCount":"(\d+)"', content)
220
  if view_match:
221
  views = int(view_match.group(1))
 
266
  try:
267
  problem_lower = problem.lower()
268
 
269
+ # Handle commutative operation tables
270
  if "commutative" in problem_lower and "|" in problem:
271
  lines = problem.split('\n')
272
  table_lines = [line for line in lines if '|' in line and any(x in line for x in ['a', 'b', 'c', 'd', 'e'])]
 
297
  result = sorted(list(breaking_elements))
298
  return ', '.join(result) if result else "No elements break commutativity"
299
 
300
+ # Handle chess problems
301
  elif "chess" in problem_lower or "move" in problem_lower:
302
  chess_moves = re.findall(r'\b[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?\b', problem)
303
  if chess_moves:
304
  return f"Chess moves found: {', '.join(chess_moves)}"
305
  return "Analyze position for best move: check for tactics, threats, and forcing moves"
306
 
307
+ # Handle basic arithmetic
308
  numbers = re.findall(r'-?\d+\.?\d*', problem)
309
  if numbers:
310
  nums = [float(n) for n in numbers if n.replace('.', '').replace('-', '').isdigit()]
 
324
  result *= n
325
  return str(result)
326
 
327
+ # Handle percentages
328
  if "%" in problem or "percent" in problem_lower:
329
  percentages = re.findall(r'(\d+\.?\d*)%', problem)
330
  if percentages:
 
350
  def generate_with_model(self, prompt: str) -> str:
351
  """Generate response using the SmolLM model"""
352
  try:
353
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
354
+
355
+ # Move inputs to same device as model
356
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
357
+
358
+ with torch.no_grad():
359
+ outputs = model.generate(
360
+ **inputs,
361
+ max_new_tokens=256,
362
+ temperature=0.7,
363
+ do_sample=True,
364
+ pad_token_id=tokenizer.eos_token_id
365
+ )
366
+
367
+ # Decode only the new tokens
368
+ new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
369
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
370
+ return response.strip()
371
+
372
  except Exception as e:
373
  print(f"Model generation failed: {e}")
374
  return ""
 
377
  """Analyze question type and provide targeted solution"""
378
  question_lower = question.lower()
379
 
380
+ # Handle reversed text
381
  if "ecnetnes siht dnatsrednu uoy fi" in question_lower:
382
  return decode_reversed_text(question)
383
 
384
+ # Handle YouTube links
385
  if "youtube.com" in question or "youtu.be" in question:
386
  url_match = re.search(r'https?://(?:www\.)?(?:youtube\.com/watch\?v=|youtu\.be/)([a-zA-Z0-9_-]+)', question)
387
  if url_match:
 
389
  if "highest number" in question_lower and "bird species" in question_lower:
390
  numbers = re.findall(r'BIRD_SPECIES_COUNT:\s*(\d+)', result)
391
  if numbers:
392
+ return str(max([int(x) for x in numbers]))
393
  return result
394
 
395
+ # Handle math problems
396
  if any(term in question_lower for term in ["commutative", "operation", "table", "chess", "checkmate"]):
397
  return solve_advanced_math(question)
398
 
399
+ # Handle knowledge questions
400
  if any(term in question_lower for term in ["who", "what", "when", "where", "wikipedia", "article"]):
401
  return get_wikipedia_info(question)
402
 
403
+ # Handle Olympics queries
404
  if "olympics" in question_lower or "1928" in question:
405
  return get_wikipedia_info("1928 Summer Olympics")
406
 
407
+ # Default to web search
408
  return smart_web_search(question)
409
 
410
  def solve(self, question: str) -> str:
411
  """Main solving method with fallback chain"""
412
  print(f"Solving: {question[:80]}...")
413
 
414
+ # Try direct analysis first
415
  try:
416
  direct_result = self.analyze_and_solve(question)
417
  if direct_result and len(str(direct_result).strip()) > 3:
 
419
  except Exception as e:
420
  print(f"Direct analysis failed: {e}")
421
 
422
+ # Try model generation
423
  try:
424
  time.sleep(2)
425
+ prompt = f"""Answer the following question concisely and accurately:
426
 
427
  Question: {question}
428
 
429
+ Answer:"""
430
 
431
  result = self.generate_with_model(prompt)
432
  if result and len(str(result).strip()) > 3:
 
434
  except Exception as e:
435
  print(f"Model generation failed: {e}")
436
 
437
+ # Final fallback to web search
438
  time.sleep(3)
439
  return smart_web_search(question)
440
 
 
500
 
501
  print(f"{status} Answer: {str(answer)[:100]}")
502
 
503
+ # Rate limiting
504
  time.sleep(random.uniform(2, 4))
505
 
506
  except Exception as e:
 
518
  })
519
  print(f"❌ Error: {e}")
520
 
521
+ # Submit results
522
  space_id = os.getenv("SPACE_ID", "unknown")
523
  submission = {
524
  "username": username,
 
554
  # --- Gradio Interface ---
555
  with gr.Blocks(title="Optimized GAIA Agent", theme=gr.themes.Soft()) as demo:
556
  gr.Markdown("# 🎯 Optimized GAIA Agent")
557
+ gr.Markdown("**SmolLM-135M-Instruct • Wikipedia Search • Pattern Recognition**")
558
 
559
  with gr.Row():
560
  gr.LoginButton()
 
579
  if __name__ == "__main__":
580
  print("🎯 Starting Optimized GAIA Agent...")
581
 
582
+ env_vars = ["SPACE_ID", "SERPER_API_KEY"]
583
  for var in env_vars:
584
  status = "✅" if os.getenv(var) else "⚠️"
585
  print(f"{status} {var}")