Pawel Piwowarski
init commit
0a82b18
from kornia.feature.lightglue import LightGlue
from torch import nn
import torch
import os
class LighterGlue(nn.Module):
"""
Lighter version of LightGlue :)
"""
default_conf_xfeat = {
"name": "xfeat", # just for interfacing
"input_dim": 64, # input descriptor dimension (autoselected from weights)
"descriptor_dim": 96,
"add_scale_ori": False,
"add_laf": False, # for KeyNetAffNetHardNet
"scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet
"n_layers": 6,
"num_heads": 1,
"flash": True, # enable FlashAttention if available.
"mp": False, # enable mixed precision
"depth_confidence": -1, # early stopping, disable with -1
"width_confidence": 0.95, # point pruning, disable with -1
"filter_threshold": 0.1, # match threshold
"weights": None,
}
def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat-lighterglue.pt'):
super().__init__()
LightGlue.default_conf = self.default_conf_xfeat
self.net = LightGlue(None)
self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if os.path.exists(weights):
state_dict = torch.load(weights, map_location=self.dev)
else:
state_dict = torch.hub.load_state_dict_from_url("https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt")
# rename old state dict entries
for i in range(self.net.conf.n_layers):
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
state_dict = {k.replace('matcher.', ''): v for k, v in state_dict.items()}
self.net.load_state_dict(state_dict, strict=False)
self.net.to(self.dev)
@torch.inference_mode()
def forward(self, data, min_conf = 0.1):
self.net.conf.filter_threshold = min_conf
result = self.net( { 'image0': {'keypoints': data['keypoints0'], 'descriptors': data['descriptors0'], 'image_size': data['image_size0']},
'image1': {'keypoints': data['keypoints1'], 'descriptors': data['descriptors1'], 'image_size': data['image_size1']}
} )
return result