Spaces:
Running
Running
import os | |
import shutil | |
from tqdm import tqdm | |
import numpy as np | |
import h5py | |
import torch | |
from pathlib import Path | |
from typing import Dict, Iterable, Optional, List, Tuple, Union, Set | |
import pprint | |
import argparse | |
import torchvision.transforms.functional as F | |
from types import SimpleNamespace | |
from collections import defaultdict | |
from scipy.spatial import KDTree | |
from collections import Counter | |
from itertools import chain | |
from . import matchers, logger | |
from .utils.base_model import dynamic_load | |
from .utils.parsers import parse_retrieval, names_to_pair | |
from .match_features import find_unique_new_pairs | |
from .extract_features import read_image, resize_image | |
from .utils.io import list_h5_names | |
confs = { | |
'gim_dkm': { | |
'output': 'matches-gim', | |
'model': { | |
'name': 'dkm', | |
'weights': 'gim_dkm_100h.ckpt' | |
}, | |
'preprocessing': { | |
'grayscale': False, | |
'resize_max': None, | |
'dfactor': 1 | |
}, | |
'max_error': 2, # max error for assigned keypoints (in px) | |
'cell_size': 8, # size of quantization patch (max 1 kp/patch) | |
}, | |
} | |
def to_cpts(kpts, ps): | |
if ps > 0.0: | |
kpts = np.round(np.round((kpts + 0.5) / ps) * ps - 0.5, 2) | |
return [tuple(cpt) for cpt in kpts] | |
def assign_keypoints(kpts: np.ndarray, | |
other_cpts: Union[List[Tuple], np.ndarray], | |
max_error: float, | |
update: bool = False, | |
ref_bins: Optional[List[Counter]] = None, | |
scores: Optional[np.ndarray] = None, | |
cell_size: Optional[int] = None): | |
if not update: | |
if len(other_cpts) == 0: return np.array([], dtype=np.int64) | |
# Without update this is just a NN search | |
dist, kpt_ids = KDTree(np.array(other_cpts)).query(kpts) | |
valid = (dist <= max_error) | |
kpt_ids[~valid] = -1 | |
return kpt_ids | |
else: | |
ps = cell_size if cell_size is not None else max_error | |
ps = max(ps, max_error) | |
# With update we quantize and bin (optionally) | |
assert isinstance(other_cpts, list) | |
kpt_ids = [] | |
cpts = to_cpts(kpts, ps) | |
bpts = to_cpts(kpts, int(max_error)) | |
cp_to_id = {val: i for i, val in enumerate(other_cpts)} | |
for i, (cpt, bpt) in enumerate(zip(cpts, bpts)): | |
try: | |
kid = cp_to_id[cpt] | |
except KeyError: | |
kid = len(cp_to_id) | |
cp_to_id[cpt] = kid | |
other_cpts.append(cpt) | |
if ref_bins is not None: | |
ref_bins.append(Counter()) | |
if ref_bins is not None: | |
score = scores[i] if scores is not None else 1 | |
ref_bins[cp_to_id[cpt]][bpt] += score | |
kpt_ids.append(kid) | |
return np.array(kpt_ids) | |
def get_grouped_ids(array): | |
# Group array indices based on its values | |
# all duplicates are grouped as a set | |
idx_sort = np.argsort(array) | |
sorted_array = array[idx_sort] | |
_, ids, _ = np.unique(sorted_array, return_counts=True, | |
return_index=True) | |
res = np.split(idx_sort, ids[1:]) | |
return res | |
def get_unique_matches(match_ids, scores): | |
if len(match_ids.shape) == 1: | |
return [0] | |
isets1 = get_grouped_ids(match_ids[:, 0]) | |
isets2 = get_grouped_ids(match_ids[:, 1]) | |
uid1s = [ids[scores[ids].argmax()] for ids in isets1 if len(ids) > 0] | |
uid2s = [ids[scores[ids].argmax()] for ids in isets2 if len(ids) > 0] | |
uids = list(set(uid1s).intersection(uid2s)) | |
return match_ids[uids], scores[uids] | |
def matches_to_matches0(matches, scores): | |
if len(matches) == 0: | |
return np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.float16) | |
n_kps0 = np.max(matches[:, 0]) + 1 | |
matches0 = -np.ones((n_kps0,)) | |
scores0 = np.zeros((n_kps0,)) | |
matches0[matches[:, 0]] = matches[:, 1] | |
scores0[matches[:, 0]] = scores | |
return matches0.astype(np.int32), scores0.astype(np.float16) | |
def kpids_to_matches0(kpt_ids0, kpt_ids1, scores): | |
valid = (kpt_ids0 != -1) & (kpt_ids1 != -1) | |
matches = np.dstack([kpt_ids0[valid], kpt_ids1[valid]]) | |
matches = matches.reshape(-1, 2) | |
scores = scores[valid] | |
# Remove n-to-1 matches | |
matches, scores = get_unique_matches(matches, scores) | |
return matches_to_matches0(matches, scores) | |
def scale_keypoints(kpts, scale): | |
if np.any(scale != 1.0): | |
kpts *= kpts.new_tensor(scale) | |
return kpts | |
class ImagePairDataset(torch.utils.data.Dataset): | |
default_conf = { | |
'grayscale': True, | |
'resize_max': 1024, | |
'dfactor': 8, | |
'cache_images': False, | |
} | |
def __init__(self, image_dir, conf, pairs): | |
self.image_dir = image_dir | |
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) | |
self.pairs = sorted(pairs) if pairs else pairs | |
if self.conf.cache_images: | |
image_names = set(sum(pairs, ())) # unique image names in pairs | |
logger.info( | |
f'Loading and caching {len(image_names)} unique images.') | |
self.images = {} | |
self.scales = {} | |
for name in tqdm(image_names): | |
image = read_image(self.image_dir / name, self.conf.grayscale) | |
self.images[name], self.scales[name] = self.preprocess(image) | |
def preprocess(self, image: np.ndarray): | |
image = image.astype(np.float32, copy=False) | |
size = image.shape[:2][::-1] | |
scale = np.array([1.0, 1.0]) | |
if self.conf.resize_max: | |
scale = self.conf.resize_max / max(size) | |
if scale < 1.0: | |
size_new = tuple(int(round(x*scale)) for x in size) | |
image = resize_image(image, size_new, 'cv2_area') | |
scale = np.array(size) / np.array(size_new) | |
if self.conf.grayscale: | |
assert image.ndim == 2, image.shape | |
image = image[None] | |
else: | |
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | |
image = torch.from_numpy(image / 255.0).float() | |
# assure that the size is divisible by dfactor | |
size_new = tuple(map( | |
lambda x: int(x // self.conf.dfactor * self.conf.dfactor), | |
image.shape[-2:])) | |
image = F.resize(image, size=size_new) | |
scale = np.array(size) / np.array(size_new)[::-1] | |
return image, scale | |
def __len__(self): | |
return len(self.pairs) | |
def __getitem__(self, idx): | |
name0, name1 = self.pairs[idx] | |
if self.conf.cache_images: | |
image0, scale0 = self.images[name0], self.scales[name0] | |
image1, scale1 = self.images[name1], self.scales[name1] | |
else: | |
image0 = read_image(self.image_dir / name0, self.conf.grayscale) | |
image1 = read_image(self.image_dir / name1, self.conf.grayscale) | |
image0, scale0 = self.preprocess(image0) | |
image1, scale1 = self.preprocess(image1) | |
return image0, image1, scale0, scale1, name0, name1 | |
def match_dense(conf: Dict, | |
pairs: List[Tuple[str, str]], | |
image_dir: Path, | |
match_path: Path, # out | |
existing_refs: Optional[List] = []): | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
Model = dynamic_load(matchers, conf['model']['name']) | |
model = Model(conf['model']).eval().to(device) | |
dataset = ImagePairDataset(image_dir, conf["preprocessing"], pairs) | |
loader = torch.utils.data.DataLoader( | |
dataset, num_workers=16, batch_size=1, shuffle=False) | |
logger.info("Performing dense matching...") | |
with h5py.File(str(match_path), 'a') as fd: | |
for data in tqdm(loader, smoothing=.1): | |
# load image-pair data | |
image0, image1, scale0, scale1, (name0,), (name1,) = data | |
scale0, scale1 = scale0[0].numpy(), scale1[0].numpy() | |
image0, image1 = image0.to(device), image1.to(device) | |
# match semi-dense | |
# for consistency with pairs_from_*: refine kpts of image0 | |
if name0 in existing_refs: | |
# special case: flip to enable refinement in query image | |
pred = model({'image0': image1, 'image1': image0, 'name0': name1, 'name1': name0}) | |
pred = {**pred, | |
'keypoints0': pred['keypoints1'], | |
'keypoints1': pred['keypoints0']} | |
else: | |
# usual case | |
# # 在 image1 上 grid sample 关键点, 在 image0 上预测 refine 关键点 | |
pred = model({'image0': image0, 'image1': image1, 'name0': name0, 'name1': name1}) | |
# Rescale keypoints and move to cpu | |
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1'] | |
kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5 | |
kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5 | |
kpts0 = kpts0.cpu().numpy() | |
kpts1 = kpts1.cpu().numpy() | |
scores = pred['scores'].cpu().numpy() | |
# Write matches and matching scores in hloc format | |
pair = names_to_pair(name0, name1) | |
if pair in fd: | |
del fd[pair] | |
grp = fd.create_group(pair) | |
# Write dense matching output | |
grp.create_dataset('keypoints0', data=kpts0) | |
grp.create_dataset('keypoints1', data=kpts1) | |
grp.create_dataset('scores', data=scores) | |
del model, loader | |
# default: quantize all! | |
def load_keypoints(conf: Dict, | |
feature_paths_refs: List[Path], | |
quantize: Optional[set] = None): | |
name2ref = {n: i for i, p in enumerate(feature_paths_refs) | |
for n in list_h5_names(p)} | |
existing_refs = set(name2ref.keys()) | |
if quantize is None: | |
quantize = existing_refs # quantize all | |
if len(existing_refs) > 0: | |
logger.info(f'Loading keypoints from {len(existing_refs)} images.') | |
# Load query keypoints | |
cpdict = defaultdict(list) | |
bindict = defaultdict(list) | |
for name in existing_refs: | |
with h5py.File(str(feature_paths_refs[name2ref[name]]), 'r') as fd: | |
kps = fd[name]['keypoints'].__array__() | |
if name not in quantize: | |
cpdict[name] = kps | |
else: | |
if 'scores' in fd[name].keys(): | |
kp_scores = fd[name]['scores'].__array__() | |
else: | |
# we set the score to 1.0 if not provided | |
# increase for more weight on reference keypoints for | |
# stronger anchoring | |
kp_scores = \ | |
[1.0 for _ in range(kps.shape[0])] | |
# bin existing keypoints of reference images for association | |
assign_keypoints( | |
kps, cpdict[name], conf['max_error'], True, bindict[name], | |
kp_scores, conf['cell_size']) | |
return cpdict, bindict | |
def aggregate_matches( | |
conf: Dict, | |
pairs: List[Tuple[str, str]], | |
match_path: Path, | |
feature_path: Path, | |
required_queries: Optional[Set[str]] = None, | |
max_kps: Optional[int] = None, | |
cpdict: Dict[str, Iterable] = defaultdict(list), | |
bindict: Dict[str, List[Counter]] = defaultdict(list)): | |
if required_queries is None: | |
required_queries = set(sum(pairs, ())) | |
# default: do not overwrite existing features in feature_path! | |
required_queries -= set(list_h5_names(feature_path)) | |
# if an entry in cpdict is provided as np.ndarray we assume it is fixed | |
required_queries -= set( | |
[k for k, v in cpdict.items() if isinstance(v, np.ndarray)]) | |
# sort pairs for reduced RAM | |
pairs_per_q = Counter(list(chain(*pairs))) | |
pairs_score = [min(pairs_per_q[i], pairs_per_q[j]) for i, j in pairs] | |
pairs = [p for _, p in sorted(zip(pairs_score, pairs))] | |
if len(required_queries) > 0: | |
logger.info(f'Aggregating keypoints for {len(required_queries)} images.') | |
n_kps = 0 | |
with h5py.File(str(match_path), 'a') as fd: | |
for name0, name1 in tqdm(pairs, smoothing=.1): | |
pair = names_to_pair(name0, name1) | |
grp = fd[pair] | |
kpts0 = grp['keypoints0'].__array__() | |
kpts1 = grp['keypoints1'].__array__() | |
scores = grp['scores'].__array__() | |
# Aggregate local features | |
update0 = name0 in required_queries | |
update1 = name1 in required_queries | |
# in localization we do not want to bin the query kp | |
# assumes that the query is name0! | |
if update0 and not update1 and max_kps is None: | |
max_error0 = cell_size0 = 0.0 | |
else: | |
max_error0 = conf['max_error'] | |
cell_size0 = conf['cell_size'] | |
# Get match ids and extend query keypoints (cpdict) | |
mkp_ids0 = assign_keypoints(kpts0, cpdict[name0], max_error0, | |
update0, bindict[name0], scores, | |
cell_size0) | |
mkp_ids1 = assign_keypoints(kpts1, cpdict[name1], conf['max_error'], | |
update1, bindict[name1], scores, | |
conf['cell_size']) | |
# Build matches from assignments | |
matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores) | |
assert kpts0.shape[0] == scores.shape[0] | |
# del grp['matches0'], grp['matching_scores0'] | |
grp.create_dataset('matches0', data=matches0) | |
grp.create_dataset('matching_scores0', data=scores0) | |
# Convert bins to kps if finished, and store them | |
for name in (name0, name1): | |
pairs_per_q[name] -= 1 | |
if pairs_per_q[name] > 0 or name not in required_queries: | |
continue | |
kp_score = [c.most_common(1)[0][1] for c in bindict[name]] | |
cpdict[name] = [c.most_common(1)[0][0] for c in bindict[name]] | |
cpdict[name] = np.array(cpdict[name], dtype=np.float32) | |
# Select top-k query kps by score (reassign matches later) | |
if max_kps: | |
top_k = min(max_kps, cpdict[name].shape[0]) | |
top_k = np.argsort(kp_score)[::-1][:top_k] | |
cpdict[name] = cpdict[name][top_k] | |
kp_score = np.array(kp_score)[top_k] | |
# Write query keypoints | |
with h5py.File(feature_path, 'a') as kfd: | |
if name in kfd: | |
del kfd[name] | |
kgrp = kfd.create_group(name) | |
kgrp.create_dataset('keypoints', data=cpdict[name]) | |
kgrp.create_dataset('score', data=kp_score) | |
n_kps += cpdict[name].shape[0] | |
del bindict[name] | |
if len(required_queries) > 0: | |
avg_kp_per_image = round(n_kps / len(required_queries), 1) | |
logger.info(f'Finished assignment, found {avg_kp_per_image} ' | |
f'keypoints/image (avg.), total {n_kps}.') | |
return cpdict | |
def assign_matches( | |
pairs: List[Tuple[str, str]], | |
match_path: Path, | |
keypoints: Union[List[Path], Dict[str, np.array]], | |
max_error: float): | |
if isinstance(keypoints, list): | |
keypoints = load_keypoints({}, keypoints, quantize=set([])) | |
assert len(set(sum(pairs, ())) - set(keypoints.keys())) == 0 | |
with h5py.File(str(match_path), 'a') as fd: | |
for name0, name1 in tqdm(pairs): | |
pair = names_to_pair(name0, name1) | |
grp = fd[pair] | |
kpts0 = grp['keypoints0'].__array__() | |
kpts1 = grp['keypoints1'].__array__() | |
scores = grp['scores'].__array__() | |
# NN search across cell boundaries | |
mkp_ids0 = assign_keypoints(kpts0, keypoints[name0], max_error) | |
mkp_ids1 = assign_keypoints(kpts1, keypoints[name1], max_error) | |
matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, | |
scores) | |
# overwrite matches0 and matching_scores0 | |
del grp['matches0'], grp['matching_scores0'] | |
grp.create_dataset('matches0', data=matches0) | |
grp.create_dataset('matching_scores0', data=scores0) | |
def match_and_assign(conf: Dict, | |
pairs_path: Path, | |
image_dir: Path, | |
match_path: Path, # out | |
feature_path_q: Path, # out | |
feature_paths_refs: Optional[List[Path]] = [], | |
max_kps: Optional[int] = 8192, | |
overwrite: bool = False) -> Path: | |
for path in feature_paths_refs: | |
if not path.exists(): | |
raise FileNotFoundError(f'Reference feature file {path}.') | |
pairs = parse_retrieval(pairs_path) | |
pairs = [(q, r) for q, rs in pairs.items() for r in rs] | |
pairs = find_unique_new_pairs(pairs, None if overwrite else match_path) | |
required_queries = set(sum(pairs, ())) | |
name2ref = {n: i for i, p in enumerate(feature_paths_refs) | |
for n in list_h5_names(p)} | |
existing_refs = required_queries.intersection(set(name2ref.keys())) | |
# images which require feature extraction | |
required_queries = required_queries - existing_refs | |
if feature_path_q.exists(): | |
existing_queries = set(list_h5_names(feature_path_q)) | |
feature_paths_refs.append(feature_path_q) | |
existing_refs = set.union(existing_refs, existing_queries) | |
if not overwrite: | |
required_queries = required_queries - existing_queries | |
if len(pairs) == 0 and len(required_queries) == 0: | |
logger.info("All pairs exist. Skipping dense matching.") | |
return | |
# extract semi-dense matches | |
parts = list(match_path.parts) | |
match_cache_base = os.sep.join(parts[:-1] + ['cache']) | |
match_cache_path = os.path.join(match_cache_base, parts[-1]) | |
if not os.path.exists(match_cache_path): | |
match_dense(conf, pairs, image_dir, match_path, | |
existing_refs=existing_refs) | |
if not os.path.exists(match_cache_base): os.mkdir(match_cache_base) | |
shutil.copy(str(match_path), str(match_cache_path)) | |
else: | |
shutil.copy(str(match_cache_path), str(match_path)) | |
logger.info("Assigning matches...") | |
# Pre-load existing keypoints | |
cpdict, bindict = load_keypoints( | |
conf, feature_paths_refs, | |
quantize=required_queries) | |
# Reassign matches by aggregation | |
cpdict = aggregate_matches( | |
conf, pairs, match_path, feature_path=feature_path_q, | |
required_queries=required_queries, max_kps=max_kps, cpdict=cpdict, | |
bindict=bindict) | |
# Invalidate matches that are far from selected bin by reassignment | |
if max_kps is not None: | |
logger.info(f'Reassign matches with max_error={conf["max_error"]}.') | |
assign_matches(pairs, match_path, cpdict, | |
max_error=conf['max_error']) | |
def main(conf: Dict, | |
pairs: Path, | |
image_dir: Path, | |
export_dir: Optional[Path] = None, | |
matches: Optional[Path] = None, # out | |
features: Optional[Path] = None, # out | |
features_ref: Optional[Path] = None, | |
max_kps: Optional[int] = 8192, | |
overwrite: bool = False) -> Path: | |
logger.info('Extracting semi-dense features with configuration:' | |
f'\n{pprint.pformat(conf)}') | |
if features is None: | |
features = 'feats_' | |
if isinstance(features, Path): | |
features_q = features | |
if matches is None: | |
raise ValueError('Either provide both features and matches as Path' | |
' or both as names.') | |
else: | |
if export_dir is None: | |
raise ValueError('Provide an export_dir if features and matches' | |
f' are not file paths: {features}, {matches}.') | |
features_q = Path(export_dir, | |
f'{features}{conf["output"]}.h5') | |
if matches is None: | |
matches = Path( | |
export_dir, f'{conf["output"]}_{pairs.stem}.h5') | |
if features_ref is None: | |
features_ref = [] | |
elif isinstance(features_ref, list): | |
features_ref = list(features_ref) | |
elif isinstance(features_ref, Path): | |
features_ref = [features_ref] | |
else: | |
raise TypeError(str(features_ref)) | |
match_and_assign(conf, pairs, image_dir, matches, | |
features_q, features_ref, | |
max_kps, overwrite) | |
return features_q, matches | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--pairs', type=Path, required=True) | |
parser.add_argument('--image_dir', type=Path, required=True) | |
parser.add_argument('--export_dir', type=Path, required=True) | |
parser.add_argument('--matches', type=Path, | |
default=confs['loftr']['output']) | |
parser.add_argument('--features', type=str, | |
default='feats_' + confs['loftr']['output']) | |
parser.add_argument('--conf', type=str, default='loftr', | |
choices=list(confs.keys())) | |
args = parser.parse_args() | |
main(confs[args.conf], args.pairs, args.image_dir, args.export_dir, | |
args.matches, args.features) | |