Spaces:
Runtime error
Runtime error
| # Copyright 2023 The Flax Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # CODE IS HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py | |
| # MOST OF THE TIME JUST A CONVERSION IN JAX | |
| """Relative Attention HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py | |
| , flax attention, https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L143, most of the time just a flax/jax conversion """ | |
| import functools | |
| from typing import Any, Callable, Optional, Tuple | |
| from flax.linen.dtypes import promote_dtype | |
| from flax.linen import initializers | |
| from flax.linen.linear import default_kernel_init | |
| from flax.linen.linear import DenseGeneral | |
| from flax.linen.linear import DotGeneralT | |
| from flax.linen.linear import PrecisionLike | |
| from flax.linen.module import compact | |
| from flax.linen.module import merge_param | |
| from flax.linen.module import Module | |
| import jax | |
| from jax import lax | |
| from jax import random | |
| import jax.numpy as jnp | |
| PRNGKey = Any | |
| Shape = Tuple[int, ...] | |
| Dtype = Any | |
| Array = Any | |
| roll_vmap = jax.vmap(jnp.roll, in_axes=(-2, 0, None), out_axes=-2) | |
| def _rel_shift(x): | |
| zero_pad_shape = x.shape[:-2] + (x.shape[-2], 1) | |
| zero_pad = jnp.zeros(zero_pad_shape, dtype=x.dtype) | |
| x_padded = jnp.concatenate([zero_pad, x], axis=-1) | |
| x_padded_shape = x.shape[:-2] + (x.shape[-1] + 1, x.shape[-2]) | |
| x_padded = x_padded.reshape(x_padded_shape) | |
| # x_padded=jnp.swapaxes(x_padded,0,1) | |
| x = jnp.take(x_padded, jnp.arange(1, x_padded.shape[-2]), axis=-2).reshape(x.shape) | |
| return x | |
| def dot_product_attention_weights( | |
| query: Array, | |
| key: Array, | |
| r_pos_embed, | |
| r_r_bias, | |
| r_w_bias, | |
| bias: Optional[Array] = None, | |
| mask: Optional[Array] = None, | |
| broadcast_dropout: bool = True, | |
| dropout_rng: Optional[PRNGKey] = None, | |
| dropout_rate: float = 0.0, | |
| deterministic: bool = False, | |
| dtype: Optional[Dtype] = None, | |
| precision: PrecisionLike = None, | |
| ): | |
| """Computes dot-product attention weights given query and key. | |
| Used by :func:`dot_product_attention`, which is what you'll most likely use. | |
| But if you want access to the attention weights for introspection, then | |
| you can directly call this function and call einsum yourself. | |
| Args: | |
| query: queries for calculating attention with shape of | |
| `[batch..., q_length, num_heads, qk_depth_per_head]`. | |
| key: keys for calculating attention with shape of | |
| `[batch..., kv_length, num_heads, qk_depth_per_head]`. | |
| bias: bias for the attention weights. This should be broadcastable to the | |
| shape `[batch..., num_heads, q_length, kv_length]`. | |
| This can be used for incorporating causal masks, padding masks, | |
| proximity bias, etc. | |
| mask: mask for the attention weights. This should be broadcastable to the | |
| shape `[batch..., num_heads, q_length, kv_length]`. | |
| This can be used for incorporating causal masks. | |
| Attention weights are masked out if their corresponding mask value | |
| is `False`. | |
| broadcast_dropout: bool: use a broadcasted dropout along batch dims. | |
| dropout_rng: JAX PRNGKey: to be used for dropout | |
| dropout_rate: dropout rate | |
| deterministic: bool, deterministic or not (to apply dropout) | |
| dtype: the dtype of the computation (default: infer from inputs and params) | |
| precision: numerical precision of the computation see `jax.lax.Precision` | |
| for details. | |
| Returns: | |
| Output of shape `[batch..., num_heads, q_length, kv_length]`. | |
| """ | |
| query, key = promote_dtype(query, key, dtype=dtype) | |
| dtype = query.dtype | |
| assert query.ndim == key.ndim, "q, k must have same rank." | |
| assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match." | |
| assert query.shape[-2] == key.shape[-2], "q, k num_heads must match." | |
| assert query.shape[-1] == key.shape[-1], "q, k depths must match." | |
| # calculate attention matrix | |
| depth = query.shape[-1] | |
| # query = query | |
| # attn weight shape is (batch..., num_heads, q_length, kv_length) | |
| attn_weights = jnp.einsum("...qhd,...khd->...hqk", query + r_w_bias, key, precision=precision) | |
| attn_weights_r = jnp.einsum("...qhd,khd->...hqk", query + r_r_bias, r_pos_embed, precision=precision) | |
| attn_weights_r = roll_vmap(attn_weights_r, jnp.arange(0, query.shape[-3]) - (query.shape[-3] - 1), -1) | |
| # attn_weights_r=_rel_shift(attn_weights_r) | |
| attn_weights = attn_weights + attn_weights_r | |
| attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype) | |
| # apply attention bias: masking, dropout, proximity bias, etc. | |
| if bias is not None: | |
| attn_weights = attn_weights + bias | |
| # apply attention mask | |
| if mask is not None: | |
| big_neg = jnp.finfo(dtype).min | |
| attn_weights = jnp.where(mask, attn_weights, big_neg) | |
| # normalize the attention weights | |
| attn_weights = jax.nn.softmax(attn_weights).astype(dtype) | |
| # apply attention dropout | |
| if not deterministic and dropout_rate > 0.0: | |
| keep_prob = 1.0 - dropout_rate | |
| if broadcast_dropout: | |
| # dropout is broadcast across the batch + head dimensions | |
| dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:] | |
| keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore | |
| else: | |
| keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore | |
| multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) | |
| attn_weights = attn_weights * multiplier | |
| return attn_weights | |
| def dot_product_attention( | |
| query: Array, | |
| key: Array, | |
| value: Array, | |
| r_pos_embed, | |
| r_r_bias, | |
| r_w_bias, | |
| bias: Optional[Array] = None, | |
| mask: Optional[Array] = None, | |
| broadcast_dropout: bool = True, | |
| dropout_rng: Optional[PRNGKey] = None, | |
| dropout_rate: float = 0.0, | |
| deterministic: bool = False, | |
| dtype: Optional[Dtype] = None, | |
| precision: PrecisionLike = None, | |
| ): | |
| """Computes dot-product attention given query, key, and value. | |
| This is the core function for applying attention based on | |
| https://arxiv.org/abs/1706.03762. It calculates the attention weights given | |
| query and key and combines the values using the attention weights. | |
| Note: query, key, value needn't have any batch dimensions. | |
| Args: | |
| query: queries for calculating attention with shape of | |
| `[batch..., q_length, num_heads, qk_depth_per_head]`. | |
| key: keys for calculating attention with shape of | |
| `[batch..., kv_length, num_heads, qk_depth_per_head]`. | |
| value: values to be used in attention with shape of | |
| `[batch..., kv_length, num_heads, v_depth_per_head]`. | |
| bias: bias for the attention weights. This should be broadcastable to the | |
| shape `[batch..., num_heads, q_length, kv_length]`. | |
| This can be used for incorporating causal masks, padding masks, | |
| proximity bias, etc. | |
| mask: mask for the attention weights. This should be broadcastable to the | |
| shape `[batch..., num_heads, q_length, kv_length]`. | |
| This can be used for incorporating causal masks. | |
| Attention weights are masked out if their corresponding mask value | |
| is `False`. | |
| broadcast_dropout: bool: use a broadcasted dropout along batch dims. | |
| dropout_rng: JAX PRNGKey: to be used for dropout | |
| dropout_rate: dropout rate | |
| deterministic: bool, deterministic or not (to apply dropout) | |
| dtype: the dtype of the computation (default: infer from inputs) | |
| precision: numerical precision of the computation see `jax.lax.Precision` | |
| for details. | |
| Returns: | |
| Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. | |
| """ | |
| query, key, value = promote_dtype(query, key, value, dtype=dtype) | |
| dtype = query.dtype | |
| assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." | |
| assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." | |
| assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match." | |
| assert key.shape[-3] == value.shape[-3], "k, v lengths must match." | |
| # compute attention weights | |
| attn_weights = dot_product_attention_weights( | |
| query, | |
| key, | |
| r_pos_embed, | |
| r_r_bias, | |
| r_w_bias, | |
| bias, | |
| mask, | |
| broadcast_dropout, | |
| dropout_rng, | |
| dropout_rate, | |
| deterministic, | |
| dtype, | |
| precision, | |
| ) | |
| # return weighted sum over values for each query position | |
| return jnp.einsum("...hqk,...khd->...qhd", attn_weights, value, precision=precision) | |
| class RelMultiHeadDotProductAttention(Module): | |
| """Multi-head dot-product attention. | |
| Attributes: | |
| num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) | |
| should be divisible by the number of heads. | |
| dtype: the dtype of the computation (default: infer from inputs and params) | |
| param_dtype: the dtype passed to parameter initializers (default: float32) | |
| qkv_features: dimension of the key, query, and value. | |
| out_features: dimension of the last projection | |
| broadcast_dropout: bool: use a broadcasted dropout along batch dims. | |
| dropout_rate: dropout rate | |
| deterministic: if false, the attention weight is masked randomly using | |
| dropout, whereas if true, the attention weights are deterministic. | |
| precision: numerical precision of the computation see `jax.lax.Precision` | |
| for details. | |
| kernel_init: initializer for the kernel of the Dense layers. | |
| bias_init: initializer for the bias of the Dense layers. | |
| use_bias: bool: whether pointwise QKVO dense transforms use bias. | |
| attention_fn: dot_product_attention or compatible function. Accepts query, | |
| key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, | |
| num_heads, value_channels]`` | |
| decode: whether to prepare and use an autoregressive cache. | |
| """ | |
| num_heads: int | |
| dtype: Optional[Dtype] = None | |
| param_dtype: Dtype = jnp.float32 | |
| qkv_features: Optional[int] = None | |
| out_features: Optional[int] = None | |
| broadcast_dropout: bool = True | |
| dropout_rate: float = 0.0 | |
| deterministic: Optional[bool] = None | |
| precision: PrecisionLike = None | |
| kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init | |
| bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() | |
| use_bias: bool = True | |
| attention_fn: Callable[..., Array] = dot_product_attention | |
| decode: bool = False | |
| qkv_dot_general: DotGeneralT = lax.dot_general | |
| out_dot_general: DotGeneralT = lax.dot_general | |
| def __call__( | |
| self, | |
| inputs_q: Array, | |
| inputs_kv: Array, | |
| pos_embed: Array, | |
| mask: Optional[Array] = None, | |
| deterministic: Optional[bool] = None, | |
| ): | |
| """Applies multi-head dot product attention on the input data. | |
| Projects the inputs into multi-headed query, key, and value vectors, | |
| applies dot-product attention and project the results to an output vector. | |
| Args: | |
| inputs_q: input queries of shape | |
| `[batch_sizes..., length, features]`. | |
| inputs_kv: key/values of shape | |
| `[batch_sizes..., length, features]`. | |
| mask: attention mask of shape | |
| `[batch_sizes..., num_heads, query_length, key/value_length]`. | |
| Attention weights are masked out if their corresponding mask value | |
| is `False`. | |
| deterministic: if false, the attention weight is masked randomly | |
| using dropout, whereas if true, the attention weights | |
| are deterministic. | |
| Returns: | |
| output of shape `[batch_sizes..., length, features]`. | |
| """ | |
| features = self.out_features or inputs_q.shape[-1] | |
| qkv_features = self.qkv_features or inputs_q.shape[-1] | |
| assert qkv_features % self.num_heads == 0, ( | |
| f"Memory dimension ({qkv_features}) must be divisible by number of" f" heads ({self.num_heads})." | |
| ) | |
| head_dim = qkv_features // self.num_heads | |
| dense = functools.partial( | |
| DenseGeneral, | |
| axis=-1, | |
| dtype=self.dtype, | |
| param_dtype=self.param_dtype, | |
| features=(self.num_heads, head_dim), | |
| kernel_init=self.kernel_init, | |
| bias_init=self.bias_init, | |
| use_bias=self.use_bias, | |
| precision=self.precision, | |
| dot_general=self.qkv_dot_general, | |
| ) | |
| # project inputs_q to multi-headed q/k/v | |
| # dimensions are then [batch..., length, n_heads, n_features_per_head] | |
| query, key, value = ( | |
| dense(name="query")(inputs_q), | |
| dense(name="key")(inputs_kv), | |
| dense(name="value")(inputs_kv), | |
| ) | |
| # different bc no bias | |
| dense_relpos = functools.partial( | |
| DenseGeneral, | |
| axis=-1, | |
| dtype=self.dtype, | |
| param_dtype=self.param_dtype, | |
| features=(self.num_heads, head_dim), | |
| kernel_init=self.kernel_init, | |
| use_bias=False, | |
| precision=self.precision, | |
| dot_general=self.qkv_dot_general, | |
| ) | |
| r_pos_embed = dense_relpos(name="pos_embed_mat")(pos_embed) | |
| r_r_bias = self.param("r_r_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function | |
| r_w_bias = self.param("r_w_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function | |
| # During fast autoregressive decoding, we feed one position at a time, | |
| # and cache the keys and values step by step. | |
| if self.decode: | |
| # detect if we're initializing by absence of existing cache data. | |
| is_initialized = self.has_variable("cache", "cached_key") | |
| cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) | |
| cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) | |
| cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) | |
| if is_initialized: | |
| ( | |
| *batch_dims, | |
| max_length, | |
| num_heads, | |
| depth_per_head, | |
| ) = cached_key.value.shape | |
| # shape check of cached keys against query input | |
| expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) | |
| if expected_shape != query.shape: | |
| raise ValueError( | |
| "Autoregressive cache shape error, " | |
| "expected query shape %s instead got %s." % (expected_shape, query.shape) | |
| ) | |
| # update key, value caches with our new 1d spatial slices | |
| cur_index = cache_index.value | |
| indices = (0,) * len(batch_dims) + (cur_index, 0, 0) | |
| key = lax.dynamic_update_slice(cached_key.value, key, indices) | |
| value = lax.dynamic_update_slice(cached_value.value, value, indices) | |
| cached_key.value = key | |
| cached_value.value = value | |
| cache_index.value = cache_index.value + 1 | |
| # causal mask for cached decoder self-attention: | |
| # our single query position should only attend to those key | |
| # positions that have already been generated and cached, | |
| # not the remaining zero elements. | |
| mask = combine_masks( | |
| mask, | |
| jnp.broadcast_to( | |
| jnp.arange(max_length) <= cur_index, | |
| tuple(batch_dims) + (1, 1, max_length), | |
| ), | |
| ) | |
| dropout_rng = None | |
| if self.dropout_rate > 0.0: # Require `deterministic` only if using dropout. | |
| m_deterministic = merge_param("deterministic", self.deterministic, deterministic) | |
| if not m_deterministic: | |
| dropout_rng = self.make_rng("dropout") | |
| else: | |
| m_deterministic = True | |
| # apply attention | |
| x = self.attention_fn( | |
| query, | |
| key, | |
| value, | |
| r_pos_embed, | |
| r_r_bias, | |
| r_w_bias, | |
| mask=mask, | |
| dropout_rng=dropout_rng, | |
| dropout_rate=self.dropout_rate, | |
| broadcast_dropout=self.broadcast_dropout, | |
| deterministic=m_deterministic, | |
| dtype=self.dtype, | |
| precision=self.precision, | |
| ) # pytype: disable=wrong-keyword-args | |
| # back to the original inputs dimensions | |
| out = DenseGeneral( | |
| features=features, | |
| axis=(-2, -1), | |
| kernel_init=self.kernel_init, | |
| bias_init=self.bias_init, | |
| use_bias=self.use_bias, | |
| dtype=self.dtype, | |
| param_dtype=self.param_dtype, | |
| precision=self.precision, | |
| dot_general=self.out_dot_general, | |
| name="out", # type: ignore[call-arg] | |
| )(x) | |
| return out | |
| class SelfAttention(RelMultiHeadDotProductAttention): | |
| """Self-attention special case of multi-head dot-product attention.""" | |
| def __call__( # type: ignore | |
| self, | |
| inputs_q: Array, | |
| mask: Optional[Array] = None, | |
| deterministic: Optional[bool] = None, | |
| ): | |
| """Applies multi-head dot product self-attention on the input data. | |
| Projects the inputs into multi-headed query, key, and value vectors, | |
| applies dot-product attention and project the results to an output vector. | |
| Args: | |
| inputs_q: input queries of shape | |
| `[batch_sizes..., length, features]`. | |
| mask: attention mask of shape | |
| `[batch_sizes..., num_heads, query_length, key/value_length]`. | |
| Attention weights are masked out if their corresponding mask value | |
| is `False`. | |
| deterministic: if false, the attention weight is masked randomly | |
| using dropout, whereas if true, the attention weights | |
| are deterministic. | |
| Returns: | |
| output of shape `[batch_sizes..., length, features]`. | |
| """ | |
| return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic) | |
| # mask-making utility functions | |
| def make_attention_mask( | |
| query_input: Array, | |
| key_input: Array, | |
| pairwise_fn: Callable[..., Any] = jnp.multiply, | |
| extra_batch_dims: int = 0, | |
| dtype: Dtype = jnp.float32, | |
| ): | |
| """Mask-making helper for attention weights. | |
| In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the | |
| attention weights will be `[batch..., heads, len_q, len_kv]` and this | |
| function will produce `[batch..., 1, len_q, len_kv]`. | |
| Args: | |
| query_input: a batched, flat input of query_length size | |
| key_input: a batched, flat input of key_length size | |
| pairwise_fn: broadcasting elementwise comparison function | |
| extra_batch_dims: number of extra batch dims to add singleton | |
| axes for, none by default | |
| dtype: mask return dtype | |
| Returns: | |
| A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention. | |
| """ | |
| mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)) | |
| mask = jnp.expand_dims(mask, axis=-3) | |
| mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) | |
| return mask.astype(dtype) | |
| def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32) -> Array: | |
| """Make a causal mask for self-attention. | |
| In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights | |
| will be `[batch..., heads, len, len]` and this function will produce a | |
| causal mask of shape `[batch..., 1, len, len]`. | |
| Args: | |
| x: input array of shape `[batch..., len]` | |
| extra_batch_dims: number of batch dims to add singleton axes for, | |
| none by default | |
| dtype: mask return dtype | |
| Returns: | |
| A `[batch..., 1, len, len]` shaped causal mask for 1d attention. | |
| """ | |
| idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) | |
| return make_attention_mask( | |
| idxs, | |
| idxs, | |
| jnp.greater_equal, | |
| extra_batch_dims=extra_batch_dims, | |
| dtype=dtype, | |
| ) | |
| def combine_masks(*masks: Optional[Array], dtype: Dtype = jnp.float32) -> Array: | |
| """Combine attention masks. | |
| Args: | |
| *masks: set of attention mask arguments to combine, some can be None. | |
| dtype: dtype for the returned mask. | |
| Returns: | |
| Combined mask, reduced by logical and, returns None if no masks given. | |
| """ | |
| masks_list = [m for m in masks if m is not None] | |
| if not masks_list: | |
| return None | |
| assert all( | |
| map(lambda x: x.ndim == masks_list[0].ndim, masks_list) | |
| ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}" | |
| mask, *other_masks = masks_list | |
| for other_mask in other_masks: | |
| mask = jnp.logical_and(mask, other_mask) | |
| return mask.astype(dtype) | |