Update modeling_quiet.py
Browse files- modeling_quiet.py +22 -28
modeling_quiet.py
CHANGED
@@ -1616,27 +1616,24 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1616 |
base_embeddings = self.model.embed_tokens.weight
|
1617 |
if self.train_only_thinking_embedding:
|
1618 |
base_embeddings = base_embeddings.detach()
|
1619 |
-
|
1620 |
-
|
1621 |
-
|
1622 |
-
|
1623 |
-
|
1624 |
-
|
1625 |
-
|
1626 |
-
|
1627 |
-
|
1628 |
-
|
1629 |
-
|
1630 |
-
|
1631 |
-
|
1632 |
-
|
1633 |
-
|
1634 |
-
|
1635 |
-
|
1636 |
-
|
1637 |
-
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
|
1638 |
-
else:
|
1639 |
-
position_ids = position_ids.view(-1, seq_len).long()
|
1640 |
|
1641 |
if inputs_embeds is None:
|
1642 |
contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
|
@@ -1697,8 +1694,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1697 |
hidden_states = outputs[0]
|
1698 |
prev_rm_logits = rm_logits # for policy gradient
|
1699 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1700 |
-
|
1701 |
-
print("Logits contains NaN after loop:", torch.isnan(logits).any().item())
|
1702 |
if ahead_idx == 0:
|
1703 |
hidden_states_lm = hidden_states
|
1704 |
logits = self.lm_head(hidden_states_lm)
|
@@ -2088,7 +2084,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2088 |
# This will only happen when we force the next token to be the end of thought token
|
2089 |
break
|
2090 |
dqn_loss_list.append(actor_loss.mean())
|
2091 |
-
|
2092 |
if loss_list:
|
2093 |
if self.first_and_last_mode:
|
2094 |
loss = sum(
|
@@ -2116,20 +2112,18 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2116 |
loss = loss / len(loss_list)
|
2117 |
|
2118 |
loss = loss * self.base_loss_beta
|
2119 |
-
|
2120 |
-
print("DQN loss list contains NaN before loss computation:", any(torch.isnan(loss).any() for loss in dqn_loss_list))
|
2121 |
if dqn_loss_list:
|
2122 |
dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list)
|
2123 |
-
print("DQN loss contains NaN after loss computation:", torch.isnan(dqn_loss).any().item())
|
2124 |
if self.include_policy_loss:
|
2125 |
if loss is not None:
|
2126 |
loss += dqn_loss * self.policy_loss_beta
|
2127 |
else:
|
2128 |
loss = dqn_loss * self.policy_loss_beta
|
|
|
2129 |
if not return_dict:
|
2130 |
output = (logits,) + outputs[1:]
|
2131 |
return (loss,) + output if loss is not None else output
|
2132 |
-
print("DQN loss contains NaN after loss computation:", torch.isnan(dqn_loss).any().item())
|
2133 |
|
2134 |
base_log_dict = {
|
2135 |
f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list))
|
|
|
1616 |
base_embeddings = self.model.embed_tokens.weight
|
1617 |
if self.train_only_thinking_embedding:
|
1618 |
base_embeddings = base_embeddings.detach()
|
1619 |
+
# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1620 |
+
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
1621 |
+
for ahead_idx in range(fwd_iters):
|
1622 |
+
past_key_values_length = 0
|
1623 |
+
if past_key_values is not None:
|
1624 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
1625 |
+
if use_legacy_cache:
|
1626 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
1627 |
+
past_key_values_length = past_key_values.get_usable_length(seq_len)
|
1628 |
+
|
1629 |
+
if position_ids is None:
|
1630 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1631 |
+
position_ids = torch.arange(
|
1632 |
+
past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device
|
1633 |
+
)
|
1634 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
|
1635 |
+
else:
|
1636 |
+
position_ids = position_ids.view(-1, seq_len).long()
|
|
|
|
|
|
|
1637 |
|
1638 |
if inputs_embeds is None:
|
1639 |
contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
|
|
|
1694 |
hidden_states = outputs[0]
|
1695 |
prev_rm_logits = rm_logits # for policy gradient
|
1696 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1697 |
+
|
|
|
1698 |
if ahead_idx == 0:
|
1699 |
hidden_states_lm = hidden_states
|
1700 |
logits = self.lm_head(hidden_states_lm)
|
|
|
2084 |
# This will only happen when we force the next token to be the end of thought token
|
2085 |
break
|
2086 |
dqn_loss_list.append(actor_loss.mean())
|
2087 |
+
|
2088 |
if loss_list:
|
2089 |
if self.first_and_last_mode:
|
2090 |
loss = sum(
|
|
|
2112 |
loss = loss / len(loss_list)
|
2113 |
|
2114 |
loss = loss * self.base_loss_beta
|
2115 |
+
|
|
|
2116 |
if dqn_loss_list:
|
2117 |
dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list)
|
|
|
2118 |
if self.include_policy_loss:
|
2119 |
if loss is not None:
|
2120 |
loss += dqn_loss * self.policy_loss_beta
|
2121 |
else:
|
2122 |
loss = dqn_loss * self.policy_loss_beta
|
2123 |
+
|
2124 |
if not return_dict:
|
2125 |
output = (logits,) + outputs[1:]
|
2126 |
return (loss,) + output if loss is not None else output
|
|
|
2127 |
|
2128 |
base_log_dict = {
|
2129 |
f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list))
|