|
from dataclasses import dataclass |
|
from typing import Any, ClassVar |
|
|
|
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry |
|
from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class GSM8KParseEntry(HuggingFaceParseEntry): |
|
"""Custom entry class for GSM8K, with fields specific to this dataset parser.""" |
|
|
|
solution: str |
|
numerical_answer: int | float |
|
task_name: str |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
prompt: str, |
|
answer: str, |
|
raw_question: str, |
|
raw_answer: str, |
|
solution: str, |
|
numerical_answer: int | float, |
|
task_name: str, |
|
) -> "GSM8KParseEntry": |
|
return cls( |
|
prompt=prompt, |
|
answer=answer, |
|
raw_question=raw_question, |
|
raw_answer=raw_answer, |
|
solution=solution, |
|
numerical_answer=numerical_answer, |
|
task_name=task_name, |
|
) |
|
|
|
|
|
class GSM8KDatasetParser(HuggingFaceDatasetParser[GSM8KParseEntry]): |
|
"""Parser for the GSM8K dataset.""" |
|
|
|
_data_source: ClassVar[str] = "openai/gsm8k" |
|
_task_names: ClassVar[list[str]] = ["main", "socratic"] |
|
_default_task: ClassVar[str] = "main" |
|
_default_system_prompt: ClassVar[str] = GSM8K_SYSTEM_PROMPT |
|
|
|
def process_entry( |
|
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any |
|
) -> GSM8KParseEntry: |
|
"""Process a single GSM8K entry.""" |
|
task = task_name or self._get_current_task(row) |
|
raw_question = row["question"] |
|
raw_answer = row["answer"] |
|
|
|
|
|
numerical_str = raw_answer.split("####")[-1].strip().replace(",", "") |
|
|
|
try: |
|
numerical_answer = float(numerical_str) |
|
if numerical_answer.is_integer(): |
|
numerical_answer = int(numerical_answer) |
|
except ValueError: |
|
raise ValueError(f"Could not convert '{numerical_str}' to number") |
|
|
|
|
|
solution = raw_answer.split("####")[0].strip() |
|
|
|
prompt = f"{self._system_prompt}\n{raw_question}" |
|
|
|
return GSM8KParseEntry.create( |
|
prompt=prompt, |
|
answer=str(numerical_answer), |
|
raw_question=raw_question, |
|
raw_answer=raw_answer, |
|
solution=solution, |
|
numerical_answer=numerical_answer, |
|
task_name=task, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
from pprint import pprint |
|
|
|
parser = GSM8KDatasetParser() |
|
parser.load() |
|
parser.parse() |
|
|
|
parsed_data = parser.get_parsed_data |
|
pprint(parsed_data[0].prompt) |
|
pprint(parsed_data[0].answer) |
|
pprint(parsed_data[0].raw_question) |
|
pprint(parsed_data[0].raw_answer) |
|
pprint(parsed_data[0].solution) |
|
pprint(parsed_data[0].numerical_answer) |
|
|