Elron commited on
Commit
2c69fb8
·
verified ·
1 Parent(s): 1ace635

Upload standard.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. standard.py +94 -24
standard.py CHANGED
@@ -4,11 +4,12 @@ from .card import TaskCard
4
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
5
  from .formats import Format, SystemFormat
6
  from .logging_utils import get_logger
7
- from .operator import SourceSequentialOperator, StreamingOperator
8
  from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
9
  from .recipe import Recipe
10
  from .schema import ToUnitxtGroup
11
  from .splitters import Sampler, SeparateSplit, SpreadSplit
 
12
  from .system_prompts import EmptySystemPrompt, SystemPrompt
13
  from .templates import Template
14
 
@@ -99,15 +100,15 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
99
  def prepare_refiners(self):
100
  self.train_refiner.max_instances = self.max_train_instances
101
  self.train_refiner.apply_to_streams = ["train"]
102
- self.steps.append(self.train_refiner)
103
 
104
  self.validation_refiner.max_instances = self.max_validation_instances
105
  self.validation_refiner.apply_to_streams = ["validation"]
106
- self.steps.append(self.validation_refiner)
107
 
108
  self.test_refiner.max_instances = self.max_test_instances
109
  self.test_refiner.apply_to_streams = ["test"]
110
- self.steps.append(self.test_refiner)
111
 
112
  def prepare_metrics_and_postprocessors(self):
113
  if self.postprocessors is None:
@@ -121,9 +122,84 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
121
  metrics = self.metrics
122
  return metrics, postprocessors
123
 
124
- def prepare(self):
 
 
 
 
 
 
 
125
  self.steps = [
126
- self.card.loader,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  AddFields(
128
  fields={
129
  "recipe_metadata": {
@@ -133,25 +209,19 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
133
  "format": self.format,
134
  }
135
  }
136
- ),
137
- ]
138
-
139
- if self.loader_limit:
140
- self.card.loader.loader_limit = self.loader_limit
141
- logger.info(f"Loader line limit was set to {self.loader_limit}")
142
- self.steps.append(StreamRefiner(max_instances=self.loader_limit))
143
 
144
- if self.card.preprocess_steps is not None:
145
- self.steps.extend(self.card.preprocess_steps)
146
 
147
- self.steps.append(self.card.task)
148
 
149
  if self.augmentor.augment_task_input:
150
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
151
- self.steps.append(self.augmentor)
152
 
153
  if self.demos_pool_size is not None:
154
- self.steps.append(
155
  CreateDemosPool(
156
  from_split=self.demos_taken_from,
157
  to_split_names=[self.demos_pool_name, self.demos_taken_from],
@@ -173,23 +243,23 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
173
 
174
  self.prepare_refiners()
175
 
176
- self.steps.append(self.template)
177
  if self.num_demos > 0:
178
- self.steps.append(
179
  AddDemosField(
180
  source_stream=self.demos_pool_name,
181
  target_field=self.demos_field,
182
  sampler=self.sampler,
183
  )
184
  )
185
- self.steps.append(self.system_prompt)
186
- self.steps.append(self.format)
187
  if self.augmentor.augment_model_input:
188
- self.steps.append(self.augmentor)
189
 
190
  metrics, postprocessors = self.prepare_metrics_and_postprocessors()
191
 
192
- self.steps.append(
193
  ToUnitxtGroup(
194
  group="unitxt",
195
  metrics=metrics,
 
4
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
5
  from .formats import Format, SystemFormat
6
  from .logging_utils import get_logger
7
+ from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
8
  from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
9
  from .recipe import Recipe
10
  from .schema import ToUnitxtGroup
11
  from .splitters import Sampler, SeparateSplit, SpreadSplit
12
+ from .stream import MultiStream
13
  from .system_prompts import EmptySystemPrompt, SystemPrompt
14
  from .templates import Template
15
 
 
100
  def prepare_refiners(self):
101
  self.train_refiner.max_instances = self.max_train_instances
102
  self.train_refiner.apply_to_streams = ["train"]
103
+ self.processing.steps.append(self.train_refiner)
104
 
105
  self.validation_refiner.max_instances = self.max_validation_instances
106
  self.validation_refiner.apply_to_streams = ["validation"]
107
+ self.processing.steps.append(self.validation_refiner)
108
 
109
  self.test_refiner.max_instances = self.max_test_instances
110
  self.test_refiner.apply_to_streams = ["test"]
111
+ self.processing.steps.append(self.test_refiner)
112
 
113
  def prepare_metrics_and_postprocessors(self):
114
  if self.postprocessors is None:
 
122
  metrics = self.metrics
123
  return metrics, postprocessors
124
 
125
+ def set_pipelines(self):
126
+ self.loading = SequentialOperator()
127
+ self.metadata = SequentialOperator()
128
+ self.standardization = SequentialOperator()
129
+ self.processing = SequentialOperator()
130
+ self.verblization = SequentialOperator()
131
+ self.finalize = SequentialOperator()
132
+
133
  self.steps = [
134
+ self.loading,
135
+ self.metadata,
136
+ self.standardization,
137
+ self.processing,
138
+ self.verblization,
139
+ self.finalize,
140
+ ]
141
+
142
+ self.inference_instance = SequentialOperator()
143
+
144
+ self.inference_instance.steps = [
145
+ self.metadata,
146
+ self.processing,
147
+ ]
148
+
149
+ self.inference_demos = SourceSequentialOperator()
150
+
151
+ self.inference_demos.steps = [
152
+ self.loading,
153
+ self.metadata,
154
+ self.standardization,
155
+ self.processing,
156
+ ]
157
+
158
+ self.inference = SequentialOperator()
159
+
160
+ self.inference.steps = [self.verblization, self.finalize]
161
+
162
+ self._demos_pool_cache = None
163
+
164
+ def production_preprocess(self, task_instances):
165
+ ms = MultiStream.from_iterables({"__inference__": task_instances})
166
+ return list(self.inference_instance(ms)["__inference__"])
167
+
168
+ def production_demos_pool(self):
169
+ if self.num_demos > 0:
170
+ if self._demos_pool_cache is None:
171
+ self._demos_pool_cache = list(
172
+ self.inference_demos()[self.demos_pool_name]
173
+ )
174
+ return self._demos_pool_cache
175
+ return []
176
+
177
+ def produce(self, task_instances):
178
+ """Use the recipe in production to produce model ready query from standard task instance."""
179
+ self.before_process_multi_stream()
180
+ multi_stream = MultiStream.from_iterables(
181
+ {
182
+ "__inference__": self.production_preprocess(task_instances),
183
+ self.demos_pool_name: self.production_demos_pool(),
184
+ }
185
+ )
186
+ multi_stream = self.inference(multi_stream)
187
+ return list(multi_stream["__inference__"])
188
+
189
+ def prepare(self):
190
+ self.set_pipelines()
191
+
192
+ loader = self.card.loader
193
+ if self.loader_limit:
194
+ loader.loader_limit = self.loader_limit
195
+ logger.info(f"Loader line limit was set to {self.loader_limit}")
196
+ self.loading.steps.append(loader)
197
+
198
+ # This is required in case loader_limit is not enforced by the loader
199
+ if self.loader_limit:
200
+ self.loading.steps.append(StreamRefiner(max_instances=self.loader_limit))
201
+
202
+ self.metadata.steps.append(
203
  AddFields(
204
  fields={
205
  "recipe_metadata": {
 
209
  "format": self.format,
210
  }
211
  }
212
+ )
213
+ )
 
 
 
 
 
214
 
215
+ self.standardization.steps.extend(self.card.preprocess_steps)
 
216
 
217
+ self.processing.steps.append(self.card.task)
218
 
219
  if self.augmentor.augment_task_input:
220
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
221
+ self.processing.steps.append(self.augmentor)
222
 
223
  if self.demos_pool_size is not None:
224
+ self.processing.steps.append(
225
  CreateDemosPool(
226
  from_split=self.demos_taken_from,
227
  to_split_names=[self.demos_pool_name, self.demos_taken_from],
 
243
 
244
  self.prepare_refiners()
245
 
246
+ self.verblization.steps.append(self.template)
247
  if self.num_demos > 0:
248
+ self.verblization.steps.append(
249
  AddDemosField(
250
  source_stream=self.demos_pool_name,
251
  target_field=self.demos_field,
252
  sampler=self.sampler,
253
  )
254
  )
255
+ self.verblization.steps.append(self.system_prompt)
256
+ self.verblization.steps.append(self.format)
257
  if self.augmentor.augment_model_input:
258
+ self.verblization.steps.append(self.augmentor)
259
 
260
  metrics, postprocessors = self.prepare_metrics_and_postprocessors()
261
 
262
+ self.finalize.steps.append(
263
  ToUnitxtGroup(
264
  group="unitxt",
265
  metrics=metrics,