refactor: tmli/tw_legal parser
Browse files- llmdataparser/tmlu_parser.py +0 -21
- llmdataparser/tw_legal_parser.py +35 -1
- tests/test_tmlu_parser.py +1 -14
- tests/test_tw_legal_parser.py +24 -0
llmdataparser/tmlu_parser.py
CHANGED
@@ -169,27 +169,6 @@ class TMLUDatasetParser(HuggingFaceDatasetParser[TMLUParseEntry]):
|
|
169 |
implementation="custom_subject_accuracy",
|
170 |
primary=True,
|
171 |
),
|
172 |
-
EvaluationMetric.create(
|
173 |
-
name="per_difficulty_accuracy",
|
174 |
-
type="classification",
|
175 |
-
description="Accuracy broken down by test difficulty levels",
|
176 |
-
implementation="custom_difficulty_accuracy",
|
177 |
-
primary=False,
|
178 |
-
),
|
179 |
-
EvaluationMetric.create(
|
180 |
-
name="confusion_matrix",
|
181 |
-
type="classification",
|
182 |
-
description="Distribution of predicted vs actual answers",
|
183 |
-
implementation="datasets.load_metric('confusion_matrix')",
|
184 |
-
primary=False,
|
185 |
-
),
|
186 |
-
EvaluationMetric.create(
|
187 |
-
name="explanation_quality",
|
188 |
-
type="text",
|
189 |
-
description="Quality assessment of model explanations when available",
|
190 |
-
implementation="custom_explanation_metric",
|
191 |
-
primary=False,
|
192 |
-
),
|
193 |
]
|
194 |
|
195 |
|
|
|
169 |
implementation="custom_subject_accuracy",
|
170 |
primary=True,
|
171 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
]
|
173 |
|
174 |
|
llmdataparser/tw_legal_parser.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, Final
|
3 |
|
4 |
-
from llmdataparser.base_parser import
|
|
|
|
|
|
|
|
|
|
|
5 |
from llmdataparser.prompts import TW_LEGAL_SYSTEM_PROMPT
|
6 |
|
7 |
TW_LEGAL_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
|
@@ -70,6 +75,35 @@ class TWLegalDatasetParser(HuggingFaceDatasetParser[TWLegalParseEntry]):
|
|
70 |
task_name=task,
|
71 |
)
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
# Example usage
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, Final
|
3 |
|
4 |
+
from llmdataparser.base_parser import (
|
5 |
+
DatasetDescription,
|
6 |
+
EvaluationMetric,
|
7 |
+
HuggingFaceDatasetParser,
|
8 |
+
HuggingFaceParseEntry,
|
9 |
+
)
|
10 |
from llmdataparser.prompts import TW_LEGAL_SYSTEM_PROMPT
|
11 |
|
12 |
TW_LEGAL_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
|
|
|
75 |
task_name=task,
|
76 |
)
|
77 |
|
78 |
+
def get_dataset_description(self) -> DatasetDescription:
|
79 |
+
"""Returns description of the Taiwan Legal Benchmark dataset."""
|
80 |
+
return DatasetDescription.create(
|
81 |
+
name="Taiwan Legal Benchmark",
|
82 |
+
language="Traditional Chinese",
|
83 |
+
purpose="Evaluate models on Taiwan-specific legal knowledge and understanding",
|
84 |
+
source="Taiwan Bar Examination questions",
|
85 |
+
format="Multiple choice questions (A/B/C/D)",
|
86 |
+
characteristics=(
|
87 |
+
"Contains questions from Taiwan's bar examination, testing understanding "
|
88 |
+
"of Taiwan's legal system, terminology, and concepts"
|
89 |
+
),
|
90 |
+
citation="""
|
91 |
+
url={https://huggingface.co/datasets/lianghsun/tw-legal-benchmark-v1}
|
92 |
+
}""",
|
93 |
+
)
|
94 |
+
|
95 |
+
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
|
96 |
+
"""Returns recommended evaluation metrics for Taiwan Legal Benchmark."""
|
97 |
+
return [
|
98 |
+
EvaluationMetric.create(
|
99 |
+
name="accuracy",
|
100 |
+
type="classification",
|
101 |
+
description="Overall percentage of correctly answered legal questions",
|
102 |
+
implementation="datasets.load_metric('accuracy')",
|
103 |
+
primary=True,
|
104 |
+
),
|
105 |
+
]
|
106 |
+
|
107 |
|
108 |
if __name__ == "__main__":
|
109 |
# Example usage
|
tests/test_tmlu_parser.py
CHANGED
@@ -187,23 +187,10 @@ def test_get_evaluation_metrics(tmlu_parser):
|
|
187 |
"""Test evaluation metrics generation."""
|
188 |
metrics = tmlu_parser.get_evaluation_metrics()
|
189 |
|
190 |
-
assert len(metrics) ==
|
191 |
|
192 |
# Check primary metrics
|
193 |
primary_metrics = [m for m in metrics if m.primary]
|
194 |
assert len(primary_metrics) == 2
|
195 |
assert any(m.name == "accuracy" for m in primary_metrics)
|
196 |
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
|
197 |
-
|
198 |
-
# Check specific metric properties
|
199 |
-
accuracy_metric = next(m for m in metrics if m.name == "accuracy")
|
200 |
-
assert accuracy_metric.type == "classification"
|
201 |
-
assert "datasets.load_metric('accuracy')" in accuracy_metric.implementation
|
202 |
-
|
203 |
-
# Check non-primary metrics
|
204 |
-
non_primary_metrics = {m.name for m in metrics if not m.primary}
|
205 |
-
assert non_primary_metrics == {
|
206 |
-
"per_difficulty_accuracy",
|
207 |
-
"confusion_matrix",
|
208 |
-
"explanation_quality",
|
209 |
-
}
|
|
|
187 |
"""Test evaluation metrics generation."""
|
188 |
metrics = tmlu_parser.get_evaluation_metrics()
|
189 |
|
190 |
+
assert len(metrics) == 2 # Check total number of metrics
|
191 |
|
192 |
# Check primary metrics
|
193 |
primary_metrics = [m for m in metrics if m.primary]
|
194 |
assert len(primary_metrics) == 2
|
195 |
assert any(m.name == "accuracy" for m in primary_metrics)
|
196 |
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_tw_legal_parser.py
CHANGED
@@ -138,3 +138,27 @@ def test_system_prompt_override(tw_legal_parser):
|
|
138 |
|
139 |
entry = parser.process_entry(test_entry)
|
140 |
assert custom_prompt in entry.prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
entry = parser.process_entry(test_entry)
|
140 |
assert custom_prompt in entry.prompt
|
141 |
+
|
142 |
+
|
143 |
+
def test_get_dataset_description(tw_legal_parser):
|
144 |
+
"""Test getting dataset description for Taiwan Legal parser."""
|
145 |
+
description = tw_legal_parser.get_dataset_description()
|
146 |
+
|
147 |
+
assert description.name == "Taiwan Legal Benchmark"
|
148 |
+
assert description.language == "Traditional Chinese"
|
149 |
+
assert "Taiwan's legal system" in description.characteristics
|
150 |
+
assert (
|
151 |
+
"huggingface.co/datasets/lianghsun/tw-legal-benchmark-v1"
|
152 |
+
in description.citation
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
def test_get_evaluation_metrics(tw_legal_parser):
|
157 |
+
"""Test getting evaluation metrics for Taiwan Legal parser."""
|
158 |
+
metrics = tw_legal_parser.get_evaluation_metrics()
|
159 |
+
|
160 |
+
assert len(metrics) == 1
|
161 |
+
metric = metrics[0]
|
162 |
+
assert metric.name == "accuracy"
|
163 |
+
assert metric.type == "classification"
|
164 |
+
assert metric.primary is True
|