Crystalcareai commited on
Commit
08961d8
·
verified ·
1 Parent(s): 25accc9

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +7 -7
modeling_quiet.py CHANGED
@@ -1252,7 +1252,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1252
 
1253
  # For visualization
1254
  self.eval_mode = False
1255
-
1256
  num_talk = 1
1257
  talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1258
  if self.use_weighted_talk_head:
@@ -1273,8 +1273,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1273
  self.talk_head = nn.ModuleList([nn.Sequential(
1274
  nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1275
  )])
1276
-
1277
- self.mixing_head = nn.Linear(config.hidden_size * 2, 1)
1278
 
1279
  self.apply(self._init_weights)
1280
 
@@ -1668,10 +1666,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1668
  residual_logits = self.talk_head[0](head_input_hidden_states)
1669
  if self.use_shallow_talk:
1670
  residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1671
- residual_logits = residual_logits.to(logits.device)
1672
- mixing_weights = self.mixing_head(torch.cat([cur_base_hidden, talk_hidden_states], dim=-1))
1673
- mixing_weights = torch.sigmoid(mixing_weights)
1674
- logits = base_logits * (1 - mixing_weights) + residual_logits * mixing_weights
 
 
1675
  assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1676
  if self.clever_residual:
1677
  if ahead_idx >= self.n_ahead - 1:
 
1252
 
1253
  # For visualization
1254
  self.eval_mode = False
1255
+
1256
  num_talk = 1
1257
  talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1258
  if self.use_weighted_talk_head:
 
1273
  self.talk_head = nn.ModuleList([nn.Sequential(
1274
  nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1275
  )])
 
 
1276
 
1277
  self.apply(self._init_weights)
1278
 
 
1666
  residual_logits = self.talk_head[0](head_input_hidden_states)
1667
  if self.use_shallow_talk:
1668
  residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1669
+ residual_logits = residual_logits.to(logits.device)
1670
+ if self.use_weighted_talk_head:
1671
+ # combine the cur_base_hidden with the talk_hidden_states according to the weighted head
1672
+ residual_logits = cur_base_hidden * (1 - residual_logits) + talk_hidden_states * residual_logits
1673
+ residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1674
+
1675
  assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1676
  if self.clever_residual:
1677
  if ahead_idx >= self.n_ahead - 1: