Update modeling_quiet.py
Browse files- modeling_quiet.py +1 -6
modeling_quiet.py
CHANGED
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
-
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
@@ -1674,15 +1674,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1674 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1675 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1676 |
attention_mask = base_attention_mask
|
1677 |
-
# breakpoint()
|
1678 |
elif attention_mask.dim() == 2:
|
1679 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1680 |
-
# breakpoint()
|
1681 |
attention_mask = torch.cat(
|
1682 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1683 |
dim=-1
|
1684 |
)
|
1685 |
-
# # if the attention mask
|
1686 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1687 |
attention_mask,
|
1688 |
(batch_size, seq_len),
|
@@ -1700,10 +1697,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1700 |
use_cache=use_cache,
|
1701 |
output_attentions=output_attentions,
|
1702 |
output_hidden_states=output_hidden_states,
|
1703 |
-
# output_router_logits=output_router_logits,
|
1704 |
return_dict=return_dict,
|
1705 |
)
|
1706 |
-
|
1707 |
prev_hidden_states = hidden_states
|
1708 |
hidden_states = outputs[0]
|
1709 |
prev_rm_logits = rm_logits # for policy gradient
|
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
+
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
|
1674 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1675 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1676 |
attention_mask = base_attention_mask
|
|
|
1677 |
elif attention_mask.dim() == 2:
|
1678 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
|
|
1679 |
attention_mask = torch.cat(
|
1680 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1681 |
dim=-1
|
1682 |
)
|
|
|
1683 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1684 |
attention_mask,
|
1685 |
(batch_size, seq_len),
|
|
|
1697 |
use_cache=use_cache,
|
1698 |
output_attentions=output_attentions,
|
1699 |
output_hidden_states=output_hidden_states,
|
|
|
1700 |
return_dict=return_dict,
|
1701 |
)
|
|
|
1702 |
prev_hidden_states = hidden_states
|
1703 |
hidden_states = outputs[0]
|
1704 |
prev_rm_logits = rm_logits # for policy gradient
|