Joschka Strueber commited on
Commit
2d8352e
·
1 Parent(s): 7c4f6b6

[Fix] error in label filtering

Browse files
Files changed (1) hide show
  1. src/dataloading.py +3 -5
src/dataloading.py CHANGED
@@ -103,11 +103,9 @@ def filter_labels(dataset_name, doc):
103
  labels = []
104
  test_target, target_key = get_test_target(doc[0])
105
  if "answer_index" in doc[0].keys():
106
- for d in doc:
107
- labels.append(d["answer_index"])
108
- elif test_target.starts_with("("):
109
- for d in doc:
110
- labels.append(opt_in_pars_to_index(d[target_key]))
111
  elif dataset_name in ["bbh_boolean_expression"]:
112
  for d in doc:
113
  if d[target_key] == "True":
 
103
  labels = []
104
  test_target, target_key = get_test_target(doc[0])
105
  if "answer_index" in doc[0].keys():
106
+ labels = [d["answer_index"] for d in doc]
107
+ elif test_target.startswith("("):
108
+ labels = [opt_in_pars_to_index(d[target_key]) for d in doc]
 
 
109
  elif dataset_name in ["bbh_boolean_expression"]:
110
  for d in doc:
111
  if d[target_key] == "True":