Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -6
modeling_quiet.py
CHANGED
@@ -1666,12 +1666,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1666 |
head_input_hidden_states = talk_hidden_states
|
1667 |
|
1668 |
residual_logits = self.talk_head[0](head_input_hidden_states)
|
1669 |
-
|
1670 |
-
|
1671 |
-
|
1672 |
-
|
1673 |
-
|
1674 |
-
|
1675 |
assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
|
1676 |
if self.clever_residual:
|
1677 |
if ahead_idx >= self.n_ahead - 1:
|
|
|
1666 |
head_input_hidden_states = talk_hidden_states
|
1667 |
|
1668 |
residual_logits = self.talk_head[0](head_input_hidden_states)
|
1669 |
+
if self.use_shallow_talk:
|
1670 |
+
residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
|
1671 |
+
residual_logits = residual_logits.to(logits.device)
|
1672 |
+
mixing_weights = self.mixing_head(torch.cat([cur_base_hidden, talk_hidden_states], dim=-1))
|
1673 |
+
mixing_weights = torch.sigmoid(mixing_weights)
|
1674 |
+
logits = base_logits * (1 - mixing_weights) + residual_logits * mixing_weights
|
1675 |
assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
|
1676 |
if self.clever_residual:
|
1677 |
if ahead_idx >= self.n_ahead - 1:
|