|
from dataclasses import dataclass |
|
from typing import Any, ClassVar |
|
|
|
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry |
|
from llmdataparser.prompts import MGSM_SYSTEM_PROMPT |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class MGSMParseEntry(HuggingFaceParseEntry): |
|
"""Custom entry class for MGSM, with fields specific to this dataset parser.""" |
|
|
|
numerical_answer: int | float |
|
equation_solution: str | None |
|
language: str |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
prompt: str, |
|
answer: str, |
|
raw_question: str, |
|
raw_answer: str, |
|
numerical_answer: int | float, |
|
equation_solution: str | None, |
|
task_name: str, |
|
language: str, |
|
) -> "MGSMParseEntry": |
|
return cls( |
|
prompt=prompt, |
|
answer=answer, |
|
raw_question=raw_question, |
|
raw_answer=raw_answer, |
|
numerical_answer=numerical_answer, |
|
equation_solution=equation_solution, |
|
task_name=task_name, |
|
language=language, |
|
) |
|
|
|
|
|
class MGSMDatasetParser(HuggingFaceDatasetParser[MGSMParseEntry]): |
|
"""Parser for the MGSM (Multilingual Grade School Math) dataset.""" |
|
|
|
_data_source: ClassVar[str] = "juletxara/mgsm" |
|
_default_task: ClassVar[str] = "en" |
|
_task_names: ClassVar[list[str]] = [ |
|
"bn", |
|
"de", |
|
"en", |
|
"es", |
|
"fr", |
|
"ja", |
|
"ru", |
|
"sw", |
|
"te", |
|
"th", |
|
"zh", |
|
] |
|
_default_system_prompt: ClassVar[str] = MGSM_SYSTEM_PROMPT |
|
|
|
def process_entry( |
|
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any |
|
) -> MGSMParseEntry: |
|
""" |
|
Process a single MGSM entry. |
|
|
|
Args: |
|
row: Dictionary containing the MGSM entry fields |
|
task_name: Language code for the current task |
|
|
|
Returns: |
|
MGSMParseEntry: Processed entry with prompt, answer, and metadata |
|
""" |
|
task = task_name or self._get_current_task(row) |
|
raw_question = row["question"] |
|
raw_answer = row["answer"] if row["answer"] else "" |
|
numerical_answer = row["answer_number"] |
|
equation_solution = row["equation_solution"] |
|
|
|
|
|
prompt = f"{self._system_prompt}\n{raw_question}" |
|
|
|
|
|
answer = raw_answer if raw_answer else str(numerical_answer) |
|
|
|
return MGSMParseEntry.create( |
|
prompt=prompt, |
|
answer=answer, |
|
raw_question=raw_question, |
|
raw_answer=raw_answer, |
|
numerical_answer=numerical_answer, |
|
equation_solution=equation_solution, |
|
task_name=task, |
|
language=task, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
from pprint import pprint |
|
|
|
parser = MGSMDatasetParser() |
|
parser.load(task_name="en") |
|
parser.parse() |
|
|
|
parsed_data = parser.get_parsed_data |
|
pprint(parsed_data[0].prompt) |
|
pprint(parsed_data[0].answer) |
|
pprint(parsed_data[0].raw_question) |
|
pprint(parsed_data[0].numerical_answer) |
|
pprint(parsed_data[0].language) |
|
|