Hugo Larcher
commited on
Commit
·
5904f81
1
Parent(s):
0c36fb9
Input generation fix
Browse files- modelling_RW.py +1 -1
modelling_RW.py
CHANGED
@@ -508,7 +508,7 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
508 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
509 |
return tuple(
|
510 |
(
|
511 |
-
layer_past[0].view(batch_size_times_num_heads,
|
512 |
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
513 |
)
|
514 |
for layer_past in past_key_value
|
|
|
508 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
509 |
return tuple(
|
510 |
(
|
511 |
+
layer_past[0].view(batch_size_times_num_heads, seq_length, head_dim),
|
512 |
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
513 |
)
|
514 |
for layer_past in past_key_value
|