Update modeling_quiet.py
Browse files- modeling_quiet.py +110 -194
modeling_quiet.py
CHANGED
@@ -44,7 +44,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
44 |
|
45 |
from transformers.activations import ACT2FN
|
46 |
from transformers.cache_utils import Cache, DynamicCache
|
47 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
49 |
from transformers.modeling_utils import PreTrainedModel
|
50 |
from transformers.utils import (
|
@@ -73,67 +73,63 @@ from reportlab.pdfgen import canvas
|
|
73 |
from reportlab.lib.pagesizes import letter
|
74 |
from reportlab.lib.colors import HexColor
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
x = 50
|
134 |
-
previous_text = current_text
|
135 |
-
c.showPage()
|
136 |
-
c.save()
|
137 |
|
138 |
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
@@ -1178,6 +1174,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1178 |
self.model = QuietModel(config)
|
1179 |
self.vocab_size = config.vocab_size
|
1180 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
1181 |
self.max_thoughts = config.max_thoughts
|
1182 |
self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
|
1183 |
self.use_concat_talk_head = config.use_concat_talk_head
|
@@ -1240,10 +1239,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1240 |
self.tokenized_thought_prefix = None
|
1241 |
self.log_dict = defaultdict(int)
|
1242 |
self.eval_log_dict = defaultdict(int)
|
1243 |
-
self.print_final_only = True
|
1244 |
self.loss_mean = loss_mean
|
1245 |
-
self.all_rewards = []
|
1246 |
-
self.all_unreduced_losses = []
|
1247 |
|
1248 |
self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
1249 |
self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
@@ -1252,6 +1248,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1252 |
self.embedding_scale = 1e2
|
1253 |
self.reinforce_temperature = 3
|
1254 |
self.base_loss_beta = 1
|
|
|
|
|
|
|
1255 |
|
1256 |
# Not used in the paper:
|
1257 |
self.use_thought_prefix = False
|
@@ -1259,7 +1258,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1259 |
self.use_upper_triangular = False
|
1260 |
self.subtract_mean_reward = False
|
1261 |
self.comparison_mode = False
|
1262 |
-
self.gumbel_detach =
|
1263 |
|
1264 |
# For visualization
|
1265 |
self.eval_mode = False
|
@@ -1358,6 +1357,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1358 |
|
1359 |
# Apply Gumbel-Softmax to the logits
|
1360 |
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
|
|
|
1361 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1362 |
|
1363 |
# Append the generated token to the input sequence
|
@@ -1436,6 +1436,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1436 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1437 |
labels: Optional[torch.LongTensor] = None,
|
1438 |
use_cache: Optional[bool] = None,
|
|
|
1439 |
output_attentions: Optional[bool] = None,
|
1440 |
output_hidden_states: Optional[bool] = None,
|
1441 |
return_dict: Optional[bool] = None,
|
@@ -1459,14 +1460,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1459 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1460 |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1461 |
```"""
|
1462 |
-
log_dict = self.log_dict if self.training else self.eval_log_dict
|
1463 |
|
1464 |
if not self.training:
|
1465 |
n_ahead_talk_to_restore = self.n_ahead_talk
|
1466 |
n_passes_to_restore = self.n_passes
|
1467 |
self.n_ahead_talk = 1
|
1468 |
self.n_passes = 1
|
1469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1470 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1471 |
output_hidden_states = (
|
1472 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -1547,6 +1561,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1547 |
else:
|
1548 |
# convert to identity transform
|
1549 |
def lambda_transform(cur_head):
|
|
|
1550 |
if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
|
1551 |
return torch.cat([
|
1552 |
torch.eye(
|
@@ -1679,6 +1694,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1679 |
use_cache=use_cache,
|
1680 |
output_attentions=output_attentions,
|
1681 |
output_hidden_states=output_hidden_states,
|
|
|
1682 |
return_dict=return_dict,
|
1683 |
)
|
1684 |
|
@@ -1793,8 +1809,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1793 |
shift_labels = labels[..., 1 + shift_amount:].contiguous()
|
1794 |
# Flatten the tokens
|
1795 |
loss_fct = CrossEntropyLoss(reduction="none")
|
|
|
1796 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1797 |
shift_labels = shift_labels.view(-1).clone()
|
|
|
1798 |
# Enable model parallelism
|
1799 |
shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
|
1800 |
shift_labels = shift_labels.to(shift_logits.device)
|
@@ -1886,6 +1904,22 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1886 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
1887 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1888 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1889 |
|
1890 |
if len(attention_mask.shape) == 2:
|
1891 |
breakpoint()
|
@@ -1933,7 +1967,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1933 |
# if shift_labels.min() == self.tokenizer.pad_token_id:
|
1934 |
shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
|
1935 |
unreduced_loss = loss_fct(shift_logits, shift_labels)
|
|
|
1936 |
if torch.any(unreduced_loss != unreduced_loss):
|
|
|
1937 |
raise ValueError("NaN loss")
|
1938 |
unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
|
1939 |
loss_list.append(unreduced_loss)
|
@@ -1992,78 +2028,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1992 |
else:
|
1993 |
added_reward = original_dqn_reward
|
1994 |
policy_reward += added_reward
|
1995 |
-
|
1996 |
-
if self.use_policy_loss and ahead_idx == self.n_ahead + self.n_ahead_talk - 2:
|
1997 |
-
# only compute during the thinking phase
|
1998 |
-
if self.use_reparam_for_thought_embeddings and (self.use_start_thought_token or self.use_end_thought_token):
|
1999 |
-
# sampled_start, sampled_end
|
2000 |
-
# calculate the log likelihood of the start and end embeddings sampled from a multivariate normal distribution
|
2001 |
-
# with mean start_embedding[0] and standard deviation start_embedding[1]
|
2002 |
-
if self.use_start_thought_token:
|
2003 |
-
exp_start_std = torch.exp(start_embedding[1])
|
2004 |
-
start_loglikelihood = -0.5 * (sampled_start.detach() - start_embedding[0]) ** 2 / exp_start_std ** 2 - start_embedding[1] - 0.5 * math.log(2 * math.pi)
|
2005 |
-
start_loglikelihood = start_loglikelihood.mean(dim=-1)
|
2006 |
-
if self.use_end_thought_token:
|
2007 |
-
exp_end_std = torch.exp(end_embedding[1])
|
2008 |
-
end_loglikelihood = -0.5 * (sampled_end.detach() - end_embedding[0]) ** 2 / exp_end_std ** 2 - end_embedding[1] - 0.5 * math.log(2 * math.pi)
|
2009 |
-
end_loglikelihood = end_loglikelihood.mean(dim=-1)
|
2010 |
-
# we use the mean instead of the sum to prevent dependence on the dimensionality of the embeddings
|
2011 |
-
if self.use_end_thought_token and self.use_policy_loss_for_end_thought:
|
2012 |
-
action_loglikelihoods_list.append(end_loglikelihood)
|
2013 |
-
if self.use_start_thought_token:
|
2014 |
-
action_loglikelihoods_list.append(start_loglikelihood)
|
2015 |
-
|
2016 |
-
if ahead_idx == self.n_ahead + self.n_ahead_talk - 2 and self.eval_mode:
|
2017 |
-
with torch.no_grad():
|
2018 |
-
# calculate the 0.75 quantile of the rewards
|
2019 |
-
filtered_tokens = input_ids[:, :policy_reward.shape[-1]].cpu().detach().numpy().flatten()
|
2020 |
-
filtered_tokens_mask = filtered_tokens != self.tokenizer.pad_token_id
|
2021 |
-
filtered_tokens = filtered_tokens[filtered_tokens_mask]
|
2022 |
-
filtered_rewards = policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten()
|
2023 |
-
filtered_rewards = filtered_rewards[filtered_tokens_mask]
|
2024 |
-
|
2025 |
-
abs_reward_list = np.abs(policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten())
|
2026 |
-
abs_reward_list = abs_reward_list[filtered_tokens_mask]
|
2027 |
-
medium_quantile = np.quantile(abs_reward_list, 0.5)
|
2028 |
-
upper_quantile = np.quantile(abs_reward_list, 0.95)
|
2029 |
-
|
2030 |
-
save_tokens_with_rewards_to_pdf(
|
2031 |
-
filtered_tokens,
|
2032 |
-
[0] + filtered_rewards.tolist(),
|
2033 |
-
self.tokenizer,
|
2034 |
-
output_file=f"texts/rewards_talk_{self.n_ahead_talk}_{self.training_steps}.pdf",
|
2035 |
-
eps=medium_quantile,
|
2036 |
-
eps2=upper_quantile,
|
2037 |
-
)
|
2038 |
-
|
2039 |
-
def plot_kde(data, losses):
|
2040 |
-
sns.set(style="whitegrid")
|
2041 |
-
# Create the KDE plot
|
2042 |
-
sns.kdeplot(data, fill=True)
|
2043 |
-
# Set the plot title and labels
|
2044 |
-
plt.title("KDE Plot")
|
2045 |
-
plt.xlabel("Value")
|
2046 |
-
plt.ylabel("Density")
|
2047 |
-
# Save the plot
|
2048 |
-
plt.savefig(f"texts/kde_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
|
2049 |
-
# Close the plot
|
2050 |
-
plt.close()
|
2051 |
-
|
2052 |
-
# Step 1: Create a base color palette
|
2053 |
-
base_colors = sns.color_palette("light:#5A9", n_colors=256) # More colors for a smoother gradient
|
2054 |
-
base_cmap = LinearSegmentedColormap.from_list("log_light", base_colors)
|
2055 |
-
log_norm = LogNorm(vmin=1e-3, vmax=10)
|
2056 |
-
|
2057 |
-
sns.kdeplot(x=data, y=losses, fill=True, levels=20, norm=log_norm, cut=0, linewidths=0)
|
2058 |
-
# limit y to 0 to 25 and x to -1 to 1
|
2059 |
-
plt.xlim(-1, 1)
|
2060 |
-
plt.ylim(0, 25)
|
2061 |
-
plt.savefig(f"texts/jointer_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
|
2062 |
-
plt.close()
|
2063 |
-
|
2064 |
-
self.all_rewards.extend(filtered_rewards)
|
2065 |
-
self.all_unreduced_losses.extend(unreduced_loss[:, :-1].flatten()[filtered_tokens_mask].float().flatten().cpu().detach().numpy())
|
2066 |
-
plot_kde(self.all_rewards, self.all_unreduced_losses)
|
2067 |
|
2068 |
for action_loglikelihoods_2d in action_loglikelihoods_list:
|
2069 |
train_policy_reward = policy_reward
|
@@ -2112,6 +2076,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2112 |
else:
|
2113 |
loss = cur_loss
|
2114 |
loss = loss / len(loss_list)
|
|
|
2115 |
|
2116 |
loss = loss * self.base_loss_beta
|
2117 |
|
@@ -2133,64 +2098,15 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2133 |
|
2134 |
if loss is not None:
|
2135 |
base_log_dict["loss_train"] = loss.item()
|
2136 |
-
|
2137 |
-
for loss_key, loss_val in base_log_dict.items():
|
2138 |
-
log_dict[loss_key] += loss_val / self.n_tokens_print
|
2139 |
-
|
2140 |
-
if self.use_policy_loss and policy_reward is not None:
|
2141 |
-
log_dict["policy_loss"] += dqn_loss / self.n_tokens_print
|
2142 |
-
log_dict["policy_reward"] += policy_reward.mean() / self.n_tokens_print
|
2143 |
-
|
2144 |
-
if not loss_list:
|
2145 |
-
if loss is not None:
|
2146 |
-
log_dict["loss_0"] += loss / self.n_tokens_print
|
2147 |
-
else:
|
2148 |
-
log_dict["loss_final"] += nonzero_mean(loss_list[-1]) / self.n_tokens_print
|
2149 |
-
log_dict["loss_talk"] += sum(nonzero_mean(cur_loss_item) for cur_loss_item in loss_list[-self.n_ahead_talk:]) / self.n_ahead_talk / self.n_tokens_print
|
2150 |
-
|
2151 |
-
# also log relative losses to loss_0
|
2152 |
-
if loss_list:
|
2153 |
-
for i in range(len(loss_list)):
|
2154 |
-
talk_idx = min(max(i - (self.n_ahead - 1), 0), len(talk_loss_list) - 1)
|
2155 |
-
if not talk_loss_list:
|
2156 |
-
cur_talk_loss = nonzero_mean(loss_list[0])
|
2157 |
-
else:
|
2158 |
-
cur_talk_loss = talk_loss_list[talk_idx]
|
2159 |
-
log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
|
2160 |
-
if self.training:
|
2161 |
-
self.training_steps += 1
|
2162 |
-
try:
|
2163 |
-
# if self.training_steps % (self.gradient_accumulation_steps * 256) == 0:
|
2164 |
-
if self.wandb_enabled:
|
2165 |
-
if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device):
|
2166 |
-
if not self.training:
|
2167 |
-
new_log_dict = {}
|
2168 |
-
for key in list(log_dict.keys()):
|
2169 |
-
new_log_dict["eval_" + key] = log_dict[key]
|
2170 |
-
log_dict = new_log_dict
|
2171 |
-
log_dict["training_steps"] = self.training_steps
|
2172 |
-
log_dict["batch_size"] = batch_size
|
2173 |
-
log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
|
2174 |
-
if self.n_ahead > 1:
|
2175 |
-
log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps
|
2176 |
-
else: # There's no overhead for talk tokens if there's no thinking
|
2177 |
-
log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
|
2178 |
-
# remove all nans
|
2179 |
-
for key in list(log_dict.keys()):
|
2180 |
-
if log_dict[key] != log_dict[key]:
|
2181 |
-
del log_dict[key]
|
2182 |
-
if self.training:
|
2183 |
-
wandb.log(log_dict)
|
2184 |
-
if self.training:
|
2185 |
-
self.log_dict = defaultdict(int)
|
2186 |
-
else:
|
2187 |
-
self.eval_log_dict = defaultdict(int)
|
2188 |
-
except Exception as e:
|
2189 |
-
pass
|
2190 |
|
2191 |
if not self.training:
|
2192 |
self.n_ahead_talk = n_ahead_talk_to_restore
|
2193 |
self.n_passes = n_passes_to_restore
|
|
|
|
|
|
|
|
|
|
|
2194 |
return CausalLMOutputWithPast(
|
2195 |
loss=loss if loss is not None else None,
|
2196 |
logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
|
|
|
44 |
|
45 |
from transformers.activations import ACT2FN
|
46 |
from transformers.cache_utils import Cache, DynamicCache
|
47 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
49 |
from transformers.modeling_utils import PreTrainedModel
|
50 |
from transformers.utils import (
|
|
|
73 |
from reportlab.lib.pagesizes import letter
|
74 |
from reportlab.lib.colors import HexColor
|
75 |
|
76 |
+
|
77 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
78 |
+
# Compute the attention mask correctly
|
79 |
+
bsz, tgt_len = input_shape
|
80 |
+
|
81 |
+
# Create a 4D attention mask from a 2D tensor mask.
|
82 |
+
# The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
|
83 |
+
# The values are either 0 or 1, where 0 means padding and 1 means non-padding.
|
84 |
+
combined_attention_mask = None
|
85 |
+
if attention_mask is not None:
|
86 |
+
# What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
|
87 |
+
# In this case, we can just use it directly.
|
88 |
+
if attention_mask.dim() == 4:
|
89 |
+
combined_attention_mask = attention_mask
|
90 |
+
# What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
|
91 |
+
# In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
|
92 |
+
elif attention_mask.dim() == 3:
|
93 |
+
expanded_attn_mask = attention_mask[:, None, :, :]
|
94 |
+
combined_attention_mask = expanded_attn_mask
|
95 |
+
# What if attention_mask is not None and has a shape of (batch_size, tgt_len)
|
96 |
+
# In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
|
97 |
+
elif attention_mask.dim() == 2:
|
98 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
99 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
100 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
101 |
+
if past_key_values_length > 0:
|
102 |
+
attention_mask = attention_mask.to(dtype=torch.long)
|
103 |
+
attention_mask = attention_mask[:, past_key_values_length:]
|
104 |
+
expanded_attn_mask = attention_mask[:, None, None, :]
|
105 |
+
combined_attention_mask = expanded_attn_mask
|
106 |
+
else:
|
107 |
+
raise ValueError(
|
108 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
109 |
+
input_shape, attention_mask.shape
|
110 |
+
)
|
111 |
+
)
|
112 |
+
|
113 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
114 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
115 |
+
# positions we want to attend and -10000.0 for masked positions.
|
116 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
117 |
+
# effectively the same as removing these entirely.
|
118 |
+
if combined_attention_mask is not None:
|
119 |
+
# Ensure the attention mask values are within a reasonable range
|
120 |
+
combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
|
121 |
+
|
122 |
+
# Convert the attention mask to bfloat16
|
123 |
+
combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
|
124 |
+
|
125 |
+
# Normalize the attention mask values to be between 0 and 1
|
126 |
+
combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
|
127 |
+
else:
|
128 |
+
combined_attention_mask = torch.zeros(
|
129 |
+
(bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
|
130 |
+
)
|
131 |
+
|
132 |
+
return combined_attention_mask
|
|
|
|
|
|
|
|
|
133 |
|
134 |
|
135 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
1174 |
self.model = QuietModel(config)
|
1175 |
self.vocab_size = config.vocab_size
|
1176 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1177 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
1178 |
+
self.num_experts = config.num_experts
|
1179 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
1180 |
self.max_thoughts = config.max_thoughts
|
1181 |
self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
|
1182 |
self.use_concat_talk_head = config.use_concat_talk_head
|
|
|
1239 |
self.tokenized_thought_prefix = None
|
1240 |
self.log_dict = defaultdict(int)
|
1241 |
self.eval_log_dict = defaultdict(int)
|
|
|
1242 |
self.loss_mean = loss_mean
|
|
|
|
|
1243 |
|
1244 |
self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
1245 |
self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
|
|
|
1248 |
self.embedding_scale = 1e2
|
1249 |
self.reinforce_temperature = 3
|
1250 |
self.base_loss_beta = 1
|
1251 |
+
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
1252 |
+
self.thinking_threshold = 0.5
|
1253 |
+
self.thinking_usefulness_loss_weight = 1e-2
|
1254 |
|
1255 |
# Not used in the paper:
|
1256 |
self.use_thought_prefix = False
|
|
|
1258 |
self.use_upper_triangular = False
|
1259 |
self.subtract_mean_reward = False
|
1260 |
self.comparison_mode = False
|
1261 |
+
self.gumbel_detach = False
|
1262 |
|
1263 |
# For visualization
|
1264 |
self.eval_mode = False
|
|
|
1357 |
|
1358 |
# Apply Gumbel-Softmax to the logits
|
1359 |
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
|
1360 |
+
print("Next token logits:", next_token_logits)
|
1361 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1362 |
|
1363 |
# Append the generated token to the input sequence
|
|
|
1436 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1437 |
labels: Optional[torch.LongTensor] = None,
|
1438 |
use_cache: Optional[bool] = None,
|
1439 |
+
# output_router_logits: Optional[bool] = None,
|
1440 |
output_attentions: Optional[bool] = None,
|
1441 |
output_hidden_states: Optional[bool] = None,
|
1442 |
return_dict: Optional[bool] = None,
|
|
|
1460 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1461 |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1462 |
```"""
|
|
|
1463 |
|
1464 |
if not self.training:
|
1465 |
n_ahead_talk_to_restore = self.n_ahead_talk
|
1466 |
n_passes_to_restore = self.n_passes
|
1467 |
self.n_ahead_talk = 1
|
1468 |
self.n_passes = 1
|
1469 |
+
|
1470 |
+
# aux_loss = None
|
1471 |
+
# output_router_logits = output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
1472 |
+
# if output_router_logits:
|
1473 |
+
# router_logits = outputs.router_logits if return_dict else outputs[-1]
|
1474 |
+
# if router_logits is not None:
|
1475 |
+
# aux_loss = load_balancing_loss_func(
|
1476 |
+
# router_logits,
|
1477 |
+
# self.num_experts,
|
1478 |
+
# self.num_experts_per_tok,
|
1479 |
+
# attention_mask,
|
1480 |
+
# )
|
1481 |
+
# if labels is not None:
|
1482 |
+
# loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
1483 |
+
|
1484 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1485 |
output_hidden_states = (
|
1486 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
1561 |
else:
|
1562 |
# convert to identity transform
|
1563 |
def lambda_transform(cur_head):
|
1564 |
+
# pdb.set_trace()
|
1565 |
if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
|
1566 |
return torch.cat([
|
1567 |
torch.eye(
|
|
|
1694 |
use_cache=use_cache,
|
1695 |
output_attentions=output_attentions,
|
1696 |
output_hidden_states=output_hidden_states,
|
1697 |
+
# output_router_logits=output_router_logits,
|
1698 |
return_dict=return_dict,
|
1699 |
)
|
1700 |
|
|
|
1809 |
shift_labels = labels[..., 1 + shift_amount:].contiguous()
|
1810 |
# Flatten the tokens
|
1811 |
loss_fct = CrossEntropyLoss(reduction="none")
|
1812 |
+
print("Shift logits before:", shift_logits)
|
1813 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1814 |
shift_labels = shift_labels.view(-1).clone()
|
1815 |
+
print("shift logits after:", shift_logits)
|
1816 |
# Enable model parallelism
|
1817 |
shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
|
1818 |
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
1904 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
1905 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1906 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1907 |
+
|
1908 |
+
# Predict the usefulness of thinking at each token position
|
1909 |
+
thinking_usefulness = self.thinking_usefulness_head(hidden_states).squeeze(-1)
|
1910 |
+
|
1911 |
+
# Apply a threshold to decide where to generate thoughts
|
1912 |
+
generate_thought_mask = thinking_usefulness > self.thinking_threshold
|
1913 |
+
|
1914 |
+
# Compute the regularization loss for thinking usefulness prediction
|
1915 |
+
thinking_usefulness_loss = torch.mean(thinking_usefulness * (1 - generate_thought_mask.float()))
|
1916 |
+
|
1917 |
+
# Add the regularization loss to the total loss
|
1918 |
+
if loss is not None:
|
1919 |
+
loss = loss + self.thinking_usefulness_loss_weight * thinking_usefulness_loss
|
1920 |
+
else:
|
1921 |
+
loss = self.thinking_usefulness_loss_weight * thinking_usefulness_loss
|
1922 |
+
|
1923 |
|
1924 |
if len(attention_mask.shape) == 2:
|
1925 |
breakpoint()
|
|
|
1967 |
# if shift_labels.min() == self.tokenizer.pad_token_id:
|
1968 |
shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
|
1969 |
unreduced_loss = loss_fct(shift_logits, shift_labels)
|
1970 |
+
# print("Loss:", unreduced_loss.item()) # Print the loss before checking for NaN values
|
1971 |
if torch.any(unreduced_loss != unreduced_loss):
|
1972 |
+
# pdb.set_trace()
|
1973 |
raise ValueError("NaN loss")
|
1974 |
unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
|
1975 |
loss_list.append(unreduced_loss)
|
|
|
2028 |
else:
|
2029 |
added_reward = original_dqn_reward
|
2030 |
policy_reward += added_reward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2031 |
|
2032 |
for action_loglikelihoods_2d in action_loglikelihoods_list:
|
2033 |
train_policy_reward = policy_reward
|
|
|
2076 |
else:
|
2077 |
loss = cur_loss
|
2078 |
loss = loss / len(loss_list)
|
2079 |
+
loss = loss + thinking_usefulness_loss
|
2080 |
|
2081 |
loss = loss * self.base_loss_beta
|
2082 |
|
|
|
2098 |
|
2099 |
if loss is not None:
|
2100 |
base_log_dict["loss_train"] = loss.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2101 |
|
2102 |
if not self.training:
|
2103 |
self.n_ahead_talk = n_ahead_talk_to_restore
|
2104 |
self.n_passes = n_passes_to_restore
|
2105 |
+
|
2106 |
+
del start_embedding
|
2107 |
+
del end_embedding
|
2108 |
+
torch.cuda.empty_cache()
|
2109 |
+
|
2110 |
return CausalLMOutputWithPast(
|
2111 |
loss=loss if loss is not None else None,
|
2112 |
logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
|