Update modeling_quiet.py
Browse files- modeling_quiet.py +13 -0
modeling_quiet.py
CHANGED
@@ -1351,6 +1351,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1351 |
return_dict=return_dict,
|
1352 |
)
|
1353 |
new_key_values = outputs.past_key_values
|
|
|
|
|
1354 |
hidden_states = outputs[0]
|
1355 |
logits = self.lm_head(hidden_states)
|
1356 |
logits = logits[:, -1, :] # Only consider the last token
|
@@ -1691,6 +1693,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1691 |
)
|
1692 |
|
1693 |
prev_hidden_states = hidden_states
|
|
|
|
|
1694 |
hidden_states = outputs[0]
|
1695 |
prev_rm_logits = rm_logits # for policy gradient
|
1696 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
@@ -1814,7 +1818,12 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1814 |
if not attempted or self.comparison_mode:
|
1815 |
rm_hidden_states = hidden_states
|
1816 |
# print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
|
|
|
|
|
1817 |
rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
|
|
|
|
|
|
|
1818 |
|
1819 |
# don't allow it to predict the thinking token
|
1820 |
if self.tokenizer_has_start_thought_token:
|
@@ -1876,7 +1885,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1876 |
|
1877 |
if not contains_thought:
|
1878 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
|
|
|
|
1879 |
inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
|
|
|
|
|
1880 |
else:
|
1881 |
thought_id = self.start_token_id if contains_start else self.end_token_id
|
1882 |
cur_thought_embedding = start_embedding if contains_start else end_embedding
|
|
|
1351 |
return_dict=return_dict,
|
1352 |
)
|
1353 |
new_key_values = outputs.past_key_values
|
1354 |
+
print(f"Hidden states contains NaN: {torch.isnan(hidden_states).any().item()}")
|
1355 |
+
|
1356 |
hidden_states = outputs[0]
|
1357 |
logits = self.lm_head(hidden_states)
|
1358 |
logits = logits[:, -1, :] # Only consider the last token
|
|
|
1693 |
)
|
1694 |
|
1695 |
prev_hidden_states = hidden_states
|
1696 |
+
print(f"1696 Hidden states contains NaN: {torch.isnan(hidden_states).any().item()}")
|
1697 |
+
|
1698 |
hidden_states = outputs[0]
|
1699 |
prev_rm_logits = rm_logits # for policy gradient
|
1700 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
|
|
1818 |
if not attempted or self.comparison_mode:
|
1819 |
rm_hidden_states = hidden_states
|
1820 |
# print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
|
1821 |
+
print(f"RM hidden states contains NaN: {torch.isnan(rm_hidden_states).any().item()}")
|
1822 |
+
|
1823 |
rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
|
1824 |
+
|
1825 |
+
print(f"RM logits contains NaN: {torch.isnan(rm_logits).any().item()}")
|
1826 |
+
|
1827 |
|
1828 |
# don't allow it to predict the thinking token
|
1829 |
if self.tokenizer_has_start_thought_token:
|
|
|
1885 |
|
1886 |
if not contains_thought:
|
1887 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1888 |
+
print(f"Probabilities_2d contains NaN: {torch.isnan(probabilities_2d).any().item()}")
|
1889 |
+
|
1890 |
inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
|
1891 |
+
print(f"Inputs_embeds contains NaN: {torch.isnan(inputs_embeds).any().item()}")
|
1892 |
+
|
1893 |
else:
|
1894 |
thought_id = self.start_token_id if contains_start else self.end_token_id
|
1895 |
cur_thought_embedding = start_embedding if contains_start else end_embedding
|