Hugo Larcher
commited on
Commit
·
6725288
1
Parent(s):
124d35c
Input generation fix tf
Browse files- modelling_RW.py +50 -39
modelling_RW.py
CHANGED
@@ -271,44 +271,52 @@ class Attention(nn.Module):
|
|
271 |
# concatenate along seq_length dimension:
|
272 |
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
274 |
-
past_key = past_key.permute(0, 2, 1)
|
275 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
276 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
277 |
|
278 |
_, kv_length, _ = key_layer.shape
|
279 |
|
280 |
if use_cache is True:
|
281 |
-
|
282 |
-
present = (key_layer_permute, value_layer)
|
283 |
else:
|
284 |
present = None
|
285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
if alibi is None:
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
|
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
|
294 |
)
|
|
|
295 |
else:
|
296 |
attn_output = F.scaled_dot_product_attention(
|
297 |
-
query_layer_, key_layer_, value_layer_,
|
298 |
)
|
|
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
attn_output =
|
303 |
|
304 |
output_tensor = self.dense(attn_output)
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
309 |
else:
|
310 |
-
|
311 |
-
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
312 |
|
313 |
# change view to [batch_size, num_heads, q_length, kv_length]
|
314 |
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
@@ -318,35 +326,34 @@ class Attention(nn.Module):
|
|
318 |
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
319 |
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
320 |
attention_scores = attention_scores.to(torch.float32)
|
321 |
-
#
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
)
|
328 |
# [batch_size, num_heads, q_length, kv_length]
|
329 |
attention_probs = self.attention_dropout(attention_probs)
|
330 |
|
331 |
if head_mask is not None:
|
332 |
attention_probs = attention_probs * head_mask
|
333 |
|
334 |
-
# change view [batch_size
|
335 |
-
attention_probs_reshaped = attention_probs.view(batch_size
|
336 |
|
337 |
# matmul: [batch_size * num_heads, q_length, head_dim]
|
338 |
-
context_layer = attention_probs_reshaped @
|
339 |
|
340 |
# change view [batch_size, num_heads, q_length, head_dim]
|
341 |
context_layer = self._merge_heads(context_layer)
|
342 |
|
343 |
output_tensor = self.dense(context_layer)
|
344 |
|
345 |
-
outputs = (output_tensor, present)
|
346 |
if output_attentions:
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
|
351 |
|
352 |
class MLP(nn.Module):
|
@@ -562,6 +569,8 @@ class RWModel(RWPreTrainedModel):
|
|
562 |
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
563 |
)
|
564 |
|
|
|
|
|
565 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
566 |
self.word_embeddings = new_embeddings
|
567 |
|
@@ -606,6 +615,8 @@ class RWModel(RWPreTrainedModel):
|
|
606 |
|
607 |
if past_key_values is None:
|
608 |
past_key_values = tuple([None] * len(self.h))
|
|
|
|
|
609 |
|
610 |
# Prepare head mask if needed
|
611 |
# 1.0 in head_mask indicate we keep the head
|
@@ -623,13 +634,11 @@ class RWModel(RWPreTrainedModel):
|
|
623 |
all_hidden_states = () if output_hidden_states else None
|
624 |
|
625 |
# Compute alibi tensor: check build_alibi_tensor documentation
|
626 |
-
seq_length_with_past = seq_length
|
627 |
past_key_values_length = 0
|
628 |
if past_key_values[0] is not None:
|
629 |
-
past_key_values_length = past_key_values[0][0].shape[
|
630 |
-
seq_length_with_past = seq_length_with_past + past_key_values_length
|
631 |
if attention_mask is None:
|
632 |
-
attention_mask = torch.ones((batch_size,
|
633 |
else:
|
634 |
attention_mask = attention_mask.to(hidden_states.device)
|
635 |
|
@@ -695,6 +704,9 @@ class RWModel(RWPreTrainedModel):
|
|
695 |
if output_hidden_states:
|
696 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
697 |
|
|
|
|
|
|
|
698 |
if not return_dict:
|
699 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
700 |
|
@@ -822,7 +834,6 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
822 |
|
823 |
Output shares the same memory storage as `past`.
|
824 |
"""
|
825 |
-
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
826 |
|
827 |
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
828 |
device_to_beam_idx = {
|
@@ -833,9 +844,9 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
833 |
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
834 |
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
835 |
)
|
836 |
-
for layer_past in
|
837 |
)
|
838 |
-
return
|
839 |
|
840 |
|
841 |
class RWForSequenceClassification(RWPreTrainedModel):
|
|
|
271 |
# concatenate along seq_length dimension:
|
272 |
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
|
274 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
276 |
|
277 |
_, kv_length, _ = key_layer.shape
|
278 |
|
279 |
if use_cache is True:
|
280 |
+
present = (key_layer, value_layer)
|
|
|
281 |
else:
|
282 |
present = None
|
283 |
|
284 |
+
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
285 |
+
|
286 |
+
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
287 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
288 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
289 |
+
|
290 |
if alibi is None:
|
291 |
+
if output_attentions:
|
292 |
+
# F.scaled_dot_product_attention doesn't return the attention weights, so we have
|
293 |
+
# to do it by hand if we want them
|
294 |
+
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
|
295 |
+
attention_scores /= math.sqrt(self.head_dim)
|
296 |
|
297 |
+
attention_scores = F.softmax(
|
298 |
+
attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
|
|
|
299 |
)
|
300 |
+
attn_output = attention_scores @ value_layer_
|
301 |
else:
|
302 |
attn_output = F.scaled_dot_product_attention(
|
303 |
+
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
|
304 |
)
|
305 |
+
attention_scores = None
|
306 |
|
307 |
+
attn_output = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
308 |
+
attn_output = attn_output.permute(0, 2, 1, 3)
|
309 |
+
attn_output = attn_output.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
310 |
|
311 |
output_tensor = self.dense(attn_output)
|
312 |
|
313 |
+
if output_attentions:
|
314 |
+
return output_tensor, present, attention_scores
|
315 |
+
else:
|
316 |
+
return output_tensor, present
|
317 |
+
|
318 |
else:
|
319 |
+
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
|
|
|
320 |
|
321 |
# change view to [batch_size, num_heads, q_length, kv_length]
|
322 |
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
|
|
326 |
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
327 |
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
328 |
attention_scores = attention_scores.to(torch.float32)
|
329 |
+
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
|
330 |
+
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
|
331 |
+
# equivalent and more performant, but there might be a numerical difference. If you're reading this
|
332 |
+
# and you'd like to experiment and maybe file a PR, feel free!
|
333 |
+
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
334 |
+
attention_logits *= self.inv_norm_factor
|
335 |
+
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
|
336 |
# [batch_size, num_heads, q_length, kv_length]
|
337 |
attention_probs = self.attention_dropout(attention_probs)
|
338 |
|
339 |
if head_mask is not None:
|
340 |
attention_probs = attention_probs * head_mask
|
341 |
|
342 |
+
# change view [batch_size, num_heads, q_length, kv_length]
|
343 |
+
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, q_length, kv_length)
|
344 |
|
345 |
# matmul: [batch_size * num_heads, q_length, head_dim]
|
346 |
+
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
|
347 |
|
348 |
# change view [batch_size, num_heads, q_length, head_dim]
|
349 |
context_layer = self._merge_heads(context_layer)
|
350 |
|
351 |
output_tensor = self.dense(context_layer)
|
352 |
|
|
|
353 |
if output_attentions:
|
354 |
+
return output_tensor, present, attention_probs
|
355 |
+
else:
|
356 |
+
return output_tensor, present
|
357 |
|
358 |
|
359 |
class MLP(nn.Module):
|
|
|
569 |
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
570 |
)
|
571 |
|
572 |
+
return combined_attention_mask
|
573 |
+
|
574 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
575 |
self.word_embeddings = new_embeddings
|
576 |
|
|
|
615 |
|
616 |
if past_key_values is None:
|
617 |
past_key_values = tuple([None] * len(self.h))
|
618 |
+
else:
|
619 |
+
past_key_values = self._convert_to_rw_cache(past_key_values)
|
620 |
|
621 |
# Prepare head mask if needed
|
622 |
# 1.0 in head_mask indicate we keep the head
|
|
|
634 |
all_hidden_states = () if output_hidden_states else None
|
635 |
|
636 |
# Compute alibi tensor: check build_alibi_tensor documentation
|
|
|
637 |
past_key_values_length = 0
|
638 |
if past_key_values[0] is not None:
|
639 |
+
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
|
|
|
640 |
if attention_mask is None:
|
641 |
+
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
|
642 |
else:
|
643 |
attention_mask = attention_mask.to(hidden_states.device)
|
644 |
|
|
|
704 |
if output_hidden_states:
|
705 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
706 |
|
707 |
+
if presents is not None:
|
708 |
+
presents = self._convert_cache_to_standard_format(presents, batch_size)
|
709 |
+
|
710 |
if not return_dict:
|
711 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
712 |
|
|
|
834 |
|
835 |
Output shares the same memory storage as `past`.
|
836 |
"""
|
|
|
837 |
|
838 |
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
839 |
device_to_beam_idx = {
|
|
|
844 |
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
845 |
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
846 |
)
|
847 |
+
for layer_past in past
|
848 |
)
|
849 |
+
return reordered_past
|
850 |
|
851 |
|
852 |
class RWForSequenceClassification(RWPreTrainedModel):
|