""" "LiftFeat: 3D Geometry-Aware Local Feature Matching" """ import numpy as np import os import torch from torch import nn import torch.nn.functional as F import tqdm import math import cv2 import sys sys.path.append('/home/yepeng_liu/code_python/laiwenpeng/LiftFeat') from utils.featurebooster import FeatureBooster from utils.config import featureboost_config # from models.model_dfb import LiftFeatModel # from models.interpolator import InterpolateSparse2d # from third_party.config import featureboost_config """ foundational functions """ def simple_nms(scores, radius): """Perform non maximum suppression on the heatmap using max-pooling. This method does not suppress contiguous points that have the same score. Args: scores: the score heatmap of size `(B, H, W)`. radius: an integer scalar, the radius of the NMS window. """ def max_pool(x): return torch.nn.functional.max_pool2d( x, kernel_size=radius * 2 + 1, stride=1, padding=radius ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) for _ in range(2): supp_mask = max_pool(max_mask.float()) > 0 supp_scores = torch.where(supp_mask, zeros, scores) new_max_mask = supp_scores == max_pool(supp_scores) max_mask = max_mask | (new_max_mask & (~supp_mask)) return torch.where(max_mask, scores, zeros) def top_k_keypoints(keypoints, scores, k): if k >= len(keypoints): return keypoints, scores scores, indices = torch.topk(scores, k, dim=0, sorted=True) return keypoints[indices], scores def sample_k_keypoints(keypoints, scores, k): if k >= len(keypoints): return keypoints, scores indices = torch.multinomial(scores, k, replacement=False) return keypoints[indices], scores[indices] def soft_argmax_refinement(keypoints, scores, radius: int): width = 2 * radius + 1 sum_ = torch.nn.functional.avg_pool2d( scores[:, None], width, 1, radius, divisor_override=1 ) ar = torch.arange(-radius, radius + 1).to(scores) kernel_x = ar[None].expand(width, -1)[None, None] dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius) dy = torch.nn.functional.conv2d( scores[:, None], kernel_x.transpose(2, 3), padding=radius ) dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None] refined_keypoints = [] for i, kpts in enumerate(keypoints): delta = dydx[i][tuple(kpts.t())] refined_keypoints.append(kpts.float() + delta) return refined_keypoints # Legacy (broken) sampling of the descriptors def sample_descriptors(keypoints, descriptors, s): b, c, h, w = descriptors.shape keypoints = keypoints - s / 2 + 0.5 keypoints /= torch.tensor( [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], ).to( keypoints )[None] keypoints = keypoints * 2 - 1 # normalize to (-1, 1) args = {"align_corners": True} if torch.__version__ >= "1.3" else {} descriptors = torch.nn.functional.grid_sample( descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args ) descriptors = torch.nn.functional.normalize( descriptors.reshape(b, c, -1), p=2, dim=1 ) return descriptors # The original keypoint sampling is incorrect. We patch it here but # keep the original one above for legacy. def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): """Interpolate descriptors at keypoint locations""" b, c, h, w = descriptors.shape keypoints = keypoints / (keypoints.new_tensor([w, h]) * s) keypoints = keypoints * 2 - 1 # normalize to (-1, 1) descriptors = torch.nn.functional.grid_sample( descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False ) descriptors = torch.nn.functional.normalize( descriptors.reshape(b, c, -1), p=2, dim=1 ) return descriptors class UpsampleLayer(nn.Module): def __init__(self, in_channels): super().__init__() # 定义特征提取层,减少通道数同时增加特征提取能力 self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1) # 使用BN层 self.bn = nn.BatchNorm2d(in_channels//2) # 使用LeakyReLU激活函数 self.leaky_relu = nn.LeakyReLU(0.1) def forward(self, x): x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) x = self.leaky_relu(self.bn(self.conv(x))) return x class KeypointHead(nn.Module): def __init__(self,in_channels,out_channels): super().__init__() self.layer1=BaseLayer(in_channels,32) self.layer2=BaseLayer(32,32) self.layer3=BaseLayer(32,64) self.layer4=BaseLayer(64,64) self.layer5=BaseLayer(64,128) self.conv=nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1) self.bn=nn.BatchNorm2d(65) def forward(self,x): x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) x=self.layer4(x) x=self.layer5(x) x=self.bn(self.conv(x)) return x class DescriptorHead(nn.Module): def __init__(self,in_channels,out_channels): super().__init__() self.layer=nn.Sequential( BaseLayer(in_channels,32), BaseLayer(32,32,activation=False), BaseLayer(32,64,activation=False), BaseLayer(64,out_channels,activation=False) ) def forward(self,x): x=self.layer(x) # x=nn.functional.softmax(x,dim=1) return x class HeatmapHead(nn.Module): def __init__(self,in_channels,mid_channels,out_channels): super().__init__() self.convHa = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) self.bnHa = nn.BatchNorm2d(mid_channels) self.convHb = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) self.bnHb = nn.BatchNorm2d(out_channels) self.leaky_relu = nn.LeakyReLU(0.1) def forward(self,x): x = self.leaky_relu(self.bnHa(self.convHa(x))) x = self.leaky_relu(self.bnHb(self.convHb(x))) x = torch.sigmoid(x) return x class DepthHead(nn.Module): def __init__(self, in_channels): super().__init__() self.upsampleDa = UpsampleLayer(in_channels) self.upsampleDb = UpsampleLayer(in_channels//2) self.upsampleDc = UpsampleLayer(in_channels//4) self.convDepa = nn.Conv2d(in_channels//2+in_channels, in_channels//2, kernel_size=3, stride=1, padding=1) self.bnDepa = nn.BatchNorm2d(in_channels//2) self.convDepb = nn.Conv2d(in_channels//4+in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1) self.bnDepb = nn.BatchNorm2d(in_channels//4) self.convDepc = nn.Conv2d(in_channels//8+in_channels//4, 3, kernel_size=3, stride=1, padding=1) self.bnDepc = nn.BatchNorm2d(3) self.leaky_relu = nn.LeakyReLU(0.1) def forward(self, x): x0 = F.interpolate(x, scale_factor=2,mode='bilinear',align_corners=False) x1 = self.upsampleDa(x) x1 = torch.cat([x0,x1],dim=1) x1 = self.leaky_relu(self.bnDepa(self.convDepa(x1))) x1_0 = F.interpolate(x1,scale_factor=2,mode='bilinear',align_corners=False) x2 = self.upsampleDb(x1) x2 = torch.cat([x1_0,x2],dim=1) x2 = self.leaky_relu(self.bnDepb(self.convDepb(x2))) x2_0 = F.interpolate(x2,scale_factor=2,mode='bilinear',align_corners=False) x3 = self.upsampleDc(x2) x3 = torch.cat([x2_0,x3],dim=1) x = self.leaky_relu(self.bnDepc(self.convDepc(x3))) x = F.normalize(x,p=2,dim=1) return x class BaseLayer(nn.Module): def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False,activation=True): super().__init__() if activation: self.layer=nn.Sequential( nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias), nn.BatchNorm2d(out_channels,affine=False), nn.ReLU(inplace=True) ) else: self.layer=nn.Sequential( nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias), nn.BatchNorm2d(out_channels,affine=False) ) def forward(self,x): return self.layer(x) class LiftFeatSPModel(nn.Module): default_conf = { "has_detector": True, "has_descriptor": True, "descriptor_dim": 64, # Inference "sparse_outputs": True, "dense_outputs": False, "nms_radius": 4, "refinement_radius": 0, "detection_threshold": 0.005, "max_num_keypoints": -1, "max_num_keypoints_val": None, "force_num_keypoints": False, "randomize_keypoints_training": False, "remove_borders": 4, "legacy_sampling": True, # True to use the old broken sampling } def __init__(self, featureboost_config, use_kenc=False, use_normal=True, use_cross=True): super().__init__() self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.descriptor_dim = 64 self.norm = nn.InstanceNorm2d(1) self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) c1,c2,c3,c4,c5 = 24,24,64,64,128 self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) self.conv5a = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) self.conv5b = nn.Conv2d(c5, c5, kernel_size=3, stride=1, padding=1) self.upsample4 = UpsampleLayer(c4) self.upsample5 = UpsampleLayer(c5) self.conv_fusion45 = nn.Conv2d(c5//2+c4,c4,kernel_size=3,stride=1,padding=1) self.conv_fusion34 = nn.Conv2d(c4//2+c3,c3,kernel_size=3,stride=1,padding=1) # detector self.keypoint_head = KeypointHead(in_channels=c3,out_channels=65) # descriptor self.descriptor_head = DescriptorHead(in_channels=c3,out_channels=self.descriptor_dim) # # heatmap # self.heatmap_head = HeatmapHead(in_channels=c3,mid_channels=c3,out_channels=1) # depth self.depth_head = DepthHead(c3) self.fine_matcher = nn.Sequential( nn.Linear(128, 512), nn.BatchNorm1d(512, affine=False), nn.ReLU(inplace = True), nn.Linear(512, 512), nn.BatchNorm1d(512, affine=False), nn.ReLU(inplace = True), nn.Linear(512, 512), nn.BatchNorm1d(512, affine=False), nn.ReLU(inplace = True), nn.Linear(512, 512), nn.BatchNorm1d(512, affine=False), nn.ReLU(inplace = True), nn.Linear(512, 64), ) # feature_booster self.feature_boost = FeatureBooster(featureboost_config, use_kenc=use_kenc, use_cross=use_cross, use_normal=use_normal) def feature_extract(self, x): x1 = self.relu(self.conv1a(x)) x1 = self.relu(self.conv1b(x1)) x1 = self.pool(x1) x2 = self.relu(self.conv2a(x1)) x2 = self.relu(self.conv2b(x2)) x2 = self.pool(x2) x3 = self.relu(self.conv3a(x2)) x3 = self.relu(self.conv3b(x3)) x3 = self.pool(x3) x4 = self.relu(self.conv4a(x3)) x4 = self.relu(self.conv4b(x4)) x4 = self.pool(x4) x5 = self.relu(self.conv5a(x4)) x5 = self.relu(self.conv5b(x5)) x5 = self.pool(x5) return x3,x4,x5 def fuse_multi_features(self,x3,x4,x5): # upsample x5 feature x5 = self.upsample5(x5) x4 = torch.cat([x4,x5],dim=1) x4 = self.conv_fusion45(x4) # upsample x4 feature x4 = self.upsample4(x4) x3 = torch.cat([x3,x4],dim=1) x = self.conv_fusion34(x3) return x def _unfold2d(self, x, ws = 2): """ Unfolds tensor in 2D with desired ws (window size) and concat the channels """ B, C, H, W = x.shape x = x.unfold(2, ws , ws).unfold(3, ws,ws).reshape(B, C, H//ws, W//ws, ws**2) return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws) def forward1(self, x): """ input: x -> torch.Tensor(B, C, H, W) grayscale or rgb images return: feats -> torch.Tensor(B, 64, H/8, W/8) dense local features keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map """ with torch.no_grad(): x = x.mean(dim=1, keepdim = True) x = self.norm(x) x3,x4,x5 = self.feature_extract(x) # features fusion x = self.fuse_multi_features(x3,x4,x5) # keypoint keypoint_map = self.keypoint_head(x) # descriptor des_map = self.descriptor_head(x) # # heatmap # heatmap = self.heatmap_head(x) # import pdb;pdb.set_trace() # depth d_feats = self.depth_head(x) return des_map, keypoint_map, d_feats # return des_map, keypoint_map, heatmap, d_feats def forward2(self, descs, kpts, normals): # import pdb;pdb.set_trace() normals_feat=self._unfold2d(normals, ws=8) normals_v=normals_feat.squeeze(0).permute(1,2,0).reshape(-1,normals_feat.shape[1]) descs_v=descs.squeeze(0).permute(1,2,0).reshape(-1,descs.shape[1]) kpts_v=kpts.squeeze(0).permute(1,2,0).reshape(-1,kpts.shape[1]) descs_refine = self.feature_boost(descs_v, kpts_v, normals_v) return descs_refine def forward(self,x): M1,K1,D1=self.forward1(x) descs_refine=self.forward2(M1,K1,D1) return descs_refine,M1,K1,D1 if __name__ == "__main__": img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg') img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE) img=cv2.resize(img,(800,608)) import pdb;pdb.set_trace() img=torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()/255.0 img=img.cuda() if torch.cuda.is_available() else img liftfeat_sp=LiftFeatSPModel(featureboost_config).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) des_map, keypoint_map, d_feats=liftfeat_sp.forward1(img) des_fine=liftfeat_sp.forward2(des_map,keypoint_map,d_feats) print(des_map.shape)