Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -1
modeling_quiet.py
CHANGED
@@ -1662,7 +1662,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1662 |
prev_rm_logits = rm_logits # for policy gradient
|
1663 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1664 |
|
1665 |
-
hidden_states_lm =
|
1666 |
logits = self.lm_head(hidden_states_lm)
|
1667 |
|
1668 |
if ahead_idx == 0:
|
@@ -1682,6 +1682,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1682 |
assert self.no_residual
|
1683 |
residual_logits = self.lm_head(hidden_states)
|
1684 |
talk_hidden_states = hidden_states
|
|
|
|
|
|
|
|
|
1685 |
else:
|
1686 |
if ahead_idx > self.n_ahead - 1:
|
1687 |
cur_base_hidden = torch.cat([
|
@@ -1780,6 +1784,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1780 |
if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
|
1781 |
loss_list.append(loss)
|
1782 |
talk_loss_list.append(nonzero_mean(loss).detach())
|
|
|
1783 |
|
1784 |
if not attempted or self.comparison_mode:
|
1785 |
rm_hidden_states = hidden_states
|
|
|
1662 |
prev_rm_logits = rm_logits # for policy gradient
|
1663 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1664 |
|
1665 |
+
hidden_states_lm = hidden_states
|
1666 |
logits = self.lm_head(hidden_states_lm)
|
1667 |
|
1668 |
if ahead_idx == 0:
|
|
|
1682 |
assert self.no_residual
|
1683 |
residual_logits = self.lm_head(hidden_states)
|
1684 |
talk_hidden_states = hidden_states
|
1685 |
+
else:
|
1686 |
+
if 'hidden_states_lm' not in locals():
|
1687 |
+
hidden_states_lm = hidden_states
|
1688 |
+
rm_hidden_states = hidden_states
|
1689 |
else:
|
1690 |
if ahead_idx > self.n_ahead - 1:
|
1691 |
cur_base_hidden = torch.cat([
|
|
|
1784 |
if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
|
1785 |
loss_list.append(loss)
|
1786 |
talk_loss_list.append(nonzero_mean(loss).detach())
|
1787 |
+
|
1788 |
|
1789 |
if not attempted or self.comparison_mode:
|
1790 |
rm_hidden_states = hidden_states
|