Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Author : xuelun | |
import os | |
import torch | |
import warnings | |
import numpy as np | |
from tqdm import tqdm | |
from os.path import join | |
from pathlib import Path | |
from argparse import ArgumentParser | |
from hloc import pairs_from_exhaustive | |
from hloc import extract_features, match_features, match_dense, reconstruction | |
from hloc.utils import segment | |
from hloc.utils.io import read_image | |
from hloc.match_dense import ImagePairDataset | |
from networks.lightglue.superpoint import SuperPoint | |
from networks.lightglue.models.matchers.lightglue import LightGlue | |
from networks.mit_semseg.models import ModelBuilder, SegmentationModule | |
def segmentation(images, segment_root, matcher_conf): | |
# initial device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# initial segmentation mode | |
net_encoder = ModelBuilder.build_encoder( | |
arch='resnet50dilated', | |
fc_dim=2048, | |
weights='weights/encoder_epoch_20.pth') | |
net_decoder = ModelBuilder.build_decoder( | |
arch='ppm_deepsup', | |
fc_dim=2048, | |
num_class=150, | |
weights='weights/decoder_epoch_20.pth', | |
use_softmax=True) | |
crit = torch.nn.NLLLoss(ignore_index=-1) | |
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) | |
segmentation_module = segmentation_module.to(device).eval() | |
# initial data reader | |
dataset = ImagePairDataset(None, matcher_conf["preprocessing"], None) | |
# Segment images | |
image_list = sorted(os.listdir(images)) | |
with torch.no_grad(): | |
for img in tqdm(image_list): | |
segment_path = join(segment_root, '{}.npy'.format(img[:-4])) | |
if not os.path.exists(segment_path): | |
rgb = read_image(images / img, dataset.conf.grayscale) | |
mask = segment(rgb, 1920, device, segmentation_module) | |
np.save(segment_path, mask) | |
def main(scene_name, version): | |
# Setup | |
images = Path('inputs') / scene_name / 'images' | |
outputs = Path('outputs') / scene_name / version | |
outputs.mkdir(parents=True, exist_ok=True) | |
os.environ['GIMRECONSTRUCTION'] = str(outputs) | |
segment_root = Path('outputs') / scene_name / 'segment' | |
segment_root.mkdir(parents=True, exist_ok=True) | |
sfm_dir = outputs / 'sparse' | |
mvs_path = outputs / 'dense' | |
database_path = sfm_dir / 'database.db' | |
image_pairs = outputs / 'pairs-near.txt' | |
feature_conf = matcher_conf = None | |
if version == 'gim_dkm': | |
feature_conf = None | |
matcher_conf = match_dense.confs[version] | |
elif version == 'gim_lightglue': | |
feature_conf = extract_features.confs['gim_superpoint'] | |
matcher_conf = match_features.confs[version] | |
# Find image pairs via pair-wise image | |
exhaustive_pairs = pairs_from_exhaustive.main(image_pairs, image_list=images) | |
segmentation(images, segment_root, matcher_conf) | |
# Extract and match local features | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", category=UserWarning) | |
if version == 'gim_dkm': | |
feature_path, match_path = match_dense.main(matcher_conf, image_pairs, | |
images, outputs) | |
elif version == 'gim_lightglue': | |
checkpoints_path = join('weights', 'gim_lightglue_100h.ckpt') | |
detector = SuperPoint({ | |
'max_num_keypoints': 2048, | |
'force_num_keypoints': True, | |
'detection_threshold': 0.0, | |
'nms_radius': 3, | |
'trainable': False, | |
}) | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
for k in list(state_dict.keys()): | |
if k.startswith('model.'): | |
state_dict.pop(k) | |
if k.startswith('superpoint.'): | |
state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) | |
detector.load_state_dict(state_dict) | |
model = LightGlue({ | |
'filter_threshold': 0.1, | |
'flash': False, | |
'checkpointed': True, | |
}) | |
state_dict = torch.load(checkpoints_path, map_location='cpu') | |
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
for k in list(state_dict.keys()): | |
if k.startswith('superpoint.'): | |
state_dict.pop(k) | |
if k.startswith('model.'): | |
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
model.load_state_dict(state_dict) | |
feature_path = extract_features.main(feature_conf, images, outputs, | |
model=detector) | |
match_path = match_features.main(matcher_conf, image_pairs, | |
feature_conf['output'], outputs, | |
model=model) | |
# sparse reconstruction | |
reconstruction.main(sfm_dir, images, image_pairs, feature_path, match_path) | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
parser.add_argument('--scene_name', type=str) | |
parser.add_argument('--version', type=str, choices={'gim_dkm', 'gim_lightglue'}, | |
default='gim_dkm') | |
args = parser.parse_args() | |
main(args.scene_name, args.version) | |