Spaces:
Running
Running
Realcat
commited on
Commit
·
89c9b15
1
Parent(s):
bd20887
add: dad detector with roma matcher
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -0
- config/config.yaml +23 -0
- imcui/hloc/match_dense.py +48 -5
- imcui/hloc/matchers/dad_roma.py +121 -0
- imcui/hloc/matchers/roma.py +10 -4
- imcui/hloc/matchers/xfeat_dense.py +4 -2
- imcui/hloc/matchers/xfeat_lightglue.py +4 -2
- imcui/third_party/RoMa/.gitignore +11 -0
- imcui/third_party/RoMa/LICENSE +21 -0
- imcui/third_party/RoMa/README.md +123 -0
- imcui/third_party/RoMa/data/.gitignore +2 -0
- imcui/third_party/RoMa/requirements.txt +14 -0
- imcui/third_party/RoMa/romatch/models/matcher.py +68 -32
- imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py +1 -1
- imcui/third_party/RoMa/romatch/models/transformer/layers/block.py +1 -1
- imcui/third_party/RoMa/romatch/utils/utils.py +9 -1
- imcui/third_party/RoMa/setup.py +1 -1
- imcui/third_party/dad/.gitignore +170 -0
- imcui/third_party/dad/.python-version +1 -0
- imcui/third_party/dad/LICENSE +21 -0
- imcui/third_party/dad/README.md +130 -0
- imcui/third_party/dad/dad/__init__.py +17 -0
- imcui/third_party/dad/dad/augs.py +214 -0
- imcui/third_party/dad/dad/benchmarks/__init__.py +21 -0
- imcui/third_party/dad/dad/benchmarks/hpatches.py +117 -0
- imcui/third_party/dad/dad/benchmarks/megadepth.py +219 -0
- imcui/third_party/dad/dad/benchmarks/num_inliers.py +106 -0
- imcui/third_party/dad/dad/benchmarks/scannet.py +163 -0
- imcui/third_party/dad/dad/checkpoint.py +61 -0
- imcui/third_party/dad/dad/datasets/__init__.py +0 -0
- imcui/third_party/dad/dad/datasets/megadepth.py +312 -0
- imcui/third_party/dad/dad/detectors/__init__.py +50 -0
- imcui/third_party/dad/dad/detectors/dedode_detector.py +559 -0
- imcui/third_party/dad/dad/detectors/third_party/__init__.py +11 -0
- imcui/third_party/dad/dad/detectors/third_party/harrisaff.py +35 -0
- imcui/third_party/dad/dad/detectors/third_party/hesaff.py +40 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/__init__.py +9 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/aliked.py +770 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/disk.py +48 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/dog_hardnet.py +41 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/lightglue.py +655 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/sift.py +216 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/superpoint.py +233 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue/utils.py +158 -0
- imcui/third_party/dad/dad/detectors/third_party/lightglue_detector.py +42 -0
- imcui/third_party/dad/dad/detectors/third_party/rekd/config.py +206 -0
- imcui/third_party/dad/dad/detectors/third_party/rekd/geometry_tools.py +204 -0
- imcui/third_party/dad/dad/detectors/third_party/rekd/model/REKD.py +234 -0
- imcui/third_party/dad/dad/detectors/third_party/rekd/model/kernels.py +118 -0
- imcui/third_party/dad/dad/detectors/third_party/rekd/model/load_models.py +25 -0
README.md
CHANGED
@@ -44,6 +44,7 @@ The tool currently supports various popular image matching algorithms, namely:
|
|
44 |
|
45 |
| Algorithm | Supported | Conference/Journal | Year | GitHub Link |
|
46 |
|------------------|-----------|--------------------|------|-------------|
|
|
|
47 |
| MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
|
48 |
| XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
|
49 |
| EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
|
|
|
44 |
|
45 |
| Algorithm | Supported | Conference/Journal | Year | GitHub Link |
|
46 |
|------------------|-----------|--------------------|------|-------------|
|
47 |
+
| DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
|
48 |
| MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
|
49 |
| XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
|
50 |
| EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
|
config/config.yaml
CHANGED
@@ -43,6 +43,17 @@ matcher_zoo:
|
|
43 |
# low, medium, high
|
44 |
efficiency: low
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
minima(loftr):
|
47 |
matcher: minima_loftr
|
48 |
dense: true
|
@@ -50,6 +61,7 @@ matcher_zoo:
|
|
50 |
name: MINIMA(LoFTR) #dispaly name
|
51 |
source: "ARXIV 2024"
|
52 |
paper: https://arxiv.org/abs/2412.19412
|
|
|
53 |
display: true
|
54 |
minima(RoMa):
|
55 |
matcher: minima_roma
|
@@ -59,6 +71,7 @@ matcher_zoo:
|
|
59 |
name: MINIMA(RoMa) #dispaly name
|
60 |
source: "ARXIV 2024"
|
61 |
paper: https://arxiv.org/abs/2412.19412
|
|
|
62 |
display: false
|
63 |
efficiency: low # low, medium, high
|
64 |
omniglue:
|
@@ -164,6 +177,16 @@ matcher_zoo:
|
|
164 |
paper: https://arxiv.org/pdf/2404.09692
|
165 |
project: null
|
166 |
display: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
cotr:
|
168 |
enable: false
|
169 |
skip_ci: true
|
|
|
43 |
# low, medium, high
|
44 |
efficiency: low
|
45 |
|
46 |
+
dad(RoMa):
|
47 |
+
matcher: dad_roma
|
48 |
+
skip_ci: true
|
49 |
+
dense: true
|
50 |
+
info:
|
51 |
+
name: Dad(RoMa) #dispaly name
|
52 |
+
source: "ARXIV 2025"
|
53 |
+
github: https://github.com/example/example
|
54 |
+
paper: https://arxiv.org/abs/2503.07347
|
55 |
+
display: true
|
56 |
+
efficiency: low # low, medium, high
|
57 |
minima(loftr):
|
58 |
matcher: minima_loftr
|
59 |
dense: true
|
|
|
61 |
name: MINIMA(LoFTR) #dispaly name
|
62 |
source: "ARXIV 2024"
|
63 |
paper: https://arxiv.org/abs/2412.19412
|
64 |
+
github: https://github.com/LSXI7/MINIMA
|
65 |
display: true
|
66 |
minima(RoMa):
|
67 |
matcher: minima_roma
|
|
|
71 |
name: MINIMA(RoMa) #dispaly name
|
72 |
source: "ARXIV 2024"
|
73 |
paper: https://arxiv.org/abs/2412.19412
|
74 |
+
github: https://github.com/LSXI7/MINIMA
|
75 |
display: false
|
76 |
efficiency: low # low, medium, high
|
77 |
omniglue:
|
|
|
177 |
paper: https://arxiv.org/pdf/2404.09692
|
178 |
project: null
|
179 |
display: true
|
180 |
+
jamma:
|
181 |
+
matcher: jamma
|
182 |
+
dense: true
|
183 |
+
info:
|
184 |
+
name: Jamma #dispaly name
|
185 |
+
source: "CVPR 2024"
|
186 |
+
github: https://github.com/OnderT/XoFTR
|
187 |
+
paper: https://arxiv.org/pdf/2404.09692
|
188 |
+
project: null
|
189 |
+
display: false
|
190 |
cotr:
|
191 |
enable: false
|
192 |
skip_ci: true
|
imcui/hloc/match_dense.py
CHANGED
@@ -102,6 +102,23 @@ confs = {
|
|
102 |
"max_error": 1, # max error for assigned keypoints (in px)
|
103 |
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
104 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
# "loftr_quadtree": {
|
106 |
# "output": "matches-loftr-quadtree",
|
107 |
# "model": {
|
@@ -295,7 +312,25 @@ confs = {
|
|
295 |
},
|
296 |
"preprocessing": {
|
297 |
"grayscale": False,
|
298 |
-
"force_resize":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
"resize_max": 1024,
|
300 |
"width": 320,
|
301 |
"height": 240,
|
@@ -1010,9 +1045,17 @@ def match_images(model, image_0, image_1, conf, device="cpu"):
|
|
1010 |
# Rescale keypoints and move to cpu
|
1011 |
if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
|
1012 |
kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
|
|
|
|
|
|
|
|
|
|
|
1013 |
kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
|
1014 |
kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
|
1015 |
|
|
|
|
|
|
|
1016 |
ret = {
|
1017 |
"image0": image0.squeeze().cpu().numpy(),
|
1018 |
"image1": image1.squeeze().cpu().numpy(),
|
@@ -1022,10 +1065,10 @@ def match_images(model, image_0, image_1, conf, device="cpu"):
|
|
1022 |
"keypoints1": kpts1.cpu().numpy(),
|
1023 |
"keypoints0_orig": kpts0_origin.cpu().numpy(),
|
1024 |
"keypoints1_orig": kpts1_origin.cpu().numpy(),
|
1025 |
-
"mkeypoints0":
|
1026 |
-
"mkeypoints1":
|
1027 |
-
"mkeypoints0_orig":
|
1028 |
-
"mkeypoints1_orig":
|
1029 |
"original_size0": np.array(image_0.shape[:2][::-1]),
|
1030 |
"original_size1": np.array(image_1.shape[:2][::-1]),
|
1031 |
"new_size0": np.array(image0.shape[-2:][::-1]),
|
|
|
102 |
"max_error": 1, # max error for assigned keypoints (in px)
|
103 |
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
104 |
},
|
105 |
+
"jamma": {
|
106 |
+
"output": "matches-jamma",
|
107 |
+
"model": {
|
108 |
+
"name": "jamma",
|
109 |
+
"weights": "jamma_weight.ckpt",
|
110 |
+
"max_keypoints": 2000,
|
111 |
+
"match_threshold": 0.3,
|
112 |
+
},
|
113 |
+
"preprocessing": {
|
114 |
+
"grayscale": True,
|
115 |
+
"resize_max": 1024,
|
116 |
+
"dfactor": 16,
|
117 |
+
"width": 832,
|
118 |
+
"height": 832,
|
119 |
+
"force_resize": True,
|
120 |
+
},
|
121 |
+
},
|
122 |
# "loftr_quadtree": {
|
123 |
# "output": "matches-loftr-quadtree",
|
124 |
# "model": {
|
|
|
312 |
},
|
313 |
"preprocessing": {
|
314 |
"grayscale": False,
|
315 |
+
"force_resize": False,
|
316 |
+
"resize_max": 1024,
|
317 |
+
"width": 320,
|
318 |
+
"height": 240,
|
319 |
+
"dfactor": 8,
|
320 |
+
},
|
321 |
+
},
|
322 |
+
"dad_roma": {
|
323 |
+
"output": "matches-dad_roma",
|
324 |
+
"model": {
|
325 |
+
"name": "dad_roma",
|
326 |
+
"weights": "outdoor",
|
327 |
+
"model_name": "roma_outdoor.pth",
|
328 |
+
"max_keypoints": 2000,
|
329 |
+
"match_threshold": 0.2,
|
330 |
+
},
|
331 |
+
"preprocessing": {
|
332 |
+
"grayscale": False,
|
333 |
+
"force_resize": False,
|
334 |
"resize_max": 1024,
|
335 |
"width": 320,
|
336 |
"height": 240,
|
|
|
1045 |
# Rescale keypoints and move to cpu
|
1046 |
if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
|
1047 |
kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
|
1048 |
+
mkpts0, mkpts1 = pred.get("mkeypoints0"), pred.get("mkeypoints1")
|
1049 |
+
if mkpts0 is None or mkpts1 is None:
|
1050 |
+
mkpts0 = kpts0
|
1051 |
+
mkpts1 = kpts1
|
1052 |
+
|
1053 |
kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
|
1054 |
kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
|
1055 |
|
1056 |
+
mkpts0_origin = scale_keypoints(mkpts0 + 0.5, s0) - 0.5
|
1057 |
+
mkpts1_origin = scale_keypoints(mkpts1 + 0.5, s1) - 0.5
|
1058 |
+
|
1059 |
ret = {
|
1060 |
"image0": image0.squeeze().cpu().numpy(),
|
1061 |
"image1": image1.squeeze().cpu().numpy(),
|
|
|
1065 |
"keypoints1": kpts1.cpu().numpy(),
|
1066 |
"keypoints0_orig": kpts0_origin.cpu().numpy(),
|
1067 |
"keypoints1_orig": kpts1_origin.cpu().numpy(),
|
1068 |
+
"mkeypoints0": mkpts0.cpu().numpy(),
|
1069 |
+
"mkeypoints1": mkpts1.cpu().numpy(),
|
1070 |
+
"mkeypoints0_orig": mkpts0_origin.cpu().numpy(),
|
1071 |
+
"mkeypoints1_orig": mkpts1_origin.cpu().numpy(),
|
1072 |
"original_size0": np.array(image_0.shape[:2][::-1]),
|
1073 |
"original_size1": np.array(image_1.shape[:2][::-1]),
|
1074 |
"new_size0": np.array(image0.shape[-2:][::-1]),
|
imcui/hloc/matchers/dad_roma.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
import tempfile
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from .. import MODEL_REPO_ID, logger
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
+
roma_path = Path(__file__).parent / "../../third_party/RoMa"
|
11 |
+
sys.path.append(str(roma_path))
|
12 |
+
from romatch.models.model_zoo import roma_model
|
13 |
+
|
14 |
+
dad_path = Path(__file__).parent / "../../third_party/dad"
|
15 |
+
sys.path.append(str(dad_path))
|
16 |
+
import dad as dad_detector
|
17 |
+
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
|
20 |
+
|
21 |
+
class Dad(BaseModel):
|
22 |
+
default_conf = {
|
23 |
+
"name": "two_view_pipeline",
|
24 |
+
"model_name": "roma_outdoor.pth",
|
25 |
+
"model_utils_name": "dinov2_vitl14_pretrain.pth",
|
26 |
+
"max_keypoints": 3000,
|
27 |
+
"coarse_res": (560, 560),
|
28 |
+
"upsample_res": (864, 1152),
|
29 |
+
}
|
30 |
+
required_inputs = [
|
31 |
+
"image0",
|
32 |
+
"image1",
|
33 |
+
]
|
34 |
+
|
35 |
+
# Initialize the line matcher
|
36 |
+
def _init(self, conf):
|
37 |
+
model_path = self._download_model(
|
38 |
+
repo_id=MODEL_REPO_ID,
|
39 |
+
filename="{}/{}".format("roma", self.conf["model_name"]),
|
40 |
+
)
|
41 |
+
|
42 |
+
dinov2_weights = self._download_model(
|
43 |
+
repo_id=MODEL_REPO_ID,
|
44 |
+
filename="{}/{}".format("roma", self.conf["model_utils_name"]),
|
45 |
+
)
|
46 |
+
|
47 |
+
logger.info("Loading Dad + Roma model")
|
48 |
+
# load the model
|
49 |
+
weights = torch.load(model_path, map_location="cpu")
|
50 |
+
dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
|
51 |
+
|
52 |
+
if str(device) == "cpu":
|
53 |
+
amp_dtype = torch.float32
|
54 |
+
else:
|
55 |
+
amp_dtype = torch.float16
|
56 |
+
|
57 |
+
self.matcher = roma_model(
|
58 |
+
resolution=self.conf["coarse_res"],
|
59 |
+
upsample_preds=True,
|
60 |
+
weights=weights,
|
61 |
+
dinov2_weights=dinov2_weights,
|
62 |
+
device=device,
|
63 |
+
amp_dtype=amp_dtype,
|
64 |
+
)
|
65 |
+
self.matcher.upsample_res = self.conf["upsample_res"]
|
66 |
+
self.matcher.symmetric = False
|
67 |
+
|
68 |
+
self.detector = dad_detector.load_DaD()
|
69 |
+
logger.info("Load Dad + Roma model done.")
|
70 |
+
|
71 |
+
def _forward(self, data):
|
72 |
+
img0 = data["image0"].cpu().numpy().squeeze() * 255
|
73 |
+
img1 = data["image1"].cpu().numpy().squeeze() * 255
|
74 |
+
img0 = img0.transpose(1, 2, 0)
|
75 |
+
img1 = img1.transpose(1, 2, 0)
|
76 |
+
img0 = Image.fromarray(img0.astype("uint8"))
|
77 |
+
img1 = Image.fromarray(img1.astype("uint8"))
|
78 |
+
W_A, H_A = img0.size
|
79 |
+
W_B, H_B = img1.size
|
80 |
+
|
81 |
+
# hack: bad way to save then match
|
82 |
+
with (
|
83 |
+
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_img0,
|
84 |
+
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_img1,
|
85 |
+
):
|
86 |
+
img0_path = temp_img0.name
|
87 |
+
img1_path = temp_img1.name
|
88 |
+
img0.save(img0_path)
|
89 |
+
img1.save(img1_path)
|
90 |
+
|
91 |
+
# Match
|
92 |
+
warp, certainty = self.matcher.match(img0_path, img1_path, device=device)
|
93 |
+
# Detect
|
94 |
+
keypoints_A = self.detector.detect_from_path(
|
95 |
+
img0_path,
|
96 |
+
num_keypoints=self.conf["max_keypoints"],
|
97 |
+
)["keypoints"][0]
|
98 |
+
keypoints_B = self.detector.detect_from_path(
|
99 |
+
img1_path,
|
100 |
+
num_keypoints=self.conf["max_keypoints"],
|
101 |
+
)["keypoints"][0]
|
102 |
+
matches = self.matcher.match_keypoints(
|
103 |
+
keypoints_A,
|
104 |
+
keypoints_B,
|
105 |
+
warp,
|
106 |
+
certainty,
|
107 |
+
return_tuple=False,
|
108 |
+
)
|
109 |
+
|
110 |
+
# Sample matches for estimation
|
111 |
+
kpts1, kpts2 = self.matcher.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
112 |
+
offset = self.detector.topleft - 0
|
113 |
+
kpts1, kpts2 = kpts1 - offset, kpts2 - offset
|
114 |
+
pred = {
|
115 |
+
"keypoints0": self.matcher._to_pixel_coordinates(keypoints_A, H_A, W_A),
|
116 |
+
"keypoints1": self.matcher._to_pixel_coordinates(keypoints_B, H_B, W_B),
|
117 |
+
"mkeypoints0": kpts1,
|
118 |
+
"mkeypoints1": kpts2,
|
119 |
+
"mconf": torch.ones_like(kpts1[:, 0]),
|
120 |
+
}
|
121 |
+
return pred
|
imcui/hloc/matchers/roma.py
CHANGED
@@ -20,6 +20,8 @@ class Roma(BaseModel):
|
|
20 |
"model_name": "roma_outdoor.pth",
|
21 |
"model_utils_name": "dinov2_vitl14_pretrain.pth",
|
22 |
"max_keypoints": 3000,
|
|
|
|
|
23 |
}
|
24 |
required_inputs = [
|
25 |
"image0",
|
@@ -43,15 +45,19 @@ class Roma(BaseModel):
|
|
43 |
weights = torch.load(model_path, map_location="cpu")
|
44 |
dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
|
45 |
|
|
|
|
|
|
|
|
|
46 |
self.net = roma_model(
|
47 |
-
resolution=
|
48 |
-
upsample_preds=
|
49 |
weights=weights,
|
50 |
dinov2_weights=dinov2_weights,
|
51 |
device=device,
|
52 |
-
|
53 |
-
amp_dtype=torch.float32,
|
54 |
)
|
|
|
55 |
logger.info("Load Roma model done.")
|
56 |
|
57 |
def _forward(self, data):
|
|
|
20 |
"model_name": "roma_outdoor.pth",
|
21 |
"model_utils_name": "dinov2_vitl14_pretrain.pth",
|
22 |
"max_keypoints": 3000,
|
23 |
+
"coarse_res": (560, 560),
|
24 |
+
"upsample_res": (864, 1152),
|
25 |
}
|
26 |
required_inputs = [
|
27 |
"image0",
|
|
|
45 |
weights = torch.load(model_path, map_location="cpu")
|
46 |
dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
|
47 |
|
48 |
+
if str(device) == "cpu":
|
49 |
+
amp_dtype = torch.float32
|
50 |
+
else:
|
51 |
+
amp_dtype = torch.float16
|
52 |
self.net = roma_model(
|
53 |
+
resolution=self.conf["coarse_res"],
|
54 |
+
upsample_preds=True,
|
55 |
weights=weights,
|
56 |
dinov2_weights=dinov2_weights,
|
57 |
device=device,
|
58 |
+
amp_dtype=amp_dtype,
|
|
|
59 |
)
|
60 |
+
self.matcher.upsample_res = self.conf["upsample_res"]
|
61 |
logger.info("Load Roma model done.")
|
62 |
|
63 |
def _forward(self, data):
|
imcui/hloc/matchers/xfeat_dense.py
CHANGED
@@ -47,8 +47,10 @@ class XFeatDense(BaseModel):
|
|
47 |
# we use results from one batch
|
48 |
matches = matches[0]
|
49 |
pred = {
|
50 |
-
"keypoints0":
|
51 |
-
"keypoints1":
|
|
|
|
|
52 |
"mconf": torch.ones_like(matches[:, 0]),
|
53 |
}
|
54 |
return pred
|
|
|
47 |
# we use results from one batch
|
48 |
matches = matches[0]
|
49 |
pred = {
|
50 |
+
"keypoints0": out0["keypoints"].squeeze(),
|
51 |
+
"keypoints1": out1["keypoints"].squeeze(),
|
52 |
+
"mkeypoints0": matches[:, :2],
|
53 |
+
"mkeypoints1": matches[:, 2:],
|
54 |
"mconf": torch.ones_like(matches[:, 0]),
|
55 |
}
|
56 |
return pred
|
imcui/hloc/matchers/xfeat_lightglue.py
CHANGED
@@ -41,8 +41,10 @@ class XFeatLightGlue(BaseModel):
|
|
41 |
mkpts_0 = torch.from_numpy(mkpts_0) # n x 2
|
42 |
mkpts_1 = torch.from_numpy(mkpts_1) # n x 2
|
43 |
pred = {
|
44 |
-
"keypoints0":
|
45 |
-
"keypoints1":
|
|
|
|
|
46 |
"mconf": torch.ones_like(mkpts_0[:, 0]),
|
47 |
}
|
48 |
return pred
|
|
|
41 |
mkpts_0 = torch.from_numpy(mkpts_0) # n x 2
|
42 |
mkpts_1 = torch.from_numpy(mkpts_1) # n x 2
|
43 |
pred = {
|
44 |
+
"keypoints0": out0["keypoints"].squeeze(),
|
45 |
+
"keypoints1": out1["keypoints"].squeeze(),
|
46 |
+
"mkeypoints0": mkpts_0,
|
47 |
+
"mkeypoints1": mkpts_1,
|
48 |
"mconf": torch.ones_like(mkpts_0[:, 0]),
|
49 |
}
|
50 |
return pred
|
imcui/third_party/RoMa/.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.egg-info*
|
2 |
+
*.vscode*
|
3 |
+
*__pycache__*
|
4 |
+
vis*
|
5 |
+
workspace*
|
6 |
+
.venv
|
7 |
+
.DS_Store
|
8 |
+
jobs/*
|
9 |
+
*ignore_me*
|
10 |
+
*.pth
|
11 |
+
wandb*
|
imcui/third_party/RoMa/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Johan Edstedt
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
imcui/third_party/RoMa/README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
<p align="center">
|
3 |
+
<h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
|
6 |
+
·
|
7 |
+
<a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
|
8 |
+
·
|
9 |
+
<a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
|
10 |
+
·
|
11 |
+
<a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
|
12 |
+
·
|
13 |
+
<a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
|
14 |
+
</p>
|
15 |
+
<h2 align="center"><p>
|
16 |
+
<a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
|
17 |
+
<a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
|
18 |
+
</p></h2>
|
19 |
+
<div align="center"></div>
|
20 |
+
</p>
|
21 |
+
<br/>
|
22 |
+
<p align="center">
|
23 |
+
<img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
|
24 |
+
<br>
|
25 |
+
<em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
|
26 |
+
</p>
|
27 |
+
|
28 |
+
## Setup/Install
|
29 |
+
In your python environment (tested on Linux python 3.10), run:
|
30 |
+
```bash
|
31 |
+
pip install -e .
|
32 |
+
```
|
33 |
+
## Demo / How to Use
|
34 |
+
We provide two demos in the [demos folder](demo).
|
35 |
+
Here's the gist of it:
|
36 |
+
```python
|
37 |
+
from romatch import roma_outdoor
|
38 |
+
roma_model = roma_outdoor(device=device)
|
39 |
+
# Match
|
40 |
+
warp, certainty = roma_model.match(imA_path, imB_path, device=device)
|
41 |
+
# Sample matches for estimation
|
42 |
+
matches, certainty = roma_model.sample(warp, certainty)
|
43 |
+
# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
|
44 |
+
kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
45 |
+
# Find a fundamental matrix (or anything else of interest)
|
46 |
+
F, mask = cv2.findFundamentalMat(
|
47 |
+
kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
48 |
+
)
|
49 |
+
```
|
50 |
+
|
51 |
+
**New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
|
52 |
+
|
53 |
+
## Settings
|
54 |
+
|
55 |
+
### Resolution
|
56 |
+
By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
|
57 |
+
You can change this at construction (see roma_outdoor kwargs).
|
58 |
+
You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
|
59 |
+
|
60 |
+
### Sampling
|
61 |
+
roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
|
62 |
+
|
63 |
+
|
64 |
+
## Reproducing Results
|
65 |
+
The experiments in the paper are provided in the [experiments folder](experiments).
|
66 |
+
|
67 |
+
### Training
|
68 |
+
1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
|
69 |
+
2. Run the relevant experiment, e.g.,
|
70 |
+
```bash
|
71 |
+
torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
|
72 |
+
```
|
73 |
+
### Testing
|
74 |
+
```bash
|
75 |
+
python experiments/roma_outdoor.py --only_test --benchmark mega-1500
|
76 |
+
```
|
77 |
+
## License
|
78 |
+
All our code except DINOv2 is MIT license.
|
79 |
+
DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
|
80 |
+
|
81 |
+
## Acknowledgement
|
82 |
+
Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
|
83 |
+
|
84 |
+
## Tiny RoMa
|
85 |
+
If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
|
86 |
+
```python
|
87 |
+
from romatch import tiny_roma_v1_outdoor
|
88 |
+
tiny_roma_model = tiny_roma_v1_outdoor(device=device)
|
89 |
+
```
|
90 |
+
Mega1500:
|
91 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
92 |
+
|----------|----------|----------|----------|
|
93 |
+
| XFeat | 46.4 | 58.9 | 69.2 |
|
94 |
+
| XFeat* | 51.9 | 67.2 | 78.9 |
|
95 |
+
| Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
|
96 |
+
| RoMa | - | - | - |
|
97 |
+
|
98 |
+
Mega-8-Scenes (See DKM):
|
99 |
+
| | AUC@5 | AUC@10 | AUC@20 |
|
100 |
+
|----------|----------|----------|----------|
|
101 |
+
| XFeat | - | - | - |
|
102 |
+
| XFeat* | 50.1 | 64.4 | 75.2 |
|
103 |
+
| Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
|
104 |
+
| RoMa | - | - | - |
|
105 |
+
|
106 |
+
IMC22 :'):
|
107 |
+
| | mAA@10 |
|
108 |
+
|----------|----------|
|
109 |
+
| XFeat | 42.1 |
|
110 |
+
| XFeat* | - |
|
111 |
+
| Tiny RoMa v1 | 42.2 |
|
112 |
+
| RoMa | - |
|
113 |
+
|
114 |
+
## BibTeX
|
115 |
+
If you find our models useful, please consider citing our paper!
|
116 |
+
```
|
117 |
+
@article{edstedt2024roma,
|
118 |
+
title={{RoMa: Robust Dense Feature Matching}},
|
119 |
+
author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
|
120 |
+
journal={IEEE Conference on Computer Vision and Pattern Recognition},
|
121 |
+
year={2024}
|
122 |
+
}
|
123 |
+
```
|
imcui/third_party/RoMa/data/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
imcui/third_party/RoMa/requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
einops
|
3 |
+
torchvision
|
4 |
+
opencv-python
|
5 |
+
kornia
|
6 |
+
albumentations
|
7 |
+
loguru
|
8 |
+
tqdm
|
9 |
+
matplotlib
|
10 |
+
h5py
|
11 |
+
wandb
|
12 |
+
timm
|
13 |
+
poselib
|
14 |
+
#xformers # Optional, used for memefficient attention
|
imcui/third_party/RoMa/romatch/models/matcher.py
CHANGED
@@ -11,7 +11,7 @@ from PIL import Image
|
|
11 |
|
12 |
from romatch.utils import get_tuple_transform_ops
|
13 |
from romatch.utils.local_correlation import local_correlation
|
14 |
-
from romatch.utils.utils import cls_to_flow_refine, get_autocast_params
|
15 |
from romatch.utils.kde import kde
|
16 |
|
17 |
class ConvRefiner(nn.Module):
|
@@ -573,12 +573,30 @@ class RegressionMatcher(nn.Module):
|
|
573 |
kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
|
574 |
return kpts_A, kpts_B
|
575 |
|
576 |
-
def match_keypoints(
|
577 |
-
|
578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
D = torch.cdist(x_A_to_B, x_B)
|
580 |
-
inds_A, inds_B = torch.nonzero(
|
581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
if return_tuple:
|
583 |
if return_inds:
|
584 |
return inds_A, inds_B
|
@@ -586,25 +604,38 @@ class RegressionMatcher(nn.Module):
|
|
586 |
return x_A[inds_A], x_B[inds_B]
|
587 |
else:
|
588 |
if return_inds:
|
589 |
-
return torch.cat((inds_A, inds_B),dim=-1)
|
590 |
else:
|
591 |
-
return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
592 |
|
593 |
@torch.inference_mode()
|
594 |
def match(
|
595 |
self,
|
596 |
-
|
597 |
-
|
598 |
*args,
|
599 |
batched=False,
|
600 |
-
device
|
601 |
):
|
602 |
if device is None:
|
603 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
604 |
-
|
605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
else:
|
607 |
-
|
|
|
608 |
|
609 |
symmetric = self.symmetric
|
610 |
self.train(False)
|
@@ -616,9 +647,9 @@ class RegressionMatcher(nn.Module):
|
|
616 |
# Get images in good format
|
617 |
ws = self.w_resized
|
618 |
hs = self.h_resized
|
619 |
-
|
620 |
test_transform = get_tuple_transform_ops(
|
621 |
-
resize=(hs, ws), normalize=True, clahe
|
622 |
)
|
623 |
im_A, im_B = test_transform((im_A, im_B))
|
624 |
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
@@ -633,20 +664,20 @@ class RegressionMatcher(nn.Module):
|
|
633 |
finest_scale = 1
|
634 |
# Run matcher
|
635 |
if symmetric:
|
636 |
-
corresps
|
637 |
else:
|
638 |
-
corresps = self.forward(batch, batched
|
639 |
|
640 |
if self.upsample_preds:
|
641 |
hs, ws = self.upsample_res
|
642 |
-
|
643 |
if self.attenuate_cert:
|
644 |
low_res_certainty = F.interpolate(
|
645 |
-
|
646 |
)
|
647 |
cert_clamp = 0
|
648 |
factor = 0.5
|
649 |
-
low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
650 |
|
651 |
if self.upsample_preds:
|
652 |
finest_corresps = corresps[finest_scale]
|
@@ -654,34 +685,39 @@ class RegressionMatcher(nn.Module):
|
|
654 |
test_transform = get_tuple_transform_ops(
|
655 |
resize=(hs, ws), normalize=True
|
656 |
)
|
657 |
-
|
|
|
|
|
|
|
|
|
|
|
658 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
659 |
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
660 |
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
661 |
if symmetric:
|
662 |
-
corresps = self.forward_symmetric(batch, upsample
|
663 |
else:
|
664 |
-
corresps = self.forward(batch, batched
|
665 |
-
|
666 |
-
im_A_to_im_B = corresps[finest_scale]["flow"]
|
667 |
certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
|
668 |
if finest_scale != 1:
|
669 |
im_A_to_im_B = F.interpolate(
|
670 |
-
|
671 |
)
|
672 |
certainty = F.interpolate(
|
673 |
-
|
674 |
)
|
675 |
im_A_to_im_B = im_A_to_im_B.permute(
|
676 |
0, 2, 3, 1
|
677 |
-
|
678 |
# Create im_A meshgrid
|
679 |
im_A_coords = torch.meshgrid(
|
680 |
(
|
681 |
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
682 |
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
683 |
),
|
684 |
-
indexing
|
685 |
)
|
686 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
687 |
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
@@ -689,14 +725,14 @@ class RegressionMatcher(nn.Module):
|
|
689 |
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
690 |
if (im_A_to_im_B.abs() > 1).any() and True:
|
691 |
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
692 |
-
certainty[wrong[:,None]] = 0
|
693 |
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
694 |
if symmetric:
|
695 |
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
696 |
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
697 |
im_B_coords = im_A_coords
|
698 |
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
699 |
-
warp = torch.cat((q_warp, s_warp),dim=2)
|
700 |
certainty = torch.cat(certainty.chunk(2), dim=3)
|
701 |
else:
|
702 |
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
|
|
11 |
|
12 |
from romatch.utils import get_tuple_transform_ops
|
13 |
from romatch.utils.local_correlation import local_correlation
|
14 |
+
from romatch.utils.utils import check_rgb, cls_to_flow_refine, get_autocast_params, check_not_i16
|
15 |
from romatch.utils.kde import kde
|
16 |
|
17 |
class ConvRefiner(nn.Module):
|
|
|
573 |
kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
|
574 |
return kpts_A, kpts_B
|
575 |
|
576 |
+
def match_keypoints(
|
577 |
+
self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False, max_dist = 0.005, cert_th = 0,
|
578 |
+
):
|
579 |
+
x_A_to_B = F.grid_sample(
|
580 |
+
warp[..., -2:].permute(2, 0, 1)[None],
|
581 |
+
x_A[None, None],
|
582 |
+
align_corners=False,
|
583 |
+
mode="bilinear",
|
584 |
+
)[0, :, 0].mT
|
585 |
+
cert_A_to_B = F.grid_sample(
|
586 |
+
certainty[None, None, ...],
|
587 |
+
x_A[None, None],
|
588 |
+
align_corners=False,
|
589 |
+
mode="bilinear",
|
590 |
+
)[0, 0, 0]
|
591 |
D = torch.cdist(x_A_to_B, x_B)
|
592 |
+
inds_A, inds_B = torch.nonzero(
|
593 |
+
(D == D.min(dim=-1, keepdim=True).values)
|
594 |
+
* (D == D.min(dim=-2, keepdim=True).values)
|
595 |
+
* (cert_A_to_B[:, None] > cert_th)
|
596 |
+
* (D < max_dist),
|
597 |
+
as_tuple=True,
|
598 |
+
)
|
599 |
+
|
600 |
if return_tuple:
|
601 |
if return_inds:
|
602 |
return inds_A, inds_B
|
|
|
604 |
return x_A[inds_A], x_B[inds_B]
|
605 |
else:
|
606 |
if return_inds:
|
607 |
+
return torch.cat((inds_A, inds_B), dim=-1)
|
608 |
else:
|
609 |
+
return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
|
610 |
|
611 |
@torch.inference_mode()
|
612 |
def match(
|
613 |
self,
|
614 |
+
im_A_input,
|
615 |
+
im_B_input,
|
616 |
*args,
|
617 |
batched=False,
|
618 |
+
device=None,
|
619 |
):
|
620 |
if device is None:
|
621 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
622 |
+
|
623 |
+
# Check if inputs are file paths or already loaded images
|
624 |
+
if isinstance(im_A_input, (str, os.PathLike)):
|
625 |
+
im_A = Image.open(im_A_input)
|
626 |
+
check_not_i16(im_A)
|
627 |
+
im_A = im_A.convert("RGB")
|
628 |
+
else:
|
629 |
+
check_rgb(im_A_input)
|
630 |
+
im_A = im_A_input
|
631 |
+
|
632 |
+
if isinstance(im_B_input, (str, os.PathLike)):
|
633 |
+
im_B = Image.open(im_B_input)
|
634 |
+
check_not_i16(im_B)
|
635 |
+
im_B = im_B.convert("RGB")
|
636 |
else:
|
637 |
+
check_rgb(im_B_input)
|
638 |
+
im_B = im_B_input
|
639 |
|
640 |
symmetric = self.symmetric
|
641 |
self.train(False)
|
|
|
647 |
# Get images in good format
|
648 |
ws = self.w_resized
|
649 |
hs = self.h_resized
|
650 |
+
|
651 |
test_transform = get_tuple_transform_ops(
|
652 |
+
resize=(hs, ws), normalize=True, clahe=False
|
653 |
)
|
654 |
im_A, im_B = test_transform((im_A, im_B))
|
655 |
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
|
|
664 |
finest_scale = 1
|
665 |
# Run matcher
|
666 |
if symmetric:
|
667 |
+
corresps = self.forward_symmetric(batch)
|
668 |
else:
|
669 |
+
corresps = self.forward(batch, batched=True)
|
670 |
|
671 |
if self.upsample_preds:
|
672 |
hs, ws = self.upsample_res
|
673 |
+
|
674 |
if self.attenuate_cert:
|
675 |
low_res_certainty = F.interpolate(
|
676 |
+
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
677 |
)
|
678 |
cert_clamp = 0
|
679 |
factor = 0.5
|
680 |
+
low_res_certainty = factor * low_res_certainty * (low_res_certainty < cert_clamp)
|
681 |
|
682 |
if self.upsample_preds:
|
683 |
finest_corresps = corresps[finest_scale]
|
|
|
685 |
test_transform = get_tuple_transform_ops(
|
686 |
resize=(hs, ws), normalize=True
|
687 |
)
|
688 |
+
if isinstance(im_A_input, (str, os.PathLike)):
|
689 |
+
im_A, im_B = test_transform(
|
690 |
+
(Image.open(im_A_input).convert('RGB'), Image.open(im_B_input).convert('RGB')))
|
691 |
+
else:
|
692 |
+
im_A, im_B = test_transform((im_A_input, im_B_input))
|
693 |
+
|
694 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
695 |
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
696 |
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
697 |
if symmetric:
|
698 |
+
corresps = self.forward_symmetric(batch, upsample=True, batched=True, scale_factor=scale_factor)
|
699 |
else:
|
700 |
+
corresps = self.forward(batch, batched=True, upsample=True, scale_factor=scale_factor)
|
701 |
+
|
702 |
+
im_A_to_im_B = corresps[finest_scale]["flow"]
|
703 |
certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
|
704 |
if finest_scale != 1:
|
705 |
im_A_to_im_B = F.interpolate(
|
706 |
+
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
707 |
)
|
708 |
certainty = F.interpolate(
|
709 |
+
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
710 |
)
|
711 |
im_A_to_im_B = im_A_to_im_B.permute(
|
712 |
0, 2, 3, 1
|
713 |
+
)
|
714 |
# Create im_A meshgrid
|
715 |
im_A_coords = torch.meshgrid(
|
716 |
(
|
717 |
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
718 |
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
719 |
),
|
720 |
+
indexing='ij'
|
721 |
)
|
722 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
723 |
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
|
|
725 |
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
726 |
if (im_A_to_im_B.abs() > 1).any() and True:
|
727 |
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
728 |
+
certainty[wrong[:, None]] = 0
|
729 |
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
730 |
if symmetric:
|
731 |
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
732 |
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
733 |
im_B_coords = im_A_coords
|
734 |
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
735 |
+
warp = torch.cat((q_warp, s_warp), dim=2)
|
736 |
certainty = torch.cat(certainty.chunk(2), dim=3)
|
737 |
else:
|
738 |
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py
CHANGED
@@ -22,7 +22,7 @@ try:
|
|
22 |
|
23 |
XFORMERS_AVAILABLE = True
|
24 |
except ImportError:
|
25 |
-
logger.warning("xFormers not available")
|
26 |
XFORMERS_AVAILABLE = False
|
27 |
|
28 |
|
|
|
22 |
|
23 |
XFORMERS_AVAILABLE = True
|
24 |
except ImportError:
|
25 |
+
# logger.warning("xFormers not available")
|
26 |
XFORMERS_AVAILABLE = False
|
27 |
|
28 |
|
imcui/third_party/RoMa/romatch/models/transformer/layers/block.py
CHANGED
@@ -29,7 +29,7 @@ try:
|
|
29 |
|
30 |
XFORMERS_AVAILABLE = True
|
31 |
except ImportError:
|
32 |
-
logger.warning("xFormers not available")
|
33 |
XFORMERS_AVAILABLE = False
|
34 |
|
35 |
|
|
|
29 |
|
30 |
XFORMERS_AVAILABLE = True
|
31 |
except ImportError:
|
32 |
+
# logger.warning("xFormers not available")
|
33 |
XFORMERS_AVAILABLE = False
|
34 |
|
35 |
|
imcui/third_party/RoMa/romatch/utils/utils.py
CHANGED
@@ -651,4 +651,12 @@ def get_autocast_params(device=None, enabled=False, dtype=None):
|
|
651 |
enabled = False
|
652 |
# mps is not supported
|
653 |
autocast_device = "cpu"
|
654 |
-
return autocast_device, enabled, out_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
enabled = False
|
652 |
# mps is not supported
|
653 |
autocast_device = "cpu"
|
654 |
+
return autocast_device, enabled, out_dtype
|
655 |
+
|
656 |
+
def check_not_i16(im):
|
657 |
+
if im.mode == "I;16":
|
658 |
+
raise NotImplementedError("Can't handle 16 bit images")
|
659 |
+
|
660 |
+
def check_rgb(im):
|
661 |
+
if im.mode != "RGB":
|
662 |
+
raise NotImplementedError("Can't handle non-RGB images")
|
imcui/third_party/RoMa/setup.py
CHANGED
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|
3 |
setup(
|
4 |
name="romatch",
|
5 |
packages=find_packages(include=("romatch*",)),
|
6 |
-
version="0.0.
|
7 |
author="Johan Edstedt",
|
8 |
install_requires=open("requirements.txt", "r").read().split("\n"),
|
9 |
)
|
|
|
3 |
setup(
|
4 |
name="romatch",
|
5 |
packages=find_packages(include=("romatch*",)),
|
6 |
+
version="0.0.2",
|
7 |
author="Johan Edstedt",
|
8 |
install_requires=open("requirements.txt", "r").read().split("\n"),
|
9 |
)
|
imcui/third_party/dad/.gitignore
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
.vscode*
|
163 |
+
*.pth
|
164 |
+
wandb
|
165 |
+
*.out
|
166 |
+
vis/
|
167 |
+
workspace/
|
168 |
+
|
169 |
+
.DS_Store
|
170 |
+
*.tar
|
imcui/third_party/dad/.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.10
|
imcui/third_party/dad/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Johan Edstedt
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
imcui/third_party/dad/README.md
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<h1 align="center"> <ins>DaD:</ins> Distilled Reinforcement Learning for Diverse Keypoint Detection</h1>
|
3 |
+
<p align="center">
|
4 |
+
<a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
|
5 |
+
·
|
6 |
+
<a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
|
7 |
+
·
|
8 |
+
<a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
|
9 |
+
·
|
10 |
+
<a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
|
11 |
+
</p>
|
12 |
+
<h2 align="center"><p>
|
13 |
+
<a href="https://arxiv.org/abs/2503.07347" align="center">Paper</a>
|
14 |
+
</p></h2>
|
15 |
+
<p align="center">
|
16 |
+
<img src="assets/qualitative.jpg" alt="example" width=80%>
|
17 |
+
<br>
|
18 |
+
<em>DaD's a pretty good keypoint detector, probably the best.</em>
|
19 |
+
</p>
|
20 |
+
</p>
|
21 |
+
<p align="center">
|
22 |
+
</p>
|
23 |
+
|
24 |
+
## Run
|
25 |
+
```python
|
26 |
+
import dad
|
27 |
+
from PIL import Image
|
28 |
+
img_path = "assets/0015_A.jpg"
|
29 |
+
W, H = Image.open(img_path).size# your image shape,
|
30 |
+
detector = dad.load_DaD()
|
31 |
+
detections = detector.detect_from_path(
|
32 |
+
img_path,
|
33 |
+
num_keypoints = 512,
|
34 |
+
return_dense_probs=True)
|
35 |
+
detections["keypoints"] # 1 x 512 x 2, normalized coordinates of keypoints
|
36 |
+
detector.to_pixel_coords(detections["keypoints"], H, W)
|
37 |
+
detections["keypoint_probs"] # 1 x 512, probs of sampled keypoints
|
38 |
+
detections["dense_probs"] # 1 x H x W, probability map
|
39 |
+
```
|
40 |
+
|
41 |
+
## Visualize
|
42 |
+
```python
|
43 |
+
import dad
|
44 |
+
from dad.utils import visualize_keypoints
|
45 |
+
detector = dad.load_DaD()
|
46 |
+
img_path = "assets/0015_A.jpg"
|
47 |
+
vis_path = "vis/0015_A_dad.jpg"
|
48 |
+
visualize_keypoints(img_path, vis_path, detector, num_keypoints = 512)
|
49 |
+
```
|
50 |
+
|
51 |
+
## Install
|
52 |
+
Get uv
|
53 |
+
```bash
|
54 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
55 |
+
```
|
56 |
+
### In an existing env
|
57 |
+
Assuming you already have some env active:
|
58 |
+
```bash
|
59 |
+
uv pip install dad@git+https://github.com/Parskatt/dad.git
|
60 |
+
```
|
61 |
+
### As a project
|
62 |
+
For dev, etc:
|
63 |
+
```bash
|
64 |
+
git clone [email protected]:Parskatt/dad.git
|
65 |
+
uv sync
|
66 |
+
source .venv/bin/activate
|
67 |
+
```
|
68 |
+
|
69 |
+
## Evaluation
|
70 |
+
For to evaluate, e.g., DaD on ScanNet1500 with 512 keypoints, run
|
71 |
+
```bash
|
72 |
+
python experiments/benchmark.py --detector DaD --num_keypoints 512 --benchmark ScanNet1500
|
73 |
+
```
|
74 |
+
Note: leaving out num_keypoints will run the benchmark for all numbers of keypoints, i.e., [512, 1024, 2048, 4096, 8192].
|
75 |
+
### Third party detectors
|
76 |
+
We provide wrappers for a somewhat large set of previous detectors,
|
77 |
+
```bash
|
78 |
+
python experiments/benchmark.py --help
|
79 |
+
```
|
80 |
+
|
81 |
+
## Training
|
82 |
+
To train our final model from the emergent light and dark detector, run
|
83 |
+
```bash
|
84 |
+
python experiments/repro_paper_results/distill.py
|
85 |
+
```
|
86 |
+
The emergent models come from running
|
87 |
+
```bash
|
88 |
+
python experiments/repro_paper_results/rl.py
|
89 |
+
```
|
90 |
+
Note however that the types of detectors that come from this type of training is stochastic, and you may need to do several runs to get a detector that matches our results.
|
91 |
+
|
92 |
+
## How I run experiments
|
93 |
+
(Note: You don't have to do this, it's just how I do it.)
|
94 |
+
At the start of a new day I typically run
|
95 |
+
```bash
|
96 |
+
python new_day.py
|
97 |
+
```
|
98 |
+
This creates a new folder in experiments, e.g., `experiments/w11/monday`.
|
99 |
+
I then typically just copy the contents of a previous experiment, e.g.,
|
100 |
+
```bash
|
101 |
+
cp experiments/repro_paper_results/rl.py experiments/w11/monday/new-cool-hparams.py
|
102 |
+
```
|
103 |
+
Change whatever you want to change in `experiments/w11/monday/new-cool-hparams.py`.
|
104 |
+
|
105 |
+
Then run it with
|
106 |
+
```bash
|
107 |
+
python experiments/w11/monday/new-cool-hparams.py
|
108 |
+
```
|
109 |
+
This will be tracked in wandb as `w11-monday-new-cool-hparams` in the `DaD` project.
|
110 |
+
|
111 |
+
You might not want to track stuff, and perhaps display some debugstuff, then you can run instead as, which also won't log to wandb
|
112 |
+
```bash
|
113 |
+
DEBUG=1 python experiments/w11/monday/new-cool-hparams.py
|
114 |
+
```
|
115 |
+
## Evaluation Results
|
116 |
+
TODO
|
117 |
+
|
118 |
+
## Licenses
|
119 |
+
DaD is MIT licensed.
|
120 |
+
|
121 |
+
Third party detectors in [dad/detectors/third_party](dad/detectors/third_party) have their own licenses. If you use them, please refer to their respective licenses in [here](licenses) (NOTE: There may be more licenses you need to care about than the ones listed. Before using any third pary code, make sure you're following their respective license).
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
## BibTeX
|
127 |
+
|
128 |
+
```txt
|
129 |
+
TODO
|
130 |
+
```
|
imcui/third_party/dad/dad/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .logging import logger as logger
|
2 |
+
from .logging import configure_logger as configure_logger
|
3 |
+
import os
|
4 |
+
from .detectors import load_DaD as load_DaD
|
5 |
+
from .detectors import dedode_detector_S as dedode_detector_S
|
6 |
+
from .detectors import dedode_detector_B as dedode_detector_B
|
7 |
+
from .detectors import dedode_detector_L as dedode_detector_L
|
8 |
+
from .detectors import load_DaDDark as load_DaDDark
|
9 |
+
from .detectors import load_DaDLight as load_DaDLight
|
10 |
+
from .types import Detector as Detector
|
11 |
+
from .types import Matcher as Matcher
|
12 |
+
from .types import Benchmark as Benchmark
|
13 |
+
|
14 |
+
configure_logger()
|
15 |
+
DEBUG_MODE = bool(os.environ.get("DEBUG", False))
|
16 |
+
RANK = 0
|
17 |
+
GLOBAL_STEP = 0
|
imcui/third_party/dad/dad/augs.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.transforms.functional import InterpolationMode
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
|
12 |
+
def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
|
13 |
+
ops = []
|
14 |
+
if resize:
|
15 |
+
ops.append(
|
16 |
+
TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias=False)
|
17 |
+
)
|
18 |
+
return TupleCompose(ops)
|
19 |
+
|
20 |
+
|
21 |
+
def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe=False):
|
22 |
+
ops = []
|
23 |
+
if resize:
|
24 |
+
ops.append(TupleResize(resize, antialias=True))
|
25 |
+
if clahe:
|
26 |
+
ops.append(TupleClahe())
|
27 |
+
if normalize:
|
28 |
+
ops.append(TupleToTensorScaled())
|
29 |
+
ops.append(
|
30 |
+
TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
31 |
+
) # Imagenet mean/std
|
32 |
+
else:
|
33 |
+
if unscale:
|
34 |
+
ops.append(TupleToTensorUnscaled())
|
35 |
+
else:
|
36 |
+
ops.append(TupleToTensorScaled())
|
37 |
+
return TupleCompose(ops)
|
38 |
+
|
39 |
+
|
40 |
+
class Clahe:
|
41 |
+
def __init__(self, cliplimit=2, blocksize=8) -> None:
|
42 |
+
self.clahe = cv2.createCLAHE(cliplimit, (blocksize, blocksize))
|
43 |
+
|
44 |
+
def __call__(self, im):
|
45 |
+
im_hsv = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2HSV)
|
46 |
+
im_v = self.clahe.apply(im_hsv[:, :, 2])
|
47 |
+
im_hsv[..., 2] = im_v
|
48 |
+
im_clahe = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2RGB)
|
49 |
+
return Image.fromarray(im_clahe)
|
50 |
+
|
51 |
+
|
52 |
+
class TupleClahe:
|
53 |
+
def __init__(self, cliplimit=8, blocksize=8) -> None:
|
54 |
+
self.clahe = Clahe(cliplimit, blocksize)
|
55 |
+
|
56 |
+
def __call__(self, ims):
|
57 |
+
return [self.clahe(im) for im in ims]
|
58 |
+
|
59 |
+
|
60 |
+
class ToTensorScaled(object):
|
61 |
+
"""Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
|
62 |
+
|
63 |
+
def __call__(self, im):
|
64 |
+
if not isinstance(im, torch.Tensor):
|
65 |
+
im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
|
66 |
+
im /= 255.0
|
67 |
+
return torch.from_numpy(im)
|
68 |
+
else:
|
69 |
+
return im
|
70 |
+
|
71 |
+
def __repr__(self):
|
72 |
+
return "ToTensorScaled(./255)"
|
73 |
+
|
74 |
+
|
75 |
+
class TupleToTensorScaled(object):
|
76 |
+
def __init__(self):
|
77 |
+
self.to_tensor = ToTensorScaled()
|
78 |
+
|
79 |
+
def __call__(self, im_tuple):
|
80 |
+
return [self.to_tensor(im) for im in im_tuple]
|
81 |
+
|
82 |
+
def __repr__(self):
|
83 |
+
return "TupleToTensorScaled(./255)"
|
84 |
+
|
85 |
+
|
86 |
+
class ToTensorUnscaled(object):
|
87 |
+
"""Convert a RGB PIL Image to a CHW ordered Tensor"""
|
88 |
+
|
89 |
+
def __call__(self, im):
|
90 |
+
return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
|
91 |
+
|
92 |
+
def __repr__(self):
|
93 |
+
return "ToTensorUnscaled()"
|
94 |
+
|
95 |
+
|
96 |
+
class TupleToTensorUnscaled(object):
|
97 |
+
"""Convert a RGB PIL Image to a CHW ordered Tensor"""
|
98 |
+
|
99 |
+
def __init__(self):
|
100 |
+
self.to_tensor = ToTensorUnscaled()
|
101 |
+
|
102 |
+
def __call__(self, im_tuple):
|
103 |
+
return [self.to_tensor(im) for im in im_tuple]
|
104 |
+
|
105 |
+
def __repr__(self):
|
106 |
+
return "TupleToTensorUnscaled()"
|
107 |
+
|
108 |
+
|
109 |
+
class TupleResize(object):
|
110 |
+
def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias=None):
|
111 |
+
self.size = size
|
112 |
+
self.resize = transforms.Resize(size, mode, antialias=antialias)
|
113 |
+
|
114 |
+
def __call__(self, im_tuple):
|
115 |
+
return [self.resize(im) for im in im_tuple]
|
116 |
+
|
117 |
+
def __repr__(self):
|
118 |
+
return "TupleResize(size={})".format(self.size)
|
119 |
+
|
120 |
+
|
121 |
+
class Normalize:
|
122 |
+
def __call__(self, im):
|
123 |
+
mean = im.mean(dim=(1, 2), keepdims=True)
|
124 |
+
std = im.std(dim=(1, 2), keepdims=True)
|
125 |
+
return (im - mean) / std
|
126 |
+
|
127 |
+
|
128 |
+
class TupleNormalize(object):
|
129 |
+
def __init__(self, mean, std):
|
130 |
+
self.mean = mean
|
131 |
+
self.std = std
|
132 |
+
self.normalize = transforms.Normalize(mean=mean, std=std)
|
133 |
+
|
134 |
+
def __call__(self, im_tuple):
|
135 |
+
c, h, w = im_tuple[0].shape
|
136 |
+
if c > 3:
|
137 |
+
warnings.warn(f"Number of channels {c=} > 3, assuming first 3 are rgb")
|
138 |
+
return [self.normalize(im[:3]) for im in im_tuple]
|
139 |
+
|
140 |
+
def __repr__(self):
|
141 |
+
return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
|
142 |
+
|
143 |
+
|
144 |
+
class TupleCompose(object):
|
145 |
+
def __init__(self, transforms):
|
146 |
+
self.transforms = transforms
|
147 |
+
|
148 |
+
def __call__(self, im_tuple):
|
149 |
+
for t in self.transforms:
|
150 |
+
im_tuple = t(im_tuple)
|
151 |
+
return im_tuple
|
152 |
+
|
153 |
+
def __repr__(self):
|
154 |
+
format_string = self.__class__.__name__ + "("
|
155 |
+
for t in self.transforms:
|
156 |
+
format_string += "\n"
|
157 |
+
format_string += " {0}".format(t)
|
158 |
+
format_string += "\n)"
|
159 |
+
return format_string
|
160 |
+
|
161 |
+
|
162 |
+
def pad_kps(kps: torch.Tensor, pad_num_kps: int, value: int = -1):
|
163 |
+
assert len(kps.shape) == 2
|
164 |
+
N = len(kps)
|
165 |
+
padded_kps = value * torch.ones((pad_num_kps, 2)).to(kps)
|
166 |
+
padded_kps[:N] = kps
|
167 |
+
return padded_kps
|
168 |
+
|
169 |
+
|
170 |
+
def crop(img: Image.Image, x: int, y: int, crop_size: int):
|
171 |
+
width, height = img.size
|
172 |
+
if width < crop_size or height < crop_size:
|
173 |
+
raise ValueError(f"Image dimensions must be at least {crop_size}x{crop_size}")
|
174 |
+
cropped_img = img.crop((x, y, x + crop_size, y + crop_size))
|
175 |
+
return cropped_img
|
176 |
+
|
177 |
+
|
178 |
+
def random_crop(img: Image.Image, crop_size: int):
|
179 |
+
width, height = img.size
|
180 |
+
|
181 |
+
if width < crop_size or height < crop_size:
|
182 |
+
raise ValueError(f"Image dimensions must be at least {crop_size}x{crop_size}")
|
183 |
+
|
184 |
+
max_x = width - crop_size
|
185 |
+
max_y = height - crop_size
|
186 |
+
|
187 |
+
x = random.randint(0, max_x)
|
188 |
+
y = random.randint(0, max_y)
|
189 |
+
|
190 |
+
cropped_img = img.crop((x, y, x + crop_size, y + crop_size))
|
191 |
+
return cropped_img, (x, y)
|
192 |
+
|
193 |
+
|
194 |
+
def luminance_negation(pil_img):
|
195 |
+
# Convert PIL RGB to numpy array
|
196 |
+
rgb_array = np.array(pil_img)
|
197 |
+
|
198 |
+
# Convert RGB to BGR (OpenCV format)
|
199 |
+
bgr = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)
|
200 |
+
|
201 |
+
# Convert BGR to LAB
|
202 |
+
lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
|
203 |
+
|
204 |
+
# Negate L channel
|
205 |
+
lab[:, :, 0] = 255 - lab[:, :, 0]
|
206 |
+
|
207 |
+
# Convert back to BGR
|
208 |
+
bgr_result = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
209 |
+
|
210 |
+
# Convert BGR back to RGB
|
211 |
+
rgb_result = cv2.cvtColor(bgr_result, cv2.COLOR_BGR2RGB)
|
212 |
+
|
213 |
+
# Convert numpy array back to PIL Image
|
214 |
+
return Image.fromarray(rgb_result)
|
imcui/third_party/dad/dad/benchmarks/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .benchmark import Benchmark as Benchmark
|
2 |
+
from .num_inliers import NumInliersBenchmark as NumInliersBenchmark
|
3 |
+
from .megadepth import Mega1500 as Mega1500
|
4 |
+
from .megadepth import Mega1500_F as Mega1500_F
|
5 |
+
from .megadepth import MegaIMCPT as MegaIMCPT
|
6 |
+
from .megadepth import MegaIMCPT_F as MegaIMCPT_F
|
7 |
+
from .scannet import ScanNet1500 as ScanNet1500
|
8 |
+
from .scannet import ScanNet1500_F as ScanNet1500_F
|
9 |
+
from .hpatches import HPatchesViewpoint as HPatchesViewpoint
|
10 |
+
from .hpatches import HPatchesIllum as HPatchesIllum
|
11 |
+
|
12 |
+
all_benchmarks = [
|
13 |
+
Mega1500.__name__,
|
14 |
+
Mega1500_F.__name__,
|
15 |
+
MegaIMCPT.__name__,
|
16 |
+
MegaIMCPT_F.__name__,
|
17 |
+
ScanNet1500.__name__,
|
18 |
+
ScanNet1500_F.__name__,
|
19 |
+
HPatchesViewpoint.__name__,
|
20 |
+
HPatchesIllum.__name__,
|
21 |
+
]
|
imcui/third_party/dad/dad/benchmarks/hpatches.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import poselib
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from dad.types import Detector, Matcher, Benchmark
|
10 |
+
|
11 |
+
|
12 |
+
class HPatchesBenchmark(Benchmark):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
data_root="data/hpatches",
|
16 |
+
sample_every=1,
|
17 |
+
num_ransac_runs=5,
|
18 |
+
num_keypoints: Optional[list[int]] = None,
|
19 |
+
) -> None:
|
20 |
+
super().__init__(
|
21 |
+
data_root=data_root,
|
22 |
+
num_keypoints=num_keypoints,
|
23 |
+
sample_every=sample_every,
|
24 |
+
num_ransac_runs=num_ransac_runs,
|
25 |
+
thresholds=[3, 5, 10],
|
26 |
+
)
|
27 |
+
seqs_dir = "hpatches-sequences-release"
|
28 |
+
self.seqs_path = os.path.join(self.data_root, seqs_dir)
|
29 |
+
self.seq_names = sorted(os.listdir(self.seqs_path))
|
30 |
+
self.topleft = 0.0
|
31 |
+
self._post_init()
|
32 |
+
self.skip_seqs: str
|
33 |
+
self.scene_names: list[str]
|
34 |
+
|
35 |
+
def _post_init(self):
|
36 |
+
# set self.skip_seqs and self.scene_names here
|
37 |
+
raise NotImplementedError()
|
38 |
+
|
39 |
+
def benchmark(self, detector: Detector, matcher: Matcher):
|
40 |
+
homog_dists = []
|
41 |
+
for seq_idx, seq_name in enumerate(tqdm(self.seq_names[:: self.sample_every])):
|
42 |
+
if self.skip_seqs in seq_name:
|
43 |
+
# skip illumination seqs
|
44 |
+
continue
|
45 |
+
im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
|
46 |
+
im_A = Image.open(im_A_path)
|
47 |
+
w1, h1 = im_A.size
|
48 |
+
for im_idx in list(range(2, 7)):
|
49 |
+
im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
|
50 |
+
H = np.loadtxt(
|
51 |
+
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
52 |
+
)
|
53 |
+
warp, certainty = matcher.match(im_A_path, im_B_path)
|
54 |
+
for num_kps in self.num_keypoints:
|
55 |
+
keypoints_A = detector.detect_from_path(
|
56 |
+
im_A_path,
|
57 |
+
num_keypoints=num_kps,
|
58 |
+
)["keypoints"][0]
|
59 |
+
keypoints_B = detector.detect_from_path(
|
60 |
+
im_B_path,
|
61 |
+
num_keypoints=num_kps,
|
62 |
+
)["keypoints"][0]
|
63 |
+
matches = matcher.match_keypoints(
|
64 |
+
keypoints_A,
|
65 |
+
keypoints_B,
|
66 |
+
warp,
|
67 |
+
certainty,
|
68 |
+
return_tuple=False,
|
69 |
+
)
|
70 |
+
im_A = Image.open(im_A_path)
|
71 |
+
w1, h1 = im_A.size
|
72 |
+
im_B = Image.open(im_B_path)
|
73 |
+
w2, h2 = im_B.size
|
74 |
+
kpts1, kpts2 = matcher.to_pixel_coordinates(matches, h1, w1, h2, w2)
|
75 |
+
offset = detector.topleft - self.topleft
|
76 |
+
kpts1, kpts2 = kpts1 - offset, kpts2 - offset
|
77 |
+
for _ in range(self.num_ransac_runs):
|
78 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
79 |
+
kpts1 = kpts1[shuffling]
|
80 |
+
kpts2 = kpts2[shuffling]
|
81 |
+
threshold = 2.0
|
82 |
+
H_pred, res = poselib.estimate_homography(
|
83 |
+
kpts1.cpu().numpy(),
|
84 |
+
kpts2.cpu().numpy(),
|
85 |
+
ransac_opt={
|
86 |
+
"max_reproj_error": threshold,
|
87 |
+
},
|
88 |
+
)
|
89 |
+
corners = np.array(
|
90 |
+
[
|
91 |
+
[0, 0, 1],
|
92 |
+
[0, h1 - 1, 1],
|
93 |
+
[w1 - 1, 0, 1],
|
94 |
+
[w1 - 1, h1 - 1, 1],
|
95 |
+
]
|
96 |
+
)
|
97 |
+
real_warped_corners = np.dot(corners, np.transpose(H))
|
98 |
+
real_warped_corners = (
|
99 |
+
real_warped_corners[:, :2] / real_warped_corners[:, 2:]
|
100 |
+
)
|
101 |
+
warped_corners = np.dot(corners, np.transpose(H_pred))
|
102 |
+
warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
|
103 |
+
mean_dist = np.mean(
|
104 |
+
np.linalg.norm(real_warped_corners - warped_corners, axis=1)
|
105 |
+
) / (min(w2, h2) / 480.0)
|
106 |
+
homog_dists.append(mean_dist)
|
107 |
+
return self.compute_auc(np.array(homog_dists))
|
108 |
+
|
109 |
+
|
110 |
+
class HPatchesViewpoint(HPatchesBenchmark):
|
111 |
+
def _post_init(self):
|
112 |
+
self.skip_seqs = "i_"
|
113 |
+
|
114 |
+
|
115 |
+
class HPatchesIllum(HPatchesBenchmark):
|
116 |
+
def _post_init(self):
|
117 |
+
self.skip_seqs = "v_"
|
imcui/third_party/dad/dad/benchmarks/megadepth.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from dad.types import Detector, Matcher, Benchmark
|
8 |
+
from dad.utils import (
|
9 |
+
compute_pose_error,
|
10 |
+
compute_relative_pose,
|
11 |
+
estimate_pose_essential,
|
12 |
+
estimate_pose_fundamental,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class MegaDepthPoseEstimationBenchmark(Benchmark):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
data_root="data/megadepth",
|
20 |
+
sample_every=1,
|
21 |
+
num_ransac_runs=5,
|
22 |
+
num_keypoints: Optional[list[int]] = None,
|
23 |
+
) -> None:
|
24 |
+
super().__init__(
|
25 |
+
data_root=data_root,
|
26 |
+
num_keypoints=num_keypoints,
|
27 |
+
sample_every=sample_every,
|
28 |
+
num_ransac_runs=num_ransac_runs,
|
29 |
+
thresholds=[5, 10, 20],
|
30 |
+
)
|
31 |
+
self.sample_every = sample_every
|
32 |
+
self.topleft = 0.5
|
33 |
+
self._post_init()
|
34 |
+
self.model: Literal["fundamental", "essential"]
|
35 |
+
self.scene_names: list[str]
|
36 |
+
self.benchmark_name: str
|
37 |
+
|
38 |
+
def _post_init(self):
|
39 |
+
raise NotImplementedError(
|
40 |
+
"Add scene names and benchmark name in derived class _post_init"
|
41 |
+
)
|
42 |
+
|
43 |
+
def benchmark(
|
44 |
+
self,
|
45 |
+
detector: Detector,
|
46 |
+
matcher: Matcher,
|
47 |
+
):
|
48 |
+
self.scenes = [
|
49 |
+
np.load(f"{self.data_root}/{scene}", allow_pickle=True)
|
50 |
+
for scene in self.scene_names
|
51 |
+
]
|
52 |
+
|
53 |
+
data_root = self.data_root
|
54 |
+
tot_e_pose = []
|
55 |
+
n_matches = []
|
56 |
+
for scene_ind in range(len(self.scenes)):
|
57 |
+
scene = self.scenes[scene_ind]
|
58 |
+
pairs = scene["pair_infos"]
|
59 |
+
intrinsics = scene["intrinsics"]
|
60 |
+
poses = scene["poses"]
|
61 |
+
im_paths = scene["image_paths"]
|
62 |
+
pair_inds = range(len(pairs))
|
63 |
+
for pairind in (
|
64 |
+
pbar := tqdm(
|
65 |
+
pair_inds[:: self.sample_every],
|
66 |
+
desc="Current AUC: ?",
|
67 |
+
mininterval=10,
|
68 |
+
)
|
69 |
+
):
|
70 |
+
idx1, idx2 = pairs[pairind][0]
|
71 |
+
K1 = intrinsics[idx1].copy()
|
72 |
+
T1 = poses[idx1].copy()
|
73 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
74 |
+
K2 = intrinsics[idx2].copy()
|
75 |
+
T2 = poses[idx2].copy()
|
76 |
+
R2, t2 = T2[:3, :3], T2[:3, 3]
|
77 |
+
R, t = compute_relative_pose(R1, t1, R2, t2)
|
78 |
+
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
79 |
+
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
80 |
+
|
81 |
+
warp, certainty = matcher.match(im_A_path, im_B_path)
|
82 |
+
for num_kps in self.num_keypoints:
|
83 |
+
keypoints_A = detector.detect_from_path(
|
84 |
+
im_A_path,
|
85 |
+
num_keypoints=num_kps,
|
86 |
+
)["keypoints"][0]
|
87 |
+
keypoints_B = detector.detect_from_path(
|
88 |
+
im_B_path,
|
89 |
+
num_keypoints=num_kps,
|
90 |
+
)["keypoints"][0]
|
91 |
+
matches = matcher.match_keypoints(
|
92 |
+
keypoints_A,
|
93 |
+
keypoints_B,
|
94 |
+
warp,
|
95 |
+
certainty,
|
96 |
+
return_tuple=False,
|
97 |
+
)
|
98 |
+
n_matches.append(matches.shape[0])
|
99 |
+
im_A = Image.open(im_A_path)
|
100 |
+
w1, h1 = im_A.size
|
101 |
+
im_B = Image.open(im_B_path)
|
102 |
+
w2, h2 = im_B.size
|
103 |
+
kpts1, kpts2 = matcher.to_pixel_coordinates(matches, h1, w1, h2, w2)
|
104 |
+
offset = detector.topleft - self.topleft
|
105 |
+
kpts1, kpts2 = kpts1 - offset, kpts2 - offset
|
106 |
+
|
107 |
+
for _ in range(self.num_ransac_runs):
|
108 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
109 |
+
kpts1 = kpts1[shuffling]
|
110 |
+
kpts2 = kpts2[shuffling]
|
111 |
+
threshold = 2.0
|
112 |
+
if self.model == "essential":
|
113 |
+
R_est, t_est = estimate_pose_essential(
|
114 |
+
kpts1.cpu().numpy(),
|
115 |
+
kpts2.cpu().numpy(),
|
116 |
+
w1,
|
117 |
+
h1,
|
118 |
+
K1,
|
119 |
+
w2,
|
120 |
+
h2,
|
121 |
+
K2,
|
122 |
+
threshold,
|
123 |
+
)
|
124 |
+
elif self.model == "fundamental":
|
125 |
+
R_est, t_est = estimate_pose_fundamental(
|
126 |
+
kpts1.cpu().numpy(),
|
127 |
+
kpts2.cpu().numpy(),
|
128 |
+
w1,
|
129 |
+
h1,
|
130 |
+
K1,
|
131 |
+
w2,
|
132 |
+
h2,
|
133 |
+
K2,
|
134 |
+
threshold,
|
135 |
+
)
|
136 |
+
T1_to_2_est = np.concatenate((R_est, t_est[:, None]), axis=-1)
|
137 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
138 |
+
e_pose = max(e_t, e_R)
|
139 |
+
tot_e_pose.append(e_pose)
|
140 |
+
pbar.set_description(
|
141 |
+
f"Current AUCS: {self.compute_auc(np.array(tot_e_pose))}"
|
142 |
+
)
|
143 |
+
n_matches = np.array(n_matches)
|
144 |
+
print(n_matches.mean(), np.median(n_matches), np.std(n_matches))
|
145 |
+
return self.compute_auc(np.array(tot_e_pose))
|
146 |
+
|
147 |
+
|
148 |
+
class Mega1500(MegaDepthPoseEstimationBenchmark):
|
149 |
+
def _post_init(self):
|
150 |
+
self.scene_names = [
|
151 |
+
"0015_0.1_0.3.npz",
|
152 |
+
"0015_0.3_0.5.npz",
|
153 |
+
"0022_0.1_0.3.npz",
|
154 |
+
"0022_0.3_0.5.npz",
|
155 |
+
"0022_0.5_0.7.npz",
|
156 |
+
]
|
157 |
+
self.benchmark_name = "Mega1500"
|
158 |
+
self.model = "essential"
|
159 |
+
|
160 |
+
|
161 |
+
class Mega1500_F(MegaDepthPoseEstimationBenchmark):
|
162 |
+
def _post_init(self):
|
163 |
+
self.scene_names = [
|
164 |
+
"0015_0.1_0.3.npz",
|
165 |
+
"0015_0.3_0.5.npz",
|
166 |
+
"0022_0.1_0.3.npz",
|
167 |
+
"0022_0.3_0.5.npz",
|
168 |
+
"0022_0.5_0.7.npz",
|
169 |
+
]
|
170 |
+
# self.benchmark_name = "Mega1500_F"
|
171 |
+
self.model = "fundamental"
|
172 |
+
|
173 |
+
|
174 |
+
class MegaIMCPT(MegaDepthPoseEstimationBenchmark):
|
175 |
+
def _post_init(self):
|
176 |
+
self.scene_names = [
|
177 |
+
"mega_8_scenes_0008_0.1_0.3.npz",
|
178 |
+
"mega_8_scenes_0008_0.3_0.5.npz",
|
179 |
+
"mega_8_scenes_0019_0.1_0.3.npz",
|
180 |
+
"mega_8_scenes_0019_0.3_0.5.npz",
|
181 |
+
"mega_8_scenes_0021_0.1_0.3.npz",
|
182 |
+
"mega_8_scenes_0021_0.3_0.5.npz",
|
183 |
+
"mega_8_scenes_0024_0.1_0.3.npz",
|
184 |
+
"mega_8_scenes_0024_0.3_0.5.npz",
|
185 |
+
"mega_8_scenes_0025_0.1_0.3.npz",
|
186 |
+
"mega_8_scenes_0025_0.3_0.5.npz",
|
187 |
+
"mega_8_scenes_0032_0.1_0.3.npz",
|
188 |
+
"mega_8_scenes_0032_0.3_0.5.npz",
|
189 |
+
"mega_8_scenes_0063_0.1_0.3.npz",
|
190 |
+
"mega_8_scenes_0063_0.3_0.5.npz",
|
191 |
+
"mega_8_scenes_1589_0.1_0.3.npz",
|
192 |
+
"mega_8_scenes_1589_0.3_0.5.npz",
|
193 |
+
]
|
194 |
+
# self.benchmark_name = "MegaIMCPT"
|
195 |
+
self.model = "essential"
|
196 |
+
|
197 |
+
|
198 |
+
class MegaIMCPT_F(MegaDepthPoseEstimationBenchmark):
|
199 |
+
def _post_init(self):
|
200 |
+
self.scene_names = [
|
201 |
+
"mega_8_scenes_0008_0.1_0.3.npz",
|
202 |
+
"mega_8_scenes_0008_0.3_0.5.npz",
|
203 |
+
"mega_8_scenes_0019_0.1_0.3.npz",
|
204 |
+
"mega_8_scenes_0019_0.3_0.5.npz",
|
205 |
+
"mega_8_scenes_0021_0.1_0.3.npz",
|
206 |
+
"mega_8_scenes_0021_0.3_0.5.npz",
|
207 |
+
"mega_8_scenes_0024_0.1_0.3.npz",
|
208 |
+
"mega_8_scenes_0024_0.3_0.5.npz",
|
209 |
+
"mega_8_scenes_0025_0.1_0.3.npz",
|
210 |
+
"mega_8_scenes_0025_0.3_0.5.npz",
|
211 |
+
"mega_8_scenes_0032_0.1_0.3.npz",
|
212 |
+
"mega_8_scenes_0032_0.3_0.5.npz",
|
213 |
+
"mega_8_scenes_0063_0.1_0.3.npz",
|
214 |
+
"mega_8_scenes_0063_0.3_0.5.npz",
|
215 |
+
"mega_8_scenes_1589_0.1_0.3.npz",
|
216 |
+
"mega_8_scenes_1589_0.3_0.5.npz",
|
217 |
+
]
|
218 |
+
# self.benchmark_name = "MegaIMCPT_F"
|
219 |
+
self.model = "fundamental"
|
imcui/third_party/dad/dad/benchmarks/num_inliers.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from dad.types import Detector
|
6 |
+
from dad.utils import get_gt_warp, to_best_device
|
7 |
+
|
8 |
+
|
9 |
+
class NumInliersBenchmark:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
dataset,
|
13 |
+
num_samples=1000,
|
14 |
+
batch_size=8,
|
15 |
+
num_keypoints=512,
|
16 |
+
**kwargs,
|
17 |
+
) -> None:
|
18 |
+
sampler = torch.utils.data.WeightedRandomSampler(
|
19 |
+
torch.ones(len(dataset)), replacement=False, num_samples=num_samples
|
20 |
+
)
|
21 |
+
dataloader = torch.utils.data.DataLoader(
|
22 |
+
dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler
|
23 |
+
)
|
24 |
+
self.dataloader = dataloader
|
25 |
+
self.tracked_metrics = {}
|
26 |
+
self.batch_size = batch_size
|
27 |
+
self.N = len(dataloader)
|
28 |
+
self.num_keypoints = num_keypoints
|
29 |
+
|
30 |
+
def compute_batch_metrics(self, outputs, batch):
|
31 |
+
kpts_A, kpts_B = outputs["keypoints_A"], outputs["keypoints_B"]
|
32 |
+
B, K, H, W = batch["im_A"].shape
|
33 |
+
gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(
|
34 |
+
batch["im_A_depth"],
|
35 |
+
batch["im_B_depth"],
|
36 |
+
batch["T_1to2"],
|
37 |
+
batch["K1"],
|
38 |
+
batch["K2"],
|
39 |
+
H=H,
|
40 |
+
W=W,
|
41 |
+
)
|
42 |
+
kpts_A_to_B = F.grid_sample(
|
43 |
+
gt_warp_A_to_B[..., 2:].float().permute(0, 3, 1, 2),
|
44 |
+
kpts_A[..., None, :],
|
45 |
+
align_corners=False,
|
46 |
+
mode="bilinear",
|
47 |
+
)[..., 0].mT
|
48 |
+
legit_A_to_B = F.grid_sample(
|
49 |
+
valid_mask_A_to_B.reshape(B, 1, H, W),
|
50 |
+
kpts_A[..., None, :],
|
51 |
+
align_corners=False,
|
52 |
+
mode="bilinear",
|
53 |
+
)[..., 0, :, 0]
|
54 |
+
dists = (
|
55 |
+
torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.0]
|
56 |
+
).float()
|
57 |
+
if legit_A_to_B.sum() == 0:
|
58 |
+
return
|
59 |
+
percent_inliers_at_1 = (dists < 0.02).float().mean()
|
60 |
+
percent_inliers_at_05 = (dists < 0.01).float().mean()
|
61 |
+
percent_inliers_at_025 = (dists < 0.005).float().mean()
|
62 |
+
percent_inliers_at_01 = (dists < 0.002).float().mean()
|
63 |
+
percent_inliers_at_005 = (dists < 0.001).float().mean()
|
64 |
+
|
65 |
+
self.tracked_metrics["percent_inliers_at_1"] = (
|
66 |
+
self.tracked_metrics.get("percent_inliers_at_1", 0)
|
67 |
+
+ 1 / self.N * percent_inliers_at_1
|
68 |
+
)
|
69 |
+
self.tracked_metrics["percent_inliers_at_05"] = (
|
70 |
+
self.tracked_metrics.get("percent_inliers_at_05", 0)
|
71 |
+
+ 1 / self.N * percent_inliers_at_05
|
72 |
+
)
|
73 |
+
self.tracked_metrics["percent_inliers_at_025"] = (
|
74 |
+
self.tracked_metrics.get("percent_inliers_at_025", 0)
|
75 |
+
+ 1 / self.N * percent_inliers_at_025
|
76 |
+
)
|
77 |
+
self.tracked_metrics["percent_inliers_at_01"] = (
|
78 |
+
self.tracked_metrics.get("percent_inliers_at_01", 0)
|
79 |
+
+ 1 / self.N * percent_inliers_at_01
|
80 |
+
)
|
81 |
+
self.tracked_metrics["percent_inliers_at_005"] = (
|
82 |
+
self.tracked_metrics.get("percent_inliers_at_005", 0)
|
83 |
+
+ 1 / self.N * percent_inliers_at_005
|
84 |
+
)
|
85 |
+
|
86 |
+
def benchmark(self, detector: Detector):
|
87 |
+
self.tracked_metrics = {}
|
88 |
+
|
89 |
+
print("Evaluating percent inliers...")
|
90 |
+
for idx, batch in enumerate(tqdm(self.dataloader, mininterval=10.0)):
|
91 |
+
batch = to_best_device(batch)
|
92 |
+
outputs = detector.detect(batch, num_keypoints=self.num_keypoints)
|
93 |
+
keypoints_A, keypoints_B = outputs["keypoints"].chunk(2)
|
94 |
+
if isinstance(outputs["keypoints"], (tuple, list)):
|
95 |
+
keypoints_A, keypoints_B = (
|
96 |
+
torch.stack(keypoints_A),
|
97 |
+
torch.stack(keypoints_B),
|
98 |
+
)
|
99 |
+
outputs = {"keypoints_A": keypoints_A, "keypoints_B": keypoints_B}
|
100 |
+
self.compute_batch_metrics(outputs, batch)
|
101 |
+
[
|
102 |
+
print(name, metric.item() * self.N / (idx + 1))
|
103 |
+
for name, metric in self.tracked_metrics.items()
|
104 |
+
if "percent" in name
|
105 |
+
]
|
106 |
+
return self.tracked_metrics
|
imcui/third_party/dad/dad/benchmarks/scannet.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
from typing import Literal, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from dad.types import Detector, Matcher, Benchmark
|
10 |
+
from dad.utils import (
|
11 |
+
compute_pose_error,
|
12 |
+
estimate_pose_essential,
|
13 |
+
estimate_pose_fundamental,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class ScanNetBenchmark(Benchmark):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
sample_every: int = 1,
|
21 |
+
num_ransac_runs=5,
|
22 |
+
data_root: str = "data/scannet",
|
23 |
+
num_keypoints: Optional[list[int]] = None,
|
24 |
+
) -> None:
|
25 |
+
super().__init__(
|
26 |
+
data_root=data_root,
|
27 |
+
num_keypoints=num_keypoints,
|
28 |
+
sample_every=sample_every,
|
29 |
+
num_ransac_runs=num_ransac_runs,
|
30 |
+
thresholds=[5, 10, 20],
|
31 |
+
)
|
32 |
+
self.sample_every = sample_every
|
33 |
+
self.topleft = 0.0
|
34 |
+
self._post_init()
|
35 |
+
self.model: Literal["fundamental", "essential"]
|
36 |
+
self.test_pairs: str
|
37 |
+
self.benchmark_name: str
|
38 |
+
|
39 |
+
def _post_init(self):
|
40 |
+
# set
|
41 |
+
raise NotImplementedError("")
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def benchmark(self, matcher: Matcher, detector: Detector):
|
45 |
+
tmp = np.load(self.test_pairs)
|
46 |
+
pairs, rel_pose = tmp["name"], tmp["rel_pose"]
|
47 |
+
tot_e_pose = []
|
48 |
+
# pair_inds = np.random.choice(range(len(pairs)), size=len(pairs), replace=False)
|
49 |
+
for pairind in tqdm(
|
50 |
+
range(0, len(pairs), self.sample_every), smoothing=0.9, mininterval=10
|
51 |
+
):
|
52 |
+
scene = pairs[pairind]
|
53 |
+
scene_name = f"scene0{scene[0]}_00"
|
54 |
+
im_A_path = osp.join(
|
55 |
+
self.data_root,
|
56 |
+
"scans_test",
|
57 |
+
scene_name,
|
58 |
+
"color",
|
59 |
+
f"{scene[2]}.jpg",
|
60 |
+
)
|
61 |
+
im_A = Image.open(im_A_path)
|
62 |
+
im_B_path = osp.join(
|
63 |
+
self.data_root,
|
64 |
+
"scans_test",
|
65 |
+
scene_name,
|
66 |
+
"color",
|
67 |
+
f"{scene[3]}.jpg",
|
68 |
+
)
|
69 |
+
im_B = Image.open(im_B_path)
|
70 |
+
T_gt = rel_pose[pairind].reshape(3, 4)
|
71 |
+
R, t = T_gt[:3, :3], T_gt[:3, 3]
|
72 |
+
K = np.stack(
|
73 |
+
[
|
74 |
+
np.array([float(i) for i in r.split()])
|
75 |
+
for r in open(
|
76 |
+
osp.join(
|
77 |
+
self.data_root,
|
78 |
+
"scans_test",
|
79 |
+
scene_name,
|
80 |
+
"intrinsic",
|
81 |
+
"intrinsic_color.txt",
|
82 |
+
),
|
83 |
+
"r",
|
84 |
+
)
|
85 |
+
.read()
|
86 |
+
.split("\n")
|
87 |
+
if r
|
88 |
+
]
|
89 |
+
)
|
90 |
+
w1, h1 = im_A.size
|
91 |
+
w2, h2 = im_B.size
|
92 |
+
K1 = K.copy()[:3, :3]
|
93 |
+
K2 = K.copy()[:3, :3]
|
94 |
+
warp, certainty = matcher.match(im_A_path, im_B_path)
|
95 |
+
for num_kps in self.num_keypoints:
|
96 |
+
keypoints_A = detector.detect_from_path(
|
97 |
+
im_A_path,
|
98 |
+
num_keypoints=num_kps,
|
99 |
+
)["keypoints"][0]
|
100 |
+
keypoints_B = detector.detect_from_path(
|
101 |
+
im_B_path,
|
102 |
+
num_keypoints=num_kps,
|
103 |
+
)["keypoints"][0]
|
104 |
+
matches = matcher.match_keypoints(
|
105 |
+
keypoints_A,
|
106 |
+
keypoints_B,
|
107 |
+
warp,
|
108 |
+
certainty,
|
109 |
+
return_tuple=False,
|
110 |
+
)
|
111 |
+
kpts1, kpts2 = matcher.to_pixel_coordinates(matches, h1, w1, h2, w2)
|
112 |
+
|
113 |
+
offset = detector.topleft - self.topleft
|
114 |
+
kpts1, kpts2 = kpts1 - offset, kpts2 - offset
|
115 |
+
|
116 |
+
for _ in range(self.num_ransac_runs):
|
117 |
+
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
118 |
+
kpts1 = kpts1[shuffling]
|
119 |
+
kpts2 = kpts2[shuffling]
|
120 |
+
threshold = 2.0
|
121 |
+
if self.model == "essential":
|
122 |
+
R_est, t_est = estimate_pose_essential(
|
123 |
+
kpts1.cpu().numpy(),
|
124 |
+
kpts2.cpu().numpy(),
|
125 |
+
w1,
|
126 |
+
h1,
|
127 |
+
K1,
|
128 |
+
w2,
|
129 |
+
h2,
|
130 |
+
K2,
|
131 |
+
threshold,
|
132 |
+
)
|
133 |
+
elif self.model == "fundamental":
|
134 |
+
R_est, t_est = estimate_pose_fundamental(
|
135 |
+
kpts1.cpu().numpy(),
|
136 |
+
kpts2.cpu().numpy(),
|
137 |
+
w1,
|
138 |
+
h1,
|
139 |
+
K1,
|
140 |
+
w2,
|
141 |
+
h2,
|
142 |
+
K2,
|
143 |
+
threshold,
|
144 |
+
)
|
145 |
+
T1_to_2_est = np.concatenate((R_est, t_est[:, None]), axis=-1)
|
146 |
+
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
147 |
+
e_pose = max(e_t, e_R)
|
148 |
+
tot_e_pose.append(e_pose)
|
149 |
+
return self.compute_auc(np.array(tot_e_pose))
|
150 |
+
|
151 |
+
|
152 |
+
class ScanNet1500(ScanNetBenchmark):
|
153 |
+
def _post_init(self):
|
154 |
+
self.test_pairs = osp.join(self.data_root, "test.npz")
|
155 |
+
self.benchmark_name = "ScanNet1500"
|
156 |
+
self.model = "essential"
|
157 |
+
|
158 |
+
|
159 |
+
class ScanNet1500_F(ScanNetBenchmark):
|
160 |
+
def _post_init(self):
|
161 |
+
self.test_pairs = osp.join(self.data_root, "test.npz")
|
162 |
+
self.benchmark_name = "ScanNet1500_F"
|
163 |
+
self.model = "fundamental"
|
imcui/third_party/dad/dad/checkpoint.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
3 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
4 |
+
import gc
|
5 |
+
from pathlib import Path
|
6 |
+
import dad
|
7 |
+
from dad.types import Detector
|
8 |
+
|
9 |
+
class CheckPoint:
|
10 |
+
def __init__(self, dir):
|
11 |
+
self.dir = Path(dir)
|
12 |
+
self.dir.mkdir(parents=True, exist_ok=True)
|
13 |
+
|
14 |
+
def save(
|
15 |
+
self,
|
16 |
+
model: Detector,
|
17 |
+
optimizer,
|
18 |
+
lr_scheduler,
|
19 |
+
n,
|
20 |
+
):
|
21 |
+
assert model is not None
|
22 |
+
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
23 |
+
model = model.module
|
24 |
+
states = {
|
25 |
+
"model": model.state_dict(),
|
26 |
+
"n": n,
|
27 |
+
"optimizer": optimizer.state_dict(),
|
28 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
29 |
+
}
|
30 |
+
torch.save(states, self.dir / "model_latest.pth")
|
31 |
+
dad.logger.info(f"Saved states {list(states.keys())}, at step {n}")
|
32 |
+
|
33 |
+
def load(
|
34 |
+
self,
|
35 |
+
model: Detector,
|
36 |
+
optimizer,
|
37 |
+
lr_scheduler,
|
38 |
+
n,
|
39 |
+
):
|
40 |
+
if not (self.dir / "model_latest.pth").exists():
|
41 |
+
return model, optimizer, lr_scheduler, n
|
42 |
+
|
43 |
+
states = torch.load(self.dir / "model_latest.pth")
|
44 |
+
if "model" in states:
|
45 |
+
model.load_state_dict(states["model"])
|
46 |
+
if "n" in states:
|
47 |
+
n = states["n"] if states["n"] else n
|
48 |
+
if "optimizer" in states:
|
49 |
+
try:
|
50 |
+
optimizer.load_state_dict(states["optimizer"])
|
51 |
+
except Exception as e:
|
52 |
+
dad.logger.warning(
|
53 |
+
f"Failed to load states for optimizer, with error {e}"
|
54 |
+
)
|
55 |
+
if "lr_scheduler" in states:
|
56 |
+
lr_scheduler.load_state_dict(states["lr_scheduler"])
|
57 |
+
dad.logger.info(f"Loaded states {list(states.keys())}, at step {n}")
|
58 |
+
del states
|
59 |
+
gc.collect()
|
60 |
+
torch.cuda.empty_cache()
|
61 |
+
return model, optimizer, lr_scheduler, n
|
imcui/third_party/dad/dad/datasets/__init__.py
ADDED
File without changes
|
imcui/third_party/dad/dad/datasets/megadepth.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import h5py
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms.functional as tvf
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import dad
|
11 |
+
from dad.augs import (
|
12 |
+
get_tuple_transform_ops,
|
13 |
+
get_depth_tuple_transform_ops,
|
14 |
+
)
|
15 |
+
from torch.utils.data import ConcatDataset
|
16 |
+
|
17 |
+
|
18 |
+
class MegadepthScene:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
data_root,
|
22 |
+
scene_info,
|
23 |
+
scene_name=None,
|
24 |
+
min_overlap=0.0,
|
25 |
+
max_overlap=1.0,
|
26 |
+
image_size=640,
|
27 |
+
normalize=True,
|
28 |
+
shake_t=32,
|
29 |
+
rot_360=False,
|
30 |
+
max_num_pairs=100_000,
|
31 |
+
) -> None:
|
32 |
+
self.data_root = data_root
|
33 |
+
self.scene_name = (
|
34 |
+
os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}"
|
35 |
+
)
|
36 |
+
self.image_paths = scene_info["image_paths"]
|
37 |
+
self.depth_paths = scene_info["depth_paths"]
|
38 |
+
self.intrinsics = scene_info["intrinsics"]
|
39 |
+
self.poses = scene_info["poses"]
|
40 |
+
self.pairs = scene_info["pairs"]
|
41 |
+
self.overlaps = scene_info["overlaps"]
|
42 |
+
threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
|
43 |
+
self.pairs = self.pairs[threshold]
|
44 |
+
self.overlaps = self.overlaps[threshold]
|
45 |
+
if len(self.pairs) > max_num_pairs:
|
46 |
+
pairinds = np.random.choice(
|
47 |
+
np.arange(0, len(self.pairs)), max_num_pairs, replace=False
|
48 |
+
)
|
49 |
+
self.pairs = self.pairs[pairinds]
|
50 |
+
self.overlaps = self.overlaps[pairinds]
|
51 |
+
self.im_transform_ops = get_tuple_transform_ops(
|
52 |
+
resize=(image_size, image_size),
|
53 |
+
normalize=normalize,
|
54 |
+
)
|
55 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(
|
56 |
+
resize=(image_size, image_size), normalize=False
|
57 |
+
)
|
58 |
+
self.image_size = image_size
|
59 |
+
self.shake_t = shake_t
|
60 |
+
self.rot_360 = rot_360
|
61 |
+
|
62 |
+
def load_im(self, im_B, crop=None):
|
63 |
+
im = Image.open(im_B)
|
64 |
+
return im
|
65 |
+
|
66 |
+
def rot_360_deg(self, im, depth, K, angle):
|
67 |
+
C, H, W = im.shape
|
68 |
+
im = tvf.rotate(im, angle, expand=True)
|
69 |
+
depth = tvf.rotate(depth, angle, expand=True)
|
70 |
+
radians = angle * math.pi / 180
|
71 |
+
rot_mat = torch.tensor(
|
72 |
+
[
|
73 |
+
[math.cos(radians), math.sin(radians), 0],
|
74 |
+
[-math.sin(radians), math.cos(radians), 0],
|
75 |
+
[0, 0, 1.0],
|
76 |
+
]
|
77 |
+
).to(K.device)
|
78 |
+
t_mat = torch.tensor([[1, 0, W / 2], [0, 1, H / 2], [0, 0, 1.0]]).to(K.device)
|
79 |
+
neg_t_mat = torch.tensor([[1, 0, -W / 2], [0, 1, -H / 2], [0, 0, 1.0]]).to(
|
80 |
+
K.device
|
81 |
+
)
|
82 |
+
transform = t_mat @ rot_mat @ neg_t_mat
|
83 |
+
K = transform @ K
|
84 |
+
return im, depth, K, transform
|
85 |
+
|
86 |
+
def load_depth(self, depth_ref, crop=None):
|
87 |
+
depth = np.array(h5py.File(depth_ref, "r")["depth"])
|
88 |
+
return torch.from_numpy(depth)
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return len(self.pairs)
|
92 |
+
|
93 |
+
def scale_intrinsic(self, K, wi, hi):
|
94 |
+
sx, sy = self.image_size / wi, self.image_size / hi
|
95 |
+
sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
|
96 |
+
return sK @ K
|
97 |
+
|
98 |
+
def rand_shake(self, *things):
|
99 |
+
t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=(2))
|
100 |
+
return [
|
101 |
+
tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
|
102 |
+
for thing in things
|
103 |
+
], t
|
104 |
+
|
105 |
+
def __getitem__(self, pair_idx):
|
106 |
+
try:
|
107 |
+
# read intrinsics of original size
|
108 |
+
idx1, idx2 = self.pairs[pair_idx]
|
109 |
+
K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(
|
110 |
+
3, 3
|
111 |
+
)
|
112 |
+
K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(
|
113 |
+
3, 3
|
114 |
+
)
|
115 |
+
|
116 |
+
# read and compute relative poses
|
117 |
+
T1 = self.poses[idx1]
|
118 |
+
T2 = self.poses[idx2]
|
119 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
|
120 |
+
:4, :4
|
121 |
+
] # (4, 4)
|
122 |
+
|
123 |
+
# Load positive pair data
|
124 |
+
im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
|
125 |
+
depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
|
126 |
+
im_A_ref = os.path.join(self.data_root, im_A)
|
127 |
+
im_B_ref = os.path.join(self.data_root, im_B)
|
128 |
+
depth_A_ref = os.path.join(self.data_root, depth1)
|
129 |
+
depth_B_ref = os.path.join(self.data_root, depth2)
|
130 |
+
im_A: Image.Image = self.load_im(im_A_ref)
|
131 |
+
im_B: Image.Image = self.load_im(im_B_ref)
|
132 |
+
depth_A = self.load_depth(depth_A_ref)
|
133 |
+
depth_B = self.load_depth(depth_B_ref)
|
134 |
+
|
135 |
+
# Recompute camera intrinsic matrix due to the resize
|
136 |
+
W_A, H_A = im_A.width, im_A.height
|
137 |
+
W_B, H_B = im_B.width, im_B.height
|
138 |
+
|
139 |
+
K1 = self.scale_intrinsic(K1, W_A, H_A)
|
140 |
+
K2 = self.scale_intrinsic(K2, W_B, H_B)
|
141 |
+
|
142 |
+
# Process images
|
143 |
+
im_A, im_B = self.im_transform_ops((im_A, im_B))
|
144 |
+
depth_A, depth_B = self.depth_transform_ops(
|
145 |
+
(depth_A[None, None], depth_B[None, None])
|
146 |
+
)
|
147 |
+
[im_A, depth_A], t_A = self.rand_shake(im_A, depth_A)
|
148 |
+
[im_B, depth_B], t_B = self.rand_shake(im_B, depth_B)
|
149 |
+
|
150 |
+
K1[:2, 2] += t_A
|
151 |
+
K2[:2, 2] += t_B
|
152 |
+
|
153 |
+
if self.rot_360:
|
154 |
+
angle_A = np.random.choice([-90, 0, 90, 180])
|
155 |
+
angle_B = np.random.choice([-90, 0, 90, 180])
|
156 |
+
angle_A, angle_B = int(angle_A), int(angle_B)
|
157 |
+
im_A, depth_A, K1, _ = self.rot_360_deg(
|
158 |
+
im_A, depth_A, K1, angle=angle_A
|
159 |
+
)
|
160 |
+
im_B, depth_B, K2, _ = self.rot_360_deg(
|
161 |
+
im_B, depth_B, K2, angle=angle_B
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
angle_A = 0
|
165 |
+
angle_B = 0
|
166 |
+
data_dict = {
|
167 |
+
"im_A": im_A,
|
168 |
+
"im_A_identifier": self.image_paths[idx1]
|
169 |
+
.split("/")[-1]
|
170 |
+
.split(".jpg")[0],
|
171 |
+
"im_B": im_B,
|
172 |
+
"im_B_identifier": self.image_paths[idx2]
|
173 |
+
.split("/")[-1]
|
174 |
+
.split(".jpg")[0],
|
175 |
+
"im_A_depth": depth_A[0, 0],
|
176 |
+
"im_B_depth": depth_B[0, 0],
|
177 |
+
"pose_A": T1,
|
178 |
+
"pose_B": T2,
|
179 |
+
"K1": K1,
|
180 |
+
"K2": K2,
|
181 |
+
"T_1to2": T_1to2,
|
182 |
+
"im_A_path": im_A_ref,
|
183 |
+
"im_B_path": im_B_ref,
|
184 |
+
"angle_A": angle_A,
|
185 |
+
"angle_B": angle_B,
|
186 |
+
}
|
187 |
+
except Exception as e:
|
188 |
+
dad.logger.warning(e)
|
189 |
+
dad.logger.warning(f"Failed to load image pair {self.pairs[pair_idx]}")
|
190 |
+
dad.logger.warning("Loading a random pair in scene instead")
|
191 |
+
rand_ind = np.random.choice(range(len(self)))
|
192 |
+
return self[rand_ind]
|
193 |
+
return data_dict
|
194 |
+
|
195 |
+
|
196 |
+
class MegadepthBuilder:
|
197 |
+
def __init__(self, data_root, loftr_ignore=True, imc21_ignore=True) -> None:
|
198 |
+
self.data_root = data_root
|
199 |
+
self.scene_info_root = os.path.join(data_root, "prep_scene_info")
|
200 |
+
self.all_scenes = os.listdir(self.scene_info_root)
|
201 |
+
self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
|
202 |
+
# LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
|
203 |
+
self.loftr_ignore_scenes = set(
|
204 |
+
[
|
205 |
+
"0121.npy",
|
206 |
+
"0133.npy",
|
207 |
+
"0168.npy",
|
208 |
+
"0178.npy",
|
209 |
+
"0229.npy",
|
210 |
+
"0349.npy",
|
211 |
+
"0412.npy",
|
212 |
+
"0430.npy",
|
213 |
+
"0443.npy",
|
214 |
+
"1001.npy",
|
215 |
+
"5014.npy",
|
216 |
+
"5015.npy",
|
217 |
+
"5016.npy",
|
218 |
+
]
|
219 |
+
)
|
220 |
+
self.imc21_scenes = set(
|
221 |
+
[
|
222 |
+
"0008.npy",
|
223 |
+
"0019.npy",
|
224 |
+
"0021.npy",
|
225 |
+
"0024.npy",
|
226 |
+
"0025.npy",
|
227 |
+
"0032.npy",
|
228 |
+
"0063.npy",
|
229 |
+
"1589.npy",
|
230 |
+
]
|
231 |
+
)
|
232 |
+
self.test_scenes_loftr = ["0015.npy", "0022.npy"]
|
233 |
+
self.loftr_ignore = loftr_ignore
|
234 |
+
self.imc21_ignore = imc21_ignore
|
235 |
+
|
236 |
+
def build_scenes(self, split, **kwargs):
|
237 |
+
if split == "train":
|
238 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes)
|
239 |
+
elif split == "train_loftr":
|
240 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
|
241 |
+
elif split == "test":
|
242 |
+
scene_names = self.test_scenes
|
243 |
+
elif split == "test_loftr":
|
244 |
+
scene_names = self.test_scenes_loftr
|
245 |
+
elif split == "all_scenes":
|
246 |
+
scene_names = self.all_scenes
|
247 |
+
elif split == "custom":
|
248 |
+
scene_names = scene_names
|
249 |
+
else:
|
250 |
+
raise ValueError(f"Split {split} not available")
|
251 |
+
scenes = []
|
252 |
+
for scene_name in tqdm(scene_names):
|
253 |
+
if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
|
254 |
+
continue
|
255 |
+
if self.imc21_ignore and scene_name in self.imc21_scenes:
|
256 |
+
continue
|
257 |
+
if ".npy" not in scene_name:
|
258 |
+
continue
|
259 |
+
scene_info = np.load(
|
260 |
+
os.path.join(self.scene_info_root, scene_name), allow_pickle=True
|
261 |
+
).item()
|
262 |
+
|
263 |
+
scenes.append(
|
264 |
+
MegadepthScene(
|
265 |
+
self.data_root,
|
266 |
+
scene_info,
|
267 |
+
scene_name=scene_name,
|
268 |
+
**kwargs,
|
269 |
+
)
|
270 |
+
)
|
271 |
+
return scenes
|
272 |
+
|
273 |
+
def weight_scenes(self, concat_dataset, alpha=0.5):
|
274 |
+
ns = []
|
275 |
+
for d in concat_dataset.datasets:
|
276 |
+
ns.append(len(d))
|
277 |
+
ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
|
278 |
+
return ws
|
279 |
+
|
280 |
+
def dedode_train_split(self, **kwargs):
|
281 |
+
megadepth_train1 = self.build_scenes(
|
282 |
+
split="train_loftr", min_overlap=0.01, **kwargs
|
283 |
+
)
|
284 |
+
megadepth_train2 = self.build_scenes(
|
285 |
+
split="train_loftr", min_overlap=0.35, **kwargs
|
286 |
+
)
|
287 |
+
|
288 |
+
megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
|
289 |
+
return megadepth_train
|
290 |
+
|
291 |
+
def hard_train_split(self, **kwargs):
|
292 |
+
megadepth_train = self.build_scenes(
|
293 |
+
split="train_loftr", min_overlap=0.01, **kwargs
|
294 |
+
)
|
295 |
+
megadepth_train = ConcatDataset(megadepth_train)
|
296 |
+
return megadepth_train
|
297 |
+
|
298 |
+
def easy_train_split(self, **kwargs):
|
299 |
+
megadepth_train = self.build_scenes(
|
300 |
+
split="train_loftr", min_overlap=0.35, **kwargs
|
301 |
+
)
|
302 |
+
megadepth_train = ConcatDataset(megadepth_train)
|
303 |
+
return megadepth_train
|
304 |
+
|
305 |
+
def dedode_test_split(self, **kwargs):
|
306 |
+
megadepth_test = self.build_scenes(
|
307 |
+
split="test_loftr",
|
308 |
+
min_overlap=0.01,
|
309 |
+
**kwargs,
|
310 |
+
)
|
311 |
+
megadepth_test = ConcatDataset(megadepth_test)
|
312 |
+
return megadepth_test
|
imcui/third_party/dad/dad/detectors/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dedode_detector import load_DaD as load_DaD
|
2 |
+
from .dedode_detector import load_DaDDark as load_DaDDark
|
3 |
+
from .dedode_detector import load_DaDLight as load_DaDLight
|
4 |
+
from .dedode_detector import dedode_detector_S as dedode_detector_S
|
5 |
+
from .dedode_detector import dedode_detector_B as dedode_detector_B
|
6 |
+
from .dedode_detector import dedode_detector_L as dedode_detector_L
|
7 |
+
from .dedode_detector import load_dedode_v2 as load_dedode_v2
|
8 |
+
|
9 |
+
|
10 |
+
lg_detectors = ["ALIKED", "ALIKEDROT", "SIFT", "DISK", "SuperPoint", "ReinforcedFP"]
|
11 |
+
other_detectors = ["HesAff", "HarrisAff", "REKD"]
|
12 |
+
dedode_detectors = [
|
13 |
+
"DeDoDe-v2",
|
14 |
+
"DaD",
|
15 |
+
"DaDLight",
|
16 |
+
"DaDDark",
|
17 |
+
]
|
18 |
+
all_detectors = lg_detectors + dedode_detectors + other_detectors
|
19 |
+
|
20 |
+
|
21 |
+
def load_detector_by_name(detector_name, *, resize=1024, weights_path=None):
|
22 |
+
if detector_name == "DaD":
|
23 |
+
detector = load_DaD(resize=resize, weights_path=weights_path)
|
24 |
+
elif detector_name == "DaDLight":
|
25 |
+
detector = load_DaDLight(resize=resize, weights_path=weights_path)
|
26 |
+
elif detector_name == "DaDDark":
|
27 |
+
detector = load_DaDDark(resize=resize, weights_path=weights_path)
|
28 |
+
elif detector_name == "DeDoDe-v2":
|
29 |
+
detector = load_dedode_v2()
|
30 |
+
elif detector_name in lg_detectors:
|
31 |
+
from .third_party import lightglue, LightGlueDetector
|
32 |
+
|
33 |
+
detector = LightGlueDetector(
|
34 |
+
getattr(lightglue, detector_name), detection_threshold=0, resize=resize
|
35 |
+
)
|
36 |
+
elif detector_name == "HesAff":
|
37 |
+
from .third_party import HesAff
|
38 |
+
|
39 |
+
detector = HesAff()
|
40 |
+
elif detector_name == "HarrisAff":
|
41 |
+
from .third_party import HarrisAff
|
42 |
+
|
43 |
+
detector = HarrisAff()
|
44 |
+
elif detector_name == "REKD":
|
45 |
+
from .third_party import load_REKD
|
46 |
+
|
47 |
+
detector = load_REKD(resize=resize)
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Couldn't find detector with detector name {detector_name}")
|
50 |
+
return detector
|
imcui/third_party/dad/dad/detectors/dedode_detector.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.models as tvm
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from PIL import Image
|
9 |
+
from dad.utils import get_best_device, sample_keypoints, check_not_i16
|
10 |
+
|
11 |
+
from dad.types import Detector
|
12 |
+
|
13 |
+
|
14 |
+
class DeDoDeDetector(Detector):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
*args,
|
18 |
+
encoder: nn.Module,
|
19 |
+
decoder: nn.Module,
|
20 |
+
resize: int,
|
21 |
+
nms_size: int,
|
22 |
+
subpixel: bool,
|
23 |
+
subpixel_temp: float,
|
24 |
+
keep_aspect_ratio: bool,
|
25 |
+
remove_borders: bool,
|
26 |
+
increase_coverage: bool,
|
27 |
+
coverage_pow: float,
|
28 |
+
coverage_size: int,
|
29 |
+
**kwargs,
|
30 |
+
) -> None:
|
31 |
+
super().__init__(*args, **kwargs)
|
32 |
+
self.normalizer = transforms.Normalize(
|
33 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
34 |
+
)
|
35 |
+
self.encoder = encoder
|
36 |
+
self.decoder = decoder
|
37 |
+
self.remove_borders = remove_borders
|
38 |
+
self.resize = resize
|
39 |
+
self.increase_coverage = increase_coverage
|
40 |
+
self.coverage_pow = coverage_pow
|
41 |
+
self.coverage_size = coverage_size
|
42 |
+
self.nms_size = nms_size
|
43 |
+
self.keep_aspect_ratio = keep_aspect_ratio
|
44 |
+
self.subpixel = subpixel
|
45 |
+
self.subpixel_temp = subpixel_temp
|
46 |
+
|
47 |
+
@property
|
48 |
+
def topleft(self):
|
49 |
+
return 0.5
|
50 |
+
|
51 |
+
def forward_impl(
|
52 |
+
self,
|
53 |
+
images,
|
54 |
+
):
|
55 |
+
features, sizes = self.encoder(images)
|
56 |
+
logits = 0
|
57 |
+
context = None
|
58 |
+
scales = ["8", "4", "2", "1"]
|
59 |
+
for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
|
60 |
+
delta_logits, context = self.decoder(
|
61 |
+
feature_map, context=context, scale=scale
|
62 |
+
)
|
63 |
+
logits = (
|
64 |
+
logits + delta_logits.float()
|
65 |
+
) # ensure float (need bf16 doesnt have f.interpolate)
|
66 |
+
if idx < len(scales) - 1:
|
67 |
+
size = sizes[-(idx + 2)]
|
68 |
+
logits = F.interpolate(
|
69 |
+
logits, size=size, mode="bicubic", align_corners=False
|
70 |
+
)
|
71 |
+
context = F.interpolate(
|
72 |
+
context.float(), size=size, mode="bilinear", align_corners=False
|
73 |
+
)
|
74 |
+
return logits.float()
|
75 |
+
|
76 |
+
def forward(self, batch) -> dict[str, torch.Tensor]:
|
77 |
+
# wraps internal forward impl to handle
|
78 |
+
# different types of batches etc.
|
79 |
+
if "im_A" in batch:
|
80 |
+
images = torch.cat((batch["im_A"], batch["im_B"]))
|
81 |
+
else:
|
82 |
+
images = batch["image"]
|
83 |
+
scoremap = self.forward_impl(images)
|
84 |
+
return {"scoremap": scoremap}
|
85 |
+
|
86 |
+
@torch.inference_mode()
|
87 |
+
def detect(
|
88 |
+
self, batch, *, num_keypoints, return_dense_probs=False
|
89 |
+
) -> dict[str, torch.Tensor]:
|
90 |
+
self.train(False)
|
91 |
+
scoremap = self.forward(batch)["scoremap"]
|
92 |
+
B, K, H, W = scoremap.shape
|
93 |
+
dense_probs = (
|
94 |
+
scoremap.reshape(B, K * H * W)
|
95 |
+
.softmax(dim=-1)
|
96 |
+
.reshape(B, K, H * W)
|
97 |
+
.sum(dim=1)
|
98 |
+
)
|
99 |
+
dense_probs = dense_probs.reshape(B, H, W)
|
100 |
+
keypoints, confidence = sample_keypoints(
|
101 |
+
dense_probs,
|
102 |
+
use_nms=True,
|
103 |
+
nms_size=self.nms_size,
|
104 |
+
sample_topk=True,
|
105 |
+
num_samples=num_keypoints,
|
106 |
+
return_probs=True,
|
107 |
+
increase_coverage=self.increase_coverage,
|
108 |
+
remove_borders=self.remove_borders,
|
109 |
+
coverage_pow=self.coverage_pow,
|
110 |
+
coverage_size=self.coverage_size,
|
111 |
+
subpixel=self.subpixel,
|
112 |
+
subpixel_temp=self.subpixel_temp,
|
113 |
+
scoremap=scoremap.reshape(B, H, W),
|
114 |
+
)
|
115 |
+
result = {"keypoints": keypoints, "keypoint_probs": confidence}
|
116 |
+
if return_dense_probs:
|
117 |
+
result["dense_probs"] = dense_probs
|
118 |
+
return result
|
119 |
+
|
120 |
+
def load_image(self, im_path, device=get_best_device()) -> dict[str, torch.Tensor]:
|
121 |
+
pil_im = Image.open(im_path)
|
122 |
+
check_not_i16(pil_im)
|
123 |
+
pil_im = pil_im.convert("RGB")
|
124 |
+
if self.keep_aspect_ratio:
|
125 |
+
W, H = pil_im.size
|
126 |
+
scale = self.resize / max(W, H)
|
127 |
+
W = int((scale * W) // 8 * 8)
|
128 |
+
H = int((scale * H) // 8 * 8)
|
129 |
+
else:
|
130 |
+
H, W = self.resize, self.resize
|
131 |
+
pil_im = pil_im.resize((W, H))
|
132 |
+
standard_im = np.array(pil_im) / 255.0
|
133 |
+
return {
|
134 |
+
"image": self.normalizer(torch.from_numpy(standard_im).permute(2, 0, 1))
|
135 |
+
.float()
|
136 |
+
.to(device)[None]
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
class Decoder(nn.Module):
|
141 |
+
def __init__(
|
142 |
+
self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs
|
143 |
+
) -> None:
|
144 |
+
super().__init__(*args, **kwargs)
|
145 |
+
self.layers = layers
|
146 |
+
self.scales = self.layers.keys()
|
147 |
+
self.super_resolution = super_resolution
|
148 |
+
self.num_prototypes = num_prototypes
|
149 |
+
|
150 |
+
def forward(self, features, context=None, scale=None):
|
151 |
+
if context is not None:
|
152 |
+
features = torch.cat((features, context), dim=1)
|
153 |
+
stuff = self.layers[scale](features)
|
154 |
+
logits, context = (
|
155 |
+
stuff[:, : self.num_prototypes],
|
156 |
+
stuff[:, self.num_prototypes :],
|
157 |
+
)
|
158 |
+
return logits, context
|
159 |
+
|
160 |
+
|
161 |
+
class ConvRefiner(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
in_dim=6,
|
165 |
+
hidden_dim=16,
|
166 |
+
out_dim=2,
|
167 |
+
dw=True,
|
168 |
+
kernel_size=5,
|
169 |
+
hidden_blocks=5,
|
170 |
+
amp=True,
|
171 |
+
residual=False,
|
172 |
+
amp_dtype=torch.float16,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
self.block1 = self.create_block(
|
176 |
+
in_dim,
|
177 |
+
hidden_dim,
|
178 |
+
dw=False,
|
179 |
+
kernel_size=1,
|
180 |
+
)
|
181 |
+
self.hidden_blocks = nn.Sequential(
|
182 |
+
*[
|
183 |
+
self.create_block(
|
184 |
+
hidden_dim,
|
185 |
+
hidden_dim,
|
186 |
+
dw=dw,
|
187 |
+
kernel_size=kernel_size,
|
188 |
+
)
|
189 |
+
for hb in range(hidden_blocks)
|
190 |
+
]
|
191 |
+
)
|
192 |
+
self.hidden_blocks = self.hidden_blocks
|
193 |
+
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
194 |
+
self.amp = amp
|
195 |
+
self.amp_dtype = amp_dtype
|
196 |
+
self.residual = residual
|
197 |
+
|
198 |
+
def create_block(
|
199 |
+
self,
|
200 |
+
in_dim,
|
201 |
+
out_dim,
|
202 |
+
dw=True,
|
203 |
+
kernel_size=5,
|
204 |
+
bias=True,
|
205 |
+
norm_type=nn.BatchNorm2d,
|
206 |
+
):
|
207 |
+
num_groups = 1 if not dw else in_dim
|
208 |
+
if dw:
|
209 |
+
assert out_dim % in_dim == 0, (
|
210 |
+
"outdim must be divisible by indim for depthwise"
|
211 |
+
)
|
212 |
+
conv1 = nn.Conv2d(
|
213 |
+
in_dim,
|
214 |
+
out_dim,
|
215 |
+
kernel_size=kernel_size,
|
216 |
+
stride=1,
|
217 |
+
padding=kernel_size // 2,
|
218 |
+
groups=num_groups,
|
219 |
+
bias=bias,
|
220 |
+
)
|
221 |
+
norm = (
|
222 |
+
norm_type(out_dim)
|
223 |
+
if norm_type is nn.BatchNorm2d
|
224 |
+
else norm_type(num_channels=out_dim)
|
225 |
+
)
|
226 |
+
relu = nn.ReLU(inplace=True)
|
227 |
+
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
228 |
+
return nn.Sequential(conv1, norm, relu, conv2)
|
229 |
+
|
230 |
+
def forward(self, feats):
|
231 |
+
b, c, hs, ws = feats.shape
|
232 |
+
with torch.autocast(device_type=feats.device.type, enabled=self.amp, dtype=self.amp_dtype):
|
233 |
+
x0 = self.block1(feats)
|
234 |
+
x = self.hidden_blocks(x0)
|
235 |
+
if self.residual:
|
236 |
+
x = (x + x0) / 1.4
|
237 |
+
x = self.out_conv(x)
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
class VGG19(nn.Module):
|
242 |
+
def __init__(self, amp=False, amp_dtype=torch.float16) -> None:
|
243 |
+
super().__init__()
|
244 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn().features[:40])
|
245 |
+
# Maxpool layers: 6, 13, 26, 39
|
246 |
+
self.amp = amp
|
247 |
+
self.amp_dtype = amp_dtype
|
248 |
+
|
249 |
+
def forward(self, x, **kwargs):
|
250 |
+
with torch.autocast(device_type=x.device.type, enabled=self.amp, dtype=self.amp_dtype):
|
251 |
+
feats = []
|
252 |
+
sizes = []
|
253 |
+
for layer in self.layers:
|
254 |
+
if isinstance(layer, nn.MaxPool2d):
|
255 |
+
feats.append(x)
|
256 |
+
sizes.append(x.shape[-2:])
|
257 |
+
x = layer(x)
|
258 |
+
return feats, sizes
|
259 |
+
|
260 |
+
|
261 |
+
class VGG(nn.Module):
|
262 |
+
def __init__(self, size="19", amp=False, amp_dtype=torch.float16) -> None:
|
263 |
+
super().__init__()
|
264 |
+
if size == "11":
|
265 |
+
self.layers = nn.ModuleList(tvm.vgg11_bn().features[:22])
|
266 |
+
elif size == "13":
|
267 |
+
self.layers = nn.ModuleList(tvm.vgg13_bn().features[:28])
|
268 |
+
elif size == "19":
|
269 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn().features[:40])
|
270 |
+
# Maxpool layers: 6, 13, 26, 39
|
271 |
+
self.amp = amp
|
272 |
+
self.amp_dtype = amp_dtype
|
273 |
+
|
274 |
+
def forward(self, x, **kwargs):
|
275 |
+
with torch.autocast(device_type=x.device.type, enabled=self.amp, dtype=self.amp_dtype):
|
276 |
+
feats = []
|
277 |
+
sizes = []
|
278 |
+
for layer in self.layers:
|
279 |
+
if isinstance(layer, nn.MaxPool2d):
|
280 |
+
feats.append(x)
|
281 |
+
sizes.append(x.shape[-2:])
|
282 |
+
x = layer(x)
|
283 |
+
return feats, sizes
|
284 |
+
|
285 |
+
|
286 |
+
def dedode_detector_S():
|
287 |
+
residual = True
|
288 |
+
hidden_blocks = 3
|
289 |
+
amp_dtype = torch.float16
|
290 |
+
amp = True
|
291 |
+
NUM_PROTOTYPES = 1
|
292 |
+
conv_refiner = nn.ModuleDict(
|
293 |
+
{
|
294 |
+
"8": ConvRefiner(
|
295 |
+
512,
|
296 |
+
512,
|
297 |
+
256 + NUM_PROTOTYPES,
|
298 |
+
hidden_blocks=hidden_blocks,
|
299 |
+
residual=residual,
|
300 |
+
amp=amp,
|
301 |
+
amp_dtype=amp_dtype,
|
302 |
+
),
|
303 |
+
"4": ConvRefiner(
|
304 |
+
256 + 256,
|
305 |
+
256,
|
306 |
+
128 + NUM_PROTOTYPES,
|
307 |
+
hidden_blocks=hidden_blocks,
|
308 |
+
residual=residual,
|
309 |
+
amp=amp,
|
310 |
+
amp_dtype=amp_dtype,
|
311 |
+
),
|
312 |
+
"2": ConvRefiner(
|
313 |
+
128 + 128,
|
314 |
+
64,
|
315 |
+
32 + NUM_PROTOTYPES,
|
316 |
+
hidden_blocks=hidden_blocks,
|
317 |
+
residual=residual,
|
318 |
+
amp=amp,
|
319 |
+
amp_dtype=amp_dtype,
|
320 |
+
),
|
321 |
+
"1": ConvRefiner(
|
322 |
+
64 + 32,
|
323 |
+
32,
|
324 |
+
1 + NUM_PROTOTYPES,
|
325 |
+
hidden_blocks=hidden_blocks,
|
326 |
+
residual=residual,
|
327 |
+
amp=amp,
|
328 |
+
amp_dtype=amp_dtype,
|
329 |
+
),
|
330 |
+
}
|
331 |
+
)
|
332 |
+
encoder = VGG(size="11", amp=amp, amp_dtype=amp_dtype)
|
333 |
+
decoder = Decoder(conv_refiner)
|
334 |
+
return encoder, decoder
|
335 |
+
|
336 |
+
|
337 |
+
def dedode_detector_B():
|
338 |
+
residual = True
|
339 |
+
hidden_blocks = 5
|
340 |
+
amp_dtype = torch.float16
|
341 |
+
amp = True
|
342 |
+
NUM_PROTOTYPES = 1
|
343 |
+
conv_refiner = nn.ModuleDict(
|
344 |
+
{
|
345 |
+
"8": ConvRefiner(
|
346 |
+
512,
|
347 |
+
512,
|
348 |
+
256 + NUM_PROTOTYPES,
|
349 |
+
hidden_blocks=hidden_blocks,
|
350 |
+
residual=residual,
|
351 |
+
amp=amp,
|
352 |
+
amp_dtype=amp_dtype,
|
353 |
+
),
|
354 |
+
"4": ConvRefiner(
|
355 |
+
256 + 256,
|
356 |
+
256,
|
357 |
+
128 + NUM_PROTOTYPES,
|
358 |
+
hidden_blocks=hidden_blocks,
|
359 |
+
residual=residual,
|
360 |
+
amp=amp,
|
361 |
+
amp_dtype=amp_dtype,
|
362 |
+
),
|
363 |
+
"2": ConvRefiner(
|
364 |
+
128 + 128,
|
365 |
+
64,
|
366 |
+
32 + NUM_PROTOTYPES,
|
367 |
+
hidden_blocks=hidden_blocks,
|
368 |
+
residual=residual,
|
369 |
+
amp=amp,
|
370 |
+
amp_dtype=amp_dtype,
|
371 |
+
),
|
372 |
+
"1": ConvRefiner(
|
373 |
+
64 + 32,
|
374 |
+
32,
|
375 |
+
1 + NUM_PROTOTYPES,
|
376 |
+
hidden_blocks=hidden_blocks,
|
377 |
+
residual=residual,
|
378 |
+
amp=amp,
|
379 |
+
amp_dtype=amp_dtype,
|
380 |
+
),
|
381 |
+
}
|
382 |
+
)
|
383 |
+
encoder = VGG19(amp=amp, amp_dtype=amp_dtype)
|
384 |
+
decoder = Decoder(conv_refiner)
|
385 |
+
return encoder, decoder
|
386 |
+
|
387 |
+
|
388 |
+
def dedode_detector_L():
|
389 |
+
NUM_PROTOTYPES = 1
|
390 |
+
residual = True
|
391 |
+
hidden_blocks = 8
|
392 |
+
amp_dtype = (
|
393 |
+
torch.float16
|
394 |
+
) # torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
395 |
+
amp = True
|
396 |
+
conv_refiner = nn.ModuleDict(
|
397 |
+
{
|
398 |
+
"8": ConvRefiner(
|
399 |
+
512,
|
400 |
+
512,
|
401 |
+
256 + NUM_PROTOTYPES,
|
402 |
+
hidden_blocks=hidden_blocks,
|
403 |
+
residual=residual,
|
404 |
+
amp=amp,
|
405 |
+
amp_dtype=amp_dtype,
|
406 |
+
),
|
407 |
+
"4": ConvRefiner(
|
408 |
+
256 + 256,
|
409 |
+
256,
|
410 |
+
128 + NUM_PROTOTYPES,
|
411 |
+
hidden_blocks=hidden_blocks,
|
412 |
+
residual=residual,
|
413 |
+
amp=amp,
|
414 |
+
amp_dtype=amp_dtype,
|
415 |
+
),
|
416 |
+
"2": ConvRefiner(
|
417 |
+
128 + 128,
|
418 |
+
128,
|
419 |
+
64 + NUM_PROTOTYPES,
|
420 |
+
hidden_blocks=hidden_blocks,
|
421 |
+
residual=residual,
|
422 |
+
amp=amp,
|
423 |
+
amp_dtype=amp_dtype,
|
424 |
+
),
|
425 |
+
"1": ConvRefiner(
|
426 |
+
64 + 64,
|
427 |
+
64,
|
428 |
+
1 + NUM_PROTOTYPES,
|
429 |
+
hidden_blocks=hidden_blocks,
|
430 |
+
residual=residual,
|
431 |
+
amp=amp,
|
432 |
+
amp_dtype=amp_dtype,
|
433 |
+
),
|
434 |
+
}
|
435 |
+
)
|
436 |
+
encoder = VGG19(amp=amp, amp_dtype=amp_dtype)
|
437 |
+
decoder = Decoder(conv_refiner)
|
438 |
+
return encoder, decoder
|
439 |
+
|
440 |
+
|
441 |
+
class DaD(DeDoDeDetector):
|
442 |
+
def __init__(
|
443 |
+
self,
|
444 |
+
encoder: nn.Module,
|
445 |
+
decoder: nn.Module,
|
446 |
+
*args,
|
447 |
+
resize=1024,
|
448 |
+
nms_size=3,
|
449 |
+
remove_borders=False,
|
450 |
+
increase_coverage=False,
|
451 |
+
coverage_pow=None,
|
452 |
+
coverage_size=None,
|
453 |
+
subpixel=True,
|
454 |
+
subpixel_temp=0.5,
|
455 |
+
keep_aspect_ratio=True,
|
456 |
+
**kwargs,
|
457 |
+
) -> None:
|
458 |
+
super().__init__(
|
459 |
+
*args,
|
460 |
+
encoder=encoder,
|
461 |
+
decoder=decoder,
|
462 |
+
resize=resize,
|
463 |
+
nms_size=nms_size,
|
464 |
+
remove_borders=remove_borders,
|
465 |
+
increase_coverage=increase_coverage,
|
466 |
+
coverage_pow=coverage_pow,
|
467 |
+
coverage_size=coverage_size,
|
468 |
+
subpixel=subpixel,
|
469 |
+
keep_aspect_ratio=keep_aspect_ratio,
|
470 |
+
subpixel_temp=subpixel_temp,
|
471 |
+
**kwargs,
|
472 |
+
)
|
473 |
+
|
474 |
+
|
475 |
+
class DeDoDev2(DeDoDeDetector):
|
476 |
+
def __init__(
|
477 |
+
self,
|
478 |
+
encoder: nn.Module,
|
479 |
+
decoder: nn.Module,
|
480 |
+
*args,
|
481 |
+
resize=784,
|
482 |
+
nms_size=3,
|
483 |
+
remove_borders=False,
|
484 |
+
increase_coverage=True,
|
485 |
+
coverage_pow=0.5,
|
486 |
+
coverage_size=51,
|
487 |
+
subpixel=False,
|
488 |
+
subpixel_temp=None,
|
489 |
+
keep_aspect_ratio=False,
|
490 |
+
**kwargs,
|
491 |
+
) -> None:
|
492 |
+
super().__init__(
|
493 |
+
*args,
|
494 |
+
encoder=encoder,
|
495 |
+
decoder=decoder,
|
496 |
+
resize=resize,
|
497 |
+
nms_size=nms_size,
|
498 |
+
remove_borders=remove_borders,
|
499 |
+
increase_coverage=increase_coverage,
|
500 |
+
coverage_pow=coverage_pow,
|
501 |
+
coverage_size=coverage_size,
|
502 |
+
subpixel=subpixel,
|
503 |
+
keep_aspect_ratio=keep_aspect_ratio,
|
504 |
+
subpixel_temp=subpixel_temp,
|
505 |
+
**kwargs,
|
506 |
+
)
|
507 |
+
|
508 |
+
|
509 |
+
def load_DaD(resize=1024, pretrained=True, weights_path=None) -> DaD:
|
510 |
+
if weights_path is None:
|
511 |
+
weights_path = (
|
512 |
+
"https://github.com/Parskatt/dad/releases/download/v0.1.0/dad.pth"
|
513 |
+
)
|
514 |
+
device = get_best_device()
|
515 |
+
encoder, decoder = dedode_detector_S()
|
516 |
+
model = DaD(encoder, decoder, resize=resize).to(device)
|
517 |
+
if pretrained:
|
518 |
+
weights = torch.hub.load_state_dict_from_url(
|
519 |
+
weights_path, weights_only=False, map_location=device
|
520 |
+
)
|
521 |
+
model.load_state_dict(weights)
|
522 |
+
return model
|
523 |
+
|
524 |
+
|
525 |
+
def load_DaDLight(resize=1024, weights_path=None) -> DaD:
|
526 |
+
if weights_path is None:
|
527 |
+
weights_path = (
|
528 |
+
"https://github.com/Parskatt/dad/releases/download/v0.1.0/dad_light.pth"
|
529 |
+
)
|
530 |
+
return load_DaD(
|
531 |
+
resize=resize,
|
532 |
+
pretrained=True,
|
533 |
+
weights_path=weights_path,
|
534 |
+
)
|
535 |
+
|
536 |
+
|
537 |
+
def load_DaDDark(resize=1024, weights_path=None) -> DaD:
|
538 |
+
if weights_path is None:
|
539 |
+
weights_path = (
|
540 |
+
"https://github.com/Parskatt/dad/releases/download/v0.1.0/dad_dark.pth"
|
541 |
+
)
|
542 |
+
return load_DaD(
|
543 |
+
resize=resize,
|
544 |
+
pretrained=True,
|
545 |
+
weights_path=weights_path,
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
def load_dedode_v2() -> DeDoDev2:
|
550 |
+
device = get_best_device()
|
551 |
+
weights = torch.hub.load_state_dict_from_url(
|
552 |
+
"https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_L_v2.pth",
|
553 |
+
map_location=device,
|
554 |
+
)
|
555 |
+
|
556 |
+
encoder, decoder = dedode_detector_L()
|
557 |
+
model = DeDoDev2(encoder, decoder).to(device)
|
558 |
+
model.load_state_dict(weights)
|
559 |
+
return model
|
imcui/third_party/dad/dad/detectors/third_party/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .lightglue_detector import LightGlueDetector as LightGlueDetector
|
2 |
+
from .lightglue import SuperPoint as SuperPoint
|
3 |
+
from .lightglue import ReinforcedFP as ReinforcedFP
|
4 |
+
from .lightglue import DISK as DISK
|
5 |
+
from .lightglue import ALIKED as ALIKED
|
6 |
+
from .lightglue import ALIKEDROT as ALIKEDROT
|
7 |
+
from .lightglue import SIFT as SIFT
|
8 |
+
from .lightglue import DoGHardNet as DoGHardNet
|
9 |
+
from .hesaff import HesAff as HesAff
|
10 |
+
from .harrisaff import HarrisAff as HarrisAff
|
11 |
+
from .rekd.rekd import load_REKD as load_REKD
|
imcui/third_party/dad/dad/detectors/third_party/harrisaff.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from dad.types import Detector
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
from dad.utils import get_best_device
|
8 |
+
|
9 |
+
|
10 |
+
class HarrisAff(Detector):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.detector = cv2.xfeatures2d.HarrisLaplaceFeatureDetector_create(
|
14 |
+
numOctaves=6, corn_thresh=0.0, DOG_thresh=0.0, maxCorners=8192, num_layers=4
|
15 |
+
)
|
16 |
+
|
17 |
+
@property
|
18 |
+
def topleft(self):
|
19 |
+
return 0.0
|
20 |
+
|
21 |
+
def load_image(self, im_path):
|
22 |
+
return {"image": cv2.imread(im_path, cv2.IMREAD_GRAYSCALE)}
|
23 |
+
|
24 |
+
@torch.inference_mode()
|
25 |
+
def detect(self, batch, *, num_keypoints, return_dense_probs=False) -> dict[str, torch.Tensor]:
|
26 |
+
img = batch["image"]
|
27 |
+
H, W = img.shape
|
28 |
+
# Detect keypoints
|
29 |
+
kps = self.detector.detect(img)
|
30 |
+
kps = np.array([kp.pt for kp in kps])[:num_keypoints]
|
31 |
+
kps_n = self.to_normalized_coords(torch.from_numpy(kps), H, W)[None]
|
32 |
+
detections = {"keypoints": kps_n.to(get_best_device()).float(), "keypoint_probs": None}
|
33 |
+
if return_dense_probs:
|
34 |
+
detections["dense_probs"] = None
|
35 |
+
return detections
|
imcui/third_party/dad/dad/detectors/third_party/hesaff.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from dad.utils import get_best_device
|
6 |
+
from dad.types import Detector
|
7 |
+
|
8 |
+
|
9 |
+
class HesAff(Detector):
|
10 |
+
def __init__(self):
|
11 |
+
raise NotImplementedError("Buggy implementation, don't use.")
|
12 |
+
super().__init__()
|
13 |
+
import pyhesaff
|
14 |
+
|
15 |
+
self.params = pyhesaff.get_hesaff_default_params()
|
16 |
+
|
17 |
+
@property
|
18 |
+
def topleft(self):
|
19 |
+
return 0.0
|
20 |
+
|
21 |
+
def load_image(self, im_path):
|
22 |
+
# pyhesaff doesn't seem to have a decoupled image loading and detection stage
|
23 |
+
# so load_image here is just identity
|
24 |
+
return {"image": im_path}
|
25 |
+
|
26 |
+
def detect(self, batch, *, num_keypoints, return_dense_probs=False):
|
27 |
+
import pyhesaff
|
28 |
+
|
29 |
+
im_path = batch["image"]
|
30 |
+
W, H = Image.open(im_path).size
|
31 |
+
detections = pyhesaff.detect_feats(im_path)[0][:num_keypoints]
|
32 |
+
kps = detections[..., :2]
|
33 |
+
kps_n = self.to_normalized_coords(torch.from_numpy(kps), H, W)[None]
|
34 |
+
result = {
|
35 |
+
"keypoints": kps_n.to(get_best_device()).float(),
|
36 |
+
"keypoint_probs": None,
|
37 |
+
}
|
38 |
+
if return_dense_probs is not None:
|
39 |
+
result["dense_probs"] = None
|
40 |
+
return result
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .aliked import ALIKED # noqa
|
2 |
+
from .aliked import ALIKEDROT as ALIKEDROT # noqa
|
3 |
+
from .disk import DISK # noqa
|
4 |
+
from .dog_hardnet import DoGHardNet # noqa
|
5 |
+
from .lightglue import LightGlue # noqa
|
6 |
+
from .sift import SIFT # noqa
|
7 |
+
from .superpoint import SuperPoint # noqa
|
8 |
+
from .superpoint import ReinforcedFP # noqa
|
9 |
+
from .utils import match_pair # noqa
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/aliked.py
ADDED
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BSD 3-Clause License
|
2 |
+
|
3 |
+
# Copyright (c) 2022, Zhao Xiaoming
|
4 |
+
# All rights reserved.
|
5 |
+
|
6 |
+
# Redistribution and use in source and binary forms, with or without
|
7 |
+
# modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
# list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
# this list of conditions and the following disclaimer in the documentation
|
14 |
+
# and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
17 |
+
# contributors may be used to endorse or promote products derived from
|
18 |
+
# this software without specific prior written permission.
|
19 |
+
|
20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
|
31 |
+
# Authors:
|
32 |
+
# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
|
33 |
+
# Code from https://github.com/Shiaoming/ALIKED
|
34 |
+
|
35 |
+
from typing import Callable, Optional
|
36 |
+
|
37 |
+
import torch
|
38 |
+
import torch.nn.functional as F
|
39 |
+
import torchvision
|
40 |
+
from kornia.color import grayscale_to_rgb
|
41 |
+
from torch import nn
|
42 |
+
from torch.nn.modules.utils import _pair
|
43 |
+
from torchvision.models import resnet
|
44 |
+
|
45 |
+
from .utils import Extractor
|
46 |
+
|
47 |
+
|
48 |
+
def get_patches(
|
49 |
+
tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
|
50 |
+
) -> torch.Tensor:
|
51 |
+
c, h, w = tensor.shape
|
52 |
+
corner = (required_corners - ps / 2 + 1).long()
|
53 |
+
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
|
54 |
+
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
|
55 |
+
offset = torch.arange(0, ps)
|
56 |
+
|
57 |
+
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
58 |
+
x, y = torch.meshgrid(offset, offset, **kw)
|
59 |
+
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
|
60 |
+
patches = patches.to(corner) + corner[None, None]
|
61 |
+
pts = patches.reshape(-1, 2)
|
62 |
+
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
|
63 |
+
sampled = sampled.reshape(ps, ps, -1, c)
|
64 |
+
assert sampled.shape[:3] == patches.shape[:3]
|
65 |
+
return sampled.permute(2, 3, 0, 1)
|
66 |
+
|
67 |
+
|
68 |
+
def simple_nms(scores: torch.Tensor, nms_radius: int):
|
69 |
+
"""Fast Non-maximum suppression to remove nearby points"""
|
70 |
+
|
71 |
+
zeros = torch.zeros_like(scores)
|
72 |
+
max_mask = scores == torch.nn.functional.max_pool2d(
|
73 |
+
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
74 |
+
)
|
75 |
+
|
76 |
+
for _ in range(2):
|
77 |
+
supp_mask = (
|
78 |
+
torch.nn.functional.max_pool2d(
|
79 |
+
max_mask.float(),
|
80 |
+
kernel_size=nms_radius * 2 + 1,
|
81 |
+
stride=1,
|
82 |
+
padding=nms_radius,
|
83 |
+
)
|
84 |
+
> 0
|
85 |
+
)
|
86 |
+
supp_scores = torch.where(supp_mask, zeros, scores)
|
87 |
+
new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
|
88 |
+
supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
89 |
+
)
|
90 |
+
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
91 |
+
return torch.where(max_mask, scores, zeros)
|
92 |
+
|
93 |
+
|
94 |
+
class DKD(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
radius: int = 2,
|
98 |
+
top_k: int = 0,
|
99 |
+
scores_th: float = 0.2,
|
100 |
+
n_limit: int = 20000,
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
radius: soft detection radius, kernel size is (2 * radius + 1)
|
105 |
+
top_k: top_k > 0: return top k keypoints
|
106 |
+
scores_th: top_k <= 0 threshold mode:
|
107 |
+
scores_th > 0: return keypoints with scores>scores_th
|
108 |
+
else: return keypoints with scores > scores.mean()
|
109 |
+
n_limit: max number of keypoint in threshold mode
|
110 |
+
"""
|
111 |
+
super().__init__()
|
112 |
+
self.radius = radius
|
113 |
+
self.top_k = top_k
|
114 |
+
self.scores_th = scores_th
|
115 |
+
self.n_limit = n_limit
|
116 |
+
self.kernel_size = 2 * self.radius + 1
|
117 |
+
self.temperature = 0.1 # tuned temperature
|
118 |
+
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
|
119 |
+
# local xy grid
|
120 |
+
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
|
121 |
+
# (kernel_size*kernel_size) x 2 : (w,h)
|
122 |
+
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
123 |
+
self.hw_grid = (
|
124 |
+
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self,
|
129 |
+
scores_map: torch.Tensor,
|
130 |
+
sub_pixel: bool = True,
|
131 |
+
image_size: Optional[torch.Tensor] = None,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
:param scores_map: Bx1xHxW
|
135 |
+
:param descriptor_map: BxCxHxW
|
136 |
+
:param sub_pixel: whether to use sub-pixel keypoint detection
|
137 |
+
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
|
138 |
+
"""
|
139 |
+
b, c, h, w = scores_map.shape
|
140 |
+
scores_nograd = scores_map.detach()
|
141 |
+
nms_scores = simple_nms(scores_nograd, self.radius)
|
142 |
+
|
143 |
+
# remove border
|
144 |
+
nms_scores[:, :, : self.radius, :] = 0
|
145 |
+
nms_scores[:, :, :, : self.radius] = 0
|
146 |
+
if image_size is not None:
|
147 |
+
for i in range(scores_map.shape[0]):
|
148 |
+
w, h = image_size[i].long()
|
149 |
+
nms_scores[i, :, h.item() - self.radius :, :] = 0
|
150 |
+
nms_scores[i, :, :, w.item() - self.radius :] = 0
|
151 |
+
else:
|
152 |
+
nms_scores[:, :, -self.radius :, :] = 0
|
153 |
+
nms_scores[:, :, :, -self.radius :] = 0
|
154 |
+
|
155 |
+
# detect keypoints without grad
|
156 |
+
if self.top_k > 0:
|
157 |
+
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
|
158 |
+
indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
|
159 |
+
else:
|
160 |
+
if self.scores_th > 0:
|
161 |
+
masks = nms_scores > self.scores_th
|
162 |
+
if masks.sum() == 0:
|
163 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
164 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
165 |
+
else:
|
166 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
167 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
168 |
+
masks = masks.reshape(b, -1)
|
169 |
+
|
170 |
+
indices_keypoints = [] # list, B x (any size)
|
171 |
+
scores_view = scores_nograd.reshape(b, -1)
|
172 |
+
for mask, scores in zip(masks, scores_view):
|
173 |
+
indices = mask.nonzero()[:, 0]
|
174 |
+
if len(indices) > self.n_limit:
|
175 |
+
kpts_sc = scores[indices]
|
176 |
+
sort_idx = kpts_sc.sort(descending=True)[1]
|
177 |
+
sel_idx = sort_idx[: self.n_limit]
|
178 |
+
indices = indices[sel_idx]
|
179 |
+
indices_keypoints.append(indices)
|
180 |
+
wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
|
181 |
+
|
182 |
+
keypoints = []
|
183 |
+
scoredispersitys = []
|
184 |
+
kptscores = []
|
185 |
+
if sub_pixel:
|
186 |
+
# detect soft keypoints with grad backpropagation
|
187 |
+
patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
|
188 |
+
# print(patches.shape)
|
189 |
+
self.hw_grid = self.hw_grid.to(scores_map) # to device
|
190 |
+
for b_idx in range(b):
|
191 |
+
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
192 |
+
indices_kpt = indices_keypoints[
|
193 |
+
b_idx
|
194 |
+
] # one dimension vector, say its size is M
|
195 |
+
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
196 |
+
keypoints_xy_nms = torch.stack(
|
197 |
+
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
198 |
+
dim=1,
|
199 |
+
) # Mx2
|
200 |
+
|
201 |
+
# max is detached to prevent undesired backprop loops in the graph
|
202 |
+
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
203 |
+
x_exp = (
|
204 |
+
(patch_scores - max_v) / self.temperature
|
205 |
+
).exp() # M * (kernel**2), in [0, 1]
|
206 |
+
|
207 |
+
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
208 |
+
xy_residual = (
|
209 |
+
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
|
210 |
+
) # Soft-argmax, Mx2
|
211 |
+
|
212 |
+
hw_grid_dist2 = (
|
213 |
+
torch.norm(
|
214 |
+
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
|
215 |
+
/ self.radius,
|
216 |
+
dim=-1,
|
217 |
+
)
|
218 |
+
** 2
|
219 |
+
)
|
220 |
+
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
221 |
+
|
222 |
+
# compute result keypoints
|
223 |
+
keypoints_xy = keypoints_xy_nms + xy_residual
|
224 |
+
keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
225 |
+
|
226 |
+
kptscore = torch.nn.functional.grid_sample(
|
227 |
+
scores_map[b_idx].unsqueeze(0),
|
228 |
+
keypoints_xy.view(1, 1, -1, 2),
|
229 |
+
mode="bilinear",
|
230 |
+
align_corners=True,
|
231 |
+
)[0, 0, 0, :] # CxN
|
232 |
+
|
233 |
+
keypoints.append(keypoints_xy)
|
234 |
+
scoredispersitys.append(scoredispersity)
|
235 |
+
kptscores.append(kptscore)
|
236 |
+
else:
|
237 |
+
for b_idx in range(b):
|
238 |
+
indices_kpt = indices_keypoints[
|
239 |
+
b_idx
|
240 |
+
] # one dimension vector, say its size is M
|
241 |
+
# To avoid warning: UserWarning: __floordiv__ is deprecated
|
242 |
+
keypoints_xy_nms = torch.stack(
|
243 |
+
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
244 |
+
dim=1,
|
245 |
+
) # Mx2
|
246 |
+
keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
247 |
+
kptscore = torch.nn.functional.grid_sample(
|
248 |
+
scores_map[b_idx].unsqueeze(0),
|
249 |
+
keypoints_xy.view(1, 1, -1, 2),
|
250 |
+
mode="bilinear",
|
251 |
+
align_corners=True,
|
252 |
+
)[0, 0, 0, :] # CxN
|
253 |
+
keypoints.append(keypoints_xy)
|
254 |
+
scoredispersitys.append(kptscore) # for jit.script compatability
|
255 |
+
kptscores.append(kptscore)
|
256 |
+
|
257 |
+
return keypoints, scoredispersitys, kptscores
|
258 |
+
|
259 |
+
|
260 |
+
class InputPadder(object):
|
261 |
+
"""Pads images such that dimensions are divisible by 8"""
|
262 |
+
|
263 |
+
def __init__(self, h: int, w: int, divis_by: int = 8):
|
264 |
+
self.ht = h
|
265 |
+
self.wd = w
|
266 |
+
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
267 |
+
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
268 |
+
self._pad = [
|
269 |
+
pad_wd // 2,
|
270 |
+
pad_wd - pad_wd // 2,
|
271 |
+
pad_ht // 2,
|
272 |
+
pad_ht - pad_ht // 2,
|
273 |
+
]
|
274 |
+
|
275 |
+
def pad(self, x: torch.Tensor):
|
276 |
+
assert x.ndim == 4
|
277 |
+
return F.pad(x, self._pad, mode="replicate")
|
278 |
+
|
279 |
+
def unpad(self, x: torch.Tensor):
|
280 |
+
assert x.ndim == 4
|
281 |
+
ht = x.shape[-2]
|
282 |
+
wd = x.shape[-1]
|
283 |
+
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
284 |
+
return x[..., c[0] : c[1], c[2] : c[3]]
|
285 |
+
|
286 |
+
|
287 |
+
class DeformableConv2d(nn.Module):
|
288 |
+
def __init__(
|
289 |
+
self,
|
290 |
+
in_channels,
|
291 |
+
out_channels,
|
292 |
+
kernel_size=3,
|
293 |
+
stride=1,
|
294 |
+
padding=1,
|
295 |
+
bias=False,
|
296 |
+
mask=False,
|
297 |
+
):
|
298 |
+
super(DeformableConv2d, self).__init__()
|
299 |
+
|
300 |
+
self.padding = padding
|
301 |
+
self.mask = mask
|
302 |
+
|
303 |
+
self.channel_num = (
|
304 |
+
3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
|
305 |
+
)
|
306 |
+
self.offset_conv = nn.Conv2d(
|
307 |
+
in_channels,
|
308 |
+
self.channel_num,
|
309 |
+
kernel_size=kernel_size,
|
310 |
+
stride=stride,
|
311 |
+
padding=self.padding,
|
312 |
+
bias=True,
|
313 |
+
)
|
314 |
+
|
315 |
+
self.regular_conv = nn.Conv2d(
|
316 |
+
in_channels=in_channels,
|
317 |
+
out_channels=out_channels,
|
318 |
+
kernel_size=kernel_size,
|
319 |
+
stride=stride,
|
320 |
+
padding=self.padding,
|
321 |
+
bias=bias,
|
322 |
+
)
|
323 |
+
|
324 |
+
def forward(self, x):
|
325 |
+
h, w = x.shape[2:]
|
326 |
+
max_offset = max(h, w) / 4.0
|
327 |
+
|
328 |
+
out = self.offset_conv(x)
|
329 |
+
if self.mask:
|
330 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
331 |
+
offset = torch.cat((o1, o2), dim=1)
|
332 |
+
mask = torch.sigmoid(mask)
|
333 |
+
else:
|
334 |
+
offset = out
|
335 |
+
mask = None
|
336 |
+
offset = offset.clamp(-max_offset, max_offset)
|
337 |
+
x = torchvision.ops.deform_conv2d(
|
338 |
+
input=x,
|
339 |
+
offset=offset,
|
340 |
+
weight=self.regular_conv.weight,
|
341 |
+
bias=self.regular_conv.bias,
|
342 |
+
padding=self.padding,
|
343 |
+
mask=mask,
|
344 |
+
)
|
345 |
+
return x
|
346 |
+
|
347 |
+
|
348 |
+
def get_conv(
|
349 |
+
inplanes,
|
350 |
+
planes,
|
351 |
+
kernel_size=3,
|
352 |
+
stride=1,
|
353 |
+
padding=1,
|
354 |
+
bias=False,
|
355 |
+
conv_type="conv",
|
356 |
+
mask=False,
|
357 |
+
):
|
358 |
+
if conv_type == "conv":
|
359 |
+
conv = nn.Conv2d(
|
360 |
+
inplanes,
|
361 |
+
planes,
|
362 |
+
kernel_size=kernel_size,
|
363 |
+
stride=stride,
|
364 |
+
padding=padding,
|
365 |
+
bias=bias,
|
366 |
+
)
|
367 |
+
elif conv_type == "dcn":
|
368 |
+
conv = DeformableConv2d(
|
369 |
+
inplanes,
|
370 |
+
planes,
|
371 |
+
kernel_size=kernel_size,
|
372 |
+
stride=stride,
|
373 |
+
padding=_pair(padding),
|
374 |
+
bias=bias,
|
375 |
+
mask=mask,
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
raise TypeError
|
379 |
+
return conv
|
380 |
+
|
381 |
+
|
382 |
+
class ConvBlock(nn.Module):
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
in_channels,
|
386 |
+
out_channels,
|
387 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
388 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
389 |
+
conv_type: str = "conv",
|
390 |
+
mask: bool = False,
|
391 |
+
):
|
392 |
+
super().__init__()
|
393 |
+
if gate is None:
|
394 |
+
self.gate = nn.ReLU(inplace=True)
|
395 |
+
else:
|
396 |
+
self.gate = gate
|
397 |
+
if norm_layer is None:
|
398 |
+
norm_layer = nn.BatchNorm2d
|
399 |
+
self.conv1 = get_conv(
|
400 |
+
in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
401 |
+
)
|
402 |
+
self.bn1 = norm_layer(out_channels)
|
403 |
+
self.conv2 = get_conv(
|
404 |
+
out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
405 |
+
)
|
406 |
+
self.bn2 = norm_layer(out_channels)
|
407 |
+
|
408 |
+
def forward(self, x):
|
409 |
+
x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
|
410 |
+
x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
|
411 |
+
return x
|
412 |
+
|
413 |
+
|
414 |
+
# modified based on torchvision\models\resnet.py#27->BasicBlock
|
415 |
+
class ResBlock(nn.Module):
|
416 |
+
expansion: int = 1
|
417 |
+
|
418 |
+
def __init__(
|
419 |
+
self,
|
420 |
+
inplanes: int,
|
421 |
+
planes: int,
|
422 |
+
stride: int = 1,
|
423 |
+
downsample: Optional[nn.Module] = None,
|
424 |
+
groups: int = 1,
|
425 |
+
base_width: int = 64,
|
426 |
+
dilation: int = 1,
|
427 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
428 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
429 |
+
conv_type: str = "conv",
|
430 |
+
mask: bool = False,
|
431 |
+
) -> None:
|
432 |
+
super(ResBlock, self).__init__()
|
433 |
+
if gate is None:
|
434 |
+
self.gate = nn.ReLU(inplace=True)
|
435 |
+
else:
|
436 |
+
self.gate = gate
|
437 |
+
if norm_layer is None:
|
438 |
+
norm_layer = nn.BatchNorm2d
|
439 |
+
if groups != 1 or base_width != 64:
|
440 |
+
raise ValueError("ResBlock only supports groups=1 and base_width=64")
|
441 |
+
if dilation > 1:
|
442 |
+
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
443 |
+
# Both self.conv1 and self.downsample layers
|
444 |
+
# downsample the input when stride != 1
|
445 |
+
self.conv1 = get_conv(
|
446 |
+
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
447 |
+
)
|
448 |
+
self.bn1 = norm_layer(planes)
|
449 |
+
self.conv2 = get_conv(
|
450 |
+
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
451 |
+
)
|
452 |
+
self.bn2 = norm_layer(planes)
|
453 |
+
self.downsample = downsample
|
454 |
+
self.stride = stride
|
455 |
+
|
456 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
457 |
+
identity = x
|
458 |
+
|
459 |
+
out = self.conv1(x)
|
460 |
+
out = self.bn1(out)
|
461 |
+
out = self.gate(out)
|
462 |
+
|
463 |
+
out = self.conv2(out)
|
464 |
+
out = self.bn2(out)
|
465 |
+
|
466 |
+
if self.downsample is not None:
|
467 |
+
identity = self.downsample(x)
|
468 |
+
|
469 |
+
out += identity
|
470 |
+
out = self.gate(out)
|
471 |
+
|
472 |
+
return out
|
473 |
+
|
474 |
+
|
475 |
+
class SDDH(nn.Module):
|
476 |
+
def __init__(
|
477 |
+
self,
|
478 |
+
dims: int,
|
479 |
+
kernel_size: int = 3,
|
480 |
+
n_pos: int = 8,
|
481 |
+
gate=nn.ReLU(),
|
482 |
+
conv2D=False,
|
483 |
+
mask=False,
|
484 |
+
):
|
485 |
+
super(SDDH, self).__init__()
|
486 |
+
self.kernel_size = kernel_size
|
487 |
+
self.n_pos = n_pos
|
488 |
+
self.conv2D = conv2D
|
489 |
+
self.mask = mask
|
490 |
+
|
491 |
+
self.get_patches_func = get_patches
|
492 |
+
|
493 |
+
# estimate offsets
|
494 |
+
self.channel_num = 3 * n_pos if mask else 2 * n_pos
|
495 |
+
self.offset_conv = nn.Sequential(
|
496 |
+
nn.Conv2d(
|
497 |
+
dims,
|
498 |
+
self.channel_num,
|
499 |
+
kernel_size=kernel_size,
|
500 |
+
stride=1,
|
501 |
+
padding=0,
|
502 |
+
bias=True,
|
503 |
+
),
|
504 |
+
gate,
|
505 |
+
nn.Conv2d(
|
506 |
+
self.channel_num,
|
507 |
+
self.channel_num,
|
508 |
+
kernel_size=1,
|
509 |
+
stride=1,
|
510 |
+
padding=0,
|
511 |
+
bias=True,
|
512 |
+
),
|
513 |
+
)
|
514 |
+
|
515 |
+
# sampled feature conv
|
516 |
+
self.sf_conv = nn.Conv2d(
|
517 |
+
dims, dims, kernel_size=1, stride=1, padding=0, bias=False
|
518 |
+
)
|
519 |
+
|
520 |
+
# convM
|
521 |
+
if not conv2D:
|
522 |
+
# deformable desc weights
|
523 |
+
agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
|
524 |
+
self.register_parameter("agg_weights", agg_weights)
|
525 |
+
else:
|
526 |
+
self.convM = nn.Conv2d(
|
527 |
+
dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
|
528 |
+
)
|
529 |
+
|
530 |
+
def forward(self, x, keypoints):
|
531 |
+
# x: [B,C,H,W]
|
532 |
+
# keypoints: list, [[N_kpts,2], ...] (w,h)
|
533 |
+
b, c, h, w = x.shape
|
534 |
+
wh = torch.tensor([[w - 1, h - 1]], device=x.device)
|
535 |
+
max_offset = max(h, w) / 4.0
|
536 |
+
|
537 |
+
offsets = []
|
538 |
+
descriptors = []
|
539 |
+
# get offsets for each keypoint
|
540 |
+
for ib in range(b):
|
541 |
+
xi, kptsi = x[ib], keypoints[ib]
|
542 |
+
kptsi_wh = (kptsi / 2 + 0.5) * wh
|
543 |
+
N_kpts = len(kptsi)
|
544 |
+
|
545 |
+
if self.kernel_size > 1:
|
546 |
+
patch = self.get_patches_func(
|
547 |
+
xi, kptsi_wh.long(), self.kernel_size
|
548 |
+
) # [N_kpts, C, K, K]
|
549 |
+
else:
|
550 |
+
kptsi_wh_long = kptsi_wh.long()
|
551 |
+
patch = (
|
552 |
+
xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
|
553 |
+
.permute(1, 0)
|
554 |
+
.reshape(N_kpts, c, 1, 1)
|
555 |
+
)
|
556 |
+
|
557 |
+
offset = self.offset_conv(patch).clamp(
|
558 |
+
-max_offset, max_offset
|
559 |
+
) # [N_kpts, 2*n_pos, 1, 1]
|
560 |
+
if self.mask:
|
561 |
+
offset = (
|
562 |
+
offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
|
563 |
+
) # [N_kpts, n_pos, 3]
|
564 |
+
offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
|
565 |
+
mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
|
566 |
+
else:
|
567 |
+
offset = (
|
568 |
+
offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
|
569 |
+
) # [N_kpts, n_pos, 2]
|
570 |
+
offsets.append(offset) # for visualization
|
571 |
+
|
572 |
+
# get sample positions
|
573 |
+
pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
|
574 |
+
pos = 2.0 * pos / wh[None] - 1
|
575 |
+
pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
|
576 |
+
|
577 |
+
# sample features
|
578 |
+
features = F.grid_sample(
|
579 |
+
xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
|
580 |
+
) # [1,C,(N_kpts*n_pos),1]
|
581 |
+
features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
|
582 |
+
1, 0, 2, 3
|
583 |
+
) # [N_kpts, C, n_pos, 1]
|
584 |
+
if self.mask:
|
585 |
+
features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
|
586 |
+
|
587 |
+
features = torch.selu_(self.sf_conv(features)).squeeze(
|
588 |
+
-1
|
589 |
+
) # [N_kpts, C, n_pos]
|
590 |
+
# convM
|
591 |
+
if not self.conv2D:
|
592 |
+
descs = torch.einsum(
|
593 |
+
"ncp,pcd->nd", features, self.agg_weights
|
594 |
+
) # [N_kpts, C]
|
595 |
+
else:
|
596 |
+
features = features.reshape(N_kpts, -1)[
|
597 |
+
:, :, None, None
|
598 |
+
] # [N_kpts, C*n_pos, 1, 1]
|
599 |
+
descs = self.convM(features).squeeze() # [N_kpts, C]
|
600 |
+
|
601 |
+
# normalize
|
602 |
+
descs = F.normalize(descs, p=2.0, dim=1)
|
603 |
+
descriptors.append(descs)
|
604 |
+
|
605 |
+
return descriptors, offsets
|
606 |
+
|
607 |
+
|
608 |
+
class ALIKED(Extractor):
|
609 |
+
default_conf = {
|
610 |
+
"model_name": "aliked-n16",
|
611 |
+
"max_num_keypoints": -1,
|
612 |
+
"detection_threshold": 0.2,
|
613 |
+
"nms_radius": 2,
|
614 |
+
}
|
615 |
+
|
616 |
+
checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
|
617 |
+
|
618 |
+
n_limit_max = 20000
|
619 |
+
|
620 |
+
# c1, c2, c3, c4, dim, K, M
|
621 |
+
cfgs = {
|
622 |
+
"aliked-t16": [8, 16, 32, 64, 64, 3, 16],
|
623 |
+
"aliked-n16": [16, 32, 64, 128, 128, 3, 16],
|
624 |
+
"aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
|
625 |
+
"aliked-n32": [16, 32, 64, 128, 128, 3, 32],
|
626 |
+
}
|
627 |
+
preprocess_conf = {
|
628 |
+
"resize": 1024,
|
629 |
+
}
|
630 |
+
|
631 |
+
required_data_keys = ["image"]
|
632 |
+
|
633 |
+
def __init__(self, **conf):
|
634 |
+
super().__init__(**conf) # Update with default configuration.
|
635 |
+
conf = self.conf
|
636 |
+
c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
|
637 |
+
conv_types = ["conv", "conv", "dcn", "dcn"]
|
638 |
+
conv2D = False
|
639 |
+
mask = False
|
640 |
+
|
641 |
+
# build model
|
642 |
+
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
643 |
+
self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
|
644 |
+
self.norm = nn.BatchNorm2d
|
645 |
+
self.gate = nn.SELU(inplace=True)
|
646 |
+
self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
|
647 |
+
self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
|
648 |
+
self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
|
649 |
+
self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
|
650 |
+
|
651 |
+
self.conv1 = resnet.conv1x1(c1, dim // 4)
|
652 |
+
self.conv2 = resnet.conv1x1(c2, dim // 4)
|
653 |
+
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
654 |
+
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
655 |
+
self.upsample2 = nn.Upsample(
|
656 |
+
scale_factor=2, mode="bilinear", align_corners=True
|
657 |
+
)
|
658 |
+
self.upsample4 = nn.Upsample(
|
659 |
+
scale_factor=4, mode="bilinear", align_corners=True
|
660 |
+
)
|
661 |
+
self.upsample8 = nn.Upsample(
|
662 |
+
scale_factor=8, mode="bilinear", align_corners=True
|
663 |
+
)
|
664 |
+
self.upsample32 = nn.Upsample(
|
665 |
+
scale_factor=32, mode="bilinear", align_corners=True
|
666 |
+
)
|
667 |
+
self.score_head = nn.Sequential(
|
668 |
+
resnet.conv1x1(dim, 8),
|
669 |
+
self.gate,
|
670 |
+
resnet.conv3x3(8, 4),
|
671 |
+
self.gate,
|
672 |
+
resnet.conv3x3(4, 4),
|
673 |
+
self.gate,
|
674 |
+
resnet.conv3x3(4, 1),
|
675 |
+
)
|
676 |
+
self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
|
677 |
+
self.dkd = DKD(
|
678 |
+
radius=conf.nms_radius,
|
679 |
+
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
|
680 |
+
scores_th=conf.detection_threshold,
|
681 |
+
n_limit=conf.max_num_keypoints
|
682 |
+
if conf.max_num_keypoints > 0
|
683 |
+
else self.n_limit_max,
|
684 |
+
)
|
685 |
+
|
686 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
687 |
+
self.checkpoint_url.format(conf.model_name), map_location="cpu"
|
688 |
+
)
|
689 |
+
self.load_state_dict(state_dict, strict=True)
|
690 |
+
|
691 |
+
def get_resblock(self, c_in, c_out, conv_type, mask):
|
692 |
+
return ResBlock(
|
693 |
+
c_in,
|
694 |
+
c_out,
|
695 |
+
1,
|
696 |
+
nn.Conv2d(c_in, c_out, 1),
|
697 |
+
gate=self.gate,
|
698 |
+
norm_layer=self.norm,
|
699 |
+
conv_type=conv_type,
|
700 |
+
mask=mask,
|
701 |
+
)
|
702 |
+
|
703 |
+
def extract_dense_map(self, image):
|
704 |
+
# Pads images such that dimensions are divisible by
|
705 |
+
div_by = 2**5
|
706 |
+
padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
|
707 |
+
image = padder.pad(image)
|
708 |
+
|
709 |
+
# ================================== feature encoder
|
710 |
+
x1 = self.block1(image) # B x c1 x H x W
|
711 |
+
x2 = self.pool2(x1)
|
712 |
+
x2 = self.block2(x2) # B x c2 x H/2 x W/2
|
713 |
+
x3 = self.pool4(x2)
|
714 |
+
x3 = self.block3(x3) # B x c3 x H/8 x W/8
|
715 |
+
x4 = self.pool4(x3)
|
716 |
+
x4 = self.block4(x4) # B x dim x H/32 x W/32
|
717 |
+
# ================================== feature aggregation
|
718 |
+
x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
|
719 |
+
x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
|
720 |
+
x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
|
721 |
+
x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
|
722 |
+
x2_up = self.upsample2(x2) # B x dim//4 x H x W
|
723 |
+
x3_up = self.upsample8(x3) # B x dim//4 x H x W
|
724 |
+
x4_up = self.upsample32(x4) # B x dim//4 x H x W
|
725 |
+
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
|
726 |
+
# ================================== score head
|
727 |
+
score_map = torch.sigmoid(self.score_head(x1234))
|
728 |
+
feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
|
729 |
+
|
730 |
+
# Unpads images
|
731 |
+
feature_map = padder.unpad(feature_map)
|
732 |
+
score_map = padder.unpad(score_map)
|
733 |
+
|
734 |
+
return feature_map, score_map
|
735 |
+
|
736 |
+
def forward(self, data: dict) -> dict:
|
737 |
+
# need to set here unfortunately
|
738 |
+
self.dkd.n_limit = (
|
739 |
+
self.conf.max_num_keypoints
|
740 |
+
if self.conf.max_num_keypoints > 0
|
741 |
+
else self.n_limit_max
|
742 |
+
)
|
743 |
+
image = data["image"]
|
744 |
+
if image.shape[1] == 1:
|
745 |
+
image = grayscale_to_rgb(image)
|
746 |
+
feature_map, score_map = self.extract_dense_map(image)
|
747 |
+
keypoints, kptscores, scoredispersitys = self.dkd(
|
748 |
+
score_map, image_size=data.get("image_size")
|
749 |
+
)
|
750 |
+
# descriptors, offsets = self.desc_head(feature_map, keypoints)
|
751 |
+
|
752 |
+
_, _, h, w = image.shape
|
753 |
+
wh = torch.tensor([w - 1, h - 1], device=image.device)
|
754 |
+
# no padding required
|
755 |
+
# we can set detection_threshold=-1 and conf.max_num_keypoints > 0
|
756 |
+
return {
|
757 |
+
"keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
|
758 |
+
# "descriptors": torch.stack(descriptors), # B x N x D
|
759 |
+
"keypoint_scores": torch.stack(kptscores), # B x N
|
760 |
+
"scoremap": score_map, # B x 1 x H x W
|
761 |
+
}
|
762 |
+
|
763 |
+
|
764 |
+
class ALIKEDROT(ALIKED):
|
765 |
+
default_conf = {
|
766 |
+
"model_name": "aliked-n16rot",
|
767 |
+
"max_num_keypoints": -1,
|
768 |
+
"detection_threshold": 0.2,
|
769 |
+
"nms_radius": 2,
|
770 |
+
}
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/disk.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kornia
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .utils import Extractor
|
5 |
+
|
6 |
+
|
7 |
+
class DISK(Extractor):
|
8 |
+
default_conf = {
|
9 |
+
"weights": "depth",
|
10 |
+
"max_num_keypoints": None,
|
11 |
+
"desc_dim": 128,
|
12 |
+
"nms_window_size": 5,
|
13 |
+
"detection_threshold": 0.0,
|
14 |
+
"pad_if_not_divisible": True,
|
15 |
+
}
|
16 |
+
|
17 |
+
preprocess_conf = {
|
18 |
+
"resize": 1024,
|
19 |
+
"grayscale": False,
|
20 |
+
}
|
21 |
+
|
22 |
+
required_data_keys = ["image"]
|
23 |
+
|
24 |
+
def __init__(self, **conf) -> None:
|
25 |
+
super().__init__(**conf) # Update with default configuration.
|
26 |
+
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
|
27 |
+
|
28 |
+
def forward(self, data: dict) -> dict:
|
29 |
+
"""Compute keypoints, scores, descriptors for image"""
|
30 |
+
for key in self.required_data_keys:
|
31 |
+
assert key in data, f"Missing key {key} in data"
|
32 |
+
image = data["image"]
|
33 |
+
if image.shape[1] == 1:
|
34 |
+
image = kornia.color.grayscale_to_rgb(image)
|
35 |
+
features = self.model(
|
36 |
+
image,
|
37 |
+
n=self.conf.max_num_keypoints,
|
38 |
+
window_size=self.conf.nms_window_size,
|
39 |
+
score_threshold=self.conf.detection_threshold,
|
40 |
+
pad_if_not_divisible=self.conf.pad_if_not_divisible,
|
41 |
+
)
|
42 |
+
keypoints = [f.keypoints for f in features]
|
43 |
+
|
44 |
+
keypoints = torch.stack(keypoints, 0)
|
45 |
+
|
46 |
+
return {
|
47 |
+
"keypoints": keypoints.to(image).contiguous(),
|
48 |
+
}
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/dog_hardnet.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from kornia.color import rgb_to_grayscale
|
3 |
+
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
|
4 |
+
|
5 |
+
from .sift import SIFT
|
6 |
+
|
7 |
+
|
8 |
+
class DoGHardNet(SIFT):
|
9 |
+
required_data_keys = ["image"]
|
10 |
+
|
11 |
+
def __init__(self, **conf):
|
12 |
+
super().__init__(**conf)
|
13 |
+
self.laf_desc = LAFDescriptor(HardNet(True)).eval()
|
14 |
+
|
15 |
+
def forward(self, data: dict) -> dict:
|
16 |
+
image = data["image"]
|
17 |
+
if image.shape[1] == 3:
|
18 |
+
image = rgb_to_grayscale(image)
|
19 |
+
device = image.device
|
20 |
+
self.laf_desc = self.laf_desc.to(device)
|
21 |
+
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
|
22 |
+
pred = []
|
23 |
+
if "image_size" in data.keys():
|
24 |
+
im_size = data.get("image_size").long()
|
25 |
+
else:
|
26 |
+
im_size = None
|
27 |
+
for k in range(len(image)):
|
28 |
+
img = image[k]
|
29 |
+
if im_size is not None:
|
30 |
+
w, h = data["image_size"][k]
|
31 |
+
img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
|
32 |
+
p = self.extract_single_image(img)
|
33 |
+
lafs = laf_from_center_scale_ori(
|
34 |
+
p["keypoints"].reshape(1, -1, 2),
|
35 |
+
6.0 * p["scales"].reshape(1, -1, 1, 1),
|
36 |
+
torch.rad2deg(p["oris"]).reshape(1, -1, 1),
|
37 |
+
).to(device)
|
38 |
+
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
|
39 |
+
pred.append(p)
|
40 |
+
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
41 |
+
return pred
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/lightglue.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from pathlib import Path
|
3 |
+
from types import SimpleNamespace
|
4 |
+
from typing import Callable, List, Optional, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
try:
|
12 |
+
from flash_attn.modules.mha import FlashCrossAttention
|
13 |
+
except ModuleNotFoundError:
|
14 |
+
FlashCrossAttention = None
|
15 |
+
|
16 |
+
if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
|
17 |
+
FLASH_AVAILABLE = True
|
18 |
+
else:
|
19 |
+
FLASH_AVAILABLE = False
|
20 |
+
|
21 |
+
torch.backends.cudnn.deterministic = True
|
22 |
+
|
23 |
+
|
24 |
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
25 |
+
def normalize_keypoints(
|
26 |
+
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
27 |
+
) -> torch.Tensor:
|
28 |
+
if size is None:
|
29 |
+
size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
30 |
+
elif not isinstance(size, torch.Tensor):
|
31 |
+
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
|
32 |
+
size = size.to(kpts)
|
33 |
+
shift = size / 2
|
34 |
+
scale = size.max(-1).values / 2
|
35 |
+
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
|
36 |
+
return kpts
|
37 |
+
|
38 |
+
|
39 |
+
def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
|
40 |
+
if length <= x.shape[-2]:
|
41 |
+
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
|
42 |
+
pad = torch.ones(
|
43 |
+
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
|
44 |
+
)
|
45 |
+
y = torch.cat([x, pad], dim=-2)
|
46 |
+
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
|
47 |
+
mask[..., : x.shape[-2], :] = True
|
48 |
+
return y, mask
|
49 |
+
|
50 |
+
|
51 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
52 |
+
x = x.unflatten(-1, (-1, 2))
|
53 |
+
x1, x2 = x.unbind(dim=-1)
|
54 |
+
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
55 |
+
|
56 |
+
|
57 |
+
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
58 |
+
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
|
59 |
+
|
60 |
+
|
61 |
+
class LearnableFourierPositionalEncoding(nn.Module):
|
62 |
+
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
|
63 |
+
super().__init__()
|
64 |
+
F_dim = F_dim if F_dim is not None else dim
|
65 |
+
self.gamma = gamma
|
66 |
+
self.Wr = nn.Linear(M, F_dim // 2, bias=False)
|
67 |
+
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
|
68 |
+
|
69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
70 |
+
"""encode position vector"""
|
71 |
+
projected = self.Wr(x)
|
72 |
+
cosines, sines = torch.cos(projected), torch.sin(projected)
|
73 |
+
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
|
74 |
+
return emb.repeat_interleave(2, dim=-1)
|
75 |
+
|
76 |
+
|
77 |
+
class TokenConfidence(nn.Module):
|
78 |
+
def __init__(self, dim: int) -> None:
|
79 |
+
super().__init__()
|
80 |
+
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
|
81 |
+
|
82 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
83 |
+
"""get confidence tokens"""
|
84 |
+
return (
|
85 |
+
self.token(desc0.detach()).squeeze(-1),
|
86 |
+
self.token(desc1.detach()).squeeze(-1),
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
class Attention(nn.Module):
|
91 |
+
def __init__(self, allow_flash: bool) -> None:
|
92 |
+
super().__init__()
|
93 |
+
if allow_flash and not FLASH_AVAILABLE:
|
94 |
+
warnings.warn(
|
95 |
+
"FlashAttention is not available. For optimal speed, "
|
96 |
+
"consider installing torch >= 2.0 or flash-attn.",
|
97 |
+
stacklevel=2,
|
98 |
+
)
|
99 |
+
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
100 |
+
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
|
101 |
+
if allow_flash and FlashCrossAttention:
|
102 |
+
self.flash_ = FlashCrossAttention()
|
103 |
+
if self.has_sdp:
|
104 |
+
torch.backends.cuda.enable_flash_sdp(allow_flash)
|
105 |
+
|
106 |
+
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
107 |
+
if q.shape[-2] == 0 or k.shape[-2] == 0:
|
108 |
+
return q.new_zeros((*q.shape[:-1], v.shape[-1]))
|
109 |
+
if self.enable_flash and q.device.type == "cuda":
|
110 |
+
# use torch 2.0 scaled_dot_product_attention with flash
|
111 |
+
if self.has_sdp:
|
112 |
+
args = [x.half().contiguous() for x in [q, k, v]]
|
113 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
114 |
+
return v if mask is None else v.nan_to_num()
|
115 |
+
else:
|
116 |
+
assert mask is None
|
117 |
+
q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
|
118 |
+
m = self.flash_(q.half(), torch.stack([k, v], 2).half())
|
119 |
+
return m.transpose(-2, -3).to(q.dtype).clone()
|
120 |
+
elif self.has_sdp:
|
121 |
+
args = [x.contiguous() for x in [q, k, v]]
|
122 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
123 |
+
return v if mask is None else v.nan_to_num()
|
124 |
+
else:
|
125 |
+
s = q.shape[-1] ** -0.5
|
126 |
+
sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
127 |
+
if mask is not None:
|
128 |
+
sim.masked_fill(~mask, -float("inf"))
|
129 |
+
attn = F.softmax(sim, -1)
|
130 |
+
return torch.einsum("...ij,...jd->...id", attn, v)
|
131 |
+
|
132 |
+
|
133 |
+
class SelfBlock(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
136 |
+
) -> None:
|
137 |
+
super().__init__()
|
138 |
+
self.embed_dim = embed_dim
|
139 |
+
self.num_heads = num_heads
|
140 |
+
assert self.embed_dim % num_heads == 0
|
141 |
+
self.head_dim = self.embed_dim // num_heads
|
142 |
+
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
|
143 |
+
self.inner_attn = Attention(flash)
|
144 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
145 |
+
self.ffn = nn.Sequential(
|
146 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
147 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
148 |
+
nn.GELU(),
|
149 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
x: torch.Tensor,
|
155 |
+
encoding: torch.Tensor,
|
156 |
+
mask: Optional[torch.Tensor] = None,
|
157 |
+
) -> torch.Tensor:
|
158 |
+
qkv = self.Wqkv(x)
|
159 |
+
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
160 |
+
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
161 |
+
q = apply_cached_rotary_emb(encoding, q)
|
162 |
+
k = apply_cached_rotary_emb(encoding, k)
|
163 |
+
context = self.inner_attn(q, k, v, mask=mask)
|
164 |
+
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
165 |
+
return x + self.ffn(torch.cat([x, message], -1))
|
166 |
+
|
167 |
+
|
168 |
+
class CrossBlock(nn.Module):
|
169 |
+
def __init__(
|
170 |
+
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
171 |
+
) -> None:
|
172 |
+
super().__init__()
|
173 |
+
self.heads = num_heads
|
174 |
+
dim_head = embed_dim // num_heads
|
175 |
+
self.scale = dim_head**-0.5
|
176 |
+
inner_dim = dim_head * num_heads
|
177 |
+
self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
|
178 |
+
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
|
179 |
+
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
|
180 |
+
self.ffn = nn.Sequential(
|
181 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
182 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
183 |
+
nn.GELU(),
|
184 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
185 |
+
)
|
186 |
+
if flash and FLASH_AVAILABLE:
|
187 |
+
self.flash = Attention(True)
|
188 |
+
else:
|
189 |
+
self.flash = None
|
190 |
+
|
191 |
+
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
192 |
+
return func(x0), func(x1)
|
193 |
+
|
194 |
+
def forward(
|
195 |
+
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
|
196 |
+
) -> List[torch.Tensor]:
|
197 |
+
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
198 |
+
v0, v1 = self.map_(self.to_v, x0, x1)
|
199 |
+
qk0, qk1, v0, v1 = map(
|
200 |
+
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
201 |
+
(qk0, qk1, v0, v1),
|
202 |
+
)
|
203 |
+
if self.flash is not None and qk0.device.type == "cuda":
|
204 |
+
m0 = self.flash(qk0, qk1, v1, mask)
|
205 |
+
m1 = self.flash(
|
206 |
+
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
210 |
+
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
211 |
+
if mask is not None:
|
212 |
+
sim = sim.masked_fill(~mask, -float("inf"))
|
213 |
+
attn01 = F.softmax(sim, dim=-1)
|
214 |
+
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
215 |
+
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
216 |
+
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
217 |
+
if mask is not None:
|
218 |
+
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
219 |
+
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
220 |
+
m0, m1 = self.map_(self.to_out, m0, m1)
|
221 |
+
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
222 |
+
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
223 |
+
return x0, x1
|
224 |
+
|
225 |
+
|
226 |
+
class TransformerLayer(nn.Module):
|
227 |
+
def __init__(self, *args, **kwargs):
|
228 |
+
super().__init__()
|
229 |
+
self.self_attn = SelfBlock(*args, **kwargs)
|
230 |
+
self.cross_attn = CrossBlock(*args, **kwargs)
|
231 |
+
|
232 |
+
def forward(
|
233 |
+
self,
|
234 |
+
desc0,
|
235 |
+
desc1,
|
236 |
+
encoding0,
|
237 |
+
encoding1,
|
238 |
+
mask0: Optional[torch.Tensor] = None,
|
239 |
+
mask1: Optional[torch.Tensor] = None,
|
240 |
+
):
|
241 |
+
if mask0 is not None and mask1 is not None:
|
242 |
+
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
|
243 |
+
else:
|
244 |
+
desc0 = self.self_attn(desc0, encoding0)
|
245 |
+
desc1 = self.self_attn(desc1, encoding1)
|
246 |
+
return self.cross_attn(desc0, desc1)
|
247 |
+
|
248 |
+
# This part is compiled and allows padding inputs
|
249 |
+
def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
|
250 |
+
mask = mask0 & mask1.transpose(-1, -2)
|
251 |
+
mask0 = mask0 & mask0.transpose(-1, -2)
|
252 |
+
mask1 = mask1 & mask1.transpose(-1, -2)
|
253 |
+
desc0 = self.self_attn(desc0, encoding0, mask0)
|
254 |
+
desc1 = self.self_attn(desc1, encoding1, mask1)
|
255 |
+
return self.cross_attn(desc0, desc1, mask)
|
256 |
+
|
257 |
+
|
258 |
+
def sigmoid_log_double_softmax(
|
259 |
+
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
260 |
+
) -> torch.Tensor:
|
261 |
+
"""create the log assignment matrix from logits and similarity"""
|
262 |
+
b, m, n = sim.shape
|
263 |
+
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
|
264 |
+
scores0 = F.log_softmax(sim, 2)
|
265 |
+
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
266 |
+
scores = sim.new_full((b, m + 1, n + 1), 0)
|
267 |
+
scores[:, :m, :n] = scores0 + scores1 + certainties
|
268 |
+
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
|
269 |
+
scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
|
270 |
+
return scores
|
271 |
+
|
272 |
+
|
273 |
+
class MatchAssignment(nn.Module):
|
274 |
+
def __init__(self, dim: int) -> None:
|
275 |
+
super().__init__()
|
276 |
+
self.dim = dim
|
277 |
+
self.matchability = nn.Linear(dim, 1, bias=True)
|
278 |
+
self.final_proj = nn.Linear(dim, dim, bias=True)
|
279 |
+
|
280 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
281 |
+
"""build assignment matrix from descriptors"""
|
282 |
+
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
283 |
+
_, _, d = mdesc0.shape
|
284 |
+
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
|
285 |
+
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
|
286 |
+
z0 = self.matchability(desc0)
|
287 |
+
z1 = self.matchability(desc1)
|
288 |
+
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
289 |
+
return scores, sim
|
290 |
+
|
291 |
+
def get_matchability(self, desc: torch.Tensor):
|
292 |
+
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
|
293 |
+
|
294 |
+
|
295 |
+
def filter_matches(scores: torch.Tensor, th: float):
|
296 |
+
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
297 |
+
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
|
298 |
+
m0, m1 = max0.indices, max1.indices
|
299 |
+
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
300 |
+
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
301 |
+
mutual0 = indices0 == m1.gather(1, m0)
|
302 |
+
mutual1 = indices1 == m0.gather(1, m1)
|
303 |
+
max0_exp = max0.values.exp()
|
304 |
+
zero = max0_exp.new_tensor(0)
|
305 |
+
mscores0 = torch.where(mutual0, max0_exp, zero)
|
306 |
+
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
307 |
+
valid0 = mutual0 & (mscores0 > th)
|
308 |
+
valid1 = mutual1 & valid0.gather(1, m1)
|
309 |
+
m0 = torch.where(valid0, m0, -1)
|
310 |
+
m1 = torch.where(valid1, m1, -1)
|
311 |
+
return m0, m1, mscores0, mscores1
|
312 |
+
|
313 |
+
|
314 |
+
class LightGlue(nn.Module):
|
315 |
+
default_conf = {
|
316 |
+
"name": "lightglue", # just for interfacing
|
317 |
+
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
318 |
+
"descriptor_dim": 256,
|
319 |
+
"add_scale_ori": False,
|
320 |
+
"n_layers": 9,
|
321 |
+
"num_heads": 4,
|
322 |
+
"flash": True, # enable FlashAttention if available.
|
323 |
+
"mp": False, # enable mixed precision
|
324 |
+
"depth_confidence": 0.95, # early stopping, disable with -1
|
325 |
+
"width_confidence": 0.99, # point pruning, disable with -1
|
326 |
+
"filter_threshold": 0.1, # match threshold
|
327 |
+
"weights": None,
|
328 |
+
}
|
329 |
+
|
330 |
+
# Point pruning involves an overhead (gather).
|
331 |
+
# Therefore, we only activate it if there are enough keypoints.
|
332 |
+
pruning_keypoint_thresholds = {
|
333 |
+
"cpu": -1,
|
334 |
+
"mps": -1,
|
335 |
+
"cuda": 1024,
|
336 |
+
"flash": 1536,
|
337 |
+
}
|
338 |
+
|
339 |
+
required_data_keys = ["image0", "image1"]
|
340 |
+
|
341 |
+
version = "v0.1_arxiv"
|
342 |
+
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
|
343 |
+
|
344 |
+
features = {
|
345 |
+
"superpoint": {
|
346 |
+
"weights": "superpoint_lightglue",
|
347 |
+
"input_dim": 256,
|
348 |
+
},
|
349 |
+
"disk": {
|
350 |
+
"weights": "disk_lightglue",
|
351 |
+
"input_dim": 128,
|
352 |
+
},
|
353 |
+
"aliked": {
|
354 |
+
"weights": "aliked_lightglue",
|
355 |
+
"input_dim": 128,
|
356 |
+
},
|
357 |
+
"sift": {
|
358 |
+
"weights": "sift_lightglue",
|
359 |
+
"input_dim": 128,
|
360 |
+
"add_scale_ori": True,
|
361 |
+
},
|
362 |
+
"doghardnet": {
|
363 |
+
"weights": "doghardnet_lightglue",
|
364 |
+
"input_dim": 128,
|
365 |
+
"add_scale_ori": True,
|
366 |
+
},
|
367 |
+
}
|
368 |
+
|
369 |
+
def __init__(self, features="superpoint", **conf) -> None:
|
370 |
+
super().__init__()
|
371 |
+
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
|
372 |
+
if features is not None:
|
373 |
+
if features not in self.features:
|
374 |
+
raise ValueError(
|
375 |
+
f"Unsupported features: {features} not in "
|
376 |
+
f"{{{','.join(self.features)}}}"
|
377 |
+
)
|
378 |
+
for k, v in self.features[features].items():
|
379 |
+
setattr(conf, k, v)
|
380 |
+
|
381 |
+
if conf.input_dim != conf.descriptor_dim:
|
382 |
+
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
383 |
+
else:
|
384 |
+
self.input_proj = nn.Identity()
|
385 |
+
|
386 |
+
head_dim = conf.descriptor_dim // conf.num_heads
|
387 |
+
self.posenc = LearnableFourierPositionalEncoding(
|
388 |
+
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
|
389 |
+
)
|
390 |
+
|
391 |
+
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
392 |
+
|
393 |
+
self.transformers = nn.ModuleList(
|
394 |
+
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
|
395 |
+
)
|
396 |
+
|
397 |
+
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
398 |
+
self.token_confidence = nn.ModuleList(
|
399 |
+
[TokenConfidence(d) for _ in range(n - 1)]
|
400 |
+
)
|
401 |
+
self.register_buffer(
|
402 |
+
"confidence_thresholds",
|
403 |
+
torch.Tensor(
|
404 |
+
[self.confidence_threshold(i) for i in range(self.conf.n_layers)]
|
405 |
+
),
|
406 |
+
)
|
407 |
+
|
408 |
+
state_dict = None
|
409 |
+
if features is not None:
|
410 |
+
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
|
411 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
412 |
+
self.url.format(self.version, features), file_name=fname
|
413 |
+
)
|
414 |
+
self.load_state_dict(state_dict, strict=False)
|
415 |
+
elif conf.weights is not None:
|
416 |
+
path = Path(__file__).parent
|
417 |
+
path = path / "weights/{}.pth".format(self.conf.weights)
|
418 |
+
state_dict = torch.load(str(path), map_location="cpu")
|
419 |
+
|
420 |
+
if state_dict:
|
421 |
+
# rename old state dict entries
|
422 |
+
for i in range(self.conf.n_layers):
|
423 |
+
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
424 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
425 |
+
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
426 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
427 |
+
self.load_state_dict(state_dict, strict=False)
|
428 |
+
|
429 |
+
# static lengths LightGlue is compiled for (only used with torch.compile)
|
430 |
+
self.static_lengths = None
|
431 |
+
|
432 |
+
def compile(
|
433 |
+
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
|
434 |
+
):
|
435 |
+
if self.conf.width_confidence != -1:
|
436 |
+
warnings.warn(
|
437 |
+
"Point pruning is partially disabled for compiled forward.",
|
438 |
+
stacklevel=2,
|
439 |
+
)
|
440 |
+
|
441 |
+
torch._inductor.cudagraph_mark_step_begin()
|
442 |
+
for i in range(self.conf.n_layers):
|
443 |
+
self.transformers[i].masked_forward = torch.compile(
|
444 |
+
self.transformers[i].masked_forward, mode=mode, fullgraph=True
|
445 |
+
)
|
446 |
+
|
447 |
+
self.static_lengths = static_lengths
|
448 |
+
|
449 |
+
def forward(self, data: dict) -> dict:
|
450 |
+
"""
|
451 |
+
Match keypoints and descriptors between two images
|
452 |
+
|
453 |
+
Input (dict):
|
454 |
+
image0: dict
|
455 |
+
keypoints: [B x M x 2]
|
456 |
+
descriptors: [B x M x D]
|
457 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
458 |
+
image1: dict
|
459 |
+
keypoints: [B x N x 2]
|
460 |
+
descriptors: [B x N x D]
|
461 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
462 |
+
Output (dict):
|
463 |
+
matches0: [B x M]
|
464 |
+
matching_scores0: [B x M]
|
465 |
+
matches1: [B x N]
|
466 |
+
matching_scores1: [B x N]
|
467 |
+
matches: List[[Si x 2]]
|
468 |
+
scores: List[[Si]]
|
469 |
+
stop: int
|
470 |
+
prune0: [B x M]
|
471 |
+
prune1: [B x N]
|
472 |
+
"""
|
473 |
+
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
474 |
+
return self._forward(data)
|
475 |
+
|
476 |
+
def _forward(self, data: dict) -> dict:
|
477 |
+
for key in self.required_data_keys:
|
478 |
+
assert key in data, f"Missing key {key} in data"
|
479 |
+
data0, data1 = data["image0"], data["image1"]
|
480 |
+
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
|
481 |
+
b, m, _ = kpts0.shape
|
482 |
+
b, n, _ = kpts1.shape
|
483 |
+
device = kpts0.device
|
484 |
+
size0, size1 = data0.get("image_size"), data1.get("image_size")
|
485 |
+
kpts0 = normalize_keypoints(kpts0, size0).clone()
|
486 |
+
kpts1 = normalize_keypoints(kpts1, size1).clone()
|
487 |
+
|
488 |
+
if self.conf.add_scale_ori:
|
489 |
+
kpts0 = torch.cat(
|
490 |
+
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
491 |
+
)
|
492 |
+
kpts1 = torch.cat(
|
493 |
+
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
494 |
+
)
|
495 |
+
desc0 = data0["descriptors"].detach().contiguous()
|
496 |
+
desc1 = data1["descriptors"].detach().contiguous()
|
497 |
+
|
498 |
+
assert desc0.shape[-1] == self.conf.input_dim
|
499 |
+
assert desc1.shape[-1] == self.conf.input_dim
|
500 |
+
|
501 |
+
if torch.is_autocast_enabled():
|
502 |
+
desc0 = desc0.half()
|
503 |
+
desc1 = desc1.half()
|
504 |
+
|
505 |
+
mask0, mask1 = None, None
|
506 |
+
c = max(m, n)
|
507 |
+
do_compile = self.static_lengths and c <= max(self.static_lengths)
|
508 |
+
if do_compile:
|
509 |
+
kn = min([k for k in self.static_lengths if k >= c])
|
510 |
+
desc0, mask0 = pad_to_length(desc0, kn)
|
511 |
+
desc1, mask1 = pad_to_length(desc1, kn)
|
512 |
+
kpts0, _ = pad_to_length(kpts0, kn)
|
513 |
+
kpts1, _ = pad_to_length(kpts1, kn)
|
514 |
+
desc0 = self.input_proj(desc0)
|
515 |
+
desc1 = self.input_proj(desc1)
|
516 |
+
# cache positional embeddings
|
517 |
+
encoding0 = self.posenc(kpts0)
|
518 |
+
encoding1 = self.posenc(kpts1)
|
519 |
+
|
520 |
+
# GNN + final_proj + assignment
|
521 |
+
do_early_stop = self.conf.depth_confidence > 0
|
522 |
+
do_point_pruning = self.conf.width_confidence > 0 and not do_compile
|
523 |
+
pruning_th = self.pruning_min_kpts(device)
|
524 |
+
if do_point_pruning:
|
525 |
+
ind0 = torch.arange(0, m, device=device)[None]
|
526 |
+
ind1 = torch.arange(0, n, device=device)[None]
|
527 |
+
# We store the index of the layer at which pruning is detected.
|
528 |
+
prune0 = torch.ones_like(ind0)
|
529 |
+
prune1 = torch.ones_like(ind1)
|
530 |
+
token0, token1 = None, None
|
531 |
+
for i in range(self.conf.n_layers):
|
532 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
533 |
+
break
|
534 |
+
desc0, desc1 = self.transformers[i](
|
535 |
+
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
|
536 |
+
)
|
537 |
+
if i == self.conf.n_layers - 1:
|
538 |
+
continue # no early stopping or adaptive width at last layer
|
539 |
+
|
540 |
+
if do_early_stop:
|
541 |
+
token0, token1 = self.token_confidence[i](desc0, desc1)
|
542 |
+
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
|
543 |
+
break
|
544 |
+
if do_point_pruning and desc0.shape[-2] > pruning_th:
|
545 |
+
scores0 = self.log_assignment[i].get_matchability(desc0)
|
546 |
+
prunemask0 = self.get_pruning_mask(token0, scores0, i)
|
547 |
+
keep0 = torch.where(prunemask0)[1]
|
548 |
+
ind0 = ind0.index_select(1, keep0)
|
549 |
+
desc0 = desc0.index_select(1, keep0)
|
550 |
+
encoding0 = encoding0.index_select(-2, keep0)
|
551 |
+
prune0[:, ind0] += 1
|
552 |
+
if do_point_pruning and desc1.shape[-2] > pruning_th:
|
553 |
+
scores1 = self.log_assignment[i].get_matchability(desc1)
|
554 |
+
prunemask1 = self.get_pruning_mask(token1, scores1, i)
|
555 |
+
keep1 = torch.where(prunemask1)[1]
|
556 |
+
ind1 = ind1.index_select(1, keep1)
|
557 |
+
desc1 = desc1.index_select(1, keep1)
|
558 |
+
encoding1 = encoding1.index_select(-2, keep1)
|
559 |
+
prune1[:, ind1] += 1
|
560 |
+
|
561 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
562 |
+
m0 = desc0.new_full((b, m), -1, dtype=torch.long)
|
563 |
+
m1 = desc1.new_full((b, n), -1, dtype=torch.long)
|
564 |
+
mscores0 = desc0.new_zeros((b, m))
|
565 |
+
mscores1 = desc1.new_zeros((b, n))
|
566 |
+
matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
|
567 |
+
mscores = desc0.new_empty((b, 0))
|
568 |
+
if not do_point_pruning:
|
569 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
570 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
571 |
+
return {
|
572 |
+
"matches0": m0,
|
573 |
+
"matches1": m1,
|
574 |
+
"matching_scores0": mscores0,
|
575 |
+
"matching_scores1": mscores1,
|
576 |
+
"stop": i + 1,
|
577 |
+
"matches": matches,
|
578 |
+
"scores": mscores,
|
579 |
+
"prune0": prune0,
|
580 |
+
"prune1": prune1,
|
581 |
+
}
|
582 |
+
|
583 |
+
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
|
584 |
+
scores, _ = self.log_assignment[i](desc0, desc1)
|
585 |
+
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
586 |
+
matches, mscores = [], []
|
587 |
+
for k in range(b):
|
588 |
+
valid = m0[k] > -1
|
589 |
+
m_indices_0 = torch.where(valid)[0]
|
590 |
+
m_indices_1 = m0[k][valid]
|
591 |
+
if do_point_pruning:
|
592 |
+
m_indices_0 = ind0[k, m_indices_0]
|
593 |
+
m_indices_1 = ind1[k, m_indices_1]
|
594 |
+
matches.append(torch.stack([m_indices_0, m_indices_1], -1))
|
595 |
+
mscores.append(mscores0[k][valid])
|
596 |
+
|
597 |
+
# TODO: Remove when hloc switches to the compact format.
|
598 |
+
if do_point_pruning:
|
599 |
+
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
600 |
+
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
601 |
+
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
602 |
+
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
603 |
+
mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
604 |
+
mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
605 |
+
mscores0_[:, ind0] = mscores0
|
606 |
+
mscores1_[:, ind1] = mscores1
|
607 |
+
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
608 |
+
else:
|
609 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
610 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
611 |
+
|
612 |
+
return {
|
613 |
+
"matches0": m0,
|
614 |
+
"matches1": m1,
|
615 |
+
"matching_scores0": mscores0,
|
616 |
+
"matching_scores1": mscores1,
|
617 |
+
"stop": i + 1,
|
618 |
+
"matches": matches,
|
619 |
+
"scores": mscores,
|
620 |
+
"prune0": prune0,
|
621 |
+
"prune1": prune1,
|
622 |
+
}
|
623 |
+
|
624 |
+
def confidence_threshold(self, layer_index: int) -> float:
|
625 |
+
"""scaled confidence threshold"""
|
626 |
+
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
627 |
+
return np.clip(threshold, 0, 1)
|
628 |
+
|
629 |
+
def get_pruning_mask(
|
630 |
+
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
|
631 |
+
) -> torch.Tensor:
|
632 |
+
"""mask points which should be removed"""
|
633 |
+
keep = scores > (1 - self.conf.width_confidence)
|
634 |
+
if confidences is not None: # Low-confidence points are never pruned.
|
635 |
+
keep |= confidences <= self.confidence_thresholds[layer_index]
|
636 |
+
return keep
|
637 |
+
|
638 |
+
def check_if_stop(
|
639 |
+
self,
|
640 |
+
confidences0: torch.Tensor,
|
641 |
+
confidences1: torch.Tensor,
|
642 |
+
layer_index: int,
|
643 |
+
num_points: int,
|
644 |
+
) -> torch.Tensor:
|
645 |
+
"""evaluate stopping condition"""
|
646 |
+
confidences = torch.cat([confidences0, confidences1], -1)
|
647 |
+
threshold = self.confidence_thresholds[layer_index]
|
648 |
+
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
649 |
+
return ratio_confident > self.conf.depth_confidence
|
650 |
+
|
651 |
+
def pruning_min_kpts(self, device: torch.device):
|
652 |
+
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
|
653 |
+
return self.pruning_keypoint_thresholds["flash"]
|
654 |
+
else:
|
655 |
+
return self.pruning_keypoint_thresholds[device.type]
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/sift.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from kornia.color import rgb_to_grayscale
|
7 |
+
from packaging import version
|
8 |
+
|
9 |
+
try:
|
10 |
+
import pycolmap
|
11 |
+
except ImportError:
|
12 |
+
pycolmap = None
|
13 |
+
|
14 |
+
from .utils import Extractor
|
15 |
+
|
16 |
+
|
17 |
+
def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
|
18 |
+
h, w = image_shape
|
19 |
+
ij = np.round(points - 0.5).astype(int).T[::-1]
|
20 |
+
|
21 |
+
# Remove duplicate points (identical coordinates).
|
22 |
+
# Pick highest scale or score
|
23 |
+
s = scales if scores is None else scores
|
24 |
+
buffer = np.zeros((h, w))
|
25 |
+
np.maximum.at(buffer, tuple(ij), s)
|
26 |
+
keep = np.where(buffer[tuple(ij)] == s)[0]
|
27 |
+
|
28 |
+
# Pick lowest angle (arbitrary).
|
29 |
+
ij = ij[:, keep]
|
30 |
+
buffer[:] = np.inf
|
31 |
+
o_abs = np.abs(angles[keep])
|
32 |
+
np.minimum.at(buffer, tuple(ij), o_abs)
|
33 |
+
mask = buffer[tuple(ij)] == o_abs
|
34 |
+
ij = ij[:, mask]
|
35 |
+
keep = keep[mask]
|
36 |
+
|
37 |
+
if nms_radius > 0:
|
38 |
+
# Apply NMS on the remaining points
|
39 |
+
buffer[:] = 0
|
40 |
+
buffer[tuple(ij)] = s[keep] # scores or scale
|
41 |
+
|
42 |
+
local_max = torch.nn.functional.max_pool2d(
|
43 |
+
torch.from_numpy(buffer).unsqueeze(0),
|
44 |
+
kernel_size=nms_radius * 2 + 1,
|
45 |
+
stride=1,
|
46 |
+
padding=nms_radius,
|
47 |
+
).squeeze(0)
|
48 |
+
is_local_max = buffer == local_max.numpy()
|
49 |
+
keep = keep[is_local_max[tuple(ij)]]
|
50 |
+
return keep
|
51 |
+
|
52 |
+
|
53 |
+
def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
|
54 |
+
x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
|
55 |
+
x.clip_(min=eps).sqrt_()
|
56 |
+
return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
|
57 |
+
|
58 |
+
|
59 |
+
def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
|
60 |
+
"""
|
61 |
+
Detect keypoints using OpenCV Detector.
|
62 |
+
Optionally, perform description.
|
63 |
+
Args:
|
64 |
+
features: OpenCV based keypoints detector and descriptor
|
65 |
+
image: Grayscale image of uint8 data type
|
66 |
+
Returns:
|
67 |
+
keypoints: 1D array of detected cv2.KeyPoint
|
68 |
+
scores: 1D array of responses
|
69 |
+
descriptors: 1D array of descriptors
|
70 |
+
"""
|
71 |
+
detections, descriptors = features.detectAndCompute(image, None)
|
72 |
+
points = np.array([k.pt for k in detections], dtype=np.float32)
|
73 |
+
scores = np.array([k.response for k in detections], dtype=np.float32)
|
74 |
+
scales = np.array([k.size for k in detections], dtype=np.float32)
|
75 |
+
angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
|
76 |
+
return points, scores, scales, angles, descriptors
|
77 |
+
|
78 |
+
|
79 |
+
class SIFT(Extractor):
|
80 |
+
default_conf = {
|
81 |
+
"rootsift": True,
|
82 |
+
"nms_radius": 0, # None to disable filtering entirely.
|
83 |
+
"max_num_keypoints": 4096,
|
84 |
+
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
85 |
+
"detection_threshold": 0.0066667, # from COLMAP
|
86 |
+
"edge_threshold": 10,
|
87 |
+
"first_octave": -1, # only used by pycolmap, the default of COLMAP
|
88 |
+
"num_octaves": 4,
|
89 |
+
}
|
90 |
+
|
91 |
+
preprocess_conf = {
|
92 |
+
"resize": 1024,
|
93 |
+
}
|
94 |
+
|
95 |
+
required_data_keys = ["image"]
|
96 |
+
|
97 |
+
def __init__(self, **conf):
|
98 |
+
super().__init__(**conf) # Update with default configuration.
|
99 |
+
backend = self.conf.backend
|
100 |
+
if backend.startswith("pycolmap"):
|
101 |
+
if pycolmap is None:
|
102 |
+
raise ImportError(
|
103 |
+
"Cannot find module pycolmap: install it with pip"
|
104 |
+
"or use backend=opencv."
|
105 |
+
)
|
106 |
+
options = {
|
107 |
+
"peak_threshold": self.conf.detection_threshold,
|
108 |
+
"edge_threshold": self.conf.edge_threshold,
|
109 |
+
"first_octave": self.conf.first_octave,
|
110 |
+
"num_octaves": self.conf.num_octaves,
|
111 |
+
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
112 |
+
}
|
113 |
+
device = (
|
114 |
+
"auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
|
115 |
+
)
|
116 |
+
if (
|
117 |
+
backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
118 |
+
) and pycolmap.__version__ < "0.5.0":
|
119 |
+
warnings.warn(
|
120 |
+
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
|
121 |
+
"consider upgrading pycolmap or use the CUDA version.",
|
122 |
+
stacklevel=1,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
options["max_num_features"] = self.conf.max_num_keypoints
|
126 |
+
self.sift = pycolmap.Sift(options=options, device=device)
|
127 |
+
elif backend == "opencv":
|
128 |
+
self.sift = cv2.SIFT_create(
|
129 |
+
contrastThreshold=self.conf.detection_threshold,
|
130 |
+
nfeatures=self.conf.max_num_keypoints,
|
131 |
+
edgeThreshold=self.conf.edge_threshold,
|
132 |
+
nOctaveLayers=self.conf.num_octaves,
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
|
136 |
+
raise ValueError(
|
137 |
+
f"Unknown backend: {backend} not in {{{','.join(backends)}}}."
|
138 |
+
)
|
139 |
+
|
140 |
+
def extract_single_image(self, image: torch.Tensor):
|
141 |
+
image_np = image.cpu().numpy().squeeze(0)
|
142 |
+
|
143 |
+
if self.conf.backend.startswith("pycolmap"):
|
144 |
+
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
|
145 |
+
detections, descriptors = self.sift.extract(image_np)
|
146 |
+
scores = None # Scores are not exposed by COLMAP anymore.
|
147 |
+
else:
|
148 |
+
detections, scores, descriptors = self.sift.extract(image_np)
|
149 |
+
keypoints = detections[:, :2] # Keep only (x, y).
|
150 |
+
scales, angles = detections[:, -2:].T
|
151 |
+
if scores is not None and (
|
152 |
+
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
153 |
+
):
|
154 |
+
# Set the scores as a combination of abs. response and scale.
|
155 |
+
scores = np.abs(scores) * scales
|
156 |
+
elif self.conf.backend == "opencv":
|
157 |
+
# TODO: Check if opencv keypoints are already in corner convention
|
158 |
+
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
|
159 |
+
self.sift, (image_np * 255.0).astype(np.uint8)
|
160 |
+
)
|
161 |
+
pred = {
|
162 |
+
"keypoints": keypoints,
|
163 |
+
"scales": scales,
|
164 |
+
"oris": angles,
|
165 |
+
"descriptors": descriptors,
|
166 |
+
}
|
167 |
+
if scores is not None:
|
168 |
+
pred["keypoint_scores"] = scores
|
169 |
+
|
170 |
+
# sometimes pycolmap returns points outside the image. We remove them
|
171 |
+
if self.conf.backend.startswith("pycolmap"):
|
172 |
+
is_inside = (
|
173 |
+
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
|
174 |
+
).all(-1)
|
175 |
+
pred = {k: v[is_inside] for k, v in pred.items()}
|
176 |
+
|
177 |
+
if self.conf.nms_radius is not None:
|
178 |
+
keep = filter_dog_point(
|
179 |
+
pred["keypoints"],
|
180 |
+
pred["scales"],
|
181 |
+
pred["oris"],
|
182 |
+
image_np.shape,
|
183 |
+
self.conf.nms_radius,
|
184 |
+
scores=pred.get("keypoint_scores"),
|
185 |
+
)
|
186 |
+
pred = {k: v[keep] for k, v in pred.items()}
|
187 |
+
|
188 |
+
pred = {k: torch.from_numpy(v) for k, v in pred.items()}
|
189 |
+
if scores is not None:
|
190 |
+
# Keep the k keypoints with highest score
|
191 |
+
num_points = self.conf.max_num_keypoints
|
192 |
+
if num_points is not None and len(pred["keypoints"]) > num_points:
|
193 |
+
indices = torch.topk(pred["keypoint_scores"], num_points).indices
|
194 |
+
pred = {k: v[indices] for k, v in pred.items()}
|
195 |
+
|
196 |
+
return pred
|
197 |
+
|
198 |
+
def forward(self, data: dict) -> dict:
|
199 |
+
image = data["image"]
|
200 |
+
if image.shape[1] == 3:
|
201 |
+
image = rgb_to_grayscale(image)
|
202 |
+
device = image.device
|
203 |
+
image = image.cpu()
|
204 |
+
pred = []
|
205 |
+
for k in range(len(image)):
|
206 |
+
img = image[k]
|
207 |
+
if "image_size" in data.keys():
|
208 |
+
# avoid extracting points in padded areas
|
209 |
+
w, h = data["image_size"][k]
|
210 |
+
img = img[:, :h, :w]
|
211 |
+
p = self.extract_single_image(img)
|
212 |
+
pred.append(p)
|
213 |
+
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
214 |
+
if self.conf.rootsift:
|
215 |
+
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
|
216 |
+
return pred
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/superpoint.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %BANNER_BEGIN%
|
2 |
+
# ---------------------------------------------------------------------
|
3 |
+
# %COPYRIGHT_BEGIN%
|
4 |
+
#
|
5 |
+
# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
|
6 |
+
#
|
7 |
+
# Unpublished Copyright (c) 2020
|
8 |
+
# Magic Leap, Inc., All Rights Reserved.
|
9 |
+
#
|
10 |
+
# NOTICE: All information contained herein is, and remains the property
|
11 |
+
# of COMPANY. The intellectual and technical concepts contained herein
|
12 |
+
# are proprietary to COMPANY and may be covered by U.S. and Foreign
|
13 |
+
# Patents, patents in process, and are protected by trade secret or
|
14 |
+
# copyright law. Dissemination of this information or reproduction of
|
15 |
+
# this material is strictly forbidden unless prior written permission is
|
16 |
+
# obtained from COMPANY. Access to the source code contained herein is
|
17 |
+
# hereby forbidden to anyone except current COMPANY employees, managers
|
18 |
+
# or contractors who have executed Confidentiality and Non-disclosure
|
19 |
+
# agreements explicitly covering such access.
|
20 |
+
#
|
21 |
+
# The copyright notice above does not evidence any actual or intended
|
22 |
+
# publication or disclosure of this source code, which includes
|
23 |
+
# information that is confidential and/or proprietary, and is a trade
|
24 |
+
# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
|
25 |
+
# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
|
26 |
+
# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
|
27 |
+
# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
|
28 |
+
# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
|
29 |
+
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
|
30 |
+
# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
|
31 |
+
# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
|
32 |
+
#
|
33 |
+
# %COPYRIGHT_END%
|
34 |
+
# ----------------------------------------------------------------------
|
35 |
+
# %AUTHORS_BEGIN%
|
36 |
+
#
|
37 |
+
# Originating Authors: Paul-Edouard Sarlin
|
38 |
+
#
|
39 |
+
# %AUTHORS_END%
|
40 |
+
# --------------------------------------------------------------------*/
|
41 |
+
# %BANNER_END%
|
42 |
+
|
43 |
+
# Adapted by Remi Pautrat, Philipp Lindenberger
|
44 |
+
|
45 |
+
import torch
|
46 |
+
from kornia.color import rgb_to_grayscale
|
47 |
+
from torch import nn
|
48 |
+
|
49 |
+
from .utils import Extractor
|
50 |
+
|
51 |
+
|
52 |
+
def simple_nms(scores, nms_radius: int):
|
53 |
+
"""Fast Non-maximum suppression to remove nearby points"""
|
54 |
+
assert nms_radius >= 0
|
55 |
+
|
56 |
+
def max_pool(x):
|
57 |
+
return torch.nn.functional.max_pool2d(
|
58 |
+
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
59 |
+
)
|
60 |
+
|
61 |
+
zeros = torch.zeros_like(scores)
|
62 |
+
max_mask = scores == max_pool(scores)
|
63 |
+
for _ in range(2):
|
64 |
+
supp_mask = max_pool(max_mask.float()) > 0
|
65 |
+
supp_scores = torch.where(supp_mask, zeros, scores)
|
66 |
+
new_max_mask = supp_scores == max_pool(supp_scores)
|
67 |
+
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
68 |
+
return torch.where(max_mask, scores, zeros)
|
69 |
+
|
70 |
+
|
71 |
+
def top_k_keypoints(keypoints, scores, k):
|
72 |
+
if k >= len(keypoints):
|
73 |
+
return keypoints, scores
|
74 |
+
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
|
75 |
+
return keypoints[indices], scores
|
76 |
+
|
77 |
+
|
78 |
+
def sample_descriptors(keypoints, descriptors, s: int = 8):
|
79 |
+
"""Interpolate descriptors at keypoint locations"""
|
80 |
+
b, c, h, w = descriptors.shape
|
81 |
+
keypoints = keypoints - s / 2 + 0.5
|
82 |
+
keypoints /= torch.tensor(
|
83 |
+
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
|
84 |
+
).to(keypoints)[None]
|
85 |
+
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
86 |
+
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
|
87 |
+
descriptors = torch.nn.functional.grid_sample(
|
88 |
+
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
|
89 |
+
)
|
90 |
+
descriptors = torch.nn.functional.normalize(
|
91 |
+
descriptors.reshape(b, c, -1), p=2, dim=1
|
92 |
+
)
|
93 |
+
return descriptors
|
94 |
+
|
95 |
+
|
96 |
+
class SuperPoint(Extractor):
|
97 |
+
"""SuperPoint Convolutional Detector and Descriptor
|
98 |
+
|
99 |
+
SuperPoint: Self-Supervised Interest Point Detection and
|
100 |
+
Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
|
101 |
+
Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
|
102 |
+
|
103 |
+
"""
|
104 |
+
|
105 |
+
default_conf = {
|
106 |
+
"descriptor_dim": 256,
|
107 |
+
"nms_radius": 4,
|
108 |
+
"max_num_keypoints": None,
|
109 |
+
# TODO: detection threshold
|
110 |
+
"detection_threshold": 0.0005,
|
111 |
+
"remove_borders": 4,
|
112 |
+
}
|
113 |
+
|
114 |
+
preprocess_conf = {
|
115 |
+
"resize": 1024,
|
116 |
+
}
|
117 |
+
|
118 |
+
required_data_keys = ["image"]
|
119 |
+
|
120 |
+
def __init__(self, **conf):
|
121 |
+
super().__init__(**conf) # Update with default configuration.
|
122 |
+
self.relu = nn.ReLU(inplace=True)
|
123 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
124 |
+
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
|
125 |
+
|
126 |
+
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
|
127 |
+
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
128 |
+
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
129 |
+
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
130 |
+
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
131 |
+
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
132 |
+
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
133 |
+
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
134 |
+
|
135 |
+
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
136 |
+
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
|
137 |
+
|
138 |
+
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
139 |
+
self.convDb = nn.Conv2d(
|
140 |
+
c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
|
141 |
+
)
|
142 |
+
|
143 |
+
url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
|
144 |
+
self.load_state_dict(torch.hub.load_state_dict_from_url(url))
|
145 |
+
|
146 |
+
if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
|
147 |
+
raise ValueError("max_num_keypoints must be positive or None")
|
148 |
+
|
149 |
+
def forward(self, data: dict) -> dict:
|
150 |
+
"""Compute keypoints, scores, descriptors for image"""
|
151 |
+
for key in self.required_data_keys:
|
152 |
+
assert key in data, f"Missing key {key} in data"
|
153 |
+
image = data["image"]
|
154 |
+
if image.shape[1] == 3:
|
155 |
+
image = rgb_to_grayscale(image)
|
156 |
+
|
157 |
+
# Shared Encoder
|
158 |
+
x = self.relu(self.conv1a(image))
|
159 |
+
x = self.relu(self.conv1b(x))
|
160 |
+
x = self.pool(x)
|
161 |
+
x = self.relu(self.conv2a(x))
|
162 |
+
x = self.relu(self.conv2b(x))
|
163 |
+
x = self.pool(x)
|
164 |
+
x = self.relu(self.conv3a(x))
|
165 |
+
x = self.relu(self.conv3b(x))
|
166 |
+
x = self.pool(x)
|
167 |
+
x = self.relu(self.conv4a(x))
|
168 |
+
x = self.relu(self.conv4b(x))
|
169 |
+
|
170 |
+
# Compute the dense keypoint scores
|
171 |
+
cPa = self.relu(self.convPa(x))
|
172 |
+
scores = self.convPb(cPa)
|
173 |
+
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
|
174 |
+
b, _, h, w = scores.shape
|
175 |
+
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
|
176 |
+
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
|
177 |
+
scores = simple_nms(scores, self.conf.nms_radius)
|
178 |
+
|
179 |
+
# Discard keypoints near the image borders
|
180 |
+
if self.conf.remove_borders:
|
181 |
+
pad = self.conf.remove_borders
|
182 |
+
scores[:, :pad] = -1
|
183 |
+
scores[:, :, :pad] = -1
|
184 |
+
scores[:, -pad:] = -1
|
185 |
+
scores[:, :, -pad:] = -1
|
186 |
+
|
187 |
+
# Extract keypoints
|
188 |
+
best_kp = torch.where(scores > self.conf.detection_threshold)
|
189 |
+
scores = scores[best_kp]
|
190 |
+
|
191 |
+
# Separate into batches
|
192 |
+
keypoints = [
|
193 |
+
torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
|
194 |
+
]
|
195 |
+
scores = [scores[best_kp[0] == i] for i in range(b)]
|
196 |
+
|
197 |
+
# Keep the k keypoints with highest score
|
198 |
+
if self.conf.max_num_keypoints is not None:
|
199 |
+
keypoints, scores = list(
|
200 |
+
zip(
|
201 |
+
*[
|
202 |
+
top_k_keypoints(k, s, self.conf.max_num_keypoints)
|
203 |
+
for k, s in zip(keypoints, scores)
|
204 |
+
]
|
205 |
+
)
|
206 |
+
)
|
207 |
+
|
208 |
+
# Convert (h, w) to (x, y)
|
209 |
+
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
|
210 |
+
|
211 |
+
# Compute the dense descriptors
|
212 |
+
cDa = self.relu(self.convDa(x))
|
213 |
+
descriptors = self.convDb(cDa)
|
214 |
+
descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
|
215 |
+
|
216 |
+
# Extract descriptors
|
217 |
+
descriptors = [
|
218 |
+
sample_descriptors(k[None], d[None], 8)[0]
|
219 |
+
for k, d in zip(keypoints, descriptors)
|
220 |
+
]
|
221 |
+
|
222 |
+
return {
|
223 |
+
"keypoints": torch.stack(keypoints, 0),
|
224 |
+
"keypoint_scores": torch.stack(scores, 0),
|
225 |
+
"descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
class ReinforcedFP(SuperPoint):
|
230 |
+
def __init__(self, **conf):
|
231 |
+
super().__init__(**conf) # Update with default configuration.
|
232 |
+
url = "https://github.com/aritrabhowmik/Reinforced-Feature-Points/raw/refs/heads/master/weights/baseline_mixed_loss.pth" # noqa
|
233 |
+
self.load_state_dict(torch.hub.load_state_dict_from_url(url))
|
imcui/third_party/dad/dad/detectors/third_party/lightglue/utils.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc as collections
|
2 |
+
from pathlib import Path
|
3 |
+
from types import SimpleNamespace
|
4 |
+
from typing import Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import kornia
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class ImagePreprocessor:
|
13 |
+
default_conf = {
|
14 |
+
"resize": None, # target edge length, None for no resizing
|
15 |
+
"side": "long",
|
16 |
+
"interpolation": "bilinear",
|
17 |
+
"align_corners": None,
|
18 |
+
"antialias": True,
|
19 |
+
}
|
20 |
+
|
21 |
+
def __init__(self, **conf) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.conf = {**self.default_conf, **conf}
|
24 |
+
self.conf = SimpleNamespace(**self.conf)
|
25 |
+
|
26 |
+
def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
27 |
+
"""Resize and preprocess an image, return image and resize scale"""
|
28 |
+
h, w = img.shape[-2:]
|
29 |
+
if self.conf.resize is not None:
|
30 |
+
img = kornia.geometry.transform.resize(
|
31 |
+
img,
|
32 |
+
self.conf.resize,
|
33 |
+
side=self.conf.side,
|
34 |
+
antialias=self.conf.antialias,
|
35 |
+
align_corners=self.conf.align_corners,
|
36 |
+
)
|
37 |
+
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
|
38 |
+
return img, scale
|
39 |
+
|
40 |
+
|
41 |
+
def map_tensor(input_, func: Callable):
|
42 |
+
string_classes = (str, bytes)
|
43 |
+
if isinstance(input_, string_classes):
|
44 |
+
return input_
|
45 |
+
elif isinstance(input_, collections.Mapping):
|
46 |
+
return {k: map_tensor(sample, func) for k, sample in input_.items()}
|
47 |
+
elif isinstance(input_, collections.Sequence):
|
48 |
+
return [map_tensor(sample, func) for sample in input_]
|
49 |
+
elif isinstance(input_, torch.Tensor):
|
50 |
+
return func(input_)
|
51 |
+
else:
|
52 |
+
return input_
|
53 |
+
|
54 |
+
|
55 |
+
def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
|
56 |
+
"""Move batch (dict) to device"""
|
57 |
+
|
58 |
+
def _func(tensor):
|
59 |
+
return tensor.to(device=device, non_blocking=non_blocking).detach()
|
60 |
+
|
61 |
+
return map_tensor(batch, _func)
|
62 |
+
|
63 |
+
|
64 |
+
def rbd(data: dict) -> dict:
|
65 |
+
"""Remove batch dimension from elements in data"""
|
66 |
+
return {
|
67 |
+
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
|
68 |
+
for k, v in data.items()
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
|
73 |
+
"""Normalize the image tensor and reorder the dimensions."""
|
74 |
+
if image.ndim == 3:
|
75 |
+
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
|
76 |
+
elif image.ndim == 2:
|
77 |
+
image = image[None] # add channel axis
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Not an image: {image.shape}")
|
80 |
+
return torch.tensor(image / 255.0, dtype=torch.float)
|
81 |
+
|
82 |
+
|
83 |
+
def resize_image(
|
84 |
+
image: np.ndarray,
|
85 |
+
size: Union[List[int], int],
|
86 |
+
fn: str = "max",
|
87 |
+
interp: Optional[str] = "area",
|
88 |
+
) -> np.ndarray:
|
89 |
+
"""Resize an image to a fixed size, or according to max or min edge."""
|
90 |
+
h, w = image.shape[:2]
|
91 |
+
|
92 |
+
fn = {"max": max, "min": min}[fn]
|
93 |
+
if isinstance(size, int):
|
94 |
+
scale = size / fn(h, w)
|
95 |
+
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
96 |
+
scale = (w_new / w, h_new / h)
|
97 |
+
elif isinstance(size, (tuple, list)):
|
98 |
+
h_new, w_new = size
|
99 |
+
scale = (w_new / w, h_new / h)
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Incorrect new size: {size}")
|
102 |
+
mode = {
|
103 |
+
"linear": cv2.INTER_LINEAR,
|
104 |
+
"cubic": cv2.INTER_CUBIC,
|
105 |
+
"nearest": cv2.INTER_NEAREST,
|
106 |
+
"area": cv2.INTER_AREA,
|
107 |
+
}[interp]
|
108 |
+
return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
|
109 |
+
|
110 |
+
|
111 |
+
def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
|
112 |
+
if not Path(path).exists():
|
113 |
+
raise FileNotFoundError(f"No image at path {path}.")
|
114 |
+
mode = cv2.IMREAD_COLOR
|
115 |
+
image = cv2.imread(str(path), mode)
|
116 |
+
if image is None:
|
117 |
+
raise IOError(f"Could not read image at {path}.")
|
118 |
+
image = image[..., ::-1]
|
119 |
+
if resize is not None:
|
120 |
+
image, _ = resize_image(image, resize, **kwargs)
|
121 |
+
return numpy_image_to_torch(image)
|
122 |
+
|
123 |
+
|
124 |
+
class Extractor(torch.nn.Module):
|
125 |
+
def __init__(self, **conf):
|
126 |
+
super().__init__()
|
127 |
+
self.conf = SimpleNamespace(**{**self.default_conf, **conf})
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def extract(self, img: torch.Tensor, **conf) -> dict:
|
131 |
+
"""Perform extraction with online resizing"""
|
132 |
+
if img.dim() == 3:
|
133 |
+
img = img[None] # add batch dim
|
134 |
+
assert img.dim() == 4 and img.shape[0] == 1
|
135 |
+
shape = img.shape[-2:][::-1]
|
136 |
+
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
|
137 |
+
feats = self.forward({"image": img})
|
138 |
+
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
|
139 |
+
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
|
140 |
+
return feats
|
141 |
+
|
142 |
+
|
143 |
+
def match_pair(
|
144 |
+
extractor,
|
145 |
+
matcher,
|
146 |
+
image0: torch.Tensor,
|
147 |
+
image1: torch.Tensor,
|
148 |
+
device: str = "cpu",
|
149 |
+
**preprocess,
|
150 |
+
):
|
151 |
+
"""Match a pair of images (image0, image1) with an extractor and matcher"""
|
152 |
+
feats0 = extractor.extract(image0, **preprocess)
|
153 |
+
feats1 = extractor.extract(image1, **preprocess)
|
154 |
+
matches01 = matcher({"image0": feats0, "image1": feats1})
|
155 |
+
data = [feats0, feats1, matches01]
|
156 |
+
# remove batch dim and move to target device
|
157 |
+
feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
|
158 |
+
return feats0, feats1, matches01
|
imcui/third_party/dad/dad/detectors/third_party/lightglue_detector.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Union
|
3 |
+
import torch
|
4 |
+
from .lightglue.utils import load_image
|
5 |
+
from dad.utils import (
|
6 |
+
get_best_device,
|
7 |
+
)
|
8 |
+
from dad.types import Detector
|
9 |
+
|
10 |
+
|
11 |
+
class LightGlueDetector(Detector):
|
12 |
+
def __init__(self, model, resize=None, **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.model = model(**kwargs).eval().to(get_best_device())
|
15 |
+
if resize is not None:
|
16 |
+
self.model.preprocess_conf["resize"] = resize
|
17 |
+
|
18 |
+
@property
|
19 |
+
def topleft(self):
|
20 |
+
return 0.0
|
21 |
+
|
22 |
+
def load_image(self, im_path: Union[str, Path]):
|
23 |
+
return {"image": load_image(im_path).to(get_best_device())}
|
24 |
+
|
25 |
+
@torch.inference_mode()
|
26 |
+
def detect(
|
27 |
+
self,
|
28 |
+
batch: dict[str, torch.Tensor],
|
29 |
+
*,
|
30 |
+
num_keypoints: int,
|
31 |
+
return_dense_probs: bool = False,
|
32 |
+
):
|
33 |
+
image = batch["image"]
|
34 |
+
self.model.conf.max_num_keypoints = num_keypoints
|
35 |
+
ret = self.model.extract(image)
|
36 |
+
kpts = self.to_normalized_coords(
|
37 |
+
ret["keypoints"], ret["image_size"][0, 1], ret["image_size"][0, 0]
|
38 |
+
)
|
39 |
+
result = {"keypoints": kpts, "keypoint_probs": None}
|
40 |
+
if return_dense_probs:
|
41 |
+
result["dense_probs"] = ret["dense_probs"] if "dense_probs" in ret else None
|
42 |
+
return result
|
imcui/third_party/dad/dad/detectors/third_party/rekd/config.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
## for fix seed
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import numpy
|
7 |
+
|
8 |
+
|
9 |
+
def get_config(jupyter=False):
|
10 |
+
parser = argparse.ArgumentParser(description="Train REKD Architecture")
|
11 |
+
|
12 |
+
## basic configuration
|
13 |
+
parser.add_argument(
|
14 |
+
"--data_dir",
|
15 |
+
type=str,
|
16 |
+
default="../ImageNet2012/ILSVRC2012_img_val", # default='path-to-ImageNet',
|
17 |
+
help="The root path to the data from which the synthetic dataset will be created.",
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--synth_dir",
|
21 |
+
type=str,
|
22 |
+
default="",
|
23 |
+
help="The path to save the generated sythetic image pairs.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--log_dir",
|
27 |
+
type=str,
|
28 |
+
default="trained_models/weights",
|
29 |
+
help="The path to save the REKD weights.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--load_dir",
|
33 |
+
type=str,
|
34 |
+
default="",
|
35 |
+
help="Set saved model parameters if resume training is desired.",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--exp_name",
|
39 |
+
type=str,
|
40 |
+
default="REKD",
|
41 |
+
help="The Rotaton-equivaraiant Keypoint Detection (REKD) experiment name",
|
42 |
+
)
|
43 |
+
## network architecture
|
44 |
+
parser.add_argument(
|
45 |
+
"--factor_scaling_pyramid",
|
46 |
+
type=float,
|
47 |
+
default=1.2,
|
48 |
+
help="The scale factor between the multi-scale pyramid levels in the architecture.",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--group_size",
|
52 |
+
type=int,
|
53 |
+
default=36,
|
54 |
+
help="The number of groups for the group convolution.",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--dim_first",
|
58 |
+
type=int,
|
59 |
+
default=2,
|
60 |
+
help="The number of channels of the first layer",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--dim_second",
|
64 |
+
type=int,
|
65 |
+
default=2,
|
66 |
+
help="The number of channels of the second layer",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--dim_third",
|
70 |
+
type=int,
|
71 |
+
default=2,
|
72 |
+
help="The number of channels of the thrid layer",
|
73 |
+
)
|
74 |
+
## network training
|
75 |
+
parser.add_argument(
|
76 |
+
"--batch_size", type=int, default=16, help="The batch size for training."
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--num_epochs", type=int, default=20, help="Number of epochs for training."
|
80 |
+
)
|
81 |
+
## Loss function
|
82 |
+
parser.add_argument(
|
83 |
+
"--init_initial_learning_rate",
|
84 |
+
type=float,
|
85 |
+
default=1e-3,
|
86 |
+
help="The init initial learning rate value.",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--MSIP_sizes", type=str, default="8,16,24,32,40", help="MSIP sizes."
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--MSIP_factor_loss",
|
93 |
+
type=str,
|
94 |
+
default="256.0,64.0,16.0,4.0,1.0",
|
95 |
+
help="MSIP loss balancing parameters.",
|
96 |
+
)
|
97 |
+
parser.add_argument("--ori_loss_balance", type=float, default=100.0, help="")
|
98 |
+
## Dataset generation
|
99 |
+
parser.add_argument(
|
100 |
+
"--patch_size",
|
101 |
+
type=int,
|
102 |
+
default=192,
|
103 |
+
help="The patch size of the generated dataset.",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--max_angle",
|
107 |
+
type=int,
|
108 |
+
default=180,
|
109 |
+
help="The max angle value for generating a synthetic view to train REKD.",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--min_scale",
|
113 |
+
type=float,
|
114 |
+
default=1.0,
|
115 |
+
help="The min scale value for generating a synthetic view to train REKD.",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--max_scale",
|
119 |
+
type=float,
|
120 |
+
default=1.0,
|
121 |
+
help="The max scale value for generating a synthetic view to train REKD.",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--max_shearing",
|
125 |
+
type=float,
|
126 |
+
default=0.0,
|
127 |
+
help="The max shearing value for generating a synthetic view to train REKD.",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--num_training_data",
|
131 |
+
type=int,
|
132 |
+
default=9000,
|
133 |
+
help="The number of the generated dataset.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--is_debugging",
|
137 |
+
type=bool,
|
138 |
+
default=False,
|
139 |
+
help="Set variable to True if you desire to train network on a smaller dataset.",
|
140 |
+
)
|
141 |
+
## For eval/inference
|
142 |
+
parser.add_argument(
|
143 |
+
"--num_points",
|
144 |
+
type=int,
|
145 |
+
default=1500,
|
146 |
+
help="the number of points at evaluation time.",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--pyramid_levels", type=int, default=5, help="downsampling pyramid levels."
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--upsampled_levels", type=int, default=2, help="upsampling image levels."
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--nms_size",
|
156 |
+
type=int,
|
157 |
+
default=15,
|
158 |
+
help="The NMS size for computing the validation repeatability.",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--border_size",
|
162 |
+
type=int,
|
163 |
+
default=15,
|
164 |
+
help="The number of pixels to remove from the borders to compute the repeatability.",
|
165 |
+
)
|
166 |
+
## For HPatches evaluation
|
167 |
+
parser.add_argument(
|
168 |
+
"--hpatches_path",
|
169 |
+
type=str,
|
170 |
+
default="./datasets/hpatches-sequences-release",
|
171 |
+
help="dataset ",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--eval_split",
|
175 |
+
type=str,
|
176 |
+
default="debug",
|
177 |
+
help="debug, view, illum, full, debug_view, debug_illum ...",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--descriptor", type=str, default="hardnet", help="hardnet, sosnet, hynet"
|
181 |
+
)
|
182 |
+
|
183 |
+
args, weird_args = (
|
184 |
+
parser.parse_known_args() if not jupyter else parser.parse_args(args=[])
|
185 |
+
)
|
186 |
+
|
187 |
+
fix_randseed(12345)
|
188 |
+
|
189 |
+
if args.synth_dir == "":
|
190 |
+
args.synth_dir = "datasets/synth_data"
|
191 |
+
|
192 |
+
args.MSIP_sizes = [int(i) for i in args.MSIP_sizes.split(",")]
|
193 |
+
args.MSIP_factor_loss = [float(i) for i in args.MSIP_factor_loss.split(",")]
|
194 |
+
|
195 |
+
return args
|
196 |
+
|
197 |
+
|
198 |
+
def fix_randseed(randseed):
|
199 |
+
r"""Fix random seed"""
|
200 |
+
random.seed(randseed)
|
201 |
+
numpy.random.seed(randseed)
|
202 |
+
torch.manual_seed(randseed)
|
203 |
+
torch.cuda.manual_seed(randseed)
|
204 |
+
torch.cuda.manual_seed_all(randseed)
|
205 |
+
torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic = False, True
|
206 |
+
# torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic = True, False
|
imcui/third_party/dad/dad/detectors/third_party/rekd/geometry_tools.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cv2 import warpPerspective as applyH
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def apply_nms(score_map, size):
|
7 |
+
from scipy.ndimage.filters import maximum_filter
|
8 |
+
|
9 |
+
score_map = score_map * (
|
10 |
+
score_map == maximum_filter(score_map, footprint=np.ones((size, size)))
|
11 |
+
)
|
12 |
+
return score_map
|
13 |
+
|
14 |
+
|
15 |
+
def remove_borders(images, borders):
|
16 |
+
## input [B,C,H,W]
|
17 |
+
shape = images.shape
|
18 |
+
|
19 |
+
if len(shape) == 4:
|
20 |
+
for batch_id in range(shape[0]):
|
21 |
+
images[batch_id, :, 0:borders, :] = 0
|
22 |
+
images[batch_id, :, :, 0:borders] = 0
|
23 |
+
images[batch_id, :, shape[2] - borders : shape[2], :] = 0
|
24 |
+
images[batch_id, :, :, shape[3] - borders : shape[3]] = 0
|
25 |
+
elif len(shape) == 2:
|
26 |
+
images[0:borders, :] = 0
|
27 |
+
images[:, 0:borders] = 0
|
28 |
+
images[shape[0] - borders : shape[0], :] = 0
|
29 |
+
images[:, shape[1] - borders : shape[1]] = 0
|
30 |
+
else:
|
31 |
+
print("Not implemented")
|
32 |
+
exit()
|
33 |
+
|
34 |
+
return images
|
35 |
+
|
36 |
+
|
37 |
+
def create_common_region_masks(h_dst_2_src, shape_src, shape_dst):
|
38 |
+
# Create mask. Only take into account pixels in the two images
|
39 |
+
inv_h = np.linalg.inv(h_dst_2_src)
|
40 |
+
inv_h = inv_h / inv_h[2, 2]
|
41 |
+
|
42 |
+
# Applies mask to destination. Where there is no 1, we can no find a point in source.
|
43 |
+
ones_dst = np.ones((shape_dst[0], shape_dst[1]))
|
44 |
+
ones_dst = remove_borders(ones_dst, borders=15)
|
45 |
+
mask_src = applyH(ones_dst, h_dst_2_src, (shape_src[1], shape_src[0]))
|
46 |
+
mask_src = np.where(mask_src >= 0.75, 1.0, 0.0)
|
47 |
+
mask_src = remove_borders(mask_src, borders=15)
|
48 |
+
|
49 |
+
ones_src = np.ones((shape_src[0], shape_src[1]))
|
50 |
+
ones_src = remove_borders(ones_src, borders=15)
|
51 |
+
mask_dst = applyH(ones_src, inv_h, (shape_dst[1], shape_dst[0]))
|
52 |
+
mask_dst = np.where(mask_dst >= 0.75, 1.0, 0.0)
|
53 |
+
mask_dst = remove_borders(mask_dst, borders=15)
|
54 |
+
|
55 |
+
return mask_src, mask_dst
|
56 |
+
|
57 |
+
|
58 |
+
def prepare_homography(hom):
|
59 |
+
if len(hom.shape) == 1:
|
60 |
+
h = np.zeros((3, 3))
|
61 |
+
for j in range(3):
|
62 |
+
for i in range(3):
|
63 |
+
if j == 2 and i == 2:
|
64 |
+
h[j, i] = 1.0
|
65 |
+
else:
|
66 |
+
h[j, i] = hom[j * 3 + i]
|
67 |
+
elif len(hom.shape) == 2: ## batch
|
68 |
+
ones = torch.ones(hom.shape[0]).unsqueeze(1)
|
69 |
+
h = torch.cat([hom, ones], dim=1).reshape(-1, 3, 3).type(torch.float32)
|
70 |
+
|
71 |
+
return h
|
72 |
+
|
73 |
+
|
74 |
+
def getAff(x, y, H):
|
75 |
+
h11 = H[0, 0]
|
76 |
+
h12 = H[0, 1]
|
77 |
+
h13 = H[0, 2]
|
78 |
+
h21 = H[1, 0]
|
79 |
+
h22 = H[1, 1]
|
80 |
+
h23 = H[1, 2]
|
81 |
+
h31 = H[2, 0]
|
82 |
+
h32 = H[2, 1]
|
83 |
+
h33 = H[2, 2]
|
84 |
+
fxdx = (
|
85 |
+
h11 / (h31 * x + h32 * y + h33)
|
86 |
+
- (h11 * x + h12 * y + h13) * h31 / (h31 * x + h32 * y + h33) ** 2
|
87 |
+
)
|
88 |
+
fxdy = (
|
89 |
+
h12 / (h31 * x + h32 * y + h33)
|
90 |
+
- (h11 * x + h12 * y + h13) * h32 / (h31 * x + h32 * y + h33) ** 2
|
91 |
+
)
|
92 |
+
|
93 |
+
fydx = (
|
94 |
+
h21 / (h31 * x + h32 * y + h33)
|
95 |
+
- (h21 * x + h22 * y + h23) * h31 / (h31 * x + h32 * y + h33) ** 2
|
96 |
+
)
|
97 |
+
fydy = (
|
98 |
+
h22 / (h31 * x + h32 * y + h33)
|
99 |
+
- (h21 * x + h22 * y + h23) * h32 / (h31 * x + h32 * y + h33) ** 2
|
100 |
+
)
|
101 |
+
|
102 |
+
Aff = [[fxdx, fxdy], [fydx, fydy]]
|
103 |
+
|
104 |
+
return np.asarray(Aff)
|
105 |
+
|
106 |
+
|
107 |
+
def apply_homography_to_points(points, h):
|
108 |
+
new_points = []
|
109 |
+
|
110 |
+
for point in points:
|
111 |
+
new_point = h.dot([point[0], point[1], 1.0])
|
112 |
+
|
113 |
+
tmp = point[2] ** 2 + np.finfo(np.float32).eps
|
114 |
+
|
115 |
+
Mi1 = [[1 / tmp, 0], [0, 1 / tmp]]
|
116 |
+
Mi1_inv = np.linalg.inv(Mi1)
|
117 |
+
Aff = getAff(point[0], point[1], h)
|
118 |
+
|
119 |
+
BMB = np.linalg.inv(np.dot(Aff, np.dot(Mi1_inv, np.matrix.transpose(Aff))))
|
120 |
+
|
121 |
+
[e, _] = np.linalg.eig(BMB)
|
122 |
+
new_radious = 1 / ((e[0] * e[1]) ** 0.5) ** 0.5
|
123 |
+
|
124 |
+
new_point = [
|
125 |
+
new_point[0] / new_point[2],
|
126 |
+
new_point[1] / new_point[2],
|
127 |
+
new_radious,
|
128 |
+
point[3],
|
129 |
+
]
|
130 |
+
new_points.append(new_point)
|
131 |
+
|
132 |
+
return np.asarray(new_points)
|
133 |
+
|
134 |
+
|
135 |
+
def find_index_higher_scores(map, num_points=1000, threshold=-1):
|
136 |
+
# Best n points
|
137 |
+
if threshold == -1:
|
138 |
+
flatten = map.flatten()
|
139 |
+
order_array = np.sort(flatten)
|
140 |
+
|
141 |
+
order_array = np.flip(order_array, axis=0)
|
142 |
+
|
143 |
+
if order_array.shape[0] < num_points:
|
144 |
+
num_points = order_array.shape[0]
|
145 |
+
|
146 |
+
threshold = order_array[num_points - 1]
|
147 |
+
|
148 |
+
if threshold <= 0.0:
|
149 |
+
### This is the problem case which derive smaller number of keypoints than the argument "num_points".
|
150 |
+
indexes = np.argwhere(order_array > 0.0)
|
151 |
+
|
152 |
+
if len(indexes) == 0:
|
153 |
+
threshold = 0.0
|
154 |
+
else:
|
155 |
+
threshold = order_array[indexes[len(indexes) - 1]]
|
156 |
+
|
157 |
+
indexes = np.argwhere(map >= threshold)
|
158 |
+
|
159 |
+
return indexes[:num_points]
|
160 |
+
|
161 |
+
|
162 |
+
def get_point_coordinates(
|
163 |
+
map, scale_value=1.0, num_points=1000, threshold=-1, order_coord="xysr"
|
164 |
+
):
|
165 |
+
## input numpy array score map : [H, W]
|
166 |
+
indexes = find_index_higher_scores(map, num_points=num_points, threshold=threshold)
|
167 |
+
new_indexes = []
|
168 |
+
for ind in indexes:
|
169 |
+
scores = map[ind[0], ind[1]]
|
170 |
+
if order_coord == "xysr":
|
171 |
+
tmp = [ind[1], ind[0], scale_value, scores]
|
172 |
+
elif order_coord == "yxsr":
|
173 |
+
tmp = [ind[0], ind[1], scale_value, scores]
|
174 |
+
|
175 |
+
new_indexes.append(tmp)
|
176 |
+
|
177 |
+
indexes = np.asarray(new_indexes)
|
178 |
+
|
179 |
+
return np.asarray(indexes)
|
180 |
+
|
181 |
+
|
182 |
+
def get_point_coordinates3D(
|
183 |
+
map,
|
184 |
+
scale_factor=1.0,
|
185 |
+
up_levels=0,
|
186 |
+
num_points=1000,
|
187 |
+
threshold=-1,
|
188 |
+
order_coord="xysr",
|
189 |
+
):
|
190 |
+
indexes = find_index_higher_scores(map, num_points=num_points, threshold=threshold)
|
191 |
+
new_indexes = []
|
192 |
+
for ind in indexes:
|
193 |
+
scale_value = scale_factor ** (ind[2] - up_levels)
|
194 |
+
scores = map[ind[0], ind[1], ind[2]]
|
195 |
+
if order_coord == "xysr":
|
196 |
+
tmp = [ind[1], ind[0], scale_value, scores]
|
197 |
+
elif order_coord == "yxsr":
|
198 |
+
tmp = [ind[0], ind[1], scale_value, scores]
|
199 |
+
|
200 |
+
new_indexes.append(tmp)
|
201 |
+
|
202 |
+
indexes = np.asarray(new_indexes)
|
203 |
+
|
204 |
+
return np.asarray(indexes)
|
imcui/third_party/dad/dad/detectors/third_party/rekd/model/REKD.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
from .kernels import gaussian_multiple_channels
|
6 |
+
|
7 |
+
|
8 |
+
class REKD(torch.nn.Module):
|
9 |
+
def __init__(self, args, device):
|
10 |
+
super(REKD, self).__init__()
|
11 |
+
from e2cnn import gspaces
|
12 |
+
from e2cnn import nn
|
13 |
+
|
14 |
+
self.pyramid_levels = 3
|
15 |
+
self.factor_scaling = args.factor_scaling_pyramid
|
16 |
+
|
17 |
+
# Smooth Gausian Filter
|
18 |
+
num_channels = 1 ## gray scale image
|
19 |
+
self.gaussian_avg = gaussian_multiple_channels(num_channels, 1.5)
|
20 |
+
|
21 |
+
r2_act = gspaces.Rot2dOnR2(N=args.group_size)
|
22 |
+
|
23 |
+
self.feat_type_in = nn.FieldType(
|
24 |
+
r2_act, num_channels * [r2_act.trivial_repr]
|
25 |
+
) ## input 1 channels (gray scale image)
|
26 |
+
|
27 |
+
feat_type_out1 = nn.FieldType(r2_act, args.dim_first * [r2_act.regular_repr])
|
28 |
+
feat_type_out2 = nn.FieldType(r2_act, args.dim_second * [r2_act.regular_repr])
|
29 |
+
feat_type_out3 = nn.FieldType(r2_act, args.dim_third * [r2_act.regular_repr])
|
30 |
+
|
31 |
+
feat_type_ori_est = nn.FieldType(r2_act, [r2_act.regular_repr])
|
32 |
+
|
33 |
+
self.block1 = nn.SequentialModule(
|
34 |
+
nn.R2Conv(
|
35 |
+
self.feat_type_in, feat_type_out1, kernel_size=5, padding=2, bias=False
|
36 |
+
),
|
37 |
+
nn.InnerBatchNorm(feat_type_out1),
|
38 |
+
nn.ReLU(feat_type_out1, inplace=True),
|
39 |
+
)
|
40 |
+
self.block2 = nn.SequentialModule(
|
41 |
+
nn.R2Conv(
|
42 |
+
feat_type_out1, feat_type_out2, kernel_size=5, padding=2, bias=False
|
43 |
+
),
|
44 |
+
nn.InnerBatchNorm(feat_type_out2),
|
45 |
+
nn.ReLU(feat_type_out2, inplace=True),
|
46 |
+
)
|
47 |
+
self.block3 = nn.SequentialModule(
|
48 |
+
nn.R2Conv(
|
49 |
+
feat_type_out2, feat_type_out3, kernel_size=5, padding=2, bias=False
|
50 |
+
),
|
51 |
+
nn.InnerBatchNorm(feat_type_out3),
|
52 |
+
nn.ReLU(feat_type_out3, inplace=True),
|
53 |
+
)
|
54 |
+
|
55 |
+
self.ori_learner = nn.SequentialModule(
|
56 |
+
nn.R2Conv(
|
57 |
+
feat_type_out3, feat_type_ori_est, kernel_size=1, padding=0, bias=False
|
58 |
+
) ## Channel pooling by 8*G -> 1*G conv.
|
59 |
+
)
|
60 |
+
self.softmax = torch.nn.Softmax(dim=1)
|
61 |
+
|
62 |
+
self.gpool = nn.GroupPooling(feat_type_out3)
|
63 |
+
self.last_layer_learner = torch.nn.Sequential(
|
64 |
+
torch.nn.BatchNorm2d(num_features=args.dim_third * self.pyramid_levels),
|
65 |
+
torch.nn.Conv2d(
|
66 |
+
in_channels=args.dim_third * self.pyramid_levels,
|
67 |
+
out_channels=1,
|
68 |
+
kernel_size=1,
|
69 |
+
bias=True,
|
70 |
+
),
|
71 |
+
torch.nn.ReLU(inplace=True), ## clamp to make the scores positive values.
|
72 |
+
)
|
73 |
+
|
74 |
+
self.dim_third = args.dim_third
|
75 |
+
self.group_size = args.group_size
|
76 |
+
self.exported = False
|
77 |
+
|
78 |
+
def export(self):
|
79 |
+
from e2cnn import nn
|
80 |
+
|
81 |
+
for name, module in dict(self.named_modules()).copy().items():
|
82 |
+
if isinstance(module, nn.EquivariantModule):
|
83 |
+
# print(name, "--->", module)
|
84 |
+
module = module.export()
|
85 |
+
setattr(self, name, module)
|
86 |
+
|
87 |
+
self.exported = True
|
88 |
+
|
89 |
+
def forward(self, input_data):
|
90 |
+
features_key, features_o = self.compute_features(input_data)
|
91 |
+
|
92 |
+
return features_key, features_o
|
93 |
+
|
94 |
+
def compute_features(self, input_data):
|
95 |
+
B, _, H, W = input_data.shape
|
96 |
+
|
97 |
+
for idx_level in range(self.pyramid_levels):
|
98 |
+
with torch.no_grad():
|
99 |
+
input_data_resized = self._resize_input_image(
|
100 |
+
input_data, idx_level, H, W
|
101 |
+
)
|
102 |
+
|
103 |
+
if H > 2500 or W > 2500:
|
104 |
+
features_t, features_o = self._forwarding_networks_divide_grid(
|
105 |
+
input_data_resized
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
features_t, features_o = self._forwarding_networks(input_data_resized)
|
109 |
+
|
110 |
+
features_t = F.interpolate(
|
111 |
+
features_t, size=(H, W), align_corners=True, mode="bilinear"
|
112 |
+
)
|
113 |
+
features_o = F.interpolate(
|
114 |
+
features_o, size=(H, W), align_corners=True, mode="bilinear"
|
115 |
+
)
|
116 |
+
|
117 |
+
if idx_level == 0:
|
118 |
+
features_key = features_t
|
119 |
+
features_ori = features_o
|
120 |
+
else:
|
121 |
+
features_key = torch.cat([features_key, features_t], axis=1)
|
122 |
+
features_ori = torch.add(features_ori, features_o)
|
123 |
+
|
124 |
+
features_key = self.last_layer_learner(features_key)
|
125 |
+
features_ori = self.softmax(features_ori)
|
126 |
+
|
127 |
+
return features_key, features_ori
|
128 |
+
|
129 |
+
def _forwarding_networks(self, input_data_resized):
|
130 |
+
from e2cnn import nn
|
131 |
+
|
132 |
+
# wrap the input tensor in a GeometricTensor (associate it with the input type)
|
133 |
+
features_t = (
|
134 |
+
nn.GeometricTensor(input_data_resized, self.feat_type_in)
|
135 |
+
if not self.exported
|
136 |
+
else input_data_resized
|
137 |
+
)
|
138 |
+
|
139 |
+
## Geometric tensor feed forwarding
|
140 |
+
features_t = self.block1(features_t)
|
141 |
+
features_t = self.block2(features_t)
|
142 |
+
features_t = self.block3(features_t)
|
143 |
+
|
144 |
+
## orientation pooling
|
145 |
+
features_o = self.ori_learner(features_t) ## self.cpool
|
146 |
+
features_o = features_o.tensor if not self.exported else features_o
|
147 |
+
|
148 |
+
## keypoint pooling
|
149 |
+
features_t = self.gpool(features_t)
|
150 |
+
features_t = features_t.tensor if not self.exported else features_t
|
151 |
+
|
152 |
+
return features_t, features_o
|
153 |
+
|
154 |
+
def _forwarding_networks_divide_grid(self, input_data_resized):
|
155 |
+
## for inference time high resolution image. # spatial grid 4
|
156 |
+
B, _, H_resized, W_resized = input_data_resized.shape
|
157 |
+
features_t = torch.zeros(B, self.dim_third, H_resized, W_resized).cuda()
|
158 |
+
features_o = torch.zeros(B, self.group_size, H_resized, W_resized).cuda()
|
159 |
+
h_divide = 2
|
160 |
+
w_divide = 2
|
161 |
+
for idx in range(h_divide):
|
162 |
+
for jdx in range(w_divide):
|
163 |
+
## compute the start and end spatial index
|
164 |
+
h_start = H_resized // h_divide * idx
|
165 |
+
w_start = W_resized // w_divide * jdx
|
166 |
+
h_end = H_resized // h_divide * (idx + 1)
|
167 |
+
w_end = W_resized // w_divide * (jdx + 1)
|
168 |
+
## crop the input image
|
169 |
+
input_data_divided = input_data_resized[
|
170 |
+
:, :, h_start:h_end, w_start:w_end
|
171 |
+
]
|
172 |
+
features_t_temp, features_o_temp = self._forwarding_networks(
|
173 |
+
input_data_divided
|
174 |
+
)
|
175 |
+
## take into the values.
|
176 |
+
features_t[:, :, h_start:h_end, w_start:w_end] = features_t_temp
|
177 |
+
features_o[:, :, h_start:h_end, w_start:w_end] = features_o_temp
|
178 |
+
|
179 |
+
return features_t, features_o
|
180 |
+
|
181 |
+
def _resize_input_image(self, input_data, idx_level, H, W):
|
182 |
+
if idx_level == 0:
|
183 |
+
input_data_smooth = input_data
|
184 |
+
else:
|
185 |
+
## (7,7) size gaussian kernel.
|
186 |
+
input_data_smooth = F.conv2d(
|
187 |
+
input_data, self.gaussian_avg.to(input_data.device), padding=[3, 3]
|
188 |
+
)
|
189 |
+
|
190 |
+
target_resize = (
|
191 |
+
int(H / (self.factor_scaling**idx_level)),
|
192 |
+
int(W / (self.factor_scaling**idx_level)),
|
193 |
+
)
|
194 |
+
|
195 |
+
input_data_resized = F.interpolate(
|
196 |
+
input_data_smooth, size=target_resize, align_corners=True, mode="bilinear"
|
197 |
+
)
|
198 |
+
|
199 |
+
input_data_resized = self.local_norm_image(input_data_resized)
|
200 |
+
|
201 |
+
return input_data_resized
|
202 |
+
|
203 |
+
def local_norm_image(self, x, k_size=65, eps=1e-10):
|
204 |
+
pad = int(k_size / 2)
|
205 |
+
|
206 |
+
x_pad = F.pad(x, (pad, pad, pad, pad), mode="reflect")
|
207 |
+
x_mean = F.avg_pool2d(
|
208 |
+
x_pad, kernel_size=[k_size, k_size], stride=[1, 1], padding=0
|
209 |
+
) ## padding='valid'==0
|
210 |
+
x2_mean = F.avg_pool2d(
|
211 |
+
torch.pow(x_pad, 2.0),
|
212 |
+
kernel_size=[k_size, k_size],
|
213 |
+
stride=[1, 1],
|
214 |
+
padding=0,
|
215 |
+
)
|
216 |
+
|
217 |
+
x_std = torch.sqrt(torch.abs(x2_mean - x_mean * x_mean)) + eps
|
218 |
+
x_norm = (x - x_mean) / (1.0 + x_std)
|
219 |
+
|
220 |
+
return x_norm
|
221 |
+
|
222 |
+
|
223 |
+
def count_model_parameters(model):
|
224 |
+
## Count the number of learnable parameters.
|
225 |
+
print("================ List of Learnable model parameters ================ ")
|
226 |
+
for n, p in model.named_parameters():
|
227 |
+
if p.requires_grad:
|
228 |
+
print("{} {}".format(n, p.data.shape))
|
229 |
+
else:
|
230 |
+
print("\n\n\n None learnable params {} {}".format(n, p.data.shape))
|
231 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
232 |
+
params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
|
233 |
+
print("The number of learnable parameters : {} ".format(params.data))
|
234 |
+
print("==================================================================== ")
|
imcui/third_party/dad/dad/detectors/third_party/rekd/model/kernels.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def gaussian_multiple_channels(num_channels, sigma):
|
6 |
+
r = 2 * sigma
|
7 |
+
size = 2 * r + 1
|
8 |
+
size = int(math.ceil(size))
|
9 |
+
x = torch.arange(0, size, 1, dtype=torch.float)
|
10 |
+
y = x.unsqueeze(1)
|
11 |
+
x0 = y0 = r
|
12 |
+
|
13 |
+
gaussian = torch.exp(-1 * (((x - x0) ** 2 + (y - y0) ** 2) / (2 * (sigma**2)))) / (
|
14 |
+
(2 * math.pi * (sigma**2)) ** 0.5
|
15 |
+
)
|
16 |
+
gaussian = gaussian.to(dtype=torch.float32)
|
17 |
+
|
18 |
+
weights = torch.zeros((num_channels, num_channels, size, size), dtype=torch.float32)
|
19 |
+
for i in range(num_channels):
|
20 |
+
weights[i, i, :, :] = gaussian
|
21 |
+
|
22 |
+
return weights
|
23 |
+
|
24 |
+
|
25 |
+
def ones_multiple_channels(size, num_channels):
|
26 |
+
ones = torch.ones((size, size))
|
27 |
+
weights = torch.zeros((num_channels, num_channels, size, size), dtype=torch.float32)
|
28 |
+
|
29 |
+
for i in range(num_channels):
|
30 |
+
weights[i, i, :, :] = ones
|
31 |
+
|
32 |
+
return weights
|
33 |
+
|
34 |
+
|
35 |
+
def grid_indexes(size):
|
36 |
+
weights = torch.zeros((2, 1, size, size), dtype=torch.float32)
|
37 |
+
|
38 |
+
columns = []
|
39 |
+
for idx in range(1, 1 + size):
|
40 |
+
columns.append(torch.ones((size)) * idx)
|
41 |
+
columns = torch.stack(columns)
|
42 |
+
|
43 |
+
rows = []
|
44 |
+
for idx in range(1, 1 + size):
|
45 |
+
rows.append(torch.tensor(range(1, 1 + size)))
|
46 |
+
rows = torch.stack(rows)
|
47 |
+
|
48 |
+
weights[0, 0, :, :] = columns
|
49 |
+
weights[1, 0, :, :] = rows
|
50 |
+
|
51 |
+
return weights
|
52 |
+
|
53 |
+
|
54 |
+
def get_kernel_size(factor):
|
55 |
+
"""
|
56 |
+
Find the kernel size given the desired factor of upsampling.
|
57 |
+
"""
|
58 |
+
return 2 * factor - factor % 2
|
59 |
+
|
60 |
+
|
61 |
+
def linear_upsample_weights(half_factor, number_of_classes):
|
62 |
+
"""
|
63 |
+
Create weights matrix for transposed convolution with linear filter
|
64 |
+
initialization.
|
65 |
+
"""
|
66 |
+
|
67 |
+
filter_size = get_kernel_size(half_factor)
|
68 |
+
|
69 |
+
weights = torch.zeros(
|
70 |
+
(
|
71 |
+
number_of_classes,
|
72 |
+
number_of_classes,
|
73 |
+
filter_size,
|
74 |
+
filter_size,
|
75 |
+
),
|
76 |
+
dtype=torch.float32,
|
77 |
+
)
|
78 |
+
|
79 |
+
upsample_kernel = torch.ones((filter_size, filter_size))
|
80 |
+
for i in range(number_of_classes):
|
81 |
+
weights[i, i, :, :] = upsample_kernel
|
82 |
+
|
83 |
+
return weights
|
84 |
+
|
85 |
+
|
86 |
+
class Kernels_custom:
|
87 |
+
def __init__(self, args, MSIP_sizes=[]):
|
88 |
+
self.batch_size = args.batch_size
|
89 |
+
# create_kernels
|
90 |
+
self.kernels = {}
|
91 |
+
|
92 |
+
if MSIP_sizes != []:
|
93 |
+
self.create_kernels(MSIP_sizes)
|
94 |
+
|
95 |
+
if 8 not in MSIP_sizes:
|
96 |
+
self.create_kernels([8])
|
97 |
+
|
98 |
+
def create_kernels(self, MSIP_sizes):
|
99 |
+
# Grid Indexes for MSIP
|
100 |
+
for ksize in MSIP_sizes:
|
101 |
+
ones_kernel = ones_multiple_channels(ksize, 1)
|
102 |
+
indexes_kernel = grid_indexes(ksize)
|
103 |
+
upsample_filter_np = linear_upsample_weights(int(ksize / 2), 1)
|
104 |
+
|
105 |
+
self.ones_kernel = ones_kernel.requires_grad_(False)
|
106 |
+
self.kernels["ones_kernel_" + str(ksize)] = self.ones_kernel
|
107 |
+
|
108 |
+
self.upsample_filter_np = upsample_filter_np.requires_grad_(False)
|
109 |
+
self.kernels["upsample_filter_np_" + str(ksize)] = self.upsample_filter_np
|
110 |
+
|
111 |
+
self.indexes_kernel = indexes_kernel.requires_grad_(False)
|
112 |
+
self.kernels["indexes_kernel_" + str(ksize)] = self.indexes_kernel
|
113 |
+
|
114 |
+
def get_kernels(self, device):
|
115 |
+
kernels = {}
|
116 |
+
for k, v in self.kernels.items():
|
117 |
+
kernels[k] = v.to(device)
|
118 |
+
return kernels
|
imcui/third_party/dad/dad/detectors/third_party/rekd/model/load_models.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .REKD import REKD
|
3 |
+
|
4 |
+
|
5 |
+
def load_detector(args, device):
|
6 |
+
args.group_size, args.dim_first, args.dim_second, args.dim_third = model_parsing(
|
7 |
+
args
|
8 |
+
)
|
9 |
+
model1 = REKD(args, device)
|
10 |
+
model1.load_state_dict(torch.load(args.load_dir, weights_only=True))
|
11 |
+
model1.export()
|
12 |
+
model1.eval()
|
13 |
+
model1.to(device) ## use GPU
|
14 |
+
|
15 |
+
return model1
|
16 |
+
|
17 |
+
|
18 |
+
## Load our model
|
19 |
+
def model_parsing(args):
|
20 |
+
group_size = args.load_dir.split("_group")[1].split("_")[0]
|
21 |
+
dim_first = args.load_dir.split("_f")[1].split("_")[0]
|
22 |
+
dim_second = args.load_dir.split("_s")[1].split("_")[0]
|
23 |
+
dim_third = args.load_dir.split("_t")[1].split(".log")[0]
|
24 |
+
|
25 |
+
return int(group_size), int(dim_first), int(dim_second), int(dim_third)
|