File size: 3,126 Bytes
b65e855 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from dataclasses import dataclass
from typing import Any, ClassVar
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
@dataclass(frozen=True, kw_only=True, slots=True)
class MATHParseEntry(HuggingFaceParseEntry):
"""Custom entry class for MATH dataset, with fields specific to this dataset parser."""
level: str
task_name: str
solution: str
@classmethod
def create(
cls,
prompt: str,
answer: str,
raw_question: str,
raw_answer: str,
level: str,
task_name: str,
solution: str,
) -> "MATHParseEntry":
return cls(
prompt=prompt,
answer=answer,
raw_question=raw_question,
raw_answer=raw_answer,
level=level,
task_name=task_name,
solution=solution,
)
class MATHDatasetParser(HuggingFaceDatasetParser[MATHParseEntry]):
"""Parser for the MATH dataset."""
_data_source: ClassVar[str] = "lighteval/MATH"
_task_names: ClassVar[list[str]] = [
"algebra",
"geometry",
"calculus",
"prealgebra",
"intermediate_algebra",
"number_theory",
"precalculus",
"all",
]
_default_task: ClassVar[str] = "all"
_default_system_prompt: ClassVar[
str
] = "Solve the following mathematics problem step by step:"
_valid_levels: ClassVar[set[str]] = {
f"Level {i}" for i in range(1, 6)
} # Levels 1-5 are valid
def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str:
"""Get the task name from the data entry or fall back to current task."""
entry_type = data_entry.get("type")
if entry_type and (entry_type in self._task_names):
return entry_type
return self._current_task or self._default_task
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> MATHParseEntry:
"""Process a single MATH dataset entry."""
task = task_name or self._get_current_task(row)
# Validate and normalize level
level = row.get("level")
if level not in self._valid_levels:
level = "Unknown"
return MATHParseEntry.create(
prompt=f"{self._system_prompt}\n{row['problem']}",
answer=row["solution"],
raw_question=row["problem"],
raw_answer=row["solution"],
level=level,
task_name=task,
solution=row["solution"],
)
if __name__ == "__main__":
# Example usage of MATH parser
parser = MATHDatasetParser()
# Load the dataset
parser.load()
# Parse all splits
parser.parse()
# Get parsed data
parsed_data = parser.get_parsed_data
# Print example entry
if parsed_data:
example = parsed_data[0]
print("\nExample parsed entry:")
print(f"Task: {example.task_name}")
print(f"Level: {example.level}")
print(f"Question: {example.raw_question}")
print(f"Solution: {example.solution}")
|