Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -0
modeling_quiet.py
CHANGED
@@ -719,6 +719,12 @@ class QuietSdpaAttention(QuietAttention):
|
|
719 |
key_states = key_states.contiguous()
|
720 |
value_states = value_states.contiguous()
|
721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
723 |
query_states,
|
724 |
key_states,
|
|
|
719 |
key_states = key_states.contiguous()
|
720 |
value_states = value_states.contiguous()
|
721 |
|
722 |
+
|
723 |
+
# Cast query_states, key_states, and value_states to the same data type as attention_mask
|
724 |
+
query_states = query_states.to(attention_mask.dtype)
|
725 |
+
key_states = key_states.to(attention_mask.dtype)
|
726 |
+
value_states = value_states.to(attention_mask.dtype)
|
727 |
+
|
728 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
729 |
query_states,
|
730 |
key_states,
|