Update modeling_quiet.py
Browse files- modeling_quiet.py +9 -11
modeling_quiet.py
CHANGED
@@ -1682,18 +1682,16 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1682 |
assert self.no_residual
|
1683 |
residual_logits = self.lm_head(hidden_states)
|
1684 |
talk_hidden_states = hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1685 |
else:
|
1686 |
-
|
1687 |
-
hidden_states_lm = hidden_states
|
1688 |
-
rm_hidden_states = hidden_states
|
1689 |
-
else:
|
1690 |
-
if ahead_idx > self.n_ahead - 1:
|
1691 |
-
cur_base_hidden = torch.cat([
|
1692 |
-
base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
|
1693 |
-
base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
|
1694 |
-
], dim=-2)
|
1695 |
-
else:
|
1696 |
-
cur_base_hidden = base_hidden_states
|
1697 |
|
1698 |
if self.use_concat_talk_head:
|
1699 |
# concatenate the hidden states with the original hidden states
|
|
|
1682 |
assert self.no_residual
|
1683 |
residual_logits = self.lm_head(hidden_states)
|
1684 |
talk_hidden_states = hidden_states
|
1685 |
+
if 'hidden_states_lm' not in locals():
|
1686 |
+
hidden_states_lm = hidden_states
|
1687 |
+
rm_hidden_states = hidden_states
|
1688 |
+
if ahead_idx > self.n_ahead - 1:
|
1689 |
+
cur_base_hidden = torch.cat([
|
1690 |
+
base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
|
1691 |
+
base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
|
1692 |
+
], dim=-2)
|
1693 |
else:
|
1694 |
+
cur_base_hidden = base_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1695 |
|
1696 |
if self.use_concat_talk_head:
|
1697 |
# concatenate the hidden states with the original hidden states
|