refactor: bbh/tmlu test case
Browse files- tests/test_bbh_parser.py +29 -65
- tests/test_tmlu_parser.py +26 -62
tests/test_bbh_parser.py
CHANGED
@@ -160,74 +160,38 @@ def test_different_tasks_parsing(bbh_parser, task_name):
|
|
160 |
assert all(isinstance(entry.answer, str) for entry in parsed_data)
|
161 |
|
162 |
|
163 |
-
def
|
164 |
-
"""Test
|
165 |
-
metrics = bbh_parser.get_evaluation_metrics()
|
166 |
-
|
167 |
-
# Check basic structure
|
168 |
-
assert isinstance(metrics, list)
|
169 |
-
assert len(metrics) > 0
|
170 |
-
|
171 |
-
# Check each metric has required fields
|
172 |
-
required_fields = ["name", "type", "description", "implementation", "primary"]
|
173 |
-
for metric in metrics:
|
174 |
-
for field in required_fields:
|
175 |
-
assert field in metric, f"Missing field {field} in metric {metric['name']}"
|
176 |
-
|
177 |
-
# Check field types
|
178 |
-
assert isinstance(metric["name"], str)
|
179 |
-
assert isinstance(metric["type"], str)
|
180 |
-
assert isinstance(metric["description"], str)
|
181 |
-
assert isinstance(metric["implementation"], str)
|
182 |
-
assert isinstance(metric["primary"], bool)
|
183 |
-
|
184 |
-
# Check specific metrics exist
|
185 |
-
metric_names = {m["name"] for m in metrics}
|
186 |
-
expected_metrics = {
|
187 |
-
"accuracy",
|
188 |
-
"human_eval_delta",
|
189 |
-
"per_task_accuracy",
|
190 |
-
"exact_match",
|
191 |
-
}
|
192 |
-
assert expected_metrics.issubset(metric_names)
|
193 |
-
|
194 |
-
# Check primary metrics
|
195 |
-
primary_metrics = {m["name"] for m in metrics if m["primary"]}
|
196 |
-
assert "accuracy" in primary_metrics
|
197 |
-
assert "human_eval_delta" in primary_metrics
|
198 |
-
|
199 |
-
|
200 |
-
def test_dataset_description_citation_format(bbh_parser):
|
201 |
-
"""Test that the citation in dataset description is properly formatted."""
|
202 |
description = bbh_parser.get_dataset_description()
|
203 |
-
citation = description["citation"]
|
204 |
-
|
205 |
-
# Check citation structure
|
206 |
-
assert citation.startswith("@article{")
|
207 |
-
assert "title=" in citation
|
208 |
-
assert "author=" in citation
|
209 |
-
assert "journal=" in citation
|
210 |
-
assert "year=" in citation
|
211 |
|
212 |
-
|
213 |
-
assert "
|
214 |
-
assert
|
215 |
-
assert
|
216 |
-
assert "
|
|
|
|
|
|
|
|
|
217 |
|
218 |
|
219 |
-
def
|
220 |
-
"""Test
|
221 |
metrics = bbh_parser.get_evaluation_metrics()
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
160 |
assert all(isinstance(entry.answer, str) for entry in parsed_data)
|
161 |
|
162 |
|
163 |
+
def test_get_dataset_description(bbh_parser):
|
164 |
+
"""Test dataset description generation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
description = bbh_parser.get_dataset_description()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
+
assert description.name == "Big Bench Hard (BBH)"
|
168 |
+
assert "challenging BIG-Bench tasks" in description.purpose
|
169 |
+
assert description.language == "English"
|
170 |
+
assert description.format == "Multiple choice questions with single correct answers"
|
171 |
+
assert "Tasks require complex multi-step reasoning" in description.characteristics
|
172 |
+
assert "suzgun2022challenging" in description.citation
|
173 |
+
assert description.additional_info is not None
|
174 |
+
assert "model_performance" in description.additional_info
|
175 |
+
assert "size" in description.additional_info
|
176 |
|
177 |
|
178 |
+
def test_get_evaluation_metrics(bbh_parser):
|
179 |
+
"""Test evaluation metrics generation."""
|
180 |
metrics = bbh_parser.get_evaluation_metrics()
|
181 |
|
182 |
+
assert len(metrics) == 4 # Check total number of metrics
|
183 |
+
|
184 |
+
# Check primary metrics
|
185 |
+
primary_metrics = [m for m in metrics if m.primary]
|
186 |
+
assert len(primary_metrics) == 2
|
187 |
+
assert any(m.name == "accuracy" for m in primary_metrics)
|
188 |
+
assert any(m.name == "human_eval_delta" for m in primary_metrics)
|
189 |
+
|
190 |
+
# Check specific metric properties
|
191 |
+
accuracy_metric = next(m for m in metrics if m.name == "accuracy")
|
192 |
+
assert accuracy_metric.type == "classification"
|
193 |
+
assert "evaluate.load('accuracy')" in accuracy_metric.implementation
|
194 |
+
|
195 |
+
# Check non-primary metrics
|
196 |
+
assert any(m.name == "per_task_accuracy" and not m.primary for m in metrics)
|
197 |
+
assert any(m.name == "exact_match" and not m.primary for m in metrics)
|
tests/test_tmlu_parser.py
CHANGED
@@ -170,76 +170,40 @@ def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
|
|
170 |
assert entry.metadata["source"] == "AST chinese - 108"
|
171 |
|
172 |
|
173 |
-
def
|
174 |
-
"""Test dataset description
|
175 |
description = tmlu_parser.get_dataset_description()
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
"size",
|
185 |
-
"domain",
|
186 |
-
"characteristics",
|
187 |
-
"reference",
|
188 |
-
]
|
189 |
-
|
190 |
-
for field in required_fields:
|
191 |
-
assert field in description, f"Missing required field: {field}"
|
192 |
|
193 |
-
assert description["language"] == "Traditional Chinese"
|
194 |
-
assert "TMLU" in description["name"]
|
195 |
-
assert "miulab/tmlu" in description["reference"]
|
196 |
-
assert "AST" in description["characteristics"]
|
197 |
-
assert "GSAT" in description["characteristics"]
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
"""Test evaluation metrics structure and content."""
|
202 |
metrics = tmlu_parser.get_evaluation_metrics()
|
203 |
|
204 |
-
# Check
|
205 |
-
assert len(metrics) > 0
|
206 |
|
207 |
-
# Check
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
"implementation",
|
213 |
-
"primary",
|
214 |
-
]
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
assert isinstance(metric["implementation"], str)
|
225 |
-
assert isinstance(metric["primary"], bool)
|
226 |
-
|
227 |
-
# Check for TMLU-specific metrics
|
228 |
-
metric_names = {m["name"] for m in metrics}
|
229 |
-
expected_metrics = {
|
230 |
-
"accuracy",
|
231 |
-
"per_subject_accuracy",
|
232 |
"per_difficulty_accuracy",
|
|
|
233 |
"explanation_quality",
|
234 |
}
|
235 |
-
|
236 |
-
for expected in expected_metrics:
|
237 |
-
assert expected in metric_names, f"Missing expected metric: {expected}"
|
238 |
-
|
239 |
-
# Verify primary metrics
|
240 |
-
primary_metrics = [m for m in metrics if m["primary"]]
|
241 |
-
assert (
|
242 |
-
len(primary_metrics) >= 2
|
243 |
-
) # Should have at least accuracy and per_subject_accuracy
|
244 |
-
assert any(m["name"] == "accuracy" for m in primary_metrics)
|
245 |
-
assert any(m["name"] == "per_subject_accuracy" for m in primary_metrics)
|
|
|
170 |
assert entry.metadata["source"] == "AST chinese - 108"
|
171 |
|
172 |
|
173 |
+
def test_get_dataset_description(tmlu_parser):
|
174 |
+
"""Test dataset description generation."""
|
175 |
description = tmlu_parser.get_dataset_description()
|
176 |
|
177 |
+
assert description.name == "Taiwan Multiple-choice Language Understanding (TMLU)"
|
178 |
+
assert description.language == "Traditional Chinese"
|
179 |
+
assert "Taiwan-specific educational" in description.purpose
|
180 |
+
assert "Various Taiwan standardized tests" in description.source
|
181 |
+
assert description.format == "Multiple choice questions (A/B/C/D)"
|
182 |
+
assert "Advanced Subjects Test (AST)" in description.characteristics
|
183 |
+
assert "DBLP:journals/corr/abs-2403-20180" in description.citation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
+
def test_get_evaluation_metrics(tmlu_parser):
|
187 |
+
"""Test evaluation metrics generation."""
|
|
|
188 |
metrics = tmlu_parser.get_evaluation_metrics()
|
189 |
|
190 |
+
assert len(metrics) == 5 # Check total number of metrics
|
|
|
191 |
|
192 |
+
# Check primary metrics
|
193 |
+
primary_metrics = [m for m in metrics if m.primary]
|
194 |
+
assert len(primary_metrics) == 2
|
195 |
+
assert any(m.name == "accuracy" for m in primary_metrics)
|
196 |
+
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
|
|
|
|
|
|
|
197 |
|
198 |
+
# Check specific metric properties
|
199 |
+
accuracy_metric = next(m for m in metrics if m.name == "accuracy")
|
200 |
+
assert accuracy_metric.type == "classification"
|
201 |
+
assert "datasets.load_metric('accuracy')" in accuracy_metric.implementation
|
202 |
+
|
203 |
+
# Check non-primary metrics
|
204 |
+
non_primary_metrics = {m.name for m in metrics if not m.primary}
|
205 |
+
assert non_primary_metrics == {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
"per_difficulty_accuracy",
|
207 |
+
"confusion_matrix",
|
208 |
"explanation_quality",
|
209 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|