File size: 4,432 Bytes
23c9ef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#    Copyright 2024 Xi Zhang
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import torch
import torch.nn as nn

from transformers import AutoImageProcessor, AutoModel, AutoConfig
        
class DINOVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower
        self.select_layer = args.mm_vision_select_layer 
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 

        if not delay_load:
            self.load_model()
        elif getattr(args, 'unfreeze_mm_vision_tower', False):
            self.load_model()
        else:
            self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)

    def load_model(self):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return
        
        self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
        self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name)
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True
    
    def get_features(self, images):
        outputs = self.vision_tower(images, output_hidden_states=True)
        hidden_states = outputs.hidden_states
        
        if self.select_layer == "all":
            if self.select_feature == "patch":
                all_layers_features = [hidden_state[:, 1:, :].contiguous() for hidden_state in hidden_states[1:]]
            elif self.select_feature == "cls_patch":
                all_layers_features = [hidden_state.contiguous() for hidden_state in hidden_states[1:]]
            else:
                raise ValueError(f"Unexpected select feature: {self.select_feature}")

            return torch.stack(all_layers_features)  
        else:
            selected_layer_features = hidden_states[int(self.select_layer)]

            if self.select_feature == "patch":
                selected_layer_features = selected_layer_features[:, 1:]
            elif self.select_feature == "cls_patch":
                selected_layer_features = selected_layer_features
            else:
                raise ValueError(f"Unexpected select feature: {self.select_feature}")

            return torch.stack([selected_layer_features])
    
    @torch.no_grad()
    def forward(self, images):
        
        if images.shape[0] != 2:
            raise ValueError(
                f"Expected images.shape[0] == 2, but got {images.shape}. "
                "Ensure the input includes both current and previous images."
            )

        cur_images = images[0]  
        prev_images = images[1]  

        cur_features = self.get_features(cur_images) 
        prev_features = self.get_features(prev_images) 
        
        cur_features = cur_features.permute(1, 0, 2, 3) 
        prev_features = prev_features.permute(1, 0, 2, 3) 

        # Stack current and previous images along a new dimension
        images_features = torch.stack([cur_features, prev_features])  
        
        return images_features

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        
        return self.vision_tower.dtype 

    @property
    def device(self):
        return self.vision_tower.device 

    @property
    def config(self):
        if self.is_loaded:
            return self.vision_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        return self.config.hidden_size 
    
    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2 
    
    @property
    def num_layers(self):
        return self.config.num_hidden_layers