File size: 3,275 Bytes
952a3b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from dataclasses import dataclass
from typing import Any, ClassVar

from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
from llmdataparser.prompts import MGSM_SYSTEM_PROMPT


@dataclass(frozen=True, kw_only=True, slots=True)
class MGSMParseEntry(HuggingFaceParseEntry):
    """Custom entry class for MGSM, with fields specific to this dataset parser."""

    numerical_answer: int | float
    equation_solution: str | None
    language: str

    @classmethod
    def create(
        cls,
        prompt: str,
        answer: str,
        raw_question: str,
        raw_answer: str,
        numerical_answer: int | float,
        equation_solution: str | None,
        task_name: str,
        language: str,
    ) -> "MGSMParseEntry":
        return cls(
            prompt=prompt,
            answer=answer,
            raw_question=raw_question,
            raw_answer=raw_answer,
            numerical_answer=numerical_answer,
            equation_solution=equation_solution,
            task_name=task_name,
            language=language,
        )


class MGSMDatasetParser(HuggingFaceDatasetParser[MGSMParseEntry]):
    """Parser for the MGSM (Multilingual Grade School Math) dataset."""

    _data_source: ClassVar[str] = "juletxara/mgsm"
    _default_task: ClassVar[str] = "en"
    _task_names: ClassVar[list[str]] = [
        "bn",
        "de",
        "en",
        "es",
        "fr",
        "ja",
        "ru",
        "sw",
        "te",
        "th",
        "zh",
    ]
    _default_system_prompt: ClassVar[str] = MGSM_SYSTEM_PROMPT

    def process_entry(
        self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
    ) -> MGSMParseEntry:
        """
        Process a single MGSM entry.

        Args:
            row: Dictionary containing the MGSM entry fields
            task_name: Language code for the current task

        Returns:
            MGSMParseEntry: Processed entry with prompt, answer, and metadata
        """
        task = task_name or self._get_current_task(row)
        raw_question = row["question"]
        raw_answer = row["answer"] if row["answer"] else ""
        numerical_answer = row["answer_number"]
        equation_solution = row["equation_solution"]

        # Construct the prompt with the system prompt and question
        prompt = f"{self._system_prompt}\n{raw_question}"

        # Use numerical answer as string for the answer field if no detailed answer is provided
        answer = raw_answer if raw_answer else str(numerical_answer)

        return MGSMParseEntry.create(
            prompt=prompt,
            answer=answer,
            raw_question=raw_question,
            raw_answer=raw_answer,
            numerical_answer=numerical_answer,
            equation_solution=equation_solution,
            task_name=task,
            language=task,
        )


if __name__ == "__main__":
    from pprint import pprint

    parser = MGSMDatasetParser()
    parser.load(task_name="en")  # Load French dataset
    parser.parse()

    parsed_data = parser.get_parsed_data
    pprint(parsed_data[0].prompt)
    pprint(parsed_data[0].answer)
    pprint(parsed_data[0].raw_question)
    pprint(parsed_data[0].numerical_answer)
    pprint(parsed_data[0].language)