Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -0
modeling_quiet.py
CHANGED
@@ -761,8 +761,11 @@ class QuietSdpaAttention(QuietAttention):
|
|
761 |
|
762 |
bsz, q_len, _ = hidden_states.size()
|
763 |
|
|
|
764 |
query_states = self.q_proj(hidden_states)
|
|
|
765 |
key_states = self.k_proj(hidden_states)
|
|
|
766 |
value_states = self.v_proj(hidden_states)
|
767 |
|
768 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
761 |
|
762 |
bsz, q_len, _ = hidden_states.size()
|
763 |
|
764 |
+
hidden_states = hidden_states.to(self.q_proj.weight.dtype)
|
765 |
query_states = self.q_proj(hidden_states)
|
766 |
+
hidden_states = hidden_states.to(self.k_proj.weight.dtype)
|
767 |
key_states = self.k_proj(hidden_states)
|
768 |
+
hidden_states = hidden_states.to(self.v_proj.weight.dtype)
|
769 |
value_states = self.v_proj(hidden_states)
|
770 |
|
771 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|