Crystalcareai commited on
Commit
aaccd41
·
verified ·
1 Parent(s): 0e2d305

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +37 -29
modeling_quiet.py CHANGED
@@ -35,7 +35,6 @@ from matplotlib.colors import LinearSegmentedColormap, LogNorm
35
  import warnings
36
  from collections import defaultdict
37
  from typing import List, Optional, Tuple, Union
38
- import pdb
39
 
40
  import torch
41
  import torch.nn.functional as F
@@ -47,7 +46,7 @@ from transformers.activations import ACT2FN
47
  from transformers.cache_utils import Cache, DynamicCache
48
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
49
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
50
- from transformers.modeling_utils import PreTrainedModel
51
  from transformers.utils import (
52
  add_start_docstrings,
53
  add_start_docstrings_to_model_forward,
@@ -271,14 +270,22 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
271
 
272
 
273
  class QuietAttention(nn.Module):
 
 
 
 
 
274
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
275
  super().__init__()
276
  self.config = config
277
  self.layer_idx = layer_idx
278
  if layer_idx is None:
279
  logger.warning_once(
280
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class."
 
 
281
  )
 
282
  self.hidden_size = config.hidden_size
283
  self.num_heads = config.num_attention_heads
284
  self.head_dim = self.hidden_size // self.num_heads
@@ -289,17 +296,20 @@ class QuietAttention(nn.Module):
289
  self.is_causal = True
290
  self.attention_dropout = config.attention_dropout
291
 
292
- if self.head_dim * self.num_heads != self.hidden_size:
293
  raise ValueError(
294
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})."
 
295
  )
296
-
297
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
298
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
299
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
300
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
301
  self.rotary_emb = QuietRotaryEmbedding(
302
- self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta
 
 
303
  )
304
 
305
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -317,9 +327,8 @@ class QuietAttention(nn.Module):
317
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
318
  if "padding_mask" in kwargs:
319
  warnings.warn(
320
- "`padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
321
  )
322
-
323
  bsz, q_len, _ = hidden_states.size()
324
 
325
  query_states = self.q_proj(hidden_states)
@@ -334,49 +343,52 @@ class QuietAttention(nn.Module):
334
  if past_key_value is not None:
335
  if self.layer_idx is None:
336
  raise ValueError(
337
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to initialize the attention class with a layer index."
 
 
338
  )
339
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
340
-
341
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
342
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
343
 
344
  if past_key_value is not None:
345
- cache_kwargs = {"sin": sin, "cos": cos} # required by original DynamicCache.update() function
346
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
347
 
 
348
  key_states = repeat_kv(key_states, self.num_key_value_groups)
349
  value_states = repeat_kv(value_states, self.num_key_value_groups)
350
 
351
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
352
-
353
 
354
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
355
  raise ValueError(
356
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is {attn_weights.size()}"
 
357
  )
358
 
359
  if attention_mask is not None:
360
- if attention_mask.dim() == 2:
361
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
362
- attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
363
- attention_mask = (1.0 - attention_mask) * -10000.0
364
- elif attention_mask.dim() == 3:
365
- attention_mask = attention_mask.unsqueeze(1)
366
 
367
  attn_weights = attn_weights + attention_mask
368
 
 
369
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
370
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
371
-
372
  attn_output = torch.matmul(attn_weights, value_states)
373
 
374
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
375
  raise ValueError(
376
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}"
 
377
  )
378
 
379
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
 
 
380
  attn_output = self.o_proj(attn_output)
381
 
382
  if not output_attentions:
@@ -737,7 +749,6 @@ class QuietSdpaAttention(QuietAttention):
737
  value_states = repeat_kv(value_states, self.num_key_value_groups)
738
 
739
  if attention_mask is not None:
740
- print("Attention mask shape:", attention_mask.size())
741
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
742
  raise ValueError(
743
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
@@ -761,7 +772,6 @@ class QuietSdpaAttention(QuietAttention):
761
  )
762
 
763
  attn_output = attn_output.transpose(1, 2).contiguous()
764
- pdb.set_trace()
765
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
766
 
767
  attn_output = self.o_proj(attn_output)
@@ -780,7 +790,7 @@ class QuietDecoderLayer(nn.Module):
780
  def __init__(self, config: QuietConfig, layer_idx: int):
781
  super().__init__()
782
  self.hidden_size = config.hidden_size
783
- print("Attention mask shape in decoder before QuietAttention:", attention_mask.size())
784
  self.self_attn = QUIET_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
785
 
786
  self.mlp = QuietMLP(config)
@@ -1067,14 +1077,12 @@ class QuietModel(QuietPreTrainedModel):
1067
  elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1068
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1069
  # the manual implementation that requires a 4D causal mask in all cases.
1070
-
1071
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1072
  attention_mask,
1073
  (batch_size, seq_length),
1074
  inputs_embeds,
1075
  past_key_values_length,
1076
  )
1077
- print("Attention mask shape in _prepare_4d_causal_attention_mask_for_sdpa:", attention_mask.size())
1078
  elif attention_mask is None or attention_mask.dim() == 2:
1079
  # 4d mask is passed through the layers
1080
  attention_mask = _prepare_4d_causal_attention_mask(
@@ -2368,4 +2376,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
2368
  past_key_values=transformer_outputs.past_key_values,
2369
  hidden_states=transformer_outputs.hidden_states,
2370
  attentions=transformer_outputs.attentions,
2371
- )
 
35
  import warnings
36
  from collections import defaultdict
37
  from typing import List, Optional, Tuple, Union
 
38
 
39
  import torch
40
  import torch.nn.functional as F
 
46
  from transformers.cache_utils import Cache, DynamicCache
47
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
48
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
+ from transformers.models.utils import PreTrainedModel
50
  from transformers.utils import (
51
  add_start_docstrings,
52
  add_start_docstrings_to_model_forward,
 
270
 
271
 
272
  class QuietAttention(nn.Module):
273
+ """
274
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
275
+ and "Generating Long Sequences with Sparse Transformers".
276
+ """
277
+
278
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
279
  super().__init__()
280
  self.config = config
281
  self.layer_idx = layer_idx
282
  if layer_idx is None:
283
  logger.warning_once(
284
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
285
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
286
+ "when creating this class."
287
  )
288
+
289
  self.hidden_size = config.hidden_size
290
  self.num_heads = config.num_attention_heads
291
  self.head_dim = self.hidden_size // self.num_heads
 
296
  self.is_causal = True
297
  self.attention_dropout = config.attention_dropout
298
 
299
+ if (self.head_dim * self.num_heads) != self.hidden_size:
300
  raise ValueError(
301
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
302
+ f" and `num_heads`: {self.num_heads})."
303
  )
 
304
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
305
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
306
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
307
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
308
+
309
  self.rotary_emb = QuietRotaryEmbedding(
310
+ self.head_dim,
311
+ max_position_embeddings=self.max_position_embeddings,
312
+ base=self.rope_theta,
313
  )
314
 
315
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
327
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
328
  if "padding_mask" in kwargs:
329
  warnings.warn(
330
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
331
  )
 
332
  bsz, q_len, _ = hidden_states.size()
333
 
334
  query_states = self.q_proj(hidden_states)
 
343
  if past_key_value is not None:
344
  if self.layer_idx is None:
345
  raise ValueError(
346
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
347
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
348
+ "with a layer index."
349
  )
350
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
351
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
352
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
353
 
354
  if past_key_value is not None:
355
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
356
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
357
 
358
+ # repeat k/v heads if n_kv_heads < n_heads
359
  key_states = repeat_kv(key_states, self.num_key_value_groups)
360
  value_states = repeat_kv(value_states, self.num_key_value_groups)
361
 
362
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
363
 
364
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
365
  raise ValueError(
366
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
367
+ f" {attn_weights.size()}"
368
  )
369
 
370
  if attention_mask is not None:
371
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
372
+ raise ValueError(
373
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
+ )
 
 
375
 
376
  attn_weights = attn_weights + attention_mask
377
 
378
+ # upcast attention to fp32
379
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
380
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
381
  attn_output = torch.matmul(attn_weights, value_states)
382
 
383
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
384
  raise ValueError(
385
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
386
+ f" {attn_output.size()}"
387
  )
388
 
389
+ attn_output = attn_output.transpose(1, 2).contiguous()
390
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
391
+
392
  attn_output = self.o_proj(attn_output)
393
 
394
  if not output_attentions:
 
749
  value_states = repeat_kv(value_states, self.num_key_value_groups)
750
 
751
  if attention_mask is not None:
 
752
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
753
  raise ValueError(
754
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
 
772
  )
773
 
774
  attn_output = attn_output.transpose(1, 2).contiguous()
 
775
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
776
 
777
  attn_output = self.o_proj(attn_output)
 
790
  def __init__(self, config: QuietConfig, layer_idx: int):
791
  super().__init__()
792
  self.hidden_size = config.hidden_size
793
+
794
  self.self_attn = QUIET_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
795
 
796
  self.mlp = QuietMLP(config)
 
1077
  elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1078
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1079
  # the manual implementation that requires a 4D causal mask in all cases.
 
1080
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1081
  attention_mask,
1082
  (batch_size, seq_length),
1083
  inputs_embeds,
1084
  past_key_values_length,
1085
  )
 
1086
  elif attention_mask is None or attention_mask.dim() == 2:
1087
  # 4d mask is passed through the layers
1088
  attention_mask = _prepare_4d_causal_attention_mask(
 
2376
  past_key_values=transformer_outputs.past_key_values,
2377
  hidden_states=transformer_outputs.hidden_states,
2378
  attentions=transformer_outputs.attentions,
2379
+ )