File size: 10,790 Bytes
6b8a59c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
# Copyright 2023 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from utils import FloatTensor
import mlx.core as mx
# Custom function to mimic torch.finfo
def get_finfo_min(dtype: mx.Dtype):
dtype_str = str(dtype)
if dtype_str == 'float32':
return -3.4028235e+38 # Minimum value for float32
elif dtype_str == 'float64':
return -1.7976931348623157e+308 # Minimum value for float64
elif dtype_str == 'float16':
return -65504.0 # Minimum value for float16
raise ValueError(f"Unsupported data type: {dtype_str}")
class AttentionMaskConverter:
is_causal: bool
sliding_window: Optional[int]
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
def to_causal_4d(
batch_size: int,
query_length: int,
key_value_length: int,
dtype: mx.Dtype,
device: Union[mx.Device, "str"] = "cpu",
) -> Optional[mx.array]:
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
causal_4d_mask = self._make_causal_mask(
return causal_4d_mask
def to_4d(
attention_mask_2d: mx.array,
query_length: int,
dtype: mx.Dtype,
key_value_length: Optional[int] = None,
) -> mx.array:
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
input_shape = (attention_mask_2d.shape[0], query_length)
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
if causal_4d_mask is not None:
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), get_finfo_min(dtype))
# expanded_attn_mask + causal_4d_mask can cause some overflow
expanded_4d_mask = expanded_attn_mask
return expanded_4d_mask
def _make_causal_mask(
input_ids_shape: Tuple[int, int],
dtype: mx.Dtype,
device: mx.Device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
Make causal mask used for bi-directional self-attention.
bsz, tgt_len = input_ids_shape
mask = mx.full((tgt_len, tgt_len), get_finfo_min(dtype), device=device)
mask_cond = mx.arange(tgt_len, device=device)
mask = mask * (mask_cond[:, None] >= mask_cond[None, :])
mask = mask.astype(dtype)
if past_key_values_length > 0:
past_mask = mx.zeros((tgt_len, past_key_values_length), dtype=dtype, device=device)
mask = mx.concatenate([past_mask, mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window - 1
context_mask = mx.tril(mx.ones_like(mask, dtype=mx.bool_), k=diagonal)
mask = mask * (1 - context_mask.astype(dtype)) + context_mask.astype(dtype) * get_finfo_min(dtype)
return mask.expand_dims(axis=0).expand_dims(axis=0).broadcast_to((bsz, 1, tgt_len, tgt_len + past_key_values_length))
def _expand_mask(mask: mx.array, dtype: mx.Dtype, tgt_len: Optional[int] = None):
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(, get_finfo_min(dtype))
def _unmask_unattended(
expanded_mask: FloatTensor,
min_dtype: float,
# fmt: off
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
`attention_mask` is [bsz, src_seq_len].
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `expanded_mask` is (e.g. here left-padding case)
[[[[0, 0, 0],
[0, 0, 0],
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[0, 0, 0],
[0, 1, 0],
[0, 1, 1]]]]
then the modified `expanded_mask` will be
[[[[1, 1, 1], <-- modified
[1, 1, 1], <-- modified
[0, 0, 1]]],
[[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]],
[[[1, 1, 1], <-- modified
[0, 1, 0],
[0, 1, 1]]]]
# fmt: on
if expanded_mask.dtype == mx.bool_:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
return expanded_mask.mul(~mx.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
def _prepare_4d_causal_attention_mask(
attention_mask: Optional[mx.array],
input_shape: Union[mx.array, Tuple, List],
inputs_embeds: mx.array,
past_key_values_length: int,
sliding_window: Optional[int] = None,
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
attention_mask (`mx.array` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
input_shape (`tuple(int)` or `list(int)`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
inputs_embeds (`mx.array`):
The embedded inputs as a torch Tensor.
past_key_values_length (`int`):
The length of the key value cache.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
# 4d mask is passed through the layers
if attention_mask is not None and len(attention_mask.shape) == 2:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
elif attention_mask is not None and len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill(, get_finfo_min(inputs_embeds.dtype)
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
return attention_mask