JeffYang52415 commited on
Commit
2e6d41b
·
unverified ·
1 Parent(s): 793be05

refactor: tmli/tw_legal parser

Browse files
llmdataparser/tmlu_parser.py CHANGED
@@ -169,27 +169,6 @@ class TMLUDatasetParser(HuggingFaceDatasetParser[TMLUParseEntry]):
169
  implementation="custom_subject_accuracy",
170
  primary=True,
171
  ),
172
- EvaluationMetric.create(
173
- name="per_difficulty_accuracy",
174
- type="classification",
175
- description="Accuracy broken down by test difficulty levels",
176
- implementation="custom_difficulty_accuracy",
177
- primary=False,
178
- ),
179
- EvaluationMetric.create(
180
- name="confusion_matrix",
181
- type="classification",
182
- description="Distribution of predicted vs actual answers",
183
- implementation="datasets.load_metric('confusion_matrix')",
184
- primary=False,
185
- ),
186
- EvaluationMetric.create(
187
- name="explanation_quality",
188
- type="text",
189
- description="Quality assessment of model explanations when available",
190
- implementation="custom_explanation_metric",
191
- primary=False,
192
- ),
193
  ]
194
 
195
 
 
169
  implementation="custom_subject_accuracy",
170
  primary=True,
171
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  ]
173
 
174
 
llmdataparser/tw_legal_parser.py CHANGED
@@ -1,7 +1,12 @@
1
  from dataclasses import dataclass
2
  from typing import Any, Final
3
 
4
- from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
 
 
 
 
 
5
  from llmdataparser.prompts import TW_LEGAL_SYSTEM_PROMPT
6
 
7
  TW_LEGAL_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
@@ -70,6 +75,35 @@ class TWLegalDatasetParser(HuggingFaceDatasetParser[TWLegalParseEntry]):
70
  task_name=task,
71
  )
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  if __name__ == "__main__":
75
  # Example usage
 
1
  from dataclasses import dataclass
2
  from typing import Any, Final
3
 
4
+ from llmdataparser.base_parser import (
5
+ DatasetDescription,
6
+ EvaluationMetric,
7
+ HuggingFaceDatasetParser,
8
+ HuggingFaceParseEntry,
9
+ )
10
  from llmdataparser.prompts import TW_LEGAL_SYSTEM_PROMPT
11
 
12
  TW_LEGAL_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
 
75
  task_name=task,
76
  )
77
 
78
+ def get_dataset_description(self) -> DatasetDescription:
79
+ """Returns description of the Taiwan Legal Benchmark dataset."""
80
+ return DatasetDescription.create(
81
+ name="Taiwan Legal Benchmark",
82
+ language="Traditional Chinese",
83
+ purpose="Evaluate models on Taiwan-specific legal knowledge and understanding",
84
+ source="Taiwan Bar Examination questions",
85
+ format="Multiple choice questions (A/B/C/D)",
86
+ characteristics=(
87
+ "Contains questions from Taiwan's bar examination, testing understanding "
88
+ "of Taiwan's legal system, terminology, and concepts"
89
+ ),
90
+ citation="""
91
+ url={https://huggingface.co/datasets/lianghsun/tw-legal-benchmark-v1}
92
+ }""",
93
+ )
94
+
95
+ def get_evaluation_metrics(self) -> list[EvaluationMetric]:
96
+ """Returns recommended evaluation metrics for Taiwan Legal Benchmark."""
97
+ return [
98
+ EvaluationMetric.create(
99
+ name="accuracy",
100
+ type="classification",
101
+ description="Overall percentage of correctly answered legal questions",
102
+ implementation="datasets.load_metric('accuracy')",
103
+ primary=True,
104
+ ),
105
+ ]
106
+
107
 
108
  if __name__ == "__main__":
109
  # Example usage
tests/test_tmlu_parser.py CHANGED
@@ -187,23 +187,10 @@ def test_get_evaluation_metrics(tmlu_parser):
187
  """Test evaluation metrics generation."""
188
  metrics = tmlu_parser.get_evaluation_metrics()
189
 
190
- assert len(metrics) == 5 # Check total number of metrics
191
 
192
  # Check primary metrics
193
  primary_metrics = [m for m in metrics if m.primary]
194
  assert len(primary_metrics) == 2
195
  assert any(m.name == "accuracy" for m in primary_metrics)
196
  assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
197
-
198
- # Check specific metric properties
199
- accuracy_metric = next(m for m in metrics if m.name == "accuracy")
200
- assert accuracy_metric.type == "classification"
201
- assert "datasets.load_metric('accuracy')" in accuracy_metric.implementation
202
-
203
- # Check non-primary metrics
204
- non_primary_metrics = {m.name for m in metrics if not m.primary}
205
- assert non_primary_metrics == {
206
- "per_difficulty_accuracy",
207
- "confusion_matrix",
208
- "explanation_quality",
209
- }
 
187
  """Test evaluation metrics generation."""
188
  metrics = tmlu_parser.get_evaluation_metrics()
189
 
190
+ assert len(metrics) == 2 # Check total number of metrics
191
 
192
  # Check primary metrics
193
  primary_metrics = [m for m in metrics if m.primary]
194
  assert len(primary_metrics) == 2
195
  assert any(m.name == "accuracy" for m in primary_metrics)
196
  assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_tw_legal_parser.py CHANGED
@@ -138,3 +138,27 @@ def test_system_prompt_override(tw_legal_parser):
138
 
139
  entry = parser.process_entry(test_entry)
140
  assert custom_prompt in entry.prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  entry = parser.process_entry(test_entry)
140
  assert custom_prompt in entry.prompt
141
+
142
+
143
+ def test_get_dataset_description(tw_legal_parser):
144
+ """Test getting dataset description for Taiwan Legal parser."""
145
+ description = tw_legal_parser.get_dataset_description()
146
+
147
+ assert description.name == "Taiwan Legal Benchmark"
148
+ assert description.language == "Traditional Chinese"
149
+ assert "Taiwan's legal system" in description.characteristics
150
+ assert (
151
+ "huggingface.co/datasets/lianghsun/tw-legal-benchmark-v1"
152
+ in description.citation
153
+ )
154
+
155
+
156
+ def test_get_evaluation_metrics(tw_legal_parser):
157
+ """Test getting evaluation metrics for Taiwan Legal parser."""
158
+ metrics = tw_legal_parser.get_evaluation_metrics()
159
+
160
+ assert len(metrics) == 1
161
+ metric = metrics[0]
162
+ assert metric.name == "accuracy"
163
+ assert metric.type == "classification"
164
+ assert metric.primary is True