Crystalcareai commited on
Commit
b2672e5
·
verified ·
1 Parent(s): b28a110

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +9 -11
modeling_quiet.py CHANGED
@@ -1682,18 +1682,16 @@ class QuietForCausalLM(QuietPreTrainedModel):
1682
  assert self.no_residual
1683
  residual_logits = self.lm_head(hidden_states)
1684
  talk_hidden_states = hidden_states
 
 
 
 
 
 
 
 
1685
  else:
1686
- if 'hidden_states_lm' not in locals():
1687
- hidden_states_lm = hidden_states
1688
- rm_hidden_states = hidden_states
1689
- else:
1690
- if ahead_idx > self.n_ahead - 1:
1691
- cur_base_hidden = torch.cat([
1692
- base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
1693
- base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
1694
- ], dim=-2)
1695
- else:
1696
- cur_base_hidden = base_hidden_states
1697
 
1698
  if self.use_concat_talk_head:
1699
  # concatenate the hidden states with the original hidden states
 
1682
  assert self.no_residual
1683
  residual_logits = self.lm_head(hidden_states)
1684
  talk_hidden_states = hidden_states
1685
+ if 'hidden_states_lm' not in locals():
1686
+ hidden_states_lm = hidden_states
1687
+ rm_hidden_states = hidden_states
1688
+ if ahead_idx > self.n_ahead - 1:
1689
+ cur_base_hidden = torch.cat([
1690
+ base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
1691
+ base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
1692
+ ], dim=-2)
1693
  else:
1694
+ cur_base_hidden = base_hidden_states
 
 
 
 
 
 
 
 
 
 
1695
 
1696
  if self.use_concat_talk_head:
1697
  # concatenate the hidden states with the original hidden states