EarthLoc2
/
image-matching-models
/matching
/third_party
/accelerated_features
/modules
/lighterglue.py
from kornia.feature.lightglue import LightGlue | |
from torch import nn | |
import torch | |
import os | |
class LighterGlue(nn.Module): | |
""" | |
Lighter version of LightGlue :) | |
""" | |
default_conf_xfeat = { | |
"name": "xfeat", # just for interfacing | |
"input_dim": 64, # input descriptor dimension (autoselected from weights) | |
"descriptor_dim": 96, | |
"add_scale_ori": False, | |
"add_laf": False, # for KeyNetAffNetHardNet | |
"scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet | |
"n_layers": 6, | |
"num_heads": 1, | |
"flash": True, # enable FlashAttention if available. | |
"mp": False, # enable mixed precision | |
"depth_confidence": -1, # early stopping, disable with -1 | |
"width_confidence": 0.95, # point pruning, disable with -1 | |
"filter_threshold": 0.1, # match threshold | |
"weights": None, | |
} | |
def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat-lighterglue.pt'): | |
super().__init__() | |
LightGlue.default_conf = self.default_conf_xfeat | |
self.net = LightGlue(None) | |
self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if os.path.exists(weights): | |
state_dict = torch.load(weights, map_location=self.dev) | |
else: | |
state_dict = torch.hub.load_state_dict_from_url("https://github.com/verlab/accelerated_features/raw/main/weights/xfeat-lighterglue.pt") | |
# rename old state dict entries | |
for i in range(self.net.conf.n_layers): | |
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" | |
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} | |
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" | |
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} | |
state_dict = {k.replace('matcher.', ''): v for k, v in state_dict.items()} | |
self.net.load_state_dict(state_dict, strict=False) | |
self.net.to(self.dev) | |
def forward(self, data, min_conf = 0.1): | |
self.net.conf.filter_threshold = min_conf | |
result = self.net( { 'image0': {'keypoints': data['keypoints0'], 'descriptors': data['descriptors0'], 'image_size': data['image_size0']}, | |
'image1': {'keypoints': data['keypoints1'], 'descriptors': data['descriptors1'], 'image_size': data['image_size1']} | |
} ) | |
return result | |