Realcat's picture
update: major change
499e141
import torch
from .superpoint import SuperPoint
from .models.matchers.lightglue import LightGlue
class Matching(torch.nn.Module):
""" Image Matching Frontend (SuperPoint + SuperGlue) """
# noinspection PyDefaultArgument
def __init__(self, config={}):
super().__init__()
self.detector = SuperPoint({
'max_num_keypoints': 2048,
'force_num_keypoints': True,
'detection_threshold': 0.0,
'nms_radius': 3,
'trainable': False,
})
self.model = LightGlue({
'filter_threshold': 0.1,
'flash': False,
'checkpointed': True,
})
def forward(self, data):
""" Run SuperPoint (optionally) and SuperGlue
SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
Args:
data: dictionary with minimal keys: ['image0', 'image1']
"""
pred = {}
pred.update({k + '0': v for k, v in self.detector({
"image": data["gray0"],
"image_size": data["size0"],
}).items()})
pred.update({k + '1': v for k, v in self.detector({
"image": data["gray1"],
"image_size": data["size1"],
}).items()})
pred.update(self.model({
**pred, **{
'resize0': data['size0'],
'resize1': data['size1']
}
}))
return pred