Realcat
update: major change
499e141
raw
history blame
4.27 kB
import argparse
import numpy as np
import imageio
import torch
from tqdm import tqdm
import scipy
import scipy.io
import scipy.misc
from lib.model_test import D2Net
from lib.utils import preprocess_image
from lib.pyramid import process_multiscale
# CUDA
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# Argument parsing
parser = argparse.ArgumentParser(description='Feature extraction script')
parser.add_argument(
'--image_list_file', type=str, required=True,
help='path to a file containing a list of images to process'
)
parser.add_argument(
'--preprocessing', type=str, default='caffe',
help='image preprocessing (caffe or torch)'
)
parser.add_argument(
'--model_file', type=str, default='models/d2_tf.pth',
help='path to the full model'
)
parser.add_argument(
'--max_edge', type=int, default=1600,
help='maximum image size at network input'
)
parser.add_argument(
'--max_sum_edges', type=int, default=2800,
help='maximum sum of image sizes at network input'
)
parser.add_argument(
'--output_extension', type=str, default='.d2-net',
help='extension for the output'
)
parser.add_argument(
'--output_type', type=str, default='npz',
help='output file type (npz or mat)'
)
parser.add_argument(
'--multiscale', dest='multiscale', action='store_true',
help='extract multiscale features'
)
parser.set_defaults(multiscale=False)
parser.add_argument(
'--no-relu', dest='use_relu', action='store_false',
help='remove ReLU after the dense feature extraction module'
)
parser.set_defaults(use_relu=True)
args = parser.parse_args()
print(args)
# Creating CNN model
model = D2Net(
model_file=args.model_file,
use_relu=args.use_relu,
use_cuda=use_cuda
)
# Process the file
with open(args.image_list_file, 'r') as f:
lines = f.readlines()
for line in tqdm(lines, total=len(lines)):
path = line.strip()
image = imageio.imread(path)
if len(image.shape) == 2:
image = image[:, :, np.newaxis]
image = np.repeat(image, 3, -1)
# TODO: switch to PIL.Image due to deprecation of scipy.misc.imresize.
resized_image = image
if max(resized_image.shape) > args.max_edge:
resized_image = scipy.misc.imresize(
resized_image,
args.max_edge / max(resized_image.shape)
).astype('float')
if sum(resized_image.shape[: 2]) > args.max_sum_edges:
resized_image = scipy.misc.imresize(
resized_image,
args.max_sum_edges / sum(resized_image.shape[: 2])
).astype('float')
fact_i = image.shape[0] / resized_image.shape[0]
fact_j = image.shape[1] / resized_image.shape[1]
input_image = preprocess_image(
resized_image,
preprocessing=args.preprocessing
)
with torch.no_grad():
if args.multiscale:
keypoints, scores, descriptors = process_multiscale(
torch.tensor(
input_image[np.newaxis, :, :, :].astype(np.float32),
device=device
),
model
)
else:
keypoints, scores, descriptors = process_multiscale(
torch.tensor(
input_image[np.newaxis, :, :, :].astype(np.float32),
device=device
),
model,
scales=[1]
)
# Input image coordinates
keypoints[:, 0] *= fact_i
keypoints[:, 1] *= fact_j
# i, j -> u, v
keypoints = keypoints[:, [1, 0, 2]]
if args.output_type == 'npz':
with open(path + args.output_extension, 'wb') as output_file:
np.savez(
output_file,
keypoints=keypoints,
scores=scores,
descriptors=descriptors
)
elif args.output_type == 'mat':
with open(path + args.output_extension, 'wb') as output_file:
scipy.io.savemat(
output_file,
{
'keypoints': keypoints,
'scores': scores,
'descriptors': descriptors
}
)
else:
raise ValueError('Unknown output type.')