Crystalcareai commited on
Commit
695a26f
·
verified ·
1 Parent(s): 7a973a6

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +6 -0
modeling_quiet.py CHANGED
@@ -719,6 +719,12 @@ class QuietSdpaAttention(QuietAttention):
719
  key_states = key_states.contiguous()
720
  value_states = value_states.contiguous()
721
 
 
 
 
 
 
 
722
  attn_output = torch.nn.functional.scaled_dot_product_attention(
723
  query_states,
724
  key_states,
 
719
  key_states = key_states.contiguous()
720
  value_states = value_states.contiguous()
721
 
722
+
723
+ # Cast query_states, key_states, and value_states to the same data type as attention_mask
724
+ query_states = query_states.to(attention_mask.dtype)
725
+ key_states = key_states.to(attention_mask.dtype)
726
+ value_states = value_states.to(attention_mask.dtype)
727
+
728
  attn_output = torch.nn.functional.scaled_dot_product_attention(
729
  query_states,
730
  key_states,