import os import gradio as gr import requests import json import re import numexpr import pandas as pd import math from pdfminer.high_level import extract_text from bs4 import BeautifulSoup from typing import Dict, Any, List, Tuple, Optional from dotenv import load_dotenv from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import torch import time import gc # --- Load Environment Variables --- load_dotenv() SERPER_API_KEY = os.getenv("SERPER_API_KEY") # --- Constants --- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" MAX_STEPS = 6 # Increased from 4 MAX_TOKENS = 256 # Increased from 128 MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" TIMEOUT_PER_QUESTION = 45 # Increased from 30 MAX_RESULT_LENGTH = 500 # For tool outputs # --- Model Loading --- print("Loading optimized model...") start_time = time.time() model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="auto", low_cpu_mem_usage=True ) tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Model loaded in {time.time() - start_time:.2f} seconds") # --- Enhanced Tools --- def web_search(query: str) -> str: """Enhanced web search with better result parsing""" try: if SERPER_API_KEY: params = {'q': query, 'num': 3, 'hl': 'en', 'gl': 'us'} headers = {'X-API-KEY': SERPER_API_KEY} response = requests.post( 'https://google.serper.dev/search', headers=headers, json=params, timeout=10 ) results = response.json() if 'organic' in results: output = [] for r in results['organic'][:3]: if 'title' in r and 'snippet' in r: output.append(f"{r['title']}: {r['snippet']}") return "\n".join(output)[:MAX_RESULT_LENGTH] return "No relevant results found" else: with DDGS() as ddgs: results = [r for r in ddgs.text(query, max_results=3)] return "\n".join([f"{r['title']}: {r['body']}" for r in results])[:MAX_RESULT_LENGTH] except Exception as e: return f"Search error: {str(e)}" def calculator(expression: str) -> str: """More robust calculator with validation""" try: # Clean and validate expression expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression) if not expression: return "Invalid empty expression" # Handle percentages and commas expression = expression.replace('%', '/100').replace(',', '') result = numexpr.evaluate(expression) return str(float(result)) except Exception as e: return f"Calculation error: {str(e)}" def read_pdf(file_path: str) -> str: """PDF reader with better text extraction""" try: text = extract_text(file_path) if not text: return "No readable text found in PDF" # Clean and condense text text = re.sub(r'\s+', ' ', text).strip() return text[:MAX_RESULT_LENGTH] except Exception as e: return f"PDF read error: {str(e)}" def read_webpage(url: str) -> str: """Improved webpage reader with better content extraction""" try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, timeout=10, headers=headers) response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') # Remove unwanted elements for element in soup(['script', 'style', 'nav', 'footer']): element.decompose() # Get text with better formatting text = soup.get_text(separator='\n', strip=True) text = re.sub(r'\n{3,}', '\n\n', text) return text[:MAX_RESULT_LENGTH] if text else "No main content found" except Exception as e: return f"Webpage read error: {str(e)}" TOOLS = { "web_search": web_search, "calculator": calculator, "read_pdf": read_pdf, "read_webpage": read_webpage } # --- Improved GAIA Agent --- class GAIA_Agent: def __init__(self): self.tools = TOOLS self.system_prompt = """You are an advanced GAIA problem solver. Follow these steps: 1. Analyze the question carefully 2. Choose the most appropriate tool 3. Process the results 4. Provide a precise final answer Available Tools: - web_search: For general knowledge questions - calculator: For math problems - read_pdf: For PDF content extraction - read_webpage: For webpage content extraction Tool format: ```json {"tool": "tool_name", "args": {"arg1": value}}``` Always end with: Final Answer: [your answer]""" def __call__(self, question: str) -> str: start_time = time.time() history = [f"Question: {question}"] try: for step in range(MAX_STEPS): if time.time() - start_time > TIMEOUT_PER_QUESTION: return "Timeout: Processing took too long" prompt = self._build_prompt(history) response = self._call_model(prompt) if "Final Answer:" in response: answer = response.split("Final Answer:")[-1].strip() return answer[:500] # Limit answer length tool_call = self._parse_tool_call(response) if tool_call: tool_name, args = tool_call observation = self._use_tool(tool_name, args) history.append(f"Tool Used: {tool_name}") history.append(f"Tool Result: {observation[:300]}...") # Truncate long results else: history.append(f"Analysis: {response}") gc.collect() return "Maximum steps reached without final answer" except Exception as e: return f"Error: {str(e)}" def _build_prompt(self, history: List[str]) -> str: return f"<|system|>\n{self.system_prompt}<|end|>\n<|user|>\n" + "\n".join(history) + "<|end|>\n<|assistant|>" def _call_model(self, prompt: str) -> str: inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=3072, padding=False ) generation_config = GenerationConfig( max_new_tokens=MAX_TOKENS, temperature=0.3, top_p=0.9, do_sample=True, pad_token_id=tokenizer.pad_token_id ) with torch.no_grad(): outputs = model.generate( inputs.input_ids, generation_config=generation_config, attention_mask=inputs.attention_mask ) return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip() def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]: try: json_match = re.search(r'```json\s*({.+?})\s*```', text, re.DOTALL) if json_match: tool_call = json.loads(json_match.group(1)) if "tool" in tool_call and "args" in tool_call: return tool_call["tool"], tool_call["args"] except: return None return None def _use_tool(self, tool_name: str, args: Dict) -> str: if tool_name not in self.tools: return f"Unknown tool: {tool_name}" try: # Special handling for URL-containing questions if tool_name == "read_webpage" and "url" not in args: if "args" in args and isinstance(args["args"], dict) and "url" in args["args"]: args = args["args"] elif "http" in str(args): url = re.search(r'https?://[^\s]+', str(args)).group() args = {"url": url} return str(self.tools[tool_name](**args))[:MAX_RESULT_LENGTH] except Exception as e: return f"Tool error: {str(e)}" # --- Evaluation Runner --- def run_and_submit_all(profile: gr.OAuthProfile | None): if not profile: return "Please login first", None agent = GAIA_Agent() questions_url = f"{DEFAULT_API_URL}/questions" submit_url = f"{DEFAULT_API_URL}/submit" try: response = requests.get(questions_url, timeout=15) questions_data = response.json() except Exception as e: return f"Failed to get questions: {str(e)}", None results = [] answers = [] for i, item in enumerate(questions_data): task_id = item.get("task_id") question = item.get("question") if not task_id or not question: continue print(f"Processing question {i+1}/{len(questions_data)}") answer = agent(question) answers.append({"task_id": task_id, "submitted_answer": answer}) results.append({ "Task ID": task_id, "Question": question[:100] + "..." if len(question) > 100 else question, "Answer": answer[:100] + "..." if len(answer) > 100 else answer }) submission = { "username": profile.username, "agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}", "answers": answers } try: response = requests.post(submit_url, json=submission, timeout=30) result = response.json() return f"Submitted! Score: {result.get('score', 'N/A')}", pd.DataFrame(results) except Exception as e: return f"Submission failed: {str(e)}", pd.DataFrame(results) # --- Gradio Interface --- with gr.Blocks(title="Enhanced GAIA Agent") as demo: gr.Markdown("## 🚀 Enhanced GAIA Agent Evaluation") gr.Markdown(""" Improved version with: - Better tool utilization - Increased step/token limits - Enhanced error handling """) with gr.Row(): gr.LoginButton() run_btn = gr.Button("Run Evaluation", variant="primary") output_status = gr.Textbox(label="Status") results_table = gr.DataFrame(label="Results") run_btn.click( run_and_submit_all, outputs=[output_status, results_table] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)