Elron commited on
Commit
a987537
·
verified ·
1 Parent(s): 3cbd8f0

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +28 -0
templates.py CHANGED
@@ -161,6 +161,7 @@ class MultipleChoiceTemplate(Template):
161
  source_choice_format: str = "{choice_numeral}. {choice_text}"
162
  target_choice_format: str = "{choice_numeral}"
163
  enumerator: str = "capitals"
 
164
 
165
  def prepare(self):
166
  super().prepare()
@@ -229,6 +230,18 @@ class MultipleChoiceTemplate(Template):
229
  inputs = self.prepare_multiple_choice_inputs(inputs)
230
  return super().inputs_to_instruction_and_target_prefix(inputs)
231
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
233
  target = outputs[self.target_field]
234
 
@@ -251,9 +264,24 @@ class MultipleChoiceTemplate(Template):
251
 
252
  return target, [target]
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  def process(
255
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
256
  ) -> Dict[str, Any]:
 
 
257
  result = super().process(instance, stream_name)
258
  if "options" not in result["outputs"]:
259
  result["outputs"]["options"] = self.inputs_to_choices(
 
161
  source_choice_format: str = "{choice_numeral}. {choice_text}"
162
  target_choice_format: str = "{choice_numeral}"
163
  enumerator: str = "capitals"
164
+ shuffle_choices: bool = False
165
 
166
  def prepare(self):
167
  super().prepare()
 
230
  inputs = self.prepare_multiple_choice_inputs(inputs)
231
  return super().inputs_to_instruction_and_target_prefix(inputs)
232
 
233
+ def outputs_to_target_index(self, outputs: Dict[str, object]) -> str:
234
+ target = outputs[self.target_field]
235
+
236
+ if not isinstance(target, int):
237
+ try:
238
+ return outputs[self.choices_field].index(target)
239
+ except ValueError as e:
240
+ raise ValueError(
241
+ f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {outputs[self.choices_field]}"
242
+ ) from e
243
+ return target
244
+
245
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
246
  target = outputs[self.target_field]
247
 
 
264
 
265
  return target, [target]
266
 
267
+ def _shuffle_choices(self, instance):
268
+ target_index = self.outputs_to_target_index(instance["outputs"])
269
+ original_label_choice = instance["outputs"][self.choices_field][target_index]
270
+ choices = instance["inputs"][self.choices_field]
271
+ random_generator = new_random_generator(
272
+ {**instance["inputs"], **instance["outputs"]}
273
+ )
274
+ random_generator.shuffle(choices)
275
+ instance["inputs"][self.choices_field] = choices
276
+ instance["outputs"][self.choices_field] = choices
277
+ instance["outputs"][self.target_field] = choices.index(original_label_choice)
278
+ return instance
279
+
280
  def process(
281
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
282
  ) -> Dict[str, Any]:
283
+ if self.shuffle_choices:
284
+ instance = self._shuffle_choices(instance)
285
  result = super().process(instance, stream_name)
286
  if "options" not in result["outputs"]:
287
  result["outputs"]["options"] = self.inputs_to_choices(