LamiaYT's picture
fixing
aa6f3a8
raw
history blame
10.1 kB
import os
import gradio as gr
import requests
import json
import re
import numexpr
import pandas as pd
from pdfminer.high_level import extract_text
from bs4 import BeautifulSoup
from typing import List, Dict, Optional, Tuple
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import gc
# --- Configuration ---
load_dotenv()
SERPER_API_KEY = os.getenv("SERPER_API_KEY")
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# --- Constants ---
MAX_STEPS = 6
MAX_TOKENS = 256
TIMEOUT_PER_QUESTION = 45
MAX_RESULT_LENGTH = 500
MAX_ATTEMPTS = 2
# --- Model Initialization ---
print("Initializing model with fixed cache configuration...")
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=True,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Model loaded in {time.time() - start_time:.2f} seconds")
# --- Tool Implementations ---
def web_search(query: str) -> str:
try:
if not SERPER_API_KEY:
return "Search API key not configured"
params = {'q': query, 'num': 3}
headers = {'X-API-KEY': SERPER_API_KEY}
response = requests.post(
'https://google.serper.dev/search',
headers=headers,
json=params,
timeout=10
)
response.raise_for_status()
results = response.json()
if 'organic' not in results or not results['organic']:
return "No relevant results found"
output = []
for r in results['organic'][:3]:
if 'title' in r and 'snippet' in r:
output.append(f"Title: {r['title']}\nSnippet: {r['snippet']}")
return "\n\n".join(output)[:MAX_RESULT_LENGTH]
except Exception as e:
return f"Search error: {str(e)}"
def calculator(expression: str) -> str:
try:
expression = re.sub(r'[^\d+\-*/().^%,\s]', '', expression)
if not expression:
return "Invalid empty expression"
return str(numexpr.evaluate(expression))
except Exception as e:
return f"Calculation error: {str(e)}"
def read_webpage(url: str) -> str:
try:
if not re.match(r'^https?://', url):
return "Invalid URL format"
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, timeout=15, headers=headers)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
for element in soup(['script', 'style', 'nav', 'footer', 'aside']):
element.decompose()
main_content = soup.find('main') or soup.find('article') or soup
text = main_content.get_text(separator='\n', strip=True)
text = re.sub(r'\n{3,}', '\n\n', text)
return text[:MAX_RESULT_LENGTH]
except Exception as e:
return f"Webpage error: {str(e)}"
TOOLS = {
"web_search": web_search,
"calculator": calculator,
"read_webpage": read_webpage
}
# --- GAIA Agent Class ---
class GAIA_Agent:
def __init__(self):
self.tools = TOOLS
self.system_prompt = """You are an advanced problem solver. Follow these steps:
1. Analyze the question
2. Select the best tool
3. Execute with proper arguments
4. Interpret results
5. Provide final answer
Tools:
- web_search(query): For general knowledge
- calculator(expression): For math
- read_webpage(url): For web content
Tool format: ```json
{"tool": "tool_name", "args": {"arg": value}}```
Always conclude with: Final Answer: [answer]"""
def __call__(self, question: str) -> str:
start_time = time.time()
history = [f"Question: {question}"]
try:
for step in range(MAX_STEPS):
if time.time() - start_time > TIMEOUT_PER_QUESTION:
return "Timeout: Processing took too long"
prompt = self._build_prompt(history)
response = self._call_model(prompt)
if "Final Answer:" in response:
return response.split("Final Answer:")[-1].strip()[:500]
tool_call = self._parse_tool_call(response)
if tool_call:
tool_name, args = tool_call
observation = self._use_tool(tool_name, args)
history.append(f"Tool: {tool_name}")
history.append(f"Result: {observation[:300]}...")
else:
history.append(f"Thought: {response}")
gc.collect()
return "Maximum steps reached"
except Exception as e:
return f"Agent error: {str(e)}"
def _build_prompt(self, history: List[str]) -> str:
return f"<|system|>\n{self.system_prompt}<|end|>\n<|user|>\n" + "\n".join(history) + "<|end|>\n<|assistant|>"
def _call_model(self, prompt: str) -> str:
for attempt in range(MAX_ATTEMPTS):
try:
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=3072,
padding=False
)
outputs = model.generate(
inputs.input_ids,
max_new_tokens=MAX_TOKENS,
temperature=0.3,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
attention_mask=inputs.attention_mask
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
except Exception as e:
if attempt < MAX_ATTEMPTS - 1:
time.sleep(0.5)
continue
return f"Model error: {str(e)}"
def _parse_tool_call(self, text: str) -> Optional[Tuple[str, Dict]]:
try:
json_match = re.search(r'```json\s*({.+?})\s*```', text, re.DOTALL)
if not json_match:
return None
tool_call = json.loads(json_match.group(1))
if not isinstance(tool_call, dict):
return None
if "tool" not in tool_call or "args" not in tool_call:
return None
if not isinstance(tool_call["args"], dict):
return None
return tool_call["tool"], tool_call["args"]
except:
return None
def _use_tool(self, tool_name: str, args: Dict) -> str:
if tool_name not in self.tools:
return f"Unknown tool: {tool_name}"
try:
if tool_name == "read_webpage" and "url" not in args:
url_match = re.search(r'https?://[^\s]+', str(args))
if url_match:
args = {"url": url_match.group()}
else:
return "Missing URL argument"
return str(self.tools[tool_name](**args))[:MAX_RESULT_LENGTH]
except Exception as e:
return f"Tool error: {str(e)}"
# --- Evaluation Function ---
def run_evaluation(profile: gr.OAuthProfile | None):
if not profile:
return "Please login first", None
agent = GAIA_Agent()
questions_url = f"{DEFAULT_API_URL}/questions"
submit_url = f"{DEFAULT_API_URL}/submit"
try:
response = requests.get(questions_url, timeout=20)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
return "No questions available", None
except Exception as e:
return f"Failed to get questions: {str(e)}", None
results = []
answers = []
for i, item in enumerate(questions_data):
task_id = item.get("task_id")
question = item.get("question")
if not task_id or not question:
continue
print(f"Processing question {i+1}/{len(questions_data)}")
answer = agent(question)
answers.append({"task_id": task_id, "submitted_answer": answer})
results.append({
"Task ID": task_id,
"Question": question[:100] + "..." if len(question) > 100 else question,
"Answer": answer[:100] + "..." if len(answer) > 100 else answer
})
submission = {
"username": profile.username,
"agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID')}",
"answers": answers
}
try:
response = requests.post(submit_url, json=submission, timeout=60)
response.raise_for_status()
result = response.json()
status = (f"✅ Submission Successful!\n"
f"Score: {result.get('score', 'N/A')}%\n"
f"Correct: {result.get('correct_count', '?')}/{result.get('total_attempted', '?')}")
return status, pd.DataFrame(results)
except Exception as e:
return f"❌ Submission failed: {str(e)}", pd.DataFrame(results)
# --- Gradio Interface ---
with gr.Blocks(title="Fixed GAIA Agent", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🚀 GAIA Agent Evaluation")
with gr.Row():
gr.LoginButton()
run_btn = gr.Button("Run Evaluation", variant="primary")
status_output = gr.Textbox(label="Status")
results_table = gr.DataFrame(label="Results")
run_btn.click(
run_evaluation,
outputs=[status_output, results_table]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860
)