Update modeling_quiet.py
Browse files- modeling_quiet.py +1 -1
modeling_quiet.py
CHANGED
@@ -768,7 +768,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
768 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
769 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
770 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
771 |
-
|
772 |
)
|
773 |
|
774 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
768 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
769 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
770 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
771 |
+
causal=self.is_causal and attention_mask is None and q_len > 1,
|
772 |
)
|
773 |
|
774 |
attn_output = attn_output.transpose(1, 2).contiguous()
|