File size: 10,730 Bytes
574b6ca
 
 
d591a7a
086b425
d591a7a
 
0f20e93
d591a7a
 
d66e9b7
9f29ca9
8c139ea
f0b3f91
 
8c139ea
 
9f29ca9
 
 
757ebd9
d66e9b7
3db6293
d66e9b7
 
9f29ca9
d66e9b7
 
e80aab9
d66e9b7
 
f0b3f91
9f29ca9
d591a7a
 
f0b3f91
cccb073
d66e9b7
 
8c139ea
 
 
d66e9b7
d3c0517
d66e9b7
d591a7a
25405da
8c139ea
 
9f29ca9
d66e9b7
d591a7a
d66e9b7
d591a7a
d66e9b7
d591a7a
9f29ca9
d66e9b7
 
9f29ca9
 
 
8c139ea
d66e9b7
9f29ca9
 
d66e9b7
 
 
 
 
 
 
 
9f29ca9
 
d66e9b7
 
 
 
d591a7a
 
d66e9b7
d591a7a
d66e9b7
 
 
 
 
 
 
 
cccb073
d66e9b7
 
d591a7a
 
d66e9b7
d591a7a
8c139ea
d66e9b7
 
 
 
 
 
 
 
d591a7a
 
d66e9b7
d591a7a
d66e9b7
 
 
 
d591a7a
d66e9b7
 
 
 
 
 
 
 
 
 
 
 
d591a7a
 
 
d66e9b7
d591a7a
 
 
 
d66e9b7
 
bbb34b9
d591a7a
d66e9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d591a7a
 
cccb073
d66e9b7
0f20e93
8c139ea
 
cccb073
d66e9b7
d3c0517
d66e9b7
 
cccb073
d3c0517
d66e9b7
 
c66203c
d66e9b7
 
 
 
 
 
8c139ea
d66e9b7
cccb073
d66e9b7
8c139ea
d66e9b7
8c139ea
d66e9b7
d591a7a
d66e9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c139ea
d66e9b7
 
d591a7a
d66e9b7
d591a7a
d66e9b7
 
 
 
 
19b7914
d66e9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ca047
d66e9b7
d591a7a
cccb073
d66e9b7
d3c0517
d66e9b7
 
 
19b7914
eccf8e4
d66e9b7
 
a39e119
d66e9b7
 
d3c0517
 
8c139ea
d66e9b7
bbb34b9
d66e9b7
8c139ea
d66e9b7
f96a820
8c139ea
d66e9b7
 
d3c0517
d66e9b7
 
 
 
 
 
d3c0517
d66e9b7
 
 
 
 
d3c0517
e80aab9
d66e9b7
d3c0517
d66e9b7
7963312
d66e9b7
7963312
d66e9b7
 
 
 
 
 
 
 
 
8c139ea
d66e9b7
 
 
8c139ea
d66e9b7
 
9f29ca9
d66e9b7
 
 
 
e80aab9
 
d66e9b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import gradio as gr
import requests
import json
import re
import numexpr
import pandas as pd
import math
from pdfminer.high_level import extract_text
from bs4 import BeautifulSoup
from typing import Dict, Any, List, Tuple, Optional
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
import time
import gc

# --- Load Environment Variables ---
load_dotenv()
SERPER_API_KEY = os.getenv("SERPER_API_KEY")

# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
MAX_STEPS = 6  # Increased from 4
MAX_TOKENS = 256  # Increased from 128
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
TIMEOUT_PER_QUESTION = 45  # Increased from 30
MAX_RESULT_LENGTH = 500  # For tool outputs

# --- Model Loading ---
print("Loading optimized model...")
start_time = time.time()

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    use_fast=True,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded in {time.time() - start_time:.2f} seconds")

# --- Enhanced Tools ---
def web_search(query: str) -> str:
    """Enhanced web search with better result parsing"""
    try:
        if SERPER_API_KEY:
            params = {'q': query, 'num': 3, 'hl': 'en', 'gl': 'us'}
            headers = {'X-API-KEY': SERPER_API_KEY}
            response = requests.post(
                'https://google.serper.dev/search',
                headers=headers,
                json=params,
                timeout=10
            )
            results = response.json()
            
            if 'organic' in results:
                output = []
                for r in results['organic'][:3]:
                    if 'title' in r and 'snippet' in r:
                        output.append(f"{r['title']}: {r['snippet']}")
                return "\n".join(output)[:MAX_RESULT_LENGTH]
            return "No relevant results found"
        else:
            with DDGS() as ddgs:
                results = [r for r in ddgs.text(query, max_results=3)]
                return "\n".join([f"{r['title']}: {r['body']}" for r in results])[:MAX_RESULT_LENGTH]
    except Exception as e:
        return f"Search error: {str(e)}"

def calculator(expression: str) -> str:
    """More robust calculator with validation"""
    try:
        # Clean and validate expression
        expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression)
        if not expression:
            return "Invalid empty expression"
        
        # Handle percentages and commas
        expression = expression.replace('%', '/100').replace(',', '')
        result = numexpr.evaluate(expression)
        return str(float(result))
    except Exception as e:
        return f"Calculation error: {str(e)}"

def read_pdf(file_path: str) -> str:
    """PDF reader with better text extraction"""
    try:
        text = extract_text(file_path)
        if not text:
            return "No readable text found in PDF"
        
        # Clean and condense text
        text = re.sub(r'\s+', ' ', text).strip()
        return text[:MAX_RESULT_LENGTH]
    except Exception as e:
        return f"PDF read error: {str(e)}"

def read_webpage(url: str) -> str:
    """Improved webpage reader with better content extraction"""
    try:
        headers = {'User-Agent': 'Mozilla/5.0'}
        response = requests.get(url, timeout=10, headers=headers)
        response.raise_for_status()
        
        soup = BeautifulSoup(response.text, 'html.parser')
        
        # Remove unwanted elements
        for element in soup(['script', 'style', 'nav', 'footer']):
            element.decompose()
            
        # Get text with better formatting
        text = soup.get_text(separator='\n', strip=True)
        text = re.sub(r'\n{3,}', '\n\n', text)
        
        return text[:MAX_RESULT_LENGTH] if text else "No main content found"
    except Exception as e:
        return f"Webpage read error: {str(e)}"

TOOLS = {
    "web_search": web_search,
    "calculator": calculator,
    "read_pdf": read_pdf,
    "read_webpage": read_webpage
}

# --- Improved GAIA Agent ---
class GAIA_Agent:
    def __init__(self):
        self.tools = TOOLS
        self.system_prompt = """You are an advanced GAIA problem solver. Follow these steps:
1. Analyze the question carefully
2. Choose the most appropriate tool
3. Process the results
4. Provide a precise final answer

Available Tools:
- web_search: For general knowledge questions
- calculator: For math problems
- read_pdf: For PDF content extraction
- read_webpage: For webpage content extraction

Tool format: ```json
{"tool": "tool_name", "args": {"arg1": value}}```

Always end with: Final Answer: [your answer]"""

    def __call__(self, question: str) -> str:
        start_time = time.time()
        history = [f"Question: {question}"]
        
        try:
            for step in range(MAX_STEPS):
                if time.time() - start_time > TIMEOUT_PER_QUESTION:
                    return "Timeout: Processing took too long"
                
                prompt = self._build_prompt(history)
                response = self._call_model(prompt)
                
                if "Final Answer:" in response:
                    answer = response.split("Final Answer:")[-1].strip()
                    return answer[:500]  # Limit answer length
                
                tool_call = self._parse_tool_call(response)
                if tool_call:
                    tool_name, args = tool_call
                    observation = self._use_tool(tool_name, args)
                    history.append(f"Tool Used: {tool_name}")
                    history.append(f"Tool Result: {observation[:300]}...")  # Truncate long results
                else:
                    history.append(f"Analysis: {response}")
                
                gc.collect()
            
            return "Maximum steps reached without final answer"
        except Exception as e:
            return f"Error: {str(e)}"

    def _build_prompt(self, history: List[str]) -> str:
        return f"<|system|>\n{self.system_prompt}<|end|>\n<|user|>\n" + "\n".join(history) + "<|end|>\n<|assistant|>"

    def _call_model(self, prompt: str) -> str:
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=3072,
            padding=False
        )
        
        generation_config = GenerationConfig(
            max_new_tokens=MAX_TOKENS,
            temperature=0.3,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
        
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                generation_config=generation_config,
                attention_mask=inputs.attention_mask
            )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()

    def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]:
        try:
            json_match = re.search(r'```json\s*({.+?})\s*```', text, re.DOTALL)
            if json_match:
                tool_call = json.loads(json_match.group(1))
                if "tool" in tool_call and "args" in tool_call:
                    return tool_call["tool"], tool_call["args"]
        except:
            return None
        return None

    def _use_tool(self, tool_name: str, args: Dict) -> str:
        if tool_name not in self.tools:
            return f"Unknown tool: {tool_name}"
        
        try:
            # Special handling for URL-containing questions
            if tool_name == "read_webpage" and "url" not in args:
                if "args" in args and isinstance(args["args"], dict) and "url" in args["args"]:
                    args = args["args"]
                elif "http" in str(args):
                    url = re.search(r'https?://[^\s]+', str(args)).group()
                    args = {"url": url}
            
            return str(self.tools[tool_name](**args))[:MAX_RESULT_LENGTH]
        except Exception as e:
            return f"Tool error: {str(e)}"

# --- Evaluation Runner ---
def run_and_submit_all(profile: gr.OAuthProfile | None):
    if not profile:
        return "Please login first", None
    
    agent = GAIA_Agent()
    questions_url = f"{DEFAULT_API_URL}/questions"
    submit_url = f"{DEFAULT_API_URL}/submit"
    
    try:
        response = requests.get(questions_url, timeout=15)
        questions_data = response.json()
    except Exception as e:
        return f"Failed to get questions: {str(e)}", None
    
    results = []
    answers = []
    
    for i, item in enumerate(questions_data):
        task_id = item.get("task_id")
        question = item.get("question")
        
        if not task_id or not question:
            continue
            
        print(f"Processing question {i+1}/{len(questions_data)}")
        answer = agent(question)
        
        answers.append({"task_id": task_id, "submitted_answer": answer})
        results.append({
            "Task ID": task_id,
            "Question": question[:100] + "..." if len(question) > 100 else question,
            "Answer": answer[:100] + "..." if len(answer) > 100 else answer
        })
    
    submission = {
        "username": profile.username,
        "agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}",
        "answers": answers
    }
    
    try:
        response = requests.post(submit_url, json=submission, timeout=30)
        result = response.json()
        return f"Submitted! Score: {result.get('score', 'N/A')}", pd.DataFrame(results)
    except Exception as e:
        return f"Submission failed: {str(e)}", pd.DataFrame(results)

# --- Gradio Interface ---
with gr.Blocks(title="Enhanced GAIA Agent") as demo:
    gr.Markdown("## 🚀 Enhanced GAIA Agent Evaluation")
    gr.Markdown("""
    Improved version with:
    - Better tool utilization
    - Increased step/token limits
    - Enhanced error handling
    """)
    
    with gr.Row():
        gr.LoginButton()
        run_btn = gr.Button("Run Evaluation", variant="primary")
    
    output_status = gr.Textbox(label="Status")
    results_table = gr.DataFrame(label="Results")
    
    run_btn.click(
        run_and_submit_all,
        outputs=[output_status, results_table]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)