Crystalcareai commited on
Commit
25accc9
·
verified ·
1 Parent(s): 5e5e800

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +6 -6
modeling_quiet.py CHANGED
@@ -1666,12 +1666,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1666
  head_input_hidden_states = talk_hidden_states
1667
 
1668
  residual_logits = self.talk_head[0](head_input_hidden_states)
1669
- if self.use_shallow_talk:
1670
- residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1671
- residual_logits = residual_logits.to(logits.device)
1672
- mixing_weights = self.mixing_head(torch.cat([cur_base_hidden, talk_hidden_states], dim=-1))
1673
- mixing_weights = torch.sigmoid(mixing_weights)
1674
- logits = base_logits * (1 - mixing_weights) + residual_logits * mixing_weights
1675
  assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1676
  if self.clever_residual:
1677
  if ahead_idx >= self.n_ahead - 1:
 
1666
  head_input_hidden_states = talk_hidden_states
1667
 
1668
  residual_logits = self.talk_head[0](head_input_hidden_states)
1669
+ if self.use_shallow_talk:
1670
+ residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1671
+ residual_logits = residual_logits.to(logits.device)
1672
+ mixing_weights = self.mixing_head(torch.cat([cur_base_hidden, talk_hidden_states], dim=-1))
1673
+ mixing_weights = torch.sigmoid(mixing_weights)
1674
+ logits = base_logits * (1 - mixing_weights) + residual_logits * mixing_weights
1675
  assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1676
  if self.clever_residual:
1677
  if ahead_idx >= self.n_ahead - 1: