Spaces:
Sleeping
Sleeping
minor code cleanup
Browse files- syntaxgym.py +0 -4
- test.py +3 -5
syntaxgym.py
CHANGED
@@ -197,14 +197,10 @@ class SyntaxGym(evaluate.EvaluationModule):
|
|
197 |
surps_shifted = surprisals[:, :-1, :]
|
198 |
expected_ids = input_ids[:, 1:]
|
199 |
|
200 |
-
# TODO: check this logic
|
201 |
-
tt = expected_ids.unsqueeze(2)
|
202 |
# reindexed surprisals: B * (T - 1)
|
203 |
surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
|
204 |
.squeeze(2)
|
205 |
|
206 |
-
# surprisals is now B * (T - 1)
|
207 |
-
|
208 |
#### aggregate
|
209 |
condition_names = item["conditions"]["condition_name"]
|
210 |
region_totals = {condition_name: defaultdict(float)
|
|
|
197 |
surps_shifted = surprisals[:, :-1, :]
|
198 |
expected_ids = input_ids[:, 1:]
|
199 |
|
|
|
|
|
200 |
# reindexed surprisals: B * (T - 1)
|
201 |
surprisals = torch.gather(surps_shifted, 2, expected_ids.unsqueeze(2)) \
|
202 |
.squeeze(2)
|
203 |
|
|
|
|
|
204 |
#### aggregate
|
205 |
condition_names = item["conditions"]["condition_name"]
|
206 |
region_totals = {condition_name: defaultdict(float)
|
test.py
CHANGED
@@ -14,6 +14,7 @@ def syntaxgym_dataset():
|
|
14 |
|
15 |
@pytest.fixture(scope="session")
|
16 |
def syntaxgym_metric():
|
|
|
17 |
return evaluate.load("./syntaxgym.py")
|
18 |
|
19 |
|
@@ -488,17 +489,14 @@ GPT2_SUBORDINATION_SRC_REFERENCE = \
|
|
488 |
('sub_no-matrix', 5): 4.819862633503057}]
|
489 |
|
490 |
|
491 |
-
def test_gpt_subordination_region_totals():
|
492 |
"""
|
493 |
Check region-level surprisals against the original syntaxgym-core
|
494 |
implementation, using the same underlying `gpt2` model.
|
495 |
"""
|
496 |
-
reference = ... # TODO
|
497 |
|
498 |
-
# TODO work out references
|
499 |
dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
|
500 |
-
|
501 |
-
result = metric.compute(suite=dataset["test"], model_id="gpt2")
|
502 |
|
503 |
from pprint import pprint
|
504 |
pprint(result["region_totals"][0])
|
|
|
14 |
|
15 |
@pytest.fixture(scope="session")
|
16 |
def syntaxgym_metric():
|
17 |
+
# TODO work out reference
|
18 |
return evaluate.load("./syntaxgym.py")
|
19 |
|
20 |
|
|
|
489 |
('sub_no-matrix', 5): 4.819862633503057}]
|
490 |
|
491 |
|
492 |
+
def test_gpt_subordination_region_totals(syntaxgym_metric):
|
493 |
"""
|
494 |
Check region-level surprisals against the original syntaxgym-core
|
495 |
implementation, using the same underlying `gpt2` model.
|
496 |
"""
|
|
|
497 |
|
|
|
498 |
dataset = datasets.load_dataset("cpllab/syntaxgym", "subordination_src-src")
|
499 |
+
result = syntaxgym_metric.compute(suite=dataset["test"], model_id="gpt2")
|
|
|
500 |
|
501 |
from pprint import pprint
|
502 |
pprint(result["region_totals"][0])
|