Edit 536 lines due to type mismatch error when evaluation
Browse filesThe specific error message is as follows:
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.
- modeling_magma.py +2 -1
modeling_magma.py
CHANGED
@@ -533,7 +533,8 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
|
|
533 |
f" the number of image given to the model is {num_images}. "
|
534 |
f"This prevents correct indexing and breaks batch generation."
|
535 |
)
|
536 |
-
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
|
|
537 |
final_attention_mask |= image_to_overwrite
|
538 |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
539 |
|
|
|
533 |
f" the number of image given to the model is {num_images}. "
|
534 |
f"This prevents correct indexing and breaks batch generation."
|
535 |
)
|
536 |
+
# final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
537 |
+
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device).to(torch.bfloat16)
|
538 |
final_attention_mask |= image_to_overwrite
|
539 |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
540 |
|