Crystalcareai commited on
Commit
7fb20bf
·
verified ·
1 Parent(s): f5c1913

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +110 -194
modeling_quiet.py CHANGED
@@ -44,7 +44,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
 
45
  from transformers.activations import ACT2FN
46
  from transformers.cache_utils import Cache, DynamicCache
47
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
48
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
  from transformers.modeling_utils import PreTrainedModel
50
  from transformers.utils import (
@@ -73,67 +73,63 @@ from reportlab.pdfgen import canvas
73
  from reportlab.lib.pagesizes import letter
74
  from reportlab.lib.colors import HexColor
75
 
76
- def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
77
- c = canvas.Canvas(output_file, pagesize=letter)
78
- c.setFont("Courier", 8)
79
- x, y = 50, 750
80
- previous_text = ""
81
- current_text = ""
82
- for token_idx, reward in enumerate(token_rewards):
83
- current_text = tokenizer.decode(input_ids[: token_idx + 1])
84
- if current_text != previous_text:
85
- diff_text = current_text[len(previous_text) :]
86
- if "\n" in diff_text:
87
- lines = diff_text.split("\n")
88
- for line_idx, line in enumerate(lines):
89
- if line_idx > 0:
90
- x = 50
91
- y -= 12
92
- if abs(reward) < eps:
93
- opacity = 0
94
- elif abs(reward) > eps2:
95
- opacity = 0.8
96
- else:
97
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
98
- text_width = c.stringWidth(line)
99
- if reward > 0:
100
- highlight_color = HexColor("#4CCD99")
101
- else:
102
- highlight_color = HexColor("#FFC700")
103
- highlight_color.alpha = opacity
104
- c.setFillColor(highlight_color)
105
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
106
- c.setFillColor(HexColor("#000000"))
107
- c.drawString(x, y, line)
108
- x += text_width
109
- else:
110
- if abs(reward) < eps:
111
- opacity = 0
112
- elif abs(reward) > eps2:
113
- opacity = 0.8
114
- else:
115
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
116
- text_width = c.stringWidth(diff_text)
117
- if reward > 0:
118
- highlight_color = HexColor("#4CCD99")
119
- else:
120
- highlight_color = HexColor("#FFC700")
121
- highlight_color.alpha = opacity
122
- c.setFillColor(highlight_color)
123
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
124
- c.setFillColor(HexColor("#000000"))
125
- c.drawString(x, y, diff_text)
126
- x += text_width
127
- if x > 550:
128
- x = 50
129
- y -= 12
130
- if y < 50:
131
- c.showPage()
132
- y = 750
133
- x = 50
134
- previous_text = current_text
135
- c.showPage()
136
- c.save()
137
 
138
 
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -1178,6 +1174,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1178
  self.model = QuietModel(config)
1179
  self.vocab_size = config.vocab_size
1180
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
1181
  self.max_thoughts = config.max_thoughts
1182
  self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
1183
  self.use_concat_talk_head = config.use_concat_talk_head
@@ -1240,10 +1239,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1240
  self.tokenized_thought_prefix = None
1241
  self.log_dict = defaultdict(int)
1242
  self.eval_log_dict = defaultdict(int)
1243
- self.print_final_only = True
1244
  self.loss_mean = loss_mean
1245
- self.all_rewards = []
1246
- self.all_unreduced_losses = []
1247
 
1248
  self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
1249
  self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
@@ -1252,6 +1248,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1252
  self.embedding_scale = 1e2
1253
  self.reinforce_temperature = 3
1254
  self.base_loss_beta = 1
 
 
 
1255
 
1256
  # Not used in the paper:
1257
  self.use_thought_prefix = False
@@ -1259,7 +1258,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1259
  self.use_upper_triangular = False
1260
  self.subtract_mean_reward = False
1261
  self.comparison_mode = False
1262
- self.gumbel_detach = True
1263
 
1264
  # For visualization
1265
  self.eval_mode = False
@@ -1358,6 +1357,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1358
 
1359
  # Apply Gumbel-Softmax to the logits
1360
  next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
 
1361
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1362
 
1363
  # Append the generated token to the input sequence
@@ -1436,6 +1436,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1436
  inputs_embeds: Optional[torch.FloatTensor] = None,
1437
  labels: Optional[torch.LongTensor] = None,
1438
  use_cache: Optional[bool] = None,
 
1439
  output_attentions: Optional[bool] = None,
1440
  output_hidden_states: Optional[bool] = None,
1441
  return_dict: Optional[bool] = None,
@@ -1459,14 +1460,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
1459
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1460
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1461
  ```"""
1462
- log_dict = self.log_dict if self.training else self.eval_log_dict
1463
 
1464
  if not self.training:
1465
  n_ahead_talk_to_restore = self.n_ahead_talk
1466
  n_passes_to_restore = self.n_passes
1467
  self.n_ahead_talk = 1
1468
  self.n_passes = 1
1469
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1470
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1471
  output_hidden_states = (
1472
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1547,6 +1561,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1547
  else:
1548
  # convert to identity transform
1549
  def lambda_transform(cur_head):
 
1550
  if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
1551
  return torch.cat([
1552
  torch.eye(
@@ -1679,6 +1694,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1679
  use_cache=use_cache,
1680
  output_attentions=output_attentions,
1681
  output_hidden_states=output_hidden_states,
 
1682
  return_dict=return_dict,
1683
  )
1684
 
@@ -1793,8 +1809,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1793
  shift_labels = labels[..., 1 + shift_amount:].contiguous()
1794
  # Flatten the tokens
1795
  loss_fct = CrossEntropyLoss(reduction="none")
 
1796
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1797
  shift_labels = shift_labels.view(-1).clone()
 
1798
  # Enable model parallelism
1799
  shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
1800
  shift_labels = shift_labels.to(shift_logits.device)
@@ -1886,6 +1904,22 @@ class QuietForCausalLM(QuietPreTrainedModel):
1886
  inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1887
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1888
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1889
 
1890
  if len(attention_mask.shape) == 2:
1891
  breakpoint()
@@ -1933,7 +1967,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1933
  # if shift_labels.min() == self.tokenizer.pad_token_id:
1934
  shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
1935
  unreduced_loss = loss_fct(shift_logits, shift_labels)
 
1936
  if torch.any(unreduced_loss != unreduced_loss):
 
1937
  raise ValueError("NaN loss")
1938
  unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
1939
  loss_list.append(unreduced_loss)
@@ -1992,78 +2028,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1992
  else:
1993
  added_reward = original_dqn_reward
1994
  policy_reward += added_reward
1995
-
1996
- if self.use_policy_loss and ahead_idx == self.n_ahead + self.n_ahead_talk - 2:
1997
- # only compute during the thinking phase
1998
- if self.use_reparam_for_thought_embeddings and (self.use_start_thought_token or self.use_end_thought_token):
1999
- # sampled_start, sampled_end
2000
- # calculate the log likelihood of the start and end embeddings sampled from a multivariate normal distribution
2001
- # with mean start_embedding[0] and standard deviation start_embedding[1]
2002
- if self.use_start_thought_token:
2003
- exp_start_std = torch.exp(start_embedding[1])
2004
- start_loglikelihood = -0.5 * (sampled_start.detach() - start_embedding[0]) ** 2 / exp_start_std ** 2 - start_embedding[1] - 0.5 * math.log(2 * math.pi)
2005
- start_loglikelihood = start_loglikelihood.mean(dim=-1)
2006
- if self.use_end_thought_token:
2007
- exp_end_std = torch.exp(end_embedding[1])
2008
- end_loglikelihood = -0.5 * (sampled_end.detach() - end_embedding[0]) ** 2 / exp_end_std ** 2 - end_embedding[1] - 0.5 * math.log(2 * math.pi)
2009
- end_loglikelihood = end_loglikelihood.mean(dim=-1)
2010
- # we use the mean instead of the sum to prevent dependence on the dimensionality of the embeddings
2011
- if self.use_end_thought_token and self.use_policy_loss_for_end_thought:
2012
- action_loglikelihoods_list.append(end_loglikelihood)
2013
- if self.use_start_thought_token:
2014
- action_loglikelihoods_list.append(start_loglikelihood)
2015
-
2016
- if ahead_idx == self.n_ahead + self.n_ahead_talk - 2 and self.eval_mode:
2017
- with torch.no_grad():
2018
- # calculate the 0.75 quantile of the rewards
2019
- filtered_tokens = input_ids[:, :policy_reward.shape[-1]].cpu().detach().numpy().flatten()
2020
- filtered_tokens_mask = filtered_tokens != self.tokenizer.pad_token_id
2021
- filtered_tokens = filtered_tokens[filtered_tokens_mask]
2022
- filtered_rewards = policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten()
2023
- filtered_rewards = filtered_rewards[filtered_tokens_mask]
2024
-
2025
- abs_reward_list = np.abs(policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten())
2026
- abs_reward_list = abs_reward_list[filtered_tokens_mask]
2027
- medium_quantile = np.quantile(abs_reward_list, 0.5)
2028
- upper_quantile = np.quantile(abs_reward_list, 0.95)
2029
-
2030
- save_tokens_with_rewards_to_pdf(
2031
- filtered_tokens,
2032
- [0] + filtered_rewards.tolist(),
2033
- self.tokenizer,
2034
- output_file=f"texts/rewards_talk_{self.n_ahead_talk}_{self.training_steps}.pdf",
2035
- eps=medium_quantile,
2036
- eps2=upper_quantile,
2037
- )
2038
-
2039
- def plot_kde(data, losses):
2040
- sns.set(style="whitegrid")
2041
- # Create the KDE plot
2042
- sns.kdeplot(data, fill=True)
2043
- # Set the plot title and labels
2044
- plt.title("KDE Plot")
2045
- plt.xlabel("Value")
2046
- plt.ylabel("Density")
2047
- # Save the plot
2048
- plt.savefig(f"texts/kde_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
2049
- # Close the plot
2050
- plt.close()
2051
-
2052
- # Step 1: Create a base color palette
2053
- base_colors = sns.color_palette("light:#5A9", n_colors=256) # More colors for a smoother gradient
2054
- base_cmap = LinearSegmentedColormap.from_list("log_light", base_colors)
2055
- log_norm = LogNorm(vmin=1e-3, vmax=10)
2056
-
2057
- sns.kdeplot(x=data, y=losses, fill=True, levels=20, norm=log_norm, cut=0, linewidths=0)
2058
- # limit y to 0 to 25 and x to -1 to 1
2059
- plt.xlim(-1, 1)
2060
- plt.ylim(0, 25)
2061
- plt.savefig(f"texts/jointer_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
2062
- plt.close()
2063
-
2064
- self.all_rewards.extend(filtered_rewards)
2065
- self.all_unreduced_losses.extend(unreduced_loss[:, :-1].flatten()[filtered_tokens_mask].float().flatten().cpu().detach().numpy())
2066
- plot_kde(self.all_rewards, self.all_unreduced_losses)
2067
 
2068
  for action_loglikelihoods_2d in action_loglikelihoods_list:
2069
  train_policy_reward = policy_reward
@@ -2112,6 +2076,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
2112
  else:
2113
  loss = cur_loss
2114
  loss = loss / len(loss_list)
 
2115
 
2116
  loss = loss * self.base_loss_beta
2117
 
@@ -2133,64 +2098,15 @@ class QuietForCausalLM(QuietPreTrainedModel):
2133
 
2134
  if loss is not None:
2135
  base_log_dict["loss_train"] = loss.item()
2136
-
2137
- for loss_key, loss_val in base_log_dict.items():
2138
- log_dict[loss_key] += loss_val / self.n_tokens_print
2139
-
2140
- if self.use_policy_loss and policy_reward is not None:
2141
- log_dict["policy_loss"] += dqn_loss / self.n_tokens_print
2142
- log_dict["policy_reward"] += policy_reward.mean() / self.n_tokens_print
2143
-
2144
- if not loss_list:
2145
- if loss is not None:
2146
- log_dict["loss_0"] += loss / self.n_tokens_print
2147
- else:
2148
- log_dict["loss_final"] += nonzero_mean(loss_list[-1]) / self.n_tokens_print
2149
- log_dict["loss_talk"] += sum(nonzero_mean(cur_loss_item) for cur_loss_item in loss_list[-self.n_ahead_talk:]) / self.n_ahead_talk / self.n_tokens_print
2150
-
2151
- # also log relative losses to loss_0
2152
- if loss_list:
2153
- for i in range(len(loss_list)):
2154
- talk_idx = min(max(i - (self.n_ahead - 1), 0), len(talk_loss_list) - 1)
2155
- if not talk_loss_list:
2156
- cur_talk_loss = nonzero_mean(loss_list[0])
2157
- else:
2158
- cur_talk_loss = talk_loss_list[talk_idx]
2159
- log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
2160
- if self.training:
2161
- self.training_steps += 1
2162
- try:
2163
- # if self.training_steps % (self.gradient_accumulation_steps * 256) == 0:
2164
- if self.wandb_enabled:
2165
- if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device):
2166
- if not self.training:
2167
- new_log_dict = {}
2168
- for key in list(log_dict.keys()):
2169
- new_log_dict["eval_" + key] = log_dict[key]
2170
- log_dict = new_log_dict
2171
- log_dict["training_steps"] = self.training_steps
2172
- log_dict["batch_size"] = batch_size
2173
- log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
2174
- if self.n_ahead > 1:
2175
- log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps
2176
- else: # There's no overhead for talk tokens if there's no thinking
2177
- log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
2178
- # remove all nans
2179
- for key in list(log_dict.keys()):
2180
- if log_dict[key] != log_dict[key]:
2181
- del log_dict[key]
2182
- if self.training:
2183
- wandb.log(log_dict)
2184
- if self.training:
2185
- self.log_dict = defaultdict(int)
2186
- else:
2187
- self.eval_log_dict = defaultdict(int)
2188
- except Exception as e:
2189
- pass
2190
 
2191
  if not self.training:
2192
  self.n_ahead_talk = n_ahead_talk_to_restore
2193
  self.n_passes = n_passes_to_restore
 
 
 
 
 
2194
  return CausalLMOutputWithPast(
2195
  loss=loss if loss is not None else None,
2196
  logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
 
44
 
45
  from transformers.activations import ACT2FN
46
  from transformers.cache_utils import Cache, DynamicCache
47
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
48
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
  from transformers.modeling_utils import PreTrainedModel
50
  from transformers.utils import (
 
73
  from reportlab.lib.pagesizes import letter
74
  from reportlab.lib.colors import HexColor
75
 
76
+
77
+ def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
78
+ # Compute the attention mask correctly
79
+ bsz, tgt_len = input_shape
80
+
81
+ # Create a 4D attention mask from a 2D tensor mask.
82
+ # The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
83
+ # The values are either 0 or 1, where 0 means padding and 1 means non-padding.
84
+ combined_attention_mask = None
85
+ if attention_mask is not None:
86
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
87
+ # In this case, we can just use it directly.
88
+ if attention_mask.dim() == 4:
89
+ combined_attention_mask = attention_mask
90
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
91
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
92
+ elif attention_mask.dim() == 3:
93
+ expanded_attn_mask = attention_mask[:, None, :, :]
94
+ combined_attention_mask = expanded_attn_mask
95
+ # What if attention_mask is not None and has a shape of (batch_size, tgt_len)
96
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
97
+ elif attention_mask.dim() == 2:
98
+ # Provided a padding mask of dimensions [batch_size, seq_length]
99
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
100
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
101
+ if past_key_values_length > 0:
102
+ attention_mask = attention_mask.to(dtype=torch.long)
103
+ attention_mask = attention_mask[:, past_key_values_length:]
104
+ expanded_attn_mask = attention_mask[:, None, None, :]
105
+ combined_attention_mask = expanded_attn_mask
106
+ else:
107
+ raise ValueError(
108
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
109
+ input_shape, attention_mask.shape
110
+ )
111
+ )
112
+
113
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
114
+ # masked positions, this operation will create a tensor which is 0.0 for
115
+ # positions we want to attend and -10000.0 for masked positions.
116
+ # Since we are adding it to the raw scores before the softmax, this is
117
+ # effectively the same as removing these entirely.
118
+ if combined_attention_mask is not None:
119
+ # Ensure the attention mask values are within a reasonable range
120
+ combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
121
+
122
+ # Convert the attention mask to bfloat16
123
+ combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
124
+
125
+ # Normalize the attention mask values to be between 0 and 1
126
+ combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
127
+ else:
128
+ combined_attention_mask = torch.zeros(
129
+ (bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
130
+ )
131
+
132
+ return combined_attention_mask
 
 
 
 
133
 
134
 
135
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
1174
  self.model = QuietModel(config)
1175
  self.vocab_size = config.vocab_size
1176
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1177
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1178
+ self.num_experts = config.num_experts
1179
+ self.num_experts_per_tok = config.num_experts_per_tok
1180
  self.max_thoughts = config.max_thoughts
1181
  self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
1182
  self.use_concat_talk_head = config.use_concat_talk_head
 
1239
  self.tokenized_thought_prefix = None
1240
  self.log_dict = defaultdict(int)
1241
  self.eval_log_dict = defaultdict(int)
 
1242
  self.loss_mean = loss_mean
 
 
1243
 
1244
  self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
1245
  self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
 
1248
  self.embedding_scale = 1e2
1249
  self.reinforce_temperature = 3
1250
  self.base_loss_beta = 1
1251
+ self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
1252
+ self.thinking_threshold = 0.5
1253
+ self.thinking_usefulness_loss_weight = 1e-2
1254
 
1255
  # Not used in the paper:
1256
  self.use_thought_prefix = False
 
1258
  self.use_upper_triangular = False
1259
  self.subtract_mean_reward = False
1260
  self.comparison_mode = False
1261
+ self.gumbel_detach = False
1262
 
1263
  # For visualization
1264
  self.eval_mode = False
 
1357
 
1358
  # Apply Gumbel-Softmax to the logits
1359
  next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1360
+ print("Next token logits:", next_token_logits)
1361
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1362
 
1363
  # Append the generated token to the input sequence
 
1436
  inputs_embeds: Optional[torch.FloatTensor] = None,
1437
  labels: Optional[torch.LongTensor] = None,
1438
  use_cache: Optional[bool] = None,
1439
+ # output_router_logits: Optional[bool] = None,
1440
  output_attentions: Optional[bool] = None,
1441
  output_hidden_states: Optional[bool] = None,
1442
  return_dict: Optional[bool] = None,
 
1460
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1461
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1462
  ```"""
 
1463
 
1464
  if not self.training:
1465
  n_ahead_talk_to_restore = self.n_ahead_talk
1466
  n_passes_to_restore = self.n_passes
1467
  self.n_ahead_talk = 1
1468
  self.n_passes = 1
1469
+
1470
+ # aux_loss = None
1471
+ # output_router_logits = output_router_logits if output_router_logits is not None else self.config.output_router_logits
1472
+ # if output_router_logits:
1473
+ # router_logits = outputs.router_logits if return_dict else outputs[-1]
1474
+ # if router_logits is not None:
1475
+ # aux_loss = load_balancing_loss_func(
1476
+ # router_logits,
1477
+ # self.num_experts,
1478
+ # self.num_experts_per_tok,
1479
+ # attention_mask,
1480
+ # )
1481
+ # if labels is not None:
1482
+ # loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1483
+
1484
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1485
  output_hidden_states = (
1486
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1561
  else:
1562
  # convert to identity transform
1563
  def lambda_transform(cur_head):
1564
+ # pdb.set_trace()
1565
  if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
1566
  return torch.cat([
1567
  torch.eye(
 
1694
  use_cache=use_cache,
1695
  output_attentions=output_attentions,
1696
  output_hidden_states=output_hidden_states,
1697
+ # output_router_logits=output_router_logits,
1698
  return_dict=return_dict,
1699
  )
1700
 
 
1809
  shift_labels = labels[..., 1 + shift_amount:].contiguous()
1810
  # Flatten the tokens
1811
  loss_fct = CrossEntropyLoss(reduction="none")
1812
+ print("Shift logits before:", shift_logits)
1813
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1814
  shift_labels = shift_labels.view(-1).clone()
1815
+ print("shift logits after:", shift_logits)
1816
  # Enable model parallelism
1817
  shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
1818
  shift_labels = shift_labels.to(shift_logits.device)
 
1904
  inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1905
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1906
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1907
+
1908
+ # Predict the usefulness of thinking at each token position
1909
+ thinking_usefulness = self.thinking_usefulness_head(hidden_states).squeeze(-1)
1910
+
1911
+ # Apply a threshold to decide where to generate thoughts
1912
+ generate_thought_mask = thinking_usefulness > self.thinking_threshold
1913
+
1914
+ # Compute the regularization loss for thinking usefulness prediction
1915
+ thinking_usefulness_loss = torch.mean(thinking_usefulness * (1 - generate_thought_mask.float()))
1916
+
1917
+ # Add the regularization loss to the total loss
1918
+ if loss is not None:
1919
+ loss = loss + self.thinking_usefulness_loss_weight * thinking_usefulness_loss
1920
+ else:
1921
+ loss = self.thinking_usefulness_loss_weight * thinking_usefulness_loss
1922
+
1923
 
1924
  if len(attention_mask.shape) == 2:
1925
  breakpoint()
 
1967
  # if shift_labels.min() == self.tokenizer.pad_token_id:
1968
  shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
1969
  unreduced_loss = loss_fct(shift_logits, shift_labels)
1970
+ # print("Loss:", unreduced_loss.item()) # Print the loss before checking for NaN values
1971
  if torch.any(unreduced_loss != unreduced_loss):
1972
+ # pdb.set_trace()
1973
  raise ValueError("NaN loss")
1974
  unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
1975
  loss_list.append(unreduced_loss)
 
2028
  else:
2029
  added_reward = original_dqn_reward
2030
  policy_reward += added_reward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2031
 
2032
  for action_loglikelihoods_2d in action_loglikelihoods_list:
2033
  train_policy_reward = policy_reward
 
2076
  else:
2077
  loss = cur_loss
2078
  loss = loss / len(loss_list)
2079
+ loss = loss + thinking_usefulness_loss
2080
 
2081
  loss = loss * self.base_loss_beta
2082
 
 
2098
 
2099
  if loss is not None:
2100
  base_log_dict["loss_train"] = loss.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2101
 
2102
  if not self.training:
2103
  self.n_ahead_talk = n_ahead_talk_to_restore
2104
  self.n_passes = n_passes_to_restore
2105
+
2106
+ del start_embedding
2107
+ del end_embedding
2108
+ torch.cuda.empty_cache()
2109
+
2110
  return CausalLMOutputWithPast(
2111
  loss=loss if loss is not None else None,
2112
  logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,