Spaces:
Runtime error
Runtime error
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # utilities needed to load image pairs | |
| # -------------------------------------------------------- | |
| import numpy as np | |
| import torch | |
| def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True): | |
| pairs = [] | |
| if scene_graph == 'complete': # complete graph | |
| for i in range(len(imgs)): | |
| for j in range(i): | |
| pairs.append((imgs[i], imgs[j])) | |
| elif scene_graph.startswith('swin'): | |
| iscyclic = not scene_graph.endswith('noncyclic') | |
| try: | |
| winsize = int(scene_graph.split('-')[1]) | |
| except Exception as e: | |
| winsize = 3 | |
| pairsid = set() | |
| for i in range(len(imgs)): | |
| for j in range(1, winsize + 1): | |
| idx = (i + j) | |
| if iscyclic: | |
| idx = idx % len(imgs) # explicit loop closure | |
| if idx >= len(imgs): | |
| continue | |
| pairsid.add((i, idx) if i < idx else (idx, i)) | |
| for i, j in pairsid: | |
| pairs.append((imgs[i], imgs[j])) | |
| elif scene_graph.startswith('logwin'): | |
| iscyclic = not scene_graph.endswith('noncyclic') | |
| try: | |
| winsize = int(scene_graph.split('-')[1]) | |
| except Exception as e: | |
| winsize = 3 | |
| offsets = [2**i for i in range(winsize)] | |
| pairsid = set() | |
| for i in range(len(imgs)): | |
| ixs_l = [i - off for off in offsets] | |
| ixs_r = [i + off for off in offsets] | |
| for j in ixs_l + ixs_r: | |
| if iscyclic: | |
| j = j % len(imgs) # Explicit loop closure | |
| if j < 0 or j >= len(imgs) or j == i: | |
| continue | |
| pairsid.add((i, j) if i < j else (j, i)) | |
| for i, j in pairsid: | |
| pairs.append((imgs[i], imgs[j])) | |
| elif scene_graph.startswith('oneref'): | |
| refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 | |
| for j in range(len(imgs)): | |
| if j != refid: | |
| pairs.append((imgs[refid], imgs[j])) | |
| if symmetrize: | |
| pairs += [(img2, img1) for img1, img2 in pairs] | |
| # now, remove edges | |
| if isinstance(prefilter, str) and prefilter.startswith('seq'): | |
| pairs = filter_pairs_seq(pairs, int(prefilter[3:])) | |
| if isinstance(prefilter, str) and prefilter.startswith('cyc'): | |
| pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) | |
| return pairs | |
| def sel(x, kept): | |
| if isinstance(x, dict): | |
| return {k: sel(v, kept) for k, v in x.items()} | |
| if isinstance(x, (torch.Tensor, np.ndarray)): | |
| return x[kept] | |
| if isinstance(x, (tuple, list)): | |
| return type(x)([x[k] for k in kept]) | |
| def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): | |
| # number of images | |
| n = max(max(e) for e in edges) + 1 | |
| kept = [] | |
| for e, (i, j) in enumerate(edges): | |
| dis = abs(i - j) | |
| if cyclic: | |
| dis = min(dis, abs(i + n - j), abs(i - n - j)) | |
| if dis <= seq_dis_thr: | |
| kept.append(e) | |
| return kept | |
| def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): | |
| edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] | |
| kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) | |
| return [pairs[i] for i in kept] | |
| def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): | |
| edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] | |
| kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) | |
| print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') | |
| return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) | |