Crystalcareai commited on
Commit
da4fa77
·
verified ·
1 Parent(s): 5049df3

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +11 -7
modeling_quiet.py CHANGED
@@ -139,14 +139,18 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
142
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
143
- max_seqlen_in_batch = seqlens_in_batch.max().item()
144
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
145
- return (
146
- indices,
147
- cu_seqlens,
148
- max_seqlen_in_batch,
149
- )
 
 
 
150
 
151
 
152
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
 
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
142
+
143
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
144
+
145
+ # Handle the case when seqlens_in_batch is empty
146
+ if seqlens_in_batch.numel() == 0:
147
+ max_seqlen_in_batch = 0
148
+ else:
149
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
150
+
151
+ cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
152
+
153
+ return indices, cu_seqlens, max_seqlen_in_batch
154
 
155
 
156
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet