RaymondLi commited on
Commit
88dbd99
·
1 Parent(s): e3c5922

Update modeling_gpt2_mq.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2_mq.py +1 -1
modeling_gpt2_mq.py CHANGED
@@ -244,7 +244,7 @@ class GPT2MQAttention(nn.Module):
244
  attention_mask = encoder_attention_mask
245
  else:
246
  query = self.q_attn(hidden_states)
247
- key, value = self.kv_attn(hidden_states).split(self.split_size, dim=2)
248
 
249
 
250
  batch_size, seq_length = query.shape[:2]
 
244
  attention_mask = encoder_attention_mask
245
  else:
246
  query = self.q_attn(hidden_states)
247
+ key, value = self.kv_attn(hidden_states).split(self.head_dim, dim=2)
248
 
249
 
250
  batch_size, seq_length = query.shape[:2]