JeffYang52415 commited on
Commit
6ed7950
·
unverified ·
1 Parent(s): dd0b07f

feat: add tmlu parser

Browse files
llmdataparser/__init__.py CHANGED
@@ -15,6 +15,7 @@ from .mmlu_parser import (
15
  MMLUReduxDatasetParser,
16
  TMMLUPlusDatasetParser,
17
  )
 
18
  from .tw_legal_parser import TWLegalDatasetParser
19
 
20
 
@@ -56,3 +57,4 @@ ParserRegistry.register_parser("bbh", BBHDatasetParser)
56
  ParserRegistry.register_parser("mbpp", MBPPDatasetParser)
57
  ParserRegistry.register_parser("ifeval", IFEvalDatasetParser)
58
  ParserRegistry.register_parser("twlegal", TWLegalDatasetParser)
 
 
15
  MMLUReduxDatasetParser,
16
  TMMLUPlusDatasetParser,
17
  )
18
+ from .tmlu_parser import TMLUDatasetParser
19
  from .tw_legal_parser import TWLegalDatasetParser
20
 
21
 
 
57
  ParserRegistry.register_parser("mbpp", MBPPDatasetParser)
58
  ParserRegistry.register_parser("ifeval", IFEvalDatasetParser)
59
  ParserRegistry.register_parser("twlegal", TWLegalDatasetParser)
60
+ ParserRegistry.register_parser("tmlu", TMLUDatasetParser)
llmdataparser/prompts.py CHANGED
@@ -118,7 +118,7 @@ BBH_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
118
  3. Consider all relationships and constraints mentioned in the problem
119
  4. Apply structured thinking to reach a valid conclusion
120
  5. Choose the answer that logically follows from the given information
121
- 6. Respond with ONLY the letter (A, B, C, etc.) or "True"/"False" or "Yes"/"No" - no explanations or additional text
122
  """
123
  )
124
 
@@ -141,7 +141,7 @@ MBPP_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
141
 
142
  TW_LEGAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
143
  """\
144
- You are an expert lawyer with deep knowledge of Taiwan's legal system. You are taking the Taiwan Bar Examination (臺灣律師資格考試). For each question, you will analyze legal scenarios or concepts based on Taiwan's laws and regulations. Your task is to select the most appropriate answer that aligns with Taiwan's legal principles.
145
 
146
  Instructions:
147
  1. Carefully analyze the legal question and all options
@@ -151,3 +151,16 @@ TW_LEGAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
151
  5. Respond with ONLY the letter (A, B, C, or D) - no explanations or additional text
152
  """
153
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  3. Consider all relationships and constraints mentioned in the problem
119
  4. Apply structured thinking to reach a valid conclusion
120
  5. Choose the answer that logically follows from the given information
121
+ 6. Respond with ONLY the letter (A, B, C, etc.) or "True"/"False" or "Yes"/"No" and so on - no explanations or additional text
122
  """
123
  )
124
 
 
141
 
142
  TW_LEGAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
143
  """\
144
+ You are an expert lawyer with deep knowledge of Taiwan's legal system. For each question, you will analyze legal scenarios or concepts based on Taiwan's laws and regulations. Your task is to select the most appropriate answer that aligns with Taiwan's legal principles.
145
 
146
  Instructions:
147
  1. Carefully analyze the legal question and all options
 
151
  5. Respond with ONLY the letter (A, B, C, or D) - no explanations or additional text
152
  """
153
  )
154
+
155
+ TMLU_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
156
+ """\
157
+ You are an expert evaluator with deep knowledge of Taiwan's educational system and professional fields. For each question, analyze it carefully and select the most appropriate answer based on your understanding of the subject matter.
158
+
159
+ Instructions:
160
+ 1. Carefully read and understand the question
161
+ 2. Consider all answer options thoroughly
162
+ 3. Apply subject-specific knowledge and reasoning
163
+ 4. Select the single most accurate answer
164
+ 5. Respond with ONLY the letter (A, B, C, or D) - no explanations or additional text
165
+ """
166
+ )
llmdataparser/tmlu_parser.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Final, List
3
+
4
+ from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
5
+ from llmdataparser.prompts import TMLU_SYSTEM_PROMPT
6
+
7
+ TMLU_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
8
+ TMLU_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(TMLU_VALID_ANSWERS))
9
+
10
+
11
+ @dataclass(frozen=True, kw_only=True, slots=True)
12
+ class TMLUParseEntry(HuggingFaceParseEntry):
13
+ """Custom entry class for TMLU, with fields specific to this dataset parser."""
14
+
15
+ raw_choices: list[str]
16
+ explanation: str
17
+ metadata: dict[str, Any]
18
+
19
+ @classmethod
20
+ def create(
21
+ cls,
22
+ prompt: str,
23
+ answer: str,
24
+ raw_question: str,
25
+ raw_choices: list[str],
26
+ raw_answer: str,
27
+ task_name: str,
28
+ explanation: str = "",
29
+ metadata: dict[str, Any] = {},
30
+ ) -> "TMLUParseEntry":
31
+ if answer not in TMLU_VALID_ANSWERS:
32
+ raise ValueError(
33
+ f"Invalid answer_letter '{answer}'; must be one of {TMLU_VALID_ANSWER_STR}"
34
+ )
35
+ return cls(
36
+ prompt=prompt,
37
+ answer=answer,
38
+ raw_question=raw_question,
39
+ raw_answer=raw_answer,
40
+ raw_choices=raw_choices,
41
+ task_name=task_name,
42
+ explanation=explanation,
43
+ metadata=metadata,
44
+ )
45
+
46
+
47
+ class TMLUDatasetParser(HuggingFaceDatasetParser[TMLUParseEntry]):
48
+ """Parser for the TMLU dataset."""
49
+
50
+ _data_source = "miulab/tmlu"
51
+ _default_task = "AST_chinese"
52
+ _task_names = [
53
+ "AST_chinese",
54
+ "AST_mathematics",
55
+ "AST_biology",
56
+ "AST_chemistry",
57
+ "AST_physics",
58
+ "AST_civics",
59
+ "AST_geography",
60
+ "AST_history",
61
+ "GSAT_chinese",
62
+ "GSAT_chemistry",
63
+ "GSAT_biology",
64
+ "GSAT_physics",
65
+ "GSAT_earth_science",
66
+ "GSAT_mathematics",
67
+ "GSAT_geography",
68
+ "GSAT_history",
69
+ "GSAT_civics",
70
+ "CAP_mathematics",
71
+ "CAP_biology",
72
+ "CAP_physics",
73
+ "CAP_chemistry",
74
+ "CAP_earth_science",
75
+ "CAP_civics",
76
+ "CAP_history",
77
+ "CAP_geography",
78
+ "CAP_chinese",
79
+ "driving_rule",
80
+ "basic_traditional_chinese_medicine",
81
+ "clinical_traditional_chinese_medicine",
82
+ "lawyer_qualification",
83
+ "nutritionist",
84
+ "tour_leader",
85
+ "tour_guide",
86
+ "taiwan_tourist_resources",
87
+ "clinical_psychologist",
88
+ "teacher_qualification",
89
+ "accountant",
90
+ ]
91
+ _default_system_prompt = TMLU_SYSTEM_PROMPT
92
+
93
+ def process_entry(
94
+ self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
95
+ ) -> TMLUParseEntry:
96
+ """Process a single TMLU entry."""
97
+ task = task_name or self._get_current_task(row)
98
+ # Extract choices in order
99
+ raw_choices = [row["A"], row["B"], row["C"], row["D"]]
100
+ choices = "\n".join(
101
+ f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
102
+ )
103
+ raw_question = row["question"]
104
+ raw_answer = row["answer"]
105
+ explanation = row.get("explanation", "")
106
+ metadata = row.get("metadata", {})
107
+
108
+ prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
109
+
110
+ return TMLUParseEntry.create(
111
+ prompt=prompt,
112
+ answer=raw_answer,
113
+ raw_question=raw_question,
114
+ raw_choices=raw_choices,
115
+ raw_answer=raw_answer,
116
+ task_name=task,
117
+ explanation=explanation,
118
+ metadata=metadata,
119
+ )
120
+
121
+ def get_dataset_description(self) -> Dict[str, str]:
122
+ """Returns description of the TMLU dataset."""
123
+ return {
124
+ "name": "Taiwan Multiple-choice Language Understanding (TMLU)",
125
+ "version": "1.0",
126
+ "language": "Traditional Chinese",
127
+ "purpose": "Evaluate models on Taiwan-specific educational and professional knowledge",
128
+ "source": "Various Taiwan standardized tests and professional certifications",
129
+ "format": "Multiple choice questions (A/B/C/D)",
130
+ "size": "Multiple subjects across different test types",
131
+ "domain": "Education and Professional Certification",
132
+ "characteristics": (
133
+ "Covers various subjects including Advanced Subjects Test (AST), "
134
+ "General Scholastic Ability Test (GSAT), College Admission Practice (CAP), "
135
+ "and professional certifications"
136
+ ),
137
+ "reference": "https://huggingface.co/datasets/miulab/tmlu",
138
+ }
139
+
140
+ def get_evaluation_metrics(self) -> List[Dict[str, Any]]:
141
+ """Returns recommended evaluation metrics for TMLU."""
142
+ return [
143
+ {
144
+ "name": "accuracy",
145
+ "type": "classification",
146
+ "description": "Overall percentage of correctly answered questions",
147
+ "implementation": "datasets.load_metric('accuracy')",
148
+ "primary": True,
149
+ },
150
+ {
151
+ "name": "per_subject_accuracy",
152
+ "type": "classification",
153
+ "description": "Accuracy broken down by subject areas (AST, GSAT, CAP, etc.)",
154
+ "implementation": "custom_subject_accuracy",
155
+ "primary": True,
156
+ },
157
+ {
158
+ "name": "per_difficulty_accuracy",
159
+ "type": "classification",
160
+ "description": "Accuracy broken down by test difficulty levels",
161
+ "implementation": "custom_difficulty_accuracy",
162
+ "primary": False,
163
+ },
164
+ {
165
+ "name": "confusion_matrix",
166
+ "type": "classification",
167
+ "description": "Distribution of predicted vs actual answers",
168
+ "implementation": "datasets.load_metric('confusion_matrix')",
169
+ "primary": False,
170
+ },
171
+ {
172
+ "name": "explanation_quality",
173
+ "type": "text",
174
+ "description": "Quality assessment of model explanations when available",
175
+ "implementation": "custom_explanation_metric",
176
+ "primary": False,
177
+ },
178
+ ]
179
+
180
+
181
+ if __name__ == "__main__":
182
+ # Example usage
183
+ parser = TMLUDatasetParser()
184
+ parser.load()
185
+ parser.parse()
186
+
187
+ # Get parsed data with correct type
188
+ parsed_data = parser.get_parsed_data
189
+
190
+ # Print example entry
191
+ if parsed_data:
192
+ example = parsed_data[0]
193
+ print("\nExample parsed entry:")
194
+ print(f"Task: {example.task_name}")
195
+ print(f"Question: {example.raw_question}")
196
+ print("Choices:")
197
+ for i, choice in enumerate(example.raw_choices):
198
+ print(f"{chr(65 + i)}. {choice}")
199
+ print(f"Correct Answer: {example.answer}")
200
+ if example.explanation:
201
+ print(f"Explanation: {example.explanation}")
202
+ print(f"Metadata: {example.metadata}")
tests/test_tmlu_parser.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from llmdataparser.tmlu_parser import TMLUDatasetParser, TMLUParseEntry
4
+
5
+
6
+ @pytest.fixture
7
+ def tmlu_parser():
8
+ """Create a TMLU parser instance for testing."""
9
+ return TMLUDatasetParser()
10
+
11
+
12
+ @pytest.fixture
13
+ def sample_tmlu_entries():
14
+ """Create sample TMLU dataset entries for testing."""
15
+ return [
16
+ {
17
+ "question": "閱讀下文,選出依序最適合填入□內的選項:",
18
+ "A": "張揚/綢繆未雨/奏疏",
19
+ "B": "抽搐/煮繭抽絲/奏疏",
20
+ "C": "張揚/煮繭抽絲/進貢",
21
+ "D": "抽搐/綢繆未雨/進貢",
22
+ "answer": "B",
23
+ "explanation": "根據文意,選項B最為恰當。",
24
+ "metadata": {
25
+ "timestamp": "2023-10-09T18:27:20.304623",
26
+ "source": "AST chinese - 108",
27
+ "explanation_source": "",
28
+ },
29
+ },
30
+ {
31
+ "question": "下列何者是質數?",
32
+ "A": "21",
33
+ "B": "27",
34
+ "C": "31",
35
+ "D": "33",
36
+ "answer": "C",
37
+ "explanation": "31是質數,其他選項都是合數。",
38
+ "metadata": {
39
+ "timestamp": "2023-10-09T18:27:20.304623",
40
+ "source": "AST mathematics - 108",
41
+ "explanation_source": "",
42
+ },
43
+ },
44
+ ]
45
+
46
+
47
+ def test_tmlu_parse_entry_creation_valid():
48
+ """Test valid creation of TMLUParseEntry."""
49
+ entry = TMLUParseEntry.create(
50
+ prompt="Test prompt",
51
+ answer="A",
52
+ raw_question="Test question",
53
+ raw_choices=["choice1", "choice2", "choice3", "choice4"],
54
+ raw_answer="A",
55
+ task_name="AST_chinese",
56
+ explanation="Test explanation",
57
+ metadata={"source": "test"},
58
+ )
59
+ assert isinstance(entry, TMLUParseEntry)
60
+ assert entry.prompt == "Test prompt"
61
+ assert entry.answer == "A"
62
+ assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
63
+ assert entry.explanation == "Test explanation"
64
+ assert entry.metadata == {"source": "test"}
65
+
66
+
67
+ @pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
68
+ def test_tmlu_parse_entry_creation_invalid(invalid_answer):
69
+ """Test invalid answer handling in TMLUParseEntry creation."""
70
+ with pytest.raises(
71
+ ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
72
+ ):
73
+ TMLUParseEntry.create(
74
+ prompt="Test prompt",
75
+ answer=invalid_answer,
76
+ raw_question="Test question",
77
+ raw_choices=["choice1", "choice2", "choice3", "choice4"],
78
+ raw_answer=invalid_answer,
79
+ task_name="AST_chinese",
80
+ )
81
+
82
+
83
+ def test_process_entry(tmlu_parser, sample_tmlu_entries):
84
+ """Test processing entries in TMLU parser."""
85
+ entry = tmlu_parser.process_entry(sample_tmlu_entries[0], task_name="AST_chinese")
86
+
87
+ assert isinstance(entry, TMLUParseEntry)
88
+ assert entry.answer == "B"
89
+ assert entry.task_name == "AST_chinese"
90
+ assert len(entry.raw_choices) == 4
91
+ assert entry.explanation == "根據文意,選項B最為恰當。"
92
+ assert "AST chinese - 108" in entry.metadata["source"]
93
+
94
+
95
+ def test_tmlu_parser_initialization(tmlu_parser):
96
+ """Test TMLU parser initialization and properties."""
97
+ assert isinstance(tmlu_parser.task_names, list)
98
+ assert len(tmlu_parser.task_names) == 37 # Total number of tasks
99
+ assert tmlu_parser._data_source == "miulab/tmlu"
100
+ assert tmlu_parser._default_task == "AST_chinese"
101
+ assert "AST_chinese" in tmlu_parser.task_names
102
+ assert "GSAT_mathematics" in tmlu_parser.task_names
103
+ assert (
104
+ tmlu_parser.get_huggingface_link
105
+ == "https://huggingface.co/datasets/miulab/tmlu"
106
+ )
107
+
108
+
109
+ @pytest.mark.integration
110
+ def test_load_dataset(tmlu_parser):
111
+ """Test loading the TMLU dataset."""
112
+ tmlu_parser.load(task_name="AST_chinese", split="test")
113
+ assert tmlu_parser.raw_data is not None
114
+ assert tmlu_parser.split_names == ["test"]
115
+ assert tmlu_parser._current_task == "AST_chinese"
116
+
117
+
118
+ def test_parser_string_representation(tmlu_parser):
119
+ """Test string representation of TMLU parser."""
120
+ repr_str = str(tmlu_parser)
121
+ assert "TMLUDatasetParser" in repr_str
122
+ assert "miulab/tmlu" in repr_str
123
+ assert "not loaded" in repr_str
124
+
125
+
126
+ @pytest.mark.integration
127
+ def test_different_tasks_parsing(tmlu_parser):
128
+ """Test parsing different tasks of the dataset."""
129
+ # Load and parse AST_chinese
130
+ tmlu_parser.load(task_name="AST_chinese", split="test")
131
+ tmlu_parser.parse(split_names="test", force=True)
132
+ chinese_count = len(tmlu_parser.get_parsed_data)
133
+
134
+ # Load and parse AST_mathematics
135
+ tmlu_parser.load(task_name="AST_mathematics", split="test")
136
+ tmlu_parser.parse(split_names="test", force=True)
137
+ math_count = len(tmlu_parser.get_parsed_data)
138
+
139
+ assert chinese_count > 0
140
+ assert math_count > 0
141
+
142
+
143
+ def test_system_prompt_override(tmlu_parser):
144
+ """Test overriding the default system prompt."""
145
+ custom_prompt = "Custom system prompt for testing"
146
+ parser = TMLUDatasetParser(system_prompt=custom_prompt)
147
+
148
+ test_entry = {
149
+ "question": "Test question",
150
+ "A": "Choice A",
151
+ "B": "Choice B",
152
+ "C": "Choice C",
153
+ "D": "Choice D",
154
+ "answer": "A",
155
+ "explanation": "Test explanation",
156
+ "metadata": {"source": "test"},
157
+ }
158
+
159
+ entry = parser.process_entry(test_entry)
160
+ assert custom_prompt in entry.prompt
161
+
162
+
163
+ def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
164
+ """Test proper handling of metadata in entries."""
165
+ entry = tmlu_parser.process_entry(sample_tmlu_entries[0])
166
+
167
+ assert "timestamp" in entry.metadata
168
+ assert "source" in entry.metadata
169
+ assert "explanation_source" in entry.metadata
170
+ assert entry.metadata["source"] == "AST chinese - 108"
171
+
172
+
173
+ def test_dataset_description(tmlu_parser):
174
+ """Test dataset description contains all required fields."""
175
+ description = tmlu_parser.get_dataset_description()
176
+
177
+ required_fields = [
178
+ "name",
179
+ "version",
180
+ "language",
181
+ "purpose",
182
+ "source",
183
+ "format",
184
+ "size",
185
+ "domain",
186
+ "characteristics",
187
+ "reference",
188
+ ]
189
+
190
+ for field in required_fields:
191
+ assert field in description, f"Missing required field: {field}"
192
+
193
+ assert description["language"] == "Traditional Chinese"
194
+ assert "TMLU" in description["name"]
195
+ assert "miulab/tmlu" in description["reference"]
196
+ assert "AST" in description["characteristics"]
197
+ assert "GSAT" in description["characteristics"]
198
+
199
+
200
+ def test_evaluation_metrics(tmlu_parser):
201
+ """Test evaluation metrics structure and content."""
202
+ metrics = tmlu_parser.get_evaluation_metrics()
203
+
204
+ # Check if we have metrics defined
205
+ assert len(metrics) > 0
206
+
207
+ # Check structure of each metric
208
+ required_metric_fields = [
209
+ "name",
210
+ "type",
211
+ "description",
212
+ "implementation",
213
+ "primary",
214
+ ]
215
+
216
+ for metric in metrics:
217
+ for field in required_metric_fields:
218
+ assert field in metric, f"Missing required field in metric: {field}"
219
+
220
+ # Type checks
221
+ assert isinstance(metric["name"], str)
222
+ assert isinstance(metric["type"], str)
223
+ assert isinstance(metric["description"], str)
224
+ assert isinstance(metric["implementation"], str)
225
+ assert isinstance(metric["primary"], bool)
226
+
227
+ # Check for TMLU-specific metrics
228
+ metric_names = {m["name"] for m in metrics}
229
+ expected_metrics = {
230
+ "accuracy",
231
+ "per_subject_accuracy",
232
+ "per_difficulty_accuracy",
233
+ "explanation_quality",
234
+ }
235
+
236
+ for expected in expected_metrics:
237
+ assert expected in metric_names, f"Missing expected metric: {expected}"
238
+
239
+ # Verify primary metrics
240
+ primary_metrics = [m for m in metrics if m["primary"]]
241
+ assert (
242
+ len(primary_metrics) >= 2
243
+ ) # Should have at least accuracy and per_subject_accuracy
244
+ assert any(m["name"] == "accuracy" for m in primary_metrics)
245
+ assert any(m["name"] == "per_subject_accuracy" for m in primary_metrics)