File size: 2,995 Bytes
da35c69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from dataclasses import dataclass
from typing import Any

from llmdataparser.base_parser import HuggingFaceDatasetParser, ParseEntry
from llmdataparser.prompts import MMLU_SYSTEM_PROMPT


@dataclass(frozen=True)
class MMLUParseEntry(ParseEntry):
    """
    Custom entry class for MMLU, with fields specific to this dataset parser.
    """

    prompt: str
    answer_letter: str

    @classmethod
    def create(cls, prompt: str, answer_letter: str) -> "MMLUParseEntry":
        if answer_letter not in {"A", "B", "C", "D"}:
            raise ValueError(
                f"Invalid answer_letter '{answer_letter}'; must be one of 'A', 'B', 'C', 'D'."
            )
        return cls(prompt=prompt, answer_letter=answer_letter)


class MMLUDatasetParser(HuggingFaceDatasetParser[MMLUParseEntry]):
    _data_source = "cais/mmlu"

    def __init__(self, system_prompt: str = MMLU_SYSTEM_PROMPT):
        super().__init__()  # Properly initialize the base class
        self.parsed_data: list[MMLUParseEntry] = []
        self.task_names: list[str] = []
        self.subject_list: set[str] = set()
        self.system_prompt: str = system_prompt
        super().__init__()

    def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None:
        self.parsed_data.clear()
        if self.raw_data is None:
            raise ValueError("No data loaded. Please load the dataset first.")

        if split_names is None:
            split_names = self.task_names
        elif isinstance(split_names, str):
            split_names = [split_names]

        for split_name in split_names:
            if split_name not in self.task_names:
                raise ValueError(f"Task '{split_name}' not found in the dataset.")

            dataset_split = self.raw_data[split_name]
            for index, entry in enumerate(dataset_split, start=1):
                data_entry = self.process_entry(entry, **kwargs)
                self._parsed_data.append(data_entry)
                self.subject_list.add(entry.get("subject", "Unknown"))
            print(f"Parsed {index} data points from task '{split_name}'.")

        print(
            f"Number of subjects: {len(self.subject_list)}. "
            "For more details, please check the `self.subject_list` attribute."
        )

    def process_entry(self, row: dict[str, Any], **kwargs) -> MMLUParseEntry:
        """
        Generate a prompt and expected answer from the given row.

        Args:
            row (dict[str, Any]): A data point to be formatted.

        Returns:
            MMLUParseEntry: The formatted entry object.
        """
        choices = "\n".join(
            f"{chr(65 + i)}. {choice}" for i, choice in enumerate(row["choices"])
        )
        prompt = (
            f"{self.system_prompt}\nQuestion: {row['question']}\n{choices}\nAnswer:"
        )
        answer_letter = chr(65 + row["answer"])  # Convert index to 'A', 'B', 'C', 'D'

        return MMLUParseEntry.create(prompt, answer_letter)