Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -1
modeling_quiet.py
CHANGED
@@ -136,7 +136,6 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
|
|
136 |
c.save()
|
137 |
|
138 |
|
139 |
-
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
140 |
def _get_unpad_data(attention_mask):
|
141 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
142 |
|
@@ -148,6 +147,8 @@ def _get_unpad_data(attention_mask):
|
|
148 |
else:
|
149 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
150 |
|
|
|
|
|
151 |
cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
|
152 |
|
153 |
return indices, cu_seqlens, max_seqlen_in_batch
|
|
|
136 |
c.save()
|
137 |
|
138 |
|
|
|
139 |
def _get_unpad_data(attention_mask):
|
140 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
141 |
|
|
|
147 |
else:
|
148 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
149 |
|
150 |
+
# Ensure seqlens_in_batch has the correct shape before cumulative sum
|
151 |
+
seqlens_in_batch = seqlens_in_batch.view(-1)
|
152 |
cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
|
153 |
|
154 |
return indices, cu_seqlens, max_seqlen_in_batch
|