Crystalcareai commited on
Commit
b28a110
·
verified ·
1 Parent(s): cd900ce

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +6 -1
modeling_quiet.py CHANGED
@@ -1662,7 +1662,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1662
  prev_rm_logits = rm_logits # for policy gradient
1663
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1664
 
1665
- hidden_states_lm = hidden_states_lm.to(self.lm_head.weight.dtype)
1666
  logits = self.lm_head(hidden_states_lm)
1667
 
1668
  if ahead_idx == 0:
@@ -1682,6 +1682,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1682
  assert self.no_residual
1683
  residual_logits = self.lm_head(hidden_states)
1684
  talk_hidden_states = hidden_states
 
 
 
 
1685
  else:
1686
  if ahead_idx > self.n_ahead - 1:
1687
  cur_base_hidden = torch.cat([
@@ -1780,6 +1784,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1780
  if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
1781
  loss_list.append(loss)
1782
  talk_loss_list.append(nonzero_mean(loss).detach())
 
1783
 
1784
  if not attempted or self.comparison_mode:
1785
  rm_hidden_states = hidden_states
 
1662
  prev_rm_logits = rm_logits # for policy gradient
1663
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1664
 
1665
+ hidden_states_lm = hidden_states
1666
  logits = self.lm_head(hidden_states_lm)
1667
 
1668
  if ahead_idx == 0:
 
1682
  assert self.no_residual
1683
  residual_logits = self.lm_head(hidden_states)
1684
  talk_hidden_states = hidden_states
1685
+ else:
1686
+ if 'hidden_states_lm' not in locals():
1687
+ hidden_states_lm = hidden_states
1688
+ rm_hidden_states = hidden_states
1689
  else:
1690
  if ahead_idx > self.n_ahead - 1:
1691
  cur_base_hidden = torch.cat([
 
1784
  if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
1785
  loss_list.append(loss)
1786
  talk_loss_list.append(nonzero_mean(loss).detach())
1787
+
1788
 
1789
  if not attempted or self.comparison_mode:
1790
  rm_hidden_states = hidden_states