import os # gradio for visual demo import gradio as gr # transformers for easy access to nnet os.system("pip install scipy") os.system("pip install torch") os.system("pip install scikit-learn") os.system("pip install torchvision") import numpy as np import torch import torchvision.transforms as transforms from PIL import ImageDraw, ImageColor, Image from typing import Tuple from scipy.ndimage import binary_closing, binary_opening from sklearn.decomposition import PCA from sklearn.neighbors import NearestNeighbors from random import randint ### Models standard_array = np.load('standard.npy') pca_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") pca_model.eval() ### Parameters IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) smaller_edge_size = 448 interpolation_mode = transforms.InterpolationMode.BICUBIC patch_size = pca_model.patch_size background_threshold = 0.05 apply_opening = True apply_closing = True device = 'cpu' def make_transform() -> transforms.Compose: return transforms.Compose([ transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ]) def prepare_image(image: Image) -> Tuple[torch.Tensor, Tuple[int, int]]: transform = make_transform() image_tensor = transform(image) resize_scale = image.width / image_tensor.shape[2] # Crop image to dimensions that are a multiple of the patch size height, width = image_tensor.shape[1:] # C x H x W cropped_width, cropped_height = width - width % patch_size, height - height % patch_size image_tensor = image_tensor[:, :cropped_height, :cropped_width] grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check) return image_tensor, grid_size, resize_scale def make_foreground_mask(tokens, grid_size: Tuple[int, int]): projection = tokens @ standard_array mask = projection >= background_threshold mask = mask.reshape(*grid_size) if apply_opening: mask = binary_opening(mask) if apply_closing: mask = binary_closing(mask) return mask.flatten() def render_patch_pca(image: Image, mask, filter_background, tokens, grid_size) -> Image: pca = PCA(n_components=3) if filter_background : pca.fit(tokens[mask]) else : pca.fit(tokens) projected_tokens = pca.transform(tokens) t = torch.tensor(projected_tokens) t_min = t.min(dim=0, keepdim=True).values t_max = t.max(dim=0, keepdim=True).values normalized_t = (t - t_min) / (t_max - t_min) array = (normalized_t * 255).byte().numpy() if filter_background : array[~mask] = 0 array = array.reshape(*grid_size, 3) return Image.fromarray(array).resize((image.width, image.height), 0) def extract_features(img, image_tensor, filter_background, grid_size): with torch.inference_mode(): image_batch = image_tensor.unsqueeze(0).to(device) tokens = pca_model.get_intermediate_layers(image_batch)[0].squeeze() mask = make_foreground_mask(tokens, grid_size) img_pca = render_patch_pca(img, mask, filter_background, tokens, grid_size) return tokens.cpu().numpy(), mask, img_pca def compute_features(img, filter_background): image_tensor, grid_size, resize_scale = prepare_image(img) features, mask, img_pca = extract_features(img, image_tensor, filter_background, grid_size) return features, mask, grid_size, resize_scale, img_pca def idx_to_source_position(idx, grid_size, resize_scale): row = (idx // grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2 col = (idx % grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2 return row, col def compute_nn(features1, features2): knn = NearestNeighbors(n_neighbors=1) knn.fit(features1) distances, match2to1 = knn.kneighbors(features2) match2to1 = np.array(match2to1) return distances, match2to1 def compute_matches(img1, img2, lr_check, filter_background, display_matches_threshold): # compute features features1, mask1, grid_size1, resize_scale1, img_pca1 = compute_features(img1, filter_background) features2, mask2, grid_size2, resize_scale2, img_pca2 = compute_features(img2, filter_background) # match features distances2to1, match2to1 = compute_nn(features1, features2) distances1to2, match1to2 = compute_nn(features2, features1) # display matches draw1 = ImageDraw.Draw(img1) draw2 = ImageDraw.Draw(img2) if(img1.size[1] > img2.size[1]): img1 = img1.resize(img2.size) resize_scale1 = resize_scale2 else: img2 = img2.resize(img1.size) resize_scale2 = resize_scale1 offset = img1.size[0] merged_image = Image.new('RGB',(offset + img2.size[0], max(img1.size[1], img2.size[1])), (250,250,250)) merged_image.paste(img1,(0,0)) merged_image.paste(img2,(offset,0)) draw = ImageDraw.Draw(merged_image) colormap = ImageColor.colormap for idx2, idx1 in enumerate(match2to1): if lr_check and match1to2[idx1] != idx2: continue row1, col1 = idx_to_source_position(idx1, grid_size1, resize_scale1) row2, col2 = idx_to_source_position(idx2, grid_size2, resize_scale2) if filter_background and not mask1[idx1]: continue if filter_background and not mask2[idx2]: continue r = randint(0,255) g = randint(0,255) color = (r,g,255-r) draw1.point((col1, row1), color) draw2.point((col2, row2), color) if 100*np.random.rand() > display_matches_threshold: continue draw.line((col1, row1, col2 + offset, row2), fill=color) return [[img1, img_pca1], [img2, img_pca2], merged_image] iface = gr.Interface(fn=compute_matches, inputs=[ gr.Image(type="pil"), gr.Image(type="pil"), gr.Checkbox(label="Keep only symmetric matches",), gr.Checkbox(label="Mask background"), gr.Slider(0, 100, step=5, value=5, label="Display matches ratio", info="Choose between 0 and 100%"), ], outputs=[ gr.Gallery( label="Image 1", show_label=False, elem_id="gallery", columns=[2], rows=[1], object_fit="contain", height="auto"), gr.Gallery( label="Image 1", show_label=False, elem_id="gallery", columns=[2], rows=[1], object_fit="contain", height="auto"), gr.Image(type="pil") ]) iface.launch(debug=True)