jgauthier commited on
Commit
27bb1ab
·
1 Parent(s): af46379

minor code cleanup

Browse files
Files changed (2) hide show
  1. syntaxgym.py +0 -4
  2. test.py +3 -5
syntaxgym.py CHANGED
@@ -197,14 +197,10 @@ class SyntaxGym(evaluate.EvaluationModule):
197
  surps_shifted = surprisals[:, :-1, :]
198
  expected_ids = input_ids[:, 1:]
199
 
200
- # TODO: check this logic
201
- tt = expected_ids.unsqueeze(2)
202
  # reindexed surprisals: B * (T - 1)
203
  surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
204
  .squeeze(2)
205
 
206
- # surprisals is now B * (T - 1)
207
-
208
  #### aggregate
209
  condition_names = item["conditions"]["condition_name"]
210
  region_totals = {condition_name: defaultdict(float)
 
197
  surps_shifted = surprisals[:, :-1, :]
198
  expected_ids = input_ids[:, 1:]
199
 
 
 
200
  # reindexed surprisals: B * (T - 1)
201
  surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
202
  .squeeze(2)
203
 
 
 
204
  #### aggregate
205
  condition_names = item["conditions"]["condition_name"]
206
  region_totals = {condition_name: defaultdict(float)
test.py CHANGED
@@ -14,6 +14,7 @@ def syntaxgym_dataset():
14
 
15
  @pytest.fixture(scope="session")
16
  def syntaxgym_metric():
 
17
  return evaluate.load("./syntaxgym.py")
18
 
19
 
@@ -488,17 +489,14 @@ GPT2_SUBORDINATION_SRC_REFERENCE = \
488
  ('sub_no-matrix', 5): 4.819862633503057}]
489
 
490
 
491
- def test_gpt_subordination_region_totals():
492
  """
493
  Check region-level surprisals against the original syntaxgym-core
494
  implementation, using the same underlying `gpt2` model.
495
  """
496
- reference = ... # TODO
497
 
498
- # TODO work out references
499
  dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
500
- metric = evaluate.load("./syntaxgym.py")
501
- result = metric.compute(suite=dataset["test"], model_id="gpt2")
502
 
503
  from pprint import pprint
504
  pprint(result["region_totals"][0])
 
14
 
15
  @pytest.fixture(scope="session")
16
  def syntaxgym_metric():
17
+ # TODO work out reference
18
  return evaluate.load("./syntaxgym.py")
19
 
20
 
 
489
  ('sub_no-matrix', 5): 4.819862633503057}]
490
 
491
 
492
+ def test_gpt_subordination_region_totals(syntaxgym_metric):
493
  """
494
  Check region-level surprisals against the original syntaxgym-core
495
  implementation, using the same underlying `gpt2` model.
496
  """
 
497
 
 
498
  dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
499
+ result = syntaxgym_metric.compute(suite=dataset["test"], model_id="gpt2")
 
500
 
501
  from pprint import pprint
502
  pprint(result["region_totals"][0])