diff --git a/imcui/third_party/MatchAnything/LICENSE b/imcui/third_party/MatchAnything/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7
--- /dev/null
+++ b/imcui/third_party/MatchAnything/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/imcui/third_party/MatchAnything/README.md b/imcui/third_party/MatchAnything/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ceb1c2a3152cbf3c8d7e42190e54040efa38d23d
--- /dev/null
+++ b/imcui/third_party/MatchAnything/README.md
@@ -0,0 +1,104 @@
+# MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training
+### [Project Page](https://zju3dv.github.io/MatchAnything) | [Paper](??)
+
+> MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training\
+> [Xingyi He](https://hxy-123.github.io/),
+[Hao Yu](https://ritianyu.github.io/),
+[Sida Peng](https://pengsida.net),
+[Dongli Tan](https://github.com/Cuistiano),
+[Zehong Shen](https://zehongs.github.io),
+[Xiaowei Zhou](https://xzhou.me/),
+[Hujun Bao](http://www.cad.zju.edu.cn/home/bao/)†\
+> Arxiv 2025
+
+
+
+
+
+## TODO List
+- [x] Pre-trained models and inference code
+- [x] Huggingface demo
+- [ ] Data generation and training code
+- [ ] Finetune code to further train on your own data
+- [ ] Incorporate more synthetic modalities and image generation methods
+
+## Quick Start
+
+### [ HuggingFace demo for MatchAnything](https://huggingface.co/spaces/LittleFrog/MatchAnything)
+
+## Setup
+Create the python environment by:
+```
+conda env create -f environment.yaml
+conda activate env
+```
+We have tested our code on the device with CUDA 11.7.
+
+Download pretrained weights from [here](https://drive.google.com/file/d/12L3g9-w8rR9K2L4rYaGaDJ7NqX1D713d/view?usp=sharing) and place it under repo directory. Then unzip it by running the following command:
+```
+unzip weights.zip
+rm -rf weights.zip
+```
+
+## Test:
+We evaluate the models pretrained by our framework using a single network weight on all cross-modality matching and registration tasks.
+
+### Data Preparing
+Download the `test_data` directory from [here](https://drive.google.com/drive/folders/1jpxIOcgnQfl9IEPPifdXQ7S7xuj9K4j7?usp=sharing) and plase it under `repo_directory/data`. Then, unzip all datasets by:
+```shell
+cd repo_directiry/data/test_data
+
+for file in *.zip; do
+ unzip "$file" && rm "$file"
+done
+```
+
+The data structure should looks like:
+```
+repo_directiry/data/test_data
+ - Liver_CT-MR
+ - havard_medical_matching
+ - remote_sense_thermal
+ - MTV_cross_modal_data
+ - thermal_visible_ground
+ - visible_sar_dataset
+ - visible_vectorized_map
+```
+
+### Evaluation
+```shell
+# For Tomography datasets:
+sh scripts/evaluate/eval_liver_ct_mr.sh
+sh scripts/evaluate/eval_harvard_brain.sh
+
+
+
+# For visible-thermal datasets:
+sh scripts/evaluate/eval_thermal_remote_sense.sh
+sh scripts/evaluate/eval_thermal_mtv.sh
+sh scripts/evaluate/eval_thermal_ground.sh
+
+# For visible-sar dataset:
+sh scripts/evaluate/eval_visible_sar.sh
+
+# For visible-vectorized map dataset:
+sh scripts/evaluate/eval_visible_vectorized_map.sh
+```
+
+# Citation
+
+If you find this code useful for your research, please use the following BibTeX entry.
+
+```
+@inproceedings{he2025matchanything,
+title={MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training},
+author={He, Xingyi and Yu, Hao and Peng, Sida and Tan, Dongli and Shen, Zehong and Bao, Hujun and Zhou, Xiaowei},
+booktitle={Arxiv},
+year={2025}
+}
+```
+
+# Acknowledgement
+We thank the authors of
+[ELoFTR](https://github.com/zju3dv/EfficientLoFTR),
+[ROMA](https://github.com/Parskatt/RoMa) for their great works, without which our project/code would not be possible.
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/configs/models/eloftr_model.py b/imcui/third_party/MatchAnything/configs/models/eloftr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..abbc030bb4181ea1d33d06f90797880ae03a18da
--- /dev/null
+++ b/imcui/third_party/MatchAnything/configs/models/eloftr_model.py
@@ -0,0 +1,128 @@
+from src.config.default import _CN as cfg
+
+cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
+
+cfg.TRAINER.CANONICAL_LR = 8e-3
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+
+cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16]
+
+# pose estimation
+cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
+
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+
+cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.1
+
+cfg.LOFTR.MATCH_COARSE.MTD_SPVS = True
+cfg.LOFTR.FINE.MTD_SPVS = True
+
+cfg.LOFTR.RESOLUTION = (8, 1) # options: [(8, 2), (16, 4)]
+cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be odd
+cfg.LOFTR.MATCH_FINE.THR = 0
+cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2']
+
+cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+
+cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
+
+# PAN
+cfg.LOFTR.COARSE.PAN = True
+cfg.LOFTR.COARSE.POOl_SIZE = 4
+cfg.LOFTR.COARSE.BN = False
+cfg.LOFTR.COARSE.XFORMER = True
+cfg.LOFTR.COARSE.ATTENTION = 'full' # options: ['linear', 'full']
+
+cfg.LOFTR.FINE.PAN = False
+cfg.LOFTR.FINE.POOl_SIZE = 4
+cfg.LOFTR.FINE.BN = False
+cfg.LOFTR.FINE.XFORMER = False
+
+# noalign
+cfg.LOFTR.ALIGN_CORNER = False
+
+# fp16
+cfg.DATASET.FP16 = False
+cfg.LOFTR.FP16 = False
+
+# DEBUG
+cfg.LOFTR.FP16LOG = False
+cfg.LOFTR.MATCH_COARSE.FP16LOG = False
+
+# fine skip
+cfg.LOFTR.FINE.SKIP = True
+
+# clip
+cfg.TRAINER.GRADIENT_CLIPPING = 0.5
+
+# backbone
+cfg.LOFTR.BACKBONE_TYPE = 'RepVGG'
+
+# A1
+cfg.LOFTR.RESNETFPN.INITIAL_DIM = 64
+cfg.LOFTR.RESNETFPN.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3
+cfg.LOFTR.COARSE.D_MODEL = 256
+cfg.LOFTR.FINE.D_MODEL = 64
+
+# FPN backbone_inter_feat with coarse_attn.
+cfg.LOFTR.COARSE_FEAT_ONLY = True
+cfg.LOFTR.INTER_FEAT = True
+cfg.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = True
+cfg.LOFTR.RESNETFPN.INTER_FEAT = True
+
+# loop back spv coarse match
+cfg.LOFTR.FORCE_LOOP_BACK = False
+
+# fix norm fine match
+cfg.LOFTR.MATCH_FINE.NORMFINEM = True
+
+# loss cf weight
+cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True
+cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True
+
+# leaky relu
+cfg.LOFTR.RESNETFPN.LEAKY = False
+cfg.LOFTR.COARSE.LEAKY = 0.01
+
+# prevent FP16 OVERFLOW in dirty data
+cfg.LOFTR.NORM_FPNFEAT = True
+cfg.LOFTR.REPLACE_NAN = True
+
+# force mutual nearest
+cfg.LOFTR.MATCH_COARSE.FORCE_NEAREST = True
+cfg.LOFTR.MATCH_COARSE.THR = 0.1
+
+# fix fine matching
+cfg.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = True
+
+# dwconv
+cfg.LOFTR.COARSE.DWCONV = True
+
+# localreg
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS = True
+cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25
+
+# it5
+cfg.LOFTR.EVAL_TIMES = 1
+
+# rope
+cfg.LOFTR.COARSE.ROPE = True
+
+# local regress temperature
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0
+
+# SLICE
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = True
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
+
+# inner with no mask [64,100]
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = True
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = True
+
+cfg.LOFTR.MATCH_FINE.TOPK = 1
+cfg.LOFTR.MATCH_COARSE.FINE_TOPK = 1
+
+cfg.LOFTR.MATCH_COARSE.FP16MATMUL = False
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/configs/models/roma_model.py b/imcui/third_party/MatchAnything/configs/models/roma_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..207ab322242234242b5e558e962f57dce91d0566
--- /dev/null
+++ b/imcui/third_party/MatchAnything/configs/models/roma_model.py
@@ -0,0 +1,27 @@
+from src.config.default import _CN as cfg
+cfg.ROMA.RESIZE_BY_STRETCH = True
+cfg.DATASET.RESIZE_BY_STRETCH = True
+
+cfg.TRAINER.CANONICAL_LR = 8e-3
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+
+cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18, 20]
+
+# pose estimation
+cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
+
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+cfg.TRAINER.OPTIMIZER_EPS = 5e-7
+
+cfg.TRAINER.EPI_ERR_THR = 5e-4
+
+# fp16
+cfg.DATASET.FP16 = False
+cfg.LOFTR.FP16 = True
+
+# clip
+cfg.TRAINER.GRADIENT_CLIPPING = 0.5
+
+cfg.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = True
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/environment.yaml b/imcui/third_party/MatchAnything/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d61b8e99932550f4c9c3f3a9e7e58e0b9bf68b4
--- /dev/null
+++ b/imcui/third_party/MatchAnything/environment.yaml
@@ -0,0 +1,14 @@
+name: env
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.8
+ - pytorch-cuda=11.7
+ - pytorch=1.12.1
+ - torchvision=0.13.1
+ - pip
+ - pip:
+ - -r requirements.txt
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81512278dfdfa73dd0915defa732b3b0e7db6af6
--- /dev/null
+++ b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py
@@ -0,0 +1 @@
+from .plotting import *
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..9993ed5b989d8babc87969e34d153ba3bcc05e1f
--- /dev/null
+++ b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py
@@ -0,0 +1,344 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+from matplotlib.colors import hsv_to_rgb
+import pylab as pl
+import matplotlib.cm as cm
+from PIL import Image
+import cv2
+
+
+def visualize_features(feat, img_h, img_w, save_path=None):
+ from sklearn.decomposition import PCA
+ pca = PCA(n_components=3, svd_solver="arpack")
+ img = pca.fit_transform(feat).reshape(img_h * 2, img_w, 3)
+ img_norm = cv2.normalize(
+ img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC3
+ )
+ img_resized = cv2.resize(
+ img_norm, (img_w * 8, img_h * 2 * 8), interpolation=cv2.INTER_LINEAR
+ )
+ img_colormap = img_resized
+ img1, img2 = img_colormap[: img_h * 8, :, :], img_colormap[img_h * 8 :, :, :]
+ img_gapped = np.hstack(
+ (img1, np.ones((img_h * 8, 10, 3), dtype=np.uint8) * 255, img2)
+ )
+ if save_path is not None:
+ cv2.imwrite(save_path, img_gapped)
+
+ fig, axes = plt.subplots(1, 1, dpi=200)
+ axes.imshow(img_gapped)
+ axes.get_yaxis().set_ticks([])
+ axes.get_xaxis().set_ticks([])
+ plt.tight_layout(pad=0.5)
+ return fig
+
+def make_matching_figure(
+ img0,
+ img1,
+ mkpts0,
+ mkpts1,
+ color,
+ kpts0=None,
+ kpts1=None,
+ text=[],
+ path=None,
+ draw_detection=False,
+ draw_match_type='corres', # ['color', 'corres', None]
+ r_normalize_factor=0.4,
+ white_center=True,
+ vertical=False,
+ use_position_color=False,
+ draw_local_window=False,
+ window_size=(9, 9),
+ plot_size_factor=1, # Point size and line width
+ anchor_pts0=None,
+ anchor_pts1=None,
+ rescale_thr=5000,
+):
+ if (max(img0.shape) > rescale_thr) or (max(img1.shape) > rescale_thr):
+ scale_factor = 0.5
+ img0 = np.array(Image.fromarray((img0 * 255).astype(np.uint8)).resize((int(img0.shape[1] * scale_factor), int(img0.shape[0] * scale_factor)))) / 255.
+ img1 = np.array(Image.fromarray((img1 * 255).astype(np.uint8)).resize((int(img1.shape[1] * scale_factor), int(img1.shape[0] * scale_factor)))) / 255.
+ mkpts0, mkpts1 = mkpts0 * scale_factor, mkpts1 * scale_factor
+ if kpts0 is not None:
+ kpts0, kpts1 = kpts0 * scale_factor, kpts1 * scale_factor
+
+ # draw image pair
+ fig, axes = (
+ plt.subplots(2, 1, figsize=(10, 6), dpi=600)
+ if vertical
+ else plt.subplots(1, 2, figsize=(10, 6), dpi=600)
+ )
+ axes[0].imshow(img0, aspect='auto')
+ axes[1].imshow(img1, aspect='auto')
+
+ # axes[0].imshow(img0, aspect='equal')
+ # axes[1].imshow(img1, aspect='equal')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if use_position_color:
+ mean_coord = np.mean(mkpts0, axis=0)
+ x_center, y_center = mean_coord
+ # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive.
+ position_color = matching_coord2color(
+ mkpts0,
+ x_center,
+ y_center,
+ r_normalize_factor=r_normalize_factor,
+ white_center=white_center,
+ )
+ color[:, :3] = position_color
+
+ if draw_detection and kpts0 is not None and kpts1 is not None:
+ # color = 'g'
+ color = 'r'
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=1 * plot_size_factor)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=1 * plot_size_factor)
+
+ if draw_match_type is 'corres':
+ # draw matches
+ fig.canvas.draw()
+ plt.pause(2.0)
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [
+ matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure,
+ c=color[i],
+ linewidth=1* plot_size_factor,
+ )
+ for i in range(len(mkpts0))
+ ]
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=2* plot_size_factor)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=2* plot_size_factor)
+ elif draw_match_type is 'color':
+ # x_center = img0.shape[-1] / 2
+ # y_center = img1.shape[-2] / 2
+
+ mean_coord = np.mean(mkpts0, axis=0)
+ x_center, y_center = mean_coord
+ # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive.
+ kpts_color = matching_coord2color(
+ mkpts0,
+ x_center,
+ y_center,
+ r_normalize_factor=r_normalize_factor,
+ white_center=white_center,
+ )
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=kpts_color, s=1 * plot_size_factor)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=kpts_color, s=1 * plot_size_factor)
+
+ if draw_local_window:
+ anchor_pts0 = mkpts0 if anchor_pts0 is None else anchor_pts0
+ anchor_pts1 = mkpts1 if anchor_pts1 is None else anchor_pts1
+ plot_local_windows(
+ anchor_pts0, color=(1, 0, 0, 0.4), lw=0.2, ax_=0, window_size=window_size
+ )
+ plot_local_windows(
+ anchor_pts1, color=(1, 0, 0, 0.4), lw=0.2, ax_=1, window_size=window_size
+ ) # lw =0.2
+
+ # put txts
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
+ fig.text(
+ 0.01,
+ 0.99,
+ "\n".join(text),
+ transform=fig.axes[0].transAxes,
+ fontsize=15,
+ va="top",
+ ha="left",
+ color=txt_color,
+ )
+ plt.tight_layout(pad=1)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+def make_triple_matching_figure(
+ img0,
+ img1,
+ img2,
+ mkpts01,
+ mkpts12,
+ color01,
+ color12,
+ text=[],
+ path=None,
+ draw_match=True,
+ r_normalize_factor=0.4,
+ white_center=True,
+ vertical=False,
+ draw_local_window=False,
+ window_size=(9, 9),
+ anchor_pts0=None,
+ anchor_pts1=None,
+):
+ # draw image pair
+ fig, axes = (
+ plt.subplots(3, 1, figsize=(10, 6), dpi=600)
+ if vertical
+ else plt.subplots(1, 3, figsize=(10, 6), dpi=600)
+ )
+ axes[0].imshow(img0)
+ axes[1].imshow(img1)
+ axes[2].imshow(img2)
+ for i in range(3): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if draw_match:
+ # draw matches for [0,1]
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts01[0]))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts01[1]))
+ fig.lines = [
+ matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure,
+ c=color01[i],
+ linewidth=1,
+ )
+ for i in range(len(mkpts01[0]))
+ ]
+
+ axes[0].scatter(mkpts01[0][:, 0], mkpts01[0][:, 1], c=color01[:, :3], s=1)
+ axes[1].scatter(mkpts01[1][:, 0], mkpts01[1][:, 1], c=color01[:, :3], s=1)
+
+ fig.canvas.draw()
+ # draw matches for [1,2]
+ fkpts1_1 = transFigure.transform(axes[1].transData.transform(mkpts12[0]))
+ fkpts2 = transFigure.transform(axes[2].transData.transform(mkpts12[1]))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (fkpts1_1[i, 0], fkpts2[i, 0]),
+ (fkpts1_1[i, 1], fkpts2[i, 1]),
+ transform=fig.transFigure,
+ c=color12[i],
+ linewidth=1,
+ )
+ for i in range(len(mkpts12[0]))
+ ]
+
+ axes[1].scatter(mkpts12[0][:, 0], mkpts12[0][:, 1], c=color12[:, :3], s=1)
+ axes[2].scatter(mkpts12[1][:, 0], mkpts12[1][:, 1], c=color12[:, :3], s=1)
+
+ # # put txts
+ # txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
+ # fig.text(
+ # 0.01,
+ # 0.99,
+ # "\n".join(text),
+ # transform=fig.axes[0].transAxes,
+ # fontsize=15,
+ # va="top",
+ # ha="left",
+ # color=txt_color,
+ # )
+ plt.tight_layout(pad=0.1)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def matching_coord2color(kpts, x_center, y_center, r_normalize_factor=0.4, white_center=True):
+ """
+ r_normalize_factor is used to visualize clearer according to points space distribution
+ r_normalize_factor maxium=1, larger->points darker/brighter
+ """
+ if not white_center:
+ # dark center points
+ V, H = np.mgrid[0:1:10j, 0:1:360j]
+ S = np.ones_like(V)
+ else:
+ # white center points
+ S, H = np.mgrid[0:1:10j, 0:1:360j]
+ V = np.ones_like(S)
+
+ HSV = np.dstack((H, S, V))
+ RGB = hsv_to_rgb(HSV)
+ """
+ # used to visualize hsv
+ pl.imshow(RGB, origin="lower", extent=[0, 360, 0, 1], aspect=150)
+ pl.xlabel("H")
+ pl.ylabel("S")
+ pl.title("$V_{HSV}=1$")
+ pl.show()
+ """
+ kpts = np.copy(kpts)
+ distance = kpts - np.array([x_center, y_center])[None]
+ r_max = np.percentile(np.linalg.norm(distance, axis=1), 85)
+ # r_max = np.sqrt((x_center) ** 2 + (y_center) ** 2)
+ kpts[:, 0] = kpts[:, 0] - x_center # x
+ kpts[:, 1] = kpts[:, 1] - y_center # y
+
+ r = np.sqrt(kpts[:, 0] ** 2 + kpts[:, 1] ** 2) + 1e-6
+ r_normalized = r / (r_max * r_normalize_factor)
+ r_normalized[r_normalized > 1] = 1
+ r_normalized = (r_normalized) * 9
+
+ cos_theta = kpts[:, 0] / r # x / r
+ theta = np.arccos(cos_theta) # from 0 to pi
+ change_angle_mask = kpts[:, 1] < 0
+ theta[change_angle_mask] = 2 * np.pi - theta[change_angle_mask]
+ theta_degree = np.degrees(theta)
+ theta_degree[theta_degree == 360] = 0 # to avoid overflow
+ theta_degree = theta_degree / 360 * 360
+ kpts_color = RGB[r_normalized.astype(int), theta_degree.astype(int)]
+ return kpts_color
+
+
+def show_image_pair(img0, img1, path=None):
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=200)
+ axes[0].imshow(img0, cmap="gray")
+ axes[1].imshow(img1, cmap="gray")
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+ if path:
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
+ return fig
+
+def plot_local_windows(kpts, color="r", lw=1, ax_=0, window_size=(9, 9)):
+ ax = plt.gcf().axes
+ for kpt in kpts:
+ ax[ax_].add_patch(
+ matplotlib.patches.Rectangle(
+ (
+ kpt[0] - (window_size[0] // 2) - 1,
+ kpt[1] - (window_size[1] // 2) - 1,
+ ),
+ window_size[0] + 1,
+ window_size[1] + 1,
+ lw=lw,
+ color=color,
+ fill=False,
+ )
+ )
+
diff --git a/imcui/third_party/MatchAnything/requirements.txt b/imcui/third_party/MatchAnything/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..891cdb8becdb19790c4ac3659d69f8118d14bd2c
--- /dev/null
+++ b/imcui/third_party/MatchAnything/requirements.txt
@@ -0,0 +1,22 @@
+opencv_python==4.4.0.46
+albumentations==0.5.1 --no-binary=imgaug,albumentations
+Pillow==9.5.0
+ray==2.9.3
+einops==0.3.0
+kornia==0.4.1
+loguru==0.5.3
+yacs>=0.1.8
+tqdm
+autopep8
+pylint
+ipython
+jupyterlab
+matplotlib
+h5py==3.1.0
+pytorch-lightning==1.3.5
+torchmetrics==0.6.0 # version problem: https://github.com/NVIDIA/DeepLearningExamples/issues/1113#issuecomment-1102969461
+joblib>=1.0.1
+pynvml
+gpustat
+safetensors
+timm==0.6.7
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a517ce7855dc9c87798a9a230fa8f21c4c06ca5a
--- /dev/null
+++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh
@@ -0,0 +1,17 @@
+#!/bin/bash -l
+
+SCRIPTPATH=$(dirname $(readlink -f "$0"))
+PROJECT_DIR="${SCRIPTPATH}/../../"
+
+cd $PROJECT_DIR
+
+DEVICE_ID='0'
+NPZ_ROOT=data/test_data/havard_medical_matching/all_eval
+NPZ_LIST_PATH=data/test_data/havard_medical_matching/all_eval/val_list.txt
+OUTPUT_PATH=results/havard_medical_matching
+
+# ELoFTR pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
+
+# ROMA pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f58b0b623565a3131f0be290c81e3842c6db4d3e
--- /dev/null
+++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh
@@ -0,0 +1,17 @@
+#!/bin/bash -l
+
+SCRIPTPATH=$(dirname $(readlink -f "$0"))
+PROJECT_DIR="${SCRIPTPATH}/../../"
+
+cd $PROJECT_DIR
+
+DEVICE_ID='0'
+NPZ_ROOT=data/test_data/Liver_CT-MR/eval_indexs
+NPZ_LIST_PATH=data/test_data/Liver_CT-MR/eval_indexs/val_list.txt
+OUTPUT_PATH=results/Liver_CT-MR
+
+# ELoFTR pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
+
+# ROMA pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2f7d5ea30804e749c5eff6f94d9963073683e441
--- /dev/null
+++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh
@@ -0,0 +1,17 @@
+#!/bin/bash -l
+
+SCRIPTPATH=$(dirname $(readlink -f "$0"))
+PROJECT_DIR="${SCRIPTPATH}/../../"
+
+cd $PROJECT_DIR
+
+DEVICE_ID='0'
+NPZ_ROOT=data/test_data/thermal_visible_ground/eval_indexs
+NPZ_LIST_PATH=data/test_data/thermal_visible_ground/eval_indexs/val_list.txt
+OUTPUT_PATH=results/thermal_visible_ground
+
+# ELoFTR pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
+
+# ROMA pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh
new file mode 100644
index 0000000000000000000000000000000000000000..faf816d4a377626f35e3549b1de8c65a8b4ae540
--- /dev/null
+++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh
@@ -0,0 +1,17 @@
+#!/bin/bash -l
+
+SCRIPTPATH=$(dirname $(readlink -f "$0"))
+PROJECT_DIR="${SCRIPTPATH}/../../"
+
+cd $PROJECT_DIR
+
+DEVICE_ID='0'
+NPZ_ROOT=data/test_data/MTV_cross_modal_data/scene_info/scene_info
+NPZ_LIST_PATH=data/test_data/MTV_cross_modal_data/scene_info/test_list.txt
+OUTPUT_PATH=results/MTV_cross_modal_data
+
+# ELoFTR pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
+
+# ROMA pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh
new file mode 100644
index 0000000000000000000000000000000000000000..afd0c62edb3ba9e9ffc4a3a4140a38c3fd1d475e
--- /dev/null
+++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh
@@ -0,0 +1,17 @@
+#!/bin/bash -l
+
+SCRIPTPATH=$(dirname $(readlink -f "$0"))
+PROJECT_DIR="${SCRIPTPATH}/../../"
+
+cd $PROJECT_DIR
+
+DEVICE_ID='0'
+NPZ_ROOT=data/test_data/remote_sense_thermal/eval_Optical-Infrared
+NPZ_LIST_PATH=data/test_data/remote_sense_thermal/eval_Optical-Infrared/val_list.txt
+OUTPUT_PATH=results/remote_sense_thermal
+
+# ELoFTR pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
+
+# ROMA pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh
new file mode 100644
index 0000000000000000000000000000000000000000..231f303a642ce890c7e90470120ef8744d89455e
--- /dev/null
+++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh
@@ -0,0 +1,17 @@
+#!/bin/bash -l
+
+SCRIPTPATH=$(dirname $(readlink -f "$0"))
+PROJECT_DIR="${SCRIPTPATH}/../../"
+
+cd $PROJECT_DIR
+
+DEVICE_ID='0'
+NPZ_ROOT=data/test_data/visible_sar_dataset/eval
+NPZ_LIST_PATH=data/test_data/visible_sar_dataset/eval/val_list.txt
+OUTPUT_PATH=results/visible_sar_dataset
+
+# ELoFTR pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
+
+# ROMA pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a8bb56f0a8273409fa51532e61323d423ad04191
--- /dev/null
+++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh
@@ -0,0 +1,17 @@
+#!/bin/bash -l
+
+SCRIPTPATH=$(dirname $(readlink -f "$0"))
+PROJECT_DIR="${SCRIPTPATH}/../../"
+
+cd $PROJECT_DIR
+
+DEVICE_ID='0'
+NPZ_ROOT=data/test_data/visible_vectorized_map/scene_indices
+NPZ_LIST_PATH=data/test_data/visible_vectorized_map/scene_indices/val_list.txt
+OUTPUT_PATH=results/visible_vectorized_map
+
+# ELoFTR pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
+
+# ROMA pretrained:
+CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/__init__.py b/imcui/third_party/MatchAnything/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/MatchAnything/src/config/default.py b/imcui/third_party/MatchAnything/src/config/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b43845bfec84071b29b96012bf5401a889327ed
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/config/default.py
@@ -0,0 +1,344 @@
+from yacs.config import CfgNode as CN
+_CN = CN()
+############## ROMA Pipeline #########
+_CN.ROMA = CN()
+_CN.ROMA.MATCH_THRESH = 0.0
+_CN.ROMA.RESIZE_BY_STRETCH = False # Used for test mode
+_CN.ROMA.NORMALIZE_IMG = False # Used for test mode
+
+_CN.ROMA.MODE = "train_framework" # Used in Lightning Train & Val
+_CN.ROMA.MODEL = CN()
+_CN.ROMA.MODEL.COARSE_BACKBONE = 'DINOv2_large'
+_CN.ROMA.MODEL.COARSE_FEAT_DIM = 1024
+_CN.ROMA.MODEL.MEDIUM_FEAT_DIM = 512
+_CN.ROMA.MODEL.COARSE_PATCH_SIZE = 14
+_CN.ROMA.MODEL.AMP = True # FP16 mode
+
+_CN.ROMA.SAMPLE = CN()
+_CN.ROMA.SAMPLE.METHOD = "threshold_balanced"
+_CN.ROMA.SAMPLE.N_SAMPLE = 5000
+_CN.ROMA.SAMPLE.THRESH = 0.05
+
+_CN.ROMA.TEST_TIME = CN()
+_CN.ROMA.TEST_TIME.COARSE_RES = (560, 560) # need to divisable by 14 & 8
+_CN.ROMA.TEST_TIME.UPSAMPLE = True
+_CN.ROMA.TEST_TIME.UPSAMPLE_RES = (864, 864) # need to divisable by 8
+_CN.ROMA.TEST_TIME.SYMMETRIC = True
+_CN.ROMA.TEST_TIME.ATTENUTATE_CERT = True
+
+############## ↓ LoFTR Pipeline ↓ ##############
+_CN.LOFTR = CN()
+_CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN'
+_CN.LOFTR.ALIGN_CORNER = True
+_CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
+_CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
+_CN.LOFTR.FINE_WINDOW_MATCHING_SIZE = 5 # window_size for loftr fine-matching, odd for select and even for average
+_CN.LOFTR.FINE_CONCAT_COARSE_FEAT = True
+_CN.LOFTR.FINE_SAMPLE_COARSE_FEAT = False
+_CN.LOFTR.COARSE_FEAT_ONLY = False # TO BE DONE
+_CN.LOFTR.INTER_FEAT = False # FPN backbone inter feat with coarse_attn.
+_CN.LOFTR.FP16 = False
+_CN.LOFTR.FIX_BIAS = False
+_CN.LOFTR.MATCHABILITY = False
+_CN.LOFTR.FORCE_LOOP_BACK = False
+_CN.LOFTR.NORM_FPNFEAT = False
+_CN.LOFTR.NORM_FPNFEAT2 = False
+_CN.LOFTR.REPLACE_NAN = False
+_CN.LOFTR.PLOT_SCORES = False
+_CN.LOFTR.REP_FPN = False
+_CN.LOFTR.REP_DEPLOY = False
+_CN.LOFTR.EVAL_TIMES = 1
+
+# 1. LoFTR-backbone (local feature CNN) config
+_CN.LOFTR.RESNETFPN = CN()
+_CN.LOFTR.RESNETFPN.INITIAL_DIM = 128
+_CN.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
+_CN.LOFTR.RESNETFPN.SAMPLE_FINE = False
+_CN.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = False # TO BE DONE
+_CN.LOFTR.RESNETFPN.INTER_FEAT = False # FPN backbone inter feat with coarse_attn.
+_CN.LOFTR.RESNETFPN.LEAKY = False
+_CN.LOFTR.RESNETFPN.REPVGGMODEL = None
+
+# 2. LoFTR-coarse module config
+_CN.LOFTR.COARSE = CN()
+_CN.LOFTR.COARSE.D_MODEL = 256
+_CN.LOFTR.COARSE.D_FFN = 256
+_CN.LOFTR.COARSE.NHEAD = 8
+_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
+_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
+_CN.LOFTR.COARSE.TEMP_BUG_FIX = True
+_CN.LOFTR.COARSE.NPE = False
+_CN.LOFTR.COARSE.PAN = False
+_CN.LOFTR.COARSE.POOl_SIZE = 4
+_CN.LOFTR.COARSE.POOl_SIZE2 = 4
+_CN.LOFTR.COARSE.BN = True
+_CN.LOFTR.COARSE.XFORMER = False
+_CN.LOFTR.COARSE.BIDIRECTION = False
+_CN.LOFTR.COARSE.DEPTH_CONFIDENCE = -1.0
+_CN.LOFTR.COARSE.WIDTH_CONFIDENCE = -1.0
+_CN.LOFTR.COARSE.LEAKY = -1.0
+_CN.LOFTR.COARSE.ASYMMETRIC = False
+_CN.LOFTR.COARSE.ASYMMETRIC_SELF = False
+_CN.LOFTR.COARSE.ROPE = False
+_CN.LOFTR.COARSE.TOKEN_MIXER = None
+_CN.LOFTR.COARSE.SKIP = False
+_CN.LOFTR.COARSE.DWCONV = False
+_CN.LOFTR.COARSE.DWCONV2 = False
+_CN.LOFTR.COARSE.SCATTER = False
+_CN.LOFTR.COARSE.ROPE = False
+_CN.LOFTR.COARSE.NPE = None
+_CN.LOFTR.COARSE.NORM_BEFORE = True
+_CN.LOFTR.COARSE.VIT_NORM = False
+_CN.LOFTR.COARSE.ROPE_DWPROJ = False
+_CN.LOFTR.COARSE.ABSPE = False
+
+
+# 3. Coarse-Matching config
+_CN.LOFTR.MATCH_COARSE = CN()
+_CN.LOFTR.MATCH_COARSE.THR = 0.2
+_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2
+_CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
+_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3
+_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
+_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False
+_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
+_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
+_CN.LOFTR.MATCH_COARSE.MTD_SPVS = False
+_CN.LOFTR.MATCH_COARSE.FIX_BIAS = False
+_CN.LOFTR.MATCH_COARSE.BINARY = False
+_CN.LOFTR.MATCH_COARSE.BINARY_SPV = 'l2'
+_CN.LOFTR.MATCH_COARSE.NORMFEAT = False
+_CN.LOFTR.MATCH_COARSE.NORMFEATMUL = False
+_CN.LOFTR.MATCH_COARSE.DIFFSIGN2 = False
+_CN.LOFTR.MATCH_COARSE.DIFFSIGN3 = False
+_CN.LOFTR.MATCH_COARSE.CLASSIFY = False
+_CN.LOFTR.MATCH_COARSE.D_CLASSIFY = 256
+_CN.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = False
+_CN.LOFTR.MATCH_COARSE.FORCE_NEAREST = False # in case binary is True, force nearest neighbor, preventing finding a reasonable threshold
+_CN.LOFTR.MATCH_COARSE.FP16MATMUL = False
+_CN.LOFTR.MATCH_COARSE.SEQSOFTMAX = False
+_CN.LOFTR.MATCH_COARSE.SEQSOFTMAX2 = False
+_CN.LOFTR.MATCH_COARSE.RATIO_TEST = False
+_CN.LOFTR.MATCH_COARSE.RATIO_TEST_VAL = -1.0
+_CN.LOFTR.MATCH_COARSE.USE_GT_COARSE = False
+_CN.LOFTR.MATCH_COARSE.CROSS_SOFTMAX = False
+_CN.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES = False
+_CN.LOFTR.MATCH_COARSE.USE_PERCENT_THR = False
+_CN.LOFTR.MATCH_COARSE.PERCENT_THR = 0.1
+_CN.LOFTR.MATCH_COARSE.ADD_SIGMOID = False
+_CN.LOFTR.MATCH_COARSE.SIGMOID_BIAS = 20.0
+_CN.LOFTR.MATCH_COARSE.SIGMOID_SIGMA = 2.5
+_CN.LOFTR.MATCH_COARSE.CAL_PER_OF_GT = False
+
+# 4. LoFTR-fine module config
+_CN.LOFTR.FINE = CN()
+_CN.LOFTR.FINE.SKIP = False
+_CN.LOFTR.FINE.D_MODEL = 128
+_CN.LOFTR.FINE.D_FFN = 128
+_CN.LOFTR.FINE.NHEAD = 8
+_CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1
+_CN.LOFTR.FINE.ATTENTION = 'linear'
+_CN.LOFTR.FINE.MTD_SPVS = False
+_CN.LOFTR.FINE.PAN = False
+_CN.LOFTR.FINE.POOl_SIZE = 4
+_CN.LOFTR.FINE.BN = True
+_CN.LOFTR.FINE.XFORMER = False
+_CN.LOFTR.FINE.BIDIRECTION = False
+
+
+# Fine-Matching config
+_CN.LOFTR.MATCH_FINE = CN()
+_CN.LOFTR.MATCH_FINE.THR = 0
+_CN.LOFTR.MATCH_FINE.TOPK = 3
+_CN.LOFTR.MATCH_FINE.NORMFINEM = False
+_CN.LOFTR.MATCH_FINE.USE_GT_FINE = False
+_CN.LOFTR.MATCH_COARSE.FINE_TOPK = _CN.LOFTR.MATCH_FINE.TOPK
+_CN.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = False
+_CN.LOFTR.MATCH_FINE.SKIP_FINE_SOFTMAX = False
+_CN.LOFTR.MATCH_FINE.USE_SIGMOID = False
+_CN.LOFTR.MATCH_FINE.SIGMOID_BIAS = 0.0
+_CN.LOFTR.MATCH_FINE.NORMFEAT = False
+_CN.LOFTR.MATCH_FINE.SPARSE_SPVS = True
+_CN.LOFTR.MATCH_FINE.FORCE_NEAREST = False
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS = False
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_RMBORDER = False
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = False
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 1.0
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_PADONE = False
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = False
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = False
+_CN.LOFTR.MATCH_FINE.MULTI_REGRESS = False
+
+
+
+# 5. LoFTR Losses
+# -- # coarse-level
+_CN.LOFTR.LOSS = CN()
+_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
+_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0
+_CN.LOFTR.LOSS.COARSE_SIGMOID_WEIGHT = 1.0
+_CN.LOFTR.LOSS.LOCAL_WEIGHT = 0.5
+_CN.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = False
+_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = False
+_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2 = False
+# _CN.LOFTR.LOSS.SPARSE_SPVS = False
+# -- - -- # focal loss (coarse)
+_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25
+_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0
+_CN.LOFTR.LOSS.POS_WEIGHT = 1.0
+_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0
+_CN.LOFTR.LOSS.CORRECT_NEG_WEIGHT = False
+# _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not.
+# use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE`
+
+# -- # fine-level
+_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
+_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0
+_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
+
+# -- # ROMA:
+_CN.LOFTR.ROMA_LOSS = CN()
+_CN.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2']
+
+# -- # DKM:
+_CN.LOFTR.DKM_LOSS = CN()
+_CN.LOFTR.DKM_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2']
+
+############## Dataset ##############
+_CN.DATASET = CN()
+# 1. data config
+# training and validating
+_CN.DATASET.TB_LOG_DIR= "logs/tb_logs" # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.TRAIN_DATA_SAMPLE_RATIO = [1.0] # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+_CN.DATASET.FP16 = False
+_CN.DATASET.TRAIN_GT_MATCHES_PADDING_N = 8000
+# testing
+_CN.DATASET.TEST_DATA_SOURCE = None
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+
+# 2. dataset config
+# general options
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
+_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
+
+# debug options
+_CN.DATASET.TEST_N_PAIRS = None # Debug first N pairs
+# DEBUG
+_CN.LOFTR.FP16LOG = False
+_CN.LOFTR.MATCH_COARSE.FP16LOG = False
+
+# scanNet options
+_CN.DATASET.SCAN_IMG_RESIZEX = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.SCAN_IMG_RESIZEY = 480 # resize the shorter side, zero-pad bottom-right to square.
+
+# MegaDepth options
+_CN.DATASET.MGDPT_IMG_RESIZE = (640, 640) # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
+_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
+_CN.DATASET.MGDPT_DF = 8
+_CN.DATASET.LOAD_ORIGIN_RGB = False # Only open in test mode, useful for RGB required baselines such as DKM, ROMA.
+_CN.DATASET.READ_GRAY = True
+_CN.DATASET.RESIZE_BY_STRETCH = False
+_CN.DATASET.NORMALIZE_IMG = False # For backbone using pretrained DINO feats, use True may be better.
+_CN.DATASET.HOMO_WARP_USE_MASK = False
+
+_CN.DATASET.NPE_NAME = "megadepth"
+
+############## Trainer ##############
+_CN.TRAINER = CN()
+_CN.TRAINER.WORLD_SIZE = 1
+_CN.TRAINER.CANONICAL_BS = 64
+_CN.TRAINER.CANONICAL_LR = 6e-3
+_CN.TRAINER.SCALING = None # this will be calculated automatically
+_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
+
+# optimizer
+_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
+_CN.TRAINER.OPTIMIZER_EPS = 1e-8 # Default for optimizers, but set smaller, e.g., 1e-7 for fp16 mix training
+_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
+_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
+_CN.TRAINER.ADAMW_DECAY = 0.1
+
+# step-based warm-up
+_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
+_CN.TRAINER.WARMUP_RATIO = 0.
+_CN.TRAINER.WARMUP_STEP = 4800
+
+# learning rate scheduler
+_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
+_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
+_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
+_CN.TRAINER.MSLR_GAMMA = 0.5
+_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
+_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
+
+# plotting related
+_CN.TRAINER.ENABLE_PLOTTING = True
+_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 8 # number of val/test paris for plotting
+_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
+_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
+
+# geometric metrics and pose solver
+_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
+_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
+_CN.TRAINER.WARP_ESTIMATOR_MODEL = 'affine' # [RANSAC, DEGENSAC, MAGSAC]
+_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
+_CN.TRAINER.RANSAC_CONF = 0.99999
+_CN.TRAINER.RANSAC_MAX_ITERS = 10000
+_CN.TRAINER.USE_MAGSACPP = False
+_CN.TRAINER.THRESHOLDS = [5, 10, 20]
+
+# data sampler for train_dataloader
+_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
+# 'scene_balance' config
+_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
+_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
+_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
+_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
+_CN.TRAINER.AUC_METHOD = 'exact_auc'
+# 'random' config
+_CN.TRAINER.RDM_REPLACEMENT = True
+_CN.TRAINER.RDM_NUM_SAMPLES = None
+
+# gradient clipping
+_CN.TRAINER.GRADIENT_CLIPPING = 0.5
+
+# Finetune Mode:
+_CN.FINETUNE = CN()
+_CN.FINETUNE.ENABLE = False
+_CN.FINETUNE.METHOD = "lora" #['lora', 'whole_network']
+
+_CN.FINETUNE.LORA = CN()
+_CN.FINETUNE.LORA.RANK = 2
+_CN.FINETUNE.LORA.MODE = "linear&conv" # ["linear&conv", "linear_only"]
+_CN.FINETUNE.LORA.SCALE = 1.0
+
+_CN.TRAINER.SEED = 66
+
+
+def get_cfg_defaults():
+ """Get a yacs CfgNode object with default values for my_project."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _CN.clone()
diff --git a/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py b/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..91a5132809b88d14d8f02ddf7012c5abba8c46ec
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py
@@ -0,0 +1,343 @@
+
+from collections import defaultdict
+import pprint
+from loguru import logger
+from pathlib import Path
+
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from matplotlib import pyplot as plt
+
+from src.loftr import LoFTR
+from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine, compute_roma_supervision
+from src.optimizers import build_optimizer, build_scheduler
+from src.utils.metrics import (
+ compute_symmetrical_epipolar_errors,
+ compute_pose_errors,
+ compute_homo_corner_warp_errors,
+ compute_homo_match_warp_errors,
+ compute_warp_control_pts_errors,
+ aggregate_metrics
+)
+from src.utils.plotting import make_matching_figures, make_scores_figures
+from src.utils.comm import gather, all_gather
+from src.utils.misc import lower_config, flattenList
+from src.utils.profiler import PassThroughProfiler
+from third_party.ROMA.roma.matchanything_roma_model import MatchAnything_Model
+
+import pynvml
+
+def reparameter(matcher):
+ module = matcher.backbone.layer0
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ print('m0 switch to deploy ok')
+ for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]:
+ for module in modules:
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ print('backbone switch to deploy ok')
+ for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]:
+ for module in modules:
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ print('fpn switch to deploy ok')
+ return matcher
+
+class PL_LoFTR(pl.LightningModule):
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None, test_mode=False, baseline_config=None):
+ """
+ TODO:
+ - use the new version of PL logging API.
+ """
+ super().__init__()
+ # Misc
+ self.config = config # full config
+ _config = lower_config(self.config)
+ self.profiler = profiler or PassThroughProfiler()
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+
+ if config.METHOD == "matchanything_eloftr":
+ self.matcher = LoFTR(config=_config['loftr'], profiler=self.profiler)
+ elif config.METHOD == "matchanything_roma":
+ self.matcher = MatchAnything_Model(config=_config['roma'], test_mode=test_mode)
+ else:
+ raise NotImplementedError
+
+ if config.FINETUNE.ENABLE and test_mode:
+ # Inference time change model architecture before load pretrained model:
+ raise NotImplementedError
+
+ # Pretrained weights
+ if pretrained_ckpt:
+ if config.METHOD in ["matchanything_eloftr", "matchanything_roma"]:
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
+ logger.info(f"Load model from:{self.matcher.load_state_dict(state_dict, strict=False)}")
+ else:
+ raise NotImplementedError
+
+ if self.config.LOFTR.BACKBONE_TYPE == 'RepVGG' and test_mode and (config.METHOD == 'loftr'):
+ module = self.matcher.backbone.layer0
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ print('m0 switch to deploy ok')
+ for modules in [self.matcher.backbone.layer1, self.matcher.backbone.layer2, self.matcher.backbone.layer3]:
+ for module in modules:
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ print('m switch to deploy ok')
+
+ # Testing
+ self.dump_dir = dump_dir
+ self.max_gpu_memory = 0
+ self.GPUID = 0
+ self.warmup = False
+
+ def gpumem(self, des, gpuid=None):
+ NUM_EXPAND = 1024 * 1024 * 1024
+ gpu_id= self.GPUID if self.GPUID is not None else gpuid
+ handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+ gpu_Used = info.used
+ logger.info(f"GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}")
+ # print(des, gpu_Used / NUM_EXPAND)
+ if gpu_Used / NUM_EXPAND > self.max_gpu_memory:
+ self.max_gpu_memory = gpu_Used / NUM_EXPAND
+ logger.info(f"[MAX]GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}")
+ print('max_gpu_memory', self.max_gpu_memory)
+
+ def configure_optimizers(self):
+ optimizer = build_optimizer(self, self.config)
+ scheduler = build_scheduler(self.config, optimizer)
+ return [optimizer], [scheduler]
+
+ def optimizer_step(
+ self, epoch, batch_idx, optimizer, optimizer_idx,
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+ # learning rate warm up
+ warmup_step = self.config.TRAINER.WARMUP_STEP
+ if self.trainer.global_step < warmup_step:
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
+ lr = base_lr + \
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
+ for pg in optimizer.param_groups:
+ pg['lr'] = lr
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+ pass
+ else:
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+
+ # update params
+ if self.config.LOFTR.FP16:
+ optimizer.step(closure=optimizer_closure)
+ else:
+ optimizer.step(closure=optimizer_closure)
+ optimizer.zero_grad()
+
+ def _trainval_inference(self, batch):
+ with self.profiler.profile("Compute coarse supervision"):
+
+ with torch.autocast(enabled=False, device_type='cuda'):
+ if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD):
+ pass
+ else:
+ compute_supervision_coarse(batch, self.config)
+
+ with self.profiler.profile("LoFTR"):
+ with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'):
+ self.matcher(batch)
+
+ with self.profiler.profile("Compute fine supervision"):
+ with torch.autocast(enabled=False, device_type='cuda'):
+ if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD):
+ compute_roma_supervision(batch, self.config)
+ else:
+ compute_supervision_fine(batch, self.config, self.logger)
+
+ with self.profiler.profile("Compute losses"):
+ pass
+
+ def _compute_metrics(self, batch):
+ if 'gt_2D_matches' in batch:
+ compute_warp_control_pts_errors(batch, self.config)
+ elif batch['homography'].sum() != 0 and batch['T_0to1'].sum() == 0:
+ compute_homo_match_warp_errors(batch, self.config) # compute warp_errors for each match
+ compute_homo_corner_warp_errors(batch, self.config) # compute mean corner warp error each pair
+ else:
+ compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match
+ compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
+
+ rel_pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].size(0)
+ if self.config.LOFTR.FINE.MTD_SPVS:
+ topk = self.config.LOFTR.MATCH_FINE.TOPK
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [(batch['epi_errs'].reshape(-1,topk))[batch['m_bids'] == b].reshape(-1).cpu().numpy() for b in range(bs)],
+ 'R_errs': batch['R_errs'],
+ 't_errs': batch['t_errs'],
+ 'inliers': batch['inliers'],
+ 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only
+ 'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only
+ }
+ else:
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
+ 'R_errs': batch['R_errs'],
+ 't_errs': batch['t_errs'],
+ 'inliers': batch['inliers'],
+ 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only
+ 'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only
+ }
+ ret_dict = {'metrics': metrics}
+ return ret_dict, rel_pair_names
+
+ def training_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ # logging
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+ # scalars
+ for k, v in batch['loss_scalars'].items():
+ self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
+
+ # net-params
+ method = 'LOFTR'
+ if self.config[method]['MATCH_COARSE']['MATCH_TYPE'] == 'sinkhorn':
+ self.logger.experiment.add_scalar(
+ f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step)
+
+ figures = {}
+ if self.config.TRAINER.ENABLE_PLOTTING:
+ compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match
+ figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+ for k, v in figures.items():
+ self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
+
+ return {'loss': batch['loss']}
+
+ def training_epoch_end(self, outputs):
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+ if self.trainer.global_rank == 0:
+ self.logger.experiment.add_scalar(
+ 'train/avg_loss_on_epoch', avg_loss,
+ global_step=self.current_epoch)
+
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
+ self._trainval_inference(batch)
+
+ ret_dict, _ = self._compute_metrics(batch)
+
+ val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
+ figures = {self.config.TRAINER.PLOT_MODE: []}
+ if batch_idx % val_plot_interval == 0:
+ figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
+ if self.config.LOFTR.PLOT_SCORES:
+ figs = make_scores_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+ figures[self.config.TRAINER.PLOT_MODE] += figs[self.config.TRAINER.PLOT_MODE]
+ del figs
+
+ return {
+ **ret_dict,
+ 'loss_scalars': batch['loss_scalars'],
+ 'figures': figures,
+ }
+
+ def validation_epoch_end(self, outputs):
+ # handle multiple validation sets
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+ multi_val_metrics = defaultdict(list)
+
+ for valset_idx, outputs in enumerate(multi_outputs):
+ # since pl performs sanity_check at the very begining of the training
+ cur_epoch = self.trainer.current_epoch
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+ cur_epoch = -1
+
+ # 1. loss_scalars: dict of list, on cpu
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+
+ # 2. val metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES)
+ for thr in [5, 10, 20]:
+ multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
+
+ # 3. figures
+ _figures = [o['figures'] for o in outputs]
+ figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
+
+ # tensorboard records only on rank 0
+ if self.trainer.global_rank == 0:
+ for k, v in loss_scalars.items():
+ mean_v = torch.stack(v).mean()
+ self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+
+ for k, v in val_metrics_4tb.items():
+ self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
+
+ for k, v in figures.items():
+ if self.trainer.global_rank == 0:
+ for plot_idx, fig in enumerate(v):
+ self.logger.experiment.add_figure(
+ f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
+ plt.close('all')
+
+ for thr in [5, 10, 20]:
+ self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
+
+ def test_step(self, batch, batch_idx):
+ if self.warmup:
+ for i in range(50):
+ self.matcher(batch)
+ self.warmup = False
+
+ with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'):
+ with self.profiler.profile("LoFTR"):
+ self.matcher(batch)
+
+ ret_dict, rel_pair_names = self._compute_metrics(batch)
+ print(ret_dict['metrics']['num_matches'])
+ self.dump_dir = None
+
+ return ret_dict
+
+ def test_epoch_end(self, outputs):
+ print(self.config)
+ print('max GPU memory: ', self.max_gpu_memory)
+ print(self.profiler.summary())
+ # metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+
+ # [{key: [{...}, *#bs]}, *#batch]
+ if self.dump_dir is not None:
+ Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
+ _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
+ dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
+ logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
+
+ if self.trainer.global_rank == 0:
+ NUM_EXPAND = 1024 * 1024 * 1024
+ gpu_id=self.GPUID
+ handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+ gpu_Used = info.used
+ print('pynvml', gpu_Used / NUM_EXPAND)
+ if gpu_Used / NUM_EXPAND > self.max_gpu_memory:
+ self.max_gpu_memory = gpu_Used / NUM_EXPAND
+
+ print(self.profiler.summary())
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES, self.config.TRAINER.THRESHOLDS, method=self.config.TRAINER.AUC_METHOD)
+ logger.info('\n' + pprint.pformat(val_metrics_4tb))
+ if self.dump_dir is not None:
+ np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps)
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/__init__.py b/imcui/third_party/MatchAnything/src/loftr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82e7da71337eb772257c9a2b6c96b41a562aadea
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/__init__.py
@@ -0,0 +1 @@
+from .loftr import LoFTR
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py b/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0eb682da64b684eeddcc5ea576b6e89137dd40b
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py
@@ -0,0 +1,61 @@
+from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4, ResNetFPN_8_1, ResNetFPN_8_2_align, ResNetFPN_8_1_align, ResNetFPN_8_2_fix, ResNet_8_1_align, VGG_8_1_align, RepVGG_8_1_align, \
+ RepVGGnfpn_8_1_align, RepVGG_8_2_fix, s2dnet_8_1_align
+
+def build_backbone(config):
+ if config['backbone_type'] == 'ResNetFPN':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ if config['resolution'] == (8, 2):
+ return ResNetFPN_8_2(config['resnetfpn'])
+ elif config['resolution'] == (16, 4):
+ return ResNetFPN_16_4(config['resnetfpn'])
+ elif config['resolution'] == (8, 1):
+ return ResNetFPN_8_1(config['resnetfpn'])
+ elif config['align_corner'] is False:
+ if config['resolution'] == (8, 2):
+ return ResNetFPN_8_2_align(config['resnetfpn'])
+ elif config['resolution'] == (16, 4):
+ return ResNetFPN_16_4(config['resnetfpn'])
+ elif config['resolution'] == (8, 1):
+ return ResNetFPN_8_1_align(config['resnetfpn'])
+ elif config['backbone_type'] == 'ResNetFPNFIX':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ if config['resolution'] == (8, 2):
+ return ResNetFPN_8_2_fix(config['resnetfpn'])
+ elif config['backbone_type'] == 'ResNet':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
+ elif config['align_corner'] is False:
+ if config['resolution'] == (8, 1):
+ return ResNet_8_1_align(config['resnetfpn'])
+ elif config['backbone_type'] == 'VGG':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
+ elif config['align_corner'] is False:
+ if config['resolution'] == (8, 1):
+ return VGG_8_1_align(config['resnetfpn'])
+ elif config['backbone_type'] == 'RepVGG':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
+ elif config['align_corner'] is False:
+ if config['resolution'] == (8, 1):
+ return RepVGG_8_1_align(config['resnetfpn'])
+ elif config['backbone_type'] == 'RepVGGNFPN':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
+ elif config['align_corner'] is False:
+ if config['resolution'] == (8, 1):
+ return RepVGGnfpn_8_1_align(config['resnetfpn'])
+ elif config['backbone_type'] == 'RepVGGFPNFIX':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ if config['resolution'] == (8, 2):
+ return RepVGG_8_2_fix(config['resnetfpn'])
+ elif config['align_corner'] is False:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
+ elif config['backbone_type'] == 's2dnet':
+ if config['align_corner'] is None or config['align_corner'] is True:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
+ elif config['align_corner'] is False:
+ if config['resolution'] == (8, 1):
+ return s2dnet_8_1_align(config['resnetfpn'])
+ else:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py b/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..873a934dc0094fc742076c10efbaafcc78c283a7
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py
@@ -0,0 +1,319 @@
+# --------------------------------------------------------
+# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
+# Github source: https://github.com/DingXiaoH/RepVGG
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+import torch.nn as nn
+import numpy as np
+import torch
+import copy
+# from se_block import SEBlock
+import torch.utils.checkpoint as checkpoint
+from loguru import logger
+
+def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
+ result = nn.Sequential()
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
+ return result
+
+class RepVGGBlock(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, leaky=-1.0):
+ super(RepVGGBlock, self).__init__()
+ self.deploy = deploy
+ self.groups = groups
+ self.in_channels = in_channels
+
+ assert kernel_size == 3
+ assert padding == 1
+
+ padding_11 = padding - kernel_size // 2
+
+ if leaky == -2:
+ self.nonlinearity = nn.Identity()
+ logger.info(f"Using Identity nonlinearity in repvgg_block")
+ elif leaky < 0:
+ self.nonlinearity = nn.ReLU()
+ else:
+ self.nonlinearity = nn.LeakyReLU(leaky)
+
+ if use_se:
+ # Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.
+ # self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
+ raise ValueError(f"SEBlock not supported")
+ else:
+ self.se = nn.Identity()
+
+ if deploy:
+ self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
+
+ else:
+ self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
+ self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
+ self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
+ print('RepVGG Block, identity = ', self.rbr_identity)
+
+
+ def forward(self, inputs):
+ if hasattr(self, 'rbr_reparam'):
+ return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
+
+ if self.rbr_identity is None:
+ id_out = 0
+ else:
+ id_out = self.rbr_identity(inputs)
+
+ return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
+
+
+ # Optional. This may improve the accuracy and facilitates quantization in some cases.
+ # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
+ # 2. Use like this.
+ # loss = criterion(....)
+ # for every RepVGGBlock blk:
+ # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
+ # optimizer.zero_grad()
+ # loss.backward()
+ def get_custom_L2(self):
+ K3 = self.rbr_dense.conv.weight
+ K1 = self.rbr_1x1.conv.weight
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
+
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
+ eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
+ return l2_loss_eq_kernel + l2_loss_circle
+
+
+
+# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
+# You can get the equivalent kernel and bias at any time and do whatever you want,
+ # for example, apply some penalties or constraints during training, just like you do to the other models.
+# May be useful for quantization or pruning.
+ def get_equivalent_kernel_bias(self):
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
+
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
+ if kernel1x1 is None:
+ return 0
+ else:
+ return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
+
+ def _fuse_bn_tensor(self, branch):
+ if branch is None:
+ return 0, 0
+ if isinstance(branch, nn.Sequential):
+ kernel = branch.conv.weight
+ running_mean = branch.bn.running_mean
+ running_var = branch.bn.running_var
+ gamma = branch.bn.weight
+ beta = branch.bn.bias
+ eps = branch.bn.eps
+ else:
+ assert isinstance(branch, nn.BatchNorm2d)
+ if not hasattr(self, 'id_tensor'):
+ input_dim = self.in_channels // self.groups
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
+ for i in range(self.in_channels):
+ kernel_value[i, i % input_dim, 1, 1] = 1
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
+ kernel = self.id_tensor
+ running_mean = branch.running_mean
+ running_var = branch.running_var
+ gamma = branch.weight
+ beta = branch.bias
+ eps = branch.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+
+ def switch_to_deploy(self):
+ if hasattr(self, 'rbr_reparam'):
+ return
+ kernel, bias = self.get_equivalent_kernel_bias()
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
+ kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
+ padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
+ self.rbr_reparam.weight.data = kernel
+ self.rbr_reparam.bias.data = bias
+ self.__delattr__('rbr_dense')
+ self.__delattr__('rbr_1x1')
+ if hasattr(self, 'rbr_identity'):
+ self.__delattr__('rbr_identity')
+ if hasattr(self, 'id_tensor'):
+ self.__delattr__('id_tensor')
+ self.deploy = True
+
+
+
+class RepVGG(nn.Module):
+
+ def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False, leaky=-1.0):
+ super(RepVGG, self).__init__()
+ assert len(width_multiplier) == 4
+ self.deploy = deploy
+ self.override_groups_map = override_groups_map or dict()
+ assert 0 not in self.override_groups_map
+ self.use_se = use_se
+ self.use_checkpoint = use_checkpoint
+
+ self.in_planes = min(64, int(64 * width_multiplier[0]))
+ self.stage0 = RepVGGBlock(in_channels=1, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se, leaky=leaky)
+ self.cur_layer_idx = 1
+ self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1, leaky=leaky)
+ self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2, leaky=leaky)
+ self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2, leaky=leaky)
+ # self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=1)
+ # self.gap = nn.AdaptiveAvgPool2d(output_size=1)
+ # self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
+
+ def _make_stage(self, planes, num_blocks, stride, leaky=-1.0):
+ strides = [stride] + [1]*(num_blocks-1)
+ blocks = []
+ for stride in strides:
+ cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
+ blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
+ stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se, leaky=leaky))
+ self.in_planes = planes
+ self.cur_layer_idx += 1
+ return nn.ModuleList(blocks)
+
+ def forward(self, x):
+ out = self.stage0(x)
+ for stage in (self.stage1, self.stage2, self.stage3): # , self.stage4):
+ for block in stage:
+ if self.use_checkpoint:
+ out = checkpoint.checkpoint(block, out)
+ else:
+ out = block(out)
+ out = self.gap(out)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
+g2_map = {l: 2 for l in optional_groupwise_layers}
+g4_map = {l: 4 for l in optional_groupwise_layers}
+
+def create_RepVGG_A0(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
+ width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_A1(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+def create_RepVGG_A15(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
+ width_multiplier=[1.25, 1.25, 1.25, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+def create_RepVGG_A1_leaky(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint, leaky=0.01)
+
+def create_RepVGG_A2(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
+ width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B0(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B1(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B1g2(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B1g4(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
+
+
+def create_RepVGG_B2(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B2g2(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B2g4(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
+
+
+def create_RepVGG_B3(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B3g2(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_B3g4(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
+ width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
+
+def create_RepVGG_D2se(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint)
+
+
+func_dict = {
+'RepVGG-A0': create_RepVGG_A0,
+'RepVGG-A1': create_RepVGG_A1,
+'RepVGG-A15': create_RepVGG_A15,
+'RepVGG-A1_leaky': create_RepVGG_A1_leaky,
+'RepVGG-A2': create_RepVGG_A2,
+'RepVGG-B0': create_RepVGG_B0,
+'RepVGG-B1': create_RepVGG_B1,
+'RepVGG-B1g2': create_RepVGG_B1g2,
+'RepVGG-B1g4': create_RepVGG_B1g4,
+'RepVGG-B2': create_RepVGG_B2,
+'RepVGG-B2g2': create_RepVGG_B2g2,
+'RepVGG-B2g4': create_RepVGG_B2g4,
+'RepVGG-B3': create_RepVGG_B3,
+'RepVGG-B3g2': create_RepVGG_B3g2,
+'RepVGG-B3g4': create_RepVGG_B3g4,
+'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper.
+}
+def get_RepVGG_func_by_name(name):
+ return func_dict[name]
+
+
+
+# Use this for converting a RepVGG model or a bigger model with RepVGG as its component
+# Use like this
+# model = create_RepVGG_A0(deploy=False)
+# train model or load weights
+# repvgg_model_convert(model, save_path='repvgg_deploy.pth')
+# If you want to preserve the original model, call with do_copy=True
+
+# ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
+# train_backbone = create_RepVGG_B2(deploy=False)
+# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
+# train_pspnet = build_pspnet(backbone=train_backbone)
+# segmentation_train(train_pspnet)
+# deploy_pspnet = repvgg_model_convert(train_pspnet)
+# segmentation_test(deploy_pspnet)
+# ===================== example_pspnet.py shows an example
+
+def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
+ if do_copy:
+ model = copy.deepcopy(model)
+ for module in model.modules():
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ if save_path is not None:
+ torch.save(model.state_dict(), save_path)
+ return model
diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py b/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3596d7bd7f827197476e3f6ffaa1770a6913a3f8
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py
@@ -0,0 +1,1094 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from .repvgg import get_RepVGG_func_by_name
+from .s2dnet import S2DNet
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution without padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super().__init__()
+ self.conv1 = conv3x3(in_planes, planes, stride)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ conv1x1(in_planes, planes, stride=stride),
+ nn.BatchNorm2d(planes)
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.bn1(self.conv1(y)))
+ y = self.bn2(self.conv2(y))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+class ResNetFPN_8_2(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1/2.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ # FPN
+ x3_out = self.layer3_outconv(x3)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ return {'feats_c': x3_out, 'feats_f': x1_out}
+
+ def pro(self, x, profiler):
+ with profiler.profile('ResNet Backbone'):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ with profiler.profile('ResNet FPN'):
+ # FPN
+ x3_out = self.layer3_outconv(x3)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ return {'feats_c': x3_out, 'feats_f': x1_out}
+
+class ResNetFPN_8_2_fix(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1/2.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ # FPN
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
+ else:
+ return {'feats_c': x3, 'feats_f': None}
+
+
+ x3_out = self.layer3_outconv(x3) # n+1
+
+ x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 2n+1
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 4n+1
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ return {'feats_c': x3_out, 'feats_f': x1_out}
+
+
+class ResNetFPN_16_4(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/16 and 1/4.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+ self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
+
+ # 3. FPN upsample
+ self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
+ self.layer3_outconv2 = nn.Sequential(
+ conv3x3(block_dims[3], block_dims[3]),
+ nn.BatchNorm2d(block_dims[3]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[3], block_dims[2]),
+ )
+
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+ x4 = self.layer4(x3) # 1/16
+
+ # FPN
+ x4_out = self.layer4_outconv(x4)
+
+ x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x3_out = self.layer3_outconv(x3)
+ x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ return {'feats_c': x4_out, 'feats_f': x2_out}
+
+
+class ResNetFPN_8_1(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ if not self.skip_fine_feature:
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ # FPN
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
+ else:
+ return {'feats_c': x3, 'feats_f': None}
+
+ x3_out = self.layer3_outconv(x3)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
+
+ if not self.inter_feat:
+ return {'feats_c': x3, 'feats_f': x0_out}
+ else:
+ return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1}
+
+
+class ResNetFPN_8_1_align(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ if not self.skip_fine_feature:
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ # FPN
+
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
+ else:
+ return {'feats_c': x3, 'feats_f': None}
+
+ x3_out = self.layer3_outconv(x3)
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
+
+ if not self.inter_feat:
+ return {'feats_c': x3, 'feats_f': x0_out}
+ else:
+ return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1}
+
+ def pro(self, x, profiler):
+ with profiler.profile('ResNet Backbone'):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ with profiler.profile('FPN'):
+ # FPN
+ x3_out = self.layer3_outconv(x3)
+
+ if self.skip_fine:
+ return [x3_out, None]
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ with profiler.profile('upsample*1'):
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
+
+ return {'feats_c': x3_out, 'feats_f': x0_out}
+
+
+class ResNetFPN_8_2_align(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1/2.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ if not self.skip_fine_feature:
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
+ else:
+ return {'feats_c': x3, 'feats_f': None}
+
+ # FPN
+ x3_out = self.layer3_outconv(x3)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ if not self.inter_feat:
+ return {'feats_c': x3, 'feats_f': x1_out}
+ else:
+ return {'feats_c': x3, 'feats_f': x1_out, 'feats_x2': x2, 'feats_x1': x1}
+
+
+class ResNet_8_1_align(nn.Module):
+ """
+ ResNet, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ # self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ # self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ # self.layer2_outconv2 = nn.Sequential(
+ # conv3x3(block_dims[2], block_dims[2]),
+ # nn.BatchNorm2d(block_dims[2]),
+ # nn.LeakyReLU(),
+ # conv3x3(block_dims[2], block_dims[1]),
+ # )
+ # self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ # self.layer1_outconv2 = nn.Sequential(
+ # conv3x3(block_dims[1], block_dims[1]),
+ # nn.BatchNorm2d(block_dims[1]),
+ # nn.LeakyReLU(),
+ # conv3x3(block_dims[1], block_dims[0]),
+ # )
+ self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ # FPN
+ # x3_out = self.layer3_outconv(x3)
+
+ # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
+ # x2_out = self.layer2_outconv(x2)
+ # x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
+ # x1_out = self.layer1_outconv(x1)
+ # x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ x0_out = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
+ x0_out = self.layer0_outconv(x0_out)
+
+ return {'feats_c': x3, 'feats_f': x0_out}
+
+class VGG_8_1_align(nn.Module):
+ """
+ VGG-like backbone, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ self.relu = nn.ReLU(inplace=True)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
+
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
+
+ # self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ # self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
+
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convDb = nn.Conv2d(
+ c5, 256,
+ kernel_size=1, stride=1, padding=0)
+ self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # Shared Encoder
+ x = self.relu(self.conv1a(x))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ cDa = self.relu(self.convDa(x))
+ descriptors = self.convDb(cDa)
+ x3_out = nn.functional.normalize(descriptors, p=2, dim=1)
+
+ x0_out = F.interpolate(x3_out, scale_factor=8., mode='bilinear', align_corners=False)
+ x0_out = self.layer0_outconv(x0_out)
+ # ResNet Backbone
+ # x0 = self.relu(self.bn1(self.conv1(x)))
+ # x1 = self.layer1(x0) # 1/2
+ # x2 = self.layer2(x1) # 1/4
+ # x3 = self.layer3(x2) # 1/8
+
+ # # FPN
+ # x3_out = self.layer3_outconv(x3)
+
+ # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
+ # x2_out = self.layer2_outconv(x2)
+ # x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
+ # x1_out = self.layer1_outconv(x1)
+ # x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
+
+ return {'feats_c': x3_out, 'feats_f': x0_out}
+
+class RepVGG_8_1_align(nn.Module):
+ """
+ RepVGG backbone, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ # block = BasicBlock
+ # initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+ self.leaky = config['leaky']
+
+ # backbone_name='RepVGG-B0'
+ if config.get('repvggmodel') is not None:
+ backbone_name=config['repvggmodel']
+ elif self.leaky:
+ backbone_name='RepVGG-A1_leaky'
+ else:
+ backbone_name='RepVGG-A1'
+ repvgg_fn = get_RepVGG_func_by_name(backbone_name)
+ backbone = repvgg_fn(False)
+ self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
+ # self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4
+
+ # 3. FPN upsample
+ if not self.skip_fine_feature:
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ # self.layer0_outconv = conv1x1(192, 48)
+
+ for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
+ for m in layer.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ # for layer in [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4]:
+ # for m in layer.modules():
+ # if isinstance(m, nn.Conv2d):
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ # nn.init.constant_(m.weight, 1)
+ # nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+
+ out = self.layer0(x) # 1/2
+ for module in self.layer1:
+ out = module(out) # 1/2
+ x1 = out
+ for module in self.layer2:
+ out = module(out) # 1/4
+ x2 = out
+ for module in self.layer3:
+ out = module(out) # 1/8
+ x3 = out
+ # for module in self.layer4:
+ # out = module(out)
+ # x3 = out
+
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
+ else:
+ return {'feats_c': x3, 'feats_f': None}
+ x3_out = self.layer3_outconv(x3)
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
+
+ # x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False)
+ # x_f = self.layer0_outconv(x_f)
+ return {'feats_c': x3_out, 'feats_f': x0_out}
+
+
+class RepVGG_8_2_fix(nn.Module):
+ """
+ RepVGG backbone, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ # block = BasicBlock
+ # initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+
+ # backbone_name='RepVGG-B0'
+ backbone_name='RepVGG-A1'
+ repvgg_fn = get_RepVGG_func_by_name(backbone_name)
+ backbone = repvgg_fn(False)
+ self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
+
+ # 3. FPN upsample
+ if not self.skip_fine_feature:
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ # self.layer0_outconv = conv1x1(192, 48)
+
+ for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
+ for m in layer.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+
+ x0 = self.layer0(x) # 1/2
+ out = x0
+ for module in self.layer1:
+ out = module(out) # 1/2
+ x1 = out
+ for module in self.layer2:
+ out = module(out) # 1/4
+ x2 = out
+ for module in self.layer3:
+ out = module(out) # 1/8
+ x3 = out
+ # for module in self.layer4:
+ # out = module(out)
+
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
+ else:
+ return {'feats_c': x3, 'feats_f': None}
+ x3_out = self.layer3_outconv(x3)
+ x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
+
+ # x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False)
+ # x_f = self.layer0_outconv(x_f)
+ return {'feats_c': x3_out, 'feats_f': x1_out}
+
+
+class RepVGGnfpn_8_1_align(nn.Module):
+ """
+ RepVGG backbone, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ # block = BasicBlock
+ # initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+
+ # backbone_name='RepVGG-B0'
+ backbone_name='RepVGG-A1'
+ repvgg_fn = get_RepVGG_func_by_name(backbone_name)
+ backbone = repvgg_fn(False)
+ self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4
+
+ # 3. FPN upsample
+ if not self.skip_fine_feature:
+ self.layer0_outconv = conv1x1(block_dims[2], block_dims[0])
+ # self.layer0_outconv = conv1x1(192, 48)
+
+ for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
+ for m in layer.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+
+ x0 = self.layer0(x) # 1/2
+ out = x0
+ for module in self.layer1:
+ out = module(out) # 1/2
+ x1 = out
+ for module in self.layer2:
+ out = module(out) # 1/4
+ x2 = out
+ for module in self.layer3:
+ out = module(out) # 1/8
+ x3 = out
+ # for module in self.layer4:
+ # out = module(out)
+
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
+ else:
+ return {'feats_c': x3, 'feats_f': None}
+ # x3_out = self.layer3_outconv(x3)
+ # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False)
+ # x2_out = self.layer2_outconv(x2)
+ # x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False)
+ # x1_out = self.layer1_outconv(x1)
+ # x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False)
+
+ x_f = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
+ x_f = self.layer0_outconv(x_f)
+ # x_f2 = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False)
+ # x_f2 = self.layer0_outconv(x_f2)
+ return {'feats_c': x3, 'feats_f': x_f}
+
+
+class s2dnet_8_1_align(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+ self.skip_fine_feature = config['coarse_feat_only']
+ self.inter_feat = config['inter_feat']
+ # Networks
+ self.backbone = S2DNet(checkpoint_path = '/cephfs-mvs/3dv-research/hexingyi/code_yf/loftrdev/weights/s2dnet/s2dnet_weights.pth')
+ # 3. FPN upsample
+ # self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ # if not self.skip_fine_feature:
+ # self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ # self.layer2_outconv2 = nn.Sequential(
+ # conv3x3(block_dims[2], block_dims[2]),
+ # nn.BatchNorm2d(block_dims[2]),
+ # nn.LeakyReLU(),
+ # conv3x3(block_dims[2], block_dims[1]),
+ # )
+ # self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ # self.layer1_outconv2 = nn.Sequential(
+ # conv3x3(block_dims[1], block_dims[1]),
+ # nn.BatchNorm2d(block_dims[1]),
+ # nn.LeakyReLU(),
+ # conv3x3(block_dims[1], block_dims[0]),
+ # )
+
+ # for m in self.modules():
+ # if isinstance(m, nn.Conv2d):
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ # nn.init.constant_(m.weight, 1)
+ # nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ ret = self.backbone(x)
+ ret[2] = F.interpolate(ret[2], scale_factor=2., mode='bilinear', align_corners=False)
+ if self.skip_fine_feature:
+ if self.inter_feat:
+ return {'feats_c': ret[2], 'feats_f': None, 'feats_x2': ret[1], 'feats_x1': ret[0]}
+ else:
+ return {'feats_c': ret[2], 'feats_f': None,}
+
+ def pro(self, x, profiler):
+ pass
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py b/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4c2eb8a61a5193405ea86b7e67c11a19fa94f7
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# from torchvision import models
+from typing import List, Dict
+
+# VGG-16 Layer Names and Channels
+vgg16_layers = {
+ "conv1_1": 64,
+ "relu1_1": 64,
+ "conv1_2": 64,
+ "relu1_2": 64,
+ "pool1": 64,
+ "conv2_1": 128,
+ "relu2_1": 128,
+ "conv2_2": 128,
+ "relu2_2": 128,
+ "pool2": 128,
+ "conv3_1": 256,
+ "relu3_1": 256,
+ "conv3_2": 256,
+ "relu3_2": 256,
+ "conv3_3": 256,
+ "relu3_3": 256,
+ "pool3": 256,
+ "conv4_1": 512,
+ "relu4_1": 512,
+ "conv4_2": 512,
+ "relu4_2": 512,
+ "conv4_3": 512,
+ "relu4_3": 512,
+ "pool4": 512,
+ "conv5_1": 512,
+ "relu5_1": 512,
+ "conv5_2": 512,
+ "relu5_2": 512,
+ "conv5_3": 512,
+ "relu5_3": 512,
+ "pool5": 512,
+}
+
+class AdapLayers(nn.Module):
+ """Small adaptation layers.
+ """
+
+ def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128):
+ """Initialize one adaptation layer for every extraction point.
+
+ Args:
+ hypercolumn_layers: The list of the hypercolumn layer names.
+ output_dim: The output channel dimension.
+ """
+ super(AdapLayers, self).__init__()
+ self.layers = []
+ channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers]
+ for i, l in enumerate(channel_sizes):
+ layer = nn.Sequential(
+ nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(),
+ nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2),
+ nn.BatchNorm2d(output_dim),
+ )
+ self.layers.append(layer)
+ self.add_module("adap_layer_{}".format(i), layer)
+
+ def forward(self, features: List[torch.tensor]):
+ """Apply adaptation layers.
+ """
+ for i, _ in enumerate(features):
+ features[i] = getattr(self, "adap_layer_{}".format(i))(features[i])
+ return features
+
+class S2DNet(nn.Module):
+ """The S2DNet model
+ """
+
+ def __init__(
+ self,
+ # hypercolumn_layers: List[str] = ["conv2_2", "conv3_3", "relu4_3"],
+ hypercolumn_layers: List[str] = ["conv1_2", "conv3_3", "conv5_3"],
+ checkpoint_path: str = None,
+ ):
+ """Initialize S2DNet.
+
+ Args:
+ device: The torch device to put the model on
+ hypercolumn_layers: Names of the layers to extract features from
+ checkpoint_path: Path to the pre-trained model.
+ """
+ super(S2DNet, self).__init__()
+ self._checkpoint_path = checkpoint_path
+ self.layer_to_index = dict((k, v) for v, k in enumerate(vgg16_layers.keys()))
+ self._hypercolumn_layers = hypercolumn_layers
+
+ # Initialize architecture
+ vgg16 = models.vgg16(pretrained=False)
+ # layers = list(vgg16.features.children())[:-2]
+ layers = list(vgg16.features.children())[:-1]
+ # layers = list(vgg16.features.children())[:23] # relu4_3
+ self.encoder = nn.Sequential(*layers)
+ self.adaptation_layers = AdapLayers(self._hypercolumn_layers) # .to(self._device)
+ self.eval()
+
+ # Restore params from checkpoint
+ if checkpoint_path:
+ print(">> Loading weights from {}".format(checkpoint_path))
+ self._checkpoint = torch.load(checkpoint_path)
+ self._hypercolumn_layers = self._checkpoint["hypercolumn_layers"]
+ self.load_state_dict(self._checkpoint["state_dict"])
+
+ def forward(self, image_tensor: torch.FloatTensor):
+ """Compute intermediate feature maps at the provided extraction levels.
+
+ Args:
+ image_tensor: The [N x 3 x H x Ws] input image tensor.
+ Returns:
+ feature_maps: The list of output feature maps.
+ """
+ feature_maps, j = [], 0
+ feature_map = image_tensor.repeat(1,3,1,1)
+ layer_list = list(self.encoder.modules())[0]
+ for i, layer in enumerate(layer_list):
+ feature_map = layer(feature_map)
+ if j < len(self._hypercolumn_layers):
+ next_extraction_index = self.layer_to_index[self._hypercolumn_layers[j]]
+ if i == next_extraction_index:
+ feature_maps.append(feature_map)
+ j += 1
+ feature_maps = self.adaptation_layers(feature_maps)
+ return feature_maps
diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr.py b/imcui/third_party/MatchAnything/src/loftr/loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..53b7d4a86ec175b483ead096ff9db1ae5802fa63
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/loftr.py
@@ -0,0 +1,273 @@
+import torch
+import torch.nn as nn
+from einops.einops import rearrange
+
+from .backbone import build_backbone
+# from third_party.matchformer.model.backbone import build_backbone as build_backbone_matchformer
+from .utils.position_encoding import PositionEncodingSine
+from .loftr_module import LocalFeatureTransformer, FinePreprocess
+from .utils.coarse_matching import CoarseMatching
+from .utils.fine_matching import FineMatching
+
+from loguru import logger
+
+class LoFTR(nn.Module):
+ def __init__(self, config, profiler=None):
+ super().__init__()
+ # Misc
+ self.config = config
+ self.profiler = profiler
+
+ # Modules
+ self.backbone = build_backbone(config)
+ if not (self.config['coarse']['skip'] or self.config['coarse']['rope'] or self.config['coarse']['pan'] or self.config['coarse']['token_mixer'] is not None):
+ self.pos_encoding = PositionEncodingSine(
+ config['coarse']['d_model'],
+ temp_bug_fix=config['coarse']['temp_bug_fix'],
+ npe=config['coarse']['npe'],
+ )
+ if self.config['coarse']['abspe']:
+ self.pos_encoding = PositionEncodingSine(
+ config['coarse']['d_model'],
+ temp_bug_fix=config['coarse']['temp_bug_fix'],
+ npe=config['coarse']['npe'],
+ )
+
+ if self.config['coarse']['skip'] is False:
+ self.loftr_coarse = LocalFeatureTransformer(config)
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
+ # self.fine_preprocess = FinePreprocess(config).float()
+ self.fine_preprocess = FinePreprocess(config)
+ if self.config['fine']['skip'] is False:
+ self.loftr_fine = LocalFeatureTransformer(config["fine"])
+ self.fine_matching = FineMatching(config)
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
+ # feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
+ ret_dict = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
+ feats_c, feats_f = ret_dict['feats_c'], ret_dict['feats_f']
+ if self.config['inter_feat']:
+ data.update({
+ 'feats_x2': ret_dict['feats_x2'],
+ 'feats_x1': ret_dict['feats_x1'],
+ })
+ if self.config['coarse_feat_only']:
+ (feat_c0, feat_c1) = feats_c.split(data['bs'])
+ feat_f0, feat_f1 = None, None
+ else:
+ (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
+ else: # handle different input shapes
+ # (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
+ ret_dict0, ret_dict1 = self.backbone(data['image0']), self.backbone(data['image1'])
+ feat_c0, feat_f0 = ret_dict0['feats_c'], ret_dict0['feats_f']
+ feat_c1, feat_f1 = ret_dict1['feats_c'], ret_dict1['feats_f']
+ if self.config['inter_feat']:
+ data.update({
+ 'feats_x2_0': ret_dict0['feats_x2'],
+ 'feats_x1_0': ret_dict0['feats_x1'],
+ 'feats_x2_1': ret_dict1['feats_x2'],
+ 'feats_x1_1': ret_dict1['feats_x1'],
+ })
+ if self.config['coarse_feat_only']:
+ feat_f0, feat_f1 = None, None
+
+
+ mul = self.config['resolution'][0] // self.config['resolution'][1]
+ # mul = 4
+ if self.config['fix_bias']:
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [(feat_c0.shape[2]-1) * mul+1, (feat_c0.shape[3]-1) * mul+1] ,
+ 'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [(feat_c1.shape[2]-1) * mul+1, (feat_c1.shape[3]-1) * mul+1]
+ })
+ else:
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [feat_c0.shape[2] * mul, feat_c0.shape[3] * mul] ,
+ 'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [feat_c1.shape[2] * mul, feat_c1.shape[3] * mul]
+ })
+
+ # 2. coarse-level loftr module
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
+ if self.config['coarse']['skip']:
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'], data['mask1']
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+
+ elif self.config['coarse']['pan']:
+ # assert feat_c0.shape[0] == 1, 'batch size must be 1 when using mask Xformer now'
+ if self.config['coarse']['abspe']:
+ feat_c0 = self.pos_encoding(feat_c0)
+ feat_c1 = self.pos_encoding(feat_c1)
+
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'], data['mask1']
+ if self.config['matchability']: # else match in loftr_coarse
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1, data=data)
+ else:
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
+
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+ else:
+ if not (self.config['coarse']['rope'] or self.config['coarse']['token_mixer'] is not None):
+ feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
+
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if self.config['coarse']['rope']:
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'], data['mask1']
+ else:
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
+ if self.config['coarse']['rope']:
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+
+ # detect nan
+ if self.config['replace_nan'] and (torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))):
+ logger.info(f'replace nan in coarse attention')
+ logger.info(f"feat_c0_nan_num: {torch.isnan(feat_c0).int().sum()}, feat_c1_nan_num: {torch.isnan(feat_c1).int().sum()}")
+ logger.info(f"feat_c0: {feat_c0}, feat_c1: {feat_c1}")
+ logger.info(f"feat_c0_max: {feat_c0.abs().max()}, feat_c1_max: {feat_c1.abs().max()}")
+ feat_c0[torch.isnan(feat_c0)] = 0
+ feat_c1[torch.isnan(feat_c1)] = 0
+ logger.info(f"feat_c0_nanmax: {feat_c0.abs().max()}, feat_c1_nanmax: {feat_c1.abs().max()}")
+
+ # 3. match coarse-level
+ if not self.config['matchability']: # else match in loftr_coarse
+ self.coarse_matching(feat_c0, feat_c1, data,
+ mask_c0=mask_c0.view(mask_c0.size(0), -1) if mask_c0 is not None else mask_c0,
+ mask_c1=mask_c1.view(mask_c1.size(0), -1) if mask_c1 is not None else mask_c1
+ )
+
+ #return data['conf_matrix'],feat_c0,feat_c1,data['feats_x2'],data['feats_x1']
+
+ # norm FPNfeat
+ if self.config['norm_fpnfeat']:
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1])
+ if self.config['norm_fpnfeat2']:
+ assert self.config['inter_feat']
+ logger.info(f'before norm_fpnfeat2 max of feat_c0, feat_c1:{feat_c0.abs().max()}, {feat_c1.abs().max()}')
+ if data['hw0_i'] == data['hw1_i']:
+ logger.info(f'before norm_fpnfeat2 max of data[feats_x2], data[feats_x1]:{data["feats_x2"].abs().max()}, {data["feats_x1"].abs().max()}')
+ feat_c0, feat_c1, data['feats_x2'], data['feats_x1'] = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1, data['feats_x2'], data['feats_x1']])
+ else:
+ feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1'] = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1']])
+
+
+ # 4. fine-level refinement
+ with torch.autocast(enabled=False, device_type="cuda"):
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
+
+ # detect nan
+ if self.config['replace_nan'] and (torch.any(torch.isnan(feat_f0_unfold)) or torch.any(torch.isnan(feat_f1_unfold))):
+ logger.info(f'replace nan in fine_preprocess')
+ logger.info(f"feat_f0_unfold_nan_num: {torch.isnan(feat_f0_unfold).int().sum()}, feat_f1_unfold_nan_num: {torch.isnan(feat_f1_unfold).int().sum()}")
+ logger.info(f"feat_f0_unfold: {feat_f0_unfold}, feat_f1_unfold: {feat_f1_unfold}")
+ logger.info(f"feat_f0_unfold_max: {feat_f0_unfold}, feat_f1_unfold_max: {feat_f1_unfold}")
+ feat_f0_unfold[torch.isnan(feat_f0_unfold)] = 0
+ feat_f1_unfold[torch.isnan(feat_f1_unfold)] = 0
+ logger.info(f"feat_f0_unfold_nanmax: {feat_f0_unfold}, feat_f1_unfold_nanmax: {feat_f1_unfold}")
+
+ if self.config['fp16log'] and feat_c0 is not None:
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
+ del feat_c0, feat_c1, mask_c0, mask_c1
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
+ if self.config['fine']['pan']:
+ m, ww, c = feat_f0_unfold.size() # [m, ww, c]
+ w = self.config['fine_window_size']
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w))
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c')
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c')
+ elif self.config['fine']['skip']:
+ pass
+ else:
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
+ # 5. match fine-level
+ # log forward nan
+ if self.config['fp16log']:
+ if feat_f0_unfold.size(0) != 0 and feat_f0 is not None:
+ logger.info(f"f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
+ elif feat_f0_unfold.size(0) != 0:
+ logger.info(f"uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
+ # elif feat_c0 is not None:
+ # logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
+
+ with torch.autocast(enabled=False, device_type="cuda"):
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
+
+ return data
+
+ def load_state_dict(self, state_dict, *args, **kwargs):
+ for k in list(state_dict.keys()):
+ if k.startswith('matcher.'):
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+ return super().load_state_dict(state_dict, *args, **kwargs)
+
+ def refine(self, data):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+ feat_f0, feat_f1 = None, None
+ feat_c0, feat_c1 = data['feat_c0'], data['feat_c1']
+ # 4. fine-level refinement
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
+ if self.config['fine']['pan']:
+ m, ww, c = feat_f0_unfold.size() # [m, ww, c]
+ w = self.config['fine_window_size']
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w))
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c')
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c')
+ elif self.config['fine']['skip']:
+ pass
+ else:
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
+ # 5. match fine-level
+ # log forward nan
+ if self.config['fp16log']:
+ if feat_f0_unfold.size(0) != 0 and feat_f0 is not None and feat_c0 is not None:
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}")
+ elif feat_f0 is not None and feat_c0 is not None:
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}")
+ elif feat_c0 is not None:
+ logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}")
+
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
+ return data
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py
@@ -0,0 +1,2 @@
+from .transformer import LocalFeatureTransformer
+from .fine_preprocess import FinePreprocess
diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..019c0199e7be5c7fe65669420a98003c51c8bed2
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py
@@ -0,0 +1,350 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange, repeat
+from ..backbone.repvgg import RepVGGBlock
+
+from loguru import logger
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution without padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+class FinePreprocess(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.cat_c_feat = config['fine_concat_coarse_feat']
+ self.sample_c_feat = config['fine_sample_coarse_feat']
+ self.fpn_inter_feat = config['inter_feat']
+ self.rep_fpn = config['rep_fpn']
+ self.deploy = config['rep_deploy']
+ self.multi_regress = config['match_fine']['multi_regress']
+ self.local_regress = config['match_fine']['local_regress']
+ self.local_regress_inner = config['match_fine']['local_regress_inner']
+ block_dims = config['resnetfpn']['block_dims']
+
+ self.mtd_spvs = self.config['fine']['mtd_spvs']
+ self.align_corner = self.config['align_corner']
+ self.fix_bias = self.config['fix_bias']
+
+ if self.mtd_spvs:
+ self.W = self.config['fine_window_size']
+ else:
+ # assert False, 'fine_window_matching_size to be revised' # good notification!
+ # self.W = self.config['fine_window_matching_size']
+ self.W = self.config['fine_window_size']
+
+ self.backbone_type = self.config['backbone_type']
+
+ d_model_c = self.config['coarse']['d_model']
+ d_model_f = self.config['fine']['d_model']
+ self.d_model_f = d_model_f
+ if self.fpn_inter_feat:
+ if self.rep_fpn:
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = []
+ self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[2], kernel_size=3,
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01))
+ self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[1], kernel_size=3,
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2))
+ self.layer2_outconv2 = nn.ModuleList(self.layer2_outconv2)
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = []
+ self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[1], kernel_size=3,
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01))
+ self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[0], kernel_size=3,
+ stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2))
+ self.layer1_outconv2 = nn.ModuleList(self.layer1_outconv2)
+
+ else:
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+ elif self.cat_c_feat:
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
+ self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
+ if self.sample_c_feat:
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
+
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
+
+ def inter_fpn(self, feat_c, x2, x1, stride):
+ feat_c = self.layer3_outconv(feat_c)
+ feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False)
+ x2 = self.layer2_outconv(x2)
+ if self.rep_fpn:
+ x2 = x2 + feat_c
+ for layer in self.layer2_outconv2:
+ x2 = layer(x2)
+ else:
+ x2 = self.layer2_outconv2(x2+feat_c)
+
+ x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False)
+ x1 = self.layer1_outconv(x1)
+ if self.rep_fpn:
+ x1 = x1 + x2
+ for layer in self.layer1_outconv2:
+ x1 = layer(x1)
+ else:
+ x1 = self.layer1_outconv2(x1+x2)
+
+ if stride == 4:
+ logger.info('stride == 4')
+
+ elif stride == 8:
+ logger.info('stride == 8')
+ x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False)
+ else:
+ logger.info('stride not in {4,8}')
+ assert False
+ return x1
+
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
+ W = self.W
+ if self.fix_bias:
+ stride = 4
+ else:
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
+
+ data.update({'W': W})
+ if data['b_ids'].shape[0] == 0:
+ feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device)
+ feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device)
+ # return feat0, feat1
+ return feat0.float(), feat1.float()
+
+ if self.fpn_inter_feat:
+ if data['hw0_i'] != data['hw1_i']:
+ if self.align_corner is False:
+ assert self.backbone_type != 's2dnet'
+
+ feat_c0 = rearrange(feat_c0, 'b (h w) c -> b c h w', h=data['hw0_c'][0])
+ feat_c1 = rearrange(feat_c1, 'b (h w) c -> b c h w', h=data['hw1_c'][0])
+ x2_0, x1_0 = data['feats_x2_0'], data['feats_x1_0']
+ x2_1, x1_1 = data['feats_x2_1'], data['feats_x1_1']
+ del data['feats_x2_0'], data['feats_x1_0'], data['feats_x2_1'], data['feats_x1_1']
+ feat_f0, feat_f1 = self.inter_fpn(feat_c0, x2_0, x1_0, stride), self.inter_fpn(feat_c1, x2_1, x1_1, stride)
+
+ if self.local_regress_inner:
+ assert W == 8
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
+ elif W == 10 and self.multi_regress:
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
+ elif W == 10:
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
+ else:
+ assert not self.multi_regress
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
+
+ return feat_f0, feat_f1
+
+ else:
+ if self.align_corner is False:
+ feat_c = torch.cat([feat_c0, feat_c1], 0)
+ feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0]) # 1/8 256
+ x2 = data['feats_x2'].float() # 1/4 128
+ x1 = data['feats_x1'].float() # 1/2 64
+ del data['feats_x2'], data['feats_x1']
+ assert self.backbone_type != 's2dnet'
+ feat_c = self.layer3_outconv(feat_c)
+ feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False)
+ x2 = self.layer2_outconv(x2)
+ if self.rep_fpn:
+ x2 = x2 + feat_c
+ for layer in self.layer2_outconv2:
+ x2 = layer(x2)
+ else:
+ x2 = self.layer2_outconv2(x2+feat_c)
+
+ x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False)
+ x1 = self.layer1_outconv(x1)
+ if self.rep_fpn:
+ x1 = x1 + x2
+ for layer in self.layer1_outconv2:
+ x1 = layer(x1)
+ else:
+ x1 = self.layer1_outconv2(x1+x2)
+
+ if stride == 4:
+ # logger.info('stride == 4')
+ pass
+ elif stride == 8:
+ # logger.info('stride == 8')
+ x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False)
+ else:
+ # logger.info('stride not in {4,8}')
+ assert False
+
+ feat_f0, feat_f1 = torch.chunk(x1, 2, dim=0)
+
+ # 1. unfold(crop) all local windows
+ if self.local_regress_inner:
+ assert W == 8
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
+ elif self.multi_regress or (self.local_regress and W == 10):
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
+ elif W == 10:
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ else:
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
+
+ return feat_f0, feat_f1
+ elif self.fix_bias:
+ feat_c = torch.cat([feat_c0, feat_c1], 0)
+ feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0])
+ x2 = data['feats_x2'].float()
+ x1 = data['feats_x1'].float()
+ assert self.backbone_type != 's2dnet'
+ x3_out = self.layer3_outconv(feat_c)
+ x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=False)
+ x2 = self.layer2_outconv(x2)
+ x2 = self.layer2_outconv2(x2+x3_out_2x)
+
+ x2 = F.interpolate(x2, size=((x2.size(-2)-1)*2+1, (x2.size(-1)-1)*2+1), mode='bilinear', align_corners=False)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2)
+ x0_out = x1_out
+
+ feat_f0, feat_f1 = torch.chunk(x0_out, 2, dim=0)
+
+ # 1. unfold(crop) all local windows
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+
+ return feat_f0_unfold, feat_f1_unfold
+
+
+
+ elif self.sample_c_feat:
+ if self.align_corner is False:
+ # easy implemented but memory consuming
+ feat_c = self.down_proj(torch.cat([feat_c0,
+ feat_c1], 0)) # [n, (h w), c] -> [2n, (h w), cf]
+ feat_c = rearrange(feat_c, 'n (h w) c -> n c h w', h=data['hw0_c'][0], w=data['hw0_c'][1])
+ feat_f = F.interpolate(feat_c, scale_factor=8., mode='bilinear', align_corners=False) # [2n, cf, hf, wf]
+ feat_f_unfold = F.unfold(feat_f, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f_unfold = rearrange(feat_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_f_unfold, 2, dim=0)
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [m, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] # [m, ww, cf]
+ # return feat_f0_unfold, feat_f1_unfold
+ return feat_f0_unfold.float(), feat_f1_unfold.float()
+ else:
+ if self.align_corner is False:
+ # 1. unfold(crop) all local windows
+ assert False, 'maybe exist bugs'
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+
+ # option: use coarse-level loftr feature as context: concat and linear
+ if self.cat_c_feat:
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
+ feat_cf_win = self.merge_feat(torch.cat([
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
+ ], -1))
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
+
+ return feat_f0_unfold, feat_f1_unfold
+
+ else:
+ # 1. unfold(crop) all local windows
+ if self.fix_bias:
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ else:
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+
+ # option: use coarse-level loftr feature as context: concat and linear
+ if self.cat_c_feat:
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
+ feat_cf_win = self.merge_feat(torch.cat([
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
+ ], -1))
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
+
+ # return feat_f0_unfold, feat_f1_unfold
+ return feat_f0_unfold.float(), feat_f1_unfold.float()
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d7f08fcd77195dab126a59d2e832e72bc31012a
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py
@@ -0,0 +1,217 @@
+"""
+Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
+Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
+"""
+
+import torch
+from torch.nn import Module, Dropout
+import torch.nn.functional as F
+
+# if hasattr(F, 'scaled_dot_product_attention'):
+# FLASH_AVAILABLE = True
+# else: # v100
+FLASH_AVAILABLE = False
+ # import xformers.ops
+from ..utils.position_encoding import PositionEncodingSine, RoPEPositionEncodingSine
+from einops.einops import rearrange
+from loguru import logger
+
+
+# flash_attn_func_ok = True
+# try:
+# from flash_attn import flash_attn_func
+# except ModuleNotFoundError:
+# flash_attn_func_ok = False
+
+def elu_feature_map(x):
+ return torch.nn.functional.elu(x) + 1
+
+
+class LinearAttention(Module):
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.feature_map = elu_feature_map
+ self.eps = eps
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = self.feature_map(queries)
+ K = self.feature_map(keys)
+
+ # set padded position to zero
+ if q_mask is not None:
+ Q = Q * q_mask[:, :, None, None]
+ if kv_mask is not None:
+ K = K * kv_mask[:, :, None, None]
+ values = values * kv_mask[:, :, None, None]
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
+ # queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+
+ return queried_values.contiguous()
+
+class RoPELinearAttention(Module):
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.feature_map = elu_feature_map
+ self.eps = eps
+ self.RoPE = RoPEPositionEncodingSine(256, max_shape=(256, 256))
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None, H=None, W=None):
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = self.feature_map(queries)
+ K = self.feature_map(keys)
+ nhead, d = Q.size(2), Q.size(3)
+ # set padded position to zero
+ if q_mask is not None:
+ Q = Q * q_mask[:, :, None, None]
+ if kv_mask is not None:
+ K = K * kv_mask[:, :, None, None]
+ values = values * kv_mask[:, :, None, None]
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ # Q = Q / Q.size(1)
+ # logger.info(f"Q: {Q.dtype}, K: {K.dtype}, values: {values.dtype}")
+
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
+ # logger.info(f"Z_max: {Z.abs().max()}")
+ Q = rearrange(Q, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W)
+ K = rearrange(K, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W)
+ Q, K = self.RoPE(Q), self.RoPE(K)
+ # logger.info(f"Q_rope: {Q.abs().max()}, K_rope: {K.abs().max()}")
+ Q = rearrange(Q, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d)
+ K = rearrange(K, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d)
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
+ del K, values
+ # logger.info(f"KV_max: {KV.abs().max()}")
+ # queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+ # Q = torch.einsum("nlhd,nlh->nlhd", Q, Z)
+ # logger.info(f"QZ_max: {Q.abs().max()}")
+ # queried_values = torch.einsum("nlhd,nhdv->nlhv", Q, KV) * v_length
+ # logger.info(f"message_max: {queried_values.abs().max()}")
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+
+ return queried_values.contiguous()
+
+
+class FullAttention(Module):
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ self.dropout = Dropout(attention_dropout)
+
+ # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ # assert kv_mask is None
+ # mask = torch.zeros(queries.size(0)*queries.size(2), queries.size(1), keys.size(1), device=queries.device)
+ # mask.masked_fill(~(q_mask[:, :, None] * kv_mask[:, None, :]), float('-inf'))
+ # if keys.size(1) % 8 != 0:
+ # mask = torch.cat([mask, torch.zeros(queries.size(0)*queries.size(2), queries.size(1), 8-keys.size(1)%8, device=queries.device)], dim=-1)
+ # out = xformers.ops.memory_efficient_attention(queries, keys, values, attn_bias=mask[...,:keys.size(1)])
+ # return out
+
+ # N = queries.size(0)
+ # list_q = [queries[i, :q_mask[i].sum, ...] for i in N]
+ # list_k = [keys[i, :kv_mask[i].sum, ...] for i in N]
+ # list_v = [values[i, :kv_mask[i].sum, ...] for i in N]
+ # assert N == 1
+ # out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...])
+ # out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1)
+ # return out
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
+ if kv_mask is not None:
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), -1e5) # float('-inf')
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+ if self.use_dropout:
+ A = self.dropout(A)
+
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
+
+ return queried_values.contiguous()
+
+
+class XAttention(Module):
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ if use_dropout:
+ self.dropout = Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ if FLASH_AVAILABLE: # pytorch scaled_dot_product_attention
+ queries: [N, H, L, D]
+ keys: [N, H, S, D]
+ values: [N, H, S, D]
+ else:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+
+ assert q_mask is None and kv_mask is None, "already been sliced"
+ if FLASH_AVAILABLE:
+ # args = [x.half().contiguous() for x in [queries, keys, values]]
+ # out = F.scaled_dot_product_attention(*args, attn_mask=mask).to(queries.dtype)
+ args = [x.contiguous() for x in [queries, keys, values]]
+ out = F.scaled_dot_product_attention(*args)
+ else:
+ # if flash_attn_func_ok:
+ # out = flash_attn_func(queries, keys, values)
+ # else:
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+
+ out = torch.einsum("nlsh,nshd->nlhd", A, values)
+
+ # out = xformers.ops.memory_efficient_attention(queries, keys, values)
+ # out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...])
+ # out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1)
+ return out
diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee8ce85912ad44539c27836ddd20f676912df5b
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py
@@ -0,0 +1,1768 @@
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .linear_attention import LinearAttention, RoPELinearAttention, FullAttention, XAttention
+from einops.einops import rearrange
+from collections import OrderedDict
+from .transformer_utils import TokenConfidence, MatchAssignment, filter_matches
+from ..utils.coarse_matching import CoarseMatching
+from ..utils.position_encoding import RoPEPositionEncodingSine
+import numpy as np
+from loguru import logger
+
+PFLASH_AVAILABLE = False
+
+class PANEncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear',
+ pool_size=4,
+ bn=True,
+ xformer=False,
+ leaky=-1.0,
+ dw_conv=False,
+ scatter=False,
+ ):
+ super(PANEncoderLayer, self).__init__()
+
+ self.pool_size = pool_size
+ self.dw_conv = dw_conv
+ self.scatter = scatter
+ if self.dw_conv:
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
+
+ assert not self.scatter, 'buggy implemented here'
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
+ # multi-head attention
+ if bn:
+ method = 'dw_bn'
+ else:
+ method = 'dw'
+ self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
+ self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
+
+ # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
+ # self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ # self.v_proj = nn.Linear(d_model, d_model, bias=False)
+ if xformer:
+ self.attention = XAttention()
+ else:
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ if leaky > 0:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.LeakyReLU(leaky, True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # self.norm1 = nn.BatchNorm2d(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, C, H1, W1]
+ source (torch.Tensor): [N, C, H2, W2]
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
+ """
+ bs = x.size(0)
+ H1, W1 = x.size(-2), x.size(-1)
+ H2, W2 = source.size(-2), source.size(-1)
+
+ query, key, value = x, source, source
+
+ if self.dw_conv:
+ query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
+ else:
+ query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
+ # only need to cal key or value...
+ key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
+ value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
+
+ # After 0617 bnorm to prevent permute*6
+ # query = self.norm1(self.max_pool(query))
+ # key = self.norm1(self.max_pool(key))
+ # value = self.norm1(self.max_pool(value))
+ # multi-head attention
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
+ key = self.k_proj_conv(key)
+ value = self.v_proj_conv(value)
+
+ C = query.shape[-3]
+
+ ismask = x_mask is not None and source_mask is not None
+ if bs == 1 or not ismask:
+ if ismask:
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
+ source_mask = self.max_pool(source_mask.float()).bool()
+
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
+
+ query = query[:, :, :mask_h0, :mask_w0]
+ key = key[:, :, :mask_h1, :mask_w1]
+ value = value[:, :, :mask_h1, :mask_w1]
+
+ else:
+ assert x_mask is None and source_mask is None
+
+ # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
+ # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
+ # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
+ if PFLASH_AVAILABLE: # N H L D
+ query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+
+ else: # N L H D
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
+
+ if PFLASH_AVAILABLE: # N H L D
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
+
+ if ismask:
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
+ if mask_h0 != x_mask.size(-2):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
+ # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
+
+ else:
+ assert x_mask is None and source_mask is None
+
+
+ message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
+
+ if self.scatter:
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
+ # message = self.aggregate(message)
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
+ else:
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+
+ # message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x + message
+ else:
+ x_mask = self.max_pool(x_mask.float()).bool()
+ source_mask = self.max_pool(source_mask.float()).bool()
+ m_list = []
+ for i in range(bs):
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
+
+ q = query[i:i+1, :, :mask_h0, :mask_w0]
+ k = key[i:i+1, :, :mask_h1, :mask_w1]
+ v = value[i:i+1, :, :mask_h1, :mask_w1]
+
+ if PFLASH_AVAILABLE: # N H L D
+ q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+
+ else: # N L H D
+
+ q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
+ k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
+
+ if PFLASH_AVAILABLE: # N H L D
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
+
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
+ if mask_h0 != x_mask.size(-2):
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
+ m_list.append(m)
+ message = torch.cat(m_list, dim=0)
+
+
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
+
+ if self.scatter:
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
+ # message = self.aggregate(message)
+ # assert False
+ else:
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+
+ # message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x + message
+
+
+ def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
+ """
+ Args:
+ x (torch.Tensor): [N, C, H1, W1]
+ source (torch.Tensor): [N, C, H2, W2]
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
+ """
+ bs = x.size(0)
+ H1, W1 = x.size(-2), x.size(-1)
+ H2, W2 = source.size(-2), source.size(-1)
+
+ query, key, value = x, source, source
+
+ with profiler.profile("permute*6+norm1*3+max_pool*3"):
+ if self.dw_conv:
+ query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
+ else:
+ query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
+ # only need to cal key or value...
+ key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
+ value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
+
+ with profiler.profile("permute*6"):
+ query = query.permute(0, 2, 3, 1)
+ key = key.permute(0, 2, 3, 1)
+ value = value.permute(0, 2, 3, 1)
+
+ query = query.permute(0,3,1,2)
+ key = key.permute(0,3,1,2)
+ value = value.permute(0,3,1,2)
+
+ # query = self.bnorm1(self.max_pool(query))
+ # key = self.bnorm1(self.max_pool(key))
+ # value = self.bnorm1(self.max_pool(value))
+ # multi-head attention
+
+ with profiler.profile("q_conv+k_conv+v_conv"):
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
+ key = self.k_proj_conv(key)
+ value = self.v_proj_conv(value)
+
+ C = query.shape[-3]
+ # TODO: Need to be consistent with bs=1 (where mask region do not in attention at all)
+ if x_mask is not None and source_mask is not None:
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
+ source_mask = self.max_pool(source_mask.float()).bool()
+
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
+
+ query = query[:, :, :mask_h0, :mask_w0]
+ key = key[:, :, :mask_h1, :mask_w1]
+ value = value[:, :, :mask_h1, :mask_w1]
+
+ # mask_h0, mask_w0 = data['mask0'][0].sum(-2)[0], data['mask0'][0].sum(-1)[0]
+ # mask_h1, mask_w1 = data['mask1'][0].sum(-2)[0], data['mask1'][0].sum(-1)[0]
+ # C = feat_c0.shape[-3]
+ # feat_c0 = feat_c0[:, :, :mask_h0, :mask_w0]
+ # feat_c1 = feat_c1[:, :, :mask_h1, :mask_w1]
+
+
+ # feat_c0 = feat_c0.reshape(-1, mask_h0, mask_w0, C)
+ # feat_c1 = feat_c1.reshape(-1, mask_h1, mask_w1, C)
+ # if mask_h0 != data['mask0'].size(-2):
+ # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0]-mask_h0, data['hw0_c'][1], C, device=feat_c0.device)], dim=1)
+ # elif mask_w0 != data['mask0'].size(-1):
+ # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0], data['hw0_c'][1]-mask_w0, C, device=feat_c0.device)], dim=2)
+
+ # if mask_h1 != data['mask1'].size(-2):
+ # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0]-mask_h1, data['hw1_c'][1], C, device=feat_c1.device)], dim=1)
+ # elif mask_w1 != data['mask1'].size(-1):
+ # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0], data['hw1_c'][1]-mask_w1, C, device=feat_c1.device)], dim=2)
+
+
+ else:
+ assert x_mask is None and source_mask is None
+
+
+
+ # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
+ # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
+ # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
+
+ with profiler.profile("rearrange*3"):
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+
+ with profiler.profile("attention"):
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D]
+
+ if x_mask is not None and source_mask is not None:
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
+ if mask_h0 != x_mask.size(-2):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
+ # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
+
+ else:
+ assert x_mask is None and source_mask is None
+
+ with profiler.profile("merge"):
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
+
+ with profiler.profile("rearrange*1"):
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
+
+ with profiler.profile("upsample"):
+ if self.scatter:
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
+ # message = self.aggregate(message)
+ # assert False
+ else:
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+
+ # message = self.norm1(message)
+
+ # feed-forward network
+ with profiler.profile("feed-forward_mlp+permute*2+norm2"):
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x + message
+
+
+ def _build_projection(self,
+ dim_in,
+ dim_out,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ method='dw_bn',
+ ):
+ if method == 'dw_bn':
+ proj = nn.Sequential(OrderedDict([
+ ('conv', nn.Conv2d(
+ dim_in,
+ dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=False,
+ groups=dim_in
+ )),
+ ('bn', nn.BatchNorm2d(dim_in)),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ elif method == 'avg':
+ proj = nn.Sequential(OrderedDict([
+ ('avg', nn.AvgPool2d(
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ ceil_mode=True
+ )),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ elif method == 'linear':
+ proj = None
+ elif method == 'dw':
+ proj = nn.Sequential(OrderedDict([
+ ('conv', nn.Conv2d(
+ dim_in,
+ dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=False,
+ groups=dim_in
+ )),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ else:
+ raise ValueError('Unknown method ({})'.format(method))
+
+ return proj
+
+class AG_RoPE_EncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear',
+ pool_size=4,
+ pool_size2=4,
+ xformer=False,
+ leaky=-1.0,
+ dw_conv=False,
+ dw_conv2=False,
+ scatter=False,
+ norm_before=True,
+ rope=False,
+ npe=None,
+ vit_norm=False,
+ dw_proj=False,
+ ):
+ super(AG_RoPE_EncoderLayer, self).__init__()
+
+ self.pool_size = pool_size
+ self.pool_size2 = pool_size2
+ self.dw_conv = dw_conv
+ self.dw_conv2 = dw_conv2
+ self.scatter = scatter
+ self.norm_before = norm_before
+ self.vit_norm = vit_norm
+ self.dw_proj = dw_proj
+ self.rope = rope
+ if self.dw_conv and self.pool_size != 1:
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
+ if self.dw_conv2 and self.pool_size2 != 1:
+ self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size2, padding=0, stride=pool_size2, bias=False, groups=d_model)
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size2, stride=self.pool_size2)
+
+ # multi-head attention
+ if self.dw_proj:
+ self.q_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
+ self.k_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
+ self.v_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
+ else:
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+
+ if self.rope:
+ self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
+
+ if xformer:
+ self.attention = XAttention()
+ else:
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ if leaky > 0:
+ if self.vit_norm:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model, d_model*2, bias=False),
+ nn.LeakyReLU(leaky, True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.LeakyReLU(leaky, True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ else:
+ if self.vit_norm:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # self.norm1 = nn.BatchNorm2d(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, C, H1, W1]
+ source (torch.Tensor): [N, C, H2, W2]
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
+ """
+ bs, C, H1, W1 = x.size()
+ H2, W2 = source.size(-2), source.size(-1)
+
+
+ if self.norm_before and not self.vit_norm:
+ if self.pool_size == 1:
+ query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
+ elif self.dw_conv:
+ query = self.norm1(self.aggregate(x).permute(0,2,3,1)) # [N, H, W, C]
+ else:
+ query = self.norm1(self.max_pool(x).permute(0,2,3,1)) # [N, H, W, C]
+ if self.pool_size2 == 1:
+ source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
+ elif self.dw_conv2:
+ source = self.norm1(self.aggregate2(source).permute(0,2,3,1)) # [N, H, W, C]
+ else:
+ source = self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
+ elif self.vit_norm:
+ if self.pool_size == 1:
+ query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
+ elif self.dw_conv:
+ query = self.aggregate(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
+ else:
+ query = self.max_pool(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
+ if self.pool_size2 == 1:
+ source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
+ elif self.dw_conv2:
+ source = self.aggregate2(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
+ else:
+ source = self.max_pool(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
+ else:
+ if self.pool_size == 1:
+ query = x.permute(0,2,3,1) # [N, H, W, C]
+ elif self.dw_conv:
+ query = self.aggregate(x).permute(0,2,3,1) # [N, H, W, C]
+ else:
+ query = self.max_pool(x).permute(0,2,3,1) # [N, H, W, C]
+ if self.pool_size2 == 1:
+ source = source.permute(0,2,3,1) # [N, H, W, C]
+ elif self.dw_conv2:
+ source = self.aggregate2(source).permute(0,2,3,1) # [N, H, W, C]
+ else:
+ source = self.max_pool(source).permute(0,2,3,1) # [N, H, W, C]
+
+ # projection
+ if self.dw_proj:
+ query = self.q_proj(query.permute(0,3,1,2)).permute(0,2,3,1)
+ key = self.k_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
+ value = self.v_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
+ else:
+ query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
+
+ # RoPE
+ if self.rope:
+ query = self.rope_pos_enc(query)
+ if self.pool_size == 1 and self.pool_size2 == 4:
+ key = self.rope_pos_enc(key, 4)
+ else:
+ key = self.rope_pos_enc(key)
+
+ use_mask = x_mask is not None and source_mask is not None
+ if bs == 1 or not use_mask:
+ if use_mask:
+ # downsample mask
+ if self.pool_size ==1:
+ pass
+ else:
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
+
+ if self.pool_size2 ==1:
+ pass
+ else:
+ source_mask = self.max_pool(source_mask.float()).bool()
+
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
+
+ query = query[:, :mask_h0, :mask_w0, :]
+ key = key[:, :mask_h1, :mask_w1, :]
+ value = value[:, :mask_h1, :mask_w1, :]
+ else:
+ assert x_mask is None and source_mask is None
+
+ if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
+ query = rearrange(query, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ key = rearrange(key, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ value = rearrange(value, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ else: # N L H D
+ query = rearrange(query, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
+ key = rearrange(key, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
+ value = rearrange(value, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
+
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
+
+ if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
+
+ if use_mask: # padding zero
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L h D]
+ if mask_h0 != x_mask.size(-2):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
+ else:
+ assert x_mask is None and source_mask is None
+
+ message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
+
+ if self.pool_size == 1:
+ pass
+ else:
+ if self.scatter:
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
+ else:
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+
+ if not self.norm_before and not self.vit_norm:
+ message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
+
+ # feed-forward network
+ if self.vit_norm:
+ message_inter = (x + message)
+ del x
+ message = self.norm2(message_inter.permute(0, 2, 3, 1))
+ message = self.mlp(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
+ return message_inter + message
+ else:
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x + message
+ else: # mask with bs > 1
+ if self.pool_size ==1:
+ pass
+ else:
+ x_mask = self.max_pool(x_mask.float()).bool()
+
+ if self.pool_size2 ==1:
+ pass
+ else:
+ source_mask = self.max_pool(source_mask.float()).bool()
+ m_list = []
+ for i in range(bs):
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
+
+ q = query[i:i+1, :mask_h0, :mask_w0, :]
+ k = key[i:i+1, :mask_h1, :mask_w1, :]
+ v = value[i:i+1, :mask_h1, :mask_w1, :]
+
+ if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
+ q = rearrange(q, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ k = rearrange(k, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ v = rearrange(v, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ else: # N L H D
+ q = rearrange(q, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
+ k = rearrange(k, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
+ v = rearrange(v, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
+
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
+
+ if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
+
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
+ if mask_h0 != x_mask.size(-2):
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
+ m_list.append(m)
+ m = torch.cat(m_list, dim=0)
+
+ m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
+ m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
+
+ if self.pool_size == 1:
+ pass
+ else:
+ if self.scatter:
+ m = torch.repeat_interleave(m, self.pool_size, dim=-2)
+ m = torch.repeat_interleave(m, self.pool_size, dim=-1)
+ m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
+ else:
+ m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+
+
+ if not self.norm_before and not self.vit_norm:
+ m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
+
+ # feed-forward network
+ if self.vit_norm:
+ m_inter = (x + m)
+ del x
+ m = self.norm2(m_inter.permute(0, 2, 3, 1))
+ m = self.mlp(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
+ return m_inter + m
+ else:
+ m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x + m
+
+ return x + m
+
+class AG_Conv_EncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear',
+ pool_size=4,
+ bn=True,
+ xformer=False,
+ leaky=-1.0,
+ dw_conv=False,
+ dw_conv2=False,
+ scatter=False,
+ norm_before=True,
+ ):
+ super(AG_Conv_EncoderLayer, self).__init__()
+
+ self.pool_size = pool_size
+ self.dw_conv = dw_conv
+ self.dw_conv2 = dw_conv2
+ self.scatter = scatter
+ self.norm_before = norm_before
+ if self.dw_conv:
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
+ if self.dw_conv2:
+ self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
+
+ # multi-head attention
+ if bn:
+ method = 'dw_bn'
+ else:
+ method = 'dw'
+ self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
+ self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
+
+ if xformer:
+ self.attention = XAttention()
+ else:
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ if leaky > 0:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.LeakyReLU(leaky, True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, C, H1, W1]
+ source (torch.Tensor): [N, C, H2, W2]
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
+ """
+ bs = x.size(0)
+ H1, W1 = x.size(-2), x.size(-1)
+ H2, W2 = source.size(-2), source.size(-1)
+ C = x.shape[-3]
+
+ if self.norm_before:
+ if self.dw_conv:
+ query = self.norm1(self.aggregate(x).permute(0,2,3,1)).permute(0,3,1,2)
+ else:
+ query = self.norm1(self.max_pool(x).permute(0,2,3,1)).permute(0,3,1,2)
+ if self.dw_conv2:
+ source = self.norm1(self.aggregate2(source).permute(0,2,3,1)).permute(0,3,1,2)
+ else:
+ source = self.norm1(self.max_pool(source).permute(0,2,3,1)).permute(0,3,1,2)
+ else:
+ if self.dw_conv:
+ query = self.aggregate(x)
+ else:
+ query = self.max_pool(x)
+ if self.dw_conv2:
+ source = self.aggregate2(source)
+ else:
+ source = self.max_pool(source)
+
+ key, value = source, source
+
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
+ key = self.k_proj_conv(key)
+ value = self.v_proj_conv(value)
+
+ use_mask = x_mask is not None and source_mask is not None
+ if bs == 1 or not use_mask:
+ if use_mask:
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
+ source_mask = self.max_pool(source_mask.float()).bool()
+
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
+
+ query = query[:, :, :mask_h0, :mask_w0]
+ key = key[:, :, :mask_h1, :mask_w1]
+ value = value[:, :, :mask_h1, :mask_w1]
+
+ else:
+ assert x_mask is None and source_mask is None
+
+ if PFLASH_AVAILABLE: # N H L D
+ query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+
+ else: # N L H D
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
+
+ if PFLASH_AVAILABLE: # N H L D
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
+
+ if use_mask: # padding zero
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L H D]
+ if mask_h0 != x_mask.size(-2):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
+ else:
+ assert x_mask is None and source_mask is None
+
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
+
+ if self.scatter:
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
+ else:
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+
+ if not self.norm_before:
+ message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x + message
+ else: # mask with bs > 1
+ x_mask = self.max_pool(x_mask.float()).bool()
+ source_mask = self.max_pool(source_mask.float()).bool()
+ m_list = []
+ for i in range(bs):
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
+
+ q = query[i:i+1, :, :mask_h0, :mask_w0]
+ k = key[i:i+1, :, :mask_h1, :mask_w1]
+ v = value[i:i+1, :, :mask_h1, :mask_w1]
+
+ if PFLASH_AVAILABLE: # N H L D
+ q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+ v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
+
+ else: # N L H D
+ q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
+ k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
+
+ if PFLASH_AVAILABLE: # N H L D
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
+
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
+ if mask_h0 != x_mask.size(-2):
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
+ m_list.append(m)
+ m = torch.cat(m_list, dim=0)
+
+ m = self.merge(m.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+
+ # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
+ m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
+
+ if self.scatter:
+ m = torch.repeat_interleave(m, self.pool_size, dim=-2)
+ m = torch.repeat_interleave(m, self.pool_size, dim=-1)
+ m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
+ else:
+ m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+
+ if not self.norm_before:
+ m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
+
+ # feed-forward network
+ m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x + m
+
+ def _build_projection(self,
+ dim_in,
+ dim_out,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ method='dw_bn',
+ ):
+ if method == 'dw_bn':
+ proj = nn.Sequential(OrderedDict([
+ ('conv', nn.Conv2d(
+ dim_in,
+ dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=False,
+ groups=dim_in
+ )),
+ ('bn', nn.BatchNorm2d(dim_in)),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ elif method == 'avg':
+ proj = nn.Sequential(OrderedDict([
+ ('avg', nn.AvgPool2d(
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ ceil_mode=True
+ )),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ elif method == 'linear':
+ proj = None
+ elif method == 'dw':
+ proj = nn.Sequential(OrderedDict([
+ ('conv', nn.Conv2d(
+ dim_in,
+ dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=False,
+ groups=dim_in
+ )),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ else:
+ raise ValueError('Unknown method ({})'.format(method))
+
+ return proj
+
+
+class RoPELoFTREncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear',
+ rope=False,
+ token_mixer=None,
+ ):
+ super(RoPELoFTREncoderLayer, self).__init__()
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ # multi-head attention
+ if token_mixer is None:
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+
+ self.rope = rope
+ self.token_mixer = None
+ if token_mixer is not None:
+ self.token_mixer = token_mixer
+ if token_mixer == 'dwcn':
+ self.attention = nn.Sequential(OrderedDict([
+ ('conv', nn.Conv2d(
+ d_model,
+ d_model,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ bias=False,
+ groups=d_model
+ )),
+ ]))
+ elif self.rope:
+ assert attention == 'linear'
+ self.attention = RoPELinearAttention()
+
+ if token_mixer is None:
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ if token_mixer is None:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model, d_model, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model, d_model, bias=False),
+ )
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None, H=None, W=None):
+ """
+ Args:
+ x (torch.Tensor): [N, L, C]
+ source (torch.Tensor): [N, L, C]
+ x_mask (torch.Tensor): [N, L] (optional)
+ source_mask (torch.Tensor): [N, S] (optional)
+ """
+ bs = x.size(0)
+ assert H*W == x.size(-2)
+
+ # x = rearrange(x, 'n c h w -> n (h w) c')
+ # source = rearrange(source, 'n c h w -> n (h w) c')
+ query, key, value = x, source, source
+
+ if self.token_mixer is not None:
+ # multi-head attention
+ m = self.norm1(x)
+ m = rearrange(m, 'n (h w) c -> n c h w', h=H, w=W)
+ m = self.attention(m)
+ m = rearrange(m, 'n c h w -> n (h w) c')
+
+ x = x + m
+ x = x + self.mlp(self.norm2(x))
+ return x
+ else:
+ # multi-head attention
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask, H=H, W=W) # [N, L, (H, D)]
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=2))
+ message = self.norm2(message)
+
+ return x + message
+
+class LoFTREncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear',
+ xformer=False,
+ ):
+ super(LoFTREncoderLayer, self).__init__()
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+
+ if xformer:
+ self.attention = XAttention()
+ else:
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, L, C]
+ source (torch.Tensor): [N, S, C]
+ x_mask (torch.Tensor): [N, L] (optional)
+ source_mask (torch.Tensor): [N, S] (optional)
+ """
+ bs = x.size(0)
+ query, key, value = x, source, source
+
+ # multi-head attention
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=2))
+ message = self.norm2(message)
+
+ return x + message
+
+ def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
+ """
+ Args:
+ x (torch.Tensor): [N, L, C]
+ source (torch.Tensor): [N, S, C]
+ x_mask (torch.Tensor): [N, L] (optional)
+ source_mask (torch.Tensor): [N, S] (optional)
+ """
+ bs = x.size(0)
+ query, key, value = x, source, source
+
+ # multi-head attention
+ with profiler.profile("proj*3"):
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
+ with profiler.profile("attention"):
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
+ with profiler.profile("merge"):
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ with profiler.profile("norm1"):
+ message = self.norm1(message)
+
+ # feed-forward network
+ with profiler.profile("mlp"):
+ message = self.mlp(torch.cat([x, message], dim=2))
+ with profiler.profile("norm2"):
+ message = self.norm2(message)
+
+ return x + message
+
+class PANEncoderLayer_cross(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear',
+ pool_size=4,
+ bn=True,
+ ):
+ super(PANEncoderLayer_cross, self).__init__()
+
+ self.pool_size = pool_size
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
+ # multi-head attention
+ if bn:
+ method = 'dw_bn'
+ else:
+ method = 'dw'
+ self.qk_proj_conv = self._build_projection(d_model, d_model, method=method)
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
+
+ # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
+ # self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ # self.v_proj = nn.Linear(d_model, d_model, bias=False)
+ self.attention = FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # self.norm1 = nn.BatchNorm2d(d_model)
+
+ def forward(self, x1, x2, x1_mask=None, x2_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, C, H1, W1]
+ source (torch.Tensor): [N, C, H2, W2]
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
+ """
+ bs = x1.size(0)
+ H1, W1 = x1.size(-2) // self.pool_size, x1.size(-1) // self.pool_size
+ H2, W2 = x2.size(-2) // self.pool_size, x2.size(-1) // self.pool_size
+
+ query = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
+ key = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
+ v2 = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
+ v1 = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
+
+ # multi-head attention
+ query = self.qk_proj_conv(query) # [N, C, H1//pool, W1//pool]
+ key = self.qk_proj_conv(key)
+ v2 = self.v_proj_conv(v2)
+ v1 = self.v_proj_conv(v1)
+
+ C = query.shape[-3]
+ if x1_mask is not None and x2_mask is not None:
+ x1_mask = self.max_pool(x1_mask.float()).bool() # [N, H1//pool, W1//pool]
+ x2_mask = self.max_pool(x2_mask.float()).bool()
+
+ mask_h1, mask_w1 = x1_mask[0].sum(-2)[0], x1_mask[0].sum(-1)[0]
+ mask_h2, mask_w2 = x2_mask[0].sum(-2)[0], x2_mask[0].sum(-1)[0]
+
+ query = query[:, :, :mask_h1, :mask_w1]
+ key = key[:, :, :mask_h2, :mask_w2]
+ v1 = v1[:, :, :mask_h1, :mask_w1]
+ v2 = v2[:, :, :mask_h2, :mask_w2]
+ x1_mask = x1_mask[:, :mask_h1, :mask_w1]
+ x2_mask = x2_mask[:, :mask_h2, :mask_w2]
+
+ else:
+ assert x1_mask is None and x2_mask is None
+
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ v2 = rearrange(v2, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ v1 = rearrange(v1, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
+ if x2_mask is not None or x1_mask is not None:
+ x1_mask = x1_mask.flatten(-2)
+ x2_mask = x2_mask.flatten(-2)
+
+
+ QK = torch.einsum("nlhd,nshd->nlsh", query, key)
+ with torch.autocast(enabled=False, device_type='cuda'):
+ if x2_mask is not None or x1_mask is not None:
+ # S1 = S2.transpose(-2,-3).masked_fill(~(x_mask[:, None, :, None] * source_mask[:, :, None, None]), -1e9) # float('-inf')
+ QK = QK.float().masked_fill_(~(x1_mask[:, :, None, None] * x2_mask[:, None, :, None]), -1e9) # float('-inf')
+
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / query.size(3)**.5 # sqrt(D)
+ S1 = torch.softmax(softmax_temp * QK, dim=2)
+ S2 = torch.softmax(softmax_temp * QK, dim=3)
+
+ m1 = torch.einsum("nlsh,nshd->nlhd", S1, v2)
+ m2 = torch.einsum("nlsh,nlhd->nshd", S2, v1)
+
+ if x1_mask is not None and x2_mask is not None:
+ m1 = m1.view(bs, mask_h1, mask_w1, self.nhead, self.dim)
+ if mask_h1 != H1:
+ m1 = torch.cat([m1, torch.zeros(m1.size(0), H1-mask_h1, W1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=1)
+ elif mask_w1 != W1:
+ m1 = torch.cat([m1, torch.zeros(m1.size(0), H1, W1-mask_w1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=2)
+ else:
+ assert x1_mask is None and x2_mask is None
+
+ m1 = self.merge(m1.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ m1 = rearrange(m1, 'b (h w) c -> b c h w', h=H1, w=W1) # [N, C, H, W]
+ m1 = torch.nn.functional.interpolate(m1, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+ # feed-forward network
+ m1 = self.mlp(torch.cat([x1, m1], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ m1 = self.norm2(m1).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ if x1_mask is not None and x2_mask is not None:
+ m2 = m2.view(bs, mask_h2, mask_w2, self.nhead, self.dim)
+ if mask_h2 != H2:
+ m2 = torch.cat([m2, torch.zeros(m2.size(0), H2-mask_h2, W2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=1)
+ elif mask_w2 != W2:
+ m2 = torch.cat([m2, torch.zeros(m2.size(0), H2, W2-mask_w2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=2)
+ else:
+ assert x1_mask is None and x2_mask is None
+
+ m2 = self.merge(m2.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ m2 = rearrange(m2, 'b (h w) c -> b c h w', h=H2, w=W2) # [N, C, H, W]
+ m2 = torch.nn.functional.interpolate(m2, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
+ # feed-forward network
+ m2 = self.mlp(torch.cat([x2, m2], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
+ m2 = self.norm2(m2).permute(0, 3, 1, 2) # [N, C, H1, W1]
+
+ return x1 + m1, x2 + m2
+
+ def _build_projection(self,
+ dim_in,
+ dim_out,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ method='dw_bn',
+ ):
+ if method == 'dw_bn':
+ proj = nn.Sequential(OrderedDict([
+ ('conv', nn.Conv2d(
+ dim_in,
+ dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=False,
+ groups=dim_in
+ )),
+ ('bn', nn.BatchNorm2d(dim_in)),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ elif method == 'avg':
+ proj = nn.Sequential(OrderedDict([
+ ('avg', nn.AvgPool2d(
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ ceil_mode=True
+ )),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ elif method == 'linear':
+ proj = None
+ elif method == 'dw':
+ proj = nn.Sequential(OrderedDict([
+ ('conv', nn.Conv2d(
+ dim_in,
+ dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=stride,
+ bias=False,
+ groups=dim_in
+ )),
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
+ ]))
+ else:
+ raise ValueError('Unknown method ({})'.format(method))
+
+ return proj
+
+class LocalFeatureTransformer(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(LocalFeatureTransformer, self).__init__()
+
+ self.full_config = config
+ self.fine = False
+ if 'coarse' not in config:
+ self.fine = True # fine attention
+ else:
+ config = config['coarse']
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+ self.layer_names = config['layer_names']
+ self.pan = config['pan']
+ self.bidirect = config['bidirection']
+ # prune
+ self.pool_size = config['pool_size']
+ self.matchability = False
+ self.depth_confidence = -1.0
+ self.width_confidence = -1.0
+ # self.depth_confidence = config['depth_confidence']
+ # self.width_confidence = config['width_confidence']
+ # self.matchability = self.depth_confidence > 0 or self.width_confidence > 0
+ # self.thr = self.full_config['match_coarse']['thr']
+ if not self.fine:
+ # asy
+ self.asymmetric = config['asymmetric']
+ self.asymmetric_self = config['asymmetric_self']
+ # aggregate
+ self.aggregate = config['dwconv']
+ # RoPE
+ self.rope = config['rope']
+ # absPE
+ self.abspe = config['abspe']
+
+ else:
+ self.rope, self.asymmetric, self.asymmetric_self, self.aggregate = False, False, False, False
+ if self.matchability:
+ self.n_layers = len(self.layer_names) // 2
+ assert self.n_layers == 4
+ self.log_assignment = nn.ModuleList(
+ [MatchAssignment(self.d_model) for _ in range(self.n_layers)])
+ self.token_confidence = nn.ModuleList([
+ TokenConfidence(self.d_model) for _ in range(self.n_layers-1)])
+
+ self.CoarseMatching = CoarseMatching(self.full_config['match_coarse'])
+
+ # self only
+ # if self.rope:
+ # self_layer = RoPELoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'], config['rope'], config['token_mixer'])
+ # self.layers = nn.ModuleList([copy.deepcopy(self_layer) for _ in range(len(self.layer_names))])
+
+ if self.bidirect:
+ assert config['xformer'] is False and config['pan'] is True
+ self_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'], config['xformer'])
+ cross_layer = PANEncoderLayer_cross(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'])
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
+ else:
+ if self.aggregate:
+ if self.rope:
+ # assert config['npe'][0] == 832 and config['npe'][1] == 832 and config['npe'][2] == 832 and config['npe'][3] == 832
+ logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
+ self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
+ config['norm_before'], config['rope'], config['npe'], config['vit_norm'], config['rope_dwproj'])
+ cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
+ elif self.abspe:
+ logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
+ self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
+ cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
+
+ else:
+ encoder_layer = AG_Conv_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'],
+ config['xformer'], config['leaky'], config['dwconv'], config['scatter'],
+ config['norm_before'])
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
+ else:
+ encoder_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'],
+ config['bn'], config['xformer'], config['leaky'], config['dwconv'], config['scatter']) \
+ if config['pan'] else LoFTREncoderLayer(config['d_model'], config['nhead'],
+ config['attention'], config['xformer'])
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, C, H, W]
+ feat1 (torch.Tensor): [N, C, H, W]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+ # nchw for pan and n(hw)c for loftr
+ assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
+ H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
+ bs = feat0.shape[0]
+ padding = False
+ if bs == 1 and mask0 is not None and mask1 is not None and self.pan: # NCHW for pan
+ mask_H0, mask_W0 = mask0.size(-2), mask0.size(-1)
+ mask_H1, mask_W1 = mask1.size(-2), mask1.size(-1)
+ mask_h0, mask_w0 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0]
+ mask_h1, mask_w1 = mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
+
+ #round to self.pool_size
+ if self.pan:
+ mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.pool_size*self.pool_size, mask_w0//self.pool_size*self.pool_size, mask_h1//self.pool_size*self.pool_size, mask_w1//self.pool_size*self.pool_size
+
+ feat0 = feat0[:, :, :mask_h0, :mask_w0]
+ feat1 = feat1[:, :, :mask_h1, :mask_w1]
+
+ padding = True
+
+ # rope self only
+ # if self.rope:
+ # feat0, feat1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
+ # prune
+ if padding:
+ l0, l1 = mask_h0 * mask_w0, mask_h1 * mask_w1
+ else:
+ l0, l1 = H0 * W0, H1 * W1
+ do_early_stop = self.depth_confidence > 0
+ do_point_pruning = self.width_confidence > 0
+ if do_point_pruning:
+ ind0 = torch.arange(0, l0, device=feat0.device)[None]
+ ind1 = torch.arange(0, l1, device=feat0.device)[None]
+ # We store the index of the layer at which pruning is detected.
+ prune0 = torch.ones_like(ind0)
+ prune1 = torch.ones_like(ind1)
+ if do_early_stop:
+ token0, token1 = None, None
+
+ for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
+ if padding:
+ mask0, mask1 = None, None
+ if name == 'self':
+ # if self.rope:
+ # feat0 = layer(feat0, feat0, mask0, mask1, H0, W0)
+ # feat1 = layer(feat1, feat1, mask0, mask1, H1, W1)
+ if self.asymmetric:
+ assert False, 'not worked'
+ # feat0 = layer(feat0, feat0, mask0, mask1)
+ feat1 = layer(feat1, feat1, mask1, mask1)
+ else:
+ feat0 = layer(feat0, feat0, mask0, mask0)
+ feat1 = layer(feat1, feat1, mask1, mask1)
+ elif name == 'cross':
+ if self.bidirect:
+ feat0, feat1 = layer(feat0, feat1, mask0, mask1)
+ else:
+ if self.asymmetric or self.asymmetric_self:
+ assert False, 'not worked'
+ feat0 = layer(feat0, feat1, mask0, mask1)
+ else:
+ feat0 = layer(feat0, feat1, mask0, mask1)
+ feat1 = layer(feat1, feat0, mask1, mask0)
+
+ if i == len(self.layer_names) - 1 and not self.training:
+ continue
+ if self.matchability:
+ desc0, desc1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
+ if do_early_stop:
+ token0, token1 = self.token_confidence[i//2](desc0, desc1)
+ if self.check_if_stop(token0, token1, i, l0+l1) and not self.training:
+ break
+ if do_point_pruning:
+ scores0, scores1 = self.log_assignment[i//2].scores(desc0, desc1)
+ mask0 = self.get_pruning_mask(token0, scores0, i)
+ mask1 = self.get_pruning_mask(token1, scores1, i)
+ ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
+ feat0, feat1 = desc0[mask0][None], desc1[mask1][None]
+ if feat0.shape[-2] == 0 or desc1.shape[-2] == 0:
+ break
+ prune0[:, ind0] += 1
+ prune1[:, ind1] += 1
+ if self.training and self.matchability:
+ scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
+ m0_full = torch.zeros((bs, mask_h0 * mask_w0), device=matchability0.device, dtype=matchability0.dtype)
+ m0_full.scatter(1, ind0, matchability0.squeeze(-1))
+ if padding and self.d_model == feat0.size(1):
+ m0_full = m0_full.reshape(bs, mask_h0, mask_w0)
+ bs, c, mask_h0, mask_w0 = feat0.size()
+ if mask_h0 != mask_H0:
+ m0_full = torch.cat([m0_full, torch.zeros(bs, mask_H0-mask_h0, mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=1)
+ elif mask_w0 != mask_W0:
+ m0_full = torch.cat([m0_full, torch.zeros(bs, mask_h0, mask_W0-mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=2)
+ m0_full = m0_full.reshape(bs, mask_H0*mask_W0)
+ m1_full = torch.zeros((bs, mask_h1 * mask_w1), device=matchability0.device, dtype=matchability0.dtype)
+ m1_full.scatter(1, ind1, matchability1.squeeze(-1))
+ if padding and self.d_model == feat1.size(1):
+ m1_full = m1_full.reshape(bs, mask_h1, mask_w1)
+ bs, c, mask_h1, mask_w1 = feat1.size()
+ if mask_h1 != mask_H1:
+ m1_full = torch.cat([m1_full, torch.zeros(bs, mask_H1-mask_h1, mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=1)
+ elif mask_w1 != mask_W1:
+ m1_full = torch.cat([m1_full, torch.zeros(bs, mask_h1, mask_W1-mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=2)
+ m1_full = m1_full.reshape(bs, mask_H1*mask_W1)
+ data.update({'matchability0_'+str(i//2): m0_full, 'matchability1_'+str(i//2): m1_full})
+ m0, m1, mscores0, mscores1 = filter_matches(
+ scores, self.thr)
+ if do_point_pruning:
+ m0_ = torch.full((bs, l0), -1, device=m0.device, dtype=m0.dtype)
+ m1_ = torch.full((bs, l1), -1, device=m1.device, dtype=m1.dtype)
+ m0_[:, ind0] = torch.where(
+ m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
+ m1_[:, ind1] = torch.where(
+ m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
+ mscores0_ = torch.zeros((bs, l0), device=mscores0.device)
+ mscores1_ = torch.zeros((bs, l1), device=mscores1.device)
+ mscores0_[:, ind0] = mscores0
+ mscores1_[:, ind1] = mscores1
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
+ if padding and self.d_model == feat0.size(1):
+ m0 = m0.reshape(bs, mask_h0, mask_w0)
+ bs, c, mask_h0, mask_w0 = feat0.size()
+ if mask_h0 != mask_H0:
+ m0 = torch.cat([m0, -torch.ones(bs, mask_H0-mask_h0, mask_w0, device=m0.device, dtype=m0.dtype)], dim=1)
+ elif mask_w0 != mask_W0:
+ m0 = torch.cat([m0, -torch.ones(bs, mask_h0, mask_W0-mask_w0, device=m0.device, dtype=m0.dtype)], dim=2)
+ m0 = m0.reshape(bs, mask_H0*mask_W0)
+ if padding and self.d_model == feat1.size(1):
+ m1 = m1.reshape(bs, mask_h1, mask_w1)
+ bs, c, mask_h1, mask_w1 = feat1.size()
+ if mask_h1 != mask_H1:
+ m1 = torch.cat([m1, -torch.ones(bs, mask_H1-mask_h1, mask_w1, device=m1.device, dtype=m1.dtype)], dim=1)
+ elif mask_w1 != mask_W1:
+ m1 = torch.cat([m1, -torch.ones(bs, mask_h1, mask_W1-mask_w1, device=m1.device, dtype=m1.dtype)], dim=2)
+ m1 = m1.reshape(bs, mask_H1*mask_W1)
+ data.update({'matches0_'+str(i//2): m0, 'matches1_'+str(i//2): m1})
+ conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
+ ind = ind0[...,None] * l1 + ind1[:,None,:]
+ # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
+ conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
+ if padding and self.d_model == feat0.size(1):
+ conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
+ bs, c, mask_h0, mask_w0 = feat0.size()
+ if mask_h0 != mask_H0:
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
+ elif mask_w0 != mask_W0:
+ conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
+ bs, c, mask_h1, mask_w1 = feat1.size()
+ if mask_h1 != mask_H1:
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
+ elif mask_w1 != mask_W1:
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
+ conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
+ data.update({'conf_matrix_'+str(i//2): conf})
+
+
+
+ else:
+ raise KeyError
+
+ if self.matchability and not self.training:
+ scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
+ conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
+ ind = ind0[...,None] * l1 + ind1[:,None,:]
+ # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
+ conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
+ if padding and self.d_model == feat0.size(1):
+ conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
+ bs, c, mask_h0, mask_w0 = feat0.size()
+ if mask_h0 != mask_H0:
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
+ elif mask_w0 != mask_W0:
+ conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
+ bs, c, mask_h1, mask_w1 = feat1.size()
+ if mask_h1 != mask_H1:
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
+ elif mask_w1 != mask_W1:
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
+ conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
+ data.update({'conf_matrix': conf})
+ data.update(**self.CoarseMatching.get_coarse_match(conf, data))
+ # m0, m1, mscores0, mscores1 = filter_matches(
+ # scores, self.conf.filter_threshold)
+
+ # matches, mscores = [], []
+ # for k in range(b):
+ # valid = m0[k] > -1
+ # m_indices_0 = torch.where(valid)[0]
+ # m_indices_1 = m0[k][valid]
+ # if do_point_pruning:
+ # m_indices_0 = ind0[k, m_indices_0]
+ # m_indices_1 = ind1[k, m_indices_1]
+ # matches.append(torch.stack([m_indices_0, m_indices_1], -1))
+ # mscores.append(mscores0[k][valid])
+
+ # # TODO: Remove when hloc switches to the compact format.
+ # if do_point_pruning:
+ # m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
+ # m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
+ # m0_[:, ind0] = torch.where(
+ # m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
+ # m1_[:, ind1] = torch.where(
+ # m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
+ # mscores0_ = torch.zeros((b, m), device=mscores0.device)
+ # mscores1_ = torch.zeros((b, n), device=mscores1.device)
+ # mscores0_[:, ind0] = mscores0
+ # mscores1_[:, ind1] = mscores1
+ # m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
+
+ # pred = {
+ # 'matches0': m0,
+ # 'matches1': m1,
+ # 'matching_scores0': mscores0,
+ # 'matching_scores1': mscores1,
+ # 'stop': i+1,
+ # 'matches': matches,
+ # 'scores': mscores,
+ # }
+
+ # if do_point_pruning:
+ # pred.update(dict(prune0=prune0, prune1=prune1))
+ # return pred
+
+
+ if padding and self.d_model == feat0.size(1):
+ bs, c, mask_h0, mask_w0 = feat0.size()
+ if mask_h0 != mask_H0:
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
+ elif mask_w0 != mask_W0:
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
+ bs, c, mask_h1, mask_w1 = feat1.size()
+ if mask_h1 != mask_H1:
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
+ elif mask_w1 != mask_W1:
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
+
+ return feat0, feat1
+
+ def pro(self, feat0, feat1, mask0=None, mask1=None, profiler=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, C, H, W]
+ feat1 (torch.Tensor): [N, C, H, W]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+
+ assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
+ with profiler.profile("LoFTR_transformer_attention"):
+ for layer, name in zip(self.layers, self.layer_names):
+ if name == 'self':
+ feat0 = layer.pro(feat0, feat0, mask0, mask0, profiler=profiler)
+ feat1 = layer.pro(feat1, feat1, mask1, mask1, profiler=profiler)
+ elif name == 'cross':
+ feat0 = layer.pro(feat0, feat1, mask0, mask1, profiler=profiler)
+ feat1 = layer.pro(feat1, feat0, mask1, mask0, profiler=profiler)
+ else:
+ raise KeyError
+
+ return feat0, feat1
+
+ def confidence_threshold(self, layer_index: int) -> float:
+ """ scaled confidence threshold """
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
+ return np.clip(threshold, 0, 1)
+
+ def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor,
+ layer_index: int) -> torch.Tensor:
+ """ mask points which should be removed """
+ threshold = self.confidence_threshold(layer_index)
+ if confidences is not None:
+ scores = torch.where(
+ confidences > threshold, scores, scores.new_tensor(1.0))
+ return scores > (1 - self.width_confidence)
+
+ def check_if_stop(self,
+ confidences0: torch.Tensor,
+ confidences1: torch.Tensor,
+ layer_index: int, num_points: int) -> torch.Tensor:
+ """ evaluate stopping condition"""
+ confidences = torch.cat([confidences0, confidences1], -1)
+ threshold = self.confidence_threshold(layer_index)
+ pos = 1.0 - (confidences < threshold).float().sum() / num_points
+ return pos > self.depth_confidence
diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..25c261c973e8eeb6803ba6d21b5eb86992c3d857
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py
@@ -0,0 +1,76 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class TokenConfidence(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+ self.token = nn.Sequential(
+ nn.Linear(dim, 1),
+ nn.Sigmoid()
+ )
+
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
+ """ get confidence tokens """
+ return (
+ self.token(desc0.detach().float()).squeeze(-1),
+ self.token(desc1.detach().float()).squeeze(-1))
+
+def sigmoid_log_double_softmax(
+ sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
+ """ create the log assignment matrix from logits and similarity"""
+ b, m, n = sim.shape
+ m0, m1 = torch.sigmoid(z0), torch.sigmoid(z1)
+ certainties = torch.log(m0) + torch.log(m1).transpose(1, 2)
+ scores0 = F.log_softmax(sim, 2)
+ scores1 = F.log_softmax(
+ sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
+ scores = scores0 + scores1 + certainties
+ # scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
+ # scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
+ return scores, m0, m1
+
+class MatchAssignment(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+ self.dim = dim
+ self.matchability = nn.Linear(dim, 1, bias=True)
+ self.final_proj = nn.Linear(dim, dim, bias=True)
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
+ """ build assignment matrix from descriptors """
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
+ _, _, d = mdesc0.shape
+ mdesc0, mdesc1 = mdesc0 / d**.25, mdesc1 / d**.25
+ sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
+ z0 = self.matchability(desc0)
+ z1 = self.matchability(desc1)
+ scores, m0, m1 = sigmoid_log_double_softmax(sim, z0, z1)
+ return scores, sim, m0, m1
+
+ def scores(self, desc0: torch.Tensor, desc1: torch.Tensor):
+ m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1)
+ m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1)
+ return m0, m1
+
+def filter_matches(scores: torch.Tensor, th: float):
+ """ obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
+ max0, max1 = scores.max(2), scores.max(1)
+ m0, m1 = max0.indices, max1.indices
+ indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
+ indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
+ mutual0 = indices0 == m1.gather(1, m0)
+ mutual1 = indices1 == m0.gather(1, m1)
+ max0_exp = max0.values.exp()
+ zero = max0_exp.new_tensor(0)
+ mscores0 = torch.where(mutual0, max0_exp, zero)
+ mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
+ if th is not None:
+ valid0 = mutual0 & (mscores0 > th)
+ else:
+ valid0 = mutual0
+ valid1 = mutual1 & valid0.gather(1, m1)
+ m0 = torch.where(valid0, m0, -1)
+ m1 = torch.where(valid1, m1, -1)
+ return m0, m1, mscores0, mscores1
diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py b/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8dfca8227423ed699ea736e23b516bed68c19d
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py
@@ -0,0 +1,266 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange, repeat
+
+from loguru import logger
+
+INF = 1e9
+
+def mask_border(m, b: int, v):
+ """ Mask borders with value
+ Args:
+ m (torch.Tensor): [N, H0, W0, H1, W1]
+ b (int)
+ v (m.dtype)
+ """
+ if b <= 0:
+ return
+
+ m[:, :b] = v
+ m[:, :, :b] = v
+ m[:, :, :, :b] = v
+ m[:, :, :, :, :b] = v
+ m[:, -b:] = v
+ m[:, :, -b:] = v
+ m[:, :, :, -b:] = v
+ m[:, :, :, :, -b:] = v
+
+
+def mask_border_with_padding(m, bd, v, p_m0, p_m1):
+ if bd <= 0:
+ return
+
+ m[:, :bd] = v
+ m[:, :, :bd] = v
+ m[:, :, :, :bd] = v
+ m[:, :, :, :, :bd] = v
+
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
+ m[b_idx, h0 - bd:] = v
+ m[b_idx, :, w0 - bd:] = v
+ m[b_idx, :, :, h1 - bd:] = v
+ m[b_idx, :, :, :, w1 - bd:] = v
+
+
+def compute_max_candidates(p_m0, p_m1):
+ """Compute the max candidates of all pairs within a batch
+
+ Args:
+ p_m0, p_m1 (torch.Tensor): padded masks
+ """
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
+ max_cand = torch.sum(
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+ return max_cand
+
+
+class CoarseMatching(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # general config
+ self.thr = config['thr']
+ self.border_rm = config['border_rm']
+ # -- # for trainig fine-level LoFTR
+ self.train_coarse_percent = config['train_coarse_percent']
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
+
+ # we provide 2 options for differentiable matching
+ self.match_type = config['match_type']
+ if self.match_type == 'dual_softmax':
+ self.temperature = config['dsmax_temperature']
+ elif self.match_type == 'sinkhorn':
+ try:
+ from .superglue import log_optimal_transport
+ except ImportError:
+ raise ImportError("download superglue.py first!")
+ self.log_optimal_transport = log_optimal_transport
+ self.bin_score = nn.Parameter(
+ torch.tensor(config['skh_init_bin_score'], requires_grad=True))
+ self.skh_iters = config['skh_iters']
+ self.skh_prefilter = config['skh_prefilter']
+ else:
+ raise NotImplementedError()
+
+ self.mtd = config['mtd_spvs']
+ self.fix_bias = config['fix_bias']
+
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ data (dict)
+ mask_c0 (torch.Tensor): [N, L] (optional)
+ mask_c1 (torch.Tensor): [N, S] (optional)
+ Update:
+ data (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ NOTE: M' != M during training.
+ """
+ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
+
+ # normalize
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1])
+
+ if self.match_type == 'dual_softmax':
+ with torch.autocast(enabled=False, device_type='cuda'):
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
+ feat_c1) / self.temperature
+ if mask_c0 is not None:
+ sim_matrix = sim_matrix.float().masked_fill_(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -INF
+ # float("-inf") if sim_matrix.dtype == torch.float16 else -INF
+ )
+ if self.config['fp16log']:
+ t1 = F.softmax(sim_matrix, 1)
+ t2 = F.softmax(sim_matrix, 2)
+ conf_matrix = t1*t2
+ logger.info(f'feat_c0absmax: {feat_c0.abs().max()}')
+ logger.info(f'feat_c1absmax: {feat_c1.abs().max()}')
+ logger.info(f'sim_matrix: {sim_matrix.dtype}')
+ logger.info(f'sim_matrixabsmax: {sim_matrix.abs().max()}')
+ logger.info(f't1: {t1.dtype}, t2: {t2.dtype}, conf_matrix: {conf_matrix.dtype}')
+ logger.info(f't1absmax: {t1.abs().max()}, t2absmax: {t2.abs().max()}, conf_matrixabsmax: {conf_matrix.abs().max()}')
+ else:
+ conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
+
+ data.update({'conf_matrix': conf_matrix})
+
+ # predict coarse matches from conf_matrix
+ data.update(**self.get_coarse_match(conf_matrix, data))
+
+ @torch.no_grad()
+ def get_coarse_match(self, conf_matrix, data):
+ """
+ Args:
+ conf_matrix (torch.Tensor): [N, L, S]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ axes_lengths = {
+ 'h0c': data['hw0_c'][0],
+ 'w0c': data['hw0_c'][1],
+ 'h1c': data['hw1_c'][0],
+ 'w1c': data['hw1_c'][1]
+ }
+ _device = conf_matrix.device
+ # 1. confidence thresholding
+ mask = conf_matrix > self.thr
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
+ **axes_lengths)
+ if 'mask0' not in data:
+ mask_border(mask, self.border_rm, False)
+ else:
+ mask_border_with_padding(mask, self.border_rm, False,
+ data['mask0'], data['mask1'])
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
+ **axes_lengths)
+
+ # 2. mutual nearest
+ if self.mtd:
+ b_ids, i_ids, j_ids = torch.where(mask)
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
+ else:
+ mask = mask \
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
+
+ # 3. find all valid coarse matches
+ # this only works when at most one `True` in each row
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
+
+ # 4. Random sampling of training samples for fine-level LoFTR
+ # (optional) pad samples with gt coarse-level matches
+ if self.training:
+ # NOTE:
+ # The sampling is performed across all pairs in a batch without manually balancing
+ # #samples for fine-level increases w.r.t. batch_size
+ if 'mask0' not in data:
+ num_candidates_max = mask.size(0) * max(
+ mask.size(1), mask.size(2))
+ else:
+ num_candidates_max = compute_max_candidates(
+ data['mask0'], data['mask1'])
+ num_matches_train = int(num_candidates_max *
+ self.train_coarse_percent)
+ num_matches_pred = len(b_ids)
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
+
+ # pred_indices is to select from prediction
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
+ pred_indices = torch.arange(num_matches_pred, device=_device)
+ else:
+ pred_indices = torch.randint(
+ num_matches_pred,
+ (num_matches_train - self.train_pad_num_gt_min, ),
+ device=_device)
+
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
+ gt_pad_indices = torch.randint(
+ len(data['spv_b_ids']),
+ (max(num_matches_train - num_matches_pred,
+ self.train_pad_num_gt_min), ),
+ device=_device)
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
+
+ b_ids, i_ids, j_ids, mconf = map(
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
+ dim=0),
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+
+ # These matches select patches that feed into fine-level network
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+
+ # 4. Update with matches in original image resolution
+ if self.fix_bias:
+ scale = 8
+ else:
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+ mkpts0_c = torch.stack(
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
+ dim=1) * scale0
+ mkpts1_c = torch.stack(
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
+ dim=1) * scale1
+
+ m_bids = b_ids[mconf != 0]
+
+ m_bids_f = repeat(m_bids, 'b -> b k', k = 3).reshape(-1)
+ coarse_matches.update({
+ 'gt_mask': mconf == 0,
+ 'm_bids': m_bids, # mconf == 0 => gt matches
+ 'm_bids_f': m_bids_f,
+ 'mkpts0_c': mkpts0_c[mconf != 0],
+ 'mkpts1_c': mkpts1_c[mconf != 0],
+ 'mconf': mconf[mconf != 0]
+ })
+
+ return coarse_matches
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py b/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..7be172a9bf9d45e3cbcf33a4926abddfca877629
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py
@@ -0,0 +1,493 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+from loguru import logger
+
+class FineMatching(nn.Module):
+ """FineMatching with s2d paradigm"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.topk = config['match_fine']['topk']
+ self.mtd_spvs = config['fine']['mtd_spvs']
+ self.align_corner = config['align_corner']
+ self.fix_bias = config['fix_bias']
+ self.normfinem = config['match_fine']['normfinem']
+ self.fix_fine_matching = config['match_fine']['fix_fine_matching']
+ self.mutual_nearest = config['match_fine']['force_nearest']
+ self.skip_fine_softmax = config['match_fine']['skip_fine_softmax']
+ self.normfeat = config['match_fine']['normfeat']
+ self.use_sigmoid = config['match_fine']['use_sigmoid']
+ self.local_regress = config['match_fine']['local_regress']
+ self.local_regress_rmborder = config['match_fine']['local_regress_rmborder']
+ self.local_regress_nomask = config['match_fine']['local_regress_nomask']
+ self.local_regress_temperature = config['match_fine']['local_regress_temperature']
+ self.local_regress_padone = config['match_fine']['local_regress_padone']
+ self.local_regress_slice = config['match_fine']['local_regress_slice']
+ self.local_regress_slicedim = config['match_fine']['local_regress_slicedim']
+ self.local_regress_inner = config['match_fine']['local_regress_inner']
+ self.multi_regress = config['match_fine']['multi_regress']
+ def forward(self, feat_0, feat_1, data):
+ """
+ Args:
+ feat0 (torch.Tensor): [M, WW, C]
+ feat1 (torch.Tensor): [M, WW, C]
+ data (dict)
+ Update:
+ data (dict):{
+ 'expec_f' (torch.Tensor): [M, 3],
+ 'mkpts0_f' (torch.Tensor): [M, 2],
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
+ """
+ M, WW, C = feat_0.shape
+ W = int(math.sqrt(WW))
+ if self.fix_bias:
+ scale = 2
+ else:
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
+
+ # corner case: if no coarse matches found
+ if M == 0:
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
+ # logger.warning('No matches found in coarse-level.')
+ if self.mtd_spvs:
+ data.update({
+ 'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ # if self.local_regress:
+ # data.update({
+ # 'sim_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
+ # })
+ return
+ else:
+ data.update({
+ 'expec_f': torch.empty(0, 3, device=feat_0.device),
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+
+ if self.mtd_spvs:
+ with torch.autocast(enabled=False, device_type='cuda'):
+ # feat_0 = feat_0 / feat_0.size(-2)
+ if self.local_regress_slice:
+ feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:]
+ feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim]
+ conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5)
+ else:
+ feat_f0, feat_f1 = feat_0, feat_1
+ if self.normfinem:
+ feat_f0 = feat_f0 / C**.5
+ feat_f1 = feat_f1 / C**.5
+ conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1)
+ else:
+ if self.local_regress_slice:
+ conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / (C - self.local_regress_slicedim)**.5)
+ else:
+ conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / C**.5)
+
+ if self.normfeat:
+ feat_f0, feat_f1 = torch.nn.functional.normalize(feat_f0.float(), p=2, dim=-1), torch.nn.functional.normalize(feat_f1.float(), p=2, dim=-1)
+
+ if self.config['fp16log']:
+ logger.info(f'sim_matrix: {conf_matrix_f.abs().max()}')
+ # sim_matrix *= 1. / C**.5 # normalize
+
+ if self.multi_regress:
+ assert not self.local_regress
+ assert not self.normfinem and not self.normfeat
+ heatmap = F.softmax(conf_matrix_f, 2).view(M, WW, W, W) # [M, WW, W, W]
+
+ assert (W - 2) == (self.config['resolution'][0] // self.config['resolution'][1]) # c8
+ windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1])
+
+ coords_normalized = dsnt.spatial_expectation2d(heatmap, True) * windows_scale # [M, WW, 2]
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2)[:,None,:,:] * windows_scale # [1, 1, WW, 2]
+
+ # compute std over
+ var = torch.sum(grid_normalized**2 * heatmap.view(M, WW, WW, 1), dim=-2) - coords_normalized**2 # ([1,1,WW,2] * [M,WW,WW,1])->[M,WW,2]
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M,WW] clamp needed for numerical stability
+
+ # for fine-level supervision
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(-1)], -1)}) # [M, WW, 2]
+
+ # get the least uncertain matches
+ val, idx = torch.topk(std, self.topk, dim=-1, largest=False) # [M,topk]
+ coords_normalized = coords_normalized[torch.arange(M, device=conf_matrix_f.device, dtype=torch.long)[:,None], idx] # [M,topk]
+
+ grid = create_meshgrid(W, W, False, idx.device) - W // 2 + 0.5 # [1, W, W, 2]
+ grid = grid.reshape(1, -1, 2).expand(M, -1, -1) # [M, WW, 2]
+ delta_l = torch.gather(grid, 1, idx.unsqueeze(-1).expand(-1, -1, 2)) # [M, topk, 2] in (x, y)
+
+ # compute absolute kpt coords
+ self.get_multi_fine_match_align(delta_l, coords_normalized, data)
+
+
+ else:
+
+ if self.skip_fine_softmax:
+ pass
+ elif self.use_sigmoid:
+ conf_matrix_f = torch.sigmoid(conf_matrix_f)
+ else:
+ if self.local_regress:
+ del feat_f0, feat_f1
+ softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
+ # softmax_matrix_f = conf_matrix_f
+ if self.local_regress_inner:
+ softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
+ softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW)
+ # if self.training:
+ # for fine-level supervision
+ data.update({'conf_matrix_f': softmax_matrix_f})
+ if self.local_regress_slice:
+ data.update({'sim_matrix_ff': conf_matrix_ff})
+ else:
+ data.update({'sim_matrix_f': conf_matrix_f})
+
+ else:
+ conf_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
+
+ # for fine-level supervision
+ data.update({'conf_matrix_f': conf_matrix_f})
+
+ # compute absolute kpt coords
+ if self.local_regress:
+ self.get_fine_ds_match(softmax_matrix_f, data)
+ del softmax_matrix_f
+ idx_l, idx_r = data['idx_l'], data['idx_r']
+ del data['idx_l'], data['idx_r']
+ m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1).expand(-1, self.topk)
+ # if self.training:
+ m_ids = m_ids[:len(data['mconf']) // self.topk]
+ idx_r_iids, idx_r_jids = idx_r // W, idx_r % W
+
+ # remove boarder
+ if self.local_regress_nomask:
+ # log for inner precent
+ # mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2)
+ # mask_sum = mask.sum()
+ # logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}')
+ mask = None
+ m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
+ if self.local_regress_inner: # been sliced before
+ delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2]
+ else:
+ # no mask + 1 for padding
+ delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) + torch.tensor([1], dtype=torch.long, device=conf_matrix_f.device) # [1, 3, 3, 2]
+
+ m_ids = m_ids[...,None,None].expand(-1, 3, 3)
+ idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3]
+
+ idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
+ idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
+
+ if idx_l.numel() == 0:
+ data.update({
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+
+ if self.local_regress_slice:
+ conf_matrix_f = conf_matrix_ff
+ if self.local_regress_inner:
+ conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
+ else:
+ conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W)
+ conf_matrix_f = F.pad(conf_matrix_f, (1,1,1,1))
+ else:
+ mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2)
+ if W == 10:
+ idx_l_iids, idx_l_jids = idx_l // W, idx_l % W
+ mask = mask & (idx_l_iids >= 1) & (idx_l_iids <= W-2) & (idx_l_jids >= 1) & (idx_l_jids <= W-2)
+
+ m_ids = m_ids[mask].to(torch.long)
+ idx_l, idx_r_iids, idx_r_jids = idx_l[mask].to(torch.long), idx_r_iids[mask].to(torch.long), idx_r_jids[mask].to(torch.long)
+
+ m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
+ mask = mask.reshape(-1)
+
+ delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2]
+
+ m_ids = m_ids[:,None,None].expand(-1, 3, 3)
+ idx_l = idx_l[:,None,None].expand(-1, 3, 3) # [m, 3, 3]
+ # bug !!!!!!!!! 1,0 rather 0,1
+ # idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
+ # idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
+ idx_r_iids = idx_r_iids[:,None,None].expand(-1, 3, 3) + delta[..., 1]
+ idx_r_jids = idx_r_jids[:,None,None].expand(-1, 3, 3) + delta[..., 0]
+
+ if idx_l.numel() == 0:
+ data.update({
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+ if not self.local_regress_slice:
+ conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W)
+ else:
+ conf_matrix_f = conf_matrix_ff.reshape(M, self.WW, self.W, self.W)
+
+ conf_matrix_f = conf_matrix_f[m_ids, idx_l, idx_r_iids, idx_r_jids]
+ conf_matrix_f = conf_matrix_f.reshape(-1, 9)
+ if self.local_regress_padone: # follow the training detach the gradient of center
+ conf_matrix_f[:,4] = -1e4
+ heatmap = F.softmax(conf_matrix_f / self.local_regress_temperature, -1)
+ logger.info(f'maxmax&maxmean of heatmap: {heatmap.view(-1).max()}, {heatmap.view(-1).min(), heatmap.max(-1)[0].mean()}')
+ heatmap[:,4] = 1.0 # no need gradient calculation in inference
+ logger.info(f'min of heatmap: {heatmap.view(-1).min()}')
+ heatmap = heatmap.reshape(-1, 3, 3)
+ # heatmap = torch.ones_like(softmax) # ones_like for detach the gradient of center
+ # heatmap[:,:4], heatmap[:,5:] = softmax[:,:4], softmax[:,5:]
+ # heatmap = heatmap.reshape(-1, 3, 3)
+ else:
+ conf_matrix_f = F.softmax(conf_matrix_f / self.local_regress_temperature, -1)
+ # logger.info(f'max&min&mean of heatmap: {conf_matrix_f.view(-1).max()}, {conf_matrix_f.view(-1).min(), conf_matrix_f.max(-1)[0].mean()}')
+ heatmap = conf_matrix_f.reshape(-1, 3, 3)
+
+ # compute coordinates from heatmap
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
+
+ # coords_normalized_l2 = coords_normalized.norm(p=2, dim=-1)
+ # logger.info(f'mean&max&min abs of local: {coords_normalized_l2.mean(), coords_normalized_l2.max(), coords_normalized_l2.min()}')
+
+ # compute absolute kpt coords
+
+ if data['bs'] == 1:
+ scale1 = scale * data['scale1'] if 'scale0' in data else scale
+ else:
+ if mask is not None:
+ scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2)[mask] if 'scale0' in data else scale
+ else:
+ scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2) if 'scale0' in data else scale
+
+ self.get_fine_match_local(coords_normalized, data, scale1, mask, True)
+
+ else:
+ self.get_fine_ds_match(conf_matrix_f, data)
+
+
+ else:
+ if self.align_corner is True:
+ feat_f0, feat_f1 = feat_0, feat_1
+ feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
+ softmax_temp = 1. / C**.5
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
+
+ # compute coordinates from heatmap
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
+
+ # compute std over
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
+
+ # for fine-level supervision
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
+
+ # compute absolute kpt coords
+ self.get_fine_match(coords_normalized, data)
+ else:
+ feat_f0, feat_f1 = feat_0, feat_1
+ # even matching windows while coarse grid not aligned to fine grid!!!
+ # assert W == 5, "others size not checked"
+ if self.fix_bias:
+ assert W % 2 == 1, "W must be odd when select"
+ feat_f0_picked = feat_f0[:, WW//2]
+
+ else:
+ # assert W == 6, "others size not checked"
+ assert W % 2 == 0, "W must be even when coarse grid not aligned to fine grid(average)"
+ feat_f0_picked = (feat_f0[:, WW//2 - W//2 - 1] + feat_f0[:, WW//2 - W//2] + feat_f0[:, WW//2 + W//2] + feat_f0[:, WW//2 + W//2 - 1]) / 4
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
+ softmax_temp = 1. / C**.5
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
+
+ # compute coordinates from heatmap
+ windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1])
+
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] * windows_scale # [M, 2]
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) * windows_scale # [1, WW, 2]
+
+ # compute std over
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
+
+ # for fine-level supervision
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
+
+ # compute absolute kpt coords
+ self.get_fine_match_align(coords_normalized, data)
+
+
+ @torch.no_grad()
+ def get_fine_match(self, coords_normed, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+
+ # mkpts0_f and mkpts1_f
+ mkpts0_f = data['mkpts0_c']
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
+ mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
+
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
+
+ def get_fine_match_local(self, coords_normed, data, scale1, mask, reserve_border=True):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+
+ if mask is None:
+ mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c']
+ else:
+ data['mkpts0_c'], data['mkpts1_c'] = data['mkpts0_c'].reshape(-1, 2), data['mkpts1_c'].reshape(-1, 2)
+ mkpts0_c, mkpts1_c = data['mkpts0_c'][mask], data['mkpts1_c'][mask]
+ mask_sum = mask.sum()
+ logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}')
+ # print(mkpts0_c.shape, mkpts1_c.shape, coords_normed.shape, scale1.shape)
+ # print(data['mkpts0_c'].shape, data['mkpts1_c'].shape)
+ # mkpts0_f and mkpts1_f
+ mkpts0_f = mkpts0_c
+ mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1)
+
+ if reserve_border and mask is not None:
+ mkpts0_f, mkpts1_f = torch.cat([mkpts0_f, data['mkpts0_c'][~mask].reshape(-1, 2)]), torch.cat([mkpts1_f, data['mkpts1_c'][~mask].reshape(-1, 2)])
+ else:
+ pass
+
+ del data['mkpts0_c'], data['mkpts1_c']
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
+
+ # can be used for both aligned and not aligned
+ @torch.no_grad()
+ def get_fine_match_align(self, coord_normed, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+ c2f = self.config['resolution'][0] // self.config['resolution'][1]
+ # mkpts0_f and mkpts1_f
+ mkpts0_f = data['mkpts0_c']
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
+ mkpts1_f = data['mkpts1_c'] + (coord_normed * (c2f // 2) * scale1)[:len(data['mconf'])]
+
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
+
+ @torch.no_grad()
+ def get_multi_fine_match_align(self, delta_l, coord_normed, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+ c2f = self.config['resolution'][0] // self.config['resolution'][1]
+ # mkpts0_f and mkpts1_f
+ scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device)
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device)
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (coord_normed * (c2f // 2) * scale1[:,None,:])[:len(data['mconf'])]).reshape(-1, 2)
+
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f,
+ "mconf": data['mconf'][:,None].expand(-1, self.topk).reshape(-1)
+ })
+
+ @torch.no_grad()
+ def get_fine_ds_match(self, conf_matrix, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+
+ # select topk matches
+ m, _, _ = conf_matrix.shape
+
+
+ if self.mutual_nearest:
+ pass
+
+
+ elif not self.fix_fine_matching: # only allow one2mul but mul2one
+
+ val, idx_r = conf_matrix.max(-1) # (m, WW), (m, WW)
+ val, idx_l = torch.topk(val, self.topk, dim = -1) # (m, topk), (m, topk)
+ idx_r = torch.gather(idx_r, 1, idx_l) # (m, topk)
+
+ # mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate
+ # grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2)
+ grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5 # (1, W, W, 2)
+ grid = grid.reshape(1, -1, 2).expand(m, -1, -1) # (m, WW, 2)
+ delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2)
+ delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2)
+
+ # mkpts0_f and mkpts1_f
+ scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
+
+ if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
+ else: # scale0 is a float
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2)
+
+ else: # allow one2mul mul2one and mul2mul
+ conf_matrix = conf_matrix.reshape(m, -1)
+ if self.local_regress: # for the compatibility of former config
+ conf_matrix = conf_matrix[:len(data['mconf']),...]
+ val, idx = torch.topk(conf_matrix, self.topk, dim = -1)
+ idx_l = idx // WW
+ idx_r = idx % WW
+
+ if self.local_regress:
+ data.update({'idx_l': idx_l, 'idx_r': idx_r})
+
+ # mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate
+ # grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2)
+ grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5
+ grid = grid.reshape(1, -1, 2).expand(m, -1, -1)
+ delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2))
+ delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2))
+
+ # mkpts0_f and mkpts1_f
+ scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
+
+ if self.local_regress:
+ if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
+ else: # scale0 is a float
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2)
+
+ else:
+ if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2)
+ else: # scale0 is a float
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2)
+ del data['mkpts0_c'], data['mkpts1_c']
+ data['mconf'] = data['mconf'].reshape(-1, 1).expand(-1, self.topk).reshape(-1)
+ # data['mconf'] = val.reshape(-1)[:len(data['mconf'])]*0.1 + data['mconf']
+
+ if self.local_regress:
+ data.update({
+ "mkpts0_c": mkpts0_f,
+ "mkpts1_c": mkpts1_f
+ })
+ else:
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
+
diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py b/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..47de76bd8d8928b123bc7357349b1e7ae4ee90ac
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py
@@ -0,0 +1,298 @@
+import torch
+from src.utils.homography_utils import warp_points_torch
+
+def get_unique_indices(input_tensor):
+ if input_tensor.shape[0] > 1:
+ unique, inverse = torch.unique(input_tensor, sorted=True, return_inverse=True, dim=0)
+ perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
+ inverse, perm = inverse.flip([0]), perm.flip([0])
+ perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
+ else:
+ perm = torch.zeros((input_tensor.shape[0],), dtype=torch.long, device=input_tensor.device)
+ return perm
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, consistency_thr=0.2, cycle_proj_distance_thr=3.0):
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - ,
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ kpts0_long = kpts0.round().long()
+
+ # Sample depth, get calculable_mask on depth != 0
+ kpts0_depth = torch.stack(
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+ ) # (N, L)
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+ w_kpts0_long = w_kpts0.long()
+ w_kpts0_long[~covisible_mask, :] = 0
+
+ w_kpts0_depth = torch.stack(
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+ ) # (N, L)
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < consistency_thr
+
+ # Cycle Consistency Check
+ dst_pts_h = torch.cat([w_kpts0, torch.ones_like(w_kpts0[..., [0]], device=w_kpts0.device)], dim=-1) * w_kpts0_depth[..., None] # B * N_dst * N_pts * 3
+ dst_pts_cam = K1.inverse() @ dst_pts_h.transpose(2, 1) # (N, 3, L)
+ dst_pose = T_0to1.inverse()
+ world_points_cycle_back = dst_pose[:, :3, :3] @ dst_pts_cam + dst_pose[:, :3, [3]]
+ src_warp_back_h = (K0 @ world_points_cycle_back).transpose(2, 1) # (N, L, 3)
+ src_back_proj_pts = src_warp_back_h[..., :2] / (src_warp_back_h[..., [2]] + 1e-4)
+ cycle_reproj_distance_mask = torch.linalg.norm(src_back_proj_pts - kpts0[:, None], dim=-1) < cycle_proj_distance_thr
+
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask * cycle_reproj_distance_mask
+
+ return valid_mask, w_kpts0
+
+@torch.no_grad()
+def warp_kpts_by_sparse_gt_matches_batches(kpts0, gt_matches, dist_thr):
+ B, n_pts = kpts0.shape[0], kpts0.shape[1]
+ if n_pts > 20 * 10000:
+ all_kpts_valid_mask, all_kpts_warpped = [], []
+ for b_id in range(B):
+ kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches(kpts0[[b_id]], gt_matches[[b_id]], dist_thr[[b_id]])
+ all_kpts_valid_mask.append(kpts_valid_mask)
+ all_kpts_warpped.append(kpts_warpped)
+ return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0)
+ else:
+ return warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr)
+
+@torch.no_grad()
+def warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr):
+ kpts_warpped = torch.zeros_like(kpts0)
+ kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
+ gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
+
+ dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M
+ if dist_thr is not None:
+ mask = dist_matrix < dist_thr[:, None, None]
+ else:
+ mask = torch.ones_like(dist_matrix, dtype=torch.bool)
+ # Mutual-Nearest check:
+ mask = mask \
+ * (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \
+ * (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0])
+
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+
+ j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
+ b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
+
+ i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
+ b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
+
+ kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids]
+ kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids]
+
+ return kpts_valid_mask, kpts_warpped
+
+@torch.no_grad()
+def warp_kpts_by_sparse_gt_matches_fine_chunks(kpts0, gt_matches, dist_thr):
+ B, n_pts = kpts0.shape[0], kpts0.shape[1]
+ chunk_n = 500
+ all_kpts_valid_mask, all_kpts_warpped = [], []
+ for b_id in range(0, B, chunk_n):
+ kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches_fine(kpts0[b_id : b_id+chunk_n], gt_matches, dist_thr)
+ all_kpts_valid_mask.append(kpts_valid_mask)
+ all_kpts_warpped.append(kpts_warpped)
+ return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0)
+
+@torch.no_grad()
+def warp_kpts_by_sparse_gt_matches_fine(kpts0, gt_matches, dist_thr):
+ """
+ Only support single batch
+ Input:
+ kpts0: N * ww * 2
+ gt_matches: M * 2
+ """
+ B = kpts0.shape[0] # B is the fine matches in a single pair
+ assert gt_matches.shape[0] == 1
+ kpts_warpped = torch.zeros_like(kpts0)
+ kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
+ gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
+
+ dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M
+ if dist_thr is not None:
+ mask = dist_matrix < dist_thr[:, None, None]
+ else:
+ mask = torch.ones_like(dist_matrix, dtype=torch.bool)
+ # Mutual-Nearest check:
+ mask = mask \
+ * (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \
+ * (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0])
+
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+
+ j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
+ b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
+
+ i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
+ b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
+
+ kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[0, j_ids]
+ kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][0, j_ids]
+
+ return kpts_valid_mask, kpts_warpped
+
+@torch.no_grad()
+def warp_kpts_by_sparse_gt_matches_fast(kpts0, gt_matches, scale0, current_h, current_w):
+ B, n_gt_pts = gt_matches.shape[0], gt_matches.shape[1]
+ kpts_warpped = torch.zeros_like(kpts0)
+ kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool)
+ gt_matches_non_padding_mask = gt_matches.sum(-1) > 0
+
+ all_j_idxs = torch.arange(gt_matches.shape[-2], device=gt_matches.device, dtype=torch.long)[None].expand(B, n_gt_pts)
+ all_b_idxs = torch.arange(B, device=gt_matches.device, dtype=torch.long)[:, None].expand(B, n_gt_pts)
+ gt_matches_rescale = gt_matches[..., :2] / scale0 # From original img scale to resized scale
+ in_boundary_mask = (gt_matches_rescale[..., 0] <= current_w-1) & (gt_matches_rescale[..., 0] >= 0) & (gt_matches_rescale[..., 1] <= current_h -1) & (gt_matches_rescale[..., 1] >= 0)
+
+ gt_matches_rescale = gt_matches_rescale.round().to(torch.long)
+ all_i_idxs = gt_matches_rescale[..., 1] * current_w + gt_matches_rescale[..., 0] # idx = y * w + x
+
+ # Filter:
+ b_ids, i_ids, j_ids = map(lambda x: x[gt_matches_non_padding_mask & in_boundary_mask], [all_b_idxs, all_i_idxs, all_j_idxs])
+
+ j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1))
+ b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids])
+
+ i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1))
+ b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids])
+
+ kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids]
+ kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids]
+
+ return kpts_valid_mask, kpts_warpped
+
+
+@torch.no_grad()
+def homo_warp_kpts(kpts0, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
+ """
+ original_size1: N * 2, (h, w)
+ """
+ normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
+ kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
+ kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L
+ if original_size0 is not None:
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
+ & (kpts0[..., 1] < original_size0[:, [0]]) # N * L
+
+ return valid_mask, kpts_warpped
+
+@torch.no_grad()
+# if using mask in homo warp(for coarse supervision)
+def homo_warp_kpts_with_mask(kpts0, scale, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
+ """
+ original_size1: N * 2, (h, w)
+ """
+ normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
+ kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
+ kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
+ # get coarse-level depth_mask
+ depth_mask_coarse = depth_mask[:, :, ::scale, ::scale]
+ depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1)
+
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L
+ if original_size0 is not None:
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L
+
+ return valid_mask, kpts_warpped
+
+@torch.no_grad()
+# if using mask in homo warp(for fine supervision)
+def homo_warp_kpts_with_mask_f(kpts0, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None):
+ """
+ original_size1: N * 2, (h, w)
+ """
+ normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L)
+ kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3)
+ kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L
+ if original_size0 is not None:
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L
+
+ return valid_mask, kpts_warpped
+
+@torch.no_grad()
+def homo_warp_kpts_glue(kpts0, homo, original_size0=None, original_size1=None):
+ """
+ original_size1: N * 2, (h, w)
+ """
+ kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L
+ if original_size0 is not None:
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
+ & (kpts0[..., 1] < original_size0[:, [0]]) # N * L
+ return valid_mask, kpts_warpped
+
+@torch.no_grad()
+# if using mask in homo warp(for coarse supervision)
+def homo_warp_kpts_glue_with_mask(kpts0, scale, depth_mask, homo, original_size0=None, original_size1=None):
+ """
+ original_size1: N * 2, (h, w)
+ """
+ kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
+ # get coarse-level depth_mask
+ depth_mask_coarse = depth_mask[:, :, ::scale, ::scale]
+ depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1)
+
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L
+ if original_size0 is not None:
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L
+ return valid_mask, kpts_warpped
+
+@torch.no_grad()
+# if using mask in homo warp(for fine supervision)
+def homo_warp_kpts_glue_with_mask_f(kpts0, depth_mask, homo, original_size0=None, original_size1=None):
+ """
+ original_size1: N * 2, (h, w)
+ """
+ kpts_warpped = warp_points_torch(kpts0, homo, inverse=False)
+ valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \
+ & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L
+ if original_size0 is not None:
+ valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \
+ & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L
+ return valid_mask, kpts_warpped
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py b/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4a4b4780943617588cb193efb51a261ebc17cda
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py
@@ -0,0 +1,131 @@
+import math
+import torch
+from torch import nn
+
+
+class PositionEncodingSine(nn.Module):
+ """
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
+ """
+
+ def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True, npe=False):
+ """
+ Args:
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
+ on the final performance. For now, we keep both impls for backward compatability.
+ We will remove the buggy impl after re-training all variants of our released models.
+ """
+ super().__init__()
+
+ pe = torch.zeros((d_model, *max_shape))
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
+
+ assert npe is not None
+ if npe is not None:
+ if isinstance(npe, bool):
+ train_res_H, train_res_W, test_res_H, test_res_W = 832, 832, 832, 832
+ print('loftr no npe!!!!', npe)
+ else:
+ print('absnpe!!!!', npe)
+ train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
+ y_position, x_position = y_position * train_res_H / test_res_H, x_position * train_res_W / test_res_W
+
+ if temp_bug_fix:
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
+ else: # a buggy implementation (for backward compatability only)
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
+
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
+
+ def forward(self, x):
+ """
+ Args:
+ x: [N, C, H, W]
+ """
+ return x + self.pe[:, :, :x.size(2), :x.size(3)]
+
+class RoPEPositionEncodingSine(nn.Module):
+ """
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
+ """
+
+ def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True):
+ """
+ Args:
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
+ on the final performance. For now, we keep both impls for backward compatability.
+ We will remove the buggy impl after re-training all variants of our released models.
+ """
+ super().__init__()
+
+ # pe = torch.zeros((d_model, *max_shape))
+ # y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1)
+ # x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1)
+ i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1]
+ j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1]
+
+ assert npe is not None
+ if npe is not None:
+ train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
+ i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W
+
+ div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4)))
+ div_term = div_term[None, None, :] # [1, 1, C//4]
+ # pe[0::4, :, :] = torch.sin(x_position * div_term)
+ # pe[1::4, :, :] = torch.cos(x_position * div_term)
+ # pe[2::4, :, :] = torch.sin(y_position * div_term)
+ # pe[3::4, :, :] = torch.cos(y_position * div_term)
+ sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
+ cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
+ sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term)
+ sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term)
+ cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term)
+ cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term)
+
+ sin = sin.repeat_interleave(2, dim=-1)
+ cos = cos.repeat_interleave(2, dim=-1)
+ # self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, H, W, C]
+ self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2]
+ self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2]
+
+ i_position4 = i_position.reshape(64,4,64,4,1)[...,0,:]
+ i_position4 = i_position4.mean(-3)
+ j_position4 = j_position.reshape(64,4,64,4,1)[:,0,...]
+ j_position4 = j_position4.mean(-2)
+ sin4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
+ cos4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
+ sin4[:, :, 0::2] = torch.sin(i_position4 * div_term).half() if ropefp16 else torch.sin(i_position4 * div_term)
+ sin4[:, :, 1::2] = torch.sin(j_position4 * div_term).half() if ropefp16 else torch.sin(j_position4 * div_term)
+ cos4[:, :, 0::2] = torch.cos(i_position4 * div_term).half() if ropefp16 else torch.cos(i_position4 * div_term)
+ cos4[:, :, 1::2] = torch.cos(j_position4 * div_term).half() if ropefp16 else torch.cos(j_position4 * div_term)
+ sin4 = sin4.repeat_interleave(2, dim=-1)
+ cos4 = cos4.repeat_interleave(2, dim=-1)
+ self.register_buffer('sin4', sin4.unsqueeze(0), persistent=False) # [1, H, W, C//2]
+ self.register_buffer('cos4', cos4.unsqueeze(0), persistent=False) # [1, H, W, C//2]
+
+
+
+ def forward(self, x, ratio=1):
+ """
+ Args:
+ x: [N, H, W, C]
+ """
+ if ratio == 4:
+ return (x * self.cos4[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin4[:, :x.size(1), :x.size(2), :])
+ else:
+ return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :])
+
+ def rotate_half(self, x):
+ x = x.unflatten(-1, (-1, 2))
+ x1, x2 = x.unbind(dim=-1)
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py b/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..f57caa3a4b1498e31b5daca0289dd9381489f2b9
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py
@@ -0,0 +1,475 @@
+from math import log
+from loguru import logger as loguru_logger
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from kornia.utils import create_meshgrid
+
+from .geometry import warp_kpts, homo_warp_kpts, homo_warp_kpts_glue, homo_warp_kpts_with_mask, homo_warp_kpts_with_mask_f, homo_warp_kpts_glue_with_mask, homo_warp_kpts_glue_with_mask_f, warp_kpts_by_sparse_gt_matches_fast, warp_kpts_by_sparse_gt_matches_fine_chunks
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+def static_vars(**kwargs):
+ def decorate(func):
+ for k in kwargs:
+ setattr(func, k, kwargs[k])
+ return func
+ return decorate
+
+############## ↓ Coarse-Level supervision ↓ ##############
+
+@torch.no_grad()
+def mask_pts_at_padded_regions(grid_pt, mask):
+ """For megadepth dataset, zero-padding exists in images"""
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+ grid_pt[~mask.bool()] = 0
+ return grid_pt
+
+
+@torch.no_grad()
+def spvs_coarse(data, config):
+ """
+ Update:
+ data (dict): {
+ "conf_matrix_gt": [N, hw0, hw1],
+ 'spv_b_ids': [M]
+ 'spv_i_ids': [M]
+ 'spv_j_ids': [M]
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
+ }
+
+ NOTE:
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
+ """
+ # 1. misc
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+
+ if 'loftr' in config.METHOD:
+ scale = config['LOFTR']['RESOLUTION'][0]
+
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+
+ if config['LOFTR']['MATCH_COARSE']['MTD_SPVS'] and not config['LOFTR']['FORCE_LOOP_BACK']:
+ # 2. warp grids
+ # create kpts in meshgrid and resize them to image resolution
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
+ grid_pt0_i = scale0 * grid_pt0_c
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+ grid_pt1_i = scale1 * grid_pt1_c
+
+ correct_0to1 = torch.zeros((grid_pt0_i.shape[0], grid_pt0_i.shape[1]), dtype=torch.bool, device=grid_pt0_i.device)
+ w_pt0_i = torch.zeros_like(grid_pt0_i)
+
+ valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0
+ valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0)
+ valid_gt_match_warp_mask = (data['gt_matches_mask'][:, 0] != 0) # N
+
+ if valid_homo_warp_mask.sum() != 0:
+ if data['homography'].sum()==0:
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): # the key 'depth_mask' only exits when using the dataste "CommonDataSetHomoWarp"
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ else:
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_pt0_i[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \
+ data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ else:
+ if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0:
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ else:
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_pt0_i[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ correct_0to1[valid_homo_warp_mask] = correct_0to1_homo
+ w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo
+
+ if valid_gt_match_warp_mask.sum() != 0:
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_pt0_i[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h0, current_w=w0)
+ correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt
+ w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt
+
+ if valid_dpt_b_mask.sum() != 0:
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_pt0_i[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05)
+ correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt
+ w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt
+
+ w_pt0_c = w_pt0_i / scale1
+
+ # 3. check if mutual nearest neighbor
+ w_pt0_c_round = w_pt0_c[:, :, :].round() # [N, hw, 2]
+ if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
+ w_pt0_c_error = (1.0 - 2*torch.abs(w_pt0_c - w_pt0_c_round)).prod(-1)
+ w_pt0_c_round = w_pt0_c_round.long() # [N, hw, 2]
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 # [N, hw]
+
+ # corner case: out of boundary
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = -1
+
+ correct_0to1[:, 0] = False # ignore the top-left corner
+
+ # 4. construct a gt conf_matrix
+ mask1 = torch.stack([data['mask1'].reshape(-1, h1*w1)[_b, _i] for _b, _i in enumerate(nearest_index1)], dim=0)
+ correct_0to1 = correct_0to1 * data['mask0'].reshape(-1, h0*w0) * mask1
+
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device, dtype=torch.bool)
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
+ j_ids = nearest_index1[b_ids, i_ids]
+ valid_j_ids = j_ids != -1
+ b_ids, i_ids, j_ids = map(lambda x: x[valid_j_ids], [b_ids, i_ids, j_ids])
+
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
+
+ # overlap weight
+ if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
+ conf_matrix_error_gt = w_pt0_c_error[b_ids, i_ids]
+ assert torch.all(conf_matrix_error_gt >= -0.001)
+ assert torch.all(conf_matrix_error_gt <= 1.001)
+ data.update({'conf_matrix_error_gt': conf_matrix_error_gt})
+ data.update({'conf_matrix_gt': conf_matrix_gt})
+
+ # 5. save coarse matches(gt) for training fine level
+ if len(b_ids) == 0:
+ loguru_logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
+ # this won't affect fine-level loss calculation
+ b_ids = torch.tensor([0], device=device)
+ i_ids = torch.tensor([0], device=device)
+ j_ids = torch.tensor([0], device=device)
+
+ data.update({
+ 'spv_b_ids': b_ids,
+ 'spv_i_ids': i_ids,
+ 'spv_j_ids': j_ids
+ })
+
+ data.update({'mkpts0_c_gt_b_ids': b_ids})
+ data.update({'mkpts0_c_gt': torch.stack([i_ids % w0, i_ids // w0], dim=-1) * scale0[b_ids, 0]})
+ data.update({'mkpts1_c_gt': torch.stack([j_ids % w1, j_ids // w1], dim=-1) * scale1[b_ids, 0]})
+
+ # 6. save intermediate results (for fast fine-level computation)
+ data.update({
+ 'spv_w_pt0_i': w_pt0_i,
+ 'spv_pt1_i': grid_pt1_i,
+ # 'correct_0to1_c': correct_0to1
+ })
+ else:
+ raise NotImplementedError
+
+def compute_supervision_coarse(data, config):
+ spvs_coarse(data, config)
+
+@torch.no_grad()
+def get_gt_flow(data, h, w):
+ device = data['image0'].device
+ B, _, H0, W0 = data['image0'].shape
+ scale = H0 / h
+
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
+
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=device
+ )
+ for n in (B, h, w)
+ ]
+ )
+ grid_coord = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, h*w, 2) # normalized
+ grid_coord = torch.stack(
+ (w * (grid_coord[..., 0] + 1) / 2, h * (grid_coord[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ grid_coord_in_origin = grid_coord * scale0
+
+ correct_0to1 = torch.zeros((grid_coord_in_origin.shape[0], grid_coord_in_origin.shape[1]), dtype=torch.bool, device=device)
+ w_pt0_i = torch.zeros_like(grid_coord_in_origin)
+
+ valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0
+ valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0)
+ valid_gt_match_warp_mask = (data['gt_matches_mask'] != 0)[:, 0]
+
+ if valid_homo_warp_mask.sum() != 0:
+ if data['homography'].sum()==0:
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
+ # data['load_mask'] = True or False, data['depth_mask'] = depth_mask or None
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \
+ data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ else:
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_coord_in_origin[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], \
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ else:
+ if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0:
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ else:
+ correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_coord_in_origin[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \
+ original_size1=data['origin_img_size1'][valid_homo_warp_mask])
+ correct_0to1[valid_homo_warp_mask] = correct_0to1_homo
+ w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo
+
+ if valid_gt_match_warp_mask.sum() != 0:
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_coord_in_origin[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h, current_w=w)
+ correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt
+ w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt
+ if valid_dpt_b_mask.sum() != 0:
+ correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_coord_in_origin[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05)
+ correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt
+ w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt
+
+ w_pt0_c = w_pt0_i / scale1
+
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ correct_0to1[out_bound_mask(w_pt0_c, w, h)] = 0
+
+ w_pt0_n = torch.stack(
+ (2 * w_pt0_c[..., 0] / w - 1, 2 * w_pt0_c[..., 1] / h - 1), dim=-1
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+ # w_pt1_c = w_pt1_i / scale0
+
+ if scale > 8:
+ data.update({'mkpts0_c_gt': grid_coord_in_origin[correct_0to1]})
+ data.update({'mkpts1_c_gt': w_pt0_i[correct_0to1]})
+
+ return w_pt0_n.reshape(B, h, w, 2), correct_0to1.float().reshape(B, h, w)
+
+@torch.no_grad()
+def compute_roma_supervision(data, config):
+ gt_flow = {}
+ for scale in list(data["corresps"]):
+ scale_corresps = data["corresps"][scale]
+ flow_pre_delta = rearrange(scale_corresps['flow'] if 'flow'in scale_corresps else scale_corresps['dense_flow'], "b d h w -> b h w d")
+ b, h, w, d = flow_pre_delta.shape
+ gt_warp, gt_prob = get_gt_flow(data, h, w)
+ gt_flow[scale] = {'gt_warp': gt_warp, "gt_prob": gt_prob}
+
+ data.update({"gt": gt_flow})
+
+############## ↓ Fine-Level supervision ↓ ##############
+
+@static_vars(counter = 0)
+@torch.no_grad()
+def spvs_fine(data, config, logger = None):
+ """
+ Update:
+ data (dict):{
+ "expec_f_gt": [M, 2]}
+ """
+ # 1. misc
+ # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
+ if config.LOFTR.FINE.MTD_SPVS:
+ pt1_i = data['spv_pt1_i']
+ else:
+ spv_w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
+ if 'loftr' in config.METHOD:
+ scale = config['LOFTR']['RESOLUTION'][1]
+ scale_c = config['LOFTR']['RESOLUTION'][0]
+ radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
+
+ # 2. get coarse prediction
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+
+ # 3. compute gt
+ scalei0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale0 = scale * data['scale0'] if 'scale0' in data else scale
+ scalei1 = scale * data['scale1'][b_ids] if 'scale0' in data else scale
+
+ if config.LOFTR.FINE.MTD_SPVS:
+ W = config['LOFTR']['FINE_WINDOW_SIZE']
+ WW = W*W
+ device = data['image0'].device
+
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+
+ if config.LOFTR.ALIGN_CORNER is False:
+ hf0, wf0, hf1, wf1 = data['hw0_f'][0], data['hw0_f'][1], data['hw1_f'][0], data['hw1_f'][1]
+ hc0, wc0, hc1, wc1 = data['hw0_c'][0], data['hw0_c'][1], data['hw1_c'][0], data['hw1_c'][1]
+ # loguru_logger.info('hf0, wf0, hf1, wf1', hf0, wf0, hf1, wf1)
+ else:
+ hf0, wf0, hf1, wf1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+ hc0, wc0, hc1, wc1 = map(lambda x: x // scale_c, [H0, W0, H1, W1])
+
+ m = b_ids.shape[0]
+ if m == 0:
+ conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device)
+
+ data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
+ conf_matrix_f_error_gt = torch.zeros(1, device=device)
+ data.update({'conf_matrix_f_error_gt': conf_matrix_f_error_gt})
+ if config.LOFTR.MATCH_FINE.MULTI_REGRESS:
+ data.update({'expec_f': torch.zeros(1, 3, device=device)})
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
+
+ if config.LOFTR.MATCH_FINE.LOCAL_REGRESS:
+ data.update({'expec_f': torch.zeros(1, 2, device=device)})
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
+ else:
+ grid_pt0_f = create_meshgrid(hf0, wf0, False, device) - W // 2 + 0.5 # [1, hf0, wf0, 2] # use fine coordinates
+ # grid_pt0_f = create_meshgrid(hf0, wf0, False, device) + 0.5 # [1, hf0, wf0, 2] # use fine coordinates
+ grid_pt0_f = rearrange(grid_pt0_f, 'n h w c -> n c h w')
+ # 1. unfold(crop) all local windows
+ if config.LOFTR.ALIGN_CORNER is False: # even windows
+ if config.LOFTR.MATCH_FINE.MULTI_REGRESS or (config.LOFTR.MATCH_FINE.LOCAL_REGRESS and W == 10):
+ grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W-2, padding=1) # overlap windows W-2 padding=1
+ else:
+ grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W, padding=0)
+ else:
+ grid_pt0_f_unfold = F.unfold(grid_pt0_f[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2)
+ grid_pt0_f_unfold = rearrange(grid_pt0_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 2]
+ grid_pt0_f_unfold = repeat(grid_pt0_f_unfold[0], 'l ww c -> N l ww c', N=N)
+
+ # 2. select only the predicted matches
+ grid_pt0_f_unfold = grid_pt0_f_unfold[data['b_ids'], data['i_ids']] # [m, ww, 2]
+ grid_pt0_f_unfold = scalei0[:,None,:] * grid_pt0_f_unfold # [m, ww, 2]
+
+ # use depth mask
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
+ # depth_mask --> (n, 1, hf, wf)
+ homo_mask0 = data['homo_mask0']
+ homo_mask0 = F.unfold(homo_mask0[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2)
+ homo_mask0 = rearrange(homo_mask0, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 1]
+ homo_mask0 = repeat(homo_mask0[0], 'l ww c -> N l ww c', N=N)
+ # select only the predicted matches
+ homo_mask0 = homo_mask0[data['b_ids'], data['i_ids']]
+
+ correct_0to1_f_list, w_pt0_i_list = [], []
+
+ correct_0to1_f = torch.zeros(m, WW, device=device, dtype=torch.bool)
+ w_pt0_i = torch.zeros(m, WW, 2, device=device, dtype=torch.float32)
+ for b in range(N):
+ mask = b_ids == b
+
+ match = int(mask.sum())
+ skip_reshape = False
+ if match == 0:
+ print(f"no pred fine matches, skip!")
+ continue
+ if (data['homography'][b].sum() != 0) | (data['homo_sample_normed'][b].sum() != 0):
+ if data['homography'][b].sum()==0:
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['norm_pixel_mat'][[b]], \
+ data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
+ else:
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['norm_pixel_mat'][[b]], \
+ data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
+ else:
+ if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0):
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['homography'][[b]], \
+ data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
+ else:
+ correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['homography'][[b]], \
+ data['origin_img_size0'][[b]], data['origin_img_size1'][[b]])
+ elif data['T_0to1'][b].sum() != 0:
+ correct_0to1_f_mask, w_pt0_i_mask = warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['depth0'][[b],...],
+ data['depth1'][[b],...], data['T_0to1'][[b],...],
+ data['K0'][[b],...], data['K1'][[b],...]) # [k, WW], [k, WW, 2]
+ elif data['gt_matches_mask'][b].sum() != 0:
+ correct_0to1_f_mask, w_pt0_i_mask = warp_kpts_by_sparse_gt_matches_fine_chunks(grid_pt0_f_unfold[mask], gt_matches=data['gt_matches'][[b]], dist_thr=scale0[[b]].max(dim=-1)[0])
+ skip_reshape = True
+ correct_0to1_f[mask] = correct_0to1_f_mask.reshape(match, WW) if not skip_reshape else correct_0to1_f_mask
+ w_pt0_i[mask] = w_pt0_i_mask.reshape(match, WW, 2) if not skip_reshape else w_pt0_i_mask
+
+ delta_w_pt0_i = w_pt0_i - pt1_i[b_ids, j_ids][:,None,:] # [m, WW, 2]
+ delta_w_pt0_f = delta_w_pt0_i / scalei1[:,None,:] + W // 2 - 0.5
+ delta_w_pt0_f_round = delta_w_pt0_f[:, :, :].round()
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT and config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2:
+ w_pt0_f_error = (1.0 - torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0.25, 1]
+ elif config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
+ w_pt0_f_error = (1.0 - 2*torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0, 1]
+ delta_w_pt0_f_round = delta_w_pt0_f_round.long()
+
+
+ nearest_index1 = delta_w_pt0_f_round[..., 0] + delta_w_pt0_f_round[..., 1] * W # [m, WW]
+
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ ob_mask = out_bound_mask(delta_w_pt0_f_round, W, W)
+ nearest_index1[ob_mask] = 0
+
+ correct_0to1_f[ob_mask] = 0
+ m_ids_d, i_ids_d = torch.where(correct_0to1_f != 0)
+
+ j_ids_d = nearest_index1[m_ids_d, i_ids_d]
+
+ # For plotting:
+ mkpts0_f_gt = grid_pt0_f_unfold[m_ids_d, i_ids_d] # [m, 2]
+ mkpts1_f_gt = w_pt0_i[m_ids_d, i_ids_d] # [m, 2]
+ data.update({'mkpts0_f_gt_b_ids': m_ids_d})
+ data.update({'mkpts0_f_gt': mkpts0_f_gt})
+ data.update({'mkpts1_f_gt': mkpts1_f_gt})
+
+ if config.LOFTR.MATCH_FINE.MULTI_REGRESS:
+ assert not config.LOFTR.MATCH_FINE.LOCAL_REGRESS
+ expec_f_gt = delta_w_pt0_f - W // 2 + 0.5 # use delta(e.g. [-3.5,3.5]) in regression rather than [0,W] (e.g. [0,7])
+ expec_f_gt = expec_f_gt[m_ids_d, i_ids_d] / (W // 2 - 1) # specific radius for overlaped even windows & align_corner=False
+ data.update({'expec_f_gt': expec_f_gt})
+ data.update({'m_ids_d': m_ids_d, 'i_ids_d': i_ids_d})
+ else: # spv fine dual softmax
+ if config.LOFTR.MATCH_FINE.LOCAL_REGRESS:
+ expec_f_gt = delta_w_pt0_f - delta_w_pt0_f_round
+
+ # mask fine windows boarder
+ j_ids_d_il, j_ids_d_jl = j_ids_d // W, j_ids_d % W
+ if config.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK:
+ mask = None
+ m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d.to(torch.long), i_ids_d.to(torch.long), j_ids_d_il.to(torch.long), j_ids_d_jl.to(torch.long)
+ else:
+ mask = (j_ids_d_il >= 1) & (j_ids_d_il < W-1) & (j_ids_d_jl >= 1) & (j_ids_d_jl < W-1)
+ if W == 10:
+ i_ids_d_il, i_ids_d_jl = i_ids_d // W, i_ids_d % W
+ mask = mask & (i_ids_d_il >= 1) & (i_ids_d_il <= W-2) & (i_ids_d_jl >= 1) & (i_ids_d_jl <= W-2)
+
+ m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d[mask].to(torch.long), i_ids_d[mask].to(torch.long), j_ids_d_il[mask].to(torch.long), j_ids_d_jl[mask].to(torch.long)
+ if mask is not None:
+ loguru_logger.info(f'percent of gt mask.sum / mask.numel: {mask.sum().float()/mask.numel():.2f}')
+ if m_ids_dl.numel() == 0:
+ loguru_logger.warning(f"No groundtruth fine match found for local regress: {data['pair_names']}")
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
+ data.update({'expec_f': torch.zeros(1, 2, device=device)})
+ else:
+ expec_f_gt = expec_f_gt[m_ids_dl, i_ids_dl]
+ data.update({"expec_f_gt": expec_f_gt})
+
+ data.update({"m_ids_dl": m_ids_dl,
+ "i_ids_dl": i_ids_dl,
+ "j_ids_d_il": j_ids_d_il,
+ "j_ids_d_jl": j_ids_d_jl
+ })
+ else: # no fine regress
+ pass
+
+ # spv fine dual softmax
+ conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device, dtype=torch.bool)
+ conf_matrix_f_gt[m_ids_d, i_ids_d, j_ids_d] = 1
+ data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
+ w_pt0_f_error = w_pt0_f_error[m_ids_d, i_ids_d]
+ assert torch.all(w_pt0_f_error >= -0.001)
+ assert torch.all(w_pt0_f_error <= 1.001)
+ data.update({'conf_matrix_f_error_gt': w_pt0_f_error})
+
+ conf_matrix_f_gt_sum = conf_matrix_f_gt.sum()
+ if conf_matrix_f_gt_sum != 0:
+ pass
+ else:
+ loguru_logger.info(f'[no gt plot]no fine matches to supervise')
+
+ else:
+ expec_f_gt = (spv_w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scalei1 / 4 # [M, 2]
+ data.update({"expec_f_gt": expec_f_gt})
+
+
+def compute_supervision_fine(data, config, logger=None):
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_fine(data, config, logger)
+ else:
+ raise NotImplementedError
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/optimizers/__init__.py b/imcui/third_party/MatchAnything/src/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c946086518db86e6775774eca36d67188c8b657
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/optimizers/__init__.py
@@ -0,0 +1,50 @@
+import torch
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
+
+
+def build_optimizer(model, config):
+ name = config.TRAINER.OPTIMIZER
+ lr = config.TRAINER.TRUE_LR
+
+ if name == "adam":
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
+ elif name == "adamw":
+ if ("ROMA" in config.METHOD) or ("DKM" in config.METHOD):
+ # Filter the backbone param and others param:
+ keyword = 'model.encoder'
+ backbone_params = [param for name, param in list(filter(lambda kv: keyword in kv[0], model.named_parameters()))]
+ base_params = [param for name, param in list(filter(lambda kv: keyword not in kv[0], model.named_parameters()))]
+ params = [{'params': backbone_params, 'lr': lr * 0.05}, {'params': base_params}]
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
+ else:
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS)
+ else:
+ raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
+
+
+def build_scheduler(config, optimizer):
+ """
+ Returns:
+ scheduler (dict):{
+ 'scheduler': lr_scheduler,
+ 'interval': 'step', # or 'epoch'
+ 'monitor': 'val_f1', (optional)
+ 'frequency': x, (optional)
+ }
+ """
+ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
+ name = config.TRAINER.SCHEDULER
+
+ if name == 'MultiStepLR':
+ scheduler.update(
+ {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
+ elif name == 'CosineAnnealing':
+ scheduler.update(
+ {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
+ elif name == 'ExponentialLR':
+ scheduler.update(
+ {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
+ else:
+ raise NotImplementedError()
+
+ return scheduler
diff --git a/imcui/third_party/MatchAnything/src/utils/__init__.py b/imcui/third_party/MatchAnything/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/MatchAnything/src/utils/augment.py b/imcui/third_party/MatchAnything/src/utils/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/augment.py
@@ -0,0 +1,55 @@
+import albumentations as A
+
+
+class DarkAug(object):
+ """
+ Extreme dark augmentation aiming at Aachen Day-Night
+ """
+
+ def __init__(self) -> None:
+ self.augmentor = A.Compose([
+ A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
+ A.Blur(p=0.1, blur_limit=(3, 9)),
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
+ ], p=0.75)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+class MobileAug(object):
+ """
+ Random augmentations aiming at images of mobile/handhold devices.
+ """
+
+ def __init__(self):
+ self.augmentor = A.Compose([
+ A.MotionBlur(p=0.25),
+ A.ColorJitter(p=0.5),
+ A.RandomRain(p=0.1), # random occlusion
+ A.RandomSunFlare(p=0.1),
+ A.JpegCompression(p=0.25),
+ A.ISONoise(p=0.25)
+ ], p=1.0)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+def build_augmentor(method=None, **kwargs):
+ if method is not None:
+ raise NotImplementedError('Using of augmentation functions are not supported yet!')
+ if method == 'dark':
+ return DarkAug()
+ elif method == 'mobile':
+ return MobileAug()
+ elif method is None:
+ return None
+ else:
+ raise ValueError(f'Invalid augmentation method: {method}')
+
+
+if __name__ == '__main__':
+ augmentor = build_augmentor('FDA')
diff --git a/imcui/third_party/MatchAnything/src/utils/colmap.py b/imcui/third_party/MatchAnything/src/utils/colmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..deefe92a3a8e132e3d5d51b8eaf08b1050e22aac
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/colmap.py
@@ -0,0 +1,530 @@
+# Copyright (c) 2022, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+from typing import List, Tuple, Dict
+import os
+import collections
+import numpy as np
+import struct
+import argparse
+
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"])
+BaseCamera = collections.namedtuple(
+ "Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+ @property
+ def world_to_camera(self) -> np.ndarray:
+ R = qvec2rotmat(self.qvec)
+ t = self.tvec
+ world2cam = np.eye(4)
+ world2cam[:3, :3] = R
+ world2cam[:3, 3] = t
+ return world2cam
+
+
+class Camera(BaseCamera):
+ @property
+ def K(self):
+ K = np.eye(3)
+ if self.model == "SIMPLE_PINHOLE" or self.model == "SIMPLE_RADIAL" or self.model == "RADIAL" or self.model == "SIMPLE_RADIAL_FISHEYE" or self.model == "RADIAL_FISHEYE":
+ K[0, 0] = self.params[0]
+ K[1, 1] = self.params[0]
+ K[0, 2] = self.params[1]
+ K[1, 2] = self.params[2]
+ elif self.model == "PINHOLE" or self.model == "OPENCV" or self.model == "OPENCV_FISHEYE" or self.model == "FULL_OPENCV" or self.model == "FOV" or self.model == "THIN_PRISM_FISHEYE":
+ K[0, 0] = self.params[0]
+ K[1, 1] = self.params[1]
+ K[0, 2] = self.params[2]
+ K[1, 2] = self.params[3]
+ else:
+ raise NotImplementedError
+ return K
+
+
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
+}
+CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
+ for camera_model in CAMERA_MODELS])
+CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
+ for camera_model in CAMERA_MODELS])
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
+ """pack and write to a binary file.
+ :param fid:
+ :param data: data to send, if multiple elements are sent at the same time,
+ they should be encapsuled either in a list or a tuple
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ should be the same length as the data list or tuple
+ :param endian_character: Any of {@, =, <, >, !}
+ """
+ if isinstance(data, (list, tuple)):
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
+ else:
+ bytes = struct.pack(endian_character + format_char_sequence, data)
+ fid.write(bytes)
+
+
+def read_cameras_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(id=camera_id, model=model,
+ width=width, height=height,
+ params=params)
+ return cameras
+
+
+def read_cameras_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ")
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(fid, num_bytes=8*num_params,
+ format_char_sequence="d"*num_params)
+ cameras[camera_id] = Camera(id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params))
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def write_cameras_text(cameras, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ HEADER = "# Camera list with one line of data per camera:\n" + \
+ "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \
+ "# Number of cameras: {}\n".format(len(cameras))
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, cam in cameras.items():
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
+ line = " ".join([str(elem) for elem in to_write])
+ fid.write(line + "\n")
+
+
+def write_cameras_binary(cameras, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(cameras), "Q")
+ for _, cam in cameras.items():
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
+ camera_properties = [cam.id,
+ model_id,
+ cam.width,
+ cam.height]
+ write_next_bytes(fid, camera_properties, "iiQQ")
+ for p in cam.params:
+ write_next_bytes(fid, float(p), "d")
+ return cameras
+
+
+def read_images_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
+ tuple(map(float, elems[1::3]))])
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def read_images_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi")
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8,
+ format_char_sequence="Q")[0]
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
+ format_char_sequence="ddq"*num_points2D)
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
+ tuple(map(float, x_y_id_s[1::3]))])
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def write_images_text(images, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ if len(images) == 0:
+ mean_observations = 0
+ else:
+ mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
+ HEADER = "# Image list with two lines of data per image:\n" + \
+ "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \
+ "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \
+ "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, img in images.items():
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
+ first_line = " ".join(map(str, image_header))
+ fid.write(first_line + "\n")
+
+ points_strings = []
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
+ fid.write(" ".join(points_strings) + "\n")
+
+
+def write_images_binary(images, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_points3D_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ points3D = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ point3D_id = int(elems[0])
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = float(elems[7])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
+ points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
+ error=error, image_ids=image_ids,
+ point2D_idxs=point2D_idxs)
+ return points3D
+
+
+def read_points3D_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(
+ fid, num_bytes=8, format_char_sequence="Q")[0]
+ track_elems = read_next_bytes(
+ fid, num_bytes=8*track_length,
+ format_char_sequence="ii"*track_length)
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id, xyz=xyz, rgb=rgb,
+ error=error, image_ids=image_ids,
+ point2D_idxs=point2D_idxs)
+ return points3D
+
+
+def write_points3D_text(points3D, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ if len(points3D) == 0:
+ mean_track_length = 0
+ else:
+ mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
+ HEADER = "# 3D point list with one line of data per point:\n" + \
+ "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \
+ "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, pt in points3D.items():
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
+ fid.write(" ".join(map(str, point_header)) + " ")
+ track_strings = []
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
+ fid.write(" ".join(track_strings) + "\n")
+
+
+def write_points3D_binary(points3D, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
+
+
+def detect_model_format(path, ext):
+ if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
+ os.path.isfile(os.path.join(path, "images" + ext)) and \
+ os.path.isfile(os.path.join(path, "points3D" + ext)):
+ print("Detected model format: '" + ext + "'")
+ return True
+
+ return False
+
+
+def read_model(path, ext="") -> Tuple[Dict[int, Camera], Dict[int, Image], Dict[int, Point3D]]:
+ # try to detect the extension automatically
+ if ext == "":
+ if detect_model_format(path, ".bin"):
+ ext = ".bin"
+ elif detect_model_format(path, ".txt"):
+ ext = ".txt"
+ else:
+ raise ValueError("Provide model format: '.bin' or '.txt'")
+
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def write_model(cameras, images, points3D, path, ext=".bin"):
+ if ext == ".txt":
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
+ write_images_text(images, os.path.join(path, "images" + ext))
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
+ else:
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
+ write_images_binary(images, os.path.join(path, "images" + ext))
+ write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def qvec2rotmat(qvec):
+ return np.array([
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = np.array([
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
+ parser.add_argument("--input_model", help="path to input model folder")
+ parser.add_argument("--input_format", choices=[".bin", ".txt"],
+ help="input model format", default="")
+ parser.add_argument("--output_model",
+ help="path to output model folder")
+ parser.add_argument("--output_format", choices=[".bin", ".txt"],
+ help="outut model format", default=".txt")
+ args = parser.parse_args()
+
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
+
+ print("num_cameras:", len(cameras))
+ print("num_images:", len(images))
+ print("num_points3D:", len(points3D))
+
+ if args.output_model is not None:
+ write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/__init__.py b/imcui/third_party/MatchAnything/src/utils/colmap/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/database.py b/imcui/third_party/MatchAnything/src/utils/colmap/database.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a9c47bad6c522597dfaaf2a85ae0d252b5ab10
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/colmap/database.py
@@ -0,0 +1,417 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+# This script is based on an original implementation by True Price.
+
+import sys
+import sqlite3
+import numpy as np
+from loguru import logger
+
+
+IS_PYTHON3 = sys.version_info[0] >= 3
+
+MAX_IMAGE_ID = 2**31 - 1
+
+CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
+ camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ model INTEGER NOT NULL,
+ width INTEGER NOT NULL,
+ height INTEGER NOT NULL,
+ params BLOB,
+ prior_focal_length INTEGER NOT NULL)"""
+
+CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
+
+CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
+ image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ name TEXT NOT NULL UNIQUE,
+ camera_id INTEGER NOT NULL,
+ prior_qw REAL,
+ prior_qx REAL,
+ prior_qy REAL,
+ prior_qz REAL,
+ prior_tx REAL,
+ prior_ty REAL,
+ prior_tz REAL,
+ CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
+ FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
+""".format(MAX_IMAGE_ID)
+
+CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
+CREATE TABLE IF NOT EXISTS two_view_geometries (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ config INTEGER NOT NULL,
+ F BLOB,
+ E BLOB,
+ H BLOB,
+ qvec BLOB,
+ tvec BLOB)
+"""
+
+CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
+"""
+
+CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB)"""
+
+CREATE_NAME_INDEX = \
+ "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
+
+CREATE_ALL = "; ".join([
+ CREATE_CAMERAS_TABLE,
+ CREATE_IMAGES_TABLE,
+ CREATE_KEYPOINTS_TABLE,
+ CREATE_DESCRIPTORS_TABLE,
+ CREATE_MATCHES_TABLE,
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE,
+ CREATE_NAME_INDEX
+])
+
+
+def image_ids_to_pair_id(image_id1, image_id2):
+ if image_id1 > image_id2:
+ image_id1, image_id2 = image_id2, image_id1
+ return image_id1 * MAX_IMAGE_ID + image_id2
+
+
+def pair_id_to_image_ids(pair_id):
+ image_id2 = pair_id % MAX_IMAGE_ID
+ image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
+ return image_id1, image_id2
+
+
+def array_to_blob(array):
+ if IS_PYTHON3:
+ return array.tobytes()
+ else:
+ return np.getbuffer(array)
+
+
+def blob_to_array(blob, dtype, shape=(-1,)):
+ if IS_PYTHON3:
+ return np.fromstring(blob, dtype=dtype).reshape(*shape)
+ else:
+ return np.frombuffer(blob, dtype=dtype).reshape(*shape)
+
+
+class COLMAPDatabase(sqlite3.Connection):
+
+ @staticmethod
+ def connect(database_path):
+ return sqlite3.connect(str(database_path), factory=COLMAPDatabase)
+
+
+ def __init__(self, *args, **kwargs):
+ super(COLMAPDatabase, self).__init__(*args, **kwargs)
+
+ self.create_tables = lambda: self.executescript(CREATE_ALL)
+ self.create_cameras_table = \
+ lambda: self.executescript(CREATE_CAMERAS_TABLE)
+ self.create_descriptors_table = \
+ lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
+ self.create_images_table = \
+ lambda: self.executescript(CREATE_IMAGES_TABLE)
+ self.create_two_view_geometries_table = \
+ lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
+ self.create_keypoints_table = \
+ lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
+ self.create_matches_table = \
+ lambda: self.executescript(CREATE_MATCHES_TABLE)
+ self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
+
+ def add_camera(self, model, width, height, params,
+ prior_focal_length=False, camera_id=None):
+ params = np.asarray(params, np.float64)
+ cursor = self.execute(
+ "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
+ (camera_id, model, width, height, array_to_blob(params),
+ prior_focal_length))
+ return cursor.lastrowid
+
+ def add_image(self, name, camera_id,
+ prior_q=np.full(4, np.NaN), prior_t=np.full(3, np.NaN),
+ image_id=None):
+ cursor = self.execute(
+ "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
+ prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
+ return cursor.lastrowid
+
+ def add_keypoints(self, image_id, keypoints):
+ assert(len(keypoints.shape) == 2)
+ assert(keypoints.shape[1] in [2, 4, 6])
+
+ keypoints = np.asarray(keypoints, np.float32)
+ self.execute(
+ "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
+ (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
+
+ def add_descriptors(self, image_id, descriptors):
+ descriptors = np.ascontiguousarray(descriptors, np.uint8)
+ self.execute(
+ "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
+ (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
+
+ def add_matches(self, image_id1, image_id2, matches):
+ assert(len(matches.shape) == 2)
+ assert(matches.shape[1] == 2)
+
+ if image_id1 > image_id2:
+ matches = matches[:,::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ self.execute(
+ "INSERT INTO matches VALUES (?, ?, ?, ?)",
+ (pair_id,) + matches.shape + (array_to_blob(matches),))
+
+ def add_two_view_geometry(self, image_id1, image_id2, matches,
+ F=np.eye(3), E=np.eye(3), H=np.eye(3),
+ qvec=np.array([1.0, 0.0, 0.0, 0.0]),
+ tvec=np.zeros(3), config=2):
+ assert(len(matches.shape) == 2)
+ assert(matches.shape[1] == 2)
+
+ if image_id1 > image_id2:
+ matches = matches[:,::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ F = np.asarray(F, dtype=np.float64)
+ E = np.asarray(E, dtype=np.float64)
+ H = np.asarray(H, dtype=np.float64)
+ qvec = np.asarray(qvec, dtype=np.float64)
+ tvec = np.asarray(tvec, dtype=np.float64)
+ self.execute(
+ "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (pair_id,) + matches.shape + (array_to_blob(matches), config,
+ array_to_blob(F), array_to_blob(E), array_to_blob(H),
+ array_to_blob(qvec), array_to_blob(tvec)))
+
+ def update_two_view_geometry(self, image_id1, image_id2, matches,
+ F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
+ assert(len(matches.shape) == 2)
+ assert(matches.shape[1] == 2)
+
+ if image_id1 > image_id2:
+ matches = matches[:,::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ F = np.asarray(F, dtype=np.float64)
+ E = np.asarray(E, dtype=np.float64)
+ H = np.asarray(H, dtype=np.float64)
+
+ # Find whether exists:
+ row = self.execute(f"SELECT * FROM two_view_geometries WHERE pair_id = {pair_id} ")
+ data = list(next(row))
+ try:
+ matches_old = blob_to_array(data[3], np.uint32, (-1, 2))
+ except:
+ matches_old = None
+
+ if matches_old is not None:
+ for match in matches:
+ img0_id, img1_id = match
+
+ # Find duplicated pts
+ img0_dup_idxs = np.where(matches_old[:, 0] == img0_id)
+ img1_dup_idxs = np.where(matches_old[:, 1] == img1_id)
+
+ if len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 0:
+ # No duplicated matches:
+ matches_old = np.concatenate([matches_old, match[None]], axis=0)
+ elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 0:
+ matches_old[img0_dup_idxs[0]][0,1] = img1_id
+ elif len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 1:
+ matches_old[img1_dup_idxs[0]][0,0] = img0_id
+ elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 1:
+ if img0_dup_idxs[0] != img1_dup_idxs[0]:
+ # logger.warning(f"Duplicated matches exists!")
+ matches_old[img0_dup_idxs[0]][0,1] = img1_id
+ matches_old[img1_dup_idxs[0]][0,0] = img0_id
+ else:
+ raise NotImplementedError
+
+ # matches = np.concatenate([matches_old, matches], axis=0) # N * 2
+ matches = matches_old
+ self.execute(f"DELETE FROM two_view_geometries WHERE pair_id = {pair_id}")
+
+ data[1:4] = matches.shape + (array_to_blob(np.asarray(matches, np.uint32)),)
+ self.execute("INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", tuple(data))
+ else:
+ raise NotImplementedError
+
+ # self.add_two_view_geometry(image_id1, image_id2, matches)
+
+
+def example_usage():
+ import os
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--database_path", default="database.db")
+ args = parser.parse_args()
+
+ if os.path.exists(args.database_path):
+ print("ERROR: database path already exists -- will not modify it.")
+ return
+
+ # Open the database.
+
+ db = COLMAPDatabase.connect(args.database_path)
+
+ # For convenience, try creating all the tables upfront.
+
+ db.create_tables()
+
+ # Create dummy cameras.
+
+ model1, width1, height1, params1 = \
+ 0, 1024, 768, np.array((1024., 512., 384.))
+ model2, width2, height2, params2 = \
+ 2, 1024, 768, np.array((1024., 512., 384., 0.1))
+
+ camera_id1 = db.add_camera(model1, width1, height1, params1)
+ camera_id2 = db.add_camera(model2, width2, height2, params2)
+
+ # Create dummy images.
+
+ image_id1 = db.add_image("image1.png", camera_id1)
+ image_id2 = db.add_image("image2.png", camera_id1)
+ image_id3 = db.add_image("image3.png", camera_id2)
+ image_id4 = db.add_image("image4.png", camera_id2)
+
+ # Create dummy keypoints.
+ #
+ # Note that COLMAP supports:
+ # - 2D keypoints: (x, y)
+ # - 4D keypoints: (x, y, theta, scale)
+ # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
+
+ num_keypoints = 1000
+ keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
+ keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
+
+ db.add_keypoints(image_id1, keypoints1)
+ db.add_keypoints(image_id2, keypoints2)
+ db.add_keypoints(image_id3, keypoints3)
+ db.add_keypoints(image_id4, keypoints4)
+
+ # Create dummy matches.
+
+ M = 50
+ matches12 = np.random.randint(num_keypoints, size=(M, 2))
+ matches23 = np.random.randint(num_keypoints, size=(M, 2))
+ matches34 = np.random.randint(num_keypoints, size=(M, 2))
+
+ db.add_matches(image_id1, image_id2, matches12)
+ db.add_matches(image_id2, image_id3, matches23)
+ db.add_matches(image_id3, image_id4, matches34)
+
+ # Commit the data to the file.
+
+ db.commit()
+
+ # Read and check cameras.
+
+ rows = db.execute("SELECT * FROM cameras")
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id1
+ assert model == model1 and width == width1 and height == height1
+ assert np.allclose(params, params1)
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id2
+ assert model == model2 and width == width2 and height == height2
+ assert np.allclose(params, params2)
+
+ # Read and check keypoints.
+
+ keypoints = dict(
+ (image_id, blob_to_array(data, np.float32, (-1, 2)))
+ for image_id, data in db.execute(
+ "SELECT image_id, data FROM keypoints"))
+
+ assert np.allclose(keypoints[image_id1], keypoints1)
+ assert np.allclose(keypoints[image_id2], keypoints2)
+ assert np.allclose(keypoints[image_id3], keypoints3)
+ assert np.allclose(keypoints[image_id4], keypoints4)
+
+ # Read and check matches.
+
+ pair_ids = [image_ids_to_pair_id(*pair) for pair in
+ ((image_id1, image_id2),
+ (image_id2, image_id3),
+ (image_id3, image_id4))]
+
+ matches = dict(
+ (pair_id_to_image_ids(pair_id),
+ blob_to_array(data, np.uint32, (-1, 2)))
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
+ )
+
+ assert np.all(matches[(image_id1, image_id2)] == matches12)
+ assert np.all(matches[(image_id2, image_id3)] == matches23)
+ assert np.all(matches[(image_id3, image_id4)] == matches34)
+
+ # Clean up.
+
+ db.close()
+
+ if os.path.exists(args.database_path):
+ os.remove(args.database_path)
+
+
+if __name__ == "__main__":
+ example_usage()
diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py b/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..2335a1abbc6321f7e3a4b40123d8386d4900a9d2
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py
@@ -0,0 +1,232 @@
+import math
+import cv2
+import os
+import numpy as np
+from .read_write_model import read_images_binary
+
+
+def align_model(model, rot, trans, scale):
+ return (np.matmul(rot, model) + trans) * scale
+
+
+def align(model, data):
+ '''
+ Source: https://vision.in.tum.de/data/datasets/rgbd-dataset/tools
+ #absolute_trajectory_error_ate
+ Align two trajectories using the method of Horn (closed-form).
+
+ Input:
+ model -- first trajectory (3xn)
+ data -- second trajectory (3xn)
+
+ Output:
+ rot -- rotation matrix (3x3)
+ trans -- translation vector (3x1)
+ trans_error -- translational error per point (1xn)
+
+ '''
+
+ if model.shape[1] < 3:
+ print('Need at least 3 points for ATE: {}'.format(model))
+ return np.identity(3), np.zeros((3, 1)), 1
+
+ # Get zero centered point cloud
+ model_zerocentered = model - model.mean(1, keepdims=True)
+ data_zerocentered = data - data.mean(1, keepdims=True)
+
+ # constructed covariance matrix
+ W = np.zeros((3, 3))
+ for column in range(model.shape[1]):
+ W += np.outer(model_zerocentered[:, column],
+ data_zerocentered[:, column])
+
+ # SVD
+ U, d, Vh = np.linalg.linalg.svd(W.transpose())
+ S = np.identity(3)
+ if (np.linalg.det(U) * np.linalg.det(Vh) < 0):
+ S[2, 2] = -1
+ rot = np.matmul(np.matmul(U, S), Vh)
+ trans = data.mean(1, keepdims=True) - np.matmul(
+ rot, model.mean(1, keepdims=True))
+
+ # apply rot and trans to point cloud
+ model_aligned = align_model(model, rot, trans, 1.0)
+ model_aligned_zerocentered = model_aligned - model_aligned.mean(
+ 1, keepdims=True)
+
+ # calc scale based on distance to point cloud center
+ data_dist = np.sqrt((data_zerocentered * data_zerocentered).sum(axis=0))
+ model_aligned_dist = np.sqrt(
+ (model_aligned_zerocentered * model_aligned_zerocentered).sum(axis=0))
+ scale_array = data_dist / model_aligned_dist
+ scale = np.median(scale_array)
+
+ return rot, trans, scale
+
+
+def quaternion_matrix(quaternion):
+ '''Return homogeneous rotation matrix from quaternion.
+
+ >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
+ >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
+ True
+ >>> M = quaternion_matrix([1, 0, 0, 0])
+ >>> numpy.allclose(M, numpy.identity(4))
+ True
+ >>> M = quaternion_matrix([0, 1, 0, 0])
+ >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
+ True
+ '''
+
+ q = np.array(quaternion, dtype=np.float64, copy=True)
+ n = np.dot(q, q)
+ if n < _EPS:
+ return np.identity(4)
+
+ q *= math.sqrt(2.0 / n)
+ q = np.outer(q, q)
+
+ return np.array(
+ [[1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0],
+ [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0],
+ [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0],
+ [0.0, 0.0, 0.0, 1.0]])
+
+
+def quaternion_from_matrix(matrix, isprecise=False):
+ '''Return quaternion from rotation matrix.
+
+ If isprecise is True, the input matrix is assumed to be a precise rotation
+ matrix and a faster algorithm is used.
+
+ >>> q = quaternion_from_matrix(numpy.identity(4), True)
+ >>> numpy.allclose(q, [1, 0, 0, 0])
+ True
+ >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
+ >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
+ True
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
+ >>> q = quaternion_from_matrix(R, True)
+ >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
+ True
+ >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
+ ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
+ >>> q = quaternion_from_matrix(R)
+ >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
+ True
+ >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
+ ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
+ >>> q = quaternion_from_matrix(R)
+ >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
+ True
+ >>> R = random_rotation_matrix()
+ >>> q = quaternion_from_matrix(R)
+ >>> is_same_transform(R, quaternion_matrix(q))
+ True
+ >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
+ >>> numpy.allclose(quaternion_from_matrix(R, isprecise=False),
+ ... quaternion_from_matrix(R, isprecise=True))
+ True
+
+ '''
+
+ M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
+ if isprecise:
+ q = np.empty((4, ))
+ t = np.trace(M)
+ if t > M[3, 3]:
+ q[0] = t
+ q[3] = M[1, 0] - M[0, 1]
+ q[2] = M[0, 2] - M[2, 0]
+ q[1] = M[2, 1] - M[1, 2]
+ else:
+ i, j, k = 1, 2, 3
+ if M[1, 1] > M[0, 0]:
+ i, j, k = 2, 3, 1
+ if M[2, 2] > M[i, i]:
+ i, j, k = 3, 1, 2
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
+ q[i] = t
+ q[j] = M[i, j] + M[j, i]
+ q[k] = M[k, i] + M[i, k]
+ q[3] = M[k, j] - M[j, k]
+ q *= 0.5 / math.sqrt(t * M[3, 3])
+ else:
+ m00 = M[0, 0]
+ m01 = M[0, 1]
+ m02 = M[0, 2]
+ m10 = M[1, 0]
+ m11 = M[1, 1]
+ m12 = M[1, 2]
+ m20 = M[2, 0]
+ m21 = M[2, 1]
+ m22 = M[2, 2]
+
+ # symmetric matrix K
+ K = np.array([[m00 - m11 - m22, 0.0, 0.0, 0.0],
+ [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
+ [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
+ [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22]])
+ K /= 3.0
+
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
+ w, V = np.linalg.eigh(K)
+ q = V[[3, 0, 1, 2], np.argmax(w)]
+
+ if q[0] < 0.0:
+ np.negative(q, q)
+
+ return q
+
+def is_colmap_img_valid(colmap_img_file):
+ '''Return validity of a colmap reconstruction'''
+
+ images_bin = read_images_binary(colmap_img_file)
+ # Check if everything is finite for this subset
+ for key in images_bin.keys():
+ q = np.asarray(images_bin[key].qvec).flatten()
+ t = np.asarray(images_bin[key].tvec).flatten()
+
+ is_cur_valid = True
+ is_cur_valid = is_cur_valid and q.shape == (4, )
+ is_cur_valid = is_cur_valid and t.shape == (3, )
+ is_cur_valid = is_cur_valid and np.all(np.isfinite(q))
+ is_cur_valid = is_cur_valid and np.all(np.isfinite(t))
+
+ # If any is invalid, immediately return
+ if not is_cur_valid:
+ return False
+
+ return True
+
+def get_best_colmap_index(colmap_output_path):
+ '''
+ Determines the colmap model with the most images if there is more than one.
+ '''
+
+ # First find the colmap reconstruction with the most number of images.
+ best_index, best_num_images = -1, 0
+
+ # Check all valid sub reconstructions.
+ if os.path.exists(colmap_output_path):
+ idx_list = [
+ _d for _d in os.listdir(colmap_output_path)
+ if os.path.isdir(os.path.join(colmap_output_path, _d))
+ ]
+ else:
+ idx_list = []
+
+ for cur_index in idx_list:
+ cur_output_path = os.path.join(colmap_output_path, cur_index)
+ if os.path.isdir(cur_output_path):
+ colmap_img_file = os.path.join(cur_output_path, 'images.bin')
+ images_bin = read_images_binary(colmap_img_file)
+ # Check validity
+ if not is_colmap_img_valid(colmap_img_file):
+ continue
+ # Find the reconstruction with most number of images
+ if len(images_bin) > best_num_images:
+ best_index = int(cur_index)
+ best_num_images = len(images_bin)
+
+ return str(best_index)
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py b/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeb03c3bee0f1d6ffd5285835c83920e11de5b51
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py
@@ -0,0 +1,509 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+import os
+import sys
+import collections
+import numpy as np
+import struct
+import argparse
+
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"])
+Camera = collections.namedtuple(
+ "Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
+}
+CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
+ for camera_model in CAMERA_MODELS])
+CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
+ for camera_model in CAMERA_MODELS])
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
+ """pack and write to a binary file.
+ :param fid:
+ :param data: data to send, if multiple elements are sent at the same time,
+ they should be encapsuled either in a list or a tuple
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ should be the same length as the data list or tuple
+ :param endian_character: Any of {@, =, <, >, !}
+ """
+ if isinstance(data, (list, tuple)):
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
+ else:
+ bytes = struct.pack(endian_character + format_char_sequence, data)
+ fid.write(bytes)
+
+
+def read_cameras_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(id=camera_id, model=model,
+ width=width, height=height,
+ params=params)
+ return cameras
+
+
+def read_cameras_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ")
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(fid, num_bytes=8*num_params,
+ format_char_sequence="d"*num_params)
+ cameras[camera_id] = Camera(id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params))
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def write_cameras_text(cameras, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ HEADER = "# Camera list with one line of data per camera:\n"
+ "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
+ "# Number of cameras: {}\n".format(len(cameras))
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, cam in cameras.items():
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
+ line = " ".join([str(elem) for elem in to_write])
+ fid.write(line + "\n")
+
+
+def write_cameras_binary(cameras, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(cameras), "Q")
+ for _, cam in cameras.items():
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
+ camera_properties = [cam.id,
+ model_id,
+ cam.width,
+ cam.height]
+ write_next_bytes(fid, camera_properties, "iiQQ")
+ for p in cam.params:
+ write_next_bytes(fid, float(p), "d")
+ return cameras
+
+
+def read_images_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
+ tuple(map(float, elems[1::3]))])
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def read_images_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi")
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8,
+ format_char_sequence="Q")[0]
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
+ format_char_sequence="ddq"*num_points2D)
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
+ tuple(map(float, x_y_id_s[1::3]))])
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def write_images_text(images, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ if len(images) == 0:
+ mean_observations = 0
+ else:
+ mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
+ HEADER = "# Image list with two lines of data per image:\n"
+ "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
+ "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
+ "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, img in images.items():
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
+ first_line = " ".join(map(str, image_header))
+ fid.write(first_line + "\n")
+
+ points_strings = []
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
+ fid.write(" ".join(points_strings) + "\n")
+
+
+def write_images_binary(images, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_points3D_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ points3D = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ point3D_id = int(elems[0])
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = float(elems[7])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
+ points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
+ error=error, image_ids=image_ids,
+ point2D_idxs=point2D_idxs)
+ return points3D
+
+
+def read_points3d_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(
+ fid, num_bytes=8, format_char_sequence="Q")[0]
+ track_elems = read_next_bytes(
+ fid, num_bytes=8*track_length,
+ format_char_sequence="ii"*track_length)
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id, xyz=xyz, rgb=rgb,
+ error=error, image_ids=image_ids,
+ point2D_idxs=point2D_idxs)
+ return points3D
+
+def write_points3D_text(points3D, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ if len(points3D) == 0:
+ mean_track_length = 0
+ else:
+ mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
+ HEADER = "# 3D point list with one line of data per point:\n"
+ "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
+ "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, pt in points3D.items():
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
+ fid.write(" ".join(map(str, point_header)) + " ")
+ track_strings = []
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
+ fid.write(" ".join(track_strings) + "\n")
+
+
+def write_points3d_binary(points3D, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
+
+
+def detect_model_format(path, ext):
+ if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
+ os.path.isfile(os.path.join(path, "images" + ext)) and \
+ os.path.isfile(os.path.join(path, "points3D" + ext)):
+ print("Detected model format: '" + ext + "'")
+ return True
+
+ return False
+
+
+def read_model(path, ext=""):
+ # try to detect the extension automatically
+ if ext == "":
+ if detect_model_format(path, ".bin"):
+ ext = ".bin"
+ elif detect_model_format(path, ".txt"):
+ ext = ".txt"
+ else:
+ print("Provide model format: '.bin' or '.txt'")
+ return
+
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def write_model(cameras, images, points3D, path, ext=".bin"):
+ if ext == ".txt":
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
+ write_images_text(images, os.path.join(path, "images" + ext))
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
+ else:
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
+ write_images_binary(images, os.path.join(path, "images" + ext))
+ write_points3d_binary(points3D, os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def qvec2rotmat(qvec):
+ return np.array([
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = np.array([
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
+ parser.add_argument("--input_model", help="path to input model folder")
+ parser.add_argument("--input_format", choices=[".bin", ".txt"],
+ help="input model format", default="")
+ parser.add_argument("--output_model",
+ help="path to output model folder")
+ parser.add_argument("--output_format", choices=[".bin", ".txt"],
+ help="outut model format", default=".txt")
+ args = parser.parse_args()
+
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
+
+ # FIXME: for debug only
+ # images_ = images[1]
+ # tvec, qvec = images_.tvec, images_.qvec
+ # rotation = qvec2rotmat(qvec).reshape(3, 3)
+ # pose = np.concatenate([rotation, tvec.reshape(3, 1)], axis=1)
+ # import ipdb; ipdb.set_trace()
+
+ print("num_cameras:", len(cameras))
+ print("num_images:", len(images))
+ print("num_points3D:", len(points3D))
+
+ if args.output_model is not None:
+ write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/imcui/third_party/MatchAnything/src/utils/comm.py b/imcui/third_party/MatchAnything/src/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/comm.py
@@ -0,0 +1,265 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+[Copied from detectron2]
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group=group) == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving Tensor from all ranks
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/imcui/third_party/MatchAnything/src/utils/dataloader.py b/imcui/third_party/MatchAnything/src/utils/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/dataloader.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+# --- PL-DATAMODULE ---
+
+def get_local_split(items: list, world_size: int, rank: int, seed: int):
+ """ The local rank only loads a split of the dataset. """
+ n_items = len(items)
+ items_permute = np.random.RandomState(seed).permutation(items)
+ if n_items % world_size == 0:
+ padded_items = items_permute
+ else:
+ padding = np.random.RandomState(seed).choice(
+ items,
+ world_size - (n_items % world_size),
+ replace=True)
+ padded_items = np.concatenate([items_permute, padding])
+ assert len(padded_items) % world_size == 0, \
+ f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
+ n_per_rank = len(padded_items) // world_size
+ local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
+
+ return local_items
diff --git a/imcui/third_party/MatchAnything/src/utils/dataset.py b/imcui/third_party/MatchAnything/src/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..73a0a96db5ef2c08c99394a25e2db306bdb47b6a
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/dataset.py
@@ -0,0 +1,518 @@
+import io
+from loguru import logger
+
+import cv2
+import numpy as np
+from pathlib import Path
+import h5py
+import torch
+import re
+from PIL import Image
+from numpy.linalg import inv
+from torchvision.transforms import Normalize
+from .sample_homo import sample_homography_sap
+from kornia.geometry import homography_warp, normalize_homography, normal_transform_pixel
+OSS_FOLDER_PATH = '???'
+PCACHE_FOLDER_PATH = '???'
+
+import fsspec
+from PIL import Image
+
+# Initialize pcache
+try:
+ PCACHE_HOST = "???"
+ PCACHE_PORT = 00000
+ pcache_kwargs = {"host": PCACHE_HOST, "port": PCACHE_PORT}
+ pcache_fs = fsspec.filesystem("pcache", pcache_kwargs=pcache_kwargs)
+ root_dir='???'
+except Exception as e:
+ logger.error(f"Error captured:{e}")
+
+try:
+ # for internel use only
+ from pcache_fileio import fileio
+except Exception:
+ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
+
+# --- DATA IO ---
+
+def load_pfm(pfm_path):
+ with open(pfm_path, 'rb') as fin:
+ color = None
+ width = None
+ height = None
+ scale = None
+ data_type = None
+ header = str(fin.readline().decode('UTF-8')).rstrip()
+
+ if header == 'PF':
+ color = True
+ elif header == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+ scale = float((fin.readline().decode('UTF-8')).rstrip())
+ if scale < 0: # little-endian
+ data_type = ' 5000:
+ logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
+ continue
+ else:
+ load_failed = True
+ failed_num = 0
+ while load_failed:
+ try:
+ with pcache_fs.open(str(pcache_path), 'rb') as f:
+ data = np.array(h5py.File(io.BytesIO(f.read()), 'r')['/depth'])
+ load_failed = False
+ except:
+ failed_num += 1
+ if failed_num > 5000:
+ logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
+ continue
+
+ except Exception as ex:
+ print(f"==> Data loading failure: {path}")
+ raise ex
+
+ assert data is not None
+ return data
+
+
+def imread_gray(path, augment_fn=None, cv_type=None):
+ if path.startswith('oss://'):
+ path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
+ if path.startswith('pcache://'):
+ path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
+
+ if cv_type is None:
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
+ else cv2.IMREAD_COLOR
+ if str(path).startswith('oss://') or str(path).startswith('pcache://'):
+ image = load_array_from_pcache(str(path), cv_type)
+ else:
+ image = cv2.imread(str(path), cv_type)
+
+ if augment_fn is not None:
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = augment_fn(image)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ return image # (h, w)
+
+def imread_color(path, augment_fn=None):
+ if path.startswith('oss://'):
+ path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
+ if path.startswith('pcache://'):
+ path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
+
+ if str(path).startswith('oss://') or str(path).startswith('pcache://'):
+ filename = path.split(root_dir)[1]
+ pcache_path = Path(root_dir) / filename
+ load_failed = True
+ failed_num = 0
+ while load_failed:
+ try:
+ with pcache_fs.open(str(pcache_path), 'rb') as f:
+ pil_image = Image.open(f).convert("RGB")
+ load_failed = False
+ except:
+ failed_num += 1
+ if failed_num > 5000:
+ logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times")
+ continue
+ else:
+ pil_image = Image.open(str(path)).convert("RGB")
+ image = np.array(pil_image)
+
+ if augment_fn is not None:
+ image = augment_fn(image)
+ return image # (h, w)
+
+
+def get_resized_wh(w, h, resize=None):
+ if resize is not None: # resize the longer edge
+ scale = resize / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def get_divisible_wh(w, h, df=None):
+ if df is not None:
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def pad_bottom_right(inp, pad_size, ret_mask=False):
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+ mask = None
+ if inp.ndim == 2:
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
+ padded[:inp.shape[0], :inp.shape[1]] = inp
+ if ret_mask:
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
+ mask[:inp.shape[0], :inp.shape[1]] = True
+ elif inp.ndim == 3:
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
+ if ret_mask:
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
+ mask = mask[0]
+ else:
+ raise NotImplementedError()
+ return padded, mask
+
+
+# --- MEGADEPTH ---
+
+def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read image
+ if read_gray:
+ image = imread_gray(path, augment_fn)
+ else:
+ image = imread_color(path, augment_fn)
+
+ # resize image
+ try:
+ w, h = image.shape[1], image.shape[0]
+ except:
+ logger.error(f"{path} not exist or read image error!")
+ if resize_by_stretch:
+ w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
+ else:
+ if resize:
+ if not isinstance(resize, int):
+ assert resize[0] == resize[1]
+ resize = resize[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+ else:
+ w_new, h_new = w, h
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+ origin_img_size = torch.tensor([h, w], dtype=torch.float)
+
+ if not read_gray:
+ image = image.transpose(2,0,1)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ else:
+ mask = None
+
+ if len(image.shape) == 2:
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ else:
+ image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
+ if mask is not None:
+ mask = torch.from_numpy(mask)
+
+ if image.shape[0] == 3 and normalize_img:
+ # Normalize image:
+ image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
+
+ return image, mask, scale, origin_img_size
+
+def read_megadepth_gray_sample_homowarp(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read image
+ if read_gray:
+ image = imread_gray(path, augment_fn)
+ else:
+ image = imread_color(path, augment_fn)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ if resize_by_stretch:
+ w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
+ else:
+ if not isinstance(resize, int):
+ assert resize[0] == resize[1]
+ resize = resize[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ origin_img_size = torch.tensor([h, w], dtype=torch.float)
+
+ # Sample homography and warp:
+ homo_sampled = sample_homography_sap(h, w) # 3*3
+ homo_sampled_normed = normalize_homography(
+ torch.from_numpy(homo_sampled[None]).to(torch.float32),
+ (h, w),
+ (h, w),
+ )
+
+ if len(image.shape) == 2:
+ image = torch.from_numpy(image).float()[None, None] / 255 # B * C * H * W
+ else:
+ image = torch.from_numpy(image).float().permute(2,0,1)[None] / 255
+
+ homo_warpped_image = homography_warp(
+ image, # 1 * C * H * W
+ torch.linalg.inv(homo_sampled_normed),
+ (h, w),
+ )
+ image = (homo_warpped_image[0].permute(1,2,0).numpy() * 255).astype(np.uint8)
+ norm_pixel_mat = normal_transform_pixel(h, w) # 1 * 3 * 3
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ if not read_gray:
+ image = image.transpose(2,0,1)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ else:
+ mask = None
+
+ if len(image.shape) == 2:
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ else:
+ image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
+ if mask is not None:
+ mask = torch.from_numpy(mask)
+
+ if image.shape[0] == 3 and normalize_img:
+ # Normalize image:
+ image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
+
+ return image, mask, scale, origin_img_size, norm_pixel_mat[0], homo_sampled_normed[0]
+
+
+def read_megadepth_depth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ depth = read_megadepth_depth(path, return_tensor=False)
+
+ # following controlnet 1-depth
+ depth = depth.astype(np.float64)
+ depth_non_zero = depth[depth!=0]
+ vmin = np.percentile(depth_non_zero, 2)
+ vmax = np.percentile(depth_non_zero, 85)
+ depth -= vmin
+ depth /= (vmax - vmin + 1e-4)
+ depth = 1.0 - depth
+ image = (depth * 255.0).clip(0, 255).astype(np.uint8)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ if resize_by_stretch:
+ w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0])
+ else:
+ if not isinstance(resize, int):
+ assert resize[0] == resize[1]
+ resize = resize[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+ origin_img_size = torch.tensor([h, w], dtype=torch.float)
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ else:
+ mask = None
+
+ if read_gray:
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ else:
+ image = np.stack([image]*3) # 3 * H * W
+ image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
+ if mask is not None:
+ mask = torch.from_numpy(mask)
+
+ if image.shape[0] == 3 and normalize_img:
+ # Normalize image:
+ image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W
+
+ return image, mask, scale, origin_img_size
+
+def read_megadepth_depth(path, pad_to=None, return_tensor=True):
+ if path.startswith('oss://'):
+ path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH)
+ if path.startswith('pcache://'):
+ path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/'
+
+ load_failed = True
+ failed_num = 0
+ while load_failed:
+ try:
+ if '.png' in path:
+ if 'scannet_plus' in path:
+ depth = imread_gray(path, cv_type=cv2.IMREAD_UNCHANGED).astype(np.float32)
+
+ with open(path, 'rb') as f:
+ # CO3D
+ depth = np.asarray(Image.open(f)).astype(np.float32)
+ depth = depth / 1000
+ elif '.pfm' in path:
+ # For BlendedMVS dataset (not support pcache):
+ depth = load_pfm(path).copy()
+ else:
+ # For MegaDepth
+ if str(path).startswith('oss://') or str(path).startswith('pcache://'):
+ depth = load_array_from_pcache(path, None, use_h5py=True)
+ else:
+ depth = np.array(h5py.File(path, 'r')['depth'])
+ load_failed = False
+ except:
+ failed_num += 1
+ if failed_num > 5000:
+ logger.error(f"Try to load: {path}, but failed {failed_num} times")
+ continue
+
+ if pad_to is not None:
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
+ if return_tensor:
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+# --- ScanNet ---
+
+def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
+ """
+ Args:
+ resize (tuple): align image to depthmap, in (w, h).
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read and resize image
+ image = imread_gray(path, augment_fn)
+ image = cv2.resize(image, resize)
+
+ # (h, w) -> (1, h, w) and normalized
+ image = torch.from_numpy(image).float()[None] / 255
+ return image
+
+
+def read_scannet_depth(path):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
+ else:
+ depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+def read_scannet_pose(path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = inv(cam2world)
+ return world2cam
+
+
+def read_scannet_intrinsic(path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return intrinsic[:-1, :-1]
+
+def dict_to_cuda(data_dict):
+ data_dict_cuda = {}
+ for k, v in data_dict.items():
+ if isinstance(v, torch.Tensor):
+ data_dict_cuda[k] = v.cuda()
+ elif isinstance(v, dict):
+ data_dict_cuda[k] = dict_to_cuda(v)
+ elif isinstance(v, list):
+ data_dict_cuda[k] = list_to_cuda(v)
+ else:
+ data_dict_cuda[k] = v
+ return data_dict_cuda
+
+def list_to_cuda(data_list):
+ data_list_cuda = []
+ for obj in data_list:
+ if isinstance(obj, torch.Tensor):
+ data_list_cuda.append(obj.cuda())
+ elif isinstance(obj, dict):
+ data_list_cuda.append(dict_to_cuda(obj))
+ elif isinstance(obj, list):
+ data_list_cuda.append(list_to_cuda(obj))
+ else:
+ data_list_cuda.append(obj)
+ return data_list_cuda
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/utils/easydict.py b/imcui/third_party/MatchAnything/src/utils/easydict.py
new file mode 100755
index 0000000000000000000000000000000000000000..e4af7a343311581cd56b486fa4d1cd0f60d1ad86
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/easydict.py
@@ -0,0 +1,148 @@
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> map(attrgetter('x'), d.bar)
+ [1, 3]
+ >>> map(attrgetter('y'), d.bar)
+ [2, 4]
+ >>> d = EasyDict()
+ >>> d.keys()
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> o.items()
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+
+ update and pop items
+ >>> d = EasyDict(a=1, b='2')
+ >>> e = EasyDict(c=3.0, a=9.0)
+ >>> d.update(e)
+ >>> d.c
+ 3.0
+ >>> d['c']
+ 3.0
+ >>> d.get('c')
+ 3.0
+ >>> d.update(a=4, b=4)
+ >>> d.b
+ 4
+ >>> d.pop('a')
+ 4
+ >>> d.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'a'
+ """
+
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+ def update(self, e=None, **f):
+ d = e or dict()
+ d.update(f)
+ for k in d:
+ setattr(self, k, d[k])
+
+ def pop(self, k, d=None):
+ if hasattr(self, k):
+ delattr(self, k)
+ return super(EasyDict, self).pop(k, d)
+
+
+if __name__ == "__main__":
+ import doctest
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/utils/geometry.py b/imcui/third_party/MatchAnything/src/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6470ca309655e4e81f58bfe515a428b6e8b3623
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/geometry.py
@@ -0,0 +1,366 @@
+from __future__ import division
+import torch
+import torch.nn.functional as F
+import numpy as np
+# from numba import jit
+
+pixel_coords = None
+
+def set_id_grid(depth):
+ b, h, w = depth.size()
+ i_range = torch.arange(0, h).view(1, h, 1).expand(1,h,w).type_as(depth) # [1, H, W]
+ j_range = torch.arange(0, w).view(1, 1, w).expand(1,h,w).type_as(depth) # [1, H, W]
+ ones = torch.ones(1,h,w).type_as(depth)
+
+ pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W]
+ return pixel_coords
+
+def check_sizes(input, input_name, expected):
+ condition = [input.ndimension() == len(expected)]
+ for i,size in enumerate(expected):
+ if size.isdigit():
+ condition.append(input.size(i) == int(size))
+ assert(all(condition)), "wrong size for {}, expected {}, got {}".format(input_name, 'x'.join(expected), list(input.size()))
+
+
+def pixel2cam(depth, intrinsics_inv):
+ """Transform coordinates in the pixel frame to the camera frame.
+ Args:
+ depth: depth maps -- [B, H, W]
+ intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
+ Returns:
+ array of (u,v,1) cam coordinates -- [B, 3, H, W]
+ """
+ b, h, w = depth.size()
+ pixel_coords = set_id_grid(depth)
+ current_pixel_coords = pixel_coords[:,:,:h,:w].expand(b,3,h,w).reshape(b, 3, -1) # [B, 3, H*W]
+ cam_coords = (intrinsics_inv.float() @ current_pixel_coords.float()).reshape(b, 3, h, w)
+ return cam_coords * depth.unsqueeze(1)
+
+def cam2pixel_depth(cam_coords, proj_c2p_rot, proj_c2p_tr):
+ """Transform coordinates in the camera frame to the pixel frame and get depth map.
+ Args:
+ cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
+ proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
+ proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
+ Returns:
+ tensor of [-1,1] coordinates -- [B, 2, H, W]
+ depth map -- [B, H, W]
+ """
+ b, _, h, w = cam_coords.size()
+ cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
+ if proj_c2p_rot is not None:
+ pcoords = proj_c2p_rot @ cam_coords_flat
+ else:
+ pcoords = cam_coords_flat
+
+ if proj_c2p_tr is not None:
+ pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
+ X = pcoords[:, 0]
+ Y = pcoords[:, 1]
+ Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm
+
+ X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
+ Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
+
+ pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
+ return pixel_coords.reshape(b,h,w,2), Z.reshape(b, h, w)
+
+
+def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr):
+ """Transform coordinates in the camera frame to the pixel frame.
+ Args:
+ cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
+ proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
+ proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
+ Returns:
+ array of [-1,1] coordinates -- [B, 2, H, W]
+ """
+ b, _, h, w = cam_coords.size()
+ cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
+ if proj_c2p_rot is not None:
+ pcoords = proj_c2p_rot @ cam_coords_flat
+ else:
+ pcoords = cam_coords_flat
+
+ if proj_c2p_tr is not None:
+ pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
+ X = pcoords[:, 0]
+ Y = pcoords[:, 1]
+ Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm
+
+ X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
+ Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
+
+ pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
+ return pixel_coords.reshape(b,h,w,2)
+
+
+def reproject_kpts(dim0_idxs, kpts, depth, rel_pose, K0, K1):
+ """ Reproject keypoints with depth, relative pose and camera intrinsics
+ Args:
+ dim0_idxs (torch.LoneTensor): (B*max_kpts, )
+ kpts (torch.LongTensor): (B, max_kpts, 2) -
+ depth (torch.Tensor): (B, H, W)
+ rel_pose (torch.Tensor): (B, 3, 4) relative transfomation from target to source (T_0to1) --
+ K0: (torch.Tensor): (N, 3, 3) - (K_0)
+ K1: (torch.Tensor): (N, 3, 3) - (K_1)
+ Returns:
+ (torch.Tensor): (B, max_kpts, 2) the reprojected kpts
+ """
+ # pixel to camera
+ device = kpts.device
+ B, max_kpts, _ = kpts.shape
+
+ kpts = kpts.reshape(-1, 2) # (B*K, 2)
+ kpts_depth = depth[dim0_idxs, kpts[:, 1], kpts[:, 0]] # (B*K, )
+ kpts = torch.cat([kpts.float(),
+ torch.ones((kpts.shape[0], 1), dtype=torch.float32, device=device)], -1) # (B*K, 3)
+ pixel_coords = (kpts * kpts_depth[:, None]).reshape(B, max_kpts, 3).permute(0, 2, 1) # (B, 3, K)
+
+ cam_coords = K0.inverse() @ pixel_coords # (N, 3, max_kpts)
+ # camera1 to camera 2
+ rel_pose_R = rel_pose[:, :, :-1] # (B, 3, 3)
+ rel_pose_t = rel_pose[:, :, -1][..., None] # (B, 3, 1)
+ cam2_coords = rel_pose_R @ cam_coords + rel_pose_t # (B, 3, max_kpts)
+ # projection
+ pixel2_coords = K1 @ cam2_coords # (B, 3, max_kpts)
+ reproj_kpts = pixel2_coords[:, :-1, :] / pixel2_coords[:, -1, :][:, None].expand(-1, 2, -1)
+ return reproj_kpts.permute(0, 2, 1)
+
+
+def check_depth_consistency(b_idxs, kpts0, depth0, kpts1, depth1, T_0to1, K0, K1,
+ atol=0.1, rtol=0.0):
+ """
+ Args:
+ b_idxs (torch.LongTensor): (n_kpts, ) the batch indices which each keypoints pairs belong to
+ kpts0 (torch.LongTensor): (n_kpts, 2) -
+ depth0 (torch.Tensor): (B, H, W)
+ kpts1 (torch.LongTensor): (n_kpts, 2)
+ depth1 (torch.Tensor): (B, H, W)
+ T_0to1 (torch.Tensor): (B, 3, 4)
+ K0: (torch.Tensor): (N, 3, 3) - (K_0)
+ K1: (torch.Tensor): (N, 3, 3) - (K_1)
+ atol (float): the absolute tolerance for depth consistency check
+ rtol (float): the relative tolerance for depth consistency check
+ Returns:
+ valid_mask (torch.Tensor): (n_kpts, )
+ Notes:
+ The two corresponding keypoints are depth consistent if the following equation is held:
+ abs(kpt_0to1_depth - kpt1_depth) <= (atol + rtol * abs(kpt1_depth))
+ * In the initial reimplementation, `atol=0.1, rtol=0` is used, and the result is better with
+ `atol=1.0, rtol=0` (which nearly ignore the depth consistency check).
+ * However, the author suggests using `atol=0.0, rtol=0.1` as in https://github.com/magicleap/SuperGluePretrainedNetwork/issues/31#issuecomment-681866054
+ """
+ device = kpts0.device
+ n_kpts = kpts0.shape[0]
+
+ kpts0_depth = depth0[b_idxs, kpts0[:, 1], kpts0[:, 0]] # (n_kpts, )
+ kpts1_depth = depth1[b_idxs, kpts1[:, 1], kpts1[:, 0]] # (n_kpts, )
+ kpts0 = torch.cat([kpts0.float(),
+ torch.ones((n_kpts, 1), dtype=torch.float32, device=device)], -1) # (n_kpts, 3)
+ pixel_coords = (kpts0 * kpts0_depth[:, None])[..., None] # (n_kpts, 3, 1)
+
+ # indexing from T_0to1 and K - treat all kpts as a batch
+ K0 = K0[b_idxs, :, :] # (n_kpts, 3, 3)
+ T_0to1 = T_0to1[b_idxs, :, :] # (n_kpts, 3, 4)
+ cam_coords = K0.inverse() @ pixel_coords # (n_kpts, 3, 1)
+
+ # camera1 to camera2
+ R_0to1 = T_0to1[:, :, :-1] # (n_kpts, 3, 3)
+ t_0to1 = T_0to1[:, :, -1][..., None] # (n_kpts, 3, 1)
+ cam1_coords = R_0to1 @ cam_coords + t_0to1 # (n_kpts, 3, 1)
+ K1 = K1[b_idxs, :, :] # (n_kpts, 3, 3)
+ pixel1_coords = K1 @ cam1_coords # (n_kpts, 3, 1)
+ kpts_0to1_depth = pixel1_coords[:, -1, 0] # (n_kpts, )
+ return (kpts_0to1_depth - kpts1_depth).abs() <= atol + rtol * kpts1_depth.abs()
+
+
+def inverse_warp(img, depth, pose, intrinsics, mode='bilinear', padding_mode='zeros'):
+ """
+ Inverse warp a source image to the target image plane.
+
+ Args:
+ img: the source image (where to sample pixels) -- [B, 3, H, W]
+ depth: depth map of the target image -- [B, H, W]
+ pose: relative transfomation from target to source -- [B, 3, 4]
+ intrinsics: camera intrinsic matrix -- [B, 3, 3]
+ Returns:
+ projected_img: Source image warped to the target image plane
+ valid_points: Boolean array indicating point validity
+ """
+ # check_sizes(img, 'img', 'B3HW')
+ check_sizes(depth, 'depth', 'BHW')
+# check_sizes(pose, 'pose', 'B6')
+ check_sizes(intrinsics, 'intrinsics', 'B33')
+
+ batch_size, _, img_height, img_width = img.size()
+
+ cam_coords = pixel2cam(depth, intrinsics.inverse()) # [B,3,H,W]
+
+ pose_mat = pose # (B, 3, 4)
+
+ # Get projection matrix for target camera frame to source pixel frame
+ proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4]
+
+ rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:]
+ src_pixel_coords = cam2pixel(cam_coords, rot, tr) # [B,H,W,2]
+ projected_img = F.grid_sample(img, src_pixel_coords, mode=mode,
+ padding_mode=padding_mode, align_corners=True)
+
+ valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1
+
+ return projected_img, valid_points
+
+def depth_inverse_warp(depth_source, depth, pose, intrinsic_source, intrinsic, mode='nearest', padding_mode='zeros'):
+ """
+ 1. Inversely warp a source depth map to the target image plane (warped depth map still in source frame)
+ 2. Transform the target depth map to the source image frame
+ Args:
+ depth_source: the source image (where to sample pixels) -- [B, H, W]
+ depth: depth map of the target image -- [B, H, W]
+ pose: relative transfomation from target to source -- [B, 3, 4]
+ intrinsics: camera intrinsic matrix -- [B, 3, 3]
+ Returns:
+ warped_depth: Source depth warped to the target image plane -- [B, H, W]
+ projected_depth: Target depth projected to the source image frame -- [B, H, W]
+ valid_points: Boolean array indicating point validity -- [B, H, W]
+ """
+ check_sizes(depth_source, 'depth', 'BHW')
+ check_sizes(depth, 'depth', 'BHW')
+ check_sizes(intrinsic_source, 'intrinsics', 'B33')
+
+ b, h, w = depth.size()
+
+ cam_coords = pixel2cam(depth, intrinsic.inverse()) # [B,3,H,W]
+
+ pose_mat = pose # (B, 3, 4)
+
+ # Get projection matrix from target camera frame to source pixel frame
+ proj_cam_to_src_pixel = intrinsic_source @ pose_mat # [B, 3, 4]
+
+ rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:]
+ src_pixel_coords, depth_target2src = cam2pixel_depth(cam_coords, rot, tr) # [B,H,W,2]
+ warped_depth = F.grid_sample(depth_source[:, None], src_pixel_coords, mode=mode,
+ padding_mode=padding_mode, align_corners=True) # [B, 1, H, W]
+
+ valid_points = (src_pixel_coords.abs().max(dim=-1)[0] <= 1) &\
+ (depth > 0.0) & (warped_depth[:, 0] > 0.0) # [B, H, W]
+ return warped_depth[:, 0], depth_target2src, valid_points
+
+def to_skew(t):
+ """ Transform the translation vector t to skew-symmetric matrix.
+ Args:
+ t (torch.Tensor): (B, 3)
+ """
+ t_skew = t.new_ones((t.shape[0], 3, 3))
+ t_skew[:, 0, 1] = -t[:, 2]
+ t_skew[:, 1, 0] = t[:, 2]
+ t_skew[:, 0, 2] = t[:, 1]
+ t_skew[:, 2, 0] = -t[:, 1]
+ t_skew[:, 1, 2] = -t[:, 0]
+ t_skew[:, 2, 1] = t[:, 0]
+ return t_skew # (B, 3, 3)
+
+
+def to_homogeneous(pts):
+ """
+ Args:
+ pts (torch.Tensor): (B, K, 2)
+ """
+ return torch.cat([pts, torch.ones_like(pts[..., :1])], -1) # (B, K, 3)
+
+
+def pix2img(pts, K):
+ """
+ Args:
+ pts (torch.Tensor): (B, K, 2)
+ K (torch.Tensor): (B, 3, 3)
+ """
+ return (pts - K[:, [0, 1], [2, 2]][:, None]) / K[:, [0, 1], [0, 1]][:, None]
+
+
+def weighted_blind_sed(kpts0, kpts1, weights, E, K0, K1):
+ """ Calculate the squared weighted blind symmetric epipolar distance, which is the sed between
+ all possible keypoints pairs.
+ Args:
+ kpts0 (torch.Tensor): (B, K0, 2)
+ ktps1 (torch.Tensor): (B, K1, 2)
+ weights (torch.Tensor): (B, K0, K1)
+ E (torch.Tensor): (B, 3, 3) - the essential matrix
+ K0 (torch.Tensor): (B, 3, 3)
+ K1 (torch.Tensor): (B, 3, 3)
+ Returns:
+ w_sed (torch.Tensor): (B, K0, K1)
+ """
+ M, N = kpts0.shape[1], kpts1.shape[1]
+
+ kpts0 = to_homogeneous(pix2img(kpts0, K0))
+ kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3)
+
+ R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1)
+ # w_R = weights * R # (B, K0, K1)
+
+ Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3)
+ Etp1 = kpts1 @ E # (B, K1, 3)
+ d = R**2 * (1.0 / (Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N)
+ + 1.0 / (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1)) * weights # (B, K0, K1)
+ return d
+
+def weighted_blind_sampson(kpts0, kpts1, weights, E, K0, K1):
+ """ Calculate the squared weighted blind sampson distance, which is the sampson distance between
+ all possible keypoints pairs weighted by the given weights.
+ """
+ M, N = kpts0.shape[1], kpts1.shape[1]
+
+ kpts0 = to_homogeneous(pix2img(kpts0, K0))
+ kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3)
+
+ R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1)
+ # w_R = weights * R # (B, K0, K1)
+
+ Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3)
+ Etp1 = kpts1 @ E # (B, K1, 3)
+ d = R**2 * (1.0 / ((Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N)
+ + (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1))) * weights # (B, K0, K1)
+ return d
+
+
+def angular_rel_rot(T_0to1):
+ """
+ Args:
+ T0_to_1 (np.ndarray): (4, 4)
+ """
+ cos = (np.trace(T_0to1[:-1, :-1]) - 1) / 2
+ if cos < -1:
+ cos = -1.0
+ if cos > 1:
+ cos = 1.0
+ angle_error_rot = np.rad2deg(np.abs(np.arccos(cos)))
+
+ return angle_error_rot
+
+def angular_rel_pose(T0, T1):
+ """
+ Args:
+ T0 (np.ndarray): (4, 4)
+ T1 (np.ndarray): (4, 4)
+
+ """
+ cos = (np.trace(T0[:-1, :-1].T @ T1[:-1, :-1]) - 1) / 2
+ if cos < -1:
+ cos = -1.0
+ if cos > 1:
+ cos = 1.0
+ angle_error_rot = np.rad2deg(np.abs(np.arccos(cos)))
+
+ # calculate angular translation error
+ n = np.linalg.norm(T0[:-1, -1]) * np.linalg.norm(T1[:-1, -1])
+ cos = np.dot(T0[:-1, -1], T1[:-1, -1]) / n
+ if cos < -1:
+ cos = -1.0
+ if cos > 1:
+ cos = 1.0
+ angle_error_trans = np.rad2deg(np.arccos(cos))
+
+ return angle_error_rot, angle_error_trans
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/utils/homography_utils.py b/imcui/third_party/MatchAnything/src/utils/homography_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d6ac5adb97c9963c216a97761cca9dfccacd91f
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/homography_utils.py
@@ -0,0 +1,366 @@
+import math
+from typing import Tuple
+
+import numpy as np
+import torch
+
+def to_homogeneous(points):
+ """Convert N-dimensional points to homogeneous coordinates.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N).
+ Returns:
+ A torch.Tensor or numpy.ndarray with size (..., N+1).
+ """
+ if isinstance(points, torch.Tensor):
+ pad = points.new_ones(points.shape[:-1] + (1,))
+ return torch.cat([points, pad], dim=-1)
+ elif isinstance(points, np.ndarray):
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
+ return np.concatenate([points, pad], axis=-1)
+ else:
+ raise ValueError
+
+
+def from_homogeneous(points, eps=0.0):
+ """Remove the homogeneous dimension of N-dimensional points.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N+1).
+ eps: Epsilon value to prevent zero division.
+ Returns:
+ A torch.Tensor or numpy ndarray with size (..., N).
+ """
+ return points[..., :-1] / (points[..., -1:] + eps)
+
+
+def flat2mat(H):
+ return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
+
+
+# Homography creation
+
+
+def create_center_patch(shape, patch_shape=None):
+ if patch_shape is None:
+ patch_shape = shape
+ width, height = shape
+ pwidth, pheight = patch_shape
+ left = int((width - pwidth) / 2)
+ bottom = int((height - pheight) / 2)
+ right = int((width + pwidth) / 2)
+ top = int((height + pheight) / 2)
+ return np.array([[left, bottom], [left, top], [right, top], [right, bottom]])
+
+
+def check_convex(patch, min_convexity=0.05):
+ """Checks if given polygon vertices [N,2] form a convex shape"""
+ for i in range(patch.shape[0]):
+ x1, y1 = patch[(i - 1) % patch.shape[0]]
+ x2, y2 = patch[i]
+ x3, y3 = patch[(i + 1) % patch.shape[0]]
+ if (x2 - x1) * (y3 - y2) - (x3 - x2) * (y2 - y1) > -min_convexity:
+ return False
+ return True
+
+
+def sample_homography_corners(
+ shape,
+ patch_shape,
+ difficulty=1.0,
+ translation=0.4,
+ n_angles=10,
+ max_angle=90,
+ min_convexity=0.05,
+ rng=np.random,
+):
+ max_angle = max_angle / 180.0 * math.pi
+ width, height = shape
+ pwidth, pheight = width * (1 - difficulty), height * (1 - difficulty)
+ min_pts1 = create_center_patch(shape, (pwidth, pheight))
+ full = create_center_patch(shape)
+ pts2 = create_center_patch(patch_shape)
+ scale = min_pts1 - full
+ found_valid = False
+ cnt = -1
+ while not found_valid:
+ offsets = rng.uniform(0.0, 1.0, size=(4, 2)) * scale
+ pts1 = full + offsets
+ found_valid = check_convex(pts1 / np.array(shape), min_convexity)
+ cnt += 1
+
+ # re-center
+ pts1 = pts1 - np.mean(pts1, axis=0, keepdims=True)
+ pts1 = pts1 + np.mean(min_pts1, axis=0, keepdims=True)
+
+ # Rotation
+ if n_angles > 0 and difficulty > 0:
+ angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
+ rng.shuffle(angles)
+ rng.shuffle(angles)
+ angles = np.concatenate([[0.0], angles], axis=0)
+
+ center = np.mean(pts1, axis=0, keepdims=True)
+ rot_mat = np.reshape(
+ np.stack(
+ [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)],
+ axis=1,
+ ),
+ [-1, 2, 2],
+ )
+ rotated = (
+ np.matmul(
+ np.tile(np.expand_dims(pts1 - center, axis=0), [n_angles + 1, 1, 1]),
+ rot_mat,
+ )
+ + center
+ )
+
+ for idx in range(1, n_angles):
+ warped_points = rotated[idx] / np.array(shape)
+ if np.all((warped_points >= 0.0) & (warped_points < 1.0)):
+ pts1 = rotated[idx]
+ break
+
+ # Translation
+ if translation > 0:
+ min_trans = -np.min(pts1, axis=0)
+ max_trans = shape - np.max(pts1, axis=0)
+ trans = rng.uniform(min_trans, max_trans)[None]
+ pts1 += trans * translation * difficulty
+
+ H = compute_homography(pts1, pts2, [1.0, 1.0])
+ warped = warp_points(full, H, inverse=False)
+ return H, full, warped, patch_shape
+
+
+def compute_homography(pts1_, pts2_, shape):
+ """Compute the homography matrix from 4 point correspondences"""
+ # Rescale to actual size
+ shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
+ pts1 = pts1_ * np.expand_dims(shape, axis=0)
+ pts2 = pts2_ * np.expand_dims(shape, axis=0)
+
+ def ax(p, q):
+ return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
+
+ def ay(p, q):
+ return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
+
+ a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
+ p_mat = np.transpose(
+ np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
+ )
+ homography = np.transpose(np.linalg.solve(a_mat, p_mat))
+ return flat2mat(homography)
+
+
+# Point warping utils
+
+
+def warp_points(points, homography, inverse=True):
+ """
+ Warp a list of points with the INVERSE of the given homography.
+ The inverse is used to be coherent with tf.contrib.image.transform
+ Arguments:
+ points: list of N points, shape (N, 2).
+ homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
+ Returns: a Tensor of shape (N, 2) or (B, N, 2) (depending on whether the homography
+ is batched) containing the new coordinates of the warped points.
+ """
+ H = homography[None] if len(homography.shape) == 2 else homography
+
+ # Get the points to the homogeneous format
+ num_points = points.shape[0]
+ points = np.concatenate([points, np.ones([num_points, 1], dtype=np.float32)], -1)
+
+ H_inv = np.transpose(np.linalg.inv(H) if inverse else H)
+ warped_points = np.tensordot(points, H_inv, axes=[[1], [0]])
+
+ warped_points = np.transpose(warped_points, [2, 0, 1])
+ warped_points[np.abs(warped_points[:, :, 2]) < 1e-8, 2] = 1e-8
+ warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:]
+
+ return warped_points[0] if len(homography.shape) == 2 else warped_points
+
+
+def warp_points_torch(points, H, inverse=True):
+ """
+ Warp a list of points with the INVERSE of the given homography.
+ The inverse is used to be coherent with tf.contrib.image.transform
+ Arguments:
+ points: batched list of N points, shape (B, N, 2).
+ H: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
+ inverse: Whether to multiply the points by H or the inverse of H
+ Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
+ """
+
+ # Get the points to the homogeneous format
+ points = to_homogeneous(points)
+
+ # Apply the homography
+ H_mat = (torch.inverse(H) if inverse else H).transpose(-2, -1)
+ warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat)
+
+ warped_points = from_homogeneous(warped_points, eps=1e-5)
+ return warped_points
+
+
+# Line warping utils
+
+
+def seg_equation(segs):
+ # calculate list of start, end and midpoints points from both lists
+ start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(
+ segs[..., 1, :]
+ )
+ # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1
+ lines = torch.cross(start_points, end_points, dim=-1)
+ lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]
+ assert torch.all(
+ lines_norm > 0
+ ), "Error: trying to compute the equation of a line with a single point"
+ lines = lines / lines_norm
+ return lines
+
+
+def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]):
+ h, w = img_shape
+ return (
+ (pts >= 0).all(dim=-1)
+ & (pts[..., 0] < w)
+ & (pts[..., 1] < h)
+ & (~torch.isinf(pts).any(dim=-1))
+ )
+
+
+def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor:
+ """
+ Shrink an array of segments to fit inside the image.
+ :param segs: The tensor of segments with shape (N, 2, 2)
+ :param img_shape: The image shape in format (H, W)
+ """
+ EPS = 1e-4
+ device = segs.device
+ w, h = img_shape[1], img_shape[0]
+ # Project the segments to the reference image
+ segs = segs.clone()
+ eqs = seg_equation(segs)
+ x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor(
+ [0.0, 1, 0], device=device
+ )
+ x0 = x0.repeat(eqs.shape[:-1] + (1,))
+ y0 = y0.repeat(eqs.shape[:-1] + (1,))
+ pt_x0s = torch.cross(eqs, x0, dim=-1)
+ pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1]
+ pt_x0s_valid = is_inside_img(pt_x0s, img_shape)
+ pt_y0s = torch.cross(eqs, y0, dim=-1)
+ pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1]
+ pt_y0s_valid = is_inside_img(pt_y0s, img_shape)
+
+ xW = torch.tensor([1.0, 0, EPS - w], device=device)
+ yH = torch.tensor([0.0, 1, EPS - h], device=device)
+ xW = xW.repeat(eqs.shape[:-1] + (1,))
+ yH = yH.repeat(eqs.shape[:-1] + (1,))
+ pt_xWs = torch.cross(eqs, xW, dim=-1)
+ pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1]
+ pt_xWs_valid = is_inside_img(pt_xWs, img_shape)
+ pt_yHs = torch.cross(eqs, yH, dim=-1)
+ pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1]
+ pt_yHs_valid = is_inside_img(pt_yHs, img_shape)
+
+ # If the X coordinate of the first endpoint is out
+ mask = (segs[..., 0, 0] < 0) & pt_x0s_valid
+ segs[mask, 0, :] = pt_x0s[mask]
+ mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid
+ segs[mask, 0, :] = pt_xWs[mask]
+ # If the X coordinate of the second endpoint is out
+ mask = (segs[..., 1, 0] < 0) & pt_x0s_valid
+ segs[mask, 1, :] = pt_x0s[mask]
+ mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid
+ segs[mask, 1, :] = pt_xWs[mask]
+ # If the Y coordinate of the first endpoint is out
+ mask = (segs[..., 0, 1] < 0) & pt_y0s_valid
+ segs[mask, 0, :] = pt_y0s[mask]
+ mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid
+ segs[mask, 0, :] = pt_yHs[mask]
+ # If the Y coordinate of the second endpoint is out
+ mask = (segs[..., 1, 1] < 0) & pt_y0s_valid
+ segs[mask, 1, :] = pt_y0s[mask]
+ mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid
+ segs[mask, 1, :] = pt_yHs[mask]
+
+ assert (
+ torch.all(segs >= 0)
+ and torch.all(segs[..., 0] < w)
+ and torch.all(segs[..., 1] < h)
+ )
+ return segs
+
+
+def warp_lines_torch(
+ lines, H, inverse=True, dst_shape: Tuple[int, int] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ :param lines: A tensor of shape (B, N, 2, 2)
+ where B is the batch size, N the number of lines.
+ :param H: The homography used to convert the lines.
+ batched or not (shapes (B, 3, 3) and (3, 3) respectively).
+ :param inverse: Whether to apply H or the inverse of H
+ :param dst_shape:If provided, lines are trimmed to be inside the image
+ """
+ device = lines.device
+ batch_size = len(lines)
+ lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(
+ lines.shape
+ )
+
+ if dst_shape is None:
+ return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device)
+
+ out_img = torch.any(
+ (lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1
+ )
+ valid = ~out_img.all(-1)
+ any_out_of_img = out_img.any(-1)
+ lines_to_trim = valid & any_out_of_img
+
+ for b in range(batch_size):
+ lines_to_trim_mask_b = lines_to_trim[b]
+ lines_to_trim_b = lines[b][lines_to_trim_mask_b]
+ corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape)
+ lines[b][lines_to_trim_mask_b] = corrected_lines
+
+ return lines, valid
+
+
+# Homography evaluation utils
+
+
+def sym_homography_error(kpts0, kpts1, T_0to1):
+ kpts0_1 = from_homogeneous(to_homogeneous(kpts0) @ T_0to1.transpose(-1, -2))
+ dist0_1 = ((kpts0_1 - kpts1) ** 2).sum(-1).sqrt()
+
+ kpts1_0 = from_homogeneous(
+ to_homogeneous(kpts1) @ torch.pinverse(T_0to1.transpose(-1, -2))
+ )
+ dist1_0 = ((kpts1_0 - kpts0) ** 2).sum(-1).sqrt()
+
+ return (dist0_1 + dist1_0) / 2.0
+
+
+def sym_homography_error_all(kpts0, kpts1, H):
+ kp0_1 = warp_points_torch(kpts0, H, inverse=False)
+ kp1_0 = warp_points_torch(kpts1, H, inverse=True)
+
+ # build a distance matrix of size [... x M x N]
+ dist0 = torch.sum((kp0_1.unsqueeze(-2) - kpts1.unsqueeze(-3)) ** 2, -1).sqrt()
+ dist1 = torch.sum((kpts0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1).sqrt()
+ return (dist0 + dist1) / 2.0
+
+
+def homography_corner_error(T, T_gt, image_size):
+ W, H = image_size[..., 0], image_size[..., 1]
+ corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
+ corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
+ corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
+ d = torch.sqrt(((corners1 - corners1_gt) ** 2).sum(-1))
+ return d.mean(-1)
diff --git a/imcui/third_party/MatchAnything/src/utils/metrics.py b/imcui/third_party/MatchAnything/src/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..32703f64de82aa45996011195a59b5f493a82bf1
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/metrics.py
@@ -0,0 +1,445 @@
+import torch
+import cv2
+import numpy as np
+from collections import OrderedDict
+from loguru import logger
+from .homography_utils import warp_points, warp_points_torch
+from kornia.geometry.epipolar import numeric
+from kornia.geometry.conversions import convert_points_to_homogeneous
+import pprint
+
+
+# --- METRICS ---
+
+def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
+ # angle error between 2 vectors
+ t_gt = T_0to1[:3, 3]
+ n = np.linalg.norm(t) * np.linalg.norm(t_gt)
+ t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
+ t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
+ if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
+ t_err = 0
+
+ # angle error between 2 rotation matrices
+ R_gt = T_0to1[:3, :3]
+ cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
+ cos = np.clip(cos, -1., 1.) # handle numercial errors
+ R_err = np.rad2deg(np.abs(np.arccos(cos)))
+
+ return t_err, R_err
+
+def warp_pts_error(H_est, pts_coord, H_gt=None, pts_gt=None):
+ """
+ corner_coord: 4*2
+ """
+ if H_gt is not None:
+ est_warp = warp_points(pts_coord, H_est, False)
+ est_gt = warp_points(pts_coord, H_gt, False)
+ diff = est_warp - est_gt
+ elif pts_gt is not None:
+ est_warp = warp_points(pts_coord, H_est, False)
+ diff = est_warp - pts_gt
+
+ return np.mean(np.linalg.norm(diff, axis=1))
+
+def homo_warp_match_distance(H_gt, kpts0, kpts1, hw):
+ """
+ corner_coord: 4*2
+ """
+ if isinstance(H_gt, np.ndarray):
+ kpts_warped = warp_points(kpts0, H_gt)
+ normalized_distance = np.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1)
+ else:
+ kpts_warped = warp_points_torch(kpts0, H_gt)
+ normalized_distance = torch.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1)
+ return normalized_distance
+
+def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
+ """Squared symmetric epipolar distance.
+ This can be seen as a biased estimation of the reprojection error.
+ Args:
+ pts0 (torch.Tensor): [N, 2]
+ E (torch.Tensor): [3, 3]
+ """
+ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+ pts0 = convert_points_to_homogeneous(pts0)
+ pts1 = convert_points_to_homogeneous(pts1)
+
+ Ep0 = pts0 @ E.T # [N, 3]
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
+ Etp1 = pts1 @ E # [N, 3]
+
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
+ return d
+
+
+def compute_symmetrical_epipolar_errors(data, config):
+ """
+ Update:
+ data (dict):{"epi_errs": [M]}
+ """
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
+
+ m_bids = data['m_bids']
+ pts0 = data['mkpts0_f']
+ pts1 = data['mkpts1_f'].clone().detach()
+
+ if config.LOFTR.FINE.MTD_SPVS:
+ m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids']
+ epi_errs = []
+ for bs in range(Tx.size(0)):
+ mask = m_bids == bs
+ epi_errs.append(
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+ epi_errs = torch.cat(epi_errs, dim=0)
+
+ data.update({'epi_errs': epi_errs})
+
+def compute_homo_match_warp_errors(data, config):
+ """
+ Update:
+ data (dict):{"epi_errs": [M]}
+ """
+
+ homography_gt = data['homography']
+ m_bids = data['m_bids']
+ pts0 = data['mkpts0_f']
+ pts1 = data['mkpts1_f']
+ origin_img0_size = data['origin_img_size0']
+
+ if config.LOFTR.FINE.MTD_SPVS:
+ m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids']
+ epi_errs = []
+ for bs in range(homography_gt.shape[0]):
+ mask = m_bids == bs
+ epi_errs.append(
+ homo_warp_match_distance(homography_gt[bs], pts0[mask], pts1[mask], origin_img0_size[bs]))
+ epi_errs = torch.cat(epi_errs, dim=0)
+
+ data.update({'epi_errs': epi_errs})
+
+
+def compute_symmetrical_epipolar_errors_gt(data, config):
+ """
+ Update:
+ data (dict):{"epi_errs": [M]}
+ """
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
+
+ m_bids = data['m_bids']
+ pts0 = data['mkpts0_f_gt']
+ pts1 = data['mkpts1_f_gt']
+
+ epi_errs = []
+ for bs in range(Tx.size(0)):
+ # mask = m_bids == bs
+ assert bs == 0
+ mask = torch.tensor([True]*pts0.shape[0], device = pts0.device)
+ if config.LOFTR.FINE.MTD_SPVS:
+ epi_errs.append(
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+ else:
+ epi_errs.append(
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+ epi_errs = torch.cat(epi_errs, dim=0)
+
+ data.update({'epi_errs': epi_errs})
+
+
+def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ # normalize keypoints
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+
+ # normalize ransac threshold
+ ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
+
+ # compute pose with cv2
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
+ if E is None:
+ print("\nE is None while trying to recover pose.\n")
+ return None
+
+ # recover pose from E
+ best_num_inliers = 0
+ ret = None
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ ret = (R, t[:, 0], mask.ravel() > 0)
+ best_num_inliers = n
+
+ return ret
+
+def estimate_homo(kpts0, kpts1, thresh, conf=0.99999, mode='affine'):
+ if mode == 'affine':
+ H_est, inliers = cv2.estimateAffine2D(kpts0, kpts1, ransacReprojThreshold=thresh, confidence=conf, method=cv2.RANSAC)
+ if H_est is None:
+ return np.eye(3) * 0, np.empty((0))
+ H_est = np.concatenate([H_est, np.array([[0, 0, 1]])], axis=0) # 3 * 3
+ elif mode == 'homo':
+ H_est, inliers = cv2.findHomography(kpts0, kpts1, method=cv2.LMEDS, ransacReprojThreshold=thresh)
+ if H_est is None:
+ return np.eye(3) * 0, np.empty((0))
+
+ return H_est, inliers
+
+def compute_homo_corner_warp_errors(data, config):
+ """
+ Update:
+ data (dict):{
+ "R_errs" List[float]: [N] # Actually warp error
+ "t_errs" List[float]: [N] # Zero, place holder
+ "inliers" List[np.ndarray]: [N]
+ }
+ """
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+
+ if config.LOFTR.FINE.MTD_SPVS:
+ m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
+
+ else:
+ m_bids = data['m_bids'].cpu().numpy()
+ pts0 = data['mkpts0_f'].cpu().numpy()
+ pts1 = data['mkpts1_f'].cpu().numpy()
+ homography_gt = data['homography'].cpu().numpy()
+ origin_size_0 = data['origin_img_size0'].cpu().numpy()
+
+ for bs in range(homography_gt.shape[0]):
+ mask = m_bids == bs
+ ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf)
+
+ if ret is None:
+ data['R_errs'].append(np.inf)
+ data['t_errs'].append(np.inf)
+ data['inliers'].append(np.array([]).astype(bool))
+ else:
+ H_est, inliers = ret
+ corner_coord = np.array([[0, 0], [0, origin_size_0[bs][0]], [origin_size_0[bs][1], 0], [origin_size_0[bs][1], origin_size_0[bs][0]]])
+ corner_warp_distance = warp_pts_error(H_est, corner_coord, H_gt=homography_gt[bs])
+ data['R_errs'].append(corner_warp_distance)
+ data['t_errs'].append(0)
+ data['inliers'].append(inliers)
+
+def compute_warp_control_pts_errors(data, config):
+ """
+ Update:
+ data (dict):{
+ "R_errs" List[float]: [N] # Actually warp error
+ "t_errs" List[float]: [N] # Zero, place holder
+ "inliers" List[np.ndarray]: [N]
+ }
+ """
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+
+ if config.LOFTR.FINE.MTD_SPVS:
+ m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
+
+ else:
+ m_bids = data['m_bids'].cpu().numpy()
+ pts0 = data['mkpts0_f'].cpu().numpy()
+ pts1 = data['mkpts1_f'].cpu().numpy()
+ gt_2D_matches = data["gt_2D_matches"].cpu().numpy()
+
+ data.update({'epi_errs': torch.zeros(m_bids.shape[0])})
+ for bs in range(gt_2D_matches.shape[0]):
+ mask = m_bids == bs
+ ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf, mode=config.TRAINER.WARP_ESTIMATOR_MODEL)
+
+ if ret is None:
+ data['R_errs'].append(np.inf)
+ data['t_errs'].append(np.inf)
+ data['inliers'].append(np.array([]).astype(bool))
+ else:
+ H_est, inliers = ret
+ img0_pts, img1_pts = gt_2D_matches[bs][:, :2], gt_2D_matches[bs][:, 2:]
+ pts_warp_distance = warp_pts_error(H_est, img0_pts, pts_gt=img1_pts)
+ print(pts_warp_distance)
+ data['R_errs'].append(pts_warp_distance)
+ data['t_errs'].append(0)
+ data['inliers'].append(inliers)
+
+def compute_pose_errors(data, config):
+ """
+ Update:
+ data (dict):{
+ "R_errs" List[float]: [N]
+ "t_errs" List[float]: [N]
+ "inliers" List[np.ndarray]: [N]
+ }
+ """
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+
+ if config.LOFTR.FINE.MTD_SPVS:
+ m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy()
+
+ else:
+ m_bids = data['m_bids'].cpu().numpy()
+ pts0 = data['mkpts0_f'].cpu().numpy()
+ pts1 = data['mkpts1_f'].cpu().numpy()
+ K0 = data['K0'].cpu().numpy()
+ K1 = data['K1'].cpu().numpy()
+ T_0to1 = data['T_0to1'].cpu().numpy()
+
+ for bs in range(K0.shape[0]):
+ mask = m_bids == bs
+ if config.LOFTR.EVAL_TIMES >= 1:
+ bpts0, bpts1 = pts0[mask], pts1[mask]
+ R_list, T_list, inliers_list = [], [], []
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(bpts0)))
+ if _ >= config.LOFTR.EVAL_TIMES:
+ continue
+ bpts0 = bpts0[shuffling]
+ bpts1 = bpts1[shuffling]
+
+ ret = estimate_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf)
+ if ret is None:
+ R_list.append(np.inf)
+ T_list.append(np.inf)
+ inliers_list.append(np.array([]).astype(bool))
+ print('Pose error: inf')
+ else:
+ R, t, inliers = ret
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
+ R_list.append(R_err)
+ T_list.append(t_err)
+ inliers_list.append(inliers)
+ print(f'Pose error: {max(R_err, t_err)}')
+ R_err_mean = np.array(R_list).mean()
+ T_err_mean = np.array(T_list).mean()
+ # inliers_mean = np.array(inliers_list).mean()
+
+ data['R_errs'].append(R_list)
+ data['t_errs'].append(T_list)
+ data['inliers'].append(inliers_list[0])
+
+ else:
+ ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
+
+ if ret is None:
+ data['R_errs'].append(np.inf)
+ data['t_errs'].append(np.inf)
+ data['inliers'].append(np.array([]).astype(bool))
+ print('Pose error: inf')
+ else:
+ R, t, inliers = ret
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
+ data['R_errs'].append(R_err)
+ data['t_errs'].append(t_err)
+ data['inliers'].append(inliers)
+ print(f'Pose error: {max(R_err, t_err)}')
+
+
+# --- METRIC AGGREGATION ---
+def error_rmse(error):
+ squard_errors = np.square(error) # N * 2
+ mse = np.mean(np.sum(squard_errors, axis=1))
+ rmse = np.sqrt(mse)
+ return rmse
+
+def error_mae(error):
+ abs_diff = np.abs(error) # N * 2
+ absolute_errors = np.sum(abs_diff, axis=1)
+
+ # Return the maximum absolute error
+ mae = np.max(absolute_errors)
+ return mae
+
+def error_auc(errors, thresholds, method='exact_auc'):
+ """
+ Args:
+ errors (list): [N,]
+ thresholds (list)
+ """
+ if method == 'exact_auc':
+ errors = [0] + sorted(list(errors))
+ recall = list(np.linspace(0, 1, len(errors)))
+
+ aucs = []
+ for thr in thresholds:
+ last_index = np.searchsorted(errors, thr)
+ y = recall[:last_index] + [recall[last_index-1]]
+ x = errors[:last_index] + [thr]
+ aucs.append(np.trapz(y, x) / thr)
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+ elif method == 'fire_paper':
+ aucs = []
+ for threshold in thresholds:
+ accum_error = 0
+ percent_error_below = np.zeros(threshold + 1)
+ for i in range(1, threshold + 1):
+ percent_error_below[i] = np.sum(errors < i) * 100 / len(errors)
+ accum_error += percent_error_below[i]
+
+ aucs.append(accum_error / (threshold * 100))
+
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+ elif method == 'success_rate':
+ aucs = []
+ for threshold in thresholds:
+ aucs.append((errors < threshold).astype(float).mean())
+ return {f'SR@{t}': auc for t, auc in zip(thresholds, aucs)}
+ else:
+ raise NotImplementedError
+
+
+def epidist_prec(errors, thresholds, ret_dict=False):
+ precs = []
+ for thr in thresholds:
+ prec_ = []
+ for errs in errors:
+ correct_mask = errs < thr
+ prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
+ precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
+ if ret_dict:
+ return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
+ else:
+ return precs
+
+
+def aggregate_metrics(metrics, epi_err_thr=5e-4, eval_n_time=1, threshold=[5, 10, 20], method='exact_auc'):
+ """ Aggregate metrics for the whole dataset:
+ (This method should be called once per dataset)
+ 1. AUC of the pose error (angular) at the threshold [5, 10, 20]
+ 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
+ """
+ # filter duplicates
+ unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
+ unq_ids = list(unq_ids.values())
+ logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
+
+ # pose auc
+ angular_thresholds = threshold
+ if eval_n_time >= 1:
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0).reshape(-1, eval_n_time)[unq_ids].reshape(-1)
+ else:
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
+ logger.info('num of pose_errors: {}'.format(pose_errors.shape))
+ aucs = error_auc(pose_errors, angular_thresholds, method=method) # (auc@5, auc@10, auc@20)
+
+ if eval_n_time >= 1:
+ for i in range(eval_n_time):
+ aucs_i = error_auc(pose_errors.reshape(-1, eval_n_time)[:,i], angular_thresholds, method=method)
+ logger.info('\n' + f'results of {i}-th RANSAC' + pprint.pformat(aucs_i))
+ # matching precision
+ dist_thresholds = [epi_err_thr]
+ precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
+
+ u_num_mathces = np.array(metrics['num_matches'], dtype=object)[unq_ids]
+ u_percent_inliers = np.array(metrics['percent_inliers'], dtype=object)[unq_ids]
+ num_matches = {f'num_matches': u_num_mathces.mean() }
+ percent_inliers = {f'percent_inliers': u_percent_inliers.mean()}
+ return {**aucs, **precs, **num_matches, **percent_inliers}
diff --git a/imcui/third_party/MatchAnything/src/utils/misc.py b/imcui/third_party/MatchAnything/src/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8db04666519753ea2df43903ab6c47ec00a9a1
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/misc.py
@@ -0,0 +1,101 @@
+import os
+import contextlib
+import joblib
+from typing import Union
+from loguru import _Logger, logger
+from itertools import chain
+
+import torch
+from yacs.config import CfgNode as CN
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+def upper_config(dict_cfg):
+ if not isinstance(dict_cfg, dict):
+ return dict_cfg
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
+
+
+def log_on(condition, message, level):
+ if condition:
+ assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
+ logger.log(level, message)
+
+
+def get_rank_zero_only_logger(logger: _Logger):
+ if rank_zero_only.rank == 0:
+ return logger
+ else:
+ for _level in logger._core.levels.keys():
+ level = _level.lower()
+ setattr(logger, level,
+ lambda x: None)
+ logger._log = lambda x: None
+ return logger
+
+
+def setup_gpus(gpus: Union[str, int]) -> int:
+ """ A temporary fix for pytorch-lighting 1.3.x """
+ gpus = str(gpus)
+ gpu_ids = []
+
+ if ',' not in gpus:
+ n_gpus = int(gpus)
+ return n_gpus if n_gpus != -1 else torch.cuda.device_count()
+ else:
+ gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
+
+ # setup environment variables
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ if visible_devices is None:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
+ else:
+ logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
+ return len(gpu_ids)
+
+
+def flattenList(x):
+ return list(chain(*x))
+
+
+@contextlib.contextmanager
+def tqdm_joblib(tqdm_object):
+ """Context manager to patch joblib to report into tqdm progress bar given as argument
+
+ Usage:
+ with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
+ Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
+
+ When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
+ ret_vals = Parallel(n_jobs=args.world_size)(
+ delayed(lambda x: _compute_cov_score(pid, *x))(param)
+ for param in tqdm(combinations(image_ids, 2),
+ desc=f'Computing cov_score of [{pid}]',
+ total=len(image_ids)*(len(image_ids)-1)/2))
+ Src: https://stackoverflow.com/a/58936697
+ """
+ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ tqdm_object.update(n=self.batch_size)
+ return super().__call__(*args, **kwargs)
+
+ old_batch_callback = joblib.parallel.BatchCompletionCallBack
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
+ try:
+ yield tqdm_object
+ finally:
+ joblib.parallel.BatchCompletionCallBack = old_batch_callback
+ tqdm_object.close()
+
diff --git a/imcui/third_party/MatchAnything/src/utils/plotting.py b/imcui/third_party/MatchAnything/src/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..87d39733f57a09e4db61b61e53f82ad80e4a839b
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/plotting.py
@@ -0,0 +1,248 @@
+import bisect
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+
+import torch
+
+def _compute_conf_thresh(data):
+ dataset_name = data['dataset_name'][0].lower()
+ if dataset_name == 'scannet':
+ thr = 5e-4
+ elif dataset_name == 'megadepth':
+ thr = 1e-4
+ else:
+ raise ValueError(f'Unknown dataset: {dataset_name}')
+ return thr
+
+
+# --- VISUALIZATION --- #
+def make_matching_figure(
+ img0, img1, mkpts0, mkpts1, color,
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+ # draw image pair
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0, cmap='gray')
+ axes[1].imshow(img1, cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if kpts0 is not None:
+ assert kpts1 is not None
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
+
+ # draw matches
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure, c=color[i], linewidth=1)
+ for i in range(len(mkpts0))]
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
+
+ # put txts
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def _make_evaluation_figure(data, b_id, alpha='dynamic', use_m_bids_f=False):
+ if use_m_bids_f:
+ b_mask = (data['m_bids_f'] == b_id) if 'm_bids_f' in data else (data['m_bids'] == b_id)
+ else:
+ b_mask = data['m_bids'] == b_id
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
+ kpts1 = data['mkpts1_f'][b_mask].clone().detach().cpu().numpy()
+
+ # for megadepth, we visualize matches on the resized image
+ if 'scale0' in data:
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()
+
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) if 'conf_matrix_gt' in data else data['gt'][1]['gt_prob'].sum()
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+ ]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
+ color, text=text)
+ return figure
+
+def _make_confidence_figure(data, b_id):
+ raise NotImplementedError()
+
+def _make_gt_figure(data, b_id, alpha='dynamic', use_m_bids_f=False, mode='gt_fine'):
+ if 'fine' in mode:
+ mkpts0_key, mkpts1_key = 'mkpts0_f_gt', 'mkpts1_f_gt'
+ else:
+ mkpts0_key, mkpts1_key = 'mkpts0_c_gt', 'mkpts1_c_gt'
+
+ if data['image0'].shape[0] == 1:
+ b_mask = torch.tensor([True]*data[mkpts0_key].shape[0], device = data[mkpts0_key].device)
+ else:
+ raise NotImplementedError
+
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ try:
+ kpts0 = data[mkpts0_key][b_mask].cpu().numpy()
+ kpts1 = data[mkpts1_key][b_mask].cpu().numpy()
+ except:
+ kpts0, kpts1 = np.ones((0, 2)), np.ones((0, 2))
+
+ # for megadepth, we visualize matches on the resized image
+ if 'scale0' in data:
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(kpts0))
+ color = error_colormap(np.full((kpts0.shape[0]), conf_thr), conf_thr, alpha=0.1)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ ]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
+ color, text=text)
+ return figure
+
+def make_matching_figures(data, config, mode='evaluation'):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ fig = _make_evaluation_figure(
+ data, b_id,
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA,
+ use_m_bids_f=config.LOFTR.FINE.MTD_SPVS)
+ elif mode == 'confidence':
+ fig = _make_confidence_figure(data, b_id)
+ elif 'gt' in mode:
+ fig = _make_gt_figure(data, b_id, use_m_bids_f=config.LOFTR.FINE.MTD_SPVS, mode=mode)
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+def make_scores_figures(data, config, mode='evaluation'):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ['evaluation', 'confidence', 'gt'] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ if config.LOFTR.MATCH_COARSE.SKIP_SOFTMAX and config.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES:
+ plots = [data['histc_skipmn_in_softmax'][b_id].reshape(-1)] # [-30, 70] scores
+ if 'histc_skipmn_in_softmax_gt' in data:
+ plots.append(data['histc_skipmn_in_softmax_gt'][b_id].reshape(-1))
+ elif config.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES:
+ pass
+ else:
+ pass
+ print(plots[0], plots[-1])
+ group = len(plots)
+ start, end = 0, 100
+ bins=100
+ width = (end//bins-1)/group
+ fig, ax = plt.subplots()
+ for i, hist in enumerate(plots):
+ ax.set_yscale('log')
+ x = range(start, end, end//bins)
+ x = [t + i*width for t in x]
+ ax.bar(x, hist.cpu(), align='edge', width=width)
+
+ elif mode == 'confidence':
+ raise NotImplementedError()
+ elif mode == 'gt':
+ raise NotImplementedError()
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+
+def dynamic_alpha(n_matches,
+ milestones=[0, 300, 1000, 2000],
+ alphas=[1.0, 0.8, 0.4, 0.2]):
+ if n_matches == 0:
+ return 1.0
+ ranges = list(zip(alphas, alphas[1:] + [None]))
+ loc = bisect.bisect_right(milestones, n_matches) - 1
+ _range = ranges[loc]
+ if _range[1] is None:
+ return _range[0]
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+
+
+def error_colormap(err, thr, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ if thr is not None:
+ large_error_mask = err > (thr * 2)
+ x = np.clip((err - thr) / (thr * 2), 0, 1)
+ else:
+ large_error_mask = np.zeros_like(err, dtype=bool)
+ x = np.clip(err, 0.1, 1)
+
+ cm_ = matplotlib.colormaps['RdYlGn_r']
+ color = cm_(x, bytes=False)
+ color[:, 3] = alpha
+ color[:, 3][large_error_mask] = alpha * 0.6
+ return color
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/utils/profiler.py b/imcui/third_party/MatchAnything/src/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/profiler.py
@@ -0,0 +1,39 @@
+import torch
+from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
+from contextlib import contextmanager
+from pytorch_lightning.utilities import rank_zero_only
+
+
+class InferenceProfiler(SimpleProfiler):
+ """
+ This profiler records duration of actions with cuda.synchronize()
+ Use this in test time.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.start = rank_zero_only(self.start)
+ self.stop = rank_zero_only(self.stop)
+ self.summary = rank_zero_only(self.summary)
+
+ @contextmanager
+ def profile(self, action_name: str) -> None:
+ try:
+ torch.cuda.synchronize()
+ self.start(action_name)
+ yield action_name
+ finally:
+ torch.cuda.synchronize()
+ self.stop(action_name)
+
+
+def build_profiler(name):
+ if name == 'inference':
+ return InferenceProfiler()
+ elif name == 'pytorch':
+ from pytorch_lightning.profiler import PyTorchProfiler
+ return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
+ elif name is None:
+ return PassThroughProfiler()
+ else:
+ raise ValueError(f'Invalid profiler: {name}')
diff --git a/imcui/third_party/MatchAnything/src/utils/ray_utils.py b/imcui/third_party/MatchAnything/src/utils/ray_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d6c7bf3172f4d53513338b1b615efcfded1c4c9
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/ray_utils.py
@@ -0,0 +1,134 @@
+from asyncio import Event
+from typing import Tuple
+import numpy as np
+import random
+
+import ray
+from ray.actor import ActorHandle
+from tqdm import tqdm
+
+
+@ray.remote
+class ProgressBarActor:
+ counter: int
+ delta: int
+ event: Event
+
+ def __init__(self) -> None:
+ self.counter = 0
+ self.delta = 0
+ self.event = Event()
+
+ def update(self, num_items_completed: int) -> None:
+ """Updates the ProgressBar with the incremental
+ number of items that were just completed.
+ """
+ self.counter += num_items_completed
+ self.delta += num_items_completed
+ self.event.set()
+
+ async def wait_for_update(self) -> Tuple[int, int]:
+ """Blocking call.
+
+ Waits until somebody calls `update`, then returns a tuple of
+ the number of updates since the last call to
+ `wait_for_update`, and the total number of completed items.
+ """
+ await self.event.wait()
+ self.event.clear()
+ saved_delta = self.delta
+ self.delta = 0
+ return saved_delta, self.counter
+
+ def get_counter(self) -> int:
+ """
+ Returns the total number of complete items.
+ """
+ return self.counter
+
+
+class ProgressBar:
+ progress_actor: ActorHandle
+ total: int
+ description: str
+ pbar: tqdm
+
+ def __init__(self, total: int, description: str = ""):
+ # Ray actors don't seem to play nice with mypy, generating
+ # a spurious warning for the following line,
+ # which we need to suppress. The code is fine.
+ self.progress_actor = ProgressBarActor.remote() # type: ignore
+ self.total = total
+ self.description = description
+
+ @property
+ def actor(self) -> ActorHandle:
+ """Returns a reference to the remote `ProgressBarActor`.
+
+ When you complete tasks, call `update` on the actor.
+ """
+ return self.progress_actor
+
+ def print_until_done(self) -> None:
+ """Blocking call.
+
+ Do this after starting a series of remote Ray tasks, to which you've
+ passed the actor handle. Each of them calls `update` on the actor.
+ When the progress meter reaches 100%, this method returns.
+ """
+ pbar = tqdm(desc=self.description, total=self.total)
+ while True:
+ delta, counter = ray.get(self.actor.wait_for_update.remote())
+ pbar.update(delta)
+ if counter >= self.total:
+ pbar.close()
+ return
+
+# Ray data utils
+def chunks(lst, n, length=None):
+ """Yield successive n-sized chunks from lst."""
+ try:
+ _len = len(lst)
+ except TypeError as _:
+ assert length is not None
+ _len = length
+
+ for i in range(0, _len, n):
+ yield lst[i : i + n]
+
+def chunks_balance(lst, n_split):
+ if n_split == 0:
+ # 0 is not allowed
+ n_split = 1
+ splited_list = [[] for i in range(n_split)]
+ for id, obj in enumerate(lst):
+ assign_id = id % n_split
+ splited_list[assign_id].append(obj)
+ return splited_list
+
+
+def chunk_index(total_len, sub_len, shuffle=True):
+ index_array = np.arange(total_len)
+ if shuffle:
+ random.shuffle(index_array)
+
+ index_list = []
+ for i in range(0, total_len, sub_len):
+ index_list.append(list(index_array[i : i + sub_len]))
+
+ return index_list
+
+def chunk_index_balance(total_len, n_split, shuffle=True):
+ index_array = np.arange(total_len)
+ if shuffle:
+ random.shuffle(index_array)
+
+ splited_list = [[] for i in range(n_split)]
+ for id, obj in enumerate(index_array):
+ assign_id = id % n_split
+ splited_list[assign_id].append(obj)
+ return splited_list
+
+def split_dict(_dict, n):
+ for _items in chunks(list(_dict.items()), n):
+ yield dict(_items)
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/utils/sample_homo.py b/imcui/third_party/MatchAnything/src/utils/sample_homo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4fdac234e1058b5a1106a98b5c3f77426af4b91
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/sample_homo.py
@@ -0,0 +1,58 @@
+import numpy as np
+
+# ----- Similarity-Affinity-Perspective (SAP) impl ----- #
+
+def similarity_mat(angle, tx, ty, s):
+ theta = np.deg2rad(angle)
+ return np.array([[s*np.cos(theta), -s*np.sin(theta), tx], [s*np.sin(theta), s*np.cos(theta), ty], [0, 0, 1]])
+
+
+def affinity_mat(k0, k1):
+ return np.array([[k0, k1, 0], [0, 1/k0, 0], [0, 0, 1]])
+
+
+def perspective_mat(v0, v1):
+ return np.array([[1, 0, 0], [0, 1, 0], [v0, v1, 1]])
+
+
+def compute_homography_sap(h, w, angle=0, tx=0, ty=0, scale=1, k0=1, k1=0, v0=0, v1=0):
+ """
+ Args:
+ img_size: (h, w)
+ angle: in degree, goes clock-wise in image-coordinate-system
+ tx, ty: displacement
+ scale: factor to zoom in, by default 1
+ k0: non-isotropic squeeze factor - 1 +(stretch x, squeeze y) [0.5, 1.5]
+ k1: non-isotropic skew factor, - 0 +(up-to-left, down-to-right) [-0.5, 0.5]
+ v0: left-right perspective factor, - 0 +(move left) [-1, 1]
+ v1: up-down perspective factor, - 0 +(move up) [-1, 1]
+ """
+ # move image to its center
+ max_size = max(w/2, h/2)
+ M_norm = similarity_mat(0, 0, 0, 1/max_size).dot(similarity_mat(0, -w/2, -h/2, 1))
+ M_denorm = similarity_mat(0, w/2, h/2, 1).dot(similarity_mat(0, 0, 0, max_size))
+
+ # compute HS, HA and HP accordingly
+ HS = similarity_mat(angle, tx, ty, scale)
+ HA = affinity_mat(k0, k1)
+ HP = perspective_mat(v0, v1)
+
+ # final H
+ H = M_denorm.dot(HS).dot(HA).dot(HP).dot(M_norm)
+ return H
+
+
+def sample_homography_sap(h, w, angle=180, tx=0.25, ty=0.25, scale=2.0, k1=0.1, v0=0.5, v1=0.5):
+ angle = np.random.uniform(-1 * angle, angle)
+ tx = np.random.uniform(-1 * tx, tx)
+ ty = np.random.uniform(-1 * ty, ty)
+ scale = np.random.uniform(1/scale, scale)
+
+ k0 = 1 # similar effects as the ratio of xy-focal lengths
+ k1 = np.random.uniform(-1 * k1, k1)
+
+ v0 = np.random.uniform(-1 * v0, v0)
+ v1 = np.random.uniform(-1 * v1, v1)
+
+ H = compute_homography_sap(h, w, angle, tx, ty, scale, k0, k1, v0, v1)
+ return H
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/src/utils/utils.py b/imcui/third_party/MatchAnything/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2f3d6f96fbb800d290c83a174f93ead187dcf3
--- /dev/null
+++ b/imcui/third_party/MatchAnything/src/utils/utils.py
@@ -0,0 +1,600 @@
+from pathlib import Path
+import time
+from collections import OrderedDict
+from threading import Thread
+from loguru import logger
+from PIL import Image
+
+import numpy as np
+import cv2
+import torch
+import matplotlib.pyplot as plt
+import matplotlib
+matplotlib.use('Agg')
+
+class AverageTimer:
+ """ Class to help manage printing simple timing of code execution. """
+
+ def __init__(self, smoothing=0.3, newline=False):
+ self.smoothing = smoothing
+ self.newline = newline
+ self.times = OrderedDict()
+ self.will_print = OrderedDict()
+ self.reset()
+
+ def reset(self):
+ now = time.time()
+ self.start = now
+ self.last_time = now
+ for name in self.will_print:
+ self.will_print[name] = False
+
+ def update(self, name='default'):
+ now = time.time()
+ dt = now - self.last_time
+ if name in self.times:
+ dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name]
+ self.times[name] = dt
+ self.will_print[name] = True
+ self.last_time = now
+
+ def print(self, text='Timer'):
+ total = 0.
+ print('[{}]'.format(text), end=' ')
+ for key in self.times:
+ val = self.times[key]
+ if self.will_print[key]:
+ print('%s=%.3f' % (key, val), end=' ')
+ total += val
+ print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ')
+ if self.newline:
+ print(flush=True)
+ else:
+ print(end='\r', flush=True)
+ self.reset()
+
+
+class VideoStreamer:
+ """ Class to help process image streams. Four types of possible inputs:"
+ 1.) USB Webcam.
+ 2.) An IP camera
+ 3.) A directory of images (files in directory matching 'image_glob').
+ 4.) A video file, such as an .mp4 or .avi file.
+ """
+
+ def __init__(self, basedir, resize, skip, image_glob, max_length=1000000):
+ self._ip_grabbed = False
+ self._ip_running = False
+ self._ip_camera = False
+ self._ip_image = None
+ self._ip_index = 0
+ self.cap = []
+ self.camera = True
+ self.video_file = False
+ self.listing = []
+ self.resize = resize
+ self.interp = cv2.INTER_AREA
+ self.i = 0
+ self.skip = skip
+ self.max_length = max_length
+ if isinstance(basedir, int) or basedir.isdigit():
+ print('==> Processing USB webcam input: {}'.format(basedir))
+ self.cap = cv2.VideoCapture(int(basedir))
+ self.listing = range(0, self.max_length)
+ elif basedir.startswith(('http', 'rtsp')):
+ print('==> Processing IP camera input: {}'.format(basedir))
+ self.cap = cv2.VideoCapture(basedir)
+ self.start_ip_camera_thread()
+ self._ip_camera = True
+ self.listing = range(0, self.max_length)
+ elif Path(basedir).is_dir():
+ print('==> Processing image directory input: {}'.format(basedir))
+ self.listing = list(Path(basedir).glob(image_glob[0]))
+ for j in range(1, len(image_glob)):
+ image_path = list(Path(basedir).glob(image_glob[j]))
+ self.listing = self.listing + image_path
+ self.listing.sort()
+ self.listing = self.listing[::self.skip]
+ self.max_length = np.min([self.max_length, len(self.listing)])
+ if self.max_length == 0:
+ raise IOError('No images found (maybe bad \'image_glob\' ?)')
+ self.listing = self.listing[:self.max_length]
+ self.camera = False
+ elif Path(basedir).exists():
+ print('==> Processing video input: {}'.format(basedir))
+ self.cap = cv2.VideoCapture(basedir)
+ self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
+ num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.listing = range(0, num_frames)
+ self.listing = self.listing[::self.skip]
+ self.video_file = True
+ self.max_length = np.min([self.max_length, len(self.listing)])
+ self.listing = self.listing[:self.max_length]
+ else:
+ raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir))
+ if self.camera and not self.cap.isOpened():
+ raise IOError('Could not read camera')
+
+ def load_image(self, impath):
+ """ Read image as grayscale and resize to img_size.
+ Inputs
+ impath: Path to input image.
+ Returns
+ grayim: uint8 numpy array sized H x W.
+ """
+ grayim = cv2.imread(impath, 0)
+ if grayim is None:
+ raise Exception('Error reading image %s' % impath)
+ w, h = grayim.shape[1], grayim.shape[0]
+ w_new, h_new = process_resize(w, h, self.resize)
+ grayim = cv2.resize(
+ grayim, (w_new, h_new), interpolation=self.interp)
+ return grayim
+
+ def next_frame(self):
+ """ Return the next frame, and increment internal counter.
+ Returns
+ image: Next H x W image.
+ status: True or False depending whether image was loaded.
+ """
+
+ if self.i == self.max_length:
+ return (None, False)
+ if self.camera:
+
+ if self._ip_camera:
+ # Wait for first image, making sure we haven't exited
+ while self._ip_grabbed is False and self._ip_exited is False:
+ time.sleep(.001)
+
+ ret, image = self._ip_grabbed, self._ip_image.copy()
+ if ret is False:
+ self._ip_running = False
+ else:
+ ret, image = self.cap.read()
+ if ret is False:
+ print('VideoStreamer: Cannot get image from camera')
+ return (None, False)
+ w, h = image.shape[1], image.shape[0]
+ if self.video_file:
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i])
+
+ w_new, h_new = process_resize(w, h, self.resize)
+ image = cv2.resize(image, (w_new, h_new),
+ interpolation=self.interp)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ else:
+ image_file = str(self.listing[self.i])
+ image = self.load_image(image_file)
+ self.i = self.i + 1
+ return (image, True)
+
+ def start_ip_camera_thread(self):
+ self._ip_thread = Thread(target=self.update_ip_camera, args=())
+ self._ip_running = True
+ self._ip_thread.start()
+ self._ip_exited = False
+ return self
+
+ def update_ip_camera(self):
+ while self._ip_running:
+ ret, img = self.cap.read()
+ if ret is False:
+ self._ip_running = False
+ self._ip_exited = True
+ self._ip_grabbed = False
+ return
+
+ self._ip_image = img
+ self._ip_grabbed = ret
+ self._ip_index += 1
+ #print('IPCAMERA THREAD got frame {}'.format(self._ip_index))
+
+ def cleanup(self):
+ self._ip_running = False
+
+# --- PREPROCESSING ---
+
+
+def process_resize(w, h, resize):
+ assert(len(resize) > 0 and len(resize) <= 2)
+ if len(resize) == 1 and resize[0] > -1:
+ scale = resize[0] / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ elif len(resize) == 1 and resize[0] == -1:
+ w_new, h_new = w, h
+ else: # len(resize) == 2:
+ w_new, h_new = resize[0], resize[1]
+
+ # Issue warning if resolution is too small or too large.
+ if max(w_new, h_new) < 160:
+ print('Warning: input resolution is very small, results may vary')
+ elif max(w_new, h_new) > 2000:
+ print('Warning: input resolution is very large, results may vary')
+
+ return w_new, h_new
+
+
+def frame2tensor(frame, device):
+ """ Depth image to tensor
+ """
+ return torch.from_numpy(frame/255.).float()[None, None].to(device)
+
+
+def read_image(path, device, resize, rotation, resize_float):
+ image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
+ if image is None:
+ return None, None, None
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = process_resize(w, h, resize)
+ scales = (float(w) / float(w_new), float(h) / float(h_new))
+
+ if resize_float:
+ image = cv2.resize(image.astype('float32'), (w_new, h_new))
+ else:
+ image = cv2.resize(image, (w_new, h_new)).astype('float32')
+
+ if rotation != 0:
+ image = np.rot90(image, k=rotation)
+ if rotation % 2:
+ scales = scales[::-1]
+
+ inp = frame2tensor(image, device)
+ return image, inp, scales
+
+
+# --- GEOMETRY ---
+def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+
+ f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
+ norm_thresh = thresh / f_mean
+
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf,
+ method=cv2.RANSAC)
+
+ # assert E is not None # might cause unexpected exception in validation step
+ if E is None:
+ print("\nE is None while trying to recover pose.\n")
+ return None
+
+ best_num_inliers = 0
+ ret = None
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(
+ _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t[:, 0], mask.ravel() > 0)
+ return ret
+
+
+def estimate_pose_degensac(kpts0, kpts1, K0, K1, thresh, conf=0.9999, max_iters=1000, min_candidates=10):
+ import pydegensac
+ # TODO: Try different `min_candidatas`?
+ if len(kpts0) < min_candidates:
+ return None
+
+ F, mask = pydegensac.findFundamentalMatrix(kpts0,
+ kpts1,
+ px_th=thresh,
+ conf=conf,
+ max_iters=max_iters)
+ mask = mask.astype(np.uint8)
+ E = (K1.T @ F @ K0).astype(np.float64)
+
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+
+ # This might be optional (since DEGENSAC handle it internally ?)
+ best_num_inliers = 0
+ ret = None
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(
+ _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t[:, 0], mask.ravel() > 0)
+
+ return ret
+
+def rotate_intrinsics(K, image_shape, rot):
+ """image_shape is the shape of the image after rotation"""
+ assert rot <= 3
+ h, w = image_shape[:2][::-1 if (rot % 2) else 1]
+ fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
+ rot = rot % 4
+ if rot == 1:
+ return np.array([[fy, 0., cy],
+ [0., fx, w-1-cx],
+ [0., 0., 1.]], dtype=K.dtype)
+ elif rot == 2:
+ return np.array([[fx, 0., w-1-cx],
+ [0., fy, h-1-cy],
+ [0., 0., 1.]], dtype=K.dtype)
+ else: # if rot == 3:
+ return np.array([[fy, 0., h-1-cy],
+ [0., fx, cx],
+ [0., 0., 1.]], dtype=K.dtype)
+
+
+def rotate_pose_inplane(i_T_w, rot):
+ rotation_matrices = [
+ np.array([[np.cos(r), -np.sin(r), 0., 0.],
+ [np.sin(r), np.cos(r), 0., 0.],
+ [0., 0., 1., 0.],
+ [0., 0., 0., 1.]], dtype=np.float32)
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+ ]
+ return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+ scales = np.diag([1./scales[0], 1./scales[1], 1.])
+ return np.dot(scales, K)
+
+
+def to_homogeneous(points):
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
+
+
+def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1, enable_MEinPC=False):
+ """ Comupute the squared symmetric epipolar distance (SED^2).
+ The essential matrix is calculated with the relative pose T_0to1.
+ SED can be seen as a biased estimation of the reprojection error.
+ Args:
+ enable_MEinPC: Mean Error in Pixel Coordinate
+ """
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+ kpts0 = to_homogeneous(kpts0)
+ kpts1 = to_homogeneous(kpts1)
+
+ t0, t1, t2 = T_0to1[:3, 3]
+ t_skew = np.array([
+ [0, -t2, t1],
+ [t2, 0, -t0],
+ [-t1, t0, 0]
+ ])
+ E = t_skew @ T_0to1[:3, :3]
+
+ Ep0 = kpts0 @ E.T # N x 3
+ p1Ep0 = np.sum(kpts1 * Ep0, -1) # N
+ Etp1 = kpts1 @ E # N x 3
+ if enable_MEinPC:
+ d = 0.5 * np.abs(p1Ep0) * (np.linalg.norm([K1[0, 0], K1[1, 1]]) / np.linalg.norm([Ep0[:, 0], Ep0[:, 1]], axis=0)
+ + np.linalg.norm([K0[0, 0], K0[1, 1]]) / np.linalg.norm([Etp1[:, 0], Etp1[:, 1]], axis=0)) # N
+ else:
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2)
+ + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
+ return d
+
+
+def compute_homogeneous_error(kpts0, kpts1, H):
+ """ warp kpts0 to img1, compute error with kpts1
+ """
+ kpts0 = to_homogeneous(kpts0)
+
+ w_kpts0 = kpts0 @ H.T # N x 3
+ w_kpts0 = w_kpts0[:, :2] / w_kpts0[:, [2]]
+
+ d = np.linalg.norm(w_kpts0 - kpts1, axis=1)
+ return d
+
+
+def angle_error_mat(R1, R2):
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+ cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds
+ return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
+ R_gt = T_0to1[:3, :3]
+ t_gt = T_0to1[:3, 3]
+ error_t = angle_error_vec(t, t_gt)
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_R = angle_error_mat(R, R_gt)
+ if np.linalg.norm(t_gt) < ignore_gt_t_thr: # NOTE: as a close-to-zero translation is not good for angle_error calculation
+ error_t = 0
+ return error_t, error_R
+
+def convert_gt_T(T_0to1):
+ gt_R_degree = angle_error_mat(T_0to1[:, :3], np.eye(3))
+ gt_t_dist = np.linalg.norm(T_0to1[:, 3])
+ return gt_t_dist, gt_R_degree
+
+def pose_auc(errors, thresholds, ret_dict=False):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0., errors]
+ recall = np.r_[0., recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index-1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e)/t)
+ if ret_dict:
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+ else:
+ return aucs
+
+
+def epidist_prec(errors, thresholds, ret_dict=False):
+ precs = []
+ for thr in thresholds:
+ prec_ = []
+ for errs in errors:
+ correct_mask = errs < thr
+ prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
+ precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
+ if ret_dict:
+ return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
+ else:
+ return precs
+
+# --- VISUALIZATION ---
+def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
+ n = len(imgs)
+ assert n == 2, 'number of images must be two'
+ figsize = (size*n, size*3/4) if size is not None else None
+ _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255)
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ plt.tight_layout(pad=pad)
+
+
+def plot_keypoints(kpts0, kpts1, color='w', ps=2):
+ ax = plt.gcf().axes
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4):
+ fig = plt.gcf()
+ ax = fig.axes
+ fig.canvas.draw()
+
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1))
+
+ fig.lines = [matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1,
+ transform=fig.transFigure, c=color[i], linewidth=lw)
+ for i in range(len(kpts0))]
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
+ color, text, path=None, show_keypoints=False,
+ fast_viz=False, opencv_display=False,
+ opencv_title='matches', small_text=[]):
+
+ if fast_viz:
+ make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
+ color, text, path, show_keypoints, 10,
+ opencv_display, opencv_title, small_text)
+ return
+
+ plot_image_pair([image0, image1]) # will create a new figure
+ if show_keypoints:
+ plot_keypoints(kpts0, kpts1, color='k', ps=4)
+ plot_keypoints(kpts0, kpts1, color='w', ps=2)
+ plot_matches(mkpts0, mkpts1, color)
+
+ fig = plt.gcf()
+ txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes,
+ fontsize=5, va='bottom', ha='left', color=txt_color)
+ if path:
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+ else:
+ # TODO: Would it leads to any issue without current figure opened?
+ return fig
+
+
+def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0,
+ mkpts1, color, text, path=None,
+ show_keypoints=False, margin=10,
+ opencv_display=False, opencv_title='',
+ small_text=[]):
+ H0, W0 = image0.shape
+ H1, W1 = image1.shape
+ H, W = max(H0, H1), W0 + W1 + margin
+
+ out = 255*np.ones((H, W), np.uint8)
+ out[:H0, :W0] = image0
+ out[:H1, W0+margin:] = image1
+ out = np.stack([out]*3, -1)
+
+ if show_keypoints:
+ kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
+ white = (255, 255, 255)
+ black = (0, 0, 0)
+ for x, y in kpts0:
+ cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA)
+ cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA)
+ for x, y in kpts1:
+ cv2.circle(out, (x + margin + W0, y), 2, black, -1,
+ lineType=cv2.LINE_AA)
+ cv2.circle(out, (x + margin + W0, y), 1, white, -1,
+ lineType=cv2.LINE_AA)
+
+ mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
+ color = (np.array(color[:, :3])*255).astype(int)[:, ::-1]
+ for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color):
+ c = c.tolist()
+ cv2.line(out, (x0, y0), (x1 + margin + W0, y1),
+ color=c, thickness=1, lineType=cv2.LINE_AA)
+ # display line end-points as circles
+ cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA)
+ cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1,
+ lineType=cv2.LINE_AA)
+
+ # Scale factor for consistent visualization across scales.
+ sc = min(H / 640., 2.0)
+
+ # Big text.
+ Ht = int(30 * sc) # text height
+ txt_color_fg = (255, 255, 255)
+ txt_color_bg = (0, 0, 0)
+ for i, t in enumerate(text):
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0*sc, txt_color_bg, 2, cv2.LINE_AA)
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0*sc, txt_color_fg, 1, cv2.LINE_AA)
+
+ # Small text.
+ Ht = int(18 * sc) # text height
+ for i, t in enumerate(reversed(small_text)):
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
+ 0.5*sc, txt_color_bg, 2, cv2.LINE_AA)
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
+ 0.5*sc, txt_color_fg, 1, cv2.LINE_AA)
+
+ if path is not None:
+ cv2.imwrite(str(path), out)
+
+ if opencv_display:
+ cv2.imshow(opencv_title, out)
+ cv2.waitKey(1)
+
+ return out
+
+
+def error_colormap(x, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ return np.clip(
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
+
+def check_img_ok(img_path):
+ img_ok = True
+ try:
+ Image.open(str(img_path)).convert('RGB')
+ except:
+ img_ok = False
+ return img_ok
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/.gitignore b/imcui/third_party/MatchAnything/third_party/ROMA/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ff4633046059da69ab4b6e222909614ccda82ac4
--- /dev/null
+++ b/imcui/third_party/MatchAnything/third_party/ROMA/.gitignore
@@ -0,0 +1,5 @@
+*.egg-info*
+*.vscode*
+*__pycache__*
+vis*
+workspace*
\ No newline at end of file
diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/LICENSE b/imcui/third_party/MatchAnything/third_party/ROMA/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ca95157052a76debc473afb395bffae0c1329e63
--- /dev/null
+++ b/imcui/third_party/MatchAnything/third_party/ROMA/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Johan Edstedt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/README.md b/imcui/third_party/MatchAnything/third_party/ROMA/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..284d8f0bea84d7f67a416bc933067a3acfe23740
--- /dev/null
+++ b/imcui/third_party/MatchAnything/third_party/ROMA/README.md
@@ -0,0 +1,82 @@
+#
+