RASP-Synthesis / tracr /craft /chamber /categorical_attn_test.py
Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
raw
history blame
9.37 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for chamber.categorical_attn."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tracr.craft import bases
from tracr.craft import tests_common
from tracr.craft.chamber import categorical_attn
class CategoricalAttnTest(tests_common.VectorFnTestCase):
@parameterized.parameters([
dict(causal=False, input_seq=[1, 2, 3, 4, 5], result_seq=[3, 3, 3, 3, 3]),
dict(
causal=True,
input_seq=[1, 2, 3, 4, 5],
result_seq=[1, 1.5, 2, 2.5, 3]),
dict(causal=False, input_seq=[10], result_seq=[10]),
dict(causal=True, input_seq=[10], result_seq=[10]),
dict(causal=False, input_seq=[-1, 0, 1], result_seq=[0, 0, 0]),
dict(causal=True, input_seq=[-1, 0, 1], result_seq=[-1, -0.5, 0]),
])
def test_categorical_attn_can_implement_select_all(self, causal, input_seq,
result_seq):
vocab = range(-20, 20)
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
output_dir = bases.BasisDirection("output")
output_space = bases.VectorSpaceWithBasis([output_dir])
output_vec = output_space.vector_from_basis_direction(output_dir)
bos_dir = bases.BasisDirection("bos_dimension")
bos_space = bases.VectorSpaceWithBasis([bos_dir])
one_dir = bases.BasisDirection("one")
one_space = bases.VectorSpaceWithBasis([one_dir])
value_dir = bases.BasisDirection("value")
value_space = bases.VectorSpaceWithBasis([value_dir])
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
value_space = bases.join_vector_spaces(value_space, bos_space)
residual_space = bases.join_vector_spaces(input_space, value_space,
output_space)
one_vec = residual_space.vector_from_basis_direction(one_dir)
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
value_vec = residual_space.vector_from_basis_direction(value_dir)
attn = categorical_attn.categorical_attn(
key_space=input_space,
query_space=input_space,
value_space=value_space,
output_space=output_space,
bos_space=bos_space,
one_space=one_space,
attn_fn=lambda x, y: True,
causal=causal)
test_inputs = [bos_vec + one_vec]
for x in input_seq:
test_inputs.append(
residual_space.vector_from_basis_direction(
bases.BasisDirection("input", x)) + x * value_vec)
test_inputs = bases.VectorInBasis.stack(test_inputs)
# Expect the average of all (previous) tokens
expected_results = [x * output_vec for x in result_seq]
expected_results = bases.VectorInBasis.stack(expected_results)
test_outputs = attn.apply(test_inputs).project(output_space)
self.assertVectorAllClose(
tests_common.strip_bos_token(test_outputs), expected_results)
@parameterized.parameters([
dict(causal=False, input_seq=[1, 2, 3, 4, 5], default=0),
dict(causal=True, input_seq=[1, 2, 3, 4, 5], default=1),
dict(causal=False, input_seq=[10], default=2),
dict(causal=True, input_seq=[10], default=-3),
dict(causal=False, input_seq=[-1, 0, 1], default=-2),
dict(causal=True, input_seq=[-1, 0, 1], default=-1),
])
def test_categorical_attn_can_implement_select_none(self, causal, input_seq,
default):
vocab = range(-20, 20)
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
output_dir = bases.BasisDirection("output")
output_space = bases.VectorSpaceWithBasis([output_dir])
default_vec = default * output_space.vector_from_basis_direction(output_dir)
bos_dir = bases.BasisDirection("bos_dimension")
bos_space = bases.VectorSpaceWithBasis([bos_dir])
one_dir = bases.BasisDirection("one")
one_space = bases.VectorSpaceWithBasis([one_dir])
value_dir = bases.BasisDirection("value")
value_space = bases.VectorSpaceWithBasis([value_dir])
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
value_space = bases.join_vector_spaces(value_space, bos_space)
residual_space = bases.join_vector_spaces(input_space, value_space,
output_space)
value_vec = residual_space.vector_from_basis_direction(value_dir)
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
one_vec = residual_space.vector_from_basis_direction(one_dir)
attn = categorical_attn.categorical_attn(
key_space=input_space,
query_space=input_space,
value_space=value_space,
output_space=output_space,
bos_space=bos_space,
one_space=one_space,
attn_fn=lambda x, y: False,
default_output=default_vec,
causal=causal,
always_attend_to_bos=False,
use_bos_for_default_output=True)
def make_input(x):
return (one_vec + x * value_vec +
residual_space.vector_from_basis_direction(
bases.BasisDirection("input", x)))
test_inputs = bases.VectorInBasis.stack([bos_vec + one_vec] +
[make_input(x) for x in input_seq])
# Expect the default value
expected_results = [default_vec for x in input_seq]
expected_results = bases.VectorInBasis.stack(expected_results)
test_outputs = attn.apply(test_inputs).project(output_space)
self.assertVectorAllClose(
tests_common.strip_bos_token(test_outputs), expected_results)
@parameterized.parameters([
dict(num_counts=5, input_seq=[1, 4, 3, 2], n=1, result=[4, 3, 2, 1]),
dict(num_counts=10, input_seq=[5, 8, 9, 2], n=3, result=[2, 5, 8, 9])
])
def test_categorical_attn_can_implement_shift_by_n(self, num_counts,
input_seq, n, result):
query_prefix = "prefix1"
key_prefix = "prefix2"
agg_input_prefix = "prefix3"
output_prefix = "prefix4"
bos_direction = bases.BasisDirection("bos")
one_direction = bases.BasisDirection("one")
query_space = bases.VectorSpaceWithBasis.from_values(
query_prefix, range(num_counts))
key_space = bases.VectorSpaceWithBasis.from_values(key_prefix,
range(num_counts))
bos_space = bases.VectorSpaceWithBasis([bos_direction])
one_space = bases.VectorSpaceWithBasis([one_direction])
key_space = bases.join_vector_spaces(key_space, bos_space)
agg_input_space = bases.VectorSpaceWithBasis.from_values(
agg_input_prefix, range(num_counts))
agg_input_space = bases.join_vector_spaces(agg_input_space, bos_space)
output_space = bases.VectorSpaceWithBasis.from_values(
output_prefix, range(num_counts))
attn = categorical_attn.categorical_attn(
query_space=query_space,
key_space=key_space,
value_space=agg_input_space,
output_space=output_space,
bos_space=bos_space,
one_space=one_space,
attn_fn=lambda q, k: q.value == k.value,
default_output=None,
always_attend_to_bos=False,
use_bos_for_default_output=True,
causal=False)
residual_space = bases.join_vector_spaces(key_space, query_space,
agg_input_space, output_space,
one_space)
seq_len = len(input_seq)
query_seq = np.arange(n, seq_len + n) % seq_len
key_seq = np.arange(seq_len)
bos_vec = residual_space.vector_from_basis_direction(bos_direction)
one_vec = residual_space.vector_from_basis_direction(one_direction)
test_inputs = [bos_vec + one_vec]
expected_results = []
for i in range(seq_len):
test_inputs.append(
residual_space.vector_from_basis_direction(
bases.BasisDirection(query_prefix, query_seq[i])) +
residual_space.vector_from_basis_direction(
bases.BasisDirection(key_prefix, key_seq[i])) +
residual_space.vector_from_basis_direction(
bases.BasisDirection(agg_input_prefix, input_seq[i])))
expected_results.append(
residual_space.vector_from_basis_direction(
bases.BasisDirection(output_prefix, result[i])))
test_inputs = bases.VectorInBasis.stack(test_inputs)
expected_results = bases.VectorInBasis.stack(expected_results)
test_outputs = attn.apply(test_inputs)
self.assertVectorAllClose(
tests_common.strip_bos_token(test_outputs), expected_results)
if __name__ == "__main__":
absltest.main()