Update modeling_quiet.py
Browse files- modeling_quiet.py +53 -76
modeling_quiet.py
CHANGED
@@ -23,7 +23,6 @@ import math
|
|
23 |
import copy
|
24 |
import os
|
25 |
import time
|
26 |
-
import pandas as pd
|
27 |
import seaborn as sns
|
28 |
import matplotlib.pyplot as plt
|
29 |
import wandb
|
@@ -69,73 +68,6 @@ logger = logging.get_logger(__name__)
|
|
69 |
|
70 |
_CONFIG_FOR_DOC = "QuietConfig"
|
71 |
|
72 |
-
from reportlab.pdfgen import canvas
|
73 |
-
from reportlab.lib.pagesizes import letter
|
74 |
-
from reportlab.lib.colors import HexColor
|
75 |
-
|
76 |
-
def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
|
77 |
-
c = canvas.Canvas(output_file, pagesize=letter)
|
78 |
-
c.setFont("Courier", 8)
|
79 |
-
x, y = 50, 750
|
80 |
-
previous_text = ""
|
81 |
-
current_text = ""
|
82 |
-
for token_idx, reward in enumerate(token_rewards):
|
83 |
-
current_text = tokenizer.decode(input_ids[: token_idx + 1])
|
84 |
-
if current_text != previous_text:
|
85 |
-
diff_text = current_text[len(previous_text) :]
|
86 |
-
if "\n" in diff_text:
|
87 |
-
lines = diff_text.split("\n")
|
88 |
-
for line_idx, line in enumerate(lines):
|
89 |
-
if line_idx > 0:
|
90 |
-
x = 50
|
91 |
-
y -= 12
|
92 |
-
if abs(reward) < eps:
|
93 |
-
opacity = 0
|
94 |
-
elif abs(reward) > eps2:
|
95 |
-
opacity = 0.8
|
96 |
-
else:
|
97 |
-
opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
|
98 |
-
text_width = c.stringWidth(line)
|
99 |
-
if reward > 0:
|
100 |
-
highlight_color = HexColor("#4CCD99")
|
101 |
-
else:
|
102 |
-
highlight_color = HexColor("#FFC700")
|
103 |
-
highlight_color.alpha = opacity
|
104 |
-
c.setFillColor(highlight_color)
|
105 |
-
c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
|
106 |
-
c.setFillColor(HexColor("#000000"))
|
107 |
-
c.drawString(x, y, line)
|
108 |
-
x += text_width
|
109 |
-
else:
|
110 |
-
if abs(reward) < eps:
|
111 |
-
opacity = 0
|
112 |
-
elif abs(reward) > eps2:
|
113 |
-
opacity = 0.8
|
114 |
-
else:
|
115 |
-
opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
|
116 |
-
text_width = c.stringWidth(diff_text)
|
117 |
-
if reward > 0:
|
118 |
-
highlight_color = HexColor("#4CCD99")
|
119 |
-
else:
|
120 |
-
highlight_color = HexColor("#FFC700")
|
121 |
-
highlight_color.alpha = opacity
|
122 |
-
c.setFillColor(highlight_color)
|
123 |
-
c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
|
124 |
-
c.setFillColor(HexColor("#000000"))
|
125 |
-
c.drawString(x, y, diff_text)
|
126 |
-
x += text_width
|
127 |
-
if x > 550:
|
128 |
-
x = 50
|
129 |
-
y -= 12
|
130 |
-
if y < 50:
|
131 |
-
c.showPage()
|
132 |
-
y = 750
|
133 |
-
x = 50
|
134 |
-
previous_text = current_text
|
135 |
-
c.showPage()
|
136 |
-
c.save()
|
137 |
-
|
138 |
-
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
140 |
def _get_unpad_data(attention_mask):
|
141 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
@@ -325,12 +257,22 @@ class QuietAttention(nn.Module):
|
|
325 |
use_cache: bool = False,
|
326 |
**kwargs,
|
327 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
if "padding_mask" in kwargs:
|
329 |
warnings.warn(
|
330 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
331 |
)
|
332 |
bsz, q_len, _ = hidden_states.size()
|
333 |
|
|
|
|
|
|
|
334 |
query_states = self.q_proj(hidden_states)
|
335 |
key_states = self.k_proj(hidden_states)
|
336 |
value_states = self.v_proj(hidden_states)
|
@@ -368,11 +310,16 @@ class QuietAttention(nn.Module):
|
|
368 |
)
|
369 |
|
370 |
if attention_mask is not None:
|
371 |
-
if attention_mask.
|
|
|
|
|
|
|
|
|
|
|
372 |
raise ValueError(
|
373 |
-
f"Attention mask should be of size {
|
374 |
)
|
375 |
-
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
# upcast attention to fp32
|
@@ -749,11 +696,21 @@ class QuietSdpaAttention(QuietAttention):
|
|
749 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
750 |
|
751 |
if attention_mask is not None:
|
752 |
-
if attention_mask.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
raise ValueError(
|
754 |
-
f"Attention mask should be of size {
|
755 |
)
|
756 |
-
|
757 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
758 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
759 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
@@ -1327,7 +1284,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1327 |
# Generate the continuation
|
1328 |
continuation_length = self.n_ahead - 2
|
1329 |
new_key_values = past_key_values
|
1330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1331 |
start_time = time.time()
|
1332 |
for continuation_idx in range(continuation_length):
|
1333 |
outputs = self.model(
|
@@ -2376,4 +2353,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
|
|
2376 |
past_key_values=transformer_outputs.past_key_values,
|
2377 |
hidden_states=transformer_outputs.hidden_states,
|
2378 |
attentions=transformer_outputs.attentions,
|
2379 |
-
)
|
|
|
23 |
import copy
|
24 |
import os
|
25 |
import time
|
|
|
26 |
import seaborn as sns
|
27 |
import matplotlib.pyplot as plt
|
28 |
import wandb
|
|
|
68 |
|
69 |
_CONFIG_FOR_DOC = "QuietConfig"
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
72 |
def _get_unpad_data(attention_mask):
|
73 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
257 |
use_cache: bool = False,
|
258 |
**kwargs,
|
259 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
260 |
+
|
261 |
+
if past_key_value is not None:
|
262 |
+
expected_attention_mask_size = (bsz, 1, q_len, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
|
263 |
+
if attention_mask.size() != expected_attention_mask_size:
|
264 |
+
# Assuming the attention mask is larger than expected, slice it to match the expected size
|
265 |
+
attention_mask = attention_mask[:, :, :, -expected_attention_mask_size[-1]:]
|
266 |
+
|
267 |
if "padding_mask" in kwargs:
|
268 |
warnings.warn(
|
269 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
270 |
)
|
271 |
bsz, q_len, _ = hidden_states.size()
|
272 |
|
273 |
+
query_states = query_states.to(attention_mask.dtype)
|
274 |
+
key_states = key_states.to(attention_mask.dtype)
|
275 |
+
value_states = value_states.to(attention_mask.dtype)
|
276 |
query_states = self.q_proj(hidden_states)
|
277 |
key_states = self.k_proj(hidden_states)
|
278 |
value_states = self.v_proj(hidden_states)
|
|
|
310 |
)
|
311 |
|
312 |
if attention_mask is not None:
|
313 |
+
if attention_mask.dim() == 3:
|
314 |
+
attention_mask = attention_mask.unsqueeze(1)
|
315 |
+
elif attention_mask.dim() == 2:
|
316 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
317 |
+
|
318 |
+
if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
|
319 |
raise ValueError(
|
320 |
+
f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
|
321 |
)
|
322 |
+
|
323 |
attn_weights = attn_weights + attention_mask
|
324 |
|
325 |
# upcast attention to fp32
|
|
|
696 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
697 |
|
698 |
if attention_mask is not None:
|
699 |
+
if attention_mask.dim() == 3:
|
700 |
+
attention_mask = attention_mask.unsqueeze(1)
|
701 |
+
elif attention_mask.dim() == 2:
|
702 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
703 |
+
|
704 |
+
if attention_mask is not None:
|
705 |
+
if attention_mask.dim() == 3:
|
706 |
+
attention_mask = attention_mask.unsqueeze(1)
|
707 |
+
elif attention_mask.dim() == 2:
|
708 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
709 |
+
|
710 |
+
if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
|
711 |
raise ValueError(
|
712 |
+
f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
|
713 |
)
|
|
|
714 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
715 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
716 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
|
1284 |
# Generate the continuation
|
1285 |
continuation_length = self.n_ahead - 2
|
1286 |
new_key_values = past_key_values
|
1287 |
+
|
1288 |
+
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1289 |
+
if attention_mask is None:
|
1290 |
+
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|
1291 |
+
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1292 |
+
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1293 |
+
attention_mask = base_attention_mask
|
1294 |
+
elif attention_mask.dim() == 2:
|
1295 |
+
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1296 |
+
attention_mask = torch.cat(
|
1297 |
+
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1298 |
+
dim=-1
|
1299 |
+
)
|
1300 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
1301 |
+
attention_mask,
|
1302 |
+
(batch_size, seq_len),
|
1303 |
+
inputs_embeds,
|
1304 |
+
past_key_values_length,
|
1305 |
+
sliding_window=self.config.sliding_window,
|
1306 |
+
)
|
1307 |
+
|
1308 |
start_time = time.time()
|
1309 |
for continuation_idx in range(continuation_length):
|
1310 |
outputs = self.model(
|
|
|
2353 |
past_key_values=transformer_outputs.past_key_values,
|
2354 |
hidden_states=transformer_outputs.hidden_states,
|
2355 |
attentions=transformer_outputs.attentions,
|
2356 |
+
)
|