File size: 3,415 Bytes
18bf871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from dataclasses import dataclass
from typing import Any, ClassVar

from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
from llmdataparser.prompts import MBPP_SYSTEM_PROMPT


@dataclass(frozen=True, kw_only=True, slots=True)
class MBPPParseEntry(HuggingFaceParseEntry):
    """Custom entry class for MBPP, with fields specific to this dataset parser."""

    task_id: int
    test_list: list[str]
    test_setup_code: str
    challenge_test_list: list[str]
    source_file: str

    @classmethod
    def create(
        cls,
        prompt: str,
        answer: str,
        raw_question: str,
        task_id: int,
        test_list: list[str],
        test_setup_code: str,
        challenge_test_list: list[str],
        task_name: str,
        source_file: str,
    ) -> "MBPPParseEntry":
        if not isinstance(task_id, int):
            raise ValueError("Task ID must be an integer")

        return cls(
            prompt=prompt,
            answer=answer,
            raw_question=raw_question,
            raw_answer=answer,  # In MBPP, the code solution is the raw answer
            task_id=task_id,
            test_list=test_list,
            test_setup_code=test_setup_code,
            challenge_test_list=challenge_test_list,
            task_name=task_name,
            source_file=source_file,
        )


class MBPPDatasetParser(HuggingFaceDatasetParser[MBPPParseEntry]):
    """Parser for the MBPP (Mostly Basic Python Programming) dataset."""

    _data_source: ClassVar[str] = "google-research-datasets/mbpp"
    _default_task: ClassVar[str] = "full"  # Can be 'full' or 'sanitized'
    _task_names: ClassVar[list[str]] = ["full", "sanitized"]
    _default_system_prompt: ClassVar[str] = MBPP_SYSTEM_PROMPT

    def process_entry(
        self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
    ) -> MBPPParseEntry:
        """Process a single MBPP entry."""
        raw_question = row.get("text", row.get("prompt"))
        answer = row["code"]
        task_id = row["task_id"]
        test_list = row["test_list"]
        test_setup_code = row.get("test_setup_code", "")
        challenge_test_list = row.get("challenge_test_list", [])

        # Combine system prompt with the task description
        prompt = f"{self._system_prompt}\n\nTask: {raw_question}"

        # Use task_name if provided, otherwise use default
        task = task_name or self._get_current_task(row)
        source_file = row.get("source_file", "")

        return MBPPParseEntry.create(
            prompt=prompt,
            answer=answer,
            raw_question=raw_question,
            task_id=task_id,
            test_list=test_list,
            test_setup_code=test_setup_code,
            challenge_test_list=challenge_test_list,
            task_name=task,
            source_file=source_file,
        )


if __name__ == "__main__":
    # Example usage
    parser = MBPPDatasetParser()

    # Load the dataset
    parser.load()

    # Parse all splits
    parser.parse()

    # Get parsed data
    parsed_data = parser.get_parsed_data

    # Print example entry
    if parsed_data:
        example = parsed_data[0]
        print("\nExample parsed entry:")
        print(f"Task ID: {example.task_id}")
        print(f"Task: {example.raw_question}")
        print(f"Solution:\n{example.answer}")
        print(f"Test Cases:\n{example.test_list}")