Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from einops.einops import rearrange | |
from .backbone import build_backbone | |
from .utils.position_encoding import PositionEncodingSine | |
from .submodules import LocalFeatureTransformer, FinePreprocess | |
import warnings | |
from .utils.coarse_matching import CoarseMatching | |
warnings.simplefilter("ignore", UserWarning) | |
from .utils.fine_matching import FineMatching | |
class LoFTR(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
# Misc | |
self.config = config | |
# Modules | |
self.backbone = build_backbone(config) | |
self.pos_encoding = PositionEncodingSine( | |
config['coarse']['d_model'], | |
temp_bug_fix=False) | |
self.loftr_coarse = LocalFeatureTransformer(config['coarse']) | |
self.coarse_matching = CoarseMatching(config['match_coarse']) | |
self.fine_preprocess = FinePreprocess(config) | |
self.loftr_fine = LocalFeatureTransformer(config["fine"]) | |
self.fine_matching = FineMatching() | |
""" | |
outdoor_ds.ckpt: {OrderedDict: 211} | |
backbone: {OrderedDict: 107} | |
loftr_coarse: {OrderedDict: 80} | |
loftr_fine: {OrderedDict: 20} | |
fine_preprocess: {OrderedDict: 4} | |
""" | |
if config['weight'] is not None: | |
weights = torch.load(config['weight'], map_location='cpu') | |
self.load_state_dict(weights) | |
# print(config['weight'] + ' load success.') | |
def forward(self, data): | |
""" | |
Update: | |
data (dict): { | |
'image0': (torch.Tensor): (N, 1, H, W) | |
'image1': (torch.Tensor): (N, 1, H, W) | |
'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position | |
'mask1'(optional) : (torch.Tensor): (N, H, W) | |
} | |
""" | |
# 1. Local Feature CNN | |
data.update({ | |
'bs': data['image0'].size(0), | |
'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] | |
}) | |
if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence | |
feats_c, feats_f = self.backbone(torch.cat([data['color0'], data['color1']], dim=0)) # h == h0 == h1, w == w0 == w1feats_c: (bs*2, 256, h//8, w//8), feats_f: (bs*2, 128, h//2, w//2) | |
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) # feat_c0, feat_c1: (bs, 256, h//8, w//8), feat_f0, feat_f1: (bs, 128, h//2, w//2) | |
else: # handle different input shapes | |
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['color0']), self.backbone(data['color1']) | |
data.update({ | |
'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], | |
'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] | |
}) | |
# 2. coarse-level loftr module | |
b, c, h0, w0 = feat_c0.size() | |
_, _, h1, w1 = feat_c1.size() | |
# add featmap with positional encoding, then flatten it to sequence [N, HW, C] | |
feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') | |
feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') | |
mask_c0 = mask_c1 = None # mask is useful in training | |
if 'mask0' in data: | |
mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) | |
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) | |
# 3. match coarse-level | |
self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) | |
# 4. fine-level refinement | |
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) | |
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted | |
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) | |
# 5. match fine-level | |
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) | |
def load_state_dict(self, state_dict, *args, **kwargs): | |
for k in list(state_dict.keys()): | |
if k.startswith('model.'): | |
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
if k.startswith('matcher.'): | |
state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) | |
return super().load_state_dict(state_dict, *args, **kwargs) | |