Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for chamber.categorical_attn.""" | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import numpy as np | |
from tracr.craft import bases | |
from tracr.craft import tests_common | |
from tracr.craft.chamber import categorical_attn | |
class CategoricalAttnTest(tests_common.VectorFnTestCase): | |
def test_categorical_attn_can_implement_select_all(self, causal, input_seq, | |
result_seq): | |
vocab = range(-20, 20) | |
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) | |
output_dir = bases.BasisDirection("output") | |
output_space = bases.VectorSpaceWithBasis([output_dir]) | |
output_vec = output_space.vector_from_basis_direction(output_dir) | |
bos_dir = bases.BasisDirection("bos_dimension") | |
bos_space = bases.VectorSpaceWithBasis([bos_dir]) | |
one_dir = bases.BasisDirection("one") | |
one_space = bases.VectorSpaceWithBasis([one_dir]) | |
value_dir = bases.BasisDirection("value") | |
value_space = bases.VectorSpaceWithBasis([value_dir]) | |
input_space = bases.join_vector_spaces(input_space, bos_space, one_space) | |
value_space = bases.join_vector_spaces(value_space, bos_space) | |
residual_space = bases.join_vector_spaces(input_space, value_space, | |
output_space) | |
one_vec = residual_space.vector_from_basis_direction(one_dir) | |
bos_vec = residual_space.vector_from_basis_direction(bos_dir) | |
value_vec = residual_space.vector_from_basis_direction(value_dir) | |
attn = categorical_attn.categorical_attn( | |
key_space=input_space, | |
query_space=input_space, | |
value_space=value_space, | |
output_space=output_space, | |
bos_space=bos_space, | |
one_space=one_space, | |
attn_fn=lambda x, y: True, | |
causal=causal) | |
test_inputs = [bos_vec + one_vec] | |
for x in input_seq: | |
test_inputs.append( | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection("input", x)) + x * value_vec) | |
test_inputs = bases.VectorInBasis.stack(test_inputs) | |
# Expect the average of all (previous) tokens | |
expected_results = [x * output_vec for x in result_seq] | |
expected_results = bases.VectorInBasis.stack(expected_results) | |
test_outputs = attn.apply(test_inputs).project(output_space) | |
self.assertVectorAllClose( | |
tests_common.strip_bos_token(test_outputs), expected_results) | |
def test_categorical_attn_can_implement_select_none(self, causal, input_seq, | |
default): | |
vocab = range(-20, 20) | |
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) | |
output_dir = bases.BasisDirection("output") | |
output_space = bases.VectorSpaceWithBasis([output_dir]) | |
default_vec = default * output_space.vector_from_basis_direction(output_dir) | |
bos_dir = bases.BasisDirection("bos_dimension") | |
bos_space = bases.VectorSpaceWithBasis([bos_dir]) | |
one_dir = bases.BasisDirection("one") | |
one_space = bases.VectorSpaceWithBasis([one_dir]) | |
value_dir = bases.BasisDirection("value") | |
value_space = bases.VectorSpaceWithBasis([value_dir]) | |
input_space = bases.join_vector_spaces(input_space, bos_space, one_space) | |
value_space = bases.join_vector_spaces(value_space, bos_space) | |
residual_space = bases.join_vector_spaces(input_space, value_space, | |
output_space) | |
value_vec = residual_space.vector_from_basis_direction(value_dir) | |
bos_vec = residual_space.vector_from_basis_direction(bos_dir) | |
one_vec = residual_space.vector_from_basis_direction(one_dir) | |
attn = categorical_attn.categorical_attn( | |
key_space=input_space, | |
query_space=input_space, | |
value_space=value_space, | |
output_space=output_space, | |
bos_space=bos_space, | |
one_space=one_space, | |
attn_fn=lambda x, y: False, | |
default_output=default_vec, | |
causal=causal, | |
always_attend_to_bos=False, | |
use_bos_for_default_output=True) | |
def make_input(x): | |
return (one_vec + x * value_vec + | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection("input", x))) | |
test_inputs = bases.VectorInBasis.stack([bos_vec + one_vec] + | |
[make_input(x) for x in input_seq]) | |
# Expect the default value | |
expected_results = [default_vec for x in input_seq] | |
expected_results = bases.VectorInBasis.stack(expected_results) | |
test_outputs = attn.apply(test_inputs).project(output_space) | |
self.assertVectorAllClose( | |
tests_common.strip_bos_token(test_outputs), expected_results) | |
def test_categorical_attn_can_implement_shift_by_n(self, num_counts, | |
input_seq, n, result): | |
query_prefix = "prefix1" | |
key_prefix = "prefix2" | |
agg_input_prefix = "prefix3" | |
output_prefix = "prefix4" | |
bos_direction = bases.BasisDirection("bos") | |
one_direction = bases.BasisDirection("one") | |
query_space = bases.VectorSpaceWithBasis.from_values( | |
query_prefix, range(num_counts)) | |
key_space = bases.VectorSpaceWithBasis.from_values(key_prefix, | |
range(num_counts)) | |
bos_space = bases.VectorSpaceWithBasis([bos_direction]) | |
one_space = bases.VectorSpaceWithBasis([one_direction]) | |
key_space = bases.join_vector_spaces(key_space, bos_space) | |
agg_input_space = bases.VectorSpaceWithBasis.from_values( | |
agg_input_prefix, range(num_counts)) | |
agg_input_space = bases.join_vector_spaces(agg_input_space, bos_space) | |
output_space = bases.VectorSpaceWithBasis.from_values( | |
output_prefix, range(num_counts)) | |
attn = categorical_attn.categorical_attn( | |
query_space=query_space, | |
key_space=key_space, | |
value_space=agg_input_space, | |
output_space=output_space, | |
bos_space=bos_space, | |
one_space=one_space, | |
attn_fn=lambda q, k: q.value == k.value, | |
default_output=None, | |
always_attend_to_bos=False, | |
use_bos_for_default_output=True, | |
causal=False) | |
residual_space = bases.join_vector_spaces(key_space, query_space, | |
agg_input_space, output_space, | |
one_space) | |
seq_len = len(input_seq) | |
query_seq = np.arange(n, seq_len + n) % seq_len | |
key_seq = np.arange(seq_len) | |
bos_vec = residual_space.vector_from_basis_direction(bos_direction) | |
one_vec = residual_space.vector_from_basis_direction(one_direction) | |
test_inputs = [bos_vec + one_vec] | |
expected_results = [] | |
for i in range(seq_len): | |
test_inputs.append( | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection(query_prefix, query_seq[i])) + | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection(key_prefix, key_seq[i])) + | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection(agg_input_prefix, input_seq[i]))) | |
expected_results.append( | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection(output_prefix, result[i]))) | |
test_inputs = bases.VectorInBasis.stack(test_inputs) | |
expected_results = bases.VectorInBasis.stack(expected_results) | |
test_outputs = attn.apply(test_inputs) | |
self.assertVectorAllClose( | |
tests_common.strip_bos_token(test_outputs), expected_results) | |
if __name__ == "__main__": | |
absltest.main() | |