""" Wrap torch's flex attention and handle mess info or potentially refactor """ from functools import partial import torch import numpy as np import torch.nn as nn import torch.nn.functional as F try: from torch.nn.attention.flex_attention import flex_attention, create_block_mask flex_attention_available = True except ImportError: print(f"[Warning] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}") flex_attention_available = False def _causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def _length_to_offsets(lengths, device): """Converts a list of lengths to a list of offsets. Args: lengths: A list of lengths. """ offsets = [0] offsets.extend(lengths) offsets = torch.tensor(offsets, device=device, dtype=torch.int32) offsets = torch.cumsum(offsets, dim=-1) return offsets def _generate_var_mask_mod(offsets): """Generates mask mods that apply to inputs to flex attention in the sequence stacked format. Args: offsets: This tensor should be of shape(num_documents + 1) this should contain the cumulative counts of document tokens. e.g. if you have 3 documents of length 2, 4, 3 then offsets = [0, 2, 6, 9] Note: What is the sequence stacked format? When assembling batches of inputs, we take multiple sequences and stack them together to form 1 large sequence. We then use masking to ensure that the attention scores are only applied to tokens within the same document. """ def _offsets_to_doc_ids_tensor(offsets): device = offsets.device counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave( torch.arange(len(counts), device=device, dtype=torch.int32), counts ) document_id = _offsets_to_doc_ids_tensor(offsets) def var_mask_mod(b, h, q_idx, kv_idx): same_doc = document_id[q_idx] == document_id[kv_idx] causal_mask = _causal_mask(b, h, q_idx, kv_idx) return same_doc | causal_mask return var_mask_mod def _generate_var_infer_mask_with_kv_cache(lengths): kv_len = sum(lengths) def var_mask_mod(b, h, q_idx, kv_idx): return kv_idx < kv_len return var_mask_mod class FlexAttn(nn.Module): def __init__( self, block_scales:list, mask_type:str, B, H, L:int, auto_padding=False ): """ :param block_scales: accept VAR's block sizes like [(1,1), (2,2), (3,3)] :param mask_type: var/causal :param B: batch size :param H: heads num :param L: sequence length """ super().__init__() if not flex_attention_available: raise NotImplementedError((f"[Error] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}")) self.support_mask_type = ["var", "causal", "var_infer_mask_with_kv_cache"] self.auto_padding = auto_padding self.flex_attention = torch.compile(flex_attention) self.block_scales = block_scales self.lengths = [ x * y * z for x,y,z in block_scales] self.offsets = _length_to_offsets(self.lengths, device='cuda') # if L paded to align 128, block need to cover padding area if self.offsets[-1] < L: self.offsets = torch.cat((self.offsets, torch.tensor([L], device='cuda')), dim=0) if mask_type == "var": self.mask_mod = _generate_var_mask_mod(self.offsets) self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) elif mask_type == "causal": self.mask_mod = _causal_mask self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) elif mask_type == 'var_infer_mask_with_kv_cache': self.mask_mod = _generate_var_infer_mask_with_kv_cache(self.lengths) self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) else: raise NotImplementedError(f"{mask_type} not supportted in FlexAttn, support type:{self.support_mask_type}") def forward(self, q, k, v, scale = None): if self.auto_padding: q_pad_len = (128 - q.shape[-2] % 128) % 128 kv_pad_len = (128 - k.shape[-2] % 128) % 128 q_pad = F.pad(q, (0, 0, 0, q_pad_len)) k_pad = F.pad(k, (0, 0, 0, kv_pad_len)) v_pad = F.pad(v, (0, 0, 0, kv_pad_len)) oup = self.flex_attention(q_pad.to(v_pad.dtype), k_pad.to(v.dtype), v_pad, block_mask = self.block_mask, scale = scale) if q_pad_len > 0: oup = oup[:,:,:-q_pad_len] else: oup = self.flex_attention(q.to(v.dtype), k.to(v.dtype), v, block_mask = self.block_mask, scale = scale) return oup def extra_repr(self) -> str: tail = '' return f'block size:{self.block_scales} {tail}'