Spaces:
Running
on
Zero
Running
on
Zero
| import collections.abc as collections | |
| from pathlib import Path | |
| from types import SimpleNamespace | |
| from typing import Callable, List, Optional, Tuple, Union | |
| import cv2 | |
| import kornia | |
| import numpy as np | |
| import torch | |
| class ImagePreprocessor: | |
| default_conf = { | |
| "resize": None, # target edge length, None for no resizing | |
| "side": "long", | |
| "interpolation": "bilinear", | |
| "align_corners": None, | |
| "antialias": True, | |
| } | |
| def __init__(self, **conf) -> None: | |
| super().__init__() | |
| self.conf = {**self.default_conf, **conf} | |
| self.conf = SimpleNamespace(**self.conf) | |
| def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Resize and preprocess an image, return image and resize scale""" | |
| h, w = img.shape[-2:] | |
| if self.conf.resize is not None: | |
| img = kornia.geometry.transform.resize( | |
| img, | |
| self.conf.resize, | |
| side=self.conf.side, | |
| antialias=self.conf.antialias, | |
| align_corners=self.conf.align_corners, | |
| ) | |
| scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) | |
| return img, scale | |
| def map_tensor(input_, func: Callable): | |
| string_classes = (str, bytes) | |
| if isinstance(input_, string_classes): | |
| return input_ | |
| elif isinstance(input_, collections.Mapping): | |
| return {k: map_tensor(sample, func) for k, sample in input_.items()} | |
| elif isinstance(input_, collections.Sequence): | |
| return [map_tensor(sample, func) for sample in input_] | |
| elif isinstance(input_, torch.Tensor): | |
| return func(input_) | |
| else: | |
| return input_ | |
| def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True): | |
| """Move batch (dict) to device""" | |
| def _func(tensor): | |
| return tensor.to(device=device, non_blocking=non_blocking).detach() | |
| return map_tensor(batch, _func) | |
| def rbd(data: dict) -> dict: | |
| """Remove batch dimension from elements in data""" | |
| return { | |
| k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v | |
| for k, v in data.items() | |
| } | |
| def read_image(path: Path, grayscale: bool = False) -> np.ndarray: | |
| """Read an image from path as RGB or grayscale""" | |
| if not Path(path).exists(): | |
| raise FileNotFoundError(f"No image at path {path}.") | |
| mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR | |
| image = cv2.imread(str(path), mode) | |
| if image is None: | |
| raise IOError(f"Could not read image at {path}.") | |
| if not grayscale: | |
| image = image[..., ::-1] | |
| return image | |
| def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: | |
| """Normalize the image tensor and reorder the dimensions.""" | |
| if image.ndim == 3: | |
| image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | |
| elif image.ndim == 2: | |
| image = image[None] # add channel axis | |
| else: | |
| raise ValueError(f"Not an image: {image.shape}") | |
| return torch.tensor(image / 255.0, dtype=torch.float) | |
| def resize_image( | |
| image: np.ndarray, | |
| size: Union[List[int], int], | |
| fn: str = "max", | |
| interp: Optional[str] = "area", | |
| ) -> np.ndarray: | |
| """Resize an image to a fixed size, or according to max or min edge.""" | |
| h, w = image.shape[:2] | |
| fn = {"max": max, "min": min}[fn] | |
| if isinstance(size, int): | |
| scale = size / fn(h, w) | |
| h_new, w_new = int(round(h * scale)), int(round(w * scale)) | |
| scale = (w_new / w, h_new / h) | |
| elif isinstance(size, (tuple, list)): | |
| h_new, w_new = size | |
| scale = (w_new / w, h_new / h) | |
| else: | |
| raise ValueError(f"Incorrect new size: {size}") | |
| mode = { | |
| "linear": cv2.INTER_LINEAR, | |
| "cubic": cv2.INTER_CUBIC, | |
| "nearest": cv2.INTER_NEAREST, | |
| "area": cv2.INTER_AREA, | |
| }[interp] | |
| return cv2.resize(image, (w_new, h_new), interpolation=mode), scale | |
| def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor: | |
| image = read_image(path) | |
| if resize is not None: | |
| image, _ = resize_image(image, resize, **kwargs) | |
| return numpy_image_to_torch(image) | |
| class Extractor(torch.nn.Module): | |
| def __init__(self, **conf): | |
| super().__init__() | |
| self.conf = SimpleNamespace(**{**self.default_conf, **conf}) | |
| def extract(self, img: torch.Tensor, **conf) -> dict: | |
| """Perform extraction with online resizing""" | |
| if img.dim() == 3: | |
| img = img[None] # add batch dim | |
| assert img.dim() == 4 and img.shape[0] == 1 | |
| shape = img.shape[-2:][::-1] | |
| img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) | |
| feats = self.forward({"image": img}) | |
| feats["image_size"] = torch.tensor(shape)[None].to(img).float() | |
| feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 | |
| return feats | |
| def match_pair( | |
| extractor, | |
| matcher, | |
| image0: torch.Tensor, | |
| image1: torch.Tensor, | |
| device: str = "cpu", | |
| **preprocess, | |
| ): | |
| """Match a pair of images (image0, image1) with an extractor and matcher""" | |
| feats0 = extractor.extract(image0, **preprocess) | |
| feats1 = extractor.extract(image1, **preprocess) | |
| matches01 = matcher({"image0": feats0, "image1": feats1}) | |
| data = [feats0, feats1, matches01] | |
| # remove batch dim and move to target device | |
| feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data] | |
| return feats0, feats1, matches01 | |