Update modeling_quiet.py
Browse files- modeling_quiet.py +7 -7
modeling_quiet.py
CHANGED
@@ -1252,7 +1252,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1252 |
|
1253 |
# For visualization
|
1254 |
self.eval_mode = False
|
1255 |
-
|
1256 |
num_talk = 1
|
1257 |
talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
|
1258 |
if self.use_weighted_talk_head:
|
@@ -1273,8 +1273,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1273 |
self.talk_head = nn.ModuleList([nn.Sequential(
|
1274 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
1275 |
)])
|
1276 |
-
|
1277 |
-
self.mixing_head = nn.Linear(config.hidden_size * 2, 1)
|
1278 |
|
1279 |
self.apply(self._init_weights)
|
1280 |
|
@@ -1668,10 +1666,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1668 |
residual_logits = self.talk_head[0](head_input_hidden_states)
|
1669 |
if self.use_shallow_talk:
|
1670 |
residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
|
1671 |
-
|
1672 |
-
|
1673 |
-
|
1674 |
-
|
|
|
|
|
1675 |
assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
|
1676 |
if self.clever_residual:
|
1677 |
if ahead_idx >= self.n_ahead - 1:
|
|
|
1252 |
|
1253 |
# For visualization
|
1254 |
self.eval_mode = False
|
1255 |
+
|
1256 |
num_talk = 1
|
1257 |
talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
|
1258 |
if self.use_weighted_talk_head:
|
|
|
1273 |
self.talk_head = nn.ModuleList([nn.Sequential(
|
1274 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
1275 |
)])
|
|
|
|
|
1276 |
|
1277 |
self.apply(self._init_weights)
|
1278 |
|
|
|
1666 |
residual_logits = self.talk_head[0](head_input_hidden_states)
|
1667 |
if self.use_shallow_talk:
|
1668 |
residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
|
1669 |
+
residual_logits = residual_logits.to(logits.device)
|
1670 |
+
if self.use_weighted_talk_head:
|
1671 |
+
# combine the cur_base_hidden with the talk_hidden_states according to the weighted head
|
1672 |
+
residual_logits = cur_base_hidden * (1 - residual_logits) + talk_hidden_states * residual_logits
|
1673 |
+
residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
|
1674 |
+
|
1675 |
assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
|
1676 |
if self.clever_residual:
|
1677 |
if ahead_idx >= self.n_ahead - 1:
|