Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Attention head for categorical inputs.""" | |
| from typing import Optional | |
| from typing_extensions import Protocol | |
| from tracr.craft import bases | |
| from tracr.craft import transformers | |
| from tracr.craft import vectorspace_fns | |
| class QueryKeyToAttnLogit(Protocol): | |
| def __call__(self, query: bases.BasisDirection, | |
| key: bases.BasisDirection) -> bool: | |
| pass | |
| def categorical_attn( | |
| query_space: bases.VectorSpaceWithBasis, | |
| key_space: bases.VectorSpaceWithBasis, | |
| value_space: bases.VectorSpaceWithBasis, | |
| output_space: bases.VectorSpaceWithBasis, | |
| bos_space: bases.VectorSpaceWithBasis, | |
| one_space: bases.VectorSpaceWithBasis, | |
| attn_fn: QueryKeyToAttnLogit, | |
| default_output: Optional[bases.VectorInBasis] = None, | |
| causal: bool = False, | |
| always_attend_to_bos: bool = False, | |
| use_bos_for_default_output: bool = True, | |
| softmax_coldness: float = 100., | |
| ) -> transformers.AttentionHead: | |
| """Returns an attention head for categorical inputs. | |
| Assumes the existence of a beginning of sequence token and attends to it | |
| always with strength 0.5*softmax_coldness. This allows to implement an | |
| arbitrary default value for rows in the attention pattern that are all-zero. | |
| Attends to the BOS token if all other key-query pairs have zero attention. | |
| Hence, the first value in the value sequence will be the default output for | |
| such cases. | |
| Args: | |
| query_space: Vector space containing (categorical) query input. | |
| key_space: Vector space containing (categorical) key input. | |
| value_space: Vector space containing (numerical) value input. | |
| output_space: Vector space which will contain (numerical) output. | |
| bos_space: 1-d space used to identify the beginning of sequence token. | |
| one_space: 1-d space which contains 1 at every position. | |
| attn_fn: A selector function f(query, key) operating on the query/key basis | |
| directions that defines the attention pattern. | |
| default_output: Output to return if attention pattern is all zero. | |
| causal: If True, use masked attention. | |
| always_attend_to_bos: If True, always attend to the BOS token. If False, | |
| only attend to BOS when attending to nothing else. | |
| use_bos_for_default_output: If True, assume BOS is not in the value space | |
| and output a default value when attending to BOS. If False, assume BOS is | |
| in the value space, and map it to the output space like any other token. | |
| softmax_coldness: The inverse temperature of the softmax. Default value is | |
| high which makes the attention close to a hard maximum. | |
| """ | |
| bases.ensure_dims(bos_space, num_dims=1, name="bos_space") | |
| bases.ensure_dims(one_space, num_dims=1, name="one_space") | |
| bos_direction = bos_space.basis[0] | |
| one_direction = one_space.basis[0] | |
| # Add bos direction to query, key, and value spaces in case it is missing | |
| query_space = bases.join_vector_spaces(query_space, bos_space, one_space) | |
| key_space = bases.join_vector_spaces(key_space, bos_space) | |
| value_space = bases.join_vector_spaces(value_space, bos_space) | |
| if always_attend_to_bos: | |
| value_basis = value_space.basis | |
| else: | |
| value_basis = [v for v in value_space.basis if v != bos_direction] | |
| assert len(value_basis) == output_space.num_dims | |
| value_to_output = dict(zip(value_basis, output_space.basis)) | |
| if default_output is None: | |
| default_output = output_space.null_vector() | |
| assert default_output in output_space | |
| def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float: | |
| # We want to enforce the following property on our attention patterns: | |
| # - if nothing else is attended to, attend to the BOS token. | |
| # - otherwise, don't attend to the BOS token. | |
| # | |
| # We assume that the BOS position always only contains the vector bos + one, | |
| # and that any other position has bos coefficient 0. | |
| # | |
| # We do this as follows: | |
| # Let Q and K be subspaces of V containing the query and key vectors, | |
| # both disjoint with the BOS space {bos} or the one space {one}. | |
| # Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> β, | |
| # s.t. W_QK(q, k) = 0 whenever either q or k are bos or one. | |
| # | |
| # Then define W_new: V x V -> β st: | |
| # W_new(one, bos) = 0.5, otherwise 0. | |
| # | |
| # Now set W_QK' = W_QK + W_new. | |
| # | |
| # To evaluate the attention to the BOS position: | |
| # W_QK'(q, bos + one) | |
| # = W_QK'(q, bos) + W_QK'(q, one) | |
| # = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one) | |
| # = 0 + 0 + W_new(q, bos) + W_new(q, one) | |
| # = W_new(q, bos) + W_new(q, one) | |
| # = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q' | |
| # = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one) | |
| # = 0 + 0.5 + 0 + 0 | |
| # = 0.5 | |
| # | |
| # To evaluate the attention to a non-BOS position: | |
| # W_QK'(0 * bos + q, 0 * bos + k) # s.t. q β Q+{one}, k β K+{one} | |
| # = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k) | |
| # = W_QK'(q, 0*bos + k) | |
| # = 0*W_QK'(q, bos) + W_QK'(q, k) | |
| # = W_QK'(q, k) | |
| # = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos. | |
| # = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one. | |
| # | |
| # Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax | |
| # coldness will give us the desired property. QED | |
| # | |
| # The following implements this idea. | |
| # By replacing 0.5 with 1, we can instead enforce a different property: that | |
| # the BOS token is always attended to in addition to whatever else. | |
| if key == bos_direction and query == one_direction: | |
| c = 1. if always_attend_to_bos else 0.5 | |
| return c * softmax_coldness | |
| elif {key, query}.intersection({one_direction, bos_direction}): | |
| return 0 | |
| return softmax_coldness * attn_fn(query, key) | |
| w_qk = vectorspace_fns.ScalarBilinear.from_action( | |
| query_space, | |
| key_space, | |
| qk_fun, | |
| ) | |
| def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis: | |
| if use_bos_for_default_output and input_dir == bos_direction: | |
| return default_output | |
| return output_space.vector_from_basis_direction(value_to_output[input_dir]) | |
| w_ov = vectorspace_fns.Linear.from_action( | |
| value_space, | |
| output_space, | |
| ov_fun, | |
| ) | |
| return transformers.AttentionHead(w_qk, w_ov, causal=causal) | |