|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Wrapper for performing DINOv2 inference.""" |
|
|
|
import cv2 |
|
import numpy as np |
|
from third_party.dinov2 import dino |
|
from omniglue import utils |
|
import tensorflow as tf |
|
import torch |
|
|
|
|
|
class DINOExtract: |
|
"""Class to initialize DINO model and extract features from an image.""" |
|
|
|
def __init__(self, cpt_path: str, feature_layer: int = 1): |
|
self.feature_layer = feature_layer |
|
self.model = dino.vit_base() |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
state_dict_raw = torch.load(cpt_path, map_location='cpu') |
|
|
|
|
|
|
|
|
|
|
|
self.model.load_state_dict(state_dict_raw) |
|
self.model.eval().to(self.device) |
|
|
|
self.image_size_max = 630 |
|
|
|
self.h_down_rate = self.model.patch_embed.patch_size[0] |
|
self.w_down_rate = self.model.patch_embed.patch_size[1] |
|
|
|
def __call__(self, image) -> np.ndarray: |
|
return self.forward(image) |
|
|
|
def forward(self, image: np.ndarray) -> np.ndarray: |
|
"""Feeds image through DINO ViT model to extract features. |
|
|
|
Args: |
|
image: (H, W, 3) numpy array, decoded image bytes, value range [0, 255]. |
|
|
|
Returns: |
|
features: (H // 14, W // 14, C) numpy array image features. |
|
""" |
|
image = self._resize_input_image(image) |
|
image_processed = self._process_image(image) |
|
image_processed = image_processed.unsqueeze(0).float() |
|
features = self.extract_feature(image_processed) |
|
features = features.squeeze(0).permute(1, 2, 0).cpu().numpy() |
|
return features |
|
|
|
def _resize_input_image( |
|
self, image: np.ndarray, interpolation=cv2.INTER_LINEAR |
|
): |
|
"""Resizes image such that both dimensions are divisble by down_rate.""" |
|
h_image, w_image = image.shape[:2] |
|
h_larger_flag = h_image > w_image |
|
large_side_image = max(h_image, w_image) |
|
|
|
|
|
|
|
if large_side_image > self.image_size_max: |
|
if h_larger_flag: |
|
h_image_target = self.image_size_max |
|
w_image_target = int(self.image_size_max * w_image / h_image) |
|
else: |
|
w_image_target = self.image_size_max |
|
h_image_target = int(self.image_size_max * h_image / w_image) |
|
else: |
|
h_image_target = h_image |
|
w_image_target = w_image |
|
|
|
h, w = ( |
|
h_image_target // self.h_down_rate, |
|
w_image_target // self.w_down_rate, |
|
) |
|
h_resize, w_resize = h * self.h_down_rate, w * self.w_down_rate |
|
image = cv2.resize(image, (w_resize, h_resize), interpolation=interpolation) |
|
return image |
|
|
|
def _process_image(self, image: np.ndarray) -> torch.Tensor: |
|
"""Turn image into pytorch tensor and normalize it.""" |
|
mean = np.array([0.485, 0.456, 0.406]) |
|
std = np.array([0.229, 0.224, 0.225]) |
|
|
|
image_processed = image / 255.0 |
|
image_processed = (image_processed - mean) / std |
|
image_processed = torch.from_numpy(image_processed).permute(2, 0, 1) |
|
image_processed = image_processed.to(self.device) |
|
return image_processed |
|
|
|
def extract_feature(self, image): |
|
"""Extracts features from image. |
|
|
|
Args: |
|
image: (B, 3, H, W) torch tensor, normalized with ImageNet mean/std. |
|
|
|
Returns: |
|
features: (B, C, H//14, W//14) torch tensor image features. |
|
""" |
|
b, _, h_origin, w_origin = image.shape |
|
out = self.model.get_intermediate_layers(image, n=self.feature_layer)[0] |
|
h = int(h_origin / self.h_down_rate) |
|
w = int(w_origin / self.w_down_rate) |
|
dim = out.shape[-1] |
|
out = out.reshape(b, h, w, dim).permute(0, 3, 1, 2).detach() |
|
return out |
|
|
|
|
|
def _preprocess_shape( |
|
h_image, w_image, image_size_max=630, h_down_rate=14, w_down_rate=14 |
|
): |
|
|
|
h_image = tf.squeeze(h_image) |
|
w_image = tf.squeeze(w_image) |
|
|
|
|
|
h_larger_flag = tf.greater(h_image, w_image) |
|
large_side_image = tf.maximum(h_image, w_image) |
|
|
|
|
|
def resize_h_larger(): |
|
h_image_target = image_size_max |
|
w_image_target = tf.cast(image_size_max * w_image / h_image, tf.int32) |
|
return h_image_target, w_image_target |
|
|
|
|
|
def resize_w_larger_or_equal(): |
|
w_image_target = image_size_max |
|
h_image_target = tf.cast(image_size_max * h_image / w_image, tf.int32) |
|
return h_image_target, w_image_target |
|
|
|
|
|
def keep_original(): |
|
return h_image, w_image |
|
|
|
h_image_target, w_image_target = tf.cond( |
|
tf.greater(large_side_image, image_size_max), |
|
lambda: tf.cond(h_larger_flag, resize_h_larger, resize_w_larger_or_equal), |
|
keep_original, |
|
) |
|
|
|
|
|
h = h_image_target // h_down_rate |
|
w = w_image_target // w_down_rate |
|
h_resize = h * h_down_rate |
|
w_resize = w * w_down_rate |
|
|
|
|
|
h_resize = tf.expand_dims(h_resize, 0) |
|
w_resize = tf.expand_dims(w_resize, 0) |
|
|
|
return h_resize, w_resize |
|
|
|
|
|
def get_dino_descriptors(dino_features, keypoints, height, width, feature_dim): |
|
"""Get DINO descriptors using Superpoint keypoints. |
|
|
|
Args: |
|
dino_features: DINO features in 1-D. |
|
keypoints: Superpoint keypoint locations, in format (x, y), in pixels, shape |
|
(N, 2). |
|
height: image height, type tf.Tensor.int32. |
|
width: image width, type tf.Tensor.int32. |
|
feature_dim: DINO feature channel size, type tf.Tensor.int32. |
|
|
|
Returns: |
|
Interpolated DINO descriptors. |
|
""" |
|
|
|
height_1d = tf.reshape(height, [1]) |
|
width_1d = tf.reshape(width, [1]) |
|
|
|
height_1d_resized, width_1d_resized = _preprocess_shape( |
|
height_1d, width_1d, image_size_max=630, h_down_rate=14, w_down_rate=14 |
|
) |
|
|
|
height_feat = height_1d_resized // 14 |
|
width_feat = width_1d_resized // 14 |
|
feature_dim_1d = tf.reshape(feature_dim, [1]) |
|
|
|
size_feature = tf.concat([height_feat, width_feat, feature_dim_1d], axis=0) |
|
dino_features = tf.reshape(dino_features, size_feature) |
|
|
|
img_size = tf.cast(tf.concat([width_1d, height_1d], axis=0), tf.float32) |
|
feature_size = tf.cast( |
|
tf.concat([width_feat, height_feat], axis=0), tf.float32 |
|
) |
|
|
|
keypoints_feature = ( |
|
keypoints |
|
/ tf.expand_dims(img_size, axis=0) |
|
* tf.expand_dims(feature_size, axis=0) |
|
) |
|
|
|
dino_descriptors = [] |
|
for kp in keypoints_feature: |
|
dino_descriptors.append( |
|
utils.lookup_descriptor_bilinear(kp.numpy(), dino_features.numpy()) |
|
) |
|
dino_descriptors = tf.convert_to_tensor( |
|
np.array(dino_descriptors), dtype=tf.float32 |
|
) |
|
return dino_descriptors |
|
|