shoveling42 commited on
Commit
72959d0
·
verified ·
1 Parent(s): 892ccfc

Edit 536 lines due to type mismatch error when evaluation

Browse files

The specific error message is as follows:
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.

Files changed (1) hide show
  1. modeling_magma.py +2 -1
modeling_magma.py CHANGED
@@ -533,7 +533,8 @@ class MagmaForCausalLM(MagmaPreTrainedModel):
533
  f" the number of image given to the model is {num_images}. "
534
  f"This prevents correct indexing and breaks batch generation."
535
  )
536
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
 
537
  final_attention_mask |= image_to_overwrite
538
  position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
539
 
 
533
  f" the number of image given to the model is {num_images}. "
534
  f"This prevents correct indexing and breaks batch generation."
535
  )
536
+ # final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
537
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device).to(torch.bfloat16)
538
  final_attention_mask |= image_to_overwrite
539
  position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
540