Crystalcareai commited on
Commit
7530909
·
verified ·
1 Parent(s): da4fa77

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -1
modeling_quiet.py CHANGED
@@ -136,7 +136,6 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
136
  c.save()
137
 
138
 
139
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
142
 
@@ -148,6 +147,8 @@ def _get_unpad_data(attention_mask):
148
  else:
149
  max_seqlen_in_batch = seqlens_in_batch.max().item()
150
 
 
 
151
  cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
152
 
153
  return indices, cu_seqlens, max_seqlen_in_batch
 
136
  c.save()
137
 
138
 
 
139
  def _get_unpad_data(attention_mask):
140
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
141
 
 
147
  else:
148
  max_seqlen_in_batch = seqlens_in_batch.max().item()
149
 
150
+ # Ensure seqlens_in_batch has the correct shape before cumulative sum
151
+ seqlens_in_batch = seqlens_in_batch.view(-1)
152
  cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
153
 
154
  return indices, cu_seqlens, max_seqlen_in_batch