Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import requests | |
import json | |
import re | |
import numexpr | |
import pandas as pd | |
import math | |
from pdfminer.high_level import extract_text | |
from bs4 import BeautifulSoup | |
from typing import Dict, Any, List, Tuple, Optional | |
from dotenv import load_dotenv | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
import torch | |
import time | |
import gc | |
# --- Load Environment Variables --- | |
load_dotenv() | |
SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
# --- Constants --- | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
MAX_STEPS = 6 # Increased from 4 | |
MAX_TOKENS = 256 # Increased from 128 | |
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" | |
TIMEOUT_PER_QUESTION = 45 # Increased from 30 | |
MAX_RESULT_LENGTH = 500 # For tool outputs | |
# --- Model Loading --- | |
print("Loading optimized model...") | |
start_time = time.time() | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
device_map="auto", | |
low_cpu_mem_usage=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
use_fast=True, | |
trust_remote_code=True | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
print(f"Model loaded in {time.time() - start_time:.2f} seconds") | |
# --- Enhanced Tools --- | |
def web_search(query: str) -> str: | |
"""Enhanced web search with better result parsing""" | |
try: | |
if SERPER_API_KEY: | |
params = {'q': query, 'num': 3, 'hl': 'en', 'gl': 'us'} | |
headers = {'X-API-KEY': SERPER_API_KEY} | |
response = requests.post( | |
'https://google.serper.dev/search', | |
headers=headers, | |
json=params, | |
timeout=10 | |
) | |
results = response.json() | |
if 'organic' in results: | |
output = [] | |
for r in results['organic'][:3]: | |
if 'title' in r and 'snippet' in r: | |
output.append(f"{r['title']}: {r['snippet']}") | |
return "\n".join(output)[:MAX_RESULT_LENGTH] | |
return "No relevant results found" | |
else: | |
with DDGS() as ddgs: | |
results = [r for r in ddgs.text(query, max_results=3)] | |
return "\n".join([f"{r['title']}: {r['body']}" for r in results])[:MAX_RESULT_LENGTH] | |
except Exception as e: | |
return f"Search error: {str(e)}" | |
def calculator(expression: str) -> str: | |
"""More robust calculator with validation""" | |
try: | |
# Clean and validate expression | |
expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression) | |
if not expression: | |
return "Invalid empty expression" | |
# Handle percentages and commas | |
expression = expression.replace('%', '/100').replace(',', '') | |
result = numexpr.evaluate(expression) | |
return str(float(result)) | |
except Exception as e: | |
return f"Calculation error: {str(e)}" | |
def read_pdf(file_path: str) -> str: | |
"""PDF reader with better text extraction""" | |
try: | |
text = extract_text(file_path) | |
if not text: | |
return "No readable text found in PDF" | |
# Clean and condense text | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text[:MAX_RESULT_LENGTH] | |
except Exception as e: | |
return f"PDF read error: {str(e)}" | |
def read_webpage(url: str) -> str: | |
"""Improved webpage reader with better content extraction""" | |
try: | |
headers = {'User-Agent': 'Mozilla/5.0'} | |
response = requests.get(url, timeout=10, headers=headers) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.text, 'html.parser') | |
# Remove unwanted elements | |
for element in soup(['script', 'style', 'nav', 'footer']): | |
element.decompose() | |
# Get text with better formatting | |
text = soup.get_text(separator='\n', strip=True) | |
text = re.sub(r'\n{3,}', '\n\n', text) | |
return text[:MAX_RESULT_LENGTH] if text else "No main content found" | |
except Exception as e: | |
return f"Webpage read error: {str(e)}" | |
TOOLS = { | |
"web_search": web_search, | |
"calculator": calculator, | |
"read_pdf": read_pdf, | |
"read_webpage": read_webpage | |
} | |
# --- Improved GAIA Agent --- | |
class GAIA_Agent: | |
def __init__(self): | |
self.tools = TOOLS | |
self.system_prompt = """You are an advanced GAIA problem solver. Follow these steps: | |
1. Analyze the question carefully | |
2. Choose the most appropriate tool | |
3. Process the results | |
4. Provide a precise final answer | |
Available Tools: | |
- web_search: For general knowledge questions | |
- calculator: For math problems | |
- read_pdf: For PDF content extraction | |
- read_webpage: For webpage content extraction | |
Tool format: ```json | |
{"tool": "tool_name", "args": {"arg1": value}}``` | |
Always end with: Final Answer: [your answer]""" | |
def __call__(self, question: str) -> str: | |
start_time = time.time() | |
history = [f"Question: {question}"] | |
try: | |
for step in range(MAX_STEPS): | |
if time.time() - start_time > TIMEOUT_PER_QUESTION: | |
return "Timeout: Processing took too long" | |
prompt = self._build_prompt(history) | |
response = self._call_model(prompt) | |
if "Final Answer:" in response: | |
answer = response.split("Final Answer:")[-1].strip() | |
return answer[:500] # Limit answer length | |
tool_call = self._parse_tool_call(response) | |
if tool_call: | |
tool_name, args = tool_call | |
observation = self._use_tool(tool_name, args) | |
history.append(f"Tool Used: {tool_name}") | |
history.append(f"Tool Result: {observation[:300]}...") # Truncate long results | |
else: | |
history.append(f"Analysis: {response}") | |
gc.collect() | |
return "Maximum steps reached without final answer" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def _build_prompt(self, history: List[str]) -> str: | |
return f"<|system|>\n{self.system_prompt}<|end|>\n<|user|>\n" + "\n".join(history) + "<|end|>\n<|assistant|>" | |
def _call_model(self, prompt: str) -> str: | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=3072, | |
padding=False | |
) | |
generation_config = GenerationConfig( | |
max_new_tokens=MAX_TOKENS, | |
temperature=0.3, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
generation_config=generation_config, | |
attention_mask=inputs.attention_mask | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip() | |
def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]: | |
try: | |
json_match = re.search(r'```json\s*({.+?})\s*```', text, re.DOTALL) | |
if json_match: | |
tool_call = json.loads(json_match.group(1)) | |
if "tool" in tool_call and "args" in tool_call: | |
return tool_call["tool"], tool_call["args"] | |
except: | |
return None | |
return None | |
def _use_tool(self, tool_name: str, args: Dict) -> str: | |
if tool_name not in self.tools: | |
return f"Unknown tool: {tool_name}" | |
try: | |
# Special handling for URL-containing questions | |
if tool_name == "read_webpage" and "url" not in args: | |
if "args" in args and isinstance(args["args"], dict) and "url" in args["args"]: | |
args = args["args"] | |
elif "http" in str(args): | |
url = re.search(r'https?://[^\s]+', str(args)).group() | |
args = {"url": url} | |
return str(self.tools[tool_name](**args))[:MAX_RESULT_LENGTH] | |
except Exception as e: | |
return f"Tool error: {str(e)}" | |
# --- Evaluation Runner --- | |
def run_and_submit_all(profile: gr.OAuthProfile | None): | |
if not profile: | |
return "Please login first", None | |
agent = GAIA_Agent() | |
questions_url = f"{DEFAULT_API_URL}/questions" | |
submit_url = f"{DEFAULT_API_URL}/submit" | |
try: | |
response = requests.get(questions_url, timeout=15) | |
questions_data = response.json() | |
except Exception as e: | |
return f"Failed to get questions: {str(e)}", None | |
results = [] | |
answers = [] | |
for i, item in enumerate(questions_data): | |
task_id = item.get("task_id") | |
question = item.get("question") | |
if not task_id or not question: | |
continue | |
print(f"Processing question {i+1}/{len(questions_data)}") | |
answer = agent(question) | |
answers.append({"task_id": task_id, "submitted_answer": answer}) | |
results.append({ | |
"Task ID": task_id, | |
"Question": question[:100] + "..." if len(question) > 100 else question, | |
"Answer": answer[:100] + "..." if len(answer) > 100 else answer | |
}) | |
submission = { | |
"username": profile.username, | |
"agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}", | |
"answers": answers | |
} | |
try: | |
response = requests.post(submit_url, json=submission, timeout=30) | |
result = response.json() | |
return f"Submitted! Score: {result.get('score', 'N/A')}", pd.DataFrame(results) | |
except Exception as e: | |
return f"Submission failed: {str(e)}", pd.DataFrame(results) | |
# --- Gradio Interface --- | |
with gr.Blocks(title="Enhanced GAIA Agent") as demo: | |
gr.Markdown("## 🚀 Enhanced GAIA Agent Evaluation") | |
gr.Markdown(""" | |
Improved version with: | |
- Better tool utilization | |
- Increased step/token limits | |
- Enhanced error handling | |
""") | |
with gr.Row(): | |
gr.LoginButton() | |
run_btn = gr.Button("Run Evaluation", variant="primary") | |
output_status = gr.Textbox(label="Status") | |
results_table = gr.DataFrame(label="Results") | |
run_btn.click( | |
run_and_submit_all, | |
outputs=[output_status, results_table] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |