Crystalcareai commited on
Commit
275d80c
·
verified ·
1 Parent(s): 40e9ae3

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +22 -28
modeling_quiet.py CHANGED
@@ -1616,27 +1616,24 @@ class QuietForCausalLM(QuietPreTrainedModel):
1616
  base_embeddings = self.model.embed_tokens.weight
1617
  if self.train_only_thinking_embedding:
1618
  base_embeddings = base_embeddings.detach()
1619
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1620
- fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1621
- print("Input IDs contains NaN:", torch.isnan(input_ids).any().item())
1622
- print("Attention mask contains NaN:", torch.isnan(attention_mask).any().item())
1623
- print("Labels contains NaN:", torch.isnan(labels).any().item() if labels is not None else False)
1624
- for ahead_idx in range(fwd_iters):
1625
- past_key_values_length = 0
1626
- if past_key_values is not None:
1627
- use_legacy_cache = not isinstance(past_key_values, Cache)
1628
- if use_legacy_cache:
1629
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1630
- past_key_values_length = past_key_values.get_usable_length(seq_len)
1631
-
1632
- if position_ids is None:
1633
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1634
- position_ids = torch.arange(
1635
- past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device
1636
- )
1637
- position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
1638
- else:
1639
- position_ids = position_ids.view(-1, seq_len).long()
1640
 
1641
  if inputs_embeds is None:
1642
  contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
@@ -1697,8 +1694,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1697
  hidden_states = outputs[0]
1698
  prev_rm_logits = rm_logits # for policy gradient
1699
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1700
- print("Hidden states contains NaN after loop:", torch.isnan(hidden_states).any().item())
1701
- print("Logits contains NaN after loop:", torch.isnan(logits).any().item())
1702
  if ahead_idx == 0:
1703
  hidden_states_lm = hidden_states
1704
  logits = self.lm_head(hidden_states_lm)
@@ -2088,7 +2084,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
2088
  # This will only happen when we force the next token to be the end of thought token
2089
  break
2090
  dqn_loss_list.append(actor_loss.mean())
2091
- print("Loss list contains NaN before loss computation:", any(torch.isnan(loss).any() for loss in loss_list))
2092
  if loss_list:
2093
  if self.first_and_last_mode:
2094
  loss = sum(
@@ -2116,20 +2112,18 @@ class QuietForCausalLM(QuietPreTrainedModel):
2116
  loss = loss / len(loss_list)
2117
 
2118
  loss = loss * self.base_loss_beta
2119
- print("Loss contains NaN after loss computation:", torch.isnan(loss).any().item())
2120
- print("DQN loss list contains NaN before loss computation:", any(torch.isnan(loss).any() for loss in dqn_loss_list))
2121
  if dqn_loss_list:
2122
  dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list)
2123
- print("DQN loss contains NaN after loss computation:", torch.isnan(dqn_loss).any().item())
2124
  if self.include_policy_loss:
2125
  if loss is not None:
2126
  loss += dqn_loss * self.policy_loss_beta
2127
  else:
2128
  loss = dqn_loss * self.policy_loss_beta
 
2129
  if not return_dict:
2130
  output = (logits,) + outputs[1:]
2131
  return (loss,) + output if loss is not None else output
2132
- print("DQN loss contains NaN after loss computation:", torch.isnan(dqn_loss).any().item())
2133
 
2134
  base_log_dict = {
2135
  f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list))
 
1616
  base_embeddings = self.model.embed_tokens.weight
1617
  if self.train_only_thinking_embedding:
1618
  base_embeddings = base_embeddings.detach()
1619
+ # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1620
+ fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1621
+ for ahead_idx in range(fwd_iters):
1622
+ past_key_values_length = 0
1623
+ if past_key_values is not None:
1624
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1625
+ if use_legacy_cache:
1626
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1627
+ past_key_values_length = past_key_values.get_usable_length(seq_len)
1628
+
1629
+ if position_ids is None:
1630
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1631
+ position_ids = torch.arange(
1632
+ past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device
1633
+ )
1634
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
1635
+ else:
1636
+ position_ids = position_ids.view(-1, seq_len).long()
 
 
 
1637
 
1638
  if inputs_embeds is None:
1639
  contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
 
1694
  hidden_states = outputs[0]
1695
  prev_rm_logits = rm_logits # for policy gradient
1696
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1697
+
 
1698
  if ahead_idx == 0:
1699
  hidden_states_lm = hidden_states
1700
  logits = self.lm_head(hidden_states_lm)
 
2084
  # This will only happen when we force the next token to be the end of thought token
2085
  break
2086
  dqn_loss_list.append(actor_loss.mean())
2087
+
2088
  if loss_list:
2089
  if self.first_and_last_mode:
2090
  loss = sum(
 
2112
  loss = loss / len(loss_list)
2113
 
2114
  loss = loss * self.base_loss_beta
2115
+
 
2116
  if dqn_loss_list:
2117
  dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list)
 
2118
  if self.include_policy_loss:
2119
  if loss is not None:
2120
  loss += dqn_loss * self.policy_loss_beta
2121
  else:
2122
  loss = dqn_loss * self.policy_loss_beta
2123
+
2124
  if not return_dict:
2125
  output = (logits,) + outputs[1:]
2126
  return (loss,) + output if loss is not None else output
 
2127
 
2128
  base_log_dict = {
2129
  f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list))