refactor: mbpp parser
Browse files- llmdataparser/mbpp_parser.py +70 -1
- tests/test_mbpp_parser.py +71 -0
llmdataparser/mbpp_parser.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, ClassVar
|
3 |
|
4 |
-
from llmdataparser.base_parser import
|
|
|
|
|
|
|
|
|
|
|
5 |
from llmdataparser.prompts import MBPP_SYSTEM_PROMPT
|
6 |
|
7 |
|
@@ -83,6 +88,70 @@ class MBPPDatasetParser(HuggingFaceDatasetParser[MBPPParseEntry]):
|
|
83 |
source_file=source_file,
|
84 |
)
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
if __name__ == "__main__":
|
88 |
# Example usage
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, ClassVar
|
3 |
|
4 |
+
from llmdataparser.base_parser import (
|
5 |
+
DatasetDescription,
|
6 |
+
EvaluationMetric,
|
7 |
+
HuggingFaceDatasetParser,
|
8 |
+
HuggingFaceParseEntry,
|
9 |
+
)
|
10 |
from llmdataparser.prompts import MBPP_SYSTEM_PROMPT
|
11 |
|
12 |
|
|
|
88 |
source_file=source_file,
|
89 |
)
|
90 |
|
91 |
+
def get_dataset_description(self) -> DatasetDescription:
|
92 |
+
"""Returns a description of the MBPP dataset."""
|
93 |
+
return DatasetDescription.create(
|
94 |
+
name="Mostly Basic Python Problems (MBPP)",
|
95 |
+
purpose="A benchmark for evaluating code generation capabilities using entry-level Python programming problems",
|
96 |
+
source="https://github.com/google-research/google-research/tree/master/mbpp",
|
97 |
+
language="English and Python",
|
98 |
+
format="Task descriptions in English with corresponding Python solutions and automated test cases",
|
99 |
+
characteristics=(
|
100 |
+
"Contains approximately 1,000 crowd-sourced Python programming problems "
|
101 |
+
"designed for entry-level programmers. Problems cover programming fundamentals "
|
102 |
+
"and standard library functionality. Each problem includes a task description, "
|
103 |
+
"code solution, and 3 automated test cases. A subset of the data has been "
|
104 |
+
"hand-verified by the authors."
|
105 |
+
),
|
106 |
+
citation=(
|
107 |
+
"@article{austin2021program,\n"
|
108 |
+
" title={Program Synthesis with Large Language Models},\n"
|
109 |
+
" author={Austin, Jacob and Odena, Augustus and Nye, Maxwell and Bosma, Maarten and Michalewski, Henryk and Dohan, David and Jiang, Ellen and Cai, Carrie and Terry, Michael and Le, Quoc and others},\n"
|
110 |
+
" journal={arXiv preprint arXiv:2108.07732},\n"
|
111 |
+
" year={2021}\n"
|
112 |
+
"}"
|
113 |
+
),
|
114 |
+
additional_info={
|
115 |
+
"size": "~1,000 programming problems",
|
116 |
+
"splits": "Available in full or sanitized versions",
|
117 |
+
"test_coverage": "Each problem includes 3 automated test cases",
|
118 |
+
"verification": "Subset of data has been hand-verified by authors",
|
119 |
+
},
|
120 |
+
)
|
121 |
+
|
122 |
+
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
|
123 |
+
"""Returns the recommended evaluation metrics for MBPP dataset."""
|
124 |
+
return [
|
125 |
+
EvaluationMetric.create(
|
126 |
+
name="pass@k",
|
127 |
+
type="code_evaluation",
|
128 |
+
description="Percentage of problems where at least one solution in k generations passes all test cases",
|
129 |
+
implementation="custom_pass_at_k",
|
130 |
+
primary=True,
|
131 |
+
),
|
132 |
+
EvaluationMetric.create(
|
133 |
+
name="test_case_success_rate",
|
134 |
+
type="code_evaluation",
|
135 |
+
description="Percentage of test cases passed across all problems",
|
136 |
+
implementation="custom_test_success_rate",
|
137 |
+
primary=False,
|
138 |
+
),
|
139 |
+
EvaluationMetric.create(
|
140 |
+
name="syntax_validity",
|
141 |
+
type="code_evaluation",
|
142 |
+
description="Verifies that generated code is syntactically valid Python",
|
143 |
+
implementation="custom_syntax_check",
|
144 |
+
primary=False,
|
145 |
+
),
|
146 |
+
EvaluationMetric.create(
|
147 |
+
name="code_similarity",
|
148 |
+
type="similarity",
|
149 |
+
description="Similarity between generated code and reference solution",
|
150 |
+
implementation="evaluate.load('code_eval')",
|
151 |
+
primary=False,
|
152 |
+
),
|
153 |
+
]
|
154 |
+
|
155 |
|
156 |
if __name__ == "__main__":
|
157 |
# Example usage
|
tests/test_mbpp_parser.py
CHANGED
@@ -152,3 +152,74 @@ def test_custom_system_prompt():
|
|
152 |
def test_default_system_prompt(parser):
|
153 |
"""Test parser uses default system prompt when none provided"""
|
154 |
assert parser._system_prompt == parser._default_system_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
def test_default_system_prompt(parser):
|
153 |
"""Test parser uses default system prompt when none provided"""
|
154 |
assert parser._system_prompt == parser._default_system_prompt
|
155 |
+
|
156 |
+
|
157 |
+
def test_get_dataset_description(parser):
|
158 |
+
"""Test dataset description generation."""
|
159 |
+
description = parser.get_dataset_description()
|
160 |
+
|
161 |
+
assert description.name == "Mostly Basic Python Problems (MBPP)"
|
162 |
+
assert "code generation" in description.purpose.lower()
|
163 |
+
assert "google-research" in description.source
|
164 |
+
assert description.language == "English and Python"
|
165 |
+
assert "task descriptions" in description.format.lower()
|
166 |
+
assert "python solutions" in description.format.lower()
|
167 |
+
assert "1,000" in description.characteristics
|
168 |
+
assert "entry-level programmers" in description.characteristics.lower()
|
169 |
+
assert "3 automated test cases" in description.characteristics
|
170 |
+
assert "hand-verified" in description.characteristics
|
171 |
+
assert "austin2021program" in description.citation
|
172 |
+
assert "Program Synthesis" in description.citation
|
173 |
+
|
174 |
+
# Check additional info
|
175 |
+
assert description.additional_info is not None
|
176 |
+
assert description.additional_info["size"] == "~1,000 programming problems"
|
177 |
+
assert (
|
178 |
+
description.additional_info["splits"]
|
179 |
+
== "Available in full or sanitized versions"
|
180 |
+
)
|
181 |
+
assert (
|
182 |
+
description.additional_info["test_coverage"]
|
183 |
+
== "Each problem includes 3 automated test cases"
|
184 |
+
)
|
185 |
+
assert (
|
186 |
+
description.additional_info["verification"]
|
187 |
+
== "Subset of data has been hand-verified by authors"
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
def test_get_evaluation_metrics(parser):
|
192 |
+
"""Test evaluation metrics generation."""
|
193 |
+
metrics = parser.get_evaluation_metrics()
|
194 |
+
|
195 |
+
# Check total number of metrics
|
196 |
+
assert len(metrics) == 4
|
197 |
+
|
198 |
+
# Check primary metrics
|
199 |
+
primary_metrics = [m for m in metrics if m.primary]
|
200 |
+
assert len(primary_metrics) == 1
|
201 |
+
|
202 |
+
# Verify specific metrics exist with correct properties
|
203 |
+
metric_names = {m.name for m in metrics}
|
204 |
+
assert "pass@k" in metric_names
|
205 |
+
assert "test_case_success_rate" in metric_names
|
206 |
+
assert "syntax_validity" in metric_names
|
207 |
+
|
208 |
+
# Check specific metric properties
|
209 |
+
pass_k_metric = next(m for m in metrics if m.name == "pass@k")
|
210 |
+
assert pass_k_metric.type == "code_evaluation"
|
211 |
+
assert pass_k_metric.primary is True
|
212 |
+
assert "k generations" in pass_k_metric.description.lower()
|
213 |
+
assert "custom_pass_at_k" in pass_k_metric.implementation
|
214 |
+
|
215 |
+
test_case_metric = next(m for m in metrics if m.name == "test_case_success_rate")
|
216 |
+
assert test_case_metric.type == "code_evaluation"
|
217 |
+
assert test_case_metric.primary is False
|
218 |
+
assert "test cases" in test_case_metric.description.lower()
|
219 |
+
assert "custom_test_success_rate" in test_case_metric.implementation
|
220 |
+
|
221 |
+
syntax_metric = next(m for m in metrics if m.name == "syntax_validity")
|
222 |
+
assert syntax_metric.type == "code_evaluation"
|
223 |
+
assert syntax_metric.primary is False
|
224 |
+
assert "syntactically valid" in syntax_metric.description.lower()
|
225 |
+
assert "custom_syntax_check" in syntax_metric.implementation
|