File size: 7,295 Bytes
cff8c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f706be
cff8c58
 
 
 
 
 
 
0f706be
cff8c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f706be
cff8c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrapper for performing DINOv2 inference."""

import cv2
import numpy as np
from third_party.dinov2 import dino
from omniglue import utils
import tensorflow as tf
import torch


class DINOExtract:
  """Class to initialize DINO model and extract features from an image."""

  def __init__(self, cpt_path: str, feature_layer: int = 1):
    self.feature_layer = feature_layer
    self.model = dino.vit_base()
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    state_dict_raw = torch.load(cpt_path, map_location='cpu')

    # state_dict = {}
    # for k, v in state_dict_raw.items():
    #   state_dict[k.replace('blocks', 'blocks.0')] = v

    self.model.load_state_dict(state_dict_raw)
    self.model.eval().to(self.device)

    self.image_size_max = 630

    self.h_down_rate = self.model.patch_embed.patch_size[0]
    self.w_down_rate = self.model.patch_embed.patch_size[1]

  def __call__(self, image) -> np.ndarray:
    return self.forward(image)

  def forward(self, image: np.ndarray) -> np.ndarray:
    """Feeds image through DINO ViT model to extract features.

    Args:
      image: (H, W, 3) numpy array, decoded image bytes, value range [0, 255].

    Returns:
      features: (H // 14, W // 14, C) numpy array image features.
    """
    image = self._resize_input_image(image)
    image_processed = self._process_image(image)
    image_processed = image_processed.unsqueeze(0).float()
    features = self.extract_feature(image_processed)
    features = features.squeeze(0).permute(1, 2, 0).cpu().numpy()
    return features

  def _resize_input_image(
      self, image: np.ndarray, interpolation=cv2.INTER_LINEAR
  ):
    """Resizes image such that both dimensions are divisble by down_rate."""
    h_image, w_image = image.shape[:2]
    h_larger_flag = h_image > w_image
    large_side_image = max(h_image, w_image)

    # resize the image with the largest side length smaller than a threshold
    # to accelerate ViT backbone inference (which has quadratic complexity).
    if large_side_image > self.image_size_max:
      if h_larger_flag:
        h_image_target = self.image_size_max
        w_image_target = int(self.image_size_max * w_image / h_image)
      else:
        w_image_target = self.image_size_max
        h_image_target = int(self.image_size_max * h_image / w_image)
    else:
      h_image_target = h_image
      w_image_target = w_image

    h, w = (
        h_image_target // self.h_down_rate,
        w_image_target // self.w_down_rate,
    )
    h_resize, w_resize = h * self.h_down_rate, w * self.w_down_rate
    image = cv2.resize(image, (w_resize, h_resize), interpolation=interpolation)
    return image

  def _process_image(self, image: np.ndarray) -> torch.Tensor:
    """Turn image into pytorch tensor and normalize it."""
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    image_processed = image / 255.0
    image_processed = (image_processed - mean) / std
    image_processed = torch.from_numpy(image_processed).permute(2, 0, 1)
    image_processed = image_processed.to(self.device)
    return image_processed

  def extract_feature(self, image):
    """Extracts features from image.

    Args:
      image: (B, 3, H, W) torch tensor, normalized with ImageNet mean/std.

    Returns:
      features: (B, C, H//14, W//14) torch tensor image features.
    """
    b, _, h_origin, w_origin = image.shape
    out = self.model.get_intermediate_layers(image, n=self.feature_layer)[0]
    h = int(h_origin / self.h_down_rate)
    w = int(w_origin / self.w_down_rate)
    dim = out.shape[-1]
    out = out.reshape(b, h, w, dim).permute(0, 3, 1, 2).detach()
    return out


def _preprocess_shape(
    h_image, w_image, image_size_max=630, h_down_rate=14, w_down_rate=14
):
  # Flatten the tensors
  h_image = tf.squeeze(h_image)
  w_image = tf.squeeze(w_image)
  # logging.info(h_image, w_image)

  h_larger_flag = tf.greater(h_image, w_image)
  large_side_image = tf.maximum(h_image, w_image)

  # Function to calculate new dimensions when height is larger
  def resize_h_larger():
    h_image_target = image_size_max
    w_image_target = tf.cast(image_size_max * w_image / h_image, tf.int32)
    return h_image_target, w_image_target

  # Function to calculate new dimensions when width is larger or equal
  def resize_w_larger_or_equal():
    w_image_target = image_size_max
    h_image_target = tf.cast(image_size_max * h_image / w_image, tf.int32)
    return h_image_target, w_image_target

  # Function to keep original dimensions
  def keep_original():
    return h_image, w_image

  h_image_target, w_image_target = tf.cond(
      tf.greater(large_side_image, image_size_max),
      lambda: tf.cond(h_larger_flag, resize_h_larger, resize_w_larger_or_equal),
      keep_original,
  )

  # resize to be divided by patch size
  h = h_image_target // h_down_rate
  w = w_image_target // w_down_rate
  h_resize = h * h_down_rate
  w_resize = w * w_down_rate

  # Expand dimensions
  h_resize = tf.expand_dims(h_resize, 0)
  w_resize = tf.expand_dims(w_resize, 0)

  return h_resize, w_resize


def get_dino_descriptors(dino_features, keypoints, height, width, feature_dim):
  """Get DINO descriptors using Superpoint keypoints.

  Args:
    dino_features: DINO features in 1-D.
    keypoints: Superpoint keypoint locations, in format (x, y), in pixels, shape
      (N, 2).
    height: image height, type tf.Tensor.int32.
    width: image width, type tf.Tensor.int32.
    feature_dim: DINO feature channel size, type tf.Tensor.int32.

  Returns:
    Interpolated DINO descriptors.
  """
  # TODO(omniglue): fix the hard-coded DINO patch size (14).
  height_1d = tf.reshape(height, [1])
  width_1d = tf.reshape(width, [1])

  height_1d_resized, width_1d_resized = _preprocess_shape(
      height_1d, width_1d, image_size_max=630, h_down_rate=14, w_down_rate=14
  )

  height_feat = height_1d_resized // 14
  width_feat = width_1d_resized // 14
  feature_dim_1d = tf.reshape(feature_dim, [1])

  size_feature = tf.concat([height_feat, width_feat, feature_dim_1d], axis=0)
  dino_features = tf.reshape(dino_features, size_feature)

  img_size = tf.cast(tf.concat([width_1d, height_1d], axis=0), tf.float32)
  feature_size = tf.cast(
      tf.concat([width_feat, height_feat], axis=0), tf.float32
  )

  keypoints_feature = (
      keypoints
      / tf.expand_dims(img_size, axis=0)
      * tf.expand_dims(feature_size, axis=0)
  )

  dino_descriptors = []
  for kp in keypoints_feature:
    dino_descriptors.append(
        utils.lookup_descriptor_bilinear(kp.numpy(), dino_features.numpy())
    )
  dino_descriptors = tf.convert_to_tensor(
      np.array(dino_descriptors), dtype=tf.float32
  )
  return dino_descriptors