Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Attention head for categorical inputs.""" | |
from typing import Optional | |
from typing_extensions import Protocol | |
from tracr.craft import bases | |
from tracr.craft import transformers | |
from tracr.craft import vectorspace_fns | |
class QueryKeyToAttnLogit(Protocol): | |
def __call__(self, query: bases.BasisDirection, | |
key: bases.BasisDirection) -> bool: | |
pass | |
def categorical_attn( | |
query_space: bases.VectorSpaceWithBasis, | |
key_space: bases.VectorSpaceWithBasis, | |
value_space: bases.VectorSpaceWithBasis, | |
output_space: bases.VectorSpaceWithBasis, | |
bos_space: bases.VectorSpaceWithBasis, | |
one_space: bases.VectorSpaceWithBasis, | |
attn_fn: QueryKeyToAttnLogit, | |
default_output: Optional[bases.VectorInBasis] = None, | |
causal: bool = False, | |
always_attend_to_bos: bool = False, | |
use_bos_for_default_output: bool = True, | |
softmax_coldness: float = 100., | |
) -> transformers.AttentionHead: | |
"""Returns an attention head for categorical inputs. | |
Assumes the existence of a beginning of sequence token and attends to it | |
always with strength 0.5*softmax_coldness. This allows to implement an | |
arbitrary default value for rows in the attention pattern that are all-zero. | |
Attends to the BOS token if all other key-query pairs have zero attention. | |
Hence, the first value in the value sequence will be the default output for | |
such cases. | |
Args: | |
query_space: Vector space containing (categorical) query input. | |
key_space: Vector space containing (categorical) key input. | |
value_space: Vector space containing (numerical) value input. | |
output_space: Vector space which will contain (numerical) output. | |
bos_space: 1-d space used to identify the beginning of sequence token. | |
one_space: 1-d space which contains 1 at every position. | |
attn_fn: A selector function f(query, key) operating on the query/key basis | |
directions that defines the attention pattern. | |
default_output: Output to return if attention pattern is all zero. | |
causal: If True, use masked attention. | |
always_attend_to_bos: If True, always attend to the BOS token. If False, | |
only attend to BOS when attending to nothing else. | |
use_bos_for_default_output: If True, assume BOS is not in the value space | |
and output a default value when attending to BOS. If False, assume BOS is | |
in the value space, and map it to the output space like any other token. | |
softmax_coldness: The inverse temperature of the softmax. Default value is | |
high which makes the attention close to a hard maximum. | |
""" | |
bases.ensure_dims(bos_space, num_dims=1, name="bos_space") | |
bases.ensure_dims(one_space, num_dims=1, name="one_space") | |
bos_direction = bos_space.basis[0] | |
one_direction = one_space.basis[0] | |
# Add bos direction to query, key, and value spaces in case it is missing | |
query_space = bases.join_vector_spaces(query_space, bos_space, one_space) | |
key_space = bases.join_vector_spaces(key_space, bos_space) | |
value_space = bases.join_vector_spaces(value_space, bos_space) | |
if always_attend_to_bos: | |
value_basis = value_space.basis | |
else: | |
value_basis = [v for v in value_space.basis if v != bos_direction] | |
assert len(value_basis) == output_space.num_dims | |
value_to_output = dict(zip(value_basis, output_space.basis)) | |
if default_output is None: | |
default_output = output_space.null_vector() | |
assert default_output in output_space | |
def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float: | |
# We want to enforce the following property on our attention patterns: | |
# - if nothing else is attended to, attend to the BOS token. | |
# - otherwise, don't attend to the BOS token. | |
# | |
# We assume that the BOS position always only contains the vector bos + one, | |
# and that any other position has bos coefficient 0. | |
# | |
# We do this as follows: | |
# Let Q and K be subspaces of V containing the query and key vectors, | |
# both disjoint with the BOS space {bos} or the one space {one}. | |
# Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> β, | |
# s.t. W_QK(q, k) = 0 whenever either q or k are bos or one. | |
# | |
# Then define W_new: V x V -> β st: | |
# W_new(one, bos) = 0.5, otherwise 0. | |
# | |
# Now set W_QK' = W_QK + W_new. | |
# | |
# To evaluate the attention to the BOS position: | |
# W_QK'(q, bos + one) | |
# = W_QK'(q, bos) + W_QK'(q, one) | |
# = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one) | |
# = 0 + 0 + W_new(q, bos) + W_new(q, one) | |
# = W_new(q, bos) + W_new(q, one) | |
# = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q' | |
# = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one) | |
# = 0 + 0.5 + 0 + 0 | |
# = 0.5 | |
# | |
# To evaluate the attention to a non-BOS position: | |
# W_QK'(0 * bos + q, 0 * bos + k) # s.t. q β Q+{one}, k β K+{one} | |
# = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k) | |
# = W_QK'(q, 0*bos + k) | |
# = 0*W_QK'(q, bos) + W_QK'(q, k) | |
# = W_QK'(q, k) | |
# = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos. | |
# = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one. | |
# | |
# Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax | |
# coldness will give us the desired property. QED | |
# | |
# The following implements this idea. | |
# By replacing 0.5 with 1, we can instead enforce a different property: that | |
# the BOS token is always attended to in addition to whatever else. | |
if key == bos_direction and query == one_direction: | |
c = 1. if always_attend_to_bos else 0.5 | |
return c * softmax_coldness | |
elif {key, query}.intersection({one_direction, bos_direction}): | |
return 0 | |
return softmax_coldness * attn_fn(query, key) | |
w_qk = vectorspace_fns.ScalarBilinear.from_action( | |
query_space, | |
key_space, | |
qk_fun, | |
) | |
def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis: | |
if use_bos_for_default_output and input_dir == bos_direction: | |
return default_output | |
return output_space.vector_from_basis_direction(value_to_output[input_dir]) | |
w_ov = vectorspace_fns.Linear.from_action( | |
value_space, | |
output_space, | |
ov_fun, | |
) | |
return transformers.AttentionHead(w_qk, w_ov, causal=causal) | |