Spaces:
Sleeping
Sleeping
File size: 7,063 Bytes
9bdaa77 c46567d 9bdaa77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Attention head for categorical inputs."""
from typing import Optional
from typing_extensions import Protocol
from tracr.craft import bases
from tracr.craft import transformers
from tracr.craft import vectorspace_fns
class QueryKeyToAttnLogit(Protocol):
def __call__(self, query: bases.BasisDirection,
key: bases.BasisDirection) -> bool:
pass
def categorical_attn(
query_space: bases.VectorSpaceWithBasis,
key_space: bases.VectorSpaceWithBasis,
value_space: bases.VectorSpaceWithBasis,
output_space: bases.VectorSpaceWithBasis,
bos_space: bases.VectorSpaceWithBasis,
one_space: bases.VectorSpaceWithBasis,
attn_fn: QueryKeyToAttnLogit,
default_output: Optional[bases.VectorInBasis] = None,
causal: bool = False,
always_attend_to_bos: bool = False,
use_bos_for_default_output: bool = True,
softmax_coldness: float = 100.,
) -> transformers.AttentionHead:
"""Returns an attention head for categorical inputs.
Assumes the existence of a beginning of sequence token and attends to it
always with strength 0.5*softmax_coldness. This allows to implement an
arbitrary default value for rows in the attention pattern that are all-zero.
Attends to the BOS token if all other key-query pairs have zero attention.
Hence, the first value in the value sequence will be the default output for
such cases.
Args:
query_space: Vector space containing (categorical) query input.
key_space: Vector space containing (categorical) key input.
value_space: Vector space containing (numerical) value input.
output_space: Vector space which will contain (numerical) output.
bos_space: 1-d space used to identify the beginning of sequence token.
one_space: 1-d space which contains 1 at every position.
attn_fn: A selector function f(query, key) operating on the query/key basis
directions that defines the attention pattern.
default_output: Output to return if attention pattern is all zero.
causal: If True, use masked attention.
always_attend_to_bos: If True, always attend to the BOS token. If False,
only attend to BOS when attending to nothing else.
use_bos_for_default_output: If True, assume BOS is not in the value space
and output a default value when attending to BOS. If False, assume BOS is
in the value space, and map it to the output space like any other token.
softmax_coldness: The inverse temperature of the softmax. Default value is
high which makes the attention close to a hard maximum.
"""
bases.ensure_dims(bos_space, num_dims=1, name="bos_space")
bases.ensure_dims(one_space, num_dims=1, name="one_space")
bos_direction = bos_space.basis[0]
one_direction = one_space.basis[0]
# Add bos direction to query, key, and value spaces in case it is missing
query_space = bases.join_vector_spaces(query_space, bos_space, one_space)
key_space = bases.join_vector_spaces(key_space, bos_space)
value_space = bases.join_vector_spaces(value_space, bos_space)
if always_attend_to_bos:
value_basis = value_space.basis
else:
value_basis = [v for v in value_space.basis if v != bos_direction]
assert len(value_basis) == output_space.num_dims
value_to_output = dict(zip(value_basis, output_space.basis))
if default_output is None:
default_output = output_space.null_vector()
assert default_output in output_space
def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float:
# We want to enforce the following property on our attention patterns:
# - if nothing else is attended to, attend to the BOS token.
# - otherwise, don't attend to the BOS token.
#
# We assume that the BOS position always only contains the vector bos + one,
# and that any other position has bos coefficient 0.
#
# We do this as follows:
# Let Q and K be subspaces of V containing the query and key vectors,
# both disjoint with the BOS space {bos} or the one space {one}.
# Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> β,
# s.t. W_QK(q, k) = 0 whenever either q or k are bos or one.
#
# Then define W_new: V x V -> β st:
# W_new(one, bos) = 0.5, otherwise 0.
#
# Now set W_QK' = W_QK + W_new.
#
# To evaluate the attention to the BOS position:
# W_QK'(q, bos + one)
# = W_QK'(q, bos) + W_QK'(q, one)
# = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one)
# = 0 + 0 + W_new(q, bos) + W_new(q, one)
# = W_new(q, bos) + W_new(q, one)
# = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q'
# = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one)
# = 0 + 0.5 + 0 + 0
# = 0.5
#
# To evaluate the attention to a non-BOS position:
# W_QK'(0 * bos + q, 0 * bos + k) # s.t. q β Q+{one}, k β K+{one}
# = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k)
# = W_QK'(q, 0*bos + k)
# = 0*W_QK'(q, bos) + W_QK'(q, k)
# = W_QK'(q, k)
# = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos.
# = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one.
#
# Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax
# coldness will give us the desired property. QED
#
# The following implements this idea.
# By replacing 0.5 with 1, we can instead enforce a different property: that
# the BOS token is always attended to in addition to whatever else.
if key == bos_direction and query == one_direction:
c = 1. if always_attend_to_bos else 0.5
return c * softmax_coldness
elif {key, query}.intersection({one_direction, bos_direction}):
return 0
return softmax_coldness * attn_fn(query, key)
w_qk = vectorspace_fns.ScalarBilinear.from_action(
query_space,
key_space,
qk_fun,
)
def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis:
if use_bos_for_default_output and input_dir == bos_direction:
return default_output
return output_space.vector_from_basis_direction(value_to_output[input_dir])
w_ov = vectorspace_fns.Linear.from_action(
value_space,
output_space,
ov_fun,
)
return transformers.AttentionHead(w_qk, w_ov, causal=causal)
|