Crystalcareai commited on
Commit
bd80d20
·
verified ·
1 Parent(s): cbafcfb

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -2
modeling_quiet.py CHANGED
@@ -1110,6 +1110,7 @@ class QuietModel(QuietPreTrainedModel):
1110
  next_decoder_cache = None
1111
 
1112
  for decoder_layer in self.layers:
 
1113
  if output_hidden_states:
1114
  all_hidden_states += (hidden_states,)
1115
 
@@ -1167,6 +1168,7 @@ def nonzero_mean(x, axis=None):
1167
 
1168
  def loss_mean(x):
1169
  return x.sum() / (x != 0).sum()
 
1170
 
1171
  class QuietForCausalLM(QuietPreTrainedModel):
1172
  _tied_weights_keys = ["lm_head.weight"]
@@ -1915,8 +1917,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1915
  else:
1916
  loss_logits = logits
1917
  shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
1918
- import pdb; pdb.set_trace()
1919
-
1920
  # print("initial_loss_logits contains NaN:", torch.isnan(initial_loss_logits).any().item())
1921
  # print("logits contains NaN:", torch.isnan(logits).any().item())
1922
  # print("loss_logits contains NaN:", torch.isnan(loss_logits).any().item())
 
1110
  next_decoder_cache = None
1111
 
1112
  for decoder_layer in self.layers:
1113
+ print(f"Hidden states contains NaN before layer {i}:", torch.isnan(hidden_states).any().item())
1114
  if output_hidden_states:
1115
  all_hidden_states += (hidden_states,)
1116
 
 
1168
 
1169
  def loss_mean(x):
1170
  return x.sum() / (x != 0).sum()
1171
+ print(f"Hidden states contains NaN after layer {i}:", torch.isnan(hidden_states).any().item())
1172
 
1173
  class QuietForCausalLM(QuietPreTrainedModel):
1174
  _tied_weights_keys = ["lm_head.weight"]
 
1917
  else:
1918
  loss_logits = logits
1919
  shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
 
 
1920
  # print("initial_loss_logits contains NaN:", torch.isnan(initial_loss_logits).any().item())
1921
  # print("logits contains NaN:", torch.isnan(logits).any().item())
1922
  # print("loss_logits contains NaN:", torch.isnan(loss_logits).any().item())