Important: Update modeling_mpt.py after Flash Attention 2.7.0
#22
by
KingRei
- opened
- modeling_mpt.py +3 -3
modeling_mpt.py
CHANGED
@@ -140,9 +140,9 @@ def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: tor
|
|
140 |
key_padding_mask = attention_mask_in_length
|
141 |
query_padding_mask = attention_mask_in_length
|
142 |
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
143 |
-
(_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
|
144 |
-
(_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
145 |
-
(_, indices_v, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
146 |
flash_attn_padding_info['indices_q'] = indices_q
|
147 |
flash_attn_padding_info['indices_k'] = indices_k
|
148 |
flash_attn_padding_info['indices_v'] = indices_v
|
|
|
140 |
key_padding_mask = attention_mask_in_length
|
141 |
query_padding_mask = attention_mask_in_length
|
142 |
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
143 |
+
(_, indices_q, cu_seqlens_q, max_seqlen_q, *rest) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
|
144 |
+
(_, indices_k, cu_seqlens_k, max_seqlen_k, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
145 |
+
(_, indices_v, _, _, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
|
146 |
flash_attn_padding_info['indices_q'] = indices_q
|
147 |
flash_attn_padding_info['indices_k'] = indices_k
|
148 |
flash_attn_padding_info['indices_v'] = indices_v
|