update
Browse files- modeling_magma.py +4 -4
modeling_magma.py
CHANGED
@@ -680,7 +680,7 @@ class MagmaForForCausalLM(MagmaPreTrainedModel):
|
|
680 |
pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
|
681 |
image_features = self.vision_tower(pixel_values_for_image)
|
682 |
selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
|
683 |
-
selected_image_feature = self.multi_modal_projector(
|
684 |
selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
|
685 |
selected_image_features.append(selected_image_feature.flatten(0, 1))
|
686 |
elif self.config.vision_config['img_anyres_strategy'] == "crop":
|
@@ -690,7 +690,7 @@ class MagmaForForCausalLM(MagmaPreTrainedModel):
|
|
690 |
_pixel_values_list_temp = sum(_pixel_values_list, ())
|
691 |
_pixel_values_list_temp = torch.cat(_pixel_values_list_temp, dim=0)
|
692 |
image_features = self.vision_tower(_pixel_values_list_temp)[vision_feature_layer].permute(0, 2, 3, 1)
|
693 |
-
image_features = self.multi_modal_projector(
|
694 |
|
695 |
num_crops_list = [_image_size[0]*_image_size[1] for _image_size in _image_sizes_list_temp]
|
696 |
image_features_split = torch.split(image_features, num_crops_list, dim=0)
|
@@ -1281,12 +1281,12 @@ class MagmaForConditionalGeneration(MagmaPreTrainedModel):
|
|
1281 |
pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
|
1282 |
image_features = self.vision_tower(pixel_values_for_image)
|
1283 |
selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
|
1284 |
-
selected_image_feature = self.multi_modal_projector(
|
1285 |
selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
|
1286 |
selected_image_features.append(selected_image_feature)
|
1287 |
elif self.config.vision_config['img_anyres_strategy'] == "crop":
|
1288 |
image_features = self.vision_tower(pixel_values)[vision_feature_layer].permute(0, 2, 3, 1)
|
1289 |
-
image_features = self.multi_modal_projector(
|
1290 |
num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
|
1291 |
image_features_split = torch.split(image_features, num_patches_for_images, dim=0)
|
1292 |
selected_image_features = []
|
|
|
680 |
pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
|
681 |
image_features = self.vision_tower(pixel_values_for_image)
|
682 |
selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
|
683 |
+
selected_image_feature = self.multi_modal_projector(selected_image_feature)
|
684 |
selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
|
685 |
selected_image_features.append(selected_image_feature.flatten(0, 1))
|
686 |
elif self.config.vision_config['img_anyres_strategy'] == "crop":
|
|
|
690 |
_pixel_values_list_temp = sum(_pixel_values_list, ())
|
691 |
_pixel_values_list_temp = torch.cat(_pixel_values_list_temp, dim=0)
|
692 |
image_features = self.vision_tower(_pixel_values_list_temp)[vision_feature_layer].permute(0, 2, 3, 1)
|
693 |
+
image_features = self.multi_modal_projector(image_features)
|
694 |
|
695 |
num_crops_list = [_image_size[0]*_image_size[1] for _image_size in _image_sizes_list_temp]
|
696 |
image_features_split = torch.split(image_features, num_crops_list, dim=0)
|
|
|
1281 |
pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
|
1282 |
image_features = self.vision_tower(pixel_values_for_image)
|
1283 |
selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
|
1284 |
+
selected_image_feature = self.multi_modal_projector(selected_image_feature)
|
1285 |
selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
|
1286 |
selected_image_features.append(selected_image_feature)
|
1287 |
elif self.config.vision_config['img_anyres_strategy'] == "crop":
|
1288 |
image_features = self.vision_tower(pixel_values)[vision_feature_layer].permute(0, 2, 3, 1)
|
1289 |
+
image_features = self.multi_modal_projector(image_features)
|
1290 |
num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
|
1291 |
image_features_split = torch.split(image_features, num_patches_for_images, dim=0)
|
1292 |
selected_image_features = []
|