File size: 4,971 Bytes
a6998ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import json
import os
import pytest
from pathlib import Path
import functools
from typing import Callable, Type, Any, Dict, Optional

from gagent.agents import BaseAgent, GeminiAgent
from tests.agents.fixtures import (
    agent_factory,
    ollama_agent,
    gemini_agent,
    openai_agent,
)


class TestAgents:
    """Test suite for agents with GAIA data."""

    @staticmethod
    def load_questions():
        """Load questions from questions.json file."""
        with open("exp/questions.json", "r") as f:
            return json.load(f)

    @staticmethod
    def load_validation_data():
        """Load validation data from GAIA dataset metadata."""
        validation_data = {}

        with open("metadata.jsonl", "r") as f:
            for line in f:
                data = json.loads(line)
                validation_data[data["task_id"]] = data["Final answer"]

        return validation_data

    def _run_agent_test(self, agent: BaseAgent, num_questions: int = 2):
        """
        Common test implementation for all agent types

        Args:
            agent: The agent to test
            num_questions: Number of questions to test (default: 2)

        Returns:
            Tuple of (correct_count, total_tested)
        """
        questions = self.load_questions()
        validation_data = self.load_validation_data()

        # Limit number of questions for testing
        questions = questions[:num_questions]

        # Keep track of correct answers
        correct_count = 0
        total_tested = 0
        total_questions = len(questions)
        for i, question_data in enumerate(questions):
            task_id = question_data["task_id"]
            if task_id not in validation_data:
                continue

            question = question_data["question"]
            expected_answer = validation_data[task_id]

            print(f"Testing question {i + 1}: {question[:50]}...")

            # Call the agent with the question
            response = agent.run(question, question_number=i + 1, total_questions=total_questions)

            # Extract the final answer from the response
            # Assuming the agent follows the format with "FINAL ANSWER: [answer]"
            if "FINAL ANSWER:" in response:
                answer = response.split("FINAL ANSWER:")[1].strip()
            else:
                answer = response.strip()

            # Check if the answer is correct (exact match)
            is_correct = answer == expected_answer
            if is_correct:
                correct_count += 1

            total_tested += 1

            print(f"Expected: {expected_answer}")
            print(f"Got: {answer}")
            print(f"Result: {'✓' if is_correct else '✗'}")
            print("-" * 80)

        # Compute accuracy
        accuracy = correct_count / total_tested if total_tested > 0 else 0
        print(f"Accuracy: {accuracy:.2%} ({correct_count}/{total_tested})")

        return correct_count, total_tested

    # def test_ollama_agent_with_gaia_data(self, ollama_agent: BaseAgent):
    #     """Test the Ollama agent with GAIA dataset questions and validate against ground truth."""
    #     correct_count, total_tested = self._run_agent_test(agent)

    #     # At least one correct answer required to pass the test
    #     assert correct_count > 0, "Agent should get at least one answer correct"

    # def test_gemini_agent_with_gaia_data(self, gemini_agent: GeminiAgent):
    #     """Test the Gemini agent with the same GAIA test approach."""
    #     correct_count, total_tested = self._run_agent_test(gemini_agent, num_questions=2)

    #     # At least one correct answer required to pass the test
    #     assert correct_count > 0, "Agent should get at least one answer correct"

    @pytest.mark.parametrize("agent_type,model_name", [("ollama", "phi4-mini")])
    def test_ollama_with_different_model(self, agent_factory, agent_type, model_name):
        """Test Ollama agent with a different model."""
        agent = agent_factory(agent_type=agent_type, model_name=model_name)
        correct_count, total_tested = self._run_agent_test(agent, num_questions=3)

        # Just verify it runs, not accuracy
        assert correct_count > 0, "Should test at least one question"

    # def test_ollama_with_different_model(self, ollama_agent: BaseAgent):
    #     """Test Ollama agent with a different model."""
    #     correct_count, total_tested = self._run_agent_test(ollama_agent, num_questions=3)

    #     # Just verify it runs, not accuracy
    #     assert correct_count > 0, "Should test at least one question"

    # Can be uncommented when OpenAI API key is available
    # def test_openai_agent_with_gaia_data(self, openai_agent: BaseAgent):
    #     """Test the OpenAI agent with the same GAIA test approach."""
    #     correct_count, total_tested = self._run_agent_test(agent)
    #     assert correct_count > 0, "Agent should get at least one answer correct"