File size: 3,131 Bytes
c6e9c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6be0c8
c6e9c8c
e6be0c8
c6e9c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6be0c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import List

from .card import TaskCard
from .dataclass import InternalField
from .formats import ICLFormat
from .instructions import Instruction
from .operator import SourceSequntialOperator, StreamingOperator
from .recipe import Recipe
from .renderers import StandardRenderer
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .templates import Template


class StandardRecipe(Recipe, SourceSequntialOperator):
    card: TaskCard
    template: Template = None
    instruction: Instruction = None
    format: ICLFormat = ICLFormat()

    demos_pool_size: int = None
    num_demos: int = None

    demos_pool_name: str = "demos_pool"
    demos_taken_from: str = "train"
    demos_field: str = "demos"
    sampler: Sampler = None

    steps: List[StreamingOperator] = InternalField(default_factory=list)

    def prepare(self):
        self.steps = [
            self.card.loader,
        ]

        if self.card.preprocess_steps is not None:
            self.steps.extend(self.card.preprocess_steps)

        self.steps.append(self.card.task)

        if self.demos_pool_size is not None:
            self.steps.append(
                SeparateSplit(
                    from_split=self.demos_taken_from,
                    to_split_names=[self.demos_pool_name, self.demos_taken_from],
                    to_split_sizes=[int(self.demos_pool_size)],
                )
            )

        if self.num_demos is not None:
            sampler = self.card.sampler

            if self.sampler is not None:
                sampler = self.sampler

            sampler.set_size(self.num_demos)

            self.steps.append(
                SpreadSplit(
                    source_stream=self.demos_pool_name,
                    target_field=self.demos_field,
                    sampler=sampler,
                )
            )

        render = StandardRenderer(
            instruction=self.instruction,
            template=self.template,
            format=self.format,
            demos_field=self.demos_field,
        )

        self.steps.append(render)

        postprocessors = render.get_postprocessors()

        self.steps.append(
            ToUnitxtGroup(
                group="unitxt",
                metrics=self.card.task.metrics,
                postprocessors=postprocessors,
            )
        )


class StandardRecipeWithIndexes(StandardRecipe):
    instruction_card_index: int = None
    template_card_index: int = None

    def prepare(self):
        assert (
            self.template_card_index is None or self.template is None
        ), "Specify either template or template_card_index"
        if self.template_card_index is not None:
            self.template = self.card.templates[int(self.template_card_index)]

        assert (
            self.instruction_card_index is None or self.instruction is None
        ), "Specify either instruction or instruction_card_index"
        if self.instruction_card_index is not None:
            self.instruction = self.card.instructions[int(self.instruction_card_index)]

        super().prepare()