JeffYang52415's picture
feat: add gsm8k parser
424ff6a unverified
raw
history blame
3 kB
from dataclasses import dataclass
from typing import Any, ClassVar
from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
from llmdataparser.prompts import GSM8K_SYSTEM_PROMPT
@dataclass(frozen=True, kw_only=True, slots=True)
class GSM8KParseEntry(HuggingFaceParseEntry):
"""Custom entry class for GSM8K, with fields specific to this dataset parser."""
solution: str
numerical_answer: int | float
task_name: str
@classmethod
def create(
cls,
prompt: str,
answer: str,
raw_question: str,
raw_answer: str,
solution: str,
numerical_answer: int | float,
task_name: str,
) -> "GSM8KParseEntry":
return cls(
prompt=prompt,
answer=answer,
raw_question=raw_question,
raw_answer=raw_answer,
solution=solution,
numerical_answer=numerical_answer,
task_name=task_name,
)
class GSM8KDatasetParser(HuggingFaceDatasetParser[GSM8KParseEntry]):
"""Parser for the GSM8K dataset."""
_data_source: ClassVar[str] = "openai/gsm8k"
_task_names: ClassVar[list[str]] = ["main", "socratic"]
_default_task: ClassVar[str] = "main"
_default_system_prompt: ClassVar[str] = GSM8K_SYSTEM_PROMPT
def process_entry(
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
) -> GSM8KParseEntry:
"""Process a single GSM8K entry."""
task = task_name or self._get_current_task(row)
raw_question = row["question"]
raw_answer = row["answer"]
# Extract numerical answer (always after '####' in GSM8K)
numerical_str = raw_answer.split("####")[-1].strip().replace(",", "")
# Convert string to number
try:
numerical_answer = float(numerical_str)
if numerical_answer.is_integer():
numerical_answer = int(numerical_answer)
except ValueError:
raise ValueError(f"Could not convert '{numerical_str}' to number")
# Extract solution (everything before '####')
solution = raw_answer.split("####")[0].strip()
prompt = f"{self._system_prompt}\n{raw_question}"
return GSM8KParseEntry.create(
prompt=prompt,
answer=str(numerical_answer),
raw_question=raw_question,
raw_answer=raw_answer,
solution=solution,
numerical_answer=numerical_answer, # Now guaranteed to be int or float
task_name=task, # Guarantee non-None
)
if __name__ == "__main__":
from pprint import pprint
parser = GSM8KDatasetParser()
parser.load()
parser.parse()
parsed_data = parser.get_parsed_data
pprint(parsed_data[0].prompt)
pprint(parsed_data[0].answer)
pprint(parsed_data[0].raw_question)
pprint(parsed_data[0].raw_answer)
pprint(parsed_data[0].solution)
pprint(parsed_data[0].numerical_answer)