""" "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/ Camera pose metrics adapted from LoFTR https://github.com/zju3dv/LoFTR/blob/master/src/utils/metrics.py The main difference is the use of poselib instead of OpenCV's vanilla RANSAC for E_mat, which is more stable and MUCH and faster. """ import argparse import numpy as np import os import cv2 from tqdm import tqdm import json import multiprocessing as mp # Disable scientific notation np.set_printoptions(suppress=True) def intrinsics_to_camera(K): px, py = K[0, 2], K[1, 2] fx, fy = K[0, 0], K[1, 1] return { "model": "PINHOLE", "width": int(2 * px), "height": int(2 * py), "params": [fx, fy, px, py], } def angle_error_vec(v1, v2): n = np.linalg.norm(v1) * np.linalg.norm(v2) return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) def angle_error_mat(R1, R2): cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds return np.rad2deg(np.abs(np.arccos(cos))) def compute_pose_error(T_0to1, R, t): R_gt = T_0to1[:3, :3] t_gt = T_0to1[:3, 3] error_t = angle_error_vec(t, t_gt) error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation error_R = angle_error_mat(R, R_gt) return error_t, error_R def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999, type='poselib'): if len(kpts0) < 5: return None if type == 'poselib': import poselib (pose,details) = poselib.estimate_relative_pose( kpts0.tolist(), kpts1.tolist(), intrinsics_to_camera(K0), intrinsics_to_camera(K1), ransac_opt={ 'max_iterations': 10000, # default 100000 'success_prob': conf, # default 0.99999 'max_epipolar_error': thresh, # default 1.0 }, bundle_opt={ # all defaults }, ) ret = (pose.R, pose.t, details['inliers']) elif type == 'opencv': f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) norm_thresh = thresh / f_mean kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] E, mask = cv2.findEssentialMat( kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC) assert E is not None best_num_inliers = 0 ret = None for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose( _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) if n > best_num_inliers: best_num_inliers = n ret = (R, t[:, 0], mask.ravel() > 0) else: raise NotImplementedError return ret def estimate_pose_parallel(args): return estimate_pose(*args) def pose_auc(errors, thresholds): sort_idx = np.argsort(errors) errors = np.array(errors.copy())[sort_idx] recall = (np.arange(len(errors)) + 1) / len(errors) errors = np.r_[0., errors] recall = np.r_[0., recall] aucs = [] for t in thresholds: last_index = np.searchsorted(errors, t) r = np.r_[recall[:last_index], recall[last_index-1]] e = np.r_[errors[:last_index], t] aucs.append(np.trapz(r, x=e)/t) return aucs def pose_accuracy(errors, thresholds): return [np.mean(errors < t) * 100 for t in thresholds] def get_relative_transform(pose0, pose1): R0 = pose0[..., :3, :3] # Bx3x3 t0 = pose0[..., :3, [3]] # Bx3x1 R1 = pose1[..., :3, :3] # Bx3x3 t1 = pose1[..., :3, [3]] # Bx3x1 R_0to1 = R1.transpose(-1, -2) @ R0 # Bx3x3 t_0to1 = R1.transpose(-1, -2) @ (t0 - t1) # Bx3x1 T_0to1 = np.concatenate([R_0to1, t_0to1], axis=-1) # Bx3x4 return T_0to1 class Scannet1500: default_config = { 'scannet_path': os.path.abspath(os.path.join(os.path.dirname(__file__), '../../data/ScanNet/scannet_test_1500')), 'gt_path': os.path.abspath(os.path.join(os.path.dirname(__file__), '../../data/ScanNet/test.npz')), 'pose_estimator': 'poselib', # poselib, opencv 'cache_images': True, 'ransac_thresholds': [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], 'pose_thresholds': [5, 10, 20], 'max_pairs': -1, 'output': './output/scannet/', 'n_workers': 8, } def __init__(self, config={}) -> None: self.config = {**self.default_config, **config} if not os.path.exists(self.config['scannet_path']): raise RuntimeError( f"Dataset {self.config['scannet_path']} does not exist! \n \ > If you didn't download the dataset, use the downloader tool: python3 -m modules.dataset.download -h") self.pairs = self.read_gt() os.makedirs(self.config['output'], exist_ok=True) if self.config['n_workers'] == -1: self.config['n_workers'] = mp.cpu_count() self.image_cache = {} if self.config['cache_images']: self.load_images() def load_images(self): for pair in tqdm(self.pairs, desc='Caching images'): if pair['image0'] not in self.image_cache: self.image_cache[pair['image0']] = cv2.imread(pair['image0']) if pair['image1'] not in self.image_cache: self.image_cache[pair['image1']] = cv2.imread(pair['image1']) def read_image(self, path): if self.config['cache_images']: return self.image_cache[path] else: return cv2.imread(path) def read_gt(self): pairs = [] gt_poses = np.load(self.config['gt_path']) names = gt_poses['name'] for i in range(len(names)): scene_id = names[i, 0] scene_idx = names[i, 1] scene = f'scene{scene_id:04d}_{scene_idx:02d}' image0 = str(int(names[i, 2])) image1 = str(int(names[i, 3])) K0 = np.loadtxt( os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'intrinsic/intrinsic_color.txt') ) K1 = K0 pose_0 = np.loadtxt( os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'pose', image0 + '.txt') ) pose_1 = np.loadtxt( os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'pose', image1 + '.txt') ) T_0to1 = get_relative_transform(pose_0, pose_1) pairs.append({ 'image0': os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'color', image0 + '.jpg'), 'image1': os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'color', image1 + '.jpg'), 'K0': K0, 'K1': K1, 'T_0to1': T_0to1, }) return pairs def extract_and_save_matches(self, matcher_fn, name='', force=False): all_matches = [] if name == '': name = matcher_fn.__name__ fname = os.path.join(self.config['output'], f'{name}_matches.npz') if not force and os.path.exists(fname): return np.load(fname, allow_pickle=True)['all_matches'] for pair in tqdm(self.pairs, desc='Extracting matches'): image0 = self.read_image(pair['image0']) image1 = self.read_image(pair['image1']) mkpts0, mkpts1 = matcher_fn(image0, image1) all_matches.append({ 'image0': pair['image0'], 'image1': pair['image1'], 'mkpts0': mkpts0, 'mkpts1': mkpts1, }) np.savez(fname, all_matches=all_matches) return all_matches def run_benchmark(self, matcher_fn, name='', force=False): if name == '': name = matcher_fn.__name__ all_matches = self.extract_and_save_matches(matcher_fn, name=name, force=force) aucs_by_thresh = {} accs_by_thresh = {} for ransac_thresh in self.config['ransac_thresholds']: fname = os.path.join(self.config['output'], f'{name}_{self.config["pose_estimator"]}_{ransac_thresh}.txt') # check if exists and has the right number of lines if not force and os.path.exists(fname) and len(open(fname, 'r').readlines()) == len(self.pairs): errors = [] with open(fname, 'r') as f: lines = f.readlines() for line in lines: line = line.replace('\n', '') err_t, err_R = line.split(' ') errors.append([float(err_t), float(err_R)]) # redo the benchmark else: errors = [] pairs = self.pairs errors_file = open(fname, 'w') # do the benchmark in parallel if self.config['n_workers'] != 1: pool = mp.Pool(self.config['n_workers']) pool_args = [ (all_matches[pair_idx]['mkpts0'], all_matches[pair_idx]['mkpts1'], pair['K0'], pair['K1'], ransac_thresh) for pair_idx, pair in enumerate(pairs) ] results = list(tqdm(pool.imap(estimate_pose_parallel, pool_args), total=len(pool_args), desc=f'Running benchmark for th={ransac_thresh}', leave=False)) pool.close() for pair_idx, ret in enumerate(results): if ret is None: err_t, err_R = np.inf, np.inf else: R, t, inliers = ret pair = pairs[pair_idx] err_t, err_R = compute_pose_error(pair['T_0to1'], R, t) errors_file.write(f'{err_t} {err_R}\n') errors.append([err_t, err_R]) # do the benchmark in serial else: for pair_idx, pair in tqdm(enumerate(pairs), desc=f'Running benchmark for th={ransac_thresh}', leave=False, total=len(pairs)): mkpts0 = all_matches[pair_idx]['mkpts0'] mkpts1 = all_matches[pair_idx]['mkpts1'] ret = estimate_pose(mkpts0, mkpts1, pair['K0'], pair['K1'], ransac_thresh) if ret is None: err_t, err_R = np.inf, np.inf else: R, t, inliers = ret err_t, err_R = compute_pose_error(pair['T_0to1'], R, t) errors_file.write(f'{err_t} {err_R}\n') errors_file.flush() errors.append([err_t, err_R]) errors_file.close() # compute AUCs errors = np.array(errors) errors = errors.max(axis=1) aucs = pose_auc(errors, self.config['pose_thresholds']) accs = pose_accuracy(errors, self.config['pose_thresholds']) aucs = {k: v*100 for k, v in zip(self.config['pose_thresholds'], aucs)} accs = {k: v for k, v in zip(self.config['pose_thresholds'], accs)} aucs_by_thresh[ransac_thresh] = aucs accs_by_thresh[ransac_thresh] = accs # dump summary for this method summary = { 'name': name, 'aucs_by_thresh': aucs_by_thresh, 'accs_by_thresh': accs_by_thresh, } json.dump(summary, open(os.path.join(self.config['output'], f'{name}_{self.config["pose_estimator"]}_summary.json'), 'w'), indent=2) return aucs_by_thresh def get_xfeat(): from modules.xfeat import XFeat xfeat = XFeat() return xfeat.match_xfeat def get_xfeat_star(): from modules.xfeat import XFeat xfeat = XFeat(top_k=10_000) return xfeat.match_xfeat_star def get_alike(): from third_party import alike_wrapper as alike return alike.match_alike def print_fancy(d): print(json.dumps(d, indent=2)) def parse(): parser = argparse.ArgumentParser() parser.add_argument("--scannet_path", type=str, required=True, help="Path to the Scannet 1500 dataset") parser.add_argument("--output", type=str, default="./output/scannet/", help="Path to the output directory") parser.add_argument("--max_pairs", type=int, default=-1, help="Maximum number of pairs to run the benchmark on") parser.add_argument("--force", action='store_true', help="Force running the benchmark again") parser.add_argument("--pose_estimator", type=str, default='poselib', help="Which pose estimator to use: poselib, opencv", choices=['poselib', 'opencv']) parser.add_argument("--show", action='store_true', help="Show the matches") parser.add_argument("--accuracy", action='store_true', help="Show the accuracy instead of AUC") parser.add_argument("--filter", type=str, nargs='+', help="Filter the results by the given names") return parser.parse_args() if __name__ == "__main__": args = parse() if not args.show: scannet = Scannet1500({ 'scannet_path': args.scannet_path, 'gt_path': args.scannet_path + "/test.npz", 'cache_images': False, 'output': args.output, 'max_pairs': args.max_pairs, 'pose_estimator': args.pose_estimator, 'ransac_thresholds': [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], 'n_workers': 8, }) functions = { 'xfeat': get_xfeat(), 'xfeat_star': get_xfeat_star(), 'alike': get_alike(), } # save all results to a file all_results = {} for name, fn in functions.items(): print(name) result = scannet.run_benchmark(matcher_fn = fn, name=name, force=args.force) all_results[name] = result json.dump(all_results, open(os.path.join(args.output, 'summary.json'), 'w'), indent=2) if args.show: import glob import pandas as pd dataset_name = 'scannet' all_summary_files = glob.glob(os.path.join(args.output, "**_summary.json"), recursive=True) if args.filter: all_summary_files = [f for f in all_summary_files if any([fil in f for fil in args.filter])] dfs = [] names = [] estimators = [] metric_key = 'aucs_by_thresh' if args.accuracy: metric_key = 'accuracies_by_thresh' for summary in all_summary_files: summary_data = json.load(open(summary, 'r')) if metric_key not in summary_data: continue aucs_by_thresh = summary_data[metric_key] estimator = 'poselib' if 'opencv' in summary: estimator = 'opencv' #make sure everything is float for thresh in aucs_by_thresh: for k in aucs_by_thresh[thresh]: if isinstance(aucs_by_thresh[thresh][k], str): aucs_by_thresh[thresh][k] = float(aucs_by_thresh[thresh][k].replace(' ', '')) # find best threshold based on the 5, 10, 20 mAP and everything is float df = pd.DataFrame(aucs_by_thresh).T.astype(float) df['mean'] = df.mean(axis=1) # create a string column called estimator cols = df.columns.tolist() dfs.append(df) names.append(summary_data['name']) estimators.append(estimator) # use each col as the main col to determine the best threshold # for col in cols: col = 'mean' final_df = pd.DataFrame() # add cols final_df['name'] = names final_df['best_thresh'] = '' final_df['estimator'] = estimators final_df[cols] = -1.0 for df, name, estimator in zip(dfs, names, estimators): best_thresh = df[col].idxmax() best_results = df.loc[best_thresh] # now update the best_thresh based on the estimator final_df.loc[(final_df['name'] == name) & (final_df['estimator'] == estimator), 'best_thresh'] = best_thresh for _col in cols: final_df.loc[(final_df['name'] == name) & (final_df['estimator'] == estimator), _col] = best_results[_col] # sort by mean final_df = final_df.sort_values(by=['mean']) # reset index final_df = final_df.reset_index(drop=True) # drop estimator column final_df = final_df.drop(columns=['estimator']) # set max float precision to 1 final_df = final_df.round(1) print(f"Dataset: {dataset_name}") print(f"Sorting by {col}") print(final_df) print() final_df.to_csv(os.path.join(args.output, f"{dataset_name}_{col}.csv"), index=False)