Hugo Larcher
commited on
Commit
·
0c36fb9
1
Parent(s):
7094fd0
Input generation fix
Browse files- modelling_RW.py +3 -1
modelling_RW.py
CHANGED
@@ -490,7 +490,9 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
490 |
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
491 |
return tuple(
|
492 |
(
|
493 |
-
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
|
|
|
|
494 |
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
495 |
)
|
496 |
for layer_past in past_key_value
|
|
|
490 |
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
491 |
return tuple(
|
492 |
(
|
493 |
+
# layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
494 |
+
# layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
495 |
+
layer_past[0].view(batch_size, num_heads, seq_length, head_dim),
|
496 |
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
497 |
)
|
498 |
for layer_past in past_key_value
|