Crystalcareai commited on
Commit
f9819f8
·
verified ·
1 Parent(s): 8dec3b9

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +1 -6
modeling_quiet.py CHANGED
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
- attention_mask = attention_mask.to(query_states.dtype)
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
@@ -1674,15 +1674,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1674
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1675
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1676
  attention_mask = base_attention_mask
1677
- # breakpoint()
1678
  elif attention_mask.dim() == 2:
1679
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
1680
- # breakpoint()
1681
  attention_mask = torch.cat(
1682
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1683
  dim=-1
1684
  )
1685
- # # if the attention mask
1686
  attention_mask = _prepare_4d_causal_attention_mask(
1687
  attention_mask,
1688
  (batch_size, seq_len),
@@ -1700,10 +1697,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1700
  use_cache=use_cache,
1701
  output_attentions=output_attentions,
1702
  output_hidden_states=output_hidden_states,
1703
- # output_router_logits=output_router_logits,
1704
  return_dict=return_dict,
1705
  )
1706
-
1707
  prev_hidden_states = hidden_states
1708
  hidden_states = outputs[0]
1709
  prev_rm_logits = rm_logits # for policy gradient
 
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
+
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
 
1674
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1675
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1676
  attention_mask = base_attention_mask
 
1677
  elif attention_mask.dim() == 2:
1678
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
 
1679
  attention_mask = torch.cat(
1680
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1681
  dim=-1
1682
  )
 
1683
  attention_mask = _prepare_4d_causal_attention_mask(
1684
  attention_mask,
1685
  (batch_size, seq_len),
 
1697
  use_cache=use_cache,
1698
  output_attentions=output_attentions,
1699
  output_hidden_states=output_hidden_states,
 
1700
  return_dict=return_dict,
1701
  )
 
1702
  prev_hidden_states = hidden_states
1703
  hidden_states = outputs[0]
1704
  prev_rm_logits = rm_logits # for policy gradient