jw2yang commited on
Commit
0583f08
·
1 Parent(s): 3fe01ed
Files changed (1) hide show
  1. modeling_magma.py +4 -4
modeling_magma.py CHANGED
@@ -680,7 +680,7 @@ class MagmaForForCausalLM(MagmaPreTrainedModel):
680
  pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
681
  image_features = self.vision_tower(pixel_values_for_image)
682
  selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
683
- selected_image_feature = self.multi_modal_projector((selected_image_feature, None))
684
  selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
685
  selected_image_features.append(selected_image_feature.flatten(0, 1))
686
  elif self.config.vision_config['img_anyres_strategy'] == "crop":
@@ -690,7 +690,7 @@ class MagmaForForCausalLM(MagmaPreTrainedModel):
690
  _pixel_values_list_temp = sum(_pixel_values_list, ())
691
  _pixel_values_list_temp = torch.cat(_pixel_values_list_temp, dim=0)
692
  image_features = self.vision_tower(_pixel_values_list_temp)[vision_feature_layer].permute(0, 2, 3, 1)
693
- image_features = self.multi_modal_projector((image_features, None))
694
 
695
  num_crops_list = [_image_size[0]*_image_size[1] for _image_size in _image_sizes_list_temp]
696
  image_features_split = torch.split(image_features, num_crops_list, dim=0)
@@ -1281,12 +1281,12 @@ class MagmaForConditionalGeneration(MagmaPreTrainedModel):
1281
  pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
1282
  image_features = self.vision_tower(pixel_values_for_image)
1283
  selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
1284
- selected_image_feature = self.multi_modal_projector((selected_image_feature, None))
1285
  selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
1286
  selected_image_features.append(selected_image_feature)
1287
  elif self.config.vision_config['img_anyres_strategy'] == "crop":
1288
  image_features = self.vision_tower(pixel_values)[vision_feature_layer].permute(0, 2, 3, 1)
1289
- image_features = self.multi_modal_projector((image_features, None))
1290
  num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
1291
  image_features_split = torch.split(image_features, num_patches_for_images, dim=0)
1292
  selected_image_features = []
 
680
  pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
681
  image_features = self.vision_tower(pixel_values_for_image)
682
  selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
683
+ selected_image_feature = self.multi_modal_projector(selected_image_feature)
684
  selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
685
  selected_image_features.append(selected_image_feature.flatten(0, 1))
686
  elif self.config.vision_config['img_anyres_strategy'] == "crop":
 
690
  _pixel_values_list_temp = sum(_pixel_values_list, ())
691
  _pixel_values_list_temp = torch.cat(_pixel_values_list_temp, dim=0)
692
  image_features = self.vision_tower(_pixel_values_list_temp)[vision_feature_layer].permute(0, 2, 3, 1)
693
+ image_features = self.multi_modal_projector(image_features)
694
 
695
  num_crops_list = [_image_size[0]*_image_size[1] for _image_size in _image_sizes_list_temp]
696
  image_features_split = torch.split(image_features, num_crops_list, dim=0)
 
1281
  pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
1282
  image_features = self.vision_tower(pixel_values_for_image)
1283
  selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
1284
+ selected_image_feature = self.multi_modal_projector(selected_image_feature)
1285
  selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
1286
  selected_image_features.append(selected_image_feature)
1287
  elif self.config.vision_config['img_anyres_strategy'] == "crop":
1288
  image_features = self.vision_tower(pixel_values)[vision_feature_layer].permute(0, 2, 3, 1)
1289
+ image_features = self.multi_modal_projector(image_features)
1290
  num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
1291
  image_features_split = torch.split(image_features, num_patches_for_images, dim=0)
1292
  selected_image_features = []