Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # -*- coding: utf-8 -*- | |
| # @Author : xuelun | |
| import os | |
| import cv2 | |
| import csv | |
| import math | |
| import torch | |
| import scipy.io | |
| import warnings | |
| import argparse | |
| import numpy as np | |
| from os import mkdir | |
| from tqdm import tqdm | |
| from copy import deepcopy | |
| from os.path import join, exists | |
| from torch.utils.data import DataLoader | |
| from datasets.walk.video_streamer import VideoStreamer | |
| from datasets.walk.video_loader import WALKDataset, collate_fn | |
| from networks.mit_semseg.models import ModelBuilder, SegmentationModule | |
| gray2tensor = lambda x: (torch.from_numpy(x).float() / 255)[None, None] | |
| color2tensor = lambda x: (torch.from_numpy(x).float() / 255).permute(2, 0, 1)[None] | |
| warnings.simplefilter("ignore", category=UserWarning) | |
| methods = {'SIFT', 'GIM_GLUE', 'GIM_LOFTR', 'GIM_DKM'} | |
| PALETTE = scipy.io.loadmat('weights/color150.mat')['colors'] | |
| CLS_DICT = {} # {'person': 13, 'sky': 3} | |
| with open('weights/object150_info.csv') as f: | |
| reader = csv.reader(f) | |
| next(reader) | |
| for row in reader: | |
| name = row[5].split(";")[0] | |
| if name == 'screen': | |
| name = '_'.join(row[5].split(";")[:2]) | |
| CLS_DICT[name] = int(row[0]) - 1 | |
| exclude = ['person', 'sky', 'car'] | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--debug', action='store_true') | |
| parser.add_argument("--gpu", type=int, | |
| default=0, help='-1 for CPU') | |
| parser.add_argument("--range", type=int, nargs='+', | |
| default=None, | |
| help='Video Range for seconds') | |
| parser.add_argument('--scene_name', type=str, | |
| default=None, | |
| help='Scene (video) name') | |
| parser.add_argument('--method', type=str, choices=methods, | |
| required=True, | |
| help='Method name') | |
| parser.add_argument('--resize', action='store_true', | |
| help='whether resize') | |
| parser.add_argument('--skip', type=int, | |
| required=True, | |
| help='Video skip frame: 1, 2, 3, ...') | |
| parser.add_argument('--watermarker', type=int, nargs='+', | |
| default=None, | |
| help='Watermarker Rectangle Range') | |
| opt = parser.parse_args() | |
| data_root = join('data', 'ZeroMatch') | |
| video_name = opt.scene_name.strip() | |
| video_path = join(data_root, 'video_1080p', video_name + '.mp4') | |
| # get real size of video | |
| vcap = cv2.VideoCapture(video_path) | |
| vwidth = vcap.get(3) # float `width` | |
| vheight = vcap.get(4) # float `height` | |
| fps = vcap.get(5) # float `fps` | |
| end_range = math.floor(vcap.get(cv2.CAP_PROP_FRAME_COUNT) / fps - 300) | |
| vcap.release() | |
| fps = math.ceil(fps) | |
| opt.range = [300, end_range] if opt.range is None else opt.range | |
| opt.range = [0, -1] if video_name == 'Od-rKbC30TM' else opt.range # for demo | |
| if fps <= 30: | |
| skip = [10, 20, 40][opt.skip] | |
| else: | |
| skip = [20, 40, 80][opt.skip] | |
| dump_dir = join(data_root, 'pseudo', | |
| 'WALK ' + opt.method + | |
| ' [R] ' + '{}'.format('T' if opt.resize else 'F') + | |
| ' [S] ' + '{:2}'.format(skip)) | |
| if not exists(dump_dir): mkdir(dump_dir) | |
| debug_dir = join('dump', video_name + ' ' + opt.method) | |
| if opt.resize: debug_dir = debug_dir + ' Resize' | |
| if opt.debug and (not exists(debug_dir)): mkdir(debug_dir) | |
| # start process video | |
| gap = 10 if fps <= 30 else 20 | |
| vs = VideoStreamer(basedir=video_path, resize=opt.resize, df=8, skip=gap, vrange=opt.range) | |
| # read the first frame | |
| rgb = vs[vs.listing[0]] | |
| width, height = rgb.shape[1], rgb.shape[0] | |
| # calculate ratio | |
| vratio = np.array([vwidth / width, vheight / height])[None] | |
| # set dump name | |
| scene_name = f'{video_name} ' | |
| scene_name += f'WH {width:4} {height:4} ' | |
| scene_name += f'RG {vs.range[0]:4} {vs.range[1]:4} ' | |
| scene_name += f'SP {skip} ' | |
| scene_name += f'{len(video_name)}' | |
| save_dir = join(dump_dir, scene_name) | |
| device = torch.device('cuda:{}'.format(opt.gpu)) if opt.gpu >= 0 else torch.device('cpu') | |
| # initialize segmentation model | |
| net_encoder = ModelBuilder.build_encoder( | |
| arch='resnet50dilated', | |
| fc_dim=2048, | |
| weights='weights/encoder_epoch_20.pth') | |
| net_decoder = ModelBuilder.build_decoder( | |
| arch='ppm_deepsup', | |
| fc_dim=2048, | |
| num_class=150, | |
| weights='weights/decoder_epoch_20.pth', | |
| use_softmax=True) | |
| crit = torch.nn.NLLLoss(ignore_index=-1) | |
| segmentation_module = SegmentationModule(net_encoder, net_decoder, crit).to(device).eval() | |
| old_segment_root = join(data_root, 'segment', opt.scene_name) | |
| new_segment_root = join(data_root, 'segment', opt.scene_name.strip()) | |
| if not os.path.exists(new_segment_root): | |
| if os.path.exists(old_segment_root): | |
| os.rename(old_segment_root, new_segment_root) | |
| else: | |
| os.makedirs(new_segment_root, exist_ok=True) | |
| segment_root = new_segment_root | |
| model, detectAndCompute = None, None | |
| if opt.method == 'SIFT': | |
| model = cv2.SIFT_create(nfeatures=32400, contrastThreshold=1e-5) | |
| detectAndCompute = model.detectAndCompute | |
| elif opt.method == 'GIM_DKM': | |
| from networks.dkm.models.model_zoo.DKMv3 import DKMv3 | |
| model = DKMv3(weights=None, h=672, w=896) | |
| checkpoints_path = join('weights', 'gim_dkm_100h.ckpt') | |
| state_dict = torch.load(checkpoints_path, map_location='cpu') | |
| if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
| for k in list(state_dict.keys()): | |
| if k.startswith('model.'): | |
| state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
| if 'encoder.net.fc' in k: | |
| state_dict.pop(k) | |
| model.load_state_dict(state_dict) | |
| model = model.eval().to(device) | |
| elif opt.method == 'GIM_LOFTR': | |
| from networks.loftr.loftr import LoFTR | |
| from networks.loftr.misc import lower_config | |
| from networks.loftr.config import get_cfg_defaults | |
| cfg = get_cfg_defaults() | |
| cfg.TEMP_BUG_FIX = True | |
| cfg.LOFTR.WEIGHT = 'weights/gim_loftr_50h.ckpt' | |
| cfg.LOFTR.FINE_CONCAT_COARSE_FEAT = False | |
| cfg = lower_config(cfg) | |
| model = LoFTR(cfg['loftr']) | |
| model = model.to(device) | |
| model = model.eval() | |
| elif opt.method == 'GIM_GLUE': | |
| from networks.lightglue.matching import Matching | |
| model = Matching() | |
| checkpoints_path = join('weights', 'gim_lightglue_100h.ckpt') | |
| state_dict = torch.load(checkpoints_path, map_location='cpu') | |
| if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
| for k in list(state_dict.keys()): | |
| if k.startswith('model.'): | |
| state_dict.pop(k) | |
| if k.startswith('superpoint.'): | |
| state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) | |
| model.detector.load_state_dict(state_dict) | |
| state_dict = torch.load(checkpoints_path, map_location='cpu') | |
| if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] | |
| for k in list(state_dict.keys()): | |
| if k.startswith('superpoint.'): | |
| state_dict.pop(k) | |
| if k.startswith('model.'): | |
| state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) | |
| model.model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model = model.eval() | |
| cache_dir = None | |
| if opt.resize: | |
| cache_dir = join(data_root, 'pseudo', | |
| 'WALK ' + 'GIM_DKM' + | |
| ' [R] F' + | |
| ' [S] ' + '{:2}'.format(skip), | |
| scene_name) | |
| _w_ = width if opt.method == 'SIFT' or opt.method == 'GLUE' else 1600 # TODO: confirm DKM | |
| _h_ = height if opt.method == 'SIFT' or opt.method == 'GLUE' else 900 # TODO: confirm DKM | |
| ids = list(zip(vs.listing[:-skip // gap], vs.listing[skip // gap:])) | |
| # start matching and make pseudo labels | |
| nums = None | |
| idxs = None | |
| checkpoint = 0 | |
| if not opt.debug: | |
| if exists(join(save_dir, 'nums.npy')) and exists(join(save_dir, 'idxs.npy')): | |
| with open(join(save_dir, 'nums.npy'), 'rb') as f: | |
| nums = np.load(f) | |
| with open(join(save_dir, 'idxs.npy'), 'rb') as f: | |
| idxs = np.load(f) | |
| assert len(nums) == len(idxs) == (len(os.listdir(save_dir)) - 2) | |
| whole = [str(x) + '.npy' for x in np.array(ids)] | |
| cache = [str(x) + '.npy' for x in idxs] | |
| leave = list(set(whole) - set(cache)) | |
| if len(leave): | |
| leave = list(map(lambda x: int(x.rsplit('[')[-1].strip().split()[0]), leave)) | |
| skip_id = np.array(sorted(leave)) | |
| skip_id = (skip_id[1:] - skip_id[:-1]) // gap | |
| len_id = len(skip_id) | |
| if len_id == 0: exit(0) | |
| skip_id = [i for i in range(len_id) if skip_id[i:].sum() == (len_id - i)] | |
| if len(skip_id) == 0: exit(0) | |
| skip_id = skip_id[0] | |
| checkpoint = np.where(np.array(ids)[:, 0]==sorted(leave)[skip_id])[0][0] | |
| if len(nums) + skip_id > checkpoint: exit(0) | |
| assert checkpoint == len(nums) + skip_id | |
| else: | |
| exit(0) | |
| else: | |
| if not exists(save_dir): mkdir(save_dir) | |
| nums = np.array([]) | |
| idxs = np.array([]) | |
| datasets = WALKDataset(data_root, vs=vs, ids=ids, checkpoint=checkpoint, opt=opt) | |
| loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 5, | |
| 'pin_memory': True, 'drop_last': False} | |
| loader = DataLoader(datasets, collate_fn=collate_fn, **loader_params) | |
| for i, batch in enumerate(tqdm(loader, ncols=120, bar_format="{l_bar}{bar:3}{r_bar}", | |
| desc='{:11} - [{:5}, {:2}{}]'.format(video_name[:40], opt.method, skip, '*' if opt.resize else ''), | |
| total=len(loader), leave=False)): | |
| idx = batch['idx'].item() | |
| assert i == idx | |
| idx0 = batch['idx0'].item() | |
| idx1 = batch['idx1'].item() | |
| assert idx0 == ids[idx+checkpoint][0] and idx1 == ids[idx+checkpoint][1] | |
| # cache loaded image | |
| if not batch['rgb0_is_good'].item(): | |
| img_path0 = batch['img_path0'][0] | |
| if not os.path.exists(img_path0): | |
| cv2.imwrite(img_path0, batch['rgb0'].squeeze(0).numpy()) | |
| if not batch['rgb1_is_good'].item(): | |
| img_path1 = batch['img_path1'][0] | |
| if not os.path.exists(img_path1): | |
| cv2.imwrite(img_path1, batch['rgb1'].squeeze(0).numpy()) | |
| current_id = np.array([idx0, idx1]) | |
| save_name = '{}.npy'.format(str(current_id)) | |
| save_path = join(save_dir, save_name) | |
| if exists(save_path) and not opt.debug: continue | |
| rgb0 = batch['rgb0'].squeeze(0).numpy() | |
| rgb1 = batch['rgb1'].squeeze(0).numpy() | |
| _rgb0_, _rgb1_ = deepcopy(rgb0), deepcopy(rgb1) | |
| # get correspondeces in unresize image | |
| pt0, pt1 = None, None | |
| if opt.resize: | |
| cache_path = join(cache_dir, save_name) | |
| if not exists(cache_path): continue | |
| with open(cache_path, 'rb') as f: | |
| pts = np.load(f) | |
| pt0, pt1 = pts[:, :2], pts[:, 2:] | |
| # process first frame image | |
| xA0, xA1, yA0, yA1, hA, wA, wA_new, hA_new = None, None, None, None, None, None, None, None | |
| if opt.resize: | |
| # crop rgb0 | |
| xA0 = math.floor(pt0[:, 0].min()) | |
| xA1 = math.ceil(pt0[:, 0].max()) | |
| yA0 = math.floor(pt0[:, 1].min()) | |
| yA1 = math.ceil(pt0[:, 1].max()) | |
| rgb0 = rgb0[yA0:yA1, xA0:xA1] | |
| hA, wA = rgb0.shape[:2] | |
| wA_new, hA_new = get_resized_wh(wA, hA, [_h_, _w_]) | |
| wA_new, hA_new = get_divisible_wh(wA_new, hA_new, 8) | |
| rgb0 = cv2.resize(rgb0, (wA_new, hA_new), interpolation=cv2.INTER_AREA) | |
| # go on | |
| gray0 = cv2.cvtColor(rgb0, cv2.COLOR_RGB2GRAY) | |
| # semantic segmentation | |
| with torch.no_grad(): | |
| seg_path0 = join(segment_root, '{}.npy'.format(idx0)) | |
| if not os.path.exists(seg_path0): | |
| mask0 = segment(_rgb0_, device, segmentation_module) | |
| np.save(seg_path0, mask0) | |
| else: | |
| mask0 = np.load(seg_path0) | |
| # process next frame image | |
| xB0, xB1, yB0, yB1, hB, wB, wB_new, hB_new = None, None, None, None, None, None, None, None | |
| if opt.resize: | |
| # crop rgb1 | |
| xB0 = math.floor(pt1[:, 0].min()) | |
| xB1 = math.ceil(pt1[:, 0].max()) | |
| yB0 = math.floor(pt1[:, 1].min()) | |
| yB1 = math.ceil(pt1[:, 1].max()) | |
| rgb1 = rgb1[yB0:yB1, xB0:xB1] | |
| hB, wB = rgb1.shape[:2] | |
| wB_new, hB_new = get_resized_wh(wB, hB, [_h_, _w_]) | |
| wB_new, hB_new = get_divisible_wh(wB_new, hB_new, 8) | |
| rgb1 = cv2.resize(rgb1, (wB_new, hB_new), interpolation=cv2.INTER_AREA) | |
| # go on | |
| gray1 = cv2.cvtColor(rgb1, cv2.COLOR_RGB2GRAY) | |
| # semantic segmentation | |
| with torch.no_grad(): | |
| seg_path1 = join(segment_root, '{}.npy'.format(idx1)) | |
| if not os.path.exists(seg_path1): | |
| mask1 = segment(_rgb1_, device, segmentation_module) | |
| np.save(seg_path1, mask1) | |
| else: | |
| mask1 = np.load(seg_path1) | |
| if mask0.shape[:2] != _rgb0_.shape[:2]: | |
| mask0 = cv2.resize(mask0, _rgb0_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) | |
| if mask1.shape != _rgb1_.shape[:2]: | |
| mask1 = cv2.resize(mask1, _rgb1_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) | |
| if opt.resize: | |
| # resize mask0 | |
| mask0 = mask0[yA0:yA1, xA0:xA1] | |
| mask0 = cv2.resize(mask0, (wA_new, hA_new), interpolation=cv2.INTER_NEAREST) | |
| # resize mask1 | |
| mask1 = mask1[yB0:yB1, xB0:xB1] | |
| mask1 = cv2.resize(mask1, (wB_new, hB_new), interpolation=cv2.INTER_NEAREST) | |
| data = None | |
| if opt.method == 'SIFT': | |
| mask_0 = mask0 != CLS_DICT[exclude[0]] | |
| mask_1 = mask1 != CLS_DICT[exclude[0]] | |
| for cls in exclude[1:]: | |
| mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) | |
| mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) | |
| mask_0 = mask_0.astype(np.uint8) | |
| mask_1 = mask_1.astype(np.uint8) | |
| if mask_0.sum() == 0 or mask_1.sum() == 0: continue | |
| # keypoint detection and description | |
| kpts0, desc0 = detectAndCompute(rgb0, mask_0) | |
| if desc0 is None or desc0.shape[0] < 8: continue | |
| kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0]) | |
| kpts0, desc0 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts0, desc0]) | |
| desc0 = (desc0 / desc0.sum(dim=1, keepdim=True)).sqrt() | |
| # keypoint detection and description | |
| kpts1, desc1 = detectAndCompute(rgb1, mask_1) | |
| if desc1 is None or desc1.shape[0] < 8: continue | |
| kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1]) | |
| kpts1, desc1 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts1, desc1]) | |
| desc1 = (desc1 / desc1.sum(dim=1, keepdim=True)).sqrt() | |
| # mutual nearest matching and ratio filter | |
| matches = desc0 @ desc1.transpose(0, 1) | |
| mask = (matches == matches.max(dim=1, keepdim=True).values) & \ | |
| (matches == matches.max(dim=0, keepdim=True).values) | |
| # noinspection PyUnresolvedReferences | |
| valid, indices = mask.max(dim=1) | |
| ratio = torch.topk(matches, k=2, dim=1).values | |
| ratio = (-2 * ratio + 2).sqrt() | |
| # ratio = (ratio[:, 0] / ratio[:, 1]) < opt.mt | |
| ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8 | |
| valid = valid & ratio | |
| # get matched keypoints | |
| mkpts0 = kpts0[valid] | |
| mkpts1 = kpts1[indices[valid]] | |
| b_ids = torch.where(valid[None])[0] | |
| data = dict( | |
| m_bids = b_ids, | |
| mkpts0_f = mkpts0, | |
| mkpts1_f = mkpts1, | |
| ) | |
| elif opt.method == 'GIM_DKM': | |
| mask_0 = mask0 != CLS_DICT[exclude[0]] | |
| mask_1 = mask1 != CLS_DICT[exclude[0]] | |
| for cls in exclude[1:]: | |
| mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) | |
| mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) | |
| mask_0 = mask_0.astype(np.uint8) | |
| mask_1 = mask_1.astype(np.uint8) | |
| if mask_0.sum() == 0 or mask_1.sum() == 0: continue | |
| img0 = rgb0 * mask_0[..., None] | |
| img1 = rgb1 * mask_1[..., None] | |
| width0, height0 = img0.shape[1], img0.shape[0] | |
| width1, height1 = img1.shape[1], img1.shape[0] | |
| with torch.no_grad(): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| img0 = torch.from_numpy(img0).permute(2, 0, 1).to(device)[None] / 255 | |
| img1 = torch.from_numpy(img1).permute(2, 0, 1).to(device)[None] / 255 | |
| dense_matches, dense_certainty = model.match(img0, img1) | |
| sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000) | |
| mkpts0 = sparse_matches[:, :2] | |
| mkpts0 = torch.stack((width0 * (mkpts0[:, 0] + 1) / 2, | |
| height0 * (mkpts0[:, 1] + 1) / 2), dim=-1) | |
| mkpts1 = sparse_matches[:, 2:] | |
| mkpts1 = torch.stack((width1 * (mkpts1[:, 0] + 1) / 2, | |
| height1 * (mkpts1[:, 1] + 1) / 2), dim=-1) | |
| m_bids = torch.zeros(sparse_matches.shape[0], dtype=torch.long, device=device) | |
| data = dict( | |
| m_bids = m_bids, | |
| mkpts0_f = mkpts0, | |
| mkpts1_f = mkpts1, | |
| ) | |
| elif opt.method == 'GIM_LOFTR': | |
| mask_0 = mask0 != CLS_DICT[exclude[0]] | |
| mask_1 = mask1 != CLS_DICT[exclude[0]] | |
| for cls in exclude[1:]: | |
| mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) | |
| mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) | |
| mask_0 = mask_0.astype(np.uint8) | |
| mask_1 = mask_1.astype(np.uint8) | |
| if mask_0.sum() == 0 or mask_1.sum() == 0: continue | |
| mask_0 = cv2.resize(mask_0, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST) | |
| mask_1 = cv2.resize(mask_1, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST) | |
| data = dict( | |
| image0=gray2tensor(gray0), | |
| image1=gray2tensor(gray1), | |
| color0=color2tensor(rgb0), | |
| color1=color2tensor(rgb1), | |
| mask0=torch.from_numpy(mask_0)[None], | |
| mask1=torch.from_numpy(mask_1)[None], | |
| ) | |
| with torch.no_grad(): | |
| data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v | |
| in data.items()} | |
| model(data) | |
| elif opt.method == 'GIM_GLUE': | |
| mask_0 = mask0 != CLS_DICT[exclude[0]] | |
| mask_1 = mask1 != CLS_DICT[exclude[0]] | |
| for cls in exclude[1:]: | |
| mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) | |
| mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) | |
| mask_0 = mask_0.astype(np.uint8) | |
| mask_1 = mask_1.astype(np.uint8) | |
| if mask_0.sum() == 0 or mask_1.sum() == 0: continue | |
| size0 = torch.tensor(gray0.shape[-2:][::-1])[None] | |
| size1 = torch.tensor(gray1.shape[-2:][::-1])[None] | |
| data = dict( | |
| gray0 = gray2tensor(gray0 * mask_0), | |
| gray1 = gray2tensor(gray1 * mask_1), | |
| size0 = size0, | |
| size1 = size1, | |
| ) | |
| with torch.no_grad(): | |
| data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v | |
| in data.items()} | |
| pred = model(data) | |
| kpts0, kpts1 = pred['keypoints0'][0], pred['keypoints1'][0] | |
| matches = pred['matches'][0] | |
| if len(matches) == 0: continue | |
| mkpts0 = kpts0[matches[..., 0]] | |
| mkpts1 = kpts1[matches[..., 1]] | |
| m_bids = torch.zeros(matches[..., 0].size(), dtype=torch.long, device=device) | |
| data = dict( | |
| m_bids = m_bids, | |
| mkpts0_f = mkpts0, | |
| mkpts1_f = mkpts1, | |
| ) | |
| # auto remove watermarker | |
| kpts0 = data['mkpts0_f'].clone() # (N, 2) | |
| kpts1 = data['mkpts1_f'].clone() # (N, 2) | |
| moved = ~((kpts0 - kpts1).abs() < 1).min(dim=1).values # (N) | |
| data['m_bids'] = data['m_bids'][moved] | |
| data['mkpts0_f'] = data['mkpts0_f'][moved] | |
| data['mkpts1_f'] = data['mkpts1_f'][moved] | |
| robust_fitting(data) | |
| if (data['inliers'] is None) or (sum(data['inliers'][0]) == 0): continue | |
| inliers = data['inliers'][0] | |
| if opt.debug: | |
| data.update(dict( | |
| # for debug visualization | |
| mask0 = mask0, | |
| mask1 = mask1, | |
| gray0 = gray0, | |
| gray1 = gray1, | |
| color0 = rgb0, | |
| color1 = rgb1, | |
| hw0_i = rgb0.shape[:2], | |
| hw1_i = rgb1.shape[:2], | |
| dataset_name = ['WALK'], | |
| scene_id = [video_name], | |
| pair_id = [[idx0, idx1]], | |
| imsize0=[[width, height]], | |
| imsize1=[[width, height]], | |
| )) | |
| out = fast_make_matching_robust_fitting_figure(data) | |
| cv2.imwrite(join(debug_dir, '{} {:8d} {:8d}.png'.format(scene_name, idx0, idx1)), | |
| cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) | |
| continue | |
| if opt.resize: | |
| mkpts0_f = (data['mkpts0_f'].cpu().numpy()[inliers] * np.array([[wA/wA_new, hA/hA_new]]) + np.array([[xA0, yA0]])) * vratio | |
| mkpts1_f = (data['mkpts1_f'].cpu().numpy()[inliers] * np.array([[wB/wB_new, hB/hB_new]]) + np.array([[xB0, yB0]])) * vratio | |
| else: | |
| mkpts0_f = data['mkpts0_f'].cpu().numpy()[inliers] * vratio | |
| mkpts1_f = data['mkpts1_f'].cpu().numpy()[inliers] * vratio | |
| pts = np.concatenate([mkpts0_f, mkpts1_f], axis=1).astype(np.float32) | |
| nums = np.concatenate([nums, np.array([len(pts)])], axis=0) if len(nums) else np.array([len(pts)]) | |
| idxs = np.concatenate([idxs, current_id[None]], axis=0) if len(idxs) else current_id[None] | |
| with open(save_path, 'wb') as f: | |
| np.save(f, pts) | |
| with open(join(save_dir, 'nums.npy'), 'wb') as f: | |
| np.save(f, nums) | |
| with open(join(save_dir, 'idxs.npy'), 'wb') as f: | |
| np.save(f, idxs) | |
| def robust_fitting(data, b_id=0): | |
| m_bids = data['m_bids'].cpu().numpy() | |
| kpts0 = data['mkpts0_f'].cpu().numpy() | |
| kpts1 = data['mkpts1_f'].cpu().numpy() | |
| mask = m_bids == b_id | |
| # noinspection PyBroadException | |
| try: | |
| _, mask = cv2.findFundamentalMat(kpts0[mask], kpts1[mask], cv2.USAC_MAGSAC, ransacReprojThreshold=0.5, confidence=0.999999, maxIters=100000) | |
| mask = (mask.ravel() > 0)[None] | |
| except: | |
| mask = None | |
| data.update(dict(inliers=mask)) | |
| def get_resized_wh(w, h, resize): | |
| nh, nw = resize | |
| sh, sw = nh / h, nw / w | |
| scale = min(sh, sw) | |
| w_new, h_new = int(round(w*scale)), int(round(h*scale)) | |
| return w_new, h_new | |
| def get_divisible_wh(w, h, df=None): | |
| if df is not None: | |
| w_new = max((w // df), 1) * df | |
| h_new = max((h // df), 1) * df | |
| else: | |
| w_new, h_new = w, h | |
| return w_new, h_new | |
| def read_deeplab_image(img, size=1920): | |
| width, height = img.shape[1], img.shape[0] | |
| if max(width, height) > size: | |
| if width > height: | |
| img = cv2.resize(img, (size, int(size * height / width)), interpolation=cv2.INTER_AREA) | |
| else: | |
| img = cv2.resize(img, (int(size * width / height), size), interpolation=cv2.INTER_AREA) | |
| img = (torch.from_numpy(img).float() / 255).permute(2, 0, 1)[None] | |
| return img | |
| def read_segmentation_image(img): | |
| img = read_deeplab_image(img, size=720)[0] | |
| img = img - torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) | |
| img = img / torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) | |
| return img | |
| def segment(rgb, device, segmentation_module): | |
| img_data = read_segmentation_image(rgb) | |
| singleton_batch = {'img_data': img_data[None].to(device)} | |
| output_size = img_data.shape[1:] | |
| # Run the segmentation at the highest resolution. | |
| scores = segmentation_module(singleton_batch, segSize=output_size) | |
| # Get the predicted scores for each pixel | |
| _, pred = torch.max(scores, dim=1) | |
| return pred.cpu()[0].numpy().astype(np.uint8) | |
| def getLabel(pair, idxs, nums, h5py_i, h5py_f): | |
| """ | |
| Args: | |
| pair: [6965 6970] | |
| idxs: (N, 2) | |
| nums: (N,) | |
| h5py_i: (M, 2) | |
| h5py_f: (M, 2) | |
| Returns: pseudo_label (N, 4) | |
| """ | |
| i, j = np.where(idxs == pair) | |
| if len(i) == 0: return None | |
| assert (len(i) == len(j) == 2) and (i[0] == i[1]) and (j[0] == 0) and (j[1] == 1) | |
| i = i[0] | |
| nums = nums[:i+1] | |
| idx0, idx1 = sum(nums[:-1]), sum(nums) | |
| mkpts0 = h5py_i[idx0:idx1] | |
| mkpts1 = h5py_f[idx0:idx1] # (N, 2) | |
| return mkpts0, mkpts1 | |
| def fast_make_matching_robust_fitting_figure(data, b_id=0): | |
| b_mask = data['m_bids'] == b_id | |
| gray0 = data['gray0'] | |
| gray1 = data['gray1'] | |
| kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() | |
| kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() | |
| margin = 2 | |
| (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] | |
| h, w = max(h0, h1), max(w0, w1) | |
| H, W = margin * 5 + h * 4, margin * 3 + w * 2 | |
| # canvas | |
| out = 255 * np.ones((H, W), np.uint8) | |
| wx = [margin, margin + w0, margin + w + margin, margin + w + margin + w1] | |
| hx = lambda row: margin * row + h * (row-1) | |
| out = np.stack([out] * 3, -1) | |
| sh = hx(row=1) | |
| color0 = data['color0'] # (rH, rW, 3) | |
| color1 = data['color1'] # (rH, rW, 3) | |
| out[sh: sh + h0, wx[0]: wx[1]] = color0 | |
| out[sh: sh + h1, wx[2]: wx[3]] = color1 | |
| sh = hx(row=2) | |
| img0 = np.stack([gray0] * 3, -1) * 0 | |
| for cls in exclude: img0[data['mask0'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]] | |
| out[sh: sh + h0, wx[0]: wx[1]] = img0 | |
| img1 = np.stack([gray1] * 3, -1) * 0 | |
| for cls in exclude: img1[data['mask1'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]] | |
| out[sh: sh + h1, wx[2]: wx[3]] = img1 | |
| # before outlier filtering | |
| sh = hx(row=3) | |
| mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) | |
| out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) | |
| out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) | |
| for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): | |
| # display line end-points as circles | |
| c = (230, 216, 132) | |
| cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA) | |
| cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA) | |
| # after outlier filtering | |
| if data['inliers'] is not None: | |
| sh = hx(row=4) | |
| inliers = data['inliers'][b_id] | |
| mkpts0, mkpts1 = np.round(kpts0).astype(int)[inliers], np.round(kpts1).astype(int)[inliers] | |
| out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) | |
| out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) | |
| for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): | |
| # display line end-points as circles | |
| c = (230, 216, 132) | |
| cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA) | |
| cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA) | |
| # Big text. | |
| text = [ | |
| f' ', | |
| f'#Matches {len(kpts0)}', | |
| f'#Matches {sum(data["inliers"][b_id]) if data["inliers"] is not None else 0}', | |
| ] | |
| sc = min(H / 640., 1.0) | |
| Ht = int(30 * sc) # text height | |
| txt_color_fg = (255, 255, 255) # white | |
| txt_color_bg = (0, 0, 0) # black | |
| for i, t in enumerate(text): | |
| cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA) | |
| cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA) | |
| fingerprint = [ | |
| 'Dataset: {}'.format(data['dataset_name'][b_id]), | |
| 'Scene ID: {}'.format(data['scene_id'][b_id]), | |
| 'Pair ID: {}'.format(data['pair_id'][b_id]), | |
| 'Image sizes: {} - {}'.format(data['imsize0'][b_id], | |
| data['imsize1'][b_id]), | |
| ] | |
| sc = min(H / 640., 1.0) | |
| Ht = int(18 * sc) # text height | |
| txt_color_fg = (255, 255, 255) # white | |
| txt_color_bg = (0, 0, 0) # black | |
| for i, t in enumerate(reversed(fingerprint)): | |
| cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_bg, 2, cv2.LINE_AA) | |
| cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_fg, 1, cv2.LINE_AA) | |
| return out | |
| if __name__ == '__main__': | |
| with torch.no_grad(): | |
| main() | |