Update modeling_quiet.py
Browse files- modeling_quiet.py +11 -7
modeling_quiet.py
CHANGED
@@ -139,14 +139,18 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
|
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
140 |
def _get_unpad_data(attention_mask):
|
141 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
142 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
max_seqlen_in_batch
|
149 |
-
|
|
|
|
|
|
|
150 |
|
151 |
|
152 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
|
|
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
140 |
def _get_unpad_data(attention_mask):
|
141 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
142 |
+
|
143 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
144 |
+
|
145 |
+
# Handle the case when seqlens_in_batch is empty
|
146 |
+
if seqlens_in_batch.numel() == 0:
|
147 |
+
max_seqlen_in_batch = 0
|
148 |
+
else:
|
149 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
150 |
+
|
151 |
+
cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
|
152 |
+
|
153 |
+
return indices, cu_seqlens, max_seqlen_in_batch
|
154 |
|
155 |
|
156 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
|