from PIL import Image import torch import torch.nn.functional as F import cv2 import numpy as np from matching import BaseMatcher from matching.utils import add_to_path from matching import WEIGHTS_DIR, THIRD_PARTY_DIR add_to_path(THIRD_PARTY_DIR.joinpath("ALIKED")) add_to_path(THIRD_PARTY_DIR.joinpath("vggt")) from nets.aliked import ALIKED from vggt.models.vggt import VGGT def torch_to_cv2(tensor): """Convert CxHxW [0,1] tensor to OpenCV-style output""" tensor = tensor.clone().mul(255).permute(1, 2, 0) numpy_img = tensor.byte().cpu().numpy() if numpy_img.shape[2] == 3: numpy_img = cv2.cvtColor(numpy_img, cv2.COLOR_RGB2BGR) return numpy_img class VGGTMatcher(BaseMatcher): def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): super().__init__(device, **kwargs) self.model = VGGT.from_pretrained("facebook/VGGT-1B").to(device) self.query_key_point_finder = ALIKED( model_name="aliked-n16rot", device=device, top_k=-1, scores_th=0.8, n_limit=max_num_keypoints ) self.target_size = 518 self.patch_size = 14 self.device = device def preprocess(self, img, mode="crop"): """ Preprocess a single image tensor for model input. Returns: (batched tensor of shape (1, 3, H, W), (original_height, original_width)) """ if not isinstance(img, torch.Tensor): raise TypeError("Input must be a torch.Tensor") if img.dim() != 3 or img.shape[0] != 3: raise ValueError("Image must have shape (3, H, W)") if mode not in ["crop", "pad"]: raise ValueError("Mode must be either 'crop' or 'pad'") _, height, width = img.shape orig_shape = (height, width) if mode == "pad": if width >= height: new_width = self.target_size new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size else: new_height = self.target_size new_width = round(width * (new_height / height) / self.patch_size) * self.patch_size else: # mode == "crop" new_width = self.target_size new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size img = F.interpolate( img.unsqueeze(0), size=(new_height, new_width), mode="bicubic", align_corners=False ).squeeze(0) if mode == "crop" and new_height > self.target_size: start_y = (new_height - self.target_size) // 2 img = img[:, start_y : start_y + self.target_size, :] if mode == "pad": h_padding = self.target_size - img.shape[1] w_padding = self.target_size - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left img = F.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0, ) return img.unsqueeze(0), orig_shape def _forward(self, img0, img1): # Preprocess both images to model input size query_image_tensor, img0_orig_shape = self.preprocess(img0) reference_image_tensor, img1_orig_shape = self.preprocess(img1) # Convert the query image to OpenCV format for ALIKED query_image_cv2 = torch_to_cv2(query_image_tensor.squeeze(0)) # Run ALIKED on the preprocessed query image pred = self.query_key_point_finder.run(query_image_cv2) mkpts0 = torch.tensor(pred['keypoints'], dtype=torch.float32, device=self.device) # Get the model input sizes H0, W0 = query_image_tensor.shape[-2:] H1, W1 = reference_image_tensor.shape[-2:] # Rescale mkpts0 from ALIKED image size (query_image_tensor) to reference image size (reference_image_tensor) mkpts0_for_model = torch.tensor( self.rescale_coords(mkpts0, H1, W1, H0, W0), dtype=torch.float32, device=self.device ) # Forward pass to VGGT with the rescaled query points pred = self.model(reference_image_tensor, query_points=mkpts0_for_model) mkpts1 = pred['track'].squeeze() # Rescale mkpts1 from reference image size (VGGT input) to original candidate image size mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) # Rescale mkpts0 from query image size (VGGT input) to original query image size mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0 ) return mkpts0, mkpts1, None, None, None, None