Update modeling_quiet.py
Browse files- modeling_quiet.py +37 -29
modeling_quiet.py
CHANGED
@@ -35,7 +35,6 @@ from matplotlib.colors import LinearSegmentedColormap, LogNorm
|
|
35 |
import warnings
|
36 |
from collections import defaultdict
|
37 |
from typing import List, Optional, Tuple, Union
|
38 |
-
import pdb
|
39 |
|
40 |
import torch
|
41 |
import torch.nn.functional as F
|
@@ -47,7 +46,7 @@ from transformers.activations import ACT2FN
|
|
47 |
from transformers.cache_utils import Cache, DynamicCache
|
48 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
49 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
50 |
-
from transformers.
|
51 |
from transformers.utils import (
|
52 |
add_start_docstrings,
|
53 |
add_start_docstrings_to_model_forward,
|
@@ -271,14 +270,22 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
271 |
|
272 |
|
273 |
class QuietAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
274 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
275 |
super().__init__()
|
276 |
self.config = config
|
277 |
self.layer_idx = layer_idx
|
278 |
if layer_idx is None:
|
279 |
logger.warning_once(
|
280 |
-
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will
|
|
|
|
|
281 |
)
|
|
|
282 |
self.hidden_size = config.hidden_size
|
283 |
self.num_heads = config.num_attention_heads
|
284 |
self.head_dim = self.hidden_size // self.num_heads
|
@@ -289,17 +296,20 @@ class QuietAttention(nn.Module):
|
|
289 |
self.is_causal = True
|
290 |
self.attention_dropout = config.attention_dropout
|
291 |
|
292 |
-
if self.head_dim * self.num_heads != self.hidden_size:
|
293 |
raise ValueError(
|
294 |
-
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}
|
|
|
295 |
)
|
296 |
-
|
297 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
298 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
299 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
300 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
301 |
self.rotary_emb = QuietRotaryEmbedding(
|
302 |
-
self.head_dim,
|
|
|
|
|
303 |
)
|
304 |
|
305 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
@@ -317,9 +327,8 @@ class QuietAttention(nn.Module):
|
|
317 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
318 |
if "padding_mask" in kwargs:
|
319 |
warnings.warn(
|
320 |
-
"`padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
321 |
)
|
322 |
-
|
323 |
bsz, q_len, _ = hidden_states.size()
|
324 |
|
325 |
query_states = self.q_proj(hidden_states)
|
@@ -334,49 +343,52 @@ class QuietAttention(nn.Module):
|
|
334 |
if past_key_value is not None:
|
335 |
if self.layer_idx is None:
|
336 |
raise ValueError(
|
337 |
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__}
|
|
|
|
|
338 |
)
|
339 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
340 |
-
|
341 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
342 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
343 |
|
344 |
if past_key_value is not None:
|
345 |
-
cache_kwargs = {"sin": sin, "cos": cos} #
|
346 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
347 |
|
|
|
348 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
349 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
350 |
|
351 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
352 |
-
|
353 |
|
354 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
355 |
raise ValueError(
|
356 |
-
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is
|
|
|
357 |
)
|
358 |
|
359 |
if attention_mask is not None:
|
360 |
-
if attention_mask.
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
elif attention_mask.dim() == 3:
|
365 |
-
attention_mask = attention_mask.unsqueeze(1)
|
366 |
|
367 |
attn_weights = attn_weights + attention_mask
|
368 |
|
|
|
369 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
370 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
371 |
-
|
372 |
attn_output = torch.matmul(attn_weights, value_states)
|
373 |
|
374 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
375 |
raise ValueError(
|
376 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is
|
|
|
377 |
)
|
378 |
|
379 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
|
380 |
attn_output = self.o_proj(attn_output)
|
381 |
|
382 |
if not output_attentions:
|
@@ -737,7 +749,6 @@ class QuietSdpaAttention(QuietAttention):
|
|
737 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
738 |
|
739 |
if attention_mask is not None:
|
740 |
-
print("Attention mask shape:", attention_mask.size())
|
741 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
742 |
raise ValueError(
|
743 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
@@ -761,7 +772,6 @@ class QuietSdpaAttention(QuietAttention):
|
|
761 |
)
|
762 |
|
763 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
764 |
-
pdb.set_trace()
|
765 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
766 |
|
767 |
attn_output = self.o_proj(attn_output)
|
@@ -780,7 +790,7 @@ class QuietDecoderLayer(nn.Module):
|
|
780 |
def __init__(self, config: QuietConfig, layer_idx: int):
|
781 |
super().__init__()
|
782 |
self.hidden_size = config.hidden_size
|
783 |
-
|
784 |
self.self_attn = QUIET_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
785 |
|
786 |
self.mlp = QuietMLP(config)
|
@@ -1067,14 +1077,12 @@ class QuietModel(QuietPreTrainedModel):
|
|
1067 |
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
1068 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1069 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1070 |
-
|
1071 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1072 |
attention_mask,
|
1073 |
(batch_size, seq_length),
|
1074 |
inputs_embeds,
|
1075 |
past_key_values_length,
|
1076 |
)
|
1077 |
-
print("Attention mask shape in _prepare_4d_causal_attention_mask_for_sdpa:", attention_mask.size())
|
1078 |
elif attention_mask is None or attention_mask.dim() == 2:
|
1079 |
# 4d mask is passed through the layers
|
1080 |
attention_mask = _prepare_4d_causal_attention_mask(
|
@@ -2368,4 +2376,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
|
|
2368 |
past_key_values=transformer_outputs.past_key_values,
|
2369 |
hidden_states=transformer_outputs.hidden_states,
|
2370 |
attentions=transformer_outputs.attentions,
|
2371 |
-
)
|
|
|
35 |
import warnings
|
36 |
from collections import defaultdict
|
37 |
from typing import List, Optional, Tuple, Union
|
|
|
38 |
|
39 |
import torch
|
40 |
import torch.nn.functional as F
|
|
|
46 |
from transformers.cache_utils import Cache, DynamicCache
|
47 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
49 |
+
from transformers.models.utils import PreTrainedModel
|
50 |
from transformers.utils import (
|
51 |
add_start_docstrings,
|
52 |
add_start_docstrings_to_model_forward,
|
|
|
270 |
|
271 |
|
272 |
class QuietAttention(nn.Module):
|
273 |
+
"""
|
274 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
275 |
+
and "Generating Long Sequences with Sparse Transformers".
|
276 |
+
"""
|
277 |
+
|
278 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
279 |
super().__init__()
|
280 |
self.config = config
|
281 |
self.layer_idx = layer_idx
|
282 |
if layer_idx is None:
|
283 |
logger.warning_once(
|
284 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
285 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
286 |
+
"when creating this class."
|
287 |
)
|
288 |
+
|
289 |
self.hidden_size = config.hidden_size
|
290 |
self.num_heads = config.num_attention_heads
|
291 |
self.head_dim = self.hidden_size // self.num_heads
|
|
|
296 |
self.is_causal = True
|
297 |
self.attention_dropout = config.attention_dropout
|
298 |
|
299 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
300 |
raise ValueError(
|
301 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
302 |
+
f" and `num_heads`: {self.num_heads})."
|
303 |
)
|
|
|
304 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
305 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
306 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
307 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
308 |
+
|
309 |
self.rotary_emb = QuietRotaryEmbedding(
|
310 |
+
self.head_dim,
|
311 |
+
max_position_embeddings=self.max_position_embeddings,
|
312 |
+
base=self.rope_theta,
|
313 |
)
|
314 |
|
315 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
327 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
328 |
if "padding_mask" in kwargs:
|
329 |
warnings.warn(
|
330 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
331 |
)
|
|
|
332 |
bsz, q_len, _ = hidden_states.size()
|
333 |
|
334 |
query_states = self.q_proj(hidden_states)
|
|
|
343 |
if past_key_value is not None:
|
344 |
if self.layer_idx is None:
|
345 |
raise ValueError(
|
346 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
347 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
348 |
+
"with a layer index."
|
349 |
)
|
350 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
351 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
352 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
353 |
|
354 |
if past_key_value is not None:
|
355 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
356 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
357 |
|
358 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
359 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
360 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
361 |
|
362 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
363 |
|
364 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
365 |
raise ValueError(
|
366 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
367 |
+
f" {attn_weights.size()}"
|
368 |
)
|
369 |
|
370 |
if attention_mask is not None:
|
371 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
372 |
+
raise ValueError(
|
373 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
+
)
|
|
|
|
|
375 |
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
+
# upcast attention to fp32
|
379 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
380 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
|
381 |
attn_output = torch.matmul(attn_weights, value_states)
|
382 |
|
383 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
384 |
raise ValueError(
|
385 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
386 |
+
f" {attn_output.size()}"
|
387 |
)
|
388 |
|
389 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
390 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
391 |
+
|
392 |
attn_output = self.o_proj(attn_output)
|
393 |
|
394 |
if not output_attentions:
|
|
|
749 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
750 |
|
751 |
if attention_mask is not None:
|
|
|
752 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
753 |
raise ValueError(
|
754 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
|
772 |
)
|
773 |
|
774 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
775 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
776 |
|
777 |
attn_output = self.o_proj(attn_output)
|
|
|
790 |
def __init__(self, config: QuietConfig, layer_idx: int):
|
791 |
super().__init__()
|
792 |
self.hidden_size = config.hidden_size
|
793 |
+
|
794 |
self.self_attn = QUIET_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
795 |
|
796 |
self.mlp = QuietMLP(config)
|
|
|
1077 |
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
1078 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1079 |
# the manual implementation that requires a 4D causal mask in all cases.
|
|
|
1080 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1081 |
attention_mask,
|
1082 |
(batch_size, seq_length),
|
1083 |
inputs_embeds,
|
1084 |
past_key_values_length,
|
1085 |
)
|
|
|
1086 |
elif attention_mask is None or attention_mask.dim() == 2:
|
1087 |
# 4d mask is passed through the layers
|
1088 |
attention_mask = _prepare_4d_causal_attention_mask(
|
|
|
2376 |
past_key_values=transformer_outputs.past_key_values,
|
2377 |
hidden_states=transformer_outputs.hidden_states,
|
2378 |
attentions=transformer_outputs.attentions,
|
2379 |
+
)
|