feat: add tmlu parser
Browse files- llmdataparser/__init__.py +2 -0
- llmdataparser/prompts.py +15 -2
- llmdataparser/tmlu_parser.py +202 -0
- tests/test_tmlu_parser.py +245 -0
llmdataparser/__init__.py
CHANGED
@@ -15,6 +15,7 @@ from .mmlu_parser import (
|
|
15 |
MMLUReduxDatasetParser,
|
16 |
TMMLUPlusDatasetParser,
|
17 |
)
|
|
|
18 |
from .tw_legal_parser import TWLegalDatasetParser
|
19 |
|
20 |
|
@@ -56,3 +57,4 @@ ParserRegistry.register_parser("bbh", BBHDatasetParser)
|
|
56 |
ParserRegistry.register_parser("mbpp", MBPPDatasetParser)
|
57 |
ParserRegistry.register_parser("ifeval", IFEvalDatasetParser)
|
58 |
ParserRegistry.register_parser("twlegal", TWLegalDatasetParser)
|
|
|
|
15 |
MMLUReduxDatasetParser,
|
16 |
TMMLUPlusDatasetParser,
|
17 |
)
|
18 |
+
from .tmlu_parser import TMLUDatasetParser
|
19 |
from .tw_legal_parser import TWLegalDatasetParser
|
20 |
|
21 |
|
|
|
57 |
ParserRegistry.register_parser("mbpp", MBPPDatasetParser)
|
58 |
ParserRegistry.register_parser("ifeval", IFEvalDatasetParser)
|
59 |
ParserRegistry.register_parser("twlegal", TWLegalDatasetParser)
|
60 |
+
ParserRegistry.register_parser("tmlu", TMLUDatasetParser)
|
llmdataparser/prompts.py
CHANGED
@@ -118,7 +118,7 @@ BBH_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
|
118 |
3. Consider all relationships and constraints mentioned in the problem
|
119 |
4. Apply structured thinking to reach a valid conclusion
|
120 |
5. Choose the answer that logically follows from the given information
|
121 |
-
6. Respond with ONLY the letter (A, B, C, etc.) or "True"/"False" or "Yes"/"No" - no explanations or additional text
|
122 |
"""
|
123 |
)
|
124 |
|
@@ -141,7 +141,7 @@ MBPP_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
|
141 |
|
142 |
TW_LEGAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
143 |
"""\
|
144 |
-
You are an expert lawyer with deep knowledge of Taiwan's legal system.
|
145 |
|
146 |
Instructions:
|
147 |
1. Carefully analyze the legal question and all options
|
@@ -151,3 +151,16 @@ TW_LEGAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
|
151 |
5. Respond with ONLY the letter (A, B, C, or D) - no explanations or additional text
|
152 |
"""
|
153 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
3. Consider all relationships and constraints mentioned in the problem
|
119 |
4. Apply structured thinking to reach a valid conclusion
|
120 |
5. Choose the answer that logically follows from the given information
|
121 |
+
6. Respond with ONLY the letter (A, B, C, etc.) or "True"/"False" or "Yes"/"No" and so on - no explanations or additional text
|
122 |
"""
|
123 |
)
|
124 |
|
|
|
141 |
|
142 |
TW_LEGAL_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
143 |
"""\
|
144 |
+
You are an expert lawyer with deep knowledge of Taiwan's legal system. For each question, you will analyze legal scenarios or concepts based on Taiwan's laws and regulations. Your task is to select the most appropriate answer that aligns with Taiwan's legal principles.
|
145 |
|
146 |
Instructions:
|
147 |
1. Carefully analyze the legal question and all options
|
|
|
151 |
5. Respond with ONLY the letter (A, B, C, or D) - no explanations or additional text
|
152 |
"""
|
153 |
)
|
154 |
+
|
155 |
+
TMLU_SYSTEM_PROMPT: Final[str] = textwrap.dedent(
|
156 |
+
"""\
|
157 |
+
You are an expert evaluator with deep knowledge of Taiwan's educational system and professional fields. For each question, analyze it carefully and select the most appropriate answer based on your understanding of the subject matter.
|
158 |
+
|
159 |
+
Instructions:
|
160 |
+
1. Carefully read and understand the question
|
161 |
+
2. Consider all answer options thoroughly
|
162 |
+
3. Apply subject-specific knowledge and reasoning
|
163 |
+
4. Select the single most accurate answer
|
164 |
+
5. Respond with ONLY the letter (A, B, C, or D) - no explanations or additional text
|
165 |
+
"""
|
166 |
+
)
|
llmdataparser/tmlu_parser.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, Final, List
|
3 |
+
|
4 |
+
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
|
5 |
+
from llmdataparser.prompts import TMLU_SYSTEM_PROMPT
|
6 |
+
|
7 |
+
TMLU_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
|
8 |
+
TMLU_VALID_ANSWER_STR: Final[str] = ", ".join(sorted(TMLU_VALID_ANSWERS))
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass(frozen=True, kw_only=True, slots=True)
|
12 |
+
class TMLUParseEntry(HuggingFaceParseEntry):
|
13 |
+
"""Custom entry class for TMLU, with fields specific to this dataset parser."""
|
14 |
+
|
15 |
+
raw_choices: list[str]
|
16 |
+
explanation: str
|
17 |
+
metadata: dict[str, Any]
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def create(
|
21 |
+
cls,
|
22 |
+
prompt: str,
|
23 |
+
answer: str,
|
24 |
+
raw_question: str,
|
25 |
+
raw_choices: list[str],
|
26 |
+
raw_answer: str,
|
27 |
+
task_name: str,
|
28 |
+
explanation: str = "",
|
29 |
+
metadata: dict[str, Any] = {},
|
30 |
+
) -> "TMLUParseEntry":
|
31 |
+
if answer not in TMLU_VALID_ANSWERS:
|
32 |
+
raise ValueError(
|
33 |
+
f"Invalid answer_letter '{answer}'; must be one of {TMLU_VALID_ANSWER_STR}"
|
34 |
+
)
|
35 |
+
return cls(
|
36 |
+
prompt=prompt,
|
37 |
+
answer=answer,
|
38 |
+
raw_question=raw_question,
|
39 |
+
raw_answer=raw_answer,
|
40 |
+
raw_choices=raw_choices,
|
41 |
+
task_name=task_name,
|
42 |
+
explanation=explanation,
|
43 |
+
metadata=metadata,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
class TMLUDatasetParser(HuggingFaceDatasetParser[TMLUParseEntry]):
|
48 |
+
"""Parser for the TMLU dataset."""
|
49 |
+
|
50 |
+
_data_source = "miulab/tmlu"
|
51 |
+
_default_task = "AST_chinese"
|
52 |
+
_task_names = [
|
53 |
+
"AST_chinese",
|
54 |
+
"AST_mathematics",
|
55 |
+
"AST_biology",
|
56 |
+
"AST_chemistry",
|
57 |
+
"AST_physics",
|
58 |
+
"AST_civics",
|
59 |
+
"AST_geography",
|
60 |
+
"AST_history",
|
61 |
+
"GSAT_chinese",
|
62 |
+
"GSAT_chemistry",
|
63 |
+
"GSAT_biology",
|
64 |
+
"GSAT_physics",
|
65 |
+
"GSAT_earth_science",
|
66 |
+
"GSAT_mathematics",
|
67 |
+
"GSAT_geography",
|
68 |
+
"GSAT_history",
|
69 |
+
"GSAT_civics",
|
70 |
+
"CAP_mathematics",
|
71 |
+
"CAP_biology",
|
72 |
+
"CAP_physics",
|
73 |
+
"CAP_chemistry",
|
74 |
+
"CAP_earth_science",
|
75 |
+
"CAP_civics",
|
76 |
+
"CAP_history",
|
77 |
+
"CAP_geography",
|
78 |
+
"CAP_chinese",
|
79 |
+
"driving_rule",
|
80 |
+
"basic_traditional_chinese_medicine",
|
81 |
+
"clinical_traditional_chinese_medicine",
|
82 |
+
"lawyer_qualification",
|
83 |
+
"nutritionist",
|
84 |
+
"tour_leader",
|
85 |
+
"tour_guide",
|
86 |
+
"taiwan_tourist_resources",
|
87 |
+
"clinical_psychologist",
|
88 |
+
"teacher_qualification",
|
89 |
+
"accountant",
|
90 |
+
]
|
91 |
+
_default_system_prompt = TMLU_SYSTEM_PROMPT
|
92 |
+
|
93 |
+
def process_entry(
|
94 |
+
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
|
95 |
+
) -> TMLUParseEntry:
|
96 |
+
"""Process a single TMLU entry."""
|
97 |
+
task = task_name or self._get_current_task(row)
|
98 |
+
# Extract choices in order
|
99 |
+
raw_choices = [row["A"], row["B"], row["C"], row["D"]]
|
100 |
+
choices = "\n".join(
|
101 |
+
f"{chr(65 + i)}. {choice}" for i, choice in enumerate(raw_choices)
|
102 |
+
)
|
103 |
+
raw_question = row["question"]
|
104 |
+
raw_answer = row["answer"]
|
105 |
+
explanation = row.get("explanation", "")
|
106 |
+
metadata = row.get("metadata", {})
|
107 |
+
|
108 |
+
prompt = f"{self._system_prompt}\nQuestion: {raw_question}\n{choices}\nAnswer:"
|
109 |
+
|
110 |
+
return TMLUParseEntry.create(
|
111 |
+
prompt=prompt,
|
112 |
+
answer=raw_answer,
|
113 |
+
raw_question=raw_question,
|
114 |
+
raw_choices=raw_choices,
|
115 |
+
raw_answer=raw_answer,
|
116 |
+
task_name=task,
|
117 |
+
explanation=explanation,
|
118 |
+
metadata=metadata,
|
119 |
+
)
|
120 |
+
|
121 |
+
def get_dataset_description(self) -> Dict[str, str]:
|
122 |
+
"""Returns description of the TMLU dataset."""
|
123 |
+
return {
|
124 |
+
"name": "Taiwan Multiple-choice Language Understanding (TMLU)",
|
125 |
+
"version": "1.0",
|
126 |
+
"language": "Traditional Chinese",
|
127 |
+
"purpose": "Evaluate models on Taiwan-specific educational and professional knowledge",
|
128 |
+
"source": "Various Taiwan standardized tests and professional certifications",
|
129 |
+
"format": "Multiple choice questions (A/B/C/D)",
|
130 |
+
"size": "Multiple subjects across different test types",
|
131 |
+
"domain": "Education and Professional Certification",
|
132 |
+
"characteristics": (
|
133 |
+
"Covers various subjects including Advanced Subjects Test (AST), "
|
134 |
+
"General Scholastic Ability Test (GSAT), College Admission Practice (CAP), "
|
135 |
+
"and professional certifications"
|
136 |
+
),
|
137 |
+
"reference": "https://huggingface.co/datasets/miulab/tmlu",
|
138 |
+
}
|
139 |
+
|
140 |
+
def get_evaluation_metrics(self) -> List[Dict[str, Any]]:
|
141 |
+
"""Returns recommended evaluation metrics for TMLU."""
|
142 |
+
return [
|
143 |
+
{
|
144 |
+
"name": "accuracy",
|
145 |
+
"type": "classification",
|
146 |
+
"description": "Overall percentage of correctly answered questions",
|
147 |
+
"implementation": "datasets.load_metric('accuracy')",
|
148 |
+
"primary": True,
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"name": "per_subject_accuracy",
|
152 |
+
"type": "classification",
|
153 |
+
"description": "Accuracy broken down by subject areas (AST, GSAT, CAP, etc.)",
|
154 |
+
"implementation": "custom_subject_accuracy",
|
155 |
+
"primary": True,
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"name": "per_difficulty_accuracy",
|
159 |
+
"type": "classification",
|
160 |
+
"description": "Accuracy broken down by test difficulty levels",
|
161 |
+
"implementation": "custom_difficulty_accuracy",
|
162 |
+
"primary": False,
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"name": "confusion_matrix",
|
166 |
+
"type": "classification",
|
167 |
+
"description": "Distribution of predicted vs actual answers",
|
168 |
+
"implementation": "datasets.load_metric('confusion_matrix')",
|
169 |
+
"primary": False,
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"name": "explanation_quality",
|
173 |
+
"type": "text",
|
174 |
+
"description": "Quality assessment of model explanations when available",
|
175 |
+
"implementation": "custom_explanation_metric",
|
176 |
+
"primary": False,
|
177 |
+
},
|
178 |
+
]
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == "__main__":
|
182 |
+
# Example usage
|
183 |
+
parser = TMLUDatasetParser()
|
184 |
+
parser.load()
|
185 |
+
parser.parse()
|
186 |
+
|
187 |
+
# Get parsed data with correct type
|
188 |
+
parsed_data = parser.get_parsed_data
|
189 |
+
|
190 |
+
# Print example entry
|
191 |
+
if parsed_data:
|
192 |
+
example = parsed_data[0]
|
193 |
+
print("\nExample parsed entry:")
|
194 |
+
print(f"Task: {example.task_name}")
|
195 |
+
print(f"Question: {example.raw_question}")
|
196 |
+
print("Choices:")
|
197 |
+
for i, choice in enumerate(example.raw_choices):
|
198 |
+
print(f"{chr(65 + i)}. {choice}")
|
199 |
+
print(f"Correct Answer: {example.answer}")
|
200 |
+
if example.explanation:
|
201 |
+
print(f"Explanation: {example.explanation}")
|
202 |
+
print(f"Metadata: {example.metadata}")
|
tests/test_tmlu_parser.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
+
from llmdataparser.tmlu_parser import TMLUDatasetParser, TMLUParseEntry
|
4 |
+
|
5 |
+
|
6 |
+
@pytest.fixture
|
7 |
+
def tmlu_parser():
|
8 |
+
"""Create a TMLU parser instance for testing."""
|
9 |
+
return TMLUDatasetParser()
|
10 |
+
|
11 |
+
|
12 |
+
@pytest.fixture
|
13 |
+
def sample_tmlu_entries():
|
14 |
+
"""Create sample TMLU dataset entries for testing."""
|
15 |
+
return [
|
16 |
+
{
|
17 |
+
"question": "閱讀下文,選出依序最適合填入□內的選項:",
|
18 |
+
"A": "張揚/綢繆未雨/奏疏",
|
19 |
+
"B": "抽搐/煮繭抽絲/奏疏",
|
20 |
+
"C": "張揚/煮繭抽絲/進貢",
|
21 |
+
"D": "抽搐/綢繆未雨/進貢",
|
22 |
+
"answer": "B",
|
23 |
+
"explanation": "根據文意,選項B最為恰當。",
|
24 |
+
"metadata": {
|
25 |
+
"timestamp": "2023-10-09T18:27:20.304623",
|
26 |
+
"source": "AST chinese - 108",
|
27 |
+
"explanation_source": "",
|
28 |
+
},
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"question": "下列何者是質數?",
|
32 |
+
"A": "21",
|
33 |
+
"B": "27",
|
34 |
+
"C": "31",
|
35 |
+
"D": "33",
|
36 |
+
"answer": "C",
|
37 |
+
"explanation": "31是質數,其他選項都是合數。",
|
38 |
+
"metadata": {
|
39 |
+
"timestamp": "2023-10-09T18:27:20.304623",
|
40 |
+
"source": "AST mathematics - 108",
|
41 |
+
"explanation_source": "",
|
42 |
+
},
|
43 |
+
},
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
def test_tmlu_parse_entry_creation_valid():
|
48 |
+
"""Test valid creation of TMLUParseEntry."""
|
49 |
+
entry = TMLUParseEntry.create(
|
50 |
+
prompt="Test prompt",
|
51 |
+
answer="A",
|
52 |
+
raw_question="Test question",
|
53 |
+
raw_choices=["choice1", "choice2", "choice3", "choice4"],
|
54 |
+
raw_answer="A",
|
55 |
+
task_name="AST_chinese",
|
56 |
+
explanation="Test explanation",
|
57 |
+
metadata={"source": "test"},
|
58 |
+
)
|
59 |
+
assert isinstance(entry, TMLUParseEntry)
|
60 |
+
assert entry.prompt == "Test prompt"
|
61 |
+
assert entry.answer == "A"
|
62 |
+
assert entry.raw_choices == ["choice1", "choice2", "choice3", "choice4"]
|
63 |
+
assert entry.explanation == "Test explanation"
|
64 |
+
assert entry.metadata == {"source": "test"}
|
65 |
+
|
66 |
+
|
67 |
+
@pytest.mark.parametrize("invalid_answer", ["E", "F", "1", "", None])
|
68 |
+
def test_tmlu_parse_entry_creation_invalid(invalid_answer):
|
69 |
+
"""Test invalid answer handling in TMLUParseEntry creation."""
|
70 |
+
with pytest.raises(
|
71 |
+
ValueError, match="Invalid answer_letter.*must be one of A, B, C, D"
|
72 |
+
):
|
73 |
+
TMLUParseEntry.create(
|
74 |
+
prompt="Test prompt",
|
75 |
+
answer=invalid_answer,
|
76 |
+
raw_question="Test question",
|
77 |
+
raw_choices=["choice1", "choice2", "choice3", "choice4"],
|
78 |
+
raw_answer=invalid_answer,
|
79 |
+
task_name="AST_chinese",
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
def test_process_entry(tmlu_parser, sample_tmlu_entries):
|
84 |
+
"""Test processing entries in TMLU parser."""
|
85 |
+
entry = tmlu_parser.process_entry(sample_tmlu_entries[0], task_name="AST_chinese")
|
86 |
+
|
87 |
+
assert isinstance(entry, TMLUParseEntry)
|
88 |
+
assert entry.answer == "B"
|
89 |
+
assert entry.task_name == "AST_chinese"
|
90 |
+
assert len(entry.raw_choices) == 4
|
91 |
+
assert entry.explanation == "根據文意,選項B最為恰當。"
|
92 |
+
assert "AST chinese - 108" in entry.metadata["source"]
|
93 |
+
|
94 |
+
|
95 |
+
def test_tmlu_parser_initialization(tmlu_parser):
|
96 |
+
"""Test TMLU parser initialization and properties."""
|
97 |
+
assert isinstance(tmlu_parser.task_names, list)
|
98 |
+
assert len(tmlu_parser.task_names) == 37 # Total number of tasks
|
99 |
+
assert tmlu_parser._data_source == "miulab/tmlu"
|
100 |
+
assert tmlu_parser._default_task == "AST_chinese"
|
101 |
+
assert "AST_chinese" in tmlu_parser.task_names
|
102 |
+
assert "GSAT_mathematics" in tmlu_parser.task_names
|
103 |
+
assert (
|
104 |
+
tmlu_parser.get_huggingface_link
|
105 |
+
== "https://huggingface.co/datasets/miulab/tmlu"
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
@pytest.mark.integration
|
110 |
+
def test_load_dataset(tmlu_parser):
|
111 |
+
"""Test loading the TMLU dataset."""
|
112 |
+
tmlu_parser.load(task_name="AST_chinese", split="test")
|
113 |
+
assert tmlu_parser.raw_data is not None
|
114 |
+
assert tmlu_parser.split_names == ["test"]
|
115 |
+
assert tmlu_parser._current_task == "AST_chinese"
|
116 |
+
|
117 |
+
|
118 |
+
def test_parser_string_representation(tmlu_parser):
|
119 |
+
"""Test string representation of TMLU parser."""
|
120 |
+
repr_str = str(tmlu_parser)
|
121 |
+
assert "TMLUDatasetParser" in repr_str
|
122 |
+
assert "miulab/tmlu" in repr_str
|
123 |
+
assert "not loaded" in repr_str
|
124 |
+
|
125 |
+
|
126 |
+
@pytest.mark.integration
|
127 |
+
def test_different_tasks_parsing(tmlu_parser):
|
128 |
+
"""Test parsing different tasks of the dataset."""
|
129 |
+
# Load and parse AST_chinese
|
130 |
+
tmlu_parser.load(task_name="AST_chinese", split="test")
|
131 |
+
tmlu_parser.parse(split_names="test", force=True)
|
132 |
+
chinese_count = len(tmlu_parser.get_parsed_data)
|
133 |
+
|
134 |
+
# Load and parse AST_mathematics
|
135 |
+
tmlu_parser.load(task_name="AST_mathematics", split="test")
|
136 |
+
tmlu_parser.parse(split_names="test", force=True)
|
137 |
+
math_count = len(tmlu_parser.get_parsed_data)
|
138 |
+
|
139 |
+
assert chinese_count > 0
|
140 |
+
assert math_count > 0
|
141 |
+
|
142 |
+
|
143 |
+
def test_system_prompt_override(tmlu_parser):
|
144 |
+
"""Test overriding the default system prompt."""
|
145 |
+
custom_prompt = "Custom system prompt for testing"
|
146 |
+
parser = TMLUDatasetParser(system_prompt=custom_prompt)
|
147 |
+
|
148 |
+
test_entry = {
|
149 |
+
"question": "Test question",
|
150 |
+
"A": "Choice A",
|
151 |
+
"B": "Choice B",
|
152 |
+
"C": "Choice C",
|
153 |
+
"D": "Choice D",
|
154 |
+
"answer": "A",
|
155 |
+
"explanation": "Test explanation",
|
156 |
+
"metadata": {"source": "test"},
|
157 |
+
}
|
158 |
+
|
159 |
+
entry = parser.process_entry(test_entry)
|
160 |
+
assert custom_prompt in entry.prompt
|
161 |
+
|
162 |
+
|
163 |
+
def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
|
164 |
+
"""Test proper handling of metadata in entries."""
|
165 |
+
entry = tmlu_parser.process_entry(sample_tmlu_entries[0])
|
166 |
+
|
167 |
+
assert "timestamp" in entry.metadata
|
168 |
+
assert "source" in entry.metadata
|
169 |
+
assert "explanation_source" in entry.metadata
|
170 |
+
assert entry.metadata["source"] == "AST chinese - 108"
|
171 |
+
|
172 |
+
|
173 |
+
def test_dataset_description(tmlu_parser):
|
174 |
+
"""Test dataset description contains all required fields."""
|
175 |
+
description = tmlu_parser.get_dataset_description()
|
176 |
+
|
177 |
+
required_fields = [
|
178 |
+
"name",
|
179 |
+
"version",
|
180 |
+
"language",
|
181 |
+
"purpose",
|
182 |
+
"source",
|
183 |
+
"format",
|
184 |
+
"size",
|
185 |
+
"domain",
|
186 |
+
"characteristics",
|
187 |
+
"reference",
|
188 |
+
]
|
189 |
+
|
190 |
+
for field in required_fields:
|
191 |
+
assert field in description, f"Missing required field: {field}"
|
192 |
+
|
193 |
+
assert description["language"] == "Traditional Chinese"
|
194 |
+
assert "TMLU" in description["name"]
|
195 |
+
assert "miulab/tmlu" in description["reference"]
|
196 |
+
assert "AST" in description["characteristics"]
|
197 |
+
assert "GSAT" in description["characteristics"]
|
198 |
+
|
199 |
+
|
200 |
+
def test_evaluation_metrics(tmlu_parser):
|
201 |
+
"""Test evaluation metrics structure and content."""
|
202 |
+
metrics = tmlu_parser.get_evaluation_metrics()
|
203 |
+
|
204 |
+
# Check if we have metrics defined
|
205 |
+
assert len(metrics) > 0
|
206 |
+
|
207 |
+
# Check structure of each metric
|
208 |
+
required_metric_fields = [
|
209 |
+
"name",
|
210 |
+
"type",
|
211 |
+
"description",
|
212 |
+
"implementation",
|
213 |
+
"primary",
|
214 |
+
]
|
215 |
+
|
216 |
+
for metric in metrics:
|
217 |
+
for field in required_metric_fields:
|
218 |
+
assert field in metric, f"Missing required field in metric: {field}"
|
219 |
+
|
220 |
+
# Type checks
|
221 |
+
assert isinstance(metric["name"], str)
|
222 |
+
assert isinstance(metric["type"], str)
|
223 |
+
assert isinstance(metric["description"], str)
|
224 |
+
assert isinstance(metric["implementation"], str)
|
225 |
+
assert isinstance(metric["primary"], bool)
|
226 |
+
|
227 |
+
# Check for TMLU-specific metrics
|
228 |
+
metric_names = {m["name"] for m in metrics}
|
229 |
+
expected_metrics = {
|
230 |
+
"accuracy",
|
231 |
+
"per_subject_accuracy",
|
232 |
+
"per_difficulty_accuracy",
|
233 |
+
"explanation_quality",
|
234 |
+
}
|
235 |
+
|
236 |
+
for expected in expected_metrics:
|
237 |
+
assert expected in metric_names, f"Missing expected metric: {expected}"
|
238 |
+
|
239 |
+
# Verify primary metrics
|
240 |
+
primary_metrics = [m for m in metrics if m["primary"]]
|
241 |
+
assert (
|
242 |
+
len(primary_metrics) >= 2
|
243 |
+
) # Should have at least accuracy and per_subject_accuracy
|
244 |
+
assert any(m["name"] == "accuracy" for m in primary_metrics)
|
245 |
+
assert any(m["name"] == "per_subject_accuracy" for m in primary_metrics)
|