Crystalcareai commited on
Commit
8d44852
·
verified ·
1 Parent(s): 275d80c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +13 -0
modeling_quiet.py CHANGED
@@ -1351,6 +1351,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1351
  return_dict=return_dict,
1352
  )
1353
  new_key_values = outputs.past_key_values
 
 
1354
  hidden_states = outputs[0]
1355
  logits = self.lm_head(hidden_states)
1356
  logits = logits[:, -1, :] # Only consider the last token
@@ -1691,6 +1693,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1691
  )
1692
 
1693
  prev_hidden_states = hidden_states
 
 
1694
  hidden_states = outputs[0]
1695
  prev_rm_logits = rm_logits # for policy gradient
1696
  prev_rm_tokens = cur_rm_tokens # for policy gradient
@@ -1814,7 +1818,12 @@ class QuietForCausalLM(QuietPreTrainedModel):
1814
  if not attempted or self.comparison_mode:
1815
  rm_hidden_states = hidden_states
1816
  # print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
 
 
1817
  rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
 
 
 
1818
 
1819
  # don't allow it to predict the thinking token
1820
  if self.tokenizer_has_start_thought_token:
@@ -1876,7 +1885,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
1876
 
1877
  if not contains_thought:
1878
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
 
 
1879
  inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
 
 
1880
  else:
1881
  thought_id = self.start_token_id if contains_start else self.end_token_id
1882
  cur_thought_embedding = start_embedding if contains_start else end_embedding
 
1351
  return_dict=return_dict,
1352
  )
1353
  new_key_values = outputs.past_key_values
1354
+ print(f"Hidden states contains NaN: {torch.isnan(hidden_states).any().item()}")
1355
+
1356
  hidden_states = outputs[0]
1357
  logits = self.lm_head(hidden_states)
1358
  logits = logits[:, -1, :] # Only consider the last token
 
1693
  )
1694
 
1695
  prev_hidden_states = hidden_states
1696
+ print(f"1696 Hidden states contains NaN: {torch.isnan(hidden_states).any().item()}")
1697
+
1698
  hidden_states = outputs[0]
1699
  prev_rm_logits = rm_logits # for policy gradient
1700
  prev_rm_tokens = cur_rm_tokens # for policy gradient
 
1818
  if not attempted or self.comparison_mode:
1819
  rm_hidden_states = hidden_states
1820
  # print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
1821
+ print(f"RM hidden states contains NaN: {torch.isnan(rm_hidden_states).any().item()}")
1822
+
1823
  rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
1824
+
1825
+ print(f"RM logits contains NaN: {torch.isnan(rm_logits).any().item()}")
1826
+
1827
 
1828
  # don't allow it to predict the thinking token
1829
  if self.tokenizer_has_start_thought_token:
 
1885
 
1886
  if not contains_thought:
1887
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1888
+ print(f"Probabilities_2d contains NaN: {torch.isnan(probabilities_2d).any().item()}")
1889
+
1890
  inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
1891
+ print(f"Inputs_embeds contains NaN: {torch.isnan(inputs_embeds).any().item()}")
1892
+
1893
  else:
1894
  thought_id = self.start_token_id if contains_start else self.end_token_id
1895
  cur_thought_embedding = start_embedding if contains_start else end_embedding