Crystalcareai commited on
Commit
7ed349e
·
verified ·
1 Parent(s): d842ce9

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +53 -76
modeling_quiet.py CHANGED
@@ -23,7 +23,6 @@ import math
23
  import copy
24
  import os
25
  import time
26
- import pandas as pd
27
  import seaborn as sns
28
  import matplotlib.pyplot as plt
29
  import wandb
@@ -69,73 +68,6 @@ logger = logging.get_logger(__name__)
69
 
70
  _CONFIG_FOR_DOC = "QuietConfig"
71
 
72
- from reportlab.pdfgen import canvas
73
- from reportlab.lib.pagesizes import letter
74
- from reportlab.lib.colors import HexColor
75
-
76
- def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
77
- c = canvas.Canvas(output_file, pagesize=letter)
78
- c.setFont("Courier", 8)
79
- x, y = 50, 750
80
- previous_text = ""
81
- current_text = ""
82
- for token_idx, reward in enumerate(token_rewards):
83
- current_text = tokenizer.decode(input_ids[: token_idx + 1])
84
- if current_text != previous_text:
85
- diff_text = current_text[len(previous_text) :]
86
- if "\n" in diff_text:
87
- lines = diff_text.split("\n")
88
- for line_idx, line in enumerate(lines):
89
- if line_idx > 0:
90
- x = 50
91
- y -= 12
92
- if abs(reward) < eps:
93
- opacity = 0
94
- elif abs(reward) > eps2:
95
- opacity = 0.8
96
- else:
97
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
98
- text_width = c.stringWidth(line)
99
- if reward > 0:
100
- highlight_color = HexColor("#4CCD99")
101
- else:
102
- highlight_color = HexColor("#FFC700")
103
- highlight_color.alpha = opacity
104
- c.setFillColor(highlight_color)
105
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
106
- c.setFillColor(HexColor("#000000"))
107
- c.drawString(x, y, line)
108
- x += text_width
109
- else:
110
- if abs(reward) < eps:
111
- opacity = 0
112
- elif abs(reward) > eps2:
113
- opacity = 0.8
114
- else:
115
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
116
- text_width = c.stringWidth(diff_text)
117
- if reward > 0:
118
- highlight_color = HexColor("#4CCD99")
119
- else:
120
- highlight_color = HexColor("#FFC700")
121
- highlight_color.alpha = opacity
122
- c.setFillColor(highlight_color)
123
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
124
- c.setFillColor(HexColor("#000000"))
125
- c.drawString(x, y, diff_text)
126
- x += text_width
127
- if x > 550:
128
- x = 50
129
- y -= 12
130
- if y < 50:
131
- c.showPage()
132
- y = 750
133
- x = 50
134
- previous_text = current_text
135
- c.showPage()
136
- c.save()
137
-
138
-
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -325,12 +257,22 @@ class QuietAttention(nn.Module):
325
  use_cache: bool = False,
326
  **kwargs,
327
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
328
  if "padding_mask" in kwargs:
329
  warnings.warn(
330
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
331
  )
332
  bsz, q_len, _ = hidden_states.size()
333
 
 
 
 
334
  query_states = self.q_proj(hidden_states)
335
  key_states = self.k_proj(hidden_states)
336
  value_states = self.v_proj(hidden_states)
@@ -368,11 +310,16 @@ class QuietAttention(nn.Module):
368
  )
369
 
370
  if attention_mask is not None:
371
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 
 
 
 
 
372
  raise ValueError(
373
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
  )
375
-
376
  attn_weights = attn_weights + attention_mask
377
 
378
  # upcast attention to fp32
@@ -749,11 +696,21 @@ class QuietSdpaAttention(QuietAttention):
749
  value_states = repeat_kv(value_states, self.num_key_value_groups)
750
 
751
  if attention_mask is not None:
752
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 
 
 
 
 
 
 
 
 
 
 
753
  raise ValueError(
754
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
755
  )
756
-
757
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
758
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
759
  if query_states.device.type == "cuda" and attention_mask is not None:
@@ -1327,7 +1284,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
1327
  # Generate the continuation
1328
  continuation_length = self.n_ahead - 2
1329
  new_key_values = past_key_values
1330
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1331
  start_time = time.time()
1332
  for continuation_idx in range(continuation_length):
1333
  outputs = self.model(
@@ -2376,4 +2353,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
2376
  past_key_values=transformer_outputs.past_key_values,
2377
  hidden_states=transformer_outputs.hidden_states,
2378
  attentions=transformer_outputs.attentions,
2379
- )
 
23
  import copy
24
  import os
25
  import time
 
26
  import seaborn as sns
27
  import matplotlib.pyplot as plt
28
  import wandb
 
68
 
69
  _CONFIG_FOR_DOC = "QuietConfig"
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
72
  def _get_unpad_data(attention_mask):
73
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
257
  use_cache: bool = False,
258
  **kwargs,
259
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
260
+
261
+ if past_key_value is not None:
262
+ expected_attention_mask_size = (bsz, 1, q_len, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
263
+ if attention_mask.size() != expected_attention_mask_size:
264
+ # Assuming the attention mask is larger than expected, slice it to match the expected size
265
+ attention_mask = attention_mask[:, :, :, -expected_attention_mask_size[-1]:]
266
+
267
  if "padding_mask" in kwargs:
268
  warnings.warn(
269
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
270
  )
271
  bsz, q_len, _ = hidden_states.size()
272
 
273
+ query_states = query_states.to(attention_mask.dtype)
274
+ key_states = key_states.to(attention_mask.dtype)
275
+ value_states = value_states.to(attention_mask.dtype)
276
  query_states = self.q_proj(hidden_states)
277
  key_states = self.k_proj(hidden_states)
278
  value_states = self.v_proj(hidden_states)
 
310
  )
311
 
312
  if attention_mask is not None:
313
+ if attention_mask.dim() == 3:
314
+ attention_mask = attention_mask.unsqueeze(1)
315
+ elif attention_mask.dim() == 2:
316
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
317
+
318
+ if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
319
  raise ValueError(
320
+ f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
321
  )
322
+
323
  attn_weights = attn_weights + attention_mask
324
 
325
  # upcast attention to fp32
 
696
  value_states = repeat_kv(value_states, self.num_key_value_groups)
697
 
698
  if attention_mask is not None:
699
+ if attention_mask.dim() == 3:
700
+ attention_mask = attention_mask.unsqueeze(1)
701
+ elif attention_mask.dim() == 2:
702
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
703
+
704
+ if attention_mask is not None:
705
+ if attention_mask.dim() == 3:
706
+ attention_mask = attention_mask.unsqueeze(1)
707
+ elif attention_mask.dim() == 2:
708
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
709
+
710
+ if attention_mask.size(0) != bsz or attention_mask.size(-1) != kv_seq_len:
711
  raise ValueError(
712
+ f"Attention mask should be of size ({bsz}, 1, q_len, {kv_seq_len}), but is {attention_mask.size()}"
713
  )
 
714
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
715
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
716
  if query_states.device.type == "cuda" and attention_mask is not None:
 
1284
  # Generate the continuation
1285
  continuation_length = self.n_ahead - 2
1286
  new_key_values = past_key_values
1287
+
1288
+ if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1289
+ if attention_mask is None:
1290
+ base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
1291
+ base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1292
+ base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1293
+ attention_mask = base_attention_mask
1294
+ elif attention_mask.dim() == 2:
1295
+ if seq_len + past_key_values_length != attention_mask.shape[-1]:
1296
+ attention_mask = torch.cat(
1297
+ [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1298
+ dim=-1
1299
+ )
1300
+ attention_mask = _prepare_4d_causal_attention_mask(
1301
+ attention_mask,
1302
+ (batch_size, seq_len),
1303
+ inputs_embeds,
1304
+ past_key_values_length,
1305
+ sliding_window=self.config.sliding_window,
1306
+ )
1307
+
1308
  start_time = time.time()
1309
  for continuation_idx in range(continuation_length):
1310
  outputs = self.model(
 
2353
  past_key_values=transformer_outputs.past_key_values,
2354
  hidden_states=transformer_outputs.hidden_states,
2355
  attentions=transformer_outputs.attentions,
2356
+ )