import torch import torch.nn as nn from .processor import Blip2ImageTrainProcessor from .eva_vit import create_eva_vit_g class EvaClipVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower # self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') self.args = args if not delay_load: self.load_model() # self.is_loaded = True def load_model(self, device_map=None): if self.is_loaded: print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) return if not hasattr(self.args, 'dynamic_resolution'): dynamic_resolution = None else: dynamic_resolution = self.args.dynamic_resolution if (not hasattr(self.args, 'freeze_vision_encoder')) or self.args.freeze_vision_encoder: use_checkpoint = False else: use_checkpoint = True assert self.args.vit_precision == 'fp32', 'if the vision encoder is training, the type needs to be fp32' self.image_processor = Blip2ImageTrainProcessor( image_size=self.args.img_size, dynamic_resolution= dynamic_resolution ) self.vision_tower = create_eva_vit_g( img_size=self.args.img_size, drop_path_rate=self.args.drop_path_rate, precision=self.args.vit_precision, vit_model_path=self.args.vit_model_path, use_checkpoint=use_checkpoint ) # self.vision_tower.requires_grad_(False) self.is_loaded = True def feature_select(self, image_features): if self.select_feature == 'patch': image_features = image_features[:, 1:] elif self.select_feature == 'cls_patch': image_features = image_features else: raise ValueError(f'Unexpected select feature: {self.select_feature}') return image_features # @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower(image.unsqueeze(0)) image_features.append(self.feature_select(image_forward_out).to(image.dtype)) else: image_features = self.vision_tower(images.to(dtype=self.dtype)) image_features = self.feature_select(image_features).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, dtype=torch.float) @property def hidden_size(self): return self.vision_tower.hidden_size @property def num_patches(self): return (self.vision_tower.image_size // self.vision_tower.patch_size) ** 2 @property def num_patches_per_side(self): return (self.vision_tower.image_size // self.vision_tower.patch_size) @property def dtype(self): return self.vision_tower.pos_embed.dtype