Update modeling_quiet.py
Browse files- modeling_quiet.py +33 -9
modeling_quiet.py
CHANGED
|
@@ -44,7 +44,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
| 44 |
|
| 45 |
from transformers.activations import ACT2FN
|
| 46 |
from transformers.cache_utils import Cache, DynamicCache
|
| 47 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask,
|
| 48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 49 |
from transformers.modeling_utils import PreTrainedModel
|
| 50 |
from transformers.utils import (
|
|
@@ -134,6 +134,34 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
|
|
| 134 |
previous_text = current_text
|
| 135 |
c.showPage()
|
| 136 |
c.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
@@ -1070,10 +1098,11 @@ class QuietModel(QuietPreTrainedModel):
|
|
| 1070 |
" this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
|
| 1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 1072 |
)
|
| 1073 |
-
|
|
|
|
| 1074 |
# 2d mask is passed through the layers
|
| 1075 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1076 |
-
elif self._attn_implementation ==
|
| 1077 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1078 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1079 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
@@ -1082,12 +1111,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
| 1082 |
inputs_embeds,
|
| 1083 |
past_key_values_length,
|
| 1084 |
)
|
| 1085 |
-
|
| 1086 |
-
# Check the shape of the attention mask
|
| 1087 |
-
if attention_mask is not None and attention_mask.dim() == 2:
|
| 1088 |
-
# Reshape the attention mask to 4D
|
| 1089 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
| 1090 |
-
|
| 1091 |
# 4d mask is passed through the layers
|
| 1092 |
attention_mask = _prepare_4d_causal_attention_mask(
|
| 1093 |
attention_mask,
|
|
|
|
| 44 |
|
| 45 |
from transformers.activations import ACT2FN
|
| 46 |
from transformers.cache_utils import Cache, DynamicCache
|
| 47 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask,
|
| 48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 49 |
from transformers.modeling_utils import PreTrainedModel
|
| 50 |
from transformers.utils import (
|
|
|
|
| 134 |
previous_text = current_text
|
| 135 |
c.showPage()
|
| 136 |
c.save()
|
| 137 |
+
|
| 138 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
| 139 |
+
attn_mask: Optional[torch.Tensor],
|
| 140 |
+
shape: Tuple[int, int],
|
| 141 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 142 |
+
past_key_values_length: int = 0,
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
+
batch_size, seq_len = shape
|
| 145 |
+
if attn_mask is None:
|
| 146 |
+
attn_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
|
| 147 |
+
else:
|
| 148 |
+
attn_mask = attn_mask.bool()
|
| 149 |
+
|
| 150 |
+
# Extend the attention mask to account for past key/value states
|
| 151 |
+
if past_key_values_length > 0:
|
| 152 |
+
extended_attn_mask = torch.cat(
|
| 153 |
+
[
|
| 154 |
+
attn_mask.new_zeros(batch_size, seq_len, past_key_values_length),
|
| 155 |
+
attn_mask.unsqueeze(2),
|
| 156 |
+
],
|
| 157 |
+
dim=2,
|
| 158 |
+
)
|
| 159 |
+
attn_mask = extended_attn_mask
|
| 160 |
+
|
| 161 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
| 162 |
+
causal_mask = torch.tril(torch.ones(seq_len, seq_len + past_key_values_length, device=attn_mask.device)).bool()
|
| 163 |
+
attn_mask = attn_mask & causal_mask.unsqueeze(0).unsqueeze(0)
|
| 164 |
+
return attn_mask
|
| 165 |
|
| 166 |
|
| 167 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
|
| 1098 |
" this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
|
| 1099 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 1100 |
)
|
| 1101 |
+
|
| 1102 |
+
if self._attn_implementation == "flash_attention_2":
|
| 1103 |
# 2d mask is passed through the layers
|
| 1104 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1105 |
+
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
| 1106 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1107 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1108 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
|
|
| 1111 |
inputs_embeds,
|
| 1112 |
past_key_values_length,
|
| 1113 |
)
|
| 1114 |
+
elif attention_mask is None or attention_mask.dim() == 2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
# 4d mask is passed through the layers
|
| 1116 |
attention_mask = _prepare_4d_causal_attention_mask(
|
| 1117 |
attention_mask,
|