# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F # noqa: N812 from packaging.version import Version if Version(torch.__version__) > Version("2.5.0"): # Ffex attention is only available from torch 2.5 onwards from torch.nn.attention.flex_attention import ( _mask_mod_signature, _round_up_to_multiple, create_block_mask, create_mask, flex_attention, ) # @torch.compile(dynamic=False) def flex_attention_forward( attention_mask: torch.Tensor, batch_size: int, head_dim: int, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, scaling=None, ): """ This is defined out of classes to make compile happy. """ original_dtype = query_states.dtype num_att_heads = 8 num_key_value_heads = 1 num_key_value_groups = num_att_heads // num_key_value_heads key_states = key_states[:, :, :, None, :] key_states = key_states.expand( batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim ) key_states = key_states.reshape( batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim ) value_states = value_states[:, :, :, None, :] value_states = value_states.expand( batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim ) value_states = value_states.reshape( batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) query_states = query_states.to(torch.float32) key_states = key_states.to(torch.float32) value_states = value_states.to(torch.float32) causal_mask = attention_mask if causal_mask is not None: causal_mask = causal_mask[:, None, :, : key_states.shape[2]] if causal_mask.shape[1] == 1 and query_states.shape[1] > 1: causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1) def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature: def mask_mod(b, h, q_idx, kv_idx): # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs. return precomputed_mask[b][h][q_idx][kv_idx] return mask_mod b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask block_size = 128 q_len_rounded = _round_up_to_multiple(q_len, block_size) kv_len_rounded = _round_up_to_multiple(kv_len, block_size) # *CRITICAL* we do need to expand here, else we get a CUDA index error pad_q = q_len_rounded - q_len pad_k = kv_len_rounded - kv_len padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) mask_4d = create_mask( mod_fn=mask_mod_fn_orig, B=b_mask, H=h_mask, Q_LEN=q_len_rounded, KV_LEN=kv_len_rounded, device=causal_mask.device, _compile=False, ) mask_mod_fn_padded = precomputed_mask_factory(mask_4d) block_mask = create_block_mask( mask_mod=mask_mod_fn_padded, B=b_mask, H=h_mask, Q_LEN=q_len_rounded, KV_LEN=kv_len_rounded, BLOCK_SIZE=block_size, device=causal_mask.device, _compile=False, ) # mask is applied inside the kernel, ideally more efficiently than score_mod. attn_output, attention_weights = flex_attention( query_states, key_states, value_states, block_mask=block_mask, enable_gqa=True, # because we shaped query/key states for GQA scale=head_dim**-0.5 if scaling is None else scaling, return_lse=True, ) attn_output = attn_output.to(dtype=original_dtype) attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim] attn_output = attn_output.reshape( batch_size, -1, attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim] ) return attn_output