Realcat
update: major change
499e141
raw
history blame
10.7 kB
import argparse
import numpy as np
from PIL import Image
import torch
import math
from tqdm import tqdm
from os import path
# Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) and more generally sensor-acquired data
# it can be installed with
# pip install kapture
# for more information check out https://github.com/naver/kapture
import kapture
from kapture.io.records import get_image_fullpath
from kapture.io.csv import kapture_from_dir, get_all_tar_handlers
from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file
from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file
from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file
from lib.model_test import D2Net
from lib.utils import preprocess_image
from lib.pyramid import process_multiscale
# import imageio
# CUDA
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# Argument parsing
parser = argparse.ArgumentParser(description='Feature extraction script')
parser.add_argument(
'--kapture-root', type=str, required=True,
help='path to kapture root directory'
)
parser.add_argument(
'--preprocessing', type=str, default='caffe',
help='image preprocessing (caffe or torch)'
)
parser.add_argument(
'--model_file', type=str, default='models/d2_tf.pth',
help='path to the full model'
)
parser.add_argument(
'--keypoints-type', type=str, default=None,
help='keypoint type_name, default is filename of model'
)
parser.add_argument(
'--descriptors-type', type=str, default=None,
help='descriptors type_name, default is filename of model'
)
parser.add_argument(
'--max_edge', type=int, default=1600,
help='maximum image size at network input'
)
parser.add_argument(
'--max_sum_edges', type=int, default=2800,
help='maximum sum of image sizes at network input'
)
parser.add_argument(
'--multiscale', dest='multiscale', action='store_true',
help='extract multiscale features'
)
parser.set_defaults(multiscale=False)
parser.add_argument(
'--no-relu', dest='use_relu', action='store_false',
help='remove ReLU after the dense feature extraction module'
)
parser.set_defaults(use_relu=True)
parser.add_argument("--max-keypoints", type=int, default=float("+inf"),
help='max number of keypoints save to disk')
args = parser.parse_args()
print(args)
with get_all_tar_handlers(args.kapture_root,
mode={kapture.Keypoints: 'a',
kapture.Descriptors: 'a',
kapture.GlobalFeatures: 'r',
kapture.Matches: 'r'}) as tar_handlers:
kdata = kapture_from_dir(args.kapture_root,
skip_list=[kapture.GlobalFeatures,
kapture.Matches,
kapture.Points3d,
kapture.Observations],
tar_handlers=tar_handlers)
if kdata.keypoints is None:
kdata.keypoints = {}
if kdata.descriptors is None:
kdata.descriptors = {}
assert kdata.records_camera is not None
image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)]
if args.keypoints_type is None:
args.keypoints_type = path.splitext(path.basename(args.model_file))[0]
print(f'keypoints_type set to {args.keypoints_type}')
if args.descriptors_type is None:
args.descriptors_type = path.splitext(path.basename(args.model_file))[0]
print(f'descriptors_type set to {args.descriptors_type}')
if args.keypoints_type in kdata.keypoints and args.descriptors_type in kdata.descriptors:
image_list = [name
for name in image_list
if name not in kdata.keypoints[args.keypoints_type] or
name not in kdata.descriptors[args.descriptors_type]]
if len(image_list) == 0:
print('All features were already extracted')
exit(0)
else:
print(f'Extracting d2net features for {len(image_list)} images')
# Creating CNN model
model = D2Net(
model_file=args.model_file,
use_relu=args.use_relu,
use_cuda=use_cuda
)
if args.keypoints_type not in kdata.keypoints:
keypoints_dtype = None
keypoints_dsize = None
else:
keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype
keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize
if args.descriptors_type not in kdata.descriptors:
descriptors_dtype = None
descriptors_dsize = None
else:
descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype
descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize
# Process the files
for image_name in tqdm(image_list, total=len(image_list)):
img_path = get_image_fullpath(args.kapture_root, image_name)
image = Image.open(img_path).convert('RGB')
width, height = image.size
resized_image = image
resized_width = width
resized_height = height
max_edge = args.max_edge
max_sum_edges = args.max_sum_edges
if max(resized_width, resized_height) > max_edge:
scale_multiplier = max_edge / max(resized_width, resized_height)
resized_width = math.floor(resized_width * scale_multiplier)
resized_height = math.floor(resized_height * scale_multiplier)
resized_image = image.resize((resized_width, resized_height))
if resized_width + resized_height > max_sum_edges:
scale_multiplier = max_sum_edges / (resized_width + resized_height)
resized_width = math.floor(resized_width * scale_multiplier)
resized_height = math.floor(resized_height * scale_multiplier)
resized_image = image.resize((resized_width, resized_height))
fact_i = width / resized_width
fact_j = height / resized_height
resized_image = np.array(resized_image).astype('float')
input_image = preprocess_image(
resized_image,
preprocessing=args.preprocessing
)
with torch.no_grad():
if args.multiscale:
keypoints, scores, descriptors = process_multiscale(
torch.tensor(
input_image[np.newaxis, :, :, :].astype(np.float32),
device=device
),
model
)
else:
keypoints, scores, descriptors = process_multiscale(
torch.tensor(
input_image[np.newaxis, :, :, :].astype(np.float32),
device=device
),
model,
scales=[1]
)
# Input image coordinates
keypoints[:, 0] *= fact_i
keypoints[:, 1] *= fact_j
# i, j -> u, v
keypoints = keypoints[:, [1, 0, 2]]
if args.max_keypoints != float("+inf"):
# keep the last (the highest) indexes
idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints):]
keypoints = keypoints[idx_keep]
descriptors = descriptors[idx_keep]
if keypoints_dtype is None or descriptors_dtype is None:
keypoints_dtype = keypoints.dtype
descriptors_dtype = descriptors.dtype
keypoints_dsize = keypoints.shape[1]
descriptors_dsize = descriptors.shape[1]
kdata.keypoints[args.keypoints_type] = kapture.Keypoints('d2net', keypoints_dtype, keypoints_dsize)
kdata.descriptors[args.descriptors_type] = kapture.Descriptors('d2net', descriptors_dtype,
descriptors_dsize,
args.keypoints_type, 'L2')
keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints,
args.keypoints_type,
args.kapture_root)
descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors,
args.descriptors_type,
args.kapture_root)
keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type])
descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type])
else:
assert kdata.keypoints[args.keypoints_type].dtype == keypoints.dtype
assert kdata.descriptors[args.descriptors_type].dtype == descriptors.dtype
assert kdata.keypoints[args.keypoints_type].dsize == keypoints.shape[1]
assert kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1]
assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type
assert kdata.descriptors[args.descriptors_type].metric_type == 'L2'
keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root,
image_name, tar_handlers)
print(f"Saving {keypoints.shape[0]} keypoints to {keypoints_fullpath}")
image_keypoints_to_file(keypoints_fullpath, keypoints)
kdata.keypoints[args.keypoints_type].add(image_name)
descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root,
image_name, tar_handlers)
print(f"Saving {descriptors.shape[0]} descriptors to {descriptors_fullpath}")
image_descriptors_to_file(descriptors_fullpath, descriptors)
kdata.descriptors[args.descriptors_type].add(image_name)
if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type,
args.kapture_root, tar_handlers) or \
not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type,
args.kapture_root, tar_handlers):
print('local feature extraction ended successfully but not all files were saved')