File size: 6,233 Bytes
499e141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import cv2
import torch
import warnings
import numpy as np
from os.path import join
from pathlib import Path

from tools import get_padding_size
from hloc.utils import CLS_DICT, exclude
from ..utils.base_model import BaseModel
from networks.dkm.models.model_zoo.DKMv3 import DKMv3


class LoFTR(BaseModel):
    default_conf = {
        'max_num_matches': None,
    }
    required_inputs = [
        'image0',
        'image1'
    ]

    def _init(self, conf):
        self.h = 672
        self.w = 896
        model = DKMv3(None, self.h, self.w, upsample_preds=True)

        checkpoints_path = join('weights', conf['weights'])
        state_dict = torch.load(checkpoints_path, map_location='cpu')
        if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
        for k in list(state_dict.keys()):
            if k.startswith('model.'):
                state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
            if 'encoder.net.fc' in k:
                state_dict.pop(k)
        model.load_state_dict(state_dict)

        self.net = model

    def _forward(self, data):
        outputs = Path(os.environ['GIMRECONSTRUCTION'])
        segment_root = outputs / '..' / 'segment'

        # For consistency with hloc pairs, we refine kpts in image0!
        rename = {
            'keypoints0': 'keypoints1',
            'keypoints1': 'keypoints0',
            'image0': 'image1',
            'image1': 'image0',
            'mask0': 'mask1',
            'mask1': 'mask0',
            'name0': 'name1',
            'name1': 'name0',
        }
        data_ = {rename[k]: v for k, v in data.items()}
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            image0, image1 = data_['image0'], data_['image1']
            img0, img1 = data_['name0'], data_['name1']
            
            # segment image
            seg_path0 = join(segment_root, '{}.npy'.format(img0[:-4]))
            mask0 = np.load(seg_path0)
            if mask0.shape[:2] != image0.shape[-2:]:
                mask0 = cv2.resize(mask0, image0.shape[-2:][::-1],
                                   interpolation=cv2.INTER_NEAREST)
            mask_0 = mask0 != CLS_DICT[exclude[0]]
            for cls in exclude[1:]:
                mask_0 = mask_0 & (mask0 != CLS_DICT[cls])
            mask_0 = mask0
            mask_0 = mask_0.astype(np.uint8)
            mask_0 = torch.from_numpy((mask_0 == 0).astype(np.uint8)).to(image0.device)
            mask_0 = mask_0.float()[None, None] == 0
            image0 = image0 * mask_0
            # segment image
            seg_path1 = join(segment_root, '{}.npy'.format(img1[:-4]))
            mask1 = np.load(seg_path1)
            if mask1.shape != image1.shape[-2:]:
                mask1 = cv2.resize(mask1, image1.shape[-2:][::-1],
                                   interpolation=cv2.INTER_NEAREST)
            mask_1 = mask1 != CLS_DICT[exclude[0]]
            for cls in exclude[1:]:
                mask_1 = mask_1 & (mask1 != CLS_DICT[cls])
            mask_1 = mask1
            mask_1 = mask_1.astype(np.uint8)
            mask_1 = torch.from_numpy((mask_1 == 0).astype(np.uint8)).to(image1.device)
            mask_1 = mask_1.float()[None, None] == 0
            image1 = image1 * mask_1

            orig_width0, orig_height0, pad_left0, pad_right0, pad_top0, pad_bottom0 = get_padding_size(image0, self.h, self.w)
            orig_width1, orig_height1, pad_left1, pad_right1, pad_top1, pad_bottom1 = get_padding_size(image1, self.h, self.w)
            image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0))
            image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1))

            dense_matches, dense_certainty = self.net.match(image0, image1)
            sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, 8192)

            m = mconf > 0
            mconf = mconf[m]
            sparse_matches = sparse_matches[m]

            height0, width0 = image0.shape[-2:]
            height1, width1 = image1.shape[-2:]

            kpts0 = sparse_matches[:, :2]
            kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2,
                                 height0 * (kpts0[:, 1] + 1) / 2), dim=-1, )
            kpts1 = sparse_matches[:, 2:]
            kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2,
                                 height1 * (kpts1[:, 1] + 1) / 2), dim=-1, )
            b_ids, i_ids = torch.where(mconf[None])

            # before padding
            kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None]
            kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None]
            mask = (kpts0[:, 0] > 0) & \
                   (kpts0[:, 1] > 0) & \
                   (kpts1[:, 0] > 0) & \
                   (kpts1[:, 1] > 0)
            mask = mask & \
                   (kpts0[:, 0] <= (orig_width0 - 1)) & \
                   (kpts1[:, 0] <= (orig_width1 - 1)) & \
                   (kpts0[:, 1] <= (orig_height0 - 1)) & \
                   (kpts1[:, 1] <= (orig_height1 - 1))

            pred = {
                'keypoints0': kpts0[i_ids],
                'keypoints1': kpts1[i_ids],
                'confidence': mconf[i_ids],
                'batch_indexes': b_ids,
            }

            # noinspection PyUnresolvedReferences
            scores, b_ids = pred['confidence'], pred['batch_indexes']
            kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
            pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask]
            pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask]

        scores = pred['confidence']

        top_k = self.conf['max_num_matches']
        if top_k is not None and len(scores) > top_k:
            keep = torch.argsort(scores, descending=True)[:top_k]
            pred['keypoints0'], pred['keypoints1'] =\
                pred['keypoints0'][keep], pred['keypoints1'][keep]
            scores = scores[keep]

        # Switch back indices
        pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
        pred['scores'] = scores
        del pred['confidence']
        return pred