Update modeling_quiet.py
Browse files- modeling_quiet.py +110 -194
modeling_quiet.py
CHANGED
|
@@ -44,7 +44,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
| 44 |
|
| 45 |
from transformers.activations import ACT2FN
|
| 46 |
from transformers.cache_utils import Cache, DynamicCache
|
| 47 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 49 |
from transformers.modeling_utils import PreTrainedModel
|
| 50 |
from transformers.utils import (
|
|
@@ -73,67 +73,63 @@ from reportlab.pdfgen import canvas
|
|
| 73 |
from reportlab.lib.pagesizes import letter
|
| 74 |
from reportlab.lib.colors import HexColor
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
x = 50
|
| 134 |
-
previous_text = current_text
|
| 135 |
-
c.showPage()
|
| 136 |
-
c.save()
|
| 137 |
|
| 138 |
|
| 139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
@@ -1178,6 +1174,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1178 |
self.model = QuietModel(config)
|
| 1179 |
self.vocab_size = config.vocab_size
|
| 1180 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
| 1181 |
self.max_thoughts = config.max_thoughts
|
| 1182 |
self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
|
| 1183 |
self.use_concat_talk_head = config.use_concat_talk_head
|
|
@@ -1240,10 +1239,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1240 |
self.tokenized_thought_prefix = None
|
| 1241 |
self.log_dict = defaultdict(int)
|
| 1242 |
self.eval_log_dict = defaultdict(int)
|
| 1243 |
-
self.print_final_only = True
|
| 1244 |
self.loss_mean = loss_mean
|
| 1245 |
-
self.all_rewards = []
|
| 1246 |
-
self.all_unreduced_losses = []
|
| 1247 |
|
| 1248 |
self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
| 1249 |
self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
|
@@ -1252,6 +1248,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1252 |
self.embedding_scale = 1e2
|
| 1253 |
self.reinforce_temperature = 3
|
| 1254 |
self.base_loss_beta = 1
|
|
|
|
|
|
|
|
|
|
| 1255 |
|
| 1256 |
# Not used in the paper:
|
| 1257 |
self.use_thought_prefix = False
|
|
@@ -1259,7 +1258,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1259 |
self.use_upper_triangular = False
|
| 1260 |
self.subtract_mean_reward = False
|
| 1261 |
self.comparison_mode = False
|
| 1262 |
-
self.gumbel_detach =
|
| 1263 |
|
| 1264 |
# For visualization
|
| 1265 |
self.eval_mode = False
|
|
@@ -1358,6 +1357,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1358 |
|
| 1359 |
# Apply Gumbel-Softmax to the logits
|
| 1360 |
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
|
|
|
|
| 1361 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
| 1362 |
|
| 1363 |
# Append the generated token to the input sequence
|
|
@@ -1436,6 +1436,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1436 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1437 |
labels: Optional[torch.LongTensor] = None,
|
| 1438 |
use_cache: Optional[bool] = None,
|
|
|
|
| 1439 |
output_attentions: Optional[bool] = None,
|
| 1440 |
output_hidden_states: Optional[bool] = None,
|
| 1441 |
return_dict: Optional[bool] = None,
|
|
@@ -1459,14 +1460,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1459 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1460 |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1461 |
```"""
|
| 1462 |
-
log_dict = self.log_dict if self.training else self.eval_log_dict
|
| 1463 |
|
| 1464 |
if not self.training:
|
| 1465 |
n_ahead_talk_to_restore = self.n_ahead_talk
|
| 1466 |
n_passes_to_restore = self.n_passes
|
| 1467 |
self.n_ahead_talk = 1
|
| 1468 |
self.n_passes = 1
|
| 1469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1470 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1471 |
output_hidden_states = (
|
| 1472 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -1547,6 +1561,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1547 |
else:
|
| 1548 |
# convert to identity transform
|
| 1549 |
def lambda_transform(cur_head):
|
|
|
|
| 1550 |
if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
|
| 1551 |
return torch.cat([
|
| 1552 |
torch.eye(
|
|
@@ -1679,6 +1694,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1679 |
use_cache=use_cache,
|
| 1680 |
output_attentions=output_attentions,
|
| 1681 |
output_hidden_states=output_hidden_states,
|
|
|
|
| 1682 |
return_dict=return_dict,
|
| 1683 |
)
|
| 1684 |
|
|
@@ -1793,8 +1809,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1793 |
shift_labels = labels[..., 1 + shift_amount:].contiguous()
|
| 1794 |
# Flatten the tokens
|
| 1795 |
loss_fct = CrossEntropyLoss(reduction="none")
|
|
|
|
| 1796 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1797 |
shift_labels = shift_labels.view(-1).clone()
|
|
|
|
| 1798 |
# Enable model parallelism
|
| 1799 |
shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
|
| 1800 |
shift_labels = shift_labels.to(shift_logits.device)
|
|
@@ -1886,6 +1904,22 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1886 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
| 1887 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
| 1888 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1889 |
|
| 1890 |
if len(attention_mask.shape) == 2:
|
| 1891 |
breakpoint()
|
|
@@ -1933,7 +1967,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1933 |
# if shift_labels.min() == self.tokenizer.pad_token_id:
|
| 1934 |
shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
|
| 1935 |
unreduced_loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
| 1936 |
if torch.any(unreduced_loss != unreduced_loss):
|
|
|
|
| 1937 |
raise ValueError("NaN loss")
|
| 1938 |
unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
|
| 1939 |
loss_list.append(unreduced_loss)
|
|
@@ -1992,78 +2028,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1992 |
else:
|
| 1993 |
added_reward = original_dqn_reward
|
| 1994 |
policy_reward += added_reward
|
| 1995 |
-
|
| 1996 |
-
if self.use_policy_loss and ahead_idx == self.n_ahead + self.n_ahead_talk - 2:
|
| 1997 |
-
# only compute during the thinking phase
|
| 1998 |
-
if self.use_reparam_for_thought_embeddings and (self.use_start_thought_token or self.use_end_thought_token):
|
| 1999 |
-
# sampled_start, sampled_end
|
| 2000 |
-
# calculate the log likelihood of the start and end embeddings sampled from a multivariate normal distribution
|
| 2001 |
-
# with mean start_embedding[0] and standard deviation start_embedding[1]
|
| 2002 |
-
if self.use_start_thought_token:
|
| 2003 |
-
exp_start_std = torch.exp(start_embedding[1])
|
| 2004 |
-
start_loglikelihood = -0.5 * (sampled_start.detach() - start_embedding[0]) ** 2 / exp_start_std ** 2 - start_embedding[1] - 0.5 * math.log(2 * math.pi)
|
| 2005 |
-
start_loglikelihood = start_loglikelihood.mean(dim=-1)
|
| 2006 |
-
if self.use_end_thought_token:
|
| 2007 |
-
exp_end_std = torch.exp(end_embedding[1])
|
| 2008 |
-
end_loglikelihood = -0.5 * (sampled_end.detach() - end_embedding[0]) ** 2 / exp_end_std ** 2 - end_embedding[1] - 0.5 * math.log(2 * math.pi)
|
| 2009 |
-
end_loglikelihood = end_loglikelihood.mean(dim=-1)
|
| 2010 |
-
# we use the mean instead of the sum to prevent dependence on the dimensionality of the embeddings
|
| 2011 |
-
if self.use_end_thought_token and self.use_policy_loss_for_end_thought:
|
| 2012 |
-
action_loglikelihoods_list.append(end_loglikelihood)
|
| 2013 |
-
if self.use_start_thought_token:
|
| 2014 |
-
action_loglikelihoods_list.append(start_loglikelihood)
|
| 2015 |
-
|
| 2016 |
-
if ahead_idx == self.n_ahead + self.n_ahead_talk - 2 and self.eval_mode:
|
| 2017 |
-
with torch.no_grad():
|
| 2018 |
-
# calculate the 0.75 quantile of the rewards
|
| 2019 |
-
filtered_tokens = input_ids[:, :policy_reward.shape[-1]].cpu().detach().numpy().flatten()
|
| 2020 |
-
filtered_tokens_mask = filtered_tokens != self.tokenizer.pad_token_id
|
| 2021 |
-
filtered_tokens = filtered_tokens[filtered_tokens_mask]
|
| 2022 |
-
filtered_rewards = policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten()
|
| 2023 |
-
filtered_rewards = filtered_rewards[filtered_tokens_mask]
|
| 2024 |
-
|
| 2025 |
-
abs_reward_list = np.abs(policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten())
|
| 2026 |
-
abs_reward_list = abs_reward_list[filtered_tokens_mask]
|
| 2027 |
-
medium_quantile = np.quantile(abs_reward_list, 0.5)
|
| 2028 |
-
upper_quantile = np.quantile(abs_reward_list, 0.95)
|
| 2029 |
-
|
| 2030 |
-
save_tokens_with_rewards_to_pdf(
|
| 2031 |
-
filtered_tokens,
|
| 2032 |
-
[0] + filtered_rewards.tolist(),
|
| 2033 |
-
self.tokenizer,
|
| 2034 |
-
output_file=f"texts/rewards_talk_{self.n_ahead_talk}_{self.training_steps}.pdf",
|
| 2035 |
-
eps=medium_quantile,
|
| 2036 |
-
eps2=upper_quantile,
|
| 2037 |
-
)
|
| 2038 |
-
|
| 2039 |
-
def plot_kde(data, losses):
|
| 2040 |
-
sns.set(style="whitegrid")
|
| 2041 |
-
# Create the KDE plot
|
| 2042 |
-
sns.kdeplot(data, fill=True)
|
| 2043 |
-
# Set the plot title and labels
|
| 2044 |
-
plt.title("KDE Plot")
|
| 2045 |
-
plt.xlabel("Value")
|
| 2046 |
-
plt.ylabel("Density")
|
| 2047 |
-
# Save the plot
|
| 2048 |
-
plt.savefig(f"texts/kde_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
|
| 2049 |
-
# Close the plot
|
| 2050 |
-
plt.close()
|
| 2051 |
-
|
| 2052 |
-
# Step 1: Create a base color palette
|
| 2053 |
-
base_colors = sns.color_palette("light:#5A9", n_colors=256) # More colors for a smoother gradient
|
| 2054 |
-
base_cmap = LinearSegmentedColormap.from_list("log_light", base_colors)
|
| 2055 |
-
log_norm = LogNorm(vmin=1e-3, vmax=10)
|
| 2056 |
-
|
| 2057 |
-
sns.kdeplot(x=data, y=losses, fill=True, levels=20, norm=log_norm, cut=0, linewidths=0)
|
| 2058 |
-
# limit y to 0 to 25 and x to -1 to 1
|
| 2059 |
-
plt.xlim(-1, 1)
|
| 2060 |
-
plt.ylim(0, 25)
|
| 2061 |
-
plt.savefig(f"texts/jointer_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
|
| 2062 |
-
plt.close()
|
| 2063 |
-
|
| 2064 |
-
self.all_rewards.extend(filtered_rewards)
|
| 2065 |
-
self.all_unreduced_losses.extend(unreduced_loss[:, :-1].flatten()[filtered_tokens_mask].float().flatten().cpu().detach().numpy())
|
| 2066 |
-
plot_kde(self.all_rewards, self.all_unreduced_losses)
|
| 2067 |
|
| 2068 |
for action_loglikelihoods_2d in action_loglikelihoods_list:
|
| 2069 |
train_policy_reward = policy_reward
|
|
@@ -2112,6 +2076,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 2112 |
else:
|
| 2113 |
loss = cur_loss
|
| 2114 |
loss = loss / len(loss_list)
|
|
|
|
| 2115 |
|
| 2116 |
loss = loss * self.base_loss_beta
|
| 2117 |
|
|
@@ -2133,64 +2098,15 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 2133 |
|
| 2134 |
if loss is not None:
|
| 2135 |
base_log_dict["loss_train"] = loss.item()
|
| 2136 |
-
|
| 2137 |
-
for loss_key, loss_val in base_log_dict.items():
|
| 2138 |
-
log_dict[loss_key] += loss_val / self.n_tokens_print
|
| 2139 |
-
|
| 2140 |
-
if self.use_policy_loss and policy_reward is not None:
|
| 2141 |
-
log_dict["policy_loss"] += dqn_loss / self.n_tokens_print
|
| 2142 |
-
log_dict["policy_reward"] += policy_reward.mean() / self.n_tokens_print
|
| 2143 |
-
|
| 2144 |
-
if not loss_list:
|
| 2145 |
-
if loss is not None:
|
| 2146 |
-
log_dict["loss_0"] += loss / self.n_tokens_print
|
| 2147 |
-
else:
|
| 2148 |
-
log_dict["loss_final"] += nonzero_mean(loss_list[-1]) / self.n_tokens_print
|
| 2149 |
-
log_dict["loss_talk"] += sum(nonzero_mean(cur_loss_item) for cur_loss_item in loss_list[-self.n_ahead_talk:]) / self.n_ahead_talk / self.n_tokens_print
|
| 2150 |
-
|
| 2151 |
-
# also log relative losses to loss_0
|
| 2152 |
-
if loss_list:
|
| 2153 |
-
for i in range(len(loss_list)):
|
| 2154 |
-
talk_idx = min(max(i - (self.n_ahead - 1), 0), len(talk_loss_list) - 1)
|
| 2155 |
-
if not talk_loss_list:
|
| 2156 |
-
cur_talk_loss = nonzero_mean(loss_list[0])
|
| 2157 |
-
else:
|
| 2158 |
-
cur_talk_loss = talk_loss_list[talk_idx]
|
| 2159 |
-
log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
|
| 2160 |
-
if self.training:
|
| 2161 |
-
self.training_steps += 1
|
| 2162 |
-
try:
|
| 2163 |
-
# if self.training_steps % (self.gradient_accumulation_steps * 256) == 0:
|
| 2164 |
-
if self.wandb_enabled:
|
| 2165 |
-
if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device):
|
| 2166 |
-
if not self.training:
|
| 2167 |
-
new_log_dict = {}
|
| 2168 |
-
for key in list(log_dict.keys()):
|
| 2169 |
-
new_log_dict["eval_" + key] = log_dict[key]
|
| 2170 |
-
log_dict = new_log_dict
|
| 2171 |
-
log_dict["training_steps"] = self.training_steps
|
| 2172 |
-
log_dict["batch_size"] = batch_size
|
| 2173 |
-
log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
|
| 2174 |
-
if self.n_ahead > 1:
|
| 2175 |
-
log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps
|
| 2176 |
-
else: # There's no overhead for talk tokens if there's no thinking
|
| 2177 |
-
log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
|
| 2178 |
-
# remove all nans
|
| 2179 |
-
for key in list(log_dict.keys()):
|
| 2180 |
-
if log_dict[key] != log_dict[key]:
|
| 2181 |
-
del log_dict[key]
|
| 2182 |
-
if self.training:
|
| 2183 |
-
wandb.log(log_dict)
|
| 2184 |
-
if self.training:
|
| 2185 |
-
self.log_dict = defaultdict(int)
|
| 2186 |
-
else:
|
| 2187 |
-
self.eval_log_dict = defaultdict(int)
|
| 2188 |
-
except Exception as e:
|
| 2189 |
-
pass
|
| 2190 |
|
| 2191 |
if not self.training:
|
| 2192 |
self.n_ahead_talk = n_ahead_talk_to_restore
|
| 2193 |
self.n_passes = n_passes_to_restore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2194 |
return CausalLMOutputWithPast(
|
| 2195 |
loss=loss if loss is not None else None,
|
| 2196 |
logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
|
|
|
|
| 44 |
|
| 45 |
from transformers.activations import ACT2FN
|
| 46 |
from transformers.cache_utils import Cache, DynamicCache
|
| 47 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 49 |
from transformers.modeling_utils import PreTrainedModel
|
| 50 |
from transformers.utils import (
|
|
|
|
| 73 |
from reportlab.lib.pagesizes import letter
|
| 74 |
from reportlab.lib.colors import HexColor
|
| 75 |
|
| 76 |
+
|
| 77 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
| 78 |
+
# Compute the attention mask correctly
|
| 79 |
+
bsz, tgt_len = input_shape
|
| 80 |
+
|
| 81 |
+
# Create a 4D attention mask from a 2D tensor mask.
|
| 82 |
+
# The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
|
| 83 |
+
# The values are either 0 or 1, where 0 means padding and 1 means non-padding.
|
| 84 |
+
combined_attention_mask = None
|
| 85 |
+
if attention_mask is not None:
|
| 86 |
+
# What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
|
| 87 |
+
# In this case, we can just use it directly.
|
| 88 |
+
if attention_mask.dim() == 4:
|
| 89 |
+
combined_attention_mask = attention_mask
|
| 90 |
+
# What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
|
| 91 |
+
# In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
|
| 92 |
+
elif attention_mask.dim() == 3:
|
| 93 |
+
expanded_attn_mask = attention_mask[:, None, :, :]
|
| 94 |
+
combined_attention_mask = expanded_attn_mask
|
| 95 |
+
# What if attention_mask is not None and has a shape of (batch_size, tgt_len)
|
| 96 |
+
# In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
|
| 97 |
+
elif attention_mask.dim() == 2:
|
| 98 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 99 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 100 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 101 |
+
if past_key_values_length > 0:
|
| 102 |
+
attention_mask = attention_mask.to(dtype=torch.long)
|
| 103 |
+
attention_mask = attention_mask[:, past_key_values_length:]
|
| 104 |
+
expanded_attn_mask = attention_mask[:, None, None, :]
|
| 105 |
+
combined_attention_mask = expanded_attn_mask
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 109 |
+
input_shape, attention_mask.shape
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 114 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 115 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 116 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 117 |
+
# effectively the same as removing these entirely.
|
| 118 |
+
if combined_attention_mask is not None:
|
| 119 |
+
# Ensure the attention mask values are within a reasonable range
|
| 120 |
+
combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
|
| 121 |
+
|
| 122 |
+
# Convert the attention mask to bfloat16
|
| 123 |
+
combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
|
| 124 |
+
|
| 125 |
+
# Normalize the attention mask values to be between 0 and 1
|
| 126 |
+
combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
|
| 127 |
+
else:
|
| 128 |
+
combined_attention_mask = torch.zeros(
|
| 129 |
+
(bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return combined_attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
|
| 1174 |
self.model = QuietModel(config)
|
| 1175 |
self.vocab_size = config.vocab_size
|
| 1176 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1177 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 1178 |
+
self.num_experts = config.num_experts
|
| 1179 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 1180 |
self.max_thoughts = config.max_thoughts
|
| 1181 |
self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
|
| 1182 |
self.use_concat_talk_head = config.use_concat_talk_head
|
|
|
|
| 1239 |
self.tokenized_thought_prefix = None
|
| 1240 |
self.log_dict = defaultdict(int)
|
| 1241 |
self.eval_log_dict = defaultdict(int)
|
|
|
|
| 1242 |
self.loss_mean = loss_mean
|
|
|
|
|
|
|
| 1243 |
|
| 1244 |
self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
| 1245 |
self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
|
|
|
| 1248 |
self.embedding_scale = 1e2
|
| 1249 |
self.reinforce_temperature = 3
|
| 1250 |
self.base_loss_beta = 1
|
| 1251 |
+
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
| 1252 |
+
self.thinking_threshold = 0.5
|
| 1253 |
+
self.thinking_usefulness_loss_weight = 1e-2
|
| 1254 |
|
| 1255 |
# Not used in the paper:
|
| 1256 |
self.use_thought_prefix = False
|
|
|
|
| 1258 |
self.use_upper_triangular = False
|
| 1259 |
self.subtract_mean_reward = False
|
| 1260 |
self.comparison_mode = False
|
| 1261 |
+
self.gumbel_detach = False
|
| 1262 |
|
| 1263 |
# For visualization
|
| 1264 |
self.eval_mode = False
|
|
|
|
| 1357 |
|
| 1358 |
# Apply Gumbel-Softmax to the logits
|
| 1359 |
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
|
| 1360 |
+
print("Next token logits:", next_token_logits)
|
| 1361 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
| 1362 |
|
| 1363 |
# Append the generated token to the input sequence
|
|
|
|
| 1436 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1437 |
labels: Optional[torch.LongTensor] = None,
|
| 1438 |
use_cache: Optional[bool] = None,
|
| 1439 |
+
# output_router_logits: Optional[bool] = None,
|
| 1440 |
output_attentions: Optional[bool] = None,
|
| 1441 |
output_hidden_states: Optional[bool] = None,
|
| 1442 |
return_dict: Optional[bool] = None,
|
|
|
|
| 1460 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1461 |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1462 |
```"""
|
|
|
|
| 1463 |
|
| 1464 |
if not self.training:
|
| 1465 |
n_ahead_talk_to_restore = self.n_ahead_talk
|
| 1466 |
n_passes_to_restore = self.n_passes
|
| 1467 |
self.n_ahead_talk = 1
|
| 1468 |
self.n_passes = 1
|
| 1469 |
+
|
| 1470 |
+
# aux_loss = None
|
| 1471 |
+
# output_router_logits = output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1472 |
+
# if output_router_logits:
|
| 1473 |
+
# router_logits = outputs.router_logits if return_dict else outputs[-1]
|
| 1474 |
+
# if router_logits is not None:
|
| 1475 |
+
# aux_loss = load_balancing_loss_func(
|
| 1476 |
+
# router_logits,
|
| 1477 |
+
# self.num_experts,
|
| 1478 |
+
# self.num_experts_per_tok,
|
| 1479 |
+
# attention_mask,
|
| 1480 |
+
# )
|
| 1481 |
+
# if labels is not None:
|
| 1482 |
+
# loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
| 1483 |
+
|
| 1484 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1485 |
output_hidden_states = (
|
| 1486 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
| 1561 |
else:
|
| 1562 |
# convert to identity transform
|
| 1563 |
def lambda_transform(cur_head):
|
| 1564 |
+
# pdb.set_trace()
|
| 1565 |
if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
|
| 1566 |
return torch.cat([
|
| 1567 |
torch.eye(
|
|
|
|
| 1694 |
use_cache=use_cache,
|
| 1695 |
output_attentions=output_attentions,
|
| 1696 |
output_hidden_states=output_hidden_states,
|
| 1697 |
+
# output_router_logits=output_router_logits,
|
| 1698 |
return_dict=return_dict,
|
| 1699 |
)
|
| 1700 |
|
|
|
|
| 1809 |
shift_labels = labels[..., 1 + shift_amount:].contiguous()
|
| 1810 |
# Flatten the tokens
|
| 1811 |
loss_fct = CrossEntropyLoss(reduction="none")
|
| 1812 |
+
print("Shift logits before:", shift_logits)
|
| 1813 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1814 |
shift_labels = shift_labels.view(-1).clone()
|
| 1815 |
+
print("shift logits after:", shift_logits)
|
| 1816 |
# Enable model parallelism
|
| 1817 |
shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
|
| 1818 |
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
| 1904 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
| 1905 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
| 1906 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
| 1907 |
+
|
| 1908 |
+
# Predict the usefulness of thinking at each token position
|
| 1909 |
+
thinking_usefulness = self.thinking_usefulness_head(hidden_states).squeeze(-1)
|
| 1910 |
+
|
| 1911 |
+
# Apply a threshold to decide where to generate thoughts
|
| 1912 |
+
generate_thought_mask = thinking_usefulness > self.thinking_threshold
|
| 1913 |
+
|
| 1914 |
+
# Compute the regularization loss for thinking usefulness prediction
|
| 1915 |
+
thinking_usefulness_loss = torch.mean(thinking_usefulness * (1 - generate_thought_mask.float()))
|
| 1916 |
+
|
| 1917 |
+
# Add the regularization loss to the total loss
|
| 1918 |
+
if loss is not None:
|
| 1919 |
+
loss = loss + self.thinking_usefulness_loss_weight * thinking_usefulness_loss
|
| 1920 |
+
else:
|
| 1921 |
+
loss = self.thinking_usefulness_loss_weight * thinking_usefulness_loss
|
| 1922 |
+
|
| 1923 |
|
| 1924 |
if len(attention_mask.shape) == 2:
|
| 1925 |
breakpoint()
|
|
|
|
| 1967 |
# if shift_labels.min() == self.tokenizer.pad_token_id:
|
| 1968 |
shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
|
| 1969 |
unreduced_loss = loss_fct(shift_logits, shift_labels)
|
| 1970 |
+
# print("Loss:", unreduced_loss.item()) # Print the loss before checking for NaN values
|
| 1971 |
if torch.any(unreduced_loss != unreduced_loss):
|
| 1972 |
+
# pdb.set_trace()
|
| 1973 |
raise ValueError("NaN loss")
|
| 1974 |
unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
|
| 1975 |
loss_list.append(unreduced_loss)
|
|
|
|
| 2028 |
else:
|
| 2029 |
added_reward = original_dqn_reward
|
| 2030 |
policy_reward += added_reward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2031 |
|
| 2032 |
for action_loglikelihoods_2d in action_loglikelihoods_list:
|
| 2033 |
train_policy_reward = policy_reward
|
|
|
|
| 2076 |
else:
|
| 2077 |
loss = cur_loss
|
| 2078 |
loss = loss / len(loss_list)
|
| 2079 |
+
loss = loss + thinking_usefulness_loss
|
| 2080 |
|
| 2081 |
loss = loss * self.base_loss_beta
|
| 2082 |
|
|
|
|
| 2098 |
|
| 2099 |
if loss is not None:
|
| 2100 |
base_log_dict["loss_train"] = loss.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2101 |
|
| 2102 |
if not self.training:
|
| 2103 |
self.n_ahead_talk = n_ahead_talk_to_restore
|
| 2104 |
self.n_passes = n_passes_to_restore
|
| 2105 |
+
|
| 2106 |
+
del start_embedding
|
| 2107 |
+
del end_embedding
|
| 2108 |
+
torch.cuda.empty_cache()
|
| 2109 |
+
|
| 2110 |
return CausalLMOutputWithPast(
|
| 2111 |
loss=loss if loss is not None else None,
|
| 2112 |
logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
|