Update modeling_gpt2l.py
Browse files- modeling_gpt2l.py +2 -2
modeling_gpt2l.py
CHANGED
|
@@ -169,11 +169,11 @@ class Attention(nn.Module):
|
|
| 169 |
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 170 |
|
| 171 |
query = self.split_heads(query)
|
| 172 |
-
key = self.split_heads(key
|
| 173 |
value = self.split_heads(value)
|
| 174 |
if layer_past is not None:
|
| 175 |
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
| 176 |
-
key = torch.cat((past_key, key), dim=-1
|
| 177 |
value = torch.cat((past_value, value), dim=-2)
|
| 178 |
|
| 179 |
if use_cache is True:
|
|
|
|
| 169 |
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 170 |
|
| 171 |
query = self.split_heads(query)
|
| 172 |
+
key = self.split_heads(key) # Dude what? @ the k=True
|
| 173 |
value = self.split_heads(value)
|
| 174 |
if layer_past is not None:
|
| 175 |
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
| 176 |
+
key = torch.cat((past_key, key), dim=-2) # this was dim=-1??? I'm trying to patch in flash attention and this is giving me TROUBLE
|
| 177 |
value = torch.cat((past_value, value), dim=-2)
|
| 178 |
|
| 179 |
if use_cache is True:
|