Realcat
update: major change
499e141
raw
history blame
4.73 kB
import torch
import argparse
from lib.models.builder import build_model
from lib.datasets.utils import correct_intrinsic_scale
from lib.models.MicKey.modules.utils.training_utils import colorize, generate_heat_map
from config.default import cfg
import numpy as np
from pathlib import Path
import cv2
def prepare_score_map(scs, img, temperature=0.5):
score_map = generate_heat_map(scs, img, temperature)
score_map = 255 * score_map.permute(1, 2, 0).numpy()
return score_map
def colorize_depth(value, vmin=None, vmax=None, cmap='magma_r', invalid_val=-99, invalid_mask=None, background_color=(0, 0, 0, 255), gamma_corrected=False, value_transform=None):
img = colorize(value, vmin, vmax, cmap, invalid_val, invalid_mask, background_color, gamma_corrected, value_transform)
shape_im = img.shape
img = np.asarray(img, np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
img = cv2.resize(img, (shape_im[1]*14, shape_im[0]*14), interpolation=cv2.INTER_LINEAR)
return img
def read_color_image(path, resize=(540, 720)):
"""
Args:
resize (tuple): align image to depthmap, in (w, h).
Returns:
image (torch.tensor): (3, h, w)
"""
# read and resize image
cv_type = cv2.IMREAD_COLOR
image = cv2.imread(str(path), cv_type)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if resize:
image = cv2.resize(image, resize)
# (h, w, 3) -> (3, h, w) and normalized
image = torch.from_numpy(image).float().permute(2, 0, 1) / 255
return image.unsqueeze(0)
def read_intrinsics(path_intrinsics, resize=None):
Ks = {}
with Path(path_intrinsics).open('r') as f:
for line in f.readlines():
if '#' in line:
continue
line = line.strip().split(' ')
img_name = line[0]
fx, fy, cx, cy, W, H = map(float, line[1:])
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
if resize is not None:
K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H)
Ks[img_name] = K
return Ks
def run_demo_inference(args):
# Select device
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
print('Preparing data...')
# Prepare config file
cfg.merge_from_file(args.config)
# Prepare the model
model = build_model(cfg, checkpoint=args.checkpoint)
# Load demo images
im0 = read_color_image(args.im_path_ref).to(device)
im1 = read_color_image(args.im_path_dst).to(device)
# Load intrinsics
K = read_intrinsics(args.intrinsics)
# Prepare data for MicKey
data = {}
data['image0'] = im0
data['image1'] = im1
data['K_color0'] = torch.from_numpy(K['im0.jpg']).unsqueeze(0).to(device)
data['K_color1'] = torch.from_numpy(K['im1.jpg']).unsqueeze(0).to(device)
# Run inference
print('Running MicKey relative pose estimation...')
model(data)
# Pose, inliers and score are stored in:
# data['R'] = R
# data['t'] = t
# data['inliers'] = inliers
# data['inliers_list'] = inliers_list
print('Saving depth and score maps in image directory ...')
depth0_map = colorize_depth(data['depth0_map'][0], invalid_mask=(data['depth0_map'][0] < 0.001).cpu()[0])
depth1_map = colorize_depth(data['depth1_map'][0], invalid_mask=(data['depth1_map'][0] < 0.001).cpu()[0])
score0_map = prepare_score_map(data['scr0'][0], data['image0'][0], temperature=0.5)
score1_map = prepare_score_map(data['scr1'][0], data['image1'][0], temperature=0.5)
ext_im0 = args.im_path_ref.split('.')[-1]
ext_im1 = args.im_path_dst.split('.')[-1]
cv2.imwrite(args.im_path_ref.replace(ext_im0, 'score.jpg'), score0_map)
cv2.imwrite(args.im_path_dst.replace(ext_im1, 'score.jpg'), score1_map)
cv2.imwrite(args.im_path_ref.replace(ext_im0, 'depth.jpg'), depth0_map)
cv2.imwrite(args.im_path_dst.replace(ext_im1, 'depth.jpg'), depth1_map)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--im_path_ref', help='path to reference image', default='data/toy_example/im0.jpg')
parser.add_argument('--im_path_dst', help='path to destination image', default='data/toy_example/im1.jpg')
parser.add_argument('--intrinsics', help='path to intrinsics file', default='data/toy_example/intrinsics.txt')
parser.add_argument('--config', help='path to config file', default='weights/mickey_weights/config.yaml')
parser.add_argument('--checkpoint', help='path to model checkpoint',
default='weights/mickey_weights/mickey.ckpt')
args = parser.parse_args()
run_demo_inference(args)