File size: 8,133 Bytes
6ed7950 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
import pytest
from llmdataparser.tmlu_parser import TMLUDatasetParser, TMLUParseEntry
@pytest.fixture
def tmlu_parser():
"""Create a TMLU parser instance for testing."""
return TMLUDatasetParser()
@pytest.fixture
def sample_tmlu_entries():
"""Create sample TMLU dataset entries for testing."""
return [
{
"question": "閱讀下文,選出依序最適合填入□內的選項:",
"A": "張揚/綢繆未雨/奏疏",
"B": "抽搐/煮繭抽絲/奏疏",
"C": "張揚/煮繭抽絲/進貢",
"D": "抽搐/綢繆未雨/進貢",
"answer": "B",
"explanation": "根據文意,選項B最為恰當。",
"metadata": {
"timestamp": "2023-10-09T18:27:20.304623",
"source": "AST chinese - 108",
"explanation_source": "",
},
},
{
"question": "下列何者是質數?",
"A": "21",
"B": "27",
"C": "31",
"D": "33",
"answer": "C",
"explanation": "31是質數,其他選項都是合數。",
"metadata": {
"timestamp": "2023-10-09T18:27:20.304623",
"source": "AST mathematics - 108",
"explanation_source": "",
},
},
]
def test_tmlu_parse_entry_creation_valid():
"""Test valid creation of TMLUParseEntry."""
entry = TMLUParseEntry.create(
prompt="Test prompt",
answer="A",
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer="A",
task_name="AST_chinese",
explanation="Test explanation",
metadata={"source": "test"},
)
assert isinstance(entry, TMLUParseEntry)
assert entry.prompt == "Test prompt"
assert entry.answer == "A"
assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
assert entry.explanation == "Test explanation"
assert entry.metadata == {"source": "test"}
@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
def test_tmlu_parse_entry_creation_invalid(invalid_answer):
"""Test invalid answer handling in TMLUParseEntry creation."""
with pytest.raises(
ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
):
TMLUParseEntry.create(
prompt="Test prompt",
answer=invalid_answer,
raw_question="Test question",
raw_choices=["choice1", "choice2", "choice3", "choice4"],
raw_answer=invalid_answer,
task_name="AST_chinese",
)
def test_process_entry(tmlu_parser, sample_tmlu_entries):
"""Test processing entries in TMLU parser."""
entry = tmlu_parser.process_entry(sample_tmlu_entries[0], task_name="AST_chinese")
assert isinstance(entry, TMLUParseEntry)
assert entry.answer == "B"
assert entry.task_name == "AST_chinese"
assert len(entry.raw_choices) == 4
assert entry.explanation == "根據文意,選項B最為恰當。"
assert "AST chinese - 108" in entry.metadata["source"]
def test_tmlu_parser_initialization(tmlu_parser):
"""Test TMLU parser initialization and properties."""
assert isinstance(tmlu_parser.task_names, list)
assert len(tmlu_parser.task_names) == 37 # Total number of tasks
assert tmlu_parser._data_source == "miulab/tmlu"
assert tmlu_parser._default_task == "AST_chinese"
assert "AST_chinese" in tmlu_parser.task_names
assert "GSAT_mathematics" in tmlu_parser.task_names
assert (
tmlu_parser.get_huggingface_link
== "https://huggingface.co/datasets/miulab/tmlu"
)
@pytest.mark.integration
def test_load_dataset(tmlu_parser):
"""Test loading the TMLU dataset."""
tmlu_parser.load(task_name="AST_chinese", split="test")
assert tmlu_parser.raw_data is not None
assert tmlu_parser.split_names == ["test"]
assert tmlu_parser._current_task == "AST_chinese"
def test_parser_string_representation(tmlu_parser):
"""Test string representation of TMLU parser."""
repr_str = str(tmlu_parser)
assert "TMLUDatasetParser" in repr_str
assert "miulab/tmlu" in repr_str
assert "not loaded" in repr_str
@pytest.mark.integration
def test_different_tasks_parsing(tmlu_parser):
"""Test parsing different tasks of the dataset."""
# Load and parse AST_chinese
tmlu_parser.load(task_name="AST_chinese", split="test")
tmlu_parser.parse(split_names="test", force=True)
chinese_count = len(tmlu_parser.get_parsed_data)
# Load and parse AST_mathematics
tmlu_parser.load(task_name="AST_mathematics", split="test")
tmlu_parser.parse(split_names="test", force=True)
math_count = len(tmlu_parser.get_parsed_data)
assert chinese_count > 0
assert math_count > 0
def test_system_prompt_override(tmlu_parser):
"""Test overriding the default system prompt."""
custom_prompt = "Custom system prompt for testing"
parser = TMLUDatasetParser(system_prompt=custom_prompt)
test_entry = {
"question": "Test question",
"A": "Choice A",
"B": "Choice B",
"C": "Choice C",
"D": "Choice D",
"answer": "A",
"explanation": "Test explanation",
"metadata": {"source": "test"},
}
entry = parser.process_entry(test_entry)
assert custom_prompt in entry.prompt
def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
"""Test proper handling of metadata in entries."""
entry = tmlu_parser.process_entry(sample_tmlu_entries[0])
assert "timestamp" in entry.metadata
assert "source" in entry.metadata
assert "explanation_source" in entry.metadata
assert entry.metadata["source"] == "AST chinese - 108"
def test_dataset_description(tmlu_parser):
"""Test dataset description contains all required fields."""
description = tmlu_parser.get_dataset_description()
required_fields = [
"name",
"version",
"language",
"purpose",
"source",
"format",
"size",
"domain",
"characteristics",
"reference",
]
for field in required_fields:
assert field in description, f"Missing required field: {field}"
assert description["language"] == "Traditional Chinese"
assert "TMLU" in description["name"]
assert "miulab/tmlu" in description["reference"]
assert "AST" in description["characteristics"]
assert "GSAT" in description["characteristics"]
def test_evaluation_metrics(tmlu_parser):
"""Test evaluation metrics structure and content."""
metrics = tmlu_parser.get_evaluation_metrics()
# Check if we have metrics defined
assert len(metrics) > 0
# Check structure of each metric
required_metric_fields = [
"name",
"type",
"description",
"implementation",
"primary",
]
for metric in metrics:
for field in required_metric_fields:
assert field in metric, f"Missing required field in metric: {field}"
# Type checks
assert isinstance(metric["name"], str)
assert isinstance(metric["type"], str)
assert isinstance(metric["description"], str)
assert isinstance(metric["implementation"], str)
assert isinstance(metric["primary"], bool)
# Check for TMLU-specific metrics
metric_names = {m["name"] for m in metrics}
expected_metrics = {
"accuracy",
"per_subject_accuracy",
"per_difficulty_accuracy",
"explanation_quality",
}
for expected in expected_metrics:
assert expected in metric_names, f"Missing expected metric: {expected}"
# Verify primary metrics
primary_metrics = [m for m in metrics if m["primary"]]
assert (
len(primary_metrics) >= 2
) # Should have at least accuracy and per_subject_accuracy
assert any(m["name"] == "accuracy" for m in primary_metrics)
assert any(m["name"] == "per_subject_accuracy" for m in primary_metrics)
|