Important: Update modeling_mpt.py after Flash Attention 2.7.0

#22
Files changed (1) hide show
  1. modeling_mpt.py +3 -3
modeling_mpt.py CHANGED
@@ -140,9 +140,9 @@ def gen_flash_attn_padding_info(bsz: int, S: int, past_key_len: int, device: tor
140
  key_padding_mask = attention_mask_in_length
141
  query_padding_mask = attention_mask_in_length
142
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
143
- (_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
144
- (_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
145
- (_, indices_v, _, _) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
146
  flash_attn_padding_info['indices_q'] = indices_q
147
  flash_attn_padding_info['indices_k'] = indices_k
148
  flash_attn_padding_info['indices_v'] = indices_v
 
140
  key_padding_mask = attention_mask_in_length
141
  query_padding_mask = attention_mask_in_length
142
  unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
143
+ (_, indices_q, cu_seqlens_q, max_seqlen_q, *rest) = unpadding_function(torch.empty(bsz, S, 1, device=device), query_padding_mask)
144
+ (_, indices_k, cu_seqlens_k, max_seqlen_k, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
145
+ (_, indices_v, _, _, *rest) = unpadding_function(torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
146
  flash_attn_padding_info['indices_q'] = indices_q
147
  flash_attn_padding_info['indices_k'] = indices_k
148
  flash_attn_padding_info['indices_v'] = indices_v