File size: 19,297 Bytes
34c5bf3
ca2b63a
 
 
34c5bf3
 
574b6ca
 
 
 
a42d6f7
51e7f46
26e4907
34c5bf3
10e9b7d
a42d6f7
 
 
 
 
 
 
 
34c5bf3
a42d6f7
 
 
 
 
757ebd9
e80aab9
3db6293
e80aab9
34c5bf3
 
 
 
 
 
 
 
 
 
 
 
 
31243f4
34c5bf3
a42d6f7
34c5bf3
 
8f6825e
34c5bf3
 
51e7f46
34c5bf3
 
51e7f46
34c5bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ea9560
34c5bf3
 
 
 
 
ca2b63a
 
34c5bf3
 
 
ca2b63a
 
34c5bf3
 
 
 
 
 
 
 
ca2b63a
 
 
34c5bf3
 
 
a42d6f7
 
34c5bf3
a42d6f7
757ebd9
 
34c5bf3
 
26e4907
34c5bf3
6ea9560
34c5bf3
 
 
 
 
 
 
 
 
 
 
26e4907
34c5bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757ebd9
6ea9560
 
ca2b63a
34c5bf3
 
 
26e4907
8f6825e
34c5bf3
8f6825e
34c5bf3
8f6825e
6ea9560
34c5bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ea9560
34c5bf3
6ea9560
34c5bf3
6ea9560
8f6825e
 
ca2b63a
34c5bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ea9560
34c5bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ea9560
 
34c5bf3
6ea9560
34c5bf3
6ea9560
34c5bf3
 
 
 
6ea9560
34c5bf3
 
6ea9560
 
34c5bf3
 
 
 
 
 
8f6825e
c549c70
34c5bf3
 
 
 
 
c549c70
34c5bf3
 
 
 
 
26e4907
34c5bf3
 
 
 
757ebd9
8f6825e
34c5bf3
8f6825e
 
6ea9560
51e7f46
ca2b63a
34c5bf3
6ea9560
8f6825e
6ea9560
8f6825e
6ea9560
8f6825e
3c4371f
6ea9560
7e4a06b
31243f4
 
6ea9560
8f6825e
 
34c5bf3
31243f4
34c5bf3
 
 
31243f4
34c5bf3
 
 
757ebd9
6ea9560
 
36ed51a
3c4371f
8f6825e
eccf8e4
6ea9560
8f6825e
7d65c66
31243f4
6ea9560
7d65c66
6ea9560
e80aab9
6ea9560
7d65c66
 
a42d6f7
6ea9560
34c5bf3
6ea9560
 
a42d6f7
31243f4
8f6825e
a42d6f7
8f6825e
31243f4
a42d6f7
6ea9560
 
34c5bf3
a42d6f7
31243f4
34c5bf3
8f6825e
 
34c5bf3
 
 
8f6825e
34c5bf3
6ea9560
 
26e4907
6ea9560
8f6825e
26e4907
8f6825e
a42d6f7
26e4907
34c5bf3
 
a42d6f7
51e7f46
6ea9560
34c5bf3
8f6825e
51e7f46
31243f4
6ea9560
34c5bf3
6ea9560
26e4907
6ea9560
8f6825e
26e4907
6ea9560
a42d6f7
26e4907
34c5bf3
8f6825e
a42d6f7
31243f4
6ea9560
 
8f6825e
a42d6f7
6ea9560
26e4907
a42d6f7
 
 
e80aab9
34c5bf3
e80aab9
8f6825e
a42d6f7
8f6825e
 
 
6ea9560
8f6825e
6ea9560
34c5bf3
8f6825e
6ea9560
34c5bf3
6ea9560
 
34c5bf3
8f6825e
6ea9560
 
34c5bf3
 
 
 
 
 
 
6ea9560
8f6825e
6ea9560
8f6825e
a42d6f7
7d65c66
8f6825e
26e4907
 
e80aab9
6ea9560
34c5bf3
 
26e4907
34c5bf3
 
 
 
 
 
 
26e4907
34c5bf3
8f6825e
 
a42d6f7
6ea9560
a42d6f7
8f6825e
 
34c5bf3
8f6825e
6ea9560
8f6825e
a42d6f7
8f6825e
6ea9560
34c5bf3
6ea9560
a42d6f7
 
 
34c5bf3
26e4907
a42d6f7
e80aab9
8f6825e
31243f4
8f6825e
e80aab9
 
 
34c5bf3
 
a42d6f7
 
8f6825e
 
a42d6f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
# app.py - Improved GAIA Agent with GPT-NeoX-20B + LoRA
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.agent import ReActAgent
from llama_index.core.tools import FunctionTool
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import os
import gradio as gr
import requests
import pandas as pd
import traceback
import torch
import re
import json

# Import real tool dependencies
try:
    from duckduckgo_search import DDGS
except ImportError:
    print("Warning: duckduckgo_search not installed. Web search will be limited.")
    DDGS = None

try:
    from sympy import sympify, solve, simplify, N, symbols
    from sympy.core.sympify import SympifyError
except ImportError:
    print("Warning: sympy not installed. Math calculator will be limited.")
    sympify = None
    SympifyError = Exception

# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

def print_trainable_parameters(model):
    """Print trainable parameters info"""
    trainable_parameters = 0
    all_parameters = 0
    for _, param in model.named_parameters():
        all_parameters += param.numel()
        if param.requires_grad:
            trainable_parameters += param.numel()
    print(
        f"Trainable: {trainable_parameters} || All: {all_parameters} || Trainable %: {100 * trainable_parameters / all_parameters:.2f}%"
    )

class ImprovedGAIAAgent:
    def __init__(self):
        print("๐Ÿš€ Initializing Improved GAIA Agent with GPT-NeoX-20B...")
        
        if not torch.cuda.is_available():
            raise RuntimeError("โŒ CUDA required for GPT-NeoX-20B. Please use a GPU environment.")
        
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"๐Ÿ”ฅ GPU Memory: {gpu_memory:.1f}GB")
        
        # Model configuration
        self.model_name = "EleutherAI/gpt-neox-20b"
        
        # 4-bit quantization config for memory efficiency
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        
        # LoRA configuration for efficient fine-tuning capability
        self.lora_config = LoraConfig(
            r=16,  # Increased for better performance
            lora_alpha=32,
            target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],  # More comprehensive targets
            lora_dropout=0.1,
            bias="none",
            task_type="CAUSAL_LM"
        )
        
        self.load_model()
        self.setup_tools()
        self.create_agent()

    def load_model(self):
        """Load and configure the model"""
        print("๐Ÿ“ฅ Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        
        # Add padding token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        print("๐Ÿ“ฅ Loading model with 4-bit quantization...")
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=self.bnb_config,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16
        )
        
        print("๐Ÿ”ง Applying LoRA configuration...")
        self.model = get_peft_model(self.model, self.lora_config)
        print_trainable_parameters(self.model)
        
        # Create LlamaIndex LLM wrapper
        print("๐Ÿ”— Creating LlamaIndex LLM wrapper...")
        self.llm = HuggingFaceLLM(
            model=self.model,
            tokenizer=self.tokenizer,
            context_window=2048,  # GPT-NeoX context length
            max_new_tokens=512,
            generate_kwargs={
                "temperature": 0.1,
                "do_sample": True,
                "top_p": 0.9,
                "repetition_penalty": 1.1,
                "pad_token_id": self.tokenizer.eos_token_id,
            },
            # Improved system message for GAIA tasks
            system_message="""You are a helpful AI assistant that can search the web and perform calculations. 
When answering questions:
1. Think step by step
2. Use tools when you need current information or calculations
3. Be precise and factual
4. For numerical answers, provide exact numbers when possible
5. Always show your reasoning

Available tools: web_search, math_calculator"""
        )

    def setup_tools(self):
        """Setup enhanced tools for GAIA benchmark"""
        self.tools = [
            FunctionTool.from_defaults(
                fn=self.enhanced_web_search,
                name="web_search",
                description="Search the web for current information, facts, people, events, or recent data. Use specific keywords."
            ),
            FunctionTool.from_defaults(
                fn=self.advanced_calculator,
                name="math_calculator", 
                description="Perform mathematical calculations, solve equations, handle percentages, averages, and complex math operations."
            ),
            FunctionTool.from_defaults(
                fn=self.fact_checker,
                name="fact_checker",
                description="Verify facts and get detailed information about people, places, events, or concepts."
            )
        ]

    def enhanced_web_search(self, query: str) -> str:
        """Enhanced web search with better result processing"""
        print(f"๐Ÿ” Enhanced search: {query}")
        
        if not DDGS:
            return "Web search unavailable - duckduckgo_search not installed"
        
        try:
            with DDGS() as ddgs:
                # Get both regular results and news if relevant
                results = list(ddgs.text(query, max_results=8, region='wt-wt'))
                
                if not results:
                    return f"No results found for: {query}"
                
                # Process and format results
                formatted_results = []
                for i, result in enumerate(results, 1):
                    title = result.get('title', 'No title')
                    body = result.get('body', '').strip()
                    url = result.get('href', '')
                    
                    # Extract key information
                    if len(body) > 300:
                        body = body[:300] + "..."
                    
                    formatted_results.append(f"""Result {i}: {title}
Content: {body}
Source: {url}
""")
                
                search_summary = f"Search results for '{query}':\n\n" + "\n".join(formatted_results)
                
                # Try to extract specific answers for common question types
                if any(keyword in query.lower() for keyword in ['how many', 'when was', 'who is', 'what year']):
                    # Look for numbers and dates in results
                    all_text = " ".join([r.get('body', '') for r in results])
                    
                    # Extract years
                    years = re.findall(r'\b(19|20)\d{2}\b', all_text)
                    if years and 'when' in query.lower():
                        search_summary += f"\n\nExtracted years: {', '.join(set(years))}"
                    
                    # Extract numbers
                    numbers = re.findall(r'\b\d+\b', all_text)
                    if numbers and 'how many' in query.lower():
                        search_summary += f"\n\nExtracted numbers: {', '.join(set(numbers)[:5])}"
                
                return search_summary
                
        except Exception as e:
            print(f"โŒ Search error: {e}")
            return f"Search failed: {str(e)}"

    def advanced_calculator(self, expression: str) -> str:
        """Advanced calculator with symbolic math"""
        print(f"๐Ÿงฎ Advanced calculation: {expression}")
        
        try:
            # Clean and normalize the expression
            clean_expr = expression.replace('^', '**').replace('ร—', '*').replace('รท', '/')
            clean_expr = re.sub(r'(\d)\s*\(', r'\1*(', clean_expr)  # Add implicit multiplication
            
            if sympify:
                try:
                    # Try symbolic computation first
                    expr = sympify(clean_expr, evaluate=False)
                    result = simplify(expr)
                    numerical = N(result, 15)  # High precision
                    
                    # Handle different result types
                    if result.is_number:
                        return f"Calculation: {expression} = {numerical}"
                    else:
                        return f"Calculation: {expression} = {result} โ‰ˆ {numerical}"
                        
                except SympifyError:
                    # Fallback to numerical evaluation
                    result = eval(clean_expr)
                    return f"Calculation: {expression} = {result}"
            else:
                # Basic evaluation
                result = eval(clean_expr)
                return f"Calculation: {expression} = {result}"
                
        except Exception as e:
            return f"Could not calculate '{expression}': {str(e)}"

    def fact_checker(self, query: str) -> str:
        """Specialized fact checking with multiple search strategies"""
        print(f"โœ… Fact checking: {query}")
        
        # Try different search strategies
        search_variations = [
            query,
            f"{query} facts",
            f"{query} biography" if any(word in query.lower() for word in ['who is', 'person', 'artist']) else f"{query} information",
        ]
        
        all_results = []
        for search_query in search_variations[:2]:  # Limit to avoid rate limiting
            result = self.enhanced_web_search(search_query)
            if "No results found" not in result:
                all_results.append(f"Search: {search_query}\n{result}")
        
        return "\n\n" + "="*50 + "\n\n".join(all_results) if all_results else f"Could not verify facts about: {query}"

    def create_agent(self):
        """Create the ReAct agent"""
        print("๐Ÿค– Creating ReAct agent...")
        try:
            self.agent = ReActAgent.from_tools(
                tools=self.tools,
                llm=self.llm,
                verbose=True,
                max_iterations=5,  # Allow more iterations for complex problems
                react_chat_formatter=None,  # Use default formatter
            )
            print("โœ… ReAct Agent created successfully")
        except Exception as e:
            print(f"โŒ Agent creation failed: {e}")
            traceback.print_exc()
            raise

    def __call__(self, question: str) -> str:
        """Process question through the agent"""
        print(f"\n" + "="*60)
        print(f"๐Ÿค” Processing: {question}")
        print("="*60)
        
        try:
            # Use the agent to process the question
            response = self.agent.query(question)
            answer = str(response).strip()
            
            # Validate response quality
            if len(answer) < 10 or answer.lower() in ['error', 'none', 'unknown']:
                print("โš ๏ธ Poor response, trying direct approach...")
                return self._direct_approach(question)
            
            print(f"โœ… Agent response: {answer[:200]}...")
            return answer
            
        except Exception as e:
            print(f"โŒ Agent error: {e}")
            print("๐Ÿ”„ Falling back to direct approach...")
            return self._direct_approach(question)

    def _direct_approach(self, question: str) -> str:
        """Direct approach when agent fails"""
        question_lower = question.lower()
        
        # Determine approach based on question type
        if any(term in question_lower for term in ['calculate', 'compute', 'math', '+', '-', '*', '/', '=', 'percentage', 'average']):
            # Math-focused approach
            math_result = self.advanced_calculator(question)
            return math_result
        
        elif any(term in question_lower for term in ['who is', 'when was', 'where is', 'what is', 'how many']):
            # Search-focused approach
            search_result = self.enhanced_web_search(question)
            fact_result = self.fact_checker(question)
            return f"{search_result}\n\nFact Check:\n{fact_result}"
        
        else:
            # General approach
            search_result = self.enhanced_web_search(question)
            return search_result

def cleanup_memory():
    """Clean up GPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("๐Ÿงน Memory cleaned")

def run_and_submit_all(profile: gr.OAuthProfile | None):
    """Run evaluation with improved agent"""
    
    if not profile:
        return "โŒ Please login to Hugging Face first", None

    username = profile.username
    print(f"๐Ÿ‘ค User: {username}")

    # API endpoints
    api_url = DEFAULT_API_URL
    questions_url = f"{api_url}/questions"
    submit_url = f"{api_url}/submit"
    
    cleanup_memory()

    # Initialize improved agent
    try:
        print("๐Ÿš€ Initializing Improved GAIA Agent...")
        agent = ImprovedGAIAAgent()
        print("โœ… Agent initialized successfully")
    except Exception as e:
        error_msg = f"โŒ Agent initialization failed: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return error_msg, None

    # Get space info
    space_id = os.getenv("SPACE_ID", "unknown")
    agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"

    # Fetch questions
    try:
        print("๐Ÿ“ฅ Fetching questions...")
        response = requests.get(questions_url, timeout=30)
        response.raise_for_status()
        questions_data = response.json()
        print(f"๐Ÿ“‹ Got {len(questions_data)} questions")
    except Exception as e:
        return f"โŒ Failed to fetch questions: {str(e)}", None

    # Process all questions
    results_log = []
    answers_payload = []
    
    print("\n" + "="*50)
    print("๐Ÿš€ STARTING GAIA EVALUATION")
    print("="*50)
    
    for i, item in enumerate(questions_data, 1):
        task_id = item.get("task_id")
        question_text = item.get("question")
        
        if not task_id or not question_text:
            continue
            
        print(f"\n๐Ÿ“ Question {i}/{len(questions_data)}")
        print(f"๐Ÿ†” ID: {task_id}")
        print(f"โ“ Question: {question_text}")
        
        try:
            # Get answer from improved agent
            answer = agent(question_text)
            
            # Ensure answer is meaningful
            if not answer or len(answer.strip()) < 5:
                answer = f"Unable to determine answer for: {question_text[:100]}..."
            
            print(f"โœ… Answer: {answer[:200]}...")
            
            # Store results
            answers_payload.append({
                "task_id": task_id,
                "submitted_answer": answer
            })
            
            results_log.append({
                "Task ID": task_id,
                "Question": question_text[:150] + ("..." if len(question_text) > 150 else ""),
                "Answer": answer[:200] + ("..." if len(answer) > 200 else "")
            })
            
            # Memory cleanup every few questions
            if i % 3 == 0:
                cleanup_memory()
                
        except Exception as e:
            print(f"โŒ Error processing {task_id}: {e}")
            error_answer = f"Processing error: {str(e)[:150]}"
            
            answers_payload.append({
                "task_id": task_id,
                "submitted_answer": error_answer
            })
            
            results_log.append({
                "Task ID": task_id,
                "Question": question_text[:150] + "...",
                "Answer": error_answer
            })

    print(f"\n๐Ÿ“ค Submitting {len(answers_payload)} answers...")

    # Submit answers
    submission_data = {
        "username": username,
        "agent_code": agent_code,
        "answers": answers_payload
    }
    
    try:
        response = requests.post(submit_url, json=submission_data, timeout=180)
        response.raise_for_status()
        result_data = response.json()
        
        score = result_data.get('score', 0)
        correct = result_data.get('correct_count', 0)
        total = result_data.get('total_attempted', len(answers_payload))
        message = result_data.get('message', '')
        
        # Create final status message
        final_status = f"""๐ŸŽ‰ IMPROVED GAIA EVALUATION COMPLETE!

๐Ÿ‘ค User: {username}
๐Ÿค– Model: GPT-NeoX-20B + LoRA + 4-bit Quantization
๐Ÿ“Š Final Score: {score}%
โœ… Correct: {correct}/{total}
๐ŸŽฏ Target: 30%+ {'๐ŸŽ‰ ACHIEVED!' if score >= 30 else '๐Ÿ“ˆ Significant improvement expected!'}

๐Ÿ“ Message: {message}

๐Ÿ”ง Improvements Made:
- โœ… Proper causal LM (GPT-NeoX-20B) instead of encoder-decoder
- โœ… 4-bit quantization for memory efficiency  
- โœ… LoRA for better parameter efficiency
- โœ… Enhanced tools with fact checking
- โœ… Better reasoning prompts
- โœ… Multi-strategy search approach
"""
        
        print(f"\n๐Ÿ† FINAL SCORE: {score}%")
        return final_status, pd.DataFrame(results_log)
        
    except Exception as e:
        error_msg = f"โŒ Submission failed: {str(e)}"
        print(error_msg)
        return error_msg, pd.DataFrame(results_log)

# --- Gradio Interface ---
with gr.Blocks(title="Improved GAIA Agent", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ๐Ÿš€ Improved GAIA Agent - GPT-NeoX-20B + LoRA")
    gr.Markdown("""
    **Major Improvements:**
    - ๐Ÿง  **GPT-NeoX-20B**: 20B parameter causal language model (vs 220M FLAN-T5)
    - โšก **4-bit Quantization**: Memory efficient loading with BitsAndBytes
    - ๐ŸŽฏ **LoRA**: Parameter-efficient fine-tuning ready
    - ๐Ÿ” **Enhanced Tools**: Multi-strategy search + fact checking + advanced math
    - ๐Ÿค– **Better ReAct**: Improved reasoning prompts and error handling
    - ๐Ÿ“ˆ **Expected**: Significant improvement over 0% baseline
    
    **Requirements**: CUDA GPU with 16GB+ VRAM
    """)

    with gr.Row():
        gr.LoginButton()
    
    with gr.Row():
        run_button = gr.Button(
            "๐Ÿš€ Run Improved GAIA Evaluation", 
            variant="primary", 
            size="lg"
        )
    
    status_output = gr.Textbox(
        label="๐Ÿ“Š Evaluation Results", 
        lines=15, 
        interactive=False
    )
    
    results_table = gr.DataFrame(
        label="๐Ÿ“ Detailed Results",
        wrap=True
    )

    run_button.click(
        fn=run_and_submit_all,
        outputs=[status_output, results_table]
    )

if __name__ == "__main__":
    print("๐Ÿš€ Starting Improved GAIA Agent...")
    print("๐Ÿ’ช Using GPT-NeoX-20B + LoRA + 4-bit Quantization")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )