crumb commited on
Commit
5cb957c
·
verified ·
1 Parent(s): 8224b23

Update modeling_gpt2l.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2l.py +2 -2
modeling_gpt2l.py CHANGED
@@ -169,11 +169,11 @@ class Attention(nn.Module):
169
  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
170
 
171
  query = self.split_heads(query)
172
- key = self.split_heads(key, k=True)
173
  value = self.split_heads(value)
174
  if layer_past is not None:
175
  past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
176
- key = torch.cat((past_key, key), dim=-1)
177
  value = torch.cat((past_value, value), dim=-2)
178
 
179
  if use_cache is True:
 
169
  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
170
 
171
  query = self.split_heads(query)
172
+ key = self.split_heads(key) # Dude what? @ the k=True
173
  value = self.split_heads(value)
174
  if layer_past is not None:
175
  past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
176
+ key = torch.cat((past_key, key), dim=-2) # this was dim=-1??? I'm trying to patch in flash attention and this is giving me TROUBLE
177
  value = torch.cat((past_value, value), dim=-2)
178
 
179
  if use_cache is True: