Hugo Larcher commited on
Commit
6725288
·
1 Parent(s): 124d35c

Input generation fix tf

Browse files
Files changed (1) hide show
  1. modelling_RW.py +50 -39
modelling_RW.py CHANGED
@@ -271,44 +271,52 @@ class Attention(nn.Module):
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, kv_length, head_dim]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
274
- past_key = past_key.permute(0, 2, 1)
275
  key_layer = torch.cat((past_key, key_layer), dim=1)
276
  value_layer = torch.cat((past_value, value_layer), dim=1)
277
 
278
  _, kv_length, _ = key_layer.shape
279
 
280
  if use_cache is True:
281
- key_layer_permute = key_layer.permute(0, 2, 1)
282
- present = (key_layer_permute, value_layer)
283
  else:
284
  present = None
285
 
 
 
 
 
 
 
286
  if alibi is None:
287
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
288
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
289
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
 
290
 
291
- if attention_mask is not None:
292
- attn_output = F.scaled_dot_product_attention(
293
- query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
294
  )
 
295
  else:
296
  attn_output = F.scaled_dot_product_attention(
297
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
298
  )
 
299
 
300
- x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
301
- x = x.permute(0, 2, 1, 3)
302
- attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
303
 
304
  output_tensor = self.dense(attn_output)
305
 
306
- outputs = (output_tensor, present)
307
- assert not output_attentions # not supported.
308
- return outputs
 
 
309
  else:
310
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
311
- matmul_result = query_layer @ key_layer.transpose(-1, -2)
312
 
313
  # change view to [batch_size, num_heads, q_length, kv_length]
314
  attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
@@ -318,35 +326,34 @@ class Attention(nn.Module):
318
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
319
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
320
  attention_scores = attention_scores.to(torch.float32)
321
- # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
322
- attention_probs = F.softmax(
323
- (attention_scores + alibi.view(batch_size, self.num_heads, 1,
324
- -1)) * self.inv_norm_factor + attention_mask_float,
325
- dim=-1,
326
- dtype=hidden_states.dtype,
327
- )
328
  # [batch_size, num_heads, q_length, kv_length]
329
  attention_probs = self.attention_dropout(attention_probs)
330
 
331
  if head_mask is not None:
332
  attention_probs = attention_probs * head_mask
333
 
334
- # change view [batch_size x num_heads, q_length, kv_length]
335
- attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
336
 
337
  # matmul: [batch_size * num_heads, q_length, head_dim]
338
- context_layer = attention_probs_reshaped @ value_layer
339
 
340
  # change view [batch_size, num_heads, q_length, head_dim]
341
  context_layer = self._merge_heads(context_layer)
342
 
343
  output_tensor = self.dense(context_layer)
344
 
345
- outputs = (output_tensor, present)
346
  if output_attentions:
347
- outputs += (attention_probs,)
348
-
349
- return outputs
350
 
351
 
352
  class MLP(nn.Module):
@@ -562,6 +569,8 @@ class RWModel(RWPreTrainedModel):
562
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
563
  )
564
 
 
 
565
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
566
  self.word_embeddings = new_embeddings
567
 
@@ -606,6 +615,8 @@ class RWModel(RWPreTrainedModel):
606
 
607
  if past_key_values is None:
608
  past_key_values = tuple([None] * len(self.h))
 
 
609
 
610
  # Prepare head mask if needed
611
  # 1.0 in head_mask indicate we keep the head
@@ -623,13 +634,11 @@ class RWModel(RWPreTrainedModel):
623
  all_hidden_states = () if output_hidden_states else None
624
 
625
  # Compute alibi tensor: check build_alibi_tensor documentation
626
- seq_length_with_past = seq_length
627
  past_key_values_length = 0
628
  if past_key_values[0] is not None:
629
- past_key_values_length = past_key_values[0][0].shape[2]
630
- seq_length_with_past = seq_length_with_past + past_key_values_length
631
  if attention_mask is None:
632
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
633
  else:
634
  attention_mask = attention_mask.to(hidden_states.device)
635
 
@@ -695,6 +704,9 @@ class RWModel(RWPreTrainedModel):
695
  if output_hidden_states:
696
  all_hidden_states = all_hidden_states + (hidden_states,)
697
 
 
 
 
698
  if not return_dict:
699
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
700
 
@@ -822,7 +834,6 @@ class RWForCausalLM(RWPreTrainedModel):
822
 
823
  Output shares the same memory storage as `past`.
824
  """
825
- standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
826
 
827
  # Get a copy of `beam_idx` on all the devices where we need those indices.
828
  device_to_beam_idx = {
@@ -833,9 +844,9 @@ class RWForCausalLM(RWPreTrainedModel):
833
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
834
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
835
  )
836
- for layer_past in standardized_past
837
  )
838
- return self._convert_to_rw_cache(reordered_past)
839
 
840
 
841
  class RWForSequenceClassification(RWPreTrainedModel):
 
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, kv_length, head_dim]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
 
274
  key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
278
 
279
  if use_cache is True:
280
+ present = (key_layer, value_layer)
 
281
  else:
282
  present = None
283
 
284
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
285
+
286
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
+ key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
288
+ value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
289
+
290
  if alibi is None:
291
+ if output_attentions:
292
+ # F.scaled_dot_product_attention doesn't return the attention weights, so we have
293
+ # to do it by hand if we want them
294
+ attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
295
+ attention_scores /= math.sqrt(self.head_dim)
296
 
297
+ attention_scores = F.softmax(
298
+ attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
 
299
  )
300
+ attn_output = attention_scores @ value_layer_
301
  else:
302
  attn_output = F.scaled_dot_product_attention(
303
+ query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
304
  )
305
+ attention_scores = None
306
 
307
+ attn_output = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
308
+ attn_output = attn_output.permute(0, 2, 1, 3)
309
+ attn_output = attn_output.reshape(batch_size, q_length, self.num_heads * self.head_dim)
310
 
311
  output_tensor = self.dense(attn_output)
312
 
313
+ if output_attentions:
314
+ return output_tensor, present, attention_scores
315
+ else:
316
+ return output_tensor, present
317
+
318
  else:
319
+ matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
320
 
321
  # change view to [batch_size, num_heads, q_length, kv_length]
322
  attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
 
326
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
327
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
328
  attention_scores = attention_scores.to(torch.float32)
329
+ # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
330
+ # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
331
+ # equivalent and more performant, but there might be a numerical difference. If you're reading this
332
+ # and you'd like to experiment and maybe file a PR, feel free!
333
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
334
+ attention_logits *= self.inv_norm_factor
335
+ attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
336
  # [batch_size, num_heads, q_length, kv_length]
337
  attention_probs = self.attention_dropout(attention_probs)
338
 
339
  if head_mask is not None:
340
  attention_probs = attention_probs * head_mask
341
 
342
+ # change view [batch_size, num_heads, q_length, kv_length]
343
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, q_length, kv_length)
344
 
345
  # matmul: [batch_size * num_heads, q_length, head_dim]
346
+ context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
347
 
348
  # change view [batch_size, num_heads, q_length, head_dim]
349
  context_layer = self._merge_heads(context_layer)
350
 
351
  output_tensor = self.dense(context_layer)
352
 
 
353
  if output_attentions:
354
+ return output_tensor, present, attention_probs
355
+ else:
356
+ return output_tensor, present
357
 
358
 
359
  class MLP(nn.Module):
 
569
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
570
  )
571
 
572
+ return combined_attention_mask
573
+
574
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
575
  self.word_embeddings = new_embeddings
576
 
 
615
 
616
  if past_key_values is None:
617
  past_key_values = tuple([None] * len(self.h))
618
+ else:
619
+ past_key_values = self._convert_to_rw_cache(past_key_values)
620
 
621
  # Prepare head mask if needed
622
  # 1.0 in head_mask indicate we keep the head
 
634
  all_hidden_states = () if output_hidden_states else None
635
 
636
  # Compute alibi tensor: check build_alibi_tensor documentation
 
637
  past_key_values_length = 0
638
  if past_key_values[0] is not None:
639
+ past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
640
  if attention_mask is None:
641
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
642
  else:
643
  attention_mask = attention_mask.to(hidden_states.device)
644
 
 
704
  if output_hidden_states:
705
  all_hidden_states = all_hidden_states + (hidden_states,)
706
 
707
+ if presents is not None:
708
+ presents = self._convert_cache_to_standard_format(presents, batch_size)
709
+
710
  if not return_dict:
711
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
712
 
 
834
 
835
  Output shares the same memory storage as `past`.
836
  """
 
837
 
838
  # Get a copy of `beam_idx` on all the devices where we need those indices.
839
  device_to_beam_idx = {
 
844
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
845
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
846
  )
847
+ for layer_past in past
848
  )
849
+ return reordered_past
850
 
851
 
852
  class RWForSequenceClassification(RWPreTrainedModel):