Crystalcareai commited on
Commit
974e6b8
·
verified ·
1 Parent(s): 9fdcb7b

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -0
modeling_quiet.py CHANGED
@@ -761,8 +761,11 @@ class QuietSdpaAttention(QuietAttention):
761
 
762
  bsz, q_len, _ = hidden_states.size()
763
 
 
764
  query_states = self.q_proj(hidden_states)
 
765
  key_states = self.k_proj(hidden_states)
 
766
  value_states = self.v_proj(hidden_states)
767
 
768
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
761
 
762
  bsz, q_len, _ = hidden_states.size()
763
 
764
+ hidden_states = hidden_states.to(self.q_proj.weight.dtype)
765
  query_states = self.q_proj(hidden_states)
766
+ hidden_states = hidden_states.to(self.k_proj.weight.dtype)
767
  key_states = self.k_proj(hidden_states)
768
+ hidden_states = hidden_states.to(self.v_proj.weight.dtype)
769
  value_states = self.v_proj(hidden_states)
770
 
771
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)