|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
from functools import lru_cache |
|
from typing import Any, ClassVar, Generic, TypeVar |
|
|
|
import datasets |
|
|
|
|
|
T = TypeVar("T", bound="ParseEntry") |
|
|
|
|
|
|
|
VALID_CATEGORIES = { |
|
"Math", |
|
"General Knowledge and Reasoning", |
|
"Programming", |
|
"MultiLingual", |
|
"Taiwan", |
|
"Advanced Reasoning", |
|
"Legal", |
|
} |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class ParseEntry: |
|
"""A simple base class for entries, customizable by each dataset parser.""" |
|
|
|
question: str |
|
answer: str |
|
raw_question: str |
|
raw_answer: str |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class DatasetDescription: |
|
"""Standardized description of a dataset.""" |
|
|
|
name: str |
|
purpose: str |
|
source: str |
|
language: str |
|
format: str |
|
category: list[str] |
|
characteristics: str |
|
citation: str | None = None |
|
additional_info: dict[str, Any] | None = None |
|
|
|
@classmethod |
|
def create( |
|
cls, |
|
name: str, |
|
purpose: str, |
|
source: str, |
|
language: str, |
|
format: str, |
|
category: list[str], |
|
characteristics: str, |
|
citation: str | None = None, |
|
additional_info: dict[str, Any] | None = None, |
|
) -> "DatasetDescription": |
|
|
|
for item in category: |
|
assert ( |
|
item in VALID_CATEGORIES |
|
), f"Category '{item}' is not a valid category. Valid categories are: {VALID_CATEGORIES}" |
|
return cls( |
|
name=name, |
|
purpose=purpose, |
|
source=source, |
|
language=language, |
|
format=format, |
|
category=category, |
|
characteristics=characteristics, |
|
citation=citation, |
|
additional_info=additional_info, |
|
) |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class EvaluationMetric: |
|
"""Description of an evaluation metric for a dataset.""" |
|
|
|
name: str |
|
type: str |
|
description: str |
|
implementation: str |
|
primary: bool |
|
|
|
@classmethod |
|
def create( |
|
cls, name: str, type: str, description: str, implementation: str, primary: bool |
|
) -> "EvaluationMetric": |
|
return cls( |
|
name=name, |
|
type=type, |
|
description=description, |
|
implementation=implementation, |
|
primary=primary, |
|
) |
|
|
|
|
|
class DatasetParser(Generic[T], ABC): |
|
""" |
|
Abstract base class defining the interface for all dataset parsers. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self._parsed_data: list[T] = [] |
|
|
|
@abstractmethod |
|
def load(self, **kwargs: Any) -> None: |
|
pass |
|
|
|
@abstractmethod |
|
def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None: |
|
""" |
|
Parse the loaded dataset into self._parsed_data. |
|
""" |
|
|
|
@property |
|
def get_parsed_data(self) -> list[T]: |
|
if not hasattr(self, "_parsed_data") or not self._parsed_data: |
|
raise ValueError("Parsed data has not been initialized.") |
|
return self._parsed_data |
|
|
|
@abstractmethod |
|
def process_entry( |
|
self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any |
|
) -> T: |
|
""" |
|
Process a single entry from the dataset. |
|
|
|
Args: |
|
row: A dictionary representing a single entry from the dataset. |
|
task_name: Optional task name for the entry. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
T: The processed entry, typically an instance of a subclass of ParseEntry. |
|
""" |
|
|
|
@abstractmethod |
|
def get_dataset_description(self) -> DatasetDescription: |
|
"""Returns a standardized description of the dataset.""" |
|
|
|
@abstractmethod |
|
def get_evaluation_metrics(self) -> list[EvaluationMetric]: |
|
"""Returns the recommended evaluation metrics for the dataset.""" |
|
|
|
|
|
@dataclass(frozen=True, kw_only=True, slots=True) |
|
class HuggingFaceParseEntry(ParseEntry): |
|
"""ParseEntry with an additional task_name field.""" |
|
|
|
task_name: str |
|
|
|
|
|
class HuggingFaceDatasetParser(DatasetParser[T]): |
|
""" |
|
Base class for parsers that use datasets from Hugging Face. |
|
""" |
|
|
|
|
|
_data_source: ClassVar[str] |
|
|
|
_task_names: ClassVar[list[str]] |
|
|
|
_default_task: ClassVar[str] |
|
|
|
_hidden_task_names: ClassVar[list[str]] = [] |
|
|
|
def __init__(self, **kwargs: Any) -> None: |
|
""" |
|
Initialize a HuggingFaceDatasetParser. |
|
|
|
Args: |
|
**kwargs: Additional keyword arguments passed to the parent class. |
|
""" |
|
super().__init__() |
|
|
|
self.raw_data: dict[str, Any] | None = None |
|
|
|
self.split_names: list[str] = [] |
|
|
|
self._current_task: str = "" |
|
|
|
def _get_current_task(self, data_entry: dict[str, Any] | None = None) -> str: |
|
""" |
|
Get the currently loaded task name. |
|
|
|
Args: |
|
data_entry: Optional dictionary containing entry data that might include task information |
|
|
|
Returns: |
|
str: The task name from either the data entry (if available) or the currently set task |
|
""" |
|
|
|
if data_entry is not None and hasattr(self, "_get_task_from_entry"): |
|
try: |
|
task = self._get_task_from_entry(data_entry) |
|
if isinstance(task, str): |
|
return task |
|
except (KeyError, AttributeError): |
|
pass |
|
|
|
|
|
return self._current_task or self._default_task |
|
|
|
@property |
|
def task_names(self) -> list[str]: |
|
"""Get all available task names.""" |
|
return self._task_names |
|
|
|
@property |
|
def total_tasks(self) -> int: |
|
"""Get total number of available tasks.""" |
|
return len(self._task_names) |
|
|
|
@property |
|
def get_huggingface_link(self) -> str: |
|
return "https://huggingface.co/datasets/" + self._data_source |
|
|
|
@staticmethod |
|
@lru_cache(maxsize=3) |
|
def load_dataset_cached( |
|
data_source: str, |
|
task_name: str = "default", |
|
trust_remote_code: bool = True, |
|
**kwargs: Any, |
|
) -> datasets.Dataset: |
|
""" |
|
Cached static method to load a dataset from Hugging Face. |
|
""" |
|
return datasets.load_dataset( |
|
data_source, task_name, trust_remote_code=trust_remote_code, **kwargs |
|
) |
|
|
|
def parse( |
|
self, |
|
split_names: str | list[str] | None = None, |
|
force: bool = False, |
|
**kwargs: Any, |
|
) -> None: |
|
""" |
|
Parse the MATH dataset splits into structured entries. |
|
|
|
Args: |
|
split_names: Dataset splits to parse. Can be: |
|
- None: Parse all available splits |
|
- str: Parse a single split (e.g., "train") |
|
- list[str]: Parse multiple splits (e.g., ["train", "test"]) |
|
force: If True, overwrites existing parsed data without confirmation. |
|
If False and parsed data exists, prompts for confirmation. |
|
**kwargs: Additional keyword arguments passed to process_entry |
|
|
|
Raises: |
|
ValueError: If no data is loaded or if a specified split name doesn't exist |
|
""" |
|
if self.raw_data is None: |
|
raise ValueError("No data loaded. Please load the dataset first.") |
|
|
|
if self._parsed_data and not force: |
|
response = input( |
|
f"Found {len(self._parsed_data)} existing parsed entries. " |
|
"Do you want to overwrite them? [y/N]: " |
|
).lower() |
|
if response not in ("y", "yes"): |
|
print("Parsing cancelled. Existing data preserved.") |
|
return |
|
|
|
self._parsed_data.clear() |
|
|
|
|
|
if split_names is None: |
|
split_names = self.split_names |
|
elif isinstance(split_names, str): |
|
split_names = [split_names] |
|
|
|
for split_name in split_names: |
|
if split_name not in self.split_names: |
|
raise ValueError(f"Split '{split_name}' not found in the dataset.") |
|
|
|
dataset_split = self.raw_data[split_name] |
|
total_entries = len(dataset_split) |
|
print(f"Processing {split_name} split with {total_entries} entries...") |
|
|
|
for index, entry in enumerate(dataset_split, start=1): |
|
try: |
|
task_name = self._get_current_task(data_entry=entry) |
|
parsed_entry = self.process_entry(entry, task_name, **kwargs) |
|
self._parsed_data.append(parsed_entry) |
|
|
|
|
|
if index % 100 == 0: |
|
print( |
|
f"Processed {index}/{total_entries} entries from '{split_name}'" |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error processing entry {index} in {split_name}: {str(e)}") |
|
continue |
|
|
|
print(f"Completed parsing {index} entries from '{split_name}'") |
|
|
|
print(f"Total parsed entries: {len(self._parsed_data)}") |
|
|
|
def load( |
|
self, |
|
task_name: str | None = None, |
|
trust_remote_code: bool = True, |
|
split: str | None = None, |
|
**kwargs: Any, |
|
) -> None: |
|
""" |
|
Load the dataset using the Hugging Face datasets library. |
|
""" |
|
|
|
self._current_task = task_name or self._default_task |
|
|
|
|
|
raw_data = self.load_dataset_cached( |
|
self._data_source, |
|
task_name=self._current_task, |
|
trust_remote_code=trust_remote_code, |
|
split=split, |
|
**kwargs, |
|
) |
|
|
|
|
|
if split: |
|
self.raw_data = {split: raw_data} |
|
self.split_names = [split] |
|
else: |
|
self.raw_data = raw_data |
|
self.split_names = list(raw_data.keys()) |
|
|
|
print( |
|
f"Loaded dataset with {len(self.split_names)} groups: {', '.join(self.split_names)}." |
|
) |
|
|
|
def __repr__(self) -> str: |
|
status = "loaded" if self.raw_data is not None else "not loaded" |
|
parsed_count = len(self._parsed_data) if self._parsed_data else 0 |
|
return ( |
|
f"{self.__class__.__name__}(" |
|
f"data_source='{self._data_source}', " |
|
f"task='{self._current_task}', " |
|
f"status='{status}', " |
|
f"parsed_entries={parsed_count}" |
|
")" |
|
) |
|
|
|
def __str__(self) -> str: |
|
return self.__repr__() |
|
|