# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wrapper for performing SuperPoint inference.""" import math from typing import Optional, Tuple import cv2 import numpy as np from omniglue import utils import tensorflow.compat.v1 as tf1 class SuperPointExtract: """Class to initialize SuperPoint model and extract features from an image. To stay consistent with SuperPoint training and eval configurations, resize images to (320x240) or (640x480). Attributes model_path: string, filepath to saved SuperPoint TF1 model weights. """ def __init__(self, model_path: str): self.model_path = model_path self._graph = tf1.Graph() self._sess = tf1.Session(graph=self._graph) tf1.saved_model.loader.load( self._sess, [tf1.saved_model.tag_constants.SERVING], model_path ) def __call__( self, image, segmentation_mask=None, num_features=1024, pad_random_features=False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: return self.compute( image, segmentation_mask=segmentation_mask, num_features=num_features, pad_random_features=pad_random_features, ) def compute( self, image: np.ndarray, segmentation_mask: Optional[np.ndarray] = None, num_features: int = 1024, pad_random_features: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Feeds image through SuperPoint model to extract keypoints and features. Args: image: (H, W, 3) numpy array, decoded image bytes. segmentation_mask: (H, W) binary numpy array or None. If not None, extracted keypoints are restricted to being within the mask. num_features: max number of features to extract (or 0 to indicate keeping all extracted features). pad_random_features: if True, adds randomly sampled keypoints to the output such that there are exactly 'num_features' keypoints. Descriptors for these sampled keypoints are taken from the network's descriptor map output, and scores are set to 0. No action taken if num_features = 0. Returns: keypoints: (N, 2) numpy array, coordinates of keypoints as floats. descriptors: (N, 256) numpy array, descriptors for keypoints as floats. scores: (N, 1) numpy array, confidence values for keypoints as floats. """ # Resize image so both dimensions are divisible by 8. image, keypoint_scale_factors = self._resize_input_image(image) if segmentation_mask is not None: segmentation_mask, _ = self._resize_input_image( segmentation_mask, interpolation=cv2.INTER_NEAREST ) assert ( segmentation_mask is None or image.shape[:2] == segmentation_mask.shape[:2] ) # Preprocess and feed-forward image. image_preprocessed = self._preprocess_image(image) input_image_tensor = self._graph.get_tensor_by_name('superpoint/image:0') output_prob_nms_tensor = self._graph.get_tensor_by_name( 'superpoint/prob_nms:0' ) output_desc_tensors = self._graph.get_tensor_by_name( 'superpoint/descriptors:0' ) out = self._sess.run( [output_prob_nms_tensor, output_desc_tensors], feed_dict={input_image_tensor: np.expand_dims(image_preprocessed, 0)}, ) # Format output from network. keypoint_map = np.squeeze(out[0]) descriptor_map = np.squeeze(out[1]) if segmentation_mask is not None: keypoint_map = np.where(segmentation_mask, keypoint_map, 0.0) keypoints, descriptors, scores = self._extract_superpoint_output( keypoint_map, descriptor_map, num_features, pad_random_features ) # Rescale keypoint locations to match original input image size, and return. keypoints = keypoints / keypoint_scale_factors return (keypoints, descriptors, scores) def _resize_input_image(self, image, interpolation=cv2.INTER_LINEAR): """Resizes image such that both dimensions are divisble by 8.""" # Calculate new image dimensions and per-dimension resizing scale factor. new_dim = [-1, -1] keypoint_scale_factors = [1.0, 1.0] for i in range(2): dim_size = image.shape[i] mod_eight = dim_size % 8 if mod_eight < 4: # Round down to nearest multiple of 8. new_dim[i] = dim_size - mod_eight elif mod_eight >= 4: # Round up to nearest multiple of 8. new_dim[i] = dim_size + (8 - mod_eight) keypoint_scale_factors[i] = (new_dim[i] - 1) / (dim_size - 1) # Resize and return image + scale factors. new_dim = new_dim[::-1] # Convert from (row, col) to (x,y). keypoint_scale_factors = keypoint_scale_factors[::-1] image = cv2.resize(image, tuple(new_dim), interpolation=interpolation) return image, keypoint_scale_factors def _preprocess_image(self, image): """Converts image to grayscale and normalizes values for model input.""" image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) image = np.expand_dims(image, 2) image = image.astype(np.float32) image = image / 255.0 return image def _extract_superpoint_output( self, keypoint_map, descriptor_map, keep_k_points=512, pad_random_features=False, ): """Converts from raw SuperPoint output (feature maps) into numpy arrays. If keep_k_points is 0, then keep all detected keypoints. Otherwise, sort by confidence and keep only the top k confidence keypoints. Args: keypoint_map: (H, W, 1) numpy array, raw output confidence values from SuperPoint model. descriptor_map: (H, W, 256) numpy array, raw output descriptors from SuperPoint model. keep_k_points: int, number of keypoints to keep (or 0 to indicate keeping all detected keypoints). pad_random_features: if True, adds randomly sampled keypoints to the output such that there are exactly 'num_features' keypoints. Descriptors for these sampled keypoints are taken from the network's descriptor map output, and scores are set to 0. No action taken if keep_k_points = 0. Returns: keypoints: (N, 2) numpy array, image coordinates (x, y) of keypoints as floats. descriptors: (N, 256) numpy array, descriptors for keypoints as floats. scores: (N, 1) numpy array, confidence values for keypoints as floats. """ def _select_k_best(points, k): sorted_prob = points[points[:, 2].argsort(), :] start = min(k, points.shape[0]) return sorted_prob[-start:, :2], sorted_prob[-start:, 2] keypoints = np.where(keypoint_map > 0) prob = keypoint_map[keypoints[0], keypoints[1]] keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1) # Keep only top k points, or all points if keep_k_points param is 0. if keep_k_points == 0: keep_k_points = keypoints.shape[0] keypoints, scores = _select_k_best(keypoints, keep_k_points) # Optionally, pad with random features (and confidence scores of 0). image_shape = np.array(keypoint_map.shape[:2]) if pad_random_features and (keep_k_points > keypoints.shape[0]): num_pad = keep_k_points - keypoints.shape[0] keypoints_pad = (image_shape - 1) * np.random.uniform(size=(num_pad, 2)) keypoints = np.concatenate((keypoints, keypoints_pad)) scores_pad = np.zeros((num_pad)) scores = np.concatenate((scores, scores_pad)) # Lookup descriptors via bilinear interpolation. # TODO: batch descriptor lookup with bilinear interpolation. keypoints[:, [0, 1]] = keypoints[:, [1, 0]] # Swap from (row,col) to (x,y). descriptors = [] for kp in keypoints: descriptors.append(utils.lookup_descriptor_bilinear(kp, descriptor_map)) descriptors = np.array(descriptors) return keypoints, descriptors, scores