File size: 3,126 Bytes
b65e855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from dataclasses import dataclass
from typing import Any, ClassVar

from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry


@dataclass(frozen=True, kw_only=True, slots=True)
class MATHParseEntry(HuggingFaceParseEntry):
    """Custom entry class for MATH dataset, with fields specific to this dataset parser."""

    level: str
    task_name: str
    solution: str

    @classmethod
    def create(
        cls,
        prompt: str,
        answer: str,
        raw_question: str,
        raw_answer: str,
        level: str,
        task_name: str,
        solution: str,
    ) -> "MATHParseEntry":
        return cls(
            prompt=prompt,
            answer=answer,
            raw_question=raw_question,
            raw_answer=raw_answer,
            level=level,
            task_name=task_name,
            solution=solution,
        )


class MATHDatasetParser(HuggingFaceDatasetParser[MATHParseEntry]):
    """Parser for the MATH dataset."""

    _data_source: ClassVar[str] = "lighteval/MATH"
    _task_names: ClassVar[list[str]] = [
        "algebra",
        "geometry",
        "calculus",
        "prealgebra",
        "intermediate_algebra",
        "number_theory",
        "precalculus",
        "all",
    ]
    _default_task: ClassVar[str] = "all"
    _default_system_prompt: ClassVar[
        str
    ] = "Solve the following mathematics problem step by step:"
    _valid_levels: ClassVar[set[str]] = {
        f"Level {i}" for i in range(1, 6)
    }  # Levels 1-5 are valid

    def _get_task_from_entry(self, data_entry: dict[str, Any]) -> str:
        """Get the task name from the data entry or fall back to current task."""
        entry_type = data_entry.get("type")
        if entry_type and (entry_type in self._task_names):
            return entry_type
        return self._current_task or self._default_task

    def process_entry(
        self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
    ) -> MATHParseEntry:
        """Process a single MATH dataset entry."""
        task = task_name or self._get_current_task(row)

        # Validate and normalize level
        level = row.get("level")
        if level not in self._valid_levels:
            level = "Unknown"

        return MATHParseEntry.create(
            prompt=f"{self._system_prompt}\n{row['problem']}",
            answer=row["solution"],
            raw_question=row["problem"],
            raw_answer=row["solution"],
            level=level,
            task_name=task,
            solution=row["solution"],
        )


if __name__ == "__main__":
    # Example usage of MATH parser
    parser = MATHDatasetParser()

    # Load the dataset
    parser.load()

    # Parse all splits
    parser.parse()

    # Get parsed data
    parsed_data = parser.get_parsed_data

    # Print example entry
    if parsed_data:
        example = parsed_data[0]
        print("\nExample parsed entry:")
        print(f"Task: {example.task_name}")
        print(f"Level: {example.level}")
        print(f"Question: {example.raw_question}")
        print(f"Solution: {example.solution}")