Realcat commited on
Commit
45fa4b6
·
1 Parent(s): 09c4be0

update: api

Browse files
imcui/api/__init__.py CHANGED
@@ -1,47 +1,47 @@
1
- import base64
2
- import io
3
- from typing import List
4
-
5
- import numpy as np
6
- from fastapi.exceptions import HTTPException
7
- from PIL import Image
8
- from pydantic import BaseModel
9
-
10
- from ..hloc import logger
11
- from .core import ImageMatchingAPI
12
-
13
-
14
- class ImagesInput(BaseModel):
15
- data: List[str] = []
16
- max_keypoints: List[int] = []
17
- timestamps: List[str] = []
18
- grayscale: bool = False
19
- image_hw: List[List[int]] = [[], []]
20
- feature_type: int = 0
21
- rotates: List[float] = []
22
- scales: List[float] = []
23
- reference_points: List[List[float]] = []
24
- binarize: bool = False
25
-
26
-
27
- def decode_base64_to_image(encoding):
28
- if encoding.startswith("data:image/"):
29
- encoding = encoding.split(";")[1].split(",")[1]
30
- try:
31
- image = Image.open(io.BytesIO(base64.b64decode(encoding)))
32
- return image
33
- except Exception as e:
34
- logger.warning(f"API cannot decode image: {e}")
35
- raise HTTPException(status_code=500, detail="Invalid encoded image") from e
36
-
37
-
38
- def to_base64_nparray(encoding: str) -> np.ndarray:
39
- return np.array(decode_base64_to_image(encoding)).astype("uint8")
40
-
41
-
42
- __all__ = [
43
- "ImageMatchingAPI",
44
- "ImagesInput",
45
- "decode_base64_to_image",
46
- "to_base64_nparray",
47
- ]
 
1
+ import base64
2
+ import io
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ from fastapi.exceptions import HTTPException
7
+ from PIL import Image
8
+ from pydantic import BaseModel
9
+
10
+ from ..hloc import logger
11
+ from .core import ImageMatchingAPI
12
+
13
+
14
+ class ImagesInput(BaseModel):
15
+ data: List[str] = []
16
+ max_keypoints: List[int] = []
17
+ timestamps: List[str] = []
18
+ grayscale: bool = False
19
+ image_hw: List[List[int]] = [[], []]
20
+ feature_type: int = 0
21
+ rotates: List[float] = []
22
+ scales: List[float] = []
23
+ reference_points: List[List[float]] = []
24
+ binarize: bool = False
25
+
26
+
27
+ def decode_base64_to_image(encoding):
28
+ if encoding.startswith("data:image/"):
29
+ encoding = encoding.split(";")[1].split(",")[1]
30
+ try:
31
+ image = Image.open(io.BytesIO(base64.b64decode(encoding)))
32
+ return image
33
+ except Exception as e:
34
+ logger.warning(f"API cannot decode image: {e}")
35
+ raise HTTPException(status_code=500, detail="Invalid encoded image") from e
36
+
37
+
38
+ def to_base64_nparray(encoding: str) -> np.ndarray:
39
+ return np.array(decode_base64_to_image(encoding)).astype("uint8")
40
+
41
+
42
+ __all__ = [
43
+ "ImageMatchingAPI",
44
+ "ImagesInput",
45
+ "decode_base64_to_image",
46
+ "to_base64_nparray",
47
+ ]
imcui/api/client.py CHANGED
@@ -1,232 +1,232 @@
1
- import argparse
2
- import base64
3
- import os
4
- import pickle
5
- import time
6
- from typing import Dict, List
7
-
8
- import cv2
9
- import numpy as np
10
- import requests
11
-
12
- ENDPOINT = "http://127.0.0.1:8001"
13
- if "REMOTE_URL_RAILWAY" in os.environ:
14
- ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
15
-
16
- print(f"API ENDPOINT: {ENDPOINT}")
17
-
18
- API_VERSION = f"{ENDPOINT}/version"
19
- API_URL_MATCH = f"{ENDPOINT}/v1/match"
20
- API_URL_EXTRACT = f"{ENDPOINT}/v1/extract"
21
-
22
-
23
- def read_image(path: str) -> str:
24
- """
25
- Read an image from a file, encode it as a JPEG and then as a base64 string.
26
-
27
- Args:
28
- path (str): The path to the image to read.
29
-
30
- Returns:
31
- str: The base64 encoded image.
32
- """
33
- # Read the image from the file
34
- img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
35
-
36
- # Encode the image as a png, NO COMPRESSION!!!
37
- retval, buffer = cv2.imencode(".png", img)
38
-
39
- # Encode the JPEG as a base64 string
40
- b64img = base64.b64encode(buffer).decode("utf-8")
41
-
42
- return b64img
43
-
44
-
45
- def do_api_requests(url=API_URL_EXTRACT, **kwargs):
46
- """
47
- Helper function to send an API request to the image matching service.
48
-
49
- Args:
50
- url (str): The URL of the API endpoint to use. Defaults to the
51
- feature extraction endpoint.
52
- **kwargs: Additional keyword arguments to pass to the API.
53
-
54
- Returns:
55
- List[Dict[str, np.ndarray]]: A list of dictionaries containing the
56
- extracted features. The keys are "keypoints", "descriptors", and
57
- "scores", and the values are ndarrays of shape (N, 2), (N, ?),
58
- and (N,), respectively.
59
- """
60
- # Set up the request body
61
- reqbody = {
62
- # List of image data base64 encoded
63
- "data": [],
64
- # List of maximum number of keypoints to extract from each image
65
- "max_keypoints": [100, 100],
66
- # List of timestamps for each image (not used?)
67
- "timestamps": ["0", "1"],
68
- # Whether to convert the images to grayscale
69
- "grayscale": 0,
70
- # List of image height and width
71
- "image_hw": [[640, 480], [320, 240]],
72
- # Type of feature to extract
73
- "feature_type": 0,
74
- # List of rotation angles for each image
75
- "rotates": [0.0, 0.0],
76
- # List of scale factors for each image
77
- "scales": [1.0, 1.0],
78
- # List of reference points for each image (not used)
79
- "reference_points": [[640, 480], [320, 240]],
80
- # Whether to binarize the descriptors
81
- "binarize": True,
82
- }
83
- # Update the request body with the additional keyword arguments
84
- reqbody.update(kwargs)
85
- try:
86
- # Send the request
87
- r = requests.post(url, json=reqbody)
88
- if r.status_code == 200:
89
- # Return the response
90
- return r.json()
91
- else:
92
- # Print an error message if the response code is not 200
93
- print(f"Error: Response code {r.status_code} - {r.text}")
94
- except Exception as e:
95
- # Print an error message if an exception occurs
96
- print(f"An error occurred: {e}")
97
-
98
-
99
- def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
100
- """
101
- Send a request to the API to generate a match between two images.
102
-
103
- Args:
104
- path0 (str): The path to the first image.
105
- path1 (str): The path to the second image.
106
-
107
- Returns:
108
- Dict[str, np.ndarray]: A dictionary containing the generated matches.
109
- The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
110
- and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and
111
- (N, 2), respectively.
112
- """
113
- files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
114
- try:
115
- # TODO: replace files with post json
116
- response = requests.post(API_URL_MATCH, files=files)
117
- pred = {}
118
- if response.status_code == 200:
119
- pred = response.json()
120
- for key in list(pred.keys()):
121
- pred[key] = np.array(pred[key])
122
- else:
123
- print(f"Error: Response code {response.status_code} - {response.text}")
124
- finally:
125
- files["image0"].close()
126
- files["image1"].close()
127
- return pred
128
-
129
-
130
- def send_request_extract(
131
- input_images: str, viz: bool = False
132
- ) -> List[Dict[str, np.ndarray]]:
133
- """
134
- Send a request to the API to extract features from an image.
135
-
136
- Args:
137
- input_images (str): The path to the image.
138
-
139
- Returns:
140
- List[Dict[str, np.ndarray]]: A list of dictionaries containing the
141
- extracted features. The keys are "keypoints", "descriptors", and
142
- "scores", and the values are ndarrays of shape (N, 2), (N, 128),
143
- and (N,), respectively.
144
- """
145
- image_data = read_image(input_images)
146
- inputs = {
147
- "data": [image_data],
148
- }
149
- response = do_api_requests(
150
- url=API_URL_EXTRACT,
151
- **inputs,
152
- )
153
- # breakpoint()
154
- # print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
155
-
156
- # draw matching, debug only
157
- if viz:
158
- from hloc.utils.viz import plot_keypoints
159
- from ui.viz import fig2im, plot_images
160
-
161
- kpts = np.array(response[0]["keypoints_orig"])
162
- if "image_orig" in response[0].keys():
163
- img_orig = np.array(["image_orig"])
164
-
165
- output_keypoints = plot_images([img_orig], titles="titles", dpi=300)
166
- plot_keypoints([kpts])
167
- output_keypoints = fig2im(output_keypoints)
168
- cv2.imwrite(
169
- "demo_match.jpg",
170
- output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
171
- )
172
- return response
173
-
174
-
175
- def get_api_version():
176
- try:
177
- response = requests.get(API_VERSION).json()
178
- print("API VERSION: {}".format(response["version"]))
179
- except Exception as e:
180
- print(f"An error occurred: {e}")
181
-
182
-
183
- if __name__ == "__main__":
184
- from pathlib import Path
185
-
186
- parser = argparse.ArgumentParser(
187
- description="Send text to stable audio server and receive generated audio."
188
- )
189
- parser.add_argument(
190
- "--image0",
191
- required=False,
192
- help="Path for the file's melody",
193
- default=str(
194
- Path(__file__).parents[1]
195
- / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg"
196
- ),
197
- )
198
- parser.add_argument(
199
- "--image1",
200
- required=False,
201
- help="Path for the file's melody",
202
- default=str(
203
- Path(__file__).parents[1]
204
- / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg"
205
- ),
206
- )
207
- args = parser.parse_args()
208
-
209
- # get api version
210
- get_api_version()
211
-
212
- # request match
213
- # for i in range(10):
214
- # t1 = time.time()
215
- # preds = send_request_match(args.image0, args.image1)
216
- # t2 = time.time()
217
- # print(
218
- # "Time cost1: {} seconds, matched: {}".format(
219
- # (t2 - t1), len(preds["mmkeypoints0_orig"])
220
- # )
221
- # )
222
-
223
- # request extract
224
- for i in range(1000):
225
- t1 = time.time()
226
- preds = send_request_extract(args.image0)
227
- t2 = time.time()
228
- print(f"Time cost2: {(t2 - t1)} seconds")
229
-
230
- # dump preds
231
- with open("preds.pkl", "wb") as f:
232
- pickle.dump(preds, f)
 
1
+ import argparse
2
+ import base64
3
+ import os
4
+ import pickle
5
+ import time
6
+ from typing import Dict, List
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import requests
11
+
12
+ ENDPOINT = "http://127.0.0.1:8001"
13
+ if "REMOTE_URL_RAILWAY" in os.environ:
14
+ ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
15
+
16
+ print(f"API ENDPOINT: {ENDPOINT}")
17
+
18
+ API_VERSION = f"{ENDPOINT}/version"
19
+ API_URL_MATCH = f"{ENDPOINT}/v1/match"
20
+ API_URL_EXTRACT = f"{ENDPOINT}/v1/extract"
21
+
22
+
23
+ def read_image(path: str) -> str:
24
+ """
25
+ Read an image from a file, encode it as a JPEG and then as a base64 string.
26
+
27
+ Args:
28
+ path (str): The path to the image to read.
29
+
30
+ Returns:
31
+ str: The base64 encoded image.
32
+ """
33
+ # Read the image from the file
34
+ img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
35
+
36
+ # Encode the image as a png, NO COMPRESSION!!!
37
+ retval, buffer = cv2.imencode(".png", img)
38
+
39
+ # Encode the JPEG as a base64 string
40
+ b64img = base64.b64encode(buffer).decode("utf-8")
41
+
42
+ return b64img
43
+
44
+
45
+ def do_api_requests(url=API_URL_EXTRACT, **kwargs):
46
+ """
47
+ Helper function to send an API request to the image matching service.
48
+
49
+ Args:
50
+ url (str): The URL of the API endpoint to use. Defaults to the
51
+ feature extraction endpoint.
52
+ **kwargs: Additional keyword arguments to pass to the API.
53
+
54
+ Returns:
55
+ List[Dict[str, np.ndarray]]: A list of dictionaries containing the
56
+ extracted features. The keys are "keypoints", "descriptors", and
57
+ "scores", and the values are ndarrays of shape (N, 2), (N, ?),
58
+ and (N,), respectively.
59
+ """
60
+ # Set up the request body
61
+ reqbody = {
62
+ # List of image data base64 encoded
63
+ "data": [],
64
+ # List of maximum number of keypoints to extract from each image
65
+ "max_keypoints": [100, 100],
66
+ # List of timestamps for each image (not used?)
67
+ "timestamps": ["0", "1"],
68
+ # Whether to convert the images to grayscale
69
+ "grayscale": 0,
70
+ # List of image height and width
71
+ "image_hw": [[640, 480], [320, 240]],
72
+ # Type of feature to extract
73
+ "feature_type": 0,
74
+ # List of rotation angles for each image
75
+ "rotates": [0.0, 0.0],
76
+ # List of scale factors for each image
77
+ "scales": [1.0, 1.0],
78
+ # List of reference points for each image (not used)
79
+ "reference_points": [[640, 480], [320, 240]],
80
+ # Whether to binarize the descriptors
81
+ "binarize": True,
82
+ }
83
+ # Update the request body with the additional keyword arguments
84
+ reqbody.update(kwargs)
85
+ try:
86
+ # Send the request
87
+ r = requests.post(url, json=reqbody)
88
+ if r.status_code == 200:
89
+ # Return the response
90
+ return r.json()
91
+ else:
92
+ # Print an error message if the response code is not 200
93
+ print(f"Error: Response code {r.status_code} - {r.text}")
94
+ except Exception as e:
95
+ # Print an error message if an exception occurs
96
+ print(f"An error occurred: {e}")
97
+
98
+
99
+ def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
100
+ """
101
+ Send a request to the API to generate a match between two images.
102
+
103
+ Args:
104
+ path0 (str): The path to the first image.
105
+ path1 (str): The path to the second image.
106
+
107
+ Returns:
108
+ Dict[str, np.ndarray]: A dictionary containing the generated matches.
109
+ The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
110
+ and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and
111
+ (N, 2), respectively.
112
+ """
113
+ files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
114
+ try:
115
+ # TODO: replace files with post json
116
+ response = requests.post(API_URL_MATCH, files=files)
117
+ pred = {}
118
+ if response.status_code == 200:
119
+ pred = response.json()
120
+ for key in list(pred.keys()):
121
+ pred[key] = np.array(pred[key])
122
+ else:
123
+ print(f"Error: Response code {response.status_code} - {response.text}")
124
+ finally:
125
+ files["image0"].close()
126
+ files["image1"].close()
127
+ return pred
128
+
129
+
130
+ def send_request_extract(
131
+ input_images: str, viz: bool = False
132
+ ) -> List[Dict[str, np.ndarray]]:
133
+ """
134
+ Send a request to the API to extract features from an image.
135
+
136
+ Args:
137
+ input_images (str): The path to the image.
138
+
139
+ Returns:
140
+ List[Dict[str, np.ndarray]]: A list of dictionaries containing the
141
+ extracted features. The keys are "keypoints", "descriptors", and
142
+ "scores", and the values are ndarrays of shape (N, 2), (N, 128),
143
+ and (N,), respectively.
144
+ """
145
+ image_data = read_image(input_images)
146
+ inputs = {
147
+ "data": [image_data],
148
+ }
149
+ response = do_api_requests(
150
+ url=API_URL_EXTRACT,
151
+ **inputs,
152
+ )
153
+ # breakpoint()
154
+ # print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
155
+
156
+ # draw matching, debug only
157
+ if viz:
158
+ from hloc.utils.viz import plot_keypoints
159
+ from ui.viz import fig2im, plot_images
160
+
161
+ kpts = np.array(response[0]["keypoints_orig"])
162
+ if "image_orig" in response[0].keys():
163
+ img_orig = np.array(["image_orig"])
164
+
165
+ output_keypoints = plot_images([img_orig], titles="titles", dpi=300)
166
+ plot_keypoints([kpts])
167
+ output_keypoints = fig2im(output_keypoints)
168
+ cv2.imwrite(
169
+ "demo_match.jpg",
170
+ output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
171
+ )
172
+ return response
173
+
174
+
175
+ def get_api_version():
176
+ try:
177
+ response = requests.get(API_VERSION).json()
178
+ print("API VERSION: {}".format(response["version"]))
179
+ except Exception as e:
180
+ print(f"An error occurred: {e}")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ from pathlib import Path
185
+
186
+ parser = argparse.ArgumentParser(
187
+ description="Send text to stable audio server and receive generated audio."
188
+ )
189
+ parser.add_argument(
190
+ "--image0",
191
+ required=False,
192
+ help="Path for the file's melody",
193
+ default=str(
194
+ Path(__file__).parents[1]
195
+ / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg"
196
+ ),
197
+ )
198
+ parser.add_argument(
199
+ "--image1",
200
+ required=False,
201
+ help="Path for the file's melody",
202
+ default=str(
203
+ Path(__file__).parents[1]
204
+ / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg"
205
+ ),
206
+ )
207
+ args = parser.parse_args()
208
+
209
+ # get api version
210
+ get_api_version()
211
+
212
+ # request match
213
+ # for i in range(10):
214
+ # t1 = time.time()
215
+ # preds = send_request_match(args.image0, args.image1)
216
+ # t2 = time.time()
217
+ # print(
218
+ # "Time cost1: {} seconds, matched: {}".format(
219
+ # (t2 - t1), len(preds["mmkeypoints0_orig"])
220
+ # )
221
+ # )
222
+
223
+ # request extract
224
+ for i in range(1000):
225
+ t1 = time.time()
226
+ preds = send_request_extract(args.image0)
227
+ t2 = time.time()
228
+ print(f"Time cost2: {(t2 - t1)} seconds")
229
+
230
+ # dump preds
231
+ with open("preds.pkl", "wb") as f:
232
+ pickle.dump(preds, f)
imcui/api/config/api.yaml CHANGED
@@ -1,51 +1,35 @@
1
- # This file was generated using the `serve build` command on Ray v2.38.0.
2
-
3
- proxy_location: EveryNode
4
- http_options:
5
- host: 0.0.0.0
6
- port: 8001
7
-
8
- grpc_options:
9
- port: 9000
10
- grpc_servicer_functions: []
11
-
12
- logging_config:
13
- encoding: TEXT
14
- log_level: INFO
15
- logs_dir: null
16
- enable_access_log: true
17
-
18
- applications:
19
- - name: app1
20
- route_prefix: /
21
- import_path: api.server:service
22
- runtime_env: {}
23
- deployments:
24
- - name: ImageMatchingService
25
- num_replicas: 4
26
- ray_actor_options:
27
- num_cpus: 2.0
28
- num_gpus: 1.0
29
-
30
- api:
31
- feature:
32
- output: feats-superpoint-n4096-rmax1600
33
- model:
34
- name: superpoint
35
- nms_radius: 3
36
- max_keypoints: 4096
37
- keypoint_threshold: 0.005
38
- preprocessing:
39
- grayscale: True
40
- force_resize: True
41
- resize_max: 1600
42
- width: 640
43
- height: 480
44
- dfactor: 8
45
- matcher:
46
- output: matches-NN-mutual
47
- model:
48
- name: nearest_neighbor
49
- do_mutual_check: True
50
- match_threshold: 0.2
51
- dense: False
 
1
+ service:
2
+ num_replicas: 4
3
+ ray_actor_options:
4
+ num_cpus: 2.0
5
+ num_gpus: 1.0
6
+ host: &default_host
7
+ "0.0.0.0"
8
+ http_options:
9
+ host: *default_host
10
+ port: 8001
11
+ route_prefix: "/"
12
+ dashboard_port: 8265
13
+
14
+ api:
15
+ feature:
16
+ output: feats-superpoint-n4096-rmax1600
17
+ model:
18
+ name: superpoint
19
+ nms_radius: 3
20
+ max_keypoints: 4096
21
+ keypoint_threshold: 0.005
22
+ preprocessing:
23
+ grayscale: True
24
+ force_resize: True
25
+ resize_max: 1600
26
+ width: 640
27
+ height: 480
28
+ dfactor: 8
29
+ matcher:
30
+ output: matches-NN-mutual
31
+ model:
32
+ name: nearest_neighbor
33
+ do_mutual_check: True
34
+ match_threshold: 0.2
35
+ dense: False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
imcui/api/core.py CHANGED
@@ -1,308 +1,308 @@
1
- # api.py
2
- import warnings
3
- from pathlib import Path
4
- from typing import Any, Dict, Optional
5
-
6
- import cv2
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import torch
10
-
11
- from ..hloc import extract_features, logger, match_dense, match_features
12
- from ..hloc.utils.viz import add_text, plot_keypoints
13
- from ..ui.utils import filter_matches, get_feature_model, get_model
14
- from ..ui.viz import display_matches, fig2im, plot_images
15
-
16
- warnings.simplefilter("ignore")
17
-
18
-
19
- class ImageMatchingAPI(torch.nn.Module):
20
- default_conf = {
21
- "ransac": {
22
- "enable": True,
23
- "estimator": "poselib",
24
- "geometry": "homography",
25
- "method": "RANSAC",
26
- "reproj_threshold": 3,
27
- "confidence": 0.9999,
28
- "max_iter": 10000,
29
- },
30
- }
31
-
32
- def __init__(
33
- self,
34
- conf: dict = {},
35
- device: str = "cpu",
36
- detect_threshold: float = 0.015,
37
- max_keypoints: int = 1024,
38
- match_threshold: float = 0.2,
39
- ) -> None:
40
- """
41
- Initializes an instance of the ImageMatchingAPI class.
42
-
43
- Args:
44
- conf (dict): A dictionary containing the configuration parameters.
45
- device (str, optional): The device to use for computation. Defaults to "cpu".
46
- detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
47
- max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
48
- match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
49
-
50
- Returns:
51
- None
52
- """
53
- super().__init__()
54
- self.device = device
55
- self.conf = {**self.default_conf, **conf}
56
- self._updata_config(detect_threshold, max_keypoints, match_threshold)
57
- self._init_models()
58
- if device == "cuda":
59
- memory_allocated = torch.cuda.memory_allocated(device)
60
- memory_reserved = torch.cuda.memory_reserved(device)
61
- logger.info(f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB")
62
- logger.info(f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB")
63
- self.pred = None
64
-
65
- def parse_match_config(self, conf):
66
- if conf["dense"]:
67
- return {
68
- **conf,
69
- "matcher": match_dense.confs.get(conf["matcher"]["model"]["name"]),
70
- "dense": True,
71
- }
72
- else:
73
- return {
74
- **conf,
75
- "feature": extract_features.confs.get(conf["feature"]["model"]["name"]),
76
- "matcher": match_features.confs.get(conf["matcher"]["model"]["name"]),
77
- "dense": False,
78
- }
79
-
80
- def _updata_config(
81
- self,
82
- detect_threshold: float = 0.015,
83
- max_keypoints: int = 1024,
84
- match_threshold: float = 0.2,
85
- ):
86
- self.dense = self.conf["dense"]
87
- if self.conf["dense"]:
88
- try:
89
- self.conf["matcher"]["model"]["match_threshold"] = match_threshold
90
- except TypeError as e:
91
- logger.error(e)
92
- else:
93
- self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
94
- self.conf["feature"]["model"]["keypoint_threshold"] = detect_threshold
95
- self.extract_conf = self.conf["feature"]
96
-
97
- self.match_conf = self.conf["matcher"]
98
-
99
- def _init_models(self):
100
- # initialize matcher
101
- self.matcher = get_model(self.match_conf)
102
- # initialize extractor
103
- if self.dense:
104
- self.extractor = None
105
- else:
106
- self.extractor = get_feature_model(self.conf["feature"])
107
-
108
- def _forward(self, img0, img1):
109
- if self.dense:
110
- pred = match_dense.match_images(
111
- self.matcher,
112
- img0,
113
- img1,
114
- self.match_conf["preprocessing"],
115
- device=self.device,
116
- )
117
- last_fixed = "{}".format( # noqa: F841
118
- self.match_conf["model"]["name"]
119
- )
120
- else:
121
- pred0 = extract_features.extract(
122
- self.extractor, img0, self.extract_conf["preprocessing"]
123
- )
124
- pred1 = extract_features.extract(
125
- self.extractor, img1, self.extract_conf["preprocessing"]
126
- )
127
- pred = match_features.match_images(self.matcher, pred0, pred1)
128
- return pred
129
-
130
- def _convert_pred(self, pred):
131
- ret = {
132
- k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
133
- for k, v in pred.items()
134
- }
135
- ret = {
136
- k: v[0].cpu().detach().numpy() if isinstance(v, list) else v
137
- for k, v in ret.items()
138
- }
139
- return ret
140
-
141
- @torch.inference_mode()
142
- def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
143
- """Extract features from a single image.
144
-
145
- Args:
146
- img0 (np.ndarray): image
147
-
148
- Returns:
149
- Dict[str, np.ndarray]: feature dict
150
- """
151
-
152
- # setting prams
153
- self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512)
154
- self.extractor.conf["keypoint_threshold"] = kwargs.get(
155
- "keypoint_threshold", 0.0
156
- )
157
-
158
- pred = extract_features.extract(
159
- self.extractor, img0, self.extract_conf["preprocessing"]
160
- )
161
- pred = self._convert_pred(pred)
162
- # back to origin scale
163
- s0 = pred["original_size"] / pred["size"]
164
- pred["keypoints_orig"] = (
165
- match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
166
- )
167
- # TODO: rotate back
168
- binarize = kwargs.get("binarize", False)
169
- if binarize:
170
- assert "descriptors" in pred
171
- pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8)
172
- pred["descriptors"] = pred["descriptors"].T # N x DIM
173
- return pred
174
-
175
- @torch.inference_mode()
176
- def forward(
177
- self,
178
- img0: np.ndarray,
179
- img1: np.ndarray,
180
- ) -> Dict[str, np.ndarray]:
181
- """
182
- Forward pass of the image matching API.
183
-
184
- Args:
185
- img0: A 3D NumPy array of shape (H, W, C) representing the first image.
186
- Values are in the range [0, 1] and are in RGB mode.
187
- img1: A 3D NumPy array of shape (H, W, C) representing the second image.
188
- Values are in the range [0, 1] and are in RGB mode.
189
-
190
- Returns:
191
- A dictionary containing the following keys:
192
- - image0_orig: The original image 0.
193
- - image1_orig: The original image 1.
194
- - keypoints0_orig: The keypoints detected in image 0.
195
- - keypoints1_orig: The keypoints detected in image 1.
196
- - mkeypoints0_orig: The raw matches between image 0 and image 1.
197
- - mkeypoints1_orig: The raw matches between image 1 and image 0.
198
- - mmkeypoints0_orig: The RANSAC inliers in image 0.
199
- - mmkeypoints1_orig: The RANSAC inliers in image 1.
200
- - mconf: The confidence scores for the raw matches.
201
- - mmconf: The confidence scores for the RANSAC inliers.
202
- """
203
- # Take as input a pair of images (not a batch)
204
- assert isinstance(img0, np.ndarray)
205
- assert isinstance(img1, np.ndarray)
206
- self.pred = self._forward(img0, img1)
207
- if self.conf["ransac"]["enable"]:
208
- self.pred = self._geometry_check(self.pred)
209
- return self.pred
210
-
211
- def _geometry_check(
212
- self,
213
- pred: Dict[str, Any],
214
- ) -> Dict[str, Any]:
215
- """
216
- Filter matches using RANSAC. If keypoints are available, filter by keypoints.
217
- If lines are available, filter by lines. If both keypoints and lines are
218
- available, filter by keypoints.
219
-
220
- Args:
221
- pred (Dict[str, Any]): dict of matches, including original keypoints.
222
- See :func:`filter_matches` for the expected keys.
223
-
224
- Returns:
225
- Dict[str, Any]: filtered matches
226
- """
227
- pred = filter_matches(
228
- pred,
229
- ransac_method=self.conf["ransac"]["method"],
230
- ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"],
231
- ransac_confidence=self.conf["ransac"]["confidence"],
232
- ransac_max_iter=self.conf["ransac"]["max_iter"],
233
- )
234
- return pred
235
-
236
- def visualize(
237
- self,
238
- log_path: Optional[Path] = None,
239
- ) -> None:
240
- """
241
- Visualize the matches.
242
-
243
- Args:
244
- log_path (Path, optional): The directory to save the images. Defaults to None.
245
-
246
- Returns:
247
- None
248
- """
249
- if self.conf["dense"]:
250
- postfix = str(self.conf["matcher"]["model"]["name"])
251
- else:
252
- postfix = "{}_{}".format(
253
- str(self.conf["feature"]["model"]["name"]),
254
- str(self.conf["matcher"]["model"]["name"]),
255
- )
256
- titles = [
257
- "Image 0 - Keypoints",
258
- "Image 1 - Keypoints",
259
- ]
260
- pred: Dict[str, Any] = self.pred
261
- image0: np.ndarray = pred["image0_orig"]
262
- image1: np.ndarray = pred["image1_orig"]
263
- output_keypoints: np.ndarray = plot_images(
264
- [image0, image1], titles=titles, dpi=300
265
- )
266
- if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
267
- plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
268
- text: str = (
269
- f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
270
- + f"# keypoints1: {len(pred['keypoints1_orig'])}"
271
- )
272
- add_text(0, text, fs=15)
273
- output_keypoints = fig2im(output_keypoints)
274
- # plot images with raw matches
275
- titles = [
276
- "Image 0 - Raw matched keypoints",
277
- "Image 1 - Raw matched keypoints",
278
- ]
279
- output_matches_raw, num_matches_raw = display_matches(
280
- pred, titles=titles, tag="KPTS_RAW"
281
- )
282
- # plot images with ransac matches
283
- titles = [
284
- "Image 0 - Ransac matched keypoints",
285
- "Image 1 - Ransac matched keypoints",
286
- ]
287
- output_matches_ransac, num_matches_ransac = display_matches(
288
- pred, titles=titles, tag="KPTS_RANSAC"
289
- )
290
- if log_path is not None:
291
- img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png"
292
- img_matches_raw_path: Path = log_path / f"img_matches_raw_{postfix}.png"
293
- img_matches_ransac_path: Path = (
294
- log_path / f"img_matches_ransac_{postfix}.png"
295
- )
296
- cv2.imwrite(
297
- str(img_keypoints_path),
298
- output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
299
- )
300
- cv2.imwrite(
301
- str(img_matches_raw_path),
302
- output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR
303
- )
304
- cv2.imwrite(
305
- str(img_matches_ransac_path),
306
- output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR
307
- )
308
- plt.close("all")
 
1
+ # api.py
2
+ import warnings
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Optional
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ..hloc import extract_features, logger, match_dense, match_features
12
+ from ..hloc.utils.viz import add_text, plot_keypoints
13
+ from ..ui.utils import filter_matches, get_feature_model, get_model
14
+ from ..ui.viz import display_matches, fig2im, plot_images
15
+
16
+ warnings.simplefilter("ignore")
17
+
18
+
19
+ class ImageMatchingAPI(torch.nn.Module):
20
+ default_conf = {
21
+ "ransac": {
22
+ "enable": True,
23
+ "estimator": "poselib",
24
+ "geometry": "homography",
25
+ "method": "RANSAC",
26
+ "reproj_threshold": 3,
27
+ "confidence": 0.9999,
28
+ "max_iter": 10000,
29
+ },
30
+ }
31
+
32
+ def __init__(
33
+ self,
34
+ conf: dict = {},
35
+ device: str = "cpu",
36
+ detect_threshold: float = 0.015,
37
+ max_keypoints: int = 1024,
38
+ match_threshold: float = 0.2,
39
+ ) -> None:
40
+ """
41
+ Initializes an instance of the ImageMatchingAPI class.
42
+
43
+ Args:
44
+ conf (dict): A dictionary containing the configuration parameters.
45
+ device (str, optional): The device to use for computation. Defaults to "cpu".
46
+ detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
47
+ max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
48
+ match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
49
+
50
+ Returns:
51
+ None
52
+ """
53
+ super().__init__()
54
+ self.device = device
55
+ self.conf = {**self.default_conf, **conf}
56
+ self._updata_config(detect_threshold, max_keypoints, match_threshold)
57
+ self._init_models()
58
+ if device == "cuda":
59
+ memory_allocated = torch.cuda.memory_allocated(device)
60
+ memory_reserved = torch.cuda.memory_reserved(device)
61
+ logger.info(f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB")
62
+ logger.info(f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB")
63
+ self.pred = None
64
+
65
+ def parse_match_config(self, conf):
66
+ if conf["dense"]:
67
+ return {
68
+ **conf,
69
+ "matcher": match_dense.confs.get(conf["matcher"]["model"]["name"]),
70
+ "dense": True,
71
+ }
72
+ else:
73
+ return {
74
+ **conf,
75
+ "feature": extract_features.confs.get(conf["feature"]["model"]["name"]),
76
+ "matcher": match_features.confs.get(conf["matcher"]["model"]["name"]),
77
+ "dense": False,
78
+ }
79
+
80
+ def _updata_config(
81
+ self,
82
+ detect_threshold: float = 0.015,
83
+ max_keypoints: int = 1024,
84
+ match_threshold: float = 0.2,
85
+ ):
86
+ self.dense = self.conf["dense"]
87
+ if self.conf["dense"]:
88
+ try:
89
+ self.conf["matcher"]["model"]["match_threshold"] = match_threshold
90
+ except TypeError as e:
91
+ logger.error(e)
92
+ else:
93
+ self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
94
+ self.conf["feature"]["model"]["keypoint_threshold"] = detect_threshold
95
+ self.extract_conf = self.conf["feature"]
96
+
97
+ self.match_conf = self.conf["matcher"]
98
+
99
+ def _init_models(self):
100
+ # initialize matcher
101
+ self.matcher = get_model(self.match_conf)
102
+ # initialize extractor
103
+ if self.dense:
104
+ self.extractor = None
105
+ else:
106
+ self.extractor = get_feature_model(self.conf["feature"])
107
+
108
+ def _forward(self, img0, img1):
109
+ if self.dense:
110
+ pred = match_dense.match_images(
111
+ self.matcher,
112
+ img0,
113
+ img1,
114
+ self.match_conf["preprocessing"],
115
+ device=self.device,
116
+ )
117
+ last_fixed = "{}".format( # noqa: F841
118
+ self.match_conf["model"]["name"]
119
+ )
120
+ else:
121
+ pred0 = extract_features.extract(
122
+ self.extractor, img0, self.extract_conf["preprocessing"]
123
+ )
124
+ pred1 = extract_features.extract(
125
+ self.extractor, img1, self.extract_conf["preprocessing"]
126
+ )
127
+ pred = match_features.match_images(self.matcher, pred0, pred1)
128
+ return pred
129
+
130
+ def _convert_pred(self, pred):
131
+ ret = {
132
+ k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
133
+ for k, v in pred.items()
134
+ }
135
+ ret = {
136
+ k: v[0].cpu().detach().numpy() if isinstance(v, list) else v
137
+ for k, v in ret.items()
138
+ }
139
+ return ret
140
+
141
+ @torch.inference_mode()
142
+ def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
143
+ """Extract features from a single image.
144
+
145
+ Args:
146
+ img0 (np.ndarray): image
147
+
148
+ Returns:
149
+ Dict[str, np.ndarray]: feature dict
150
+ """
151
+
152
+ # setting prams
153
+ self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512)
154
+ self.extractor.conf["keypoint_threshold"] = kwargs.get(
155
+ "keypoint_threshold", 0.0
156
+ )
157
+
158
+ pred = extract_features.extract(
159
+ self.extractor, img0, self.extract_conf["preprocessing"]
160
+ )
161
+ pred = self._convert_pred(pred)
162
+ # back to origin scale
163
+ s0 = pred["original_size"] / pred["size"]
164
+ pred["keypoints_orig"] = (
165
+ match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
166
+ )
167
+ # TODO: rotate back
168
+ binarize = kwargs.get("binarize", False)
169
+ if binarize:
170
+ assert "descriptors" in pred
171
+ pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8)
172
+ pred["descriptors"] = pred["descriptors"].T # N x DIM
173
+ return pred
174
+
175
+ @torch.inference_mode()
176
+ def forward(
177
+ self,
178
+ img0: np.ndarray,
179
+ img1: np.ndarray,
180
+ ) -> Dict[str, np.ndarray]:
181
+ """
182
+ Forward pass of the image matching API.
183
+
184
+ Args:
185
+ img0: A 3D NumPy array of shape (H, W, C) representing the first image.
186
+ Values are in the range [0, 1] and are in RGB mode.
187
+ img1: A 3D NumPy array of shape (H, W, C) representing the second image.
188
+ Values are in the range [0, 1] and are in RGB mode.
189
+
190
+ Returns:
191
+ A dictionary containing the following keys:
192
+ - image0_orig: The original image 0.
193
+ - image1_orig: The original image 1.
194
+ - keypoints0_orig: The keypoints detected in image 0.
195
+ - keypoints1_orig: The keypoints detected in image 1.
196
+ - mkeypoints0_orig: The raw matches between image 0 and image 1.
197
+ - mkeypoints1_orig: The raw matches between image 1 and image 0.
198
+ - mmkeypoints0_orig: The RANSAC inliers in image 0.
199
+ - mmkeypoints1_orig: The RANSAC inliers in image 1.
200
+ - mconf: The confidence scores for the raw matches.
201
+ - mmconf: The confidence scores for the RANSAC inliers.
202
+ """
203
+ # Take as input a pair of images (not a batch)
204
+ assert isinstance(img0, np.ndarray)
205
+ assert isinstance(img1, np.ndarray)
206
+ self.pred = self._forward(img0, img1)
207
+ if self.conf["ransac"]["enable"]:
208
+ self.pred = self._geometry_check(self.pred)
209
+ return self.pred
210
+
211
+ def _geometry_check(
212
+ self,
213
+ pred: Dict[str, Any],
214
+ ) -> Dict[str, Any]:
215
+ """
216
+ Filter matches using RANSAC. If keypoints are available, filter by keypoints.
217
+ If lines are available, filter by lines. If both keypoints and lines are
218
+ available, filter by keypoints.
219
+
220
+ Args:
221
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
222
+ See :func:`filter_matches` for the expected keys.
223
+
224
+ Returns:
225
+ Dict[str, Any]: filtered matches
226
+ """
227
+ pred = filter_matches(
228
+ pred,
229
+ ransac_method=self.conf["ransac"]["method"],
230
+ ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"],
231
+ ransac_confidence=self.conf["ransac"]["confidence"],
232
+ ransac_max_iter=self.conf["ransac"]["max_iter"],
233
+ )
234
+ return pred
235
+
236
+ def visualize(
237
+ self,
238
+ log_path: Optional[Path] = None,
239
+ ) -> None:
240
+ """
241
+ Visualize the matches.
242
+
243
+ Args:
244
+ log_path (Path, optional): The directory to save the images. Defaults to None.
245
+
246
+ Returns:
247
+ None
248
+ """
249
+ if self.conf["dense"]:
250
+ postfix = str(self.conf["matcher"]["model"]["name"])
251
+ else:
252
+ postfix = "{}_{}".format(
253
+ str(self.conf["feature"]["model"]["name"]),
254
+ str(self.conf["matcher"]["model"]["name"]),
255
+ )
256
+ titles = [
257
+ "Image 0 - Keypoints",
258
+ "Image 1 - Keypoints",
259
+ ]
260
+ pred: Dict[str, Any] = self.pred
261
+ image0: np.ndarray = pred["image0_orig"]
262
+ image1: np.ndarray = pred["image1_orig"]
263
+ output_keypoints: np.ndarray = plot_images(
264
+ [image0, image1], titles=titles, dpi=300
265
+ )
266
+ if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
267
+ plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
268
+ text: str = (
269
+ f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
270
+ + f"# keypoints1: {len(pred['keypoints1_orig'])}"
271
+ )
272
+ add_text(0, text, fs=15)
273
+ output_keypoints = fig2im(output_keypoints)
274
+ # plot images with raw matches
275
+ titles = [
276
+ "Image 0 - Raw matched keypoints",
277
+ "Image 1 - Raw matched keypoints",
278
+ ]
279
+ output_matches_raw, num_matches_raw = display_matches(
280
+ pred, titles=titles, tag="KPTS_RAW"
281
+ )
282
+ # plot images with ransac matches
283
+ titles = [
284
+ "Image 0 - Ransac matched keypoints",
285
+ "Image 1 - Ransac matched keypoints",
286
+ ]
287
+ output_matches_ransac, num_matches_ransac = display_matches(
288
+ pred, titles=titles, tag="KPTS_RANSAC"
289
+ )
290
+ if log_path is not None:
291
+ img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png"
292
+ img_matches_raw_path: Path = log_path / f"img_matches_raw_{postfix}.png"
293
+ img_matches_ransac_path: Path = (
294
+ log_path / f"img_matches_ransac_{postfix}.png"
295
+ )
296
+ cv2.imwrite(
297
+ str(img_keypoints_path),
298
+ output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
299
+ )
300
+ cv2.imwrite(
301
+ str(img_matches_raw_path),
302
+ output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR
303
+ )
304
+ cv2.imwrite(
305
+ str(img_matches_ransac_path),
306
+ output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR
307
+ )
308
+ plt.close("all")
imcui/api/server.py CHANGED
@@ -1,170 +1,186 @@
1
- # server.py
2
- import warnings
3
- from pathlib import Path
4
- from typing import Union
5
-
6
- import numpy as np
7
- import ray
8
- import torch
9
- import yaml
10
- from fastapi import FastAPI, File, UploadFile
11
- from fastapi.responses import JSONResponse
12
- from PIL import Image
13
- from ray import serve
14
-
15
- from . import ImagesInput, to_base64_nparray
16
- from .core import ImageMatchingAPI
17
- from ..hloc import DEVICE
18
- from ..ui import get_version
19
-
20
- warnings.simplefilter("ignore")
21
- app = FastAPI()
22
- if ray.is_initialized():
23
- ray.shutdown()
24
- ray.init(
25
- dashboard_port=8265,
26
- ignore_reinit_error=True,
27
- )
28
- serve.start(
29
- http_options={"host": "0.0.0.0", "port": 8001},
30
- )
31
-
32
- num_gpus = 1 if torch.cuda.is_available() else 0
33
-
34
-
35
- @serve.deployment(
36
- num_replicas=4, ray_actor_options={"num_cpus": 2, "num_gpus": num_gpus}
37
- )
38
- @serve.ingress(app)
39
- class ImageMatchingService:
40
- def __init__(self, conf: dict, device: str):
41
- self.conf = conf
42
- self.api = ImageMatchingAPI(conf=conf, device=device)
43
-
44
- @app.get("/")
45
- def root(self):
46
- return "Hello, world!"
47
-
48
- @app.get("/version")
49
- async def version(self):
50
- return {"version": get_version()}
51
-
52
- @app.post("/v1/match")
53
- async def match(
54
- self, image0: UploadFile = File(...), image1: UploadFile = File(...)
55
- ):
56
- """
57
- Handle the image matching request and return the processed result.
58
-
59
- Args:
60
- image0 (UploadFile): The first image file for matching.
61
- image1 (UploadFile): The second image file for matching.
62
-
63
- Returns:
64
- JSONResponse: A JSON response containing the filtered match results
65
- or an error message in case of failure.
66
- """
67
- try:
68
- # Load the images from the uploaded files
69
- image0_array = self.load_image(image0)
70
- image1_array = self.load_image(image1)
71
-
72
- # Perform image matching using the API
73
- output = self.api(image0_array, image1_array)
74
-
75
- # Keys to skip in the output
76
- skip_keys = ["image0_orig", "image1_orig"]
77
-
78
- # Postprocess the output to filter unwanted data
79
- pred = self.postprocess(output, skip_keys)
80
-
81
- # Return the filtered prediction as a JSON response
82
- return JSONResponse(content=pred)
83
- except Exception as e:
84
- # Return an error message with status code 500 in case of exception
85
- return JSONResponse(content={"error": str(e)}, status_code=500)
86
-
87
- @app.post("/v1/extract")
88
- async def extract(self, input_info: ImagesInput):
89
- """
90
- Extract keypoints and descriptors from images.
91
-
92
- Args:
93
- input_info: An object containing the image data and options.
94
-
95
- Returns:
96
- A list of dictionaries containing the keypoints and descriptors.
97
- """
98
- try:
99
- preds = []
100
- for i, input_image in enumerate(input_info.data):
101
- # Load the image from the input data
102
- image_array = to_base64_nparray(input_image)
103
- # Extract keypoints and descriptors
104
- output = self.api.extract(
105
- image_array,
106
- max_keypoints=input_info.max_keypoints[i],
107
- binarize=input_info.binarize,
108
- )
109
- # Do not return the original image and image_orig
110
- # skip_keys = ["image", "image_orig"]
111
- skip_keys = []
112
-
113
- # Postprocess the output
114
- pred = self.postprocess(output, skip_keys)
115
- preds.append(pred)
116
- # Return the list of extracted features
117
- return JSONResponse(content=preds)
118
- except Exception as e:
119
- # Return an error message if an exception occurs
120
- return JSONResponse(content={"error": str(e)}, status_code=500)
121
-
122
- def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
123
- """
124
- Reads an image from a file path or an UploadFile object.
125
-
126
- Args:
127
- file_path: A file path or an UploadFile object.
128
-
129
- Returns:
130
- A numpy array representing the image.
131
- """
132
- if isinstance(file_path, str):
133
- file_path = Path(file_path).resolve(strict=False)
134
- else:
135
- file_path = file_path.file
136
- with Image.open(file_path) as img:
137
- image_array = np.array(img)
138
- return image_array
139
-
140
- def postprocess(self, output: dict, skip_keys: list, binarize: bool = True) -> dict:
141
- pred = {}
142
- for key, value in output.items():
143
- if key in skip_keys:
144
- continue
145
- if isinstance(value, np.ndarray):
146
- pred[key] = value.tolist()
147
- return pred
148
-
149
- def run(self, host: str = "0.0.0.0", port: int = 8001):
150
- import uvicorn
151
-
152
- uvicorn.run(app, host=host, port=port)
153
-
154
-
155
- def read_config(config_path: Path) -> dict:
156
- with open(config_path, "r") as f:
157
- conf = yaml.safe_load(f)
158
- return conf
159
-
160
-
161
- # api server
162
- conf = read_config(Path(__file__).parent / "config/api.yaml")
163
- service = ImageMatchingService.bind(conf=conf["api"], device=DEVICE)
164
- handle = serve.run(service, route_prefix="/")
165
-
166
- # serve run api.server_ray:service
167
-
168
- # build to generate config file
169
- # serve build api.server_ray:service -o api/config/ray.yaml
170
- # serve run api/config/ray.yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # server.py
2
+ import warnings
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import ray
8
+ import torch
9
+ from fastapi import FastAPI, File, UploadFile
10
+ from fastapi.responses import JSONResponse
11
+ from PIL import Image
12
+ from ray import serve
13
+ import argparse
14
+
15
+ from . import ImagesInput, to_base64_nparray
16
+ from .core import ImageMatchingAPI
17
+ from ..hloc import DEVICE
18
+ from ..hloc.utils.io import read_yaml
19
+ from ..ui import get_version
20
+
21
+ warnings.simplefilter("ignore")
22
+ app = FastAPI()
23
+ if ray.is_initialized():
24
+ ray.shutdown()
25
+
26
+
27
+ # read some configs
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument(
30
+ "--config",
31
+ type=Path,
32
+ required=False,
33
+ default=Path(__file__).parent / "config/api.yaml",
34
+ )
35
+ args = parser.parse_args()
36
+ config_path = args.config
37
+ config = read_yaml(config_path)
38
+ num_gpus = 1 if torch.cuda.is_available() else 0
39
+ ray_actor_options = config["service"].get("ray_actor_options", {})
40
+ ray_actor_options.update({"num_gpus": num_gpus})
41
+ dashboard_port = config["service"].get("dashboard_port", 8265)
42
+ http_options = config["service"].get(
43
+ "http_options",
44
+ {
45
+ "host": "0.0.0.0",
46
+ "port": 8001,
47
+ },
48
+ )
49
+ num_replicas = config["service"].get("num_replicas", 4)
50
+ ray.init(
51
+ dashboard_port=dashboard_port,
52
+ ignore_reinit_error=True,
53
+ )
54
+ serve.start(http_options=http_options)
55
+
56
+
57
+ @serve.deployment(
58
+ num_replicas=num_replicas,
59
+ ray_actor_options=ray_actor_options,
60
+ )
61
+ @serve.ingress(app)
62
+ class ImageMatchingService:
63
+ def __init__(self, conf: dict, device: str, **kwargs):
64
+ self.conf = conf
65
+ self.api = ImageMatchingAPI(conf=conf, device=device)
66
+
67
+ @app.get("/")
68
+ def root(self):
69
+ return "Hello, world!"
70
+
71
+ @app.get("/version")
72
+ async def version(self):
73
+ return {"version": get_version()}
74
+
75
+ @app.post("/v1/match")
76
+ async def match(
77
+ self, image0: UploadFile = File(...), image1: UploadFile = File(...)
78
+ ):
79
+ """
80
+ Handle the image matching request and return the processed result.
81
+
82
+ Args:
83
+ image0 (UploadFile): The first image file for matching.
84
+ image1 (UploadFile): The second image file for matching.
85
+
86
+ Returns:
87
+ JSONResponse: A JSON response containing the filtered match results
88
+ or an error message in case of failure.
89
+ """
90
+ try:
91
+ # Load the images from the uploaded files
92
+ image0_array = self.load_image(image0)
93
+ image1_array = self.load_image(image1)
94
+
95
+ # Perform image matching using the API
96
+ output = self.api(image0_array, image1_array)
97
+
98
+ # Keys to skip in the output
99
+ skip_keys = ["image0_orig", "image1_orig"]
100
+
101
+ # Postprocess the output to filter unwanted data
102
+ pred = self.postprocess(output, skip_keys)
103
+
104
+ # Return the filtered prediction as a JSON response
105
+ return JSONResponse(content=pred)
106
+ except Exception as e:
107
+ # Return an error message with status code 500 in case of exception
108
+ return JSONResponse(content={"error": str(e)}, status_code=500)
109
+
110
+ @app.post("/v1/extract")
111
+ async def extract(self, input_info: ImagesInput):
112
+ """
113
+ Extract keypoints and descriptors from images.
114
+
115
+ Args:
116
+ input_info: An object containing the image data and options.
117
+
118
+ Returns:
119
+ A list of dictionaries containing the keypoints and descriptors.
120
+ """
121
+ try:
122
+ preds = []
123
+ for i, input_image in enumerate(input_info.data):
124
+ # Load the image from the input data
125
+ image_array = to_base64_nparray(input_image)
126
+ # Extract keypoints and descriptors
127
+ output = self.api.extract(
128
+ image_array,
129
+ max_keypoints=input_info.max_keypoints[i],
130
+ binarize=input_info.binarize,
131
+ )
132
+ # Do not return the original image and image_orig
133
+ # skip_keys = ["image", "image_orig"]
134
+ skip_keys = []
135
+
136
+ # Postprocess the output
137
+ pred = self.postprocess(output, skip_keys)
138
+ preds.append(pred)
139
+ # Return the list of extracted features
140
+ return JSONResponse(content=preds)
141
+ except Exception as e:
142
+ # Return an error message if an exception occurs
143
+ return JSONResponse(content={"error": str(e)}, status_code=500)
144
+
145
+ def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
146
+ """
147
+ Reads an image from a file path or an UploadFile object.
148
+
149
+ Args:
150
+ file_path: A file path or an UploadFile object.
151
+
152
+ Returns:
153
+ A numpy array representing the image.
154
+ """
155
+ if isinstance(file_path, str):
156
+ file_path = Path(file_path).resolve(strict=False)
157
+ else:
158
+ file_path = file_path.file
159
+ with Image.open(file_path) as img:
160
+ image_array = np.array(img)
161
+ return image_array
162
+
163
+ def postprocess(self, output: dict, skip_keys: list, **kwargs) -> dict:
164
+ pred = {}
165
+ for key, value in output.items():
166
+ if key in skip_keys:
167
+ continue
168
+ if isinstance(value, np.ndarray):
169
+ pred[key] = value.tolist()
170
+ return pred
171
+
172
+ def run(self, host: str = "0.0.0.0", port: int = 8001):
173
+ import uvicorn
174
+
175
+ uvicorn.run(app, host=host, port=port)
176
+
177
+
178
+ if __name__ == "__main__":
179
+ # api server
180
+ service = ImageMatchingService.bind(conf=config["api"], device=DEVICE)
181
+ handle = serve.run(service, route_prefix="/", blocking=False)
182
+
183
+ # serve run api.server_ray:service
184
+ # build to generate config file
185
+ # serve build api.server_ray:service -o api/config/ray.yaml
186
+ # serve run api/config/ray.yaml
imcui/api/test/build_and_run.sh CHANGED
@@ -1,16 +1,16 @@
1
- # g++ main.cpp -I/usr/include/opencv4 -lcurl -ljsoncpp -lb64 -lopencv_core -lopencv_imgcodecs -o main
2
- # sudo apt-get update
3
- # sudo apt-get install libboost-all-dev -y
4
- # sudo apt-get install libcurl4-openssl-dev libjsoncpp-dev libb64-dev libopencv-dev -y
5
-
6
- cd build
7
- cmake ..
8
- make -j12
9
-
10
- echo " ======== RUN DEMO ========"
11
-
12
- ./client
13
-
14
- echo " ======== END DEMO ========"
15
-
16
- cd ..
 
1
+ # g++ main.cpp -I/usr/include/opencv4 -lcurl -ljsoncpp -lb64 -lopencv_core -lopencv_imgcodecs -o main
2
+ # sudo apt-get update
3
+ # sudo apt-get install libboost-all-dev -y
4
+ # sudo apt-get install libcurl4-openssl-dev libjsoncpp-dev libb64-dev libopencv-dev -y
5
+
6
+ cd build
7
+ cmake ..
8
+ make -j12
9
+
10
+ echo " ======== RUN DEMO ========"
11
+
12
+ ./client
13
+
14
+ echo " ======== END DEMO ========"
15
+
16
+ cd ..
imcui/api/test/client.cpp CHANGED
@@ -1,81 +1,81 @@
1
- #include <curl/curl.h>
2
- #include <opencv2/opencv.hpp>
3
- #include "helper.h"
4
-
5
- int main() {
6
- std::string img_path =
7
- "../../../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg";
8
- cv::Mat original_img = cv::imread(img_path, cv::IMREAD_GRAYSCALE);
9
-
10
- if (original_img.empty()) {
11
- throw std::runtime_error("Failed to decode image");
12
- }
13
-
14
- // Convert the image to Base64
15
- std::string base64_img = image_to_base64(original_img);
16
-
17
- // Convert the Base64 back to an image
18
- cv::Mat decoded_img = base64_to_image(base64_img);
19
- cv::imwrite("decoded_image.jpg", decoded_img);
20
- cv::imwrite("original_img.jpg", original_img);
21
-
22
- // The images should be identical
23
- if (cv::countNonZero(original_img != decoded_img) != 0) {
24
- std::cerr << "The images are not identical" << std::endl;
25
- return -1;
26
- } else {
27
- std::cout << "The images are identical!" << std::endl;
28
- }
29
-
30
- // construct params
31
- APIParams params{.data = {base64_img},
32
- .max_keypoints = {100, 100},
33
- .timestamps = {"0", "1"},
34
- .grayscale = {0},
35
- .image_hw = {{480, 640}, {240, 320}},
36
- .feature_type = 0,
37
- .rotates = {0.0f, 0.0f},
38
- .scales = {1.0f, 1.0f},
39
- .reference_points = {{1.23e+2f, 1.2e+1f},
40
- {5.0e-1f, 3.0e-1f},
41
- {2.3e+2f, 2.2e+1f},
42
- {6.0e-1f, 4.0e-1f}},
43
- .binarize = {1}};
44
-
45
- KeyPointResults kpts_results;
46
-
47
- // Convert the parameters to JSON
48
- Json::Value jsonData = paramsToJson(params);
49
- std::string url = "http://127.0.0.1:8001/v1/extract";
50
- Json::StreamWriterBuilder writer;
51
- std::string output = Json::writeString(writer, jsonData);
52
-
53
- CURL* curl;
54
- CURLcode res;
55
- std::string readBuffer;
56
-
57
- curl_global_init(CURL_GLOBAL_DEFAULT);
58
- curl = curl_easy_init();
59
- if (curl) {
60
- struct curl_slist* hs = NULL;
61
- hs = curl_slist_append(hs, "Content-Type: application/json");
62
- curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs);
63
- curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
64
- curl_easy_setopt(curl, CURLOPT_POSTFIELDS, output.c_str());
65
- curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
66
- curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
67
- res = curl_easy_perform(curl);
68
-
69
- if (res != CURLE_OK)
70
- fprintf(
71
- stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
72
- else {
73
- // std::cout << "Response from server: " << readBuffer << std::endl;
74
- kpts_results = decode_response(readBuffer);
75
- }
76
- curl_easy_cleanup(curl);
77
- }
78
- curl_global_cleanup();
79
-
80
- return 0;
81
- }
 
1
+ #include <curl/curl.h>
2
+ #include <opencv2/opencv.hpp>
3
+ #include "helper.h"
4
+
5
+ int main() {
6
+ std::string img_path =
7
+ "../../../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg";
8
+ cv::Mat original_img = cv::imread(img_path, cv::IMREAD_GRAYSCALE);
9
+
10
+ if (original_img.empty()) {
11
+ throw std::runtime_error("Failed to decode image");
12
+ }
13
+
14
+ // Convert the image to Base64
15
+ std::string base64_img = image_to_base64(original_img);
16
+
17
+ // Convert the Base64 back to an image
18
+ cv::Mat decoded_img = base64_to_image(base64_img);
19
+ cv::imwrite("decoded_image.jpg", decoded_img);
20
+ cv::imwrite("original_img.jpg", original_img);
21
+
22
+ // The images should be identical
23
+ if (cv::countNonZero(original_img != decoded_img) != 0) {
24
+ std::cerr << "The images are not identical" << std::endl;
25
+ return -1;
26
+ } else {
27
+ std::cout << "The images are identical!" << std::endl;
28
+ }
29
+
30
+ // construct params
31
+ APIParams params{.data = {base64_img},
32
+ .max_keypoints = {100, 100},
33
+ .timestamps = {"0", "1"},
34
+ .grayscale = {0},
35
+ .image_hw = {{480, 640}, {240, 320}},
36
+ .feature_type = 0,
37
+ .rotates = {0.0f, 0.0f},
38
+ .scales = {1.0f, 1.0f},
39
+ .reference_points = {{1.23e+2f, 1.2e+1f},
40
+ {5.0e-1f, 3.0e-1f},
41
+ {2.3e+2f, 2.2e+1f},
42
+ {6.0e-1f, 4.0e-1f}},
43
+ .binarize = {1}};
44
+
45
+ KeyPointResults kpts_results;
46
+
47
+ // Convert the parameters to JSON
48
+ Json::Value jsonData = paramsToJson(params);
49
+ std::string url = "http://127.0.0.1:8001/v1/extract";
50
+ Json::StreamWriterBuilder writer;
51
+ std::string output = Json::writeString(writer, jsonData);
52
+
53
+ CURL* curl;
54
+ CURLcode res;
55
+ std::string readBuffer;
56
+
57
+ curl_global_init(CURL_GLOBAL_DEFAULT);
58
+ curl = curl_easy_init();
59
+ if (curl) {
60
+ struct curl_slist* hs = NULL;
61
+ hs = curl_slist_append(hs, "Content-Type: application/json");
62
+ curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs);
63
+ curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
64
+ curl_easy_setopt(curl, CURLOPT_POSTFIELDS, output.c_str());
65
+ curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
66
+ curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
67
+ res = curl_easy_perform(curl);
68
+
69
+ if (res != CURLE_OK)
70
+ fprintf(
71
+ stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
72
+ else {
73
+ // std::cout << "Response from server: " << readBuffer << std::endl;
74
+ kpts_results = decode_response(readBuffer);
75
+ }
76
+ curl_easy_cleanup(curl);
77
+ }
78
+ curl_global_cleanup();
79
+
80
+ return 0;
81
+ }
imcui/api/test/helper.h CHANGED
@@ -1,405 +1,405 @@
1
-
2
- #include <b64/encode.h>
3
- #include <fstream>
4
- #include <jsoncpp/json/json.h>
5
- #include <opencv2/opencv.hpp>
6
- #include <sstream>
7
- #include <vector>
8
-
9
- // base64 to image
10
- #include <boost/archive/iterators/base64_from_binary.hpp>
11
- #include <boost/archive/iterators/binary_from_base64.hpp>
12
- #include <boost/archive/iterators/transform_width.hpp>
13
-
14
- /// Parameters used in the API
15
- struct APIParams {
16
- /// A list of images, base64 encoded
17
- std::vector<std::string> data;
18
-
19
- /// The maximum number of keypoints to detect for each image
20
- std::vector<int> max_keypoints;
21
-
22
- /// The timestamps of the images
23
- std::vector<std::string> timestamps;
24
-
25
- /// Whether to convert the images to grayscale
26
- bool grayscale;
27
-
28
- /// The height and width of each image
29
- std::vector<std::vector<int>> image_hw;
30
-
31
- /// The type of feature detector to use
32
- int feature_type;
33
-
34
- /// The rotations of the images
35
- std::vector<double> rotates;
36
-
37
- /// The scales of the images
38
- std::vector<double> scales;
39
-
40
- /// The reference points of the images
41
- std::vector<std::vector<float>> reference_points;
42
-
43
- /// Whether to binarize the descriptors
44
- bool binarize;
45
- };
46
-
47
- /**
48
- * @brief Contains the results of a keypoint detector.
49
- *
50
- * @details Stores the keypoints and descriptors for each image.
51
- */
52
- class KeyPointResults {
53
- public:
54
- KeyPointResults() {
55
- }
56
-
57
- /**
58
- * @brief Constructor.
59
- *
60
- * @param kp The keypoints for each image.
61
- */
62
- KeyPointResults(const std::vector<std::vector<cv::KeyPoint>>& kp,
63
- const std::vector<cv::Mat>& desc)
64
- : keypoints(kp), descriptors(desc) {
65
- }
66
-
67
- /**
68
- * @brief Append keypoints to the result.
69
- *
70
- * @param kpts The keypoints to append.
71
- */
72
- inline void append_keypoints(std::vector<cv::KeyPoint>& kpts) {
73
- keypoints.emplace_back(kpts);
74
- }
75
-
76
- /**
77
- * @brief Append descriptors to the result.
78
- *
79
- * @param desc The descriptors to append.
80
- */
81
- inline void append_descriptors(cv::Mat& desc) {
82
- descriptors.emplace_back(desc);
83
- }
84
-
85
- /**
86
- * @brief Get the keypoints.
87
- *
88
- * @return The keypoints.
89
- */
90
- inline std::vector<std::vector<cv::KeyPoint>> get_keypoints() {
91
- return keypoints;
92
- }
93
-
94
- /**
95
- * @brief Get the descriptors.
96
- *
97
- * @return The descriptors.
98
- */
99
- inline std::vector<cv::Mat> get_descriptors() {
100
- return descriptors;
101
- }
102
-
103
- private:
104
- std::vector<std::vector<cv::KeyPoint>> keypoints;
105
- std::vector<cv::Mat> descriptors;
106
- std::vector<std::vector<float>> scores;
107
- };
108
-
109
- /**
110
- * @brief Decodes a base64 encoded string.
111
- *
112
- * @param base64 The base64 encoded string to decode.
113
- * @return The decoded string.
114
- */
115
- std::string base64_decode(const std::string& base64) {
116
- using namespace boost::archive::iterators;
117
- using It = transform_width<binary_from_base64<std::string::const_iterator>, 8, 6>;
118
-
119
- // Find the position of the last non-whitespace character
120
- auto end = base64.find_last_not_of(" \t\n\r");
121
- if (end != std::string::npos) {
122
- // Move one past the last non-whitespace character
123
- end += 1;
124
- }
125
-
126
- // Decode the base64 string and return the result
127
- return std::string(It(base64.begin()), It(base64.begin() + end));
128
- }
129
-
130
- /**
131
- * @brief Decodes a base64 string into an OpenCV image
132
- *
133
- * @param base64 The base64 encoded string
134
- * @return The decoded OpenCV image
135
- */
136
- cv::Mat base64_to_image(const std::string& base64) {
137
- // Decode the base64 string
138
- std::string decodedStr = base64_decode(base64);
139
-
140
- // Decode the image
141
- std::vector<uchar> data(decodedStr.begin(), decodedStr.end());
142
- cv::Mat img = cv::imdecode(data, cv::IMREAD_GRAYSCALE);
143
-
144
- // Check for errors
145
- if (img.empty()) {
146
- throw std::runtime_error("Failed to decode image");
147
- }
148
-
149
- return img;
150
- }
151
-
152
- /**
153
- * @brief Encodes an OpenCV image into a base64 string
154
- *
155
- * This function takes an OpenCV image and encodes it into a base64 string.
156
- * The image is first encoded as a PNG image, and then the resulting
157
- * bytes are encoded as a base64 string.
158
- *
159
- * @param img The OpenCV image
160
- * @return The base64 encoded string
161
- *
162
- * @throws std::runtime_error if the image is empty or encoding fails
163
- */
164
- std::string image_to_base64(cv::Mat& img) {
165
- if (img.empty()) {
166
- throw std::runtime_error("Failed to read image");
167
- }
168
-
169
- // Encode the image as a PNG
170
- std::vector<uchar> buf;
171
- if (!cv::imencode(".png", img, buf)) {
172
- throw std::runtime_error("Failed to encode image");
173
- }
174
-
175
- // Encode the bytes as a base64 string
176
- using namespace boost::archive::iterators;
177
- using It =
178
- base64_from_binary<transform_width<std::vector<uchar>::const_iterator, 6, 8>>;
179
- std::string base64(It(buf.begin()), It(buf.end()));
180
-
181
- // Pad the string with '=' characters to a multiple of 4 bytes
182
- base64.append((3 - buf.size() % 3) % 3, '=');
183
-
184
- return base64;
185
- }
186
-
187
- /**
188
- * @brief Callback function for libcurl to write data to a string
189
- *
190
- * This function is used as a callback for libcurl to write data to a string.
191
- * It takes the contents, size, and nmemb as parameters, and writes the data to
192
- * the string.
193
- *
194
- * @param contents The data to write
195
- * @param size The size of the data
196
- * @param nmemb The number of members in the data
197
- * @param s The string to write the data to
198
- * @return The number of bytes written
199
- */
200
- size_t WriteCallback(void* contents, size_t size, size_t nmemb, std::string* s) {
201
- size_t newLength = size * nmemb;
202
- try {
203
- // Resize the string to fit the new data
204
- s->resize(s->size() + newLength);
205
- } catch (std::bad_alloc& e) {
206
- // If there's an error allocating memory, return 0
207
- return 0;
208
- }
209
-
210
- // Copy the data to the string
211
- std::copy(static_cast<const char*>(contents),
212
- static_cast<const char*>(contents) + newLength,
213
- s->begin() + s->size() - newLength);
214
- return newLength;
215
- }
216
-
217
- // Helper functions
218
-
219
- /**
220
- * @brief Helper function to convert a type to a Json::Value
221
- *
222
- * This function takes a value of type T and converts it to a Json::Value.
223
- * It is used to simplify the process of converting a type to a Json::Value.
224
- *
225
- * @param val The value to convert
226
- * @return The converted Json::Value
227
- */
228
- template <typename T> Json::Value toJson(const T& val) {
229
- return Json::Value(val);
230
- }
231
-
232
- /**
233
- * @brief Converts a vector to a Json::Value
234
- *
235
- * This function takes a vector of type T and converts it to a Json::Value.
236
- * Each element in the vector is appended to the Json::Value array.
237
- *
238
- * @param vec The vector to convert to Json::Value
239
- * @return The Json::Value representing the vector
240
- */
241
- template <typename T> Json::Value vectorToJson(const std::vector<T>& vec) {
242
- Json::Value json(Json::arrayValue);
243
- for (const auto& item : vec) {
244
- json.append(item);
245
- }
246
- return json;
247
- }
248
-
249
- /**
250
- * @brief Converts a nested vector to a Json::Value
251
- *
252
- * This function takes a nested vector of type T and converts it to a
253
- * Json::Value. Each sub-vector is converted to a Json::Value array and appended
254
- * to the main Json::Value array.
255
- *
256
- * @param vec The nested vector to convert to Json::Value
257
- * @return The Json::Value representing the nested vector
258
- */
259
- template <typename T>
260
- Json::Value nestedVectorToJson(const std::vector<std::vector<T>>& vec) {
261
- Json::Value json(Json::arrayValue);
262
- for (const auto& subVec : vec) {
263
- json.append(vectorToJson(subVec));
264
- }
265
- return json;
266
- }
267
-
268
- /**
269
- * @brief Converts the APIParams struct to a Json::Value
270
- *
271
- * This function takes an APIParams struct and converts it to a Json::Value.
272
- * The Json::Value is a JSON object with the following fields:
273
- * - data: a JSON array of base64 encoded images
274
- * - max_keypoints: a JSON array of integers, max number of keypoints for each
275
- * image
276
- * - timestamps: a JSON array of timestamps, one for each image
277
- * - grayscale: a JSON boolean, whether to convert images to grayscale
278
- * - image_hw: a nested JSON array, each sub-array contains the height and width
279
- * of an image
280
- * - feature_type: a JSON integer, the type of feature detector to use
281
- * - rotates: a JSON array of doubles, the rotation of each image
282
- * - scales: a JSON array of doubles, the scale of each image
283
- * - reference_points: a nested JSON array, each sub-array contains the
284
- * reference points of an image
285
- * - binarize: a JSON boolean, whether to binarize the descriptors
286
- *
287
- * @param params The APIParams struct to convert
288
- * @return The Json::Value representing the APIParams struct
289
- */
290
- Json::Value paramsToJson(const APIParams& params) {
291
- Json::Value json;
292
- json["data"] = vectorToJson(params.data);
293
- json["max_keypoints"] = vectorToJson(params.max_keypoints);
294
- json["timestamps"] = vectorToJson(params.timestamps);
295
- json["grayscale"] = toJson(params.grayscale);
296
- json["image_hw"] = nestedVectorToJson(params.image_hw);
297
- json["feature_type"] = toJson(params.feature_type);
298
- json["rotates"] = vectorToJson(params.rotates);
299
- json["scales"] = vectorToJson(params.scales);
300
- json["reference_points"] = nestedVectorToJson(params.reference_points);
301
- json["binarize"] = toJson(params.binarize);
302
- return json;
303
- }
304
-
305
- template <typename T> cv::Mat jsonToMat(Json::Value json) {
306
- int rows = json.size();
307
- int cols = json[0].size();
308
-
309
- // Create a single array to hold all the data.
310
- std::vector<T> data;
311
- data.reserve(rows * cols);
312
-
313
- for (int i = 0; i < rows; i++) {
314
- for (int j = 0; j < cols; j++) {
315
- data.push_back(static_cast<T>(json[i][j].asInt()));
316
- }
317
- }
318
-
319
- // Create a cv::Mat object that points to the data.
320
- cv::Mat mat(rows, cols, CV_8UC1,
321
- data.data()); // Change the type if necessary.
322
- // cv::Mat mat(cols, rows,CV_8UC1, data.data()); // Change the type if
323
- // necessary.
324
-
325
- return mat;
326
- }
327
-
328
- /**
329
- * @brief Decodes the response of the server and prints the keypoints
330
- *
331
- * This function takes the response of the server, a JSON string, and decodes
332
- * it. It then prints the keypoints and draws them on the original image.
333
- *
334
- * @param response The response of the server
335
- * @return The keypoints and descriptors
336
- */
337
- KeyPointResults decode_response(const std::string& response, bool viz = true) {
338
- Json::CharReaderBuilder builder;
339
- Json::CharReader* reader = builder.newCharReader();
340
-
341
- Json::Value jsonData;
342
- std::string errors;
343
-
344
- // Parse the JSON response
345
- bool parsingSuccessful = reader->parse(
346
- response.c_str(), response.c_str() + response.size(), &jsonData, &errors);
347
- delete reader;
348
-
349
- if (!parsingSuccessful) {
350
- // Handle error
351
- std::cout << "Failed to parse the JSON, errors:" << std::endl;
352
- std::cout << errors << std::endl;
353
- return KeyPointResults();
354
- }
355
-
356
- KeyPointResults kpts_results;
357
-
358
- // Iterate over the images
359
- for (const auto& jsonItem : jsonData) {
360
- auto jkeypoints = jsonItem["keypoints"];
361
- auto jkeypoints_orig = jsonItem["keypoints_orig"];
362
- auto jdescriptors = jsonItem["descriptors"];
363
- auto jscores = jsonItem["scores"];
364
- auto jimageSize = jsonItem["image_size"];
365
- auto joriginalSize = jsonItem["original_size"];
366
- auto jsize = jsonItem["size"];
367
-
368
- std::vector<cv::KeyPoint> vkeypoints;
369
- std::vector<float> vscores;
370
-
371
- // Iterate over the keypoints
372
- int counter = 0;
373
- for (const auto& keypoint : jkeypoints_orig) {
374
- if (counter < 10) {
375
- // Print the first 10 keypoints
376
- std::cout << keypoint[0].asFloat() << ", " << keypoint[1].asFloat()
377
- << std::endl;
378
- }
379
- counter++;
380
- // Convert the Json::Value to a cv::KeyPoint
381
- vkeypoints.emplace_back(
382
- cv::KeyPoint(keypoint[0].asFloat(), keypoint[1].asFloat(), 0.0));
383
- }
384
-
385
- if (viz && jsonItem.isMember("image_orig")) {
386
- auto jimg_orig = jsonItem["image_orig"];
387
- cv::Mat img = jsonToMat<uchar>(jimg_orig);
388
- cv::imwrite("viz_image_orig.jpg", img);
389
-
390
- // Draw keypoints on the image
391
- cv::Mat imgWithKeypoints;
392
- cv::drawKeypoints(img, vkeypoints, imgWithKeypoints, cv::Scalar(0, 0, 255));
393
-
394
- // Write the image with keypoints
395
- std::string filename = "viz_image_orig_keypoints.jpg";
396
- cv::imwrite(filename, imgWithKeypoints);
397
- }
398
-
399
- // Iterate over the descriptors
400
- cv::Mat descriptors = jsonToMat<uchar>(jdescriptors);
401
- kpts_results.append_keypoints(vkeypoints);
402
- kpts_results.append_descriptors(descriptors);
403
- }
404
- return kpts_results;
405
- }
 
1
+
2
+ #include <b64/encode.h>
3
+ #include <fstream>
4
+ #include <jsoncpp/json/json.h>
5
+ #include <opencv2/opencv.hpp>
6
+ #include <sstream>
7
+ #include <vector>
8
+
9
+ // base64 to image
10
+ #include <boost/archive/iterators/base64_from_binary.hpp>
11
+ #include <boost/archive/iterators/binary_from_base64.hpp>
12
+ #include <boost/archive/iterators/transform_width.hpp>
13
+
14
+ /// Parameters used in the API
15
+ struct APIParams {
16
+ /// A list of images, base64 encoded
17
+ std::vector<std::string> data;
18
+
19
+ /// The maximum number of keypoints to detect for each image
20
+ std::vector<int> max_keypoints;
21
+
22
+ /// The timestamps of the images
23
+ std::vector<std::string> timestamps;
24
+
25
+ /// Whether to convert the images to grayscale
26
+ bool grayscale;
27
+
28
+ /// The height and width of each image
29
+ std::vector<std::vector<int>> image_hw;
30
+
31
+ /// The type of feature detector to use
32
+ int feature_type;
33
+
34
+ /// The rotations of the images
35
+ std::vector<double> rotates;
36
+
37
+ /// The scales of the images
38
+ std::vector<double> scales;
39
+
40
+ /// The reference points of the images
41
+ std::vector<std::vector<float>> reference_points;
42
+
43
+ /// Whether to binarize the descriptors
44
+ bool binarize;
45
+ };
46
+
47
+ /**
48
+ * @brief Contains the results of a keypoint detector.
49
+ *
50
+ * @details Stores the keypoints and descriptors for each image.
51
+ */
52
+ class KeyPointResults {
53
+ public:
54
+ KeyPointResults() {
55
+ }
56
+
57
+ /**
58
+ * @brief Constructor.
59
+ *
60
+ * @param kp The keypoints for each image.
61
+ */
62
+ KeyPointResults(const std::vector<std::vector<cv::KeyPoint>>& kp,
63
+ const std::vector<cv::Mat>& desc)
64
+ : keypoints(kp), descriptors(desc) {
65
+ }
66
+
67
+ /**
68
+ * @brief Append keypoints to the result.
69
+ *
70
+ * @param kpts The keypoints to append.
71
+ */
72
+ inline void append_keypoints(std::vector<cv::KeyPoint>& kpts) {
73
+ keypoints.emplace_back(kpts);
74
+ }
75
+
76
+ /**
77
+ * @brief Append descriptors to the result.
78
+ *
79
+ * @param desc The descriptors to append.
80
+ */
81
+ inline void append_descriptors(cv::Mat& desc) {
82
+ descriptors.emplace_back(desc);
83
+ }
84
+
85
+ /**
86
+ * @brief Get the keypoints.
87
+ *
88
+ * @return The keypoints.
89
+ */
90
+ inline std::vector<std::vector<cv::KeyPoint>> get_keypoints() {
91
+ return keypoints;
92
+ }
93
+
94
+ /**
95
+ * @brief Get the descriptors.
96
+ *
97
+ * @return The descriptors.
98
+ */
99
+ inline std::vector<cv::Mat> get_descriptors() {
100
+ return descriptors;
101
+ }
102
+
103
+ private:
104
+ std::vector<std::vector<cv::KeyPoint>> keypoints;
105
+ std::vector<cv::Mat> descriptors;
106
+ std::vector<std::vector<float>> scores;
107
+ };
108
+
109
+ /**
110
+ * @brief Decodes a base64 encoded string.
111
+ *
112
+ * @param base64 The base64 encoded string to decode.
113
+ * @return The decoded string.
114
+ */
115
+ std::string base64_decode(const std::string& base64) {
116
+ using namespace boost::archive::iterators;
117
+ using It = transform_width<binary_from_base64<std::string::const_iterator>, 8, 6>;
118
+
119
+ // Find the position of the last non-whitespace character
120
+ auto end = base64.find_last_not_of(" \t\n\r");
121
+ if (end != std::string::npos) {
122
+ // Move one past the last non-whitespace character
123
+ end += 1;
124
+ }
125
+
126
+ // Decode the base64 string and return the result
127
+ return std::string(It(base64.begin()), It(base64.begin() + end));
128
+ }
129
+
130
+ /**
131
+ * @brief Decodes a base64 string into an OpenCV image
132
+ *
133
+ * @param base64 The base64 encoded string
134
+ * @return The decoded OpenCV image
135
+ */
136
+ cv::Mat base64_to_image(const std::string& base64) {
137
+ // Decode the base64 string
138
+ std::string decodedStr = base64_decode(base64);
139
+
140
+ // Decode the image
141
+ std::vector<uchar> data(decodedStr.begin(), decodedStr.end());
142
+ cv::Mat img = cv::imdecode(data, cv::IMREAD_GRAYSCALE);
143
+
144
+ // Check for errors
145
+ if (img.empty()) {
146
+ throw std::runtime_error("Failed to decode image");
147
+ }
148
+
149
+ return img;
150
+ }
151
+
152
+ /**
153
+ * @brief Encodes an OpenCV image into a base64 string
154
+ *
155
+ * This function takes an OpenCV image and encodes it into a base64 string.
156
+ * The image is first encoded as a PNG image, and then the resulting
157
+ * bytes are encoded as a base64 string.
158
+ *
159
+ * @param img The OpenCV image
160
+ * @return The base64 encoded string
161
+ *
162
+ * @throws std::runtime_error if the image is empty or encoding fails
163
+ */
164
+ std::string image_to_base64(cv::Mat& img) {
165
+ if (img.empty()) {
166
+ throw std::runtime_error("Failed to read image");
167
+ }
168
+
169
+ // Encode the image as a PNG
170
+ std::vector<uchar> buf;
171
+ if (!cv::imencode(".png", img, buf)) {
172
+ throw std::runtime_error("Failed to encode image");
173
+ }
174
+
175
+ // Encode the bytes as a base64 string
176
+ using namespace boost::archive::iterators;
177
+ using It =
178
+ base64_from_binary<transform_width<std::vector<uchar>::const_iterator, 6, 8>>;
179
+ std::string base64(It(buf.begin()), It(buf.end()));
180
+
181
+ // Pad the string with '=' characters to a multiple of 4 bytes
182
+ base64.append((3 - buf.size() % 3) % 3, '=');
183
+
184
+ return base64;
185
+ }
186
+
187
+ /**
188
+ * @brief Callback function for libcurl to write data to a string
189
+ *
190
+ * This function is used as a callback for libcurl to write data to a string.
191
+ * It takes the contents, size, and nmemb as parameters, and writes the data to
192
+ * the string.
193
+ *
194
+ * @param contents The data to write
195
+ * @param size The size of the data
196
+ * @param nmemb The number of members in the data
197
+ * @param s The string to write the data to
198
+ * @return The number of bytes written
199
+ */
200
+ size_t WriteCallback(void* contents, size_t size, size_t nmemb, std::string* s) {
201
+ size_t newLength = size * nmemb;
202
+ try {
203
+ // Resize the string to fit the new data
204
+ s->resize(s->size() + newLength);
205
+ } catch (std::bad_alloc& e) {
206
+ // If there's an error allocating memory, return 0
207
+ return 0;
208
+ }
209
+
210
+ // Copy the data to the string
211
+ std::copy(static_cast<const char*>(contents),
212
+ static_cast<const char*>(contents) + newLength,
213
+ s->begin() + s->size() - newLength);
214
+ return newLength;
215
+ }
216
+
217
+ // Helper functions
218
+
219
+ /**
220
+ * @brief Helper function to convert a type to a Json::Value
221
+ *
222
+ * This function takes a value of type T and converts it to a Json::Value.
223
+ * It is used to simplify the process of converting a type to a Json::Value.
224
+ *
225
+ * @param val The value to convert
226
+ * @return The converted Json::Value
227
+ */
228
+ template <typename T> Json::Value toJson(const T& val) {
229
+ return Json::Value(val);
230
+ }
231
+
232
+ /**
233
+ * @brief Converts a vector to a Json::Value
234
+ *
235
+ * This function takes a vector of type T and converts it to a Json::Value.
236
+ * Each element in the vector is appended to the Json::Value array.
237
+ *
238
+ * @param vec The vector to convert to Json::Value
239
+ * @return The Json::Value representing the vector
240
+ */
241
+ template <typename T> Json::Value vectorToJson(const std::vector<T>& vec) {
242
+ Json::Value json(Json::arrayValue);
243
+ for (const auto& item : vec) {
244
+ json.append(item);
245
+ }
246
+ return json;
247
+ }
248
+
249
+ /**
250
+ * @brief Converts a nested vector to a Json::Value
251
+ *
252
+ * This function takes a nested vector of type T and converts it to a
253
+ * Json::Value. Each sub-vector is converted to a Json::Value array and appended
254
+ * to the main Json::Value array.
255
+ *
256
+ * @param vec The nested vector to convert to Json::Value
257
+ * @return The Json::Value representing the nested vector
258
+ */
259
+ template <typename T>
260
+ Json::Value nestedVectorToJson(const std::vector<std::vector<T>>& vec) {
261
+ Json::Value json(Json::arrayValue);
262
+ for (const auto& subVec : vec) {
263
+ json.append(vectorToJson(subVec));
264
+ }
265
+ return json;
266
+ }
267
+
268
+ /**
269
+ * @brief Converts the APIParams struct to a Json::Value
270
+ *
271
+ * This function takes an APIParams struct and converts it to a Json::Value.
272
+ * The Json::Value is a JSON object with the following fields:
273
+ * - data: a JSON array of base64 encoded images
274
+ * - max_keypoints: a JSON array of integers, max number of keypoints for each
275
+ * image
276
+ * - timestamps: a JSON array of timestamps, one for each image
277
+ * - grayscale: a JSON boolean, whether to convert images to grayscale
278
+ * - image_hw: a nested JSON array, each sub-array contains the height and width
279
+ * of an image
280
+ * - feature_type: a JSON integer, the type of feature detector to use
281
+ * - rotates: a JSON array of doubles, the rotation of each image
282
+ * - scales: a JSON array of doubles, the scale of each image
283
+ * - reference_points: a nested JSON array, each sub-array contains the
284
+ * reference points of an image
285
+ * - binarize: a JSON boolean, whether to binarize the descriptors
286
+ *
287
+ * @param params The APIParams struct to convert
288
+ * @return The Json::Value representing the APIParams struct
289
+ */
290
+ Json::Value paramsToJson(const APIParams& params) {
291
+ Json::Value json;
292
+ json["data"] = vectorToJson(params.data);
293
+ json["max_keypoints"] = vectorToJson(params.max_keypoints);
294
+ json["timestamps"] = vectorToJson(params.timestamps);
295
+ json["grayscale"] = toJson(params.grayscale);
296
+ json["image_hw"] = nestedVectorToJson(params.image_hw);
297
+ json["feature_type"] = toJson(params.feature_type);
298
+ json["rotates"] = vectorToJson(params.rotates);
299
+ json["scales"] = vectorToJson(params.scales);
300
+ json["reference_points"] = nestedVectorToJson(params.reference_points);
301
+ json["binarize"] = toJson(params.binarize);
302
+ return json;
303
+ }
304
+
305
+ template <typename T> cv::Mat jsonToMat(Json::Value json) {
306
+ int rows = json.size();
307
+ int cols = json[0].size();
308
+
309
+ // Create a single array to hold all the data.
310
+ std::vector<T> data;
311
+ data.reserve(rows * cols);
312
+
313
+ for (int i = 0; i < rows; i++) {
314
+ for (int j = 0; j < cols; j++) {
315
+ data.push_back(static_cast<T>(json[i][j].asInt()));
316
+ }
317
+ }
318
+
319
+ // Create a cv::Mat object that points to the data.
320
+ cv::Mat mat(rows, cols, CV_8UC1,
321
+ data.data()); // Change the type if necessary.
322
+ // cv::Mat mat(cols, rows,CV_8UC1, data.data()); // Change the type if
323
+ // necessary.
324
+
325
+ return mat;
326
+ }
327
+
328
+ /**
329
+ * @brief Decodes the response of the server and prints the keypoints
330
+ *
331
+ * This function takes the response of the server, a JSON string, and decodes
332
+ * it. It then prints the keypoints and draws them on the original image.
333
+ *
334
+ * @param response The response of the server
335
+ * @return The keypoints and descriptors
336
+ */
337
+ KeyPointResults decode_response(const std::string& response, bool viz = true) {
338
+ Json::CharReaderBuilder builder;
339
+ Json::CharReader* reader = builder.newCharReader();
340
+
341
+ Json::Value jsonData;
342
+ std::string errors;
343
+
344
+ // Parse the JSON response
345
+ bool parsingSuccessful = reader->parse(
346
+ response.c_str(), response.c_str() + response.size(), &jsonData, &errors);
347
+ delete reader;
348
+
349
+ if (!parsingSuccessful) {
350
+ // Handle error
351
+ std::cout << "Failed to parse the JSON, errors:" << std::endl;
352
+ std::cout << errors << std::endl;
353
+ return KeyPointResults();
354
+ }
355
+
356
+ KeyPointResults kpts_results;
357
+
358
+ // Iterate over the images
359
+ for (const auto& jsonItem : jsonData) {
360
+ auto jkeypoints = jsonItem["keypoints"];
361
+ auto jkeypoints_orig = jsonItem["keypoints_orig"];
362
+ auto jdescriptors = jsonItem["descriptors"];
363
+ auto jscores = jsonItem["scores"];
364
+ auto jimageSize = jsonItem["image_size"];
365
+ auto joriginalSize = jsonItem["original_size"];
366
+ auto jsize = jsonItem["size"];
367
+
368
+ std::vector<cv::KeyPoint> vkeypoints;
369
+ std::vector<float> vscores;
370
+
371
+ // Iterate over the keypoints
372
+ int counter = 0;
373
+ for (const auto& keypoint : jkeypoints_orig) {
374
+ if (counter < 10) {
375
+ // Print the first 10 keypoints
376
+ std::cout << keypoint[0].asFloat() << ", " << keypoint[1].asFloat()
377
+ << std::endl;
378
+ }
379
+ counter++;
380
+ // Convert the Json::Value to a cv::KeyPoint
381
+ vkeypoints.emplace_back(
382
+ cv::KeyPoint(keypoint[0].asFloat(), keypoint[1].asFloat(), 0.0));
383
+ }
384
+
385
+ if (viz && jsonItem.isMember("image_orig")) {
386
+ auto jimg_orig = jsonItem["image_orig"];
387
+ cv::Mat img = jsonToMat<uchar>(jimg_orig);
388
+ cv::imwrite("viz_image_orig.jpg", img);
389
+
390
+ // Draw keypoints on the image
391
+ cv::Mat imgWithKeypoints;
392
+ cv::drawKeypoints(img, vkeypoints, imgWithKeypoints, cv::Scalar(0, 0, 255));
393
+
394
+ // Write the image with keypoints
395
+ std::string filename = "viz_image_orig_keypoints.jpg";
396
+ cv::imwrite(filename, imgWithKeypoints);
397
+ }
398
+
399
+ // Iterate over the descriptors
400
+ cv::Mat descriptors = jsonToMat<uchar>(jdescriptors);
401
+ kpts_results.append_keypoints(vkeypoints);
402
+ kpts_results.append_descriptors(descriptors);
403
+ }
404
+ return kpts_results;
405
+ }
imcui/ui/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- __version__ = "1.0.1"
2
-
3
-
4
- def get_version():
5
- return __version__
 
1
+ __version__ = "1.3.0"
2
+
3
+
4
+ def get_version():
5
+ return __version__
imcui/ui/app_class.py CHANGED
@@ -1,820 +1,816 @@
1
- from pathlib import Path
2
- from typing import Any, Dict, Optional, Tuple
3
-
4
- import gradio as gr
5
- import numpy as np
6
- from easydict import EasyDict as edict
7
- from omegaconf import OmegaConf
8
-
9
- from .sfm import SfmEngine
10
- from .utils import (
11
- GRADIO_VERSION,
12
- gen_examples,
13
- generate_warp_images,
14
- get_matcher_zoo,
15
- load_config,
16
- ransac_zoo,
17
- run_matching,
18
- run_ransac,
19
- send_to_match,
20
- )
21
-
22
- DESCRIPTION = """
23
- # Image Matching WebUI
24
- This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
25
- <br/>
26
- 🔎 For more details about supported local features and matchers, please refer to https://github.com/Vincentqyw/image-matching-webui
27
-
28
- 🚀 All algorithms run on CPU for inference, causing slow speeds and high latency. For faster inference, please download the [source code](https://github.com/Vincentqyw/image-matching-webui) for local deployment.
29
-
30
- 🐛 Your feedback is valuable to me. Please do not hesitate to report any bugs [here](https://github.com/Vincentqyw/image-matching-webui/issues).
31
- """
32
-
33
- CSS = """
34
- #warning {background-color: #FFCCCB}
35
- .logs_class textarea {font-size: 12px !important}
36
- """
37
-
38
-
39
- class ImageMatchingApp:
40
- def __init__(self, server_name="0.0.0.0", server_port=7860, **kwargs):
41
- self.server_name = server_name
42
- self.server_port = server_port
43
- self.config_path = kwargs.get("config", Path(__file__).parent / "config.yaml")
44
- self.cfg = load_config(self.config_path)
45
- self.matcher_zoo = get_matcher_zoo(self.cfg["matcher_zoo"])
46
- self.app = None
47
- self.example_data_root = kwargs.get(
48
- "example_data_root", Path(__file__).parents[1] / "datasets"
49
- )
50
- # final step
51
- self.init_interface()
52
-
53
- def init_matcher_dropdown(self):
54
- algos = []
55
- for k, v in self.cfg["matcher_zoo"].items():
56
- if v.get("enable", True):
57
- algos.append(k)
58
- return algos
59
-
60
- def init_interface(self):
61
- with gr.Blocks(css=CSS) as self.app:
62
- with gr.Tab("Image Matching"):
63
- with gr.Row():
64
- with gr.Column(scale=1):
65
- gr.Image(
66
- str(Path(__file__).parent.parent / "assets/logo.webp"),
67
- elem_id="logo-img",
68
- show_label=False,
69
- show_share_button=False,
70
- show_download_button=False,
71
- )
72
- with gr.Column(scale=3):
73
- gr.Markdown(DESCRIPTION)
74
- with gr.Row(equal_height=False):
75
- with gr.Column():
76
- with gr.Row():
77
- matcher_list = gr.Dropdown(
78
- choices=self.init_matcher_dropdown(),
79
- value="disk+lightglue",
80
- label="Matching Model",
81
- interactive=True,
82
- )
83
- match_image_src = gr.Radio(
84
- (
85
- ["upload", "webcam", "clipboard"]
86
- if GRADIO_VERSION > "3"
87
- else ["upload", "webcam", "canvas"]
88
- ),
89
- label="Image Source",
90
- value="upload",
91
- )
92
- with gr.Row():
93
- input_image0 = gr.Image(
94
- label="Image 0",
95
- type="numpy",
96
- image_mode="RGB",
97
- height=300 if GRADIO_VERSION > "3" else None,
98
- interactive=True,
99
- )
100
- input_image1 = gr.Image(
101
- label="Image 1",
102
- type="numpy",
103
- image_mode="RGB",
104
- height=300 if GRADIO_VERSION > "3" else None,
105
- interactive=True,
106
- )
107
-
108
- with gr.Row():
109
- button_reset = gr.Button(value="Reset")
110
- button_run = gr.Button(value="Run Match", variant="primary")
111
- with gr.Row():
112
- button_stop = gr.Button(value="Force Stop", variant="stop")
113
-
114
- with gr.Accordion("Advanced Setting", open=False):
115
- with gr.Accordion("Image Setting", open=True):
116
- with gr.Row():
117
- image_force_resize_cb = gr.Checkbox(
118
- label="Force Resize",
119
- value=False,
120
- interactive=True,
121
- )
122
- image_setting_height = gr.Slider(
123
- minimum=48,
124
- maximum=2048,
125
- step=16,
126
- label="Image Height",
127
- value=480,
128
- visible=False,
129
- )
130
- image_setting_width = gr.Slider(
131
- minimum=64,
132
- maximum=2048,
133
- step=16,
134
- label="Image Width",
135
- value=640,
136
- visible=False,
137
- )
138
- with gr.Accordion("Matching Setting", open=True):
139
- with gr.Row():
140
- match_setting_threshold = gr.Slider(
141
- minimum=0.0,
142
- maximum=1,
143
- step=0.001,
144
- label="Match threshold",
145
- value=0.1,
146
- )
147
- match_setting_max_keypoints = gr.Slider(
148
- minimum=10,
149
- maximum=10000,
150
- step=10,
151
- label="Max features",
152
- value=1000,
153
- )
154
- # TODO: add line settings
155
- with gr.Row():
156
- detect_keypoints_threshold = gr.Slider(
157
- minimum=0,
158
- maximum=1,
159
- step=0.001,
160
- label="Keypoint threshold",
161
- value=0.015,
162
- )
163
- detect_line_threshold = ( # noqa: F841
164
- gr.Slider(
165
- minimum=0.1,
166
- maximum=1,
167
- step=0.01,
168
- label="Line threshold",
169
- value=0.2,
170
- )
171
- )
172
- # matcher_lists = gr.Radio(
173
- # ["NN-mutual", "Dual-Softmax"],
174
- # label="Matcher mode",
175
- # value="NN-mutual",
176
- # )
177
- with gr.Accordion("RANSAC Setting", open=True):
178
- with gr.Row(equal_height=False):
179
- ransac_method = gr.Dropdown(
180
- choices=ransac_zoo.keys(),
181
- value=self.cfg["defaults"]["ransac_method"],
182
- label="RANSAC Method",
183
- interactive=True,
184
- )
185
- ransac_reproj_threshold = gr.Slider(
186
- minimum=0.0,
187
- maximum=12,
188
- step=0.01,
189
- label="Ransac Reproj threshold",
190
- value=8.0,
191
- )
192
- ransac_confidence = gr.Slider(
193
- minimum=0.0,
194
- maximum=1,
195
- step=0.00001,
196
- label="Ransac Confidence",
197
- value=self.cfg["defaults"]["ransac_confidence"],
198
- )
199
- ransac_max_iter = gr.Slider(
200
- minimum=0.0,
201
- maximum=100000,
202
- step=100,
203
- label="Ransac Iterations",
204
- value=self.cfg["defaults"]["ransac_max_iter"],
205
- )
206
- button_ransac = gr.Button(
207
- value="Rerun RANSAC", variant="primary"
208
- )
209
- with gr.Accordion("Geometry Setting", open=False):
210
- with gr.Row(equal_height=False):
211
- choice_geometry_type = gr.Radio(
212
- ["Fundamental", "Homography"],
213
- label="Reconstruct Geometry",
214
- value=self.cfg["defaults"]["setting_geometry"],
215
- )
216
- # image resize
217
- image_force_resize_cb.select(
218
- fn=self._on_select_force_resize,
219
- inputs=image_force_resize_cb,
220
- outputs=[image_setting_width, image_setting_height],
221
- )
222
- # collect inputs
223
- state_cache = gr.State({})
224
- inputs = [
225
- input_image0,
226
- input_image1,
227
- match_setting_threshold,
228
- match_setting_max_keypoints,
229
- detect_keypoints_threshold,
230
- matcher_list,
231
- ransac_method,
232
- ransac_reproj_threshold,
233
- ransac_confidence,
234
- ransac_max_iter,
235
- choice_geometry_type,
236
- gr.State(self.matcher_zoo),
237
- image_force_resize_cb,
238
- image_setting_width,
239
- image_setting_height,
240
- ]
241
-
242
- # Add some examples
243
- with gr.Row():
244
- # Example inputs
245
- with gr.Accordion("Open for More: Examples", open=True):
246
- gr.Examples(
247
- examples=gen_examples(self.example_data_root),
248
- inputs=inputs,
249
- outputs=[],
250
- fn=run_matching,
251
- cache_examples=False,
252
- label=(
253
- "Examples (click one of the images below to Run"
254
- " Match). Thx: WxBS"
255
- ),
256
- )
257
- with gr.Accordion("Supported Algorithms", open=False):
258
- # add a table of supported algorithms
259
- self.display_supported_algorithms()
260
-
261
- with gr.Column():
262
- with gr.Accordion("Open for More: Keypoints", open=True):
263
- output_keypoints = gr.Image(label="Keypoints", type="numpy")
264
- with gr.Accordion(
265
- (
266
- "Open for More: Raw Matches"
267
- " (Green for good matches, Red for bad)"
268
- ),
269
- open=False,
270
- ):
271
- output_matches_raw = gr.Image(
272
- label="Raw Matches",
273
- type="numpy",
274
- )
275
- with gr.Accordion(
276
- (
277
- "Open for More: Ransac Matches"
278
- " (Green for good matches, Red for bad)"
279
- ),
280
- open=True,
281
- ):
282
- output_matches_ransac = gr.Image(
283
- label="Ransac Matches", type="numpy"
284
- )
285
- with gr.Accordion(
286
- "Open for More: Matches Statistics", open=False
287
- ):
288
- output_pred = gr.File(label="Outputs", elem_id="download")
289
- matches_result_info = gr.JSON(label="Matches Statistics")
290
- matcher_info = gr.JSON(label="Match info")
291
-
292
- with gr.Accordion("Open for More: Warped Image", open=True):
293
- output_wrapped = gr.Image(
294
- label="Wrapped Pair", type="numpy"
295
- )
296
- # send to input
297
- button_rerun = gr.Button(
298
- value="Send to Input Match Pair",
299
- variant="primary",
300
- )
301
- with gr.Accordion(
302
- "Open for More: Geometry info", open=False
303
- ):
304
- geometry_result = gr.JSON(
305
- label="Reconstructed Geometry"
306
- )
307
-
308
- # callbacks
309
- match_image_src.change(
310
- fn=self.ui_change_imagebox,
311
- inputs=match_image_src,
312
- outputs=input_image0,
313
- )
314
- match_image_src.change(
315
- fn=self.ui_change_imagebox,
316
- inputs=match_image_src,
317
- outputs=input_image1,
318
- )
319
- # collect outputs
320
- outputs = [
321
- output_keypoints,
322
- output_matches_raw,
323
- output_matches_ransac,
324
- matches_result_info,
325
- matcher_info,
326
- geometry_result,
327
- output_wrapped,
328
- state_cache,
329
- output_pred,
330
- ]
331
- # button callbacks
332
- click_event = button_run.click(
333
- fn=run_matching, inputs=inputs, outputs=outputs
334
- )
335
- # stop button
336
- button_stop.click(
337
- fn=None, inputs=None, outputs=None, cancels=[click_event]
338
- )
339
-
340
- # Reset images
341
- reset_outputs = [
342
- input_image0,
343
- input_image1,
344
- match_setting_threshold,
345
- match_setting_max_keypoints,
346
- detect_keypoints_threshold,
347
- matcher_list,
348
- input_image0,
349
- input_image1,
350
- match_image_src,
351
- output_keypoints,
352
- output_matches_raw,
353
- output_matches_ransac,
354
- matches_result_info,
355
- matcher_info,
356
- output_wrapped,
357
- geometry_result,
358
- ransac_method,
359
- ransac_reproj_threshold,
360
- ransac_confidence,
361
- ransac_max_iter,
362
- choice_geometry_type,
363
- output_pred,
364
- image_force_resize_cb,
365
- ]
366
- button_reset.click(
367
- fn=self.ui_reset_state,
368
- inputs=None,
369
- outputs=reset_outputs,
370
- )
371
-
372
- # run ransac button action
373
- button_ransac.click(
374
- fn=run_ransac,
375
- inputs=[
376
- state_cache,
377
- choice_geometry_type,
378
- ransac_method,
379
- ransac_reproj_threshold,
380
- ransac_confidence,
381
- ransac_max_iter,
382
- ],
383
- outputs=[
384
- output_matches_ransac,
385
- matches_result_info,
386
- output_wrapped,
387
- output_pred,
388
- ],
389
- )
390
-
391
- # send warped image to match
392
- button_rerun.click(
393
- fn=send_to_match,
394
- inputs=[state_cache],
395
- outputs=[input_image0, input_image1],
396
- )
397
-
398
- # estimate geo
399
- choice_geometry_type.change(
400
- fn=generate_warp_images,
401
- inputs=[
402
- input_image0,
403
- input_image1,
404
- geometry_result,
405
- choice_geometry_type,
406
- ],
407
- outputs=[output_wrapped, geometry_result],
408
- )
409
- with gr.Tab("Structure from Motion(under-dev)"):
410
- sfm_ui = AppSfmUI( # noqa: F841
411
- {
412
- **self.cfg,
413
- "matcher_zoo": self.matcher_zoo,
414
- "outputs": "experiments/sfm",
415
- }
416
- )
417
- sfm_ui.call_empty()
418
-
419
- def run(self):
420
- self.app.queue().launch(
421
- server_name=self.server_name,
422
- server_port=self.server_port,
423
- share=False,
424
- allowed_paths=[
425
- str(Path(__file__).parents[0]),
426
- str(Path(__file__).parents[1]),
427
- ],
428
- )
429
-
430
- def ui_change_imagebox(self, choice):
431
- """
432
- Updates the image box with the given choice.
433
-
434
- Args:
435
- choice (list): The list of image sources to be displayed in the image box.
436
-
437
- Returns:
438
- dict: A dictionary containing the updated value, sources, and type for the image box.
439
- """
440
- ret_dict = {
441
- "value": None, # The updated value of the image box
442
- "__type__": "update", # The type of update for the image box
443
- }
444
- if GRADIO_VERSION > "3":
445
- return {
446
- **ret_dict,
447
- "sources": choice, # The list of image sources to be displayed
448
- }
449
- else:
450
- return {
451
- **ret_dict,
452
- "source": choice, # The list of image sources to be displayed
453
- }
454
-
455
- def _on_select_force_resize(self, visible: bool = False):
456
- return gr.update(visible=visible), gr.update(visible=visible)
457
-
458
- def ui_reset_state(
459
- self,
460
- *args: Any,
461
- ) -> Tuple[
462
- Optional[np.ndarray],
463
- Optional[np.ndarray],
464
- float,
465
- int,
466
- float,
467
- str,
468
- Dict[str, Any],
469
- Dict[str, Any],
470
- str,
471
- Optional[np.ndarray],
472
- Optional[np.ndarray],
473
- Optional[np.ndarray],
474
- Dict[str, Any],
475
- Dict[str, Any],
476
- Optional[np.ndarray],
477
- Dict[str, Any],
478
- str,
479
- int,
480
- float,
481
- int,
482
- bool,
483
- ]:
484
- """
485
- Reset the state of the UI.
486
-
487
- Returns:
488
- tuple: A tuple containing the initial values for the UI state.
489
- """
490
- key: str = list(self.matcher_zoo.keys())[
491
- 0
492
- ] # Get the first key from matcher_zoo
493
- # flush_logs()
494
- return (
495
- None, # image0: Optional[np.ndarray]
496
- None, # image1: Optional[np.ndarray]
497
- self.cfg["defaults"]["match_threshold"], # matching_threshold: float
498
- self.cfg["defaults"]["max_keypoints"], # max_keypoints: int
499
- self.cfg["defaults"]["keypoint_threshold"], # keypoint_threshold: float
500
- key, # matcher: str
501
- self.ui_change_imagebox("upload"), # input image0: Dict[str, Any]
502
- self.ui_change_imagebox("upload"), # input image1: Dict[str, Any]
503
- "upload", # match_image_src: str
504
- None, # keypoints: Optional[np.ndarray]
505
- None, # raw matches: Optional[np.ndarray]
506
- None, # ransac matches: Optional[np.ndarray]
507
- {}, # matches result info: Dict[str, Any]
508
- {}, # matcher config: Dict[str, Any]
509
- None, # warped image: Optional[np.ndarray]
510
- {}, # geometry result: Dict[str, Any]
511
- self.cfg["defaults"]["ransac_method"], # ransac_method: str
512
- self.cfg["defaults"][
513
- "ransac_reproj_threshold"
514
- ], # ransac_reproj_threshold: float
515
- self.cfg["defaults"]["ransac_confidence"], # ransac_confidence: float
516
- self.cfg["defaults"]["ransac_max_iter"], # ransac_max_iter: int
517
- self.cfg["defaults"]["setting_geometry"], # geometry: str
518
- None, # predictions
519
- False,
520
- )
521
-
522
- def display_supported_algorithms(self, style="tab"):
523
- def get_link(link, tag="Link"):
524
- return "[{}]({})".format(tag, link) if link is not None else "None"
525
-
526
- data = []
527
- cfg = self.cfg["matcher_zoo"]
528
- if style == "md":
529
- markdown_table = "| Algo. | Conference | Code | Project | Paper |\n"
530
- markdown_table += "| ----- | ---------- | ---- | ------- | ----- |\n"
531
-
532
- for _, v in cfg.items():
533
- if not v["info"].get("display", True):
534
- continue
535
- github_link = get_link(v["info"].get("github", ""))
536
- project_link = get_link(v["info"].get("project", ""))
537
- paper_link = get_link(
538
- v["info"]["paper"],
539
- (
540
- Path(v["info"]["paper"]).name[-10:]
541
- if v["info"]["paper"] is not None
542
- else "Link"
543
- ),
544
- )
545
-
546
- markdown_table += "{}|{}|{}|{}|{}\n".format(
547
- v["info"].get("name", ""),
548
- v["info"].get("source", ""),
549
- github_link,
550
- project_link,
551
- paper_link,
552
- )
553
- return gr.Markdown(markdown_table)
554
- elif style == "tab":
555
- for k, v in cfg.items():
556
- if not v["info"].get("display", True):
557
- continue
558
- data.append(
559
- [
560
- v["info"].get("name", ""),
561
- v["info"].get("source", ""),
562
- v["info"].get("github", ""),
563
- v["info"].get("paper", ""),
564
- v["info"].get("project", ""),
565
- ]
566
- )
567
- tab = gr.Dataframe(
568
- headers=["Algo.", "Conference", "Code", "Paper", "Project"],
569
- datatype=["str", "str", "str", "str", "str"],
570
- col_count=(5, "fixed"),
571
- value=data,
572
- # wrap=True,
573
- # min_width = 1000,
574
- # height=1000,
575
- )
576
- return tab
577
-
578
-
579
- class AppBaseUI:
580
- def __init__(self, cfg: Dict[str, Any] = {}):
581
- self.cfg = OmegaConf.create(cfg)
582
- self.inputs = edict({})
583
- self.outputs = edict({})
584
- self.ui = edict({})
585
-
586
- def _init_ui(self):
587
- NotImplemented
588
-
589
- def call(self, **kwargs):
590
- NotImplemented
591
-
592
- def info(self):
593
- gr.Info("SFM is under construction.")
594
-
595
-
596
- class AppSfmUI(AppBaseUI):
597
- def __init__(self, cfg: Dict[str, Any] = None):
598
- super().__init__(cfg)
599
- assert "matcher_zoo" in self.cfg
600
- self.matcher_zoo = self.cfg["matcher_zoo"]
601
- self.sfm_engine = SfmEngine(cfg)
602
- self._init_ui()
603
-
604
- def init_retrieval_dropdown(self):
605
- algos = []
606
- for k, v in self.cfg["retrieval_zoo"].items():
607
- if v.get("enable", True):
608
- algos.append(k)
609
- return algos
610
-
611
- def _update_options(self, option):
612
- if option == "sparse":
613
- return gr.Textbox("sparse", visible=True)
614
- elif option == "dense":
615
- return gr.Textbox("dense", visible=True)
616
- else:
617
- return gr.Textbox("not set", visible=True)
618
-
619
- def _on_select_custom_params(self, value: bool = False):
620
- return gr.update(visible=value)
621
-
622
- def _init_ui(self):
623
- with gr.Row():
624
- # data settting and camera settings
625
- with gr.Column():
626
- self.inputs.input_images = gr.File(
627
- label="SfM",
628
- interactive=True,
629
- file_count="multiple",
630
- min_width=300,
631
- )
632
- # camera setting
633
- with gr.Accordion("Camera Settings", open=True):
634
- with gr.Column():
635
- with gr.Row():
636
- with gr.Column():
637
- self.inputs.camera_model = gr.Dropdown(
638
- choices=[
639
- "PINHOLE",
640
- "SIMPLE_RADIAL",
641
- "OPENCV",
642
- ],
643
- value="PINHOLE",
644
- label="Camera Model",
645
- interactive=True,
646
- )
647
- with gr.Column():
648
- gr.Checkbox(
649
- label="Shared Params",
650
- value=True,
651
- interactive=True,
652
- )
653
- camera_custom_params_cb = gr.Checkbox(
654
- label="Custom Params",
655
- value=False,
656
- interactive=True,
657
- )
658
- with gr.Row():
659
- self.inputs.camera_params = gr.Textbox(
660
- label="Camera Params",
661
- value="0,0,0,0",
662
- interactive=False,
663
- visible=False,
664
- )
665
- camera_custom_params_cb.select(
666
- fn=self._on_select_custom_params,
667
- inputs=camera_custom_params_cb,
668
- outputs=self.inputs.camera_params,
669
- )
670
-
671
- with gr.Accordion("Matching Settings", open=True):
672
- # feature extraction and matching setting
673
- with gr.Row():
674
- # matcher setting
675
- self.inputs.matcher_key = gr.Dropdown(
676
- choices=self.matcher_zoo.keys(),
677
- value="disk+lightglue",
678
- label="Matching Model",
679
- interactive=True,
680
- )
681
- with gr.Row():
682
- with gr.Accordion("Advanced Settings", open=False):
683
- with gr.Column():
684
- with gr.Row():
685
- # matching setting
686
- self.inputs.max_keypoints = gr.Slider(
687
- label="Max Keypoints",
688
- minimum=100,
689
- maximum=10000,
690
- value=1000,
691
- interactive=True,
692
- )
693
- self.inputs.keypoint_threshold = gr.Slider(
694
- label="Keypoint Threshold",
695
- minimum=0,
696
- maximum=1,
697
- value=0.01,
698
- )
699
- with gr.Row():
700
- self.inputs.match_threshold = gr.Slider(
701
- label="Match Threshold",
702
- minimum=0.01,
703
- maximum=12.0,
704
- value=0.2,
705
- )
706
- self.inputs.ransac_threshold = gr.Slider(
707
- label="Ransac Threshold",
708
- minimum=0.01,
709
- maximum=12.0,
710
- value=4.0,
711
- step=0.01,
712
- interactive=True,
713
- )
714
-
715
- with gr.Row():
716
- self.inputs.ransac_confidence = gr.Slider(
717
- label="Ransac Confidence",
718
- minimum=0.01,
719
- maximum=1.0,
720
- value=0.9999,
721
- step=0.0001,
722
- interactive=True,
723
- )
724
- self.inputs.ransac_max_iter = gr.Slider(
725
- label="Ransac Max Iter",
726
- minimum=1,
727
- maximum=100,
728
- value=100,
729
- step=1,
730
- interactive=True,
731
- )
732
- with gr.Accordion("Scene Graph Settings", open=True):
733
- # mapping setting
734
- self.inputs.scene_graph = gr.Dropdown(
735
- choices=["all", "swin", "oneref"],
736
- value="all",
737
- label="Scene Graph",
738
- interactive=True,
739
- )
740
-
741
- # global feature setting
742
- self.inputs.global_feature = gr.Dropdown(
743
- choices=self.init_retrieval_dropdown(),
744
- value="netvlad",
745
- label="Global features",
746
- interactive=True,
747
- )
748
- self.inputs.top_k = gr.Slider(
749
- label="Number of Images per Image to Match",
750
- minimum=1,
751
- maximum=100,
752
- value=10,
753
- step=1,
754
- )
755
- # button_match = gr.Button("Run Matching", variant="primary")
756
-
757
- # mapping setting
758
- with gr.Column():
759
- with gr.Accordion("Mapping Settings", open=True):
760
- with gr.Row():
761
- with gr.Accordion("Buddle Settings", open=True):
762
- with gr.Row():
763
- self.inputs.mapper_refine_focal_length = gr.Checkbox(
764
- label="Refine Focal Length",
765
- value=False,
766
- interactive=True,
767
- )
768
- self.inputs.mapper_refine_principle_points = (
769
- gr.Checkbox(
770
- label="Refine Principle Points",
771
- value=False,
772
- interactive=True,
773
- )
774
- )
775
- self.inputs.mapper_refine_extra_params = gr.Checkbox(
776
- label="Refine Extra Params",
777
- value=False,
778
- interactive=True,
779
- )
780
- with gr.Accordion("Retriangluation Settings", open=True):
781
- gr.Textbox(
782
- label="Retriangluation Details",
783
- )
784
- self.ui.button_sfm = gr.Button("Run SFM", variant="primary")
785
- self.outputs.model_3d = gr.Model3D(
786
- interactive=True,
787
- )
788
- self.outputs.output_image = gr.Image(
789
- label="SFM Visualize",
790
- type="numpy",
791
- image_mode="RGB",
792
- interactive=False,
793
- )
794
-
795
- def call_empty(self):
796
- self.ui.button_sfm.click(fn=self.info, inputs=[], outputs=[])
797
-
798
- def call(self):
799
- self.ui.button_sfm.click(
800
- fn=self.sfm_engine.call,
801
- inputs=[
802
- self.inputs.matcher_key,
803
- self.inputs.input_images, # images
804
- self.inputs.camera_model,
805
- self.inputs.camera_params,
806
- self.inputs.max_keypoints,
807
- self.inputs.keypoint_threshold,
808
- self.inputs.match_threshold,
809
- self.inputs.ransac_threshold,
810
- self.inputs.ransac_confidence,
811
- self.inputs.ransac_max_iter,
812
- self.inputs.scene_graph,
813
- self.inputs.global_feature,
814
- self.inputs.top_k,
815
- self.inputs.mapper_refine_focal_length,
816
- self.inputs.mapper_refine_principle_points,
817
- self.inputs.mapper_refine_extra_params,
818
- ],
819
- outputs=[self.outputs.model_3d, self.outputs.output_image],
820
- )
 
1
+ from pathlib import Path
2
+ from typing import Any, Dict, Optional, Tuple
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from easydict import EasyDict as edict
7
+ from omegaconf import OmegaConf
8
+
9
+ from .sfm import SfmEngine
10
+ from .utils import (
11
+ GRADIO_VERSION,
12
+ gen_examples,
13
+ generate_warp_images,
14
+ get_matcher_zoo,
15
+ load_config,
16
+ ransac_zoo,
17
+ run_matching,
18
+ run_ransac,
19
+ send_to_match,
20
+ )
21
+
22
+ DESCRIPTION = """
23
+ # Image Matching WebUI
24
+ This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
25
+ <br/>
26
+ 🔎 For more details about supported local features and matchers, please refer to https://github.com/Vincentqyw/image-matching-webui
27
+
28
+ 🚀 All algorithms run on CPU for inference, causing slow speeds and high latency. For faster inference, please download the [source code](https://github.com/Vincentqyw/image-matching-webui) for local deployment.
29
+
30
+ 🐛 Your feedback is valuable to me. Please do not hesitate to report any bugs [here](https://github.com/Vincentqyw/image-matching-webui/issues).
31
+ """
32
+
33
+ CSS = """
34
+ #warning {background-color: #FFCCCB}
35
+ .logs_class textarea {font-size: 12px !important}
36
+ """
37
+
38
+
39
+ class ImageMatchingApp:
40
+ def __init__(self, server_name="0.0.0.0", server_port=7860, **kwargs):
41
+ self.server_name = server_name
42
+ self.server_port = server_port
43
+ self.config_path = kwargs.get("config", Path(__file__).parent / "config.yaml")
44
+ self.cfg = load_config(self.config_path)
45
+ self.matcher_zoo = get_matcher_zoo(self.cfg["matcher_zoo"])
46
+ self.app = None
47
+ self.example_data_root = kwargs.get(
48
+ "example_data_root", Path(__file__).parents[1] / "datasets"
49
+ )
50
+ # final step
51
+ self.init_interface()
52
+
53
+ def init_matcher_dropdown(self):
54
+ algos = []
55
+ for k, v in self.cfg["matcher_zoo"].items():
56
+ if v.get("enable", True):
57
+ algos.append(k)
58
+ return algos
59
+
60
+ def init_interface(self):
61
+ with gr.Blocks(css=CSS) as self.app:
62
+ with gr.Tab("Image Matching"):
63
+ with gr.Row():
64
+ with gr.Column(scale=1):
65
+ gr.Image(
66
+ str(Path(__file__).parent.parent / "assets/logo.webp"),
67
+ elem_id="logo-img",
68
+ show_label=False,
69
+ show_share_button=False,
70
+ show_download_button=False,
71
+ )
72
+ with gr.Column(scale=3):
73
+ gr.Markdown(DESCRIPTION)
74
+ with gr.Row(equal_height=False):
75
+ with gr.Column():
76
+ with gr.Row():
77
+ matcher_list = gr.Dropdown(
78
+ choices=self.init_matcher_dropdown(),
79
+ value="disk+lightglue",
80
+ label="Matching Model",
81
+ interactive=True,
82
+ )
83
+ match_image_src = gr.Radio(
84
+ (
85
+ ["upload", "webcam", "clipboard"]
86
+ if GRADIO_VERSION > "3"
87
+ else ["upload", "webcam", "canvas"]
88
+ ),
89
+ label="Image Source",
90
+ value="upload",
91
+ )
92
+ with gr.Row():
93
+ input_image0 = gr.Image(
94
+ label="Image 0",
95
+ type="numpy",
96
+ image_mode="RGB",
97
+ height=300 if GRADIO_VERSION > "3" else None,
98
+ interactive=True,
99
+ )
100
+ input_image1 = gr.Image(
101
+ label="Image 1",
102
+ type="numpy",
103
+ image_mode="RGB",
104
+ height=300 if GRADIO_VERSION > "3" else None,
105
+ interactive=True,
106
+ )
107
+
108
+ with gr.Row():
109
+ button_reset = gr.Button(value="Reset")
110
+ button_run = gr.Button(value="Run Match", variant="primary")
111
+ with gr.Row():
112
+ button_stop = gr.Button(value="Force Stop", variant="stop")
113
+
114
+ with gr.Accordion("Advanced Setting", open=False):
115
+ with gr.Accordion("Image Setting", open=True):
116
+ with gr.Row():
117
+ image_force_resize_cb = gr.Checkbox(
118
+ label="Force Resize",
119
+ value=False,
120
+ interactive=True,
121
+ )
122
+ image_setting_height = gr.Slider(
123
+ minimum=48,
124
+ maximum=2048,
125
+ step=16,
126
+ label="Image Height",
127
+ value=480,
128
+ visible=False,
129
+ )
130
+ image_setting_width = gr.Slider(
131
+ minimum=64,
132
+ maximum=2048,
133
+ step=16,
134
+ label="Image Width",
135
+ value=640,
136
+ visible=False,
137
+ )
138
+ with gr.Accordion("Matching Setting", open=True):
139
+ with gr.Row():
140
+ match_setting_threshold = gr.Slider(
141
+ minimum=0.0,
142
+ maximum=1,
143
+ step=0.001,
144
+ label="Match threshold",
145
+ value=0.1,
146
+ )
147
+ match_setting_max_keypoints = gr.Slider(
148
+ minimum=10,
149
+ maximum=10000,
150
+ step=10,
151
+ label="Max features",
152
+ value=1000,
153
+ )
154
+ # TODO: add line settings
155
+ with gr.Row():
156
+ detect_keypoints_threshold = gr.Slider(
157
+ minimum=0,
158
+ maximum=1,
159
+ step=0.001,
160
+ label="Keypoint threshold",
161
+ value=0.015,
162
+ )
163
+ detect_line_threshold = ( # noqa: F841
164
+ gr.Slider(
165
+ minimum=0.1,
166
+ maximum=1,
167
+ step=0.01,
168
+ label="Line threshold",
169
+ value=0.2,
170
+ )
171
+ )
172
+
173
+ with gr.Accordion("RANSAC Setting", open=True):
174
+ with gr.Row(equal_height=False):
175
+ ransac_method = gr.Dropdown(
176
+ choices=ransac_zoo.keys(),
177
+ value=self.cfg["defaults"]["ransac_method"],
178
+ label="RANSAC Method",
179
+ interactive=True,
180
+ )
181
+ ransac_reproj_threshold = gr.Slider(
182
+ minimum=0.0,
183
+ maximum=12,
184
+ step=0.01,
185
+ label="Ransac Reproj threshold",
186
+ value=8.0,
187
+ )
188
+ ransac_confidence = gr.Slider(
189
+ minimum=0.0,
190
+ maximum=1,
191
+ step=0.00001,
192
+ label="Ransac Confidence",
193
+ value=self.cfg["defaults"]["ransac_confidence"],
194
+ )
195
+ ransac_max_iter = gr.Slider(
196
+ minimum=0.0,
197
+ maximum=100000,
198
+ step=100,
199
+ label="Ransac Iterations",
200
+ value=self.cfg["defaults"]["ransac_max_iter"],
201
+ )
202
+ button_ransac = gr.Button(
203
+ value="Rerun RANSAC", variant="primary"
204
+ )
205
+ with gr.Accordion("Geometry Setting", open=False):
206
+ with gr.Row(equal_height=False):
207
+ choice_geometry_type = gr.Radio(
208
+ ["Fundamental", "Homography"],
209
+ label="Reconstruct Geometry",
210
+ value=self.cfg["defaults"]["setting_geometry"],
211
+ )
212
+ # image resize
213
+ image_force_resize_cb.select(
214
+ fn=self._on_select_force_resize,
215
+ inputs=image_force_resize_cb,
216
+ outputs=[image_setting_width, image_setting_height],
217
+ )
218
+ # collect inputs
219
+ state_cache = gr.State({})
220
+ inputs = [
221
+ input_image0,
222
+ input_image1,
223
+ match_setting_threshold,
224
+ match_setting_max_keypoints,
225
+ detect_keypoints_threshold,
226
+ matcher_list,
227
+ ransac_method,
228
+ ransac_reproj_threshold,
229
+ ransac_confidence,
230
+ ransac_max_iter,
231
+ choice_geometry_type,
232
+ gr.State(self.matcher_zoo),
233
+ image_force_resize_cb,
234
+ image_setting_width,
235
+ image_setting_height,
236
+ ]
237
+
238
+ # Add some examples
239
+ with gr.Row():
240
+ # Example inputs
241
+ with gr.Accordion("Open for More: Examples", open=True):
242
+ gr.Examples(
243
+ examples=gen_examples(self.example_data_root),
244
+ inputs=inputs,
245
+ outputs=[],
246
+ fn=run_matching,
247
+ cache_examples=False,
248
+ label=(
249
+ "Examples (click one of the images below to Run"
250
+ " Match). Thx: WxBS"
251
+ ),
252
+ )
253
+ with gr.Accordion("Supported Algorithms", open=False):
254
+ # add a table of supported algorithms
255
+ self.display_supported_algorithms()
256
+
257
+ with gr.Column():
258
+ with gr.Accordion("Open for More: Keypoints", open=True):
259
+ output_keypoints = gr.Image(label="Keypoints", type="numpy")
260
+ with gr.Accordion(
261
+ (
262
+ "Open for More: Raw Matches"
263
+ " (Green for good matches, Red for bad)"
264
+ ),
265
+ open=False,
266
+ ):
267
+ output_matches_raw = gr.Image(
268
+ label="Raw Matches",
269
+ type="numpy",
270
+ )
271
+ with gr.Accordion(
272
+ (
273
+ "Open for More: Ransac Matches"
274
+ " (Green for good matches, Red for bad)"
275
+ ),
276
+ open=True,
277
+ ):
278
+ output_matches_ransac = gr.Image(
279
+ label="Ransac Matches", type="numpy"
280
+ )
281
+ with gr.Accordion(
282
+ "Open for More: Matches Statistics", open=False
283
+ ):
284
+ output_pred = gr.File(label="Outputs", elem_id="download")
285
+ matches_result_info = gr.JSON(label="Matches Statistics")
286
+ matcher_info = gr.JSON(label="Match info")
287
+
288
+ with gr.Accordion("Open for More: Warped Image", open=True):
289
+ output_wrapped = gr.Image(
290
+ label="Wrapped Pair", type="numpy"
291
+ )
292
+ # send to input
293
+ button_rerun = gr.Button(
294
+ value="Send to Input Match Pair",
295
+ variant="primary",
296
+ )
297
+ with gr.Accordion(
298
+ "Open for More: Geometry info", open=False
299
+ ):
300
+ geometry_result = gr.JSON(
301
+ label="Reconstructed Geometry"
302
+ )
303
+
304
+ # callbacks
305
+ match_image_src.change(
306
+ fn=self.ui_change_imagebox,
307
+ inputs=match_image_src,
308
+ outputs=input_image0,
309
+ )
310
+ match_image_src.change(
311
+ fn=self.ui_change_imagebox,
312
+ inputs=match_image_src,
313
+ outputs=input_image1,
314
+ )
315
+ # collect outputs
316
+ outputs = [
317
+ output_keypoints,
318
+ output_matches_raw,
319
+ output_matches_ransac,
320
+ matches_result_info,
321
+ matcher_info,
322
+ geometry_result,
323
+ output_wrapped,
324
+ state_cache,
325
+ output_pred,
326
+ ]
327
+ # button callbacks
328
+ click_event = button_run.click(
329
+ fn=run_matching, inputs=inputs, outputs=outputs
330
+ )
331
+ # stop button
332
+ button_stop.click(
333
+ fn=None, inputs=None, outputs=None, cancels=[click_event]
334
+ )
335
+
336
+ # Reset images
337
+ reset_outputs = [
338
+ input_image0,
339
+ input_image1,
340
+ match_setting_threshold,
341
+ match_setting_max_keypoints,
342
+ detect_keypoints_threshold,
343
+ matcher_list,
344
+ input_image0,
345
+ input_image1,
346
+ match_image_src,
347
+ output_keypoints,
348
+ output_matches_raw,
349
+ output_matches_ransac,
350
+ matches_result_info,
351
+ matcher_info,
352
+ output_wrapped,
353
+ geometry_result,
354
+ ransac_method,
355
+ ransac_reproj_threshold,
356
+ ransac_confidence,
357
+ ransac_max_iter,
358
+ choice_geometry_type,
359
+ output_pred,
360
+ image_force_resize_cb,
361
+ ]
362
+ button_reset.click(
363
+ fn=self.ui_reset_state,
364
+ inputs=None,
365
+ outputs=reset_outputs,
366
+ )
367
+
368
+ # run ransac button action
369
+ button_ransac.click(
370
+ fn=run_ransac,
371
+ inputs=[
372
+ state_cache,
373
+ choice_geometry_type,
374
+ ransac_method,
375
+ ransac_reproj_threshold,
376
+ ransac_confidence,
377
+ ransac_max_iter,
378
+ ],
379
+ outputs=[
380
+ output_matches_ransac,
381
+ matches_result_info,
382
+ output_wrapped,
383
+ output_pred,
384
+ ],
385
+ )
386
+
387
+ # send warped image to match
388
+ button_rerun.click(
389
+ fn=send_to_match,
390
+ inputs=[state_cache],
391
+ outputs=[input_image0, input_image1],
392
+ )
393
+
394
+ # estimate geo
395
+ choice_geometry_type.change(
396
+ fn=generate_warp_images,
397
+ inputs=[
398
+ input_image0,
399
+ input_image1,
400
+ geometry_result,
401
+ choice_geometry_type,
402
+ ],
403
+ outputs=[output_wrapped, geometry_result],
404
+ )
405
+ with gr.Tab("Structure from Motion(under-dev)"):
406
+ sfm_ui = AppSfmUI( # noqa: F841
407
+ {
408
+ **self.cfg,
409
+ "matcher_zoo": self.matcher_zoo,
410
+ "outputs": "experiments/sfm",
411
+ }
412
+ )
413
+ sfm_ui.call_empty()
414
+
415
+ def run(self):
416
+ self.app.queue().launch(
417
+ server_name=self.server_name,
418
+ server_port=self.server_port,
419
+ share=False,
420
+ allowed_paths=[
421
+ str(Path(__file__).parents[0]),
422
+ str(Path(__file__).parents[1]),
423
+ ],
424
+ )
425
+
426
+ def ui_change_imagebox(self, choice):
427
+ """
428
+ Updates the image box with the given choice.
429
+
430
+ Args:
431
+ choice (list): The list of image sources to be displayed in the image box.
432
+
433
+ Returns:
434
+ dict: A dictionary containing the updated value, sources, and type for the image box.
435
+ """
436
+ ret_dict = {
437
+ "value": None, # The updated value of the image box
438
+ "__type__": "update", # The type of update for the image box
439
+ }
440
+ if GRADIO_VERSION > "3":
441
+ return {
442
+ **ret_dict,
443
+ "sources": choice, # The list of image sources to be displayed
444
+ }
445
+ else:
446
+ return {
447
+ **ret_dict,
448
+ "source": choice, # The list of image sources to be displayed
449
+ }
450
+
451
+ def _on_select_force_resize(self, visible: bool = False):
452
+ return gr.update(visible=visible), gr.update(visible=visible)
453
+
454
+ def ui_reset_state(
455
+ self,
456
+ *args: Any,
457
+ ) -> Tuple[
458
+ Optional[np.ndarray],
459
+ Optional[np.ndarray],
460
+ float,
461
+ int,
462
+ float,
463
+ str,
464
+ Dict[str, Any],
465
+ Dict[str, Any],
466
+ str,
467
+ Optional[np.ndarray],
468
+ Optional[np.ndarray],
469
+ Optional[np.ndarray],
470
+ Dict[str, Any],
471
+ Dict[str, Any],
472
+ Optional[np.ndarray],
473
+ Dict[str, Any],
474
+ str,
475
+ int,
476
+ float,
477
+ int,
478
+ bool,
479
+ ]:
480
+ """
481
+ Reset the state of the UI.
482
+
483
+ Returns:
484
+ tuple: A tuple containing the initial values for the UI state.
485
+ """
486
+ key: str = list(self.matcher_zoo.keys())[
487
+ 0
488
+ ] # Get the first key from matcher_zoo
489
+ # flush_logs()
490
+ return (
491
+ None, # image0: Optional[np.ndarray]
492
+ None, # image1: Optional[np.ndarray]
493
+ self.cfg["defaults"]["match_threshold"], # matching_threshold: float
494
+ self.cfg["defaults"]["max_keypoints"], # max_keypoints: int
495
+ self.cfg["defaults"]["keypoint_threshold"], # keypoint_threshold: float
496
+ key, # matcher: str
497
+ self.ui_change_imagebox("upload"), # input image0: Dict[str, Any]
498
+ self.ui_change_imagebox("upload"), # input image1: Dict[str, Any]
499
+ "upload", # match_image_src: str
500
+ None, # keypoints: Optional[np.ndarray]
501
+ None, # raw matches: Optional[np.ndarray]
502
+ None, # ransac matches: Optional[np.ndarray]
503
+ {}, # matches result info: Dict[str, Any]
504
+ {}, # matcher config: Dict[str, Any]
505
+ None, # warped image: Optional[np.ndarray]
506
+ {}, # geometry result: Dict[str, Any]
507
+ self.cfg["defaults"]["ransac_method"], # ransac_method: str
508
+ self.cfg["defaults"][
509
+ "ransac_reproj_threshold"
510
+ ], # ransac_reproj_threshold: float
511
+ self.cfg["defaults"]["ransac_confidence"], # ransac_confidence: float
512
+ self.cfg["defaults"]["ransac_max_iter"], # ransac_max_iter: int
513
+ self.cfg["defaults"]["setting_geometry"], # geometry: str
514
+ None, # predictions
515
+ False,
516
+ )
517
+
518
+ def display_supported_algorithms(self, style="tab"):
519
+ def get_link(link, tag="Link"):
520
+ return "[{}]({})".format(tag, link) if link is not None else "None"
521
+
522
+ data = []
523
+ cfg = self.cfg["matcher_zoo"]
524
+ if style == "md":
525
+ markdown_table = "| Algo. | Conference | Code | Project | Paper |\n"
526
+ markdown_table += "| ----- | ---------- | ---- | ------- | ----- |\n"
527
+
528
+ for _, v in cfg.items():
529
+ if not v["info"].get("display", True):
530
+ continue
531
+ github_link = get_link(v["info"].get("github", ""))
532
+ project_link = get_link(v["info"].get("project", ""))
533
+ paper_link = get_link(
534
+ v["info"]["paper"],
535
+ (
536
+ Path(v["info"]["paper"]).name[-10:]
537
+ if v["info"]["paper"] is not None
538
+ else "Link"
539
+ ),
540
+ )
541
+
542
+ markdown_table += "{}|{}|{}|{}|{}\n".format(
543
+ v["info"].get("name", ""),
544
+ v["info"].get("source", ""),
545
+ github_link,
546
+ project_link,
547
+ paper_link,
548
+ )
549
+ return gr.Markdown(markdown_table)
550
+ elif style == "tab":
551
+ for k, v in cfg.items():
552
+ if not v["info"].get("display", True):
553
+ continue
554
+ data.append(
555
+ [
556
+ v["info"].get("name", ""),
557
+ v["info"].get("source", ""),
558
+ v["info"].get("github", ""),
559
+ v["info"].get("paper", ""),
560
+ v["info"].get("project", ""),
561
+ ]
562
+ )
563
+ tab = gr.Dataframe(
564
+ headers=["Algo.", "Conference", "Code", "Paper", "Project"],
565
+ datatype=["str", "str", "str", "str", "str"],
566
+ col_count=(5, "fixed"),
567
+ value=data,
568
+ # wrap=True,
569
+ # min_width = 1000,
570
+ # height=1000,
571
+ )
572
+ return tab
573
+
574
+
575
+ class AppBaseUI:
576
+ def __init__(self, cfg: Dict[str, Any] = {}):
577
+ self.cfg = OmegaConf.create(cfg)
578
+ self.inputs = edict({})
579
+ self.outputs = edict({})
580
+ self.ui = edict({})
581
+
582
+ def _init_ui(self):
583
+ NotImplemented
584
+
585
+ def call(self, **kwargs):
586
+ NotImplemented
587
+
588
+ def info(self):
589
+ gr.Info("SFM is under construction.")
590
+
591
+
592
+ class AppSfmUI(AppBaseUI):
593
+ def __init__(self, cfg: Dict[str, Any] = None):
594
+ super().__init__(cfg)
595
+ assert "matcher_zoo" in self.cfg
596
+ self.matcher_zoo = self.cfg["matcher_zoo"]
597
+ self.sfm_engine = SfmEngine(cfg)
598
+ self._init_ui()
599
+
600
+ def init_retrieval_dropdown(self):
601
+ algos = []
602
+ for k, v in self.cfg["retrieval_zoo"].items():
603
+ if v.get("enable", True):
604
+ algos.append(k)
605
+ return algos
606
+
607
+ def _update_options(self, option):
608
+ if option == "sparse":
609
+ return gr.Textbox("sparse", visible=True)
610
+ elif option == "dense":
611
+ return gr.Textbox("dense", visible=True)
612
+ else:
613
+ return gr.Textbox("not set", visible=True)
614
+
615
+ def _on_select_custom_params(self, value: bool = False):
616
+ return gr.update(visible=value)
617
+
618
+ def _init_ui(self):
619
+ with gr.Row():
620
+ # data settting and camera settings
621
+ with gr.Column():
622
+ self.inputs.input_images = gr.File(
623
+ label="SfM",
624
+ interactive=True,
625
+ file_count="multiple",
626
+ min_width=300,
627
+ )
628
+ # camera setting
629
+ with gr.Accordion("Camera Settings", open=True):
630
+ with gr.Column():
631
+ with gr.Row():
632
+ with gr.Column():
633
+ self.inputs.camera_model = gr.Dropdown(
634
+ choices=[
635
+ "PINHOLE",
636
+ "SIMPLE_RADIAL",
637
+ "OPENCV",
638
+ ],
639
+ value="PINHOLE",
640
+ label="Camera Model",
641
+ interactive=True,
642
+ )
643
+ with gr.Column():
644
+ gr.Checkbox(
645
+ label="Shared Params",
646
+ value=True,
647
+ interactive=True,
648
+ )
649
+ camera_custom_params_cb = gr.Checkbox(
650
+ label="Custom Params",
651
+ value=False,
652
+ interactive=True,
653
+ )
654
+ with gr.Row():
655
+ self.inputs.camera_params = gr.Textbox(
656
+ label="Camera Params",
657
+ value="0,0,0,0",
658
+ interactive=False,
659
+ visible=False,
660
+ )
661
+ camera_custom_params_cb.select(
662
+ fn=self._on_select_custom_params,
663
+ inputs=camera_custom_params_cb,
664
+ outputs=self.inputs.camera_params,
665
+ )
666
+
667
+ with gr.Accordion("Matching Settings", open=True):
668
+ # feature extraction and matching setting
669
+ with gr.Row():
670
+ # matcher setting
671
+ self.inputs.matcher_key = gr.Dropdown(
672
+ choices=self.matcher_zoo.keys(),
673
+ value="disk+lightglue",
674
+ label="Matching Model",
675
+ interactive=True,
676
+ )
677
+ with gr.Row():
678
+ with gr.Accordion("Advanced Settings", open=False):
679
+ with gr.Column():
680
+ with gr.Row():
681
+ # matching setting
682
+ self.inputs.max_keypoints = gr.Slider(
683
+ label="Max Keypoints",
684
+ minimum=100,
685
+ maximum=10000,
686
+ value=1000,
687
+ interactive=True,
688
+ )
689
+ self.inputs.keypoint_threshold = gr.Slider(
690
+ label="Keypoint Threshold",
691
+ minimum=0,
692
+ maximum=1,
693
+ value=0.01,
694
+ )
695
+ with gr.Row():
696
+ self.inputs.match_threshold = gr.Slider(
697
+ label="Match Threshold",
698
+ minimum=0.01,
699
+ maximum=12.0,
700
+ value=0.2,
701
+ )
702
+ self.inputs.ransac_threshold = gr.Slider(
703
+ label="Ransac Threshold",
704
+ minimum=0.01,
705
+ maximum=12.0,
706
+ value=4.0,
707
+ step=0.01,
708
+ interactive=True,
709
+ )
710
+
711
+ with gr.Row():
712
+ self.inputs.ransac_confidence = gr.Slider(
713
+ label="Ransac Confidence",
714
+ minimum=0.01,
715
+ maximum=1.0,
716
+ value=0.9999,
717
+ step=0.0001,
718
+ interactive=True,
719
+ )
720
+ self.inputs.ransac_max_iter = gr.Slider(
721
+ label="Ransac Max Iter",
722
+ minimum=1,
723
+ maximum=100,
724
+ value=100,
725
+ step=1,
726
+ interactive=True,
727
+ )
728
+ with gr.Accordion("Scene Graph Settings", open=True):
729
+ # mapping setting
730
+ self.inputs.scene_graph = gr.Dropdown(
731
+ choices=["all", "swin", "oneref"],
732
+ value="all",
733
+ label="Scene Graph",
734
+ interactive=True,
735
+ )
736
+
737
+ # global feature setting
738
+ self.inputs.global_feature = gr.Dropdown(
739
+ choices=self.init_retrieval_dropdown(),
740
+ value="netvlad",
741
+ label="Global features",
742
+ interactive=True,
743
+ )
744
+ self.inputs.top_k = gr.Slider(
745
+ label="Number of Images per Image to Match",
746
+ minimum=1,
747
+ maximum=100,
748
+ value=10,
749
+ step=1,
750
+ )
751
+ # button_match = gr.Button("Run Matching", variant="primary")
752
+
753
+ # mapping setting
754
+ with gr.Column():
755
+ with gr.Accordion("Mapping Settings", open=True):
756
+ with gr.Row():
757
+ with gr.Accordion("Buddle Settings", open=True):
758
+ with gr.Row():
759
+ self.inputs.mapper_refine_focal_length = gr.Checkbox(
760
+ label="Refine Focal Length",
761
+ value=False,
762
+ interactive=True,
763
+ )
764
+ self.inputs.mapper_refine_principle_points = (
765
+ gr.Checkbox(
766
+ label="Refine Principle Points",
767
+ value=False,
768
+ interactive=True,
769
+ )
770
+ )
771
+ self.inputs.mapper_refine_extra_params = gr.Checkbox(
772
+ label="Refine Extra Params",
773
+ value=False,
774
+ interactive=True,
775
+ )
776
+ with gr.Accordion("Retriangluation Settings", open=True):
777
+ gr.Textbox(
778
+ label="Retriangluation Details",
779
+ )
780
+ self.ui.button_sfm = gr.Button("Run SFM", variant="primary")
781
+ self.outputs.model_3d = gr.Model3D(
782
+ interactive=True,
783
+ )
784
+ self.outputs.output_image = gr.Image(
785
+ label="SFM Visualize",
786
+ type="numpy",
787
+ image_mode="RGB",
788
+ interactive=False,
789
+ )
790
+
791
+ def call_empty(self):
792
+ self.ui.button_sfm.click(fn=self.info, inputs=[], outputs=[])
793
+
794
+ def call(self):
795
+ self.ui.button_sfm.click(
796
+ fn=self.sfm_engine.call,
797
+ inputs=[
798
+ self.inputs.matcher_key,
799
+ self.inputs.input_images, # images
800
+ self.inputs.camera_model,
801
+ self.inputs.camera_params,
802
+ self.inputs.max_keypoints,
803
+ self.inputs.keypoint_threshold,
804
+ self.inputs.match_threshold,
805
+ self.inputs.ransac_threshold,
806
+ self.inputs.ransac_confidence,
807
+ self.inputs.ransac_max_iter,
808
+ self.inputs.scene_graph,
809
+ self.inputs.global_feature,
810
+ self.inputs.top_k,
811
+ self.inputs.mapper_refine_focal_length,
812
+ self.inputs.mapper_refine_principle_points,
813
+ self.inputs.mapper_refine_extra_params,
814
+ ],
815
+ outputs=[self.outputs.model_3d, self.outputs.output_image],
816
+ )
 
 
 
 
imcui/ui/modelcache.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import time
4
+ import threading
5
+ from collections import OrderedDict
6
+ import torch
7
+ from ..hloc import logger
8
+
9
+
10
+ class ARCSizeAwareModelCache:
11
+ def __init__(
12
+ self,
13
+ max_gpu_mem: float = 8e9,
14
+ max_cpu_mem: float = 12e9,
15
+ device_priority: list = ["cuda", "cpu"],
16
+ auto_empty_cache: bool = True,
17
+ ):
18
+ """
19
+ Initialize the model cache.
20
+
21
+ Args:
22
+ max_gpu_mem: Maximum GPU memory allowed in bytes.
23
+ max_cpu_mem: Maximum CPU memory allowed in bytes.
24
+ device_priority: List of devices to prioritize when evicting models.
25
+ auto_empty_cache: Whether to call torch.cuda.empty_cache() when out of memory.
26
+ """
27
+
28
+ self.t1 = OrderedDict()
29
+ self.t2 = OrderedDict()
30
+ self.b1 = OrderedDict()
31
+ self.b2 = OrderedDict()
32
+
33
+ self.max_gpu = max_gpu_mem
34
+ self.max_cpu = max_cpu_mem
35
+ self.current_gpu = 0
36
+ self.current_cpu = 0
37
+
38
+ self.p = 0
39
+ self.adaptive_factor = 0.5
40
+
41
+ self.device_priority = device_priority
42
+ self.lock = threading.Lock()
43
+ self.auto_empty_cache = auto_empty_cache
44
+
45
+ logger.info("ARCSizeAwareModelCache initialized.")
46
+
47
+ def _release_model(self, model_entry):
48
+ """
49
+ Release a model from memory.
50
+
51
+ Args:
52
+ model_entry: A dictionary containing the model, device and other information.
53
+
54
+ Notes:
55
+ If the device is CUDA and auto_empty_cache is True, torch.cuda.empty_cache() is called after releasing the model.
56
+ """
57
+ model = model_entry["model"]
58
+ device = model_entry["device"]
59
+
60
+ del model
61
+ if device == "cuda":
62
+ torch.cuda.synchronize()
63
+ if self.auto_empty_cache:
64
+ torch.cuda.empty_cache()
65
+
66
+ def generate_key(self, model_key, model_conf: dict) -> str:
67
+ loader_identifier = f"{model_key}"
68
+ unique_str = f"{loader_identifier}-{json.dumps(model_conf, sort_keys=True)}"
69
+ return hashlib.sha256(unique_str.encode()).hexdigest()
70
+
71
+ def _get_device(self, model_size: int) -> str:
72
+ for device in self.device_priority:
73
+ if device == "cuda" and torch.cuda.is_available():
74
+ if self.current_gpu + model_size <= self.max_gpu:
75
+ return "cuda"
76
+ elif device == "cpu":
77
+ if self.current_cpu + model_size <= self.max_cpu:
78
+ return "cpu"
79
+ return "cpu"
80
+
81
+ def _calculate_model_size(self, model):
82
+ return sum(p.numel() * p.element_size() for p in model.parameters()) + sum(
83
+ b.numel() * b.element_size() for b in model.buffers()
84
+ )
85
+
86
+ def _update_access(self, key: str, size: int, device: str):
87
+ if key in self.b1:
88
+ self.p = min(
89
+ self.p + max(1, len(self.b2) // len(self.b1)),
90
+ len(self.t1) + len(self.t2),
91
+ )
92
+ self.b1.pop(key)
93
+ self._replace(False)
94
+ elif key in self.b2:
95
+ self.p = max(self.p - max(1, len(self.b1) // len(self.b2)), 0)
96
+ self.b2.pop(key)
97
+ self._replace(True)
98
+
99
+ if key in self.t1:
100
+ self.t1.pop(key)
101
+ self.t2[key] = {
102
+ "size": size,
103
+ "device": device,
104
+ "access_count": 1,
105
+ "last_accessed": time.time(),
106
+ }
107
+
108
+ def _replace(self, in_t2: bool):
109
+ if len(self.t1) > 0 and (
110
+ (len(self.t1) > self.p) or (in_t2 and len(self.t1) == self.p)
111
+ ):
112
+ k, v = self.t1.popitem(last=False)
113
+ self.b1[k] = v
114
+ else:
115
+ k, v = self.t2.popitem(last=False)
116
+ self.b2[k] = v
117
+
118
+ def _calculate_weight(self, entry) -> float:
119
+ return entry["access_count"] / entry["size"]
120
+
121
+ def _evict_models(self, required_size: int, target_device: str) -> bool:
122
+ candidates = []
123
+ for k, v in list(self.t1.items()) + list(self.t2.items()):
124
+ if v["device"] == target_device:
125
+ candidates.append((k, v))
126
+
127
+ candidates.sort(key=lambda x: self._calculate_weight(x[1]))
128
+
129
+ freed = 0
130
+ for k, v in candidates:
131
+ self._release_model(v)
132
+ freed += v["size"]
133
+ if v in self.t1:
134
+ self.t1.pop(k)
135
+ if v in self.t2:
136
+ self.t2.pop(k)
137
+
138
+ if v["device"] == "cuda":
139
+ self.current_gpu -= v["size"]
140
+ else:
141
+ self.current_cpu -= v["size"]
142
+
143
+ if freed >= required_size:
144
+ return True
145
+
146
+ if target_device == "cuda":
147
+ return self._cross_device_evict(required_size, "cuda")
148
+ return False
149
+
150
+ def _cross_device_evict(self, required_size: int, target_device: str) -> bool:
151
+ all_entries = []
152
+ for k, v in list(self.t1.items()) + list(self.t2.items()):
153
+ all_entries.append((k, v))
154
+
155
+ all_entries.sort(
156
+ key=lambda x: self._calculate_weight(x[1])
157
+ + (0.5 if x[1]["device"] == target_device else 0)
158
+ )
159
+
160
+ freed = 0
161
+ for k, v in all_entries:
162
+ freed += v["size"]
163
+ if v in self.t1:
164
+ self.t1.pop(k)
165
+ if v in self.t2:
166
+ self.t2.pop(k)
167
+
168
+ if v["device"] == "cuda":
169
+ self.current_gpu -= v["size"]
170
+ else:
171
+ self.current_cpu -= v["size"]
172
+
173
+ if freed >= required_size:
174
+ return True
175
+ return False
176
+
177
+ def load_model(self, model_key, model_loader_func, model_conf: dict):
178
+ key = self.generate_key(model_key, model_conf)
179
+
180
+ with self.lock:
181
+ if key in self.t1 or key in self.t2:
182
+ entry = self.t1.pop(key, None) or self.t2.pop(key)
183
+ entry["access_count"] += 1
184
+ self.t2[key] = entry
185
+ return entry["model"]
186
+
187
+ raw_model = model_loader_func(model_conf)
188
+ model_size = self._calculate_model_size(raw_model)
189
+ device = self._get_device(model_size)
190
+
191
+ if device == "cuda" and self.auto_empty_cache:
192
+ torch.cuda.empty_cache()
193
+ torch.cuda.synchronize()
194
+
195
+ while True:
196
+ current_mem = self.current_gpu if device == "cuda" else self.current_cpu
197
+ max_mem = self.max_gpu if device == "cuda" else self.max_cpu
198
+
199
+ if current_mem + model_size <= max_mem:
200
+ break
201
+
202
+ if not self._evict_models(model_size, device):
203
+ if device == "cuda":
204
+ device = "cpu"
205
+ else:
206
+ raise RuntimeError("Out of memory")
207
+
208
+ try:
209
+ model = raw_model.to(device)
210
+ except RuntimeError as e:
211
+ if "CUDA out of memory" in str(e):
212
+ torch.cuda.empty_cache()
213
+ model = raw_model.to(device)
214
+
215
+ new_entry = {
216
+ "model": model,
217
+ "size": model_size,
218
+ "device": device,
219
+ "access_count": 1,
220
+ "last_accessed": time.time(),
221
+ }
222
+
223
+ if key in self.b1 or key in self.b2:
224
+ self.t2[key] = new_entry
225
+ self._replace(True)
226
+ else:
227
+ self.t1[key] = new_entry
228
+ self._replace(False)
229
+
230
+ if device == "cuda":
231
+ self.current_gpu += model_size
232
+ else:
233
+ self.current_cpu += model_size
234
+
235
+ return model
236
+
237
+ def clear_device_cache(self, device: str):
238
+ with self.lock:
239
+ for cache in [self.t1, self.t2, self.b1, self.b2]:
240
+ for k in list(cache.keys()):
241
+ if cache[k]["device"] == device:
242
+ cache.pop(k)
243
+
244
+
245
+ class LRUModelCache:
246
+ def __init__(
247
+ self,
248
+ max_gpu_mem: float = 8e9,
249
+ max_cpu_mem: float = 12e9,
250
+ device_priority: list = ["cuda", "cpu"],
251
+ ):
252
+ self.cache = OrderedDict()
253
+ self.max_gpu = max_gpu_mem
254
+ self.max_cpu = max_cpu_mem
255
+ self.current_gpu = 0
256
+ self.current_cpu = 0
257
+ self.lock = threading.Lock()
258
+ self.device_priority = device_priority
259
+
260
+ def generate_key(self, model_key, model_conf: dict) -> str:
261
+ loader_identifier = f"{model_key}"
262
+ unique_str = f"{loader_identifier}-{json.dumps(model_conf, sort_keys=True)}"
263
+ return hashlib.sha256(unique_str.encode()).hexdigest()
264
+
265
+ def get_device(self) -> str:
266
+ for device in self.device_priority:
267
+ if device == "cuda" and torch.cuda.is_available():
268
+ if self.current_gpu < self.max_gpu:
269
+ return device
270
+ elif device == "cpu":
271
+ if self.current_cpu < self.max_cpu:
272
+ return device
273
+ return "cpu"
274
+
275
+ def _calculate_model_size(self, model):
276
+ param_size = sum(p.numel() * p.element_size() for p in model.parameters())
277
+ buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
278
+ return param_size + buffer_size
279
+
280
+ def load_model(self, model_key, model_loader_func, model_conf: dict):
281
+ key = self.generate_key(model_key, model_conf)
282
+
283
+ with self.lock:
284
+ if key in self.cache:
285
+ self.cache.move_to_end(key) # update LRU
286
+ return self.cache[key]["model"]
287
+
288
+ device = self.get_device()
289
+ if device == "cuda":
290
+ torch.cuda.empty_cache()
291
+
292
+ try:
293
+ raw_model = model_loader_func(model_conf)
294
+ except Exception as e:
295
+ raise RuntimeError(f"Model loading failed: {str(e)}")
296
+
297
+ try:
298
+ model = raw_model.to(device)
299
+ except RuntimeError as e:
300
+ if "CUDA out of memory" in str(e):
301
+ return self._handle_oom(model_key, model_loader_func, model_conf)
302
+ raise
303
+
304
+ model_size = self._calculate_model_size(model)
305
+
306
+ while (
307
+ device == "cuda" and (self.current_gpu + model_size > self.max_gpu)
308
+ ) or (device == "cpu" and (self.current_cpu + model_size > self.max_cpu)):
309
+ if not self._free_space(model_size, device):
310
+ raise RuntimeError("Insufficient memory even after cache cleanup")
311
+
312
+ if device == "cuda":
313
+ self.current_gpu += model_size
314
+ else:
315
+ self.current_cpu += model_size
316
+
317
+ self.cache[key] = {
318
+ "model": model,
319
+ "size": model_size,
320
+ "device": device,
321
+ "timestamp": time.time(),
322
+ }
323
+
324
+ return model
325
+
326
+ def _free_space(self, required_size: int, device: str) -> bool:
327
+ for key in list(self.cache.keys()):
328
+ if (device == "cuda" and self.cache[key]["device"] == "cuda") or (
329
+ device == "cpu" and self.cache[key]["device"] == "cpu"
330
+ ):
331
+ self.current_gpu -= (
332
+ self.cache[key]["size"]
333
+ if self.cache[key]["device"] == "cuda"
334
+ else 0
335
+ )
336
+ self.current_cpu -= (
337
+ self.cache[key]["size"] if self.cache[key]["device"] == "cpu" else 0
338
+ )
339
+ del self.cache[key]
340
+
341
+ if (
342
+ device == "cuda"
343
+ and self.current_gpu + required_size <= self.max_gpu
344
+ ) or (
345
+ device == "cpu" and self.current_cpu + required_size <= self.max_cpu
346
+ ):
347
+ return True
348
+ return False
349
+
350
+ def _handle_oom(self, model_key, model_loader_func, model_conf: dict):
351
+ with self.lock:
352
+ self.clear_device_cache("cuda")
353
+ torch.cuda.empty_cache()
354
+
355
+ try:
356
+ return self.load_model(model_key, model_loader_func, model_conf)
357
+ except RuntimeError:
358
+ original_priority = self.device_priority
359
+ self.device_priority = ["cpu"]
360
+ try:
361
+ return self.load_model(model_key, model_loader_func, model_conf)
362
+ finally:
363
+ self.device_priority = original_priority
364
+
365
+ def clear_device_cache(self, device: str):
366
+ with self.lock:
367
+ keys_to_remove = [k for k, v in self.cache.items() if v["device"] == device]
368
+ for k in keys_to_remove:
369
+ self.current_gpu -= self.cache[k]["size"] if device == "cuda" else 0
370
+ self.current_cpu -= self.cache[k]["size"] if device == "cpu" else 0
371
+ del self.cache[k]
imcui/ui/sfm.py CHANGED
@@ -1,164 +1,164 @@
1
- import shutil
2
- import tempfile
3
- from pathlib import Path
4
- from typing import Any, Dict, List
5
-
6
-
7
- from ..hloc import (
8
- extract_features,
9
- logger,
10
- match_features,
11
- pairs_from_retrieval,
12
- reconstruction,
13
- visualization,
14
- )
15
-
16
- try:
17
- import pycolmap
18
- except ImportError:
19
- logger.warning("pycolmap not installed, some features may not work")
20
-
21
- from .viz import fig2im
22
-
23
-
24
- class SfmEngine:
25
- def __init__(self, cfg: Dict[str, Any] = None):
26
- self.cfg = cfg
27
- if "outputs" in cfg and Path(cfg["outputs"]):
28
- outputs = Path(cfg["outputs"])
29
- outputs.mkdir(parents=True, exist_ok=True)
30
- else:
31
- outputs = tempfile.mkdtemp()
32
- self.outputs = Path(outputs)
33
-
34
- def call(
35
- self,
36
- key: str,
37
- images: Path,
38
- camera_model: str,
39
- camera_params: List[float],
40
- max_keypoints: int,
41
- keypoint_threshold: float,
42
- match_threshold: float,
43
- ransac_threshold: int,
44
- ransac_confidence: float,
45
- ransac_max_iter: int,
46
- scene_graph: bool,
47
- global_feature: str,
48
- top_k: int = 10,
49
- mapper_refine_focal_length: bool = False,
50
- mapper_refine_principle_points: bool = False,
51
- mapper_refine_extra_params: bool = False,
52
- ):
53
- """
54
- Call a list of functions to perform feature extraction, matching, and reconstruction.
55
-
56
- Args:
57
- key (str): The key to retrieve the matcher and feature models.
58
- images (Path): The directory containing the images.
59
- outputs (Path): The directory to store the outputs.
60
- camera_model (str): The camera model.
61
- camera_params (List[float]): The camera parameters.
62
- max_keypoints (int): The maximum number of features.
63
- match_threshold (float): The match threshold.
64
- ransac_threshold (int): The RANSAC threshold.
65
- ransac_confidence (float): The RANSAC confidence.
66
- ransac_max_iter (int): The maximum number of RANSAC iterations.
67
- scene_graph (bool): Whether to compute the scene graph.
68
- global_feature (str): Whether to compute the global feature.
69
- top_k (int): The number of image-pair to use.
70
- mapper_refine_focal_length (bool): Whether to refine the focal length.
71
- mapper_refine_principle_points (bool): Whether to refine the principle points.
72
- mapper_refine_extra_params (bool): Whether to refine the extra parameters.
73
-
74
- Returns:
75
- Path: The directory containing the SfM results.
76
- """
77
- if len(images) == 0:
78
- logger.error(f"{images} does not exist.")
79
-
80
- temp_images = Path(tempfile.mkdtemp())
81
- # copy images
82
- logger.info(f"Copying images to {temp_images}.")
83
- for image in images:
84
- shutil.copy(image, temp_images)
85
-
86
- matcher_zoo = self.cfg["matcher_zoo"]
87
- model = matcher_zoo[key]
88
- match_conf = model["matcher"]
89
- match_conf["model"]["max_keypoints"] = max_keypoints
90
- match_conf["model"]["match_threshold"] = match_threshold
91
-
92
- feature_conf = model["feature"]
93
- feature_conf["model"]["max_keypoints"] = max_keypoints
94
- feature_conf["model"]["keypoint_threshold"] = keypoint_threshold
95
-
96
- # retrieval
97
- retrieval_name = self.cfg.get("retrieval_name", "netvlad")
98
- retrieval_conf = extract_features.confs[retrieval_name]
99
-
100
- mapper_options = {
101
- "ba_refine_extra_params": mapper_refine_extra_params,
102
- "ba_refine_focal_length": mapper_refine_focal_length,
103
- "ba_refine_principal_point": mapper_refine_principle_points,
104
- "ba_local_max_num_iterations": 40,
105
- "ba_local_max_refinements": 3,
106
- "ba_global_max_num_iterations": 100,
107
- # below 3 options are for individual/video data, for internet photos, they should be left
108
- # default
109
- "min_focal_length_ratio": 0.1,
110
- "max_focal_length_ratio": 10,
111
- "max_extra_param": 1e15,
112
- }
113
-
114
- sfm_dir = self.outputs / "sfm_{}".format(key)
115
- sfm_pairs = self.outputs / "pairs-sfm.txt"
116
- sfm_dir.mkdir(exist_ok=True, parents=True)
117
-
118
- # extract features
119
- retrieval_path = extract_features.main(
120
- retrieval_conf, temp_images, self.outputs
121
- )
122
- pairs_from_retrieval.main(retrieval_path, sfm_pairs, num_matched=top_k)
123
-
124
- feature_path = extract_features.main(feature_conf, temp_images, self.outputs)
125
- # match features
126
- match_path = match_features.main(
127
- match_conf, sfm_pairs, feature_conf["output"], self.outputs
128
- )
129
- # reconstruction
130
- already_sfm = False
131
- if sfm_dir.exists():
132
- try:
133
- model = pycolmap.Reconstruction(str(sfm_dir))
134
- already_sfm = True
135
- except ValueError:
136
- logger.info(f"sfm_dir not exists model: {sfm_dir}")
137
- if not already_sfm:
138
- model = reconstruction.main(
139
- sfm_dir,
140
- temp_images,
141
- sfm_pairs,
142
- feature_path,
143
- match_path,
144
- mapper_options=mapper_options,
145
- )
146
-
147
- vertices = []
148
- for point3D_id, point3D in model.points3D.items():
149
- vertices.append([point3D.xyz, point3D.color])
150
-
151
- model_3d = sfm_dir / "points3D.obj"
152
- with open(model_3d, "w") as f:
153
- for p, c in vertices:
154
- # Write vertex position
155
- f.write("v {} {} {}\n".format(p[0], p[1], p[2]))
156
- # Write vertex normal (color)
157
- f.write(
158
- "vn {} {} {}\n".format(c[0] / 255.0, c[1] / 255.0, c[2] / 255.0)
159
- )
160
- viz_2d = visualization.visualize_sfm_2d(
161
- model, temp_images, color_by="visibility", n=2, dpi=300
162
- )
163
-
164
- return model_3d, fig2im(viz_2d) / 255.0
 
1
+ import shutil
2
+ import tempfile
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List
5
+
6
+
7
+ from ..hloc import (
8
+ extract_features,
9
+ logger,
10
+ match_features,
11
+ pairs_from_retrieval,
12
+ reconstruction,
13
+ visualization,
14
+ )
15
+
16
+ try:
17
+ import pycolmap
18
+ except ImportError:
19
+ logger.warning("pycolmap not installed, some features may not work")
20
+
21
+ from .viz import fig2im
22
+
23
+
24
+ class SfmEngine:
25
+ def __init__(self, cfg: Dict[str, Any] = None):
26
+ self.cfg = cfg
27
+ if "outputs" in cfg and Path(cfg["outputs"]):
28
+ outputs = Path(cfg["outputs"])
29
+ outputs.mkdir(parents=True, exist_ok=True)
30
+ else:
31
+ outputs = tempfile.mkdtemp()
32
+ self.outputs = Path(outputs)
33
+
34
+ def call(
35
+ self,
36
+ key: str,
37
+ images: Path,
38
+ camera_model: str,
39
+ camera_params: List[float],
40
+ max_keypoints: int,
41
+ keypoint_threshold: float,
42
+ match_threshold: float,
43
+ ransac_threshold: int,
44
+ ransac_confidence: float,
45
+ ransac_max_iter: int,
46
+ scene_graph: bool,
47
+ global_feature: str,
48
+ top_k: int = 10,
49
+ mapper_refine_focal_length: bool = False,
50
+ mapper_refine_principle_points: bool = False,
51
+ mapper_refine_extra_params: bool = False,
52
+ ):
53
+ """
54
+ Call a list of functions to perform feature extraction, matching, and reconstruction.
55
+
56
+ Args:
57
+ key (str): The key to retrieve the matcher and feature models.
58
+ images (Path): The directory containing the images.
59
+ outputs (Path): The directory to store the outputs.
60
+ camera_model (str): The camera model.
61
+ camera_params (List[float]): The camera parameters.
62
+ max_keypoints (int): The maximum number of features.
63
+ match_threshold (float): The match threshold.
64
+ ransac_threshold (int): The RANSAC threshold.
65
+ ransac_confidence (float): The RANSAC confidence.
66
+ ransac_max_iter (int): The maximum number of RANSAC iterations.
67
+ scene_graph (bool): Whether to compute the scene graph.
68
+ global_feature (str): Whether to compute the global feature.
69
+ top_k (int): The number of image-pair to use.
70
+ mapper_refine_focal_length (bool): Whether to refine the focal length.
71
+ mapper_refine_principle_points (bool): Whether to refine the principle points.
72
+ mapper_refine_extra_params (bool): Whether to refine the extra parameters.
73
+
74
+ Returns:
75
+ Path: The directory containing the SfM results.
76
+ """
77
+ if len(images) == 0:
78
+ logger.error(f"{images} does not exist.")
79
+
80
+ temp_images = Path(tempfile.mkdtemp())
81
+ # copy images
82
+ logger.info(f"Copying images to {temp_images}.")
83
+ for image in images:
84
+ shutil.copy(image, temp_images)
85
+
86
+ matcher_zoo = self.cfg["matcher_zoo"]
87
+ model = matcher_zoo[key]
88
+ match_conf = model["matcher"]
89
+ match_conf["model"]["max_keypoints"] = max_keypoints
90
+ match_conf["model"]["match_threshold"] = match_threshold
91
+
92
+ feature_conf = model["feature"]
93
+ feature_conf["model"]["max_keypoints"] = max_keypoints
94
+ feature_conf["model"]["keypoint_threshold"] = keypoint_threshold
95
+
96
+ # retrieval
97
+ retrieval_name = self.cfg.get("retrieval_name", "netvlad")
98
+ retrieval_conf = extract_features.confs[retrieval_name]
99
+
100
+ mapper_options = {
101
+ "ba_refine_extra_params": mapper_refine_extra_params,
102
+ "ba_refine_focal_length": mapper_refine_focal_length,
103
+ "ba_refine_principal_point": mapper_refine_principle_points,
104
+ "ba_local_max_num_iterations": 40,
105
+ "ba_local_max_refinements": 3,
106
+ "ba_global_max_num_iterations": 100,
107
+ # below 3 options are for individual/video data, for internet photos, they should be left
108
+ # default
109
+ "min_focal_length_ratio": 0.1,
110
+ "max_focal_length_ratio": 10,
111
+ "max_extra_param": 1e15,
112
+ }
113
+
114
+ sfm_dir = self.outputs / "sfm_{}".format(key)
115
+ sfm_pairs = self.outputs / "pairs-sfm.txt"
116
+ sfm_dir.mkdir(exist_ok=True, parents=True)
117
+
118
+ # extract features
119
+ retrieval_path = extract_features.main(
120
+ retrieval_conf, temp_images, self.outputs
121
+ )
122
+ pairs_from_retrieval.main(retrieval_path, sfm_pairs, num_matched=top_k)
123
+
124
+ feature_path = extract_features.main(feature_conf, temp_images, self.outputs)
125
+ # match features
126
+ match_path = match_features.main(
127
+ match_conf, sfm_pairs, feature_conf["output"], self.outputs
128
+ )
129
+ # reconstruction
130
+ already_sfm = False
131
+ if sfm_dir.exists():
132
+ try:
133
+ model = pycolmap.Reconstruction(str(sfm_dir))
134
+ already_sfm = True
135
+ except ValueError:
136
+ logger.info(f"sfm_dir not exists model: {sfm_dir}")
137
+ if not already_sfm:
138
+ model = reconstruction.main(
139
+ sfm_dir,
140
+ temp_images,
141
+ sfm_pairs,
142
+ feature_path,
143
+ match_path,
144
+ mapper_options=mapper_options,
145
+ )
146
+
147
+ vertices = []
148
+ for point3D_id, point3D in model.points3D.items():
149
+ vertices.append([point3D.xyz, point3D.color])
150
+
151
+ model_3d = sfm_dir / "points3D.obj"
152
+ with open(model_3d, "w") as f:
153
+ for p, c in vertices:
154
+ # Write vertex position
155
+ f.write("v {} {} {}\n".format(p[0], p[1], p[2]))
156
+ # Write vertex normal (color)
157
+ f.write(
158
+ "vn {} {} {}\n".format(c[0] / 255.0, c[1] / 255.0, c[2] / 255.0)
159
+ )
160
+ viz_2d = visualization.visualize_sfm_2d(
161
+ model, temp_images, color_by="visibility", n=2, dpi=300
162
+ )
163
+
164
+ return model_3d, fig2im(viz_2d) / 255.0
imcui/ui/utils.py CHANGED
@@ -1,1164 +1,1108 @@
1
- import os
2
- import pickle
3
- import random
4
- import shutil
5
- import time
6
- import warnings
7
- from itertools import combinations
8
- from pathlib import Path
9
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
- from datasets import load_dataset
11
-
12
- import cv2
13
- import gradio as gr
14
- import matplotlib.pyplot as plt
15
- import numpy as np
16
- import poselib
17
- import psutil
18
- from PIL import Image
19
-
20
- from ..hloc import (
21
- DEVICE,
22
- extract_features,
23
- extractors,
24
- logger,
25
- match_dense,
26
- match_features,
27
- matchers,
28
- DATASETS_REPO_ID,
29
- )
30
- from ..hloc.utils.base_model import dynamic_load
31
- from .viz import display_keypoints, display_matches, fig2im, plot_images
32
-
33
- warnings.simplefilter("ignore")
34
-
35
- ROOT = Path(__file__).parents[1]
36
- # some default values
37
- DEFAULT_SETTING_THRESHOLD = 0.1
38
- DEFAULT_SETTING_MAX_FEATURES = 2000
39
- DEFAULT_DEFAULT_KEYPOINT_THRESHOLD = 0.01
40
- DEFAULT_ENABLE_RANSAC = True
41
- DEFAULT_RANSAC_METHOD = "CV2_USAC_MAGSAC"
42
- DEFAULT_RANSAC_REPROJ_THRESHOLD = 8
43
- DEFAULT_RANSAC_CONFIDENCE = 0.9999
44
- DEFAULT_RANSAC_MAX_ITER = 10000
45
- DEFAULT_MIN_NUM_MATCHES = 4
46
- DEFAULT_MATCHING_THRESHOLD = 0.2
47
- DEFAULT_SETTING_GEOMETRY = "Homography"
48
- GRADIO_VERSION = gr.__version__.split(".")[0]
49
- MATCHER_ZOO = None
50
-
51
-
52
- class ModelCache:
53
- def __init__(self, max_memory_size: int = 8):
54
- self.max_memory_size = max_memory_size
55
- self.current_memory_size = 0
56
- self.model_dict = {}
57
- self.model_timestamps = []
58
-
59
- def cache_model(self, model_key, model_loader_func, model_conf):
60
- if model_key in self.model_dict:
61
- self.model_timestamps.remove(model_key)
62
- self.model_timestamps.append(model_key)
63
- logger.info(f"Load cached {model_key}")
64
- return self.model_dict[model_key]
65
-
66
- model = self._load_model_from_disk(model_loader_func, model_conf)
67
- while self._calculate_model_memory() > self.max_memory_size:
68
- if len(self.model_timestamps) == 0:
69
- logger.warn(
70
- "RAM: {}GB, MAX RAM: {}GB".format(
71
- self._calculate_model_memory(), self.max_memory_size
72
- )
73
- )
74
- break
75
- oldest_model_key = self.model_timestamps.pop(0)
76
- self.current_memory_size = self._calculate_model_memory()
77
- logger.info(f"Del cached {oldest_model_key}")
78
- del self.model_dict[oldest_model_key]
79
-
80
- self.model_dict[model_key] = model
81
- self.model_timestamps.append(model_key)
82
-
83
- self.print_memory_usage()
84
- logger.info(f"Total cached {list(self.model_dict.keys())}")
85
-
86
- return model
87
-
88
- def _load_model_from_disk(self, model_loader_func, model_conf):
89
- return model_loader_func(model_conf)
90
-
91
- def _calculate_model_memory(self, verbose=False):
92
- host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
93
- vm = psutil.virtual_memory()
94
- du = shutil.disk_usage(".")
95
- if verbose:
96
- logger.info(
97
- f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
98
- )
99
- logger.info(
100
- f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB"
101
- )
102
- return vm.used / 1e9
103
-
104
- def print_memory_usage(self):
105
- self._calculate_model_memory(verbose=True)
106
-
107
-
108
- model_cache = ModelCache()
109
-
110
-
111
- def load_config(config_name: str) -> Dict[str, Any]:
112
- """
113
- Load a YAML configuration file.
114
-
115
- Args:
116
- config_name: The path to the YAML configuration file.
117
-
118
- Returns:
119
- The configuration dictionary, with string keys and arbitrary values.
120
- """
121
- import yaml
122
-
123
- with open(config_name, "r") as stream:
124
- try:
125
- config: Dict[str, Any] = yaml.safe_load(stream)
126
- except yaml.YAMLError as exc:
127
- logger.error(exc)
128
- return config
129
-
130
-
131
- def get_matcher_zoo(
132
- matcher_zoo: Dict[str, Dict[str, Union[str, bool]]],
133
- ) -> Dict[str, Dict[str, Union[Callable, bool]]]:
134
- """
135
- Restore matcher configurations from a dictionary.
136
-
137
- Args:
138
- matcher_zoo: A dictionary with the matcher configurations,
139
- where the configuration is a dictionary as loaded from a YAML file.
140
-
141
- Returns:
142
- A dictionary with the matcher configurations, where the configuration is
143
- a function or a function instead of a string.
144
- """
145
- matcher_zoo_restored = {}
146
- for k, v in matcher_zoo.items():
147
- matcher_zoo_restored[k] = parse_match_config(v)
148
- return matcher_zoo_restored
149
-
150
-
151
- def parse_match_config(conf):
152
- if conf["dense"]:
153
- return {
154
- "matcher": match_dense.confs.get(conf["matcher"]),
155
- "dense": True,
156
- "info": conf.get("info", {}),
157
- }
158
- else:
159
- return {
160
- "feature": extract_features.confs.get(conf["feature"]),
161
- "matcher": match_features.confs.get(conf["matcher"]),
162
- "dense": False,
163
- "info": conf.get("info", {}),
164
- }
165
-
166
-
167
- def get_model(match_conf: Dict[str, Any]):
168
- """
169
- Load a matcher model from the provided configuration.
170
-
171
- Args:
172
- match_conf: A dictionary containing the model configuration.
173
-
174
- Returns:
175
- A matcher model instance.
176
- """
177
- Model = dynamic_load(matchers, match_conf["model"]["name"])
178
- model = Model(match_conf["model"]).eval().to(DEVICE)
179
- return model
180
-
181
-
182
- def get_feature_model(conf: Dict[str, Dict[str, Any]]):
183
- """
184
- Load a feature extraction model from the provided configuration.
185
-
186
- Args:
187
- conf: A dictionary containing the model configuration.
188
-
189
- Returns:
190
- A feature extraction model instance.
191
- """
192
- Model = dynamic_load(extractors, conf["model"]["name"])
193
- model = Model(conf["model"]).eval().to(DEVICE)
194
- return model
195
-
196
-
197
- def download_example_images(repo_id, output_dir):
198
- logger.info(f"Download example dataset from huggingface: {repo_id}")
199
- dataset = load_dataset(repo_id)
200
- Path(output_dir).mkdir(parents=True, exist_ok=True)
201
- for example in dataset["train"]: # Assuming the dataset is in the "train" split
202
- file_path = example["path"]
203
- image = example["image"] # Access the PIL.Image object directly
204
- full_path = os.path.join(output_dir, file_path)
205
- Path(os.path.dirname(full_path)).mkdir(parents=True, exist_ok=True)
206
- image.save(full_path)
207
- logger.info(f"Images saved to {output_dir} successfully.")
208
- return Path(output_dir)
209
-
210
-
211
- def gen_examples(data_root: Path):
212
- random.seed(1)
213
- example_matchers = [
214
- "disk+lightglue",
215
- "xfeat(sparse)",
216
- "dedode",
217
- "loftr",
218
- "disk",
219
- "RoMa",
220
- "d2net",
221
- "aspanformer",
222
- "topicfm",
223
- "superpoint+superglue",
224
- "superpoint+lightglue",
225
- "superpoint+mnn",
226
- "disk",
227
- ]
228
- data_root = Path(data_root)
229
- if not Path(data_root).exists():
230
- try:
231
- download_example_images(DATASETS_REPO_ID, data_root)
232
- except Exception as e:
233
- logger.error(f"download_example_images error : {e}")
234
- data_root = ROOT / "datasets"
235
- if not Path(data_root / "sacre_coeur/mapping").exists():
236
- download_example_images(DATASETS_REPO_ID, data_root)
237
-
238
- def distribute_elements(A, B):
239
- new_B = np.array(B, copy=True).flatten()
240
- np.random.shuffle(new_B)
241
- new_B = np.resize(new_B, len(A))
242
- np.random.shuffle(new_B)
243
- return new_B.tolist()
244
-
245
- # normal examples
246
- def gen_images_pairs(count: int = 5):
247
- path = str(data_root / "sacre_coeur/mapping")
248
- imgs_list = [
249
- os.path.join(path, file)
250
- for file in os.listdir(path)
251
- if file.lower().endswith((".jpg", ".jpeg", ".png"))
252
- ]
253
- pairs = list(combinations(imgs_list, 2))
254
- if len(pairs) < count:
255
- count = len(pairs)
256
- selected = random.sample(range(len(pairs)), count)
257
- return [pairs[i] for i in selected]
258
-
259
- # rotated examples
260
- def gen_rot_image_pairs(count: int = 5):
261
- path = data_root / "sacre_coeur/mapping"
262
- path_rot = data_root / "sacre_coeur/mapping_rot"
263
- rot_list = [45, 180, 90, 225, 270]
264
- pairs = []
265
- for file in os.listdir(path):
266
- if file.lower().endswith((".jpg", ".jpeg", ".png")):
267
- for rot in rot_list:
268
- file_rot = "{}_rot{}.jpg".format(Path(file).stem, rot)
269
- if (path_rot / file_rot).exists():
270
- pairs.append(
271
- [
272
- path / file,
273
- path_rot / file_rot,
274
- ]
275
- )
276
- if len(pairs) < count:
277
- count = len(pairs)
278
- selected = random.sample(range(len(pairs)), count)
279
- return [pairs[i] for i in selected]
280
-
281
- def gen_scale_image_pairs(count: int = 5):
282
- path = data_root / "sacre_coeur/mapping"
283
- path_scale = data_root / "sacre_coeur/mapping_scale"
284
- scale_list = [0.3, 0.5]
285
- pairs = []
286
- for file in os.listdir(path):
287
- if file.lower().endswith((".jpg", ".jpeg", ".png")):
288
- for scale in scale_list:
289
- file_scale = "{}_scale{}.jpg".format(Path(file).stem, scale)
290
- if (path_scale / file_scale).exists():
291
- pairs.append(
292
- [
293
- path / file,
294
- path_scale / file_scale,
295
- ]
296
- )
297
- if len(pairs) < count:
298
- count = len(pairs)
299
- selected = random.sample(range(len(pairs)), count)
300
- return [pairs[i] for i in selected]
301
-
302
- # extramely hard examples
303
- def gen_image_pairs_wxbs(count: int = None):
304
- prefix = "wxbs_benchmark/.WxBS/v1.1"
305
- wxbs_path = data_root / prefix
306
- pairs = []
307
- for catg in os.listdir(wxbs_path):
308
- catg_path = wxbs_path / catg
309
- if not catg_path.is_dir():
310
- continue
311
- for scene in os.listdir(catg_path):
312
- scene_path = catg_path / scene
313
- if not scene_path.is_dir():
314
- continue
315
- img1_path = scene_path / "01.png"
316
- img2_path = scene_path / "02.png"
317
- if img1_path.exists() and img2_path.exists():
318
- pairs.append([str(img1_path), str(img2_path)])
319
- return pairs
320
-
321
- # image pair path
322
- pairs = gen_images_pairs()
323
- pairs += gen_rot_image_pairs()
324
- pairs += gen_scale_image_pairs()
325
- pairs += gen_image_pairs_wxbs()
326
-
327
- match_setting_threshold = DEFAULT_SETTING_THRESHOLD
328
- match_setting_max_features = DEFAULT_SETTING_MAX_FEATURES
329
- detect_keypoints_threshold = DEFAULT_DEFAULT_KEYPOINT_THRESHOLD
330
- ransac_method = DEFAULT_RANSAC_METHOD
331
- ransac_reproj_threshold = DEFAULT_RANSAC_REPROJ_THRESHOLD
332
- ransac_confidence = DEFAULT_RANSAC_CONFIDENCE
333
- ransac_max_iter = DEFAULT_RANSAC_MAX_ITER
334
- input_lists = []
335
- dist_examples = distribute_elements(pairs, example_matchers)
336
- for pair, mt in zip(pairs, dist_examples):
337
- input_lists.append(
338
- [
339
- pair[0],
340
- pair[1],
341
- match_setting_threshold,
342
- match_setting_max_features,
343
- detect_keypoints_threshold,
344
- mt,
345
- # enable_ransac,
346
- ransac_method,
347
- ransac_reproj_threshold,
348
- ransac_confidence,
349
- ransac_max_iter,
350
- ]
351
- )
352
- return input_lists
353
-
354
-
355
- def set_null_pred(feature_type: str, pred: dict):
356
- if feature_type == "KEYPOINT":
357
- pred["mmkeypoints0_orig"] = np.array([])
358
- pred["mmkeypoints1_orig"] = np.array([])
359
- pred["mmconf"] = np.array([])
360
- elif feature_type == "LINE":
361
- pred["mline_keypoints0_orig"] = np.array([])
362
- pred["mline_keypoints1_orig"] = np.array([])
363
- pred["H"] = None
364
- pred["geom_info"] = {}
365
- return pred
366
-
367
-
368
- def _filter_matches_opencv(
369
- kp0: np.ndarray,
370
- kp1: np.ndarray,
371
- method: int = cv2.RANSAC,
372
- reproj_threshold: float = 3.0,
373
- confidence: float = 0.99,
374
- max_iter: int = 2000,
375
- geometry_type: str = "Homography",
376
- ) -> Tuple[np.ndarray, np.ndarray]:
377
- """
378
- Filters matches between two sets of keypoints using OpenCV's findHomography.
379
-
380
- Args:
381
- kp0 (np.ndarray): Array of keypoints from the first image.
382
- kp1 (np.ndarray): Array of keypoints from the second image.
383
- method (int, optional): RANSAC method. Defaults to "cv2.RANSAC".
384
- reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3.0.
385
- confidence (float, optional): RANSAC confidence. Defaults to 0.99.
386
- max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000.
387
- geometry_type (str, optional): Type of geometry. Defaults to "Homography".
388
-
389
- Returns:
390
- Tuple[np.ndarray, np.ndarray]: Homography matrix and mask.
391
- """
392
- if geometry_type == "Homography":
393
- try:
394
- M, mask = cv2.findHomography(
395
- kp0,
396
- kp1,
397
- method=method,
398
- ransacReprojThreshold=reproj_threshold,
399
- confidence=confidence,
400
- maxIters=max_iter,
401
- )
402
- except cv2.error:
403
- logger.error("compute findHomography error, len(kp0): {}".format(len(kp0)))
404
- return None, None
405
- elif geometry_type == "Fundamental":
406
- try:
407
- M, mask = cv2.findFundamentalMat(
408
- kp0,
409
- kp1,
410
- method=method,
411
- ransacReprojThreshold=reproj_threshold,
412
- confidence=confidence,
413
- maxIters=max_iter,
414
- )
415
- except cv2.error:
416
- logger.error(
417
- "compute findFundamentalMat error, len(kp0): {}".format(len(kp0))
418
- )
419
- return None, None
420
- mask = np.array(mask.ravel().astype("bool"), dtype="bool")
421
- return M, mask
422
-
423
-
424
- def _filter_matches_poselib(
425
- kp0: np.ndarray,
426
- kp1: np.ndarray,
427
- method: int = None, # not used
428
- reproj_threshold: float = 3,
429
- confidence: float = 0.99,
430
- max_iter: int = 2000,
431
- geometry_type: str = "Homography",
432
- ) -> dict:
433
- """
434
- Filters matches between two sets of keypoints using the poselib library.
435
-
436
- Args:
437
- kp0 (np.ndarray): Array of keypoints from the first image.
438
- kp1 (np.ndarray): Array of keypoints from the second image.
439
- method (str, optional): RANSAC method. Defaults to "RANSAC".
440
- reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3.
441
- confidence (float, optional): RANSAC confidence. Defaults to 0.99.
442
- max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000.
443
- geometry_type (str, optional): Type of geometry. Defaults to "Homography".
444
-
445
- Returns:
446
- dict: Information about the homography estimation.
447
- """
448
- ransac_options = {
449
- "max_iterations": max_iter,
450
- # "min_iterations": min_iter,
451
- "success_prob": confidence,
452
- "max_reproj_error": reproj_threshold,
453
- # "progressive_sampling": args.sampler.lower() == 'prosac'
454
- }
455
-
456
- if geometry_type == "Homography":
457
- M, info = poselib.estimate_homography(kp0, kp1, ransac_options)
458
- elif geometry_type == "Fundamental":
459
- M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options)
460
- else:
461
- raise NotImplementedError
462
-
463
- return M, np.array(info["inliers"])
464
-
465
-
466
- def proc_ransac_matches(
467
- mkpts0: np.ndarray,
468
- mkpts1: np.ndarray,
469
- ransac_method: str = DEFAULT_RANSAC_METHOD,
470
- ransac_reproj_threshold: float = 3.0,
471
- ransac_confidence: float = 0.99,
472
- ransac_max_iter: int = 2000,
473
- geometry_type: str = "Homography",
474
- ):
475
- if ransac_method.startswith("CV2"):
476
- logger.info(f"ransac_method: {ransac_method}, geometry_type: {geometry_type}")
477
- return _filter_matches_opencv(
478
- mkpts0,
479
- mkpts1,
480
- ransac_zoo[ransac_method],
481
- ransac_reproj_threshold,
482
- ransac_confidence,
483
- ransac_max_iter,
484
- geometry_type,
485
- )
486
- elif ransac_method.startswith("POSELIB"):
487
- logger.info(f"ransac_method: {ransac_method}, geometry_type: {geometry_type}")
488
- return _filter_matches_poselib(
489
- mkpts0,
490
- mkpts1,
491
- None,
492
- ransac_reproj_threshold,
493
- ransac_confidence,
494
- ransac_max_iter,
495
- geometry_type,
496
- )
497
- else:
498
- raise NotImplementedError
499
-
500
-
501
- def filter_matches(
502
- pred: Dict[str, Any],
503
- ransac_method: str = DEFAULT_RANSAC_METHOD,
504
- ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
505
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
506
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
507
- ransac_estimator: str = None,
508
- ):
509
- """
510
- Filter matches using RANSAC. If keypoints are available, filter by keypoints.
511
- If lines are available, filter by lines. If both keypoints and lines are
512
- available, filter by keypoints.
513
-
514
- Args:
515
- pred (Dict[str, Any]): dict of matches, including original keypoints.
516
- ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
517
- ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
518
- ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
519
- ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
520
-
521
- Returns:
522
- Dict[str, Any]: filtered matches.
523
- """
524
- mkpts0: Optional[np.ndarray] = None
525
- mkpts1: Optional[np.ndarray] = None
526
- feature_type: Optional[str] = None
527
- if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys():
528
- mkpts0 = pred["mkeypoints0_orig"]
529
- mkpts1 = pred["mkeypoints1_orig"]
530
- feature_type = "KEYPOINT"
531
- elif (
532
- "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys()
533
- ):
534
- mkpts0 = pred["line_keypoints0_orig"]
535
- mkpts1 = pred["line_keypoints1_orig"]
536
- feature_type = "LINE"
537
- else:
538
- return set_null_pred(feature_type, pred)
539
- if mkpts0 is None or mkpts0 is None:
540
- return set_null_pred(feature_type, pred)
541
- if ransac_method not in ransac_zoo.keys():
542
- ransac_method = DEFAULT_RANSAC_METHOD
543
-
544
- if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES:
545
- return set_null_pred(feature_type, pred)
546
-
547
- geom_info = compute_geometry(
548
- pred,
549
- ransac_method=ransac_method,
550
- ransac_reproj_threshold=ransac_reproj_threshold,
551
- ransac_confidence=ransac_confidence,
552
- ransac_max_iter=ransac_max_iter,
553
- )
554
-
555
- if "Homography" in geom_info.keys():
556
- mask = geom_info["mask_h"]
557
- if feature_type == "KEYPOINT":
558
- pred["mmkeypoints0_orig"] = mkpts0[mask]
559
- pred["mmkeypoints1_orig"] = mkpts1[mask]
560
- pred["mmconf"] = pred["mconf"][mask]
561
- elif feature_type == "LINE":
562
- pred["mline_keypoints0_orig"] = mkpts0[mask]
563
- pred["mline_keypoints1_orig"] = mkpts1[mask]
564
- pred["H"] = np.array(geom_info["Homography"])
565
- else:
566
- set_null_pred(feature_type, pred)
567
- # do not show mask
568
- geom_info.pop("mask_h", None)
569
- geom_info.pop("mask_f", None)
570
- pred["geom_info"] = geom_info
571
- return pred
572
-
573
-
574
- def compute_geometry(
575
- pred: Dict[str, Any],
576
- ransac_method: str = DEFAULT_RANSAC_METHOD,
577
- ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
578
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
579
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
580
- ) -> Dict[str, List[float]]:
581
- """
582
- Compute geometric information of matches, including Fundamental matrix,
583
- Homography matrix, and rectification matrices (if available).
584
-
585
- Args:
586
- pred (Dict[str, Any]): dict of matches, including original keypoints.
587
- ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
588
- ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
589
- ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
590
- ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
591
-
592
- Returns:
593
- Dict[str, List[float]]: geometric information in form of a dict.
594
- """
595
- mkpts0: Optional[np.ndarray] = None
596
- mkpts1: Optional[np.ndarray] = None
597
-
598
- if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys():
599
- mkpts0 = pred["mkeypoints0_orig"]
600
- mkpts1 = pred["mkeypoints1_orig"]
601
- elif (
602
- "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys()
603
- ):
604
- mkpts0 = pred["line_keypoints0_orig"]
605
- mkpts1 = pred["line_keypoints1_orig"]
606
-
607
- if mkpts0 is not None and mkpts1 is not None:
608
- if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES:
609
- return {}
610
- geo_info: Dict[str, List[float]] = {}
611
-
612
- F, mask_f = proc_ransac_matches(
613
- mkpts0,
614
- mkpts1,
615
- ransac_method,
616
- ransac_reproj_threshold,
617
- ransac_confidence,
618
- ransac_max_iter,
619
- geometry_type="Fundamental",
620
- )
621
-
622
- if F is not None:
623
- geo_info["Fundamental"] = F.tolist()
624
- geo_info["mask_f"] = mask_f
625
- H, mask_h = proc_ransac_matches(
626
- mkpts1,
627
- mkpts0,
628
- ransac_method,
629
- ransac_reproj_threshold,
630
- ransac_confidence,
631
- ransac_max_iter,
632
- geometry_type="Homography",
633
- )
634
-
635
- h0, w0, _ = pred["image0_orig"].shape
636
- if H is not None:
637
- geo_info["Homography"] = H.tolist()
638
- geo_info["mask_h"] = mask_h
639
- try:
640
- _, H1, H2 = cv2.stereoRectifyUncalibrated(
641
- mkpts0.reshape(-1, 2),
642
- mkpts1.reshape(-1, 2),
643
- F,
644
- imgSize=(w0, h0),
645
- )
646
- geo_info["H1"] = H1.tolist()
647
- geo_info["H2"] = H2.tolist()
648
- except cv2.error as e:
649
- logger.error(f"StereoRectifyUncalibrated failed, skip! error: {e}")
650
- return geo_info
651
- else:
652
- return {}
653
-
654
-
655
- def wrap_images(
656
- img0: np.ndarray,
657
- img1: np.ndarray,
658
- geo_info: Optional[Dict[str, List[float]]],
659
- geom_type: str,
660
- ) -> Tuple[Optional[str], Optional[Dict[str, List[float]]]]:
661
- """
662
- Wraps the images based on the geometric transformation used to align them.
663
-
664
- Args:
665
- img0: numpy array representing the first image.
666
- img1: numpy array representing the second image.
667
- geo_info: dictionary containing the geometric transformation information.
668
- geom_type: type of geometric transformation used to align the images.
669
-
670
- Returns:
671
- A tuple containing a base64 encoded image string and a dictionary with the transformation matrix.
672
- """
673
- h0, w0, _ = img0.shape
674
- h1, w1, _ = img1.shape
675
- if geo_info is not None and len(geo_info) != 0:
676
- rectified_image0 = img0
677
- rectified_image1 = None
678
- if "Homography" not in geo_info:
679
- logger.warning(f"{geom_type} not exist, maybe too less matches")
680
- return None, None
681
-
682
- H = np.array(geo_info["Homography"])
683
-
684
- title: List[str] = []
685
- if geom_type == "Homography":
686
- rectified_image1 = cv2.warpPerspective(img1, H, (w0, h0))
687
- title = ["Image 0", "Image 1 - warped"]
688
- elif geom_type == "Fundamental":
689
- if geom_type not in geo_info:
690
- logger.warning(f"{geom_type} not exist, maybe too less matches")
691
- return None, None
692
- else:
693
- H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
694
- rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0))
695
- rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1))
696
- title = ["Image 0 - warped", "Image 1 - warped"]
697
- else:
698
- print("Error: Unknown geometry type")
699
- fig = plot_images(
700
- [rectified_image0.squeeze(), rectified_image1.squeeze()],
701
- title,
702
- dpi=300,
703
- )
704
- return fig2im(fig), rectified_image1
705
- else:
706
- return None, None
707
-
708
-
709
- def generate_warp_images(
710
- input_image0: np.ndarray,
711
- input_image1: np.ndarray,
712
- matches_info: Dict[str, Any],
713
- choice: str,
714
- ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
715
- """
716
- Changes the estimate of the geometric transformation used to align the images.
717
-
718
- Args:
719
- input_image0: First input image.
720
- input_image1: Second input image.
721
- matches_info: Dictionary containing information about the matches.
722
- choice: Type of geometric transformation to use ('Homography' or 'Fundamental') or 'No' to disable.
723
-
724
- Returns:
725
- A tuple containing the updated images and the warpped images.
726
- """
727
- if (
728
- matches_info is None
729
- or len(matches_info) < 1
730
- or "geom_info" not in matches_info.keys()
731
- ):
732
- return None, None
733
- geom_info = matches_info["geom_info"]
734
- warped_image = None
735
- if choice != "No":
736
- wrapped_image_pair, warped_image = wrap_images(
737
- input_image0, input_image1, geom_info, choice
738
- )
739
- return wrapped_image_pair, warped_image
740
- else:
741
- return None, None
742
-
743
-
744
- def send_to_match(state_cache: Dict[str, Any]):
745
- """
746
- Send the state cache to the match function.
747
-
748
- Args:
749
- state_cache (Dict[str, Any]): Current state of the app.
750
-
751
- Returns:
752
- None
753
- """
754
- if state_cache:
755
- return (
756
- state_cache["image0_orig"],
757
- state_cache["wrapped_image"],
758
- )
759
- else:
760
- return None, None
761
-
762
-
763
- def run_ransac(
764
- state_cache: Dict[str, Any],
765
- choice_geometry_type: str,
766
- ransac_method: str = DEFAULT_RANSAC_METHOD,
767
- ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
768
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
769
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
770
- ) -> Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]:
771
- """
772
- Run RANSAC matches and return the output images and the number of matches.
773
-
774
- Args:
775
- state_cache (Dict[str, Any]): Current state of the app, including the matches.
776
- ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
777
- ransac_reproj_threshold (int, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
778
- ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
779
- ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
780
-
781
- Returns:
782
- Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]: Tuple containing the output images and the number of matches.
783
- """
784
- if not state_cache:
785
- logger.info("Run Match first before Rerun RANSAC")
786
- gr.Warning("Run Match first before Rerun RANSAC")
787
- return None, None
788
- t1 = time.time()
789
- logger.info(
790
- f"Run RANSAC matches using: {ransac_method} with threshold: {ransac_reproj_threshold}"
791
- )
792
- logger.info(
793
- f"Run RANSAC matches using: {ransac_confidence} with iter: {ransac_max_iter}"
794
- )
795
- # if enable_ransac:
796
- filter_matches(
797
- state_cache,
798
- ransac_method=ransac_method,
799
- ransac_reproj_threshold=ransac_reproj_threshold,
800
- ransac_confidence=ransac_confidence,
801
- ransac_max_iter=ransac_max_iter,
802
- )
803
- logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
804
- t1 = time.time()
805
-
806
- # plot images with ransac matches
807
- titles = [
808
- "Image 0 - Ransac matched keypoints",
809
- "Image 1 - Ransac matched keypoints",
810
- ]
811
- output_matches_ransac, num_matches_ransac = display_matches(
812
- state_cache, titles=titles, tag="KPTS_RANSAC"
813
- )
814
- logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
815
- t1 = time.time()
816
-
817
- # compute warp images
818
- output_wrapped, warped_image = generate_warp_images(
819
- state_cache["image0_orig"],
820
- state_cache["image1_orig"],
821
- state_cache,
822
- choice_geometry_type,
823
- )
824
- plt.close("all")
825
-
826
- num_matches_raw = state_cache["num_matches_raw"]
827
- state_cache["wrapped_image"] = warped_image
828
-
829
- # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False)
830
- tmp_state_cache = "output.pkl"
831
- with open(tmp_state_cache, "wb") as f:
832
- pickle.dump(state_cache, f)
833
-
834
- logger.info("Dump results done!")
835
-
836
- return (
837
- output_matches_ransac,
838
- {
839
- "num_matches_raw": num_matches_raw,
840
- "num_matches_ransac": num_matches_ransac,
841
- },
842
- output_wrapped,
843
- tmp_state_cache,
844
- )
845
-
846
-
847
- def generate_fake_outputs(
848
- output_keypoints,
849
- output_matches_raw,
850
- output_matches_ransac,
851
- match_conf,
852
- extract_conf,
853
- pred,
854
- ):
855
- return (
856
- output_keypoints,
857
- output_matches_raw,
858
- output_matches_ransac,
859
- {},
860
- {
861
- "match_conf": match_conf,
862
- "extractor_conf": extract_conf,
863
- },
864
- {
865
- "geom_info": pred.get("geom_info", {}),
866
- },
867
- None,
868
- None,
869
- None,
870
- )
871
-
872
-
873
- def run_matching(
874
- image0: np.ndarray,
875
- image1: np.ndarray,
876
- match_threshold: float,
877
- extract_max_keypoints: int,
878
- keypoint_threshold: float,
879
- key: str,
880
- ransac_method: str = DEFAULT_RANSAC_METHOD,
881
- ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
882
- ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
883
- ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
884
- choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
885
- matcher_zoo: Dict[str, Any] = None,
886
- force_resize: bool = False,
887
- image_width: int = 640,
888
- image_height: int = 480,
889
- use_cached_model: bool = False,
890
- ) -> Tuple[
891
- np.ndarray,
892
- np.ndarray,
893
- np.ndarray,
894
- Dict[str, int],
895
- Dict[str, Dict[str, Any]],
896
- Dict[str, Dict[str, float]],
897
- np.ndarray,
898
- ]:
899
- """Match two images using the given parameters.
900
-
901
- Args:
902
- image0 (np.ndarray): RGB image 0.
903
- image1 (np.ndarray): RGB image 1.
904
- match_threshold (float): match threshold.
905
- extract_max_keypoints (int): number of keypoints to extract.
906
- keypoint_threshold (float): keypoint threshold.
907
- key (str): key of the model to use.
908
- ransac_method (str, optional): RANSAC method to use.
909
- ransac_reproj_threshold (int, optional): RANSAC reprojection threshold.
910
- ransac_confidence (float, optional): RANSAC confidence level.
911
- ransac_max_iter (int, optional): RANSAC maximum number of iterations.
912
- choice_geometry_type (str, optional): setting of geometry estimation.
913
- matcher_zoo (Dict[str, Any], optional): matcher zoo. Defaults to None.
914
- force_resize (bool, optional): force resize. Defaults to False.
915
- image_width (int, optional): image width. Defaults to 640.
916
- image_height (int, optional): image height. Defaults to 480.
917
- use_cached_model (bool, optional): use cached model. Defaults to False.
918
-
919
- Returns:
920
- tuple:
921
- - output_keypoints (np.ndarray): image with keypoints.
922
- - output_matches_raw (np.ndarray): image with raw matches.
923
- - output_matches_ransac (np.ndarray): image with RANSAC matches.
924
- - num_matches (Dict[str, int]): number of raw and RANSAC matches.
925
- - configs (Dict[str, Dict[str, Any]]): match and feature extraction configs.
926
- - geom_info (Dict[str, Dict[str, float]]): geometry information.
927
- - output_wrapped (np.ndarray): wrapped images.
928
- """
929
- # image0 and image1 is RGB mode
930
- if image0 is None or image1 is None:
931
- logger.error(
932
- "Error: No images found! Please upload two images or select an example."
933
- )
934
- raise gr.Error(
935
- "Error: No images found! Please upload two images or select an example."
936
- )
937
- # init output
938
- output_keypoints = None
939
- output_matches_raw = None
940
- output_matches_ransac = None
941
-
942
- t0 = time.time()
943
- model = matcher_zoo[key]
944
- match_conf = model["matcher"]
945
- # update match config
946
- match_conf["model"]["match_threshold"] = match_threshold
947
- match_conf["model"]["max_keypoints"] = extract_max_keypoints
948
- cache_key = "{}_{}".format(key, match_conf["model"]["name"])
949
-
950
- efficiency = model["info"].get("efficiency", "high")
951
- if efficiency == "low":
952
- gr.Warning(
953
- "Matcher {} is time-consuming, please wait for a while".format(
954
- model["info"].get("name", "unknown")
955
- )
956
- )
957
-
958
- if use_cached_model:
959
- # because of the model cache, we need to update the config
960
- matcher = model_cache.cache_model(cache_key, get_model, match_conf)
961
- matcher.conf["max_keypoints"] = extract_max_keypoints
962
- matcher.conf["match_threshold"] = match_threshold
963
- logger.info(f"Loaded cached model {cache_key}")
964
- else:
965
- matcher = get_model(match_conf)
966
- logger.info(f"Loading model using: {time.time()-t0:.3f}s")
967
- t1 = time.time()
968
- yield generate_fake_outputs(
969
- output_keypoints, output_matches_raw, output_matches_ransac, match_conf, {}, {}
970
- )
971
-
972
- if model["dense"]:
973
- if not match_conf["preprocessing"].get("force_resize", False):
974
- match_conf["preprocessing"]["force_resize"] = force_resize
975
- else:
976
- logger.info("preprocessing is already resized")
977
- if force_resize:
978
- match_conf["preprocessing"]["height"] = image_height
979
- match_conf["preprocessing"]["width"] = image_width
980
- logger.info(f"Force resize to {image_width}x{image_height}")
981
-
982
- pred = match_dense.match_images(
983
- matcher, image0, image1, match_conf["preprocessing"], device=DEVICE
984
- )
985
- del matcher
986
- extract_conf = None
987
- else:
988
- extract_conf = model["feature"]
989
- # update extract config
990
- extract_conf["model"]["max_keypoints"] = extract_max_keypoints
991
- extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
992
- cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
993
-
994
- if use_cached_model:
995
- extractor = model_cache.cache_model(
996
- cache_key, get_feature_model, extract_conf
997
- )
998
- # because of the model cache, we need to update the config
999
- extractor.conf["max_keypoints"] = extract_max_keypoints
1000
- extractor.conf["keypoint_threshold"] = keypoint_threshold
1001
- logger.info(f"Loaded cached model {cache_key}")
1002
- else:
1003
- extractor = get_feature_model(extract_conf)
1004
-
1005
- if not extract_conf["preprocessing"].get("force_resize", False):
1006
- extract_conf["preprocessing"]["force_resize"] = force_resize
1007
- else:
1008
- logger.info("preprocessing is already resized")
1009
- if force_resize:
1010
- extract_conf["preprocessing"]["height"] = image_height
1011
- extract_conf["preprocessing"]["width"] = image_width
1012
- logger.info(f"Force resize to {image_width}x{image_height}")
1013
-
1014
- pred0 = extract_features.extract(
1015
- extractor, image0, extract_conf["preprocessing"]
1016
- )
1017
- pred1 = extract_features.extract(
1018
- extractor, image1, extract_conf["preprocessing"]
1019
- )
1020
- pred = match_features.match_images(matcher, pred0, pred1)
1021
- del extractor
1022
- # gr.Info(
1023
- # f"Matching images done using: {time.time()-t1:.3f}s",
1024
- # )
1025
- logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
1026
- t1 = time.time()
1027
-
1028
- # plot images with keypoints
1029
- titles = [
1030
- "Image 0 - Keypoints",
1031
- "Image 1 - Keypoints",
1032
- ]
1033
- output_keypoints = display_keypoints(pred, titles=titles)
1034
- yield generate_fake_outputs(
1035
- output_keypoints,
1036
- output_matches_raw,
1037
- output_matches_ransac,
1038
- match_conf,
1039
- extract_conf,
1040
- pred,
1041
- )
1042
-
1043
- # plot images with raw matches
1044
- titles = [
1045
- "Image 0 - Raw matched keypoints",
1046
- "Image 1 - Raw matched keypoints",
1047
- ]
1048
- output_matches_raw, num_matches_raw = display_matches(pred, titles=titles)
1049
- yield generate_fake_outputs(
1050
- output_keypoints,
1051
- output_matches_raw,
1052
- output_matches_ransac,
1053
- match_conf,
1054
- extract_conf,
1055
- pred,
1056
- )
1057
-
1058
- # if enable_ransac:
1059
- filter_matches(
1060
- pred,
1061
- ransac_method=ransac_method,
1062
- ransac_reproj_threshold=ransac_reproj_threshold,
1063
- ransac_confidence=ransac_confidence,
1064
- ransac_max_iter=ransac_max_iter,
1065
- )
1066
-
1067
- # gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
1068
- logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
1069
- t1 = time.time()
1070
-
1071
- # plot images with ransac matches
1072
- titles = [
1073
- "Image 0 - Ransac matched keypoints",
1074
- "Image 1 - Ransac matched keypoints",
1075
- ]
1076
- output_matches_ransac, num_matches_ransac = display_matches(
1077
- pred, titles=titles, tag="KPTS_RANSAC"
1078
- )
1079
- yield generate_fake_outputs(
1080
- output_keypoints,
1081
- output_matches_raw,
1082
- output_matches_ransac,
1083
- match_conf,
1084
- extract_conf,
1085
- pred,
1086
- )
1087
-
1088
- # gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
1089
- logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
1090
- t1 = time.time()
1091
- # plot wrapped images
1092
- output_wrapped, warped_image = generate_warp_images(
1093
- pred["image0_orig"],
1094
- pred["image1_orig"],
1095
- pred,
1096
- choice_geometry_type,
1097
- )
1098
- plt.close("all")
1099
- # gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
1100
- logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
1101
-
1102
- state_cache = pred
1103
- state_cache["num_matches_raw"] = num_matches_raw
1104
- state_cache["num_matches_ransac"] = num_matches_ransac
1105
- state_cache["wrapped_image"] = warped_image
1106
-
1107
- # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False)
1108
- tmp_state_cache = "output.pkl"
1109
- with open(tmp_state_cache, "wb") as f:
1110
- pickle.dump(state_cache, f)
1111
- logger.info("Dump results done!")
1112
-
1113
- yield (
1114
- output_keypoints,
1115
- output_matches_raw,
1116
- output_matches_ransac,
1117
- {
1118
- "num_raw_matches": num_matches_raw,
1119
- "num_ransac_matches": num_matches_ransac,
1120
- },
1121
- {
1122
- "match_conf": match_conf,
1123
- "extractor_conf": extract_conf,
1124
- },
1125
- {
1126
- "geom_info": pred.get("geom_info", {}),
1127
- },
1128
- output_wrapped,
1129
- state_cache,
1130
- tmp_state_cache,
1131
- )
1132
-
1133
-
1134
- # @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html
1135
- # AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs
1136
- ransac_zoo = {
1137
- "POSELIB": "LO-RANSAC",
1138
- "CV2_RANSAC": cv2.RANSAC,
1139
- "CV2_USAC_MAGSAC": cv2.USAC_MAGSAC,
1140
- "CV2_USAC_DEFAULT": cv2.USAC_DEFAULT,
1141
- "CV2_USAC_FM_8PTS": cv2.USAC_FM_8PTS,
1142
- "CV2_USAC_PROSAC": cv2.USAC_PROSAC,
1143
- "CV2_USAC_FAST": cv2.USAC_FAST,
1144
- "CV2_USAC_ACCURATE": cv2.USAC_ACCURATE,
1145
- "CV2_USAC_PARALLEL": cv2.USAC_PARALLEL,
1146
- }
1147
-
1148
-
1149
- def rotate_image(input_path, degrees, output_path):
1150
- img = Image.open(input_path)
1151
- img_rotated = img.rotate(-degrees)
1152
- img_rotated.save(output_path)
1153
-
1154
-
1155
- def scale_image(input_path, scale_factor, output_path):
1156
- img = Image.open(input_path)
1157
- width, height = img.size
1158
- new_width = int(width * scale_factor)
1159
- new_height = int(height * scale_factor)
1160
- new_img = Image.new("RGB", (width, height), (0, 0, 0))
1161
- img_resized = img.resize((new_width, new_height))
1162
- position = ((width - new_width) // 2, (height - new_height) // 2)
1163
- new_img.paste(img_resized, position)
1164
- new_img.save(output_path)
 
1
+ import os
2
+ import pickle
3
+ import random
4
+ import time
5
+ import warnings
6
+ from itertools import combinations
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+ from datasets import load_dataset
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import poselib
16
+ from PIL import Image
17
+
18
+ from ..hloc import (
19
+ DEVICE,
20
+ extract_features,
21
+ extractors,
22
+ logger,
23
+ match_dense,
24
+ match_features,
25
+ matchers,
26
+ DATASETS_REPO_ID,
27
+ )
28
+ from ..hloc.utils.base_model import dynamic_load
29
+ from .viz import display_keypoints, display_matches, fig2im, plot_images
30
+ from .modelcache import ARCSizeAwareModelCache as ModelCache
31
+
32
+ warnings.simplefilter("ignore")
33
+
34
+ ROOT = Path(__file__).parents[1]
35
+ # some default values
36
+ DEFAULT_SETTING_THRESHOLD = 0.1
37
+ DEFAULT_SETTING_MAX_FEATURES = 2000
38
+ DEFAULT_DEFAULT_KEYPOINT_THRESHOLD = 0.01
39
+ DEFAULT_ENABLE_RANSAC = True
40
+ DEFAULT_RANSAC_METHOD = "CV2_USAC_MAGSAC"
41
+ DEFAULT_RANSAC_REPROJ_THRESHOLD = 8
42
+ DEFAULT_RANSAC_CONFIDENCE = 0.9999
43
+ DEFAULT_RANSAC_MAX_ITER = 10000
44
+ DEFAULT_MIN_NUM_MATCHES = 4
45
+ DEFAULT_MATCHING_THRESHOLD = 0.2
46
+ DEFAULT_SETTING_GEOMETRY = "Homography"
47
+ GRADIO_VERSION = gr.__version__.split(".")[0]
48
+ MATCHER_ZOO = None
49
+
50
+
51
+ model_cache = ModelCache()
52
+
53
+
54
+ def load_config(config_name: str) -> Dict[str, Any]:
55
+ """
56
+ Load a YAML configuration file.
57
+
58
+ Args:
59
+ config_name: The path to the YAML configuration file.
60
+
61
+ Returns:
62
+ The configuration dictionary, with string keys and arbitrary values.
63
+ """
64
+ import yaml
65
+
66
+ with open(config_name, "r") as stream:
67
+ try:
68
+ config: Dict[str, Any] = yaml.safe_load(stream)
69
+ except yaml.YAMLError as exc:
70
+ logger.error(exc)
71
+ return config
72
+
73
+
74
+ def get_matcher_zoo(
75
+ matcher_zoo: Dict[str, Dict[str, Union[str, bool]]],
76
+ ) -> Dict[str, Dict[str, Union[Callable, bool]]]:
77
+ """
78
+ Restore matcher configurations from a dictionary.
79
+
80
+ Args:
81
+ matcher_zoo: A dictionary with the matcher configurations,
82
+ where the configuration is a dictionary as loaded from a YAML file.
83
+
84
+ Returns:
85
+ A dictionary with the matcher configurations, where the configuration is
86
+ a function or a function instead of a string.
87
+ """
88
+ matcher_zoo_restored = {}
89
+ for k, v in matcher_zoo.items():
90
+ matcher_zoo_restored[k] = parse_match_config(v)
91
+ return matcher_zoo_restored
92
+
93
+
94
+ def parse_match_config(conf):
95
+ if conf["dense"]:
96
+ return {
97
+ "matcher": match_dense.confs.get(conf["matcher"]),
98
+ "dense": True,
99
+ "info": conf.get("info", {}),
100
+ }
101
+ else:
102
+ return {
103
+ "feature": extract_features.confs.get(conf["feature"]),
104
+ "matcher": match_features.confs.get(conf["matcher"]),
105
+ "dense": False,
106
+ "info": conf.get("info", {}),
107
+ }
108
+
109
+
110
+ def get_model(match_conf: Dict[str, Any]):
111
+ """
112
+ Load a matcher model from the provided configuration.
113
+
114
+ Args:
115
+ match_conf: A dictionary containing the model configuration.
116
+
117
+ Returns:
118
+ A matcher model instance.
119
+ """
120
+ Model = dynamic_load(matchers, match_conf["model"]["name"])
121
+ model = Model(match_conf["model"]).eval().to(DEVICE)
122
+ return model
123
+
124
+
125
+ def get_feature_model(conf: Dict[str, Dict[str, Any]]):
126
+ """
127
+ Load a feature extraction model from the provided configuration.
128
+
129
+ Args:
130
+ conf: A dictionary containing the model configuration.
131
+
132
+ Returns:
133
+ A feature extraction model instance.
134
+ """
135
+ Model = dynamic_load(extractors, conf["model"]["name"])
136
+ model = Model(conf["model"]).eval().to(DEVICE)
137
+ return model
138
+
139
+
140
+ def download_example_images(repo_id, output_dir):
141
+ logger.info(f"Download example dataset from huggingface: {repo_id}")
142
+ dataset = load_dataset(repo_id)
143
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
144
+ for example in dataset["train"]: # Assuming the dataset is in the "train" split
145
+ file_path = example["path"]
146
+ image = example["image"] # Access the PIL.Image object directly
147
+ full_path = os.path.join(output_dir, file_path)
148
+ Path(os.path.dirname(full_path)).mkdir(parents=True, exist_ok=True)
149
+ image.save(full_path)
150
+ logger.info(f"Images saved to {output_dir} successfully.")
151
+ return Path(output_dir)
152
+
153
+
154
+ def gen_examples(data_root: Path):
155
+ random.seed(1)
156
+ example_matchers = [
157
+ "disk+lightglue",
158
+ "xfeat(sparse)",
159
+ "dedode",
160
+ "loftr",
161
+ "disk",
162
+ "RoMa",
163
+ "d2net",
164
+ "aspanformer",
165
+ "topicfm",
166
+ "superpoint+superglue",
167
+ "superpoint+lightglue",
168
+ "superpoint+mnn",
169
+ "disk",
170
+ ]
171
+ data_root = Path(data_root)
172
+ if not Path(data_root).exists():
173
+ try:
174
+ download_example_images(DATASETS_REPO_ID, data_root)
175
+ except Exception as e:
176
+ logger.error(f"download_example_images error : {e}")
177
+ data_root = ROOT / "datasets"
178
+ if not Path(data_root / "sacre_coeur/mapping").exists():
179
+ download_example_images(DATASETS_REPO_ID, data_root)
180
+
181
+ def distribute_elements(A, B):
182
+ new_B = np.array(B, copy=True).flatten()
183
+ np.random.shuffle(new_B)
184
+ new_B = np.resize(new_B, len(A))
185
+ np.random.shuffle(new_B)
186
+ return new_B.tolist()
187
+
188
+ # normal examples
189
+ def gen_images_pairs(count: int = 5):
190
+ path = str(data_root / "sacre_coeur/mapping")
191
+ imgs_list = [
192
+ os.path.join(path, file)
193
+ for file in os.listdir(path)
194
+ if file.lower().endswith((".jpg", ".jpeg", ".png"))
195
+ ]
196
+ pairs = list(combinations(imgs_list, 2))
197
+ if len(pairs) < count:
198
+ count = len(pairs)
199
+ selected = random.sample(range(len(pairs)), count)
200
+ return [pairs[i] for i in selected]
201
+
202
+ # rotated examples
203
+ def gen_rot_image_pairs(count: int = 5):
204
+ path = data_root / "sacre_coeur/mapping"
205
+ path_rot = data_root / "sacre_coeur/mapping_rot"
206
+ rot_list = [45, 180, 90, 225, 270]
207
+ pairs = []
208
+ for file in os.listdir(path):
209
+ if file.lower().endswith((".jpg", ".jpeg", ".png")):
210
+ for rot in rot_list:
211
+ file_rot = "{}_rot{}.jpg".format(Path(file).stem, rot)
212
+ if (path_rot / file_rot).exists():
213
+ pairs.append(
214
+ [
215
+ path / file,
216
+ path_rot / file_rot,
217
+ ]
218
+ )
219
+ if len(pairs) < count:
220
+ count = len(pairs)
221
+ selected = random.sample(range(len(pairs)), count)
222
+ return [pairs[i] for i in selected]
223
+
224
+ def gen_scale_image_pairs(count: int = 5):
225
+ path = data_root / "sacre_coeur/mapping"
226
+ path_scale = data_root / "sacre_coeur/mapping_scale"
227
+ scale_list = [0.3, 0.5]
228
+ pairs = []
229
+ for file in os.listdir(path):
230
+ if file.lower().endswith((".jpg", ".jpeg", ".png")):
231
+ for scale in scale_list:
232
+ file_scale = "{}_scale{}.jpg".format(Path(file).stem, scale)
233
+ if (path_scale / file_scale).exists():
234
+ pairs.append(
235
+ [
236
+ path / file,
237
+ path_scale / file_scale,
238
+ ]
239
+ )
240
+ if len(pairs) < count:
241
+ count = len(pairs)
242
+ selected = random.sample(range(len(pairs)), count)
243
+ return [pairs[i] for i in selected]
244
+
245
+ # extramely hard examples
246
+ def gen_image_pairs_wxbs(count: int = None):
247
+ prefix = "wxbs_benchmark/.WxBS/v1.1"
248
+ wxbs_path = data_root / prefix
249
+ pairs = []
250
+ for catg in os.listdir(wxbs_path):
251
+ catg_path = wxbs_path / catg
252
+ if not catg_path.is_dir():
253
+ continue
254
+ for scene in os.listdir(catg_path):
255
+ scene_path = catg_path / scene
256
+ if not scene_path.is_dir():
257
+ continue
258
+ img1_path = scene_path / "01.png"
259
+ img2_path = scene_path / "02.png"
260
+ if img1_path.exists() and img2_path.exists():
261
+ pairs.append([str(img1_path), str(img2_path)])
262
+ return pairs
263
+
264
+ # image pair path
265
+ pairs = gen_images_pairs()
266
+ pairs += gen_rot_image_pairs()
267
+ pairs += gen_scale_image_pairs()
268
+ pairs += gen_image_pairs_wxbs()
269
+
270
+ match_setting_threshold = DEFAULT_SETTING_THRESHOLD
271
+ match_setting_max_features = DEFAULT_SETTING_MAX_FEATURES
272
+ detect_keypoints_threshold = DEFAULT_DEFAULT_KEYPOINT_THRESHOLD
273
+ ransac_method = DEFAULT_RANSAC_METHOD
274
+ ransac_reproj_threshold = DEFAULT_RANSAC_REPROJ_THRESHOLD
275
+ ransac_confidence = DEFAULT_RANSAC_CONFIDENCE
276
+ ransac_max_iter = DEFAULT_RANSAC_MAX_ITER
277
+ input_lists = []
278
+ dist_examples = distribute_elements(pairs, example_matchers)
279
+ for pair, mt in zip(pairs, dist_examples):
280
+ input_lists.append(
281
+ [
282
+ pair[0],
283
+ pair[1],
284
+ match_setting_threshold,
285
+ match_setting_max_features,
286
+ detect_keypoints_threshold,
287
+ mt,
288
+ # enable_ransac,
289
+ ransac_method,
290
+ ransac_reproj_threshold,
291
+ ransac_confidence,
292
+ ransac_max_iter,
293
+ ]
294
+ )
295
+ return input_lists
296
+
297
+
298
+ def set_null_pred(feature_type: str, pred: dict):
299
+ if feature_type == "KEYPOINT":
300
+ pred["mmkeypoints0_orig"] = np.array([])
301
+ pred["mmkeypoints1_orig"] = np.array([])
302
+ pred["mmconf"] = np.array([])
303
+ elif feature_type == "LINE":
304
+ pred["mline_keypoints0_orig"] = np.array([])
305
+ pred["mline_keypoints1_orig"] = np.array([])
306
+ pred["H"] = None
307
+ pred["geom_info"] = {}
308
+ return pred
309
+
310
+
311
+ def _filter_matches_opencv(
312
+ kp0: np.ndarray,
313
+ kp1: np.ndarray,
314
+ method: int = cv2.RANSAC,
315
+ reproj_threshold: float = 3.0,
316
+ confidence: float = 0.99,
317
+ max_iter: int = 2000,
318
+ geometry_type: str = "Homography",
319
+ ) -> Tuple[np.ndarray, np.ndarray]:
320
+ """
321
+ Filters matches between two sets of keypoints using OpenCV's findHomography.
322
+
323
+ Args:
324
+ kp0 (np.ndarray): Array of keypoints from the first image.
325
+ kp1 (np.ndarray): Array of keypoints from the second image.
326
+ method (int, optional): RANSAC method. Defaults to "cv2.RANSAC".
327
+ reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3.0.
328
+ confidence (float, optional): RANSAC confidence. Defaults to 0.99.
329
+ max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000.
330
+ geometry_type (str, optional): Type of geometry. Defaults to "Homography".
331
+
332
+ Returns:
333
+ Tuple[np.ndarray, np.ndarray]: Homography matrix and mask.
334
+ """
335
+ if geometry_type == "Homography":
336
+ try:
337
+ M, mask = cv2.findHomography(
338
+ kp0,
339
+ kp1,
340
+ method=method,
341
+ ransacReprojThreshold=reproj_threshold,
342
+ confidence=confidence,
343
+ maxIters=max_iter,
344
+ )
345
+ except cv2.error:
346
+ logger.error("compute findHomography error, len(kp0): {}".format(len(kp0)))
347
+ return None, None
348
+ elif geometry_type == "Fundamental":
349
+ try:
350
+ M, mask = cv2.findFundamentalMat(
351
+ kp0,
352
+ kp1,
353
+ method=method,
354
+ ransacReprojThreshold=reproj_threshold,
355
+ confidence=confidence,
356
+ maxIters=max_iter,
357
+ )
358
+ except cv2.error:
359
+ logger.error(
360
+ "compute findFundamentalMat error, len(kp0): {}".format(len(kp0))
361
+ )
362
+ return None, None
363
+ mask = np.array(mask.ravel().astype("bool"), dtype="bool")
364
+ return M, mask
365
+
366
+
367
+ def _filter_matches_poselib(
368
+ kp0: np.ndarray,
369
+ kp1: np.ndarray,
370
+ method: int = None, # not used
371
+ reproj_threshold: float = 3,
372
+ confidence: float = 0.99,
373
+ max_iter: int = 2000,
374
+ geometry_type: str = "Homography",
375
+ ) -> dict:
376
+ """
377
+ Filters matches between two sets of keypoints using the poselib library.
378
+
379
+ Args:
380
+ kp0 (np.ndarray): Array of keypoints from the first image.
381
+ kp1 (np.ndarray): Array of keypoints from the second image.
382
+ method (str, optional): RANSAC method. Defaults to "RANSAC".
383
+ reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to 3.
384
+ confidence (float, optional): RANSAC confidence. Defaults to 0.99.
385
+ max_iter (int, optional): RANSAC maximum iterations. Defaults to 2000.
386
+ geometry_type (str, optional): Type of geometry. Defaults to "Homography".
387
+
388
+ Returns:
389
+ dict: Information about the homography estimation.
390
+ """
391
+ ransac_options = {
392
+ "max_iterations": max_iter,
393
+ # "min_iterations": min_iter,
394
+ "success_prob": confidence,
395
+ "max_reproj_error": reproj_threshold,
396
+ # "progressive_sampling": args.sampler.lower() == 'prosac'
397
+ }
398
+
399
+ if geometry_type == "Homography":
400
+ M, info = poselib.estimate_homography(kp0, kp1, ransac_options)
401
+ elif geometry_type == "Fundamental":
402
+ M, info = poselib.estimate_fundamental(kp0, kp1, ransac_options)
403
+ else:
404
+ raise NotImplementedError
405
+
406
+ return M, np.array(info["inliers"])
407
+
408
+
409
+ def proc_ransac_matches(
410
+ mkpts0: np.ndarray,
411
+ mkpts1: np.ndarray,
412
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
413
+ ransac_reproj_threshold: float = 3.0,
414
+ ransac_confidence: float = 0.99,
415
+ ransac_max_iter: int = 2000,
416
+ geometry_type: str = "Homography",
417
+ ):
418
+ if ransac_method.startswith("CV2"):
419
+ logger.info(f"ransac_method: {ransac_method}, geometry_type: {geometry_type}")
420
+ return _filter_matches_opencv(
421
+ mkpts0,
422
+ mkpts1,
423
+ ransac_zoo[ransac_method],
424
+ ransac_reproj_threshold,
425
+ ransac_confidence,
426
+ ransac_max_iter,
427
+ geometry_type,
428
+ )
429
+ elif ransac_method.startswith("POSELIB"):
430
+ logger.info(f"ransac_method: {ransac_method}, geometry_type: {geometry_type}")
431
+ return _filter_matches_poselib(
432
+ mkpts0,
433
+ mkpts1,
434
+ None,
435
+ ransac_reproj_threshold,
436
+ ransac_confidence,
437
+ ransac_max_iter,
438
+ geometry_type,
439
+ )
440
+ else:
441
+ raise NotImplementedError
442
+
443
+
444
+ def filter_matches(
445
+ pred: Dict[str, Any],
446
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
447
+ ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
448
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
449
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
450
+ ransac_estimator: str = None,
451
+ ):
452
+ """
453
+ Filter matches using RANSAC. If keypoints are available, filter by keypoints.
454
+ If lines are available, filter by lines. If both keypoints and lines are
455
+ available, filter by keypoints.
456
+
457
+ Args:
458
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
459
+ ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
460
+ ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
461
+ ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
462
+ ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
463
+
464
+ Returns:
465
+ Dict[str, Any]: filtered matches.
466
+ """
467
+ mkpts0: Optional[np.ndarray] = None
468
+ mkpts1: Optional[np.ndarray] = None
469
+ feature_type: Optional[str] = None
470
+ if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys():
471
+ mkpts0 = pred["mkeypoints0_orig"]
472
+ mkpts1 = pred["mkeypoints1_orig"]
473
+ feature_type = "KEYPOINT"
474
+ elif (
475
+ "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys()
476
+ ):
477
+ mkpts0 = pred["line_keypoints0_orig"]
478
+ mkpts1 = pred["line_keypoints1_orig"]
479
+ feature_type = "LINE"
480
+ else:
481
+ return set_null_pred(feature_type, pred)
482
+ if mkpts0 is None or mkpts0 is None:
483
+ return set_null_pred(feature_type, pred)
484
+ if ransac_method not in ransac_zoo.keys():
485
+ ransac_method = DEFAULT_RANSAC_METHOD
486
+
487
+ if len(mkpts0) < DEFAULT_MIN_NUM_MATCHES:
488
+ return set_null_pred(feature_type, pred)
489
+
490
+ geom_info = compute_geometry(
491
+ pred,
492
+ ransac_method=ransac_method,
493
+ ransac_reproj_threshold=ransac_reproj_threshold,
494
+ ransac_confidence=ransac_confidence,
495
+ ransac_max_iter=ransac_max_iter,
496
+ )
497
+
498
+ if "Homography" in geom_info.keys():
499
+ mask = geom_info["mask_h"]
500
+ if feature_type == "KEYPOINT":
501
+ pred["mmkeypoints0_orig"] = mkpts0[mask]
502
+ pred["mmkeypoints1_orig"] = mkpts1[mask]
503
+ pred["mmconf"] = pred["mconf"][mask]
504
+ elif feature_type == "LINE":
505
+ pred["mline_keypoints0_orig"] = mkpts0[mask]
506
+ pred["mline_keypoints1_orig"] = mkpts1[mask]
507
+ pred["H"] = np.array(geom_info["Homography"])
508
+ else:
509
+ set_null_pred(feature_type, pred)
510
+ # do not show mask
511
+ geom_info.pop("mask_h", None)
512
+ geom_info.pop("mask_f", None)
513
+ pred["geom_info"] = geom_info
514
+ return pred
515
+
516
+
517
+ def compute_geometry(
518
+ pred: Dict[str, Any],
519
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
520
+ ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
521
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
522
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
523
+ ) -> Dict[str, List[float]]:
524
+ """
525
+ Compute geometric information of matches, including Fundamental matrix,
526
+ Homography matrix, and rectification matrices (if available).
527
+
528
+ Args:
529
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
530
+ ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
531
+ ransac_reproj_threshold (float, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
532
+ ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
533
+ ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
534
+
535
+ Returns:
536
+ Dict[str, List[float]]: geometric information in form of a dict.
537
+ """
538
+ mkpts0: Optional[np.ndarray] = None
539
+ mkpts1: Optional[np.ndarray] = None
540
+
541
+ if "mkeypoints0_orig" in pred.keys() and "mkeypoints1_orig" in pred.keys():
542
+ mkpts0 = pred["mkeypoints0_orig"]
543
+ mkpts1 = pred["mkeypoints1_orig"]
544
+ elif (
545
+ "line_keypoints0_orig" in pred.keys() and "line_keypoints1_orig" in pred.keys()
546
+ ):
547
+ mkpts0 = pred["line_keypoints0_orig"]
548
+ mkpts1 = pred["line_keypoints1_orig"]
549
+
550
+ if mkpts0 is not None and mkpts1 is not None:
551
+ if len(mkpts0) < 2 * DEFAULT_MIN_NUM_MATCHES:
552
+ return {}
553
+ geo_info: Dict[str, List[float]] = {}
554
+
555
+ F, mask_f = proc_ransac_matches(
556
+ mkpts0,
557
+ mkpts1,
558
+ ransac_method,
559
+ ransac_reproj_threshold,
560
+ ransac_confidence,
561
+ ransac_max_iter,
562
+ geometry_type="Fundamental",
563
+ )
564
+
565
+ if F is not None:
566
+ geo_info["Fundamental"] = F.tolist()
567
+ geo_info["mask_f"] = mask_f
568
+ H, mask_h = proc_ransac_matches(
569
+ mkpts0,
570
+ mkpts1,
571
+ ransac_method,
572
+ ransac_reproj_threshold,
573
+ ransac_confidence,
574
+ ransac_max_iter,
575
+ geometry_type="Homography",
576
+ )
577
+
578
+ h0, w0, _ = pred["image0_orig"].shape
579
+ if H is not None:
580
+ geo_info["Homography"] = H.tolist()
581
+ geo_info["mask_h"] = mask_h
582
+ try:
583
+ _, H1, H2 = cv2.stereoRectifyUncalibrated(
584
+ mkpts0.reshape(-1, 2),
585
+ mkpts1.reshape(-1, 2),
586
+ F,
587
+ imgSize=(w0, h0),
588
+ )
589
+ geo_info["H1"] = H1.tolist()
590
+ geo_info["H2"] = H2.tolist()
591
+ except cv2.error as e:
592
+ logger.error(f"StereoRectifyUncalibrated failed, skip! error: {e}")
593
+ return geo_info
594
+ else:
595
+ return {}
596
+
597
+
598
+ def wrap_images(
599
+ img0: np.ndarray,
600
+ img1: np.ndarray,
601
+ geo_info: Optional[Dict[str, List[float]]],
602
+ geom_type: str,
603
+ ) -> Tuple[Optional[str], Optional[Dict[str, List[float]]]]:
604
+ """
605
+ Wraps the images based on the geometric transformation used to align them.
606
+
607
+ Args:
608
+ img0: numpy array representing the first image.
609
+ img1: numpy array representing the second image.
610
+ geo_info: dictionary containing the geometric transformation information.
611
+ geom_type: type of geometric transformation used to align the images.
612
+
613
+ Returns:
614
+ A tuple containing a base64 encoded image string and a dictionary with the transformation matrix.
615
+ """
616
+ h0, w0, _ = img0.shape
617
+ h1, w1, _ = img1.shape
618
+ if geo_info is not None and len(geo_info) != 0:
619
+ rectified_image0 = img0
620
+ rectified_image1 = None
621
+ if "Homography" not in geo_info:
622
+ logger.warning(f"{geom_type} not exist, maybe too less matches")
623
+ return None, None
624
+
625
+ H = np.array(geo_info["Homography"])
626
+
627
+ title: List[str] = []
628
+ if geom_type == "Homography":
629
+ H_inv = np.linalg.inv(H)
630
+ rectified_image1 = cv2.warpPerspective(img1, H_inv, (w0, h0))
631
+ title = ["Image 0", "Image 1 - warped"]
632
+ elif geom_type == "Fundamental":
633
+ if geom_type not in geo_info:
634
+ logger.warning(f"{geom_type} not exist, maybe too less matches")
635
+ return None, None
636
+ else:
637
+ H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
638
+ rectified_image0 = cv2.warpPerspective(img0, H1, (w0, h0))
639
+ rectified_image1 = cv2.warpPerspective(img1, H2, (w1, h1))
640
+ title = ["Image 0 - warped", "Image 1 - warped"]
641
+ else:
642
+ print("Error: Unknown geometry type")
643
+ fig = plot_images(
644
+ [rectified_image0.squeeze(), rectified_image1.squeeze()],
645
+ title,
646
+ dpi=300,
647
+ )
648
+ return fig2im(fig), rectified_image1
649
+ else:
650
+ return None, None
651
+
652
+
653
+ def generate_warp_images(
654
+ input_image0: np.ndarray,
655
+ input_image1: np.ndarray,
656
+ matches_info: Dict[str, Any],
657
+ choice: str,
658
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
659
+ """
660
+ Changes the estimate of the geometric transformation used to align the images.
661
+
662
+ Args:
663
+ input_image0: First input image.
664
+ input_image1: Second input image.
665
+ matches_info: Dictionary containing information about the matches.
666
+ choice: Type of geometric transformation to use ('Homography' or 'Fundamental') or 'No' to disable.
667
+
668
+ Returns:
669
+ A tuple containing the updated images and the warpped images.
670
+ """
671
+ if (
672
+ matches_info is None
673
+ or len(matches_info) < 1
674
+ or "geom_info" not in matches_info.keys()
675
+ ):
676
+ return None, None
677
+ geom_info = matches_info["geom_info"]
678
+ warped_image = None
679
+ if choice != "No":
680
+ wrapped_image_pair, warped_image = wrap_images(
681
+ input_image0, input_image1, geom_info, choice
682
+ )
683
+ return wrapped_image_pair, warped_image
684
+ else:
685
+ return None, None
686
+
687
+
688
+ def send_to_match(state_cache: Dict[str, Any]):
689
+ """
690
+ Send the state cache to the match function.
691
+
692
+ Args:
693
+ state_cache (Dict[str, Any]): Current state of the app.
694
+
695
+ Returns:
696
+ None
697
+ """
698
+ if state_cache:
699
+ return (
700
+ state_cache["image0_orig"],
701
+ state_cache["wrapped_image"],
702
+ )
703
+ else:
704
+ return None, None
705
+
706
+
707
+ def run_ransac(
708
+ state_cache: Dict[str, Any],
709
+ choice_geometry_type: str,
710
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
711
+ ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
712
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
713
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
714
+ ) -> Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]:
715
+ """
716
+ Run RANSAC matches and return the output images and the number of matches.
717
+
718
+ Args:
719
+ state_cache (Dict[str, Any]): Current state of the app, including the matches.
720
+ ransac_method (str, optional): RANSAC method. Defaults to DEFAULT_RANSAC_METHOD.
721
+ ransac_reproj_threshold (int, optional): RANSAC reprojection threshold. Defaults to DEFAULT_RANSAC_REPROJ_THRESHOLD.
722
+ ransac_confidence (float, optional): RANSAC confidence. Defaults to DEFAULT_RANSAC_CONFIDENCE.
723
+ ransac_max_iter (int, optional): RANSAC maximum iterations. Defaults to DEFAULT_RANSAC_MAX_ITER.
724
+
725
+ Returns:
726
+ Tuple[Optional[np.ndarray], Optional[Dict[str, int]]]: Tuple containing the output images and the number of matches.
727
+ """
728
+ if not state_cache:
729
+ logger.info("Run Match first before Rerun RANSAC")
730
+ gr.Warning("Run Match first before Rerun RANSAC")
731
+ return None, None
732
+ t1 = time.time()
733
+ logger.info(
734
+ f"Run RANSAC matches using: {ransac_method} with threshold: {ransac_reproj_threshold}"
735
+ )
736
+ logger.info(
737
+ f"Run RANSAC matches using: {ransac_confidence} with iter: {ransac_max_iter}"
738
+ )
739
+ # if enable_ransac:
740
+ filter_matches(
741
+ state_cache,
742
+ ransac_method=ransac_method,
743
+ ransac_reproj_threshold=ransac_reproj_threshold,
744
+ ransac_confidence=ransac_confidence,
745
+ ransac_max_iter=ransac_max_iter,
746
+ )
747
+ logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
748
+ t1 = time.time()
749
+
750
+ # plot images with ransac matches
751
+ titles = [
752
+ "Image 0 - Ransac matched keypoints",
753
+ "Image 1 - Ransac matched keypoints",
754
+ ]
755
+ output_matches_ransac, num_matches_ransac = display_matches(
756
+ state_cache, titles=titles, tag="KPTS_RANSAC"
757
+ )
758
+ logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
759
+ t1 = time.time()
760
+
761
+ # compute warp images
762
+ output_wrapped, warped_image = generate_warp_images(
763
+ state_cache["image0_orig"],
764
+ state_cache["image1_orig"],
765
+ state_cache,
766
+ choice_geometry_type,
767
+ )
768
+ plt.close("all")
769
+
770
+ num_matches_raw = state_cache["num_matches_raw"]
771
+ state_cache["wrapped_image"] = warped_image
772
+
773
+ # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False)
774
+ tmp_state_cache = "output.pkl"
775
+ with open(tmp_state_cache, "wb") as f:
776
+ pickle.dump(state_cache, f)
777
+
778
+ logger.info("Dump results done!")
779
+
780
+ return (
781
+ output_matches_ransac,
782
+ {
783
+ "num_matches_raw": num_matches_raw,
784
+ "num_matches_ransac": num_matches_ransac,
785
+ },
786
+ output_wrapped,
787
+ tmp_state_cache,
788
+ )
789
+
790
+
791
+ def generate_fake_outputs(
792
+ output_keypoints,
793
+ output_matches_raw,
794
+ output_matches_ransac,
795
+ match_conf,
796
+ extract_conf,
797
+ pred,
798
+ ):
799
+ return (
800
+ output_keypoints,
801
+ output_matches_raw,
802
+ output_matches_ransac,
803
+ {},
804
+ {
805
+ "match_conf": match_conf,
806
+ "extractor_conf": extract_conf,
807
+ },
808
+ {
809
+ "geom_info": pred.get("geom_info", {}),
810
+ },
811
+ None,
812
+ None,
813
+ None,
814
+ )
815
+
816
+
817
+ def run_matching(
818
+ image0: np.ndarray,
819
+ image1: np.ndarray,
820
+ match_threshold: float,
821
+ extract_max_keypoints: int,
822
+ keypoint_threshold: float,
823
+ key: str,
824
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
825
+ ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
826
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
827
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
828
+ choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
829
+ matcher_zoo: Dict[str, Any] = None,
830
+ force_resize: bool = False,
831
+ image_width: int = 640,
832
+ image_height: int = 480,
833
+ use_cached_model: bool = True,
834
+ ) -> Tuple[
835
+ np.ndarray,
836
+ np.ndarray,
837
+ np.ndarray,
838
+ Dict[str, int],
839
+ Dict[str, Dict[str, Any]],
840
+ Dict[str, Dict[str, float]],
841
+ np.ndarray,
842
+ ]:
843
+ """Match two images using the given parameters.
844
+
845
+ Args:
846
+ image0 (np.ndarray): RGB image 0.
847
+ image1 (np.ndarray): RGB image 1.
848
+ match_threshold (float): match threshold.
849
+ extract_max_keypoints (int): number of keypoints to extract.
850
+ keypoint_threshold (float): keypoint threshold.
851
+ key (str): key of the model to use.
852
+ ransac_method (str, optional): RANSAC method to use.
853
+ ransac_reproj_threshold (int, optional): RANSAC reprojection threshold.
854
+ ransac_confidence (float, optional): RANSAC confidence level.
855
+ ransac_max_iter (int, optional): RANSAC maximum number of iterations.
856
+ choice_geometry_type (str, optional): setting of geometry estimation.
857
+ matcher_zoo (Dict[str, Any], optional): matcher zoo. Defaults to None.
858
+ force_resize (bool, optional): force resize. Defaults to False.
859
+ image_width (int, optional): image width. Defaults to 640.
860
+ image_height (int, optional): image height. Defaults to 480.
861
+ use_cached_model (bool, optional): use cached model. Defaults to False.
862
+
863
+ Returns:
864
+ tuple:
865
+ - output_keypoints (np.ndarray): image with keypoints.
866
+ - output_matches_raw (np.ndarray): image with raw matches.
867
+ - output_matches_ransac (np.ndarray): image with RANSAC matches.
868
+ - num_matches (Dict[str, int]): number of raw and RANSAC matches.
869
+ - configs (Dict[str, Dict[str, Any]]): match and feature extraction configs.
870
+ - geom_info (Dict[str, Dict[str, float]]): geometry information.
871
+ - output_wrapped (np.ndarray): wrapped images.
872
+ """
873
+ # image0 and image1 is RGB mode
874
+ if image0 is None or image1 is None:
875
+ logger.error(
876
+ "Error: No images found! Please upload two images or select an example."
877
+ )
878
+ raise gr.Error(
879
+ "Error: No images found! Please upload two images or select an example."
880
+ )
881
+ # init output
882
+ output_keypoints = None
883
+ output_matches_raw = None
884
+ output_matches_ransac = None
885
+
886
+ t0 = time.time()
887
+ model = matcher_zoo[key]
888
+ match_conf = model["matcher"]
889
+ # update match config
890
+ match_conf["model"]["match_threshold"] = match_threshold
891
+ match_conf["model"]["max_keypoints"] = extract_max_keypoints
892
+ cache_key = "{}_{}".format(key, match_conf["model"]["name"])
893
+
894
+ efficiency = model["info"].get("efficiency", "high")
895
+ if efficiency == "low":
896
+ gr.Warning(
897
+ "Matcher {} is time-consuming, please wait for a while".format(
898
+ model["info"].get("name", "unknown")
899
+ )
900
+ )
901
+
902
+ if use_cached_model:
903
+ # because of the model cache, we need to update the config
904
+ matcher = model_cache.load_model(cache_key, get_model, match_conf)
905
+ matcher.conf["max_keypoints"] = extract_max_keypoints
906
+ matcher.conf["match_threshold"] = match_threshold
907
+ logger.info(f"Loaded cached model {cache_key}")
908
+ else:
909
+ matcher = get_model(match_conf)
910
+ logger.info(f"Loading model using: {time.time()-t0:.3f}s")
911
+ t1 = time.time()
912
+ yield generate_fake_outputs(
913
+ output_keypoints, output_matches_raw, output_matches_ransac, match_conf, {}, {}
914
+ )
915
+
916
+ if model["dense"]:
917
+ if not match_conf["preprocessing"].get("force_resize", False):
918
+ match_conf["preprocessing"]["force_resize"] = force_resize
919
+ else:
920
+ logger.info("preprocessing is already resized")
921
+ if force_resize:
922
+ match_conf["preprocessing"]["height"] = image_height
923
+ match_conf["preprocessing"]["width"] = image_width
924
+ logger.info(f"Force resize to {image_width}x{image_height}")
925
+
926
+ pred = match_dense.match_images(
927
+ matcher, image0, image1, match_conf["preprocessing"], device=DEVICE
928
+ )
929
+ del matcher
930
+ extract_conf = None
931
+ else:
932
+ extract_conf = model["feature"]
933
+ # update extract config
934
+ extract_conf["model"]["max_keypoints"] = extract_max_keypoints
935
+ extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
936
+ cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
937
+
938
+ if use_cached_model:
939
+ extractor = model_cache.load_model(
940
+ cache_key, get_feature_model, extract_conf
941
+ )
942
+ # because of the model cache, we need to update the config
943
+ extractor.conf["max_keypoints"] = extract_max_keypoints
944
+ extractor.conf["keypoint_threshold"] = keypoint_threshold
945
+ logger.info(f"Loaded cached model {cache_key}")
946
+ else:
947
+ extractor = get_feature_model(extract_conf)
948
+
949
+ if not extract_conf["preprocessing"].get("force_resize", False):
950
+ extract_conf["preprocessing"]["force_resize"] = force_resize
951
+ else:
952
+ logger.info("preprocessing is already resized")
953
+ if force_resize:
954
+ extract_conf["preprocessing"]["height"] = image_height
955
+ extract_conf["preprocessing"]["width"] = image_width
956
+ logger.info(f"Force resize to {image_width}x{image_height}")
957
+
958
+ pred0 = extract_features.extract(
959
+ extractor, image0, extract_conf["preprocessing"]
960
+ )
961
+ pred1 = extract_features.extract(
962
+ extractor, image1, extract_conf["preprocessing"]
963
+ )
964
+ pred = match_features.match_images(matcher, pred0, pred1)
965
+ del extractor
966
+ # gr.Info(
967
+ # f"Matching images done using: {time.time()-t1:.3f}s",
968
+ # )
969
+ logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
970
+ t1 = time.time()
971
+
972
+ # plot images with keypoints
973
+ titles = [
974
+ "Image 0 - Keypoints",
975
+ "Image 1 - Keypoints",
976
+ ]
977
+ output_keypoints = display_keypoints(pred, titles=titles)
978
+ yield generate_fake_outputs(
979
+ output_keypoints,
980
+ output_matches_raw,
981
+ output_matches_ransac,
982
+ match_conf,
983
+ extract_conf,
984
+ pred,
985
+ )
986
+
987
+ # plot images with raw matches
988
+ titles = [
989
+ "Image 0 - Raw matched keypoints",
990
+ "Image 1 - Raw matched keypoints",
991
+ ]
992
+ output_matches_raw, num_matches_raw = display_matches(pred, titles=titles)
993
+ yield generate_fake_outputs(
994
+ output_keypoints,
995
+ output_matches_raw,
996
+ output_matches_ransac,
997
+ match_conf,
998
+ extract_conf,
999
+ pred,
1000
+ )
1001
+
1002
+ # if enable_ransac:
1003
+ filter_matches(
1004
+ pred,
1005
+ ransac_method=ransac_method,
1006
+ ransac_reproj_threshold=ransac_reproj_threshold,
1007
+ ransac_confidence=ransac_confidence,
1008
+ ransac_max_iter=ransac_max_iter,
1009
+ )
1010
+
1011
+ # gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
1012
+ logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
1013
+ t1 = time.time()
1014
+
1015
+ # plot images with ransac matches
1016
+ titles = [
1017
+ "Image 0 - Ransac matched keypoints",
1018
+ "Image 1 - Ransac matched keypoints",
1019
+ ]
1020
+ output_matches_ransac, num_matches_ransac = display_matches(
1021
+ pred, titles=titles, tag="KPTS_RANSAC"
1022
+ )
1023
+ yield generate_fake_outputs(
1024
+ output_keypoints,
1025
+ output_matches_raw,
1026
+ output_matches_ransac,
1027
+ match_conf,
1028
+ extract_conf,
1029
+ pred,
1030
+ )
1031
+
1032
+ # gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
1033
+ logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
1034
+ t1 = time.time()
1035
+ # plot wrapped images
1036
+ output_wrapped, warped_image = generate_warp_images(
1037
+ pred["image0_orig"],
1038
+ pred["image1_orig"],
1039
+ pred,
1040
+ choice_geometry_type,
1041
+ )
1042
+ plt.close("all")
1043
+ # gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
1044
+ logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
1045
+
1046
+ state_cache = pred
1047
+ state_cache["num_matches_raw"] = num_matches_raw
1048
+ state_cache["num_matches_ransac"] = num_matches_ransac
1049
+ state_cache["wrapped_image"] = warped_image
1050
+
1051
+ # tmp_state_cache = tempfile.NamedTemporaryFile(suffix='.pkl', delete=False)
1052
+ tmp_state_cache = "output.pkl"
1053
+ with open(tmp_state_cache, "wb") as f:
1054
+ pickle.dump(state_cache, f)
1055
+ logger.info("Dump results done!")
1056
+
1057
+ yield (
1058
+ output_keypoints,
1059
+ output_matches_raw,
1060
+ output_matches_ransac,
1061
+ {
1062
+ "num_raw_matches": num_matches_raw,
1063
+ "num_ransac_matches": num_matches_ransac,
1064
+ },
1065
+ {
1066
+ "match_conf": match_conf,
1067
+ "extractor_conf": extract_conf,
1068
+ },
1069
+ {
1070
+ "geom_info": pred.get("geom_info", {}),
1071
+ },
1072
+ output_wrapped,
1073
+ state_cache,
1074
+ tmp_state_cache,
1075
+ )
1076
+
1077
+
1078
+ # @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html
1079
+ # AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs
1080
+ ransac_zoo = {
1081
+ "POSELIB": "LO-RANSAC",
1082
+ "CV2_RANSAC": cv2.RANSAC,
1083
+ "CV2_USAC_MAGSAC": cv2.USAC_MAGSAC,
1084
+ "CV2_USAC_DEFAULT": cv2.USAC_DEFAULT,
1085
+ "CV2_USAC_FM_8PTS": cv2.USAC_FM_8PTS,
1086
+ "CV2_USAC_PROSAC": cv2.USAC_PROSAC,
1087
+ "CV2_USAC_FAST": cv2.USAC_FAST,
1088
+ "CV2_USAC_ACCURATE": cv2.USAC_ACCURATE,
1089
+ "CV2_USAC_PARALLEL": cv2.USAC_PARALLEL,
1090
+ }
1091
+
1092
+
1093
+ def rotate_image(input_path, degrees, output_path):
1094
+ img = Image.open(input_path)
1095
+ img_rotated = img.rotate(-degrees)
1096
+ img_rotated.save(output_path)
1097
+
1098
+
1099
+ def scale_image(input_path, scale_factor, output_path):
1100
+ img = Image.open(input_path)
1101
+ width, height = img.size
1102
+ new_width = int(width * scale_factor)
1103
+ new_height = int(height * scale_factor)
1104
+ new_img = Image.new("RGB", (width, height), (0, 0, 0))
1105
+ img_resized = img.resize((new_width, new_height))
1106
+ position = ((width - new_width) // 2, (height - new_height) // 2)
1107
+ new_img.paste(img_resized, position)
1108
+ new_img.save(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
imcui/ui/viz.py CHANGED
@@ -1,481 +1,481 @@
1
- import typing
2
- from pathlib import Path
3
- from typing import Dict, List, Optional, Tuple, Union
4
-
5
- import cv2
6
- import matplotlib
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import seaborn as sns
10
-
11
- from ..hloc.utils.viz import add_text, plot_keypoints
12
-
13
- np.random.seed(1995)
14
- color_map = np.arange(100)
15
- np.random.shuffle(color_map)
16
-
17
-
18
- def plot_images(
19
- imgs: List[np.ndarray],
20
- titles: Optional[List[str]] = None,
21
- cmaps: Union[str, List[str]] = "gray",
22
- dpi: int = 100,
23
- size: Optional[int] = 5,
24
- pad: float = 0.5,
25
- ) -> plt.Figure:
26
- """Plot a set of images horizontally.
27
- Args:
28
- imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
29
- titles: a list of strings, as titles for each image.
30
- cmaps: colormaps for monochrome images. If a single string is given,
31
- it is used for all images.
32
- dpi: DPI of the figure.
33
- size: figure size in inches (width). If not provided, the figure
34
- size is determined automatically.
35
- pad: padding between subplots, in inches.
36
- Returns:
37
- The created figure.
38
- """
39
- n = len(imgs)
40
- if not isinstance(cmaps, list):
41
- cmaps = [cmaps] * n
42
- figsize = (size * n, size * 6 / 5) if size is not None else None
43
- fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
44
-
45
- if n == 1:
46
- ax = [ax]
47
- for i in range(n):
48
- ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
49
- ax[i].get_yaxis().set_ticks([])
50
- ax[i].get_xaxis().set_ticks([])
51
- ax[i].set_axis_off()
52
- for spine in ax[i].spines.values(): # remove frame
53
- spine.set_visible(False)
54
- if titles:
55
- ax[i].set_title(titles[i])
56
- fig.tight_layout(pad=pad)
57
- return fig
58
-
59
-
60
- def plot_color_line_matches(
61
- lines: List[np.ndarray],
62
- correct_matches: Optional[np.ndarray] = None,
63
- lw: float = 2.0,
64
- indices: Tuple[int, int] = (0, 1),
65
- ) -> matplotlib.figure.Figure:
66
- """Plot line matches for existing images with multiple colors.
67
-
68
- Args:
69
- lines: List of ndarrays of size (N, 2, 2) representing line segments.
70
- correct_matches: Optional bool array of size (N,) indicating correct
71
- matches. If not None, display wrong matches with a low alpha.
72
- lw: Line width as float pixels.
73
- indices: Indices of the images to draw the matches on.
74
-
75
- Returns:
76
- The modified matplotlib figure.
77
- """
78
- n_lines = lines[0].shape[0]
79
- colors = sns.color_palette("husl", n_colors=n_lines)
80
- np.random.shuffle(colors)
81
- alphas = np.ones(n_lines)
82
- if correct_matches is not None:
83
- alphas[~np.array(correct_matches)] = 0.2
84
-
85
- fig = plt.gcf()
86
- ax = typing.cast(List[matplotlib.axes.Axes], fig.axes)
87
- assert len(ax) > max(indices)
88
- axes = [ax[i] for i in indices]
89
- fig.canvas.draw()
90
-
91
- # Plot the lines
92
- for a, l in zip(axes, lines): # noqa: E741
93
- # Transform the points into the figure coordinate system
94
- transFigure = fig.transFigure.inverted()
95
- endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
96
- endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
97
- fig.lines += [
98
- matplotlib.lines.Line2D(
99
- (endpoint0[i, 0], endpoint1[i, 0]),
100
- (endpoint0[i, 1], endpoint1[i, 1]),
101
- zorder=1,
102
- transform=fig.transFigure,
103
- c=colors[i],
104
- alpha=alphas[i],
105
- linewidth=lw,
106
- )
107
- for i in range(n_lines)
108
- ]
109
-
110
- return fig
111
-
112
-
113
- def make_matching_figure(
114
- img0: np.ndarray,
115
- img1: np.ndarray,
116
- mkpts0: np.ndarray,
117
- mkpts1: np.ndarray,
118
- color: np.ndarray,
119
- titles: Optional[List[str]] = None,
120
- kpts0: Optional[np.ndarray] = None,
121
- kpts1: Optional[np.ndarray] = None,
122
- text: List[str] = [],
123
- dpi: int = 75,
124
- path: Optional[Path] = None,
125
- pad: float = 0.0,
126
- ) -> Optional[plt.Figure]:
127
- """Draw image pair with matches.
128
-
129
- Args:
130
- img0: image0 as HxWx3 numpy array.
131
- img1: image1 as HxWx3 numpy array.
132
- mkpts0: matched points in image0 as Nx2 numpy array.
133
- mkpts1: matched points in image1 as Nx2 numpy array.
134
- color: colors for the matches as Nx4 numpy array.
135
- titles: titles for the two subplots.
136
- kpts0: keypoints in image0 as Kx2 numpy array.
137
- kpts1: keypoints in image1 as Kx2 numpy array.
138
- text: list of strings to display in the top-left corner of the image.
139
- dpi: dots per inch of the saved figure.
140
- path: if not None, save the figure to this path.
141
- pad: padding around the image as a fraction of the image size.
142
-
143
- Returns:
144
- The matplotlib Figure object if path is None.
145
- """
146
- # draw image pair
147
- fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
148
- axes[0].imshow(img0) # , cmap='gray')
149
- axes[1].imshow(img1) # , cmap='gray')
150
- for i in range(2): # clear all frames
151
- axes[i].get_yaxis().set_ticks([])
152
- axes[i].get_xaxis().set_ticks([])
153
- for spine in axes[i].spines.values():
154
- spine.set_visible(False)
155
- if titles is not None:
156
- axes[i].set_title(titles[i])
157
-
158
- plt.tight_layout(pad=pad)
159
-
160
- if kpts0 is not None:
161
- assert kpts1 is not None
162
- axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5)
163
- axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
164
-
165
- # draw matches
166
- if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0 and mkpts0.shape == mkpts1.shape:
167
- fig.canvas.draw()
168
- transFigure = fig.transFigure.inverted()
169
- fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
170
- fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
171
- fig.lines = [
172
- matplotlib.lines.Line2D(
173
- (fkpts0[i, 0], fkpts1[i, 0]),
174
- (fkpts0[i, 1], fkpts1[i, 1]),
175
- transform=fig.transFigure,
176
- c=color[i],
177
- linewidth=2,
178
- )
179
- for i in range(len(mkpts0))
180
- ]
181
-
182
- # freeze the axes to prevent the transform to change
183
- axes[0].autoscale(enable=False)
184
- axes[1].autoscale(enable=False)
185
-
186
- axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4)
187
- axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4)
188
-
189
- # put txts
190
- txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
191
- fig.text(
192
- 0.01,
193
- 0.99,
194
- "\n".join(text),
195
- transform=fig.axes[0].transAxes,
196
- fontsize=15,
197
- va="top",
198
- ha="left",
199
- color=txt_color,
200
- )
201
-
202
- # save or return figure
203
- if path:
204
- plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
205
- plt.close()
206
- else:
207
- return fig
208
-
209
-
210
- def error_colormap(err: np.ndarray, thr: float, alpha: float = 1.0) -> np.ndarray:
211
- """
212
- Create a colormap based on the error values.
213
-
214
- Args:
215
- err: Error values as a numpy array of shape (N,).
216
- thr: Threshold value for the error.
217
- alpha: Alpha value for the colormap, between 0 and 1.
218
-
219
- Returns:
220
- Colormap as a numpy array of shape (N, 4) with values in [0, 1].
221
- """
222
- assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
223
- x = 1 - np.clip(err / (thr * 2), 0, 1)
224
- return np.clip(
225
- np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
226
- 0,
227
- 1,
228
- )
229
-
230
-
231
- def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray:
232
- """
233
- Convert a matplotlib figure to a numpy array with RGB values.
234
-
235
- Args:
236
- fig: A matplotlib figure.
237
-
238
- Returns:
239
- A numpy array with shape (height, width, 3) and dtype uint8 containing
240
- the RGB values of the figure.
241
- """
242
- fig.canvas.draw()
243
- (width, height) = fig.canvas.get_width_height()
244
- buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1")
245
- return buf_ndarray.reshape(height, width, 3)
246
-
247
-
248
- def draw_matches_core(
249
- mkpts0: List[np.ndarray],
250
- mkpts1: List[np.ndarray],
251
- img0: np.ndarray,
252
- img1: np.ndarray,
253
- conf: np.ndarray,
254
- titles: Optional[List[str]] = None,
255
- texts: Optional[List[str]] = None,
256
- dpi: int = 150,
257
- path: Optional[str] = None,
258
- pad: float = 0.5,
259
- ) -> np.ndarray:
260
- """
261
- Draw matches between two images.
262
-
263
- Args:
264
- mkpts0: List of matches from the first image, with shape (N, 2)
265
- mkpts1: List of matches from the second image, with shape (N, 2)
266
- img0: First image, with shape (H, W, 3)
267
- img1: Second image, with shape (H, W, 3)
268
- conf: Confidence values for the matches, with shape (N,)
269
- titles: Optional list of title strings for the plot
270
- dpi: DPI for the saved image
271
- path: Optional path to save the image to. If None, the image is not saved.
272
- pad: Padding between subplots
273
-
274
- Returns:
275
- The figure as a numpy array with shape (height, width, 3) and dtype uint8
276
- containing the RGB values of the figure.
277
- """
278
- thr = 0.5
279
- color = error_colormap(1 - conf, thr, alpha=0.1)
280
- text = [
281
- # "image name",
282
- f"#Matches: {len(mkpts0)}",
283
- ]
284
- if path:
285
- fig2im(
286
- make_matching_figure(
287
- img0,
288
- img1,
289
- mkpts0,
290
- mkpts1,
291
- color,
292
- titles=titles,
293
- text=text,
294
- path=path,
295
- dpi=dpi,
296
- pad=pad,
297
- )
298
- )
299
- else:
300
- return fig2im(
301
- make_matching_figure(
302
- img0,
303
- img1,
304
- mkpts0,
305
- mkpts1,
306
- color,
307
- titles=titles,
308
- text=text,
309
- pad=pad,
310
- dpi=dpi,
311
- )
312
- )
313
-
314
-
315
- def draw_image_pairs(
316
- img0: np.ndarray,
317
- img1: np.ndarray,
318
- text: List[str] = [],
319
- dpi: int = 75,
320
- path: Optional[str] = None,
321
- pad: float = 0.5,
322
- ) -> np.ndarray:
323
- """Draw image pair horizontally.
324
-
325
- Args:
326
- img0: First image, with shape (H, W, 3)
327
- img1: Second image, with shape (H, W, 3)
328
- text: List of strings to print. Each string is a new line.
329
- dpi: DPI of the figure.
330
- path: Path to save the image to. If None, the image is not saved and
331
- the function returns the figure as a numpy array with shape
332
- (height, width, 3) and dtype uint8 containing the RGB values of the
333
- figure.
334
- pad: Padding between subplots
335
-
336
- Returns:
337
- The figure as a numpy array with shape (height, width, 3) and dtype uint8
338
- containing the RGB values of the figure, or None if path is not None.
339
- """
340
- # draw image pair
341
- fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
342
- axes[0].imshow(img0) # , cmap='gray')
343
- axes[1].imshow(img1) # , cmap='gray')
344
- for i in range(2): # clear all frames
345
- axes[i].get_yaxis().set_ticks([])
346
- axes[i].get_xaxis().set_ticks([])
347
- for spine in axes[i].spines.values():
348
- spine.set_visible(False)
349
- plt.tight_layout(pad=pad)
350
-
351
- # put txts
352
- txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
353
- fig.text(
354
- 0.01,
355
- 0.99,
356
- "\n".join(text),
357
- transform=fig.axes[0].transAxes,
358
- fontsize=15,
359
- va="top",
360
- ha="left",
361
- color=txt_color,
362
- )
363
-
364
- # save or return figure
365
- if path:
366
- plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
367
- plt.close()
368
- else:
369
- return fig2im(fig)
370
-
371
-
372
- def display_keypoints(pred: dict, titles: List[str] = []):
373
- img0 = pred["image0_orig"]
374
- img1 = pred["image1_orig"]
375
- output_keypoints = plot_images([img0, img1], titles=titles, dpi=300)
376
- if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
377
- plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
378
- text = (
379
- f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
380
- + f"# keypoints1: {len(pred['keypoints1_orig'])}"
381
- )
382
- add_text(0, text, fs=15)
383
- output_keypoints = fig2im(output_keypoints)
384
- return output_keypoints
385
-
386
-
387
- def display_matches(
388
- pred: Dict[str, np.ndarray],
389
- titles: List[str] = [],
390
- texts: List[str] = [],
391
- dpi: int = 300,
392
- tag: str = "KPTS_RAW", # KPTS_RAW, KPTS_RANSAC, LINES_RAW, LINES_RANSAC,
393
- ) -> Tuple[np.ndarray, int]:
394
- """
395
- Displays the matches between two images.
396
-
397
- Args:
398
- pred: Dictionary containing the original images and the matches.
399
- titles: Optional titles for the plot.
400
- dpi: Resolution of the plot.
401
-
402
- Returns:
403
- The resulting concatenated plot and the number of inliers.
404
- """
405
- img0 = pred["image0_orig"]
406
- img1 = pred["image1_orig"]
407
- num_inliers = 0
408
- KPTS0_KEY = None
409
- KPTS1_KEY = None
410
- confid = None
411
- if tag == "KPTS_RAW":
412
- KPTS0_KEY = "mkeypoints0_orig"
413
- KPTS1_KEY = "mkeypoints1_orig"
414
- if "mconf" in pred:
415
- confid = pred["mconf"]
416
- elif tag == "KPTS_RANSAC":
417
- KPTS0_KEY = "mmkeypoints0_orig"
418
- KPTS1_KEY = "mmkeypoints1_orig"
419
- if "mmconf" in pred:
420
- confid = pred["mmconf"]
421
- else:
422
- # TODO: LINES_RAW, LINES_RANSAC
423
- raise ValueError(f"Unknown tag: {tag}")
424
- # draw raw matches
425
- if (
426
- KPTS0_KEY in pred
427
- and KPTS1_KEY in pred
428
- and pred[KPTS0_KEY] is not None
429
- and pred[KPTS1_KEY] is not None
430
- ): # draw ransac matches
431
- mkpts0 = pred[KPTS0_KEY]
432
- mkpts1 = pred[KPTS1_KEY]
433
- num_inliers = len(mkpts0)
434
- if confid is None:
435
- confid = np.ones(len(mkpts0))
436
- fig_mkpts = draw_matches_core(
437
- mkpts0,
438
- mkpts1,
439
- img0,
440
- img1,
441
- confid,
442
- dpi=dpi,
443
- titles=titles,
444
- texts=texts,
445
- )
446
- fig = fig_mkpts
447
- elif (
448
- "line0_orig" in pred
449
- and "line1_orig" in pred
450
- and pred["line0_orig"] is not None
451
- and pred["line1_orig"] is not None
452
- # and (tag == "LINES_RAW" or tag == "LINES_RANSAC")
453
- ):
454
- # lines
455
- mtlines0 = pred["line0_orig"]
456
- mtlines1 = pred["line1_orig"]
457
- num_inliers = len(mtlines0)
458
- fig_lines = plot_images(
459
- [img0.squeeze(), img1.squeeze()],
460
- ["Image 0 - matched lines", "Image 1 - matched lines"],
461
- dpi=300,
462
- )
463
- fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2)
464
- fig_lines = fig2im(fig_lines)
465
-
466
- # keypoints
467
- mkpts0 = pred.get("line_keypoints0_orig")
468
- mkpts1 = pred.get("line_keypoints1_orig")
469
- fig = None
470
- if mkpts0 is not None and mkpts1 is not None:
471
- num_inliers = len(mkpts0)
472
- if "mconf" in pred:
473
- mconf = pred["mconf"]
474
- else:
475
- mconf = np.ones(len(mkpts0))
476
- fig_mkpts = draw_matches_core(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
477
- fig_lines = cv2.resize(fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]))
478
- fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
479
- else:
480
- fig = fig_lines
481
- return fig, num_inliers
 
1
+ import typing
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import cv2
6
+ import matplotlib
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import seaborn as sns
10
+
11
+ from ..hloc.utils.viz import add_text, plot_keypoints
12
+
13
+ np.random.seed(1995)
14
+ color_map = np.arange(100)
15
+ np.random.shuffle(color_map)
16
+
17
+
18
+ def plot_images(
19
+ imgs: List[np.ndarray],
20
+ titles: Optional[List[str]] = None,
21
+ cmaps: Union[str, List[str]] = "gray",
22
+ dpi: int = 100,
23
+ size: Optional[int] = 5,
24
+ pad: float = 0.5,
25
+ ) -> plt.Figure:
26
+ """Plot a set of images horizontally.
27
+ Args:
28
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
29
+ titles: a list of strings, as titles for each image.
30
+ cmaps: colormaps for monochrome images. If a single string is given,
31
+ it is used for all images.
32
+ dpi: DPI of the figure.
33
+ size: figure size in inches (width). If not provided, the figure
34
+ size is determined automatically.
35
+ pad: padding between subplots, in inches.
36
+ Returns:
37
+ The created figure.
38
+ """
39
+ n = len(imgs)
40
+ if not isinstance(cmaps, list):
41
+ cmaps = [cmaps] * n
42
+ figsize = (size * n, size * 6 / 5) if size is not None else None
43
+ fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
44
+
45
+ if n == 1:
46
+ ax = [ax]
47
+ for i in range(n):
48
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
49
+ ax[i].get_yaxis().set_ticks([])
50
+ ax[i].get_xaxis().set_ticks([])
51
+ ax[i].set_axis_off()
52
+ for spine in ax[i].spines.values(): # remove frame
53
+ spine.set_visible(False)
54
+ if titles:
55
+ ax[i].set_title(titles[i])
56
+ fig.tight_layout(pad=pad)
57
+ return fig
58
+
59
+
60
+ def plot_color_line_matches(
61
+ lines: List[np.ndarray],
62
+ correct_matches: Optional[np.ndarray] = None,
63
+ lw: float = 2.0,
64
+ indices: Tuple[int, int] = (0, 1),
65
+ ) -> matplotlib.figure.Figure:
66
+ """Plot line matches for existing images with multiple colors.
67
+
68
+ Args:
69
+ lines: List of ndarrays of size (N, 2, 2) representing line segments.
70
+ correct_matches: Optional bool array of size (N,) indicating correct
71
+ matches. If not None, display wrong matches with a low alpha.
72
+ lw: Line width as float pixels.
73
+ indices: Indices of the images to draw the matches on.
74
+
75
+ Returns:
76
+ The modified matplotlib figure.
77
+ """
78
+ n_lines = lines[0].shape[0]
79
+ colors = sns.color_palette("husl", n_colors=n_lines)
80
+ np.random.shuffle(colors)
81
+ alphas = np.ones(n_lines)
82
+ if correct_matches is not None:
83
+ alphas[~np.array(correct_matches)] = 0.2
84
+
85
+ fig = plt.gcf()
86
+ ax = typing.cast(List[matplotlib.axes.Axes], fig.axes)
87
+ assert len(ax) > max(indices)
88
+ axes = [ax[i] for i in indices]
89
+ fig.canvas.draw()
90
+
91
+ # Plot the lines
92
+ for a, l in zip(axes, lines): # noqa: E741
93
+ # Transform the points into the figure coordinate system
94
+ transFigure = fig.transFigure.inverted()
95
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
96
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
97
+ fig.lines += [
98
+ matplotlib.lines.Line2D(
99
+ (endpoint0[i, 0], endpoint1[i, 0]),
100
+ (endpoint0[i, 1], endpoint1[i, 1]),
101
+ zorder=1,
102
+ transform=fig.transFigure,
103
+ c=colors[i],
104
+ alpha=alphas[i],
105
+ linewidth=lw,
106
+ )
107
+ for i in range(n_lines)
108
+ ]
109
+
110
+ return fig
111
+
112
+
113
+ def make_matching_figure(
114
+ img0: np.ndarray,
115
+ img1: np.ndarray,
116
+ mkpts0: np.ndarray,
117
+ mkpts1: np.ndarray,
118
+ color: np.ndarray,
119
+ titles: Optional[List[str]] = None,
120
+ kpts0: Optional[np.ndarray] = None,
121
+ kpts1: Optional[np.ndarray] = None,
122
+ text: List[str] = [],
123
+ dpi: int = 75,
124
+ path: Optional[Path] = None,
125
+ pad: float = 0.0,
126
+ ) -> Optional[plt.Figure]:
127
+ """Draw image pair with matches.
128
+
129
+ Args:
130
+ img0: image0 as HxWx3 numpy array.
131
+ img1: image1 as HxWx3 numpy array.
132
+ mkpts0: matched points in image0 as Nx2 numpy array.
133
+ mkpts1: matched points in image1 as Nx2 numpy array.
134
+ color: colors for the matches as Nx4 numpy array.
135
+ titles: titles for the two subplots.
136
+ kpts0: keypoints in image0 as Kx2 numpy array.
137
+ kpts1: keypoints in image1 as Kx2 numpy array.
138
+ text: list of strings to display in the top-left corner of the image.
139
+ dpi: dots per inch of the saved figure.
140
+ path: if not None, save the figure to this path.
141
+ pad: padding around the image as a fraction of the image size.
142
+
143
+ Returns:
144
+ The matplotlib Figure object if path is None.
145
+ """
146
+ # draw image pair
147
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
148
+ axes[0].imshow(img0) # , cmap='gray')
149
+ axes[1].imshow(img1) # , cmap='gray')
150
+ for i in range(2): # clear all frames
151
+ axes[i].get_yaxis().set_ticks([])
152
+ axes[i].get_xaxis().set_ticks([])
153
+ for spine in axes[i].spines.values():
154
+ spine.set_visible(False)
155
+ if titles is not None:
156
+ axes[i].set_title(titles[i])
157
+
158
+ plt.tight_layout(pad=pad)
159
+
160
+ if kpts0 is not None:
161
+ assert kpts1 is not None
162
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=5)
163
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
164
+
165
+ # draw matches
166
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0 and mkpts0.shape == mkpts1.shape:
167
+ fig.canvas.draw()
168
+ transFigure = fig.transFigure.inverted()
169
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
170
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
171
+ fig.lines = [
172
+ matplotlib.lines.Line2D(
173
+ (fkpts0[i, 0], fkpts1[i, 0]),
174
+ (fkpts0[i, 1], fkpts1[i, 1]),
175
+ transform=fig.transFigure,
176
+ c=color[i],
177
+ linewidth=2,
178
+ )
179
+ for i in range(len(mkpts0))
180
+ ]
181
+
182
+ # freeze the axes to prevent the transform to change
183
+ axes[0].autoscale(enable=False)
184
+ axes[1].autoscale(enable=False)
185
+
186
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4)
187
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4)
188
+
189
+ # put txts
190
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
191
+ fig.text(
192
+ 0.01,
193
+ 0.99,
194
+ "\n".join(text),
195
+ transform=fig.axes[0].transAxes,
196
+ fontsize=15,
197
+ va="top",
198
+ ha="left",
199
+ color=txt_color,
200
+ )
201
+
202
+ # save or return figure
203
+ if path:
204
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
205
+ plt.close()
206
+ else:
207
+ return fig
208
+
209
+
210
+ def error_colormap(err: np.ndarray, thr: float, alpha: float = 1.0) -> np.ndarray:
211
+ """
212
+ Create a colormap based on the error values.
213
+
214
+ Args:
215
+ err: Error values as a numpy array of shape (N,).
216
+ thr: Threshold value for the error.
217
+ alpha: Alpha value for the colormap, between 0 and 1.
218
+
219
+ Returns:
220
+ Colormap as a numpy array of shape (N, 4) with values in [0, 1].
221
+ """
222
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
223
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
224
+ return np.clip(
225
+ np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
226
+ 0,
227
+ 1,
228
+ )
229
+
230
+
231
+ def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray:
232
+ """
233
+ Convert a matplotlib figure to a numpy array with RGB values.
234
+
235
+ Args:
236
+ fig: A matplotlib figure.
237
+
238
+ Returns:
239
+ A numpy array with shape (height, width, 3) and dtype uint8 containing
240
+ the RGB values of the figure.
241
+ """
242
+ fig.canvas.draw()
243
+ (width, height) = fig.canvas.get_width_height()
244
+ buf_ndarray = np.frombuffer(fig.canvas.tostring_rgb(), dtype="u1")
245
+ return buf_ndarray.reshape(height, width, 3)
246
+
247
+
248
+ def draw_matches_core(
249
+ mkpts0: List[np.ndarray],
250
+ mkpts1: List[np.ndarray],
251
+ img0: np.ndarray,
252
+ img1: np.ndarray,
253
+ conf: np.ndarray,
254
+ titles: Optional[List[str]] = None,
255
+ texts: Optional[List[str]] = None,
256
+ dpi: int = 150,
257
+ path: Optional[str] = None,
258
+ pad: float = 0.5,
259
+ ) -> np.ndarray:
260
+ """
261
+ Draw matches between two images.
262
+
263
+ Args:
264
+ mkpts0: List of matches from the first image, with shape (N, 2)
265
+ mkpts1: List of matches from the second image, with shape (N, 2)
266
+ img0: First image, with shape (H, W, 3)
267
+ img1: Second image, with shape (H, W, 3)
268
+ conf: Confidence values for the matches, with shape (N,)
269
+ titles: Optional list of title strings for the plot
270
+ dpi: DPI for the saved image
271
+ path: Optional path to save the image to. If None, the image is not saved.
272
+ pad: Padding between subplots
273
+
274
+ Returns:
275
+ The figure as a numpy array with shape (height, width, 3) and dtype uint8
276
+ containing the RGB values of the figure.
277
+ """
278
+ thr = 0.5
279
+ color = error_colormap(1 - conf, thr, alpha=0.1)
280
+ text = [
281
+ # "image name",
282
+ f"#Matches: {len(mkpts0)}",
283
+ ]
284
+ if path:
285
+ fig2im(
286
+ make_matching_figure(
287
+ img0,
288
+ img1,
289
+ mkpts0,
290
+ mkpts1,
291
+ color,
292
+ titles=titles,
293
+ text=text,
294
+ path=path,
295
+ dpi=dpi,
296
+ pad=pad,
297
+ )
298
+ )
299
+ else:
300
+ return fig2im(
301
+ make_matching_figure(
302
+ img0,
303
+ img1,
304
+ mkpts0,
305
+ mkpts1,
306
+ color,
307
+ titles=titles,
308
+ text=text,
309
+ pad=pad,
310
+ dpi=dpi,
311
+ )
312
+ )
313
+
314
+
315
+ def draw_image_pairs(
316
+ img0: np.ndarray,
317
+ img1: np.ndarray,
318
+ text: List[str] = [],
319
+ dpi: int = 75,
320
+ path: Optional[str] = None,
321
+ pad: float = 0.5,
322
+ ) -> np.ndarray:
323
+ """Draw image pair horizontally.
324
+
325
+ Args:
326
+ img0: First image, with shape (H, W, 3)
327
+ img1: Second image, with shape (H, W, 3)
328
+ text: List of strings to print. Each string is a new line.
329
+ dpi: DPI of the figure.
330
+ path: Path to save the image to. If None, the image is not saved and
331
+ the function returns the figure as a numpy array with shape
332
+ (height, width, 3) and dtype uint8 containing the RGB values of the
333
+ figure.
334
+ pad: Padding between subplots
335
+
336
+ Returns:
337
+ The figure as a numpy array with shape (height, width, 3) and dtype uint8
338
+ containing the RGB values of the figure, or None if path is not None.
339
+ """
340
+ # draw image pair
341
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
342
+ axes[0].imshow(img0) # , cmap='gray')
343
+ axes[1].imshow(img1) # , cmap='gray')
344
+ for i in range(2): # clear all frames
345
+ axes[i].get_yaxis().set_ticks([])
346
+ axes[i].get_xaxis().set_ticks([])
347
+ for spine in axes[i].spines.values():
348
+ spine.set_visible(False)
349
+ plt.tight_layout(pad=pad)
350
+
351
+ # put txts
352
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
353
+ fig.text(
354
+ 0.01,
355
+ 0.99,
356
+ "\n".join(text),
357
+ transform=fig.axes[0].transAxes,
358
+ fontsize=15,
359
+ va="top",
360
+ ha="left",
361
+ color=txt_color,
362
+ )
363
+
364
+ # save or return figure
365
+ if path:
366
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
367
+ plt.close()
368
+ else:
369
+ return fig2im(fig)
370
+
371
+
372
+ def display_keypoints(pred: dict, titles: List[str] = []):
373
+ img0 = pred["image0_orig"]
374
+ img1 = pred["image1_orig"]
375
+ output_keypoints = plot_images([img0, img1], titles=titles, dpi=300)
376
+ if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
377
+ plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
378
+ text = (
379
+ f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
380
+ + f"# keypoints1: {len(pred['keypoints1_orig'])}"
381
+ )
382
+ add_text(0, text, fs=15)
383
+ output_keypoints = fig2im(output_keypoints)
384
+ return output_keypoints
385
+
386
+
387
+ def display_matches(
388
+ pred: Dict[str, np.ndarray],
389
+ titles: List[str] = [],
390
+ texts: List[str] = [],
391
+ dpi: int = 300,
392
+ tag: str = "KPTS_RAW", # KPTS_RAW, KPTS_RANSAC, LINES_RAW, LINES_RANSAC,
393
+ ) -> Tuple[np.ndarray, int]:
394
+ """
395
+ Displays the matches between two images.
396
+
397
+ Args:
398
+ pred: Dictionary containing the original images and the matches.
399
+ titles: Optional titles for the plot.
400
+ dpi: Resolution of the plot.
401
+
402
+ Returns:
403
+ The resulting concatenated plot and the number of inliers.
404
+ """
405
+ img0 = pred["image0_orig"]
406
+ img1 = pred["image1_orig"]
407
+ num_inliers = 0
408
+ KPTS0_KEY = None
409
+ KPTS1_KEY = None
410
+ confid = None
411
+ if tag == "KPTS_RAW":
412
+ KPTS0_KEY = "mkeypoints0_orig"
413
+ KPTS1_KEY = "mkeypoints1_orig"
414
+ if "mconf" in pred:
415
+ confid = pred["mconf"]
416
+ elif tag == "KPTS_RANSAC":
417
+ KPTS0_KEY = "mmkeypoints0_orig"
418
+ KPTS1_KEY = "mmkeypoints1_orig"
419
+ if "mmconf" in pred:
420
+ confid = pred["mmconf"]
421
+ else:
422
+ # TODO: LINES_RAW, LINES_RANSAC
423
+ raise ValueError(f"Unknown tag: {tag}")
424
+ # draw raw matches
425
+ if (
426
+ KPTS0_KEY in pred
427
+ and KPTS1_KEY in pred
428
+ and pred[KPTS0_KEY] is not None
429
+ and pred[KPTS1_KEY] is not None
430
+ ): # draw ransac matches
431
+ mkpts0 = pred[KPTS0_KEY]
432
+ mkpts1 = pred[KPTS1_KEY]
433
+ num_inliers = len(mkpts0)
434
+ if confid is None:
435
+ confid = np.ones(len(mkpts0))
436
+ fig_mkpts = draw_matches_core(
437
+ mkpts0,
438
+ mkpts1,
439
+ img0,
440
+ img1,
441
+ confid,
442
+ dpi=dpi,
443
+ titles=titles,
444
+ texts=texts,
445
+ )
446
+ fig = fig_mkpts
447
+ elif (
448
+ "line0_orig" in pred
449
+ and "line1_orig" in pred
450
+ and pred["line0_orig"] is not None
451
+ and pred["line1_orig"] is not None
452
+ # and (tag == "LINES_RAW" or tag == "LINES_RANSAC")
453
+ ):
454
+ # lines
455
+ mtlines0 = pred["line0_orig"]
456
+ mtlines1 = pred["line1_orig"]
457
+ num_inliers = len(mtlines0)
458
+ fig_lines = plot_images(
459
+ [img0.squeeze(), img1.squeeze()],
460
+ ["Image 0 - matched lines", "Image 1 - matched lines"],
461
+ dpi=300,
462
+ )
463
+ fig_lines = plot_color_line_matches([mtlines0, mtlines1], lw=2)
464
+ fig_lines = fig2im(fig_lines)
465
+
466
+ # keypoints
467
+ mkpts0 = pred.get("line_keypoints0_orig")
468
+ mkpts1 = pred.get("line_keypoints1_orig")
469
+ fig = None
470
+ if mkpts0 is not None and mkpts1 is not None:
471
+ num_inliers = len(mkpts0)
472
+ if "mconf" in pred:
473
+ mconf = pred["mconf"]
474
+ else:
475
+ mconf = np.ones(len(mkpts0))
476
+ fig_mkpts = draw_matches_core(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
477
+ fig_lines = cv2.resize(fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0]))
478
+ fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
479
+ else:
480
+ fig = fig_lines
481
+ return fig, num_inliers