Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: utf-8 -*- | |
| # @Author : xuelun | |
| import os | |
| import torch | |
| import warnings | |
| import numpy as np | |
| from tqdm import tqdm | |
| from os.path import join | |
| from pathlib import Path | |
| from argparse import ArgumentParser | |
| from hloc import pairs_from_exhaustive | |
| from hloc import extract_features, match_features, match_dense, reconstruction | |
| from hloc.utils import segment | |
| from hloc.utils.io import read_image | |
| from hloc.match_dense import ImagePairDataset | |
| from networks.lightglue.superpoint import SuperPoint | |
| from networks.lightglue.models.matchers.lightglue import LightGlue | |
| from networks.mit_semseg.models import ModelBuilder, SegmentationModule | |
| def segmentation(images, segment_root, matcher_conf): | |
| # initial device | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # initial segmentation mode | |
| net_encoder = ModelBuilder.build_encoder( | |
| arch='resnet50dilated', | |
| fc_dim=2048, | |
| weights='weights/encoder_epoch_20.pth') | |
| net_decoder = ModelBuilder.build_decoder( | |
| arch='ppm_deepsup', | |
| fc_dim=2048, | |
| num_class=150, | |
| weights='weights/decoder_epoch_20.pth', | |
| use_softmax=True) | |
| crit = torch.nn.NLLLoss(ignore_index=-1) | |
| segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) | |
| segmentation_module = segmentation_module.to(device).eval() | |
| # initial data reader | |
| dataset = ImagePairDataset(None, matcher_conf["preprocessing"], None) | |
| # Segment images | |
| image_list = sorted(os.listdir(images)) | |
| with torch.no_grad(): | |
| for img in tqdm(image_list): | |
| segment_path = join(segment_root, '{}.npy'.format(img[:-4])) | |
| if not os.path.exists(segment_path): | |
| rgb = read_image(images / img, dataset.conf.grayscale) | |
| mask = segment(rgb, 1920, device, segmentation_module) | |
| np.save(segment_path, mask) | |
| def main(scene_name, version): | |
| # Setup | |
| images = Path('inputs') / scene_name / 'images' | |
| outputs = Path('outputs') / scene_name / version | |
| outputs.mkdir(parents=True, exist_ok=True) | |
| os.environ['GIMRECONSTRUCTION'] = str(outputs) | |
| segment_root = Path('outputs') / scene_name / 'segment' | |
| segment_root.mkdir(parents=True, exist_ok=True) | |
| sfm_dir = outputs / 'sparse' | |
| mvs_path = outputs / 'dense' | |
| database_path = sfm_dir / 'database.db' | |
| image_pairs = outputs / 'pairs-near.txt' | |
| feature_conf = matcher_conf = None | |
| if version == 'gim_dkm': | |
| feature_conf = None | |
| matcher_conf = match_dense.confs[version] | |
| elif version == 'gim_lightglue': | |
| feature_conf = extract_features.confs['gim_superpoint'] | |
| matcher_conf = match_features.confs[version] | |
| # Find image pairs via pair-wise image | |
| exhaustive_pairs = pairs_from_exhaustive.main(image_pairs, image_list=images) | |
| segmentation(images, segment_root, matcher_conf) | |
| # Extract and match local features | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| if version == 'gim_dkm': | |
| feature_path, match_path = match_dense.main(matcher_conf, image_pairs, | |
| images, outputs) | |
| elif version == 'gim_lightglue': | |
| checkpoints_path = join('weights', 'gim_lightglue_100h.ckpt') | |
| detector = SuperPoint({ | |
| 'max_num_keypoints': 2048, | |
| 'force_num_keypoints': True, | |
| 'detection_threshold': 0.0, | |
| 'nms_radius': 3, | |
| 'trainable': False, | |
| }) | |
| state_dict = torch.load(checkpoints_path, map_location='cpu') | |
| if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
| for k in list(state_dict.keys()): | |
| if k.startswith('model.'): | |
| state_dict.pop(k) | |
| if k.startswith('superpoint.'): | |
| state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) | |
| detector.load_state_dict(state_dict) | |
| model = LightGlue({ | |
| 'filter_threshold': 0.1, | |
| 'flash': False, | |
| 'checkpointed': True, | |
| }) | |
| state_dict = torch.load(checkpoints_path, map_location='cpu') | |
| if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
| for k in list(state_dict.keys()): | |
| if k.startswith('superpoint.'): | |
| state_dict.pop(k) | |
| if k.startswith('model.'): | |
| state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
| model.load_state_dict(state_dict) | |
| feature_path = extract_features.main(feature_conf, images, outputs, | |
| model=detector) | |
| match_path = match_features.main(matcher_conf, image_pairs, | |
| feature_conf['output'], outputs, | |
| model=model) | |
| # sparse reconstruction | |
| reconstruction.main(sfm_dir, images, image_pairs, feature_path, match_path) | |
| if __name__ == '__main__': | |
| parser = ArgumentParser() | |
| parser.add_argument('--scene_name', type=str) | |
| parser.add_argument('--version', type=str, choices={'gim_dkm', 'gim_lightglue'}, | |
| default='gim_dkm') | |
| args = parser.parse_args() | |
| main(args.scene_name, args.version) | |