Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Author : xuelun | |
import os | |
import cv2 | |
import torch | |
import random | |
import numpy as np | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from os import listdir | |
from pathlib import Path | |
from functools import reduce | |
from datetime import datetime | |
from argparse import ArgumentParser | |
from os.path import join, isdir, exists | |
from datasets.dataset import RGBDDataset | |
from datasets.walk import cfg | |
from datasets.walk.utils import covision, intersected, read_images | |
from datasets.walk.utils import fast_make_matching_robust_fitting_figure | |
parse_mtd = lambda name: name.parent.stem.split()[1] | |
parse_skip = lambda name: int(str(name).split(os.sep)[-1].rpartition('SP')[-1].strip().rpartition(' ')[0]) | |
parse_resize = lambda name: str(name).split(os.sep)[-2].rpartition('[R]')[-1].rpartition('[S]')[0].strip() | |
create_table = lambda x, y, w: dict(zip(np.round(x) + np.round(y) * w, list(range(len(x))))) | |
class WALKDataset(RGBDDataset): | |
def __init__(self, | |
root_dir, # data root dit | |
npz_root, # data info, like, overlap, image_path, depth_path | |
seq_name, # current sequence | |
mode, # train or val or test | |
max_resize, # max edge after resize | |
df, # general is 8 for ResNet w/ pre 3-layers | |
padding, # padding image for batch training | |
augment_fn, # augmentation function | |
max_samples, # max sample in current sequence | |
**kwargs): | |
super().__init__() | |
self.mode = mode | |
self.root_dir = root_dir | |
self.scene_path = join(root_dir, seq_name) | |
pseudo_labels = kwargs.get('PSEUDO_LABELS', None) | |
npz_paths = [join(npz_root, x) for x in pseudo_labels] | |
npz_paths = [x for x in npz_paths if exists(x)] | |
npz_names = [{d[:int(d.split()[-1])]: Path(path, d) for d in listdir(path) if isdir(join(path, d))} for path in npz_paths] | |
npz_paths = [name_dict[seq_name] for name_dict in npz_names if seq_name in name_dict.keys()] | |
self.propagating = kwargs.get('PROPAGATING', False) | |
if self.propagating and len(npz_paths) != 24: | |
print(f'{seq_name} has {len(npz_paths)} pseudo labels, but 24 are expected.') | |
exit(0) | |
self.scale = 1 / df | |
self.scene_id = seq_name | |
self.skips = sorted(list({parse_skip(name) for name in npz_paths})) | |
self.resizes = sorted(list({parse_resize(name) for name in npz_paths})) | |
self.methods = sorted(list({parse_mtd(name) for name in npz_paths}))[::-1] | |
self.min_final_matches = kwargs.get('MIN_FINAL_MATCHES', None) | |
self.min_filter_matches = kwargs.get('MIN_FILTER_MATCHES', None) | |
pproot = kwargs.get('PROPAGATE_ROOT', None) | |
ppid = ' '.join(self.methods + list(map(str, self.skips)) + self.resizes + [f'FM {self.min_filter_matches}', f'PM {self.min_final_matches}']) | |
self.pproot = join(pproot, ppid, seq_name) | |
if not self.propagating: | |
assert exists(self.pproot) | |
elif not exists(self.pproot): | |
os.makedirs(self.pproot, exist_ok=True) | |
image_root = kwargs.get('VIDEO_IMAGE_ROOT', None) | |
self.image_root = join(image_root, seq_name) | |
if not exists(self.image_root): | |
os.makedirs(self.image_root, exist_ok=True) | |
self.step = kwargs.get('STEP', None) | |
self.pix_thr = kwargs.get('PIX_THR', None) | |
self.fix_matches = kwargs.get('FIX_MATCHES', None) | |
source_root = kwargs.get('SOURCE_ROOT', None) | |
scap = cv2.VideoCapture(join(source_root, seq_name + '.mp4')) | |
self.pseudo_size = [int(scap.get(3)), int(scap.get(4))] | |
source_fps = int(scap.get(5)) | |
video_path = join(root_dir, seq_name + '.mp4') | |
vcap = cv2.VideoCapture(video_path) | |
self.frame_size = [int(vcap.get(3)), int(vcap.get(4))] | |
if self.propagating: | |
nums = {skip: [] for skip in self.skips} | |
idxs = {skip: [] for skip in self.skips} | |
self.path = {skip: [] for skip in self.skips} | |
for npz_path in npz_paths: | |
skip = parse_skip(npz_path) | |
assert exists(npz_path / 'nums.npy') | |
with open(npz_path / 'nums.npy', 'rb') as f: | |
npz = np.load(f) | |
nums[skip].append(npz) | |
assert exists(npz_path / 'idxs.npy') | |
with open(npz_path / 'idxs.npy', 'rb') as f: | |
npz = np.load(f) | |
idxs[skip].append(npz) | |
self.path[skip].append(npz_path) | |
ids1 = reduce(intersected, [idxs[nums > self.min_filter_matches] for nums, idxs in zip(nums[self.skips[-1]], idxs[self.skips[-1]])]) | |
continue1 = np.array([x in ids1[:, 0] for x in (ids1[:, 0] + self.skips[-1] * 1)]) | |
ids2 = reduce(intersected, idxs[self.skips[-2]]) | |
continue2 = np.array([x in ids2[:, 0] for x in ids1[:, 0]]) | |
continue2 = continue2 & np.array([x in ids2[:, 0] for x in (ids1[:, 0] + self.skips[-2] * 1)]) | |
ids3 = reduce(intersected, idxs[self.skips[-3]]) | |
continue3 = np.array([x in ids3[:, 0] for x in ids1[:, 0]]) | |
continue3 = continue3 & np.array([x in ids3[:, 0] for x in (ids1[:, 0] + self.skips[-3] * 1)]) | |
continue3 = continue3 & np.array([x in ids3[:, 0] for x in (ids1[:, 0] + self.skips[-3] * 2)]) | |
continue3 = continue3 & np.array([x in ids3[:, 0] for x in (ids1[:, 0] + self.skips[-3] * 3)]) | |
continues = continue1 & continue2 & continue3 | |
ids = ids1[continues] | |
pair_ids = np.array(list(zip(ids[:, 0], np.clip(ids[:, 0]+self.step*self.skips[-1], a_min=ids[0, 0], a_max=ids[-1, 1])))) if self.step > 0 else ids | |
pair_ids = pair_ids[(pair_ids[:, 1] - pair_ids[:, 0]) >= self.skips[-1]] | |
else: | |
pair_ids = np.array([tuple(map(int, x.split('.npy')[0].split('_'))) for x in os.listdir(self.pproot) if x.endswith('.npy')]) | |
if (max_samples > 0) and (len(pair_ids) > max_samples): | |
random_state = random.getstate() | |
np_random_state = np.random.get_state() | |
random.seed(3407) | |
np.random.seed(3407) | |
pair_ids = pair_ids[sorted(np.random.randint(len(pair_ids), size=max_samples))] | |
random.setstate(random_state) | |
np.random.set_state(np_random_state) | |
# remove unvalid pairs from self.pproot/bad_pairs.txt | |
pair_ids = set(map(tuple, pair_ids.tolist())) | |
if self.propagating: | |
assert not exists(join(self.pproot, 'bad_pairs.txt')) | |
if exists(join(self.pproot, 'bad_pairs.txt')): | |
with open(join(self.pproot, 'bad_pairs.txt'), 'r') as f: | |
unvalid_pairs = set([tuple(map(int, line.split())) for line in f.readlines()]) | |
self.unvalid_pairs_num = len(unvalid_pairs) if not self.propagating else 'N/A' | |
pair_ids = pair_ids - unvalid_pairs | |
self.valid_pairs_num = len(pair_ids) if not self.propagating else 'N/A' | |
self.pair_ids = list(map(list, pair_ids)) # List[List[int, int]] | |
# parameters for image resizing, padding and depthmap padding | |
if mode == 'train': assert max_resize is not None | |
self.df = df | |
self.max_resize = max_resize | |
self.padding = padding | |
# for training LoFTR | |
self.augment_fn = augment_fn if mode == 'train' else None | |
def __len__(self): | |
return len(self.pair_ids) | |
def propagate(self, idx0, idx1, skips): | |
""" | |
Args: | |
idx0: (int) index of the first frame | |
idx1: (int) index of the second frame | |
skips: (List) | |
Returns: | |
""" | |
skip = skips[-1] # 40 | |
indices = [skip * (i + 1) + idx0 for i in range((idx1 - idx0) // skip)] | |
if (not indices) or (idx0 != indices[0]): indices = [idx0] + indices | |
if idx1 != indices[-1]: indices = indices + [idx1] | |
indices = list(zip(indices[:-1], indices[1:])) | |
# [(N', 4), (N'', 4), ...] | |
labels = [] | |
ids = [idx0] | |
while indices: | |
pair = indices.pop(0) # (tuple) | |
if pair[0] == pair[1]: break | |
label = [] | |
if (pair[-1] - pair[0]) == skip: | |
tmp = self.dump(skip, pair) | |
if len(tmp) > 0: label.append(tmp) # (ndarray) (N, 4) | |
if skips[:-1]: | |
_label_, id0, id1 = self.propagate(pair[0], pair[1], skips[:-1]) | |
if (id0, id1) == pair: label.append(_label_) # (ndarray) (M, 4) | |
if label: | |
label = np.concatenate(label, axis=0) # (ndarray) (N+M, 4) | |
labels.append(label) | |
ids += [pair[1]] | |
if len(labels) > 1: | |
_labels_ = self.link(labels[0], labels[1]) | |
if _labels_ is not None: | |
labels = [_labels_] | |
ids = [ids[0], ids[-1]] | |
else: | |
labels.pop(-1) | |
ids.pop(-1) | |
indices = [(pair[0], pair[1]-skips[0])] | |
if len(labels) == 1 and len(ids) == 2: | |
return labels[0], ids[0], ids[-1] | |
else: | |
return None, None, None | |
def link(self, label0, label1): | |
""" | |
Args: | |
label0: (ndarray) N x 4 | |
label1: (ndarray) M x 4 | |
Returns: (ndarray) (N', 4) | |
""" | |
# get keypoints in left, middle and right frame | |
left_t0 = label0[:, :2] # (N, 2) | |
mid_t0 = label0[:, 2:] # (N, 2) | |
mid_t1 = label1[:, :2] # (M, 2) | |
right_t1 = label1[:, 2:] # (M, 2) | |
mid0_table = create_table(mid_t0[:, 0], mid_t0[:, 1], self.pseudo_size[0]) | |
mid1_table = create_table(mid_t1[:, 0], mid_t1[:, 1], self.pseudo_size[0]) | |
keys = {*mid0_table} & {*mid1_table} | |
i = np.array([mid0_table[k] for k in keys]) | |
j = np.array([mid1_table[k] for k in keys]) | |
# remove repeat matches | |
ij = np.unique(np.vstack((i, j)), axis=1) | |
if ij.shape[1] < self.min_final_matches: return None | |
# get the new pseudo labels | |
pseudo_label = np.concatenate([left_t0[ij[0]], right_t1[ij[1]]], axis=1) # (N', 4) | |
return pseudo_label | |
def dump(self, skip, pair): | |
""" | |
Args: | |
skip: | |
pair: | |
Returns: pseudo_label (N, 4) | |
""" | |
labels = [] | |
for path in self.path[skip]: | |
p = path / '{}.npy'.format(str(np.array(pair))) | |
if exists(p): | |
with open(p, 'rb') as f: | |
labels.append(np.load(f)) | |
if len(labels) > 0: labels = np.concatenate(labels, axis=0).astype(np.float32) # (N, 4) | |
return labels | |
def __getitem__(self, idx): | |
idx0, idx1 = self.pair_ids[idx] | |
pppath = join(self.pproot, '{}_{}.npy'.format(idx0, idx1)) | |
if self.propagating and exists(pppath): | |
return None | |
# check propagation | |
if not self.propagating: | |
assert exists(pppath), f'{pppath} does not exist' | |
if not exists(pppath): | |
pseudo_label, idx0, idx1 = self.propagate(idx0, idx1, self.skips) | |
if idx1 - idx0 == self.skips[-1]: | |
pseudo_label, idx0, idx1 = self.propagate(idx0, idx1, self.skips[:-1]) | |
if idx1 - idx0 == self.skips[-2]: | |
pseudo_label, idx0, idx1 = self.propagate(idx0, idx1, self.skips[:-2]) | |
if pseudo_label is None: | |
_idx0_, _idx1_ = self.pair_ids[idx] | |
with open(join(self.pproot, 'bad_pairs.txt'), 'a') as f: | |
f.write('{} {}\n'.format(_idx0_, _idx1_)) | |
return None | |
_, mask = cv2.findFundamentalMat(pseudo_label[:, :2], pseudo_label[:, 2:], cv2.USAC_MAGSAC, ransacReprojThreshold=1.0, confidence=0.999999, maxIters=1000) | |
mask = mask.ravel() > 0 | |
pseudo_label = pseudo_label[mask] | |
if len(pseudo_label) < 64 or (idx1 - idx0) == self.skips[-3]: | |
_idx0_, _idx1_ = self.pair_ids[idx] | |
with open(join(self.pproot, 'bad_pairs.txt'), 'a') as f: | |
f.write('{} {}\n'.format(_idx0_, _idx1_)) | |
return None | |
else: | |
with open(pppath, 'wb') as f: | |
np.save(f, np.concatenate((np.array([[idx0, idx1, idx0, idx1]]).astype(np.float32), pseudo_label), axis=0)) | |
else: | |
with open(pppath, 'rb') as f: | |
pseudo_label = np.load(f) | |
idx0, idx1 = pseudo_label[0].astype(np.int64)[:2].tolist() | |
pseudo_label = pseudo_label[1:] | |
if self.propagating: | |
return None | |
pseudo_label *= (np.array(self.frame_size * 2) / np.array(self.pseudo_size * 2))[None] | |
# get image | |
img_path0 = join(self.image_root, '{}.png'.format(idx0)) | |
color0 = cv2.imread(img_path0) | |
img_path1 = join(self.image_root, '{}.png'.format(idx1)) | |
color1 = cv2.imread(img_path1) | |
width0, height0 = self.frame_size | |
width1, height1 = self.frame_size | |
left_upper_cornor = pseudo_label[:, :2].min(axis=0) | |
left_low_corner = pseudo_label[:, :2].max(axis=0) | |
left_corner = np.concatenate([left_upper_cornor, left_low_corner], axis=0) | |
right_upper_cornor = pseudo_label[:, 2:].min(axis=0) | |
right_low_corner = pseudo_label[:, 2:].max(axis=0) | |
right_corner = np.concatenate([right_upper_cornor, right_low_corner], axis=0) | |
# Prepare variables | |
image0, color0, scale0, rands0, offset0, hlip0, vflip0, resize0, mask0 = read_images( | |
None, self.max_resize, self.df, self.padding, | |
np.random.choice([self.augment_fn, None], p=[0.5, 0.5]), | |
aug_prob=1.0, is_left=True, | |
upper_cornor=left_corner, | |
read_size=self.frame_size, image=color0) | |
image1, color1, scale1, rands1, offset1, hlip1, vflip1, resize1, mask1 = read_images( | |
None, self.max_resize, self.df, self.padding, | |
np.random.choice([self.augment_fn, None], p=[0.5, 0.5]), | |
aug_prob=1.0, is_left=False, | |
upper_cornor=right_corner, | |
read_size=self.frame_size, image=color1) | |
# warp keypoints by scale, offset and hlip | |
pseudo_label = torch.tensor(pseudo_label, dtype=torch.float) | |
left = (pseudo_label[:, :2] / scale0[None] - offset0[None]) | |
left[:, 0] = resize0[1] - 1 - left[:, 0] if hlip0 else left[:, 0] | |
left[:, 1] = resize0[0] - 1 - left[:, 1] if vflip0 else left[:, 1] | |
right = (pseudo_label[:, 2:] / scale1[None] - offset1[None]) | |
right[:, 0] = resize1[1] - 1 - right[:, 0] if hlip1 else right[:, 0] | |
right[:, 1] = resize1[0] - 1 - right[:, 1] if vflip1 else right[:, 1] | |
mask = (left[:, 0] >= 0) & (left[:, 0]*self.scale <= (resize0[1]*self.scale - 1)) & \ | |
(left[:, 1] >= 0) & (left[:, 1]*self.scale <= (resize0[0]*self.scale - 1)) & \ | |
(right[:, 0] >= 0) & (right[:, 0]*self.scale <= (resize1[1]*self.scale - 1)) & \ | |
(right[:, 1] >= 0) & (right[:, 1]*self.scale <= (resize1[0]*self.scale - 1)) | |
left, right = left[mask], right[mask] | |
pseudo_label = torch.cat([left, right], dim=1) | |
pseudo_label = torch.unique(pseudo_label, dim=0) | |
fix_pseudo_label = torch.zeros(self.fix_matches, 4, dtype=pseudo_label.dtype) | |
fix_pseudo_label[:len(pseudo_label)] = pseudo_label | |
# read image size | |
imsize0 = torch.tensor([height0, width0], dtype=torch.long) | |
imsize1 = torch.tensor([height1, width1], dtype=torch.long) | |
resize0 = torch.tensor(resize0, dtype=torch.long) | |
resize1 = torch.tensor(resize1, dtype=torch.long) | |
data = { | |
# image 0 | |
'image0': image0, | |
'color0': color0, | |
'imsize0': imsize0, | |
'offset0': offset0, | |
'resize0': resize0, | |
'depth0': torch.ones((1600, 1600), dtype=torch.float), | |
'hflip0': hlip0, | |
'vflip0': vflip0, | |
# image 1 | |
'image1': image1, | |
'color1': color1, | |
'imsize1': imsize1, | |
'offset1': offset1, | |
'resize1': resize1, | |
'depth1': torch.ones((1600, 1600), dtype=torch.float), | |
'hflip1': hlip1, | |
'vflip1': vflip1, | |
# image transform | |
'pseudo_labels': fix_pseudo_label, | |
'gt': False, | |
'zs': True, | |
# image transform | |
'T_0to1': torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float), | |
'T_1to0': torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float), | |
'K0': torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float), | |
'K1': torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float), | |
# pair information | |
'scale0': scale0 / scale0, | |
'scale1': scale1 / scale1, | |
'rands0': rands0, | |
'rands1': rands1, | |
'dataset_name': 'WALK', | |
'scene_id': '{:30}'.format(self.scene_id[:min(30, len(self.scene_id)-1)]), | |
'pair_id': f'{idx0}-{idx1}', | |
'pair_names': ('{}.png'.format(idx0), | |
'{}.png'.format(idx1)), | |
'covisible0': covision(pseudo_label[:, :2], resize0).item(), | |
'covisible1': covision(pseudo_label[:, 2:], resize1).item(), | |
} | |
item = super(WALKDataset, self).__getitem__(idx) | |
item.update(data) | |
data = item | |
if mask0 is not None: | |
if self.scale: | |
# noinspection PyArgumentList | |
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), | |
scale_factor=self.scale, | |
mode='nearest', | |
recompute_scale_factor=False)[0].bool() | |
data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) | |
data.update({'mask0_i': mask0, 'mask1_i': mask1}) | |
return data | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
parser.add_argument('seq_names', type=str, nargs='+') | |
args = parser.parse_args() | |
train_cfg = cfg.DATASET.TRAIN | |
base_input = { | |
'df': 8, | |
'mode': 'train', | |
'augment_fn': None, | |
'max_resize': [1280, 720], | |
'padding': cfg.DATASET.TRAIN.PADDING, | |
'max_samples': cfg.DATASET.TRAIN.MAX_SAMPLES, | |
'min_overlap_score': cfg.DATASET.TRAIN.MIN_OVERLAP_SCORE, | |
'max_overlap_score': cfg.DATASET.TRAIN.MAX_OVERLAP_SCORE | |
} | |
cfg_input = { | |
k: getattr(train_cfg, k) | |
for k in [ | |
'DATA_ROOT', 'NPZ_ROOT', 'STEP', 'PIX_THR', 'FIX_MATCHES', 'SOURCE_ROOT', | |
'MAX_CANDIDATE_MATCHES', 'MIN_FINAL_MATCHES', 'MIN_FILTER_MATCHES', | |
'VIDEO_IMAGE_ROOT', 'PROPAGATE_ROOT', 'PSEUDO_LABELS' | |
] | |
} | |
if os.path.isfile(args.seq_names[0]): | |
with open(args.seq_names[0], 'r') as f: | |
seq_names = [line.strip() for line in f.readlines()] | |
else: | |
seq_names = args.seq_names | |
for seq_name in seq_names: | |
input_ = { | |
**base_input, | |
**cfg_input, | |
'root_dir': cfg_input['DATA_ROOT'], | |
'npz_root': cfg_input['NPZ_ROOT'], | |
'seq_name': seq_name | |
} | |
dataset = WALKDataset(**input_) | |
random.seed(3407) | |
np.random.seed(3407) | |
samples = list(range(len(dataset))) | |
num = 10 | |
samples = random.sample(samples, num) | |
for idx_ in tqdm(samples[:num], ncols=80, bar_format="{l_bar}{bar:3}{r_bar}", total=num, | |
desc=f'[ {seq_name[:min(10, len(seq_name)-1)]:<10} ] [ {dataset.valid_pairs_num:<5} / {dataset.valid_pairs_num+dataset.unvalid_pairs_num:<5} ]',): | |
data_ = dataset[idx_] | |
if data_ is None: continue | |
pseudo_labels_ = data_['pseudo_labels'] | |
mask_ = pseudo_labels_.sum(dim=1) > 0 | |
pseudo_label_ = pseudo_labels_[mask_].cpu().numpy() | |
data_['mkpts0_f'] = pseudo_label_[:, :2] | |
data_['mkpts1_f'] = pseudo_label_[:, 2:] | |
data_['hw0_i'] = data_['image0'].shape[-2:] | |
data_['hw1_i'] = data_['image1'].shape[-2:] | |
data_['image0'] = data_['image0'][None] | |
data_['image1'] = data_['image1'][None] | |
data_['color0'] = data_['color0'][None] | |
data_['color1'] = data_['color1'][None] | |
idx0_, idx1_ = data_['pair_id'].split('-') | |
idx0_, idx1_ = map(int, [idx0_, idx1_]) | |
out = fast_make_matching_robust_fitting_figure(data_, transpose=True) | |
save_dir = Path('dump/walk') / seq_name | |
if not exists(save_dir): save_dir.mkdir(parents=True, exist_ok=True) | |
cv2.imwrite(join(save_dir, '{:8d} [{}] {:8d} {:3d}.png'.format( | |
idx0_, | |
datetime.utcnow().strftime('%Y-%m-%d %H-%M-%S %f')[:-3], | |
idx1_, | |
idx1_ - idx0_ | |
)), cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) | |