image-matching-webui / imcui /third_party /gim /video_preprocessor.py
Realcat
update: major change
499e141
raw
history blame
29.6 kB
# -*- coding: utf-8 -*-
# @Author : xuelun
import os
import cv2
import csv
import math
import torch
import scipy.io
import warnings
import argparse
import numpy as np
from os import mkdir
from tqdm import tqdm
from copy import deepcopy
from os.path import join, exists
from torch.utils.data import DataLoader
from datasets.walk.video_streamer import VideoStreamer
from datasets.walk.video_loader import WALKDataset, collate_fn
from networks.mit_semseg.models import ModelBuilder, SegmentationModule
gray2tensor = lambda x: (torch.from_numpy(x).float() / 255)[None, None]
color2tensor = lambda x: (torch.from_numpy(x).float() / 255).permute(2, 0, 1)[None]
warnings.simplefilter("ignore", category=UserWarning)
methods = {'SIFT', 'GIM_GLUE', 'GIM_LOFTR', 'GIM_DKM'}
PALETTE = scipy.io.loadmat('weights/color150.mat')['colors']
CLS_DICT = {} # {'person': 13, 'sky': 3}
with open('weights/object150_info.csv') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
name = row[5].split(";")[0]
if name == 'screen':
name = '_'.join(row[5].split(";")[:2])
CLS_DICT[name] = int(row[0]) - 1
exclude = ['person', 'sky', 'car']
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true')
parser.add_argument("--gpu", type=int,
default=0, help='-1 for CPU')
parser.add_argument("--range", type=int, nargs='+',
default=None,
help='Video Range for seconds')
parser.add_argument('--scene_name', type=str,
default=None,
help='Scene (video) name')
parser.add_argument('--method', type=str, choices=methods,
required=True,
help='Method name')
parser.add_argument('--resize', action='store_true',
help='whether resize')
parser.add_argument('--skip', type=int,
required=True,
help='Video skip frame: 1, 2, 3, ...')
parser.add_argument('--watermarker', type=int, nargs='+',
default=None,
help='Watermarker Rectangle Range')
opt = parser.parse_args()
data_root = join('data', 'ZeroMatch')
video_name = opt.scene_name.strip()
video_path = join(data_root, 'video_1080p', video_name + '.mp4')
# get real size of video
vcap = cv2.VideoCapture(video_path)
vwidth = vcap.get(3) # float `width`
vheight = vcap.get(4) # float `height`
fps = vcap.get(5) # float `fps`
end_range = math.floor(vcap.get(cv2.CAP_PROP_FRAME_COUNT) / fps - 300)
vcap.release()
fps = math.ceil(fps)
opt.range = [300, end_range] if opt.range is None else opt.range
opt.range = [0, -1] if video_name == 'Od-rKbC30TM' else opt.range # for demo
if fps <= 30:
skip = [10, 20, 40][opt.skip]
else:
skip = [20, 40, 80][opt.skip]
dump_dir = join(data_root, 'pseudo',
'WALK ' + opt.method +
' [R] ' + '{}'.format('T' if opt.resize else 'F') +
' [S] ' + '{:2}'.format(skip))
if not exists(dump_dir): mkdir(dump_dir)
debug_dir = join('dump', video_name + ' ' + opt.method)
if opt.resize: debug_dir = debug_dir + ' Resize'
if opt.debug and (not exists(debug_dir)): mkdir(debug_dir)
# start process video
gap = 10 if fps <= 30 else 20
vs = VideoStreamer(basedir=video_path, resize=opt.resize, df=8, skip=gap, vrange=opt.range)
# read the first frame
rgb = vs[vs.listing[0]]
width, height = rgb.shape[1], rgb.shape[0]
# calculate ratio
vratio = np.array([vwidth / width, vheight / height])[None]
# set dump name
scene_name = f'{video_name} '
scene_name += f'WH {width:4} {height:4} '
scene_name += f'RG {vs.range[0]:4} {vs.range[1]:4} '
scene_name += f'SP {skip} '
scene_name += f'{len(video_name)}'
save_dir = join(dump_dir, scene_name)
device = torch.device('cuda:{}'.format(opt.gpu)) if opt.gpu >= 0 else torch.device('cpu')
# initialize segmentation model
net_encoder = ModelBuilder.build_encoder(
arch='resnet50dilated',
fc_dim=2048,
weights='weights/encoder_epoch_20.pth')
net_decoder = ModelBuilder.build_decoder(
arch='ppm_deepsup',
fc_dim=2048,
num_class=150,
weights='weights/decoder_epoch_20.pth',
use_softmax=True)
crit = torch.nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit).to(device).eval()
old_segment_root = join(data_root, 'segment', opt.scene_name)
new_segment_root = join(data_root, 'segment', opt.scene_name.strip())
if not os.path.exists(new_segment_root):
if os.path.exists(old_segment_root):
os.rename(old_segment_root, new_segment_root)
else:
os.makedirs(new_segment_root, exist_ok=True)
segment_root = new_segment_root
model, detectAndCompute = None, None
if opt.method == 'SIFT':
model = cv2.SIFT_create(nfeatures=32400, contrastThreshold=1e-5)
detectAndCompute = model.detectAndCompute
elif opt.method == 'GIM_DKM':
from networks.dkm.models.model_zoo.DKMv3 import DKMv3
model = DKMv3(weights=None, h=672, w=896)
checkpoints_path = join('weights', 'gim_dkm_100h.ckpt')
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
if 'encoder.net.fc' in k:
state_dict.pop(k)
model.load_state_dict(state_dict)
model = model.eval().to(device)
elif opt.method == 'GIM_LOFTR':
from networks.loftr.loftr import LoFTR
from networks.loftr.misc import lower_config
from networks.loftr.config import get_cfg_defaults
cfg = get_cfg_defaults()
cfg.TEMP_BUG_FIX = True
cfg.LOFTR.WEIGHT = 'weights/gim_loftr_50h.ckpt'
cfg.LOFTR.FINE_CONCAT_COARSE_FEAT = False
cfg = lower_config(cfg)
model = LoFTR(cfg['loftr'])
model = model.to(device)
model = model.eval()
elif opt.method == 'GIM_GLUE':
from networks.lightglue.matching import Matching
model = Matching()
checkpoints_path = join('weights', 'gim_lightglue_100h.ckpt')
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict.pop(k)
if k.startswith('superpoint.'):
state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k)
model.detector.load_state_dict(state_dict)
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('superpoint.'):
state_dict.pop(k)
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
model.model.load_state_dict(state_dict)
model = model.to(device)
model = model.eval()
cache_dir = None
if opt.resize:
cache_dir = join(data_root, 'pseudo',
'WALK ' + 'GIM_DKM' +
' [R] F' +
' [S] ' + '{:2}'.format(skip),
scene_name)
_w_ = width if opt.method == 'SIFT' or opt.method == 'GLUE' else 1600 # TODO: confirm DKM
_h_ = height if opt.method == 'SIFT' or opt.method == 'GLUE' else 900 # TODO: confirm DKM
ids = list(zip(vs.listing[:-skip // gap], vs.listing[skip // gap:]))
# start matching and make pseudo labels
nums = None
idxs = None
checkpoint = 0
if not opt.debug:
if exists(join(save_dir, 'nums.npy')) and exists(join(save_dir, 'idxs.npy')):
with open(join(save_dir, 'nums.npy'), 'rb') as f:
nums = np.load(f)
with open(join(save_dir, 'idxs.npy'), 'rb') as f:
idxs = np.load(f)
assert len(nums) == len(idxs) == (len(os.listdir(save_dir)) - 2)
whole = [str(x) + '.npy' for x in np.array(ids)]
cache = [str(x) + '.npy' for x in idxs]
leave = list(set(whole) - set(cache))
if len(leave):
leave = list(map(lambda x: int(x.rsplit('[')[-1].strip().split()[0]), leave))
skip_id = np.array(sorted(leave))
skip_id = (skip_id[1:] - skip_id[:-1]) // gap
len_id = len(skip_id)
if len_id == 0: exit(0)
skip_id = [i for i in range(len_id) if skip_id[i:].sum() == (len_id - i)]
if len(skip_id) == 0: exit(0)
skip_id = skip_id[0]
checkpoint = np.where(np.array(ids)[:, 0]==sorted(leave)[skip_id])[0][0]
if len(nums) + skip_id > checkpoint: exit(0)
assert checkpoint == len(nums) + skip_id
else:
exit(0)
else:
if not exists(save_dir): mkdir(save_dir)
nums = np.array([])
idxs = np.array([])
datasets = WALKDataset(data_root, vs=vs, ids=ids, checkpoint=checkpoint, opt=opt)
loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 5,
'pin_memory': True, 'drop_last': False}
loader = DataLoader(datasets, collate_fn=collate_fn, **loader_params)
for i, batch in enumerate(tqdm(loader, ncols=120, bar_format="{l_bar}{bar:3}{r_bar}",
desc='{:11} - [{:5}, {:2}{}]'.format(video_name[:40], opt.method, skip, '*' if opt.resize else ''),
total=len(loader), leave=False)):
idx = batch['idx'].item()
assert i == idx
idx0 = batch['idx0'].item()
idx1 = batch['idx1'].item()
assert idx0 == ids[idx+checkpoint][0] and idx1 == ids[idx+checkpoint][1]
# cache loaded image
if not batch['rgb0_is_good'].item():
img_path0 = batch['img_path0'][0]
if not os.path.exists(img_path0):
cv2.imwrite(img_path0, batch['rgb0'].squeeze(0).numpy())
if not batch['rgb1_is_good'].item():
img_path1 = batch['img_path1'][0]
if not os.path.exists(img_path1):
cv2.imwrite(img_path1, batch['rgb1'].squeeze(0).numpy())
current_id = np.array([idx0, idx1])
save_name = '{}.npy'.format(str(current_id))
save_path = join(save_dir, save_name)
if exists(save_path) and not opt.debug: continue
rgb0 = batch['rgb0'].squeeze(0).numpy()
rgb1 = batch['rgb1'].squeeze(0).numpy()
_rgb0_, _rgb1_ = deepcopy(rgb0), deepcopy(rgb1)
# get correspondeces in unresize image
pt0, pt1 = None, None
if opt.resize:
cache_path = join(cache_dir, save_name)
if not exists(cache_path): continue
with open(cache_path, 'rb') as f:
pts = np.load(f)
pt0, pt1 = pts[:, :2], pts[:, 2:]
# process first frame image
xA0, xA1, yA0, yA1, hA, wA, wA_new, hA_new = None, None, None, None, None, None, None, None
if opt.resize:
# crop rgb0
xA0 = math.floor(pt0[:, 0].min())
xA1 = math.ceil(pt0[:, 0].max())
yA0 = math.floor(pt0[:, 1].min())
yA1 = math.ceil(pt0[:, 1].max())
rgb0 = rgb0[yA0:yA1, xA0:xA1]
hA, wA = rgb0.shape[:2]
wA_new, hA_new = get_resized_wh(wA, hA, [_h_, _w_])
wA_new, hA_new = get_divisible_wh(wA_new, hA_new, 8)
rgb0 = cv2.resize(rgb0, (wA_new, hA_new), interpolation=cv2.INTER_AREA)
# go on
gray0 = cv2.cvtColor(rgb0, cv2.COLOR_RGB2GRAY)
# semantic segmentation
with torch.no_grad():
seg_path0 = join(segment_root, '{}.npy'.format(idx0))
if not os.path.exists(seg_path0):
mask0 = segment(_rgb0_, device, segmentation_module)
np.save(seg_path0, mask0)
else:
mask0 = np.load(seg_path0)
# process next frame image
xB0, xB1, yB0, yB1, hB, wB, wB_new, hB_new = None, None, None, None, None, None, None, None
if opt.resize:
# crop rgb1
xB0 = math.floor(pt1[:, 0].min())
xB1 = math.ceil(pt1[:, 0].max())
yB0 = math.floor(pt1[:, 1].min())
yB1 = math.ceil(pt1[:, 1].max())
rgb1 = rgb1[yB0:yB1, xB0:xB1]
hB, wB = rgb1.shape[:2]
wB_new, hB_new = get_resized_wh(wB, hB, [_h_, _w_])
wB_new, hB_new = get_divisible_wh(wB_new, hB_new, 8)
rgb1 = cv2.resize(rgb1, (wB_new, hB_new), interpolation=cv2.INTER_AREA)
# go on
gray1 = cv2.cvtColor(rgb1, cv2.COLOR_RGB2GRAY)
# semantic segmentation
with torch.no_grad():
seg_path1 = join(segment_root, '{}.npy'.format(idx1))
if not os.path.exists(seg_path1):
mask1 = segment(_rgb1_, device, segmentation_module)
np.save(seg_path1, mask1)
else:
mask1 = np.load(seg_path1)
if mask0.shape[:2] != _rgb0_.shape[:2]:
mask0 = cv2.resize(mask0, _rgb0_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
if mask1.shape != _rgb1_.shape[:2]:
mask1 = cv2.resize(mask1, _rgb1_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
if opt.resize:
# resize mask0
mask0 = mask0[yA0:yA1, xA0:xA1]
mask0 = cv2.resize(mask0, (wA_new, hA_new), interpolation=cv2.INTER_NEAREST)
# resize mask1
mask1 = mask1[yB0:yB1, xB0:xB1]
mask1 = cv2.resize(mask1, (wB_new, hB_new), interpolation=cv2.INTER_NEAREST)
data = None
if opt.method == 'SIFT':
mask_0 = mask0 != CLS_DICT[exclude[0]]
mask_1 = mask1 != CLS_DICT[exclude[0]]
for cls in exclude[1:]:
mask_0 = mask_0 & (mask0 != CLS_DICT[cls])
mask_1 = mask_1 & (mask1 != CLS_DICT[cls])
mask_0 = mask_0.astype(np.uint8)
mask_1 = mask_1.astype(np.uint8)
if mask_0.sum() == 0 or mask_1.sum() == 0: continue
# keypoint detection and description
kpts0, desc0 = detectAndCompute(rgb0, mask_0)
if desc0 is None or desc0.shape[0] < 8: continue
kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0])
kpts0, desc0 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts0, desc0])
desc0 = (desc0 / desc0.sum(dim=1, keepdim=True)).sqrt()
# keypoint detection and description
kpts1, desc1 = detectAndCompute(rgb1, mask_1)
if desc1 is None or desc1.shape[0] < 8: continue
kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1])
kpts1, desc1 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts1, desc1])
desc1 = (desc1 / desc1.sum(dim=1, keepdim=True)).sqrt()
# mutual nearest matching and ratio filter
matches = desc0 @ desc1.transpose(0, 1)
mask = (matches == matches.max(dim=1, keepdim=True).values) & \
(matches == matches.max(dim=0, keepdim=True).values)
# noinspection PyUnresolvedReferences
valid, indices = mask.max(dim=1)
ratio = torch.topk(matches, k=2, dim=1).values
ratio = (-2 * ratio + 2).sqrt()
# ratio = (ratio[:, 0] / ratio[:, 1]) < opt.mt
ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8
valid = valid & ratio
# get matched keypoints
mkpts0 = kpts0[valid]
mkpts1 = kpts1[indices[valid]]
b_ids = torch.where(valid[None])[0]
data = dict(
m_bids = b_ids,
mkpts0_f = mkpts0,
mkpts1_f = mkpts1,
)
elif opt.method == 'GIM_DKM':
mask_0 = mask0 != CLS_DICT[exclude[0]]
mask_1 = mask1 != CLS_DICT[exclude[0]]
for cls in exclude[1:]:
mask_0 = mask_0 & (mask0 != CLS_DICT[cls])
mask_1 = mask_1 & (mask1 != CLS_DICT[cls])
mask_0 = mask_0.astype(np.uint8)
mask_1 = mask_1.astype(np.uint8)
if mask_0.sum() == 0 or mask_1.sum() == 0: continue
img0 = rgb0 * mask_0[..., None]
img1 = rgb1 * mask_1[..., None]
width0, height0 = img0.shape[1], img0.shape[0]
width1, height1 = img1.shape[1], img1.shape[0]
with torch.no_grad():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
img0 = torch.from_numpy(img0).permute(2, 0, 1).to(device)[None] / 255
img1 = torch.from_numpy(img1).permute(2, 0, 1).to(device)[None] / 255
dense_matches, dense_certainty = model.match(img0, img1)
sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000)
mkpts0 = sparse_matches[:, :2]
mkpts0 = torch.stack((width0 * (mkpts0[:, 0] + 1) / 2,
height0 * (mkpts0[:, 1] + 1) / 2), dim=-1)
mkpts1 = sparse_matches[:, 2:]
mkpts1 = torch.stack((width1 * (mkpts1[:, 0] + 1) / 2,
height1 * (mkpts1[:, 1] + 1) / 2), dim=-1)
m_bids = torch.zeros(sparse_matches.shape[0], dtype=torch.long, device=device)
data = dict(
m_bids = m_bids,
mkpts0_f = mkpts0,
mkpts1_f = mkpts1,
)
elif opt.method == 'GIM_LOFTR':
mask_0 = mask0 != CLS_DICT[exclude[0]]
mask_1 = mask1 != CLS_DICT[exclude[0]]
for cls in exclude[1:]:
mask_0 = mask_0 & (mask0 != CLS_DICT[cls])
mask_1 = mask_1 & (mask1 != CLS_DICT[cls])
mask_0 = mask_0.astype(np.uint8)
mask_1 = mask_1.astype(np.uint8)
if mask_0.sum() == 0 or mask_1.sum() == 0: continue
mask_0 = cv2.resize(mask_0, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST)
mask_1 = cv2.resize(mask_1, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST)
data = dict(
image0=gray2tensor(gray0),
image1=gray2tensor(gray1),
color0=color2tensor(rgb0),
color1=color2tensor(rgb1),
mask0=torch.from_numpy(mask_0)[None],
mask1=torch.from_numpy(mask_1)[None],
)
with torch.no_grad():
data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v
in data.items()}
model(data)
elif opt.method == 'GIM_GLUE':
mask_0 = mask0 != CLS_DICT[exclude[0]]
mask_1 = mask1 != CLS_DICT[exclude[0]]
for cls in exclude[1:]:
mask_0 = mask_0 & (mask0 != CLS_DICT[cls])
mask_1 = mask_1 & (mask1 != CLS_DICT[cls])
mask_0 = mask_0.astype(np.uint8)
mask_1 = mask_1.astype(np.uint8)
if mask_0.sum() == 0 or mask_1.sum() == 0: continue
size0 = torch.tensor(gray0.shape[-2:][::-1])[None]
size1 = torch.tensor(gray1.shape[-2:][::-1])[None]
data = dict(
gray0 = gray2tensor(gray0 * mask_0),
gray1 = gray2tensor(gray1 * mask_1),
size0 = size0,
size1 = size1,
)
with torch.no_grad():
data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v
in data.items()}
pred = model(data)
kpts0, kpts1 = pred['keypoints0'][0], pred['keypoints1'][0]
matches = pred['matches'][0]
if len(matches) == 0: continue
mkpts0 = kpts0[matches[..., 0]]
mkpts1 = kpts1[matches[..., 1]]
m_bids = torch.zeros(matches[..., 0].size(), dtype=torch.long, device=device)
data = dict(
m_bids = m_bids,
mkpts0_f = mkpts0,
mkpts1_f = mkpts1,
)
# auto remove watermarker
kpts0 = data['mkpts0_f'].clone() # (N, 2)
kpts1 = data['mkpts1_f'].clone() # (N, 2)
moved = ~((kpts0 - kpts1).abs() < 1).min(dim=1).values # (N)
data['m_bids'] = data['m_bids'][moved]
data['mkpts0_f'] = data['mkpts0_f'][moved]
data['mkpts1_f'] = data['mkpts1_f'][moved]
robust_fitting(data)
if (data['inliers'] is None) or (sum(data['inliers'][0]) == 0): continue
inliers = data['inliers'][0]
if opt.debug:
data.update(dict(
# for debug visualization
mask0 = mask0,
mask1 = mask1,
gray0 = gray0,
gray1 = gray1,
color0 = rgb0,
color1 = rgb1,
hw0_i = rgb0.shape[:2],
hw1_i = rgb1.shape[:2],
dataset_name = ['WALK'],
scene_id = [video_name],
pair_id = [[idx0, idx1]],
imsize0=[[width, height]],
imsize1=[[width, height]],
))
out = fast_make_matching_robust_fitting_figure(data)
cv2.imwrite(join(debug_dir, '{} {:8d} {:8d}.png'.format(scene_name, idx0, idx1)),
cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
continue
if opt.resize:
mkpts0_f = (data['mkpts0_f'].cpu().numpy()[inliers] * np.array([[wA/wA_new, hA/hA_new]]) + np.array([[xA0, yA0]])) * vratio
mkpts1_f = (data['mkpts1_f'].cpu().numpy()[inliers] * np.array([[wB/wB_new, hB/hB_new]]) + np.array([[xB0, yB0]])) * vratio
else:
mkpts0_f = data['mkpts0_f'].cpu().numpy()[inliers] * vratio
mkpts1_f = data['mkpts1_f'].cpu().numpy()[inliers] * vratio
pts = np.concatenate([mkpts0_f, mkpts1_f], axis=1).astype(np.float32)
nums = np.concatenate([nums, np.array([len(pts)])], axis=0) if len(nums) else np.array([len(pts)])
idxs = np.concatenate([idxs, current_id[None]], axis=0) if len(idxs) else current_id[None]
with open(save_path, 'wb') as f:
np.save(f, pts)
with open(join(save_dir, 'nums.npy'), 'wb') as f:
np.save(f, nums)
with open(join(save_dir, 'idxs.npy'), 'wb') as f:
np.save(f, idxs)
def robust_fitting(data, b_id=0):
m_bids = data['m_bids'].cpu().numpy()
kpts0 = data['mkpts0_f'].cpu().numpy()
kpts1 = data['mkpts1_f'].cpu().numpy()
mask = m_bids == b_id
# noinspection PyBroadException
try:
_, mask = cv2.findFundamentalMat(kpts0[mask], kpts1[mask], cv2.USAC_MAGSAC, ransacReprojThreshold=0.5, confidence=0.999999, maxIters=100000)
mask = (mask.ravel() > 0)[None]
except:
mask = None
data.update(dict(inliers=mask))
def get_resized_wh(w, h, resize):
nh, nw = resize
sh, sw = nh / h, nw / w
scale = min(sh, sw)
w_new, h_new = int(round(w*scale)), int(round(h*scale))
return w_new, h_new
def get_divisible_wh(w, h, df=None):
if df is not None:
w_new = max((w // df), 1) * df
h_new = max((h // df), 1) * df
else:
w_new, h_new = w, h
return w_new, h_new
def read_deeplab_image(img, size=1920):
width, height = img.shape[1], img.shape[0]
if max(width, height) > size:
if width > height:
img = cv2.resize(img, (size, int(size * height / width)), interpolation=cv2.INTER_AREA)
else:
img = cv2.resize(img, (int(size * width / height), size), interpolation=cv2.INTER_AREA)
img = (torch.from_numpy(img).float() / 255).permute(2, 0, 1)[None]
return img
def read_segmentation_image(img):
img = read_deeplab_image(img, size=720)[0]
img = img - torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
img = img / torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
return img
def segment(rgb, device, segmentation_module):
img_data = read_segmentation_image(rgb)
singleton_batch = {'img_data': img_data[None].to(device)}
output_size = img_data.shape[1:]
# Run the segmentation at the highest resolution.
scores = segmentation_module(singleton_batch, segSize=output_size)
# Get the predicted scores for each pixel
_, pred = torch.max(scores, dim=1)
return pred.cpu()[0].numpy().astype(np.uint8)
def getLabel(pair, idxs, nums, h5py_i, h5py_f):
"""
Args:
pair: [6965 6970]
idxs: (N, 2)
nums: (N,)
h5py_i: (M, 2)
h5py_f: (M, 2)
Returns: pseudo_label (N, 4)
"""
i, j = np.where(idxs == pair)
if len(i) == 0: return None
assert (len(i) == len(j) == 2) and (i[0] == i[1]) and (j[0] == 0) and (j[1] == 1)
i = i[0]
nums = nums[:i+1]
idx0, idx1 = sum(nums[:-1]), sum(nums)
mkpts0 = h5py_i[idx0:idx1]
mkpts1 = h5py_f[idx0:idx1] # (N, 2)
return mkpts0, mkpts1
def fast_make_matching_robust_fitting_figure(data, b_id=0):
b_mask = data['m_bids'] == b_id
gray0 = data['gray0']
gray1 = data['gray1']
kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
margin = 2
(h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i']
h, w = max(h0, h1), max(w0, w1)
H, W = margin * 5 + h * 4, margin * 3 + w * 2
# canvas
out = 255 * np.ones((H, W), np.uint8)
wx = [margin, margin + w0, margin + w + margin, margin + w + margin + w1]
hx = lambda row: margin * row + h * (row-1)
out = np.stack([out] * 3, -1)
sh = hx(row=1)
color0 = data['color0'] # (rH, rW, 3)
color1 = data['color1'] # (rH, rW, 3)
out[sh: sh + h0, wx[0]: wx[1]] = color0
out[sh: sh + h1, wx[2]: wx[3]] = color1
sh = hx(row=2)
img0 = np.stack([gray0] * 3, -1) * 0
for cls in exclude: img0[data['mask0'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]]
out[sh: sh + h0, wx[0]: wx[1]] = img0
img1 = np.stack([gray1] * 3, -1) * 0
for cls in exclude: img1[data['mask1'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]]
out[sh: sh + h1, wx[2]: wx[3]] = img1
# before outlier filtering
sh = hx(row=3)
mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1)
out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1)
for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1):
# display line end-points as circles
c = (230, 216, 132)
cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA)
cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA)
# after outlier filtering
if data['inliers'] is not None:
sh = hx(row=4)
inliers = data['inliers'][b_id]
mkpts0, mkpts1 = np.round(kpts0).astype(int)[inliers], np.round(kpts1).astype(int)[inliers]
out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1)
out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1)
for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1):
# display line end-points as circles
c = (230, 216, 132)
cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA)
cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA)
# Big text.
text = [
f' ',
f'#Matches {len(kpts0)}',
f'#Matches {sum(data["inliers"][b_id]) if data["inliers"] is not None else 0}',
]
sc = min(H / 640., 1.0)
Ht = int(30 * sc) # text height
txt_color_fg = (255, 255, 255) # white
txt_color_bg = (0, 0, 0) # black
for i, t in enumerate(text):
cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA)
cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA)
fingerprint = [
'Dataset: {}'.format(data['dataset_name'][b_id]),
'Scene ID: {}'.format(data['scene_id'][b_id]),
'Pair ID: {}'.format(data['pair_id'][b_id]),
'Image sizes: {} - {}'.format(data['imsize0'][b_id],
data['imsize1'][b_id]),
]
sc = min(H / 640., 1.0)
Ht = int(18 * sc) # text height
txt_color_fg = (255, 255, 255) # white
txt_color_bg = (0, 0, 0) # black
for i, t in enumerate(reversed(fingerprint)):
cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_bg, 2, cv2.LINE_AA)
cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_fg, 1, cv2.LINE_AA)
return out
if __name__ == '__main__':
with torch.no_grad():
main()