Update modeling_quiet.py
Browse files- modeling_quiet.py +94 -60
modeling_quiet.py
CHANGED
@@ -432,7 +432,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
432 |
super().__init__(*args, **kwargs)
|
433 |
|
434 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
435 |
-
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right
|
436 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
437 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
438 |
|
@@ -533,72 +533,104 @@ class QuietFlashAttention2(QuietAttention):
|
|
533 |
if torch.is_autocast_enabled():
|
534 |
target_dtype = torch.get_autocast_gpu_dtype()
|
535 |
# Handle the case where the model is quantized
|
|
|
|
|
536 |
else:
|
537 |
-
target_dtype =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
query_states = query_states.to(target_dtype)
|
539 |
key_states = key_states.to(target_dtype)
|
540 |
value_states = value_states.to(target_dtype)
|
541 |
|
542 |
-
#
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
# Compute the causal mask
|
547 |
-
causal_mask = torch.tril(torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device))
|
548 |
-
# Invert the mask
|
549 |
-
causal_mask = ~causal_mask
|
550 |
-
else:
|
551 |
-
causal_mask = torch.triu(
|
552 |
-
torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device), diagonal=1
|
553 |
-
)
|
554 |
-
else:
|
555 |
-
causal_mask = None
|
556 |
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
attention_mask
|
|
|
|
|
|
|
|
|
562 |
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
|
|
567 |
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
571 |
-
#
|
572 |
if attention_mask is not None:
|
573 |
-
|
574 |
-
(
|
575 |
-
query_states,
|
576 |
-
|
577 |
-
value_states,
|
578 |
-
indices_q,
|
579 |
-
cu_seq_lens,
|
580 |
-
max_seq_lens,
|
581 |
-
) = self._upad_input(query_states, key_states, value_states, attention_mask, q_len)
|
582 |
|
583 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
584 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
585 |
|
586 |
-
# Create the cu_seqlens_q and cu_seqlens_k tensors
|
587 |
-
q_max_s, k_max_s = query_states.shape[1], key_states.shape[1]
|
588 |
-
qkv_max_s = max(q_max_s, k_max_s)
|
589 |
-
|
590 |
-
q_seqlens = torch.full((batch_size,), q_max_s, dtype=torch.int32, device=query_states.device)
|
591 |
-
k_seqlens = torch.full((batch_size,), k_max_s, dtype=torch.int32, device=key_states.device)
|
592 |
-
|
593 |
-
# Adjust the attention mask to match the sequence lengths
|
594 |
-
if attention_mask is not None:
|
595 |
-
q_seqlens = attention_mask.sum(dim=1).int()
|
596 |
-
k_seqlens = attention_mask.sum(dim=1).int()
|
597 |
-
|
598 |
-
# Convert seqlens to cumulative sequence lengths
|
599 |
-
cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
|
600 |
-
cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
|
601 |
-
|
602 |
if not use_sliding_windows:
|
603 |
attn_output_unpad = flash_attn_varlen_func(
|
604 |
query_states,
|
@@ -606,8 +638,8 @@ class QuietFlashAttention2(QuietAttention):
|
|
606 |
value_states,
|
607 |
cu_seqlens_q=cu_seqlens_q,
|
608 |
cu_seqlens_k=cu_seqlens_k,
|
609 |
-
max_seqlen_q=
|
610 |
-
max_seqlen_k=
|
611 |
dropout_p=dropout,
|
612 |
softmax_scale=softmax_scale,
|
613 |
causal=causal,
|
@@ -619,8 +651,8 @@ class QuietFlashAttention2(QuietAttention):
|
|
619 |
value_states,
|
620 |
cu_seqlens_q=cu_seqlens_q,
|
621 |
cu_seqlens_k=cu_seqlens_k,
|
622 |
-
max_seqlen_q=
|
623 |
-
max_seqlen_k=
|
624 |
dropout_p=dropout,
|
625 |
softmax_scale=softmax_scale,
|
626 |
causal=causal,
|
@@ -663,7 +695,8 @@ class QuietFlashAttention2(QuietAttention):
|
|
663 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
664 |
|
665 |
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
666 |
-
value_layer= index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
|
|
667 |
if query_length == kv_seq_len:
|
668 |
query_layer = index_first_axis(
|
669 |
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
@@ -691,6 +724,8 @@ class QuietFlashAttention2(QuietAttention):
|
|
691 |
(cu_seqlens_q, cu_seqlens_k),
|
692 |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
693 |
)
|
|
|
|
|
694 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
695 |
class QuietSdpaAttention(QuietAttention):
|
696 |
"""
|
@@ -768,7 +803,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
768 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
769 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
770 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
771 |
-
|
772 |
)
|
773 |
|
774 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
@@ -1814,7 +1849,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1814 |
dim=-1
|
1815 |
)
|
1816 |
|
1817 |
-
|
1818 |
# print((new_rm_tokens > self.vocab_size - 1).any().item())
|
1819 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|
1820 |
|
|
|
432 |
super().__init__(*args, **kwargs)
|
433 |
|
434 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
435 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
436 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
437 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
438 |
|
|
|
533 |
if torch.is_autocast_enabled():
|
534 |
target_dtype = torch.get_autocast_gpu_dtype()
|
535 |
# Handle the case where the model is quantized
|
536 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
537 |
+
target_dtype = self.config._pre_quantization_dtype
|
538 |
else:
|
539 |
+
target_dtype = self.q_proj.weight.dtype
|
540 |
+
|
541 |
+
logger.warning_once(
|
542 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
543 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
544 |
+
f" {target_dtype}."
|
545 |
+
)
|
546 |
+
|
547 |
query_states = query_states.to(target_dtype)
|
548 |
key_states = key_states.to(target_dtype)
|
549 |
value_states = value_states.to(target_dtype)
|
550 |
|
551 |
+
# Reashape to the expected shape for Flash Attention
|
552 |
+
query_states = query_states.transpose(1, 2)
|
553 |
+
key_states = key_states.transpose(1, 2)
|
554 |
+
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
555 |
|
556 |
+
attn_output = self._flash_attention_forward(
|
557 |
+
query_states,
|
558 |
+
key_states,
|
559 |
+
value_states,
|
560 |
+
attention_mask,
|
561 |
+
q_len,
|
562 |
+
dropout=dropout_rate,
|
563 |
+
use_sliding_windows=use_sliding_windows,
|
564 |
+
)
|
565 |
|
566 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
567 |
+
attn_output = self.o_proj(attn_output)
|
568 |
+
|
569 |
+
if not output_attentions:
|
570 |
+
attn_weights = None
|
571 |
|
572 |
+
return attn_output, attn_weights, past_key_value
|
573 |
+
|
574 |
+
def _flash_attention_forward(
|
575 |
+
self,
|
576 |
+
query_states,
|
577 |
+
key_states,
|
578 |
+
value_states,
|
579 |
+
attention_mask,
|
580 |
+
query_length,
|
581 |
+
dropout=0.0,
|
582 |
+
softmax_scale=None,
|
583 |
+
use_sliding_windows=False,
|
584 |
+
):
|
585 |
+
"""
|
586 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
587 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
588 |
+
Args:
|
589 |
+
query_states (`torch.Tensor`):
|
590 |
+
Input query states to be passed to Flash Attention API
|
591 |
+
key_states (`torch.Tensor`):
|
592 |
+
Input key states to be passed to Flash Attention API
|
593 |
+
value_states (`torch.Tensor`):
|
594 |
+
Input value states to be passed to Flash Attention API
|
595 |
+
attention_mask (`torch.Tensor`):
|
596 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
597 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
598 |
+
dropout (`int`, *optional*):
|
599 |
+
Attention dropout
|
600 |
+
softmax_scale (`float`, *optional*):
|
601 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
602 |
+
use_sliding_windows (`bool`, *optional*):
|
603 |
+
Whether to activate sliding window attention.
|
604 |
+
"""
|
605 |
+
if not self._flash_attn_uses_top_left_mask:
|
606 |
+
causal = self.is_causal
|
607 |
+
else:
|
608 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
609 |
+
causal = self.is_causal and query_length != 1
|
610 |
+
|
611 |
+
# Ensure attention_mask has the correct shape and values
|
612 |
+
if attention_mask is not None:
|
613 |
+
if attention_mask.dim() == 4:
|
614 |
+
# Convert 4D attention mask to 2D
|
615 |
+
attention_mask = attention_mask.squeeze(1).squeeze(1)
|
616 |
+
elif attention_mask.dim() != 2:
|
617 |
+
raise ValueError(
|
618 |
+
f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
|
619 |
+
)
|
620 |
+
|
621 |
+
# Ensure attention_mask has values of 0 and 1
|
622 |
+
attention_mask = attention_mask.to(torch.bool).to(torch.int32)
|
623 |
|
624 |
+
# Contains at least one padding token in the sequence
|
625 |
if attention_mask is not None:
|
626 |
+
batch_size = query_states.shape[0]
|
627 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
628 |
+
query_states, key_states, value_states, attention_mask, query_length
|
629 |
+
)
|
|
|
|
|
|
|
|
|
|
|
630 |
|
631 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
632 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
633 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
if not use_sliding_windows:
|
635 |
attn_output_unpad = flash_attn_varlen_func(
|
636 |
query_states,
|
|
|
638 |
value_states,
|
639 |
cu_seqlens_q=cu_seqlens_q,
|
640 |
cu_seqlens_k=cu_seqlens_k,
|
641 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
642 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
643 |
dropout_p=dropout,
|
644 |
softmax_scale=softmax_scale,
|
645 |
causal=causal,
|
|
|
651 |
value_states,
|
652 |
cu_seqlens_q=cu_seqlens_q,
|
653 |
cu_seqlens_k=cu_seqlens_k,
|
654 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
655 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
656 |
dropout_p=dropout,
|
657 |
softmax_scale=softmax_scale,
|
658 |
causal=causal,
|
|
|
695 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
696 |
|
697 |
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
698 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
699 |
+
|
700 |
if query_length == kv_seq_len:
|
701 |
query_layer = index_first_axis(
|
702 |
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
|
|
724 |
(cu_seqlens_q, cu_seqlens_k),
|
725 |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
726 |
)
|
727 |
+
|
728 |
+
|
729 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
730 |
class QuietSdpaAttention(QuietAttention):
|
731 |
"""
|
|
|
803 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
804 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
805 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
806 |
+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
807 |
)
|
808 |
|
809 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
1849 |
dim=-1
|
1850 |
)
|
1851 |
|
|
|
1852 |
# print((new_rm_tokens > self.vocab_size - 1).any().item())
|
1853 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|
1854 |
|