Upload standard.py with huggingface_hub
Browse files- standard.py +94 -24
standard.py
CHANGED
@@ -4,11 +4,12 @@ from .card import TaskCard
|
|
4 |
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
|
5 |
from .formats import Format, SystemFormat
|
6 |
from .logging_utils import get_logger
|
7 |
-
from .operator import SourceSequentialOperator, StreamingOperator
|
8 |
from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
|
9 |
from .recipe import Recipe
|
10 |
from .schema import ToUnitxtGroup
|
11 |
from .splitters import Sampler, SeparateSplit, SpreadSplit
|
|
|
12 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
13 |
from .templates import Template
|
14 |
|
@@ -99,15 +100,15 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
99 |
def prepare_refiners(self):
|
100 |
self.train_refiner.max_instances = self.max_train_instances
|
101 |
self.train_refiner.apply_to_streams = ["train"]
|
102 |
-
self.steps.append(self.train_refiner)
|
103 |
|
104 |
self.validation_refiner.max_instances = self.max_validation_instances
|
105 |
self.validation_refiner.apply_to_streams = ["validation"]
|
106 |
-
self.steps.append(self.validation_refiner)
|
107 |
|
108 |
self.test_refiner.max_instances = self.max_test_instances
|
109 |
self.test_refiner.apply_to_streams = ["test"]
|
110 |
-
self.steps.append(self.test_refiner)
|
111 |
|
112 |
def prepare_metrics_and_postprocessors(self):
|
113 |
if self.postprocessors is None:
|
@@ -121,9 +122,84 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
121 |
metrics = self.metrics
|
122 |
return metrics, postprocessors
|
123 |
|
124 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
self.steps = [
|
126 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
AddFields(
|
128 |
fields={
|
129 |
"recipe_metadata": {
|
@@ -133,25 +209,19 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
133 |
"format": self.format,
|
134 |
}
|
135 |
}
|
136 |
-
)
|
137 |
-
|
138 |
-
|
139 |
-
if self.loader_limit:
|
140 |
-
self.card.loader.loader_limit = self.loader_limit
|
141 |
-
logger.info(f"Loader line limit was set to {self.loader_limit}")
|
142 |
-
self.steps.append(StreamRefiner(max_instances=self.loader_limit))
|
143 |
|
144 |
-
|
145 |
-
self.steps.extend(self.card.preprocess_steps)
|
146 |
|
147 |
-
self.steps.append(self.card.task)
|
148 |
|
149 |
if self.augmentor.augment_task_input:
|
150 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
151 |
-
self.steps.append(self.augmentor)
|
152 |
|
153 |
if self.demos_pool_size is not None:
|
154 |
-
self.steps.append(
|
155 |
CreateDemosPool(
|
156 |
from_split=self.demos_taken_from,
|
157 |
to_split_names=[self.demos_pool_name, self.demos_taken_from],
|
@@ -173,23 +243,23 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
173 |
|
174 |
self.prepare_refiners()
|
175 |
|
176 |
-
self.steps.append(self.template)
|
177 |
if self.num_demos > 0:
|
178 |
-
self.steps.append(
|
179 |
AddDemosField(
|
180 |
source_stream=self.demos_pool_name,
|
181 |
target_field=self.demos_field,
|
182 |
sampler=self.sampler,
|
183 |
)
|
184 |
)
|
185 |
-
self.steps.append(self.system_prompt)
|
186 |
-
self.steps.append(self.format)
|
187 |
if self.augmentor.augment_model_input:
|
188 |
-
self.steps.append(self.augmentor)
|
189 |
|
190 |
metrics, postprocessors = self.prepare_metrics_and_postprocessors()
|
191 |
|
192 |
-
self.steps.append(
|
193 |
ToUnitxtGroup(
|
194 |
group="unitxt",
|
195 |
metrics=metrics,
|
|
|
4 |
from .dataclass import Field, InternalField, NonPositionalField, OptionalField
|
5 |
from .formats import Format, SystemFormat
|
6 |
from .logging_utils import get_logger
|
7 |
+
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
|
8 |
from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
|
9 |
from .recipe import Recipe
|
10 |
from .schema import ToUnitxtGroup
|
11 |
from .splitters import Sampler, SeparateSplit, SpreadSplit
|
12 |
+
from .stream import MultiStream
|
13 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
14 |
from .templates import Template
|
15 |
|
|
|
100 |
def prepare_refiners(self):
|
101 |
self.train_refiner.max_instances = self.max_train_instances
|
102 |
self.train_refiner.apply_to_streams = ["train"]
|
103 |
+
self.processing.steps.append(self.train_refiner)
|
104 |
|
105 |
self.validation_refiner.max_instances = self.max_validation_instances
|
106 |
self.validation_refiner.apply_to_streams = ["validation"]
|
107 |
+
self.processing.steps.append(self.validation_refiner)
|
108 |
|
109 |
self.test_refiner.max_instances = self.max_test_instances
|
110 |
self.test_refiner.apply_to_streams = ["test"]
|
111 |
+
self.processing.steps.append(self.test_refiner)
|
112 |
|
113 |
def prepare_metrics_and_postprocessors(self):
|
114 |
if self.postprocessors is None:
|
|
|
122 |
metrics = self.metrics
|
123 |
return metrics, postprocessors
|
124 |
|
125 |
+
def set_pipelines(self):
|
126 |
+
self.loading = SequentialOperator()
|
127 |
+
self.metadata = SequentialOperator()
|
128 |
+
self.standardization = SequentialOperator()
|
129 |
+
self.processing = SequentialOperator()
|
130 |
+
self.verblization = SequentialOperator()
|
131 |
+
self.finalize = SequentialOperator()
|
132 |
+
|
133 |
self.steps = [
|
134 |
+
self.loading,
|
135 |
+
self.metadata,
|
136 |
+
self.standardization,
|
137 |
+
self.processing,
|
138 |
+
self.verblization,
|
139 |
+
self.finalize,
|
140 |
+
]
|
141 |
+
|
142 |
+
self.inference_instance = SequentialOperator()
|
143 |
+
|
144 |
+
self.inference_instance.steps = [
|
145 |
+
self.metadata,
|
146 |
+
self.processing,
|
147 |
+
]
|
148 |
+
|
149 |
+
self.inference_demos = SourceSequentialOperator()
|
150 |
+
|
151 |
+
self.inference_demos.steps = [
|
152 |
+
self.loading,
|
153 |
+
self.metadata,
|
154 |
+
self.standardization,
|
155 |
+
self.processing,
|
156 |
+
]
|
157 |
+
|
158 |
+
self.inference = SequentialOperator()
|
159 |
+
|
160 |
+
self.inference.steps = [self.verblization, self.finalize]
|
161 |
+
|
162 |
+
self._demos_pool_cache = None
|
163 |
+
|
164 |
+
def production_preprocess(self, task_instances):
|
165 |
+
ms = MultiStream.from_iterables({"__inference__": task_instances})
|
166 |
+
return list(self.inference_instance(ms)["__inference__"])
|
167 |
+
|
168 |
+
def production_demos_pool(self):
|
169 |
+
if self.num_demos > 0:
|
170 |
+
if self._demos_pool_cache is None:
|
171 |
+
self._demos_pool_cache = list(
|
172 |
+
self.inference_demos()[self.demos_pool_name]
|
173 |
+
)
|
174 |
+
return self._demos_pool_cache
|
175 |
+
return []
|
176 |
+
|
177 |
+
def produce(self, task_instances):
|
178 |
+
"""Use the recipe in production to produce model ready query from standard task instance."""
|
179 |
+
self.before_process_multi_stream()
|
180 |
+
multi_stream = MultiStream.from_iterables(
|
181 |
+
{
|
182 |
+
"__inference__": self.production_preprocess(task_instances),
|
183 |
+
self.demos_pool_name: self.production_demos_pool(),
|
184 |
+
}
|
185 |
+
)
|
186 |
+
multi_stream = self.inference(multi_stream)
|
187 |
+
return list(multi_stream["__inference__"])
|
188 |
+
|
189 |
+
def prepare(self):
|
190 |
+
self.set_pipelines()
|
191 |
+
|
192 |
+
loader = self.card.loader
|
193 |
+
if self.loader_limit:
|
194 |
+
loader.loader_limit = self.loader_limit
|
195 |
+
logger.info(f"Loader line limit was set to {self.loader_limit}")
|
196 |
+
self.loading.steps.append(loader)
|
197 |
+
|
198 |
+
# This is required in case loader_limit is not enforced by the loader
|
199 |
+
if self.loader_limit:
|
200 |
+
self.loading.steps.append(StreamRefiner(max_instances=self.loader_limit))
|
201 |
+
|
202 |
+
self.metadata.steps.append(
|
203 |
AddFields(
|
204 |
fields={
|
205 |
"recipe_metadata": {
|
|
|
209 |
"format": self.format,
|
210 |
}
|
211 |
}
|
212 |
+
)
|
213 |
+
)
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
+
self.standardization.steps.extend(self.card.preprocess_steps)
|
|
|
216 |
|
217 |
+
self.processing.steps.append(self.card.task)
|
218 |
|
219 |
if self.augmentor.augment_task_input:
|
220 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
221 |
+
self.processing.steps.append(self.augmentor)
|
222 |
|
223 |
if self.demos_pool_size is not None:
|
224 |
+
self.processing.steps.append(
|
225 |
CreateDemosPool(
|
226 |
from_split=self.demos_taken_from,
|
227 |
to_split_names=[self.demos_pool_name, self.demos_taken_from],
|
|
|
243 |
|
244 |
self.prepare_refiners()
|
245 |
|
246 |
+
self.verblization.steps.append(self.template)
|
247 |
if self.num_demos > 0:
|
248 |
+
self.verblization.steps.append(
|
249 |
AddDemosField(
|
250 |
source_stream=self.demos_pool_name,
|
251 |
target_field=self.demos_field,
|
252 |
sampler=self.sampler,
|
253 |
)
|
254 |
)
|
255 |
+
self.verblization.steps.append(self.system_prompt)
|
256 |
+
self.verblization.steps.append(self.format)
|
257 |
if self.augmentor.augment_model_input:
|
258 |
+
self.verblization.steps.append(self.augmentor)
|
259 |
|
260 |
metrics, postprocessors = self.prepare_metrics_and_postprocessors()
|
261 |
|
262 |
+
self.finalize.steps.append(
|
263 |
ToUnitxtGroup(
|
264 |
group="unitxt",
|
265 |
metrics=metrics,
|