File size: 3,830 Bytes
e6ac593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import concurrent.futures

import torch


class ConcurrentMatcher:
    """A class that performs matching and geometric filtering in parallel using a thread pool executor.
    It matches keypoints from two sets of descriptors and applies a robust estimator to filter the matches based on geometric constraints.

    Args:
        matcher (callable): A callable that takes two sets of descriptors and returns distances and indices of matches.
        robust_estimator (callable): A callable that estimates a geometric transformation and returns inliers.
        min_num_matches (int, optional): Minimum number of matches required to perform geometric filtering. Defaults to 8.
        max_workers (int, optional): Maximum number of threads in the thread pool executor. Defaults to 12.
    """

    def __init__(self, matcher, robust_estimator, min_num_matches=8, max_workers=12):
        self.matcher = matcher
        self.robust_estimator = robust_estimator
        self.min_num_matches = min_num_matches

        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)

    @torch.no_grad()
    def __call__(
        self,
        kpts1,
        kpts2,
        pdesc1,
        pdesc2,
        selected_mask1,
        selected_mask2,
        inl_th,
        label=None,
    ):
        dev = pdesc1.device
        B = pdesc1.shape[0]

        batch_rel_idx_matches = [None] * B
        batch_idx_matches = [None] * B
        future_results = [None] * B

        for b in range(B):
            if selected_mask1[b].sum() < 16 or selected_mask2[b].sum() < 16:
                continue

            dists, idx_matches = self.matcher(pdesc1[b][selected_mask1[b]], pdesc2[b][selected_mask2[b]])

            batch_rel_idx_matches[b] = idx_matches.clone()

            # calculate ABSOLUTE indexes
            idx_matches[:, 0] = torch.nonzero(selected_mask1[b], as_tuple=False)[idx_matches[:, 0]].squeeze()
            idx_matches[:, 1] = torch.nonzero(selected_mask2[b], as_tuple=False)[idx_matches[:, 1]].squeeze()

            batch_idx_matches[b] = idx_matches

            # if not enough matches
            if idx_matches.shape[0] < self.min_num_matches:
                ransac_inliers = torch.zeros((idx_matches.shape[0]), device=dev).bool()
                future_results[b] = (None, ransac_inliers)
                continue

            # use label information to exclude negative pairs from geometric filtering process -> enforces more descriminative descriptors
            if label is not None and label[b] == 0:
                ransac_inliers = torch.ones((idx_matches.shape[0]), device=dev).bool()
                future_results[b] = (None, ransac_inliers)
                continue

            mkpts1 = kpts1[b][idx_matches[:, 0]]
            mkpts2 = kpts2[b][idx_matches[:, 1]]

            future_results[b] = self.executor.submit(self.robust_estimator, mkpts1, mkpts2, inl_th)

        batch_ransac_inliers = [None] * B
        batch_Fm = [None] * B

        for b in range(B):
            future_result = future_results[b]
            if future_result is None:
                ransac_inliers = None
                Fm = None
            elif isinstance(future_result, tuple):
                Fm, ransac_inliers = future_result
            else:
                Fm, ransac_inliers = future_result.result()

                # if no inliers
                if ransac_inliers.sum() == 0:
                    ransac_inliers = ransac_inliers.squeeze(
                        -1
                    )  # kornia.geometry.ransac.RANSAC returns (N, 1) tensor if no inliers and (N,) tensor if inliers
                    Fm = None

            batch_ransac_inliers[b] = ransac_inliers
            batch_Fm[b] = Fm

        return batch_rel_idx_matches, batch_idx_matches, batch_ransac_inliers, batch_Fm