Update modeling_gpt2_mq.py
Browse files- modeling_gpt2_mq.py +1 -1
modeling_gpt2_mq.py
CHANGED
@@ -244,7 +244,7 @@ class GPT2MQAttention(nn.Module):
|
|
244 |
attention_mask = encoder_attention_mask
|
245 |
else:
|
246 |
query = self.q_attn(hidden_states)
|
247 |
-
key, value = self.kv_attn(hidden_states).split(self.
|
248 |
|
249 |
|
250 |
batch_size, seq_length = query.shape[:2]
|
|
|
244 |
attention_mask = encoder_attention_mask
|
245 |
else:
|
246 |
query = self.q_attn(hidden_states)
|
247 |
+
key, value = self.kv_attn(hidden_states).split(self.head_dim, dim=2)
|
248 |
|
249 |
|
250 |
batch_size, seq_length = query.shape[:2]
|