Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from kornia.color import rgb_to_grayscale | |
| from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori | |
| from .sift import SIFT | |
| class DoGHardNet(SIFT): | |
| required_data_keys = ["image"] | |
| def __init__(self, **conf): | |
| super().__init__(**conf) | |
| self.laf_desc = LAFDescriptor(HardNet(True)).eval() | |
| def forward(self, data: dict) -> dict: | |
| image = data["image"] | |
| if image.shape[1] == 3: | |
| image = rgb_to_grayscale(image) | |
| device = image.device | |
| self.laf_desc = self.laf_desc.to(device) | |
| self.laf_desc.descriptor = self.laf_desc.descriptor.eval() | |
| pred = [] | |
| if "image_size" in data.keys(): | |
| im_size = data.get("image_size").long() | |
| else: | |
| im_size = None | |
| for k in range(len(image)): | |
| img = image[k] | |
| if im_size is not None: | |
| w, h = data["image_size"][k] | |
| img = img[:, : h.to(torch.int32), : w.to(torch.int32)] | |
| p = self.extract_single_image(img) | |
| lafs = laf_from_center_scale_ori( | |
| p["keypoints"].reshape(1, -1, 2), | |
| 6.0 * p["scales"].reshape(1, -1, 1, 1), | |
| torch.rad2deg(p["oris"]).reshape(1, -1, 1), | |
| ).to(device) | |
| p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128) | |
| pred.append(p) | |
| pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} | |
| return pred | |