Realcat commited on
Commit
89c9b15
·
1 Parent(s): bd20887

add: dad detector with roma matcher

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -0
  2. config/config.yaml +23 -0
  3. imcui/hloc/match_dense.py +48 -5
  4. imcui/hloc/matchers/dad_roma.py +121 -0
  5. imcui/hloc/matchers/roma.py +10 -4
  6. imcui/hloc/matchers/xfeat_dense.py +4 -2
  7. imcui/hloc/matchers/xfeat_lightglue.py +4 -2
  8. imcui/third_party/RoMa/.gitignore +11 -0
  9. imcui/third_party/RoMa/LICENSE +21 -0
  10. imcui/third_party/RoMa/README.md +123 -0
  11. imcui/third_party/RoMa/data/.gitignore +2 -0
  12. imcui/third_party/RoMa/requirements.txt +14 -0
  13. imcui/third_party/RoMa/romatch/models/matcher.py +68 -32
  14. imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py +1 -1
  15. imcui/third_party/RoMa/romatch/models/transformer/layers/block.py +1 -1
  16. imcui/third_party/RoMa/romatch/utils/utils.py +9 -1
  17. imcui/third_party/RoMa/setup.py +1 -1
  18. imcui/third_party/dad/.gitignore +170 -0
  19. imcui/third_party/dad/.python-version +1 -0
  20. imcui/third_party/dad/LICENSE +21 -0
  21. imcui/third_party/dad/README.md +130 -0
  22. imcui/third_party/dad/dad/__init__.py +17 -0
  23. imcui/third_party/dad/dad/augs.py +214 -0
  24. imcui/third_party/dad/dad/benchmarks/__init__.py +21 -0
  25. imcui/third_party/dad/dad/benchmarks/hpatches.py +117 -0
  26. imcui/third_party/dad/dad/benchmarks/megadepth.py +219 -0
  27. imcui/third_party/dad/dad/benchmarks/num_inliers.py +106 -0
  28. imcui/third_party/dad/dad/benchmarks/scannet.py +163 -0
  29. imcui/third_party/dad/dad/checkpoint.py +61 -0
  30. imcui/third_party/dad/dad/datasets/__init__.py +0 -0
  31. imcui/third_party/dad/dad/datasets/megadepth.py +312 -0
  32. imcui/third_party/dad/dad/detectors/__init__.py +50 -0
  33. imcui/third_party/dad/dad/detectors/dedode_detector.py +559 -0
  34. imcui/third_party/dad/dad/detectors/third_party/__init__.py +11 -0
  35. imcui/third_party/dad/dad/detectors/third_party/harrisaff.py +35 -0
  36. imcui/third_party/dad/dad/detectors/third_party/hesaff.py +40 -0
  37. imcui/third_party/dad/dad/detectors/third_party/lightglue/__init__.py +9 -0
  38. imcui/third_party/dad/dad/detectors/third_party/lightglue/aliked.py +770 -0
  39. imcui/third_party/dad/dad/detectors/third_party/lightglue/disk.py +48 -0
  40. imcui/third_party/dad/dad/detectors/third_party/lightglue/dog_hardnet.py +41 -0
  41. imcui/third_party/dad/dad/detectors/third_party/lightglue/lightglue.py +655 -0
  42. imcui/third_party/dad/dad/detectors/third_party/lightglue/sift.py +216 -0
  43. imcui/third_party/dad/dad/detectors/third_party/lightglue/superpoint.py +233 -0
  44. imcui/third_party/dad/dad/detectors/third_party/lightglue/utils.py +158 -0
  45. imcui/third_party/dad/dad/detectors/third_party/lightglue_detector.py +42 -0
  46. imcui/third_party/dad/dad/detectors/third_party/rekd/config.py +206 -0
  47. imcui/third_party/dad/dad/detectors/third_party/rekd/geometry_tools.py +204 -0
  48. imcui/third_party/dad/dad/detectors/third_party/rekd/model/REKD.py +234 -0
  49. imcui/third_party/dad/dad/detectors/third_party/rekd/model/kernels.py +118 -0
  50. imcui/third_party/dad/dad/detectors/third_party/rekd/model/load_models.py +25 -0
README.md CHANGED
@@ -44,6 +44,7 @@ The tool currently supports various popular image matching algorithms, namely:
44
 
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
 
47
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
48
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
49
  | EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
 
44
 
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
47
+ | DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
48
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
49
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
50
  | EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
config/config.yaml CHANGED
@@ -43,6 +43,17 @@ matcher_zoo:
43
  # low, medium, high
44
  efficiency: low
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  minima(loftr):
47
  matcher: minima_loftr
48
  dense: true
@@ -50,6 +61,7 @@ matcher_zoo:
50
  name: MINIMA(LoFTR) #dispaly name
51
  source: "ARXIV 2024"
52
  paper: https://arxiv.org/abs/2412.19412
 
53
  display: true
54
  minima(RoMa):
55
  matcher: minima_roma
@@ -59,6 +71,7 @@ matcher_zoo:
59
  name: MINIMA(RoMa) #dispaly name
60
  source: "ARXIV 2024"
61
  paper: https://arxiv.org/abs/2412.19412
 
62
  display: false
63
  efficiency: low # low, medium, high
64
  omniglue:
@@ -164,6 +177,16 @@ matcher_zoo:
164
  paper: https://arxiv.org/pdf/2404.09692
165
  project: null
166
  display: true
 
 
 
 
 
 
 
 
 
 
167
  cotr:
168
  enable: false
169
  skip_ci: true
 
43
  # low, medium, high
44
  efficiency: low
45
 
46
+ dad(RoMa):
47
+ matcher: dad_roma
48
+ skip_ci: true
49
+ dense: true
50
+ info:
51
+ name: Dad(RoMa) #dispaly name
52
+ source: "ARXIV 2025"
53
+ github: https://github.com/example/example
54
+ paper: https://arxiv.org/abs/2503.07347
55
+ display: true
56
+ efficiency: low # low, medium, high
57
  minima(loftr):
58
  matcher: minima_loftr
59
  dense: true
 
61
  name: MINIMA(LoFTR) #dispaly name
62
  source: "ARXIV 2024"
63
  paper: https://arxiv.org/abs/2412.19412
64
+ github: https://github.com/LSXI7/MINIMA
65
  display: true
66
  minima(RoMa):
67
  matcher: minima_roma
 
71
  name: MINIMA(RoMa) #dispaly name
72
  source: "ARXIV 2024"
73
  paper: https://arxiv.org/abs/2412.19412
74
+ github: https://github.com/LSXI7/MINIMA
75
  display: false
76
  efficiency: low # low, medium, high
77
  omniglue:
 
177
  paper: https://arxiv.org/pdf/2404.09692
178
  project: null
179
  display: true
180
+ jamma:
181
+ matcher: jamma
182
+ dense: true
183
+ info:
184
+ name: Jamma #dispaly name
185
+ source: "CVPR 2024"
186
+ github: https://github.com/OnderT/XoFTR
187
+ paper: https://arxiv.org/pdf/2404.09692
188
+ project: null
189
+ display: false
190
  cotr:
191
  enable: false
192
  skip_ci: true
imcui/hloc/match_dense.py CHANGED
@@ -102,6 +102,23 @@ confs = {
102
  "max_error": 1, # max error for assigned keypoints (in px)
103
  "cell_size": 1, # size of quantization patch (max 1 kp/patch)
104
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # "loftr_quadtree": {
106
  # "output": "matches-loftr-quadtree",
107
  # "model": {
@@ -295,7 +312,25 @@ confs = {
295
  },
296
  "preprocessing": {
297
  "grayscale": False,
298
- "force_resize": True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  "resize_max": 1024,
300
  "width": 320,
301
  "height": 240,
@@ -1010,9 +1045,17 @@ def match_images(model, image_0, image_1, conf, device="cpu"):
1010
  # Rescale keypoints and move to cpu
1011
  if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
1012
  kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
 
 
 
 
 
1013
  kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
1014
  kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
1015
 
 
 
 
1016
  ret = {
1017
  "image0": image0.squeeze().cpu().numpy(),
1018
  "image1": image1.squeeze().cpu().numpy(),
@@ -1022,10 +1065,10 @@ def match_images(model, image_0, image_1, conf, device="cpu"):
1022
  "keypoints1": kpts1.cpu().numpy(),
1023
  "keypoints0_orig": kpts0_origin.cpu().numpy(),
1024
  "keypoints1_orig": kpts1_origin.cpu().numpy(),
1025
- "mkeypoints0": kpts0.cpu().numpy(),
1026
- "mkeypoints1": kpts1.cpu().numpy(),
1027
- "mkeypoints0_orig": kpts0_origin.cpu().numpy(),
1028
- "mkeypoints1_orig": kpts1_origin.cpu().numpy(),
1029
  "original_size0": np.array(image_0.shape[:2][::-1]),
1030
  "original_size1": np.array(image_1.shape[:2][::-1]),
1031
  "new_size0": np.array(image0.shape[-2:][::-1]),
 
102
  "max_error": 1, # max error for assigned keypoints (in px)
103
  "cell_size": 1, # size of quantization patch (max 1 kp/patch)
104
  },
105
+ "jamma": {
106
+ "output": "matches-jamma",
107
+ "model": {
108
+ "name": "jamma",
109
+ "weights": "jamma_weight.ckpt",
110
+ "max_keypoints": 2000,
111
+ "match_threshold": 0.3,
112
+ },
113
+ "preprocessing": {
114
+ "grayscale": True,
115
+ "resize_max": 1024,
116
+ "dfactor": 16,
117
+ "width": 832,
118
+ "height": 832,
119
+ "force_resize": True,
120
+ },
121
+ },
122
  # "loftr_quadtree": {
123
  # "output": "matches-loftr-quadtree",
124
  # "model": {
 
312
  },
313
  "preprocessing": {
314
  "grayscale": False,
315
+ "force_resize": False,
316
+ "resize_max": 1024,
317
+ "width": 320,
318
+ "height": 240,
319
+ "dfactor": 8,
320
+ },
321
+ },
322
+ "dad_roma": {
323
+ "output": "matches-dad_roma",
324
+ "model": {
325
+ "name": "dad_roma",
326
+ "weights": "outdoor",
327
+ "model_name": "roma_outdoor.pth",
328
+ "max_keypoints": 2000,
329
+ "match_threshold": 0.2,
330
+ },
331
+ "preprocessing": {
332
+ "grayscale": False,
333
+ "force_resize": False,
334
  "resize_max": 1024,
335
  "width": 320,
336
  "height": 240,
 
1045
  # Rescale keypoints and move to cpu
1046
  if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
1047
  kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
1048
+ mkpts0, mkpts1 = pred.get("mkeypoints0"), pred.get("mkeypoints1")
1049
+ if mkpts0 is None or mkpts1 is None:
1050
+ mkpts0 = kpts0
1051
+ mkpts1 = kpts1
1052
+
1053
  kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
1054
  kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
1055
 
1056
+ mkpts0_origin = scale_keypoints(mkpts0 + 0.5, s0) - 0.5
1057
+ mkpts1_origin = scale_keypoints(mkpts1 + 0.5, s1) - 0.5
1058
+
1059
  ret = {
1060
  "image0": image0.squeeze().cpu().numpy(),
1061
  "image1": image1.squeeze().cpu().numpy(),
 
1065
  "keypoints1": kpts1.cpu().numpy(),
1066
  "keypoints0_orig": kpts0_origin.cpu().numpy(),
1067
  "keypoints1_orig": kpts1_origin.cpu().numpy(),
1068
+ "mkeypoints0": mkpts0.cpu().numpy(),
1069
+ "mkeypoints1": mkpts1.cpu().numpy(),
1070
+ "mkeypoints0_orig": mkpts0_origin.cpu().numpy(),
1071
+ "mkeypoints1_orig": mkpts1_origin.cpu().numpy(),
1072
  "original_size0": np.array(image_0.shape[:2][::-1]),
1073
  "original_size1": np.array(image_1.shape[:2][::-1]),
1074
  "new_size0": np.array(image0.shape[-2:][::-1]),
imcui/hloc/matchers/dad_roma.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ import tempfile
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from .. import MODEL_REPO_ID, logger
8
+ from ..utils.base_model import BaseModel
9
+
10
+ roma_path = Path(__file__).parent / "../../third_party/RoMa"
11
+ sys.path.append(str(roma_path))
12
+ from romatch.models.model_zoo import roma_model
13
+
14
+ dad_path = Path(__file__).parent / "../../third_party/dad"
15
+ sys.path.append(str(dad_path))
16
+ import dad as dad_detector
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+
21
+ class Dad(BaseModel):
22
+ default_conf = {
23
+ "name": "two_view_pipeline",
24
+ "model_name": "roma_outdoor.pth",
25
+ "model_utils_name": "dinov2_vitl14_pretrain.pth",
26
+ "max_keypoints": 3000,
27
+ "coarse_res": (560, 560),
28
+ "upsample_res": (864, 1152),
29
+ }
30
+ required_inputs = [
31
+ "image0",
32
+ "image1",
33
+ ]
34
+
35
+ # Initialize the line matcher
36
+ def _init(self, conf):
37
+ model_path = self._download_model(
38
+ repo_id=MODEL_REPO_ID,
39
+ filename="{}/{}".format("roma", self.conf["model_name"]),
40
+ )
41
+
42
+ dinov2_weights = self._download_model(
43
+ repo_id=MODEL_REPO_ID,
44
+ filename="{}/{}".format("roma", self.conf["model_utils_name"]),
45
+ )
46
+
47
+ logger.info("Loading Dad + Roma model")
48
+ # load the model
49
+ weights = torch.load(model_path, map_location="cpu")
50
+ dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
51
+
52
+ if str(device) == "cpu":
53
+ amp_dtype = torch.float32
54
+ else:
55
+ amp_dtype = torch.float16
56
+
57
+ self.matcher = roma_model(
58
+ resolution=self.conf["coarse_res"],
59
+ upsample_preds=True,
60
+ weights=weights,
61
+ dinov2_weights=dinov2_weights,
62
+ device=device,
63
+ amp_dtype=amp_dtype,
64
+ )
65
+ self.matcher.upsample_res = self.conf["upsample_res"]
66
+ self.matcher.symmetric = False
67
+
68
+ self.detector = dad_detector.load_DaD()
69
+ logger.info("Load Dad + Roma model done.")
70
+
71
+ def _forward(self, data):
72
+ img0 = data["image0"].cpu().numpy().squeeze() * 255
73
+ img1 = data["image1"].cpu().numpy().squeeze() * 255
74
+ img0 = img0.transpose(1, 2, 0)
75
+ img1 = img1.transpose(1, 2, 0)
76
+ img0 = Image.fromarray(img0.astype("uint8"))
77
+ img1 = Image.fromarray(img1.astype("uint8"))
78
+ W_A, H_A = img0.size
79
+ W_B, H_B = img1.size
80
+
81
+ # hack: bad way to save then match
82
+ with (
83
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_img0,
84
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_img1,
85
+ ):
86
+ img0_path = temp_img0.name
87
+ img1_path = temp_img1.name
88
+ img0.save(img0_path)
89
+ img1.save(img1_path)
90
+
91
+ # Match
92
+ warp, certainty = self.matcher.match(img0_path, img1_path, device=device)
93
+ # Detect
94
+ keypoints_A = self.detector.detect_from_path(
95
+ img0_path,
96
+ num_keypoints=self.conf["max_keypoints"],
97
+ )["keypoints"][0]
98
+ keypoints_B = self.detector.detect_from_path(
99
+ img1_path,
100
+ num_keypoints=self.conf["max_keypoints"],
101
+ )["keypoints"][0]
102
+ matches = self.matcher.match_keypoints(
103
+ keypoints_A,
104
+ keypoints_B,
105
+ warp,
106
+ certainty,
107
+ return_tuple=False,
108
+ )
109
+
110
+ # Sample matches for estimation
111
+ kpts1, kpts2 = self.matcher.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
112
+ offset = self.detector.topleft - 0
113
+ kpts1, kpts2 = kpts1 - offset, kpts2 - offset
114
+ pred = {
115
+ "keypoints0": self.matcher._to_pixel_coordinates(keypoints_A, H_A, W_A),
116
+ "keypoints1": self.matcher._to_pixel_coordinates(keypoints_B, H_B, W_B),
117
+ "mkeypoints0": kpts1,
118
+ "mkeypoints1": kpts2,
119
+ "mconf": torch.ones_like(kpts1[:, 0]),
120
+ }
121
+ return pred
imcui/hloc/matchers/roma.py CHANGED
@@ -20,6 +20,8 @@ class Roma(BaseModel):
20
  "model_name": "roma_outdoor.pth",
21
  "model_utils_name": "dinov2_vitl14_pretrain.pth",
22
  "max_keypoints": 3000,
 
 
23
  }
24
  required_inputs = [
25
  "image0",
@@ -43,15 +45,19 @@ class Roma(BaseModel):
43
  weights = torch.load(model_path, map_location="cpu")
44
  dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
45
 
 
 
 
 
46
  self.net = roma_model(
47
- resolution=(14 * 8 * 6, 14 * 8 * 6),
48
- upsample_preds=False,
49
  weights=weights,
50
  dinov2_weights=dinov2_weights,
51
  device=device,
52
- # temp fix issue: https://github.com/Parskatt/RoMa/issues/26
53
- amp_dtype=torch.float32,
54
  )
 
55
  logger.info("Load Roma model done.")
56
 
57
  def _forward(self, data):
 
20
  "model_name": "roma_outdoor.pth",
21
  "model_utils_name": "dinov2_vitl14_pretrain.pth",
22
  "max_keypoints": 3000,
23
+ "coarse_res": (560, 560),
24
+ "upsample_res": (864, 1152),
25
  }
26
  required_inputs = [
27
  "image0",
 
45
  weights = torch.load(model_path, map_location="cpu")
46
  dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
47
 
48
+ if str(device) == "cpu":
49
+ amp_dtype = torch.float32
50
+ else:
51
+ amp_dtype = torch.float16
52
  self.net = roma_model(
53
+ resolution=self.conf["coarse_res"],
54
+ upsample_preds=True,
55
  weights=weights,
56
  dinov2_weights=dinov2_weights,
57
  device=device,
58
+ amp_dtype=amp_dtype,
 
59
  )
60
+ self.matcher.upsample_res = self.conf["upsample_res"]
61
  logger.info("Load Roma model done.")
62
 
63
  def _forward(self, data):
imcui/hloc/matchers/xfeat_dense.py CHANGED
@@ -47,8 +47,10 @@ class XFeatDense(BaseModel):
47
  # we use results from one batch
48
  matches = matches[0]
49
  pred = {
50
- "keypoints0": matches[:, :2],
51
- "keypoints1": matches[:, 2:],
 
 
52
  "mconf": torch.ones_like(matches[:, 0]),
53
  }
54
  return pred
 
47
  # we use results from one batch
48
  matches = matches[0]
49
  pred = {
50
+ "keypoints0": out0["keypoints"].squeeze(),
51
+ "keypoints1": out1["keypoints"].squeeze(),
52
+ "mkeypoints0": matches[:, :2],
53
+ "mkeypoints1": matches[:, 2:],
54
  "mconf": torch.ones_like(matches[:, 0]),
55
  }
56
  return pred
imcui/hloc/matchers/xfeat_lightglue.py CHANGED
@@ -41,8 +41,10 @@ class XFeatLightGlue(BaseModel):
41
  mkpts_0 = torch.from_numpy(mkpts_0) # n x 2
42
  mkpts_1 = torch.from_numpy(mkpts_1) # n x 2
43
  pred = {
44
- "keypoints0": mkpts_0,
45
- "keypoints1": mkpts_1,
 
 
46
  "mconf": torch.ones_like(mkpts_0[:, 0]),
47
  }
48
  return pred
 
41
  mkpts_0 = torch.from_numpy(mkpts_0) # n x 2
42
  mkpts_1 = torch.from_numpy(mkpts_1) # n x 2
43
  pred = {
44
+ "keypoints0": out0["keypoints"].squeeze(),
45
+ "keypoints1": out1["keypoints"].squeeze(),
46
+ "mkeypoints0": mkpts_0,
47
+ "mkeypoints1": mkpts_1,
48
  "mconf": torch.ones_like(mkpts_0[:, 0]),
49
  }
50
  return pred
imcui/third_party/RoMa/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.egg-info*
2
+ *.vscode*
3
+ *__pycache__*
4
+ vis*
5
+ workspace*
6
+ .venv
7
+ .DS_Store
8
+ jobs/*
9
+ *ignore_me*
10
+ *.pth
11
+ wandb*
imcui/third_party/RoMa/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Johan Edstedt
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
imcui/third_party/RoMa/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ <p align="center">
3
+ <h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
4
+ <p align="center">
5
+ <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
6
+ ·
7
+ <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
8
+ ·
9
+ <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
10
+ ·
11
+ <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
12
+ ·
13
+ <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
14
+ </p>
15
+ <h2 align="center"><p>
16
+ <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
17
+ <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
18
+ </p></h2>
19
+ <div align="center"></div>
20
+ </p>
21
+ <br/>
22
+ <p align="center">
23
+ <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
24
+ <br>
25
+ <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
26
+ </p>
27
+
28
+ ## Setup/Install
29
+ In your python environment (tested on Linux python 3.10), run:
30
+ ```bash
31
+ pip install -e .
32
+ ```
33
+ ## Demo / How to Use
34
+ We provide two demos in the [demos folder](demo).
35
+ Here's the gist of it:
36
+ ```python
37
+ from romatch import roma_outdoor
38
+ roma_model = roma_outdoor(device=device)
39
+ # Match
40
+ warp, certainty = roma_model.match(imA_path, imB_path, device=device)
41
+ # Sample matches for estimation
42
+ matches, certainty = roma_model.sample(warp, certainty)
43
+ # Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
44
+ kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
45
+ # Find a fundamental matrix (or anything else of interest)
46
+ F, mask = cv2.findFundamentalMat(
47
+ kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
48
+ )
49
+ ```
50
+
51
+ **New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
52
+
53
+ ## Settings
54
+
55
+ ### Resolution
56
+ By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
57
+ You can change this at construction (see roma_outdoor kwargs).
58
+ You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
59
+
60
+ ### Sampling
61
+ roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
62
+
63
+
64
+ ## Reproducing Results
65
+ The experiments in the paper are provided in the [experiments folder](experiments).
66
+
67
+ ### Training
68
+ 1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
69
+ 2. Run the relevant experiment, e.g.,
70
+ ```bash
71
+ torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
72
+ ```
73
+ ### Testing
74
+ ```bash
75
+ python experiments/roma_outdoor.py --only_test --benchmark mega-1500
76
+ ```
77
+ ## License
78
+ All our code except DINOv2 is MIT license.
79
+ DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
80
+
81
+ ## Acknowledgement
82
+ Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
83
+
84
+ ## Tiny RoMa
85
+ If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
86
+ ```python
87
+ from romatch import tiny_roma_v1_outdoor
88
+ tiny_roma_model = tiny_roma_v1_outdoor(device=device)
89
+ ```
90
+ Mega1500:
91
+ | | AUC@5 | AUC@10 | AUC@20 |
92
+ |----------|----------|----------|----------|
93
+ | XFeat | 46.4 | 58.9 | 69.2 |
94
+ | XFeat* | 51.9 | 67.2 | 78.9 |
95
+ | Tiny RoMa v1 | 56.4 | 69.5 | 79.5 |
96
+ | RoMa | - | - | - |
97
+
98
+ Mega-8-Scenes (See DKM):
99
+ | | AUC@5 | AUC@10 | AUC@20 |
100
+ |----------|----------|----------|----------|
101
+ | XFeat | - | - | - |
102
+ | XFeat* | 50.1 | 64.4 | 75.2 |
103
+ | Tiny RoMa v1 | 57.7 | 70.5 | 79.6 |
104
+ | RoMa | - | - | - |
105
+
106
+ IMC22 :'):
107
+ | | mAA@10 |
108
+ |----------|----------|
109
+ | XFeat | 42.1 |
110
+ | XFeat* | - |
111
+ | Tiny RoMa v1 | 42.2 |
112
+ | RoMa | - |
113
+
114
+ ## BibTeX
115
+ If you find our models useful, please consider citing our paper!
116
+ ```
117
+ @article{edstedt2024roma,
118
+ title={{RoMa: Robust Dense Feature Matching}},
119
+ author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
120
+ journal={IEEE Conference on Computer Vision and Pattern Recognition},
121
+ year={2024}
122
+ }
123
+ ```
imcui/third_party/RoMa/data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
imcui/third_party/RoMa/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ einops
3
+ torchvision
4
+ opencv-python
5
+ kornia
6
+ albumentations
7
+ loguru
8
+ tqdm
9
+ matplotlib
10
+ h5py
11
+ wandb
12
+ timm
13
+ poselib
14
+ #xformers # Optional, used for memefficient attention
imcui/third_party/RoMa/romatch/models/matcher.py CHANGED
@@ -11,7 +11,7 @@ from PIL import Image
11
 
12
  from romatch.utils import get_tuple_transform_ops
13
  from romatch.utils.local_correlation import local_correlation
14
- from romatch.utils.utils import cls_to_flow_refine, get_autocast_params
15
  from romatch.utils.kde import kde
16
 
17
  class ConvRefiner(nn.Module):
@@ -573,12 +573,30 @@ class RegressionMatcher(nn.Module):
573
  kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
574
  return kpts_A, kpts_B
575
 
576
- def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
577
- x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
578
- cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
 
 
 
 
 
 
 
 
 
 
 
 
579
  D = torch.cdist(x_A_to_B, x_B)
580
- inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
581
-
 
 
 
 
 
 
582
  if return_tuple:
583
  if return_inds:
584
  return inds_A, inds_B
@@ -586,25 +604,38 @@ class RegressionMatcher(nn.Module):
586
  return x_A[inds_A], x_B[inds_B]
587
  else:
588
  if return_inds:
589
- return torch.cat((inds_A, inds_B),dim=-1)
590
  else:
591
- return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
592
 
593
  @torch.inference_mode()
594
  def match(
595
  self,
596
- im_A_path,
597
- im_B_path,
598
  *args,
599
  batched=False,
600
- device = None,
601
  ):
602
  if device is None:
603
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
604
- if isinstance(im_A_path, (str, os.PathLike)):
605
- im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
606
  else:
607
- im_A, im_B = im_A_path, im_B_path
 
608
 
609
  symmetric = self.symmetric
610
  self.train(False)
@@ -616,9 +647,9 @@ class RegressionMatcher(nn.Module):
616
  # Get images in good format
617
  ws = self.w_resized
618
  hs = self.h_resized
619
-
620
  test_transform = get_tuple_transform_ops(
621
- resize=(hs, ws), normalize=True, clahe = False
622
  )
623
  im_A, im_B = test_transform((im_A, im_B))
624
  batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
@@ -633,20 +664,20 @@ class RegressionMatcher(nn.Module):
633
  finest_scale = 1
634
  # Run matcher
635
  if symmetric:
636
- corresps = self.forward_symmetric(batch)
637
  else:
638
- corresps = self.forward(batch, batched = True)
639
 
640
  if self.upsample_preds:
641
  hs, ws = self.upsample_res
642
-
643
  if self.attenuate_cert:
644
  low_res_certainty = F.interpolate(
645
- corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
646
  )
647
  cert_clamp = 0
648
  factor = 0.5
649
- low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
650
 
651
  if self.upsample_preds:
652
  finest_corresps = corresps[finest_scale]
@@ -654,34 +685,39 @@ class RegressionMatcher(nn.Module):
654
  test_transform = get_tuple_transform_ops(
655
  resize=(hs, ws), normalize=True
656
  )
657
- im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB')))
 
 
 
 
 
658
  im_A, im_B = im_A[None].to(device), im_B[None].to(device)
659
  scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
660
  batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
661
  if symmetric:
662
- corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
663
  else:
664
- corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
665
-
666
- im_A_to_im_B = corresps[finest_scale]["flow"]
667
  certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
668
  if finest_scale != 1:
669
  im_A_to_im_B = F.interpolate(
670
- im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
671
  )
672
  certainty = F.interpolate(
673
- certainty, size=(hs, ws), align_corners=False, mode="bilinear"
674
  )
675
  im_A_to_im_B = im_A_to_im_B.permute(
676
  0, 2, 3, 1
677
- )
678
  # Create im_A meshgrid
679
  im_A_coords = torch.meshgrid(
680
  (
681
  torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
682
  torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
683
  ),
684
- indexing = 'ij'
685
  )
686
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
687
  im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
@@ -689,14 +725,14 @@ class RegressionMatcher(nn.Module):
689
  im_A_coords = im_A_coords.permute(0, 2, 3, 1)
690
  if (im_A_to_im_B.abs() > 1).any() and True:
691
  wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
692
- certainty[wrong[:,None]] = 0
693
  im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
694
  if symmetric:
695
  A_to_B, B_to_A = im_A_to_im_B.chunk(2)
696
  q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
697
  im_B_coords = im_A_coords
698
  s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
699
- warp = torch.cat((q_warp, s_warp),dim=2)
700
  certainty = torch.cat(certainty.chunk(2), dim=3)
701
  else:
702
  warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
 
11
 
12
  from romatch.utils import get_tuple_transform_ops
13
  from romatch.utils.local_correlation import local_correlation
14
+ from romatch.utils.utils import check_rgb, cls_to_flow_refine, get_autocast_params, check_not_i16
15
  from romatch.utils.kde import kde
16
 
17
  class ConvRefiner(nn.Module):
 
573
  kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
574
  return kpts_A, kpts_B
575
 
576
+ def match_keypoints(
577
+ self, x_A, x_B, warp, certainty, return_tuple=True, return_inds=False, max_dist = 0.005, cert_th = 0,
578
+ ):
579
+ x_A_to_B = F.grid_sample(
580
+ warp[..., -2:].permute(2, 0, 1)[None],
581
+ x_A[None, None],
582
+ align_corners=False,
583
+ mode="bilinear",
584
+ )[0, :, 0].mT
585
+ cert_A_to_B = F.grid_sample(
586
+ certainty[None, None, ...],
587
+ x_A[None, None],
588
+ align_corners=False,
589
+ mode="bilinear",
590
+ )[0, 0, 0]
591
  D = torch.cdist(x_A_to_B, x_B)
592
+ inds_A, inds_B = torch.nonzero(
593
+ (D == D.min(dim=-1, keepdim=True).values)
594
+ * (D == D.min(dim=-2, keepdim=True).values)
595
+ * (cert_A_to_B[:, None] > cert_th)
596
+ * (D < max_dist),
597
+ as_tuple=True,
598
+ )
599
+
600
  if return_tuple:
601
  if return_inds:
602
  return inds_A, inds_B
 
604
  return x_A[inds_A], x_B[inds_B]
605
  else:
606
  if return_inds:
607
+ return torch.cat((inds_A, inds_B), dim=-1)
608
  else:
609
+ return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
610
 
611
  @torch.inference_mode()
612
  def match(
613
  self,
614
+ im_A_input,
615
+ im_B_input,
616
  *args,
617
  batched=False,
618
+ device=None,
619
  ):
620
  if device is None:
621
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
622
+
623
+ # Check if inputs are file paths or already loaded images
624
+ if isinstance(im_A_input, (str, os.PathLike)):
625
+ im_A = Image.open(im_A_input)
626
+ check_not_i16(im_A)
627
+ im_A = im_A.convert("RGB")
628
+ else:
629
+ check_rgb(im_A_input)
630
+ im_A = im_A_input
631
+
632
+ if isinstance(im_B_input, (str, os.PathLike)):
633
+ im_B = Image.open(im_B_input)
634
+ check_not_i16(im_B)
635
+ im_B = im_B.convert("RGB")
636
  else:
637
+ check_rgb(im_B_input)
638
+ im_B = im_B_input
639
 
640
  symmetric = self.symmetric
641
  self.train(False)
 
647
  # Get images in good format
648
  ws = self.w_resized
649
  hs = self.h_resized
650
+
651
  test_transform = get_tuple_transform_ops(
652
+ resize=(hs, ws), normalize=True, clahe=False
653
  )
654
  im_A, im_B = test_transform((im_A, im_B))
655
  batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
 
664
  finest_scale = 1
665
  # Run matcher
666
  if symmetric:
667
+ corresps = self.forward_symmetric(batch)
668
  else:
669
+ corresps = self.forward(batch, batched=True)
670
 
671
  if self.upsample_preds:
672
  hs, ws = self.upsample_res
673
+
674
  if self.attenuate_cert:
675
  low_res_certainty = F.interpolate(
676
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
677
  )
678
  cert_clamp = 0
679
  factor = 0.5
680
+ low_res_certainty = factor * low_res_certainty * (low_res_certainty < cert_clamp)
681
 
682
  if self.upsample_preds:
683
  finest_corresps = corresps[finest_scale]
 
685
  test_transform = get_tuple_transform_ops(
686
  resize=(hs, ws), normalize=True
687
  )
688
+ if isinstance(im_A_input, (str, os.PathLike)):
689
+ im_A, im_B = test_transform(
690
+ (Image.open(im_A_input).convert('RGB'), Image.open(im_B_input).convert('RGB')))
691
+ else:
692
+ im_A, im_B = test_transform((im_A_input, im_B_input))
693
+
694
  im_A, im_B = im_A[None].to(device), im_B[None].to(device)
695
  scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
696
  batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
697
  if symmetric:
698
+ corresps = self.forward_symmetric(batch, upsample=True, batched=True, scale_factor=scale_factor)
699
  else:
700
+ corresps = self.forward(batch, batched=True, upsample=True, scale_factor=scale_factor)
701
+
702
+ im_A_to_im_B = corresps[finest_scale]["flow"]
703
  certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
704
  if finest_scale != 1:
705
  im_A_to_im_B = F.interpolate(
706
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
707
  )
708
  certainty = F.interpolate(
709
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
710
  )
711
  im_A_to_im_B = im_A_to_im_B.permute(
712
  0, 2, 3, 1
713
+ )
714
  # Create im_A meshgrid
715
  im_A_coords = torch.meshgrid(
716
  (
717
  torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
718
  torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
719
  ),
720
+ indexing='ij'
721
  )
722
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
723
  im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
 
725
  im_A_coords = im_A_coords.permute(0, 2, 3, 1)
726
  if (im_A_to_im_B.abs() > 1).any() and True:
727
  wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
728
+ certainty[wrong[:, None]] = 0
729
  im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
730
  if symmetric:
731
  A_to_B, B_to_A = im_A_to_im_B.chunk(2)
732
  q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
733
  im_B_coords = im_A_coords
734
  s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
735
+ warp = torch.cat((q_warp, s_warp), dim=2)
736
  certainty = torch.cat(certainty.chunk(2), dim=3)
737
  else:
738
  warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py CHANGED
@@ -22,7 +22,7 @@ try:
22
 
23
  XFORMERS_AVAILABLE = True
24
  except ImportError:
25
- logger.warning("xFormers not available")
26
  XFORMERS_AVAILABLE = False
27
 
28
 
 
22
 
23
  XFORMERS_AVAILABLE = True
24
  except ImportError:
25
+ # logger.warning("xFormers not available")
26
  XFORMERS_AVAILABLE = False
27
 
28
 
imcui/third_party/RoMa/romatch/models/transformer/layers/block.py CHANGED
@@ -29,7 +29,7 @@ try:
29
 
30
  XFORMERS_AVAILABLE = True
31
  except ImportError:
32
- logger.warning("xFormers not available")
33
  XFORMERS_AVAILABLE = False
34
 
35
 
 
29
 
30
  XFORMERS_AVAILABLE = True
31
  except ImportError:
32
+ # logger.warning("xFormers not available")
33
  XFORMERS_AVAILABLE = False
34
 
35
 
imcui/third_party/RoMa/romatch/utils/utils.py CHANGED
@@ -651,4 +651,12 @@ def get_autocast_params(device=None, enabled=False, dtype=None):
651
  enabled = False
652
  # mps is not supported
653
  autocast_device = "cpu"
654
- return autocast_device, enabled, out_dtype
 
 
 
 
 
 
 
 
 
651
  enabled = False
652
  # mps is not supported
653
  autocast_device = "cpu"
654
+ return autocast_device, enabled, out_dtype
655
+
656
+ def check_not_i16(im):
657
+ if im.mode == "I;16":
658
+ raise NotImplementedError("Can't handle 16 bit images")
659
+
660
+ def check_rgb(im):
661
+ if im.mode != "RGB":
662
+ raise NotImplementedError("Can't handle non-RGB images")
imcui/third_party/RoMa/setup.py CHANGED
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
  setup(
4
  name="romatch",
5
  packages=find_packages(include=("romatch*",)),
6
- version="0.0.1",
7
  author="Johan Edstedt",
8
  install_requires=open("requirements.txt", "r").read().split("\n"),
9
  )
 
3
  setup(
4
  name="romatch",
5
  packages=find_packages(include=("romatch*",)),
6
+ version="0.0.2",
7
  author="Johan Edstedt",
8
  install_requires=open("requirements.txt", "r").read().split("\n"),
9
  )
imcui/third_party/dad/.gitignore ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .vscode*
163
+ *.pth
164
+ wandb
165
+ *.out
166
+ vis/
167
+ workspace/
168
+
169
+ .DS_Store
170
+ *.tar
imcui/third_party/dad/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
imcui/third_party/dad/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Johan Edstedt
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
imcui/third_party/dad/README.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <h1 align="center"> <ins>DaD:</ins> Distilled Reinforcement Learning for Diverse Keypoint Detection</h1>
3
+ <p align="center">
4
+ <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
5
+ ·
6
+ <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
7
+ ·
8
+ <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
9
+ ·
10
+ <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
11
+ </p>
12
+ <h2 align="center"><p>
13
+ <a href="https://arxiv.org/abs/2503.07347" align="center">Paper</a>
14
+ </p></h2>
15
+ <p align="center">
16
+ <img src="assets/qualitative.jpg" alt="example" width=80%>
17
+ <br>
18
+ <em>DaD's a pretty good keypoint detector, probably the best.</em>
19
+ </p>
20
+ </p>
21
+ <p align="center">
22
+ </p>
23
+
24
+ ## Run
25
+ ```python
26
+ import dad
27
+ from PIL import Image
28
+ img_path = "assets/0015_A.jpg"
29
+ W, H = Image.open(img_path).size# your image shape,
30
+ detector = dad.load_DaD()
31
+ detections = detector.detect_from_path(
32
+ img_path,
33
+ num_keypoints = 512,
34
+ return_dense_probs=True)
35
+ detections["keypoints"] # 1 x 512 x 2, normalized coordinates of keypoints
36
+ detector.to_pixel_coords(detections["keypoints"], H, W)
37
+ detections["keypoint_probs"] # 1 x 512, probs of sampled keypoints
38
+ detections["dense_probs"] # 1 x H x W, probability map
39
+ ```
40
+
41
+ ## Visualize
42
+ ```python
43
+ import dad
44
+ from dad.utils import visualize_keypoints
45
+ detector = dad.load_DaD()
46
+ img_path = "assets/0015_A.jpg"
47
+ vis_path = "vis/0015_A_dad.jpg"
48
+ visualize_keypoints(img_path, vis_path, detector, num_keypoints = 512)
49
+ ```
50
+
51
+ ## Install
52
+ Get uv
53
+ ```bash
54
+ curl -LsSf https://astral.sh/uv/install.sh | sh
55
+ ```
56
+ ### In an existing env
57
+ Assuming you already have some env active:
58
+ ```bash
59
+ uv pip install dad@git+https://github.com/Parskatt/dad.git
60
+ ```
61
+ ### As a project
62
+ For dev, etc:
63
+ ```bash
64
+ git clone [email protected]:Parskatt/dad.git
65
+ uv sync
66
+ source .venv/bin/activate
67
+ ```
68
+
69
+ ## Evaluation
70
+ For to evaluate, e.g., DaD on ScanNet1500 with 512 keypoints, run
71
+ ```bash
72
+ python experiments/benchmark.py --detector DaD --num_keypoints 512 --benchmark ScanNet1500
73
+ ```
74
+ Note: leaving out num_keypoints will run the benchmark for all numbers of keypoints, i.e., [512, 1024, 2048, 4096, 8192].
75
+ ### Third party detectors
76
+ We provide wrappers for a somewhat large set of previous detectors,
77
+ ```bash
78
+ python experiments/benchmark.py --help
79
+ ```
80
+
81
+ ## Training
82
+ To train our final model from the emergent light and dark detector, run
83
+ ```bash
84
+ python experiments/repro_paper_results/distill.py
85
+ ```
86
+ The emergent models come from running
87
+ ```bash
88
+ python experiments/repro_paper_results/rl.py
89
+ ```
90
+ Note however that the types of detectors that come from this type of training is stochastic, and you may need to do several runs to get a detector that matches our results.
91
+
92
+ ## How I run experiments
93
+ (Note: You don't have to do this, it's just how I do it.)
94
+ At the start of a new day I typically run
95
+ ```bash
96
+ python new_day.py
97
+ ```
98
+ This creates a new folder in experiments, e.g., `experiments/w11/monday`.
99
+ I then typically just copy the contents of a previous experiment, e.g.,
100
+ ```bash
101
+ cp experiments/repro_paper_results/rl.py experiments/w11/monday/new-cool-hparams.py
102
+ ```
103
+ Change whatever you want to change in `experiments/w11/monday/new-cool-hparams.py`.
104
+
105
+ Then run it with
106
+ ```bash
107
+ python experiments/w11/monday/new-cool-hparams.py
108
+ ```
109
+ This will be tracked in wandb as `w11-monday-new-cool-hparams` in the `DaD` project.
110
+
111
+ You might not want to track stuff, and perhaps display some debugstuff, then you can run instead as, which also won't log to wandb
112
+ ```bash
113
+ DEBUG=1 python experiments/w11/monday/new-cool-hparams.py
114
+ ```
115
+ ## Evaluation Results
116
+ TODO
117
+
118
+ ## Licenses
119
+ DaD is MIT licensed.
120
+
121
+ Third party detectors in [dad/detectors/third_party](dad/detectors/third_party) have their own licenses. If you use them, please refer to their respective licenses in [here](licenses) (NOTE: There may be more licenses you need to care about than the ones listed. Before using any third pary code, make sure you're following their respective license).
122
+
123
+
124
+
125
+
126
+ ## BibTeX
127
+
128
+ ```txt
129
+ TODO
130
+ ```
imcui/third_party/dad/dad/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .logging import logger as logger
2
+ from .logging import configure_logger as configure_logger
3
+ import os
4
+ from .detectors import load_DaD as load_DaD
5
+ from .detectors import dedode_detector_S as dedode_detector_S
6
+ from .detectors import dedode_detector_B as dedode_detector_B
7
+ from .detectors import dedode_detector_L as dedode_detector_L
8
+ from .detectors import load_DaDDark as load_DaDDark
9
+ from .detectors import load_DaDLight as load_DaDLight
10
+ from .types import Detector as Detector
11
+ from .types import Matcher as Matcher
12
+ from .types import Benchmark as Benchmark
13
+
14
+ configure_logger()
15
+ DEBUG_MODE = bool(os.environ.get("DEBUG", False))
16
+ RANK = 0
17
+ GLOBAL_STEP = 0
imcui/third_party/dad/dad/augs.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from torchvision.transforms.functional import InterpolationMode
8
+ import cv2
9
+
10
+
11
+ # From Patch2Pix https://github.com/GrumpyZhou/patch2pix
12
+ def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
13
+ ops = []
14
+ if resize:
15
+ ops.append(
16
+ TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias=False)
17
+ )
18
+ return TupleCompose(ops)
19
+
20
+
21
+ def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe=False):
22
+ ops = []
23
+ if resize:
24
+ ops.append(TupleResize(resize, antialias=True))
25
+ if clahe:
26
+ ops.append(TupleClahe())
27
+ if normalize:
28
+ ops.append(TupleToTensorScaled())
29
+ ops.append(
30
+ TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
31
+ ) # Imagenet mean/std
32
+ else:
33
+ if unscale:
34
+ ops.append(TupleToTensorUnscaled())
35
+ else:
36
+ ops.append(TupleToTensorScaled())
37
+ return TupleCompose(ops)
38
+
39
+
40
+ class Clahe:
41
+ def __init__(self, cliplimit=2, blocksize=8) -> None:
42
+ self.clahe = cv2.createCLAHE(cliplimit, (blocksize, blocksize))
43
+
44
+ def __call__(self, im):
45
+ im_hsv = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2HSV)
46
+ im_v = self.clahe.apply(im_hsv[:, :, 2])
47
+ im_hsv[..., 2] = im_v
48
+ im_clahe = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2RGB)
49
+ return Image.fromarray(im_clahe)
50
+
51
+
52
+ class TupleClahe:
53
+ def __init__(self, cliplimit=8, blocksize=8) -> None:
54
+ self.clahe = Clahe(cliplimit, blocksize)
55
+
56
+ def __call__(self, ims):
57
+ return [self.clahe(im) for im in ims]
58
+
59
+
60
+ class ToTensorScaled(object):
61
+ """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
62
+
63
+ def __call__(self, im):
64
+ if not isinstance(im, torch.Tensor):
65
+ im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
66
+ im /= 255.0
67
+ return torch.from_numpy(im)
68
+ else:
69
+ return im
70
+
71
+ def __repr__(self):
72
+ return "ToTensorScaled(./255)"
73
+
74
+
75
+ class TupleToTensorScaled(object):
76
+ def __init__(self):
77
+ self.to_tensor = ToTensorScaled()
78
+
79
+ def __call__(self, im_tuple):
80
+ return [self.to_tensor(im) for im in im_tuple]
81
+
82
+ def __repr__(self):
83
+ return "TupleToTensorScaled(./255)"
84
+
85
+
86
+ class ToTensorUnscaled(object):
87
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
88
+
89
+ def __call__(self, im):
90
+ return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
91
+
92
+ def __repr__(self):
93
+ return "ToTensorUnscaled()"
94
+
95
+
96
+ class TupleToTensorUnscaled(object):
97
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
98
+
99
+ def __init__(self):
100
+ self.to_tensor = ToTensorUnscaled()
101
+
102
+ def __call__(self, im_tuple):
103
+ return [self.to_tensor(im) for im in im_tuple]
104
+
105
+ def __repr__(self):
106
+ return "TupleToTensorUnscaled()"
107
+
108
+
109
+ class TupleResize(object):
110
+ def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias=None):
111
+ self.size = size
112
+ self.resize = transforms.Resize(size, mode, antialias=antialias)
113
+
114
+ def __call__(self, im_tuple):
115
+ return [self.resize(im) for im in im_tuple]
116
+
117
+ def __repr__(self):
118
+ return "TupleResize(size={})".format(self.size)
119
+
120
+
121
+ class Normalize:
122
+ def __call__(self, im):
123
+ mean = im.mean(dim=(1, 2), keepdims=True)
124
+ std = im.std(dim=(1, 2), keepdims=True)
125
+ return (im - mean) / std
126
+
127
+
128
+ class TupleNormalize(object):
129
+ def __init__(self, mean, std):
130
+ self.mean = mean
131
+ self.std = std
132
+ self.normalize = transforms.Normalize(mean=mean, std=std)
133
+
134
+ def __call__(self, im_tuple):
135
+ c, h, w = im_tuple[0].shape
136
+ if c > 3:
137
+ warnings.warn(f"Number of channels {c=} > 3, assuming first 3 are rgb")
138
+ return [self.normalize(im[:3]) for im in im_tuple]
139
+
140
+ def __repr__(self):
141
+ return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
142
+
143
+
144
+ class TupleCompose(object):
145
+ def __init__(self, transforms):
146
+ self.transforms = transforms
147
+
148
+ def __call__(self, im_tuple):
149
+ for t in self.transforms:
150
+ im_tuple = t(im_tuple)
151
+ return im_tuple
152
+
153
+ def __repr__(self):
154
+ format_string = self.__class__.__name__ + "("
155
+ for t in self.transforms:
156
+ format_string += "\n"
157
+ format_string += " {0}".format(t)
158
+ format_string += "\n)"
159
+ return format_string
160
+
161
+
162
+ def pad_kps(kps: torch.Tensor, pad_num_kps: int, value: int = -1):
163
+ assert len(kps.shape) == 2
164
+ N = len(kps)
165
+ padded_kps = value * torch.ones((pad_num_kps, 2)).to(kps)
166
+ padded_kps[:N] = kps
167
+ return padded_kps
168
+
169
+
170
+ def crop(img: Image.Image, x: int, y: int, crop_size: int):
171
+ width, height = img.size
172
+ if width < crop_size or height < crop_size:
173
+ raise ValueError(f"Image dimensions must be at least {crop_size}x{crop_size}")
174
+ cropped_img = img.crop((x, y, x + crop_size, y + crop_size))
175
+ return cropped_img
176
+
177
+
178
+ def random_crop(img: Image.Image, crop_size: int):
179
+ width, height = img.size
180
+
181
+ if width < crop_size or height < crop_size:
182
+ raise ValueError(f"Image dimensions must be at least {crop_size}x{crop_size}")
183
+
184
+ max_x = width - crop_size
185
+ max_y = height - crop_size
186
+
187
+ x = random.randint(0, max_x)
188
+ y = random.randint(0, max_y)
189
+
190
+ cropped_img = img.crop((x, y, x + crop_size, y + crop_size))
191
+ return cropped_img, (x, y)
192
+
193
+
194
+ def luminance_negation(pil_img):
195
+ # Convert PIL RGB to numpy array
196
+ rgb_array = np.array(pil_img)
197
+
198
+ # Convert RGB to BGR (OpenCV format)
199
+ bgr = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)
200
+
201
+ # Convert BGR to LAB
202
+ lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
203
+
204
+ # Negate L channel
205
+ lab[:, :, 0] = 255 - lab[:, :, 0]
206
+
207
+ # Convert back to BGR
208
+ bgr_result = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
209
+
210
+ # Convert BGR back to RGB
211
+ rgb_result = cv2.cvtColor(bgr_result, cv2.COLOR_BGR2RGB)
212
+
213
+ # Convert numpy array back to PIL Image
214
+ return Image.fromarray(rgb_result)
imcui/third_party/dad/dad/benchmarks/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .benchmark import Benchmark as Benchmark
2
+ from .num_inliers import NumInliersBenchmark as NumInliersBenchmark
3
+ from .megadepth import Mega1500 as Mega1500
4
+ from .megadepth import Mega1500_F as Mega1500_F
5
+ from .megadepth import MegaIMCPT as MegaIMCPT
6
+ from .megadepth import MegaIMCPT_F as MegaIMCPT_F
7
+ from .scannet import ScanNet1500 as ScanNet1500
8
+ from .scannet import ScanNet1500_F as ScanNet1500_F
9
+ from .hpatches import HPatchesViewpoint as HPatchesViewpoint
10
+ from .hpatches import HPatchesIllum as HPatchesIllum
11
+
12
+ all_benchmarks = [
13
+ Mega1500.__name__,
14
+ Mega1500_F.__name__,
15
+ MegaIMCPT.__name__,
16
+ MegaIMCPT_F.__name__,
17
+ ScanNet1500.__name__,
18
+ ScanNet1500_F.__name__,
19
+ HPatchesViewpoint.__name__,
20
+ HPatchesIllum.__name__,
21
+ ]
imcui/third_party/dad/dad/benchmarks/hpatches.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import poselib
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ from dad.types import Detector, Matcher, Benchmark
10
+
11
+
12
+ class HPatchesBenchmark(Benchmark):
13
+ def __init__(
14
+ self,
15
+ data_root="data/hpatches",
16
+ sample_every=1,
17
+ num_ransac_runs=5,
18
+ num_keypoints: Optional[list[int]] = None,
19
+ ) -> None:
20
+ super().__init__(
21
+ data_root=data_root,
22
+ num_keypoints=num_keypoints,
23
+ sample_every=sample_every,
24
+ num_ransac_runs=num_ransac_runs,
25
+ thresholds=[3, 5, 10],
26
+ )
27
+ seqs_dir = "hpatches-sequences-release"
28
+ self.seqs_path = os.path.join(self.data_root, seqs_dir)
29
+ self.seq_names = sorted(os.listdir(self.seqs_path))
30
+ self.topleft = 0.0
31
+ self._post_init()
32
+ self.skip_seqs: str
33
+ self.scene_names: list[str]
34
+
35
+ def _post_init(self):
36
+ # set self.skip_seqs and self.scene_names here
37
+ raise NotImplementedError()
38
+
39
+ def benchmark(self, detector: Detector, matcher: Matcher):
40
+ homog_dists = []
41
+ for seq_idx, seq_name in enumerate(tqdm(self.seq_names[:: self.sample_every])):
42
+ if self.skip_seqs in seq_name:
43
+ # skip illumination seqs
44
+ continue
45
+ im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
46
+ im_A = Image.open(im_A_path)
47
+ w1, h1 = im_A.size
48
+ for im_idx in list(range(2, 7)):
49
+ im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
50
+ H = np.loadtxt(
51
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
52
+ )
53
+ warp, certainty = matcher.match(im_A_path, im_B_path)
54
+ for num_kps in self.num_keypoints:
55
+ keypoints_A = detector.detect_from_path(
56
+ im_A_path,
57
+ num_keypoints=num_kps,
58
+ )["keypoints"][0]
59
+ keypoints_B = detector.detect_from_path(
60
+ im_B_path,
61
+ num_keypoints=num_kps,
62
+ )["keypoints"][0]
63
+ matches = matcher.match_keypoints(
64
+ keypoints_A,
65
+ keypoints_B,
66
+ warp,
67
+ certainty,
68
+ return_tuple=False,
69
+ )
70
+ im_A = Image.open(im_A_path)
71
+ w1, h1 = im_A.size
72
+ im_B = Image.open(im_B_path)
73
+ w2, h2 = im_B.size
74
+ kpts1, kpts2 = matcher.to_pixel_coordinates(matches, h1, w1, h2, w2)
75
+ offset = detector.topleft - self.topleft
76
+ kpts1, kpts2 = kpts1 - offset, kpts2 - offset
77
+ for _ in range(self.num_ransac_runs):
78
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
79
+ kpts1 = kpts1[shuffling]
80
+ kpts2 = kpts2[shuffling]
81
+ threshold = 2.0
82
+ H_pred, res = poselib.estimate_homography(
83
+ kpts1.cpu().numpy(),
84
+ kpts2.cpu().numpy(),
85
+ ransac_opt={
86
+ "max_reproj_error": threshold,
87
+ },
88
+ )
89
+ corners = np.array(
90
+ [
91
+ [0, 0, 1],
92
+ [0, h1 - 1, 1],
93
+ [w1 - 1, 0, 1],
94
+ [w1 - 1, h1 - 1, 1],
95
+ ]
96
+ )
97
+ real_warped_corners = np.dot(corners, np.transpose(H))
98
+ real_warped_corners = (
99
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
100
+ )
101
+ warped_corners = np.dot(corners, np.transpose(H_pred))
102
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
103
+ mean_dist = np.mean(
104
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
105
+ ) / (min(w2, h2) / 480.0)
106
+ homog_dists.append(mean_dist)
107
+ return self.compute_auc(np.array(homog_dists))
108
+
109
+
110
+ class HPatchesViewpoint(HPatchesBenchmark):
111
+ def _post_init(self):
112
+ self.skip_seqs = "i_"
113
+
114
+
115
+ class HPatchesIllum(HPatchesBenchmark):
116
+ def _post_init(self):
117
+ self.skip_seqs = "v_"
imcui/third_party/dad/dad/benchmarks/megadepth.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+ from dad.types import Detector, Matcher, Benchmark
8
+ from dad.utils import (
9
+ compute_pose_error,
10
+ compute_relative_pose,
11
+ estimate_pose_essential,
12
+ estimate_pose_fundamental,
13
+ )
14
+
15
+
16
+ class MegaDepthPoseEstimationBenchmark(Benchmark):
17
+ def __init__(
18
+ self,
19
+ data_root="data/megadepth",
20
+ sample_every=1,
21
+ num_ransac_runs=5,
22
+ num_keypoints: Optional[list[int]] = None,
23
+ ) -> None:
24
+ super().__init__(
25
+ data_root=data_root,
26
+ num_keypoints=num_keypoints,
27
+ sample_every=sample_every,
28
+ num_ransac_runs=num_ransac_runs,
29
+ thresholds=[5, 10, 20],
30
+ )
31
+ self.sample_every = sample_every
32
+ self.topleft = 0.5
33
+ self._post_init()
34
+ self.model: Literal["fundamental", "essential"]
35
+ self.scene_names: list[str]
36
+ self.benchmark_name: str
37
+
38
+ def _post_init(self):
39
+ raise NotImplementedError(
40
+ "Add scene names and benchmark name in derived class _post_init"
41
+ )
42
+
43
+ def benchmark(
44
+ self,
45
+ detector: Detector,
46
+ matcher: Matcher,
47
+ ):
48
+ self.scenes = [
49
+ np.load(f"{self.data_root}/{scene}", allow_pickle=True)
50
+ for scene in self.scene_names
51
+ ]
52
+
53
+ data_root = self.data_root
54
+ tot_e_pose = []
55
+ n_matches = []
56
+ for scene_ind in range(len(self.scenes)):
57
+ scene = self.scenes[scene_ind]
58
+ pairs = scene["pair_infos"]
59
+ intrinsics = scene["intrinsics"]
60
+ poses = scene["poses"]
61
+ im_paths = scene["image_paths"]
62
+ pair_inds = range(len(pairs))
63
+ for pairind in (
64
+ pbar := tqdm(
65
+ pair_inds[:: self.sample_every],
66
+ desc="Current AUC: ?",
67
+ mininterval=10,
68
+ )
69
+ ):
70
+ idx1, idx2 = pairs[pairind][0]
71
+ K1 = intrinsics[idx1].copy()
72
+ T1 = poses[idx1].copy()
73
+ R1, t1 = T1[:3, :3], T1[:3, 3]
74
+ K2 = intrinsics[idx2].copy()
75
+ T2 = poses[idx2].copy()
76
+ R2, t2 = T2[:3, :3], T2[:3, 3]
77
+ R, t = compute_relative_pose(R1, t1, R2, t2)
78
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
79
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
80
+
81
+ warp, certainty = matcher.match(im_A_path, im_B_path)
82
+ for num_kps in self.num_keypoints:
83
+ keypoints_A = detector.detect_from_path(
84
+ im_A_path,
85
+ num_keypoints=num_kps,
86
+ )["keypoints"][0]
87
+ keypoints_B = detector.detect_from_path(
88
+ im_B_path,
89
+ num_keypoints=num_kps,
90
+ )["keypoints"][0]
91
+ matches = matcher.match_keypoints(
92
+ keypoints_A,
93
+ keypoints_B,
94
+ warp,
95
+ certainty,
96
+ return_tuple=False,
97
+ )
98
+ n_matches.append(matches.shape[0])
99
+ im_A = Image.open(im_A_path)
100
+ w1, h1 = im_A.size
101
+ im_B = Image.open(im_B_path)
102
+ w2, h2 = im_B.size
103
+ kpts1, kpts2 = matcher.to_pixel_coordinates(matches, h1, w1, h2, w2)
104
+ offset = detector.topleft - self.topleft
105
+ kpts1, kpts2 = kpts1 - offset, kpts2 - offset
106
+
107
+ for _ in range(self.num_ransac_runs):
108
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
109
+ kpts1 = kpts1[shuffling]
110
+ kpts2 = kpts2[shuffling]
111
+ threshold = 2.0
112
+ if self.model == "essential":
113
+ R_est, t_est = estimate_pose_essential(
114
+ kpts1.cpu().numpy(),
115
+ kpts2.cpu().numpy(),
116
+ w1,
117
+ h1,
118
+ K1,
119
+ w2,
120
+ h2,
121
+ K2,
122
+ threshold,
123
+ )
124
+ elif self.model == "fundamental":
125
+ R_est, t_est = estimate_pose_fundamental(
126
+ kpts1.cpu().numpy(),
127
+ kpts2.cpu().numpy(),
128
+ w1,
129
+ h1,
130
+ K1,
131
+ w2,
132
+ h2,
133
+ K2,
134
+ threshold,
135
+ )
136
+ T1_to_2_est = np.concatenate((R_est, t_est[:, None]), axis=-1)
137
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
138
+ e_pose = max(e_t, e_R)
139
+ tot_e_pose.append(e_pose)
140
+ pbar.set_description(
141
+ f"Current AUCS: {self.compute_auc(np.array(tot_e_pose))}"
142
+ )
143
+ n_matches = np.array(n_matches)
144
+ print(n_matches.mean(), np.median(n_matches), np.std(n_matches))
145
+ return self.compute_auc(np.array(tot_e_pose))
146
+
147
+
148
+ class Mega1500(MegaDepthPoseEstimationBenchmark):
149
+ def _post_init(self):
150
+ self.scene_names = [
151
+ "0015_0.1_0.3.npz",
152
+ "0015_0.3_0.5.npz",
153
+ "0022_0.1_0.3.npz",
154
+ "0022_0.3_0.5.npz",
155
+ "0022_0.5_0.7.npz",
156
+ ]
157
+ self.benchmark_name = "Mega1500"
158
+ self.model = "essential"
159
+
160
+
161
+ class Mega1500_F(MegaDepthPoseEstimationBenchmark):
162
+ def _post_init(self):
163
+ self.scene_names = [
164
+ "0015_0.1_0.3.npz",
165
+ "0015_0.3_0.5.npz",
166
+ "0022_0.1_0.3.npz",
167
+ "0022_0.3_0.5.npz",
168
+ "0022_0.5_0.7.npz",
169
+ ]
170
+ # self.benchmark_name = "Mega1500_F"
171
+ self.model = "fundamental"
172
+
173
+
174
+ class MegaIMCPT(MegaDepthPoseEstimationBenchmark):
175
+ def _post_init(self):
176
+ self.scene_names = [
177
+ "mega_8_scenes_0008_0.1_0.3.npz",
178
+ "mega_8_scenes_0008_0.3_0.5.npz",
179
+ "mega_8_scenes_0019_0.1_0.3.npz",
180
+ "mega_8_scenes_0019_0.3_0.5.npz",
181
+ "mega_8_scenes_0021_0.1_0.3.npz",
182
+ "mega_8_scenes_0021_0.3_0.5.npz",
183
+ "mega_8_scenes_0024_0.1_0.3.npz",
184
+ "mega_8_scenes_0024_0.3_0.5.npz",
185
+ "mega_8_scenes_0025_0.1_0.3.npz",
186
+ "mega_8_scenes_0025_0.3_0.5.npz",
187
+ "mega_8_scenes_0032_0.1_0.3.npz",
188
+ "mega_8_scenes_0032_0.3_0.5.npz",
189
+ "mega_8_scenes_0063_0.1_0.3.npz",
190
+ "mega_8_scenes_0063_0.3_0.5.npz",
191
+ "mega_8_scenes_1589_0.1_0.3.npz",
192
+ "mega_8_scenes_1589_0.3_0.5.npz",
193
+ ]
194
+ # self.benchmark_name = "MegaIMCPT"
195
+ self.model = "essential"
196
+
197
+
198
+ class MegaIMCPT_F(MegaDepthPoseEstimationBenchmark):
199
+ def _post_init(self):
200
+ self.scene_names = [
201
+ "mega_8_scenes_0008_0.1_0.3.npz",
202
+ "mega_8_scenes_0008_0.3_0.5.npz",
203
+ "mega_8_scenes_0019_0.1_0.3.npz",
204
+ "mega_8_scenes_0019_0.3_0.5.npz",
205
+ "mega_8_scenes_0021_0.1_0.3.npz",
206
+ "mega_8_scenes_0021_0.3_0.5.npz",
207
+ "mega_8_scenes_0024_0.1_0.3.npz",
208
+ "mega_8_scenes_0024_0.3_0.5.npz",
209
+ "mega_8_scenes_0025_0.1_0.3.npz",
210
+ "mega_8_scenes_0025_0.3_0.5.npz",
211
+ "mega_8_scenes_0032_0.1_0.3.npz",
212
+ "mega_8_scenes_0032_0.3_0.5.npz",
213
+ "mega_8_scenes_0063_0.1_0.3.npz",
214
+ "mega_8_scenes_0063_0.3_0.5.npz",
215
+ "mega_8_scenes_1589_0.1_0.3.npz",
216
+ "mega_8_scenes_1589_0.3_0.5.npz",
217
+ ]
218
+ # self.benchmark_name = "MegaIMCPT_F"
219
+ self.model = "fundamental"
imcui/third_party/dad/dad/benchmarks/num_inliers.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from tqdm import tqdm
4
+
5
+ from dad.types import Detector
6
+ from dad.utils import get_gt_warp, to_best_device
7
+
8
+
9
+ class NumInliersBenchmark:
10
+ def __init__(
11
+ self,
12
+ dataset,
13
+ num_samples=1000,
14
+ batch_size=8,
15
+ num_keypoints=512,
16
+ **kwargs,
17
+ ) -> None:
18
+ sampler = torch.utils.data.WeightedRandomSampler(
19
+ torch.ones(len(dataset)), replacement=False, num_samples=num_samples
20
+ )
21
+ dataloader = torch.utils.data.DataLoader(
22
+ dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler
23
+ )
24
+ self.dataloader = dataloader
25
+ self.tracked_metrics = {}
26
+ self.batch_size = batch_size
27
+ self.N = len(dataloader)
28
+ self.num_keypoints = num_keypoints
29
+
30
+ def compute_batch_metrics(self, outputs, batch):
31
+ kpts_A, kpts_B = outputs["keypoints_A"], outputs["keypoints_B"]
32
+ B, K, H, W = batch["im_A"].shape
33
+ gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(
34
+ batch["im_A_depth"],
35
+ batch["im_B_depth"],
36
+ batch["T_1to2"],
37
+ batch["K1"],
38
+ batch["K2"],
39
+ H=H,
40
+ W=W,
41
+ )
42
+ kpts_A_to_B = F.grid_sample(
43
+ gt_warp_A_to_B[..., 2:].float().permute(0, 3, 1, 2),
44
+ kpts_A[..., None, :],
45
+ align_corners=False,
46
+ mode="bilinear",
47
+ )[..., 0].mT
48
+ legit_A_to_B = F.grid_sample(
49
+ valid_mask_A_to_B.reshape(B, 1, H, W),
50
+ kpts_A[..., None, :],
51
+ align_corners=False,
52
+ mode="bilinear",
53
+ )[..., 0, :, 0]
54
+ dists = (
55
+ torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.0]
56
+ ).float()
57
+ if legit_A_to_B.sum() == 0:
58
+ return
59
+ percent_inliers_at_1 = (dists < 0.02).float().mean()
60
+ percent_inliers_at_05 = (dists < 0.01).float().mean()
61
+ percent_inliers_at_025 = (dists < 0.005).float().mean()
62
+ percent_inliers_at_01 = (dists < 0.002).float().mean()
63
+ percent_inliers_at_005 = (dists < 0.001).float().mean()
64
+
65
+ self.tracked_metrics["percent_inliers_at_1"] = (
66
+ self.tracked_metrics.get("percent_inliers_at_1", 0)
67
+ + 1 / self.N * percent_inliers_at_1
68
+ )
69
+ self.tracked_metrics["percent_inliers_at_05"] = (
70
+ self.tracked_metrics.get("percent_inliers_at_05", 0)
71
+ + 1 / self.N * percent_inliers_at_05
72
+ )
73
+ self.tracked_metrics["percent_inliers_at_025"] = (
74
+ self.tracked_metrics.get("percent_inliers_at_025", 0)
75
+ + 1 / self.N * percent_inliers_at_025
76
+ )
77
+ self.tracked_metrics["percent_inliers_at_01"] = (
78
+ self.tracked_metrics.get("percent_inliers_at_01", 0)
79
+ + 1 / self.N * percent_inliers_at_01
80
+ )
81
+ self.tracked_metrics["percent_inliers_at_005"] = (
82
+ self.tracked_metrics.get("percent_inliers_at_005", 0)
83
+ + 1 / self.N * percent_inliers_at_005
84
+ )
85
+
86
+ def benchmark(self, detector: Detector):
87
+ self.tracked_metrics = {}
88
+
89
+ print("Evaluating percent inliers...")
90
+ for idx, batch in enumerate(tqdm(self.dataloader, mininterval=10.0)):
91
+ batch = to_best_device(batch)
92
+ outputs = detector.detect(batch, num_keypoints=self.num_keypoints)
93
+ keypoints_A, keypoints_B = outputs["keypoints"].chunk(2)
94
+ if isinstance(outputs["keypoints"], (tuple, list)):
95
+ keypoints_A, keypoints_B = (
96
+ torch.stack(keypoints_A),
97
+ torch.stack(keypoints_B),
98
+ )
99
+ outputs = {"keypoints_A": keypoints_A, "keypoints_B": keypoints_B}
100
+ self.compute_batch_metrics(outputs, batch)
101
+ [
102
+ print(name, metric.item() * self.N / (idx + 1))
103
+ for name, metric in self.tracked_metrics.items()
104
+ if "percent" in name
105
+ ]
106
+ return self.tracked_metrics
imcui/third_party/dad/dad/benchmarks/scannet.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ from typing import Literal, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ from dad.types import Detector, Matcher, Benchmark
10
+ from dad.utils import (
11
+ compute_pose_error,
12
+ estimate_pose_essential,
13
+ estimate_pose_fundamental,
14
+ )
15
+
16
+
17
+ class ScanNetBenchmark(Benchmark):
18
+ def __init__(
19
+ self,
20
+ sample_every: int = 1,
21
+ num_ransac_runs=5,
22
+ data_root: str = "data/scannet",
23
+ num_keypoints: Optional[list[int]] = None,
24
+ ) -> None:
25
+ super().__init__(
26
+ data_root=data_root,
27
+ num_keypoints=num_keypoints,
28
+ sample_every=sample_every,
29
+ num_ransac_runs=num_ransac_runs,
30
+ thresholds=[5, 10, 20],
31
+ )
32
+ self.sample_every = sample_every
33
+ self.topleft = 0.0
34
+ self._post_init()
35
+ self.model: Literal["fundamental", "essential"]
36
+ self.test_pairs: str
37
+ self.benchmark_name: str
38
+
39
+ def _post_init(self):
40
+ # set
41
+ raise NotImplementedError("")
42
+
43
+ @torch.no_grad()
44
+ def benchmark(self, matcher: Matcher, detector: Detector):
45
+ tmp = np.load(self.test_pairs)
46
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
47
+ tot_e_pose = []
48
+ # pair_inds = np.random.choice(range(len(pairs)), size=len(pairs), replace=False)
49
+ for pairind in tqdm(
50
+ range(0, len(pairs), self.sample_every), smoothing=0.9, mininterval=10
51
+ ):
52
+ scene = pairs[pairind]
53
+ scene_name = f"scene0{scene[0]}_00"
54
+ im_A_path = osp.join(
55
+ self.data_root,
56
+ "scans_test",
57
+ scene_name,
58
+ "color",
59
+ f"{scene[2]}.jpg",
60
+ )
61
+ im_A = Image.open(im_A_path)
62
+ im_B_path = osp.join(
63
+ self.data_root,
64
+ "scans_test",
65
+ scene_name,
66
+ "color",
67
+ f"{scene[3]}.jpg",
68
+ )
69
+ im_B = Image.open(im_B_path)
70
+ T_gt = rel_pose[pairind].reshape(3, 4)
71
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
72
+ K = np.stack(
73
+ [
74
+ np.array([float(i) for i in r.split()])
75
+ for r in open(
76
+ osp.join(
77
+ self.data_root,
78
+ "scans_test",
79
+ scene_name,
80
+ "intrinsic",
81
+ "intrinsic_color.txt",
82
+ ),
83
+ "r",
84
+ )
85
+ .read()
86
+ .split("\n")
87
+ if r
88
+ ]
89
+ )
90
+ w1, h1 = im_A.size
91
+ w2, h2 = im_B.size
92
+ K1 = K.copy()[:3, :3]
93
+ K2 = K.copy()[:3, :3]
94
+ warp, certainty = matcher.match(im_A_path, im_B_path)
95
+ for num_kps in self.num_keypoints:
96
+ keypoints_A = detector.detect_from_path(
97
+ im_A_path,
98
+ num_keypoints=num_kps,
99
+ )["keypoints"][0]
100
+ keypoints_B = detector.detect_from_path(
101
+ im_B_path,
102
+ num_keypoints=num_kps,
103
+ )["keypoints"][0]
104
+ matches = matcher.match_keypoints(
105
+ keypoints_A,
106
+ keypoints_B,
107
+ warp,
108
+ certainty,
109
+ return_tuple=False,
110
+ )
111
+ kpts1, kpts2 = matcher.to_pixel_coordinates(matches, h1, w1, h2, w2)
112
+
113
+ offset = detector.topleft - self.topleft
114
+ kpts1, kpts2 = kpts1 - offset, kpts2 - offset
115
+
116
+ for _ in range(self.num_ransac_runs):
117
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
118
+ kpts1 = kpts1[shuffling]
119
+ kpts2 = kpts2[shuffling]
120
+ threshold = 2.0
121
+ if self.model == "essential":
122
+ R_est, t_est = estimate_pose_essential(
123
+ kpts1.cpu().numpy(),
124
+ kpts2.cpu().numpy(),
125
+ w1,
126
+ h1,
127
+ K1,
128
+ w2,
129
+ h2,
130
+ K2,
131
+ threshold,
132
+ )
133
+ elif self.model == "fundamental":
134
+ R_est, t_est = estimate_pose_fundamental(
135
+ kpts1.cpu().numpy(),
136
+ kpts2.cpu().numpy(),
137
+ w1,
138
+ h1,
139
+ K1,
140
+ w2,
141
+ h2,
142
+ K2,
143
+ threshold,
144
+ )
145
+ T1_to_2_est = np.concatenate((R_est, t_est[:, None]), axis=-1)
146
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
147
+ e_pose = max(e_t, e_R)
148
+ tot_e_pose.append(e_pose)
149
+ return self.compute_auc(np.array(tot_e_pose))
150
+
151
+
152
+ class ScanNet1500(ScanNetBenchmark):
153
+ def _post_init(self):
154
+ self.test_pairs = osp.join(self.data_root, "test.npz")
155
+ self.benchmark_name = "ScanNet1500"
156
+ self.model = "essential"
157
+
158
+
159
+ class ScanNet1500_F(ScanNetBenchmark):
160
+ def _post_init(self):
161
+ self.test_pairs = osp.join(self.data_root, "test.npz")
162
+ self.benchmark_name = "ScanNet1500_F"
163
+ self.model = "fundamental"
imcui/third_party/dad/dad/checkpoint.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.parallel.data_parallel import DataParallel
3
+ from torch.nn.parallel.distributed import DistributedDataParallel
4
+ import gc
5
+ from pathlib import Path
6
+ import dad
7
+ from dad.types import Detector
8
+
9
+ class CheckPoint:
10
+ def __init__(self, dir):
11
+ self.dir = Path(dir)
12
+ self.dir.mkdir(parents=True, exist_ok=True)
13
+
14
+ def save(
15
+ self,
16
+ model: Detector,
17
+ optimizer,
18
+ lr_scheduler,
19
+ n,
20
+ ):
21
+ assert model is not None
22
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
23
+ model = model.module
24
+ states = {
25
+ "model": model.state_dict(),
26
+ "n": n,
27
+ "optimizer": optimizer.state_dict(),
28
+ "lr_scheduler": lr_scheduler.state_dict(),
29
+ }
30
+ torch.save(states, self.dir / "model_latest.pth")
31
+ dad.logger.info(f"Saved states {list(states.keys())}, at step {n}")
32
+
33
+ def load(
34
+ self,
35
+ model: Detector,
36
+ optimizer,
37
+ lr_scheduler,
38
+ n,
39
+ ):
40
+ if not (self.dir / "model_latest.pth").exists():
41
+ return model, optimizer, lr_scheduler, n
42
+
43
+ states = torch.load(self.dir / "model_latest.pth")
44
+ if "model" in states:
45
+ model.load_state_dict(states["model"])
46
+ if "n" in states:
47
+ n = states["n"] if states["n"] else n
48
+ if "optimizer" in states:
49
+ try:
50
+ optimizer.load_state_dict(states["optimizer"])
51
+ except Exception as e:
52
+ dad.logger.warning(
53
+ f"Failed to load states for optimizer, with error {e}"
54
+ )
55
+ if "lr_scheduler" in states:
56
+ lr_scheduler.load_state_dict(states["lr_scheduler"])
57
+ dad.logger.info(f"Loaded states {list(states.keys())}, at step {n}")
58
+ del states
59
+ gc.collect()
60
+ torch.cuda.empty_cache()
61
+ return model, optimizer, lr_scheduler, n
imcui/third_party/dad/dad/datasets/__init__.py ADDED
File without changes
imcui/third_party/dad/dad/datasets/megadepth.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import h5py
4
+ import math
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms.functional as tvf
8
+ from tqdm import tqdm
9
+
10
+ import dad
11
+ from dad.augs import (
12
+ get_tuple_transform_ops,
13
+ get_depth_tuple_transform_ops,
14
+ )
15
+ from torch.utils.data import ConcatDataset
16
+
17
+
18
+ class MegadepthScene:
19
+ def __init__(
20
+ self,
21
+ data_root,
22
+ scene_info,
23
+ scene_name=None,
24
+ min_overlap=0.0,
25
+ max_overlap=1.0,
26
+ image_size=640,
27
+ normalize=True,
28
+ shake_t=32,
29
+ rot_360=False,
30
+ max_num_pairs=100_000,
31
+ ) -> None:
32
+ self.data_root = data_root
33
+ self.scene_name = (
34
+ os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}"
35
+ )
36
+ self.image_paths = scene_info["image_paths"]
37
+ self.depth_paths = scene_info["depth_paths"]
38
+ self.intrinsics = scene_info["intrinsics"]
39
+ self.poses = scene_info["poses"]
40
+ self.pairs = scene_info["pairs"]
41
+ self.overlaps = scene_info["overlaps"]
42
+ threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
43
+ self.pairs = self.pairs[threshold]
44
+ self.overlaps = self.overlaps[threshold]
45
+ if len(self.pairs) > max_num_pairs:
46
+ pairinds = np.random.choice(
47
+ np.arange(0, len(self.pairs)), max_num_pairs, replace=False
48
+ )
49
+ self.pairs = self.pairs[pairinds]
50
+ self.overlaps = self.overlaps[pairinds]
51
+ self.im_transform_ops = get_tuple_transform_ops(
52
+ resize=(image_size, image_size),
53
+ normalize=normalize,
54
+ )
55
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
56
+ resize=(image_size, image_size), normalize=False
57
+ )
58
+ self.image_size = image_size
59
+ self.shake_t = shake_t
60
+ self.rot_360 = rot_360
61
+
62
+ def load_im(self, im_B, crop=None):
63
+ im = Image.open(im_B)
64
+ return im
65
+
66
+ def rot_360_deg(self, im, depth, K, angle):
67
+ C, H, W = im.shape
68
+ im = tvf.rotate(im, angle, expand=True)
69
+ depth = tvf.rotate(depth, angle, expand=True)
70
+ radians = angle * math.pi / 180
71
+ rot_mat = torch.tensor(
72
+ [
73
+ [math.cos(radians), math.sin(radians), 0],
74
+ [-math.sin(radians), math.cos(radians), 0],
75
+ [0, 0, 1.0],
76
+ ]
77
+ ).to(K.device)
78
+ t_mat = torch.tensor([[1, 0, W / 2], [0, 1, H / 2], [0, 0, 1.0]]).to(K.device)
79
+ neg_t_mat = torch.tensor([[1, 0, -W / 2], [0, 1, -H / 2], [0, 0, 1.0]]).to(
80
+ K.device
81
+ )
82
+ transform = t_mat @ rot_mat @ neg_t_mat
83
+ K = transform @ K
84
+ return im, depth, K, transform
85
+
86
+ def load_depth(self, depth_ref, crop=None):
87
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
88
+ return torch.from_numpy(depth)
89
+
90
+ def __len__(self):
91
+ return len(self.pairs)
92
+
93
+ def scale_intrinsic(self, K, wi, hi):
94
+ sx, sy = self.image_size / wi, self.image_size / hi
95
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
96
+ return sK @ K
97
+
98
+ def rand_shake(self, *things):
99
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=(2))
100
+ return [
101
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
102
+ for thing in things
103
+ ], t
104
+
105
+ def __getitem__(self, pair_idx):
106
+ try:
107
+ # read intrinsics of original size
108
+ idx1, idx2 = self.pairs[pair_idx]
109
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(
110
+ 3, 3
111
+ )
112
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(
113
+ 3, 3
114
+ )
115
+
116
+ # read and compute relative poses
117
+ T1 = self.poses[idx1]
118
+ T2 = self.poses[idx2]
119
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
120
+ :4, :4
121
+ ] # (4, 4)
122
+
123
+ # Load positive pair data
124
+ im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
125
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
126
+ im_A_ref = os.path.join(self.data_root, im_A)
127
+ im_B_ref = os.path.join(self.data_root, im_B)
128
+ depth_A_ref = os.path.join(self.data_root, depth1)
129
+ depth_B_ref = os.path.join(self.data_root, depth2)
130
+ im_A: Image.Image = self.load_im(im_A_ref)
131
+ im_B: Image.Image = self.load_im(im_B_ref)
132
+ depth_A = self.load_depth(depth_A_ref)
133
+ depth_B = self.load_depth(depth_B_ref)
134
+
135
+ # Recompute camera intrinsic matrix due to the resize
136
+ W_A, H_A = im_A.width, im_A.height
137
+ W_B, H_B = im_B.width, im_B.height
138
+
139
+ K1 = self.scale_intrinsic(K1, W_A, H_A)
140
+ K2 = self.scale_intrinsic(K2, W_B, H_B)
141
+
142
+ # Process images
143
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
144
+ depth_A, depth_B = self.depth_transform_ops(
145
+ (depth_A[None, None], depth_B[None, None])
146
+ )
147
+ [im_A, depth_A], t_A = self.rand_shake(im_A, depth_A)
148
+ [im_B, depth_B], t_B = self.rand_shake(im_B, depth_B)
149
+
150
+ K1[:2, 2] += t_A
151
+ K2[:2, 2] += t_B
152
+
153
+ if self.rot_360:
154
+ angle_A = np.random.choice([-90, 0, 90, 180])
155
+ angle_B = np.random.choice([-90, 0, 90, 180])
156
+ angle_A, angle_B = int(angle_A), int(angle_B)
157
+ im_A, depth_A, K1, _ = self.rot_360_deg(
158
+ im_A, depth_A, K1, angle=angle_A
159
+ )
160
+ im_B, depth_B, K2, _ = self.rot_360_deg(
161
+ im_B, depth_B, K2, angle=angle_B
162
+ )
163
+ else:
164
+ angle_A = 0
165
+ angle_B = 0
166
+ data_dict = {
167
+ "im_A": im_A,
168
+ "im_A_identifier": self.image_paths[idx1]
169
+ .split("/")[-1]
170
+ .split(".jpg")[0],
171
+ "im_B": im_B,
172
+ "im_B_identifier": self.image_paths[idx2]
173
+ .split("/")[-1]
174
+ .split(".jpg")[0],
175
+ "im_A_depth": depth_A[0, 0],
176
+ "im_B_depth": depth_B[0, 0],
177
+ "pose_A": T1,
178
+ "pose_B": T2,
179
+ "K1": K1,
180
+ "K2": K2,
181
+ "T_1to2": T_1to2,
182
+ "im_A_path": im_A_ref,
183
+ "im_B_path": im_B_ref,
184
+ "angle_A": angle_A,
185
+ "angle_B": angle_B,
186
+ }
187
+ except Exception as e:
188
+ dad.logger.warning(e)
189
+ dad.logger.warning(f"Failed to load image pair {self.pairs[pair_idx]}")
190
+ dad.logger.warning("Loading a random pair in scene instead")
191
+ rand_ind = np.random.choice(range(len(self)))
192
+ return self[rand_ind]
193
+ return data_dict
194
+
195
+
196
+ class MegadepthBuilder:
197
+ def __init__(self, data_root, loftr_ignore=True, imc21_ignore=True) -> None:
198
+ self.data_root = data_root
199
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
200
+ self.all_scenes = os.listdir(self.scene_info_root)
201
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
202
+ # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
203
+ self.loftr_ignore_scenes = set(
204
+ [
205
+ "0121.npy",
206
+ "0133.npy",
207
+ "0168.npy",
208
+ "0178.npy",
209
+ "0229.npy",
210
+ "0349.npy",
211
+ "0412.npy",
212
+ "0430.npy",
213
+ "0443.npy",
214
+ "1001.npy",
215
+ "5014.npy",
216
+ "5015.npy",
217
+ "5016.npy",
218
+ ]
219
+ )
220
+ self.imc21_scenes = set(
221
+ [
222
+ "0008.npy",
223
+ "0019.npy",
224
+ "0021.npy",
225
+ "0024.npy",
226
+ "0025.npy",
227
+ "0032.npy",
228
+ "0063.npy",
229
+ "1589.npy",
230
+ ]
231
+ )
232
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
233
+ self.loftr_ignore = loftr_ignore
234
+ self.imc21_ignore = imc21_ignore
235
+
236
+ def build_scenes(self, split, **kwargs):
237
+ if split == "train":
238
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
239
+ elif split == "train_loftr":
240
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
241
+ elif split == "test":
242
+ scene_names = self.test_scenes
243
+ elif split == "test_loftr":
244
+ scene_names = self.test_scenes_loftr
245
+ elif split == "all_scenes":
246
+ scene_names = self.all_scenes
247
+ elif split == "custom":
248
+ scene_names = scene_names
249
+ else:
250
+ raise ValueError(f"Split {split} not available")
251
+ scenes = []
252
+ for scene_name in tqdm(scene_names):
253
+ if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
254
+ continue
255
+ if self.imc21_ignore and scene_name in self.imc21_scenes:
256
+ continue
257
+ if ".npy" not in scene_name:
258
+ continue
259
+ scene_info = np.load(
260
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
261
+ ).item()
262
+
263
+ scenes.append(
264
+ MegadepthScene(
265
+ self.data_root,
266
+ scene_info,
267
+ scene_name=scene_name,
268
+ **kwargs,
269
+ )
270
+ )
271
+ return scenes
272
+
273
+ def weight_scenes(self, concat_dataset, alpha=0.5):
274
+ ns = []
275
+ for d in concat_dataset.datasets:
276
+ ns.append(len(d))
277
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
278
+ return ws
279
+
280
+ def dedode_train_split(self, **kwargs):
281
+ megadepth_train1 = self.build_scenes(
282
+ split="train_loftr", min_overlap=0.01, **kwargs
283
+ )
284
+ megadepth_train2 = self.build_scenes(
285
+ split="train_loftr", min_overlap=0.35, **kwargs
286
+ )
287
+
288
+ megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
289
+ return megadepth_train
290
+
291
+ def hard_train_split(self, **kwargs):
292
+ megadepth_train = self.build_scenes(
293
+ split="train_loftr", min_overlap=0.01, **kwargs
294
+ )
295
+ megadepth_train = ConcatDataset(megadepth_train)
296
+ return megadepth_train
297
+
298
+ def easy_train_split(self, **kwargs):
299
+ megadepth_train = self.build_scenes(
300
+ split="train_loftr", min_overlap=0.35, **kwargs
301
+ )
302
+ megadepth_train = ConcatDataset(megadepth_train)
303
+ return megadepth_train
304
+
305
+ def dedode_test_split(self, **kwargs):
306
+ megadepth_test = self.build_scenes(
307
+ split="test_loftr",
308
+ min_overlap=0.01,
309
+ **kwargs,
310
+ )
311
+ megadepth_test = ConcatDataset(megadepth_test)
312
+ return megadepth_test
imcui/third_party/dad/dad/detectors/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dedode_detector import load_DaD as load_DaD
2
+ from .dedode_detector import load_DaDDark as load_DaDDark
3
+ from .dedode_detector import load_DaDLight as load_DaDLight
4
+ from .dedode_detector import dedode_detector_S as dedode_detector_S
5
+ from .dedode_detector import dedode_detector_B as dedode_detector_B
6
+ from .dedode_detector import dedode_detector_L as dedode_detector_L
7
+ from .dedode_detector import load_dedode_v2 as load_dedode_v2
8
+
9
+
10
+ lg_detectors = ["ALIKED", "ALIKEDROT", "SIFT", "DISK", "SuperPoint", "ReinforcedFP"]
11
+ other_detectors = ["HesAff", "HarrisAff", "REKD"]
12
+ dedode_detectors = [
13
+ "DeDoDe-v2",
14
+ "DaD",
15
+ "DaDLight",
16
+ "DaDDark",
17
+ ]
18
+ all_detectors = lg_detectors + dedode_detectors + other_detectors
19
+
20
+
21
+ def load_detector_by_name(detector_name, *, resize=1024, weights_path=None):
22
+ if detector_name == "DaD":
23
+ detector = load_DaD(resize=resize, weights_path=weights_path)
24
+ elif detector_name == "DaDLight":
25
+ detector = load_DaDLight(resize=resize, weights_path=weights_path)
26
+ elif detector_name == "DaDDark":
27
+ detector = load_DaDDark(resize=resize, weights_path=weights_path)
28
+ elif detector_name == "DeDoDe-v2":
29
+ detector = load_dedode_v2()
30
+ elif detector_name in lg_detectors:
31
+ from .third_party import lightglue, LightGlueDetector
32
+
33
+ detector = LightGlueDetector(
34
+ getattr(lightglue, detector_name), detection_threshold=0, resize=resize
35
+ )
36
+ elif detector_name == "HesAff":
37
+ from .third_party import HesAff
38
+
39
+ detector = HesAff()
40
+ elif detector_name == "HarrisAff":
41
+ from .third_party import HarrisAff
42
+
43
+ detector = HarrisAff()
44
+ elif detector_name == "REKD":
45
+ from .third_party import load_REKD
46
+
47
+ detector = load_REKD(resize=resize)
48
+ else:
49
+ raise ValueError(f"Couldn't find detector with detector name {detector_name}")
50
+ return detector
imcui/third_party/dad/dad/detectors/dedode_detector.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.models as tvm
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from dad.utils import get_best_device, sample_keypoints, check_not_i16
10
+
11
+ from dad.types import Detector
12
+
13
+
14
+ class DeDoDeDetector(Detector):
15
+ def __init__(
16
+ self,
17
+ *args,
18
+ encoder: nn.Module,
19
+ decoder: nn.Module,
20
+ resize: int,
21
+ nms_size: int,
22
+ subpixel: bool,
23
+ subpixel_temp: float,
24
+ keep_aspect_ratio: bool,
25
+ remove_borders: bool,
26
+ increase_coverage: bool,
27
+ coverage_pow: float,
28
+ coverage_size: int,
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(*args, **kwargs)
32
+ self.normalizer = transforms.Normalize(
33
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
34
+ )
35
+ self.encoder = encoder
36
+ self.decoder = decoder
37
+ self.remove_borders = remove_borders
38
+ self.resize = resize
39
+ self.increase_coverage = increase_coverage
40
+ self.coverage_pow = coverage_pow
41
+ self.coverage_size = coverage_size
42
+ self.nms_size = nms_size
43
+ self.keep_aspect_ratio = keep_aspect_ratio
44
+ self.subpixel = subpixel
45
+ self.subpixel_temp = subpixel_temp
46
+
47
+ @property
48
+ def topleft(self):
49
+ return 0.5
50
+
51
+ def forward_impl(
52
+ self,
53
+ images,
54
+ ):
55
+ features, sizes = self.encoder(images)
56
+ logits = 0
57
+ context = None
58
+ scales = ["8", "4", "2", "1"]
59
+ for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
60
+ delta_logits, context = self.decoder(
61
+ feature_map, context=context, scale=scale
62
+ )
63
+ logits = (
64
+ logits + delta_logits.float()
65
+ ) # ensure float (need bf16 doesnt have f.interpolate)
66
+ if idx < len(scales) - 1:
67
+ size = sizes[-(idx + 2)]
68
+ logits = F.interpolate(
69
+ logits, size=size, mode="bicubic", align_corners=False
70
+ )
71
+ context = F.interpolate(
72
+ context.float(), size=size, mode="bilinear", align_corners=False
73
+ )
74
+ return logits.float()
75
+
76
+ def forward(self, batch) -> dict[str, torch.Tensor]:
77
+ # wraps internal forward impl to handle
78
+ # different types of batches etc.
79
+ if "im_A" in batch:
80
+ images = torch.cat((batch["im_A"], batch["im_B"]))
81
+ else:
82
+ images = batch["image"]
83
+ scoremap = self.forward_impl(images)
84
+ return {"scoremap": scoremap}
85
+
86
+ @torch.inference_mode()
87
+ def detect(
88
+ self, batch, *, num_keypoints, return_dense_probs=False
89
+ ) -> dict[str, torch.Tensor]:
90
+ self.train(False)
91
+ scoremap = self.forward(batch)["scoremap"]
92
+ B, K, H, W = scoremap.shape
93
+ dense_probs = (
94
+ scoremap.reshape(B, K * H * W)
95
+ .softmax(dim=-1)
96
+ .reshape(B, K, H * W)
97
+ .sum(dim=1)
98
+ )
99
+ dense_probs = dense_probs.reshape(B, H, W)
100
+ keypoints, confidence = sample_keypoints(
101
+ dense_probs,
102
+ use_nms=True,
103
+ nms_size=self.nms_size,
104
+ sample_topk=True,
105
+ num_samples=num_keypoints,
106
+ return_probs=True,
107
+ increase_coverage=self.increase_coverage,
108
+ remove_borders=self.remove_borders,
109
+ coverage_pow=self.coverage_pow,
110
+ coverage_size=self.coverage_size,
111
+ subpixel=self.subpixel,
112
+ subpixel_temp=self.subpixel_temp,
113
+ scoremap=scoremap.reshape(B, H, W),
114
+ )
115
+ result = {"keypoints": keypoints, "keypoint_probs": confidence}
116
+ if return_dense_probs:
117
+ result["dense_probs"] = dense_probs
118
+ return result
119
+
120
+ def load_image(self, im_path, device=get_best_device()) -> dict[str, torch.Tensor]:
121
+ pil_im = Image.open(im_path)
122
+ check_not_i16(pil_im)
123
+ pil_im = pil_im.convert("RGB")
124
+ if self.keep_aspect_ratio:
125
+ W, H = pil_im.size
126
+ scale = self.resize / max(W, H)
127
+ W = int((scale * W) // 8 * 8)
128
+ H = int((scale * H) // 8 * 8)
129
+ else:
130
+ H, W = self.resize, self.resize
131
+ pil_im = pil_im.resize((W, H))
132
+ standard_im = np.array(pil_im) / 255.0
133
+ return {
134
+ "image": self.normalizer(torch.from_numpy(standard_im).permute(2, 0, 1))
135
+ .float()
136
+ .to(device)[None]
137
+ }
138
+
139
+
140
+ class Decoder(nn.Module):
141
+ def __init__(
142
+ self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs
143
+ ) -> None:
144
+ super().__init__(*args, **kwargs)
145
+ self.layers = layers
146
+ self.scales = self.layers.keys()
147
+ self.super_resolution = super_resolution
148
+ self.num_prototypes = num_prototypes
149
+
150
+ def forward(self, features, context=None, scale=None):
151
+ if context is not None:
152
+ features = torch.cat((features, context), dim=1)
153
+ stuff = self.layers[scale](features)
154
+ logits, context = (
155
+ stuff[:, : self.num_prototypes],
156
+ stuff[:, self.num_prototypes :],
157
+ )
158
+ return logits, context
159
+
160
+
161
+ class ConvRefiner(nn.Module):
162
+ def __init__(
163
+ self,
164
+ in_dim=6,
165
+ hidden_dim=16,
166
+ out_dim=2,
167
+ dw=True,
168
+ kernel_size=5,
169
+ hidden_blocks=5,
170
+ amp=True,
171
+ residual=False,
172
+ amp_dtype=torch.float16,
173
+ ):
174
+ super().__init__()
175
+ self.block1 = self.create_block(
176
+ in_dim,
177
+ hidden_dim,
178
+ dw=False,
179
+ kernel_size=1,
180
+ )
181
+ self.hidden_blocks = nn.Sequential(
182
+ *[
183
+ self.create_block(
184
+ hidden_dim,
185
+ hidden_dim,
186
+ dw=dw,
187
+ kernel_size=kernel_size,
188
+ )
189
+ for hb in range(hidden_blocks)
190
+ ]
191
+ )
192
+ self.hidden_blocks = self.hidden_blocks
193
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
194
+ self.amp = amp
195
+ self.amp_dtype = amp_dtype
196
+ self.residual = residual
197
+
198
+ def create_block(
199
+ self,
200
+ in_dim,
201
+ out_dim,
202
+ dw=True,
203
+ kernel_size=5,
204
+ bias=True,
205
+ norm_type=nn.BatchNorm2d,
206
+ ):
207
+ num_groups = 1 if not dw else in_dim
208
+ if dw:
209
+ assert out_dim % in_dim == 0, (
210
+ "outdim must be divisible by indim for depthwise"
211
+ )
212
+ conv1 = nn.Conv2d(
213
+ in_dim,
214
+ out_dim,
215
+ kernel_size=kernel_size,
216
+ stride=1,
217
+ padding=kernel_size // 2,
218
+ groups=num_groups,
219
+ bias=bias,
220
+ )
221
+ norm = (
222
+ norm_type(out_dim)
223
+ if norm_type is nn.BatchNorm2d
224
+ else norm_type(num_channels=out_dim)
225
+ )
226
+ relu = nn.ReLU(inplace=True)
227
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
228
+ return nn.Sequential(conv1, norm, relu, conv2)
229
+
230
+ def forward(self, feats):
231
+ b, c, hs, ws = feats.shape
232
+ with torch.autocast(device_type=feats.device.type, enabled=self.amp, dtype=self.amp_dtype):
233
+ x0 = self.block1(feats)
234
+ x = self.hidden_blocks(x0)
235
+ if self.residual:
236
+ x = (x + x0) / 1.4
237
+ x = self.out_conv(x)
238
+ return x
239
+
240
+
241
+ class VGG19(nn.Module):
242
+ def __init__(self, amp=False, amp_dtype=torch.float16) -> None:
243
+ super().__init__()
244
+ self.layers = nn.ModuleList(tvm.vgg19_bn().features[:40])
245
+ # Maxpool layers: 6, 13, 26, 39
246
+ self.amp = amp
247
+ self.amp_dtype = amp_dtype
248
+
249
+ def forward(self, x, **kwargs):
250
+ with torch.autocast(device_type=x.device.type, enabled=self.amp, dtype=self.amp_dtype):
251
+ feats = []
252
+ sizes = []
253
+ for layer in self.layers:
254
+ if isinstance(layer, nn.MaxPool2d):
255
+ feats.append(x)
256
+ sizes.append(x.shape[-2:])
257
+ x = layer(x)
258
+ return feats, sizes
259
+
260
+
261
+ class VGG(nn.Module):
262
+ def __init__(self, size="19", amp=False, amp_dtype=torch.float16) -> None:
263
+ super().__init__()
264
+ if size == "11":
265
+ self.layers = nn.ModuleList(tvm.vgg11_bn().features[:22])
266
+ elif size == "13":
267
+ self.layers = nn.ModuleList(tvm.vgg13_bn().features[:28])
268
+ elif size == "19":
269
+ self.layers = nn.ModuleList(tvm.vgg19_bn().features[:40])
270
+ # Maxpool layers: 6, 13, 26, 39
271
+ self.amp = amp
272
+ self.amp_dtype = amp_dtype
273
+
274
+ def forward(self, x, **kwargs):
275
+ with torch.autocast(device_type=x.device.type, enabled=self.amp, dtype=self.amp_dtype):
276
+ feats = []
277
+ sizes = []
278
+ for layer in self.layers:
279
+ if isinstance(layer, nn.MaxPool2d):
280
+ feats.append(x)
281
+ sizes.append(x.shape[-2:])
282
+ x = layer(x)
283
+ return feats, sizes
284
+
285
+
286
+ def dedode_detector_S():
287
+ residual = True
288
+ hidden_blocks = 3
289
+ amp_dtype = torch.float16
290
+ amp = True
291
+ NUM_PROTOTYPES = 1
292
+ conv_refiner = nn.ModuleDict(
293
+ {
294
+ "8": ConvRefiner(
295
+ 512,
296
+ 512,
297
+ 256 + NUM_PROTOTYPES,
298
+ hidden_blocks=hidden_blocks,
299
+ residual=residual,
300
+ amp=amp,
301
+ amp_dtype=amp_dtype,
302
+ ),
303
+ "4": ConvRefiner(
304
+ 256 + 256,
305
+ 256,
306
+ 128 + NUM_PROTOTYPES,
307
+ hidden_blocks=hidden_blocks,
308
+ residual=residual,
309
+ amp=amp,
310
+ amp_dtype=amp_dtype,
311
+ ),
312
+ "2": ConvRefiner(
313
+ 128 + 128,
314
+ 64,
315
+ 32 + NUM_PROTOTYPES,
316
+ hidden_blocks=hidden_blocks,
317
+ residual=residual,
318
+ amp=amp,
319
+ amp_dtype=amp_dtype,
320
+ ),
321
+ "1": ConvRefiner(
322
+ 64 + 32,
323
+ 32,
324
+ 1 + NUM_PROTOTYPES,
325
+ hidden_blocks=hidden_blocks,
326
+ residual=residual,
327
+ amp=amp,
328
+ amp_dtype=amp_dtype,
329
+ ),
330
+ }
331
+ )
332
+ encoder = VGG(size="11", amp=amp, amp_dtype=amp_dtype)
333
+ decoder = Decoder(conv_refiner)
334
+ return encoder, decoder
335
+
336
+
337
+ def dedode_detector_B():
338
+ residual = True
339
+ hidden_blocks = 5
340
+ amp_dtype = torch.float16
341
+ amp = True
342
+ NUM_PROTOTYPES = 1
343
+ conv_refiner = nn.ModuleDict(
344
+ {
345
+ "8": ConvRefiner(
346
+ 512,
347
+ 512,
348
+ 256 + NUM_PROTOTYPES,
349
+ hidden_blocks=hidden_blocks,
350
+ residual=residual,
351
+ amp=amp,
352
+ amp_dtype=amp_dtype,
353
+ ),
354
+ "4": ConvRefiner(
355
+ 256 + 256,
356
+ 256,
357
+ 128 + NUM_PROTOTYPES,
358
+ hidden_blocks=hidden_blocks,
359
+ residual=residual,
360
+ amp=amp,
361
+ amp_dtype=amp_dtype,
362
+ ),
363
+ "2": ConvRefiner(
364
+ 128 + 128,
365
+ 64,
366
+ 32 + NUM_PROTOTYPES,
367
+ hidden_blocks=hidden_blocks,
368
+ residual=residual,
369
+ amp=amp,
370
+ amp_dtype=amp_dtype,
371
+ ),
372
+ "1": ConvRefiner(
373
+ 64 + 32,
374
+ 32,
375
+ 1 + NUM_PROTOTYPES,
376
+ hidden_blocks=hidden_blocks,
377
+ residual=residual,
378
+ amp=amp,
379
+ amp_dtype=amp_dtype,
380
+ ),
381
+ }
382
+ )
383
+ encoder = VGG19(amp=amp, amp_dtype=amp_dtype)
384
+ decoder = Decoder(conv_refiner)
385
+ return encoder, decoder
386
+
387
+
388
+ def dedode_detector_L():
389
+ NUM_PROTOTYPES = 1
390
+ residual = True
391
+ hidden_blocks = 8
392
+ amp_dtype = (
393
+ torch.float16
394
+ ) # torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
395
+ amp = True
396
+ conv_refiner = nn.ModuleDict(
397
+ {
398
+ "8": ConvRefiner(
399
+ 512,
400
+ 512,
401
+ 256 + NUM_PROTOTYPES,
402
+ hidden_blocks=hidden_blocks,
403
+ residual=residual,
404
+ amp=amp,
405
+ amp_dtype=amp_dtype,
406
+ ),
407
+ "4": ConvRefiner(
408
+ 256 + 256,
409
+ 256,
410
+ 128 + NUM_PROTOTYPES,
411
+ hidden_blocks=hidden_blocks,
412
+ residual=residual,
413
+ amp=amp,
414
+ amp_dtype=amp_dtype,
415
+ ),
416
+ "2": ConvRefiner(
417
+ 128 + 128,
418
+ 128,
419
+ 64 + NUM_PROTOTYPES,
420
+ hidden_blocks=hidden_blocks,
421
+ residual=residual,
422
+ amp=amp,
423
+ amp_dtype=amp_dtype,
424
+ ),
425
+ "1": ConvRefiner(
426
+ 64 + 64,
427
+ 64,
428
+ 1 + NUM_PROTOTYPES,
429
+ hidden_blocks=hidden_blocks,
430
+ residual=residual,
431
+ amp=amp,
432
+ amp_dtype=amp_dtype,
433
+ ),
434
+ }
435
+ )
436
+ encoder = VGG19(amp=amp, amp_dtype=amp_dtype)
437
+ decoder = Decoder(conv_refiner)
438
+ return encoder, decoder
439
+
440
+
441
+ class DaD(DeDoDeDetector):
442
+ def __init__(
443
+ self,
444
+ encoder: nn.Module,
445
+ decoder: nn.Module,
446
+ *args,
447
+ resize=1024,
448
+ nms_size=3,
449
+ remove_borders=False,
450
+ increase_coverage=False,
451
+ coverage_pow=None,
452
+ coverage_size=None,
453
+ subpixel=True,
454
+ subpixel_temp=0.5,
455
+ keep_aspect_ratio=True,
456
+ **kwargs,
457
+ ) -> None:
458
+ super().__init__(
459
+ *args,
460
+ encoder=encoder,
461
+ decoder=decoder,
462
+ resize=resize,
463
+ nms_size=nms_size,
464
+ remove_borders=remove_borders,
465
+ increase_coverage=increase_coverage,
466
+ coverage_pow=coverage_pow,
467
+ coverage_size=coverage_size,
468
+ subpixel=subpixel,
469
+ keep_aspect_ratio=keep_aspect_ratio,
470
+ subpixel_temp=subpixel_temp,
471
+ **kwargs,
472
+ )
473
+
474
+
475
+ class DeDoDev2(DeDoDeDetector):
476
+ def __init__(
477
+ self,
478
+ encoder: nn.Module,
479
+ decoder: nn.Module,
480
+ *args,
481
+ resize=784,
482
+ nms_size=3,
483
+ remove_borders=False,
484
+ increase_coverage=True,
485
+ coverage_pow=0.5,
486
+ coverage_size=51,
487
+ subpixel=False,
488
+ subpixel_temp=None,
489
+ keep_aspect_ratio=False,
490
+ **kwargs,
491
+ ) -> None:
492
+ super().__init__(
493
+ *args,
494
+ encoder=encoder,
495
+ decoder=decoder,
496
+ resize=resize,
497
+ nms_size=nms_size,
498
+ remove_borders=remove_borders,
499
+ increase_coverage=increase_coverage,
500
+ coverage_pow=coverage_pow,
501
+ coverage_size=coverage_size,
502
+ subpixel=subpixel,
503
+ keep_aspect_ratio=keep_aspect_ratio,
504
+ subpixel_temp=subpixel_temp,
505
+ **kwargs,
506
+ )
507
+
508
+
509
+ def load_DaD(resize=1024, pretrained=True, weights_path=None) -> DaD:
510
+ if weights_path is None:
511
+ weights_path = (
512
+ "https://github.com/Parskatt/dad/releases/download/v0.1.0/dad.pth"
513
+ )
514
+ device = get_best_device()
515
+ encoder, decoder = dedode_detector_S()
516
+ model = DaD(encoder, decoder, resize=resize).to(device)
517
+ if pretrained:
518
+ weights = torch.hub.load_state_dict_from_url(
519
+ weights_path, weights_only=False, map_location=device
520
+ )
521
+ model.load_state_dict(weights)
522
+ return model
523
+
524
+
525
+ def load_DaDLight(resize=1024, weights_path=None) -> DaD:
526
+ if weights_path is None:
527
+ weights_path = (
528
+ "https://github.com/Parskatt/dad/releases/download/v0.1.0/dad_light.pth"
529
+ )
530
+ return load_DaD(
531
+ resize=resize,
532
+ pretrained=True,
533
+ weights_path=weights_path,
534
+ )
535
+
536
+
537
+ def load_DaDDark(resize=1024, weights_path=None) -> DaD:
538
+ if weights_path is None:
539
+ weights_path = (
540
+ "https://github.com/Parskatt/dad/releases/download/v0.1.0/dad_dark.pth"
541
+ )
542
+ return load_DaD(
543
+ resize=resize,
544
+ pretrained=True,
545
+ weights_path=weights_path,
546
+ )
547
+
548
+
549
+ def load_dedode_v2() -> DeDoDev2:
550
+ device = get_best_device()
551
+ weights = torch.hub.load_state_dict_from_url(
552
+ "https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_L_v2.pth",
553
+ map_location=device,
554
+ )
555
+
556
+ encoder, decoder = dedode_detector_L()
557
+ model = DeDoDev2(encoder, decoder).to(device)
558
+ model.load_state_dict(weights)
559
+ return model
imcui/third_party/dad/dad/detectors/third_party/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .lightglue_detector import LightGlueDetector as LightGlueDetector
2
+ from .lightglue import SuperPoint as SuperPoint
3
+ from .lightglue import ReinforcedFP as ReinforcedFP
4
+ from .lightglue import DISK as DISK
5
+ from .lightglue import ALIKED as ALIKED
6
+ from .lightglue import ALIKEDROT as ALIKEDROT
7
+ from .lightglue import SIFT as SIFT
8
+ from .lightglue import DoGHardNet as DoGHardNet
9
+ from .hesaff import HesAff as HesAff
10
+ from .harrisaff import HarrisAff as HarrisAff
11
+ from .rekd.rekd import load_REKD as load_REKD
imcui/third_party/dad/dad/detectors/third_party/harrisaff.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from dad.types import Detector
5
+ import cv2
6
+
7
+ from dad.utils import get_best_device
8
+
9
+
10
+ class HarrisAff(Detector):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.detector = cv2.xfeatures2d.HarrisLaplaceFeatureDetector_create(
14
+ numOctaves=6, corn_thresh=0.0, DOG_thresh=0.0, maxCorners=8192, num_layers=4
15
+ )
16
+
17
+ @property
18
+ def topleft(self):
19
+ return 0.0
20
+
21
+ def load_image(self, im_path):
22
+ return {"image": cv2.imread(im_path, cv2.IMREAD_GRAYSCALE)}
23
+
24
+ @torch.inference_mode()
25
+ def detect(self, batch, *, num_keypoints, return_dense_probs=False) -> dict[str, torch.Tensor]:
26
+ img = batch["image"]
27
+ H, W = img.shape
28
+ # Detect keypoints
29
+ kps = self.detector.detect(img)
30
+ kps = np.array([kp.pt for kp in kps])[:num_keypoints]
31
+ kps_n = self.to_normalized_coords(torch.from_numpy(kps), H, W)[None]
32
+ detections = {"keypoints": kps_n.to(get_best_device()).float(), "keypoint_probs": None}
33
+ if return_dense_probs:
34
+ detections["dense_probs"] = None
35
+ return detections
imcui/third_party/dad/dad/detectors/third_party/hesaff.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+
5
+ from dad.utils import get_best_device
6
+ from dad.types import Detector
7
+
8
+
9
+ class HesAff(Detector):
10
+ def __init__(self):
11
+ raise NotImplementedError("Buggy implementation, don't use.")
12
+ super().__init__()
13
+ import pyhesaff
14
+
15
+ self.params = pyhesaff.get_hesaff_default_params()
16
+
17
+ @property
18
+ def topleft(self):
19
+ return 0.0
20
+
21
+ def load_image(self, im_path):
22
+ # pyhesaff doesn't seem to have a decoupled image loading and detection stage
23
+ # so load_image here is just identity
24
+ return {"image": im_path}
25
+
26
+ def detect(self, batch, *, num_keypoints, return_dense_probs=False):
27
+ import pyhesaff
28
+
29
+ im_path = batch["image"]
30
+ W, H = Image.open(im_path).size
31
+ detections = pyhesaff.detect_feats(im_path)[0][:num_keypoints]
32
+ kps = detections[..., :2]
33
+ kps_n = self.to_normalized_coords(torch.from_numpy(kps), H, W)[None]
34
+ result = {
35
+ "keypoints": kps_n.to(get_best_device()).float(),
36
+ "keypoint_probs": None,
37
+ }
38
+ if return_dense_probs is not None:
39
+ result["dense_probs"] = None
40
+ return result
imcui/third_party/dad/dad/detectors/third_party/lightglue/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .aliked import ALIKED # noqa
2
+ from .aliked import ALIKEDROT as ALIKEDROT # noqa
3
+ from .disk import DISK # noqa
4
+ from .dog_hardnet import DoGHardNet # noqa
5
+ from .lightglue import LightGlue # noqa
6
+ from .sift import SIFT # noqa
7
+ from .superpoint import SuperPoint # noqa
8
+ from .superpoint import ReinforcedFP # noqa
9
+ from .utils import match_pair # noqa
imcui/third_party/dad/dad/detectors/third_party/lightglue/aliked.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BSD 3-Clause License
2
+
3
+ # Copyright (c) 2022, Zhao Xiaoming
4
+ # All rights reserved.
5
+
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+
31
+ # Authors:
32
+ # Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
33
+ # Code from https://github.com/Shiaoming/ALIKED
34
+
35
+ from typing import Callable, Optional
36
+
37
+ import torch
38
+ import torch.nn.functional as F
39
+ import torchvision
40
+ from kornia.color import grayscale_to_rgb
41
+ from torch import nn
42
+ from torch.nn.modules.utils import _pair
43
+ from torchvision.models import resnet
44
+
45
+ from .utils import Extractor
46
+
47
+
48
+ def get_patches(
49
+ tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
50
+ ) -> torch.Tensor:
51
+ c, h, w = tensor.shape
52
+ corner = (required_corners - ps / 2 + 1).long()
53
+ corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
54
+ corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
55
+ offset = torch.arange(0, ps)
56
+
57
+ kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
58
+ x, y = torch.meshgrid(offset, offset, **kw)
59
+ patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
60
+ patches = patches.to(corner) + corner[None, None]
61
+ pts = patches.reshape(-1, 2)
62
+ sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
63
+ sampled = sampled.reshape(ps, ps, -1, c)
64
+ assert sampled.shape[:3] == patches.shape[:3]
65
+ return sampled.permute(2, 3, 0, 1)
66
+
67
+
68
+ def simple_nms(scores: torch.Tensor, nms_radius: int):
69
+ """Fast Non-maximum suppression to remove nearby points"""
70
+
71
+ zeros = torch.zeros_like(scores)
72
+ max_mask = scores == torch.nn.functional.max_pool2d(
73
+ scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
74
+ )
75
+
76
+ for _ in range(2):
77
+ supp_mask = (
78
+ torch.nn.functional.max_pool2d(
79
+ max_mask.float(),
80
+ kernel_size=nms_radius * 2 + 1,
81
+ stride=1,
82
+ padding=nms_radius,
83
+ )
84
+ > 0
85
+ )
86
+ supp_scores = torch.where(supp_mask, zeros, scores)
87
+ new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
88
+ supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
89
+ )
90
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
91
+ return torch.where(max_mask, scores, zeros)
92
+
93
+
94
+ class DKD(nn.Module):
95
+ def __init__(
96
+ self,
97
+ radius: int = 2,
98
+ top_k: int = 0,
99
+ scores_th: float = 0.2,
100
+ n_limit: int = 20000,
101
+ ):
102
+ """
103
+ Args:
104
+ radius: soft detection radius, kernel size is (2 * radius + 1)
105
+ top_k: top_k > 0: return top k keypoints
106
+ scores_th: top_k <= 0 threshold mode:
107
+ scores_th > 0: return keypoints with scores>scores_th
108
+ else: return keypoints with scores > scores.mean()
109
+ n_limit: max number of keypoint in threshold mode
110
+ """
111
+ super().__init__()
112
+ self.radius = radius
113
+ self.top_k = top_k
114
+ self.scores_th = scores_th
115
+ self.n_limit = n_limit
116
+ self.kernel_size = 2 * self.radius + 1
117
+ self.temperature = 0.1 # tuned temperature
118
+ self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
119
+ # local xy grid
120
+ x = torch.linspace(-self.radius, self.radius, self.kernel_size)
121
+ # (kernel_size*kernel_size) x 2 : (w,h)
122
+ kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
123
+ self.hw_grid = (
124
+ torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
125
+ )
126
+
127
+ def forward(
128
+ self,
129
+ scores_map: torch.Tensor,
130
+ sub_pixel: bool = True,
131
+ image_size: Optional[torch.Tensor] = None,
132
+ ):
133
+ """
134
+ :param scores_map: Bx1xHxW
135
+ :param descriptor_map: BxCxHxW
136
+ :param sub_pixel: whether to use sub-pixel keypoint detection
137
+ :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
138
+ """
139
+ b, c, h, w = scores_map.shape
140
+ scores_nograd = scores_map.detach()
141
+ nms_scores = simple_nms(scores_nograd, self.radius)
142
+
143
+ # remove border
144
+ nms_scores[:, :, : self.radius, :] = 0
145
+ nms_scores[:, :, :, : self.radius] = 0
146
+ if image_size is not None:
147
+ for i in range(scores_map.shape[0]):
148
+ w, h = image_size[i].long()
149
+ nms_scores[i, :, h.item() - self.radius :, :] = 0
150
+ nms_scores[i, :, :, w.item() - self.radius :] = 0
151
+ else:
152
+ nms_scores[:, :, -self.radius :, :] = 0
153
+ nms_scores[:, :, :, -self.radius :] = 0
154
+
155
+ # detect keypoints without grad
156
+ if self.top_k > 0:
157
+ topk = torch.topk(nms_scores.view(b, -1), self.top_k)
158
+ indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
159
+ else:
160
+ if self.scores_th > 0:
161
+ masks = nms_scores > self.scores_th
162
+ if masks.sum() == 0:
163
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
164
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
165
+ else:
166
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
167
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
168
+ masks = masks.reshape(b, -1)
169
+
170
+ indices_keypoints = [] # list, B x (any size)
171
+ scores_view = scores_nograd.reshape(b, -1)
172
+ for mask, scores in zip(masks, scores_view):
173
+ indices = mask.nonzero()[:, 0]
174
+ if len(indices) > self.n_limit:
175
+ kpts_sc = scores[indices]
176
+ sort_idx = kpts_sc.sort(descending=True)[1]
177
+ sel_idx = sort_idx[: self.n_limit]
178
+ indices = indices[sel_idx]
179
+ indices_keypoints.append(indices)
180
+ wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
181
+
182
+ keypoints = []
183
+ scoredispersitys = []
184
+ kptscores = []
185
+ if sub_pixel:
186
+ # detect soft keypoints with grad backpropagation
187
+ patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
188
+ # print(patches.shape)
189
+ self.hw_grid = self.hw_grid.to(scores_map) # to device
190
+ for b_idx in range(b):
191
+ patch = patches[b_idx].t() # (H*W) x (kernel**2)
192
+ indices_kpt = indices_keypoints[
193
+ b_idx
194
+ ] # one dimension vector, say its size is M
195
+ patch_scores = patch[indices_kpt] # M x (kernel**2)
196
+ keypoints_xy_nms = torch.stack(
197
+ [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
198
+ dim=1,
199
+ ) # Mx2
200
+
201
+ # max is detached to prevent undesired backprop loops in the graph
202
+ max_v = patch_scores.max(dim=1).values.detach()[:, None]
203
+ x_exp = (
204
+ (patch_scores - max_v) / self.temperature
205
+ ).exp() # M * (kernel**2), in [0, 1]
206
+
207
+ # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
208
+ xy_residual = (
209
+ x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
210
+ ) # Soft-argmax, Mx2
211
+
212
+ hw_grid_dist2 = (
213
+ torch.norm(
214
+ (self.hw_grid[None, :, :] - xy_residual[:, None, :])
215
+ / self.radius,
216
+ dim=-1,
217
+ )
218
+ ** 2
219
+ )
220
+ scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
221
+
222
+ # compute result keypoints
223
+ keypoints_xy = keypoints_xy_nms + xy_residual
224
+ keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
225
+
226
+ kptscore = torch.nn.functional.grid_sample(
227
+ scores_map[b_idx].unsqueeze(0),
228
+ keypoints_xy.view(1, 1, -1, 2),
229
+ mode="bilinear",
230
+ align_corners=True,
231
+ )[0, 0, 0, :] # CxN
232
+
233
+ keypoints.append(keypoints_xy)
234
+ scoredispersitys.append(scoredispersity)
235
+ kptscores.append(kptscore)
236
+ else:
237
+ for b_idx in range(b):
238
+ indices_kpt = indices_keypoints[
239
+ b_idx
240
+ ] # one dimension vector, say its size is M
241
+ # To avoid warning: UserWarning: __floordiv__ is deprecated
242
+ keypoints_xy_nms = torch.stack(
243
+ [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
244
+ dim=1,
245
+ ) # Mx2
246
+ keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
247
+ kptscore = torch.nn.functional.grid_sample(
248
+ scores_map[b_idx].unsqueeze(0),
249
+ keypoints_xy.view(1, 1, -1, 2),
250
+ mode="bilinear",
251
+ align_corners=True,
252
+ )[0, 0, 0, :] # CxN
253
+ keypoints.append(keypoints_xy)
254
+ scoredispersitys.append(kptscore) # for jit.script compatability
255
+ kptscores.append(kptscore)
256
+
257
+ return keypoints, scoredispersitys, kptscores
258
+
259
+
260
+ class InputPadder(object):
261
+ """Pads images such that dimensions are divisible by 8"""
262
+
263
+ def __init__(self, h: int, w: int, divis_by: int = 8):
264
+ self.ht = h
265
+ self.wd = w
266
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
267
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
268
+ self._pad = [
269
+ pad_wd // 2,
270
+ pad_wd - pad_wd // 2,
271
+ pad_ht // 2,
272
+ pad_ht - pad_ht // 2,
273
+ ]
274
+
275
+ def pad(self, x: torch.Tensor):
276
+ assert x.ndim == 4
277
+ return F.pad(x, self._pad, mode="replicate")
278
+
279
+ def unpad(self, x: torch.Tensor):
280
+ assert x.ndim == 4
281
+ ht = x.shape[-2]
282
+ wd = x.shape[-1]
283
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
284
+ return x[..., c[0] : c[1], c[2] : c[3]]
285
+
286
+
287
+ class DeformableConv2d(nn.Module):
288
+ def __init__(
289
+ self,
290
+ in_channels,
291
+ out_channels,
292
+ kernel_size=3,
293
+ stride=1,
294
+ padding=1,
295
+ bias=False,
296
+ mask=False,
297
+ ):
298
+ super(DeformableConv2d, self).__init__()
299
+
300
+ self.padding = padding
301
+ self.mask = mask
302
+
303
+ self.channel_num = (
304
+ 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
305
+ )
306
+ self.offset_conv = nn.Conv2d(
307
+ in_channels,
308
+ self.channel_num,
309
+ kernel_size=kernel_size,
310
+ stride=stride,
311
+ padding=self.padding,
312
+ bias=True,
313
+ )
314
+
315
+ self.regular_conv = nn.Conv2d(
316
+ in_channels=in_channels,
317
+ out_channels=out_channels,
318
+ kernel_size=kernel_size,
319
+ stride=stride,
320
+ padding=self.padding,
321
+ bias=bias,
322
+ )
323
+
324
+ def forward(self, x):
325
+ h, w = x.shape[2:]
326
+ max_offset = max(h, w) / 4.0
327
+
328
+ out = self.offset_conv(x)
329
+ if self.mask:
330
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
331
+ offset = torch.cat((o1, o2), dim=1)
332
+ mask = torch.sigmoid(mask)
333
+ else:
334
+ offset = out
335
+ mask = None
336
+ offset = offset.clamp(-max_offset, max_offset)
337
+ x = torchvision.ops.deform_conv2d(
338
+ input=x,
339
+ offset=offset,
340
+ weight=self.regular_conv.weight,
341
+ bias=self.regular_conv.bias,
342
+ padding=self.padding,
343
+ mask=mask,
344
+ )
345
+ return x
346
+
347
+
348
+ def get_conv(
349
+ inplanes,
350
+ planes,
351
+ kernel_size=3,
352
+ stride=1,
353
+ padding=1,
354
+ bias=False,
355
+ conv_type="conv",
356
+ mask=False,
357
+ ):
358
+ if conv_type == "conv":
359
+ conv = nn.Conv2d(
360
+ inplanes,
361
+ planes,
362
+ kernel_size=kernel_size,
363
+ stride=stride,
364
+ padding=padding,
365
+ bias=bias,
366
+ )
367
+ elif conv_type == "dcn":
368
+ conv = DeformableConv2d(
369
+ inplanes,
370
+ planes,
371
+ kernel_size=kernel_size,
372
+ stride=stride,
373
+ padding=_pair(padding),
374
+ bias=bias,
375
+ mask=mask,
376
+ )
377
+ else:
378
+ raise TypeError
379
+ return conv
380
+
381
+
382
+ class ConvBlock(nn.Module):
383
+ def __init__(
384
+ self,
385
+ in_channels,
386
+ out_channels,
387
+ gate: Optional[Callable[..., nn.Module]] = None,
388
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
389
+ conv_type: str = "conv",
390
+ mask: bool = False,
391
+ ):
392
+ super().__init__()
393
+ if gate is None:
394
+ self.gate = nn.ReLU(inplace=True)
395
+ else:
396
+ self.gate = gate
397
+ if norm_layer is None:
398
+ norm_layer = nn.BatchNorm2d
399
+ self.conv1 = get_conv(
400
+ in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
401
+ )
402
+ self.bn1 = norm_layer(out_channels)
403
+ self.conv2 = get_conv(
404
+ out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
405
+ )
406
+ self.bn2 = norm_layer(out_channels)
407
+
408
+ def forward(self, x):
409
+ x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
410
+ x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
411
+ return x
412
+
413
+
414
+ # modified based on torchvision\models\resnet.py#27->BasicBlock
415
+ class ResBlock(nn.Module):
416
+ expansion: int = 1
417
+
418
+ def __init__(
419
+ self,
420
+ inplanes: int,
421
+ planes: int,
422
+ stride: int = 1,
423
+ downsample: Optional[nn.Module] = None,
424
+ groups: int = 1,
425
+ base_width: int = 64,
426
+ dilation: int = 1,
427
+ gate: Optional[Callable[..., nn.Module]] = None,
428
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
429
+ conv_type: str = "conv",
430
+ mask: bool = False,
431
+ ) -> None:
432
+ super(ResBlock, self).__init__()
433
+ if gate is None:
434
+ self.gate = nn.ReLU(inplace=True)
435
+ else:
436
+ self.gate = gate
437
+ if norm_layer is None:
438
+ norm_layer = nn.BatchNorm2d
439
+ if groups != 1 or base_width != 64:
440
+ raise ValueError("ResBlock only supports groups=1 and base_width=64")
441
+ if dilation > 1:
442
+ raise NotImplementedError("Dilation > 1 not supported in ResBlock")
443
+ # Both self.conv1 and self.downsample layers
444
+ # downsample the input when stride != 1
445
+ self.conv1 = get_conv(
446
+ inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
447
+ )
448
+ self.bn1 = norm_layer(planes)
449
+ self.conv2 = get_conv(
450
+ planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
451
+ )
452
+ self.bn2 = norm_layer(planes)
453
+ self.downsample = downsample
454
+ self.stride = stride
455
+
456
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
457
+ identity = x
458
+
459
+ out = self.conv1(x)
460
+ out = self.bn1(out)
461
+ out = self.gate(out)
462
+
463
+ out = self.conv2(out)
464
+ out = self.bn2(out)
465
+
466
+ if self.downsample is not None:
467
+ identity = self.downsample(x)
468
+
469
+ out += identity
470
+ out = self.gate(out)
471
+
472
+ return out
473
+
474
+
475
+ class SDDH(nn.Module):
476
+ def __init__(
477
+ self,
478
+ dims: int,
479
+ kernel_size: int = 3,
480
+ n_pos: int = 8,
481
+ gate=nn.ReLU(),
482
+ conv2D=False,
483
+ mask=False,
484
+ ):
485
+ super(SDDH, self).__init__()
486
+ self.kernel_size = kernel_size
487
+ self.n_pos = n_pos
488
+ self.conv2D = conv2D
489
+ self.mask = mask
490
+
491
+ self.get_patches_func = get_patches
492
+
493
+ # estimate offsets
494
+ self.channel_num = 3 * n_pos if mask else 2 * n_pos
495
+ self.offset_conv = nn.Sequential(
496
+ nn.Conv2d(
497
+ dims,
498
+ self.channel_num,
499
+ kernel_size=kernel_size,
500
+ stride=1,
501
+ padding=0,
502
+ bias=True,
503
+ ),
504
+ gate,
505
+ nn.Conv2d(
506
+ self.channel_num,
507
+ self.channel_num,
508
+ kernel_size=1,
509
+ stride=1,
510
+ padding=0,
511
+ bias=True,
512
+ ),
513
+ )
514
+
515
+ # sampled feature conv
516
+ self.sf_conv = nn.Conv2d(
517
+ dims, dims, kernel_size=1, stride=1, padding=0, bias=False
518
+ )
519
+
520
+ # convM
521
+ if not conv2D:
522
+ # deformable desc weights
523
+ agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
524
+ self.register_parameter("agg_weights", agg_weights)
525
+ else:
526
+ self.convM = nn.Conv2d(
527
+ dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
528
+ )
529
+
530
+ def forward(self, x, keypoints):
531
+ # x: [B,C,H,W]
532
+ # keypoints: list, [[N_kpts,2], ...] (w,h)
533
+ b, c, h, w = x.shape
534
+ wh = torch.tensor([[w - 1, h - 1]], device=x.device)
535
+ max_offset = max(h, w) / 4.0
536
+
537
+ offsets = []
538
+ descriptors = []
539
+ # get offsets for each keypoint
540
+ for ib in range(b):
541
+ xi, kptsi = x[ib], keypoints[ib]
542
+ kptsi_wh = (kptsi / 2 + 0.5) * wh
543
+ N_kpts = len(kptsi)
544
+
545
+ if self.kernel_size > 1:
546
+ patch = self.get_patches_func(
547
+ xi, kptsi_wh.long(), self.kernel_size
548
+ ) # [N_kpts, C, K, K]
549
+ else:
550
+ kptsi_wh_long = kptsi_wh.long()
551
+ patch = (
552
+ xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
553
+ .permute(1, 0)
554
+ .reshape(N_kpts, c, 1, 1)
555
+ )
556
+
557
+ offset = self.offset_conv(patch).clamp(
558
+ -max_offset, max_offset
559
+ ) # [N_kpts, 2*n_pos, 1, 1]
560
+ if self.mask:
561
+ offset = (
562
+ offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
563
+ ) # [N_kpts, n_pos, 3]
564
+ offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
565
+ mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
566
+ else:
567
+ offset = (
568
+ offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
569
+ ) # [N_kpts, n_pos, 2]
570
+ offsets.append(offset) # for visualization
571
+
572
+ # get sample positions
573
+ pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
574
+ pos = 2.0 * pos / wh[None] - 1
575
+ pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
576
+
577
+ # sample features
578
+ features = F.grid_sample(
579
+ xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
580
+ ) # [1,C,(N_kpts*n_pos),1]
581
+ features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
582
+ 1, 0, 2, 3
583
+ ) # [N_kpts, C, n_pos, 1]
584
+ if self.mask:
585
+ features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
586
+
587
+ features = torch.selu_(self.sf_conv(features)).squeeze(
588
+ -1
589
+ ) # [N_kpts, C, n_pos]
590
+ # convM
591
+ if not self.conv2D:
592
+ descs = torch.einsum(
593
+ "ncp,pcd->nd", features, self.agg_weights
594
+ ) # [N_kpts, C]
595
+ else:
596
+ features = features.reshape(N_kpts, -1)[
597
+ :, :, None, None
598
+ ] # [N_kpts, C*n_pos, 1, 1]
599
+ descs = self.convM(features).squeeze() # [N_kpts, C]
600
+
601
+ # normalize
602
+ descs = F.normalize(descs, p=2.0, dim=1)
603
+ descriptors.append(descs)
604
+
605
+ return descriptors, offsets
606
+
607
+
608
+ class ALIKED(Extractor):
609
+ default_conf = {
610
+ "model_name": "aliked-n16",
611
+ "max_num_keypoints": -1,
612
+ "detection_threshold": 0.2,
613
+ "nms_radius": 2,
614
+ }
615
+
616
+ checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
617
+
618
+ n_limit_max = 20000
619
+
620
+ # c1, c2, c3, c4, dim, K, M
621
+ cfgs = {
622
+ "aliked-t16": [8, 16, 32, 64, 64, 3, 16],
623
+ "aliked-n16": [16, 32, 64, 128, 128, 3, 16],
624
+ "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
625
+ "aliked-n32": [16, 32, 64, 128, 128, 3, 32],
626
+ }
627
+ preprocess_conf = {
628
+ "resize": 1024,
629
+ }
630
+
631
+ required_data_keys = ["image"]
632
+
633
+ def __init__(self, **conf):
634
+ super().__init__(**conf) # Update with default configuration.
635
+ conf = self.conf
636
+ c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
637
+ conv_types = ["conv", "conv", "dcn", "dcn"]
638
+ conv2D = False
639
+ mask = False
640
+
641
+ # build model
642
+ self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
643
+ self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
644
+ self.norm = nn.BatchNorm2d
645
+ self.gate = nn.SELU(inplace=True)
646
+ self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
647
+ self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
648
+ self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
649
+ self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
650
+
651
+ self.conv1 = resnet.conv1x1(c1, dim // 4)
652
+ self.conv2 = resnet.conv1x1(c2, dim // 4)
653
+ self.conv3 = resnet.conv1x1(c3, dim // 4)
654
+ self.conv4 = resnet.conv1x1(dim, dim // 4)
655
+ self.upsample2 = nn.Upsample(
656
+ scale_factor=2, mode="bilinear", align_corners=True
657
+ )
658
+ self.upsample4 = nn.Upsample(
659
+ scale_factor=4, mode="bilinear", align_corners=True
660
+ )
661
+ self.upsample8 = nn.Upsample(
662
+ scale_factor=8, mode="bilinear", align_corners=True
663
+ )
664
+ self.upsample32 = nn.Upsample(
665
+ scale_factor=32, mode="bilinear", align_corners=True
666
+ )
667
+ self.score_head = nn.Sequential(
668
+ resnet.conv1x1(dim, 8),
669
+ self.gate,
670
+ resnet.conv3x3(8, 4),
671
+ self.gate,
672
+ resnet.conv3x3(4, 4),
673
+ self.gate,
674
+ resnet.conv3x3(4, 1),
675
+ )
676
+ self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
677
+ self.dkd = DKD(
678
+ radius=conf.nms_radius,
679
+ top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
680
+ scores_th=conf.detection_threshold,
681
+ n_limit=conf.max_num_keypoints
682
+ if conf.max_num_keypoints > 0
683
+ else self.n_limit_max,
684
+ )
685
+
686
+ state_dict = torch.hub.load_state_dict_from_url(
687
+ self.checkpoint_url.format(conf.model_name), map_location="cpu"
688
+ )
689
+ self.load_state_dict(state_dict, strict=True)
690
+
691
+ def get_resblock(self, c_in, c_out, conv_type, mask):
692
+ return ResBlock(
693
+ c_in,
694
+ c_out,
695
+ 1,
696
+ nn.Conv2d(c_in, c_out, 1),
697
+ gate=self.gate,
698
+ norm_layer=self.norm,
699
+ conv_type=conv_type,
700
+ mask=mask,
701
+ )
702
+
703
+ def extract_dense_map(self, image):
704
+ # Pads images such that dimensions are divisible by
705
+ div_by = 2**5
706
+ padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
707
+ image = padder.pad(image)
708
+
709
+ # ================================== feature encoder
710
+ x1 = self.block1(image) # B x c1 x H x W
711
+ x2 = self.pool2(x1)
712
+ x2 = self.block2(x2) # B x c2 x H/2 x W/2
713
+ x3 = self.pool4(x2)
714
+ x3 = self.block3(x3) # B x c3 x H/8 x W/8
715
+ x4 = self.pool4(x3)
716
+ x4 = self.block4(x4) # B x dim x H/32 x W/32
717
+ # ================================== feature aggregation
718
+ x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
719
+ x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
720
+ x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
721
+ x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
722
+ x2_up = self.upsample2(x2) # B x dim//4 x H x W
723
+ x3_up = self.upsample8(x3) # B x dim//4 x H x W
724
+ x4_up = self.upsample32(x4) # B x dim//4 x H x W
725
+ x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
726
+ # ================================== score head
727
+ score_map = torch.sigmoid(self.score_head(x1234))
728
+ feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
729
+
730
+ # Unpads images
731
+ feature_map = padder.unpad(feature_map)
732
+ score_map = padder.unpad(score_map)
733
+
734
+ return feature_map, score_map
735
+
736
+ def forward(self, data: dict) -> dict:
737
+ # need to set here unfortunately
738
+ self.dkd.n_limit = (
739
+ self.conf.max_num_keypoints
740
+ if self.conf.max_num_keypoints > 0
741
+ else self.n_limit_max
742
+ )
743
+ image = data["image"]
744
+ if image.shape[1] == 1:
745
+ image = grayscale_to_rgb(image)
746
+ feature_map, score_map = self.extract_dense_map(image)
747
+ keypoints, kptscores, scoredispersitys = self.dkd(
748
+ score_map, image_size=data.get("image_size")
749
+ )
750
+ # descriptors, offsets = self.desc_head(feature_map, keypoints)
751
+
752
+ _, _, h, w = image.shape
753
+ wh = torch.tensor([w - 1, h - 1], device=image.device)
754
+ # no padding required
755
+ # we can set detection_threshold=-1 and conf.max_num_keypoints > 0
756
+ return {
757
+ "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
758
+ # "descriptors": torch.stack(descriptors), # B x N x D
759
+ "keypoint_scores": torch.stack(kptscores), # B x N
760
+ "scoremap": score_map, # B x 1 x H x W
761
+ }
762
+
763
+
764
+ class ALIKEDROT(ALIKED):
765
+ default_conf = {
766
+ "model_name": "aliked-n16rot",
767
+ "max_num_keypoints": -1,
768
+ "detection_threshold": 0.2,
769
+ "nms_radius": 2,
770
+ }
imcui/third_party/dad/dad/detectors/third_party/lightglue/disk.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import torch
3
+
4
+ from .utils import Extractor
5
+
6
+
7
+ class DISK(Extractor):
8
+ default_conf = {
9
+ "weights": "depth",
10
+ "max_num_keypoints": None,
11
+ "desc_dim": 128,
12
+ "nms_window_size": 5,
13
+ "detection_threshold": 0.0,
14
+ "pad_if_not_divisible": True,
15
+ }
16
+
17
+ preprocess_conf = {
18
+ "resize": 1024,
19
+ "grayscale": False,
20
+ }
21
+
22
+ required_data_keys = ["image"]
23
+
24
+ def __init__(self, **conf) -> None:
25
+ super().__init__(**conf) # Update with default configuration.
26
+ self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
27
+
28
+ def forward(self, data: dict) -> dict:
29
+ """Compute keypoints, scores, descriptors for image"""
30
+ for key in self.required_data_keys:
31
+ assert key in data, f"Missing key {key} in data"
32
+ image = data["image"]
33
+ if image.shape[1] == 1:
34
+ image = kornia.color.grayscale_to_rgb(image)
35
+ features = self.model(
36
+ image,
37
+ n=self.conf.max_num_keypoints,
38
+ window_size=self.conf.nms_window_size,
39
+ score_threshold=self.conf.detection_threshold,
40
+ pad_if_not_divisible=self.conf.pad_if_not_divisible,
41
+ )
42
+ keypoints = [f.keypoints for f in features]
43
+
44
+ keypoints = torch.stack(keypoints, 0)
45
+
46
+ return {
47
+ "keypoints": keypoints.to(image).contiguous(),
48
+ }
imcui/third_party/dad/dad/detectors/third_party/lightglue/dog_hardnet.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from kornia.color import rgb_to_grayscale
3
+ from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
4
+
5
+ from .sift import SIFT
6
+
7
+
8
+ class DoGHardNet(SIFT):
9
+ required_data_keys = ["image"]
10
+
11
+ def __init__(self, **conf):
12
+ super().__init__(**conf)
13
+ self.laf_desc = LAFDescriptor(HardNet(True)).eval()
14
+
15
+ def forward(self, data: dict) -> dict:
16
+ image = data["image"]
17
+ if image.shape[1] == 3:
18
+ image = rgb_to_grayscale(image)
19
+ device = image.device
20
+ self.laf_desc = self.laf_desc.to(device)
21
+ self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
22
+ pred = []
23
+ if "image_size" in data.keys():
24
+ im_size = data.get("image_size").long()
25
+ else:
26
+ im_size = None
27
+ for k in range(len(image)):
28
+ img = image[k]
29
+ if im_size is not None:
30
+ w, h = data["image_size"][k]
31
+ img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
32
+ p = self.extract_single_image(img)
33
+ lafs = laf_from_center_scale_ori(
34
+ p["keypoints"].reshape(1, -1, 2),
35
+ 6.0 * p["scales"].reshape(1, -1, 1, 1),
36
+ torch.rad2deg(p["oris"]).reshape(1, -1, 1),
37
+ ).to(device)
38
+ p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
39
+ pred.append(p)
40
+ pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
41
+ return pred
imcui/third_party/dad/dad/detectors/third_party/lightglue/lightglue.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+ from types import SimpleNamespace
4
+ from typing import Callable, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ try:
12
+ from flash_attn.modules.mha import FlashCrossAttention
13
+ except ModuleNotFoundError:
14
+ FlashCrossAttention = None
15
+
16
+ if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
17
+ FLASH_AVAILABLE = True
18
+ else:
19
+ FLASH_AVAILABLE = False
20
+
21
+ torch.backends.cudnn.deterministic = True
22
+
23
+
24
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
25
+ def normalize_keypoints(
26
+ kpts: torch.Tensor, size: Optional[torch.Tensor] = None
27
+ ) -> torch.Tensor:
28
+ if size is None:
29
+ size = 1 + kpts.max(-2).values - kpts.min(-2).values
30
+ elif not isinstance(size, torch.Tensor):
31
+ size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
32
+ size = size.to(kpts)
33
+ shift = size / 2
34
+ scale = size.max(-1).values / 2
35
+ kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
36
+ return kpts
37
+
38
+
39
+ def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
40
+ if length <= x.shape[-2]:
41
+ return x, torch.ones_like(x[..., :1], dtype=torch.bool)
42
+ pad = torch.ones(
43
+ *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
44
+ )
45
+ y = torch.cat([x, pad], dim=-2)
46
+ mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
47
+ mask[..., : x.shape[-2], :] = True
48
+ return y, mask
49
+
50
+
51
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
52
+ x = x.unflatten(-1, (-1, 2))
53
+ x1, x2 = x.unbind(dim=-1)
54
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
55
+
56
+
57
+ def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
58
+ return (t * freqs[0]) + (rotate_half(t) * freqs[1])
59
+
60
+
61
+ class LearnableFourierPositionalEncoding(nn.Module):
62
+ def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
63
+ super().__init__()
64
+ F_dim = F_dim if F_dim is not None else dim
65
+ self.gamma = gamma
66
+ self.Wr = nn.Linear(M, F_dim // 2, bias=False)
67
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ """encode position vector"""
71
+ projected = self.Wr(x)
72
+ cosines, sines = torch.cos(projected), torch.sin(projected)
73
+ emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
74
+ return emb.repeat_interleave(2, dim=-1)
75
+
76
+
77
+ class TokenConfidence(nn.Module):
78
+ def __init__(self, dim: int) -> None:
79
+ super().__init__()
80
+ self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
81
+
82
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
83
+ """get confidence tokens"""
84
+ return (
85
+ self.token(desc0.detach()).squeeze(-1),
86
+ self.token(desc1.detach()).squeeze(-1),
87
+ )
88
+
89
+
90
+ class Attention(nn.Module):
91
+ def __init__(self, allow_flash: bool) -> None:
92
+ super().__init__()
93
+ if allow_flash and not FLASH_AVAILABLE:
94
+ warnings.warn(
95
+ "FlashAttention is not available. For optimal speed, "
96
+ "consider installing torch >= 2.0 or flash-attn.",
97
+ stacklevel=2,
98
+ )
99
+ self.enable_flash = allow_flash and FLASH_AVAILABLE
100
+ self.has_sdp = hasattr(F, "scaled_dot_product_attention")
101
+ if allow_flash and FlashCrossAttention:
102
+ self.flash_ = FlashCrossAttention()
103
+ if self.has_sdp:
104
+ torch.backends.cuda.enable_flash_sdp(allow_flash)
105
+
106
+ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
107
+ if q.shape[-2] == 0 or k.shape[-2] == 0:
108
+ return q.new_zeros((*q.shape[:-1], v.shape[-1]))
109
+ if self.enable_flash and q.device.type == "cuda":
110
+ # use torch 2.0 scaled_dot_product_attention with flash
111
+ if self.has_sdp:
112
+ args = [x.half().contiguous() for x in [q, k, v]]
113
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
114
+ return v if mask is None else v.nan_to_num()
115
+ else:
116
+ assert mask is None
117
+ q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
118
+ m = self.flash_(q.half(), torch.stack([k, v], 2).half())
119
+ return m.transpose(-2, -3).to(q.dtype).clone()
120
+ elif self.has_sdp:
121
+ args = [x.contiguous() for x in [q, k, v]]
122
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask)
123
+ return v if mask is None else v.nan_to_num()
124
+ else:
125
+ s = q.shape[-1] ** -0.5
126
+ sim = torch.einsum("...id,...jd->...ij", q, k) * s
127
+ if mask is not None:
128
+ sim.masked_fill(~mask, -float("inf"))
129
+ attn = F.softmax(sim, -1)
130
+ return torch.einsum("...ij,...jd->...id", attn, v)
131
+
132
+
133
+ class SelfBlock(nn.Module):
134
+ def __init__(
135
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
136
+ ) -> None:
137
+ super().__init__()
138
+ self.embed_dim = embed_dim
139
+ self.num_heads = num_heads
140
+ assert self.embed_dim % num_heads == 0
141
+ self.head_dim = self.embed_dim // num_heads
142
+ self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
143
+ self.inner_attn = Attention(flash)
144
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
145
+ self.ffn = nn.Sequential(
146
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
147
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
148
+ nn.GELU(),
149
+ nn.Linear(2 * embed_dim, embed_dim),
150
+ )
151
+
152
+ def forward(
153
+ self,
154
+ x: torch.Tensor,
155
+ encoding: torch.Tensor,
156
+ mask: Optional[torch.Tensor] = None,
157
+ ) -> torch.Tensor:
158
+ qkv = self.Wqkv(x)
159
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
160
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
161
+ q = apply_cached_rotary_emb(encoding, q)
162
+ k = apply_cached_rotary_emb(encoding, k)
163
+ context = self.inner_attn(q, k, v, mask=mask)
164
+ message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
165
+ return x + self.ffn(torch.cat([x, message], -1))
166
+
167
+
168
+ class CrossBlock(nn.Module):
169
+ def __init__(
170
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
171
+ ) -> None:
172
+ super().__init__()
173
+ self.heads = num_heads
174
+ dim_head = embed_dim // num_heads
175
+ self.scale = dim_head**-0.5
176
+ inner_dim = dim_head * num_heads
177
+ self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
178
+ self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
179
+ self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
180
+ self.ffn = nn.Sequential(
181
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
182
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
183
+ nn.GELU(),
184
+ nn.Linear(2 * embed_dim, embed_dim),
185
+ )
186
+ if flash and FLASH_AVAILABLE:
187
+ self.flash = Attention(True)
188
+ else:
189
+ self.flash = None
190
+
191
+ def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
192
+ return func(x0), func(x1)
193
+
194
+ def forward(
195
+ self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
196
+ ) -> List[torch.Tensor]:
197
+ qk0, qk1 = self.map_(self.to_qk, x0, x1)
198
+ v0, v1 = self.map_(self.to_v, x0, x1)
199
+ qk0, qk1, v0, v1 = map(
200
+ lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
201
+ (qk0, qk1, v0, v1),
202
+ )
203
+ if self.flash is not None and qk0.device.type == "cuda":
204
+ m0 = self.flash(qk0, qk1, v1, mask)
205
+ m1 = self.flash(
206
+ qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
207
+ )
208
+ else:
209
+ qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
210
+ sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
211
+ if mask is not None:
212
+ sim = sim.masked_fill(~mask, -float("inf"))
213
+ attn01 = F.softmax(sim, dim=-1)
214
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
215
+ m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
216
+ m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
217
+ if mask is not None:
218
+ m0, m1 = m0.nan_to_num(), m1.nan_to_num()
219
+ m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
220
+ m0, m1 = self.map_(self.to_out, m0, m1)
221
+ x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
222
+ x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
223
+ return x0, x1
224
+
225
+
226
+ class TransformerLayer(nn.Module):
227
+ def __init__(self, *args, **kwargs):
228
+ super().__init__()
229
+ self.self_attn = SelfBlock(*args, **kwargs)
230
+ self.cross_attn = CrossBlock(*args, **kwargs)
231
+
232
+ def forward(
233
+ self,
234
+ desc0,
235
+ desc1,
236
+ encoding0,
237
+ encoding1,
238
+ mask0: Optional[torch.Tensor] = None,
239
+ mask1: Optional[torch.Tensor] = None,
240
+ ):
241
+ if mask0 is not None and mask1 is not None:
242
+ return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
243
+ else:
244
+ desc0 = self.self_attn(desc0, encoding0)
245
+ desc1 = self.self_attn(desc1, encoding1)
246
+ return self.cross_attn(desc0, desc1)
247
+
248
+ # This part is compiled and allows padding inputs
249
+ def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
250
+ mask = mask0 & mask1.transpose(-1, -2)
251
+ mask0 = mask0 & mask0.transpose(-1, -2)
252
+ mask1 = mask1 & mask1.transpose(-1, -2)
253
+ desc0 = self.self_attn(desc0, encoding0, mask0)
254
+ desc1 = self.self_attn(desc1, encoding1, mask1)
255
+ return self.cross_attn(desc0, desc1, mask)
256
+
257
+
258
+ def sigmoid_log_double_softmax(
259
+ sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
260
+ ) -> torch.Tensor:
261
+ """create the log assignment matrix from logits and similarity"""
262
+ b, m, n = sim.shape
263
+ certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
264
+ scores0 = F.log_softmax(sim, 2)
265
+ scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
266
+ scores = sim.new_full((b, m + 1, n + 1), 0)
267
+ scores[:, :m, :n] = scores0 + scores1 + certainties
268
+ scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
269
+ scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
270
+ return scores
271
+
272
+
273
+ class MatchAssignment(nn.Module):
274
+ def __init__(self, dim: int) -> None:
275
+ super().__init__()
276
+ self.dim = dim
277
+ self.matchability = nn.Linear(dim, 1, bias=True)
278
+ self.final_proj = nn.Linear(dim, dim, bias=True)
279
+
280
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
281
+ """build assignment matrix from descriptors"""
282
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
283
+ _, _, d = mdesc0.shape
284
+ mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
285
+ sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
286
+ z0 = self.matchability(desc0)
287
+ z1 = self.matchability(desc1)
288
+ scores = sigmoid_log_double_softmax(sim, z0, z1)
289
+ return scores, sim
290
+
291
+ def get_matchability(self, desc: torch.Tensor):
292
+ return torch.sigmoid(self.matchability(desc)).squeeze(-1)
293
+
294
+
295
+ def filter_matches(scores: torch.Tensor, th: float):
296
+ """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
297
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
298
+ m0, m1 = max0.indices, max1.indices
299
+ indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
300
+ indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
301
+ mutual0 = indices0 == m1.gather(1, m0)
302
+ mutual1 = indices1 == m0.gather(1, m1)
303
+ max0_exp = max0.values.exp()
304
+ zero = max0_exp.new_tensor(0)
305
+ mscores0 = torch.where(mutual0, max0_exp, zero)
306
+ mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
307
+ valid0 = mutual0 & (mscores0 > th)
308
+ valid1 = mutual1 & valid0.gather(1, m1)
309
+ m0 = torch.where(valid0, m0, -1)
310
+ m1 = torch.where(valid1, m1, -1)
311
+ return m0, m1, mscores0, mscores1
312
+
313
+
314
+ class LightGlue(nn.Module):
315
+ default_conf = {
316
+ "name": "lightglue", # just for interfacing
317
+ "input_dim": 256, # input descriptor dimension (autoselected from weights)
318
+ "descriptor_dim": 256,
319
+ "add_scale_ori": False,
320
+ "n_layers": 9,
321
+ "num_heads": 4,
322
+ "flash": True, # enable FlashAttention if available.
323
+ "mp": False, # enable mixed precision
324
+ "depth_confidence": 0.95, # early stopping, disable with -1
325
+ "width_confidence": 0.99, # point pruning, disable with -1
326
+ "filter_threshold": 0.1, # match threshold
327
+ "weights": None,
328
+ }
329
+
330
+ # Point pruning involves an overhead (gather).
331
+ # Therefore, we only activate it if there are enough keypoints.
332
+ pruning_keypoint_thresholds = {
333
+ "cpu": -1,
334
+ "mps": -1,
335
+ "cuda": 1024,
336
+ "flash": 1536,
337
+ }
338
+
339
+ required_data_keys = ["image0", "image1"]
340
+
341
+ version = "v0.1_arxiv"
342
+ url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
343
+
344
+ features = {
345
+ "superpoint": {
346
+ "weights": "superpoint_lightglue",
347
+ "input_dim": 256,
348
+ },
349
+ "disk": {
350
+ "weights": "disk_lightglue",
351
+ "input_dim": 128,
352
+ },
353
+ "aliked": {
354
+ "weights": "aliked_lightglue",
355
+ "input_dim": 128,
356
+ },
357
+ "sift": {
358
+ "weights": "sift_lightglue",
359
+ "input_dim": 128,
360
+ "add_scale_ori": True,
361
+ },
362
+ "doghardnet": {
363
+ "weights": "doghardnet_lightglue",
364
+ "input_dim": 128,
365
+ "add_scale_ori": True,
366
+ },
367
+ }
368
+
369
+ def __init__(self, features="superpoint", **conf) -> None:
370
+ super().__init__()
371
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
372
+ if features is not None:
373
+ if features not in self.features:
374
+ raise ValueError(
375
+ f"Unsupported features: {features} not in "
376
+ f"{{{','.join(self.features)}}}"
377
+ )
378
+ for k, v in self.features[features].items():
379
+ setattr(conf, k, v)
380
+
381
+ if conf.input_dim != conf.descriptor_dim:
382
+ self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
383
+ else:
384
+ self.input_proj = nn.Identity()
385
+
386
+ head_dim = conf.descriptor_dim // conf.num_heads
387
+ self.posenc = LearnableFourierPositionalEncoding(
388
+ 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
389
+ )
390
+
391
+ h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
392
+
393
+ self.transformers = nn.ModuleList(
394
+ [TransformerLayer(d, h, conf.flash) for _ in range(n)]
395
+ )
396
+
397
+ self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
398
+ self.token_confidence = nn.ModuleList(
399
+ [TokenConfidence(d) for _ in range(n - 1)]
400
+ )
401
+ self.register_buffer(
402
+ "confidence_thresholds",
403
+ torch.Tensor(
404
+ [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
405
+ ),
406
+ )
407
+
408
+ state_dict = None
409
+ if features is not None:
410
+ fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
411
+ state_dict = torch.hub.load_state_dict_from_url(
412
+ self.url.format(self.version, features), file_name=fname
413
+ )
414
+ self.load_state_dict(state_dict, strict=False)
415
+ elif conf.weights is not None:
416
+ path = Path(__file__).parent
417
+ path = path / "weights/{}.pth".format(self.conf.weights)
418
+ state_dict = torch.load(str(path), map_location="cpu")
419
+
420
+ if state_dict:
421
+ # rename old state dict entries
422
+ for i in range(self.conf.n_layers):
423
+ pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
424
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
425
+ pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
426
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
427
+ self.load_state_dict(state_dict, strict=False)
428
+
429
+ # static lengths LightGlue is compiled for (only used with torch.compile)
430
+ self.static_lengths = None
431
+
432
+ def compile(
433
+ self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
434
+ ):
435
+ if self.conf.width_confidence != -1:
436
+ warnings.warn(
437
+ "Point pruning is partially disabled for compiled forward.",
438
+ stacklevel=2,
439
+ )
440
+
441
+ torch._inductor.cudagraph_mark_step_begin()
442
+ for i in range(self.conf.n_layers):
443
+ self.transformers[i].masked_forward = torch.compile(
444
+ self.transformers[i].masked_forward, mode=mode, fullgraph=True
445
+ )
446
+
447
+ self.static_lengths = static_lengths
448
+
449
+ def forward(self, data: dict) -> dict:
450
+ """
451
+ Match keypoints and descriptors between two images
452
+
453
+ Input (dict):
454
+ image0: dict
455
+ keypoints: [B x M x 2]
456
+ descriptors: [B x M x D]
457
+ image: [B x C x H x W] or image_size: [B x 2]
458
+ image1: dict
459
+ keypoints: [B x N x 2]
460
+ descriptors: [B x N x D]
461
+ image: [B x C x H x W] or image_size: [B x 2]
462
+ Output (dict):
463
+ matches0: [B x M]
464
+ matching_scores0: [B x M]
465
+ matches1: [B x N]
466
+ matching_scores1: [B x N]
467
+ matches: List[[Si x 2]]
468
+ scores: List[[Si]]
469
+ stop: int
470
+ prune0: [B x M]
471
+ prune1: [B x N]
472
+ """
473
+ with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
474
+ return self._forward(data)
475
+
476
+ def _forward(self, data: dict) -> dict:
477
+ for key in self.required_data_keys:
478
+ assert key in data, f"Missing key {key} in data"
479
+ data0, data1 = data["image0"], data["image1"]
480
+ kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
481
+ b, m, _ = kpts0.shape
482
+ b, n, _ = kpts1.shape
483
+ device = kpts0.device
484
+ size0, size1 = data0.get("image_size"), data1.get("image_size")
485
+ kpts0 = normalize_keypoints(kpts0, size0).clone()
486
+ kpts1 = normalize_keypoints(kpts1, size1).clone()
487
+
488
+ if self.conf.add_scale_ori:
489
+ kpts0 = torch.cat(
490
+ [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
491
+ )
492
+ kpts1 = torch.cat(
493
+ [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
494
+ )
495
+ desc0 = data0["descriptors"].detach().contiguous()
496
+ desc1 = data1["descriptors"].detach().contiguous()
497
+
498
+ assert desc0.shape[-1] == self.conf.input_dim
499
+ assert desc1.shape[-1] == self.conf.input_dim
500
+
501
+ if torch.is_autocast_enabled():
502
+ desc0 = desc0.half()
503
+ desc1 = desc1.half()
504
+
505
+ mask0, mask1 = None, None
506
+ c = max(m, n)
507
+ do_compile = self.static_lengths and c <= max(self.static_lengths)
508
+ if do_compile:
509
+ kn = min([k for k in self.static_lengths if k >= c])
510
+ desc0, mask0 = pad_to_length(desc0, kn)
511
+ desc1, mask1 = pad_to_length(desc1, kn)
512
+ kpts0, _ = pad_to_length(kpts0, kn)
513
+ kpts1, _ = pad_to_length(kpts1, kn)
514
+ desc0 = self.input_proj(desc0)
515
+ desc1 = self.input_proj(desc1)
516
+ # cache positional embeddings
517
+ encoding0 = self.posenc(kpts0)
518
+ encoding1 = self.posenc(kpts1)
519
+
520
+ # GNN + final_proj + assignment
521
+ do_early_stop = self.conf.depth_confidence > 0
522
+ do_point_pruning = self.conf.width_confidence > 0 and not do_compile
523
+ pruning_th = self.pruning_min_kpts(device)
524
+ if do_point_pruning:
525
+ ind0 = torch.arange(0, m, device=device)[None]
526
+ ind1 = torch.arange(0, n, device=device)[None]
527
+ # We store the index of the layer at which pruning is detected.
528
+ prune0 = torch.ones_like(ind0)
529
+ prune1 = torch.ones_like(ind1)
530
+ token0, token1 = None, None
531
+ for i in range(self.conf.n_layers):
532
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
533
+ break
534
+ desc0, desc1 = self.transformers[i](
535
+ desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
536
+ )
537
+ if i == self.conf.n_layers - 1:
538
+ continue # no early stopping or adaptive width at last layer
539
+
540
+ if do_early_stop:
541
+ token0, token1 = self.token_confidence[i](desc0, desc1)
542
+ if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
543
+ break
544
+ if do_point_pruning and desc0.shape[-2] > pruning_th:
545
+ scores0 = self.log_assignment[i].get_matchability(desc0)
546
+ prunemask0 = self.get_pruning_mask(token0, scores0, i)
547
+ keep0 = torch.where(prunemask0)[1]
548
+ ind0 = ind0.index_select(1, keep0)
549
+ desc0 = desc0.index_select(1, keep0)
550
+ encoding0 = encoding0.index_select(-2, keep0)
551
+ prune0[:, ind0] += 1
552
+ if do_point_pruning and desc1.shape[-2] > pruning_th:
553
+ scores1 = self.log_assignment[i].get_matchability(desc1)
554
+ prunemask1 = self.get_pruning_mask(token1, scores1, i)
555
+ keep1 = torch.where(prunemask1)[1]
556
+ ind1 = ind1.index_select(1, keep1)
557
+ desc1 = desc1.index_select(1, keep1)
558
+ encoding1 = encoding1.index_select(-2, keep1)
559
+ prune1[:, ind1] += 1
560
+
561
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
562
+ m0 = desc0.new_full((b, m), -1, dtype=torch.long)
563
+ m1 = desc1.new_full((b, n), -1, dtype=torch.long)
564
+ mscores0 = desc0.new_zeros((b, m))
565
+ mscores1 = desc1.new_zeros((b, n))
566
+ matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
567
+ mscores = desc0.new_empty((b, 0))
568
+ if not do_point_pruning:
569
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
570
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
571
+ return {
572
+ "matches0": m0,
573
+ "matches1": m1,
574
+ "matching_scores0": mscores0,
575
+ "matching_scores1": mscores1,
576
+ "stop": i + 1,
577
+ "matches": matches,
578
+ "scores": mscores,
579
+ "prune0": prune0,
580
+ "prune1": prune1,
581
+ }
582
+
583
+ desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
584
+ scores, _ = self.log_assignment[i](desc0, desc1)
585
+ m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
586
+ matches, mscores = [], []
587
+ for k in range(b):
588
+ valid = m0[k] > -1
589
+ m_indices_0 = torch.where(valid)[0]
590
+ m_indices_1 = m0[k][valid]
591
+ if do_point_pruning:
592
+ m_indices_0 = ind0[k, m_indices_0]
593
+ m_indices_1 = ind1[k, m_indices_1]
594
+ matches.append(torch.stack([m_indices_0, m_indices_1], -1))
595
+ mscores.append(mscores0[k][valid])
596
+
597
+ # TODO: Remove when hloc switches to the compact format.
598
+ if do_point_pruning:
599
+ m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
600
+ m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
601
+ m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
602
+ m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
603
+ mscores0_ = torch.zeros((b, m), device=mscores0.device)
604
+ mscores1_ = torch.zeros((b, n), device=mscores1.device)
605
+ mscores0_[:, ind0] = mscores0
606
+ mscores1_[:, ind1] = mscores1
607
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
608
+ else:
609
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
610
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
611
+
612
+ return {
613
+ "matches0": m0,
614
+ "matches1": m1,
615
+ "matching_scores0": mscores0,
616
+ "matching_scores1": mscores1,
617
+ "stop": i + 1,
618
+ "matches": matches,
619
+ "scores": mscores,
620
+ "prune0": prune0,
621
+ "prune1": prune1,
622
+ }
623
+
624
+ def confidence_threshold(self, layer_index: int) -> float:
625
+ """scaled confidence threshold"""
626
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
627
+ return np.clip(threshold, 0, 1)
628
+
629
+ def get_pruning_mask(
630
+ self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
631
+ ) -> torch.Tensor:
632
+ """mask points which should be removed"""
633
+ keep = scores > (1 - self.conf.width_confidence)
634
+ if confidences is not None: # Low-confidence points are never pruned.
635
+ keep |= confidences <= self.confidence_thresholds[layer_index]
636
+ return keep
637
+
638
+ def check_if_stop(
639
+ self,
640
+ confidences0: torch.Tensor,
641
+ confidences1: torch.Tensor,
642
+ layer_index: int,
643
+ num_points: int,
644
+ ) -> torch.Tensor:
645
+ """evaluate stopping condition"""
646
+ confidences = torch.cat([confidences0, confidences1], -1)
647
+ threshold = self.confidence_thresholds[layer_index]
648
+ ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
649
+ return ratio_confident > self.conf.depth_confidence
650
+
651
+ def pruning_min_kpts(self, device: torch.device):
652
+ if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
653
+ return self.pruning_keypoint_thresholds["flash"]
654
+ else:
655
+ return self.pruning_keypoint_thresholds[device.type]
imcui/third_party/dad/dad/detectors/third_party/lightglue/sift.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from kornia.color import rgb_to_grayscale
7
+ from packaging import version
8
+
9
+ try:
10
+ import pycolmap
11
+ except ImportError:
12
+ pycolmap = None
13
+
14
+ from .utils import Extractor
15
+
16
+
17
+ def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
18
+ h, w = image_shape
19
+ ij = np.round(points - 0.5).astype(int).T[::-1]
20
+
21
+ # Remove duplicate points (identical coordinates).
22
+ # Pick highest scale or score
23
+ s = scales if scores is None else scores
24
+ buffer = np.zeros((h, w))
25
+ np.maximum.at(buffer, tuple(ij), s)
26
+ keep = np.where(buffer[tuple(ij)] == s)[0]
27
+
28
+ # Pick lowest angle (arbitrary).
29
+ ij = ij[:, keep]
30
+ buffer[:] = np.inf
31
+ o_abs = np.abs(angles[keep])
32
+ np.minimum.at(buffer, tuple(ij), o_abs)
33
+ mask = buffer[tuple(ij)] == o_abs
34
+ ij = ij[:, mask]
35
+ keep = keep[mask]
36
+
37
+ if nms_radius > 0:
38
+ # Apply NMS on the remaining points
39
+ buffer[:] = 0
40
+ buffer[tuple(ij)] = s[keep] # scores or scale
41
+
42
+ local_max = torch.nn.functional.max_pool2d(
43
+ torch.from_numpy(buffer).unsqueeze(0),
44
+ kernel_size=nms_radius * 2 + 1,
45
+ stride=1,
46
+ padding=nms_radius,
47
+ ).squeeze(0)
48
+ is_local_max = buffer == local_max.numpy()
49
+ keep = keep[is_local_max[tuple(ij)]]
50
+ return keep
51
+
52
+
53
+ def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
54
+ x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
55
+ x.clip_(min=eps).sqrt_()
56
+ return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
57
+
58
+
59
+ def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
60
+ """
61
+ Detect keypoints using OpenCV Detector.
62
+ Optionally, perform description.
63
+ Args:
64
+ features: OpenCV based keypoints detector and descriptor
65
+ image: Grayscale image of uint8 data type
66
+ Returns:
67
+ keypoints: 1D array of detected cv2.KeyPoint
68
+ scores: 1D array of responses
69
+ descriptors: 1D array of descriptors
70
+ """
71
+ detections, descriptors = features.detectAndCompute(image, None)
72
+ points = np.array([k.pt for k in detections], dtype=np.float32)
73
+ scores = np.array([k.response for k in detections], dtype=np.float32)
74
+ scales = np.array([k.size for k in detections], dtype=np.float32)
75
+ angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
76
+ return points, scores, scales, angles, descriptors
77
+
78
+
79
+ class SIFT(Extractor):
80
+ default_conf = {
81
+ "rootsift": True,
82
+ "nms_radius": 0, # None to disable filtering entirely.
83
+ "max_num_keypoints": 4096,
84
+ "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
85
+ "detection_threshold": 0.0066667, # from COLMAP
86
+ "edge_threshold": 10,
87
+ "first_octave": -1, # only used by pycolmap, the default of COLMAP
88
+ "num_octaves": 4,
89
+ }
90
+
91
+ preprocess_conf = {
92
+ "resize": 1024,
93
+ }
94
+
95
+ required_data_keys = ["image"]
96
+
97
+ def __init__(self, **conf):
98
+ super().__init__(**conf) # Update with default configuration.
99
+ backend = self.conf.backend
100
+ if backend.startswith("pycolmap"):
101
+ if pycolmap is None:
102
+ raise ImportError(
103
+ "Cannot find module pycolmap: install it with pip"
104
+ "or use backend=opencv."
105
+ )
106
+ options = {
107
+ "peak_threshold": self.conf.detection_threshold,
108
+ "edge_threshold": self.conf.edge_threshold,
109
+ "first_octave": self.conf.first_octave,
110
+ "num_octaves": self.conf.num_octaves,
111
+ "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
112
+ }
113
+ device = (
114
+ "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
115
+ )
116
+ if (
117
+ backend == "pycolmap_cpu" or not pycolmap.has_cuda
118
+ ) and pycolmap.__version__ < "0.5.0":
119
+ warnings.warn(
120
+ "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
121
+ "consider upgrading pycolmap or use the CUDA version.",
122
+ stacklevel=1,
123
+ )
124
+ else:
125
+ options["max_num_features"] = self.conf.max_num_keypoints
126
+ self.sift = pycolmap.Sift(options=options, device=device)
127
+ elif backend == "opencv":
128
+ self.sift = cv2.SIFT_create(
129
+ contrastThreshold=self.conf.detection_threshold,
130
+ nfeatures=self.conf.max_num_keypoints,
131
+ edgeThreshold=self.conf.edge_threshold,
132
+ nOctaveLayers=self.conf.num_octaves,
133
+ )
134
+ else:
135
+ backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
136
+ raise ValueError(
137
+ f"Unknown backend: {backend} not in {{{','.join(backends)}}}."
138
+ )
139
+
140
+ def extract_single_image(self, image: torch.Tensor):
141
+ image_np = image.cpu().numpy().squeeze(0)
142
+
143
+ if self.conf.backend.startswith("pycolmap"):
144
+ if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
145
+ detections, descriptors = self.sift.extract(image_np)
146
+ scores = None # Scores are not exposed by COLMAP anymore.
147
+ else:
148
+ detections, scores, descriptors = self.sift.extract(image_np)
149
+ keypoints = detections[:, :2] # Keep only (x, y).
150
+ scales, angles = detections[:, -2:].T
151
+ if scores is not None and (
152
+ self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
153
+ ):
154
+ # Set the scores as a combination of abs. response and scale.
155
+ scores = np.abs(scores) * scales
156
+ elif self.conf.backend == "opencv":
157
+ # TODO: Check if opencv keypoints are already in corner convention
158
+ keypoints, scores, scales, angles, descriptors = run_opencv_sift(
159
+ self.sift, (image_np * 255.0).astype(np.uint8)
160
+ )
161
+ pred = {
162
+ "keypoints": keypoints,
163
+ "scales": scales,
164
+ "oris": angles,
165
+ "descriptors": descriptors,
166
+ }
167
+ if scores is not None:
168
+ pred["keypoint_scores"] = scores
169
+
170
+ # sometimes pycolmap returns points outside the image. We remove them
171
+ if self.conf.backend.startswith("pycolmap"):
172
+ is_inside = (
173
+ pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
174
+ ).all(-1)
175
+ pred = {k: v[is_inside] for k, v in pred.items()}
176
+
177
+ if self.conf.nms_radius is not None:
178
+ keep = filter_dog_point(
179
+ pred["keypoints"],
180
+ pred["scales"],
181
+ pred["oris"],
182
+ image_np.shape,
183
+ self.conf.nms_radius,
184
+ scores=pred.get("keypoint_scores"),
185
+ )
186
+ pred = {k: v[keep] for k, v in pred.items()}
187
+
188
+ pred = {k: torch.from_numpy(v) for k, v in pred.items()}
189
+ if scores is not None:
190
+ # Keep the k keypoints with highest score
191
+ num_points = self.conf.max_num_keypoints
192
+ if num_points is not None and len(pred["keypoints"]) > num_points:
193
+ indices = torch.topk(pred["keypoint_scores"], num_points).indices
194
+ pred = {k: v[indices] for k, v in pred.items()}
195
+
196
+ return pred
197
+
198
+ def forward(self, data: dict) -> dict:
199
+ image = data["image"]
200
+ if image.shape[1] == 3:
201
+ image = rgb_to_grayscale(image)
202
+ device = image.device
203
+ image = image.cpu()
204
+ pred = []
205
+ for k in range(len(image)):
206
+ img = image[k]
207
+ if "image_size" in data.keys():
208
+ # avoid extracting points in padded areas
209
+ w, h = data["image_size"][k]
210
+ img = img[:, :h, :w]
211
+ p = self.extract_single_image(img)
212
+ pred.append(p)
213
+ pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
214
+ if self.conf.rootsift:
215
+ pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
216
+ return pred
imcui/third_party/dad/dad/detectors/third_party/lightglue/superpoint.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %BANNER_BEGIN%
2
+ # ---------------------------------------------------------------------
3
+ # %COPYRIGHT_BEGIN%
4
+ #
5
+ # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6
+ #
7
+ # Unpublished Copyright (c) 2020
8
+ # Magic Leap, Inc., All Rights Reserved.
9
+ #
10
+ # NOTICE: All information contained herein is, and remains the property
11
+ # of COMPANY. The intellectual and technical concepts contained herein
12
+ # are proprietary to COMPANY and may be covered by U.S. and Foreign
13
+ # Patents, patents in process, and are protected by trade secret or
14
+ # copyright law. Dissemination of this information or reproduction of
15
+ # this material is strictly forbidden unless prior written permission is
16
+ # obtained from COMPANY. Access to the source code contained herein is
17
+ # hereby forbidden to anyone except current COMPANY employees, managers
18
+ # or contractors who have executed Confidentiality and Non-disclosure
19
+ # agreements explicitly covering such access.
20
+ #
21
+ # The copyright notice above does not evidence any actual or intended
22
+ # publication or disclosure of this source code, which includes
23
+ # information that is confidential and/or proprietary, and is a trade
24
+ # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25
+ # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26
+ # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27
+ # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28
+ # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29
+ # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30
+ # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31
+ # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32
+ #
33
+ # %COPYRIGHT_END%
34
+ # ----------------------------------------------------------------------
35
+ # %AUTHORS_BEGIN%
36
+ #
37
+ # Originating Authors: Paul-Edouard Sarlin
38
+ #
39
+ # %AUTHORS_END%
40
+ # --------------------------------------------------------------------*/
41
+ # %BANNER_END%
42
+
43
+ # Adapted by Remi Pautrat, Philipp Lindenberger
44
+
45
+ import torch
46
+ from kornia.color import rgb_to_grayscale
47
+ from torch import nn
48
+
49
+ from .utils import Extractor
50
+
51
+
52
+ def simple_nms(scores, nms_radius: int):
53
+ """Fast Non-maximum suppression to remove nearby points"""
54
+ assert nms_radius >= 0
55
+
56
+ def max_pool(x):
57
+ return torch.nn.functional.max_pool2d(
58
+ x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
59
+ )
60
+
61
+ zeros = torch.zeros_like(scores)
62
+ max_mask = scores == max_pool(scores)
63
+ for _ in range(2):
64
+ supp_mask = max_pool(max_mask.float()) > 0
65
+ supp_scores = torch.where(supp_mask, zeros, scores)
66
+ new_max_mask = supp_scores == max_pool(supp_scores)
67
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
68
+ return torch.where(max_mask, scores, zeros)
69
+
70
+
71
+ def top_k_keypoints(keypoints, scores, k):
72
+ if k >= len(keypoints):
73
+ return keypoints, scores
74
+ scores, indices = torch.topk(scores, k, dim=0, sorted=True)
75
+ return keypoints[indices], scores
76
+
77
+
78
+ def sample_descriptors(keypoints, descriptors, s: int = 8):
79
+ """Interpolate descriptors at keypoint locations"""
80
+ b, c, h, w = descriptors.shape
81
+ keypoints = keypoints - s / 2 + 0.5
82
+ keypoints /= torch.tensor(
83
+ [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
84
+ ).to(keypoints)[None]
85
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
86
+ args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
87
+ descriptors = torch.nn.functional.grid_sample(
88
+ descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
89
+ )
90
+ descriptors = torch.nn.functional.normalize(
91
+ descriptors.reshape(b, c, -1), p=2, dim=1
92
+ )
93
+ return descriptors
94
+
95
+
96
+ class SuperPoint(Extractor):
97
+ """SuperPoint Convolutional Detector and Descriptor
98
+
99
+ SuperPoint: Self-Supervised Interest Point Detection and
100
+ Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
101
+ Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
102
+
103
+ """
104
+
105
+ default_conf = {
106
+ "descriptor_dim": 256,
107
+ "nms_radius": 4,
108
+ "max_num_keypoints": None,
109
+ # TODO: detection threshold
110
+ "detection_threshold": 0.0005,
111
+ "remove_borders": 4,
112
+ }
113
+
114
+ preprocess_conf = {
115
+ "resize": 1024,
116
+ }
117
+
118
+ required_data_keys = ["image"]
119
+
120
+ def __init__(self, **conf):
121
+ super().__init__(**conf) # Update with default configuration.
122
+ self.relu = nn.ReLU(inplace=True)
123
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
124
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
125
+
126
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
127
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
128
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
129
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
130
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
131
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
132
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
133
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
134
+
135
+ self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
136
+ self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
137
+
138
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
139
+ self.convDb = nn.Conv2d(
140
+ c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
141
+ )
142
+
143
+ url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
144
+ self.load_state_dict(torch.hub.load_state_dict_from_url(url))
145
+
146
+ if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
147
+ raise ValueError("max_num_keypoints must be positive or None")
148
+
149
+ def forward(self, data: dict) -> dict:
150
+ """Compute keypoints, scores, descriptors for image"""
151
+ for key in self.required_data_keys:
152
+ assert key in data, f"Missing key {key} in data"
153
+ image = data["image"]
154
+ if image.shape[1] == 3:
155
+ image = rgb_to_grayscale(image)
156
+
157
+ # Shared Encoder
158
+ x = self.relu(self.conv1a(image))
159
+ x = self.relu(self.conv1b(x))
160
+ x = self.pool(x)
161
+ x = self.relu(self.conv2a(x))
162
+ x = self.relu(self.conv2b(x))
163
+ x = self.pool(x)
164
+ x = self.relu(self.conv3a(x))
165
+ x = self.relu(self.conv3b(x))
166
+ x = self.pool(x)
167
+ x = self.relu(self.conv4a(x))
168
+ x = self.relu(self.conv4b(x))
169
+
170
+ # Compute the dense keypoint scores
171
+ cPa = self.relu(self.convPa(x))
172
+ scores = self.convPb(cPa)
173
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
174
+ b, _, h, w = scores.shape
175
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
176
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
177
+ scores = simple_nms(scores, self.conf.nms_radius)
178
+
179
+ # Discard keypoints near the image borders
180
+ if self.conf.remove_borders:
181
+ pad = self.conf.remove_borders
182
+ scores[:, :pad] = -1
183
+ scores[:, :, :pad] = -1
184
+ scores[:, -pad:] = -1
185
+ scores[:, :, -pad:] = -1
186
+
187
+ # Extract keypoints
188
+ best_kp = torch.where(scores > self.conf.detection_threshold)
189
+ scores = scores[best_kp]
190
+
191
+ # Separate into batches
192
+ keypoints = [
193
+ torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
194
+ ]
195
+ scores = [scores[best_kp[0] == i] for i in range(b)]
196
+
197
+ # Keep the k keypoints with highest score
198
+ if self.conf.max_num_keypoints is not None:
199
+ keypoints, scores = list(
200
+ zip(
201
+ *[
202
+ top_k_keypoints(k, s, self.conf.max_num_keypoints)
203
+ for k, s in zip(keypoints, scores)
204
+ ]
205
+ )
206
+ )
207
+
208
+ # Convert (h, w) to (x, y)
209
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
210
+
211
+ # Compute the dense descriptors
212
+ cDa = self.relu(self.convDa(x))
213
+ descriptors = self.convDb(cDa)
214
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
215
+
216
+ # Extract descriptors
217
+ descriptors = [
218
+ sample_descriptors(k[None], d[None], 8)[0]
219
+ for k, d in zip(keypoints, descriptors)
220
+ ]
221
+
222
+ return {
223
+ "keypoints": torch.stack(keypoints, 0),
224
+ "keypoint_scores": torch.stack(scores, 0),
225
+ "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
226
+ }
227
+
228
+
229
+ class ReinforcedFP(SuperPoint):
230
+ def __init__(self, **conf):
231
+ super().__init__(**conf) # Update with default configuration.
232
+ url = "https://github.com/aritrabhowmik/Reinforced-Feature-Points/raw/refs/heads/master/weights/baseline_mixed_loss.pth" # noqa
233
+ self.load_state_dict(torch.hub.load_state_dict_from_url(url))
imcui/third_party/dad/dad/detectors/third_party/lightglue/utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc as collections
2
+ from pathlib import Path
3
+ from types import SimpleNamespace
4
+ from typing import Callable, List, Optional, Tuple, Union
5
+
6
+ import cv2
7
+ import kornia
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class ImagePreprocessor:
13
+ default_conf = {
14
+ "resize": None, # target edge length, None for no resizing
15
+ "side": "long",
16
+ "interpolation": "bilinear",
17
+ "align_corners": None,
18
+ "antialias": True,
19
+ }
20
+
21
+ def __init__(self, **conf) -> None:
22
+ super().__init__()
23
+ self.conf = {**self.default_conf, **conf}
24
+ self.conf = SimpleNamespace(**self.conf)
25
+
26
+ def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
27
+ """Resize and preprocess an image, return image and resize scale"""
28
+ h, w = img.shape[-2:]
29
+ if self.conf.resize is not None:
30
+ img = kornia.geometry.transform.resize(
31
+ img,
32
+ self.conf.resize,
33
+ side=self.conf.side,
34
+ antialias=self.conf.antialias,
35
+ align_corners=self.conf.align_corners,
36
+ )
37
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
38
+ return img, scale
39
+
40
+
41
+ def map_tensor(input_, func: Callable):
42
+ string_classes = (str, bytes)
43
+ if isinstance(input_, string_classes):
44
+ return input_
45
+ elif isinstance(input_, collections.Mapping):
46
+ return {k: map_tensor(sample, func) for k, sample in input_.items()}
47
+ elif isinstance(input_, collections.Sequence):
48
+ return [map_tensor(sample, func) for sample in input_]
49
+ elif isinstance(input_, torch.Tensor):
50
+ return func(input_)
51
+ else:
52
+ return input_
53
+
54
+
55
+ def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
56
+ """Move batch (dict) to device"""
57
+
58
+ def _func(tensor):
59
+ return tensor.to(device=device, non_blocking=non_blocking).detach()
60
+
61
+ return map_tensor(batch, _func)
62
+
63
+
64
+ def rbd(data: dict) -> dict:
65
+ """Remove batch dimension from elements in data"""
66
+ return {
67
+ k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
68
+ for k, v in data.items()
69
+ }
70
+
71
+
72
+ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
73
+ """Normalize the image tensor and reorder the dimensions."""
74
+ if image.ndim == 3:
75
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
76
+ elif image.ndim == 2:
77
+ image = image[None] # add channel axis
78
+ else:
79
+ raise ValueError(f"Not an image: {image.shape}")
80
+ return torch.tensor(image / 255.0, dtype=torch.float)
81
+
82
+
83
+ def resize_image(
84
+ image: np.ndarray,
85
+ size: Union[List[int], int],
86
+ fn: str = "max",
87
+ interp: Optional[str] = "area",
88
+ ) -> np.ndarray:
89
+ """Resize an image to a fixed size, or according to max or min edge."""
90
+ h, w = image.shape[:2]
91
+
92
+ fn = {"max": max, "min": min}[fn]
93
+ if isinstance(size, int):
94
+ scale = size / fn(h, w)
95
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
96
+ scale = (w_new / w, h_new / h)
97
+ elif isinstance(size, (tuple, list)):
98
+ h_new, w_new = size
99
+ scale = (w_new / w, h_new / h)
100
+ else:
101
+ raise ValueError(f"Incorrect new size: {size}")
102
+ mode = {
103
+ "linear": cv2.INTER_LINEAR,
104
+ "cubic": cv2.INTER_CUBIC,
105
+ "nearest": cv2.INTER_NEAREST,
106
+ "area": cv2.INTER_AREA,
107
+ }[interp]
108
+ return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
109
+
110
+
111
+ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
112
+ if not Path(path).exists():
113
+ raise FileNotFoundError(f"No image at path {path}.")
114
+ mode = cv2.IMREAD_COLOR
115
+ image = cv2.imread(str(path), mode)
116
+ if image is None:
117
+ raise IOError(f"Could not read image at {path}.")
118
+ image = image[..., ::-1]
119
+ if resize is not None:
120
+ image, _ = resize_image(image, resize, **kwargs)
121
+ return numpy_image_to_torch(image)
122
+
123
+
124
+ class Extractor(torch.nn.Module):
125
+ def __init__(self, **conf):
126
+ super().__init__()
127
+ self.conf = SimpleNamespace(**{**self.default_conf, **conf})
128
+
129
+ @torch.no_grad()
130
+ def extract(self, img: torch.Tensor, **conf) -> dict:
131
+ """Perform extraction with online resizing"""
132
+ if img.dim() == 3:
133
+ img = img[None] # add batch dim
134
+ assert img.dim() == 4 and img.shape[0] == 1
135
+ shape = img.shape[-2:][::-1]
136
+ img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
137
+ feats = self.forward({"image": img})
138
+ feats["image_size"] = torch.tensor(shape)[None].to(img).float()
139
+ feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
140
+ return feats
141
+
142
+
143
+ def match_pair(
144
+ extractor,
145
+ matcher,
146
+ image0: torch.Tensor,
147
+ image1: torch.Tensor,
148
+ device: str = "cpu",
149
+ **preprocess,
150
+ ):
151
+ """Match a pair of images (image0, image1) with an extractor and matcher"""
152
+ feats0 = extractor.extract(image0, **preprocess)
153
+ feats1 = extractor.extract(image1, **preprocess)
154
+ matches01 = matcher({"image0": feats0, "image1": feats1})
155
+ data = [feats0, feats1, matches01]
156
+ # remove batch dim and move to target device
157
+ feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
158
+ return feats0, feats1, matches01
imcui/third_party/dad/dad/detectors/third_party/lightglue_detector.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union
3
+ import torch
4
+ from .lightglue.utils import load_image
5
+ from dad.utils import (
6
+ get_best_device,
7
+ )
8
+ from dad.types import Detector
9
+
10
+
11
+ class LightGlueDetector(Detector):
12
+ def __init__(self, model, resize=None, **kwargs):
13
+ super().__init__()
14
+ self.model = model(**kwargs).eval().to(get_best_device())
15
+ if resize is not None:
16
+ self.model.preprocess_conf["resize"] = resize
17
+
18
+ @property
19
+ def topleft(self):
20
+ return 0.0
21
+
22
+ def load_image(self, im_path: Union[str, Path]):
23
+ return {"image": load_image(im_path).to(get_best_device())}
24
+
25
+ @torch.inference_mode()
26
+ def detect(
27
+ self,
28
+ batch: dict[str, torch.Tensor],
29
+ *,
30
+ num_keypoints: int,
31
+ return_dense_probs: bool = False,
32
+ ):
33
+ image = batch["image"]
34
+ self.model.conf.max_num_keypoints = num_keypoints
35
+ ret = self.model.extract(image)
36
+ kpts = self.to_normalized_coords(
37
+ ret["keypoints"], ret["image_size"][0, 1], ret["image_size"][0, 0]
38
+ )
39
+ result = {"keypoints": kpts, "keypoint_probs": None}
40
+ if return_dense_probs:
41
+ result["dense_probs"] = ret["dense_probs"] if "dense_probs" in ret else None
42
+ return result
imcui/third_party/dad/dad/detectors/third_party/rekd/config.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ ## for fix seed
4
+ import random
5
+ import torch
6
+ import numpy
7
+
8
+
9
+ def get_config(jupyter=False):
10
+ parser = argparse.ArgumentParser(description="Train REKD Architecture")
11
+
12
+ ## basic configuration
13
+ parser.add_argument(
14
+ "--data_dir",
15
+ type=str,
16
+ default="../ImageNet2012/ILSVRC2012_img_val", # default='path-to-ImageNet',
17
+ help="The root path to the data from which the synthetic dataset will be created.",
18
+ )
19
+ parser.add_argument(
20
+ "--synth_dir",
21
+ type=str,
22
+ default="",
23
+ help="The path to save the generated sythetic image pairs.",
24
+ )
25
+ parser.add_argument(
26
+ "--log_dir",
27
+ type=str,
28
+ default="trained_models/weights",
29
+ help="The path to save the REKD weights.",
30
+ )
31
+ parser.add_argument(
32
+ "--load_dir",
33
+ type=str,
34
+ default="",
35
+ help="Set saved model parameters if resume training is desired.",
36
+ )
37
+ parser.add_argument(
38
+ "--exp_name",
39
+ type=str,
40
+ default="REKD",
41
+ help="The Rotaton-equivaraiant Keypoint Detection (REKD) experiment name",
42
+ )
43
+ ## network architecture
44
+ parser.add_argument(
45
+ "--factor_scaling_pyramid",
46
+ type=float,
47
+ default=1.2,
48
+ help="The scale factor between the multi-scale pyramid levels in the architecture.",
49
+ )
50
+ parser.add_argument(
51
+ "--group_size",
52
+ type=int,
53
+ default=36,
54
+ help="The number of groups for the group convolution.",
55
+ )
56
+ parser.add_argument(
57
+ "--dim_first",
58
+ type=int,
59
+ default=2,
60
+ help="The number of channels of the first layer",
61
+ )
62
+ parser.add_argument(
63
+ "--dim_second",
64
+ type=int,
65
+ default=2,
66
+ help="The number of channels of the second layer",
67
+ )
68
+ parser.add_argument(
69
+ "--dim_third",
70
+ type=int,
71
+ default=2,
72
+ help="The number of channels of the thrid layer",
73
+ )
74
+ ## network training
75
+ parser.add_argument(
76
+ "--batch_size", type=int, default=16, help="The batch size for training."
77
+ )
78
+ parser.add_argument(
79
+ "--num_epochs", type=int, default=20, help="Number of epochs for training."
80
+ )
81
+ ## Loss function
82
+ parser.add_argument(
83
+ "--init_initial_learning_rate",
84
+ type=float,
85
+ default=1e-3,
86
+ help="The init initial learning rate value.",
87
+ )
88
+ parser.add_argument(
89
+ "--MSIP_sizes", type=str, default="8,16,24,32,40", help="MSIP sizes."
90
+ )
91
+ parser.add_argument(
92
+ "--MSIP_factor_loss",
93
+ type=str,
94
+ default="256.0,64.0,16.0,4.0,1.0",
95
+ help="MSIP loss balancing parameters.",
96
+ )
97
+ parser.add_argument("--ori_loss_balance", type=float, default=100.0, help="")
98
+ ## Dataset generation
99
+ parser.add_argument(
100
+ "--patch_size",
101
+ type=int,
102
+ default=192,
103
+ help="The patch size of the generated dataset.",
104
+ )
105
+ parser.add_argument(
106
+ "--max_angle",
107
+ type=int,
108
+ default=180,
109
+ help="The max angle value for generating a synthetic view to train REKD.",
110
+ )
111
+ parser.add_argument(
112
+ "--min_scale",
113
+ type=float,
114
+ default=1.0,
115
+ help="The min scale value for generating a synthetic view to train REKD.",
116
+ )
117
+ parser.add_argument(
118
+ "--max_scale",
119
+ type=float,
120
+ default=1.0,
121
+ help="The max scale value for generating a synthetic view to train REKD.",
122
+ )
123
+ parser.add_argument(
124
+ "--max_shearing",
125
+ type=float,
126
+ default=0.0,
127
+ help="The max shearing value for generating a synthetic view to train REKD.",
128
+ )
129
+ parser.add_argument(
130
+ "--num_training_data",
131
+ type=int,
132
+ default=9000,
133
+ help="The number of the generated dataset.",
134
+ )
135
+ parser.add_argument(
136
+ "--is_debugging",
137
+ type=bool,
138
+ default=False,
139
+ help="Set variable to True if you desire to train network on a smaller dataset.",
140
+ )
141
+ ## For eval/inference
142
+ parser.add_argument(
143
+ "--num_points",
144
+ type=int,
145
+ default=1500,
146
+ help="the number of points at evaluation time.",
147
+ )
148
+ parser.add_argument(
149
+ "--pyramid_levels", type=int, default=5, help="downsampling pyramid levels."
150
+ )
151
+ parser.add_argument(
152
+ "--upsampled_levels", type=int, default=2, help="upsampling image levels."
153
+ )
154
+ parser.add_argument(
155
+ "--nms_size",
156
+ type=int,
157
+ default=15,
158
+ help="The NMS size for computing the validation repeatability.",
159
+ )
160
+ parser.add_argument(
161
+ "--border_size",
162
+ type=int,
163
+ default=15,
164
+ help="The number of pixels to remove from the borders to compute the repeatability.",
165
+ )
166
+ ## For HPatches evaluation
167
+ parser.add_argument(
168
+ "--hpatches_path",
169
+ type=str,
170
+ default="./datasets/hpatches-sequences-release",
171
+ help="dataset ",
172
+ )
173
+ parser.add_argument(
174
+ "--eval_split",
175
+ type=str,
176
+ default="debug",
177
+ help="debug, view, illum, full, debug_view, debug_illum ...",
178
+ )
179
+ parser.add_argument(
180
+ "--descriptor", type=str, default="hardnet", help="hardnet, sosnet, hynet"
181
+ )
182
+
183
+ args, weird_args = (
184
+ parser.parse_known_args() if not jupyter else parser.parse_args(args=[])
185
+ )
186
+
187
+ fix_randseed(12345)
188
+
189
+ if args.synth_dir == "":
190
+ args.synth_dir = "datasets/synth_data"
191
+
192
+ args.MSIP_sizes = [int(i) for i in args.MSIP_sizes.split(",")]
193
+ args.MSIP_factor_loss = [float(i) for i in args.MSIP_factor_loss.split(",")]
194
+
195
+ return args
196
+
197
+
198
+ def fix_randseed(randseed):
199
+ r"""Fix random seed"""
200
+ random.seed(randseed)
201
+ numpy.random.seed(randseed)
202
+ torch.manual_seed(randseed)
203
+ torch.cuda.manual_seed(randseed)
204
+ torch.cuda.manual_seed_all(randseed)
205
+ torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic = False, True
206
+ # torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic = True, False
imcui/third_party/dad/dad/detectors/third_party/rekd/geometry_tools.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cv2 import warpPerspective as applyH
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def apply_nms(score_map, size):
7
+ from scipy.ndimage.filters import maximum_filter
8
+
9
+ score_map = score_map * (
10
+ score_map == maximum_filter(score_map, footprint=np.ones((size, size)))
11
+ )
12
+ return score_map
13
+
14
+
15
+ def remove_borders(images, borders):
16
+ ## input [B,C,H,W]
17
+ shape = images.shape
18
+
19
+ if len(shape) == 4:
20
+ for batch_id in range(shape[0]):
21
+ images[batch_id, :, 0:borders, :] = 0
22
+ images[batch_id, :, :, 0:borders] = 0
23
+ images[batch_id, :, shape[2] - borders : shape[2], :] = 0
24
+ images[batch_id, :, :, shape[3] - borders : shape[3]] = 0
25
+ elif len(shape) == 2:
26
+ images[0:borders, :] = 0
27
+ images[:, 0:borders] = 0
28
+ images[shape[0] - borders : shape[0], :] = 0
29
+ images[:, shape[1] - borders : shape[1]] = 0
30
+ else:
31
+ print("Not implemented")
32
+ exit()
33
+
34
+ return images
35
+
36
+
37
+ def create_common_region_masks(h_dst_2_src, shape_src, shape_dst):
38
+ # Create mask. Only take into account pixels in the two images
39
+ inv_h = np.linalg.inv(h_dst_2_src)
40
+ inv_h = inv_h / inv_h[2, 2]
41
+
42
+ # Applies mask to destination. Where there is no 1, we can no find a point in source.
43
+ ones_dst = np.ones((shape_dst[0], shape_dst[1]))
44
+ ones_dst = remove_borders(ones_dst, borders=15)
45
+ mask_src = applyH(ones_dst, h_dst_2_src, (shape_src[1], shape_src[0]))
46
+ mask_src = np.where(mask_src >= 0.75, 1.0, 0.0)
47
+ mask_src = remove_borders(mask_src, borders=15)
48
+
49
+ ones_src = np.ones((shape_src[0], shape_src[1]))
50
+ ones_src = remove_borders(ones_src, borders=15)
51
+ mask_dst = applyH(ones_src, inv_h, (shape_dst[1], shape_dst[0]))
52
+ mask_dst = np.where(mask_dst >= 0.75, 1.0, 0.0)
53
+ mask_dst = remove_borders(mask_dst, borders=15)
54
+
55
+ return mask_src, mask_dst
56
+
57
+
58
+ def prepare_homography(hom):
59
+ if len(hom.shape) == 1:
60
+ h = np.zeros((3, 3))
61
+ for j in range(3):
62
+ for i in range(3):
63
+ if j == 2 and i == 2:
64
+ h[j, i] = 1.0
65
+ else:
66
+ h[j, i] = hom[j * 3 + i]
67
+ elif len(hom.shape) == 2: ## batch
68
+ ones = torch.ones(hom.shape[0]).unsqueeze(1)
69
+ h = torch.cat([hom, ones], dim=1).reshape(-1, 3, 3).type(torch.float32)
70
+
71
+ return h
72
+
73
+
74
+ def getAff(x, y, H):
75
+ h11 = H[0, 0]
76
+ h12 = H[0, 1]
77
+ h13 = H[0, 2]
78
+ h21 = H[1, 0]
79
+ h22 = H[1, 1]
80
+ h23 = H[1, 2]
81
+ h31 = H[2, 0]
82
+ h32 = H[2, 1]
83
+ h33 = H[2, 2]
84
+ fxdx = (
85
+ h11 / (h31 * x + h32 * y + h33)
86
+ - (h11 * x + h12 * y + h13) * h31 / (h31 * x + h32 * y + h33) ** 2
87
+ )
88
+ fxdy = (
89
+ h12 / (h31 * x + h32 * y + h33)
90
+ - (h11 * x + h12 * y + h13) * h32 / (h31 * x + h32 * y + h33) ** 2
91
+ )
92
+
93
+ fydx = (
94
+ h21 / (h31 * x + h32 * y + h33)
95
+ - (h21 * x + h22 * y + h23) * h31 / (h31 * x + h32 * y + h33) ** 2
96
+ )
97
+ fydy = (
98
+ h22 / (h31 * x + h32 * y + h33)
99
+ - (h21 * x + h22 * y + h23) * h32 / (h31 * x + h32 * y + h33) ** 2
100
+ )
101
+
102
+ Aff = [[fxdx, fxdy], [fydx, fydy]]
103
+
104
+ return np.asarray(Aff)
105
+
106
+
107
+ def apply_homography_to_points(points, h):
108
+ new_points = []
109
+
110
+ for point in points:
111
+ new_point = h.dot([point[0], point[1], 1.0])
112
+
113
+ tmp = point[2] ** 2 + np.finfo(np.float32).eps
114
+
115
+ Mi1 = [[1 / tmp, 0], [0, 1 / tmp]]
116
+ Mi1_inv = np.linalg.inv(Mi1)
117
+ Aff = getAff(point[0], point[1], h)
118
+
119
+ BMB = np.linalg.inv(np.dot(Aff, np.dot(Mi1_inv, np.matrix.transpose(Aff))))
120
+
121
+ [e, _] = np.linalg.eig(BMB)
122
+ new_radious = 1 / ((e[0] * e[1]) ** 0.5) ** 0.5
123
+
124
+ new_point = [
125
+ new_point[0] / new_point[2],
126
+ new_point[1] / new_point[2],
127
+ new_radious,
128
+ point[3],
129
+ ]
130
+ new_points.append(new_point)
131
+
132
+ return np.asarray(new_points)
133
+
134
+
135
+ def find_index_higher_scores(map, num_points=1000, threshold=-1):
136
+ # Best n points
137
+ if threshold == -1:
138
+ flatten = map.flatten()
139
+ order_array = np.sort(flatten)
140
+
141
+ order_array = np.flip(order_array, axis=0)
142
+
143
+ if order_array.shape[0] < num_points:
144
+ num_points = order_array.shape[0]
145
+
146
+ threshold = order_array[num_points - 1]
147
+
148
+ if threshold <= 0.0:
149
+ ### This is the problem case which derive smaller number of keypoints than the argument "num_points".
150
+ indexes = np.argwhere(order_array > 0.0)
151
+
152
+ if len(indexes) == 0:
153
+ threshold = 0.0
154
+ else:
155
+ threshold = order_array[indexes[len(indexes) - 1]]
156
+
157
+ indexes = np.argwhere(map >= threshold)
158
+
159
+ return indexes[:num_points]
160
+
161
+
162
+ def get_point_coordinates(
163
+ map, scale_value=1.0, num_points=1000, threshold=-1, order_coord="xysr"
164
+ ):
165
+ ## input numpy array score map : [H, W]
166
+ indexes = find_index_higher_scores(map, num_points=num_points, threshold=threshold)
167
+ new_indexes = []
168
+ for ind in indexes:
169
+ scores = map[ind[0], ind[1]]
170
+ if order_coord == "xysr":
171
+ tmp = [ind[1], ind[0], scale_value, scores]
172
+ elif order_coord == "yxsr":
173
+ tmp = [ind[0], ind[1], scale_value, scores]
174
+
175
+ new_indexes.append(tmp)
176
+
177
+ indexes = np.asarray(new_indexes)
178
+
179
+ return np.asarray(indexes)
180
+
181
+
182
+ def get_point_coordinates3D(
183
+ map,
184
+ scale_factor=1.0,
185
+ up_levels=0,
186
+ num_points=1000,
187
+ threshold=-1,
188
+ order_coord="xysr",
189
+ ):
190
+ indexes = find_index_higher_scores(map, num_points=num_points, threshold=threshold)
191
+ new_indexes = []
192
+ for ind in indexes:
193
+ scale_value = scale_factor ** (ind[2] - up_levels)
194
+ scores = map[ind[0], ind[1], ind[2]]
195
+ if order_coord == "xysr":
196
+ tmp = [ind[1], ind[0], scale_value, scores]
197
+ elif order_coord == "yxsr":
198
+ tmp = [ind[0], ind[1], scale_value, scores]
199
+
200
+ new_indexes.append(tmp)
201
+
202
+ indexes = np.asarray(new_indexes)
203
+
204
+ return np.asarray(indexes)
imcui/third_party/dad/dad/detectors/third_party/rekd/model/REKD.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ from .kernels import gaussian_multiple_channels
6
+
7
+
8
+ class REKD(torch.nn.Module):
9
+ def __init__(self, args, device):
10
+ super(REKD, self).__init__()
11
+ from e2cnn import gspaces
12
+ from e2cnn import nn
13
+
14
+ self.pyramid_levels = 3
15
+ self.factor_scaling = args.factor_scaling_pyramid
16
+
17
+ # Smooth Gausian Filter
18
+ num_channels = 1 ## gray scale image
19
+ self.gaussian_avg = gaussian_multiple_channels(num_channels, 1.5)
20
+
21
+ r2_act = gspaces.Rot2dOnR2(N=args.group_size)
22
+
23
+ self.feat_type_in = nn.FieldType(
24
+ r2_act, num_channels * [r2_act.trivial_repr]
25
+ ) ## input 1 channels (gray scale image)
26
+
27
+ feat_type_out1 = nn.FieldType(r2_act, args.dim_first * [r2_act.regular_repr])
28
+ feat_type_out2 = nn.FieldType(r2_act, args.dim_second * [r2_act.regular_repr])
29
+ feat_type_out3 = nn.FieldType(r2_act, args.dim_third * [r2_act.regular_repr])
30
+
31
+ feat_type_ori_est = nn.FieldType(r2_act, [r2_act.regular_repr])
32
+
33
+ self.block1 = nn.SequentialModule(
34
+ nn.R2Conv(
35
+ self.feat_type_in, feat_type_out1, kernel_size=5, padding=2, bias=False
36
+ ),
37
+ nn.InnerBatchNorm(feat_type_out1),
38
+ nn.ReLU(feat_type_out1, inplace=True),
39
+ )
40
+ self.block2 = nn.SequentialModule(
41
+ nn.R2Conv(
42
+ feat_type_out1, feat_type_out2, kernel_size=5, padding=2, bias=False
43
+ ),
44
+ nn.InnerBatchNorm(feat_type_out2),
45
+ nn.ReLU(feat_type_out2, inplace=True),
46
+ )
47
+ self.block3 = nn.SequentialModule(
48
+ nn.R2Conv(
49
+ feat_type_out2, feat_type_out3, kernel_size=5, padding=2, bias=False
50
+ ),
51
+ nn.InnerBatchNorm(feat_type_out3),
52
+ nn.ReLU(feat_type_out3, inplace=True),
53
+ )
54
+
55
+ self.ori_learner = nn.SequentialModule(
56
+ nn.R2Conv(
57
+ feat_type_out3, feat_type_ori_est, kernel_size=1, padding=0, bias=False
58
+ ) ## Channel pooling by 8*G -> 1*G conv.
59
+ )
60
+ self.softmax = torch.nn.Softmax(dim=1)
61
+
62
+ self.gpool = nn.GroupPooling(feat_type_out3)
63
+ self.last_layer_learner = torch.nn.Sequential(
64
+ torch.nn.BatchNorm2d(num_features=args.dim_third * self.pyramid_levels),
65
+ torch.nn.Conv2d(
66
+ in_channels=args.dim_third * self.pyramid_levels,
67
+ out_channels=1,
68
+ kernel_size=1,
69
+ bias=True,
70
+ ),
71
+ torch.nn.ReLU(inplace=True), ## clamp to make the scores positive values.
72
+ )
73
+
74
+ self.dim_third = args.dim_third
75
+ self.group_size = args.group_size
76
+ self.exported = False
77
+
78
+ def export(self):
79
+ from e2cnn import nn
80
+
81
+ for name, module in dict(self.named_modules()).copy().items():
82
+ if isinstance(module, nn.EquivariantModule):
83
+ # print(name, "--->", module)
84
+ module = module.export()
85
+ setattr(self, name, module)
86
+
87
+ self.exported = True
88
+
89
+ def forward(self, input_data):
90
+ features_key, features_o = self.compute_features(input_data)
91
+
92
+ return features_key, features_o
93
+
94
+ def compute_features(self, input_data):
95
+ B, _, H, W = input_data.shape
96
+
97
+ for idx_level in range(self.pyramid_levels):
98
+ with torch.no_grad():
99
+ input_data_resized = self._resize_input_image(
100
+ input_data, idx_level, H, W
101
+ )
102
+
103
+ if H > 2500 or W > 2500:
104
+ features_t, features_o = self._forwarding_networks_divide_grid(
105
+ input_data_resized
106
+ )
107
+ else:
108
+ features_t, features_o = self._forwarding_networks(input_data_resized)
109
+
110
+ features_t = F.interpolate(
111
+ features_t, size=(H, W), align_corners=True, mode="bilinear"
112
+ )
113
+ features_o = F.interpolate(
114
+ features_o, size=(H, W), align_corners=True, mode="bilinear"
115
+ )
116
+
117
+ if idx_level == 0:
118
+ features_key = features_t
119
+ features_ori = features_o
120
+ else:
121
+ features_key = torch.cat([features_key, features_t], axis=1)
122
+ features_ori = torch.add(features_ori, features_o)
123
+
124
+ features_key = self.last_layer_learner(features_key)
125
+ features_ori = self.softmax(features_ori)
126
+
127
+ return features_key, features_ori
128
+
129
+ def _forwarding_networks(self, input_data_resized):
130
+ from e2cnn import nn
131
+
132
+ # wrap the input tensor in a GeometricTensor (associate it with the input type)
133
+ features_t = (
134
+ nn.GeometricTensor(input_data_resized, self.feat_type_in)
135
+ if not self.exported
136
+ else input_data_resized
137
+ )
138
+
139
+ ## Geometric tensor feed forwarding
140
+ features_t = self.block1(features_t)
141
+ features_t = self.block2(features_t)
142
+ features_t = self.block3(features_t)
143
+
144
+ ## orientation pooling
145
+ features_o = self.ori_learner(features_t) ## self.cpool
146
+ features_o = features_o.tensor if not self.exported else features_o
147
+
148
+ ## keypoint pooling
149
+ features_t = self.gpool(features_t)
150
+ features_t = features_t.tensor if not self.exported else features_t
151
+
152
+ return features_t, features_o
153
+
154
+ def _forwarding_networks_divide_grid(self, input_data_resized):
155
+ ## for inference time high resolution image. # spatial grid 4
156
+ B, _, H_resized, W_resized = input_data_resized.shape
157
+ features_t = torch.zeros(B, self.dim_third, H_resized, W_resized).cuda()
158
+ features_o = torch.zeros(B, self.group_size, H_resized, W_resized).cuda()
159
+ h_divide = 2
160
+ w_divide = 2
161
+ for idx in range(h_divide):
162
+ for jdx in range(w_divide):
163
+ ## compute the start and end spatial index
164
+ h_start = H_resized // h_divide * idx
165
+ w_start = W_resized // w_divide * jdx
166
+ h_end = H_resized // h_divide * (idx + 1)
167
+ w_end = W_resized // w_divide * (jdx + 1)
168
+ ## crop the input image
169
+ input_data_divided = input_data_resized[
170
+ :, :, h_start:h_end, w_start:w_end
171
+ ]
172
+ features_t_temp, features_o_temp = self._forwarding_networks(
173
+ input_data_divided
174
+ )
175
+ ## take into the values.
176
+ features_t[:, :, h_start:h_end, w_start:w_end] = features_t_temp
177
+ features_o[:, :, h_start:h_end, w_start:w_end] = features_o_temp
178
+
179
+ return features_t, features_o
180
+
181
+ def _resize_input_image(self, input_data, idx_level, H, W):
182
+ if idx_level == 0:
183
+ input_data_smooth = input_data
184
+ else:
185
+ ## (7,7) size gaussian kernel.
186
+ input_data_smooth = F.conv2d(
187
+ input_data, self.gaussian_avg.to(input_data.device), padding=[3, 3]
188
+ )
189
+
190
+ target_resize = (
191
+ int(H / (self.factor_scaling**idx_level)),
192
+ int(W / (self.factor_scaling**idx_level)),
193
+ )
194
+
195
+ input_data_resized = F.interpolate(
196
+ input_data_smooth, size=target_resize, align_corners=True, mode="bilinear"
197
+ )
198
+
199
+ input_data_resized = self.local_norm_image(input_data_resized)
200
+
201
+ return input_data_resized
202
+
203
+ def local_norm_image(self, x, k_size=65, eps=1e-10):
204
+ pad = int(k_size / 2)
205
+
206
+ x_pad = F.pad(x, (pad, pad, pad, pad), mode="reflect")
207
+ x_mean = F.avg_pool2d(
208
+ x_pad, kernel_size=[k_size, k_size], stride=[1, 1], padding=0
209
+ ) ## padding='valid'==0
210
+ x2_mean = F.avg_pool2d(
211
+ torch.pow(x_pad, 2.0),
212
+ kernel_size=[k_size, k_size],
213
+ stride=[1, 1],
214
+ padding=0,
215
+ )
216
+
217
+ x_std = torch.sqrt(torch.abs(x2_mean - x_mean * x_mean)) + eps
218
+ x_norm = (x - x_mean) / (1.0 + x_std)
219
+
220
+ return x_norm
221
+
222
+
223
+ def count_model_parameters(model):
224
+ ## Count the number of learnable parameters.
225
+ print("================ List of Learnable model parameters ================ ")
226
+ for n, p in model.named_parameters():
227
+ if p.requires_grad:
228
+ print("{} {}".format(n, p.data.shape))
229
+ else:
230
+ print("\n\n\n None learnable params {} {}".format(n, p.data.shape))
231
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
232
+ params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
233
+ print("The number of learnable parameters : {} ".format(params.data))
234
+ print("==================================================================== ")
imcui/third_party/dad/dad/detectors/third_party/rekd/model/kernels.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ def gaussian_multiple_channels(num_channels, sigma):
6
+ r = 2 * sigma
7
+ size = 2 * r + 1
8
+ size = int(math.ceil(size))
9
+ x = torch.arange(0, size, 1, dtype=torch.float)
10
+ y = x.unsqueeze(1)
11
+ x0 = y0 = r
12
+
13
+ gaussian = torch.exp(-1 * (((x - x0) ** 2 + (y - y0) ** 2) / (2 * (sigma**2)))) / (
14
+ (2 * math.pi * (sigma**2)) ** 0.5
15
+ )
16
+ gaussian = gaussian.to(dtype=torch.float32)
17
+
18
+ weights = torch.zeros((num_channels, num_channels, size, size), dtype=torch.float32)
19
+ for i in range(num_channels):
20
+ weights[i, i, :, :] = gaussian
21
+
22
+ return weights
23
+
24
+
25
+ def ones_multiple_channels(size, num_channels):
26
+ ones = torch.ones((size, size))
27
+ weights = torch.zeros((num_channels, num_channels, size, size), dtype=torch.float32)
28
+
29
+ for i in range(num_channels):
30
+ weights[i, i, :, :] = ones
31
+
32
+ return weights
33
+
34
+
35
+ def grid_indexes(size):
36
+ weights = torch.zeros((2, 1, size, size), dtype=torch.float32)
37
+
38
+ columns = []
39
+ for idx in range(1, 1 + size):
40
+ columns.append(torch.ones((size)) * idx)
41
+ columns = torch.stack(columns)
42
+
43
+ rows = []
44
+ for idx in range(1, 1 + size):
45
+ rows.append(torch.tensor(range(1, 1 + size)))
46
+ rows = torch.stack(rows)
47
+
48
+ weights[0, 0, :, :] = columns
49
+ weights[1, 0, :, :] = rows
50
+
51
+ return weights
52
+
53
+
54
+ def get_kernel_size(factor):
55
+ """
56
+ Find the kernel size given the desired factor of upsampling.
57
+ """
58
+ return 2 * factor - factor % 2
59
+
60
+
61
+ def linear_upsample_weights(half_factor, number_of_classes):
62
+ """
63
+ Create weights matrix for transposed convolution with linear filter
64
+ initialization.
65
+ """
66
+
67
+ filter_size = get_kernel_size(half_factor)
68
+
69
+ weights = torch.zeros(
70
+ (
71
+ number_of_classes,
72
+ number_of_classes,
73
+ filter_size,
74
+ filter_size,
75
+ ),
76
+ dtype=torch.float32,
77
+ )
78
+
79
+ upsample_kernel = torch.ones((filter_size, filter_size))
80
+ for i in range(number_of_classes):
81
+ weights[i, i, :, :] = upsample_kernel
82
+
83
+ return weights
84
+
85
+
86
+ class Kernels_custom:
87
+ def __init__(self, args, MSIP_sizes=[]):
88
+ self.batch_size = args.batch_size
89
+ # create_kernels
90
+ self.kernels = {}
91
+
92
+ if MSIP_sizes != []:
93
+ self.create_kernels(MSIP_sizes)
94
+
95
+ if 8 not in MSIP_sizes:
96
+ self.create_kernels([8])
97
+
98
+ def create_kernels(self, MSIP_sizes):
99
+ # Grid Indexes for MSIP
100
+ for ksize in MSIP_sizes:
101
+ ones_kernel = ones_multiple_channels(ksize, 1)
102
+ indexes_kernel = grid_indexes(ksize)
103
+ upsample_filter_np = linear_upsample_weights(int(ksize / 2), 1)
104
+
105
+ self.ones_kernel = ones_kernel.requires_grad_(False)
106
+ self.kernels["ones_kernel_" + str(ksize)] = self.ones_kernel
107
+
108
+ self.upsample_filter_np = upsample_filter_np.requires_grad_(False)
109
+ self.kernels["upsample_filter_np_" + str(ksize)] = self.upsample_filter_np
110
+
111
+ self.indexes_kernel = indexes_kernel.requires_grad_(False)
112
+ self.kernels["indexes_kernel_" + str(ksize)] = self.indexes_kernel
113
+
114
+ def get_kernels(self, device):
115
+ kernels = {}
116
+ for k, v in self.kernels.items():
117
+ kernels[k] = v.to(device)
118
+ return kernels
imcui/third_party/dad/dad/detectors/third_party/rekd/model/load_models.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .REKD import REKD
3
+
4
+
5
+ def load_detector(args, device):
6
+ args.group_size, args.dim_first, args.dim_second, args.dim_third = model_parsing(
7
+ args
8
+ )
9
+ model1 = REKD(args, device)
10
+ model1.load_state_dict(torch.load(args.load_dir, weights_only=True))
11
+ model1.export()
12
+ model1.eval()
13
+ model1.to(device) ## use GPU
14
+
15
+ return model1
16
+
17
+
18
+ ## Load our model
19
+ def model_parsing(args):
20
+ group_size = args.load_dir.split("_group")[1].split("_")[0]
21
+ dim_first = args.load_dir.split("_f")[1].split("_")[0]
22
+ dim_second = args.load_dir.split("_s")[1].split("_")[0]
23
+ dim_third = args.load_dir.split("_t")[1].split(".log")[0]
24
+
25
+ return int(group_size), int(dim_first), int(dim_second), int(dim_third)