Crystalcareai commited on
Commit
f5c1913
·
verified ·
1 Parent(s): 3d0a2d9

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +94 -60
modeling_quiet.py CHANGED
@@ -432,7 +432,7 @@ class QuietFlashAttention2(QuietAttention):
432
  super().__init__(*args, **kwargs)
433
 
434
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
435
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
436
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
437
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
438
 
@@ -533,72 +533,104 @@ class QuietFlashAttention2(QuietAttention):
533
  if torch.is_autocast_enabled():
534
  target_dtype = torch.get_autocast_gpu_dtype()
535
  # Handle the case where the model is quantized
 
 
536
  else:
537
- target_dtype = torch.float16
 
 
 
 
 
 
 
538
  query_states = query_states.to(target_dtype)
539
  key_states = key_states.to(target_dtype)
540
  value_states = value_states.to(target_dtype)
541
 
542
- # Compute the causal mask
543
- causal = self.config.causal
544
- if causal:
545
- if self._flash_attn_uses_top_left_mask:
546
- # Compute the causal mask
547
- causal_mask = torch.tril(torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device))
548
- # Invert the mask
549
- causal_mask = ~causal_mask
550
- else:
551
- causal_mask = torch.triu(
552
- torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device), diagonal=1
553
- )
554
- else:
555
- causal_mask = None
556
 
557
- # Compute the attention mask
558
- if attention_mask is not None:
559
- if attention_mask.dim() == 2:
560
- attention_mask = attention_mask[:, None, :]
561
- attention_mask = attention_mask.to(torch.bool)
 
 
 
 
562
 
563
- if causal:
564
- attention_mask = attention_mask & causal_mask
565
- else:
566
- attention_mask = attention_mask
 
567
 
568
- # Compute the softmax scale
569
- softmax_scale = self.head_dim**-0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
571
- # Compute the attention scores
572
  if attention_mask is not None:
573
- # Unpad the input
574
- (
575
- query_states,
576
- key_states,
577
- value_states,
578
- indices_q,
579
- cu_seq_lens,
580
- max_seq_lens,
581
- ) = self._upad_input(query_states, key_states, value_states, attention_mask, q_len)
582
 
583
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
584
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
585
 
586
- # Create the cu_seqlens_q and cu_seqlens_k tensors
587
- q_max_s, k_max_s = query_states.shape[1], key_states.shape[1]
588
- qkv_max_s = max(q_max_s, k_max_s)
589
-
590
- q_seqlens = torch.full((batch_size,), q_max_s, dtype=torch.int32, device=query_states.device)
591
- k_seqlens = torch.full((batch_size,), k_max_s, dtype=torch.int32, device=key_states.device)
592
-
593
- # Adjust the attention mask to match the sequence lengths
594
- if attention_mask is not None:
595
- q_seqlens = attention_mask.sum(dim=1).int()
596
- k_seqlens = attention_mask.sum(dim=1).int()
597
-
598
- # Convert seqlens to cumulative sequence lengths
599
- cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
600
- cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
601
-
602
  if not use_sliding_windows:
603
  attn_output_unpad = flash_attn_varlen_func(
604
  query_states,
@@ -606,8 +638,8 @@ class QuietFlashAttention2(QuietAttention):
606
  value_states,
607
  cu_seqlens_q=cu_seqlens_q,
608
  cu_seqlens_k=cu_seqlens_k,
609
- max_seqlen_q=qkv_max_s,
610
- max_seqlen_k=qkv_max_s,
611
  dropout_p=dropout,
612
  softmax_scale=softmax_scale,
613
  causal=causal,
@@ -619,8 +651,8 @@ class QuietFlashAttention2(QuietAttention):
619
  value_states,
620
  cu_seqlens_q=cu_seqlens_q,
621
  cu_seqlens_k=cu_seqlens_k,
622
- max_seqlen_q=qkv_max_s,
623
- max_seqlen_k=qkv_max_s,
624
  dropout_p=dropout,
625
  softmax_scale=softmax_scale,
626
  causal=causal,
@@ -663,7 +695,8 @@ class QuietFlashAttention2(QuietAttention):
663
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
664
 
665
  key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
666
- value_layer= index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
 
667
  if query_length == kv_seq_len:
668
  query_layer = index_first_axis(
669
  query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
@@ -691,6 +724,8 @@ class QuietFlashAttention2(QuietAttention):
691
  (cu_seqlens_q, cu_seqlens_k),
692
  (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
693
  )
 
 
694
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
695
  class QuietSdpaAttention(QuietAttention):
696
  """
@@ -768,7 +803,7 @@ class QuietSdpaAttention(QuietAttention):
768
  attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
769
  dropout_p=self.attention_dropout if self.training else 0.0,
770
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
771
- causal=self.is_causal and attention_mask is None and q_len > 1,
772
  )
773
 
774
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -1814,7 +1849,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1814
  dim=-1
1815
  )
1816
 
1817
-
1818
  # print((new_rm_tokens > self.vocab_size - 1).any().item())
1819
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
1820
 
 
432
  super().__init__(*args, **kwargs)
433
 
434
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
435
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
436
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
437
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
438
 
 
533
  if torch.is_autocast_enabled():
534
  target_dtype = torch.get_autocast_gpu_dtype()
535
  # Handle the case where the model is quantized
536
+ elif hasattr(self.config, "_pre_quantization_dtype"):
537
+ target_dtype = self.config._pre_quantization_dtype
538
  else:
539
+ target_dtype = self.q_proj.weight.dtype
540
+
541
+ logger.warning_once(
542
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
543
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
544
+ f" {target_dtype}."
545
+ )
546
+
547
  query_states = query_states.to(target_dtype)
548
  key_states = key_states.to(target_dtype)
549
  value_states = value_states.to(target_dtype)
550
 
551
+ # Reashape to the expected shape for Flash Attention
552
+ query_states = query_states.transpose(1, 2)
553
+ key_states = key_states.transpose(1, 2)
554
+ value_states = value_states.transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
555
 
556
+ attn_output = self._flash_attention_forward(
557
+ query_states,
558
+ key_states,
559
+ value_states,
560
+ attention_mask,
561
+ q_len,
562
+ dropout=dropout_rate,
563
+ use_sliding_windows=use_sliding_windows,
564
+ )
565
 
566
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
567
+ attn_output = self.o_proj(attn_output)
568
+
569
+ if not output_attentions:
570
+ attn_weights = None
571
 
572
+ return attn_output, attn_weights, past_key_value
573
+
574
+ def _flash_attention_forward(
575
+ self,
576
+ query_states,
577
+ key_states,
578
+ value_states,
579
+ attention_mask,
580
+ query_length,
581
+ dropout=0.0,
582
+ softmax_scale=None,
583
+ use_sliding_windows=False,
584
+ ):
585
+ """
586
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
587
+ first unpad the input, then computes the attention scores and pad the final attention scores.
588
+ Args:
589
+ query_states (`torch.Tensor`):
590
+ Input query states to be passed to Flash Attention API
591
+ key_states (`torch.Tensor`):
592
+ Input key states to be passed to Flash Attention API
593
+ value_states (`torch.Tensor`):
594
+ Input value states to be passed to Flash Attention API
595
+ attention_mask (`torch.Tensor`):
596
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
597
+ position of padding tokens and 1 for the position of non-padding tokens.
598
+ dropout (`int`, *optional*):
599
+ Attention dropout
600
+ softmax_scale (`float`, *optional*):
601
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
602
+ use_sliding_windows (`bool`, *optional*):
603
+ Whether to activate sliding window attention.
604
+ """
605
+ if not self._flash_attn_uses_top_left_mask:
606
+ causal = self.is_causal
607
+ else:
608
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
609
+ causal = self.is_causal and query_length != 1
610
+
611
+ # Ensure attention_mask has the correct shape and values
612
+ if attention_mask is not None:
613
+ if attention_mask.dim() == 4:
614
+ # Convert 4D attention mask to 2D
615
+ attention_mask = attention_mask.squeeze(1).squeeze(1)
616
+ elif attention_mask.dim() != 2:
617
+ raise ValueError(
618
+ f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
619
+ )
620
+
621
+ # Ensure attention_mask has values of 0 and 1
622
+ attention_mask = attention_mask.to(torch.bool).to(torch.int32)
623
 
624
+ # Contains at least one padding token in the sequence
625
  if attention_mask is not None:
626
+ batch_size = query_states.shape[0]
627
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
628
+ query_states, key_states, value_states, attention_mask, query_length
629
+ )
 
 
 
 
 
630
 
631
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
632
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  if not use_sliding_windows:
635
  attn_output_unpad = flash_attn_varlen_func(
636
  query_states,
 
638
  value_states,
639
  cu_seqlens_q=cu_seqlens_q,
640
  cu_seqlens_k=cu_seqlens_k,
641
+ max_seqlen_q=max_seqlen_in_batch_q,
642
+ max_seqlen_k=max_seqlen_in_batch_k,
643
  dropout_p=dropout,
644
  softmax_scale=softmax_scale,
645
  causal=causal,
 
651
  value_states,
652
  cu_seqlens_q=cu_seqlens_q,
653
  cu_seqlens_k=cu_seqlens_k,
654
+ max_seqlen_q=max_seqlen_in_batch_q,
655
+ max_seqlen_k=max_seqlen_in_batch_k,
656
  dropout_p=dropout,
657
  softmax_scale=softmax_scale,
658
  causal=causal,
 
695
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
696
 
697
  key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
698
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
699
+
700
  if query_length == kv_seq_len:
701
  query_layer = index_first_axis(
702
  query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
 
724
  (cu_seqlens_q, cu_seqlens_k),
725
  (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
726
  )
727
+
728
+
729
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
730
  class QuietSdpaAttention(QuietAttention):
731
  """
 
803
  attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
804
  dropout_p=self.attention_dropout if self.training else 0.0,
805
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
806
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
807
  )
808
 
809
  attn_output = attn_output.transpose(1, 2).contiguous()
 
1849
  dim=-1
1850
  )
1851
 
 
1852
  # print((new_rm_tokens > self.vocab_size - 1).any().item())
1853
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
1854