JeffYang52415 commited on
Commit
58d5612
·
unverified ·
1 Parent(s): 299e68a

refactor: bbh/tmlu test case

Browse files
Files changed (2) hide show
  1. tests/test_bbh_parser.py +29 -65
  2. tests/test_tmlu_parser.py +26 -62
tests/test_bbh_parser.py CHANGED
@@ -160,74 +160,38 @@ def test_different_tasks_parsing(bbh_parser, task_name):
160
  assert all(isinstance(entry.answer, str) for entry in parsed_data)
161
 
162
 
163
- def test_get_evaluation_metrics(bbh_parser):
164
- """Test evaluation metrics structure and content."""
165
- metrics = bbh_parser.get_evaluation_metrics()
166
-
167
- # Check basic structure
168
- assert isinstance(metrics, list)
169
- assert len(metrics) > 0
170
-
171
- # Check each metric has required fields
172
- required_fields = ["name", "type", "description", "implementation", "primary"]
173
- for metric in metrics:
174
- for field in required_fields:
175
- assert field in metric, f"Missing field {field} in metric {metric['name']}"
176
-
177
- # Check field types
178
- assert isinstance(metric["name"], str)
179
- assert isinstance(metric["type"], str)
180
- assert isinstance(metric["description"], str)
181
- assert isinstance(metric["implementation"], str)
182
- assert isinstance(metric["primary"], bool)
183
-
184
- # Check specific metrics exist
185
- metric_names = {m["name"] for m in metrics}
186
- expected_metrics = {
187
- "accuracy",
188
- "human_eval_delta",
189
- "per_task_accuracy",
190
- "exact_match",
191
- }
192
- assert expected_metrics.issubset(metric_names)
193
-
194
- # Check primary metrics
195
- primary_metrics = {m["name"] for m in metrics if m["primary"]}
196
- assert "accuracy" in primary_metrics
197
- assert "human_eval_delta" in primary_metrics
198
-
199
-
200
- def test_dataset_description_citation_format(bbh_parser):
201
- """Test that the citation in dataset description is properly formatted."""
202
  description = bbh_parser.get_dataset_description()
203
- citation = description["citation"]
204
-
205
- # Check citation structure
206
- assert citation.startswith("@article{")
207
- assert "title=" in citation
208
- assert "author=" in citation
209
- assert "journal=" in citation
210
- assert "year=" in citation
211
 
212
- # Check specific author formatting
213
- assert "Suzgun, Mirac" in citation
214
- assert "Wei, Jason" in citation
215
- assert "and Wei, Jason" in citation # Should be last author
216
- assert "and and" not in citation # No double "and"
 
 
 
 
217
 
218
 
219
- def test_evaluation_metrics_implementations(bbh_parser):
220
- """Test that evaluation metric implementations are properly specified."""
221
  metrics = bbh_parser.get_evaluation_metrics()
222
 
223
- for metric in metrics:
224
- impl = metric["implementation"]
225
-
226
- if "evaluate.load" in impl:
227
- # Check standard metric format
228
- assert impl.startswith("evaluate.load('")
229
- assert impl.endswith("')")
230
- elif "custom_" in impl:
231
- # Check custom metric format
232
- assert impl.startswith("custom_")
233
- assert len(impl) > 7 # More than just "custom_"
 
 
 
 
 
 
160
  assert all(isinstance(entry.answer, str) for entry in parsed_data)
161
 
162
 
163
+ def test_get_dataset_description(bbh_parser):
164
+ """Test dataset description generation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  description = bbh_parser.get_dataset_description()
 
 
 
 
 
 
 
 
166
 
167
+ assert description.name == "Big Bench Hard (BBH)"
168
+ assert "challenging BIG-Bench tasks" in description.purpose
169
+ assert description.language == "English"
170
+ assert description.format == "Multiple choice questions with single correct answers"
171
+ assert "Tasks require complex multi-step reasoning" in description.characteristics
172
+ assert "suzgun2022challenging" in description.citation
173
+ assert description.additional_info is not None
174
+ assert "model_performance" in description.additional_info
175
+ assert "size" in description.additional_info
176
 
177
 
178
+ def test_get_evaluation_metrics(bbh_parser):
179
+ """Test evaluation metrics generation."""
180
  metrics = bbh_parser.get_evaluation_metrics()
181
 
182
+ assert len(metrics) == 4 # Check total number of metrics
183
+
184
+ # Check primary metrics
185
+ primary_metrics = [m for m in metrics if m.primary]
186
+ assert len(primary_metrics) == 2
187
+ assert any(m.name == "accuracy" for m in primary_metrics)
188
+ assert any(m.name == "human_eval_delta" for m in primary_metrics)
189
+
190
+ # Check specific metric properties
191
+ accuracy_metric = next(m for m in metrics if m.name == "accuracy")
192
+ assert accuracy_metric.type == "classification"
193
+ assert "evaluate.load('accuracy')" in accuracy_metric.implementation
194
+
195
+ # Check non-primary metrics
196
+ assert any(m.name == "per_task_accuracy" and not m.primary for m in metrics)
197
+ assert any(m.name == "exact_match" and not m.primary for m in metrics)
tests/test_tmlu_parser.py CHANGED
@@ -170,76 +170,40 @@ def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
170
  assert entry.metadata["source"] == "AST chinese - 108"
171
 
172
 
173
- def test_dataset_description(tmlu_parser):
174
- """Test dataset description contains all required fields."""
175
  description = tmlu_parser.get_dataset_description()
176
 
177
- required_fields = [
178
- "name",
179
- "version",
180
- "language",
181
- "purpose",
182
- "source",
183
- "format",
184
- "size",
185
- "domain",
186
- "characteristics",
187
- "reference",
188
- ]
189
-
190
- for field in required_fields:
191
- assert field in description, f"Missing required field: {field}"
192
 
193
- assert description["language"] == "Traditional Chinese"
194
- assert "TMLU" in description["name"]
195
- assert "miulab/tmlu" in description["reference"]
196
- assert "AST" in description["characteristics"]
197
- assert "GSAT" in description["characteristics"]
198
 
199
-
200
- def test_evaluation_metrics(tmlu_parser):
201
- """Test evaluation metrics structure and content."""
202
  metrics = tmlu_parser.get_evaluation_metrics()
203
 
204
- # Check if we have metrics defined
205
- assert len(metrics) > 0
206
 
207
- # Check structure of each metric
208
- required_metric_fields = [
209
- "name",
210
- "type",
211
- "description",
212
- "implementation",
213
- "primary",
214
- ]
215
 
216
- for metric in metrics:
217
- for field in required_metric_fields:
218
- assert field in metric, f"Missing required field in metric: {field}"
219
-
220
- # Type checks
221
- assert isinstance(metric["name"], str)
222
- assert isinstance(metric["type"], str)
223
- assert isinstance(metric["description"], str)
224
- assert isinstance(metric["implementation"], str)
225
- assert isinstance(metric["primary"], bool)
226
-
227
- # Check for TMLU-specific metrics
228
- metric_names = {m["name"] for m in metrics}
229
- expected_metrics = {
230
- "accuracy",
231
- "per_subject_accuracy",
232
  "per_difficulty_accuracy",
 
233
  "explanation_quality",
234
  }
235
-
236
- for expected in expected_metrics:
237
- assert expected in metric_names, f"Missing expected metric: {expected}"
238
-
239
- # Verify primary metrics
240
- primary_metrics = [m for m in metrics if m["primary"]]
241
- assert (
242
- len(primary_metrics) >= 2
243
- ) # Should have at least accuracy and per_subject_accuracy
244
- assert any(m["name"] == "accuracy" for m in primary_metrics)
245
- assert any(m["name"] == "per_subject_accuracy" for m in primary_metrics)
 
170
  assert entry.metadata["source"] == "AST chinese - 108"
171
 
172
 
173
+ def test_get_dataset_description(tmlu_parser):
174
+ """Test dataset description generation."""
175
  description = tmlu_parser.get_dataset_description()
176
 
177
+ assert description.name == "Taiwan Multiple-choice Language Understanding (TMLU)"
178
+ assert description.language == "Traditional Chinese"
179
+ assert "Taiwan-specific educational" in description.purpose
180
+ assert "Various Taiwan standardized tests" in description.source
181
+ assert description.format == "Multiple choice questions (A/B/C/D)"
182
+ assert "Advanced Subjects Test (AST)" in description.characteristics
183
+ assert "DBLP:journals/corr/abs-2403-20180" in description.citation
 
 
 
 
 
 
 
 
184
 
 
 
 
 
 
185
 
186
+ def test_get_evaluation_metrics(tmlu_parser):
187
+ """Test evaluation metrics generation."""
 
188
  metrics = tmlu_parser.get_evaluation_metrics()
189
 
190
+ assert len(metrics) == 5 # Check total number of metrics
 
191
 
192
+ # Check primary metrics
193
+ primary_metrics = [m for m in metrics if m.primary]
194
+ assert len(primary_metrics) == 2
195
+ assert any(m.name == "accuracy" for m in primary_metrics)
196
+ assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
 
 
 
197
 
198
+ # Check specific metric properties
199
+ accuracy_metric = next(m for m in metrics if m.name == "accuracy")
200
+ assert accuracy_metric.type == "classification"
201
+ assert "datasets.load_metric('accuracy')" in accuracy_metric.implementation
202
+
203
+ # Check non-primary metrics
204
+ non_primary_metrics = {m.name for m in metrics if not m.primary}
205
+ assert non_primary_metrics == {
 
 
 
 
 
 
 
 
206
  "per_difficulty_accuracy",
207
+ "confusion_matrix",
208
  "explanation_quality",
209
  }