Spaces:
Running
Running
File size: 1,443 Bytes
499e141 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from .superpoint import SuperPoint
from .models.matchers.lightglue import LightGlue
class Matching(torch.nn.Module):
""" Image Matching Frontend (SuperPoint + SuperGlue) """
# noinspection PyDefaultArgument
def __init__(self, config={}):
super().__init__()
self.detector = SuperPoint({
'max_num_keypoints': 2048,
'force_num_keypoints': True,
'detection_threshold': 0.0,
'nms_radius': 3,
'trainable': False,
})
self.model = LightGlue({
'filter_threshold': 0.1,
'flash': False,
'checkpointed': True,
})
def forward(self, data):
""" Run SuperPoint (optionally) and SuperGlue
SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
Args:
data: dictionary with minimal keys: ['image0', 'image1']
"""
pred = {}
pred.update({k + '0': v for k, v in self.detector({
"image": data["gray0"],
"image_size": data["size0"],
}).items()})
pred.update({k + '1': v for k, v in self.detector({
"image": data["gray1"],
"image_size": data["size1"],
}).items()})
pred.update(self.model({
**pred, **{
'resize0': data['size0'],
'resize1': data['size1']
}
}))
return pred
|