diff --git a/imcui/third_party/MatchAnything/LICENSE b/imcui/third_party/MatchAnything/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/imcui/third_party/MatchAnything/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/imcui/third_party/MatchAnything/README.md b/imcui/third_party/MatchAnything/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ceb1c2a3152cbf3c8d7e42190e54040efa38d23d --- /dev/null +++ b/imcui/third_party/MatchAnything/README.md @@ -0,0 +1,104 @@ +# MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training +### [Project Page](https://zju3dv.github.io/MatchAnything) | [Paper](??) + +> MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training\ +> [Xingyi He](https://hxy-123.github.io/), +[Hao Yu](https://ritianyu.github.io/), +[Sida Peng](https://pengsida.net), +[Dongli Tan](https://github.com/Cuistiano), +[Zehong Shen](https://zehongs.github.io), +[Xiaowei Zhou](https://xzhou.me/), +[Hujun Bao](http://www.cad.zju.edu.cn/home/bao/)\ +> Arxiv 2025 + +

+ animated +

+ +## TODO List +- [x] Pre-trained models and inference code +- [x] Huggingface demo +- [ ] Data generation and training code +- [ ] Finetune code to further train on your own data +- [ ] Incorporate more synthetic modalities and image generation methods + +## Quick Start + +### [ HuggingFace demo for MatchAnything](https://huggingface.co/spaces/LittleFrog/MatchAnything) + +## Setup +Create the python environment by: +``` +conda env create -f environment.yaml +conda activate env +``` +We have tested our code on the device with CUDA 11.7. + +Download pretrained weights from [here](https://drive.google.com/file/d/12L3g9-w8rR9K2L4rYaGaDJ7NqX1D713d/view?usp=sharing) and place it under repo directory. Then unzip it by running the following command: +``` +unzip weights.zip +rm -rf weights.zip +``` + +## Test: +We evaluate the models pretrained by our framework using a single network weight on all cross-modality matching and registration tasks. + +### Data Preparing +Download the `test_data` directory from [here](https://drive.google.com/drive/folders/1jpxIOcgnQfl9IEPPifdXQ7S7xuj9K4j7?usp=sharing) and plase it under `repo_directory/data`. Then, unzip all datasets by: +```shell +cd repo_directiry/data/test_data + +for file in *.zip; do + unzip "$file" && rm "$file" +done +``` + +The data structure should looks like: +``` +repo_directiry/data/test_data + - Liver_CT-MR + - havard_medical_matching + - remote_sense_thermal + - MTV_cross_modal_data + - thermal_visible_ground + - visible_sar_dataset + - visible_vectorized_map +``` + +### Evaluation +```shell +# For Tomography datasets: +sh scripts/evaluate/eval_liver_ct_mr.sh +sh scripts/evaluate/eval_harvard_brain.sh + + + +# For visible-thermal datasets: +sh scripts/evaluate/eval_thermal_remote_sense.sh +sh scripts/evaluate/eval_thermal_mtv.sh +sh scripts/evaluate/eval_thermal_ground.sh + +# For visible-sar dataset: +sh scripts/evaluate/eval_visible_sar.sh + +# For visible-vectorized map dataset: +sh scripts/evaluate/eval_visible_vectorized_map.sh +``` + +# Citation + +If you find this code useful for your research, please use the following BibTeX entry. + +``` +@inproceedings{he2025matchanything, +title={MatchAnything: Universal Cross-Modality Image Matching with Large-Scale Pre-Training}, +author={He, Xingyi and Yu, Hao and Peng, Sida and Tan, Dongli and Shen, Zehong and Bao, Hujun and Zhou, Xiaowei}, +booktitle={Arxiv}, +year={2025} +} +``` + +# Acknowledgement +We thank the authors of +[ELoFTR](https://github.com/zju3dv/EfficientLoFTR), +[ROMA](https://github.com/Parskatt/RoMa) for their great works, without which our project/code would not be possible. \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/configs/models/eloftr_model.py b/imcui/third_party/MatchAnything/configs/models/eloftr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..abbc030bb4181ea1d33d06f90797880ae03a18da --- /dev/null +++ b/imcui/third_party/MatchAnything/configs/models/eloftr_model.py @@ -0,0 +1,128 @@ +from src.config.default import _CN as cfg + +cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' +cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False + +cfg.TRAINER.CANONICAL_LR = 8e-3 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 + +cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16] + +# pose estimation +cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 + +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 + +cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.1 + +cfg.LOFTR.MATCH_COARSE.MTD_SPVS = True +cfg.LOFTR.FINE.MTD_SPVS = True + +cfg.LOFTR.RESOLUTION = (8, 1) # options: [(8, 2), (16, 4)] +cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be odd +cfg.LOFTR.MATCH_FINE.THR = 0 +cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2'] + +cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) + +cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True + +# PAN +cfg.LOFTR.COARSE.PAN = True +cfg.LOFTR.COARSE.POOl_SIZE = 4 +cfg.LOFTR.COARSE.BN = False +cfg.LOFTR.COARSE.XFORMER = True +cfg.LOFTR.COARSE.ATTENTION = 'full' # options: ['linear', 'full'] + +cfg.LOFTR.FINE.PAN = False +cfg.LOFTR.FINE.POOl_SIZE = 4 +cfg.LOFTR.FINE.BN = False +cfg.LOFTR.FINE.XFORMER = False + +# noalign +cfg.LOFTR.ALIGN_CORNER = False + +# fp16 +cfg.DATASET.FP16 = False +cfg.LOFTR.FP16 = False + +# DEBUG +cfg.LOFTR.FP16LOG = False +cfg.LOFTR.MATCH_COARSE.FP16LOG = False + +# fine skip +cfg.LOFTR.FINE.SKIP = True + +# clip +cfg.TRAINER.GRADIENT_CLIPPING = 0.5 + +# backbone +cfg.LOFTR.BACKBONE_TYPE = 'RepVGG' + +# A1 +cfg.LOFTR.RESNETFPN.INITIAL_DIM = 64 +cfg.LOFTR.RESNETFPN.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3 +cfg.LOFTR.COARSE.D_MODEL = 256 +cfg.LOFTR.FINE.D_MODEL = 64 + +# FPN backbone_inter_feat with coarse_attn. +cfg.LOFTR.COARSE_FEAT_ONLY = True +cfg.LOFTR.INTER_FEAT = True +cfg.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = True +cfg.LOFTR.RESNETFPN.INTER_FEAT = True + +# loop back spv coarse match +cfg.LOFTR.FORCE_LOOP_BACK = False + +# fix norm fine match +cfg.LOFTR.MATCH_FINE.NORMFINEM = True + +# loss cf weight +cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True +cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True + +# leaky relu +cfg.LOFTR.RESNETFPN.LEAKY = False +cfg.LOFTR.COARSE.LEAKY = 0.01 + +# prevent FP16 OVERFLOW in dirty data +cfg.LOFTR.NORM_FPNFEAT = True +cfg.LOFTR.REPLACE_NAN = True + +# force mutual nearest +cfg.LOFTR.MATCH_COARSE.FORCE_NEAREST = True +cfg.LOFTR.MATCH_COARSE.THR = 0.1 + +# fix fine matching +cfg.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = True + +# dwconv +cfg.LOFTR.COARSE.DWCONV = True + +# localreg +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS = True +cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25 + +# it5 +cfg.LOFTR.EVAL_TIMES = 1 + +# rope +cfg.LOFTR.COARSE.ROPE = True + +# local regress temperature +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0 + +# SLICE +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = True +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8 + +# inner with no mask [64,100] +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = True +cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = True + +cfg.LOFTR.MATCH_FINE.TOPK = 1 +cfg.LOFTR.MATCH_COARSE.FINE_TOPK = 1 + +cfg.LOFTR.MATCH_COARSE.FP16MATMUL = False \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/configs/models/roma_model.py b/imcui/third_party/MatchAnything/configs/models/roma_model.py new file mode 100644 index 0000000000000000000000000000000000000000..207ab322242234242b5e558e962f57dce91d0566 --- /dev/null +++ b/imcui/third_party/MatchAnything/configs/models/roma_model.py @@ -0,0 +1,27 @@ +from src.config.default import _CN as cfg +cfg.ROMA.RESIZE_BY_STRETCH = True +cfg.DATASET.RESIZE_BY_STRETCH = True + +cfg.TRAINER.CANONICAL_LR = 8e-3 +cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +cfg.TRAINER.WARMUP_RATIO = 0.1 + +cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18, 20] + +# pose estimation +cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 + +cfg.TRAINER.OPTIMIZER = "adamw" +cfg.TRAINER.ADAMW_DECAY = 0.1 +cfg.TRAINER.OPTIMIZER_EPS = 5e-7 + +cfg.TRAINER.EPI_ERR_THR = 5e-4 + +# fp16 +cfg.DATASET.FP16 = False +cfg.LOFTR.FP16 = True + +# clip +cfg.TRAINER.GRADIENT_CLIPPING = 0.5 + +cfg.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = True \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/environment.yaml b/imcui/third_party/MatchAnything/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d61b8e99932550f4c9c3f3a9e7e58e0b9bf68b4 --- /dev/null +++ b/imcui/third_party/MatchAnything/environment.yaml @@ -0,0 +1,14 @@ +name: env +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.8 + - pytorch-cuda=11.7 + - pytorch=1.12.1 + - torchvision=0.13.1 + - pip + - pip: + - -r requirements.txt \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81512278dfdfa73dd0915defa732b3b0e7db6af6 --- /dev/null +++ b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py @@ -0,0 +1 @@ +from .plotting import * \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..9993ed5b989d8babc87969e34d153ba3bcc05e1f --- /dev/null +++ b/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py @@ -0,0 +1,344 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.colors import hsv_to_rgb +import pylab as pl +import matplotlib.cm as cm +from PIL import Image +import cv2 + + +def visualize_features(feat, img_h, img_w, save_path=None): + from sklearn.decomposition import PCA + pca = PCA(n_components=3, svd_solver="arpack") + img = pca.fit_transform(feat).reshape(img_h * 2, img_w, 3) + img_norm = cv2.normalize( + img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC3 + ) + img_resized = cv2.resize( + img_norm, (img_w * 8, img_h * 2 * 8), interpolation=cv2.INTER_LINEAR + ) + img_colormap = img_resized + img1, img2 = img_colormap[: img_h * 8, :, :], img_colormap[img_h * 8 :, :, :] + img_gapped = np.hstack( + (img1, np.ones((img_h * 8, 10, 3), dtype=np.uint8) * 255, img2) + ) + if save_path is not None: + cv2.imwrite(save_path, img_gapped) + + fig, axes = plt.subplots(1, 1, dpi=200) + axes.imshow(img_gapped) + axes.get_yaxis().set_ticks([]) + axes.get_xaxis().set_ticks([]) + plt.tight_layout(pad=0.5) + return fig + +def make_matching_figure( + img0, + img1, + mkpts0, + mkpts1, + color, + kpts0=None, + kpts1=None, + text=[], + path=None, + draw_detection=False, + draw_match_type='corres', # ['color', 'corres', None] + r_normalize_factor=0.4, + white_center=True, + vertical=False, + use_position_color=False, + draw_local_window=False, + window_size=(9, 9), + plot_size_factor=1, # Point size and line width + anchor_pts0=None, + anchor_pts1=None, + rescale_thr=5000, +): + if (max(img0.shape) > rescale_thr) or (max(img1.shape) > rescale_thr): + scale_factor = 0.5 + img0 = np.array(Image.fromarray((img0 * 255).astype(np.uint8)).resize((int(img0.shape[1] * scale_factor), int(img0.shape[0] * scale_factor)))) / 255. + img1 = np.array(Image.fromarray((img1 * 255).astype(np.uint8)).resize((int(img1.shape[1] * scale_factor), int(img1.shape[0] * scale_factor)))) / 255. + mkpts0, mkpts1 = mkpts0 * scale_factor, mkpts1 * scale_factor + if kpts0 is not None: + kpts0, kpts1 = kpts0 * scale_factor, kpts1 * scale_factor + + # draw image pair + fig, axes = ( + plt.subplots(2, 1, figsize=(10, 6), dpi=600) + if vertical + else plt.subplots(1, 2, figsize=(10, 6), dpi=600) + ) + axes[0].imshow(img0, aspect='auto') + axes[1].imshow(img1, aspect='auto') + + # axes[0].imshow(img0, aspect='equal') + # axes[1].imshow(img1, aspect='equal') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if use_position_color: + mean_coord = np.mean(mkpts0, axis=0) + x_center, y_center = mean_coord + # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive. + position_color = matching_coord2color( + mkpts0, + x_center, + y_center, + r_normalize_factor=r_normalize_factor, + white_center=white_center, + ) + color[:, :3] = position_color + + if draw_detection and kpts0 is not None and kpts1 is not None: + # color = 'g' + color = 'r' + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=1 * plot_size_factor) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=1 * plot_size_factor) + + if draw_match_type is 'corres': + # draw matches + fig.canvas.draw() + plt.pause(2.0) + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color[i], + linewidth=1* plot_size_factor, + ) + for i in range(len(mkpts0)) + ] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=2* plot_size_factor) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=2* plot_size_factor) + elif draw_match_type is 'color': + # x_center = img0.shape[-1] / 2 + # y_center = img1.shape[-2] / 2 + + mean_coord = np.mean(mkpts0, axis=0) + x_center, y_center = mean_coord + # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive. + kpts_color = matching_coord2color( + mkpts0, + x_center, + y_center, + r_normalize_factor=r_normalize_factor, + white_center=white_center, + ) + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=kpts_color, s=1 * plot_size_factor) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=kpts_color, s=1 * plot_size_factor) + + if draw_local_window: + anchor_pts0 = mkpts0 if anchor_pts0 is None else anchor_pts0 + anchor_pts1 = mkpts1 if anchor_pts1 is None else anchor_pts1 + plot_local_windows( + anchor_pts0, color=(1, 0, 0, 0.4), lw=0.2, ax_=0, window_size=window_size + ) + plot_local_windows( + anchor_pts1, color=(1, 0, 0, 0.4), lw=0.2, ax_=1, window_size=window_size + ) # lw =0.2 + + # put txts + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + fig.text( + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) + plt.tight_layout(pad=1) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig + +def make_triple_matching_figure( + img0, + img1, + img2, + mkpts01, + mkpts12, + color01, + color12, + text=[], + path=None, + draw_match=True, + r_normalize_factor=0.4, + white_center=True, + vertical=False, + draw_local_window=False, + window_size=(9, 9), + anchor_pts0=None, + anchor_pts1=None, +): + # draw image pair + fig, axes = ( + plt.subplots(3, 1, figsize=(10, 6), dpi=600) + if vertical + else plt.subplots(1, 3, figsize=(10, 6), dpi=600) + ) + axes[0].imshow(img0) + axes[1].imshow(img1) + axes[2].imshow(img2) + for i in range(3): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if draw_match: + # draw matches for [0,1] + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts01[0])) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts01[1])) + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color01[i], + linewidth=1, + ) + for i in range(len(mkpts01[0])) + ] + + axes[0].scatter(mkpts01[0][:, 0], mkpts01[0][:, 1], c=color01[:, :3], s=1) + axes[1].scatter(mkpts01[1][:, 0], mkpts01[1][:, 1], c=color01[:, :3], s=1) + + fig.canvas.draw() + # draw matches for [1,2] + fkpts1_1 = transFigure.transform(axes[1].transData.transform(mkpts12[0])) + fkpts2 = transFigure.transform(axes[2].transData.transform(mkpts12[1])) + fig.lines += [ + matplotlib.lines.Line2D( + (fkpts1_1[i, 0], fkpts2[i, 0]), + (fkpts1_1[i, 1], fkpts2[i, 1]), + transform=fig.transFigure, + c=color12[i], + linewidth=1, + ) + for i in range(len(mkpts12[0])) + ] + + axes[1].scatter(mkpts12[0][:, 0], mkpts12[0][:, 1], c=color12[:, :3], s=1) + axes[2].scatter(mkpts12[1][:, 0], mkpts12[1][:, 1], c=color12[:, :3], s=1) + + # # put txts + # txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + # fig.text( + # 0.01, + # 0.99, + # "\n".join(text), + # transform=fig.axes[0].transAxes, + # fontsize=15, + # va="top", + # ha="left", + # color=txt_color, + # ) + plt.tight_layout(pad=0.1) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig + + +def matching_coord2color(kpts, x_center, y_center, r_normalize_factor=0.4, white_center=True): + """ + r_normalize_factor is used to visualize clearer according to points space distribution + r_normalize_factor maxium=1, larger->points darker/brighter + """ + if not white_center: + # dark center points + V, H = np.mgrid[0:1:10j, 0:1:360j] + S = np.ones_like(V) + else: + # white center points + S, H = np.mgrid[0:1:10j, 0:1:360j] + V = np.ones_like(S) + + HSV = np.dstack((H, S, V)) + RGB = hsv_to_rgb(HSV) + """ + # used to visualize hsv + pl.imshow(RGB, origin="lower", extent=[0, 360, 0, 1], aspect=150) + pl.xlabel("H") + pl.ylabel("S") + pl.title("$V_{HSV}=1$") + pl.show() + """ + kpts = np.copy(kpts) + distance = kpts - np.array([x_center, y_center])[None] + r_max = np.percentile(np.linalg.norm(distance, axis=1), 85) + # r_max = np.sqrt((x_center) ** 2 + (y_center) ** 2) + kpts[:, 0] = kpts[:, 0] - x_center # x + kpts[:, 1] = kpts[:, 1] - y_center # y + + r = np.sqrt(kpts[:, 0] ** 2 + kpts[:, 1] ** 2) + 1e-6 + r_normalized = r / (r_max * r_normalize_factor) + r_normalized[r_normalized > 1] = 1 + r_normalized = (r_normalized) * 9 + + cos_theta = kpts[:, 0] / r # x / r + theta = np.arccos(cos_theta) # from 0 to pi + change_angle_mask = kpts[:, 1] < 0 + theta[change_angle_mask] = 2 * np.pi - theta[change_angle_mask] + theta_degree = np.degrees(theta) + theta_degree[theta_degree == 360] = 0 # to avoid overflow + theta_degree = theta_degree / 360 * 360 + kpts_color = RGB[r_normalized.astype(int), theta_degree.astype(int)] + return kpts_color + + +def show_image_pair(img0, img1, path=None): + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=200) + axes[0].imshow(img0, cmap="gray") + axes[1].imshow(img1, cmap="gray") + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + return fig + +def plot_local_windows(kpts, color="r", lw=1, ax_=0, window_size=(9, 9)): + ax = plt.gcf().axes + for kpt in kpts: + ax[ax_].add_patch( + matplotlib.patches.Rectangle( + ( + kpt[0] - (window_size[0] // 2) - 1, + kpt[1] - (window_size[1] // 2) - 1, + ), + window_size[0] + 1, + window_size[1] + 1, + lw=lw, + color=color, + fill=False, + ) + ) + diff --git a/imcui/third_party/MatchAnything/requirements.txt b/imcui/third_party/MatchAnything/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..891cdb8becdb19790c4ac3659d69f8118d14bd2c --- /dev/null +++ b/imcui/third_party/MatchAnything/requirements.txt @@ -0,0 +1,22 @@ +opencv_python==4.4.0.46 +albumentations==0.5.1 --no-binary=imgaug,albumentations +Pillow==9.5.0 +ray==2.9.3 +einops==0.3.0 +kornia==0.4.1 +loguru==0.5.3 +yacs>=0.1.8 +tqdm +autopep8 +pylint +ipython +jupyterlab +matplotlib +h5py==3.1.0 +pytorch-lightning==1.3.5 +torchmetrics==0.6.0 # version problem: https://github.com/NVIDIA/DeepLearningExamples/issues/1113#issuecomment-1102969461 +joblib>=1.0.1 +pynvml +gpustat +safetensors +timm==0.6.7 \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh new file mode 100644 index 0000000000000000000000000000000000000000..a517ce7855dc9c87798a9a230fa8f21c4c06ca5a --- /dev/null +++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_harvard_brain.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l + +SCRIPTPATH=$(dirname $(readlink -f "$0")) +PROJECT_DIR="${SCRIPTPATH}/../../" + +cd $PROJECT_DIR + +DEVICE_ID='0' +NPZ_ROOT=data/test_data/havard_medical_matching/all_eval +NPZ_LIST_PATH=data/test_data/havard_medical_matching/all_eval/val_list.txt +OUTPUT_PATH=results/havard_medical_matching + +# ELoFTR pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH + +# ROMA pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh new file mode 100644 index 0000000000000000000000000000000000000000..f58b0b623565a3131f0be290c81e3842c6db4d3e --- /dev/null +++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_liver_ct_mr.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l + +SCRIPTPATH=$(dirname $(readlink -f "$0")) +PROJECT_DIR="${SCRIPTPATH}/../../" + +cd $PROJECT_DIR + +DEVICE_ID='0' +NPZ_ROOT=data/test_data/Liver_CT-MR/eval_indexs +NPZ_LIST_PATH=data/test_data/Liver_CT-MR/eval_indexs/val_list.txt +OUTPUT_PATH=results/Liver_CT-MR + +# ELoFTR pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH + +# ROMA pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f7d5ea30804e749c5eff6f94d9963073683e441 --- /dev/null +++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_ground.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l + +SCRIPTPATH=$(dirname $(readlink -f "$0")) +PROJECT_DIR="${SCRIPTPATH}/../../" + +cd $PROJECT_DIR + +DEVICE_ID='0' +NPZ_ROOT=data/test_data/thermal_visible_ground/eval_indexs +NPZ_LIST_PATH=data/test_data/thermal_visible_ground/eval_indexs/val_list.txt +OUTPUT_PATH=results/thermal_visible_ground + +# ELoFTR pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH + +# ROMA pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh new file mode 100644 index 0000000000000000000000000000000000000000..faf816d4a377626f35e3549b1de8c65a8b4ae540 --- /dev/null +++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_mtv.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l + +SCRIPTPATH=$(dirname $(readlink -f "$0")) +PROJECT_DIR="${SCRIPTPATH}/../../" + +cd $PROJECT_DIR + +DEVICE_ID='0' +NPZ_ROOT=data/test_data/MTV_cross_modal_data/scene_info/scene_info +NPZ_LIST_PATH=data/test_data/MTV_cross_modal_data/scene_info/test_list.txt +OUTPUT_PATH=results/MTV_cross_modal_data + +# ELoFTR pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH + +# ROMA pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh new file mode 100644 index 0000000000000000000000000000000000000000..afd0c62edb3ba9e9ffc4a3a4140a38c3fd1d475e --- /dev/null +++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_thermal_remote_sense.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l + +SCRIPTPATH=$(dirname $(readlink -f "$0")) +PROJECT_DIR="${SCRIPTPATH}/../../" + +cd $PROJECT_DIR + +DEVICE_ID='0' +NPZ_ROOT=data/test_data/remote_sense_thermal/eval_Optical-Infrared +NPZ_LIST_PATH=data/test_data/remote_sense_thermal/eval_Optical-Infrared/val_list.txt +OUTPUT_PATH=results/remote_sense_thermal + +# ELoFTR pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH + +# ROMA pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh new file mode 100644 index 0000000000000000000000000000000000000000..231f303a642ce890c7e90470120ef8744d89455e --- /dev/null +++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_sar.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l + +SCRIPTPATH=$(dirname $(readlink -f "$0")) +PROJECT_DIR="${SCRIPTPATH}/../../" + +cd $PROJECT_DIR + +DEVICE_ID='0' +NPZ_ROOT=data/test_data/visible_sar_dataset/eval +NPZ_LIST_PATH=data/test_data/visible_sar_dataset/eval/val_list.txt +OUTPUT_PATH=results/visible_sar_dataset + +# ELoFTR pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --thr 0.05 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH + +# ROMA pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh new file mode 100644 index 0000000000000000000000000000000000000000..a8bb56f0a8273409fa51532e61323d423ad04191 --- /dev/null +++ b/imcui/third_party/MatchAnything/scripts/evaluate/eval_visible_vectorized_map.sh @@ -0,0 +1,17 @@ +#!/bin/bash -l + +SCRIPTPATH=$(dirname $(readlink -f "$0")) +PROJECT_DIR="${SCRIPTPATH}/../../" + +cd $PROJECT_DIR + +DEVICE_ID='0' +NPZ_ROOT=data/test_data/visible_vectorized_map/scene_indices +NPZ_LIST_PATH=data/test_data/visible_vectorized_map/scene_indices/val_list.txt +OUTPUT_PATH=results/visible_vectorized_map + +# ELoFTR pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/eloftr_model.py --ckpt_path weights/matchanything_eloftr.ckpt --method matchanything_eloftr@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH + +# ROMA pretrained: +CUDA_VISIBLE_DEVICES=$DEVICE_ID python tools/evaluate_datasets.py configs/models/roma_model.py --ckpt_path weights/matchanything_roma.ckpt --method matchanything_roma@-@ransac_affine --imgresize 832 --npe --npz_root $NPZ_ROOT --npz_list_path $NPZ_LIST_PATH --output_path $OUTPUT_PATH \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/__init__.py b/imcui/third_party/MatchAnything/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/MatchAnything/src/config/default.py b/imcui/third_party/MatchAnything/src/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..0b43845bfec84071b29b96012bf5401a889327ed --- /dev/null +++ b/imcui/third_party/MatchAnything/src/config/default.py @@ -0,0 +1,344 @@ +from yacs.config import CfgNode as CN +_CN = CN() +############## ROMA Pipeline ######### +_CN.ROMA = CN() +_CN.ROMA.MATCH_THRESH = 0.0 +_CN.ROMA.RESIZE_BY_STRETCH = False # Used for test mode +_CN.ROMA.NORMALIZE_IMG = False # Used for test mode + +_CN.ROMA.MODE = "train_framework" # Used in Lightning Train & Val +_CN.ROMA.MODEL = CN() +_CN.ROMA.MODEL.COARSE_BACKBONE = 'DINOv2_large' +_CN.ROMA.MODEL.COARSE_FEAT_DIM = 1024 +_CN.ROMA.MODEL.MEDIUM_FEAT_DIM = 512 +_CN.ROMA.MODEL.COARSE_PATCH_SIZE = 14 +_CN.ROMA.MODEL.AMP = True # FP16 mode + +_CN.ROMA.SAMPLE = CN() +_CN.ROMA.SAMPLE.METHOD = "threshold_balanced" +_CN.ROMA.SAMPLE.N_SAMPLE = 5000 +_CN.ROMA.SAMPLE.THRESH = 0.05 + +_CN.ROMA.TEST_TIME = CN() +_CN.ROMA.TEST_TIME.COARSE_RES = (560, 560) # need to divisable by 14 & 8 +_CN.ROMA.TEST_TIME.UPSAMPLE = True +_CN.ROMA.TEST_TIME.UPSAMPLE_RES = (864, 864) # need to divisable by 8 +_CN.ROMA.TEST_TIME.SYMMETRIC = True +_CN.ROMA.TEST_TIME.ATTENUTATE_CERT = True + +############## ↓ LoFTR Pipeline ↓ ############## +_CN.LOFTR = CN() +_CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN' +_CN.LOFTR.ALIGN_CORNER = True +_CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.LOFTR.FINE_WINDOW_MATCHING_SIZE = 5 # window_size for loftr fine-matching, odd for select and even for average +_CN.LOFTR.FINE_CONCAT_COARSE_FEAT = True +_CN.LOFTR.FINE_SAMPLE_COARSE_FEAT = False +_CN.LOFTR.COARSE_FEAT_ONLY = False # TO BE DONE +_CN.LOFTR.INTER_FEAT = False # FPN backbone inter feat with coarse_attn. +_CN.LOFTR.FP16 = False +_CN.LOFTR.FIX_BIAS = False +_CN.LOFTR.MATCHABILITY = False +_CN.LOFTR.FORCE_LOOP_BACK = False +_CN.LOFTR.NORM_FPNFEAT = False +_CN.LOFTR.NORM_FPNFEAT2 = False +_CN.LOFTR.REPLACE_NAN = False +_CN.LOFTR.PLOT_SCORES = False +_CN.LOFTR.REP_FPN = False +_CN.LOFTR.REP_DEPLOY = False +_CN.LOFTR.EVAL_TIMES = 1 + +# 1. LoFTR-backbone (local feature CNN) config +_CN.LOFTR.RESNETFPN = CN() +_CN.LOFTR.RESNETFPN.INITIAL_DIM = 128 +_CN.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 +_CN.LOFTR.RESNETFPN.SAMPLE_FINE = False +_CN.LOFTR.RESNETFPN.COARSE_FEAT_ONLY = False # TO BE DONE +_CN.LOFTR.RESNETFPN.INTER_FEAT = False # FPN backbone inter feat with coarse_attn. +_CN.LOFTR.RESNETFPN.LEAKY = False +_CN.LOFTR.RESNETFPN.REPVGGMODEL = None + +# 2. LoFTR-coarse module config +_CN.LOFTR.COARSE = CN() +_CN.LOFTR.COARSE.D_MODEL = 256 +_CN.LOFTR.COARSE.D_FFN = 256 +_CN.LOFTR.COARSE.NHEAD = 8 +_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.LOFTR.COARSE.TEMP_BUG_FIX = True +_CN.LOFTR.COARSE.NPE = False +_CN.LOFTR.COARSE.PAN = False +_CN.LOFTR.COARSE.POOl_SIZE = 4 +_CN.LOFTR.COARSE.POOl_SIZE2 = 4 +_CN.LOFTR.COARSE.BN = True +_CN.LOFTR.COARSE.XFORMER = False +_CN.LOFTR.COARSE.BIDIRECTION = False +_CN.LOFTR.COARSE.DEPTH_CONFIDENCE = -1.0 +_CN.LOFTR.COARSE.WIDTH_CONFIDENCE = -1.0 +_CN.LOFTR.COARSE.LEAKY = -1.0 +_CN.LOFTR.COARSE.ASYMMETRIC = False +_CN.LOFTR.COARSE.ASYMMETRIC_SELF = False +_CN.LOFTR.COARSE.ROPE = False +_CN.LOFTR.COARSE.TOKEN_MIXER = None +_CN.LOFTR.COARSE.SKIP = False +_CN.LOFTR.COARSE.DWCONV = False +_CN.LOFTR.COARSE.DWCONV2 = False +_CN.LOFTR.COARSE.SCATTER = False +_CN.LOFTR.COARSE.ROPE = False +_CN.LOFTR.COARSE.NPE = None +_CN.LOFTR.COARSE.NORM_BEFORE = True +_CN.LOFTR.COARSE.VIT_NORM = False +_CN.LOFTR.COARSE.ROPE_DWPROJ = False +_CN.LOFTR.COARSE.ABSPE = False + + +# 3. Coarse-Matching config +_CN.LOFTR.MATCH_COARSE = CN() +_CN.LOFTR.MATCH_COARSE.THR = 0.2 +_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2 +_CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3 +_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False +_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock +_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True +_CN.LOFTR.MATCH_COARSE.MTD_SPVS = False +_CN.LOFTR.MATCH_COARSE.FIX_BIAS = False +_CN.LOFTR.MATCH_COARSE.BINARY = False +_CN.LOFTR.MATCH_COARSE.BINARY_SPV = 'l2' +_CN.LOFTR.MATCH_COARSE.NORMFEAT = False +_CN.LOFTR.MATCH_COARSE.NORMFEATMUL = False +_CN.LOFTR.MATCH_COARSE.DIFFSIGN2 = False +_CN.LOFTR.MATCH_COARSE.DIFFSIGN3 = False +_CN.LOFTR.MATCH_COARSE.CLASSIFY = False +_CN.LOFTR.MATCH_COARSE.D_CLASSIFY = 256 +_CN.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = False +_CN.LOFTR.MATCH_COARSE.FORCE_NEAREST = False # in case binary is True, force nearest neighbor, preventing finding a reasonable threshold +_CN.LOFTR.MATCH_COARSE.FP16MATMUL = False +_CN.LOFTR.MATCH_COARSE.SEQSOFTMAX = False +_CN.LOFTR.MATCH_COARSE.SEQSOFTMAX2 = False +_CN.LOFTR.MATCH_COARSE.RATIO_TEST = False +_CN.LOFTR.MATCH_COARSE.RATIO_TEST_VAL = -1.0 +_CN.LOFTR.MATCH_COARSE.USE_GT_COARSE = False +_CN.LOFTR.MATCH_COARSE.CROSS_SOFTMAX = False +_CN.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES = False +_CN.LOFTR.MATCH_COARSE.USE_PERCENT_THR = False +_CN.LOFTR.MATCH_COARSE.PERCENT_THR = 0.1 +_CN.LOFTR.MATCH_COARSE.ADD_SIGMOID = False +_CN.LOFTR.MATCH_COARSE.SIGMOID_BIAS = 20.0 +_CN.LOFTR.MATCH_COARSE.SIGMOID_SIGMA = 2.5 +_CN.LOFTR.MATCH_COARSE.CAL_PER_OF_GT = False + +# 4. LoFTR-fine module config +_CN.LOFTR.FINE = CN() +_CN.LOFTR.FINE.SKIP = False +_CN.LOFTR.FINE.D_MODEL = 128 +_CN.LOFTR.FINE.D_FFN = 128 +_CN.LOFTR.FINE.NHEAD = 8 +_CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1 +_CN.LOFTR.FINE.ATTENTION = 'linear' +_CN.LOFTR.FINE.MTD_SPVS = False +_CN.LOFTR.FINE.PAN = False +_CN.LOFTR.FINE.POOl_SIZE = 4 +_CN.LOFTR.FINE.BN = True +_CN.LOFTR.FINE.XFORMER = False +_CN.LOFTR.FINE.BIDIRECTION = False + + +# Fine-Matching config +_CN.LOFTR.MATCH_FINE = CN() +_CN.LOFTR.MATCH_FINE.THR = 0 +_CN.LOFTR.MATCH_FINE.TOPK = 3 +_CN.LOFTR.MATCH_FINE.NORMFINEM = False +_CN.LOFTR.MATCH_FINE.USE_GT_FINE = False +_CN.LOFTR.MATCH_COARSE.FINE_TOPK = _CN.LOFTR.MATCH_FINE.TOPK +_CN.LOFTR.MATCH_FINE.FIX_FINE_MATCHING = False +_CN.LOFTR.MATCH_FINE.SKIP_FINE_SOFTMAX = False +_CN.LOFTR.MATCH_FINE.USE_SIGMOID = False +_CN.LOFTR.MATCH_FINE.SIGMOID_BIAS = 0.0 +_CN.LOFTR.MATCH_FINE.NORMFEAT = False +_CN.LOFTR.MATCH_FINE.SPARSE_SPVS = True +_CN.LOFTR.MATCH_FINE.FORCE_NEAREST = False +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS = False +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_RMBORDER = False +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK = False +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 1.0 +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_PADONE = False +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICE = False +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8 +_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_INNER = False +_CN.LOFTR.MATCH_FINE.MULTI_REGRESS = False + + + +# 5. LoFTR Losses +# -- # coarse-level +_CN.LOFTR.LOSS = CN() +_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy'] +_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0 +_CN.LOFTR.LOSS.COARSE_SIGMOID_WEIGHT = 1.0 +_CN.LOFTR.LOSS.LOCAL_WEIGHT = 0.5 +_CN.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = False +_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = False +_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2 = False +# _CN.LOFTR.LOSS.SPARSE_SPVS = False +# -- - -- # focal loss (coarse) +_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25 +_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0 +_CN.LOFTR.LOSS.POS_WEIGHT = 1.0 +_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0 +_CN.LOFTR.LOSS.CORRECT_NEG_WEIGHT = False +# _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not. +# use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE` + +# -- # fine-level +_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0 +_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) + +# -- # ROMA: +_CN.LOFTR.ROMA_LOSS = CN() +_CN.LOFTR.ROMA_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2'] + +# -- # DKM: +_CN.LOFTR.DKM_LOSS = CN() +_CN.LOFTR.DKM_LOSS.IGNORE_EMPTY_IN_SPARSE_MATCH_SPV = False # ['l2_with_std', 'l2'] + +############## Dataset ############## +_CN.DATASET = CN() +# 1. data config +# training and validating +_CN.DATASET.TB_LOG_DIR= "logs/tb_logs" # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.TRAIN_DATA_SAMPLE_RATIO = [1.0] # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.TRAIN_DATA_ROOT = None +_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TRAIN_NPZ_ROOT = None +_CN.DATASET.TRAIN_LIST_PATH = None +_CN.DATASET.TRAIN_INTRINSIC_PATH = None +_CN.DATASET.VAL_DATA_ROOT = None +_CN.DATASET.VAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] +_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.VAL_NPZ_ROOT = None +_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file +_CN.DATASET.VAL_INTRINSIC_PATH = None +_CN.DATASET.FP16 = False +_CN.DATASET.TRAIN_GT_MATCHES_PADDING_N = 8000 +# testing +_CN.DATASET.TEST_DATA_SOURCE = None +_CN.DATASET.TEST_DATA_ROOT = None +_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) +_CN.DATASET.TEST_NPZ_ROOT = None +_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file +_CN.DATASET.TEST_INTRINSIC_PATH = None + +# 2. dataset config +# general options +_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score +_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 +_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] + +# debug options +_CN.DATASET.TEST_N_PAIRS = None # Debug first N pairs +# DEBUG +_CN.LOFTR.FP16LOG = False +_CN.LOFTR.MATCH_COARSE.FP16LOG = False + +# scanNet options +_CN.DATASET.SCAN_IMG_RESIZEX = 640 # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.SCAN_IMG_RESIZEY = 480 # resize the shorter side, zero-pad bottom-right to square. + +# MegaDepth options +_CN.DATASET.MGDPT_IMG_RESIZE = (640, 640) # resize the longer side, zero-pad bottom-right to square. +_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE +_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 +_CN.DATASET.MGDPT_DF = 8 +_CN.DATASET.LOAD_ORIGIN_RGB = False # Only open in test mode, useful for RGB required baselines such as DKM, ROMA. +_CN.DATASET.READ_GRAY = True +_CN.DATASET.RESIZE_BY_STRETCH = False +_CN.DATASET.NORMALIZE_IMG = False # For backbone using pretrained DINO feats, use True may be better. +_CN.DATASET.HOMO_WARP_USE_MASK = False + +_CN.DATASET.NPE_NAME = "megadepth" + +############## Trainer ############## +_CN.TRAINER = CN() +_CN.TRAINER.WORLD_SIZE = 1 +_CN.TRAINER.CANONICAL_BS = 64 +_CN.TRAINER.CANONICAL_LR = 6e-3 +_CN.TRAINER.SCALING = None # this will be calculated automatically +_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning + +# optimizer +_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] +_CN.TRAINER.OPTIMIZER_EPS = 1e-8 # Default for optimizers, but set smaller, e.g., 1e-7 for fp16 mix training +_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime +_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam +_CN.TRAINER.ADAMW_DECAY = 0.1 + +# step-based warm-up +_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] +_CN.TRAINER.WARMUP_RATIO = 0. +_CN.TRAINER.WARMUP_STEP = 4800 + +# learning rate scheduler +_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] +_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] +_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR +_CN.TRAINER.MSLR_GAMMA = 0.5 +_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing +_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval + +# plotting related +_CN.TRAINER.ENABLE_PLOTTING = True +_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 8 # number of val/test paris for plotting +_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] +_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' + +# geometric metrics and pose solver +_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) +_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] +_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.WARP_ESTIMATOR_MODEL = 'affine' # [RANSAC, DEGENSAC, MAGSAC] +_CN.TRAINER.RANSAC_PIXEL_THR = 0.5 +_CN.TRAINER.RANSAC_CONF = 0.99999 +_CN.TRAINER.RANSAC_MAX_ITERS = 10000 +_CN.TRAINER.USE_MAGSACPP = False +_CN.TRAINER.THRESHOLDS = [5, 10, 20] + +# data sampler for train_dataloader +_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] +# 'scene_balance' config +_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 +_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not +_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not +_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data +_CN.TRAINER.AUC_METHOD = 'exact_auc' +# 'random' config +_CN.TRAINER.RDM_REPLACEMENT = True +_CN.TRAINER.RDM_NUM_SAMPLES = None + +# gradient clipping +_CN.TRAINER.GRADIENT_CLIPPING = 0.5 + +# Finetune Mode: +_CN.FINETUNE = CN() +_CN.FINETUNE.ENABLE = False +_CN.FINETUNE.METHOD = "lora" #['lora', 'whole_network'] + +_CN.FINETUNE.LORA = CN() +_CN.FINETUNE.LORA.RANK = 2 +_CN.FINETUNE.LORA.MODE = "linear&conv" # ["linear&conv", "linear_only"] +_CN.FINETUNE.LORA.SCALE = 1.0 + +_CN.TRAINER.SEED = 66 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py b/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..91a5132809b88d14d8f02ddf7012c5abba8c46ec --- /dev/null +++ b/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py @@ -0,0 +1,343 @@ + +from collections import defaultdict +import pprint +from loguru import logger +from pathlib import Path + +import torch +import numpy as np +import pytorch_lightning as pl +from matplotlib import pyplot as plt + +from src.loftr import LoFTR +from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine, compute_roma_supervision +from src.optimizers import build_optimizer, build_scheduler +from src.utils.metrics import ( + compute_symmetrical_epipolar_errors, + compute_pose_errors, + compute_homo_corner_warp_errors, + compute_homo_match_warp_errors, + compute_warp_control_pts_errors, + aggregate_metrics +) +from src.utils.plotting import make_matching_figures, make_scores_figures +from src.utils.comm import gather, all_gather +from src.utils.misc import lower_config, flattenList +from src.utils.profiler import PassThroughProfiler +from third_party.ROMA.roma.matchanything_roma_model import MatchAnything_Model + +import pynvml + +def reparameter(matcher): + module = matcher.backbone.layer0 + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + print('m0 switch to deploy ok') + for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]: + for module in modules: + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + print('backbone switch to deploy ok') + for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]: + for module in modules: + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + print('fpn switch to deploy ok') + return matcher + +class PL_LoFTR(pl.LightningModule): + def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None, test_mode=False, baseline_config=None): + """ + TODO: + - use the new version of PL logging API. + """ + super().__init__() + # Misc + self.config = config # full config + _config = lower_config(self.config) + self.profiler = profiler or PassThroughProfiler() + self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1) + + if config.METHOD == "matchanything_eloftr": + self.matcher = LoFTR(config=_config['loftr'], profiler=self.profiler) + elif config.METHOD == "matchanything_roma": + self.matcher = MatchAnything_Model(config=_config['roma'], test_mode=test_mode) + else: + raise NotImplementedError + + if config.FINETUNE.ENABLE and test_mode: + # Inference time change model architecture before load pretrained model: + raise NotImplementedError + + # Pretrained weights + if pretrained_ckpt: + if config.METHOD in ["matchanything_eloftr", "matchanything_roma"]: + state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict'] + logger.info(f"Load model from:{self.matcher.load_state_dict(state_dict, strict=False)}") + else: + raise NotImplementedError + + if self.config.LOFTR.BACKBONE_TYPE == 'RepVGG' and test_mode and (config.METHOD == 'loftr'): + module = self.matcher.backbone.layer0 + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + print('m0 switch to deploy ok') + for modules in [self.matcher.backbone.layer1, self.matcher.backbone.layer2, self.matcher.backbone.layer3]: + for module in modules: + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + print('m switch to deploy ok') + + # Testing + self.dump_dir = dump_dir + self.max_gpu_memory = 0 + self.GPUID = 0 + self.warmup = False + + def gpumem(self, des, gpuid=None): + NUM_EXPAND = 1024 * 1024 * 1024 + gpu_id= self.GPUID if self.GPUID is not None else gpuid + handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_Used = info.used + logger.info(f"GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}") + # print(des, gpu_Used / NUM_EXPAND) + if gpu_Used / NUM_EXPAND > self.max_gpu_memory: + self.max_gpu_memory = gpu_Used / NUM_EXPAND + logger.info(f"[MAX]GPU {gpu_id} memory used: {gpu_Used / NUM_EXPAND} GB while {des}") + print('max_gpu_memory', self.max_gpu_memory) + + def configure_optimizers(self): + optimizer = build_optimizer(self, self.config) + scheduler = build_scheduler(self.config, optimizer) + return [optimizer], [scheduler] + + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, + optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + # learning rate warm up + warmup_step = self.config.TRAINER.WARMUP_STEP + if self.trainer.global_step < warmup_step: + if self.config.TRAINER.WARMUP_TYPE == 'linear': + base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR + lr = base_lr + \ + (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ + abs(self.config.TRAINER.TRUE_LR - base_lr) + for pg in optimizer.param_groups: + pg['lr'] = lr + elif self.config.TRAINER.WARMUP_TYPE == 'constant': + pass + else: + raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') + + # update params + if self.config.LOFTR.FP16: + optimizer.step(closure=optimizer_closure) + else: + optimizer.step(closure=optimizer_closure) + optimizer.zero_grad() + + def _trainval_inference(self, batch): + with self.profiler.profile("Compute coarse supervision"): + + with torch.autocast(enabled=False, device_type='cuda'): + if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD): + pass + else: + compute_supervision_coarse(batch, self.config) + + with self.profiler.profile("LoFTR"): + with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'): + self.matcher(batch) + + with self.profiler.profile("Compute fine supervision"): + with torch.autocast(enabled=False, device_type='cuda'): + if ("roma" in self.config.METHOD) or ('dkm' in self.config.METHOD): + compute_roma_supervision(batch, self.config) + else: + compute_supervision_fine(batch, self.config, self.logger) + + with self.profiler.profile("Compute losses"): + pass + + def _compute_metrics(self, batch): + if 'gt_2D_matches' in batch: + compute_warp_control_pts_errors(batch, self.config) + elif batch['homography'].sum() != 0 and batch['T_0to1'].sum() == 0: + compute_homo_match_warp_errors(batch, self.config) # compute warp_errors for each match + compute_homo_corner_warp_errors(batch, self.config) # compute mean corner warp error each pair + else: + compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match + compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair + + rel_pair_names = list(zip(*batch['pair_names'])) + bs = batch['image0'].size(0) + if self.config.LOFTR.FINE.MTD_SPVS: + topk = self.config.LOFTR.MATCH_FINE.TOPK + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [(batch['epi_errs'].reshape(-1,topk))[batch['m_bids'] == b].reshape(-1).cpu().numpy() for b in range(bs)], + 'R_errs': batch['R_errs'], + 't_errs': batch['t_errs'], + 'inliers': batch['inliers'], + 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only + 'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only + } + else: + metrics = { + # to filter duplicate pairs caused by DistributedSampler + 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], + 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], + 'R_errs': batch['R_errs'], + 't_errs': batch['t_errs'], + 'inliers': batch['inliers'], + 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only + 'percent_inliers': [ batch['inliers'][0].shape[0] / batch['mconf'].shape[0] if batch['mconf'].shape[0]!=0 else 1], # batch size = 1 only + } + ret_dict = {'metrics': metrics} + return ret_dict, rel_pair_names + + def training_step(self, batch, batch_idx): + self._trainval_inference(batch) + + # logging + if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0: + # scalars + for k, v in batch['loss_scalars'].items(): + self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step) + + # net-params + method = 'LOFTR' + if self.config[method]['MATCH_COARSE']['MATCH_TYPE'] == 'sinkhorn': + self.logger.experiment.add_scalar( + f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step) + + figures = {} + if self.config.TRAINER.ENABLE_PLOTTING: + compute_symmetrical_epipolar_errors(batch, self.config) # compute epi_errs for each match + figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + for k, v in figures.items(): + self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step) + + return {'loss': batch['loss']} + + def training_epoch_end(self, outputs): + avg_loss = torch.stack([x['loss'] for x in outputs]).mean() + if self.trainer.global_rank == 0: + self.logger.experiment.add_scalar( + 'train/avg_loss_on_epoch', avg_loss, + global_step=self.current_epoch) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + self._trainval_inference(batch) + + ret_dict, _ = self._compute_metrics(batch) + + val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) + figures = {self.config.TRAINER.PLOT_MODE: []} + if batch_idx % val_plot_interval == 0: + figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE) + if self.config.LOFTR.PLOT_SCORES: + figs = make_scores_figures(batch, self.config, self.config.TRAINER.PLOT_MODE) + figures[self.config.TRAINER.PLOT_MODE] += figs[self.config.TRAINER.PLOT_MODE] + del figs + + return { + **ret_dict, + 'loss_scalars': batch['loss_scalars'], + 'figures': figures, + } + + def validation_epoch_end(self, outputs): + # handle multiple validation sets + multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs + multi_val_metrics = defaultdict(list) + + for valset_idx, outputs in enumerate(multi_outputs): + # since pl performs sanity_check at the very begining of the training + cur_epoch = self.trainer.current_epoch + if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check: + cur_epoch = -1 + + # 1. loss_scalars: dict of list, on cpu + _loss_scalars = [o['loss_scalars'] for o in outputs] + loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]} + + # 2. val metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0 + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES) + for thr in [5, 10, 20]: + multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}']) + + # 3. figures + _figures = [o['figures'] for o in outputs] + figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} + + # tensorboard records only on rank 0 + if self.trainer.global_rank == 0: + for k, v in loss_scalars.items(): + mean_v = torch.stack(v).mean() + self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch) + + for k, v in val_metrics_4tb.items(): + self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch) + + for k, v in figures.items(): + if self.trainer.global_rank == 0: + for plot_idx, fig in enumerate(v): + self.logger.experiment.add_figure( + f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True) + plt.close('all') + + for thr in [5, 10, 20]: + self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this + + def test_step(self, batch, batch_idx): + if self.warmup: + for i in range(50): + self.matcher(batch) + self.warmup = False + + with torch.autocast(enabled=self.config.LOFTR.FP16, device_type='cuda'): + with self.profiler.profile("LoFTR"): + self.matcher(batch) + + ret_dict, rel_pair_names = self._compute_metrics(batch) + print(ret_dict['metrics']['num_matches']) + self.dump_dir = None + + return ret_dict + + def test_epoch_end(self, outputs): + print(self.config) + print('max GPU memory: ', self.max_gpu_memory) + print(self.profiler.summary()) + # metrics: dict of list, numpy + _metrics = [o['metrics'] for o in outputs] + metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} + + # [{key: [{...}, *#bs]}, *#batch] + if self.dump_dir is not None: + Path(self.dump_dir).mkdir(parents=True, exist_ok=True) + _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] + dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] + logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') + + if self.trainer.global_rank == 0: + NUM_EXPAND = 1024 * 1024 * 1024 + gpu_id=self.GPUID + handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_Used = info.used + print('pynvml', gpu_Used / NUM_EXPAND) + if gpu_Used / NUM_EXPAND > self.max_gpu_memory: + self.max_gpu_memory = gpu_Used / NUM_EXPAND + + print(self.profiler.summary()) + val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, self.config.LOFTR.EVAL_TIMES, self.config.TRAINER.THRESHOLDS, method=self.config.TRAINER.AUC_METHOD) + logger.info('\n' + pprint.pformat(val_metrics_4tb)) + if self.dump_dir is not None: + np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/__init__.py b/imcui/third_party/MatchAnything/src/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82e7da71337eb772257c9a2b6c96b41a562aadea --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/__init__.py @@ -0,0 +1 @@ +from .loftr import LoFTR \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py b/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0eb682da64b684eeddcc5ea576b6e89137dd40b --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py @@ -0,0 +1,61 @@ +from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4, ResNetFPN_8_1, ResNetFPN_8_2_align, ResNetFPN_8_1_align, ResNetFPN_8_2_fix, ResNet_8_1_align, VGG_8_1_align, RepVGG_8_1_align, \ + RepVGGnfpn_8_1_align, RepVGG_8_2_fix, s2dnet_8_1_align + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['align_corner'] is None or config['align_corner'] is True: + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + elif config['resolution'] == (8, 1): + return ResNetFPN_8_1(config['resnetfpn']) + elif config['align_corner'] is False: + if config['resolution'] == (8, 2): + return ResNetFPN_8_2_align(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + elif config['resolution'] == (8, 1): + return ResNetFPN_8_1_align(config['resnetfpn']) + elif config['backbone_type'] == 'ResNetFPNFIX': + if config['align_corner'] is None or config['align_corner'] is True: + if config['resolution'] == (8, 2): + return ResNetFPN_8_2_fix(config['resnetfpn']) + elif config['backbone_type'] == 'ResNet': + if config['align_corner'] is None or config['align_corner'] is True: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") + elif config['align_corner'] is False: + if config['resolution'] == (8, 1): + return ResNet_8_1_align(config['resnetfpn']) + elif config['backbone_type'] == 'VGG': + if config['align_corner'] is None or config['align_corner'] is True: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") + elif config['align_corner'] is False: + if config['resolution'] == (8, 1): + return VGG_8_1_align(config['resnetfpn']) + elif config['backbone_type'] == 'RepVGG': + if config['align_corner'] is None or config['align_corner'] is True: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") + elif config['align_corner'] is False: + if config['resolution'] == (8, 1): + return RepVGG_8_1_align(config['resnetfpn']) + elif config['backbone_type'] == 'RepVGGNFPN': + if config['align_corner'] is None or config['align_corner'] is True: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") + elif config['align_corner'] is False: + if config['resolution'] == (8, 1): + return RepVGGnfpn_8_1_align(config['resnetfpn']) + elif config['backbone_type'] == 'RepVGGFPNFIX': + if config['align_corner'] is None or config['align_corner'] is True: + if config['resolution'] == (8, 2): + return RepVGG_8_2_fix(config['resnetfpn']) + elif config['align_corner'] is False: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") + elif config['backbone_type'] == 's2dnet': + if config['align_corner'] is None or config['align_corner'] is True: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") + elif config['align_corner'] is False: + if config['resolution'] == (8, 1): + return s2dnet_8_1_align(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py b/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py new file mode 100644 index 0000000000000000000000000000000000000000..873a934dc0094fc742076c10efbaafcc78c283a7 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py @@ -0,0 +1,319 @@ +# -------------------------------------------------------- +# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf) +# Github source: https://github.com/DingXiaoH/RepVGG +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +import torch.nn as nn +import numpy as np +import torch +import copy +# from se_block import SEBlock +import torch.utils.checkpoint as checkpoint +from loguru import logger + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + result = nn.Sequential() + result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + +class RepVGGBlock(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, leaky=-1.0): + super(RepVGGBlock, self).__init__() + self.deploy = deploy + self.groups = groups + self.in_channels = in_channels + + assert kernel_size == 3 + assert padding == 1 + + padding_11 = padding - kernel_size // 2 + + if leaky == -2: + self.nonlinearity = nn.Identity() + logger.info(f"Using Identity nonlinearity in repvgg_block") + elif leaky < 0: + self.nonlinearity = nn.ReLU() + else: + self.nonlinearity = nn.LeakyReLU(leaky) + + if use_se: + # Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity. + # self.se = SEBlock(out_channels, internal_neurons=out_channels // 16) + raise ValueError(f"SEBlock not supported") + else: + self.se = nn.Identity() + + if deploy: + self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) + + else: + self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None + self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) + self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups) + print('RepVGG Block, identity = ', self.rbr_identity) + + + def forward(self, inputs): + if hasattr(self, 'rbr_reparam'): + return self.nonlinearity(self.se(self.rbr_reparam(inputs))) + + if self.rbr_identity is None: + id_out = 0 + else: + id_out = self.rbr_identity(inputs) + + return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) + + + # Optional. This may improve the accuracy and facilitates quantization in some cases. + # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight. + # 2. Use like this. + # loss = criterion(....) + # for every RepVGGBlock blk: + # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2() + # optimizer.zero_grad() + # loss.backward() + def get_custom_L2(self): + K3 = self.rbr_dense.conv.weight + K1 = self.rbr_1x1.conv.weight + t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() + t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() + + l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them. + eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel. + l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2. + return l2_loss_eq_kernel + l2_loss_circle + + + +# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. +# You can get the equivalent kernel and bias at any time and do whatever you want, + # for example, apply some penalties or constraints during training, just like you do to the other models. +# May be useful for quantization or pruning. + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1,1,1,1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def switch_to_deploy(self): + if hasattr(self, 'rbr_reparam'): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels, + kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, + padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True) + self.rbr_reparam.weight.data = kernel + self.rbr_reparam.bias.data = bias + self.__delattr__('rbr_dense') + self.__delattr__('rbr_1x1') + if hasattr(self, 'rbr_identity'): + self.__delattr__('rbr_identity') + if hasattr(self, 'id_tensor'): + self.__delattr__('id_tensor') + self.deploy = True + + + +class RepVGG(nn.Module): + + def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False, leaky=-1.0): + super(RepVGG, self).__init__() + assert len(width_multiplier) == 4 + self.deploy = deploy + self.override_groups_map = override_groups_map or dict() + assert 0 not in self.override_groups_map + self.use_se = use_se + self.use_checkpoint = use_checkpoint + + self.in_planes = min(64, int(64 * width_multiplier[0])) + self.stage0 = RepVGGBlock(in_channels=1, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se, leaky=leaky) + self.cur_layer_idx = 1 + self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1, leaky=leaky) + self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2, leaky=leaky) + self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2, leaky=leaky) + # self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=1) + # self.gap = nn.AdaptiveAvgPool2d(output_size=1) + # self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) + + def _make_stage(self, planes, num_blocks, stride, leaky=-1.0): + strides = [stride] + [1]*(num_blocks-1) + blocks = [] + for stride in strides: + cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) + blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, + stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se, leaky=leaky)) + self.in_planes = planes + self.cur_layer_idx += 1 + return nn.ModuleList(blocks) + + def forward(self, x): + out = self.stage0(x) + for stage in (self.stage1, self.stage2, self.stage3): # , self.stage4): + for block in stage: + if self.use_checkpoint: + out = checkpoint.checkpoint(block, out) + else: + out = block(out) + out = self.gap(out) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] +g2_map = {l: 2 for l in optional_groupwise_layers} +g4_map = {l: 4 for l in optional_groupwise_layers} + +def create_RepVGG_A0(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, + width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_A1(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) +def create_RepVGG_A15(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, + width_multiplier=[1.25, 1.25, 1.25, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) +def create_RepVGG_A1_leaky(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint, leaky=0.01) + +def create_RepVGG_A2(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, + width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B0(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B1(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B1g2(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B1g4(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint) + + +def create_RepVGG_B2(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B2g2(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B2g4(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint) + + +def create_RepVGG_B3(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B3g2(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_B3g4(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, + width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint) + +def create_RepVGG_D2se(deploy=False, use_checkpoint=False): + return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint) + + +func_dict = { +'RepVGG-A0': create_RepVGG_A0, +'RepVGG-A1': create_RepVGG_A1, +'RepVGG-A15': create_RepVGG_A15, +'RepVGG-A1_leaky': create_RepVGG_A1_leaky, +'RepVGG-A2': create_RepVGG_A2, +'RepVGG-B0': create_RepVGG_B0, +'RepVGG-B1': create_RepVGG_B1, +'RepVGG-B1g2': create_RepVGG_B1g2, +'RepVGG-B1g4': create_RepVGG_B1g4, +'RepVGG-B2': create_RepVGG_B2, +'RepVGG-B2g2': create_RepVGG_B2g2, +'RepVGG-B2g4': create_RepVGG_B2g4, +'RepVGG-B3': create_RepVGG_B3, +'RepVGG-B3g2': create_RepVGG_B3g2, +'RepVGG-B3g4': create_RepVGG_B3g4, +'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper. +} +def get_RepVGG_func_by_name(name): + return func_dict[name] + + + +# Use this for converting a RepVGG model or a bigger model with RepVGG as its component +# Use like this +# model = create_RepVGG_A0(deploy=False) +# train model or load weights +# repvgg_model_convert(model, save_path='repvgg_deploy.pth') +# If you want to preserve the original model, call with do_copy=True + +# ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like +# train_backbone = create_RepVGG_B2(deploy=False) +# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth')) +# train_pspnet = build_pspnet(backbone=train_backbone) +# segmentation_train(train_pspnet) +# deploy_pspnet = repvgg_model_convert(train_pspnet) +# segmentation_test(deploy_pspnet) +# ===================== example_pspnet.py shows an example + +def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True): + if do_copy: + model = copy.deepcopy(model) + for module in model.modules(): + if hasattr(module, 'switch_to_deploy'): + module.switch_to_deploy() + if save_path is not None: + torch.save(model.state_dict(), save_path) + return model diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py b/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..3596d7bd7f827197476e3f6ffaa1770a6913a3f8 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py @@ -0,0 +1,1094 @@ +import torch.nn as nn +import torch.nn.functional as F +from .repvgg import get_RepVGG_func_by_name +from .s2dnet import S2DNet + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return {'feats_c': x3_out, 'feats_f': x1_out} + + def pro(self, x, profiler): + with profiler.profile('ResNet Backbone'): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + with profiler.profile('ResNet FPN'): + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return {'feats_c': x3_out, 'feats_f': x1_out} + +class ResNetFPN_8_2_fix(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} + else: + return {'feats_c': x3, 'feats_f': None} + + + x3_out = self.layer3_outconv(x3) # n+1 + + x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 2n+1 + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) # 4n+1 + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return {'feats_c': x3_out, 'feats_f': x1_out} + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return {'feats_c': x4_out, 'feats_f': x2_out} + + +class ResNetFPN_8_1(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + if not self.skip_fine_feature: + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} + else: + return {'feats_c': x3, 'feats_f': None} + + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False) + + if not self.inter_feat: + return {'feats_c': x3, 'feats_f': x0_out} + else: + return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1} + + +class ResNetFPN_8_1_align(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + if not self.skip_fine_feature: + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} + else: + return {'feats_c': x3, 'feats_f': None} + + x3_out = self.layer3_outconv(x3) + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False) + + if not self.inter_feat: + return {'feats_c': x3, 'feats_f': x0_out} + else: + return {'feats_c': x3, 'feats_f': x0_out, 'feats_x2': x2, 'feats_x1': x1} + + def pro(self, x, profiler): + with profiler.profile('ResNet Backbone'): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + with profiler.profile('FPN'): + # FPN + x3_out = self.layer3_outconv(x3) + + if self.skip_fine: + return [x3_out, None] + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + with profiler.profile('upsample*1'): + x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False) + + return {'feats_c': x3_out, 'feats_f': x0_out} + + +class ResNetFPN_8_2_align(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + if not self.skip_fine_feature: + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} + else: + return {'feats_c': x3, 'feats_f': None} + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + if not self.inter_feat: + return {'feats_c': x3, 'feats_f': x1_out} + else: + return {'feats_c': x3, 'feats_f': x1_out, 'feats_x2': x2, 'feats_x1': x1} + + +class ResNet_8_1_align(nn.Module): + """ + ResNet, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + # self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + # self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + # self.layer2_outconv2 = nn.Sequential( + # conv3x3(block_dims[2], block_dims[2]), + # nn.BatchNorm2d(block_dims[2]), + # nn.LeakyReLU(), + # conv3x3(block_dims[2], block_dims[1]), + # ) + # self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + # self.layer1_outconv2 = nn.Sequential( + # conv3x3(block_dims[1], block_dims[1]), + # nn.BatchNorm2d(block_dims[1]), + # nn.LeakyReLU(), + # conv3x3(block_dims[1], block_dims[0]), + # ) + self.layer0_outconv = conv1x1(block_dims[2], block_dims[0]) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + # x3_out = self.layer3_outconv(x3) + + # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False) + # x2_out = self.layer2_outconv(x2) + # x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False) + # x1_out = self.layer1_outconv(x1) + # x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + x0_out = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False) + x0_out = self.layer0_outconv(x0_out) + + return {'feats_c': x3, 'feats_f': x0_out} + +class VGG_8_1_align(nn.Module): + """ + VGG-like backbone, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + + # self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + # self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d( + c5, 256, + kernel_size=1, stride=1, padding=0) + self.layer0_outconv = conv1x1(block_dims[2], block_dims[0]) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # Shared Encoder + x = self.relu(self.conv1a(x)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + cDa = self.relu(self.convDa(x)) + descriptors = self.convDb(cDa) + x3_out = nn.functional.normalize(descriptors, p=2, dim=1) + + x0_out = F.interpolate(x3_out, scale_factor=8., mode='bilinear', align_corners=False) + x0_out = self.layer0_outconv(x0_out) + # ResNet Backbone + # x0 = self.relu(self.bn1(self.conv1(x))) + # x1 = self.layer1(x0) # 1/2 + # x2 = self.layer2(x1) # 1/4 + # x3 = self.layer3(x2) # 1/8 + + # # FPN + # x3_out = self.layer3_outconv(x3) + + # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False) + # x2_out = self.layer2_outconv(x2) + # x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False) + # x1_out = self.layer1_outconv(x1) + # x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False) + + return {'feats_c': x3_out, 'feats_f': x0_out} + +class RepVGG_8_1_align(nn.Module): + """ + RepVGG backbone, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + # block = BasicBlock + # initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + self.leaky = config['leaky'] + + # backbone_name='RepVGG-B0' + if config.get('repvggmodel') is not None: + backbone_name=config['repvggmodel'] + elif self.leaky: + backbone_name='RepVGG-A1_leaky' + else: + backbone_name='RepVGG-A1' + repvgg_fn = get_RepVGG_func_by_name(backbone_name) + backbone = repvgg_fn(False) + self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4 + # self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4 + + # 3. FPN upsample + if not self.skip_fine_feature: + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + # self.layer0_outconv = conv1x1(192, 48) + + for layer in [self.layer0, self.layer1, self.layer2, self.layer3]: + for m in layer.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + # for layer in [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4]: + # for m in layer.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + # nn.init.constant_(m.weight, 1) + # nn.init.constant_(m.bias, 0) + + def forward(self, x): + + out = self.layer0(x) # 1/2 + for module in self.layer1: + out = module(out) # 1/2 + x1 = out + for module in self.layer2: + out = module(out) # 1/4 + x2 = out + for module in self.layer3: + out = module(out) # 1/8 + x3 = out + # for module in self.layer4: + # out = module(out) + # x3 = out + + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} + else: + return {'feats_c': x3, 'feats_f': None} + x3_out = self.layer3_outconv(x3) + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False) + + # x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False) + # x_f = self.layer0_outconv(x_f) + return {'feats_c': x3_out, 'feats_f': x0_out} + + +class RepVGG_8_2_fix(nn.Module): + """ + RepVGG backbone, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + # block = BasicBlock + # initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + + # backbone_name='RepVGG-B0' + backbone_name='RepVGG-A1' + repvgg_fn = get_RepVGG_func_by_name(backbone_name) + backbone = repvgg_fn(False) + self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4 + + # 3. FPN upsample + if not self.skip_fine_feature: + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + # self.layer0_outconv = conv1x1(192, 48) + + for layer in [self.layer0, self.layer1, self.layer2, self.layer3]: + for m in layer.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + + x0 = self.layer0(x) # 1/2 + out = x0 + for module in self.layer1: + out = module(out) # 1/2 + x1 = out + for module in self.layer2: + out = module(out) # 1/4 + x2 = out + for module in self.layer3: + out = module(out) # 1/8 + x3 = out + # for module in self.layer4: + # out = module(out) + + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} + else: + return {'feats_c': x3, 'feats_f': None} + x3_out = self.layer3_outconv(x3) + x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, size=((x2_out.size(-2)-1)*2+1, (x2_out.size(-1)-1)*2+1), mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False) + + # x_f = F.interpolate(x_c, scale_factor=8., mode='bilinear', align_corners=False) + # x_f = self.layer0_outconv(x_f) + return {'feats_c': x3_out, 'feats_f': x1_out} + + +class RepVGGnfpn_8_1_align(nn.Module): + """ + RepVGG backbone, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + # block = BasicBlock + # initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + + # backbone_name='RepVGG-B0' + backbone_name='RepVGG-A1' + repvgg_fn = get_RepVGG_func_by_name(backbone_name) + backbone = repvgg_fn(False) + self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3 #, backbone.stage4 + + # 3. FPN upsample + if not self.skip_fine_feature: + self.layer0_outconv = conv1x1(block_dims[2], block_dims[0]) + # self.layer0_outconv = conv1x1(192, 48) + + for layer in [self.layer0, self.layer1, self.layer2, self.layer3]: + for m in layer.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + + x0 = self.layer0(x) # 1/2 + out = x0 + for module in self.layer1: + out = module(out) # 1/2 + x1 = out + for module in self.layer2: + out = module(out) # 1/4 + x2 = out + for module in self.layer3: + out = module(out) # 1/8 + x3 = out + # for module in self.layer4: + # out = module(out) + + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1} + else: + return {'feats_c': x3, 'feats_f': None} + # x3_out = self.layer3_outconv(x3) + # x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=False) + # x2_out = self.layer2_outconv(x2) + # x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + # x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=False) + # x1_out = self.layer1_outconv(x1) + # x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + # x0_out = F.interpolate(x1_out, scale_factor=2., mode='bilinear', align_corners=False) + + x_f = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False) + x_f = self.layer0_outconv(x_f) + # x_f2 = F.interpolate(x3, scale_factor=8., mode='bilinear', align_corners=False) + # x_f2 = self.layer0_outconv(x_f2) + return {'feats_c': x3, 'feats_f': x_f} + + +class s2dnet_8_1_align(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + self.skip_fine_feature = config['coarse_feat_only'] + self.inter_feat = config['inter_feat'] + # Networks + self.backbone = S2DNet(checkpoint_path = '/cephfs-mvs/3dv-research/hexingyi/code_yf/loftrdev/weights/s2dnet/s2dnet_weights.pth') + # 3. FPN upsample + # self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + # if not self.skip_fine_feature: + # self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + # self.layer2_outconv2 = nn.Sequential( + # conv3x3(block_dims[2], block_dims[2]), + # nn.BatchNorm2d(block_dims[2]), + # nn.LeakyReLU(), + # conv3x3(block_dims[2], block_dims[1]), + # ) + # self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + # self.layer1_outconv2 = nn.Sequential( + # conv3x3(block_dims[1], block_dims[1]), + # nn.BatchNorm2d(block_dims[1]), + # nn.LeakyReLU(), + # conv3x3(block_dims[1], block_dims[0]), + # ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + # nn.init.constant_(m.weight, 1) + # nn.init.constant_(m.bias, 0) + + def forward(self, x): + ret = self.backbone(x) + ret[2] = F.interpolate(ret[2], scale_factor=2., mode='bilinear', align_corners=False) + if self.skip_fine_feature: + if self.inter_feat: + return {'feats_c': ret[2], 'feats_f': None, 'feats_x2': ret[1], 'feats_x1': ret[0]} + else: + return {'feats_c': ret[2], 'feats_f': None,} + + def pro(self, x, profiler): + pass \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py b/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4c2eb8a61a5193405ea86b7e67c11a19fa94f7 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# from torchvision import models +from typing import List, Dict + +# VGG-16 Layer Names and Channels +vgg16_layers = { + "conv1_1": 64, + "relu1_1": 64, + "conv1_2": 64, + "relu1_2": 64, + "pool1": 64, + "conv2_1": 128, + "relu2_1": 128, + "conv2_2": 128, + "relu2_2": 128, + "pool2": 128, + "conv3_1": 256, + "relu3_1": 256, + "conv3_2": 256, + "relu3_2": 256, + "conv3_3": 256, + "relu3_3": 256, + "pool3": 256, + "conv4_1": 512, + "relu4_1": 512, + "conv4_2": 512, + "relu4_2": 512, + "conv4_3": 512, + "relu4_3": 512, + "pool4": 512, + "conv5_1": 512, + "relu5_1": 512, + "conv5_2": 512, + "relu5_2": 512, + "conv5_3": 512, + "relu5_3": 512, + "pool5": 512, +} + +class AdapLayers(nn.Module): + """Small adaptation layers. + """ + + def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128): + """Initialize one adaptation layer for every extraction point. + + Args: + hypercolumn_layers: The list of the hypercolumn layer names. + output_dim: The output channel dimension. + """ + super(AdapLayers, self).__init__() + self.layers = [] + channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers] + for i, l in enumerate(channel_sizes): + layer = nn.Sequential( + nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0), + nn.ReLU(), + nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2), + nn.BatchNorm2d(output_dim), + ) + self.layers.append(layer) + self.add_module("adap_layer_{}".format(i), layer) + + def forward(self, features: List[torch.tensor]): + """Apply adaptation layers. + """ + for i, _ in enumerate(features): + features[i] = getattr(self, "adap_layer_{}".format(i))(features[i]) + return features + +class S2DNet(nn.Module): + """The S2DNet model + """ + + def __init__( + self, + # hypercolumn_layers: List[str] = ["conv2_2", "conv3_3", "relu4_3"], + hypercolumn_layers: List[str] = ["conv1_2", "conv3_3", "conv5_3"], + checkpoint_path: str = None, + ): + """Initialize S2DNet. + + Args: + device: The torch device to put the model on + hypercolumn_layers: Names of the layers to extract features from + checkpoint_path: Path to the pre-trained model. + """ + super(S2DNet, self).__init__() + self._checkpoint_path = checkpoint_path + self.layer_to_index = dict((k, v) for v, k in enumerate(vgg16_layers.keys())) + self._hypercolumn_layers = hypercolumn_layers + + # Initialize architecture + vgg16 = models.vgg16(pretrained=False) + # layers = list(vgg16.features.children())[:-2] + layers = list(vgg16.features.children())[:-1] + # layers = list(vgg16.features.children())[:23] # relu4_3 + self.encoder = nn.Sequential(*layers) + self.adaptation_layers = AdapLayers(self._hypercolumn_layers) # .to(self._device) + self.eval() + + # Restore params from checkpoint + if checkpoint_path: + print(">> Loading weights from {}".format(checkpoint_path)) + self._checkpoint = torch.load(checkpoint_path) + self._hypercolumn_layers = self._checkpoint["hypercolumn_layers"] + self.load_state_dict(self._checkpoint["state_dict"]) + + def forward(self, image_tensor: torch.FloatTensor): + """Compute intermediate feature maps at the provided extraction levels. + + Args: + image_tensor: The [N x 3 x H x Ws] input image tensor. + Returns: + feature_maps: The list of output feature maps. + """ + feature_maps, j = [], 0 + feature_map = image_tensor.repeat(1,3,1,1) + layer_list = list(self.encoder.modules())[0] + for i, layer in enumerate(layer_list): + feature_map = layer(feature_map) + if j < len(self._hypercolumn_layers): + next_extraction_index = self.layer_to_index[self._hypercolumn_layers[j]] + if i == next_extraction_index: + feature_maps.append(feature_map) + j += 1 + feature_maps = self.adaptation_layers(feature_maps) + return feature_maps diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr.py b/imcui/third_party/MatchAnything/src/loftr/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..53b7d4a86ec175b483ead096ff9db1ae5802fa63 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/loftr.py @@ -0,0 +1,273 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +# from third_party.matchformer.model.backbone import build_backbone as build_backbone_matchformer +from .utils.position_encoding import PositionEncodingSine +from .loftr_module import LocalFeatureTransformer, FinePreprocess +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching + +from loguru import logger + +class LoFTR(nn.Module): + def __init__(self, config, profiler=None): + super().__init__() + # Misc + self.config = config + self.profiler = profiler + + # Modules + self.backbone = build_backbone(config) + if not (self.config['coarse']['skip'] or self.config['coarse']['rope'] or self.config['coarse']['pan'] or self.config['coarse']['token_mixer'] is not None): + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=config['coarse']['temp_bug_fix'], + npe=config['coarse']['npe'], + ) + if self.config['coarse']['abspe']: + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=config['coarse']['temp_bug_fix'], + npe=config['coarse']['npe'], + ) + + if self.config['coarse']['skip'] is False: + self.loftr_coarse = LocalFeatureTransformer(config) + self.coarse_matching = CoarseMatching(config['match_coarse']) + # self.fine_preprocess = FinePreprocess(config).float() + self.fine_preprocess = FinePreprocess(config) + if self.config['fine']['skip'] is False: + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching(config) + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + # feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + ret_dict = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + feats_c, feats_f = ret_dict['feats_c'], ret_dict['feats_f'] + if self.config['inter_feat']: + data.update({ + 'feats_x2': ret_dict['feats_x2'], + 'feats_x1': ret_dict['feats_x1'], + }) + if self.config['coarse_feat_only']: + (feat_c0, feat_c1) = feats_c.split(data['bs']) + feat_f0, feat_f1 = None, None + else: + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) + else: # handle different input shapes + # (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) + ret_dict0, ret_dict1 = self.backbone(data['image0']), self.backbone(data['image1']) + feat_c0, feat_f0 = ret_dict0['feats_c'], ret_dict0['feats_f'] + feat_c1, feat_f1 = ret_dict1['feats_c'], ret_dict1['feats_f'] + if self.config['inter_feat']: + data.update({ + 'feats_x2_0': ret_dict0['feats_x2'], + 'feats_x1_0': ret_dict0['feats_x1'], + 'feats_x2_1': ret_dict1['feats_x2'], + 'feats_x1_1': ret_dict1['feats_x1'], + }) + if self.config['coarse_feat_only']: + feat_f0, feat_f1 = None, None + + + mul = self.config['resolution'][0] // self.config['resolution'][1] + # mul = 4 + if self.config['fix_bias']: + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [(feat_c0.shape[2]-1) * mul+1, (feat_c0.shape[3]-1) * mul+1] , + 'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [(feat_c1.shape[2]-1) * mul+1, (feat_c1.shape[3]-1) * mul+1] + }) + else: + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:] if feat_f0 is not None else [feat_c0.shape[2] * mul, feat_c0.shape[3] * mul] , + 'hw1_f': feat_f1.shape[2:] if feat_f1 is not None else [feat_c1.shape[2] * mul, feat_c1.shape[3] * mul] + }) + + # 2. coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + if self.config['coarse']['skip']: + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'], data['mask1'] + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + + elif self.config['coarse']['pan']: + # assert feat_c0.shape[0] == 1, 'batch size must be 1 when using mask Xformer now' + if self.config['coarse']['abspe']: + feat_c0 = self.pos_encoding(feat_c0) + feat_c1 = self.pos_encoding(feat_c1) + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'], data['mask1'] + if self.config['matchability']: # else match in loftr_coarse + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1, data=data) + else: + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + else: + if not (self.config['coarse']['rope'] or self.config['coarse']['token_mixer'] is not None): + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if self.config['coarse']['rope']: + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'], data['mask1'] + else: + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + if self.config['coarse']['rope']: + feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') + feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') + + # detect nan + if self.config['replace_nan'] and (torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))): + logger.info(f'replace nan in coarse attention') + logger.info(f"feat_c0_nan_num: {torch.isnan(feat_c0).int().sum()}, feat_c1_nan_num: {torch.isnan(feat_c1).int().sum()}") + logger.info(f"feat_c0: {feat_c0}, feat_c1: {feat_c1}") + logger.info(f"feat_c0_max: {feat_c0.abs().max()}, feat_c1_max: {feat_c1.abs().max()}") + feat_c0[torch.isnan(feat_c0)] = 0 + feat_c1[torch.isnan(feat_c1)] = 0 + logger.info(f"feat_c0_nanmax: {feat_c0.abs().max()}, feat_c1_nanmax: {feat_c1.abs().max()}") + + # 3. match coarse-level + if not self.config['matchability']: # else match in loftr_coarse + self.coarse_matching(feat_c0, feat_c1, data, + mask_c0=mask_c0.view(mask_c0.size(0), -1) if mask_c0 is not None else mask_c0, + mask_c1=mask_c1.view(mask_c1.size(0), -1) if mask_c1 is not None else mask_c1 + ) + + #return data['conf_matrix'],feat_c0,feat_c1,data['feats_x2'],data['feats_x1'] + + # norm FPNfeat + if self.config['norm_fpnfeat']: + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + if self.config['norm_fpnfeat2']: + assert self.config['inter_feat'] + logger.info(f'before norm_fpnfeat2 max of feat_c0, feat_c1:{feat_c0.abs().max()}, {feat_c1.abs().max()}') + if data['hw0_i'] == data['hw1_i']: + logger.info(f'before norm_fpnfeat2 max of data[feats_x2], data[feats_x1]:{data["feats_x2"].abs().max()}, {data["feats_x1"].abs().max()}') + feat_c0, feat_c1, data['feats_x2'], data['feats_x1'] = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1, data['feats_x2'], data['feats_x1']]) + else: + feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1'] = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1, data['feats_x2_0'], data['feats_x2_1'], data['feats_x1_0'], data['feats_x1_1']]) + + + # 4. fine-level refinement + with torch.autocast(enabled=False, device_type="cuda"): + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + + # detect nan + if self.config['replace_nan'] and (torch.any(torch.isnan(feat_f0_unfold)) or torch.any(torch.isnan(feat_f1_unfold))): + logger.info(f'replace nan in fine_preprocess') + logger.info(f"feat_f0_unfold_nan_num: {torch.isnan(feat_f0_unfold).int().sum()}, feat_f1_unfold_nan_num: {torch.isnan(feat_f1_unfold).int().sum()}") + logger.info(f"feat_f0_unfold: {feat_f0_unfold}, feat_f1_unfold: {feat_f1_unfold}") + logger.info(f"feat_f0_unfold_max: {feat_f0_unfold}, feat_f1_unfold_max: {feat_f1_unfold}") + feat_f0_unfold[torch.isnan(feat_f0_unfold)] = 0 + feat_f1_unfold[torch.isnan(feat_f1_unfold)] = 0 + logger.info(f"feat_f0_unfold_nanmax: {feat_f0_unfold}, feat_f1_unfold_nanmax: {feat_f1_unfold}") + + if self.config['fp16log'] and feat_c0 is not None: + logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}") + del feat_c0, feat_c1, mask_c0, mask_c1 + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + if self.config['fine']['pan']: + m, ww, c = feat_f0_unfold.size() # [m, ww, c] + w = self.config['fine_window_size'] + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w)) + feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c') + feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c') + elif self.config['fine']['skip']: + pass + else: + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + # 5. match fine-level + # log forward nan + if self.config['fp16log']: + if feat_f0_unfold.size(0) != 0 and feat_f0 is not None: + logger.info(f"f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}") + elif feat_f0_unfold.size(0) != 0: + logger.info(f"uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}") + # elif feat_c0 is not None: + # logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}") + + with torch.autocast(enabled=False, device_type="cuda"): + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + return data + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) + + def refine(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + feat_f0, feat_f1 = None, None + feat_c0, feat_c1 = data['feat_c0'], data['feat_c1'] + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + if self.config['fine']['pan']: + m, ww, c = feat_f0_unfold.size() # [m, ww, c] + w = self.config['fine_window_size'] + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold.reshape(m, c, w, w), feat_f1_unfold.reshape(m, c, w, w)) + feat_f0_unfold = rearrange(feat_f0_unfold, 'm c w h -> m (w h) c') + feat_f1_unfold = rearrange(feat_f1_unfold, 'm c w h -> m (w h) c') + elif self.config['fine']['skip']: + pass + else: + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + # 5. match fine-level + # log forward nan + if self.config['fp16log']: + if feat_f0_unfold.size(0) != 0 and feat_f0 is not None and feat_c0 is not None: + logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}, uf0: {feat_f0_unfold.abs().max()}, uf1: {feat_f1_unfold.abs().max()}") + elif feat_f0 is not None and feat_c0 is not None: + logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}, f0: {feat_f0.abs().max()}, f1: {feat_f1.abs().max()}") + elif feat_c0 is not None: + logger.info(f"c0: {feat_c0.abs().max()}, c1: {feat_c1.abs().max()}") + + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + return data \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..019c0199e7be5c7fe65669420a98003c51c8bed2 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py @@ -0,0 +1,350 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat +from ..backbone.repvgg import RepVGGBlock + +from loguru import logger + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.sample_c_feat = config['fine_sample_coarse_feat'] + self.fpn_inter_feat = config['inter_feat'] + self.rep_fpn = config['rep_fpn'] + self.deploy = config['rep_deploy'] + self.multi_regress = config['match_fine']['multi_regress'] + self.local_regress = config['match_fine']['local_regress'] + self.local_regress_inner = config['match_fine']['local_regress_inner'] + block_dims = config['resnetfpn']['block_dims'] + + self.mtd_spvs = self.config['fine']['mtd_spvs'] + self.align_corner = self.config['align_corner'] + self.fix_bias = self.config['fix_bias'] + + if self.mtd_spvs: + self.W = self.config['fine_window_size'] + else: + # assert False, 'fine_window_matching_size to be revised' # good notification! + # self.W = self.config['fine_window_matching_size'] + self.W = self.config['fine_window_size'] + + self.backbone_type = self.config['backbone_type'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.fpn_inter_feat: + if self.rep_fpn: + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = [] + self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[2], kernel_size=3, + stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01)) + self.layer2_outconv2.append(RepVGGBlock(in_channels=block_dims[2], out_channels=block_dims[1], kernel_size=3, + stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2)) + self.layer2_outconv2 = nn.ModuleList(self.layer2_outconv2) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = [] + self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[1], kernel_size=3, + stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=0.01)) + self.layer1_outconv2.append(RepVGGBlock(in_channels=block_dims[1], out_channels=block_dims[0], kernel_size=3, + stride=1, padding=1, groups=1, deploy=self.deploy, use_se=False, leaky=-2)) + self.layer1_outconv2 = nn.ModuleList(self.layer1_outconv2) + + else: + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + elif self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + if self.sample_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def inter_fpn(self, feat_c, x2, x1, stride): + feat_c = self.layer3_outconv(feat_c) + feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False) + x2 = self.layer2_outconv(x2) + if self.rep_fpn: + x2 = x2 + feat_c + for layer in self.layer2_outconv2: + x2 = layer(x2) + else: + x2 = self.layer2_outconv2(x2+feat_c) + + x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False) + x1 = self.layer1_outconv(x1) + if self.rep_fpn: + x1 = x1 + x2 + for layer in self.layer1_outconv2: + x1 = layer(x1) + else: + x1 = self.layer1_outconv2(x1+x2) + + if stride == 4: + logger.info('stride == 4') + + elif stride == 8: + logger.info('stride == 8') + x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False) + else: + logger.info('stride not in {4,8}') + assert False + return x1 + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + if self.fix_bias: + stride = 4 + else: + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_c0.device) + # return feat0, feat1 + return feat0.float(), feat1.float() + + if self.fpn_inter_feat: + if data['hw0_i'] != data['hw1_i']: + if self.align_corner is False: + assert self.backbone_type != 's2dnet' + + feat_c0 = rearrange(feat_c0, 'b (h w) c -> b c h w', h=data['hw0_c'][0]) + feat_c1 = rearrange(feat_c1, 'b (h w) c -> b c h w', h=data['hw1_c'][0]) + x2_0, x1_0 = data['feats_x2_0'], data['feats_x1_0'] + x2_1, x1_1 = data['feats_x2_1'], data['feats_x1_1'] + del data['feats_x2_0'], data['feats_x1_0'], data['feats_x2_1'], data['feats_x1_1'] + feat_f0, feat_f1 = self.inter_fpn(feat_c0, x2_0, x1_0, stride), self.inter_fpn(feat_c1, x2_1, x1_1, stride) + + if self.local_regress_inner: + assert W == 8 + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2) + elif W == 10 and self.multi_regress: + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2) + elif W == 10: + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2) + else: + assert not self.multi_regress + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1 = feat_f1[data['b_ids'], data['j_ids']] + + return feat_f0, feat_f1 + + else: + if self.align_corner is False: + feat_c = torch.cat([feat_c0, feat_c1], 0) + feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0]) # 1/8 256 + x2 = data['feats_x2'].float() # 1/4 128 + x1 = data['feats_x1'].float() # 1/2 64 + del data['feats_x2'], data['feats_x1'] + assert self.backbone_type != 's2dnet' + feat_c = self.layer3_outconv(feat_c) + feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False) + x2 = self.layer2_outconv(x2) + if self.rep_fpn: + x2 = x2 + feat_c + for layer in self.layer2_outconv2: + x2 = layer(x2) + else: + x2 = self.layer2_outconv2(x2+feat_c) + + x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False) + x1 = self.layer1_outconv(x1) + if self.rep_fpn: + x1 = x1 + x2 + for layer in self.layer1_outconv2: + x1 = layer(x1) + else: + x1 = self.layer1_outconv2(x1+x2) + + if stride == 4: + # logger.info('stride == 4') + pass + elif stride == 8: + # logger.info('stride == 8') + x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False) + else: + # logger.info('stride not in {4,8}') + assert False + + feat_f0, feat_f1 = torch.chunk(x1, 2, dim=0) + + # 1. unfold(crop) all local windows + if self.local_regress_inner: + assert W == 8 + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2) + elif self.multi_regress or (self.local_regress and W == 10): + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2) + elif W == 10: + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=1) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=1) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2) + + else: + feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0) + feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1 = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0) + feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1 = feat_f1[data['b_ids'], data['j_ids']] + + return feat_f0, feat_f1 + elif self.fix_bias: + feat_c = torch.cat([feat_c0, feat_c1], 0) + feat_c = rearrange(feat_c, 'b (h w) c -> b c h w', h=data['hw0_c'][0]) + x2 = data['feats_x2'].float() + x1 = data['feats_x1'].float() + assert self.backbone_type != 's2dnet' + x3_out = self.layer3_outconv(feat_c) + x3_out_2x = F.interpolate(x3_out, size=((x3_out.size(-2)-1)*2+1, (x3_out.size(-1)-1)*2+1), mode='bilinear', align_corners=False) + x2 = self.layer2_outconv(x2) + x2 = self.layer2_outconv2(x2+x3_out_2x) + + x2 = F.interpolate(x2, size=((x2.size(-2)-1)*2+1, (x2.size(-1)-1)*2+1), mode='bilinear', align_corners=False) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2) + x0_out = x1_out + + feat_f0, feat_f1 = torch.chunk(x0_out, 2, dim=0) + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + return feat_f0_unfold, feat_f1_unfold + + + + elif self.sample_c_feat: + if self.align_corner is False: + # easy implemented but memory consuming + feat_c = self.down_proj(torch.cat([feat_c0, + feat_c1], 0)) # [n, (h w), c] -> [2n, (h w), cf] + feat_c = rearrange(feat_c, 'n (h w) c -> n c h w', h=data['hw0_c'][0], w=data['hw0_c'][1]) + feat_f = F.interpolate(feat_c, scale_factor=8., mode='bilinear', align_corners=False) # [2n, cf, hf, wf] + feat_f_unfold = F.unfold(feat_f, kernel_size=(W, W), stride=stride, padding=0) + feat_f_unfold = rearrange(feat_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_f_unfold, 2, dim=0) + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [m, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] # [m, ww, cf] + # return feat_f0_unfold, feat_f1_unfold + return feat_f0_unfold.float(), feat_f1_unfold.float() + else: + if self.align_corner is False: + # 1. unfold(crop) all local windows + assert False, 'maybe exist bugs' + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=0) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold + + else: + # 1. unfold(crop) all local windows + if self.fix_bias: + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + else: + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + # return feat_f0_unfold, feat_f1_unfold + return feat_f0_unfold.float(), feat_f1_unfold.float() \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7f08fcd77195dab126a59d2e832e72bc31012a --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py @@ -0,0 +1,217 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout +import torch.nn.functional as F + +# if hasattr(F, 'scaled_dot_product_attention'): +# FLASH_AVAILABLE = True +# else: # v100 +FLASH_AVAILABLE = False + # import xformers.ops +from ..utils.position_encoding import PositionEncodingSine, RoPEPositionEncodingSine +from einops.einops import rearrange +from loguru import logger + + +# flash_attn_func_ok = True +# try: +# from flash_attn import flash_attn_func +# except ModuleNotFoundError: +# flash_attn_func_ok = False + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + # queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + +class RoPELinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + self.RoPE = RoPEPositionEncodingSine(256, max_shape=(256, 256)) + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, queries, keys, values, q_mask=None, kv_mask=None, H=None, W=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + nhead, d = Q.size(2), Q.size(3) + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + # Q = Q / Q.size(1) + # logger.info(f"Q: {Q.dtype}, K: {K.dtype}, values: {values.dtype}") + + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + # logger.info(f"Z_max: {Z.abs().max()}") + Q = rearrange(Q, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W) + K = rearrange(K, 'n (h w) nhead d -> n h w (nhead d)', h=H, w=W) + Q, K = self.RoPE(Q), self.RoPE(K) + # logger.info(f"Q_rope: {Q.abs().max()}, K_rope: {K.abs().max()}") + Q = rearrange(Q, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d) + K = rearrange(K, 'n h w (nhead d) -> n (h w) nhead d', nhead=nhead, d=d) + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + del K, values + # logger.info(f"KV_max: {KV.abs().max()}") + # queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + # Q = torch.einsum("nlhd,nlh->nlhd", Q, Z) + # logger.info(f"QZ_max: {Q.abs().max()}") + # queried_values = torch.einsum("nlhd,nhdv->nlhv", Q, KV) * v_length + # logger.info(f"message_max: {queried_values.abs().max()}") + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + # assert kv_mask is None + # mask = torch.zeros(queries.size(0)*queries.size(2), queries.size(1), keys.size(1), device=queries.device) + # mask.masked_fill(~(q_mask[:, :, None] * kv_mask[:, None, :]), float('-inf')) + # if keys.size(1) % 8 != 0: + # mask = torch.cat([mask, torch.zeros(queries.size(0)*queries.size(2), queries.size(1), 8-keys.size(1)%8, device=queries.device)], dim=-1) + # out = xformers.ops.memory_efficient_attention(queries, keys, values, attn_bias=mask[...,:keys.size(1)]) + # return out + + # N = queries.size(0) + # list_q = [queries[i, :q_mask[i].sum, ...] for i in N] + # list_k = [keys[i, :kv_mask[i].sum, ...] for i in N] + # list_v = [values[i, :kv_mask[i].sum, ...] for i in N] + # assert N == 1 + # out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...]) + # out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1) + # return out + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), -1e5) # float('-inf') + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() + + +class XAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + if use_dropout: + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + if FLASH_AVAILABLE: # pytorch scaled_dot_product_attention + queries: [N, H, L, D] + keys: [N, H, S, D] + values: [N, H, S, D] + else: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + assert q_mask is None and kv_mask is None, "already been sliced" + if FLASH_AVAILABLE: + # args = [x.half().contiguous() for x in [queries, keys, values]] + # out = F.scaled_dot_product_attention(*args, attn_mask=mask).to(queries.dtype) + args = [x.contiguous() for x in [queries, keys, values]] + out = F.scaled_dot_product_attention(*args) + else: + # if flash_attn_func_ok: + # out = flash_attn_func(queries, keys, values) + # else: + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + + out = torch.einsum("nlsh,nshd->nlhd", A, values) + + # out = xformers.ops.memory_efficient_attention(queries, keys, values) + # out = xformers.ops.memory_efficient_attention(queries[:,:q_mask.sum(),...], keys[:,:kv_mask.sum(),...], values[:,:kv_mask.sum(),...]) + # out = torch.cat([out, torch.zeros(out.size(0), queries.size(1)-q_mask.sum(), queries.size(2), queries.size(3), device=queries.device)], dim=1) + return out diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee8ce85912ad44539c27836ddd20f676912df5b --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py @@ -0,0 +1,1768 @@ +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +from .linear_attention import LinearAttention, RoPELinearAttention, FullAttention, XAttention +from einops.einops import rearrange +from collections import OrderedDict +from .transformer_utils import TokenConfidence, MatchAssignment, filter_matches +from ..utils.coarse_matching import CoarseMatching +from ..utils.position_encoding import RoPEPositionEncodingSine +import numpy as np +from loguru import logger + +PFLASH_AVAILABLE = False + +class PANEncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear', + pool_size=4, + bn=True, + xformer=False, + leaky=-1.0, + dw_conv=False, + scatter=False, + ): + super(PANEncoderLayer, self).__init__() + + self.pool_size = pool_size + self.dw_conv = dw_conv + self.scatter = scatter + if self.dw_conv: + self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model) + + assert not self.scatter, 'buggy implemented here' + self.dim = d_model // nhead + self.nhead = nhead + + self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size) + # multi-head attention + if bn: + method = 'dw_bn' + else: + method = 'dw' + self.q_proj_conv = self._build_projection(d_model, d_model, method=method) + self.k_proj_conv = self._build_projection(d_model, d_model, method=method) + self.v_proj_conv = self._build_projection(d_model, d_model, method=method) + + # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False) + # self.k_proj = nn.Linear(d_model, d_model, bias=False) + # self.v_proj = nn.Linear(d_model, d_model, bias=False) + if xformer: + self.attention = XAttention() + else: + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + if leaky > 0: + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.LeakyReLU(leaky, True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + else: + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + # self.norm1 = nn.BatchNorm2d(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, C, H1, W1] + source (torch.Tensor): [N, C, H2, W2] + x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1) + source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2) + """ + bs = x.size(0) + H1, W1 = x.size(-2), x.size(-1) + H2, W2 = source.size(-2), source.size(-1) + + query, key, value = x, source, source + + if self.dw_conv: + query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2) + else: + query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2) + # only need to cal key or value... + key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2) + value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2) + + # After 0617 bnorm to prevent permute*6 + # query = self.norm1(self.max_pool(query)) + # key = self.norm1(self.max_pool(key)) + # value = self.norm1(self.max_pool(value)) + # multi-head attention + query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool] + key = self.k_proj_conv(key) + value = self.v_proj_conv(value) + + C = query.shape[-3] + + ismask = x_mask is not None and source_mask is not None + if bs == 1 or not ismask: + if ismask: + x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool] + source_mask = self.max_pool(source_mask.float()).bool() + + mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0] + mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0] + + query = query[:, :, :mask_h0, :mask_w0] + key = key[:, :, :mask_h1, :mask_w1] + value = value[:, :, :mask_h1, :mask_w1] + + else: + assert x_mask is None and source_mask is None + + # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D] + # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D] + # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D] + if PFLASH_AVAILABLE: # N H L D + query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + + else: # N L H D + query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D] + key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + + message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D] + + if PFLASH_AVAILABLE: # N H L D + message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) + + if ismask: + message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) + if mask_h0 != x_mask.size(-2): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2) + # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C] + + else: + assert x_mask is None and source_mask is None + + + message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C] + # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug??? + message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W] + + if self.scatter: + message = torch.repeat_interleave(message, self.pool_size, dim=-2) + message = torch.repeat_interleave(message, self.pool_size, dim=-1) + # message = self.aggregate(message) + message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size) + else: + message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + + # message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x + message + else: + x_mask = self.max_pool(x_mask.float()).bool() + source_mask = self.max_pool(source_mask.float()).bool() + m_list = [] + for i in range(bs): + mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0] + mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0] + + q = query[i:i+1, :, :mask_h0, :mask_w0] + k = key[i:i+1, :, :mask_h1, :mask_w1] + v = value[i:i+1, :, :mask_h1, :mask_w1] + + if PFLASH_AVAILABLE: # N H L D + q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + + else: # N L H D + + q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D] + k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + + m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D] + + if PFLASH_AVAILABLE: # N H L D + m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) + + m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim) + if mask_h0 != x_mask.size(-2): + m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2) + m_list.append(m) + message = torch.cat(m_list, dim=0) + + + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug??? + message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W] + + if self.scatter: + message = torch.repeat_interleave(message, self.pool_size, dim=-2) + message = torch.repeat_interleave(message, self.pool_size, dim=-1) + # message = self.aggregate(message) + # assert False + else: + message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + + # message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x + message + + + def pro(self, x, source, x_mask=None, source_mask=None, profiler=None): + """ + Args: + x (torch.Tensor): [N, C, H1, W1] + source (torch.Tensor): [N, C, H2, W2] + x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1) + source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2) + """ + bs = x.size(0) + H1, W1 = x.size(-2), x.size(-1) + H2, W2 = source.size(-2), source.size(-1) + + query, key, value = x, source, source + + with profiler.profile("permute*6+norm1*3+max_pool*3"): + if self.dw_conv: + query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2) + else: + query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2) + # only need to cal key or value... + key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2) + value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2) + + with profiler.profile("permute*6"): + query = query.permute(0, 2, 3, 1) + key = key.permute(0, 2, 3, 1) + value = value.permute(0, 2, 3, 1) + + query = query.permute(0,3,1,2) + key = key.permute(0,3,1,2) + value = value.permute(0,3,1,2) + + # query = self.bnorm1(self.max_pool(query)) + # key = self.bnorm1(self.max_pool(key)) + # value = self.bnorm1(self.max_pool(value)) + # multi-head attention + + with profiler.profile("q_conv+k_conv+v_conv"): + query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool] + key = self.k_proj_conv(key) + value = self.v_proj_conv(value) + + C = query.shape[-3] + # TODO: Need to be consistent with bs=1 (where mask region do not in attention at all) + if x_mask is not None and source_mask is not None: + x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool] + source_mask = self.max_pool(source_mask.float()).bool() + + mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0] + mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0] + + query = query[:, :, :mask_h0, :mask_w0] + key = key[:, :, :mask_h1, :mask_w1] + value = value[:, :, :mask_h1, :mask_w1] + + # mask_h0, mask_w0 = data['mask0'][0].sum(-2)[0], data['mask0'][0].sum(-1)[0] + # mask_h1, mask_w1 = data['mask1'][0].sum(-2)[0], data['mask1'][0].sum(-1)[0] + # C = feat_c0.shape[-3] + # feat_c0 = feat_c0[:, :, :mask_h0, :mask_w0] + # feat_c1 = feat_c1[:, :, :mask_h1, :mask_w1] + + + # feat_c0 = feat_c0.reshape(-1, mask_h0, mask_w0, C) + # feat_c1 = feat_c1.reshape(-1, mask_h1, mask_w1, C) + # if mask_h0 != data['mask0'].size(-2): + # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0]-mask_h0, data['hw0_c'][1], C, device=feat_c0.device)], dim=1) + # elif mask_w0 != data['mask0'].size(-1): + # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0], data['hw0_c'][1]-mask_w0, C, device=feat_c0.device)], dim=2) + + # if mask_h1 != data['mask1'].size(-2): + # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0]-mask_h1, data['hw1_c'][1], C, device=feat_c1.device)], dim=1) + # elif mask_w1 != data['mask1'].size(-1): + # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0], data['hw1_c'][1]-mask_w1, C, device=feat_c1.device)], dim=2) + + + else: + assert x_mask is None and source_mask is None + + + + # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D] + # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D] + # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D] + + with profiler.profile("rearrange*3"): + query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D] + key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + + with profiler.profile("attention"): + message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] + + if x_mask is not None and source_mask is not None: + message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) + if mask_h0 != x_mask.size(-2): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2) + # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C] + + else: + assert x_mask is None and source_mask is None + + with profiler.profile("merge"): + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug??? + + with profiler.profile("rearrange*1"): + message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W] + + with profiler.profile("upsample"): + if self.scatter: + message = torch.repeat_interleave(message, self.pool_size, dim=-2) + message = torch.repeat_interleave(message, self.pool_size, dim=-1) + # message = self.aggregate(message) + # assert False + else: + message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + + # message = self.norm1(message) + + # feed-forward network + with profiler.profile("feed-forward_mlp+permute*2+norm2"): + message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x + message + + + def _build_projection(self, + dim_in, + dim_out, + kernel_size=3, + padding=1, + stride=1, + method='dw_bn', + ): + if method == 'dw_bn': + proj = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=dim_in + )), + ('bn', nn.BatchNorm2d(dim_in)), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + elif method == 'avg': + proj = nn.Sequential(OrderedDict([ + ('avg', nn.AvgPool2d( + kernel_size=kernel_size, + padding=padding, + stride=stride, + ceil_mode=True + )), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + elif method == 'linear': + proj = None + elif method == 'dw': + proj = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=dim_in + )), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + else: + raise ValueError('Unknown method ({})'.format(method)) + + return proj + +class AG_RoPE_EncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear', + pool_size=4, + pool_size2=4, + xformer=False, + leaky=-1.0, + dw_conv=False, + dw_conv2=False, + scatter=False, + norm_before=True, + rope=False, + npe=None, + vit_norm=False, + dw_proj=False, + ): + super(AG_RoPE_EncoderLayer, self).__init__() + + self.pool_size = pool_size + self.pool_size2 = pool_size2 + self.dw_conv = dw_conv + self.dw_conv2 = dw_conv2 + self.scatter = scatter + self.norm_before = norm_before + self.vit_norm = vit_norm + self.dw_proj = dw_proj + self.rope = rope + if self.dw_conv and self.pool_size != 1: + self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model) + if self.dw_conv2 and self.pool_size2 != 1: + self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size2, padding=0, stride=pool_size2, bias=False, groups=d_model) + + self.dim = d_model // nhead + self.nhead = nhead + + self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size2, stride=self.pool_size2) + + # multi-head attention + if self.dw_proj: + self.q_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model) + self.k_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model) + self.v_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model) + else: + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + if self.rope: + self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True) + + if xformer: + self.attention = XAttention() + else: + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + if leaky > 0: + if self.vit_norm: + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model*2, bias=False), + nn.LeakyReLU(leaky, True), + nn.Linear(d_model*2, d_model, bias=False), + ) + else: + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.LeakyReLU(leaky, True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + else: + if self.vit_norm: + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + else: + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + # self.norm1 = nn.BatchNorm2d(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, C, H1, W1] + source (torch.Tensor): [N, C, H2, W2] + x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1) + source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2) + """ + bs, C, H1, W1 = x.size() + H2, W2 = source.size(-2), source.size(-1) + + + if self.norm_before and not self.vit_norm: + if self.pool_size == 1: + query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C] + elif self.dw_conv: + query = self.norm1(self.aggregate(x).permute(0,2,3,1)) # [N, H, W, C] + else: + query = self.norm1(self.max_pool(x).permute(0,2,3,1)) # [N, H, W, C] + if self.pool_size2 == 1: + source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C] + elif self.dw_conv2: + source = self.norm1(self.aggregate2(source).permute(0,2,3,1)) # [N, H, W, C] + else: + source = self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C] + elif self.vit_norm: + if self.pool_size == 1: + query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C] + elif self.dw_conv: + query = self.aggregate(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C] + else: + query = self.max_pool(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C] + if self.pool_size2 == 1: + source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C] + elif self.dw_conv2: + source = self.aggregate2(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C] + else: + source = self.max_pool(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C] + else: + if self.pool_size == 1: + query = x.permute(0,2,3,1) # [N, H, W, C] + elif self.dw_conv: + query = self.aggregate(x).permute(0,2,3,1) # [N, H, W, C] + else: + query = self.max_pool(x).permute(0,2,3,1) # [N, H, W, C] + if self.pool_size2 == 1: + source = source.permute(0,2,3,1) # [N, H, W, C] + elif self.dw_conv2: + source = self.aggregate2(source).permute(0,2,3,1) # [N, H, W, C] + else: + source = self.max_pool(source).permute(0,2,3,1) # [N, H, W, C] + + # projection + if self.dw_proj: + query = self.q_proj(query.permute(0,3,1,2)).permute(0,2,3,1) + key = self.k_proj(source.permute(0,3,1,2)).permute(0,2,3,1) + value = self.v_proj(source.permute(0,3,1,2)).permute(0,2,3,1) + else: + query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source) + + # RoPE + if self.rope: + query = self.rope_pos_enc(query) + if self.pool_size == 1 and self.pool_size2 == 4: + key = self.rope_pos_enc(key, 4) + else: + key = self.rope_pos_enc(key) + + use_mask = x_mask is not None and source_mask is not None + if bs == 1 or not use_mask: + if use_mask: + # downsample mask + if self.pool_size ==1: + pass + else: + x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool] + + if self.pool_size2 ==1: + pass + else: + source_mask = self.max_pool(source_mask.float()).bool() + + mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0] + mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0] + + query = query[:, :mask_h0, :mask_w0, :] + key = key[:, :mask_h1, :mask_w1, :] + value = value[:, :mask_h1, :mask_w1, :] + else: + assert x_mask is None and source_mask is None + + if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D] + query = rearrange(query, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + key = rearrange(key, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + value = rearrange(value, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + else: # N L H D + query = rearrange(query, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim) + key = rearrange(key, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim) + value = rearrange(value, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim) + + message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D] + + if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D] + message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) + + if use_mask: # padding zero + message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L h D] + if mask_h0 != x_mask.size(-2): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2) + else: + assert x_mask is None and source_mask is None + + message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W] + + if self.pool_size == 1: + pass + else: + if self.scatter: + message = torch.repeat_interleave(message, self.pool_size, dim=-2) + message = torch.repeat_interleave(message, self.pool_size, dim=-1) + message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size) + else: + message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + + if not self.norm_before and not self.vit_norm: + message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W] + + # feed-forward network + if self.vit_norm: + message_inter = (x + message) + del x + message = self.norm2(message_inter.permute(0, 2, 3, 1)) + message = self.mlp(message).permute(0, 3, 1, 2) # [N, C, H1, W1] + return message_inter + message + else: + message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x + message + else: # mask with bs > 1 + if self.pool_size ==1: + pass + else: + x_mask = self.max_pool(x_mask.float()).bool() + + if self.pool_size2 ==1: + pass + else: + source_mask = self.max_pool(source_mask.float()).bool() + m_list = [] + for i in range(bs): + mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0] + mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0] + + q = query[i:i+1, :mask_h0, :mask_w0, :] + k = key[i:i+1, :mask_h1, :mask_w1, :] + v = value[i:i+1, :mask_h1, :mask_w1, :] + + if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D] + q = rearrange(q, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + k = rearrange(k, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + v = rearrange(v, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + else: # N L H D + q = rearrange(q, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim) + k = rearrange(k, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim) + v = rearrange(v, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim) + + m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D] + + if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D] + m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) + + m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim) + if mask_h0 != x_mask.size(-2): + m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2) + m_list.append(m) + m = torch.cat(m_list, dim=0) + + m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C] + # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked + m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W] + + if self.pool_size == 1: + pass + else: + if self.scatter: + m = torch.repeat_interleave(m, self.pool_size, dim=-2) + m = torch.repeat_interleave(m, self.pool_size, dim=-1) + m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size) + else: + m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + + + if not self.norm_before and not self.vit_norm: + m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W] + + # feed-forward network + if self.vit_norm: + m_inter = (x + m) + del x + m = self.norm2(m_inter.permute(0, 2, 3, 1)) + m = self.mlp(m).permute(0, 3, 1, 2) # [N, C, H1, W1] + return m_inter + m + else: + m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x + m + + return x + m + +class AG_Conv_EncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear', + pool_size=4, + bn=True, + xformer=False, + leaky=-1.0, + dw_conv=False, + dw_conv2=False, + scatter=False, + norm_before=True, + ): + super(AG_Conv_EncoderLayer, self).__init__() + + self.pool_size = pool_size + self.dw_conv = dw_conv + self.dw_conv2 = dw_conv2 + self.scatter = scatter + self.norm_before = norm_before + if self.dw_conv: + self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model) + if self.dw_conv2: + self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model) + self.dim = d_model // nhead + self.nhead = nhead + + self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size) + + # multi-head attention + if bn: + method = 'dw_bn' + else: + method = 'dw' + self.q_proj_conv = self._build_projection(d_model, d_model, method=method) + self.k_proj_conv = self._build_projection(d_model, d_model, method=method) + self.v_proj_conv = self._build_projection(d_model, d_model, method=method) + + if xformer: + self.attention = XAttention() + else: + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + if leaky > 0: + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.LeakyReLU(leaky, True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + else: + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, C, H1, W1] + source (torch.Tensor): [N, C, H2, W2] + x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1) + source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2) + """ + bs = x.size(0) + H1, W1 = x.size(-2), x.size(-1) + H2, W2 = source.size(-2), source.size(-1) + C = x.shape[-3] + + if self.norm_before: + if self.dw_conv: + query = self.norm1(self.aggregate(x).permute(0,2,3,1)).permute(0,3,1,2) + else: + query = self.norm1(self.max_pool(x).permute(0,2,3,1)).permute(0,3,1,2) + if self.dw_conv2: + source = self.norm1(self.aggregate2(source).permute(0,2,3,1)).permute(0,3,1,2) + else: + source = self.norm1(self.max_pool(source).permute(0,2,3,1)).permute(0,3,1,2) + else: + if self.dw_conv: + query = self.aggregate(x) + else: + query = self.max_pool(x) + if self.dw_conv2: + source = self.aggregate2(source) + else: + source = self.max_pool(source) + + key, value = source, source + + query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool] + key = self.k_proj_conv(key) + value = self.v_proj_conv(value) + + use_mask = x_mask is not None and source_mask is not None + if bs == 1 or not use_mask: + if use_mask: + x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool] + source_mask = self.max_pool(source_mask.float()).bool() + + mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0] + mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0] + + query = query[:, :, :mask_h0, :mask_w0] + key = key[:, :, :mask_h1, :mask_w1] + value = value[:, :, :mask_h1, :mask_w1] + + else: + assert x_mask is None and source_mask is None + + if PFLASH_AVAILABLE: # N H L D + query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + + else: # N L H D + query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D] + key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + + message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D] + + if PFLASH_AVAILABLE: # N H L D + message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) + + if use_mask: # padding zero + message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L H D] + if mask_h0 != x_mask.size(-2): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2) + else: + assert x_mask is None and source_mask is None + + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W] + + if self.scatter: + message = torch.repeat_interleave(message, self.pool_size, dim=-2) + message = torch.repeat_interleave(message, self.pool_size, dim=-1) + message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size) + else: + message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + + if not self.norm_before: + message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W] + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x + message + else: # mask with bs > 1 + x_mask = self.max_pool(x_mask.float()).bool() + source_mask = self.max_pool(source_mask.float()).bool() + m_list = [] + for i in range(bs): + mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0] + mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0] + + q = query[i:i+1, :, :mask_h0, :mask_w0] + k = key[i:i+1, :, :mask_h1, :mask_w1] + v = value[i:i+1, :, :mask_h1, :mask_w1] + + if PFLASH_AVAILABLE: # N H L D + q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim) + + else: # N L H D + q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D] + k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + + m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D] + + if PFLASH_AVAILABLE: # N H L D + m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) + + m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim) + if mask_h0 != x_mask.size(-2): + m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1) + elif mask_w0 != x_mask.size(-1): + m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2) + m_list.append(m) + m = torch.cat(m_list, dim=0) + + m = self.merge(m.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + + # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked + m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W] + + if self.scatter: + m = torch.repeat_interleave(m, self.pool_size, dim=-2) + m = torch.repeat_interleave(m, self.pool_size, dim=-1) + m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size) + else: + m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + + if not self.norm_before: + m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W] + + # feed-forward network + m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x + m + + def _build_projection(self, + dim_in, + dim_out, + kernel_size=3, + padding=1, + stride=1, + method='dw_bn', + ): + if method == 'dw_bn': + proj = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=dim_in + )), + ('bn', nn.BatchNorm2d(dim_in)), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + elif method == 'avg': + proj = nn.Sequential(OrderedDict([ + ('avg', nn.AvgPool2d( + kernel_size=kernel_size, + padding=padding, + stride=stride, + ceil_mode=True + )), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + elif method == 'linear': + proj = None + elif method == 'dw': + proj = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=dim_in + )), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + else: + raise ValueError('Unknown method ({})'.format(method)) + + return proj + + +class RoPELoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear', + rope=False, + token_mixer=None, + ): + super(RoPELoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + if token_mixer is None: + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.rope = rope + self.token_mixer = None + if token_mixer is not None: + self.token_mixer = token_mixer + if token_mixer == 'dwcn': + self.attention = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + d_model, + d_model, + kernel_size=3, + padding=1, + stride=1, + bias=False, + groups=d_model + )), + ])) + elif self.rope: + assert attention == 'linear' + self.attention = RoPELinearAttention() + + if token_mixer is None: + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + if token_mixer is None: + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + else: + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model, bias=False), + nn.ReLU(True), + nn.Linear(d_model, d_model, bias=False), + ) + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None, H=None, W=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, L, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + assert H*W == x.size(-2) + + # x = rearrange(x, 'n c h w -> n (h w) c') + # source = rearrange(source, 'n c h w -> n (h w) c') + query, key, value = x, source, source + + if self.token_mixer is not None: + # multi-head attention + m = self.norm1(x) + m = rearrange(m, 'n (h w) c -> n c h w', h=H, w=W) + m = self.attention(m) + m = rearrange(m, 'n c h w -> n (h w) c') + + x = x + m + x = x + self.mlp(self.norm2(x)) + return x + else: + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask, H=H, W=W) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear', + xformer=False, + ): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + if xformer: + self.attention = XAttention() + else: + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + def pro(self, x, source, x_mask=None, source_mask=None, profiler=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + with profiler.profile("proj*3"): + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + with profiler.profile("attention"): + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + with profiler.profile("merge"): + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + with profiler.profile("norm1"): + message = self.norm1(message) + + # feed-forward network + with profiler.profile("mlp"): + message = self.mlp(torch.cat([x, message], dim=2)) + with profiler.profile("norm2"): + message = self.norm2(message) + + return x + message + +class PANEncoderLayer_cross(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear', + pool_size=4, + bn=True, + ): + super(PANEncoderLayer_cross, self).__init__() + + self.pool_size = pool_size + + self.dim = d_model // nhead + self.nhead = nhead + + self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size) + # multi-head attention + if bn: + method = 'dw_bn' + else: + method = 'dw' + self.qk_proj_conv = self._build_projection(d_model, d_model, method=method) + self.v_proj_conv = self._build_projection(d_model, d_model, method=method) + + # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False) + # self.k_proj = nn.Linear(d_model, d_model, bias=False) + # self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + # self.norm1 = nn.BatchNorm2d(d_model) + + def forward(self, x1, x2, x1_mask=None, x2_mask=None): + """ + Args: + x (torch.Tensor): [N, C, H1, W1] + source (torch.Tensor): [N, C, H2, W2] + x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1) + source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2) + """ + bs = x1.size(0) + H1, W1 = x1.size(-2) // self.pool_size, x1.size(-1) // self.pool_size + H2, W2 = x2.size(-2) // self.pool_size, x2.size(-1) // self.pool_size + + query = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2) + key = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2) + v2 = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2) + v1 = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2) + + # multi-head attention + query = self.qk_proj_conv(query) # [N, C, H1//pool, W1//pool] + key = self.qk_proj_conv(key) + v2 = self.v_proj_conv(v2) + v1 = self.v_proj_conv(v1) + + C = query.shape[-3] + if x1_mask is not None and x2_mask is not None: + x1_mask = self.max_pool(x1_mask.float()).bool() # [N, H1//pool, W1//pool] + x2_mask = self.max_pool(x2_mask.float()).bool() + + mask_h1, mask_w1 = x1_mask[0].sum(-2)[0], x1_mask[0].sum(-1)[0] + mask_h2, mask_w2 = x2_mask[0].sum(-2)[0], x2_mask[0].sum(-1)[0] + + query = query[:, :, :mask_h1, :mask_w1] + key = key[:, :, :mask_h2, :mask_w2] + v1 = v1[:, :, :mask_h1, :mask_w1] + v2 = v2[:, :, :mask_h2, :mask_w2] + x1_mask = x1_mask[:, :mask_h1, :mask_w1] + x2_mask = x2_mask[:, :mask_h2, :mask_w2] + + else: + assert x1_mask is None and x2_mask is None + + query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D] + key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + v2 = rearrange(v2, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + v1 = rearrange(v1, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D] + if x2_mask is not None or x1_mask is not None: + x1_mask = x1_mask.flatten(-2) + x2_mask = x2_mask.flatten(-2) + + + QK = torch.einsum("nlhd,nshd->nlsh", query, key) + with torch.autocast(enabled=False, device_type='cuda'): + if x2_mask is not None or x1_mask is not None: + # S1 = S2.transpose(-2,-3).masked_fill(~(x_mask[:, None, :, None] * source_mask[:, :, None, None]), -1e9) # float('-inf') + QK = QK.float().masked_fill_(~(x1_mask[:, :, None, None] * x2_mask[:, None, :, None]), -1e9) # float('-inf') + + + # Compute the attention and the weighted average + softmax_temp = 1. / query.size(3)**.5 # sqrt(D) + S1 = torch.softmax(softmax_temp * QK, dim=2) + S2 = torch.softmax(softmax_temp * QK, dim=3) + + m1 = torch.einsum("nlsh,nshd->nlhd", S1, v2) + m2 = torch.einsum("nlsh,nlhd->nshd", S2, v1) + + if x1_mask is not None and x2_mask is not None: + m1 = m1.view(bs, mask_h1, mask_w1, self.nhead, self.dim) + if mask_h1 != H1: + m1 = torch.cat([m1, torch.zeros(m1.size(0), H1-mask_h1, W1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=1) + elif mask_w1 != W1: + m1 = torch.cat([m1, torch.zeros(m1.size(0), H1, W1-mask_w1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=2) + else: + assert x1_mask is None and x2_mask is None + + m1 = self.merge(m1.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C] + m1 = rearrange(m1, 'b (h w) c -> b c h w', h=H1, w=W1) # [N, C, H, W] + m1 = torch.nn.functional.interpolate(m1, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + # feed-forward network + m1 = self.mlp(torch.cat([x1, m1], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + m1 = self.norm2(m1).permute(0, 3, 1, 2) # [N, C, H1, W1] + + if x1_mask is not None and x2_mask is not None: + m2 = m2.view(bs, mask_h2, mask_w2, self.nhead, self.dim) + if mask_h2 != H2: + m2 = torch.cat([m2, torch.zeros(m2.size(0), H2-mask_h2, W2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=1) + elif mask_w2 != W2: + m2 = torch.cat([m2, torch.zeros(m2.size(0), H2, W2-mask_w2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=2) + else: + assert x1_mask is None and x2_mask is None + + m2 = self.merge(m2.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C] + m2 = rearrange(m2, 'b (h w) c -> b c h w', h=H2, w=W2) # [N, C, H, W] + m2 = torch.nn.functional.interpolate(m2, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1] + # feed-forward network + m2 = self.mlp(torch.cat([x2, m2], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C] + m2 = self.norm2(m2).permute(0, 3, 1, 2) # [N, C, H1, W1] + + return x1 + m1, x2 + m2 + + def _build_projection(self, + dim_in, + dim_out, + kernel_size=3, + padding=1, + stride=1, + method='dw_bn', + ): + if method == 'dw_bn': + proj = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=dim_in + )), + ('bn', nn.BatchNorm2d(dim_in)), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + elif method == 'avg': + proj = nn.Sequential(OrderedDict([ + ('avg', nn.AvgPool2d( + kernel_size=kernel_size, + padding=padding, + stride=stride, + ceil_mode=True + )), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + elif method == 'linear': + proj = None + elif method == 'dw': + proj = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=dim_in + )), + # ('rearrage', Rearrange('b c h w -> b (h w) c')), + ])) + else: + raise ValueError('Unknown method ({})'.format(method)) + + return proj + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.full_config = config + self.fine = False + if 'coarse' not in config: + self.fine = True # fine attention + else: + config = config['coarse'] + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + self.pan = config['pan'] + self.bidirect = config['bidirection'] + # prune + self.pool_size = config['pool_size'] + self.matchability = False + self.depth_confidence = -1.0 + self.width_confidence = -1.0 + # self.depth_confidence = config['depth_confidence'] + # self.width_confidence = config['width_confidence'] + # self.matchability = self.depth_confidence > 0 or self.width_confidence > 0 + # self.thr = self.full_config['match_coarse']['thr'] + if not self.fine: + # asy + self.asymmetric = config['asymmetric'] + self.asymmetric_self = config['asymmetric_self'] + # aggregate + self.aggregate = config['dwconv'] + # RoPE + self.rope = config['rope'] + # absPE + self.abspe = config['abspe'] + + else: + self.rope, self.asymmetric, self.asymmetric_self, self.aggregate = False, False, False, False + if self.matchability: + self.n_layers = len(self.layer_names) // 2 + assert self.n_layers == 4 + self.log_assignment = nn.ModuleList( + [MatchAssignment(self.d_model) for _ in range(self.n_layers)]) + self.token_confidence = nn.ModuleList([ + TokenConfidence(self.d_model) for _ in range(self.n_layers-1)]) + + self.CoarseMatching = CoarseMatching(self.full_config['match_coarse']) + + # self only + # if self.rope: + # self_layer = RoPELoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'], config['rope'], config['token_mixer']) + # self.layers = nn.ModuleList([copy.deepcopy(self_layer) for _ in range(len(self.layer_names))]) + + if self.bidirect: + assert config['xformer'] is False and config['pan'] is True + self_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'], config['xformer']) + cross_layer = PANEncoderLayer_cross(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn']) + self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names]) + else: + if self.aggregate: + if self.rope: + # assert config['npe'][0] == 832 and config['npe'][1] == 832 and config['npe'][2] == 832 and config['npe'][3] == 832 + logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}') + self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'], + config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'], + config['norm_before'], config['rope'], config['npe'], config['vit_norm'], config['rope_dwproj']) + cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'], + config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'], + config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj']) + self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names]) + elif self.abspe: + logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}') + self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'], + config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'], + config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj']) + cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'], + config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'], + config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj']) + self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names]) + + else: + encoder_layer = AG_Conv_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'], + config['xformer'], config['leaky'], config['dwconv'], config['scatter'], + config['norm_before']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + else: + encoder_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], + config['bn'], config['xformer'], config['leaky'], config['dwconv'], config['scatter']) \ + if config['pan'] else LoFTREncoderLayer(config['d_model'], config['nhead'], + config['attention'], config['xformer']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None, data=None): + """ + Args: + feat0 (torch.Tensor): [N, C, H, W] + feat1 (torch.Tensor): [N, C, H, W] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + # nchw for pan and n(hw)c for loftr + assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal" + H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1) + bs = feat0.shape[0] + padding = False + if bs == 1 and mask0 is not None and mask1 is not None and self.pan: # NCHW for pan + mask_H0, mask_W0 = mask0.size(-2), mask0.size(-1) + mask_H1, mask_W1 = mask1.size(-2), mask1.size(-1) + mask_h0, mask_w0 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0] + mask_h1, mask_w1 = mask1[0].sum(-2)[0], mask1[0].sum(-1)[0] + + #round to self.pool_size + if self.pan: + mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.pool_size*self.pool_size, mask_w0//self.pool_size*self.pool_size, mask_h1//self.pool_size*self.pool_size, mask_w1//self.pool_size*self.pool_size + + feat0 = feat0[:, :, :mask_h0, :mask_w0] + feat1 = feat1[:, :, :mask_h1, :mask_w1] + + padding = True + + # rope self only + # if self.rope: + # feat0, feat1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c') + # prune + if padding: + l0, l1 = mask_h0 * mask_w0, mask_h1 * mask_w1 + else: + l0, l1 = H0 * W0, H1 * W1 + do_early_stop = self.depth_confidence > 0 + do_point_pruning = self.width_confidence > 0 + if do_point_pruning: + ind0 = torch.arange(0, l0, device=feat0.device)[None] + ind1 = torch.arange(0, l1, device=feat0.device)[None] + # We store the index of the layer at which pruning is detected. + prune0 = torch.ones_like(ind0) + prune1 = torch.ones_like(ind1) + if do_early_stop: + token0, token1 = None, None + + for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)): + if padding: + mask0, mask1 = None, None + if name == 'self': + # if self.rope: + # feat0 = layer(feat0, feat0, mask0, mask1, H0, W0) + # feat1 = layer(feat1, feat1, mask0, mask1, H1, W1) + if self.asymmetric: + assert False, 'not worked' + # feat0 = layer(feat0, feat0, mask0, mask1) + feat1 = layer(feat1, feat1, mask1, mask1) + else: + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + if self.bidirect: + feat0, feat1 = layer(feat0, feat1, mask0, mask1) + else: + if self.asymmetric or self.asymmetric_self: + assert False, 'not worked' + feat0 = layer(feat0, feat1, mask0, mask1) + else: + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + + if i == len(self.layer_names) - 1 and not self.training: + continue + if self.matchability: + desc0, desc1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c') + if do_early_stop: + token0, token1 = self.token_confidence[i//2](desc0, desc1) + if self.check_if_stop(token0, token1, i, l0+l1) and not self.training: + break + if do_point_pruning: + scores0, scores1 = self.log_assignment[i//2].scores(desc0, desc1) + mask0 = self.get_pruning_mask(token0, scores0, i) + mask1 = self.get_pruning_mask(token1, scores1, i) + ind0, ind1 = ind0[mask0][None], ind1[mask1][None] + feat0, feat1 = desc0[mask0][None], desc1[mask1][None] + if feat0.shape[-2] == 0 or desc1.shape[-2] == 0: + break + prune0[:, ind0] += 1 + prune1[:, ind1] += 1 + if self.training and self.matchability: + scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1) + m0_full = torch.zeros((bs, mask_h0 * mask_w0), device=matchability0.device, dtype=matchability0.dtype) + m0_full.scatter(1, ind0, matchability0.squeeze(-1)) + if padding and self.d_model == feat0.size(1): + m0_full = m0_full.reshape(bs, mask_h0, mask_w0) + bs, c, mask_h0, mask_w0 = feat0.size() + if mask_h0 != mask_H0: + m0_full = torch.cat([m0_full, torch.zeros(bs, mask_H0-mask_h0, mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=1) + elif mask_w0 != mask_W0: + m0_full = torch.cat([m0_full, torch.zeros(bs, mask_h0, mask_W0-mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=2) + m0_full = m0_full.reshape(bs, mask_H0*mask_W0) + m1_full = torch.zeros((bs, mask_h1 * mask_w1), device=matchability0.device, dtype=matchability0.dtype) + m1_full.scatter(1, ind1, matchability1.squeeze(-1)) + if padding and self.d_model == feat1.size(1): + m1_full = m1_full.reshape(bs, mask_h1, mask_w1) + bs, c, mask_h1, mask_w1 = feat1.size() + if mask_h1 != mask_H1: + m1_full = torch.cat([m1_full, torch.zeros(bs, mask_H1-mask_h1, mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=1) + elif mask_w1 != mask_W1: + m1_full = torch.cat([m1_full, torch.zeros(bs, mask_h1, mask_W1-mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=2) + m1_full = m1_full.reshape(bs, mask_H1*mask_W1) + data.update({'matchability0_'+str(i//2): m0_full, 'matchability1_'+str(i//2): m1_full}) + m0, m1, mscores0, mscores1 = filter_matches( + scores, self.thr) + if do_point_pruning: + m0_ = torch.full((bs, l0), -1, device=m0.device, dtype=m0.dtype) + m1_ = torch.full((bs, l1), -1, device=m1.device, dtype=m1.dtype) + m0_[:, ind0] = torch.where( + m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) + m1_[:, ind1] = torch.where( + m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) + mscores0_ = torch.zeros((bs, l0), device=mscores0.device) + mscores1_ = torch.zeros((bs, l1), device=mscores1.device) + mscores0_[:, ind0] = mscores0 + mscores1_[:, ind1] = mscores1 + m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_ + if padding and self.d_model == feat0.size(1): + m0 = m0.reshape(bs, mask_h0, mask_w0) + bs, c, mask_h0, mask_w0 = feat0.size() + if mask_h0 != mask_H0: + m0 = torch.cat([m0, -torch.ones(bs, mask_H0-mask_h0, mask_w0, device=m0.device, dtype=m0.dtype)], dim=1) + elif mask_w0 != mask_W0: + m0 = torch.cat([m0, -torch.ones(bs, mask_h0, mask_W0-mask_w0, device=m0.device, dtype=m0.dtype)], dim=2) + m0 = m0.reshape(bs, mask_H0*mask_W0) + if padding and self.d_model == feat1.size(1): + m1 = m1.reshape(bs, mask_h1, mask_w1) + bs, c, mask_h1, mask_w1 = feat1.size() + if mask_h1 != mask_H1: + m1 = torch.cat([m1, -torch.ones(bs, mask_H1-mask_h1, mask_w1, device=m1.device, dtype=m1.dtype)], dim=1) + elif mask_w1 != mask_W1: + m1 = torch.cat([m1, -torch.ones(bs, mask_h1, mask_W1-mask_w1, device=m1.device, dtype=m1.dtype)], dim=2) + m1 = m1.reshape(bs, mask_H1*mask_W1) + data.update({'matches0_'+str(i//2): m0, 'matches1_'+str(i//2): m1}) + conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype) + ind = ind0[...,None] * l1 + ind1[:,None,:] + # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp() + conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp()) + if padding and self.d_model == feat0.size(1): + conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1) + bs, c, mask_h0, mask_w0 = feat0.size() + if mask_h0 != mask_H0: + conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1) + elif mask_w0 != mask_W0: + conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2) + bs, c, mask_h1, mask_w1 = feat1.size() + if mask_h1 != mask_H1: + conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3) + elif mask_w1 != mask_W1: + conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4) + conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1) + data.update({'conf_matrix_'+str(i//2): conf}) + + + + else: + raise KeyError + + if self.matchability and not self.training: + scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1) + conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype) + ind = ind0[...,None] * l1 + ind1[:,None,:] + # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp() + conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp()) + if padding and self.d_model == feat0.size(1): + conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1) + bs, c, mask_h0, mask_w0 = feat0.size() + if mask_h0 != mask_H0: + conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1) + elif mask_w0 != mask_W0: + conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2) + bs, c, mask_h1, mask_w1 = feat1.size() + if mask_h1 != mask_H1: + conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3) + elif mask_w1 != mask_W1: + conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4) + conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1) + data.update({'conf_matrix': conf}) + data.update(**self.CoarseMatching.get_coarse_match(conf, data)) + # m0, m1, mscores0, mscores1 = filter_matches( + # scores, self.conf.filter_threshold) + + # matches, mscores = [], [] + # for k in range(b): + # valid = m0[k] > -1 + # m_indices_0 = torch.where(valid)[0] + # m_indices_1 = m0[k][valid] + # if do_point_pruning: + # m_indices_0 = ind0[k, m_indices_0] + # m_indices_1 = ind1[k, m_indices_1] + # matches.append(torch.stack([m_indices_0, m_indices_1], -1)) + # mscores.append(mscores0[k][valid]) + + # # TODO: Remove when hloc switches to the compact format. + # if do_point_pruning: + # m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype) + # m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype) + # m0_[:, ind0] = torch.where( + # m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) + # m1_[:, ind1] = torch.where( + # m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) + # mscores0_ = torch.zeros((b, m), device=mscores0.device) + # mscores1_ = torch.zeros((b, n), device=mscores1.device) + # mscores0_[:, ind0] = mscores0 + # mscores1_[:, ind1] = mscores1 + # m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_ + + # pred = { + # 'matches0': m0, + # 'matches1': m1, + # 'matching_scores0': mscores0, + # 'matching_scores1': mscores1, + # 'stop': i+1, + # 'matches': matches, + # 'scores': mscores, + # } + + # if do_point_pruning: + # pred.update(dict(prune0=prune0, prune1=prune1)) + # return pred + + + if padding and self.d_model == feat0.size(1): + bs, c, mask_h0, mask_w0 = feat0.size() + if mask_h0 != mask_H0: + feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2) + elif mask_w0 != mask_W0: + feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1) + bs, c, mask_h1, mask_w1 = feat1.size() + if mask_h1 != mask_H1: + feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2) + elif mask_w1 != mask_W1: + feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1) + + return feat0, feat1 + + def pro(self, feat0, feat1, mask0=None, mask1=None, profiler=None): + """ + Args: + feat0 (torch.Tensor): [N, C, H, W] + feat1 (torch.Tensor): [N, C, H, W] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal" + with profiler.profile("LoFTR_transformer_attention"): + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer.pro(feat0, feat0, mask0, mask0, profiler=profiler) + feat1 = layer.pro(feat1, feat1, mask1, mask1, profiler=profiler) + elif name == 'cross': + feat0 = layer.pro(feat0, feat1, mask0, mask1, profiler=profiler) + feat1 = layer.pro(feat1, feat0, mask1, mask0, profiler=profiler) + else: + raise KeyError + + return feat0, feat1 + + def confidence_threshold(self, layer_index: int) -> float: + """ scaled confidence threshold """ + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers) + return np.clip(threshold, 0, 1) + + def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, + layer_index: int) -> torch.Tensor: + """ mask points which should be removed """ + threshold = self.confidence_threshold(layer_index) + if confidences is not None: + scores = torch.where( + confidences > threshold, scores, scores.new_tensor(1.0)) + return scores > (1 - self.width_confidence) + + def check_if_stop(self, + confidences0: torch.Tensor, + confidences1: torch.Tensor, + layer_index: int, num_points: int) -> torch.Tensor: + """ evaluate stopping condition""" + confidences = torch.cat([confidences0, confidences1], -1) + threshold = self.confidence_threshold(layer_index) + pos = 1.0 - (confidences < threshold).float().sum() / num_points + return pos > self.depth_confidence diff --git a/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25c261c973e8eeb6803ba6d21b5eb86992c3d857 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py @@ -0,0 +1,76 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class TokenConfidence(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.token = nn.Sequential( + nn.Linear(dim, 1), + nn.Sigmoid() + ) + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """ get confidence tokens """ + return ( + self.token(desc0.detach().float()).squeeze(-1), + self.token(desc1.detach().float()).squeeze(-1)) + +def sigmoid_log_double_softmax( + sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor: + """ create the log assignment matrix from logits and similarity""" + b, m, n = sim.shape + m0, m1 = torch.sigmoid(z0), torch.sigmoid(z1) + certainties = torch.log(m0) + torch.log(m1).transpose(1, 2) + scores0 = F.log_softmax(sim, 2) + scores1 = F.log_softmax( + sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = scores0 + scores1 + certainties + # scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) + # scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) + return scores, m0, m1 + +class MatchAssignment(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + self.matchability = nn.Linear(dim, 1, bias=True) + self.final_proj = nn.Linear(dim, dim, bias=True) + + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """ build assignment matrix from descriptors """ + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) + _, _, d = mdesc0.shape + mdesc0, mdesc1 = mdesc0 / d**.25, mdesc1 / d**.25 + sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1) + z0 = self.matchability(desc0) + z1 = self.matchability(desc1) + scores, m0, m1 = sigmoid_log_double_softmax(sim, z0, z1) + return scores, sim, m0, m1 + + def scores(self, desc0: torch.Tensor, desc1: torch.Tensor): + m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1) + m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1) + return m0, m1 + +def filter_matches(scores: torch.Tensor, th: float): + """ obtain matches from a log assignment matrix [Bx M+1 x N+1]""" + max0, max1 = scores.max(2), scores.max(1) + m0, m1 = max0.indices, max1.indices + indices0 = torch.arange(m0.shape[1], device=m0.device)[None] + indices1 = torch.arange(m1.shape[1], device=m1.device)[None] + mutual0 = indices0 == m1.gather(1, m0) + mutual1 = indices1 == m0.gather(1, m1) + max0_exp = max0.values.exp() + zero = max0_exp.new_tensor(0) + mscores0 = torch.where(mutual0, max0_exp, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) + if th is not None: + valid0 = mutual0 & (mscores0 > th) + else: + valid0 = mutual0 + valid1 = mutual1 & valid0.gather(1, m1) + m0 = torch.where(valid0, m0, -1) + m1 = torch.where(valid1, m1, -1) + return m0, m1, mscores0, mscores1 diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py b/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8dfca8227423ed699ea736e23b516bed68c19d --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py @@ -0,0 +1,266 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + +from loguru import logger + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + self.mtd = config['mtd_spvs'] + self.fix_bias = config['fix_bias'] + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + if self.match_type == 'dual_softmax': + with torch.autocast(enabled=False, device_type='cuda'): + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + if mask_c0 is not None: + sim_matrix = sim_matrix.float().masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF + # float("-inf") if sim_matrix.dtype == torch.float16 else -INF + ) + if self.config['fp16log']: + t1 = F.softmax(sim_matrix, 1) + t2 = F.softmax(sim_matrix, 2) + conf_matrix = t1*t2 + logger.info(f'feat_c0absmax: {feat_c0.abs().max()}') + logger.info(f'feat_c1absmax: {feat_c1.abs().max()}') + logger.info(f'sim_matrix: {sim_matrix.dtype}') + logger.info(f'sim_matrixabsmax: {sim_matrix.abs().max()}') + logger.info(f't1: {t1.dtype}, t2: {t2.dtype}, conf_matrix: {conf_matrix.dtype}') + logger.info(f't1absmax: {t1.abs().max()}, t2absmax: {t2.abs().max()}, conf_matrixabsmax: {conf_matrix.abs().max()}') + else: + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + data.update({'conf_matrix': conf_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + if self.mtd: + b_ids, i_ids, j_ids = torch.where(mask) + mconf = conf_matrix[b_ids, i_ids, j_ids] + else: + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + if self.fix_bias: + scale = 8 + else: + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + m_bids = b_ids[mconf != 0] + + m_bids_f = repeat(m_bids, 'b -> b k', k = 3).reshape(-1) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': m_bids, # mconf == 0 => gt matches + 'm_bids_f': m_bids_f, + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py b/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..7be172a9bf9d45e3cbcf33a4926abddfca877629 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py @@ -0,0 +1,493 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + +from loguru import logger + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self, config): + super().__init__() + self.config = config + self.topk = config['match_fine']['topk'] + self.mtd_spvs = config['fine']['mtd_spvs'] + self.align_corner = config['align_corner'] + self.fix_bias = config['fix_bias'] + self.normfinem = config['match_fine']['normfinem'] + self.fix_fine_matching = config['match_fine']['fix_fine_matching'] + self.mutual_nearest = config['match_fine']['force_nearest'] + self.skip_fine_softmax = config['match_fine']['skip_fine_softmax'] + self.normfeat = config['match_fine']['normfeat'] + self.use_sigmoid = config['match_fine']['use_sigmoid'] + self.local_regress = config['match_fine']['local_regress'] + self.local_regress_rmborder = config['match_fine']['local_regress_rmborder'] + self.local_regress_nomask = config['match_fine']['local_regress_nomask'] + self.local_regress_temperature = config['match_fine']['local_regress_temperature'] + self.local_regress_padone = config['match_fine']['local_regress_padone'] + self.local_regress_slice = config['match_fine']['local_regress_slice'] + self.local_regress_slicedim = config['match_fine']['local_regress_slicedim'] + self.local_regress_inner = config['match_fine']['local_regress_inner'] + self.multi_regress = config['match_fine']['multi_regress'] + def forward(self, feat_0, feat_1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_0.shape + W = int(math.sqrt(WW)) + if self.fix_bias: + scale = 2 + else: + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + if self.mtd_spvs: + data.update({ + 'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + # if self.local_regress: + # data.update({ + # 'sim_matrix_f': torch.empty(0, WW, WW, device=feat_0.device), + # }) + return + else: + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + if self.mtd_spvs: + with torch.autocast(enabled=False, device_type='cuda'): + # feat_0 = feat_0 / feat_0.size(-2) + if self.local_regress_slice: + feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:] + feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim] + conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5) + else: + feat_f0, feat_f1 = feat_0, feat_1 + if self.normfinem: + feat_f0 = feat_f0 / C**.5 + feat_f1 = feat_f1 / C**.5 + conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1) + else: + if self.local_regress_slice: + conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / (C - self.local_regress_slicedim)**.5) + else: + conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1 / C**.5) + + if self.normfeat: + feat_f0, feat_f1 = torch.nn.functional.normalize(feat_f0.float(), p=2, dim=-1), torch.nn.functional.normalize(feat_f1.float(), p=2, dim=-1) + + if self.config['fp16log']: + logger.info(f'sim_matrix: {conf_matrix_f.abs().max()}') + # sim_matrix *= 1. / C**.5 # normalize + + if self.multi_regress: + assert not self.local_regress + assert not self.normfinem and not self.normfeat + heatmap = F.softmax(conf_matrix_f, 2).view(M, WW, W, W) # [M, WW, W, W] + + assert (W - 2) == (self.config['resolution'][0] // self.config['resolution'][1]) # c8 + windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1]) + + coords_normalized = dsnt.spatial_expectation2d(heatmap, True) * windows_scale # [M, WW, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2)[:,None,:,:] * windows_scale # [1, 1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(M, WW, WW, 1), dim=-2) - coords_normalized**2 # ([1,1,WW,2] * [M,WW,WW,1])->[M,WW,2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M,WW] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(-1)], -1)}) # [M, WW, 2] + + # get the least uncertain matches + val, idx = torch.topk(std, self.topk, dim=-1, largest=False) # [M,topk] + coords_normalized = coords_normalized[torch.arange(M, device=conf_matrix_f.device, dtype=torch.long)[:,None], idx] # [M,topk] + + grid = create_meshgrid(W, W, False, idx.device) - W // 2 + 0.5 # [1, W, W, 2] + grid = grid.reshape(1, -1, 2).expand(M, -1, -1) # [M, WW, 2] + delta_l = torch.gather(grid, 1, idx.unsqueeze(-1).expand(-1, -1, 2)) # [M, topk, 2] in (x, y) + + # compute absolute kpt coords + self.get_multi_fine_match_align(delta_l, coords_normalized, data) + + + else: + + if self.skip_fine_softmax: + pass + elif self.use_sigmoid: + conf_matrix_f = torch.sigmoid(conf_matrix_f) + else: + if self.local_regress: + del feat_f0, feat_f1 + softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2) + # softmax_matrix_f = conf_matrix_f + if self.local_regress_inner: + softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2) + softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW) + # if self.training: + # for fine-level supervision + data.update({'conf_matrix_f': softmax_matrix_f}) + if self.local_regress_slice: + data.update({'sim_matrix_ff': conf_matrix_ff}) + else: + data.update({'sim_matrix_f': conf_matrix_f}) + + else: + conf_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2) + + # for fine-level supervision + data.update({'conf_matrix_f': conf_matrix_f}) + + # compute absolute kpt coords + if self.local_regress: + self.get_fine_ds_match(softmax_matrix_f, data) + del softmax_matrix_f + idx_l, idx_r = data['idx_l'], data['idx_r'] + del data['idx_l'], data['idx_r'] + m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1).expand(-1, self.topk) + # if self.training: + m_ids = m_ids[:len(data['mconf']) // self.topk] + idx_r_iids, idx_r_jids = idx_r // W, idx_r % W + + # remove boarder + if self.local_regress_nomask: + # log for inner precent + # mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2) + # mask_sum = mask.sum() + # logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}') + mask = None + m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1) + if self.local_regress_inner: # been sliced before + delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2] + else: + # no mask + 1 for padding + delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) + torch.tensor([1], dtype=torch.long, device=conf_matrix_f.device) # [1, 3, 3, 2] + + m_ids = m_ids[...,None,None].expand(-1, 3, 3) + idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3] + + idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1] + idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0] + + if idx_l.numel() == 0: + data.update({ + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + if self.local_regress_slice: + conf_matrix_f = conf_matrix_ff + if self.local_regress_inner: + conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W+2, self.W+2) + else: + conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W) + conf_matrix_f = F.pad(conf_matrix_f, (1,1,1,1)) + else: + mask = (idx_r_iids >= 1) & (idx_r_iids <= W-2) & (idx_r_jids >= 1) & (idx_r_jids <= W-2) + if W == 10: + idx_l_iids, idx_l_jids = idx_l // W, idx_l % W + mask = mask & (idx_l_iids >= 1) & (idx_l_iids <= W-2) & (idx_l_jids >= 1) & (idx_l_jids <= W-2) + + m_ids = m_ids[mask].to(torch.long) + idx_l, idx_r_iids, idx_r_jids = idx_l[mask].to(torch.long), idx_r_iids[mask].to(torch.long), idx_r_jids[mask].to(torch.long) + + m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1) + mask = mask.reshape(-1) + + delta = create_meshgrid(3, 3, True, conf_matrix_f.device).to(torch.long) # [1, 3, 3, 2] + + m_ids = m_ids[:,None,None].expand(-1, 3, 3) + idx_l = idx_l[:,None,None].expand(-1, 3, 3) # [m, 3, 3] + # bug !!!!!!!!! 1,0 rather 0,1 + # idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0] + # idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1] + idx_r_iids = idx_r_iids[:,None,None].expand(-1, 3, 3) + delta[..., 1] + idx_r_jids = idx_r_jids[:,None,None].expand(-1, 3, 3) + delta[..., 0] + + if idx_l.numel() == 0: + data.update({ + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + if not self.local_regress_slice: + conf_matrix_f = conf_matrix_f.reshape(M, self.WW, self.W, self.W) + else: + conf_matrix_f = conf_matrix_ff.reshape(M, self.WW, self.W, self.W) + + conf_matrix_f = conf_matrix_f[m_ids, idx_l, idx_r_iids, idx_r_jids] + conf_matrix_f = conf_matrix_f.reshape(-1, 9) + if self.local_regress_padone: # follow the training detach the gradient of center + conf_matrix_f[:,4] = -1e4 + heatmap = F.softmax(conf_matrix_f / self.local_regress_temperature, -1) + logger.info(f'maxmax&maxmean of heatmap: {heatmap.view(-1).max()}, {heatmap.view(-1).min(), heatmap.max(-1)[0].mean()}') + heatmap[:,4] = 1.0 # no need gradient calculation in inference + logger.info(f'min of heatmap: {heatmap.view(-1).min()}') + heatmap = heatmap.reshape(-1, 3, 3) + # heatmap = torch.ones_like(softmax) # ones_like for detach the gradient of center + # heatmap[:,:4], heatmap[:,5:] = softmax[:,:4], softmax[:,5:] + # heatmap = heatmap.reshape(-1, 3, 3) + else: + conf_matrix_f = F.softmax(conf_matrix_f / self.local_regress_temperature, -1) + # logger.info(f'max&min&mean of heatmap: {conf_matrix_f.view(-1).max()}, {conf_matrix_f.view(-1).min(), conf_matrix_f.max(-1)[0].mean()}') + heatmap = conf_matrix_f.reshape(-1, 3, 3) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + + # coords_normalized_l2 = coords_normalized.norm(p=2, dim=-1) + # logger.info(f'mean&max&min abs of local: {coords_normalized_l2.mean(), coords_normalized_l2.max(), coords_normalized_l2.min()}') + + # compute absolute kpt coords + + if data['bs'] == 1: + scale1 = scale * data['scale1'] if 'scale0' in data else scale + else: + if mask is not None: + scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2)[mask] if 'scale0' in data else scale + else: + scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']) // self.topk,...][:,None,:].expand(-1, self.topk, 2).reshape(-1, 2) if 'scale0' in data else scale + + self.get_fine_match_local(coords_normalized, data, scale1, mask, True) + + else: + self.get_fine_ds_match(conf_matrix_f, data) + + + else: + if self.align_corner is True: + feat_f0, feat_f1 = feat_0, feat_1 + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + else: + feat_f0, feat_f1 = feat_0, feat_1 + # even matching windows while coarse grid not aligned to fine grid!!! + # assert W == 5, "others size not checked" + if self.fix_bias: + assert W % 2 == 1, "W must be odd when select" + feat_f0_picked = feat_f0[:, WW//2] + + else: + # assert W == 6, "others size not checked" + assert W % 2 == 0, "W must be even when coarse grid not aligned to fine grid(average)" + feat_f0_picked = (feat_f0[:, WW//2 - W//2 - 1] + feat_f0[:, WW//2 - W//2] + feat_f0[:, WW//2 + W//2] + feat_f0[:, WW//2 + W//2 - 1]) / 4 + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + windows_scale = (W - 1) / (self.config['resolution'][0] // self.config['resolution'][1]) + + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] * windows_scale # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) * windows_scale # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match_align(coords_normalized, data) + + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) + + def get_fine_match_local(self, coords_normed, data, scale1, mask, reserve_border=True): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + if mask is None: + mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c'] + else: + data['mkpts0_c'], data['mkpts1_c'] = data['mkpts0_c'].reshape(-1, 2), data['mkpts1_c'].reshape(-1, 2) + mkpts0_c, mkpts1_c = data['mkpts0_c'][mask], data['mkpts1_c'][mask] + mask_sum = mask.sum() + logger.info(f'total fine match: {mask.numel()}; regressed fine match: {mask_sum}, per: {mask_sum / mask.numel()}') + # print(mkpts0_c.shape, mkpts1_c.shape, coords_normed.shape, scale1.shape) + # print(data['mkpts0_c'].shape, data['mkpts1_c'].shape) + # mkpts0_f and mkpts1_f + mkpts0_f = mkpts0_c + mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1) + + if reserve_border and mask is not None: + mkpts0_f, mkpts1_f = torch.cat([mkpts0_f, data['mkpts0_c'][~mask].reshape(-1, 2)]), torch.cat([mkpts1_f, data['mkpts1_c'][~mask].reshape(-1, 2)]) + else: + pass + + del data['mkpts0_c'], data['mkpts1_c'] + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) + + # can be used for both aligned and not aligned + @torch.no_grad() + def get_fine_match_align(self, coord_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + c2f = self.config['resolution'][0] // self.config['resolution'][1] + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coord_normed * (c2f // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) + + @torch.no_grad() + def get_multi_fine_match_align(self, delta_l, coord_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + c2f = self.config['resolution'][0] // self.config['resolution'][1] + # mkpts0_f and mkpts1_f + scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device) + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else torch.tensor([[scale, scale]], device=delta_l.device) + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (coord_normed * (c2f // 2) * scale1[:,None,:])[:len(data['mconf'])]).reshape(-1, 2) + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f, + "mconf": data['mconf'][:,None].expand(-1, self.topk).reshape(-1) + }) + + @torch.no_grad() + def get_fine_ds_match(self, conf_matrix, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # select topk matches + m, _, _ = conf_matrix.shape + + + if self.mutual_nearest: + pass + + + elif not self.fix_fine_matching: # only allow one2mul but mul2one + + val, idx_r = conf_matrix.max(-1) # (m, WW), (m, WW) + val, idx_l = torch.topk(val, self.topk, dim = -1) # (m, topk), (m, topk) + idx_r = torch.gather(idx_r, 1, idx_l) # (m, topk) + + # mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate + # grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2) + grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5 # (1, W, W, 2) + grid = grid.reshape(1, -1, 2).expand(m, -1, -1) # (m, WW, 2) + delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2) + delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2)) # (m, topk, 2) + + # mkpts0_f and mkpts1_f + scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + + if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1 + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2) + else: # scale0 is a float + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2) + + else: # allow one2mul mul2one and mul2mul + conf_matrix = conf_matrix.reshape(m, -1) + if self.local_regress: # for the compatibility of former config + conf_matrix = conf_matrix[:len(data['mconf']),...] + val, idx = torch.topk(conf_matrix, self.topk, dim = -1) + idx_l = idx // WW + idx_r = idx % WW + + if self.local_regress: + data.update({'idx_l': idx_l, 'idx_r': idx_r}) + + # mkpts0_c use xy coordinate, so we don't need to convert it to hw coordinate + # grid = create_meshgrid(W, W, False, conf_matrix.device).transpose(-3,-2) - W // 2 + 0.5 # (1, W, W, 2) + grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5 + grid = grid.reshape(1, -1, 2).expand(m, -1, -1) + delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2)) + delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2)) + + # mkpts0_f and mkpts1_f + scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + + if self.local_regress: + if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1 + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2) + else: # scale0 is a float + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2) + + else: + if torch.is_tensor(scale0) and scale0.numel() > 1: # num of scale0 > 1 + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:,None,:])[:len(data['mconf']),...]).reshape(-1, 2) + else: # scale0 is a float + mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)[:len(data['mconf']),...]).reshape(-1, 2) + mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)[:len(data['mconf']),...]).reshape(-1, 2) + del data['mkpts0_c'], data['mkpts1_c'] + data['mconf'] = data['mconf'].reshape(-1, 1).expand(-1, self.topk).reshape(-1) + # data['mconf'] = val.reshape(-1)[:len(data['mconf'])]*0.1 + data['mconf'] + + if self.local_regress: + data.update({ + "mkpts0_c": mkpts0_f, + "mkpts1_c": mkpts1_f + }) + else: + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) + diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py b/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..47de76bd8d8928b123bc7357349b1e7ae4ee90ac --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py @@ -0,0 +1,298 @@ +import torch +from src.utils.homography_utils import warp_points_torch + +def get_unique_indices(input_tensor): + if input_tensor.shape[0] > 1: + unique, inverse = torch.unique(input_tensor, sorted=True, return_inverse=True, dim=0) + perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) + inverse, perm = inverse.flip([0]), perm.flip([0]) + perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) + else: + perm = torch.zeros((input_tensor.shape[0],), dtype=torch.long, device=input_tensor.device) + return perm + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, consistency_thr=0.2, cycle_proj_distance_thr=3.0): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < consistency_thr + + # Cycle Consistency Check + dst_pts_h = torch.cat([w_kpts0, torch.ones_like(w_kpts0[..., [0]], device=w_kpts0.device)], dim=-1) * w_kpts0_depth[..., None] # B * N_dst * N_pts * 3 + dst_pts_cam = K1.inverse() @ dst_pts_h.transpose(2, 1) # (N, 3, L) + dst_pose = T_0to1.inverse() + world_points_cycle_back = dst_pose[:, :3, :3] @ dst_pts_cam + dst_pose[:, :3, [3]] + src_warp_back_h = (K0 @ world_points_cycle_back).transpose(2, 1) # (N, L, 3) + src_back_proj_pts = src_warp_back_h[..., :2] / (src_warp_back_h[..., [2]] + 1e-4) + cycle_reproj_distance_mask = torch.linalg.norm(src_back_proj_pts - kpts0[:, None], dim=-1) < cycle_proj_distance_thr + + valid_mask = nonzero_mask * covisible_mask * consistent_mask * cycle_reproj_distance_mask + + return valid_mask, w_kpts0 + +@torch.no_grad() +def warp_kpts_by_sparse_gt_matches_batches(kpts0, gt_matches, dist_thr): + B, n_pts = kpts0.shape[0], kpts0.shape[1] + if n_pts > 20 * 10000: + all_kpts_valid_mask, all_kpts_warpped = [], [] + for b_id in range(B): + kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches(kpts0[[b_id]], gt_matches[[b_id]], dist_thr[[b_id]]) + all_kpts_valid_mask.append(kpts_valid_mask) + all_kpts_warpped.append(kpts_warpped) + return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0) + else: + return warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr) + +@torch.no_grad() +def warp_kpts_by_sparse_gt_matches(kpts0, gt_matches, dist_thr): + kpts_warpped = torch.zeros_like(kpts0) + kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool) + gt_matches_non_padding_mask = gt_matches.sum(-1) > 0 + + dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M + if dist_thr is not None: + mask = dist_matrix < dist_thr[:, None, None] + else: + mask = torch.ones_like(dist_matrix, dtype=torch.bool) + # Mutual-Nearest check: + mask = mask \ + * (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \ + * (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0]) + + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + + j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1)) + b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids]) + + i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1)) + b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids]) + + kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids] + kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids] + + return kpts_valid_mask, kpts_warpped + +@torch.no_grad() +def warp_kpts_by_sparse_gt_matches_fine_chunks(kpts0, gt_matches, dist_thr): + B, n_pts = kpts0.shape[0], kpts0.shape[1] + chunk_n = 500 + all_kpts_valid_mask, all_kpts_warpped = [], [] + for b_id in range(0, B, chunk_n): + kpts_valid_mask, kpts_warpped = warp_kpts_by_sparse_gt_matches_fine(kpts0[b_id : b_id+chunk_n], gt_matches, dist_thr) + all_kpts_valid_mask.append(kpts_valid_mask) + all_kpts_warpped.append(kpts_warpped) + return torch.cat(all_kpts_valid_mask, dim=0), torch.cat(all_kpts_warpped, dim=0) + +@torch.no_grad() +def warp_kpts_by_sparse_gt_matches_fine(kpts0, gt_matches, dist_thr): + """ + Only support single batch + Input: + kpts0: N * ww * 2 + gt_matches: M * 2 + """ + B = kpts0.shape[0] # B is the fine matches in a single pair + assert gt_matches.shape[0] == 1 + kpts_warpped = torch.zeros_like(kpts0) + kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool) + gt_matches_non_padding_mask = gt_matches.sum(-1) > 0 + + dist_matrix = torch.cdist(kpts0, gt_matches[..., :2]) # B * N * M + if dist_thr is not None: + mask = dist_matrix < dist_thr[:, None, None] + else: + mask = torch.ones_like(dist_matrix, dtype=torch.bool) + # Mutual-Nearest check: + mask = mask \ + * (dist_matrix == dist_matrix.min(dim=2, keepdim=True)[0]) \ + * (dist_matrix == dist_matrix.min(dim=1, keepdim=True)[0]) + + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + + j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1)) + b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids]) + + i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1)) + b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids]) + + kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[0, j_ids] + kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][0, j_ids] + + return kpts_valid_mask, kpts_warpped + +@torch.no_grad() +def warp_kpts_by_sparse_gt_matches_fast(kpts0, gt_matches, scale0, current_h, current_w): + B, n_gt_pts = gt_matches.shape[0], gt_matches.shape[1] + kpts_warpped = torch.zeros_like(kpts0) + kpts_valid_mask = torch.zeros_like(kpts0[..., 0], dtype=torch.bool) + gt_matches_non_padding_mask = gt_matches.sum(-1) > 0 + + all_j_idxs = torch.arange(gt_matches.shape[-2], device=gt_matches.device, dtype=torch.long)[None].expand(B, n_gt_pts) + all_b_idxs = torch.arange(B, device=gt_matches.device, dtype=torch.long)[:, None].expand(B, n_gt_pts) + gt_matches_rescale = gt_matches[..., :2] / scale0 # From original img scale to resized scale + in_boundary_mask = (gt_matches_rescale[..., 0] <= current_w-1) & (gt_matches_rescale[..., 0] >= 0) & (gt_matches_rescale[..., 1] <= current_h -1) & (gt_matches_rescale[..., 1] >= 0) + + gt_matches_rescale = gt_matches_rescale.round().to(torch.long) + all_i_idxs = gt_matches_rescale[..., 1] * current_w + gt_matches_rescale[..., 0] # idx = y * w + x + + # Filter: + b_ids, i_ids, j_ids = map(lambda x: x[gt_matches_non_padding_mask & in_boundary_mask], [all_b_idxs, all_i_idxs, all_j_idxs]) + + j_uq_indices = get_unique_indices(torch.stack([b_ids, j_ids], dim=-1)) + b_ids, i_ids, j_ids = map(lambda x: x[j_uq_indices], [b_ids, i_ids, j_ids]) + + i_uq_indices = get_unique_indices(torch.stack([b_ids, i_ids], dim=-1)) + b_ids, i_ids, j_ids = map(lambda x: x[i_uq_indices], [b_ids, i_ids, j_ids]) + + kpts_valid_mask[b_ids, i_ids] = gt_matches_non_padding_mask[b_ids, j_ids] + kpts_warpped[b_ids, i_ids] = gt_matches[..., 2:][b_ids, j_ids] + + return kpts_valid_mask, kpts_warpped + + +@torch.no_grad() +def homo_warp_kpts(kpts0, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None): + """ + original_size1: N * 2, (h, w) + """ + normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L) + kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3) + kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2 + valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \ + & (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L + if original_size0 is not None: + valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \ + & (kpts0[..., 1] < original_size0[:, [0]]) # N * L + + return valid_mask, kpts_warpped + +@torch.no_grad() +# if using mask in homo warp(for coarse supervision) +def homo_warp_kpts_with_mask(kpts0, scale, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None): + """ + original_size1: N * 2, (h, w) + """ + normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L) + kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3) + kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2 + # get coarse-level depth_mask + depth_mask_coarse = depth_mask[:, :, ::scale, ::scale] + depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1) + + valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \ + & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L + if original_size0 is not None: + valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \ + & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L + + return valid_mask, kpts_warpped + +@torch.no_grad() +# if using mask in homo warp(for fine supervision) +def homo_warp_kpts_with_mask_f(kpts0, depth_mask, norm_pixel_mat, homo_sample_normed, original_size0=None, original_size1=None): + """ + original_size1: N * 2, (h, w) + """ + normed_kpts0_h = norm_pixel_mat @ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1).transpose(2, 1) # (N * 3 * L) + kpts_warpped_h = (torch.linalg.inv(norm_pixel_mat) @ homo_sample_normed @ normed_kpts0_h).transpose(2, 1) # (N * L * 3) + kpts_warpped = kpts_warpped_h[..., :2] / kpts_warpped_h[..., [2]] # N * L * 2 + valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \ + & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L + if original_size0 is not None: + valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \ + & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L + + return valid_mask, kpts_warpped + +@torch.no_grad() +def homo_warp_kpts_glue(kpts0, homo, original_size0=None, original_size1=None): + """ + original_size1: N * 2, (h, w) + """ + kpts_warpped = warp_points_torch(kpts0, homo, inverse=False) + valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \ + & (kpts_warpped[..., 1] < original_size1[:, [0]]) # N * L + if original_size0 is not None: + valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \ + & (kpts0[..., 1] < original_size0[:, [0]]) # N * L + return valid_mask, kpts_warpped + +@torch.no_grad() +# if using mask in homo warp(for coarse supervision) +def homo_warp_kpts_glue_with_mask(kpts0, scale, depth_mask, homo, original_size0=None, original_size1=None): + """ + original_size1: N * 2, (h, w) + """ + kpts_warpped = warp_points_torch(kpts0, homo, inverse=False) + # get coarse-level depth_mask + depth_mask_coarse = depth_mask[:, :, ::scale, ::scale] + depth_mask_coarse = depth_mask_coarse.reshape(depth_mask.shape[0], -1) + + valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \ + & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask_coarse != 0) # N * L + if original_size0 is not None: + valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \ + & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask_coarse != 0) # N * L + return valid_mask, kpts_warpped + +@torch.no_grad() +# if using mask in homo warp(for fine supervision) +def homo_warp_kpts_glue_with_mask_f(kpts0, depth_mask, homo, original_size0=None, original_size1=None): + """ + original_size1: N * 2, (h, w) + """ + kpts_warpped = warp_points_torch(kpts0, homo, inverse=False) + valid_mask = (kpts_warpped[..., 0] > 0) & (kpts_warpped[..., 0] < original_size1[:, [1]]) & (kpts_warpped[..., 1] > 0) \ + & (kpts_warpped[..., 1] < original_size1[:, [0]]) & (depth_mask != 0) # N * L + if original_size0 is not None: + valid_mask *= (kpts0[..., 0] > 0) & (kpts0[..., 0] < original_size0[:, [1]]) & (kpts0[..., 1] > 0) \ + & (kpts0[..., 1] < original_size0[:, [0]]) & (depth_mask != 0) # N * L + return valid_mask, kpts_warpped \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py b/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a4b4780943617588cb193efb51a261ebc17cda --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py @@ -0,0 +1,131 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True, npe=False): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + + assert npe is not None + if npe is not None: + if isinstance(npe, bool): + train_res_H, train_res_W, test_res_H, test_res_W = 832, 832, 832, 832 + print('loftr no npe!!!!', npe) + else: + print('absnpe!!!!', npe) + train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W + y_position, x_position = y_position * train_res_H / test_res_H, x_position * train_res_W / test_res_W + + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] + +class RoPEPositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + # pe = torch.zeros((d_model, *max_shape)) + # y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) + # x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) + i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1] + j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1] + + assert npe is not None + if npe is not None: + train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W + i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W + + div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4))) + div_term = div_term[None, None, :] # [1, 1, C//4] + # pe[0::4, :, :] = torch.sin(x_position * div_term) + # pe[1::4, :, :] = torch.cos(x_position * div_term) + # pe[2::4, :, :] = torch.sin(y_position * div_term) + # pe[3::4, :, :] = torch.cos(y_position * div_term) + sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) + cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) + sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term) + sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term) + cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term) + cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term) + + sin = sin.repeat_interleave(2, dim=-1) + cos = cos.repeat_interleave(2, dim=-1) + # self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, H, W, C] + self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2] + self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2] + + i_position4 = i_position.reshape(64,4,64,4,1)[...,0,:] + i_position4 = i_position4.mean(-3) + j_position4 = j_position.reshape(64,4,64,4,1)[:,0,...] + j_position4 = j_position4.mean(-2) + sin4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) + cos4 = torch.zeros(max_shape[0]//4, max_shape[1]//4, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32) + sin4[:, :, 0::2] = torch.sin(i_position4 * div_term).half() if ropefp16 else torch.sin(i_position4 * div_term) + sin4[:, :, 1::2] = torch.sin(j_position4 * div_term).half() if ropefp16 else torch.sin(j_position4 * div_term) + cos4[:, :, 0::2] = torch.cos(i_position4 * div_term).half() if ropefp16 else torch.cos(i_position4 * div_term) + cos4[:, :, 1::2] = torch.cos(j_position4 * div_term).half() if ropefp16 else torch.cos(j_position4 * div_term) + sin4 = sin4.repeat_interleave(2, dim=-1) + cos4 = cos4.repeat_interleave(2, dim=-1) + self.register_buffer('sin4', sin4.unsqueeze(0), persistent=False) # [1, H, W, C//2] + self.register_buffer('cos4', cos4.unsqueeze(0), persistent=False) # [1, H, W, C//2] + + + + def forward(self, x, ratio=1): + """ + Args: + x: [N, H, W, C] + """ + if ratio == 4: + return (x * self.cos4[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin4[:, :x.size(1), :x.size(2), :]) + else: + return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :]) + + def rotate_half(self, x): + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py b/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..f57caa3a4b1498e31b5daca0289dd9381489f2b9 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py @@ -0,0 +1,475 @@ +from math import log +from loguru import logger as loguru_logger + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from kornia.utils import create_meshgrid + +from .geometry import warp_kpts, homo_warp_kpts, homo_warp_kpts_glue, homo_warp_kpts_with_mask, homo_warp_kpts_with_mask_f, homo_warp_kpts_glue_with_mask, homo_warp_kpts_glue_with_mask_f, warp_kpts_by_sparse_gt_matches_fast, warp_kpts_by_sparse_gt_matches_fine_chunks + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + +def static_vars(**kwargs): + def decorate(func): + for k in kwargs: + setattr(func, k, kwargs[k]) + return func + return decorate + +############## ↓ Coarse-Level supervision ↓ ############## + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + + if 'loftr' in config.METHOD: + scale = config['LOFTR']['RESOLUTION'][0] + + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + if config['LOFTR']['MATCH_COARSE']['MTD_SPVS'] and not config['LOFTR']['FORCE_LOOP_BACK']: + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + correct_0to1 = torch.zeros((grid_pt0_i.shape[0], grid_pt0_i.shape[1]), dtype=torch.bool, device=grid_pt0_i.device) + w_pt0_i = torch.zeros_like(grid_pt0_i) + + valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0 + valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0) + valid_gt_match_warp_mask = (data['gt_matches_mask'][:, 0] != 0) # N + + if valid_homo_warp_mask.sum() != 0: + if data['homography'].sum()==0: + if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): # the key 'depth_mask' only exits when using the dataste "CommonDataSetHomoWarp" + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + else: + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_pt0_i[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \ + data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + else: + if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0: + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_pt0_i[valid_homo_warp_mask], scale, data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + else: + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_pt0_i[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \ + original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + correct_0to1[valid_homo_warp_mask] = correct_0to1_homo + w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo + + if valid_gt_match_warp_mask.sum() != 0: + correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_pt0_i[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h0, current_w=w0) + correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt + w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt + + if valid_dpt_b_mask.sum() != 0: + correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_pt0_i[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05) + correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt + w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt + + w_pt0_c = w_pt0_i / scale1 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round() # [N, hw, 2] + if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT: + w_pt0_c_error = (1.0 - 2*torch.abs(w_pt0_c - w_pt0_c_round)).prod(-1) + w_pt0_c_round = w_pt0_c_round.long() # [N, hw, 2] + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 # [N, hw] + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = -1 + + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + mask1 = torch.stack([data['mask1'].reshape(-1, h1*w1)[_b, _i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = correct_0to1 * data['mask0'].reshape(-1, h0*w0) * mask1 + + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device, dtype=torch.bool) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + valid_j_ids = j_ids != -1 + b_ids, i_ids, j_ids = map(lambda x: x[valid_j_ids], [b_ids, i_ids, j_ids]) + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + + # overlap weight + if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT: + conf_matrix_error_gt = w_pt0_c_error[b_ids, i_ids] + assert torch.all(conf_matrix_error_gt >= -0.001) + assert torch.all(conf_matrix_error_gt <= 1.001) + data.update({'conf_matrix_error_gt': conf_matrix_error_gt}) + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + loguru_logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + data.update({'mkpts0_c_gt_b_ids': b_ids}) + data.update({'mkpts0_c_gt': torch.stack([i_ids % w0, i_ids // w0], dim=-1) * scale0[b_ids, 0]}) + data.update({'mkpts1_c_gt': torch.stack([j_ids % w1, j_ids // w1], dim=-1) * scale1[b_ids, 0]}) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i, + # 'correct_0to1_c': correct_0to1 + }) + else: + raise NotImplementedError + +def compute_supervision_coarse(data, config): + spvs_coarse(data, config) + +@torch.no_grad() +def get_gt_flow(data, h, w): + device = data['image0'].device + B, _, H0, W0 = data['image0'].shape + scale = H0 / h + + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=device + ) + for n in (B, h, w) + ] + ) + grid_coord = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, h*w, 2) # normalized + grid_coord = torch.stack( + (w * (grid_coord[..., 0] + 1) / 2, h * (grid_coord[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + grid_coord_in_origin = grid_coord * scale0 + + correct_0to1 = torch.zeros((grid_coord_in_origin.shape[0], grid_coord_in_origin.shape[1]), dtype=torch.bool, device=device) + w_pt0_i = torch.zeros_like(grid_coord_in_origin) + + valid_dpt_b_mask = data['T_0to1'].sum(dim=-1).sum(dim=-1) != 0 + valid_homo_warp_mask = (data['homography'].sum(dim=-1).sum(dim=-1) != 0) | (data['homo_sample_normed'].sum(dim=-1).sum(dim=-1) != 0) + valid_gt_match_warp_mask = (data['gt_matches_mask'] != 0)[:, 0] + + if valid_homo_warp_mask.sum() != 0: + if data['homography'].sum()==0: + if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): + # data['load_mask'] = True or False, data['depth_mask'] = depth_mask or None + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], \ + data['homo_sample_normed'][valid_homo_warp_mask], original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + else: + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts(grid_coord_in_origin[valid_homo_warp_mask], data['norm_pixel_mat'][valid_homo_warp_mask], data['homo_sample_normed'][valid_homo_warp_mask], \ + original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + else: + if 'homo_mask0' in data and (data['homo_mask0']==0).sum()!=0: + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue_with_mask(grid_coord_in_origin[valid_homo_warp_mask], int(scale), data['homo_mask0'][valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \ + original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + else: + correct_0to1_homo, w_pt0_i_homo = homo_warp_kpts_glue(grid_coord_in_origin[valid_homo_warp_mask], data['homography'][valid_homo_warp_mask], \ + original_size1=data['origin_img_size1'][valid_homo_warp_mask]) + correct_0to1[valid_homo_warp_mask] = correct_0to1_homo + w_pt0_i[valid_homo_warp_mask] = w_pt0_i_homo + + if valid_gt_match_warp_mask.sum() != 0: + correct_0to1_dpt, w_pt0_i_dpt = warp_kpts_by_sparse_gt_matches_fast(grid_coord_in_origin[valid_gt_match_warp_mask], data['gt_matches'][valid_gt_match_warp_mask], scale0=scale0[valid_gt_match_warp_mask], current_h=h, current_w=w) + correct_0to1[valid_gt_match_warp_mask] = correct_0to1_dpt + w_pt0_i[valid_gt_match_warp_mask] = w_pt0_i_dpt + if valid_dpt_b_mask.sum() != 0: + correct_0to1_dpt, w_pt0_i_dpt = warp_kpts(grid_coord_in_origin[valid_dpt_b_mask], data['depth0'][valid_dpt_b_mask], data['depth1'][valid_dpt_b_mask], data['T_0to1'][valid_dpt_b_mask], data['K0'][valid_dpt_b_mask], data['K1'][valid_dpt_b_mask], consistency_thr=0.05) + correct_0to1[valid_dpt_b_mask] = correct_0to1_dpt + w_pt0_i[valid_dpt_b_mask] = w_pt0_i_dpt + + w_pt0_c = w_pt0_i / scale1 + + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + correct_0to1[out_bound_mask(w_pt0_c, w, h)] = 0 + + w_pt0_n = torch.stack( + (2 * w_pt0_c[..., 0] / w - 1, 2 * w_pt0_c[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_pt1_c = w_pt1_i / scale0 + + if scale > 8: + data.update({'mkpts0_c_gt': grid_coord_in_origin[correct_0to1]}) + data.update({'mkpts1_c_gt': w_pt0_i[correct_0to1]}) + + return w_pt0_n.reshape(B, h, w, 2), correct_0to1.float().reshape(B, h, w) + +@torch.no_grad() +def compute_roma_supervision(data, config): + gt_flow = {} + for scale in list(data["corresps"]): + scale_corresps = data["corresps"][scale] + flow_pre_delta = rearrange(scale_corresps['flow'] if 'flow'in scale_corresps else scale_corresps['dense_flow'], "b d h w -> b h w d") + b, h, w, d = flow_pre_delta.shape + gt_warp, gt_prob = get_gt_flow(data, h, w) + gt_flow[scale] = {'gt_warp': gt_warp, "gt_prob": gt_prob} + + data.update({"gt": gt_flow}) + +############## ↓ Fine-Level supervision ↓ ############## + +@static_vars(counter = 0) +@torch.no_grad() +def spvs_fine(data, config, logger = None): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2]} + """ + # 1. misc + # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') + if config.LOFTR.FINE.MTD_SPVS: + pt1_i = data['spv_pt1_i'] + else: + spv_w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] + if 'loftr' in config.METHOD: + scale = config['LOFTR']['RESOLUTION'][1] + scale_c = config['LOFTR']['RESOLUTION'][0] + radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2 + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + # 3. compute gt + scalei0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale0 = scale * data['scale0'] if 'scale0' in data else scale + scalei1 = scale * data['scale1'][b_ids] if 'scale0' in data else scale + + if config.LOFTR.FINE.MTD_SPVS: + W = config['LOFTR']['FINE_WINDOW_SIZE'] + WW = W*W + device = data['image0'].device + + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + + if config.LOFTR.ALIGN_CORNER is False: + hf0, wf0, hf1, wf1 = data['hw0_f'][0], data['hw0_f'][1], data['hw1_f'][0], data['hw1_f'][1] + hc0, wc0, hc1, wc1 = data['hw0_c'][0], data['hw0_c'][1], data['hw1_c'][0], data['hw1_c'][1] + # loguru_logger.info('hf0, wf0, hf1, wf1', hf0, wf0, hf1, wf1) + else: + hf0, wf0, hf1, wf1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + hc0, wc0, hc1, wc1 = map(lambda x: x // scale_c, [H0, W0, H1, W1]) + + m = b_ids.shape[0] + if m == 0: + conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device) + + data.update({'conf_matrix_f_gt': conf_matrix_f_gt}) + if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT: + conf_matrix_f_error_gt = torch.zeros(1, device=device) + data.update({'conf_matrix_f_error_gt': conf_matrix_f_error_gt}) + if config.LOFTR.MATCH_FINE.MULTI_REGRESS: + data.update({'expec_f': torch.zeros(1, 3, device=device)}) + data.update({'expec_f_gt': torch.zeros(1, 2, device=device)}) + + if config.LOFTR.MATCH_FINE.LOCAL_REGRESS: + data.update({'expec_f': torch.zeros(1, 2, device=device)}) + data.update({'expec_f_gt': torch.zeros(1, 2, device=device)}) + else: + grid_pt0_f = create_meshgrid(hf0, wf0, False, device) - W // 2 + 0.5 # [1, hf0, wf0, 2] # use fine coordinates + # grid_pt0_f = create_meshgrid(hf0, wf0, False, device) + 0.5 # [1, hf0, wf0, 2] # use fine coordinates + grid_pt0_f = rearrange(grid_pt0_f, 'n h w c -> n c h w') + # 1. unfold(crop) all local windows + if config.LOFTR.ALIGN_CORNER is False: # even windows + if config.LOFTR.MATCH_FINE.MULTI_REGRESS or (config.LOFTR.MATCH_FINE.LOCAL_REGRESS and W == 10): + grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W-2, padding=1) # overlap windows W-2 padding=1 + else: + grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W, padding=0) + else: + grid_pt0_f_unfold = F.unfold(grid_pt0_f[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2) + grid_pt0_f_unfold = rearrange(grid_pt0_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 2] + grid_pt0_f_unfold = repeat(grid_pt0_f_unfold[0], 'l ww c -> N l ww c', N=N) + + # 2. select only the predicted matches + grid_pt0_f_unfold = grid_pt0_f_unfold[data['b_ids'], data['i_ids']] # [m, ww, 2] + grid_pt0_f_unfold = scalei0[:,None,:] * grid_pt0_f_unfold # [m, ww, 2] + + # use depth mask + if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): + # depth_mask --> (n, 1, hf, wf) + homo_mask0 = data['homo_mask0'] + homo_mask0 = F.unfold(homo_mask0[..., :-1, :-1], kernel_size=(W, W), stride=W, padding=W//2) + homo_mask0 = rearrange(homo_mask0, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 1] + homo_mask0 = repeat(homo_mask0[0], 'l ww c -> N l ww c', N=N) + # select only the predicted matches + homo_mask0 = homo_mask0[data['b_ids'], data['i_ids']] + + correct_0to1_f_list, w_pt0_i_list = [], [] + + correct_0to1_f = torch.zeros(m, WW, device=device, dtype=torch.bool) + w_pt0_i = torch.zeros(m, WW, 2, device=device, dtype=torch.float32) + for b in range(N): + mask = b_ids == b + + match = int(mask.sum()) + skip_reshape = False + if match == 0: + print(f"no pred fine matches, skip!") + continue + if (data['homography'][b].sum() != 0) | (data['homo_sample_normed'][b].sum() != 0): + if data['homography'][b].sum()==0: + if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): + correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['norm_pixel_mat'][[b]], \ + data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]]) + else: + correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['norm_pixel_mat'][[b]], \ + data['homo_sample_normed'][[b]], data['origin_img_size0'][[b]], data['origin_img_size1'][[b]]) + else: + if 'homo_mask0' in data and (data['homo_mask0'].sum()!=0): + correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue_with_mask_f(grid_pt0_f_unfold[mask].reshape(1,-1,2), homo_mask0[mask].reshape(1,-1), data['homography'][[b]], \ + data['origin_img_size0'][[b]], data['origin_img_size1'][[b]]) + else: + correct_0to1_f_mask, w_pt0_i_mask = homo_warp_kpts_glue(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['homography'][[b]], \ + data['origin_img_size0'][[b]], data['origin_img_size1'][[b]]) + elif data['T_0to1'][b].sum() != 0: + correct_0to1_f_mask, w_pt0_i_mask = warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['depth0'][[b],...], + data['depth1'][[b],...], data['T_0to1'][[b],...], + data['K0'][[b],...], data['K1'][[b],...]) # [k, WW], [k, WW, 2] + elif data['gt_matches_mask'][b].sum() != 0: + correct_0to1_f_mask, w_pt0_i_mask = warp_kpts_by_sparse_gt_matches_fine_chunks(grid_pt0_f_unfold[mask], gt_matches=data['gt_matches'][[b]], dist_thr=scale0[[b]].max(dim=-1)[0]) + skip_reshape = True + correct_0to1_f[mask] = correct_0to1_f_mask.reshape(match, WW) if not skip_reshape else correct_0to1_f_mask + w_pt0_i[mask] = w_pt0_i_mask.reshape(match, WW, 2) if not skip_reshape else w_pt0_i_mask + + delta_w_pt0_i = w_pt0_i - pt1_i[b_ids, j_ids][:,None,:] # [m, WW, 2] + delta_w_pt0_f = delta_w_pt0_i / scalei1[:,None,:] + W // 2 - 0.5 + delta_w_pt0_f_round = delta_w_pt0_f[:, :, :].round() + if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT and config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2: + w_pt0_f_error = (1.0 - torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0.25, 1] + elif config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT: + w_pt0_f_error = (1.0 - 2*torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0, 1] + delta_w_pt0_f_round = delta_w_pt0_f_round.long() + + + nearest_index1 = delta_w_pt0_f_round[..., 0] + delta_w_pt0_f_round[..., 1] * W # [m, WW] + + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + ob_mask = out_bound_mask(delta_w_pt0_f_round, W, W) + nearest_index1[ob_mask] = 0 + + correct_0to1_f[ob_mask] = 0 + m_ids_d, i_ids_d = torch.where(correct_0to1_f != 0) + + j_ids_d = nearest_index1[m_ids_d, i_ids_d] + + # For plotting: + mkpts0_f_gt = grid_pt0_f_unfold[m_ids_d, i_ids_d] # [m, 2] + mkpts1_f_gt = w_pt0_i[m_ids_d, i_ids_d] # [m, 2] + data.update({'mkpts0_f_gt_b_ids': m_ids_d}) + data.update({'mkpts0_f_gt': mkpts0_f_gt}) + data.update({'mkpts1_f_gt': mkpts1_f_gt}) + + if config.LOFTR.MATCH_FINE.MULTI_REGRESS: + assert not config.LOFTR.MATCH_FINE.LOCAL_REGRESS + expec_f_gt = delta_w_pt0_f - W // 2 + 0.5 # use delta(e.g. [-3.5,3.5]) in regression rather than [0,W] (e.g. [0,7]) + expec_f_gt = expec_f_gt[m_ids_d, i_ids_d] / (W // 2 - 1) # specific radius for overlaped even windows & align_corner=False + data.update({'expec_f_gt': expec_f_gt}) + data.update({'m_ids_d': m_ids_d, 'i_ids_d': i_ids_d}) + else: # spv fine dual softmax + if config.LOFTR.MATCH_FINE.LOCAL_REGRESS: + expec_f_gt = delta_w_pt0_f - delta_w_pt0_f_round + + # mask fine windows boarder + j_ids_d_il, j_ids_d_jl = j_ids_d // W, j_ids_d % W + if config.LOFTR.MATCH_FINE.LOCAL_REGRESS_NOMASK: + mask = None + m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d.to(torch.long), i_ids_d.to(torch.long), j_ids_d_il.to(torch.long), j_ids_d_jl.to(torch.long) + else: + mask = (j_ids_d_il >= 1) & (j_ids_d_il < W-1) & (j_ids_d_jl >= 1) & (j_ids_d_jl < W-1) + if W == 10: + i_ids_d_il, i_ids_d_jl = i_ids_d // W, i_ids_d % W + mask = mask & (i_ids_d_il >= 1) & (i_ids_d_il <= W-2) & (i_ids_d_jl >= 1) & (i_ids_d_jl <= W-2) + + m_ids_dl, i_ids_dl, j_ids_d_il, j_ids_d_jl = m_ids_d[mask].to(torch.long), i_ids_d[mask].to(torch.long), j_ids_d_il[mask].to(torch.long), j_ids_d_jl[mask].to(torch.long) + if mask is not None: + loguru_logger.info(f'percent of gt mask.sum / mask.numel: {mask.sum().float()/mask.numel():.2f}') + if m_ids_dl.numel() == 0: + loguru_logger.warning(f"No groundtruth fine match found for local regress: {data['pair_names']}") + data.update({'expec_f_gt': torch.zeros(1, 2, device=device)}) + data.update({'expec_f': torch.zeros(1, 2, device=device)}) + else: + expec_f_gt = expec_f_gt[m_ids_dl, i_ids_dl] + data.update({"expec_f_gt": expec_f_gt}) + + data.update({"m_ids_dl": m_ids_dl, + "i_ids_dl": i_ids_dl, + "j_ids_d_il": j_ids_d_il, + "j_ids_d_jl": j_ids_d_jl + }) + else: # no fine regress + pass + + # spv fine dual softmax + conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device, dtype=torch.bool) + conf_matrix_f_gt[m_ids_d, i_ids_d, j_ids_d] = 1 + data.update({'conf_matrix_f_gt': conf_matrix_f_gt}) + if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT: + w_pt0_f_error = w_pt0_f_error[m_ids_d, i_ids_d] + assert torch.all(w_pt0_f_error >= -0.001) + assert torch.all(w_pt0_f_error <= 1.001) + data.update({'conf_matrix_f_error_gt': w_pt0_f_error}) + + conf_matrix_f_gt_sum = conf_matrix_f_gt.sum() + if conf_matrix_f_gt_sum != 0: + pass + else: + loguru_logger.info(f'[no gt plot]no fine matches to supervise') + + else: + expec_f_gt = (spv_w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scalei1 / 4 # [M, 2] + data.update({"expec_f_gt": expec_f_gt}) + + +def compute_supervision_fine(data, config, logger=None): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config, logger) + else: + raise NotImplementedError \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/optimizers/__init__.py b/imcui/third_party/MatchAnything/src/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c946086518db86e6775774eca36d67188c8b657 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/optimizers/__init__.py @@ -0,0 +1,50 @@ +import torch +from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR + + +def build_optimizer(model, config): + name = config.TRAINER.OPTIMIZER + lr = config.TRAINER.TRUE_LR + + if name == "adam": + return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY, eps=config.TRAINER.OPTIMIZER_EPS) + elif name == "adamw": + if ("ROMA" in config.METHOD) or ("DKM" in config.METHOD): + # Filter the backbone param and others param: + keyword = 'model.encoder' + backbone_params = [param for name, param in list(filter(lambda kv: keyword in kv[0], model.named_parameters()))] + base_params = [param for name, param in list(filter(lambda kv: keyword not in kv[0], model.named_parameters()))] + params = [{'params': backbone_params, 'lr': lr * 0.05}, {'params': base_params}] + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS) + else: + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY, eps=config.TRAINER.OPTIMIZER_EPS) + else: + raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") + + +def build_scheduler(config, optimizer): + """ + Returns: + scheduler (dict):{ + 'scheduler': lr_scheduler, + 'interval': 'step', # or 'epoch' + 'monitor': 'val_f1', (optional) + 'frequency': x, (optional) + } + """ + scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} + name = config.TRAINER.SCHEDULER + + if name == 'MultiStepLR': + scheduler.update( + {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) + elif name == 'CosineAnnealing': + scheduler.update( + {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) + elif name == 'ExponentialLR': + scheduler.update( + {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) + else: + raise NotImplementedError() + + return scheduler diff --git a/imcui/third_party/MatchAnything/src/utils/__init__.py b/imcui/third_party/MatchAnything/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/MatchAnything/src/utils/augment.py b/imcui/third_party/MatchAnything/src/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/augment.py @@ -0,0 +1,55 @@ +import albumentations as A + + +class DarkAug(object): + """ + Extreme dark augmentation aiming at Aachen Day-Night + """ + + def __init__(self) -> None: + self.augmentor = A.Compose([ + A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), + A.Blur(p=0.1, blur_limit=(3, 9)), + A.MotionBlur(p=0.2, blur_limit=(3, 25)), + A.RandomGamma(p=0.1, gamma_limit=(15, 65)), + A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) + ], p=0.75) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +class MobileAug(object): + """ + Random augmentations aiming at images of mobile/handhold devices. + """ + + def __init__(self): + self.augmentor = A.Compose([ + A.MotionBlur(p=0.25), + A.ColorJitter(p=0.5), + A.RandomRain(p=0.1), # random occlusion + A.RandomSunFlare(p=0.1), + A.JpegCompression(p=0.25), + A.ISONoise(p=0.25) + ], p=1.0) + + def __call__(self, x): + return self.augmentor(image=x)['image'] + + +def build_augmentor(method=None, **kwargs): + if method is not None: + raise NotImplementedError('Using of augmentation functions are not supported yet!') + if method == 'dark': + return DarkAug() + elif method == 'mobile': + return MobileAug() + elif method is None: + return None + else: + raise ValueError(f'Invalid augmentation method: {method}') + + +if __name__ == '__main__': + augmentor = build_augmentor('FDA') diff --git a/imcui/third_party/MatchAnything/src/utils/colmap.py b/imcui/third_party/MatchAnything/src/utils/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..deefe92a3a8e132e3d5d51b8eaf08b1050e22aac --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/colmap.py @@ -0,0 +1,530 @@ +# Copyright (c) 2022, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +from typing import List, Tuple, Dict +import os +import collections +import numpy as np +import struct +import argparse + + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +BaseCamera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + @property + def world_to_camera(self) -> np.ndarray: + R = qvec2rotmat(self.qvec) + t = self.tvec + world2cam = np.eye(4) + world2cam[:3, :3] = R + world2cam[:3, 3] = t + return world2cam + + +class Camera(BaseCamera): + @property + def K(self): + K = np.eye(3) + if self.model == "SIMPLE_PINHOLE" or self.model == "SIMPLE_RADIAL" or self.model == "RADIAL" or self.model == "SIMPLE_RADIAL_FISHEYE" or self.model == "RADIAL_FISHEYE": + K[0, 0] = self.params[0] + K[1, 1] = self.params[0] + K[0, 2] = self.params[1] + K[1, 2] = self.params[2] + elif self.model == "PINHOLE" or self.model == "OPENCV" or self.model == "OPENCV_FISHEYE" or self.model == "FULL_OPENCV" or self.model == "FOV" or self.model == "THIN_PRISM_FISHEYE": + K[0, 0] = self.params[0] + K[1, 1] = self.params[1] + K[0, 2] = self.params[2] + K[1, 2] = self.params[3] + else: + raise NotImplementedError + return K + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = "# Camera list with one line of data per camera:\n" + \ + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \ + "# Number of cameras: {}\n".format(len(cameras)) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, + model_id, + cam.width, + cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) + HEADER = "# Image list with two lines of data per image:\n" + \ + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \ + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \ + "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D) + HEADER = "# 3D point list with one line of data per point:\n" + \ + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \ + "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if os.path.isfile(os.path.join(path, "cameras" + ext)) and \ + os.path.isfile(os.path.join(path, "images" + ext)) and \ + os.path.isfile(os.path.join(path, "points3D" + ext)): + print("Detected model format: '" + ext + "'") + return True + + return False + + +def read_model(path, ext="") -> Tuple[Dict[int, Camera], Dict[int, Image], Dict[int, Point3D]]: + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + raise ValueError("Provide model format: '.bin' or '.txt'") + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models") + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument("--input_format", choices=[".bin", ".txt"], + help="input model format", default="") + parser.add_argument("--output_model", + help="path to output model folder") + parser.add_argument("--output_format", choices=[".bin", ".txt"], + help="outut model format", default=".txt") + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) + + +if __name__ == "__main__": + main() diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/__init__.py b/imcui/third_party/MatchAnything/src/utils/colmap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/database.py b/imcui/third_party/MatchAnything/src/utils/colmap/database.py new file mode 100644 index 0000000000000000000000000000000000000000..81a9c47bad6c522597dfaaf2a85ae0d252b5ab10 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/colmap/database.py @@ -0,0 +1,417 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +# This script is based on an original implementation by True Price. + +import sys +import sqlite3 +import numpy as np +from loguru import logger + + +IS_PYTHON3 = sys.version_info[0] >= 3 + +MAX_IMAGE_ID = 2**31 - 1 + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id)) +""".format(MAX_IMAGE_ID) + +CREATE_TWO_VIEW_GEOMETRIES_TABLE = """ +CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB, + qvec BLOB, + tvec BLOB) +""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE) +""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = \ + "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join([ + CREATE_CAMERAS_TABLE, + CREATE_IMAGES_TABLE, + CREATE_KEYPOINTS_TABLE, + CREATE_DESCRIPTORS_TABLE, + CREATE_MATCHES_TABLE, + CREATE_TWO_VIEW_GEOMETRIES_TABLE, + CREATE_NAME_INDEX +]) + + +def image_ids_to_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def pair_id_to_image_ids(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID + return image_id1, image_id2 + + +def array_to_blob(array): + if IS_PYTHON3: + return array.tobytes() + else: + return np.getbuffer(array) + + +def blob_to_array(blob, dtype, shape=(-1,)): + if IS_PYTHON3: + return np.fromstring(blob, dtype=dtype).reshape(*shape) + else: + return np.frombuffer(blob, dtype=dtype).reshape(*shape) + + +class COLMAPDatabase(sqlite3.Connection): + + @staticmethod + def connect(database_path): + return sqlite3.connect(str(database_path), factory=COLMAPDatabase) + + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.create_tables = lambda: self.executescript(CREATE_ALL) + self.create_cameras_table = \ + lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.create_descriptors_table = \ + lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) + self.create_images_table = \ + lambda: self.executescript(CREATE_IMAGES_TABLE) + self.create_two_view_geometries_table = \ + lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) + self.create_keypoints_table = \ + lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.create_matches_table = \ + lambda: self.executescript(CREATE_MATCHES_TABLE) + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + def add_camera(self, model, width, height, params, + prior_focal_length=False, camera_id=None): + params = np.asarray(params, np.float64) + cursor = self.execute( + "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + (camera_id, model, width, height, array_to_blob(params), + prior_focal_length)) + return cursor.lastrowid + + def add_image(self, name, camera_id, + prior_q=np.full(4, np.NaN), prior_t=np.full(3, np.NaN), + image_id=None): + cursor = self.execute( + "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], + prior_q[3], prior_t[0], prior_t[1], prior_t[2])) + return cursor.lastrowid + + def add_keypoints(self, image_id, keypoints): + assert(len(keypoints.shape) == 2) + assert(keypoints.shape[1] in [2, 4, 6]) + + keypoints = np.asarray(keypoints, np.float32) + self.execute( + "INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) + + def add_descriptors(self, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + self.execute( + "INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) + + def add_matches(self, image_id1, image_id2, matches): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + self.execute( + "INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),)) + + def add_two_view_geometry(self, image_id1, image_id2, matches, + F=np.eye(3), E=np.eye(3), H=np.eye(3), + qvec=np.array([1.0, 0.0, 0.0, 0.0]), + tvec=np.zeros(3), config=2): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + F = np.asarray(F, dtype=np.float64) + E = np.asarray(E, dtype=np.float64) + H = np.asarray(H, dtype=np.float64) + qvec = np.asarray(qvec, dtype=np.float64) + tvec = np.asarray(tvec, dtype=np.float64) + self.execute( + "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches), config, + array_to_blob(F), array_to_blob(E), array_to_blob(H), + array_to_blob(qvec), array_to_blob(tvec))) + + def update_two_view_geometry(self, image_id1, image_id2, matches, + F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = image_ids_to_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + F = np.asarray(F, dtype=np.float64) + E = np.asarray(E, dtype=np.float64) + H = np.asarray(H, dtype=np.float64) + + # Find whether exists: + row = self.execute(f"SELECT * FROM two_view_geometries WHERE pair_id = {pair_id} ") + data = list(next(row)) + try: + matches_old = blob_to_array(data[3], np.uint32, (-1, 2)) + except: + matches_old = None + + if matches_old is not None: + for match in matches: + img0_id, img1_id = match + + # Find duplicated pts + img0_dup_idxs = np.where(matches_old[:, 0] == img0_id) + img1_dup_idxs = np.where(matches_old[:, 1] == img1_id) + + if len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 0: + # No duplicated matches: + matches_old = np.concatenate([matches_old, match[None]], axis=0) + elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 0: + matches_old[img0_dup_idxs[0]][0,1] = img1_id + elif len(img0_dup_idxs[0]) == 0 and len(img1_dup_idxs[0]) == 1: + matches_old[img1_dup_idxs[0]][0,0] = img0_id + elif len(img0_dup_idxs[0]) == 1 and len(img1_dup_idxs[0]) == 1: + if img0_dup_idxs[0] != img1_dup_idxs[0]: + # logger.warning(f"Duplicated matches exists!") + matches_old[img0_dup_idxs[0]][0,1] = img1_id + matches_old[img1_dup_idxs[0]][0,0] = img0_id + else: + raise NotImplementedError + + # matches = np.concatenate([matches_old, matches], axis=0) # N * 2 + matches = matches_old + self.execute(f"DELETE FROM two_view_geometries WHERE pair_id = {pair_id}") + + data[1:4] = matches.shape + (array_to_blob(np.asarray(matches, np.uint32)),) + self.execute("INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", tuple(data)) + else: + raise NotImplementedError + + # self.add_two_view_geometry(image_id1, image_id2, matches) + + +def example_usage(): + import os + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--database_path", default="database.db") + args = parser.parse_args() + + if os.path.exists(args.database_path): + print("ERROR: database path already exists -- will not modify it.") + return + + # Open the database. + + db = COLMAPDatabase.connect(args.database_path) + + # For convenience, try creating all the tables upfront. + + db.create_tables() + + # Create dummy cameras. + + model1, width1, height1, params1 = \ + 0, 1024, 768, np.array((1024., 512., 384.)) + model2, width2, height2, params2 = \ + 2, 1024, 768, np.array((1024., 512., 384., 0.1)) + + camera_id1 = db.add_camera(model1, width1, height1, params1) + camera_id2 = db.add_camera(model2, width2, height2, params2) + + # Create dummy images. + + image_id1 = db.add_image("image1.png", camera_id1) + image_id2 = db.add_image("image2.png", camera_id1) + image_id3 = db.add_image("image3.png", camera_id2) + image_id4 = db.add_image("image4.png", camera_id2) + + # Create dummy keypoints. + # + # Note that COLMAP supports: + # - 2D keypoints: (x, y) + # - 4D keypoints: (x, y, theta, scale) + # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) + + num_keypoints = 1000 + keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1) + keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2) + keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2) + + db.add_keypoints(image_id1, keypoints1) + db.add_keypoints(image_id2, keypoints2) + db.add_keypoints(image_id3, keypoints3) + db.add_keypoints(image_id4, keypoints4) + + # Create dummy matches. + + M = 50 + matches12 = np.random.randint(num_keypoints, size=(M, 2)) + matches23 = np.random.randint(num_keypoints, size=(M, 2)) + matches34 = np.random.randint(num_keypoints, size=(M, 2)) + + db.add_matches(image_id1, image_id2, matches12) + db.add_matches(image_id2, image_id3, matches23) + db.add_matches(image_id3, image_id4, matches34) + + # Commit the data to the file. + + db.commit() + + # Read and check cameras. + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id1 + assert model == model1 and width == width1 and height == height1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float64) + assert camera_id == camera_id2 + assert model == model2 and width == width2 and height == height2 + assert np.allclose(params, params2) + + # Read and check keypoints. + + keypoints = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute( + "SELECT image_id, data FROM keypoints")) + + assert np.allclose(keypoints[image_id1], keypoints1) + assert np.allclose(keypoints[image_id2], keypoints2) + assert np.allclose(keypoints[image_id3], keypoints3) + assert np.allclose(keypoints[image_id4], keypoints4) + + # Read and check matches. + + pair_ids = [image_ids_to_pair_id(*pair) for pair in + ((image_id1, image_id2), + (image_id2, image_id3), + (image_id3, image_id4))] + + matches = dict( + (pair_id_to_image_ids(pair_id), + blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches") + ) + + assert np.all(matches[(image_id1, image_id2)] == matches12) + assert np.all(matches[(image_id2, image_id3)] == matches23) + assert np.all(matches[(image_id3, image_id4)] == matches34) + + # Clean up. + + db.close() + + if os.path.exists(args.database_path): + os.remove(args.database_path) + + +if __name__ == "__main__": + example_usage() diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py b/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..2335a1abbc6321f7e3a4b40123d8386d4900a9d2 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py @@ -0,0 +1,232 @@ +import math +import cv2 +import os +import numpy as np +from .read_write_model import read_images_binary + + +def align_model(model, rot, trans, scale): + return (np.matmul(rot, model) + trans) * scale + + +def align(model, data): + ''' + Source: https://vision.in.tum.de/data/datasets/rgbd-dataset/tools + #absolute_trajectory_error_ate + Align two trajectories using the method of Horn (closed-form). + + Input: + model -- first trajectory (3xn) + data -- second trajectory (3xn) + + Output: + rot -- rotation matrix (3x3) + trans -- translation vector (3x1) + trans_error -- translational error per point (1xn) + + ''' + + if model.shape[1] < 3: + print('Need at least 3 points for ATE: {}'.format(model)) + return np.identity(3), np.zeros((3, 1)), 1 + + # Get zero centered point cloud + model_zerocentered = model - model.mean(1, keepdims=True) + data_zerocentered = data - data.mean(1, keepdims=True) + + # constructed covariance matrix + W = np.zeros((3, 3)) + for column in range(model.shape[1]): + W += np.outer(model_zerocentered[:, column], + data_zerocentered[:, column]) + + # SVD + U, d, Vh = np.linalg.linalg.svd(W.transpose()) + S = np.identity(3) + if (np.linalg.det(U) * np.linalg.det(Vh) < 0): + S[2, 2] = -1 + rot = np.matmul(np.matmul(U, S), Vh) + trans = data.mean(1, keepdims=True) - np.matmul( + rot, model.mean(1, keepdims=True)) + + # apply rot and trans to point cloud + model_aligned = align_model(model, rot, trans, 1.0) + model_aligned_zerocentered = model_aligned - model_aligned.mean( + 1, keepdims=True) + + # calc scale based on distance to point cloud center + data_dist = np.sqrt((data_zerocentered * data_zerocentered).sum(axis=0)) + model_aligned_dist = np.sqrt( + (model_aligned_zerocentered * model_aligned_zerocentered).sum(axis=0)) + scale_array = data_dist / model_aligned_dist + scale = np.median(scale_array) + + return rot, trans, scale + + +def quaternion_matrix(quaternion): + '''Return homogeneous rotation matrix from quaternion. + + >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0]) + >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0])) + True + >>> M = quaternion_matrix([1, 0, 0, 0]) + >>> numpy.allclose(M, numpy.identity(4)) + True + >>> M = quaternion_matrix([0, 1, 0, 0]) + >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1])) + True + ''' + + q = np.array(quaternion, dtype=np.float64, copy=True) + n = np.dot(q, q) + if n < _EPS: + return np.identity(4) + + q *= math.sqrt(2.0 / n) + q = np.outer(q, q) + + return np.array( + [[1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0], + [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0], + [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0], + [0.0, 0.0, 0.0, 1.0]]) + + +def quaternion_from_matrix(matrix, isprecise=False): + '''Return quaternion from rotation matrix. + + If isprecise is True, the input matrix is assumed to be a precise rotation + matrix and a faster algorithm is used. + + >>> q = quaternion_from_matrix(numpy.identity(4), True) + >>> numpy.allclose(q, [1, 0, 0, 0]) + True + >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1])) + >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0]) + True + >>> R = rotation_matrix(0.123, (1, 2, 3)) + >>> q = quaternion_from_matrix(R, True) + >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786]) + True + >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0], + ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611]) + True + >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0], + ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603]) + True + >>> R = random_rotation_matrix() + >>> q = quaternion_from_matrix(R) + >>> is_same_transform(R, quaternion_matrix(q)) + True + >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0) + >>> numpy.allclose(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + + ''' + + M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4] + if isprecise: + q = np.empty((4, )) + t = np.trace(M) + if t > M[3, 3]: + q[0] = t + q[3] = M[1, 0] - M[0, 1] + q[2] = M[0, 2] - M[2, 0] + q[1] = M[2, 1] - M[1, 2] + else: + i, j, k = 1, 2, 3 + if M[1, 1] > M[0, 0]: + i, j, k = 2, 3, 1 + if M[2, 2] > M[i, i]: + i, j, k = 3, 1, 2 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q *= 0.5 / math.sqrt(t * M[3, 3]) + else: + m00 = M[0, 0] + m01 = M[0, 1] + m02 = M[0, 2] + m10 = M[1, 0] + m11 = M[1, 1] + m12 = M[1, 2] + m20 = M[2, 0] + m21 = M[2, 1] + m22 = M[2, 2] + + # symmetric matrix K + K = np.array([[m00 - m11 - m22, 0.0, 0.0, 0.0], + [m01 + m10, m11 - m00 - m22, 0.0, 0.0], + [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0], + [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22]]) + K /= 3.0 + + # quaternion is eigenvector of K that corresponds to largest eigenvalue + w, V = np.linalg.eigh(K) + q = V[[3, 0, 1, 2], np.argmax(w)] + + if q[0] < 0.0: + np.negative(q, q) + + return q + +def is_colmap_img_valid(colmap_img_file): + '''Return validity of a colmap reconstruction''' + + images_bin = read_images_binary(colmap_img_file) + # Check if everything is finite for this subset + for key in images_bin.keys(): + q = np.asarray(images_bin[key].qvec).flatten() + t = np.asarray(images_bin[key].tvec).flatten() + + is_cur_valid = True + is_cur_valid = is_cur_valid and q.shape == (4, ) + is_cur_valid = is_cur_valid and t.shape == (3, ) + is_cur_valid = is_cur_valid and np.all(np.isfinite(q)) + is_cur_valid = is_cur_valid and np.all(np.isfinite(t)) + + # If any is invalid, immediately return + if not is_cur_valid: + return False + + return True + +def get_best_colmap_index(colmap_output_path): + ''' + Determines the colmap model with the most images if there is more than one. + ''' + + # First find the colmap reconstruction with the most number of images. + best_index, best_num_images = -1, 0 + + # Check all valid sub reconstructions. + if os.path.exists(colmap_output_path): + idx_list = [ + _d for _d in os.listdir(colmap_output_path) + if os.path.isdir(os.path.join(colmap_output_path, _d)) + ] + else: + idx_list = [] + + for cur_index in idx_list: + cur_output_path = os.path.join(colmap_output_path, cur_index) + if os.path.isdir(cur_output_path): + colmap_img_file = os.path.join(cur_output_path, 'images.bin') + images_bin = read_images_binary(colmap_img_file) + # Check validity + if not is_colmap_img_valid(colmap_img_file): + continue + # Find the reconstruction with most number of images + if len(images_bin) > best_num_images: + best_index = int(cur_index) + best_num_images = len(images_bin) + + return str(best_index) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py b/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb03c3bee0f1d6ffd5285835c83920e11de5b51 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py @@ -0,0 +1,509 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) + +import os +import sys +import collections +import numpy as np +import struct +import argparse + + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = "# Camera list with one line of data per camera:\n" + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + "# Number of cameras: {}\n".format(len(cameras)) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, + model_id, + cam.width, + cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def write_images_text(images, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) + HEADER = "# Image list with two lines of data per image:\n" + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + +def write_points3D_text(points3D, path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D) + HEADER = "# 3D point list with one line of data per point:\n" + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3d_binary(points3D, path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if os.path.isfile(os.path.join(path, "cameras" + ext)) and \ + os.path.isfile(os.path.join(path, "images" + ext)) and \ + os.path.isfile(os.path.join(path, "points3D" + ext)): + print("Detected model format: '" + ext + "'") + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + print("Provide model format: '.bin' or '.txt'") + return + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models") + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument("--input_format", choices=[".bin", ".txt"], + help="input model format", default="") + parser.add_argument("--output_model", + help="path to output model folder") + parser.add_argument("--output_format", choices=[".bin", ".txt"], + help="outut model format", default=".txt") + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + # FIXME: for debug only + # images_ = images[1] + # tvec, qvec = images_.tvec, images_.qvec + # rotation = qvec2rotmat(qvec).reshape(3, 3) + # pose = np.concatenate([rotation, tvec.reshape(3, 1)], axis=1) + # import ipdb; ipdb.set_trace() + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) + + +if __name__ == "__main__": + main() diff --git a/imcui/third_party/MatchAnything/src/utils/comm.py b/imcui/third_party/MatchAnything/src/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/comm.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +[Copied from detectron2] +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/imcui/third_party/MatchAnything/src/utils/dataloader.py b/imcui/third_party/MatchAnything/src/utils/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/dataloader.py @@ -0,0 +1,23 @@ +import numpy as np + + +# --- PL-DATAMODULE --- + +def get_local_split(items: list, world_size: int, rank: int, seed: int): + """ The local rank only loads a split of the dataset. """ + n_items = len(items) + items_permute = np.random.RandomState(seed).permutation(items) + if n_items % world_size == 0: + padded_items = items_permute + else: + padding = np.random.RandomState(seed).choice( + items, + world_size - (n_items % world_size), + replace=True) + padded_items = np.concatenate([items_permute, padding]) + assert len(padded_items) % world_size == 0, \ + f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' + n_per_rank = len(padded_items) // world_size + local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] + + return local_items diff --git a/imcui/third_party/MatchAnything/src/utils/dataset.py b/imcui/third_party/MatchAnything/src/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..73a0a96db5ef2c08c99394a25e2db306bdb47b6a --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/dataset.py @@ -0,0 +1,518 @@ +import io +from loguru import logger + +import cv2 +import numpy as np +from pathlib import Path +import h5py +import torch +import re +from PIL import Image +from numpy.linalg import inv +from torchvision.transforms import Normalize +from .sample_homo import sample_homography_sap +from kornia.geometry import homography_warp, normalize_homography, normal_transform_pixel +OSS_FOLDER_PATH = '???' +PCACHE_FOLDER_PATH = '???' + +import fsspec +from PIL import Image + +# Initialize pcache +try: + PCACHE_HOST = "???" + PCACHE_PORT = 00000 + pcache_kwargs = {"host": PCACHE_HOST, "port": PCACHE_PORT} + pcache_fs = fsspec.filesystem("pcache", pcache_kwargs=pcache_kwargs) + root_dir='???' +except Exception as e: + logger.error(f"Error captured:{e}") + +try: + # for internel use only + from pcache_fileio import fileio +except Exception: + MEGADEPTH_CLIENT = SCANNET_CLIENT = None + +# --- DATA IO --- + +def load_pfm(pfm_path): + with open(pfm_path, 'rb') as fin: + color = None + width = None + height = None + scale = None + data_type = None + header = str(fin.readline().decode('UTF-8')).rstrip() + + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8')) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + scale = float((fin.readline().decode('UTF-8')).rstrip()) + if scale < 0: # little-endian + data_type = ' 5000: + logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times") + continue + else: + load_failed = True + failed_num = 0 + while load_failed: + try: + with pcache_fs.open(str(pcache_path), 'rb') as f: + data = np.array(h5py.File(io.BytesIO(f.read()), 'r')['/depth']) + load_failed = False + except: + failed_num += 1 + if failed_num > 5000: + logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times") + continue + + except Exception as ex: + print(f"==> Data loading failure: {path}") + raise ex + + assert data is not None + return data + + +def imread_gray(path, augment_fn=None, cv_type=None): + if path.startswith('oss://'): + path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH) + if path.startswith('pcache://'): + path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/' + + if cv_type is None: + cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ + else cv2.IMREAD_COLOR + if str(path).startswith('oss://') or str(path).startswith('pcache://'): + image = load_array_from_pcache(str(path), cv_type) + else: + image = cv2.imread(str(path), cv_type) + + if augment_fn is not None: + image = cv2.imread(str(path), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = augment_fn(image) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + return image # (h, w) + +def imread_color(path, augment_fn=None): + if path.startswith('oss://'): + path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH) + if path.startswith('pcache://'): + path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/' + + if str(path).startswith('oss://') or str(path).startswith('pcache://'): + filename = path.split(root_dir)[1] + pcache_path = Path(root_dir) / filename + load_failed = True + failed_num = 0 + while load_failed: + try: + with pcache_fs.open(str(pcache_path), 'rb') as f: + pil_image = Image.open(f).convert("RGB") + load_failed = False + except: + failed_num += 1 + if failed_num > 5000: + logger.error(f"Try to load: {pcache_path}, but failed {failed_num} times") + continue + else: + pil_image = Image.open(str(path)).convert("RGB") + image = np.array(pil_image) + + if augment_fn is not None: + image = augment_fn(image) + return image # (h, w) + + +def get_resized_wh(w, h, resize=None): + if resize is not None: # resize the longer edge + scale = resize / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + else: + w_new, h_new = w, h + return w_new, h_new + + +def get_divisible_wh(w, h, df=None): + if df is not None: + w_new, h_new = map(lambda x: int(x // df * df), [w, h]) + else: + w_new, h_new = w, h + return w_new, h_new + + +def pad_bottom_right(inp, pad_size, ret_mask=False): + assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" + mask = None + if inp.ndim == 2: + padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) + padded[:inp.shape[0], :inp.shape[1]] = inp + if ret_mask: + mask = np.zeros((pad_size, pad_size), dtype=bool) + mask[:inp.shape[0], :inp.shape[1]] = True + elif inp.ndim == 3: + padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) + padded[:, :inp.shape[1], :inp.shape[2]] = inp + if ret_mask: + mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) + mask[:, :inp.shape[1], :inp.shape[2]] = True + mask = mask[0] + else: + raise NotImplementedError() + return padded, mask + + +# --- MEGADEPTH --- + +def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + if read_gray: + image = imread_gray(path, augment_fn) + else: + image = imread_color(path, augment_fn) + + # resize image + try: + w, h = image.shape[1], image.shape[0] + except: + logger.error(f"{path} not exist or read image error!") + if resize_by_stretch: + w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0]) + else: + if resize: + if not isinstance(resize, int): + assert resize[0] == resize[1] + resize = resize[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + else: + w_new, h_new = w, h + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + origin_img_size = torch.tensor([h, w], dtype=torch.float) + + if not read_gray: + image = image.transpose(2,0,1) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + if len(image.shape) == 2: + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + else: + image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized + if mask is not None: + mask = torch.from_numpy(mask) + + if image.shape[0] == 3 and normalize_img: + # Normalize image: + image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W + + return image, mask, scale, origin_img_size + +def read_megadepth_gray_sample_homowarp(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read image + if read_gray: + image = imread_gray(path, augment_fn) + else: + image = imread_color(path, augment_fn) + + # resize image + w, h = image.shape[1], image.shape[0] + if resize_by_stretch: + w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0]) + else: + if not isinstance(resize, int): + assert resize[0] == resize[1] + resize = resize[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + w_new, h_new = get_divisible_wh(w_new, h_new, df) + + origin_img_size = torch.tensor([h, w], dtype=torch.float) + + # Sample homography and warp: + homo_sampled = sample_homography_sap(h, w) # 3*3 + homo_sampled_normed = normalize_homography( + torch.from_numpy(homo_sampled[None]).to(torch.float32), + (h, w), + (h, w), + ) + + if len(image.shape) == 2: + image = torch.from_numpy(image).float()[None, None] / 255 # B * C * H * W + else: + image = torch.from_numpy(image).float().permute(2,0,1)[None] / 255 + + homo_warpped_image = homography_warp( + image, # 1 * C * H * W + torch.linalg.inv(homo_sampled_normed), + (h, w), + ) + image = (homo_warpped_image[0].permute(1,2,0).numpy() * 255).astype(np.uint8) + norm_pixel_mat = normal_transform_pixel(h, w) # 1 * 3 * 3 + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if not read_gray: + image = image.transpose(2,0,1) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + if len(image.shape) == 2: + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + else: + image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized + if mask is not None: + mask = torch.from_numpy(mask) + + if image.shape[0] == 3 and normalize_img: + # Normalize image: + image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W + + return image, mask, scale, origin_img_size, norm_pixel_mat[0], homo_sampled_normed[0] + + +def read_megadepth_depth_gray(path, resize=None, df=None, padding=False, augment_fn=None, read_gray=True, normalize_img=False, resize_by_stretch=False): + """ + Args: + resize (int, optional): the longer edge of resized images. None for no resize. + padding (bool): If set to 'True', zero-pad resized images to squared size. + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + depth = read_megadepth_depth(path, return_tensor=False) + + # following controlnet 1-depth + depth = depth.astype(np.float64) + depth_non_zero = depth[depth!=0] + vmin = np.percentile(depth_non_zero, 2) + vmax = np.percentile(depth_non_zero, 85) + depth -= vmin + depth /= (vmax - vmin + 1e-4) + depth = 1.0 - depth + image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + # resize image + w, h = image.shape[1], image.shape[0] + if resize_by_stretch: + w_new, h_new = (resize, resize) if isinstance(resize, int) else (resize[1], resize[0]) + else: + if not isinstance(resize, int): + assert resize[0] == resize[1] + resize = resize[0] + w_new, h_new = get_resized_wh(w, h, resize) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + w_new, h_new = get_divisible_wh(w_new, h_new, df) + origin_img_size = torch.tensor([h, w], dtype=torch.float) + + image = cv2.resize(image, (w_new, h_new)) + scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) + + if padding: # padding + pad_to = max(h_new, w_new) + image, mask = pad_bottom_right(image, pad_to, ret_mask=True) + else: + mask = None + + if read_gray: + image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized + else: + image = np.stack([image]*3) # 3 * H * W + image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized + if mask is not None: + mask = torch.from_numpy(mask) + + if image.shape[0] == 3 and normalize_img: + # Normalize image: + image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) # Input: 3*H*W + + return image, mask, scale, origin_img_size + +def read_megadepth_depth(path, pad_to=None, return_tensor=True): + if path.startswith('oss://'): + path = path.replace(OSS_FOLDER_PATH, PCACHE_FOLDER_PATH) + if path.startswith('pcache://'): + path = path[:9] + path[9:].replace('////', '/').replace('///', '/').replace('//', '/') # remove all continuous '/' + + load_failed = True + failed_num = 0 + while load_failed: + try: + if '.png' in path: + if 'scannet_plus' in path: + depth = imread_gray(path, cv_type=cv2.IMREAD_UNCHANGED).astype(np.float32) + + with open(path, 'rb') as f: + # CO3D + depth = np.asarray(Image.open(f)).astype(np.float32) + depth = depth / 1000 + elif '.pfm' in path: + # For BlendedMVS dataset (not support pcache): + depth = load_pfm(path).copy() + else: + # For MegaDepth + if str(path).startswith('oss://') or str(path).startswith('pcache://'): + depth = load_array_from_pcache(path, None, use_h5py=True) + else: + depth = np.array(h5py.File(path, 'r')['depth']) + load_failed = False + except: + failed_num += 1 + if failed_num > 5000: + logger.error(f"Try to load: {path}, but failed {failed_num} times") + continue + + if pad_to is not None: + depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) + if return_tensor: + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +# --- ScanNet --- + +def read_scannet_gray(path, resize=(640, 480), augment_fn=None): + """ + Args: + resize (tuple): align image to depthmap, in (w, h). + augment_fn (callable, optional): augments images with pre-defined visual effects + Returns: + image (torch.tensor): (1, h, w) + mask (torch.tensor): (h, w) + scale (torch.tensor): [w/w_new, h/h_new] + """ + # read and resize image + image = imread_gray(path, augment_fn) + image = cv2.resize(image, resize) + + # (h, w) -> (1, h, w) and normalized + image = torch.from_numpy(image).float()[None] / 255 + return image + + +def read_scannet_depth(path): + if str(path).startswith('s3://'): + depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) + else: + depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + +def read_scannet_pose(path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = inv(cam2world) + return world2cam + + +def read_scannet_intrinsic(path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return intrinsic[:-1, :-1] + +def dict_to_cuda(data_dict): + data_dict_cuda = {} + for k, v in data_dict.items(): + if isinstance(v, torch.Tensor): + data_dict_cuda[k] = v.cuda() + elif isinstance(v, dict): + data_dict_cuda[k] = dict_to_cuda(v) + elif isinstance(v, list): + data_dict_cuda[k] = list_to_cuda(v) + else: + data_dict_cuda[k] = v + return data_dict_cuda + +def list_to_cuda(data_list): + data_list_cuda = [] + for obj in data_list: + if isinstance(obj, torch.Tensor): + data_list_cuda.append(obj.cuda()) + elif isinstance(obj, dict): + data_list_cuda.append(dict_to_cuda(obj)) + elif isinstance(obj, list): + data_list_cuda.append(list_to_cuda(obj)) + else: + data_list_cuda.append(obj) + return data_list_cuda \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/utils/easydict.py b/imcui/third_party/MatchAnything/src/utils/easydict.py new file mode 100755 index 0000000000000000000000000000000000000000..e4af7a343311581cd56b486fa4d1cd0f60d1ad86 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/easydict.py @@ -0,0 +1,148 @@ +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> map(attrgetter('x'), d.bar) + [1, 3] + >>> map(attrgetter('y'), d.bar) + [2, 4] + >>> d = EasyDict() + >>> d.keys() + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> o.items() + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + if hasattr(self, k): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/utils/geometry.py b/imcui/third_party/MatchAnything/src/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..d6470ca309655e4e81f58bfe515a428b6e8b3623 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/geometry.py @@ -0,0 +1,366 @@ +from __future__ import division +import torch +import torch.nn.functional as F +import numpy as np +# from numba import jit + +pixel_coords = None + +def set_id_grid(depth): + b, h, w = depth.size() + i_range = torch.arange(0, h).view(1, h, 1).expand(1,h,w).type_as(depth) # [1, H, W] + j_range = torch.arange(0, w).view(1, 1, w).expand(1,h,w).type_as(depth) # [1, H, W] + ones = torch.ones(1,h,w).type_as(depth) + + pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W] + return pixel_coords + +def check_sizes(input, input_name, expected): + condition = [input.ndimension() == len(expected)] + for i,size in enumerate(expected): + if size.isdigit(): + condition.append(input.size(i) == int(size)) + assert(all(condition)), "wrong size for {}, expected {}, got {}".format(input_name, 'x'.join(expected), list(input.size())) + + +def pixel2cam(depth, intrinsics_inv): + """Transform coordinates in the pixel frame to the camera frame. + Args: + depth: depth maps -- [B, H, W] + intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3] + Returns: + array of (u,v,1) cam coordinates -- [B, 3, H, W] + """ + b, h, w = depth.size() + pixel_coords = set_id_grid(depth) + current_pixel_coords = pixel_coords[:,:,:h,:w].expand(b,3,h,w).reshape(b, 3, -1) # [B, 3, H*W] + cam_coords = (intrinsics_inv.float() @ current_pixel_coords.float()).reshape(b, 3, h, w) + return cam_coords * depth.unsqueeze(1) + +def cam2pixel_depth(cam_coords, proj_c2p_rot, proj_c2p_tr): + """Transform coordinates in the camera frame to the pixel frame and get depth map. + Args: + cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W] + proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4] + proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] + Returns: + tensor of [-1,1] coordinates -- [B, 2, H, W] + depth map -- [B, H, W] + """ + b, _, h, w = cam_coords.size() + cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W] + if proj_c2p_rot is not None: + pcoords = proj_c2p_rot @ cam_coords_flat + else: + pcoords = cam_coords_flat + + if proj_c2p_tr is not None: + pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] + X = pcoords[:, 0] + Y = pcoords[:, 1] + Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm + + X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W] + Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W] + + pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] + return pixel_coords.reshape(b,h,w,2), Z.reshape(b, h, w) + + +def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr): + """Transform coordinates in the camera frame to the pixel frame. + Args: + cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W] + proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4] + proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] + Returns: + array of [-1,1] coordinates -- [B, 2, H, W] + """ + b, _, h, w = cam_coords.size() + cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W] + if proj_c2p_rot is not None: + pcoords = proj_c2p_rot @ cam_coords_flat + else: + pcoords = cam_coords_flat + + if proj_c2p_tr is not None: + pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] + X = pcoords[:, 0] + Y = pcoords[:, 1] + Z = pcoords[:, 2].clamp(min=1e-3) # [B, H*W] min_depth = 1 mm + + X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W] + Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W] + + pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] + return pixel_coords.reshape(b,h,w,2) + + +def reproject_kpts(dim0_idxs, kpts, depth, rel_pose, K0, K1): + """ Reproject keypoints with depth, relative pose and camera intrinsics + Args: + dim0_idxs (torch.LoneTensor): (B*max_kpts, ) + kpts (torch.LongTensor): (B, max_kpts, 2) - + depth (torch.Tensor): (B, H, W) + rel_pose (torch.Tensor): (B, 3, 4) relative transfomation from target to source (T_0to1) -- + K0: (torch.Tensor): (N, 3, 3) - (K_0) + K1: (torch.Tensor): (N, 3, 3) - (K_1) + Returns: + (torch.Tensor): (B, max_kpts, 2) the reprojected kpts + """ + # pixel to camera + device = kpts.device + B, max_kpts, _ = kpts.shape + + kpts = kpts.reshape(-1, 2) # (B*K, 2) + kpts_depth = depth[dim0_idxs, kpts[:, 1], kpts[:, 0]] # (B*K, ) + kpts = torch.cat([kpts.float(), + torch.ones((kpts.shape[0], 1), dtype=torch.float32, device=device)], -1) # (B*K, 3) + pixel_coords = (kpts * kpts_depth[:, None]).reshape(B, max_kpts, 3).permute(0, 2, 1) # (B, 3, K) + + cam_coords = K0.inverse() @ pixel_coords # (N, 3, max_kpts) + # camera1 to camera 2 + rel_pose_R = rel_pose[:, :, :-1] # (B, 3, 3) + rel_pose_t = rel_pose[:, :, -1][..., None] # (B, 3, 1) + cam2_coords = rel_pose_R @ cam_coords + rel_pose_t # (B, 3, max_kpts) + # projection + pixel2_coords = K1 @ cam2_coords # (B, 3, max_kpts) + reproj_kpts = pixel2_coords[:, :-1, :] / pixel2_coords[:, -1, :][:, None].expand(-1, 2, -1) + return reproj_kpts.permute(0, 2, 1) + + +def check_depth_consistency(b_idxs, kpts0, depth0, kpts1, depth1, T_0to1, K0, K1, + atol=0.1, rtol=0.0): + """ + Args: + b_idxs (torch.LongTensor): (n_kpts, ) the batch indices which each keypoints pairs belong to + kpts0 (torch.LongTensor): (n_kpts, 2) - + depth0 (torch.Tensor): (B, H, W) + kpts1 (torch.LongTensor): (n_kpts, 2) + depth1 (torch.Tensor): (B, H, W) + T_0to1 (torch.Tensor): (B, 3, 4) + K0: (torch.Tensor): (N, 3, 3) - (K_0) + K1: (torch.Tensor): (N, 3, 3) - (K_1) + atol (float): the absolute tolerance for depth consistency check + rtol (float): the relative tolerance for depth consistency check + Returns: + valid_mask (torch.Tensor): (n_kpts, ) + Notes: + The two corresponding keypoints are depth consistent if the following equation is held: + abs(kpt_0to1_depth - kpt1_depth) <= (atol + rtol * abs(kpt1_depth)) + * In the initial reimplementation, `atol=0.1, rtol=0` is used, and the result is better with + `atol=1.0, rtol=0` (which nearly ignore the depth consistency check). + * However, the author suggests using `atol=0.0, rtol=0.1` as in https://github.com/magicleap/SuperGluePretrainedNetwork/issues/31#issuecomment-681866054 + """ + device = kpts0.device + n_kpts = kpts0.shape[0] + + kpts0_depth = depth0[b_idxs, kpts0[:, 1], kpts0[:, 0]] # (n_kpts, ) + kpts1_depth = depth1[b_idxs, kpts1[:, 1], kpts1[:, 0]] # (n_kpts, ) + kpts0 = torch.cat([kpts0.float(), + torch.ones((n_kpts, 1), dtype=torch.float32, device=device)], -1) # (n_kpts, 3) + pixel_coords = (kpts0 * kpts0_depth[:, None])[..., None] # (n_kpts, 3, 1) + + # indexing from T_0to1 and K - treat all kpts as a batch + K0 = K0[b_idxs, :, :] # (n_kpts, 3, 3) + T_0to1 = T_0to1[b_idxs, :, :] # (n_kpts, 3, 4) + cam_coords = K0.inverse() @ pixel_coords # (n_kpts, 3, 1) + + # camera1 to camera2 + R_0to1 = T_0to1[:, :, :-1] # (n_kpts, 3, 3) + t_0to1 = T_0to1[:, :, -1][..., None] # (n_kpts, 3, 1) + cam1_coords = R_0to1 @ cam_coords + t_0to1 # (n_kpts, 3, 1) + K1 = K1[b_idxs, :, :] # (n_kpts, 3, 3) + pixel1_coords = K1 @ cam1_coords # (n_kpts, 3, 1) + kpts_0to1_depth = pixel1_coords[:, -1, 0] # (n_kpts, ) + return (kpts_0to1_depth - kpts1_depth).abs() <= atol + rtol * kpts1_depth.abs() + + +def inverse_warp(img, depth, pose, intrinsics, mode='bilinear', padding_mode='zeros'): + """ + Inverse warp a source image to the target image plane. + + Args: + img: the source image (where to sample pixels) -- [B, 3, H, W] + depth: depth map of the target image -- [B, H, W] + pose: relative transfomation from target to source -- [B, 3, 4] + intrinsics: camera intrinsic matrix -- [B, 3, 3] + Returns: + projected_img: Source image warped to the target image plane + valid_points: Boolean array indicating point validity + """ + # check_sizes(img, 'img', 'B3HW') + check_sizes(depth, 'depth', 'BHW') +# check_sizes(pose, 'pose', 'B6') + check_sizes(intrinsics, 'intrinsics', 'B33') + + batch_size, _, img_height, img_width = img.size() + + cam_coords = pixel2cam(depth, intrinsics.inverse()) # [B,3,H,W] + + pose_mat = pose # (B, 3, 4) + + # Get projection matrix for target camera frame to source pixel frame + proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4] + + rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:] + src_pixel_coords = cam2pixel(cam_coords, rot, tr) # [B,H,W,2] + projected_img = F.grid_sample(img, src_pixel_coords, mode=mode, + padding_mode=padding_mode, align_corners=True) + + valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1 + + return projected_img, valid_points + +def depth_inverse_warp(depth_source, depth, pose, intrinsic_source, intrinsic, mode='nearest', padding_mode='zeros'): + """ + 1. Inversely warp a source depth map to the target image plane (warped depth map still in source frame) + 2. Transform the target depth map to the source image frame + Args: + depth_source: the source image (where to sample pixels) -- [B, H, W] + depth: depth map of the target image -- [B, H, W] + pose: relative transfomation from target to source -- [B, 3, 4] + intrinsics: camera intrinsic matrix -- [B, 3, 3] + Returns: + warped_depth: Source depth warped to the target image plane -- [B, H, W] + projected_depth: Target depth projected to the source image frame -- [B, H, W] + valid_points: Boolean array indicating point validity -- [B, H, W] + """ + check_sizes(depth_source, 'depth', 'BHW') + check_sizes(depth, 'depth', 'BHW') + check_sizes(intrinsic_source, 'intrinsics', 'B33') + + b, h, w = depth.size() + + cam_coords = pixel2cam(depth, intrinsic.inverse()) # [B,3,H,W] + + pose_mat = pose # (B, 3, 4) + + # Get projection matrix from target camera frame to source pixel frame + proj_cam_to_src_pixel = intrinsic_source @ pose_mat # [B, 3, 4] + + rot, tr = proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:] + src_pixel_coords, depth_target2src = cam2pixel_depth(cam_coords, rot, tr) # [B,H,W,2] + warped_depth = F.grid_sample(depth_source[:, None], src_pixel_coords, mode=mode, + padding_mode=padding_mode, align_corners=True) # [B, 1, H, W] + + valid_points = (src_pixel_coords.abs().max(dim=-1)[0] <= 1) &\ + (depth > 0.0) & (warped_depth[:, 0] > 0.0) # [B, H, W] + return warped_depth[:, 0], depth_target2src, valid_points + +def to_skew(t): + """ Transform the translation vector t to skew-symmetric matrix. + Args: + t (torch.Tensor): (B, 3) + """ + t_skew = t.new_ones((t.shape[0], 3, 3)) + t_skew[:, 0, 1] = -t[:, 2] + t_skew[:, 1, 0] = t[:, 2] + t_skew[:, 0, 2] = t[:, 1] + t_skew[:, 2, 0] = -t[:, 1] + t_skew[:, 1, 2] = -t[:, 0] + t_skew[:, 2, 1] = t[:, 0] + return t_skew # (B, 3, 3) + + +def to_homogeneous(pts): + """ + Args: + pts (torch.Tensor): (B, K, 2) + """ + return torch.cat([pts, torch.ones_like(pts[..., :1])], -1) # (B, K, 3) + + +def pix2img(pts, K): + """ + Args: + pts (torch.Tensor): (B, K, 2) + K (torch.Tensor): (B, 3, 3) + """ + return (pts - K[:, [0, 1], [2, 2]][:, None]) / K[:, [0, 1], [0, 1]][:, None] + + +def weighted_blind_sed(kpts0, kpts1, weights, E, K0, K1): + """ Calculate the squared weighted blind symmetric epipolar distance, which is the sed between + all possible keypoints pairs. + Args: + kpts0 (torch.Tensor): (B, K0, 2) + ktps1 (torch.Tensor): (B, K1, 2) + weights (torch.Tensor): (B, K0, K1) + E (torch.Tensor): (B, 3, 3) - the essential matrix + K0 (torch.Tensor): (B, 3, 3) + K1 (torch.Tensor): (B, 3, 3) + Returns: + w_sed (torch.Tensor): (B, K0, K1) + """ + M, N = kpts0.shape[1], kpts1.shape[1] + + kpts0 = to_homogeneous(pix2img(kpts0, K0)) + kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3) + + R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1) + # w_R = weights * R # (B, K0, K1) + + Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3) + Etp1 = kpts1 @ E # (B, K1, 3) + d = R**2 * (1.0 / (Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N) + + 1.0 / (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1)) * weights # (B, K0, K1) + return d + +def weighted_blind_sampson(kpts0, kpts1, weights, E, K0, K1): + """ Calculate the squared weighted blind sampson distance, which is the sampson distance between + all possible keypoints pairs weighted by the given weights. + """ + M, N = kpts0.shape[1], kpts1.shape[1] + + kpts0 = to_homogeneous(pix2img(kpts0, K0)) + kpts1 = to_homogeneous(pix2img(kpts1, K1)) # (B, K1, 3) + + R = kpts0 @ E.transpose(1, 2) @ kpts1.transpose(1, 2) # (B, K0, K1) + # w_R = weights * R # (B, K0, K1) + + Ep0 = kpts0 @ E.transpose(1, 2) # (B, K0, 3) + Etp1 = kpts1 @ E # (B, K1, 3) + d = R**2 * (1.0 / ((Ep0[..., 0]**2 + Ep0[..., 1]**2)[..., None].expand(-1, -1, N) + + (Etp1[..., 0]**2 + Etp1[..., 1]**2)[:, None].expand(-1, M, -1))) * weights # (B, K0, K1) + return d + + +def angular_rel_rot(T_0to1): + """ + Args: + T0_to_1 (np.ndarray): (4, 4) + """ + cos = (np.trace(T_0to1[:-1, :-1]) - 1) / 2 + if cos < -1: + cos = -1.0 + if cos > 1: + cos = 1.0 + angle_error_rot = np.rad2deg(np.abs(np.arccos(cos))) + + return angle_error_rot + +def angular_rel_pose(T0, T1): + """ + Args: + T0 (np.ndarray): (4, 4) + T1 (np.ndarray): (4, 4) + + """ + cos = (np.trace(T0[:-1, :-1].T @ T1[:-1, :-1]) - 1) / 2 + if cos < -1: + cos = -1.0 + if cos > 1: + cos = 1.0 + angle_error_rot = np.rad2deg(np.abs(np.arccos(cos))) + + # calculate angular translation error + n = np.linalg.norm(T0[:-1, -1]) * np.linalg.norm(T1[:-1, -1]) + cos = np.dot(T0[:-1, -1], T1[:-1, -1]) / n + if cos < -1: + cos = -1.0 + if cos > 1: + cos = 1.0 + angle_error_trans = np.rad2deg(np.arccos(cos)) + + return angle_error_rot, angle_error_trans \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/utils/homography_utils.py b/imcui/third_party/MatchAnything/src/utils/homography_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6ac5adb97c9963c216a97761cca9dfccacd91f --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/homography_utils.py @@ -0,0 +1,366 @@ +import math +from typing import Tuple + +import numpy as np +import torch + +def to_homogeneous(points): + """Convert N-dimensional points to homogeneous coordinates. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N). + Returns: + A torch.Tensor or numpy.ndarray with size (..., N+1). + """ + if isinstance(points, torch.Tensor): + pad = points.new_ones(points.shape[:-1] + (1,)) + return torch.cat([points, pad], dim=-1) + elif isinstance(points, np.ndarray): + pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype) + return np.concatenate([points, pad], axis=-1) + else: + raise ValueError + + +def from_homogeneous(points, eps=0.0): + """Remove the homogeneous dimension of N-dimensional points. + Args: + points: torch.Tensor or numpy.ndarray with size (..., N+1). + eps: Epsilon value to prevent zero division. + Returns: + A torch.Tensor or numpy ndarray with size (..., N). + """ + return points[..., :-1] / (points[..., -1:] + eps) + + +def flat2mat(H): + return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3]) + + +# Homography creation + + +def create_center_patch(shape, patch_shape=None): + if patch_shape is None: + patch_shape = shape + width, height = shape + pwidth, pheight = patch_shape + left = int((width - pwidth) / 2) + bottom = int((height - pheight) / 2) + right = int((width + pwidth) / 2) + top = int((height + pheight) / 2) + return np.array([[left, bottom], [left, top], [right, top], [right, bottom]]) + + +def check_convex(patch, min_convexity=0.05): + """Checks if given polygon vertices [N,2] form a convex shape""" + for i in range(patch.shape[0]): + x1, y1 = patch[(i - 1) % patch.shape[0]] + x2, y2 = patch[i] + x3, y3 = patch[(i + 1) % patch.shape[0]] + if (x2 - x1) * (y3 - y2) - (x3 - x2) * (y2 - y1) > -min_convexity: + return False + return True + + +def sample_homography_corners( + shape, + patch_shape, + difficulty=1.0, + translation=0.4, + n_angles=10, + max_angle=90, + min_convexity=0.05, + rng=np.random, +): + max_angle = max_angle / 180.0 * math.pi + width, height = shape + pwidth, pheight = width * (1 - difficulty), height * (1 - difficulty) + min_pts1 = create_center_patch(shape, (pwidth, pheight)) + full = create_center_patch(shape) + pts2 = create_center_patch(patch_shape) + scale = min_pts1 - full + found_valid = False + cnt = -1 + while not found_valid: + offsets = rng.uniform(0.0, 1.0, size=(4, 2)) * scale + pts1 = full + offsets + found_valid = check_convex(pts1 / np.array(shape), min_convexity) + cnt += 1 + + # re-center + pts1 = pts1 - np.mean(pts1, axis=0, keepdims=True) + pts1 = pts1 + np.mean(min_pts1, axis=0, keepdims=True) + + # Rotation + if n_angles > 0 and difficulty > 0: + angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles) + rng.shuffle(angles) + rng.shuffle(angles) + angles = np.concatenate([[0.0], angles], axis=0) + + center = np.mean(pts1, axis=0, keepdims=True) + rot_mat = np.reshape( + np.stack( + [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)], + axis=1, + ), + [-1, 2, 2], + ) + rotated = ( + np.matmul( + np.tile(np.expand_dims(pts1 - center, axis=0), [n_angles + 1, 1, 1]), + rot_mat, + ) + + center + ) + + for idx in range(1, n_angles): + warped_points = rotated[idx] / np.array(shape) + if np.all((warped_points >= 0.0) & (warped_points < 1.0)): + pts1 = rotated[idx] + break + + # Translation + if translation > 0: + min_trans = -np.min(pts1, axis=0) + max_trans = shape - np.max(pts1, axis=0) + trans = rng.uniform(min_trans, max_trans)[None] + pts1 += trans * translation * difficulty + + H = compute_homography(pts1, pts2, [1.0, 1.0]) + warped = warp_points(full, H, inverse=False) + return H, full, warped, patch_shape + + +def compute_homography(pts1_, pts2_, shape): + """Compute the homography matrix from 4 point correspondences""" + # Rescale to actual size + shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x] + pts1 = pts1_ * np.expand_dims(shape, axis=0) + pts2 = pts2_ * np.expand_dims(shape, axis=0) + + def ax(p, q): + return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] + + def ay(p, q): + return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] + + a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0) + p_mat = np.transpose( + np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0) + ) + homography = np.transpose(np.linalg.solve(a_mat, p_mat)) + return flat2mat(homography) + + +# Point warping utils + + +def warp_points(points, homography, inverse=True): + """ + Warp a list of points with the INVERSE of the given homography. + The inverse is used to be coherent with tf.contrib.image.transform + Arguments: + points: list of N points, shape (N, 2). + homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively). + Returns: a Tensor of shape (N, 2) or (B, N, 2) (depending on whether the homography + is batched) containing the new coordinates of the warped points. + """ + H = homography[None] if len(homography.shape) == 2 else homography + + # Get the points to the homogeneous format + num_points = points.shape[0] + points = np.concatenate([points, np.ones([num_points, 1], dtype=np.float32)], -1) + + H_inv = np.transpose(np.linalg.inv(H) if inverse else H) + warped_points = np.tensordot(points, H_inv, axes=[[1], [0]]) + + warped_points = np.transpose(warped_points, [2, 0, 1]) + warped_points[np.abs(warped_points[:, :, 2]) < 1e-8, 2] = 1e-8 + warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:] + + return warped_points[0] if len(homography.shape) == 2 else warped_points + + +def warp_points_torch(points, H, inverse=True): + """ + Warp a list of points with the INVERSE of the given homography. + The inverse is used to be coherent with tf.contrib.image.transform + Arguments: + points: batched list of N points, shape (B, N, 2). + H: batched or not (shapes (B, 3, 3) and (3, 3) respectively). + inverse: Whether to multiply the points by H or the inverse of H + Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps. + """ + + # Get the points to the homogeneous format + points = to_homogeneous(points) + + # Apply the homography + H_mat = (torch.inverse(H) if inverse else H).transpose(-2, -1) + warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat) + + warped_points = from_homogeneous(warped_points, eps=1e-5) + return warped_points + + +# Line warping utils + + +def seg_equation(segs): + # calculate list of start, end and midpoints points from both lists + start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous( + segs[..., 1, :] + ) + # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1 + lines = torch.cross(start_points, end_points, dim=-1) + lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None] + assert torch.all( + lines_norm > 0 + ), "Error: trying to compute the equation of a line with a single point" + lines = lines / lines_norm + return lines + + +def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]): + h, w = img_shape + return ( + (pts >= 0).all(dim=-1) + & (pts[..., 0] < w) + & (pts[..., 1] < h) + & (~torch.isinf(pts).any(dim=-1)) + ) + + +def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor: + """ + Shrink an array of segments to fit inside the image. + :param segs: The tensor of segments with shape (N, 2, 2) + :param img_shape: The image shape in format (H, W) + """ + EPS = 1e-4 + device = segs.device + w, h = img_shape[1], img_shape[0] + # Project the segments to the reference image + segs = segs.clone() + eqs = seg_equation(segs) + x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor( + [0.0, 1, 0], device=device + ) + x0 = x0.repeat(eqs.shape[:-1] + (1,)) + y0 = y0.repeat(eqs.shape[:-1] + (1,)) + pt_x0s = torch.cross(eqs, x0, dim=-1) + pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1] + pt_x0s_valid = is_inside_img(pt_x0s, img_shape) + pt_y0s = torch.cross(eqs, y0, dim=-1) + pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1] + pt_y0s_valid = is_inside_img(pt_y0s, img_shape) + + xW = torch.tensor([1.0, 0, EPS - w], device=device) + yH = torch.tensor([0.0, 1, EPS - h], device=device) + xW = xW.repeat(eqs.shape[:-1] + (1,)) + yH = yH.repeat(eqs.shape[:-1] + (1,)) + pt_xWs = torch.cross(eqs, xW, dim=-1) + pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1] + pt_xWs_valid = is_inside_img(pt_xWs, img_shape) + pt_yHs = torch.cross(eqs, yH, dim=-1) + pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1] + pt_yHs_valid = is_inside_img(pt_yHs, img_shape) + + # If the X coordinate of the first endpoint is out + mask = (segs[..., 0, 0] < 0) & pt_x0s_valid + segs[mask, 0, :] = pt_x0s[mask] + mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid + segs[mask, 0, :] = pt_xWs[mask] + # If the X coordinate of the second endpoint is out + mask = (segs[..., 1, 0] < 0) & pt_x0s_valid + segs[mask, 1, :] = pt_x0s[mask] + mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid + segs[mask, 1, :] = pt_xWs[mask] + # If the Y coordinate of the first endpoint is out + mask = (segs[..., 0, 1] < 0) & pt_y0s_valid + segs[mask, 0, :] = pt_y0s[mask] + mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid + segs[mask, 0, :] = pt_yHs[mask] + # If the Y coordinate of the second endpoint is out + mask = (segs[..., 1, 1] < 0) & pt_y0s_valid + segs[mask, 1, :] = pt_y0s[mask] + mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid + segs[mask, 1, :] = pt_yHs[mask] + + assert ( + torch.all(segs >= 0) + and torch.all(segs[..., 0] < w) + and torch.all(segs[..., 1] < h) + ) + return segs + + +def warp_lines_torch( + lines, H, inverse=True, dst_shape: Tuple[int, int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param lines: A tensor of shape (B, N, 2, 2) + where B is the batch size, N the number of lines. + :param H: The homography used to convert the lines. + batched or not (shapes (B, 3, 3) and (3, 3) respectively). + :param inverse: Whether to apply H or the inverse of H + :param dst_shape:If provided, lines are trimmed to be inside the image + """ + device = lines.device + batch_size = len(lines) + lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape( + lines.shape + ) + + if dst_shape is None: + return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device) + + out_img = torch.any( + (lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1 + ) + valid = ~out_img.all(-1) + any_out_of_img = out_img.any(-1) + lines_to_trim = valid & any_out_of_img + + for b in range(batch_size): + lines_to_trim_mask_b = lines_to_trim[b] + lines_to_trim_b = lines[b][lines_to_trim_mask_b] + corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape) + lines[b][lines_to_trim_mask_b] = corrected_lines + + return lines, valid + + +# Homography evaluation utils + + +def sym_homography_error(kpts0, kpts1, T_0to1): + kpts0_1 = from_homogeneous(to_homogeneous(kpts0) @ T_0to1.transpose(-1, -2)) + dist0_1 = ((kpts0_1 - kpts1) ** 2).sum(-1).sqrt() + + kpts1_0 = from_homogeneous( + to_homogeneous(kpts1) @ torch.pinverse(T_0to1.transpose(-1, -2)) + ) + dist1_0 = ((kpts1_0 - kpts0) ** 2).sum(-1).sqrt() + + return (dist0_1 + dist1_0) / 2.0 + + +def sym_homography_error_all(kpts0, kpts1, H): + kp0_1 = warp_points_torch(kpts0, H, inverse=False) + kp1_0 = warp_points_torch(kpts1, H, inverse=True) + + # build a distance matrix of size [... x M x N] + dist0 = torch.sum((kp0_1.unsqueeze(-2) - kpts1.unsqueeze(-3)) ** 2, -1).sqrt() + dist1 = torch.sum((kpts0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1).sqrt() + return (dist0 + dist1) / 2.0 + + +def homography_corner_error(T, T_gt, image_size): + W, H = image_size[..., 0], image_size[..., 1] + corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T) + corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2)) + corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2)) + d = torch.sqrt(((corners1 - corners1_gt) ** 2).sum(-1)) + return d.mean(-1) diff --git a/imcui/third_party/MatchAnything/src/utils/metrics.py b/imcui/third_party/MatchAnything/src/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..32703f64de82aa45996011195a59b5f493a82bf1 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/metrics.py @@ -0,0 +1,445 @@ +import torch +import cv2 +import numpy as np +from collections import OrderedDict +from loguru import logger +from .homography_utils import warp_points, warp_points_torch +from kornia.geometry.epipolar import numeric +from kornia.geometry.conversions import convert_points_to_homogeneous +import pprint + + +# --- METRICS --- + +def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): + # angle error between 2 vectors + t_gt = T_0to1[:3, 3] + n = np.linalg.norm(t) * np.linalg.norm(t_gt) + t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) + t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity + if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging + t_err = 0 + + # angle error between 2 rotation matrices + R_gt = T_0to1[:3, :3] + cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 + cos = np.clip(cos, -1., 1.) # handle numercial errors + R_err = np.rad2deg(np.abs(np.arccos(cos))) + + return t_err, R_err + +def warp_pts_error(H_est, pts_coord, H_gt=None, pts_gt=None): + """ + corner_coord: 4*2 + """ + if H_gt is not None: + est_warp = warp_points(pts_coord, H_est, False) + est_gt = warp_points(pts_coord, H_gt, False) + diff = est_warp - est_gt + elif pts_gt is not None: + est_warp = warp_points(pts_coord, H_est, False) + diff = est_warp - pts_gt + + return np.mean(np.linalg.norm(diff, axis=1)) + +def homo_warp_match_distance(H_gt, kpts0, kpts1, hw): + """ + corner_coord: 4*2 + """ + if isinstance(H_gt, np.ndarray): + kpts_warped = warp_points(kpts0, H_gt) + normalized_distance = np.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1) + else: + kpts_warped = warp_points_torch(kpts0, H_gt) + normalized_distance = torch.linalg.norm((kpts_warped - kpts1) / hw[None, [1,0]], axis=1) + return normalized_distance + +def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): + """Squared symmetric epipolar distance. + This can be seen as a biased estimation of the reprojection error. + Args: + pts0 (torch.Tensor): [N, 2] + E (torch.Tensor): [3, 3] + """ + pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + pts0 = convert_points_to_homogeneous(pts0) + pts1 = convert_points_to_homogeneous(pts1) + + Ep0 = pts0 @ E.T # [N, 3] + p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] + Etp1 = pts1 @ E # [N, 3] + + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + + +def compute_symmetrical_epipolar_errors(data, config): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) + E_mat = Tx @ data['T_0to1'][:, :3, :3] + + m_bids = data['m_bids'] + pts0 = data['mkpts0_f'] + pts1 = data['mkpts1_f'].clone().detach() + + if config.LOFTR.FINE.MTD_SPVS: + m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids'] + epi_errs = [] + for bs in range(Tx.size(0)): + mask = m_bids == bs + epi_errs.append( + symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + epi_errs = torch.cat(epi_errs, dim=0) + + data.update({'epi_errs': epi_errs}) + +def compute_homo_match_warp_errors(data, config): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + + homography_gt = data['homography'] + m_bids = data['m_bids'] + pts0 = data['mkpts0_f'] + pts1 = data['mkpts1_f'] + origin_img0_size = data['origin_img_size0'] + + if config.LOFTR.FINE.MTD_SPVS: + m_bids = data['m_bids_f'] if 'm_bids_f' in data else data['m_bids'] + epi_errs = [] + for bs in range(homography_gt.shape[0]): + mask = m_bids == bs + epi_errs.append( + homo_warp_match_distance(homography_gt[bs], pts0[mask], pts1[mask], origin_img0_size[bs])) + epi_errs = torch.cat(epi_errs, dim=0) + + data.update({'epi_errs': epi_errs}) + + +def compute_symmetrical_epipolar_errors_gt(data, config): + """ + Update: + data (dict):{"epi_errs": [M]} + """ + Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) + E_mat = Tx @ data['T_0to1'][:, :3, :3] + + m_bids = data['m_bids'] + pts0 = data['mkpts0_f_gt'] + pts1 = data['mkpts1_f_gt'] + + epi_errs = [] + for bs in range(Tx.size(0)): + # mask = m_bids == bs + assert bs == 0 + mask = torch.tensor([True]*pts0.shape[0], device = pts0.device) + if config.LOFTR.FINE.MTD_SPVS: + epi_errs.append( + symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + else: + epi_errs.append( + symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) + epi_errs = torch.cat(epi_errs, dim=0) + + data.update({'epi_errs': epi_errs}) + + +def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + if len(kpts0) < 5: + return None + # normalize keypoints + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + # normalize ransac threshold + ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) + + # compute pose with cv2 + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) + if E is None: + print("\nE is None while trying to recover pose.\n") + return None + + # recover pose from E + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + ret = (R, t[:, 0], mask.ravel() > 0) + best_num_inliers = n + + return ret + +def estimate_homo(kpts0, kpts1, thresh, conf=0.99999, mode='affine'): + if mode == 'affine': + H_est, inliers = cv2.estimateAffine2D(kpts0, kpts1, ransacReprojThreshold=thresh, confidence=conf, method=cv2.RANSAC) + if H_est is None: + return np.eye(3) * 0, np.empty((0)) + H_est = np.concatenate([H_est, np.array([[0, 0, 1]])], axis=0) # 3 * 3 + elif mode == 'homo': + H_est, inliers = cv2.findHomography(kpts0, kpts1, method=cv2.LMEDS, ransacReprojThreshold=thresh) + if H_est is None: + return np.eye(3) * 0, np.empty((0)) + + return H_est, inliers + +def compute_homo_corner_warp_errors(data, config): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] # Actually warp error + "t_errs" List[float]: [N] # Zero, place holder + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 + conf = config.TRAINER.RANSAC_CONF # 0.99999 + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + + if config.LOFTR.FINE.MTD_SPVS: + m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy() + + else: + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + homography_gt = data['homography'].cpu().numpy() + origin_size_0 = data['origin_img_size0'].cpu().numpy() + + for bs in range(homography_gt.shape[0]): + mask = m_bids == bs + ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf) + + if ret is None: + data['R_errs'].append(np.inf) + data['t_errs'].append(np.inf) + data['inliers'].append(np.array([]).astype(bool)) + else: + H_est, inliers = ret + corner_coord = np.array([[0, 0], [0, origin_size_0[bs][0]], [origin_size_0[bs][1], 0], [origin_size_0[bs][1], origin_size_0[bs][0]]]) + corner_warp_distance = warp_pts_error(H_est, corner_coord, H_gt=homography_gt[bs]) + data['R_errs'].append(corner_warp_distance) + data['t_errs'].append(0) + data['inliers'].append(inliers) + +def compute_warp_control_pts_errors(data, config): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] # Actually warp error + "t_errs" List[float]: [N] # Zero, place holder + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 + conf = config.TRAINER.RANSAC_CONF # 0.99999 + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + + if config.LOFTR.FINE.MTD_SPVS: + m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy() + + else: + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + gt_2D_matches = data["gt_2D_matches"].cpu().numpy() + + data.update({'epi_errs': torch.zeros(m_bids.shape[0])}) + for bs in range(gt_2D_matches.shape[0]): + mask = m_bids == bs + ret = estimate_homo(pts0[mask], pts1[mask], pixel_thr, conf=conf, mode=config.TRAINER.WARP_ESTIMATOR_MODEL) + + if ret is None: + data['R_errs'].append(np.inf) + data['t_errs'].append(np.inf) + data['inliers'].append(np.array([]).astype(bool)) + else: + H_est, inliers = ret + img0_pts, img1_pts = gt_2D_matches[bs][:, :2], gt_2D_matches[bs][:, 2:] + pts_warp_distance = warp_pts_error(H_est, img0_pts, pts_gt=img1_pts) + print(pts_warp_distance) + data['R_errs'].append(pts_warp_distance) + data['t_errs'].append(0) + data['inliers'].append(inliers) + +def compute_pose_errors(data, config): + """ + Update: + data (dict):{ + "R_errs" List[float]: [N] + "t_errs" List[float]: [N] + "inliers" List[np.ndarray]: [N] + } + """ + pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 + conf = config.TRAINER.RANSAC_CONF # 0.99999 + data.update({'R_errs': [], 't_errs': [], 'inliers': []}) + + if config.LOFTR.FINE.MTD_SPVS: + m_bids = data['m_bids_f'].cpu().numpy() if 'm_bids_f' in data else data['m_bids'].cpu().numpy() + + else: + m_bids = data['m_bids'].cpu().numpy() + pts0 = data['mkpts0_f'].cpu().numpy() + pts1 = data['mkpts1_f'].cpu().numpy() + K0 = data['K0'].cpu().numpy() + K1 = data['K1'].cpu().numpy() + T_0to1 = data['T_0to1'].cpu().numpy() + + for bs in range(K0.shape[0]): + mask = m_bids == bs + if config.LOFTR.EVAL_TIMES >= 1: + bpts0, bpts1 = pts0[mask], pts1[mask] + R_list, T_list, inliers_list = [], [], [] + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(bpts0))) + if _ >= config.LOFTR.EVAL_TIMES: + continue + bpts0 = bpts0[shuffling] + bpts1 = bpts1[shuffling] + + ret = estimate_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf) + if ret is None: + R_list.append(np.inf) + T_list.append(np.inf) + inliers_list.append(np.array([]).astype(bool)) + print('Pose error: inf') + else: + R, t, inliers = ret + t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) + R_list.append(R_err) + T_list.append(t_err) + inliers_list.append(inliers) + print(f'Pose error: {max(R_err, t_err)}') + R_err_mean = np.array(R_list).mean() + T_err_mean = np.array(T_list).mean() + # inliers_mean = np.array(inliers_list).mean() + + data['R_errs'].append(R_list) + data['t_errs'].append(T_list) + data['inliers'].append(inliers_list[0]) + + else: + ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) + + if ret is None: + data['R_errs'].append(np.inf) + data['t_errs'].append(np.inf) + data['inliers'].append(np.array([]).astype(bool)) + print('Pose error: inf') + else: + R, t, inliers = ret + t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) + data['R_errs'].append(R_err) + data['t_errs'].append(t_err) + data['inliers'].append(inliers) + print(f'Pose error: {max(R_err, t_err)}') + + +# --- METRIC AGGREGATION --- +def error_rmse(error): + squard_errors = np.square(error) # N * 2 + mse = np.mean(np.sum(squard_errors, axis=1)) + rmse = np.sqrt(mse) + return rmse + +def error_mae(error): + abs_diff = np.abs(error) # N * 2 + absolute_errors = np.sum(abs_diff, axis=1) + + # Return the maximum absolute error + mae = np.max(absolute_errors) + return mae + +def error_auc(errors, thresholds, method='exact_auc'): + """ + Args: + errors (list): [N,] + thresholds (list) + """ + if method == 'exact_auc': + errors = [0] + sorted(list(errors)) + recall = list(np.linspace(0, 1, len(errors))) + + aucs = [] + for thr in thresholds: + last_index = np.searchsorted(errors, thr) + y = recall[:last_index] + [recall[last_index-1]] + x = errors[:last_index] + [thr] + aucs.append(np.trapz(y, x) / thr) + return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + elif method == 'fire_paper': + aucs = [] + for threshold in thresholds: + accum_error = 0 + percent_error_below = np.zeros(threshold + 1) + for i in range(1, threshold + 1): + percent_error_below[i] = np.sum(errors < i) * 100 / len(errors) + accum_error += percent_error_below[i] + + aucs.append(accum_error / (threshold * 100)) + + return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + elif method == 'success_rate': + aucs = [] + for threshold in thresholds: + aucs.append((errors < threshold).astype(float).mean()) + return {f'SR@{t}': auc for t, auc in zip(thresholds, aucs)} + else: + raise NotImplementedError + + +def epidist_prec(errors, thresholds, ret_dict=False): + precs = [] + for thr in thresholds: + prec_ = [] + for errs in errors: + correct_mask = errs < thr + prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) + precs.append(np.mean(prec_) if len(prec_) > 0 else 0) + if ret_dict: + return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + else: + return precs + + +def aggregate_metrics(metrics, epi_err_thr=5e-4, eval_n_time=1, threshold=[5, 10, 20], method='exact_auc'): + """ Aggregate metrics for the whole dataset: + (This method should be called once per dataset) + 1. AUC of the pose error (angular) at the threshold [5, 10, 20] + 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) + """ + # filter duplicates + unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) + unq_ids = list(unq_ids.values()) + logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') + + # pose auc + angular_thresholds = threshold + if eval_n_time >= 1: + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0).reshape(-1, eval_n_time)[unq_ids].reshape(-1) + else: + pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] + logger.info('num of pose_errors: {}'.format(pose_errors.shape)) + aucs = error_auc(pose_errors, angular_thresholds, method=method) # (auc@5, auc@10, auc@20) + + if eval_n_time >= 1: + for i in range(eval_n_time): + aucs_i = error_auc(pose_errors.reshape(-1, eval_n_time)[:,i], angular_thresholds, method=method) + logger.info('\n' + f'results of {i}-th RANSAC' + pprint.pformat(aucs_i)) + # matching precision + dist_thresholds = [epi_err_thr] + precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) + + u_num_mathces = np.array(metrics['num_matches'], dtype=object)[unq_ids] + u_percent_inliers = np.array(metrics['percent_inliers'], dtype=object)[unq_ids] + num_matches = {f'num_matches': u_num_mathces.mean() } + percent_inliers = {f'percent_inliers': u_percent_inliers.mean()} + return {**aucs, **precs, **num_matches, **percent_inliers} diff --git a/imcui/third_party/MatchAnything/src/utils/misc.py b/imcui/third_party/MatchAnything/src/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8db04666519753ea2df43903ab6c47ec00a9a1 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/misc.py @@ -0,0 +1,101 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() + diff --git a/imcui/third_party/MatchAnything/src/utils/plotting.py b/imcui/third_party/MatchAnything/src/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..87d39733f57a09e4db61b61e53f82ad80e4a839b --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/plotting.py @@ -0,0 +1,248 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth': + thr = 1e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic', use_m_bids_f=False): + if use_m_bids_f: + b_mask = (data['m_bids_f'] == b_id) if 'm_bids_f' in data else (data['m_bids'] == b_id) + else: + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].clone().detach().cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy() + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy() + + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) if 'conf_matrix_gt' in data else data['gt'][1]['gt_prob'].sum() + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_confidence_figure(data, b_id): + raise NotImplementedError() + +def _make_gt_figure(data, b_id, alpha='dynamic', use_m_bids_f=False, mode='gt_fine'): + if 'fine' in mode: + mkpts0_key, mkpts1_key = 'mkpts0_f_gt', 'mkpts1_f_gt' + else: + mkpts0_key, mkpts1_key = 'mkpts0_c_gt', 'mkpts1_c_gt' + + if data['image0'].shape[0] == 1: + b_mask = torch.tensor([True]*data[mkpts0_key].shape[0], device = data[mkpts0_key].device) + else: + raise NotImplementedError + + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + try: + kpts0 = data[mkpts0_key][b_mask].cpu().numpy() + kpts1 = data[mkpts1_key][b_mask].cpu().numpy() + except: + kpts0, kpts1 = np.ones((0, 2)), np.ones((0, 2)) + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy() + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy() + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(kpts0)) + color = error_colormap(np.full((kpts0.shape[0]), conf_thr), conf_thr, alpha=0.1) + + text = [ + f'#Matches {len(kpts0)}', + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def make_matching_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA, + use_m_bids_f=config.LOFTR.FINE.MTD_SPVS) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + elif 'gt' in mode: + fig = _make_gt_figure(data, b_id, use_m_bids_f=config.LOFTR.FINE.MTD_SPVS, mode=mode) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + +def make_scores_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence', 'gt'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + if config.LOFTR.MATCH_COARSE.SKIP_SOFTMAX and config.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES: + plots = [data['histc_skipmn_in_softmax'][b_id].reshape(-1)] # [-30, 70] scores + if 'histc_skipmn_in_softmax_gt' in data: + plots.append(data['histc_skipmn_in_softmax_gt'][b_id].reshape(-1)) + elif config.LOFTR.MATCH_COARSE.PLOT_ORIGIN_SCORES: + pass + else: + pass + print(plots[0], plots[-1]) + group = len(plots) + start, end = 0, 100 + bins=100 + width = (end//bins-1)/group + fig, ax = plt.subplots() + for i, hist in enumerate(plots): + ax.set_yscale('log') + x = range(start, end, end//bins) + x = [t + i*width for t in x] + ax.bar(x, hist.cpu(), align='edge', width=width) + + elif mode == 'confidence': + raise NotImplementedError() + elif mode == 'gt': + raise NotImplementedError() + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + if thr is not None: + large_error_mask = err > (thr * 2) + x = np.clip((err - thr) / (thr * 2), 0, 1) + else: + large_error_mask = np.zeros_like(err, dtype=bool) + x = np.clip(err, 0.1, 1) + + cm_ = matplotlib.colormaps['RdYlGn_r'] + color = cm_(x, bytes=False) + color[:, 3] = alpha + color[:, 3][large_error_mask] = alpha * 0.6 + return color \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/utils/profiler.py b/imcui/third_party/MatchAnything/src/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/profiler.py @@ -0,0 +1,39 @@ +import torch +from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler +from contextlib import contextmanager +from pytorch_lightning.utilities import rank_zero_only + + +class InferenceProfiler(SimpleProfiler): + """ + This profiler records duration of actions with cuda.synchronize() + Use this in test time. + """ + + def __init__(self): + super().__init__() + self.start = rank_zero_only(self.start) + self.stop = rank_zero_only(self.stop) + self.summary = rank_zero_only(self.summary) + + @contextmanager + def profile(self, action_name: str) -> None: + try: + torch.cuda.synchronize() + self.start(action_name) + yield action_name + finally: + torch.cuda.synchronize() + self.stop(action_name) + + +def build_profiler(name): + if name == 'inference': + return InferenceProfiler() + elif name == 'pytorch': + from pytorch_lightning.profiler import PyTorchProfiler + return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) + elif name is None: + return PassThroughProfiler() + else: + raise ValueError(f'Invalid profiler: {name}') diff --git a/imcui/third_party/MatchAnything/src/utils/ray_utils.py b/imcui/third_party/MatchAnything/src/utils/ray_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6c7bf3172f4d53513338b1b615efcfded1c4c9 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/ray_utils.py @@ -0,0 +1,134 @@ +from asyncio import Event +from typing import Tuple +import numpy as np +import random + +import ray +from ray.actor import ActorHandle +from tqdm import tqdm + + +@ray.remote +class ProgressBarActor: + counter: int + delta: int + event: Event + + def __init__(self) -> None: + self.counter = 0 + self.delta = 0 + self.event = Event() + + def update(self, num_items_completed: int) -> None: + """Updates the ProgressBar with the incremental + number of items that were just completed. + """ + self.counter += num_items_completed + self.delta += num_items_completed + self.event.set() + + async def wait_for_update(self) -> Tuple[int, int]: + """Blocking call. + + Waits until somebody calls `update`, then returns a tuple of + the number of updates since the last call to + `wait_for_update`, and the total number of completed items. + """ + await self.event.wait() + self.event.clear() + saved_delta = self.delta + self.delta = 0 + return saved_delta, self.counter + + def get_counter(self) -> int: + """ + Returns the total number of complete items. + """ + return self.counter + + +class ProgressBar: + progress_actor: ActorHandle + total: int + description: str + pbar: tqdm + + def __init__(self, total: int, description: str = ""): + # Ray actors don't seem to play nice with mypy, generating + # a spurious warning for the following line, + # which we need to suppress. The code is fine. + self.progress_actor = ProgressBarActor.remote() # type: ignore + self.total = total + self.description = description + + @property + def actor(self) -> ActorHandle: + """Returns a reference to the remote `ProgressBarActor`. + + When you complete tasks, call `update` on the actor. + """ + return self.progress_actor + + def print_until_done(self) -> None: + """Blocking call. + + Do this after starting a series of remote Ray tasks, to which you've + passed the actor handle. Each of them calls `update` on the actor. + When the progress meter reaches 100%, this method returns. + """ + pbar = tqdm(desc=self.description, total=self.total) + while True: + delta, counter = ray.get(self.actor.wait_for_update.remote()) + pbar.update(delta) + if counter >= self.total: + pbar.close() + return + +# Ray data utils +def chunks(lst, n, length=None): + """Yield successive n-sized chunks from lst.""" + try: + _len = len(lst) + except TypeError as _: + assert length is not None + _len = length + + for i in range(0, _len, n): + yield lst[i : i + n] + +def chunks_balance(lst, n_split): + if n_split == 0: + # 0 is not allowed + n_split = 1 + splited_list = [[] for i in range(n_split)] + for id, obj in enumerate(lst): + assign_id = id % n_split + splited_list[assign_id].append(obj) + return splited_list + + +def chunk_index(total_len, sub_len, shuffle=True): + index_array = np.arange(total_len) + if shuffle: + random.shuffle(index_array) + + index_list = [] + for i in range(0, total_len, sub_len): + index_list.append(list(index_array[i : i + sub_len])) + + return index_list + +def chunk_index_balance(total_len, n_split, shuffle=True): + index_array = np.arange(total_len) + if shuffle: + random.shuffle(index_array) + + splited_list = [[] for i in range(n_split)] + for id, obj in enumerate(index_array): + assign_id = id % n_split + splited_list[assign_id].append(obj) + return splited_list + +def split_dict(_dict, n): + for _items in chunks(list(_dict.items()), n): + yield dict(_items) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/utils/sample_homo.py b/imcui/third_party/MatchAnything/src/utils/sample_homo.py new file mode 100644 index 0000000000000000000000000000000000000000..d4fdac234e1058b5a1106a98b5c3f77426af4b91 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/sample_homo.py @@ -0,0 +1,58 @@ +import numpy as np + +# ----- Similarity-Affinity-Perspective (SAP) impl ----- # + +def similarity_mat(angle, tx, ty, s): + theta = np.deg2rad(angle) + return np.array([[s*np.cos(theta), -s*np.sin(theta), tx], [s*np.sin(theta), s*np.cos(theta), ty], [0, 0, 1]]) + + +def affinity_mat(k0, k1): + return np.array([[k0, k1, 0], [0, 1/k0, 0], [0, 0, 1]]) + + +def perspective_mat(v0, v1): + return np.array([[1, 0, 0], [0, 1, 0], [v0, v1, 1]]) + + +def compute_homography_sap(h, w, angle=0, tx=0, ty=0, scale=1, k0=1, k1=0, v0=0, v1=0): + """ + Args: + img_size: (h, w) + angle: in degree, goes clock-wise in image-coordinate-system + tx, ty: displacement + scale: factor to zoom in, by default 1 + k0: non-isotropic squeeze factor - 1 +(stretch x, squeeze y) [0.5, 1.5] + k1: non-isotropic skew factor, - 0 +(up-to-left, down-to-right) [-0.5, 0.5] + v0: left-right perspective factor, - 0 +(move left) [-1, 1] + v1: up-down perspective factor, - 0 +(move up) [-1, 1] + """ + # move image to its center + max_size = max(w/2, h/2) + M_norm = similarity_mat(0, 0, 0, 1/max_size).dot(similarity_mat(0, -w/2, -h/2, 1)) + M_denorm = similarity_mat(0, w/2, h/2, 1).dot(similarity_mat(0, 0, 0, max_size)) + + # compute HS, HA and HP accordingly + HS = similarity_mat(angle, tx, ty, scale) + HA = affinity_mat(k0, k1) + HP = perspective_mat(v0, v1) + + # final H + H = M_denorm.dot(HS).dot(HA).dot(HP).dot(M_norm) + return H + + +def sample_homography_sap(h, w, angle=180, tx=0.25, ty=0.25, scale=2.0, k1=0.1, v0=0.5, v1=0.5): + angle = np.random.uniform(-1 * angle, angle) + tx = np.random.uniform(-1 * tx, tx) + ty = np.random.uniform(-1 * ty, ty) + scale = np.random.uniform(1/scale, scale) + + k0 = 1 # similar effects as the ratio of xy-focal lengths + k1 = np.random.uniform(-1 * k1, k1) + + v0 = np.random.uniform(-1 * v0, v0) + v1 = np.random.uniform(-1 * v1, v1) + + H = compute_homography_sap(h, w, angle, tx, ty, scale, k0, k1, v0, v1) + return H \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/src/utils/utils.py b/imcui/third_party/MatchAnything/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2f3d6f96fbb800d290c83a174f93ead187dcf3 --- /dev/null +++ b/imcui/third_party/MatchAnything/src/utils/utils.py @@ -0,0 +1,600 @@ +from pathlib import Path +import time +from collections import OrderedDict +from threading import Thread +from loguru import logger +from PIL import Image + +import numpy as np +import cv2 +import torch +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('Agg') + +class AverageTimer: + """ Class to help manage printing simple timing of code execution. """ + + def __init__(self, smoothing=0.3, newline=False): + self.smoothing = smoothing + self.newline = newline + self.times = OrderedDict() + self.will_print = OrderedDict() + self.reset() + + def reset(self): + now = time.time() + self.start = now + self.last_time = now + for name in self.will_print: + self.will_print[name] = False + + def update(self, name='default'): + now = time.time() + dt = now - self.last_time + if name in self.times: + dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name] + self.times[name] = dt + self.will_print[name] = True + self.last_time = now + + def print(self, text='Timer'): + total = 0. + print('[{}]'.format(text), end=' ') + for key in self.times: + val = self.times[key] + if self.will_print[key]: + print('%s=%.3f' % (key, val), end=' ') + total += val + print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ') + if self.newline: + print(flush=True) + else: + print(end='\r', flush=True) + self.reset() + + +class VideoStreamer: + """ Class to help process image streams. Four types of possible inputs:" + 1.) USB Webcam. + 2.) An IP camera + 3.) A directory of images (files in directory matching 'image_glob'). + 4.) A video file, such as an .mp4 or .avi file. + """ + + def __init__(self, basedir, resize, skip, image_glob, max_length=1000000): + self._ip_grabbed = False + self._ip_running = False + self._ip_camera = False + self._ip_image = None + self._ip_index = 0 + self.cap = [] + self.camera = True + self.video_file = False + self.listing = [] + self.resize = resize + self.interp = cv2.INTER_AREA + self.i = 0 + self.skip = skip + self.max_length = max_length + if isinstance(basedir, int) or basedir.isdigit(): + print('==> Processing USB webcam input: {}'.format(basedir)) + self.cap = cv2.VideoCapture(int(basedir)) + self.listing = range(0, self.max_length) + elif basedir.startswith(('http', 'rtsp')): + print('==> Processing IP camera input: {}'.format(basedir)) + self.cap = cv2.VideoCapture(basedir) + self.start_ip_camera_thread() + self._ip_camera = True + self.listing = range(0, self.max_length) + elif Path(basedir).is_dir(): + print('==> Processing image directory input: {}'.format(basedir)) + self.listing = list(Path(basedir).glob(image_glob[0])) + for j in range(1, len(image_glob)): + image_path = list(Path(basedir).glob(image_glob[j])) + self.listing = self.listing + image_path + self.listing.sort() + self.listing = self.listing[::self.skip] + self.max_length = np.min([self.max_length, len(self.listing)]) + if self.max_length == 0: + raise IOError('No images found (maybe bad \'image_glob\' ?)') + self.listing = self.listing[:self.max_length] + self.camera = False + elif Path(basedir).exists(): + print('==> Processing video input: {}'.format(basedir)) + self.cap = cv2.VideoCapture(basedir) + self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) + num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.listing = range(0, num_frames) + self.listing = self.listing[::self.skip] + self.video_file = True + self.max_length = np.min([self.max_length, len(self.listing)]) + self.listing = self.listing[:self.max_length] + else: + raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir)) + if self.camera and not self.cap.isOpened(): + raise IOError('Could not read camera') + + def load_image(self, impath): + """ Read image as grayscale and resize to img_size. + Inputs + impath: Path to input image. + Returns + grayim: uint8 numpy array sized H x W. + """ + grayim = cv2.imread(impath, 0) + if grayim is None: + raise Exception('Error reading image %s' % impath) + w, h = grayim.shape[1], grayim.shape[0] + w_new, h_new = process_resize(w, h, self.resize) + grayim = cv2.resize( + grayim, (w_new, h_new), interpolation=self.interp) + return grayim + + def next_frame(self): + """ Return the next frame, and increment internal counter. + Returns + image: Next H x W image. + status: True or False depending whether image was loaded. + """ + + if self.i == self.max_length: + return (None, False) + if self.camera: + + if self._ip_camera: + # Wait for first image, making sure we haven't exited + while self._ip_grabbed is False and self._ip_exited is False: + time.sleep(.001) + + ret, image = self._ip_grabbed, self._ip_image.copy() + if ret is False: + self._ip_running = False + else: + ret, image = self.cap.read() + if ret is False: + print('VideoStreamer: Cannot get image from camera') + return (None, False) + w, h = image.shape[1], image.shape[0] + if self.video_file: + self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i]) + + w_new, h_new = process_resize(w, h, self.resize) + image = cv2.resize(image, (w_new, h_new), + interpolation=self.interp) + image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + else: + image_file = str(self.listing[self.i]) + image = self.load_image(image_file) + self.i = self.i + 1 + return (image, True) + + def start_ip_camera_thread(self): + self._ip_thread = Thread(target=self.update_ip_camera, args=()) + self._ip_running = True + self._ip_thread.start() + self._ip_exited = False + return self + + def update_ip_camera(self): + while self._ip_running: + ret, img = self.cap.read() + if ret is False: + self._ip_running = False + self._ip_exited = True + self._ip_grabbed = False + return + + self._ip_image = img + self._ip_grabbed = ret + self._ip_index += 1 + #print('IPCAMERA THREAD got frame {}'.format(self._ip_index)) + + def cleanup(self): + self._ip_running = False + +# --- PREPROCESSING --- + + +def process_resize(w, h, resize): + assert(len(resize) > 0 and len(resize) <= 2) + if len(resize) == 1 and resize[0] > -1: + scale = resize[0] / max(h, w) + w_new, h_new = int(round(w*scale)), int(round(h*scale)) + elif len(resize) == 1 and resize[0] == -1: + w_new, h_new = w, h + else: # len(resize) == 2: + w_new, h_new = resize[0], resize[1] + + # Issue warning if resolution is too small or too large. + if max(w_new, h_new) < 160: + print('Warning: input resolution is very small, results may vary') + elif max(w_new, h_new) > 2000: + print('Warning: input resolution is very large, results may vary') + + return w_new, h_new + + +def frame2tensor(frame, device): + """ Depth image to tensor + """ + return torch.from_numpy(frame/255.).float()[None, None].to(device) + + +def read_image(path, device, resize, rotation, resize_float): + image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) + if image is None: + return None, None, None + w, h = image.shape[1], image.shape[0] + w_new, h_new = process_resize(w, h, resize) + scales = (float(w) / float(w_new), float(h) / float(h_new)) + + if resize_float: + image = cv2.resize(image.astype('float32'), (w_new, h_new)) + else: + image = cv2.resize(image, (w_new, h_new)).astype('float32') + + if rotation != 0: + image = np.rot90(image, k=rotation) + if rotation % 2: + scales = scales[::-1] + + inp = frame2tensor(image, device) + return image, inp, scales + + +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): + if len(kpts0) < 5: + return None + + f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) + norm_thresh = thresh / f_mean + + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, + method=cv2.RANSAC) + + # assert E is not None # might cause unexpected exception in validation step + if E is None: + print("\nE is None while trying to recover pose.\n") + return None + + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose( + _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t[:, 0], mask.ravel() > 0) + return ret + + +def estimate_pose_degensac(kpts0, kpts1, K0, K1, thresh, conf=0.9999, max_iters=1000, min_candidates=10): + import pydegensac + # TODO: Try different `min_candidatas`? + if len(kpts0) < min_candidates: + return None + + F, mask = pydegensac.findFundamentalMatrix(kpts0, + kpts1, + px_th=thresh, + conf=conf, + max_iters=max_iters) + mask = mask.astype(np.uint8) + E = (K1.T @ F @ K0).astype(np.float64) + + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + + # This might be optional (since DEGENSAC handle it internally ?) + best_num_inliers = 0 + ret = None + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose( + _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t[:, 0], mask.ravel() > 0) + + return ret + +def rotate_intrinsics(K, image_shape, rot): + """image_shape is the shape of the image after rotation""" + assert rot <= 3 + h, w = image_shape[:2][::-1 if (rot % 2) else 1] + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + rot = rot % 4 + if rot == 1: + return np.array([[fy, 0., cy], + [0., fx, w-1-cx], + [0., 0., 1.]], dtype=K.dtype) + elif rot == 2: + return np.array([[fx, 0., w-1-cx], + [0., fy, h-1-cy], + [0., 0., 1.]], dtype=K.dtype) + else: # if rot == 3: + return np.array([[fy, 0., h-1-cy], + [0., fx, cx], + [0., 0., 1.]], dtype=K.dtype) + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array([[np.cos(r), -np.sin(r), 0., 0.], + [np.sin(r), np.cos(r), 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]], dtype=np.float32) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1./scales[0], 1./scales[1], 1.]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1, enable_MEinPC=False): + """ Comupute the squared symmetric epipolar distance (SED^2). + The essential matrix is calculated with the relative pose T_0to1. + SED can be seen as a biased estimation of the reprojection error. + Args: + enable_MEinPC: Mean Error in Pixel Coordinate + """ + kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] + kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] + kpts0 = to_homogeneous(kpts0) + kpts1 = to_homogeneous(kpts1) + + t0, t1, t2 = T_0to1[:3, 3] + t_skew = np.array([ + [0, -t2, t1], + [t2, 0, -t0], + [-t1, t0, 0] + ]) + E = t_skew @ T_0to1[:3, :3] + + Ep0 = kpts0 @ E.T # N x 3 + p1Ep0 = np.sum(kpts1 * Ep0, -1) # N + Etp1 = kpts1 @ E # N x 3 + if enable_MEinPC: + d = 0.5 * np.abs(p1Ep0) * (np.linalg.norm([K1[0, 0], K1[1, 1]]) / np.linalg.norm([Ep0[:, 0], Ep0[:, 1]], axis=0) + + np.linalg.norm([K0[0, 0], K0[1, 1]]) / np.linalg.norm([Etp1[:, 0], Etp1[:, 1]], axis=0)) # N + else: + d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N + return d + + +def compute_homogeneous_error(kpts0, kpts1, H): + """ warp kpts0 to img1, compute error with kpts1 + """ + kpts0 = to_homogeneous(kpts0) + + w_kpts0 = kpts0 @ H.T # N x 3 + w_kpts0 = w_kpts0[:, :2] / w_kpts0[:, [2]] + + d = np.linalg.norm(w_kpts0 - kpts1, axis=1) + return d + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t, t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + if np.linalg.norm(t_gt) < ignore_gt_t_thr: # NOTE: as a close-to-zero translation is not good for angle_error calculation + error_t = 0 + return error_t, error_R + +def convert_gt_T(T_0to1): + gt_R_degree = angle_error_mat(T_0to1[:, :3], np.eye(3)) + gt_t_dist = np.linalg.norm(T_0to1[:, 3]) + return gt_t_dist, gt_R_degree + +def pose_auc(errors, thresholds, ret_dict=False): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0., errors] + recall = np.r_[0., recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index-1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e)/t) + if ret_dict: + return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} + else: + return aucs + + +def epidist_prec(errors, thresholds, ret_dict=False): + precs = [] + for thr in thresholds: + prec_ = [] + for errs in errors: + correct_mask = errs < thr + prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) + precs.append(np.mean(prec_) if len(prec_) > 0 else 0) + if ret_dict: + return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} + else: + return precs + +# --- VISUALIZATION --- +def plot_image_pair(imgs, dpi=100, size=6, pad=.5): + n = len(imgs) + assert n == 2, 'number of images must be two' + figsize = (size*n, size*3/4) if size is not None else None + _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) + for i in range(n): + ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255) + ax[i].get_yaxis().set_ticks([]) + ax[i].get_xaxis().set_ticks([]) + for spine in ax[i].spines.values(): # remove frame + spine.set_visible(False) + plt.tight_layout(pad=pad) + + +def plot_keypoints(kpts0, kpts1, color='w', ps=2): + ax = plt.gcf().axes + ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4): + fig = plt.gcf() + ax = fig.axes + fig.canvas.draw() + + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0)) + fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1)) + + fig.lines = [matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1, + transform=fig.transFigure, c=color[i], linewidth=lw) + for i in range(len(kpts0))] + ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) + ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) + + +def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1, + color, text, path=None, show_keypoints=False, + fast_viz=False, opencv_display=False, + opencv_title='matches', small_text=[]): + + if fast_viz: + make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1, + color, text, path, show_keypoints, 10, + opencv_display, opencv_title, small_text) + return + + plot_image_pair([image0, image1]) # will create a new figure + if show_keypoints: + plot_keypoints(kpts0, kpts1, color='k', ps=4) + plot_keypoints(kpts0, kpts1, color='w', ps=2) + plot_matches(mkpts0, mkpts1, color) + + fig = plt.gcf() + txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w' + fig.text( + 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes, + fontsize=5, va='bottom', ha='left', color=txt_color) + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + # TODO: Would it leads to any issue without current figure opened? + return fig + + +def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, + mkpts1, color, text, path=None, + show_keypoints=False, margin=10, + opencv_display=False, opencv_title='', + small_text=[]): + H0, W0 = image0.shape + H1, W1 = image1.shape + H, W = max(H0, H1), W0 + W1 + margin + + out = 255*np.ones((H, W), np.uint8) + out[:H0, :W0] = image0 + out[:H1, W0+margin:] = image1 + out = np.stack([out]*3, -1) + + if show_keypoints: + kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) + white = (255, 255, 255) + black = (0, 0, 0) + for x, y in kpts0: + cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA) + for x, y in kpts1: + cv2.circle(out, (x + margin + W0, y), 2, black, -1, + lineType=cv2.LINE_AA) + cv2.circle(out, (x + margin + W0, y), 1, white, -1, + lineType=cv2.LINE_AA) + + mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int) + color = (np.array(color[:, :3])*255).astype(int)[:, ::-1] + for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color): + c = c.tolist() + cv2.line(out, (x0, y0), (x1 + margin + W0, y1), + color=c, thickness=1, lineType=cv2.LINE_AA) + # display line end-points as circles + cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA) + cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1, + lineType=cv2.LINE_AA) + + # Scale factor for consistent visualization across scales. + sc = min(H / 640., 2.0) + + # Big text. + Ht = int(30 * sc) # text height + txt_color_fg = (255, 255, 255) + txt_color_bg = (0, 0, 0) + for i, t in enumerate(text): + cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, + 1.0*sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX, + 1.0*sc, txt_color_fg, 1, cv2.LINE_AA) + + # Small text. + Ht = int(18 * sc) # text height + for i, t in enumerate(reversed(small_text)): + cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, + 0.5*sc, txt_color_bg, 2, cv2.LINE_AA) + cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX, + 0.5*sc, txt_color_fg, 1, cv2.LINE_AA) + + if path is not None: + cv2.imwrite(str(path), out) + + if opencv_display: + cv2.imshow(opencv_title, out) + cv2.waitKey(1) + + return out + + +def error_colormap(x, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) + +def check_img_ok(img_path): + img_ok = True + try: + Image.open(str(img_path)).convert('RGB') + except: + img_ok = False + return img_ok \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/.gitignore b/imcui/third_party/MatchAnything/third_party/ROMA/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ff4633046059da69ab4b6e222909614ccda82ac4 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/.gitignore @@ -0,0 +1,5 @@ +*.egg-info* +*.vscode* +*__pycache__* +vis* +workspace* \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/LICENSE b/imcui/third_party/MatchAnything/third_party/ROMA/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ca95157052a76debc473afb395bffae0c1329e63 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Johan Edstedt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/README.md b/imcui/third_party/MatchAnything/third_party/ROMA/README.md new file mode 100644 index 0000000000000000000000000000000000000000..284d8f0bea84d7f67a416bc933067a3acfe23740 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/README.md @@ -0,0 +1,82 @@ +# +

+

RoMa 🏛️:
Robust Dense Feature Matching

+

+ Johan Edstedt + · + Qiyu Sun + · + Georg Bökman + · + Mårten Wadenbäck + · + Michael Felsberg +

+

+ Paper | + Project Page +

+
+

+
+

+ example +
+ RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair. +

+ +## Setup/Install +In your python environment (tested on Linux python 3.10), run: +```bash +pip install -e . +``` +## Demo / How to Use +We provide two demos in the [demos folder](demo). +Here's the gist of it: +```python +from roma import roma_outdoor +roma_model = roma_outdoor(device=device) +# Match +warp, certainty = roma_model.match(imA_path, imB_path, device=device) +# Sample matches for estimation +matches, certainty = roma_model.sample(warp, certainty) +# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1]) +kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) +# Find a fundamental matrix (or anything else of interest) +F, mask = cv2.findFundamentalMat( + kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 +) +``` + +**New**: You can also match arbitrary keypoints with RoMa. A demo for this will be added soon. + +## Reproducing Results +The experiments in the paper are provided in the [experiments folder](experiments). + +### Training +1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets. +2. Run the relevant experiment, e.g., +```bash +torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py +``` +### Testing +```bash +python experiments/roma_outdoor.py --only_test --benchmark mega-1500 +``` +## License +All our code except DINOv2 is MIT license. +DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE). + +## Acknowledgement +Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM). + +## BibTeX +If you find our models useful, please consider citing our paper! +``` +@article{edstedt2023roma, +title={{RoMa: Robust Dense Feature Matching}}, +author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael}, +journal={arXiv preprint arXiv:2305.15404}, +year={2023} +} +``` diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/data/.gitignore b/imcui/third_party/MatchAnything/third_party/ROMA/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_3D_effect.py b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_3D_effect.py new file mode 100644 index 0000000000000000000000000000000000000000..5afd6e5ce0fdd32788160e8c24df0b26a27f34dd --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_3D_effect.py @@ -0,0 +1,46 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +from roma.utils.utils import tensor_to_pil + +from roma import roma_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + # Create model + roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152)) + roma_model.symmetric = False + + H, W = roma_model.get_output_resolution() + + im1 = Image.open(im1_path).resize((W, H)) + im2 = Image.open(im2_path).resize((W, H)) + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sampling not needed, but can be done with model.sample(warp, certainty) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + coords_A, coords_B = warp[...,:2], warp[...,2:] + for i, x in enumerate(np.linspace(0,2*np.pi,200)): + t = (1 + np.cos(x))/2 + interp_warp = (1-t)*coords_A + t*coords_B + im2_transfer_rgb = F.grid_sample( + x2[None], interp_warp[None], mode="bilinear", align_corners=False + )[0] + tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg") \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental.py b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental.py new file mode 100644 index 0000000000000000000000000000000000000000..fd89df18664446fbc5ca299e7c966663e8f30aed --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental.py @@ -0,0 +1,32 @@ +from PIL import Image +import torch +import cv2 +from roma import roma_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + # Create model + roma_model = roma_outdoor(device=device) + + + W_A, H_A = Image.open(im1_path).size + W_B, H_B = Image.open(im2_path).size + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sample matches for estimation + matches, certainty = roma_model.sample(warp, certainty) + kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + F, mask = cv2.findFundamentalMat( + kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 + ) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental_model_warpper.py b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental_model_warpper.py new file mode 100644 index 0000000000000000000000000000000000000000..c5cb86202867dad000f7357d18e3bf1e8b4955a5 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental_model_warpper.py @@ -0,0 +1,34 @@ +from PIL import Image +import torch +import cv2 +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent.resolve())) +from roma.roma_adpat_model import ROMA_Model + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + # Create model + model = ROMA_Model({"n_sample": 5000}) + + + W_A, H_A = Image.open(im1_path).size + W_B, H_B = Image.open(im2_path).size + + # Match + match_results = model({"image0_path": im1_path, "image1_path": im2_path}) + kpts1, kpts2 = match_results['mkpts0_f'], match_results['mkpts1_f'] + # Sample matches for estimation + F, mask = cv2.findFundamentalMat( + kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000 + ) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match.py b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match.py new file mode 100644 index 0000000000000000000000000000000000000000..0b49ad510c02f9dd022e077667c13ee2bcb7eca8 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match.py @@ -0,0 +1,50 @@ +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent.resolve())) +from roma.utils.utils import tensor_to_pil + +from roma import roma_outdoor + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + # Create model + roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152)) + + H, W = roma_model.get_output_resolution() + + im1 = Image.open(im1_path).resize((W, H)) + im2 = Image.open(im2_path).resize((W, H)) + + # Match + warp, certainty = roma_model.match(im1_path, im2_path, device=device) + # Sampling not needed, but can be done with model.sample(warp, certainty) + x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1) + x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1) + + im2_transfer_rgb = F.grid_sample( + x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + im1_transfer_rgb = F.grid_sample( + x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + tensor_to_pil(vis_im, unnormalize=False).save(save_path) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match_opencv_sift.py b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match_opencv_sift.py new file mode 100644 index 0000000000000000000000000000000000000000..3196fcfaab248f6c4c6247a0afb4db745206aee8 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match_opencv_sift.py @@ -0,0 +1,43 @@ +from PIL import Image +import numpy as np + +import numpy as np +import cv2 as cv +import matplotlib.pyplot as plt + + + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str) + parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str) + parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str) + + args, _ = parser.parse_known_args() + im1_path = args.im_A_path + im2_path = args.im_B_path + save_path = args.save_path + + img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage + img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage + # Initiate SIFT detector + sift = cv.SIFT_create() + # find the keypoints and descriptors with SIFT + kp1, des1 = sift.detectAndCompute(img1,None) + kp2, des2 = sift.detectAndCompute(img2,None) + # BFMatcher with default params + bf = cv.BFMatcher() + matches = bf.knnMatch(des1,des2,k=2) + # Apply ratio test + good = [] + for m,n in matches: + if m.distance < 0.75*n.distance: + good.append([m]) + # cv.drawMatchesKnn expects list of lists as matches. + draw_params = dict(matchColor = (255,0,0), # draw matches in red color + singlePointColor = None, + flags = 2) + + img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params) + Image.fromarray(img3).save("demo/sift_matches.png") diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/demo/gif/.gitignore b/imcui/third_party/MatchAnything/third_party/ROMA/demo/gif/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/demo/gif/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/demo_single_pair.py b/imcui/third_party/MatchAnything/third_party/ROMA/demo_single_pair.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e11bce5a7197967697f7ebb2a76a9c7250f09a --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/demo_single_pair.py @@ -0,0 +1,329 @@ +import os +#os.chdir("..") +import torch +import cv2 +from time import time +from loguru import logger +import numpy as np +import matplotlib.cm as cm +import matplotlib.pyplot as plt +from notebooks.notebooks_utils import make_matching_figure, show_image_pair +import PIL +import torch.nn.functional as F +import pydegensac +from roma.roma_adpat_model import ROMA_Model + +def extract_geo_model_inliers(mkpts0, mkpts1, mconfs, + geo_model, ransac_method, pixel_thr, max_iters, conf_thr, + K0=None, K1=None): + if geo_model == 'E': + f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) + pixel_thr = pixel_thr / f_mean + + mkpts0, mkpts1 = map(lambda x: normalize_ketpoints(*x), [(mkpts0, K0), (mkpts1, K1)]) + + if ransac_method == 'RANSAC': + if geo_model == 'E': + E, mask = cv2.findEssentialMat(mkpts0, + mkpts1, + np.eye(3), + threshold=pixel_thr, + prob=conf_thr, + method=cv2.RANSAC) + elif geo_model == 'F': + F, mask = cv2.findFundamentalMat(mkpts0, + mkpts1, + method=cv2.FM_RANSAC, + ransacReprojThreshold=pixel_thr, + confidence=conf_thr, + maxIters=max_iters) + elif ransac_method == 'DEGENSAC': + assert geo_model == 'F' + F, mask = pydegensac.findFundamentalMatrix(mkpts0, + mkpts1, + px_th=pixel_thr, + conf=conf_thr, + max_iters=max_iters) + elif ransac_method == 'MAGSAC': + params = cv2.UsacParams() + # params.threshold = pixel_thr + # params.confidence = conf_thr + # params.maxIterations = max_iters + # params.randomGeneratorState = 0 + # params. + # F, mask = cv2.findFundamentalMat(mkpts0, + # mkpts1, + # method=cv2.USAC_MAGSAC, + # ) + F, mask = cv2.findFundamentalMat(mkpts0, + mkpts1, + method=cv2.USAC_MAGSAC, + ransacReprojThreshold=pixel_thr, + confidence=conf_thr, + maxIters=max_iters) + else: + raise ValueError() + + if mask is not None: + mask = mask.astype(bool).flatten() + else: + mask = np.full_like(mconfs, True, dtype=np.bool) + return mask + +def extract_inliers(data, args): + """extract inlier matches assume bs==1. + NOTE: If no inliers found, keep all matches. + """ + mkpts0, mkpts1, mconfs= extract_preds(data) + K0 = data['K0'][0].cpu().numpy() if args.geo_model == 'E' else None + K1 = data['K1'][0].cpu().numpy() if args.geo_model == 'E' else None + if len(mkpts0) >=8 : + inliers = extract_geo_model_inliers(mkpts0, mkpts1, mconfs, + args.geo_model, args.ransac_method, args.pixel_thr, args.max_iters, args.conf_thr, + K0=K0, K1=K1) + mkpts0, mkpts1, mconfs = map(lambda x: x[inliers], [mkpts0, mkpts1, mconfs, detector_kpts_mask]) + +# The default config uses dual-softmax. +# The outdoor and indoor models share the same config. +# You can change the default values like thr and coarse_match_type. +if __name__ == "__main__": + # matching_method = 'SuperPoint+SuperGlue' + matching_method = 'ROMA' + # enable_geometric_verify = False + enable_geometric_verify = True + loftr_cfg_path = "configs/loftr/matchanything/exps/loftr_ds_dense_PAN_M2D_noalign_repvgg_fpn_fp16_nf_conly_inter_clip0_dense_skipsoft_match_sparse_spv.py" + loftr_model_path = "logs/tb_logs/megadepth_trainval_1024_with_depth_modal_with_glddepthwarp_with_thermaltest@-@loftr_ds_dense_PAN_M2D_noalign_repvgg_fpn_fp16_nf_conly_inter_clip0_dense_skipsoft_match_sparse_spv-bs12/version_0/checkpoints/last.ckpt" + pixel_thr = 2.0 + img_resize = 840 + img_warp_back = True + if matching_method == 'SuperPoint+SuperGlue': + matcher = SPPSPG() + matcher = matcher.eval().cuda() + elif matching_method == 'LoFTR': + config = get_cfg_defaults() + config.merge_from_file(loftr_cfg_path) + config = lower_config(config) + matcher = LoFTR(config=config['loftr']) + # matcher = LoFTR(config=default_ot_cfg) + ckpt = torch.load( + loftr_model_path, map_location="cpu" + )["state_dict"] + for k in list(ckpt.keys()): + if 'matcher' in k: + newk = k[k.find("matcher")+len('matcher')+1:] + ckpt[newk] = ckpt[k] + ckpt.pop(k) + matcher.load_state_dict(ckpt) + matcher = matcher.eval().cuda() + elif matching_method == 'ROMA': + # matcher = ROMA_Model({"n_sample": 5000, "load_img_in_model": False}) + matcher = ROMA_Model({"n_sample": 5000, "load_img_in_model": True}) + + # rotation_degree = -90 + # rotation_degree = 30 + # rotation_degree = -90 + # rotation_degree = -45 + # rotation_degree = 45 + # rotation_degree = 15 + + # scene_name = 'thermal' + # # img0_pth = "assets/rgb_daytime.jpg" + # # img0_pth = "/data/hexingyi/code/LoFTR/assets/rgb_daytime_6446.jpg" + # img0_pth = "/data/hexingyi/code/LoFTR/assets/rgb_daytime.jpg" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/thermal_daytime1.jpg" + # rotation_degree = 0 + # pixel_thr = 1.0 + + # scene_name = 'satellite' + # img0_pth = "/data/hexingyi/code/LoFTR/assets/satellite.jpg" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/airplane.jpg" + # rotation_degree = 60 + # pixel_thr = 4.0 + + # scene_name = 'satellite4' + # img0_pth = "/data/common_dataset/uva_localization_data/cropped_map_images/214_115.9440317_115.9540317_40.367160000000005_40.37716.png" + # img1_pth = "/data/hexingyi/code/UAV_Loc/0.png" + # # img1_pth = "/data/hexingyi/code/LoFTR/assets/airplane2.png" + # rotation_degree = -30 + # pixel_thr = 4.0 + + # scene_name = 'satellite2' + # img0_pth = "/data/hexingyi/code/LoFTR/assets/satellite.jpg" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/airplane2_cropped.jpeg" + # # img1_pth = "/data/hexingyi/code/LoFTR/assets/airplane2.png" + # rotation_degree = 0 + # pixel_thr = 2.0 + + # scene_name = 'satellite3' + # # img0_pth = "/data/hexingyi/code/LoFTR/assets/airplane3_cropped.jpeg" + # img0_pth = "/data/hexingyi/code/LoFTR/assets/airplane3_squere.jpg" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/satellite_squere.jpg" + # # img1_pth = "/data/hexingyi/code/LoFTR/assets/airplane2.png" + # pixel_thr = 2.0 + + # scene_name = 'yanshen_demo' + # # img0_pth = "/data/hexingyi/code/LoFTR/assets/airplane3_cropped.jpeg" + # img0_pth = "/data/hexingyi/code/LoFTR/assets/view3_new.png" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/view1.png" + # # img1_pth = "/data/hexingyi/code/LoFTR/assets/airplane2.png" + # rotation_degree = 0 + # pixel_thr = 2.0 + + # scene_name = 'map' + # img0_pth = "/data/hexingyi/code/LoFTR/assets/pair76_1.jpg" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/pair76_2.jpg" + # pixel_thr = 4.0 + # rotation_degree = 0 + + scene_name = 'sar' + img0_pth = "/data/hexingyi/code/LoFTR/assets/rgb_pair_24_1.jpg" + img0_pth_ = "/data/hexingyi/code/LoFTR/assets/rgb_pair_24_1_edited.jpg" + img1_pth = "/data/hexingyi/code/LoFTR/assets/sar_pair24_2.jpg" + img1_pth_ = "/data/hexingyi/code/LoFTR/assets/sar_pair24_2_edited.jpg" + pixel_thr = 4.0 + rotation_degree = 0 + + + # scene_name = 'sar2' + # img0_pth = "/data/hexingyi/code/LoFTR/assets/pair183_1.jpg" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/pair183_2.jpg" + # img1_pth_ = "/data/hexingyi/code/LoFTR/assets/pair183_2_edited.jpg" + # pixel_thr = 4.0 + # rotation_degree = 0 + + # scene_name = 'medacine' + # img0_pth = "/data/hexingyi/code/LoFTR/assets/ct.png" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/mri.png" + # rotation_degree = 0 + # pixel_thr = 0.8 + + # scene_name = 'medacine2' + # img0_pth = "/data/hexingyi/code/LoFTR/assets/ct2.png" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/mri2.png" + # rotation_degree = 0 + # # pixel_thr = 0.8 + + # scene_name = 'deepsea' + # img0_pth = "/data/hexingyi/code/LoFTR/assets/deepsea540.png" + # # img1_pth = "/data/hexingyi/code/LoFTR/assets/deepsea700.png" + # img1_pth = "/data/hexingyi/code/LoFTR/assets/deepsea789.png" + # rotation_degree = 0 + + img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE) + img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE) + + try: + img0_origin = cv2.cvtColor(cv2.imread(img0_pth_), cv2.COLOR_BGR2RGB) + except: + img0_origin = cv2.cvtColor(cv2.imread(img0_pth), cv2.COLOR_BGR2RGB) + # img0_origin = cv2.rotate(img0_origin, cv2.cv2.ROTATE_90_CLOCKWISE) + if not img_warp_back: + img0_origin, warp_matrix = rotate_image(img0_origin, rotation_degree, preserve_full_img=False) + + try: + img1_origin = cv2.cvtColor(cv2.imread(img1_pth_),cv2.COLOR_BGR2RGB) + except: + img1_origin = cv2.cvtColor(cv2.imread(img1_pth),cv2.COLOR_BGR2RGB) + + # Inference with LoFTR and get prediction + with torch.no_grad(): + batch = matcher({"image0_path": img0_pth, "image1_path": img1_pth}) + mkpts0 = batch['mkpts0_f'].cpu().numpy() + mkpts1 = batch['mkpts1_f'].cpu().numpy() + mconf = batch['mconf'].cpu().numpy() + + kpts0 = batch['keypoints0'][0].cpu().numpy() if "keypoints0" in batch else None + kpts1 = batch['keypoints1'][0].cpu().numpy() if "keypoints1" in batch else None + + if enable_geometric_verify and mkpts0.shape[0] >= 8: + t0 = time() + # inliers = extract_geo_model_inliers(mkpts0, mkpts1, mconf, + # geo_model="F", ransac_method='MAGSAC', pixel_thr=1.0, max_iters=10000, conf_thr=0.99999, + # K0=None, K1=None) + + inliers = extract_geo_model_inliers(mkpts0, mkpts1, mconf, + # geo_model="F", ransac_method='MAGSAC', pixel_thr=pixel_thr, max_iters=10000, conf_thr=0.99999, + geo_model="F", ransac_method='DEGENSAC', pixel_thr=pixel_thr, max_iters=10000, conf_thr=0.99999, + K0=None, K1=None) + t1 = time() + mkpts0, mkpts1, mconf = map(lambda x: x[inliers], [mkpts0, mkpts1, mconf]) + print(f"Ransac takes:{t1-t0}, num inlier:{mkpts0.shape[0]}") + else: + logger.info("Geometry Verify is not Performed.") + + # Draw + alpha = 0.5 if matching_method == 'SuperPoint+SuperGlue' else 0.15 + color = cm.jet(mconf, alpha=alpha) + text = [ + matching_method, + 'Number of Matches: {}'.format(len(mkpts0)), + ] + + vertical = True + #fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text) + # fig = make_matching_figure(img1_raw, img1_raw, mkpts0, mkpts1, color, text, path="/home/hexingyi/code/LoFTR/matching_vertical.jpg", vertical=False) + if kpts0 is not None and kpts1 is not None: + text=[] + fig = make_matching_figure(img0_origin, img1_origin, mkpts0, mkpts1, color, kpts0=kpts0, kpts1=kpts1,text=text, draw_detection=True, draw_match_type=None, path=f"matching_horizontal_{matching_method}_{scene_name}_detection.jpg", vertical=vertical, plot_size_factor=3 if matching_method == 'SuperPoint+SuperGlue' else 1) + # fig = make_matching_figure(img0_origin, img1_origin, mkpts0, mkpts1, color, text=text, path=f"matching_horizontal_{matching_method}_{scene_name}.jpg", vertical=False) + + text=[] + # draw_match_type = "color" + draw_match_type = "corres" + # fig = make_matching_figure(img0_origin, img1_origin, mkpts0, mkpts1, color, text=text, path=f"{scene_name}_{matching_method}_matching{'_ransac' if enable_geometric_verify else ''}.jpg", vertical=vertical, plot_size_factor= 3 if matching_method == 'SuperPoint+SuperGlue' else 1) + # fig = make_matching_figure(img0_origin, img1_origin, mkpts0, mkpts1, color, text=text, path=f"{scene_name}_{matching_method}_matching{'_ransac' if enable_geometric_verify else ''}_{draw_match_type}.jpg", vertical=False, plot_size_factor= 3 if matching_method == 'SuperPoint+SuperGlue' else 1, draw_match_type=draw_match_type, r_normalize_factor=0.4) + fig = make_matching_figure(img0_origin, img1_origin, mkpts0, mkpts1, color, text=text, path=f"{scene_name}_{matching_method}_matching{'_ransac' if enable_geometric_verify else ''}_{draw_match_type}.jpg", vertical=True, plot_size_factor= 3 if matching_method == 'SuperPoint+SuperGlue' else 1, draw_match_type=draw_match_type, r_normalize_factor=0.4) + # fig = make_matching_figure(img0_origin, img1_origin, mkpts0, mkpts1, color, text=text, path=f"{scene_name}_{matching_method}_matching{'_ransac' if enable_geometric_verify else ''}_{draw_match_type}.jpg", vertical=False, plot_size_factor= 3 if matching_method == 'SuperPoint+SuperGlue' else 1, draw_match_type=draw_match_type, r_normalize_factor=0.4, use_position_color=True) + + # # visualize pca + # from sklearn.decomposition import PCA + # pca = PCA(n_components=3 ,svd_solver='arpack') + + # # visualize pca for backbone feature + # # feat: h*w*c + # feat0 = feat_c0 + # feat1 = feat_c1 + + # h,w,c = feat0.shape + # feat = np.concatenate([feat0.reshape(-1,c), feat1.reshape(-1, c)], axis=0) + # test_pca = np.random.rand(*feat.shape) + # feat_pca = pca.fit_transform(feat) + + # feat_pca0, feat_pca1 = feat_pca[:h*w].reshape(h,w,3), feat_pca[h*w:].reshape(h,w,3) + # feat_pca_cv2 = cv2.normalize(np.concatenate([feat_pca0,feat_pca1], axis=1), None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_8UC3) + # feat_pca_cv2_resize = cv2.resize(feat_pca_cv2,(w*2*8, h*8), interpolation=cv2.INTER_LINEAR) + + # feat_pca_resize0, feat_pca_resize1 = feat_pca_cv2_resize[:,:w*8,:], feat_pca_cv2_resize[:,w*8:,:] + # feat_map_gapped = np.hstack((feat_pca_resize0, np.ones((h*8, 10, 3),dtype=np.uint8)*255, feat_pca_resize1)) + + # # draw backbone feature pca + # fig, axes = plt.subplots(1,1,dpi=100) + # axes.imshow(feat_map_gapped) + # axes.get_yaxis().set_ticks([]) + # axes.get_xaxis().set_ticks([]) + # plt.tight_layout(pad=.5) + # plt.savefig('/home/hexingyi/code/LoFTR/backbone_feature.jpg') + + # # visualize pca for loftr coarse feature + # # feat: hw*c + # feat0 = loftr_c0 + # feat1 = loftr_c1 + + # h,w = feat_c0.shape[:2] + # c = loftr_c0.shape[-1] + # feat = np.concatenate([feat0, feat1], axis=0) + # feat_pca = pca.fit_transform(feat) + # feat_pca0, feat_pca1 = feat_pca[:h*w].reshape(h,w,3), feat_pca[h*w:].reshape(h,w,3) + # feat_pca_cv2 = cv2.normalize(np.concatenate([feat_pca0,feat_pca1], axis=1), None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_8UC3) + # feat_pca_cv2_resize = cv2.resize(feat_pca_cv2,(w*2*8, h*8), interpolation=cv2.INTER_LINEAR) + + # feat_pca_resize0, feat_pca_resize1 = feat_pca_cv2_resize[:,:w*8,:], feat_pca_cv2_resize[:,w*8:,:] + # feat_map_gapped = np.hstack((feat_pca_resize0, np.ones((h*8, 10, 3),dtype=np.uint8)*255, feat_pca_resize1)) + + # # draw patches + # fig, axes = plt.subplots(1,dpi=100) + # axes.imshow(feat_map_gapped) + # axes.get_yaxis().set_ticks([]) + # axes.get_xaxis().set_ticks([]) + # plt.tight_layout(pad=.5) + # plt.savefig('/home/hexingyi/code/LoFTR/loftr_coarse_feature.jpg') \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_indoor.py b/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_indoor.py new file mode 100644 index 0000000000000000000000000000000000000000..61734f2d452f47c448f4eb1f115bf391c92d16ab --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_indoor.py @@ -0,0 +1,320 @@ +import os +import torch +from argparse import ArgumentParser + +from torch import nn +from torch.utils.data import ConcatDataset +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import json +import wandb +from tqdm import tqdm + +from roma.benchmarks import MegadepthDenseBenchmark +from roma.datasets.megadepth import MegadepthBuilder +from roma.datasets.scannet import ScanNetBuilder +from roma.losses.robust_loss import RobustLosses +from roma.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark +from roma.train.train import train_k_steps +from roma.models.matcher import * +from roma.models.transformer import Block, TransformerDecoder, MemEffAttention +from roma.models.encoders import * +from roma.checkpointing import CheckPoint + +resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)} + +def get_model(pretrained_backbone=True, resolution = "medium", **kwargs): + gp_dim = 512 + feat_dim = 512 + decoder_dim = gp_dim + feat_dim + cls_to_coord_res = 64 + coordinate_decoder = TransformerDecoder( + nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), + decoder_dim, + cls_to_coord_res**2 + 1, + is_classifier=True, + amp = True, + pos_enc = False,) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + disable_local_corr_grad = True + + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128+(2*7+1)**2, + 2 * 512+128+(2*7+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "8": ConvRefiner( + 2 * 512+64+(2*3+1)**2, + 2 * 512+64+(2*3+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "4": ConvRefiner( + 2 * 256+32+(2*2+1)**2, + 2 * 256+32+(2*2+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "1": ConvRefiner( + 2 * 9 + 6, + 24, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks = hidden_blocks, + displacement_emb = displacement_emb, + displacement_emb_dim = 6, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"16": gp16}) + proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) + proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) + proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) + proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) + proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) + proj = nn.ModuleDict({ + "16": proj16, + "8": proj8, + "4": proj4, + "2": proj2, + "1": proj1, + }) + displacement_dropout_p = 0.0 + gm_warp_dropout_p = 0.0 + decoder = Decoder(coordinate_decoder, + gps, + proj, + conv_refiner, + detach=True, + scales=["16", "8", "4", "2", "1"], + displacement_dropout_p = displacement_dropout_p, + gm_warp_dropout_p = gm_warp_dropout_p) + h,w = resolutions[resolution] + encoder = CNNandDinov2( + cnn_kwargs = dict( + pretrained=pretrained_backbone, + amp = True), + amp = True, + use_vgg = True, + ) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs) + return matcher + +def train(args): + dist.init_process_group('nccl') + #torch._dynamo.config.verbose=True + gpus = int(os.environ['WORLD_SIZE']) + # create model and move it to GPU with id rank + rank = dist.get_rank() + print(f"Start running DDP on rank {rank}") + device_id = rank % torch.cuda.device_count() + roma.LOCAL_RANK = device_id + torch.cuda.set_device(device_id) + + resolution = args.train_resolution + wandb_log = not args.dont_log_wandb + experiment_name = os.path.splitext(os.path.basename(__file__))[0] + wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled" + wandb.init(project="roma", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode) + checkpoint_dir = "workspace/checkpoints/" + h,w = resolutions[resolution] + model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id) + # Num steps + global_step = 0 + batch_size = args.gpu_batch_size + step_size = gpus*batch_size + roma.STEP_SIZE = step_size + + N = (32 * 250000) # 250k steps of batch size 32 + # checkpoint every + k = 25000 // roma.STEP_SIZE + + # Data + mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) + use_horizontal_flip_aug = True + rot_prob = 0 + depth_interpolation_mode = "bilinear" + megadepth_train1 = mega.build_scenes( + split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, + ht=h,wt=w, + ) + megadepth_train2 = mega.build_scenes( + split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, + ht=h,wt=w, + ) + megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2) + mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75) + + scannet = ScanNetBuilder(data_root="data/scannet") + scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug) + scannet_train = ConcatDataset(scannet_train) + scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75) + + # Loss and optimizer + depth_loss_scannet = RobustLosses( + ce_weight=0.0, + local_dist={1:4, 2:4, 4:8, 8:8}, + local_largest_scale=8, + depth_interpolation_mode=depth_interpolation_mode, + alpha = 0.5, + c = 1e-4,) + # Loss and optimizer + depth_loss_mega = RobustLosses( + ce_weight=0.01, + local_dist={1:4, 2:4, 4:8, 8:8}, + local_largest_scale=8, + depth_interpolation_mode=depth_interpolation_mode, + alpha = 0.5, + c = 1e-4,) + parameters = [ + {"params": model.encoder.parameters(), "lr": roma.STEP_SIZE * 5e-6 / 8}, + {"params": model.decoder.parameters(), "lr": roma.STEP_SIZE * 1e-4 / 8}, + ] + optimizer = torch.optim.AdamW(parameters, weight_decay=0.01) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=[(9*N/roma.STEP_SIZE)//10]) + megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w) + checkpointer = CheckPoint(checkpoint_dir, experiment_name) + model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step) + roma.GLOBAL_STEP = global_step + ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True) + grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000) + grad_clip_norm = 0.01 + for n in range(roma.GLOBAL_STEP, N, k * roma.STEP_SIZE): + mega_sampler = torch.utils.data.WeightedRandomSampler( + mega_ws, num_samples = batch_size * k, replacement=False + ) + mega_dataloader = iter( + torch.utils.data.DataLoader( + megadepth_train, + batch_size = batch_size, + sampler = mega_sampler, + num_workers = 8, + ) + ) + scannet_ws_sampler = torch.utils.data.WeightedRandomSampler( + scannet_ws, num_samples=batch_size * k, replacement=False + ) + scannet_dataloader = iter( + torch.utils.data.DataLoader( + scannet_train, + batch_size=batch_size, + sampler=scannet_ws_sampler, + num_workers=gpus * 8, + ) + ) + for n_k in tqdm(range(n, n + 2 * k, 2),disable = roma.RANK > 0): + train_k_steps( + n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False + ) + train_k_steps( + n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False + ) + checkpointer.save(model, optimizer, lr_scheduler, roma.GLOBAL_STEP) + wandb.log(megadense_benchmark.benchmark(model), step = roma.GLOBAL_STEP) + +def test_scannet(model, name, resolution, sample_mode): + scannet_benchmark = ScanNetBenchmark("data/scannet") + scannet_results = scannet_benchmark.benchmark(model) + json.dump(scannet_results, open(f"results/scannet_{name}.json", "w")) + +if __name__ == "__main__": + import warnings + warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') + warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.') + os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations + os.environ["OMP_NUM_THREADS"] = "16" + + import roma + parser = ArgumentParser() + parser.add_argument("--test", action='store_true') + parser.add_argument("--debug_mode", action='store_true') + parser.add_argument("--dont_log_wandb", action='store_true') + parser.add_argument("--train_resolution", default='medium') + parser.add_argument("--gpu_batch_size", default=4, type=int) + parser.add_argument("--wandb_entity", required = False) + + args, _ = parser.parse_known_args() + roma.DEBUG_MODE = args.debug_mode + if not args.test: + train(args) + experiment_name = os.path.splitext(os.path.basename(__file__))[0] + checkpoint_dir = "workspace/" + checkpoint_name = checkpoint_dir + experiment_name + ".pth" + test_resolution = "medium" + sample_mode = "threshold_balanced" + symmetric = True + upsample_preds = False + attenuate_cert = True + + model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert) + model = model.cuda() + states = torch.load(checkpoint_name) + model.load_state_dict(states["model"]) + test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_outdoor.py b/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_outdoor.py new file mode 100644 index 0000000000000000000000000000000000000000..2d58b3d8c3c5d8c13228bf3463885eae80990934 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_outdoor.py @@ -0,0 +1,327 @@ +import os +import torch +from argparse import ArgumentParser + +from torch import nn +from torch.utils.data import ConcatDataset +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import json +# import wandb + +from roma.benchmarks import MegadepthDenseBenchmark +from roma.datasets.megadepth import MegadepthBuilder +from roma.losses.robust_loss import RobustLosses +from roma.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark + +from roma.train.train import train_k_steps +from roma.models.matcher import * +from roma.models.transformer import Block, TransformerDecoder, MemEffAttention +from roma.models.encoders import * +from roma.checkpointing import CheckPoint + +resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)} + +def get_model(pretrained_backbone=True, amp=True, coarse_resolution = (560, 560), coarse_backbone_type='DINOv2', coarse_feat_dim=1024, medium_feat_dim=512, coarse_patch_size=14, upsample_preds = False, symmetric=False, attenuate_cert=False, **kwargs): + import warnings + warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') + gp_dim = medium_feat_dim + feat_dim = medium_feat_dim + decoder_dim = gp_dim + feat_dim + cls_to_coord_res = 64 + coordinate_decoder = TransformerDecoder( + nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), + decoder_dim, + cls_to_coord_res**2 + 1, + is_classifier=True, + amp = amp, + pos_enc = False,) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + disable_local_corr_grad = True + + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * medium_feat_dim+128+(2*7+1)**2, + 2 * medium_feat_dim+128+(2*7+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + amp = amp, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "8": ConvRefiner( + 2 * medium_feat_dim+64+(2*3+1)**2, + 2 * medium_feat_dim+64+(2*3+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + amp = amp, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "4": ConvRefiner( + 2 * int(medium_feat_dim/2)+32+(2*2+1)**2, + 2 * int(medium_feat_dim/2)+32+(2*2+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + amp = amp, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + amp = amp, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "1": ConvRefiner( + 2 * 9 + 6, + 24, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks = hidden_blocks, + displacement_emb = displacement_emb, + displacement_emb_dim = 6, + amp = amp, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"16": gp16}) + proj16 = nn.Sequential(nn.Conv2d(coarse_feat_dim, medium_feat_dim, 1, 1), nn.BatchNorm2d(medium_feat_dim)) + proj8 = nn.Sequential(nn.Conv2d(512, medium_feat_dim, 1, 1), nn.BatchNorm2d(medium_feat_dim)) + proj4 = nn.Sequential(nn.Conv2d(256, int(medium_feat_dim/2), 1, 1), nn.BatchNorm2d(int(medium_feat_dim/2))) + proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) + proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) + proj = nn.ModuleDict({ + "16": proj16, + "8": proj8, + "4": proj4, + "2": proj2, + "1": proj1, + }) + displacement_dropout_p = 0.0 + gm_warp_dropout_p = 0.0 + decoder = Decoder(coordinate_decoder, + gps, + proj, + conv_refiner, + amp = amp, + detach=True, + scales=["16", "8", "4", "2", "1"], + displacement_dropout_p = displacement_dropout_p, + gm_warp_dropout_p = gm_warp_dropout_p) + h,w = coarse_resolution + encoder = CNNandDinov2( + cnn_kwargs = dict( + pretrained=pretrained_backbone, + amp = amp), + amp = amp, + use_vgg = True, + coarse_backbone=coarse_backbone_type, + coarse_patch_size=coarse_patch_size, + coarse_feat_dim=coarse_feat_dim, + ) + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, symmetric=symmetric, attenuate_cert=attenuate_cert, **kwargs) + return matcher + +def train(args): + dist.init_process_group('nccl') + #torch._dynamo.config.verbose=True + gpus = int(os.environ['WORLD_SIZE']) + # create model and move it to GPU with id rank + rank = dist.get_rank() + print(f"Start running DDP on rank {rank}") + device_id = rank % torch.cuda.device_count() + roma.LOCAL_RANK = device_id + torch.cuda.set_device(device_id) + + resolution = args.train_resolution + wandb_log = not args.dont_log_wandb + experiment_name = os.path.splitext(os.path.basename(__file__))[0] + wandb_mode = "online" if wandb_log and rank == 0 else "disabled" + wandb.init(project="roma", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode) + checkpoint_dir = "workspace/checkpoints/" + h,w = resolutions[resolution] + model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id) + # Num steps + global_step = 0 + batch_size = args.gpu_batch_size + step_size = gpus*batch_size + roma.STEP_SIZE = step_size + + N = (32 * 250000) # 250k steps of batch size 32 + # checkpoint every + k = 25000 // roma.STEP_SIZE + + # Data + mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) + use_horizontal_flip_aug = True + rot_prob = 0 + depth_interpolation_mode = "bilinear" + megadepth_train1 = mega.build_scenes( + split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, + ht=h,wt=w, + ) + megadepth_train2 = mega.build_scenes( + split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, + ht=h,wt=w, + ) + megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2) + mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75) + # Loss and optimizer + depth_loss = RobustLosses( + ce_weight=0.01, + local_dist={1:4, 2:4, 4:8, 8:8}, + local_largest_scale=8, + depth_interpolation_mode=depth_interpolation_mode, + alpha = 0.5, + c = 1e-4,) + parameters = [ + {"params": model.encoder.parameters(), "lr": roma.STEP_SIZE * 5e-6 / 8}, + {"params": model.decoder.parameters(), "lr": roma.STEP_SIZE * 1e-4 / 8}, + ] + optimizer = torch.optim.AdamW(parameters, weight_decay=0.01) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=[(9*N/roma.STEP_SIZE)//10]) + megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w) + checkpointer = CheckPoint(checkpoint_dir, experiment_name) + model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step) + roma.GLOBAL_STEP = global_step + ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True) + grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000) + grad_clip_norm = 0.01 + for n in range(roma.GLOBAL_STEP, N, k * roma.STEP_SIZE): + mega_sampler = torch.utils.data.WeightedRandomSampler( + mega_ws, num_samples = batch_size * k, replacement=False + ) + mega_dataloader = iter( + torch.utils.data.DataLoader( + megadepth_train, + batch_size = batch_size, + sampler = mega_sampler, + num_workers = 8, + ) + ) + train_k_steps( + n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, + ) + checkpointer.save(model, optimizer, lr_scheduler, roma.GLOBAL_STEP) + wandb.log(megadense_benchmark.benchmark(model), step = roma.GLOBAL_STEP) + +def test_mega_8_scenes(model, name, resolution, sample_mode): + mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth", + scene_names=['mega_8_scenes_0019_0.1_0.3.npz', + 'mega_8_scenes_0025_0.1_0.3.npz', + 'mega_8_scenes_0021_0.1_0.3.npz', + 'mega_8_scenes_0008_0.1_0.3.npz', + 'mega_8_scenes_0032_0.1_0.3.npz', + 'mega_8_scenes_1589_0.1_0.3.npz', + 'mega_8_scenes_0063_0.1_0.3.npz', + 'mega_8_scenes_0024_0.1_0.3.npz', + 'mega_8_scenes_0019_0.3_0.5.npz', + 'mega_8_scenes_0025_0.3_0.5.npz', + 'mega_8_scenes_0021_0.3_0.5.npz', + 'mega_8_scenes_0008_0.3_0.5.npz', + 'mega_8_scenes_0032_0.3_0.5.npz', + 'mega_8_scenes_1589_0.3_0.5.npz', + 'mega_8_scenes_0063_0.3_0.5.npz', + 'mega_8_scenes_0024_0.3_0.5.npz']) + mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name, scale_intrinsics = False) + print(mega_8_scenes_results) + json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w")) + +def test_mega1500(model, name, resolution, sample_mode): + mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth") + mega1500_results = mega1500_benchmark.benchmark(model, model_name=name) + json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w")) + +def test_mega_dense(model, name, resolution, sample_mode): + megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000) + megadense_results = megadense_benchmark.benchmark(model) + json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w")) + +def test_hpatches(model, name, resolution, sample_mode): + hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches") + hpatches_results = hpatches_benchmark.benchmark(model) + json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w")) + + +if __name__ == "__main__": + os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations + os.environ["OMP_NUM_THREADS"] = "16" + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + import roma + parser = ArgumentParser() + parser.add_argument("--only_test", action='store_true') + parser.add_argument("--debug_mode", action='store_true') + parser.add_argument("--dont_log_wandb", action='store_true') + parser.add_argument("--train_resolution", default='medium') + parser.add_argument("--gpu_batch_size", default=4, type=int) + parser.add_argument("--wandb_entity", required = False) + + args, _ = parser.parse_known_args() + roma.DEBUG_MODE = args.debug_mode + if not args.only_test: + train(args) + experiment_name = os.path.splitext(os.path.basename(__file__))[0] + checkpoint_dir = "workspace/checkpoints/" + checkpoint_name = checkpoint_dir + experiment_name + ".pth" + + test_resolution = "high" + sample_mode = "threshold_balanced" + symmetric = True + upsample_preds = upsample_preds + attenuate_cert = True + + model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert) + model = model.cuda() + weights = torch.load(checkpoint_name) + model.load_state_dict(weights) + test_mega1500(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/demo_single_pair.ipynb b/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/demo_single_pair.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..0b67c712a2edc2a8ab509711c246b4a4a474c7ad --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/demo_single_pair.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Demo LoFTR-DS on a single pair of images\n", + "\n", + "This notebook shows how to use the loftr matcher with default config(dual-softmax) and the pretrained weights." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 20, + "source": [ + "import os\n", + "os.chdir(\"..\")\n", + "from copy import deepcopy\n", + "\n", + "import torch\n", + "import cv2\n", + "import numpy as np\n", + "import matplotlib.cm as cm\n", + "from src.utils.plotting import make_matching_figure" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Indoor Example" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 9, + "source": [ + "from src.loftr import LoFTR, default_cfg\n", + "\n", + "# The default config uses dual-softmax.\n", + "# The outdoor and indoor models share the same config.\n", + "# You can change the default values like thr and coarse_match_type.\n", + "_default_cfg = deepcopy(default_cfg)\n", + "_default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt\n", + "matcher = LoFTR(config=_default_cfg)\n", + "matcher.load_state_dict(torch.load(\"weights/indoor_ds_new.ckpt\")['state_dict'])\n", + "matcher = matcher.eval().cuda()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 10, + "source": [ + "# Load example images\n", + "img0_pth = \"assets/scannet_sample_images/scene0711_00_frame-001680.jpg\"\n", + "img1_pth = \"assets/scannet_sample_images/scene0711_00_frame-001995.jpg\"\n", + "img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)\n", + "img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)\n", + "img0_raw = cv2.resize(img0_raw, (640, 480))\n", + "img1_raw = cv2.resize(img1_raw, (640, 480))\n", + "\n", + "img0 = torch.from_numpy(img0_raw)[None][None].cuda() / 255.\n", + "img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255.\n", + "batch = {'image0': img0, 'image1': img1}\n", + "\n", + "# Inference with LoFTR and get prediction\n", + "with torch.no_grad():\n", + " matcher(batch)\n", + " mkpts0 = batch['mkpts0_f'].cpu().numpy()\n", + " mkpts1 = batch['mkpts1_f'].cpu().numpy()\n", + " mconf = batch['mconf'].cpu().numpy()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 11, + "source": [ + "# Draw\n", + "color = cm.jet(mconf)\n", + "text = [\n", + " 'LoFTR',\n", + " 'Matches: {}'.format(len(mkpts0)),\n", + "]\n", + "fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)" + ], + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n\n \n \n \n \n 2021-08-18T00:38:02.543658\n image/svg+xml\n \n \n Matplotlib v3.3.4, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "" + }, + "metadata": {} + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Outdoor Example" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 16, + "source": [ + "from src.loftr import LoFTR, default_cfg\n", + "\n", + "# The default config uses dual-softmax.\n", + "# The outdoor and indoor models share the same config.\n", + "# You can change the default values like thr and coarse_match_type.\n", + "matcher = LoFTR(config=default_cfg)\n", + "matcher.load_state_dict(torch.load(\"weights/outdoor_ds.ckpt\")['state_dict'])\n", + "matcher = matcher.eval().cuda()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 19, + "source": [ + "default_cfg['coarse']" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'d_model': 256,\n", + " 'd_ffn': 256,\n", + " 'nhead': 8,\n", + " 'layer_names': ['self',\n", + " 'cross',\n", + " 'self',\n", + " 'cross',\n", + " 'self',\n", + " 'cross',\n", + " 'self',\n", + " 'cross'],\n", + " 'attention': 'linear',\n", + " 'temp_bug_fix': True}" + ] + }, + "metadata": {}, + "execution_count": 19 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 17, + "source": [ + "# Load example images\n", + "img0_pth = \"assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg\"\n", + "img1_pth = \"assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg\"\n", + "img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)\n", + "img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)\n", + "img0_raw = cv2.resize(img0_raw, (img0_raw.shape[1]//8*8, img0_raw.shape[0]//8*8)) # input size shuold be divisible by 8\n", + "img1_raw = cv2.resize(img1_raw, (img1_raw.shape[1]//8*8, img1_raw.shape[0]//8*8))\n", + "\n", + "img0 = torch.from_numpy(img0_raw)[None][None].cuda() / 255.\n", + "img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255.\n", + "batch = {'image0': img0, 'image1': img1}\n", + "\n", + "# Inference with LoFTR and get prediction\n", + "with torch.no_grad():\n", + " matcher(batch)\n", + " mkpts0 = batch['mkpts0_f'].cpu().numpy()\n", + " mkpts1 = batch['mkpts1_f'].cpu().numpy()\n", + " mconf = batch['mconf'].cpu().numpy()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 18, + "source": [ + "# Draw\n", + "color = cm.jet(mconf)\n", + "text = [\n", + " 'LoFTR',\n", + " 'Matches: {}'.format(len(mkpts0)),\n", + "]\n", + "fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)" + ], + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n\n \n \n \n \n 2021-08-18T00:41:19.149192\n image/svg+xml\n \n \n Matplotlib v3.3.4, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "" + }, + "metadata": {} + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [], + "outputs": [], + "metadata": {} + } + ], + "metadata": { + "kernelspec": { + "name": "python3", + "display_name": "Python 3.8.8 64-bit ('svcnn': conda)" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + }, + "orig_nbformat": 2, + "interpreter": { + "hash": "5b8911f875a754a9ad2a8804064d078bf6a1985972bb0389b9d67771213c8e20" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81512278dfdfa73dd0915defa732b3b0e7db6af6 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/__init__.py @@ -0,0 +1 @@ +from .plotting import * \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/plotting.py b/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..fb577a8012a30b1cdbf3145dcb3210986e04b2c0 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/plotting.py @@ -0,0 +1,331 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +from matplotlib.colors import hsv_to_rgb +import pylab as pl +import matplotlib.cm as cm +import cv2 + + +def visualize_features(feat, img_h, img_w, save_path=None): + from sklearn.decomposition import PCA + pca = PCA(n_components=3, svd_solver="arpack") + img = pca.fit_transform(feat).reshape(img_h * 2, img_w, 3) + img_norm = cv2.normalize( + img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC3 + ) + img_resized = cv2.resize( + img_norm, (img_w * 8, img_h * 2 * 8), interpolation=cv2.INTER_LINEAR + ) + img_colormap = img_resized + img1, img2 = img_colormap[: img_h * 8, :, :], img_colormap[img_h * 8 :, :, :] + img_gapped = np.hstack( + (img1, np.ones((img_h * 8, 10, 3), dtype=np.uint8) * 255, img2) + ) + if save_path is not None: + cv2.imwrite(save_path, img_gapped) + + fig, axes = plt.subplots(1, 1, dpi=200) + axes.imshow(img_gapped) + axes.get_yaxis().set_ticks([]) + axes.get_xaxis().set_ticks([]) + plt.tight_layout(pad=0.5) + return fig + + +def make_matching_figure( + img0, + img1, + mkpts0, + mkpts1, + color, + kpts0=None, + kpts1=None, + text=[], + path=None, + draw_detection=False, + draw_match_type='corres', # ['color', 'corres', None] + r_normalize_factor=0.4, + white_center=True, + vertical=False, + use_position_color=False, + draw_local_window=False, + window_size=(9, 9), + plot_size_factor=1, # Point size and line width + anchor_pts0=None, + anchor_pts1=None, +): + # draw image pair + fig, axes = ( + plt.subplots(2, 1, figsize=(10, 6), dpi=600) + if vertical + else plt.subplots(1, 2, figsize=(10, 6), dpi=600) + ) + axes[0].imshow(img0) + axes[1].imshow(img1) + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if use_position_color: + mean_coord = np.mean(mkpts0, axis=0) + x_center, y_center = mean_coord + # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive. + position_color = matching_coord2color( + mkpts0, + x_center, + y_center, + r_normalize_factor=r_normalize_factor, + white_center=white_center, + ) + color[:, :3] = position_color + + if draw_detection and kpts0 is not None and kpts1 is not None: + # color = 'g' + color = 'r' + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=1 * plot_size_factor) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=1 * plot_size_factor) + + if draw_match_type is 'corres': + # draw matches + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color[i], + linewidth=1* plot_size_factor, + ) + for i in range(len(mkpts0)) + ] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[:, :3], s=2* plot_size_factor) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[:, :3], s=2* plot_size_factor) + elif draw_match_type is 'color': + # x_center = img0.shape[-1] / 2 + # y_center = img1.shape[-2] / 2 + + mean_coord = np.mean(mkpts0, axis=0) + x_center, y_center = mean_coord + # NOTE: set r_normalize_factor to a smaller number will make plotted figure more contrastive. + kpts_color = matching_coord2color( + mkpts0, + x_center, + y_center, + r_normalize_factor=r_normalize_factor, + white_center=white_center, + ) + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=kpts_color, s=1 * plot_size_factor) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=kpts_color, s=1 * plot_size_factor) + + if draw_local_window: + anchor_pts0 = mkpts0 if anchor_pts0 is None else anchor_pts0 + anchor_pts1 = mkpts1 if anchor_pts1 is None else anchor_pts1 + plot_local_windows( + anchor_pts0, color=(1, 0, 0, 0.4), lw=0.2, ax_=0, window_size=window_size + ) + plot_local_windows( + anchor_pts1, color=(1, 0, 0, 0.4), lw=0.2, ax_=1, window_size=window_size + ) # lw =0.2 + + # put txts + txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + fig.text( + 0.01, + 0.99, + "\n".join(text), + transform=fig.axes[0].transAxes, + fontsize=15, + va="top", + ha="left", + color=txt_color, + ) + plt.tight_layout(pad=1) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig + +def make_triple_matching_figure( + img0, + img1, + img2, + mkpts01, + mkpts12, + color01, + color12, + text=[], + path=None, + draw_match=True, + r_normalize_factor=0.4, + white_center=True, + vertical=False, + draw_local_window=False, + window_size=(9, 9), + anchor_pts0=None, + anchor_pts1=None, +): + # draw image pair + fig, axes = ( + plt.subplots(3, 1, figsize=(10, 6), dpi=600) + if vertical + else plt.subplots(1, 3, figsize=(10, 6), dpi=600) + ) + axes[0].imshow(img0) + axes[1].imshow(img1) + axes[2].imshow(img2) + for i in range(3): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if draw_match: + # draw matches for [0,1] + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts01[0])) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts01[1])) + fig.lines = [ + matplotlib.lines.Line2D( + (fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, + c=color01[i], + linewidth=1, + ) + for i in range(len(mkpts01[0])) + ] + + axes[0].scatter(mkpts01[0][:, 0], mkpts01[0][:, 1], c=color01[:, :3], s=1) + axes[1].scatter(mkpts01[1][:, 0], mkpts01[1][:, 1], c=color01[:, :3], s=1) + + fig.canvas.draw() + # draw matches for [1,2] + fkpts1_1 = transFigure.transform(axes[1].transData.transform(mkpts12[0])) + fkpts2 = transFigure.transform(axes[2].transData.transform(mkpts12[1])) + fig.lines += [ + matplotlib.lines.Line2D( + (fkpts1_1[i, 0], fkpts2[i, 0]), + (fkpts1_1[i, 1], fkpts2[i, 1]), + transform=fig.transFigure, + c=color12[i], + linewidth=1, + ) + for i in range(len(mkpts12[0])) + ] + + axes[1].scatter(mkpts12[0][:, 0], mkpts12[0][:, 1], c=color12[:, :3], s=1) + axes[2].scatter(mkpts12[1][:, 0], mkpts12[1][:, 1], c=color12[:, :3], s=1) + + # # put txts + # txt_color = "k" if img0[:100, :200].mean() > 200 else "w" + # fig.text( + # 0.01, + # 0.99, + # "\n".join(text), + # transform=fig.axes[0].transAxes, + # fontsize=15, + # va="top", + # ha="left", + # color=txt_color, + # ) + plt.tight_layout(pad=0.1) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + plt.close() + else: + return fig + + +def matching_coord2color(kpts, x_center, y_center, r_normalize_factor=0.4, white_center=True): + """ + r_normalize_factor is used to visualize clearer according to points space distribution + r_normalize_factor maxium=1, larger->points darker/brighter + """ + if not white_center: + # dark center points + V, H = np.mgrid[0:1:10j, 0:1:360j] + S = np.ones_like(V) + else: + # white center points + S, H = np.mgrid[0:1:10j, 0:1:360j] + V = np.ones_like(S) + + HSV = np.dstack((H, S, V)) + RGB = hsv_to_rgb(HSV) + """ + # used to visualize hsv + pl.imshow(RGB, origin="lower", extent=[0, 360, 0, 1], aspect=150) + pl.xlabel("H") + pl.ylabel("S") + pl.title("$V_{HSV}=1$") + pl.show() + """ + kpts = np.copy(kpts) + distance = kpts - np.array([x_center, y_center])[None] + r_max = np.percentile(np.linalg.norm(distance, axis=1), 85) + # r_max = np.sqrt((x_center) ** 2 + (y_center) ** 2) + kpts[:, 0] = kpts[:, 0] - x_center # x + kpts[:, 1] = kpts[:, 1] - y_center # y + + r = np.sqrt(kpts[:, 0] ** 2 + kpts[:, 1] ** 2) + 1e-6 + r_normalized = r / (r_max * r_normalize_factor) + r_normalized[r_normalized > 1] = 1 + r_normalized = (r_normalized) * 9 + + cos_theta = kpts[:, 0] / r # x / r + theta = np.arccos(cos_theta) # from 0 to pi + change_angle_mask = kpts[:, 1] < 0 + theta[change_angle_mask] = 2 * np.pi - theta[change_angle_mask] + theta_degree = np.degrees(theta) + theta_degree[theta_degree == 360] = 0 # to avoid overflow + theta_degree = theta_degree / 360 * 360 + kpts_color = RGB[r_normalized.astype(int), theta_degree.astype(int)] + return kpts_color + + +def show_image_pair(img0, img1, path=None): + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=200) + axes[0].imshow(img0, cmap="gray") + axes[1].imshow(img1, cmap="gray") + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + if path: + plt.savefig(str(path), bbox_inches="tight", pad_inches=0) + return fig + +def plot_local_windows(kpts, color="r", lw=1, ax_=0, window_size=(9, 9)): + ax = plt.gcf().axes + for kpt in kpts: + ax[ax_].add_patch( + matplotlib.patches.Rectangle( + ( + kpt[0] - (window_size[0] // 2) - 1, + kpt[1] - (window_size[1] // 2) - 1, + ), + window_size[0] + 1, + window_size[1] + 1, + lw=lw, + color=color, + fill=False, + ) + ) + diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/requirements.txt b/imcui/third_party/MatchAnything/third_party/ROMA/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f0dbab3d4cb35a5f00e3dbc8e3f8b00a3e578428 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/requirements.txt @@ -0,0 +1,13 @@ +torch +einops +torchvision +opencv-python +kornia +albumentations +loguru +tqdm +matplotlib +h5py +wandb +timm +#xformers # Optional, used for memefficient attention \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c96481e0a808b68c7b3054a3e34fa0b5c45ab9 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/__init__.py @@ -0,0 +1,8 @@ +import os +from .models import roma_outdoor, roma_indoor + +DEBUG_MODE = False +RANK = int(os.environ.get('RANK', default = 0)) +GLOBAL_STEP = 0 +STEP_SIZE = 1 +LOCAL_RANK = -1 \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de7b841a5a6ab2ba91297a181a79dfaa91c9e104 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/__init__.py @@ -0,0 +1,4 @@ +from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark +from .scannet_benchmark import ScanNetBenchmark +from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark +from .megadepth_dense_benchmark import MegadepthDenseBenchmark diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/hpatches_sequences_homog_benchmark.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/hpatches_sequences_homog_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..2154a471c73d9e883c3ba8ed1b90d708f4950a63 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/hpatches_sequences_homog_benchmark.py @@ -0,0 +1,113 @@ +from PIL import Image +import numpy as np + +import os + +from tqdm import tqdm +from roma.utils import pose_auc +import cv2 + + +class HpatchesHomogBenchmark: + """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]""" + + def __init__(self, dataset_path) -> None: + seqs_dir = "hpatches-sequences-release" + self.seqs_path = os.path.join(dataset_path, seqs_dir) + self.seq_names = sorted(os.listdir(self.seqs_path)) + # Ignore seqs is same as LoFTR. + self.ignore_seqs = set( + [ + "i_contruction", + "i_crownnight", + "i_dc", + "i_pencils", + "i_whitebuilding", + "v_artisans", + "v_astronautis", + "v_talent", + ] + ) + + def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup): + offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think) + im_A_coords = ( + np.stack( + ( + wq * (im_A_coords[..., 0] + 1) / 2, + hq * (im_A_coords[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + im_A_to_im_B = ( + np.stack( + ( + wsup * (im_A_to_im_B[..., 0] + 1) / 2, + hsup * (im_A_to_im_B[..., 1] + 1) / 2, + ), + axis=-1, + ) + - offset + ) + return im_A_coords, im_A_to_im_B + + def benchmark(self, model, model_name = None): + n_matches = [] + homog_dists = [] + for seq_idx, seq_name in tqdm( + enumerate(self.seq_names), total=len(self.seq_names) + ): + im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm") + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + for im_idx in range(2, 7): + im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm") + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + H = np.loadtxt( + os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx)) + ) + dense_matches, dense_certainty = model.match( + im_A_path, im_B_path + ) + good_matches, _ = model.sample(dense_matches, dense_certainty, 5000) + pos_a, pos_b = self.convert_coordinates( + good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2 + ) + try: + H_pred, inliers = cv2.findHomography( + pos_a, + pos_b, + method = cv2.RANSAC, + confidence = 0.99999, + ransacReprojThreshold = 3 * min(w2, h2) / 480, + ) + except: + H_pred = None + if H_pred is None: + H_pred = np.zeros((3, 3)) + H_pred[2, 2] = 1.0 + corners = np.array( + [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]] + ) + real_warped_corners = np.dot(corners, np.transpose(H)) + real_warped_corners = ( + real_warped_corners[:, :2] / real_warped_corners[:, 2:] + ) + warped_corners = np.dot(corners, np.transpose(H_pred)) + warped_corners = warped_corners[:, :2] / warped_corners[:, 2:] + mean_dist = np.mean( + np.linalg.norm(real_warped_corners - warped_corners, axis=1) + ) / (min(w2, h2) / 480.0) + homog_dists.append(mean_dist) + + n_matches = np.array(n_matches) + thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + auc = pose_auc(np.array(homog_dists), thresholds) + return { + "hpatches_homog_auc_3": auc[2], + "hpatches_homog_auc_5": auc[4], + "hpatches_homog_auc_10": auc[9], + } diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_dense_benchmark.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_dense_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..0600d354b1d0dfa7f8e2b0f8882a4cc08fafeed9 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_dense_benchmark.py @@ -0,0 +1,106 @@ +import torch +import numpy as np +import tqdm +from roma.datasets import MegadepthBuilder +from roma.utils import warp_kpts +from torch.utils.data import ConcatDataset +import roma + +class MegadepthDenseBenchmark: + def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None: + mega = MegadepthBuilder(data_root=data_root) + self.dataset = ConcatDataset( + mega.build_scenes(split="test_loftr", ht=h, wt=w) + ) # fixed resolution of 384,512 + self.num_samples = num_samples + + def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches): + b, h1, w1, d = dense_matches.shape + with torch.no_grad(): + x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2) + mask, x2 = warp_kpts( + x1.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + ) + x2 = torch.stack( + (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1 + ) + prob = mask.float().reshape(b, h1, w1) + x2_hat = dense_matches[..., 2:] + x2_hat = torch.stack( + (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1 + ) + gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1) + gd = gd[prob == 1] + pck_1 = (gd < 1.0).float().mean() + pck_3 = (gd < 3.0).float().mean() + pck_5 = (gd < 5.0).float().mean() + return gd, pck_1, pck_3, pck_5, prob + + def benchmark(self, model, batch_size=8): + model.train(False) + with torch.no_grad(): + gd_tot = 0.0 + pck_1_tot = 0.0 + pck_3_tot = 0.0 + pck_5_tot = 0.0 + sampler = torch.utils.data.WeightedRandomSampler( + torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples + ) + B = batch_size + dataloader = torch.utils.data.DataLoader( + self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler + ) + for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0): + im_A, im_B, depth1, depth2, T_1to2, K1, K2 = ( + data["im_A"], + data["im_B"], + data["im_A_depth"].cuda(), + data["im_B_depth"].cuda(), + data["T_1to2"].cuda(), + data["K1"].cuda(), + data["K2"].cuda(), + ) + matches, certainty = model.match(im_A, im_B, batched=True) + gd, pck_1, pck_3, pck_5, prob = self.geometric_dist( + depth1, depth2, T_1to2, K1, K2, matches + ) + if roma.DEBUG_MODE: + from roma.utils.utils import tensor_to_pil + import torch.nn.functional as F + path = "vis" + H, W = model.get_output_resolution() + white_im = torch.ones((B,1,H,W),device="cuda") + im_B_transfer_rgb = F.grid_sample( + im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False + ) + warp_im = im_B_transfer_rgb + c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None] + vis_im = c_b * warp_im + (1 - c_b) * white_im + for b in range(B): + import os + os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True) + tensor_to_pil(vis_im[b], unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg") + tensor_to_pil(im_A[b].cuda(), unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg") + tensor_to_pil(im_B[b].cuda(), unnormalize=True).save( + f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg") + + + gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = ( + gd_tot + gd.mean(), + pck_1_tot + pck_1, + pck_3_tot + pck_3, + pck_5_tot + pck_5, + ) + return { + "epe": gd_tot.item() / len(dataloader), + "mega_pck_1": pck_1_tot.item() / len(dataloader), + "mega_pck_3": pck_3_tot.item() / len(dataloader), + "mega_pck_5": pck_5_tot.item() / len(dataloader), + } diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_pose_estimation_benchmark.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_pose_estimation_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..8007fe8ecad09c33401450ad6b7af1f3dad043d2 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_pose_estimation_benchmark.py @@ -0,0 +1,140 @@ +import numpy as np +import torch +from roma.utils import * +from PIL import Image +from tqdm import tqdm +import torch.nn.functional as F +import roma +import kornia.geometry.epipolar as kepi + +class MegaDepthPoseEstimationBenchmark: + def __init__(self, data_root="data/megadepth", scene_names = None) -> None: + if scene_names is None: + self.scene_names = [ + "0015_0.1_0.3.npz", + "0015_0.3_0.5.npz", + "0022_0.1_0.3.npz", + "0022_0.3_0.5.npz", + "0022_0.5_0.7.npz", + ] + else: + self.scene_names = scene_names + self.scenes = [ + np.load(f"{data_root}/{scene}", allow_pickle=True) + for scene in self.scene_names + ] + self.data_root = data_root + + def benchmark(self, model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True): + H,W = model.get_output_resolution() + with torch.no_grad(): + data_root = self.data_root + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + thresholds = [5, 10, 20] + for scene_ind in range(len(self.scenes)): + import os + scene_name = os.path.splitext(self.scene_names[scene_ind])[0] + scene = self.scenes[scene_ind] + pairs = scene["pair_infos"] + intrinsics = scene["intrinsics"] + poses = scene["poses"] + im_paths = scene["image_paths"] + pair_inds = range(len(pairs)) + for pairind in tqdm(pair_inds): + idx1, idx2 = pairs[pairind][0] + K1 = intrinsics[idx1].copy() + T1 = poses[idx1].copy() + R1, t1 = T1[:3, :3], T1[:3, 3] + K2 = intrinsics[idx2].copy() + T2 = poses[idx2].copy() + R2, t2 = T2[:3, :3], T2[:3, 3] + R, t = compute_relative_pose(R1, t1, R2, t2) + T1_to_2 = np.concatenate((R,t[:,None]), axis=-1) + im_A_path = f"{data_root}/{im_paths[idx1]}" + im_B_path = f"{data_root}/{im_paths[idx2]}" + dense_matches, dense_certainty = model.match( + im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy() + ) + sparse_matches,_ = model.sample( + dense_matches, dense_certainty, 5000 + ) + + im_A = Image.open(im_A_path) + w1, h1 = im_A.size + im_B = Image.open(im_B_path) + w2, h2 = im_B.size + + if scale_intrinsics: + scale1 = 1200 / max(w1, h1) + scale2 = 1200 / max(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1, K2 = K1.copy(), K2.copy() + K1[:2] = K1[:2] * scale1 + K2[:2] = K2[:2] * scale2 + + kpts1 = sparse_matches[:, :2] + kpts1 = ( + np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2, + h1 * (kpts1[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2, + h2 * (kpts2[:, 1] + 1) / 2, + ), + axis=-1, + ) + ) + + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + threshold = 0.5 + if calibrated: + norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + print(f"{model_name} auc: {auc}") + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/scannet_benchmark.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/scannet_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..853af0d0ebef4dfefe2632eb49e4156ea791ee76 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/scannet_benchmark.py @@ -0,0 +1,143 @@ +import os.path as osp +import numpy as np +import torch +from roma.utils import * +from PIL import Image +from tqdm import tqdm + + +class ScanNetBenchmark: + def __init__(self, data_root="data/scannet") -> None: + self.data_root = data_root + + def benchmark(self, model, model_name = None): + model.train(False) + with torch.no_grad(): + data_root = self.data_root + tmp = np.load(osp.join(data_root, "test.npz")) + pairs, rel_pose = tmp["name"], tmp["rel_pose"] + tot_e_t, tot_e_R, tot_e_pose = [], [], [] + pair_inds = np.random.choice( + range(len(pairs)), size=len(pairs), replace=False + ) + for pairind in tqdm(pair_inds, smoothing=0.9): + scene = pairs[pairind] + scene_name = f"scene0{scene[0]}_00" + im_A_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[2]}.jpg", + ) + im_A = Image.open(im_A_path) + im_B_path = osp.join( + self.data_root, + "scans_test", + scene_name, + "color", + f"{scene[3]}.jpg", + ) + im_B = Image.open(im_B_path) + T_gt = rel_pose[pairind].reshape(3, 4) + R, t = T_gt[:3, :3], T_gt[:3, 3] + K = np.stack( + [ + np.array([float(i) for i in r.split()]) + for r in open( + osp.join( + self.data_root, + "scans_test", + scene_name, + "intrinsic", + "intrinsic_color.txt", + ), + "r", + ) + .read() + .split("\n") + if r + ] + ) + w1, h1 = im_A.size + w2, h2 = im_B.size + K1 = K.copy() + K2 = K.copy() + dense_matches, dense_certainty = model.match(im_A_path, im_B_path) + sparse_matches, sparse_certainty = model.sample( + dense_matches, dense_certainty, 5000 + ) + scale1 = 480 / min(w1, h1) + scale2 = 480 / min(w2, h2) + w1, h1 = scale1 * w1, scale1 * h1 + w2, h2 = scale2 * w2, scale2 * h2 + K1 = K1 * scale1 + K2 = K2 * scale2 + + offset = 0.5 + kpts1 = sparse_matches[:, :2] + kpts1 = ( + np.stack( + ( + w1 * (kpts1[:, 0] + 1) / 2 - offset, + h1 * (kpts1[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + kpts2 = sparse_matches[:, 2:] + kpts2 = ( + np.stack( + ( + w2 * (kpts2[:, 0] + 1) / 2 - offset, + h2 * (kpts2[:, 1] + 1) / 2 - offset, + ), + axis=-1, + ) + ) + for _ in range(5): + shuffling = np.random.permutation(np.arange(len(kpts1))) + kpts1 = kpts1[shuffling] + kpts2 = kpts2[shuffling] + try: + norm_threshold = 0.5 / ( + np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))) + R_est, t_est, mask = estimate_pose( + kpts1, + kpts2, + K1, + K2, + norm_threshold, + conf=0.99999, + ) + T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) # + e_t, e_R = compute_pose_error(T1_to_2_est, R, t) + e_pose = max(e_t, e_R) + except Exception as e: + print(repr(e)) + e_t, e_R = 90, 90 + e_pose = max(e_t, e_R) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_t.append(e_t) + tot_e_R.append(e_R) + tot_e_pose.append(e_pose) + tot_e_pose = np.array(tot_e_pose) + thresholds = [5, 10, 20] + auc = pose_auc(tot_e_pose, thresholds) + acc_5 = (tot_e_pose < 5).mean() + acc_10 = (tot_e_pose < 10).mean() + acc_15 = (tot_e_pose < 15).mean() + acc_20 = (tot_e_pose < 20).mean() + map_5 = acc_5 + map_10 = np.mean([acc_5, acc_10]) + map_20 = np.mean([acc_5, acc_10, acc_15, acc_20]) + return { + "auc_5": auc[0], + "auc_10": auc[1], + "auc_20": auc[2], + "map_5": map_5, + "map_10": map_10, + "map_20": map_20, + } diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/__init__.py @@ -0,0 +1 @@ +from .checkpoint import CheckPoint diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/checkpoint.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8995efeb54f4d558127ea63423fa958c64e9088f --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/checkpoint.py @@ -0,0 +1,60 @@ +import os +import torch +from torch.nn.parallel.data_parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel +from loguru import logger +import gc + +import roma + +class CheckPoint: + def __init__(self, dir=None, name="tmp"): + self.name = name + self.dir = dir + os.makedirs(self.dir, exist_ok=True) + + def save( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if roma.RANK == 0: + assert model is not None + if isinstance(model, (DataParallel, DistributedDataParallel)): + model = model.module + states = { + "model": model.state_dict(), + "n": n, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + } + torch.save(states, self.dir + self.name + f"_latest.pth") + logger.info(f"Saved states {list(states.keys())}, at step {n}") + + def load( + self, + model, + optimizer, + lr_scheduler, + n, + ): + if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0: + states = torch.load(self.dir + self.name + f"_latest.pth") + if "model" in states: + model.load_state_dict(states["model"]) + if "n" in states: + n = states["n"] if states["n"] else n + if "optimizer" in states: + try: + optimizer.load_state_dict(states["optimizer"]) + except Exception as e: + print(f"Failed to load states for optimizer, with error {e}") + if "lr_scheduler" in states: + lr_scheduler.load_state_dict(states["lr_scheduler"]) + print(f"Loaded states {list(states.keys())}, at step {n}") + del states + gc.collect() + torch.cuda.empty_cache() + return model, optimizer, lr_scheduler, n \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b60c709926a4a7bd019b73eac10879063a996c90 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/__init__.py @@ -0,0 +1,2 @@ +from .megadepth import MegadepthBuilder +from .scannet import ScanNetBuilder \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/megadepth.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..5deee5ac30c439a9f300c0ad2271f141931020c0 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/megadepth.py @@ -0,0 +1,230 @@ +import os +from PIL import Image +import h5py +import numpy as np +import torch +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +import roma +from roma.utils import * +import math + +class MegadepthScene: + def __init__( + self, + data_root, + scene_info, + ht=384, + wt=512, + min_overlap=0.0, + max_overlap=1.0, + shake_t=0, + rot_prob=0.0, + normalize=True, + max_num_pairs = 100_000, + scene_name = None, + use_horizontal_flip_aug = False, + use_single_horizontal_flip_aug = False, + colorjiggle_params = None, + random_eraser = None, + use_randaug = False, + randaug_params = None, + randomize_size = False, + ) -> None: + self.data_root = data_root + self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}" + self.image_paths = scene_info["image_paths"] + self.depth_paths = scene_info["depth_paths"] + self.intrinsics = scene_info["intrinsics"] + self.poses = scene_info["poses"] + self.pairs = scene_info["pairs"] + self.overlaps = scene_info["overlaps"] + threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap) + self.pairs = self.pairs[threshold] + self.overlaps = self.overlaps[threshold] + if len(self.pairs) > max_num_pairs: + pairinds = np.random.choice( + np.arange(0, len(self.pairs)), max_num_pairs, replace=False + ) + self.pairs = self.pairs[pairinds] + self.overlaps = self.overlaps[pairinds] + if randomize_size: + area = ht * wt + s = int(16 * (math.sqrt(area)//16)) + sizes = ((ht,wt), (s,s), (wt,ht)) + choice = roma.RANK % 3 + ht, wt = sizes[choice] + # counts, bins = np.histogram(self.overlaps,20) + # print(counts) + self.im_transform_ops = get_tuple_transform_ops( + resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params, + ) + self.depth_transform_ops = get_depth_tuple_transform_ops( + resize=(ht, wt) + ) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.random_eraser = random_eraser + if use_horizontal_flip_aug and use_single_horizontal_flip_aug: + raise ValueError("Can't both flip both images and only flip one") + self.use_horizontal_flip_aug = use_horizontal_flip_aug + self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug + self.use_randaug = use_randaug + + def load_im(self, im_path): + im = Image.open(im_path) + return im + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + im_A = im_A.flip(-1) + im_B = im_B.flip(-1) + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) + K_A = flip_mat@K_A + K_B = flip_mat@K_B + + return im_A, im_B, depth_A, depth_B, K_A, K_B + + def load_depth(self, depth_ref, crop=None): + depth = np.array(h5py.File(depth_ref, "r")["depth"]) + return torch.from_numpy(depth) + + def __len__(self): + return len(self.pairs) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) + return sK @ K + + def rand_shake(self, *things): + t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2) + return [ + tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0]) + for thing in things + ], t + + def __getitem__(self, pair_idx): + # read intrinsics of original size + idx1, idx2 = self.pairs[pair_idx] + K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3) + K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3) + + # read and compute relative poses + T1 = self.poses[idx1] + T2 = self.poses[idx2] + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ + :4, :4 + ] # (4, 4) + + # Load positive pair data + im_A, im_B = self.image_paths[idx1], self.image_paths[idx2] + depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2] + im_A_ref = os.path.join(self.data_root, im_A) + im_B_ref = os.path.join(self.data_root, im_B) + depth_A_ref = os.path.join(self.data_root, depth1) + depth_B_ref = os.path.join(self.data_root, depth2) + im_A = self.load_im(im_A_ref) + im_B = self.load_im(im_B_ref) + K1 = self.scale_intrinsic(K1, im_A.width, im_A.height) + K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) + + if self.use_randaug: + im_A, im_B = self.rand_augment(im_A, im_B) + + depth_A = self.load_depth(depth_A_ref) + depth_B = self.load_depth(depth_B_ref) + # Process images + im_A, im_B = self.im_transform_ops((im_A, im_B)) + depth_A, depth_B = self.depth_transform_ops( + (depth_A[None, None], depth_B[None, None]) + ) + + [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B) + K1[:2, 2] += t + K2[:2, 2] += t + + im_A, im_B = im_A[None], im_B[None] + if self.random_eraser is not None: + im_A, depth_A = self.random_eraser(im_A, depth_A) + im_B, depth_B = self.random_eraser(im_B, depth_B) + + if self.use_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + if self.use_single_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2) + + if roma.DEBUG_MODE: + tensor_to_pil(im_A[0], unnormalize=True).save( + f"vis/im_A.jpg") + tensor_to_pil(im_B[0], unnormalize=True).save( + f"vis/im_B.jpg") + + data_dict = { + "im_A": im_A[0], + "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0], + "im_B": im_B[0], + "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0], + "im_A_depth": depth_A[0, 0], + "im_B_depth": depth_B[0, 0], + "K1": K1, + "K2": K2, + "T_1to2": T_1to2, + "im_A_path": im_A_ref, + "im_B_path": im_B_ref, + + } + return data_dict + + +class MegadepthBuilder: + def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root, "prep_scene_info") + self.all_scenes = os.listdir(self.scene_info_root) + self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"] + # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those + self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy']) + self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy']) + self.test_scenes_loftr = ["0015.npy", "0022.npy"] + self.loftr_ignore = loftr_ignore + self.imc21_ignore = imc21_ignore + + def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs): + if split == "train": + scene_names = set(self.all_scenes) - set(self.test_scenes) + elif split == "train_loftr": + scene_names = set(self.all_scenes) - set(self.test_scenes_loftr) + elif split == "test": + scene_names = self.test_scenes + elif split == "test_loftr": + scene_names = self.test_scenes_loftr + elif split == "custom": + scene_names = scene_names + else: + raise ValueError(f"Split {split} not available") + scenes = [] + for scene_name in scene_names: + if self.loftr_ignore and scene_name in self.loftr_ignore_scenes: + continue + if self.imc21_ignore and scene_name in self.imc21_scenes: + continue + scene_info = np.load( + os.path.join(self.scene_info_root, scene_name), allow_pickle=True + ).item() + scenes.append( + MegadepthScene( + self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs + ) + ) + return scenes + + def weight_scenes(self, concat_dataset, alpha=0.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) + return ws diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/scannet.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..704ea57259afdfbbca627ad143bee97a0a79d41c --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/scannet.py @@ -0,0 +1,160 @@ +import os +import random +from PIL import Image +import cv2 +import h5py +import numpy as np +import torch +from torch.utils.data import ( + Dataset, + DataLoader, + ConcatDataset) + +import torchvision.transforms.functional as tvf +import kornia.augmentation as K +import os.path as osp +import matplotlib.pyplot as plt +import roma +from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops +from roma.utils.transforms import GeometricSequential +from tqdm import tqdm + +class ScanNetScene: + def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False, +) -> None: + self.scene_root = osp.join(data_root,"scans","scans_train") + self.data_names = scene_info['name'] + self.overlaps = scene_info['score'] + # Only sample 10s + valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0 + self.overlaps = self.overlaps[valid] + self.data_names = self.data_names[valid] + if len(self.data_names) > 10000: + pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False) + self.data_names = self.data_names[pairinds] + self.overlaps = self.overlaps[pairinds] + self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) + self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False) + self.wt, self.ht = wt, ht + self.shake_t = shake_t + self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) + self.use_horizontal_flip_aug = use_horizontal_flip_aug + + def load_im(self, im_B, crop=None): + im = Image.open(im_B) + return im + + def load_depth(self, depth_ref, crop=None): + depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) + depth = depth / 1000 + depth = torch.from_numpy(depth).float() # (h, w) + return depth + + def __len__(self): + return len(self.data_names) + + def scale_intrinsic(self, K, wi, hi): + sx, sy = self.wt / wi, self.ht / hi + sK = torch.tensor([[sx, 0, 0], + [0, sy, 0], + [0, 0, 1]]) + return sK@K + + def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): + im_A = im_A.flip(-1) + im_B = im_B.flip(-1) + depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) + flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device) + K_A = flip_mat@K_A + K_B = flip_mat@K_B + + return im_A, im_B, depth_A, depth_B, K_A, K_B + def read_scannet_pose(self,path): + """ Read ScanNet's Camera2World pose and transform it to World2Camera. + + Returns: + pose_w2c (np.ndarray): (4, 4) + """ + cam2world = np.loadtxt(path, delimiter=' ') + world2cam = np.linalg.inv(cam2world) + return world2cam + + + def read_scannet_intrinsic(self,path): + """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. + """ + intrinsic = np.loadtxt(path, delimiter=' ') + return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float) + + def __getitem__(self, pair_idx): + # read intrinsics of original size + data_name = self.data_names[pair_idx] + scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name + scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' + + # read the intrinsic of depthmap + K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root, + scene_name, + 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter + # read and compute relative poses + T1 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_1}.txt')) + T2 = self.read_scannet_pose(osp.join(self.scene_root, + scene_name, + 'pose', f'{stem_name_2}.txt')) + T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4) + + # Load positive pair data + im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg') + im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg') + depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png') + depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png') + + im_A = self.load_im(im_A_ref) + im_B = self.load_im(im_B_ref) + depth_A = self.load_depth(depth_A_ref) + depth_B = self.load_depth(depth_B_ref) + + # Recompute camera intrinsic matrix due to the resize + K1 = self.scale_intrinsic(K1, im_A.width, im_A.height) + K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) + # Process images + im_A, im_B = self.im_transform_ops((im_A, im_B)) + depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None])) + if self.use_horizontal_flip_aug: + if np.random.rand() > 0.5: + im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2) + + data_dict = {'im_A': im_A, + 'im_B': im_B, + 'im_A_depth': depth_A[0,0], + 'im_B_depth': depth_B[0,0], + 'K1': K1, + 'K2': K2, + 'T_1to2':T_1to2, + } + return data_dict + + +class ScanNetBuilder: + def __init__(self, data_root = 'data/scannet') -> None: + self.data_root = data_root + self.scene_info_root = os.path.join(data_root,'scannet_indices') + self.all_scenes = os.listdir(self.scene_info_root) + + def build_scenes(self, split = 'train', min_overlap=0., **kwargs): + # Note: split doesn't matter here as we always use same scannet_train scenes + scene_names = self.all_scenes + scenes = [] + for scene_name in tqdm(scene_names, disable = roma.RANK > 0): + scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True) + scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs)) + return scenes + + def weight_scenes(self, concat_dataset, alpha=.5): + ns = [] + for d in concat_dataset.datasets: + ns.append(len(d)) + ws = torch.cat([torch.ones(n)/n**alpha for n in ns]) + return ws diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e08abacfc0f83d7de0f2ddc0583766a80bf53cf --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/__init__.py @@ -0,0 +1 @@ +from .robust_loss import RobustLosses \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/robust_loss.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/robust_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7375a15c27775ac06718f471a99cf186c7a3dba1 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/robust_loss.py @@ -0,0 +1,157 @@ +from einops.einops import rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F +from roma.utils.utils import get_gt_warp +# import wandb +import roma +import math + +class RobustLosses(nn.Module): + def __init__( + self, + robust=False, + center_coords=False, + scale_normalize=False, + ce_weight=0.01, + local_loss=True, + local_dist=4.0, + local_largest_scale=8, + smooth_mask = False, + depth_interpolation_mode = "bilinear", + mask_depth_loss = False, + relative_depth_error_threshold = 0.05, + alpha = 1., + c = 1e-3, + ): + super().__init__() + self.robust = robust # measured in pixels + self.center_coords = center_coords + self.scale_normalize = scale_normalize + self.ce_weight = ce_weight + self.local_loss = local_loss + self.local_dist = local_dist + self.local_largest_scale = local_largest_scale + self.smooth_mask = smooth_mask + self.depth_interpolation_mode = depth_interpolation_mode + self.mask_depth_loss = mask_depth_loss + self.relative_depth_error_threshold = relative_depth_error_threshold + self.avg_overlap = dict() + self.alpha = alpha + self.c = c + + def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale): + with torch.no_grad(): + B, C, H, W = scale_gm_cls.shape + device = x2.device + cls_res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) + G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) + GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices + cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99] + if not torch.any(cls_loss): + cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere + + certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) + losses = { + f"gm_certainty_loss_{scale}": certainty_loss.mean(), + f"gm_cls_loss_{scale}": cls_loss.mean(), + } + wandb.log(losses, step = roma.GLOBAL_STEP) + return losses + + def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale): + with torch.no_grad(): + B, C, H, W = delta_cls.shape + device = x2.device + cls_res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) + G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale + GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices + cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99] + if not torch.any(cls_loss): + cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere + certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob) + losses = { + f"delta_certainty_loss_{scale}": certainty_loss.mean(), + f"delta_cls_loss_{scale}": cls_loss.mean(), + } + wandb.log(losses, step = roma.GLOBAL_STEP) + return losses + + def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): + epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) + if scale == 1: + pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() + wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP) + + ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) + a = self.alpha + cs = self.c * scale + x = epe[prob > 0.99] + reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) + if not torch.any(reg_loss): + reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere + losses = { + f"{mode}_certainty_loss_{scale}": ce_loss.mean(), + f"{mode}_regression_loss_{scale}": reg_loss.mean(), + } + wandb.log(losses, step = roma.GLOBAL_STEP) + return losses + + def forward(self, corresps, batch): + scales = list(corresps.keys()) + tot_loss = 0.0 + # scale_weights due to differences in scale for regression gradients and classification gradients + scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1} + for scale in scales: + scale_corresps = corresps[scale] + scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = ( + scale_corresps["certainty"], + scale_corresps["flow_pre_delta"], + scale_corresps.get("delta_cls"), + scale_corresps.get("offset_scale"), + scale_corresps.get("gm_cls"), + scale_corresps.get("gm_certainty"), + scale_corresps["flow"], + scale_corresps.get("gm_flow"), + + ) + flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") + b, h, w, d = flow_pre_delta.shape + gt_warp, gt_prob = get_gt_warp( + batch["im_A_depth"], + batch["im_B_depth"], + batch["T_1to2"], + batch["K1"], + batch["K2"], + H=h, + W=w, + ) + x2 = gt_warp.float() + prob = gt_prob + + if self.local_largest_scale >= scale: + prob = prob * ( + F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0] + < (2 / 512) * (self.local_dist[scale] * scale)) + + if scale_gm_cls is not None: + gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale) + gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * gm_loss + elif scale_gm_flow is not None: + gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") + gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * gm_loss + + if delta_cls is not None: + delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale) + delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss + else: + delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) + reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] + tot_loss = tot_loss + scale_weights[scale] * reg_loss + prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach() + return tot_loss diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/matchanything_roma_model.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/matchanything_roma_model.py new file mode 100644 index 0000000000000000000000000000000000000000..86175c85db551063c6b73445fb59042d86b4e7d7 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/matchanything_roma_model.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +from PIL import Image +import numpy as np +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent.resolve())) + +from ..experiments.roma_outdoor import get_model + +class MatchAnything_Model(nn.Module): + def __init__(self, config, test_mode=False) -> None: + super().__init__() + self.config = config + self.test_mode = test_mode + self.resize_by_stretch = config['resize_by_stretch'] + self.norm_image = config['normalize_img'] + model_config = self.config['model'] + if not test_mode : + self.model = get_model(pretrained_backbone=True, amp=model_config['amp'], coarse_backbone_type=model_config['coarse_backbone'], coarse_feat_dim=model_config['coarse_feat_dim'], medium_feat_dim=model_config['medium_feat_dim'], coarse_patch_size=model_config['coarse_patch_size']) # Train mode + else: + self.model = get_model(pretrained_backbone=True, amp=model_config['amp'], coarse_backbone_type=model_config['coarse_backbone'], coarse_feat_dim=model_config['coarse_feat_dim'], medium_feat_dim=model_config['medium_feat_dim'], coarse_patch_size=model_config['coarse_patch_size'], coarse_resolution=self.config['test_time']['coarse_res'], symmetric=self.config['test_time']['symmetric'], upsample_preds=self.config['test_time']['upsample'], attenuate_cert=self.config['test_time']['attenutate_cert']) # Test mode + self.model.upsample_res = self.config['test_time']['upsample_res'] + self.model.sample_mode = self.config['sample']['method'] + self.model.sample_thresh = self.config['sample']['thresh'] + + def forward(self, data): + if not self.test_mode: + return self.forward_train_framework(data) + else: + return self.forward_inference(data) + + def forward_train_framework(self, data): + # Get already resize & padded images by dataloader + img0, img1 = data['image0'], data['image1'] # B * C * H * W + corresps = self.model.forward({"im_A": img0, "im_B": img1}, batched=True) + + data.update({"corresps":corresps}) # for supervision + + warp, certainity = self.model.self_train_time_match(data, corresps) # batched and padded + + m_bids = [] + mkpts0_f = [] + mkpts1_f = [] + m_conf = [] + for b_id in range(warp.shape[0]): + if self.resize_by_stretch: + H_A, W_A = data["origin_img_size0"][b_id][0], data["origin_img_size0"][b_id][1] + H_B, W_B = data["origin_img_size1"][b_id][0], data["origin_img_size1"][b_id][1] + else: + # By padding: + H_A, W_A = data["origin_img_size0"][b_id].max(), data["origin_img_size0"][b_id].max() + H_B, W_B = data["origin_img_size1"][b_id].max(), data["origin_img_size1"][b_id].max() + # # Sample matches for estimation + matches, certainity_ = self.model.sample(warp[b_id], certainity[b_id], num=self.config['sample']['n_sample']) + kpts0, kpts1 = self.model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + m_bids.append(torch.ones((kpts0.shape[0],), device=matches.device, dtype=torch.long) * b_id) + mkpts0_f.append(kpts0) + mkpts1_f.append(kpts1) + m_conf.append(certainity_) + data.update({'m_bids': torch.cat(m_bids), "mkpts0_f": torch.cat(mkpts0_f), "mkpts1_f": torch.cat(mkpts1_f), "mconf": torch.cat(m_conf)}) + + def forward_inference(self, data): + # Assume Loaded image in original image shape + if 'image0_rgb_origin' in data: + img0, img1 = data['image0_rgb_origin'][0], data['image1_rgb_origin'][0] + elif 'image0_rgb' in data: + img0, img1 = data['image0_rgb'][0], data['image1_rgb'][0] + else: + raise NotImplementedError + warp, dense_certainity = self.model.self_inference_time_match(img0, img1, resize_by_stretch=self.resize_by_stretch, norm_img=self.norm_image) + + if self.resize_by_stretch: + H_A, W_A = img0.shape[-2], img0.shape[-1] + H_B, W_B = img1.shape[-2], img1.shape[-1] + else: + A_max_edge = max(img0.shape[-2:]) + H_A, W_A = A_max_edge, A_max_edge + B_max_edge = max(img1.shape[-2:]) + H_B, W_B = B_max_edge, B_max_edge + + # Sample matches for estimation + matches, certainity = self.model.sample(warp, dense_certainity, num=self.config['sample']['n_sample']) + kpts0, kpts1 = self.model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + + mask = certainity > self.config['match_thresh'] + # Mask borders: + mask *= (kpts0[:, 0] <= img0.shape[-1]-1) * (kpts0[:, 1] <= img0.shape[-2]-1) * (kpts1[:, 0] <= img1.shape[-1]-1) * (kpts1[:, 1] <= img1.shape[-2]-1) + data.update({'m_bids': torch.zeros_like(kpts0[:, 0])[mask], "mkpts0_f": kpts0[mask], "mkpts1_f": kpts1[mask], "mconf": certainity[mask]}) + + # Warp query points: + if 'query_points' in data: + detector_kpts0 = data['query_points'].to(torch.float32) # B * N * 2 + within_mask = (detector_kpts0[..., 0] >= 0) & (detector_kpts0[..., 0] <= (W_A - 1)) & (detector_kpts0[..., 1] >= 0) & (detector_kpts0[..., 1] <= (H_A - 1)) + internal_detector_kpts0 = detector_kpts0[within_mask] + warped_detector_kpts0, cert_A_to_B = self.model.warp_keypoints(internal_detector_kpts0, warp, dense_certainity, H_A, W_A, H_B, W_B) + data.update({"query_points_warpped": warped_detector_kpts0}) + return data + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f20461e2f3a1722e558cefab94c5164be8842c3 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/__init__.py @@ -0,0 +1 @@ +from .model_zoo import roma_outdoor, roma_indoor \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/blocks.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..18133524f0ae265b0bd8d062d7c9eeaa63858a9b --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/blocks.py @@ -0,0 +1,241 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Main encoder/decoder blocks +# -------------------------------------------------------- +# References: +# timm +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py + + +import torch +import torch.nn as nn + +from itertools import repeat +import collections.abc + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class Attention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x, xpos): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3) + q, k, v = [qkv[:,:,i] for i in range(3)] + # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, xpos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class CrossAttention(nn.Module): + + def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + def forward(self, query, key, value, qpos, kpos): + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class DecoderBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x, y + + +# patch embedding +class PositionGetter(object): + """ return positions of patches """ + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h,w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() + return pos + +class PatchEmbed(nn.Module): + """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.position_getter = PositionGetter() + + def forward(self, x): + B, C, H, W = x.shape + torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + def _init_weights(self): + w = self.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/criterion.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..11696c40865344490f23796ea45e8fbd5e654731 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/criterion.py @@ -0,0 +1,37 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Criterion to train CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import torch + +class MaskedMSE(torch.nn.Module): + + def __init__(self, norm_pix_loss=False, masked=True): + """ + norm_pix_loss: normalize each patch by their pixel mean and variance + masked: compute loss over the masked patches only + """ + super().__init__() + self.norm_pix_loss = norm_pix_loss + self.masked = masked + + def forward(self, pred, mask, target): + + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + if self.masked: + loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches + else: + loss = loss.mean() # mean loss + return loss diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco.py new file mode 100644 index 0000000000000000000000000000000000000000..6a53985b6ffdc51f125cc51c0533629399776f5d --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco.py @@ -0,0 +1,253 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# CroCo model during pretraining +# -------------------------------------------------------- + + + +import torch +import torch.nn as nn +# torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 +from functools import partial + +from .blocks import Block, DecoderBlock, PatchEmbed +from .pos_embed import get_2d_sincos_pos_embed, RoPE2D +from .masking import RandomMask + + +class CroCoNet(nn.Module): + + def __init__(self, + img_size=224, # input image size + patch_size=16, # patch_size + mask_ratio=0.9, # ratios of masked tokens + enc_embed_dim=768, # encoder feature dimension + enc_depth=12, # encoder depth + enc_num_heads=12, # encoder number of heads in the transformer block + dec_embed_dim=512, # decoder feature dimension + dec_depth=8, # decoder depth + dec_num_heads=16, # decoder number of heads in the transformer block + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder + pos_embed='cosine', # positional embedding (either cosine or RoPE100) + ): + + super(CroCoNet, self).__init__() + + # patch embeddings (with initialization done as in MAE) + self._set_patch_embed(img_size, patch_size, enc_embed_dim) + + # mask generations + self._set_mask_generator(self.patch_embed.num_patches, mask_ratio) + + self.pos_embed = pos_embed + if pos_embed=='cosine': + # positional embedding of the encoder + enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float()) + # positional embedding of the decoder + dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0) + self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float()) + # pos embedding in each block + self.rope = None # nothing for cosine + elif pos_embed.startswith('RoPE'): # eg RoPE100 + self.enc_pos_embed = None # nothing to add in the encoder with RoPE + self.dec_pos_embed = None # nothing to add in the decoder with RoPE + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(pos_embed[len('RoPE'):]) + self.rope = RoPE2D(freq=freq) + else: + raise NotImplementedError('Unknown pos_embed '+pos_embed) + + # transformer for the encoder + self.enc_depth = enc_depth + self.enc_embed_dim = enc_embed_dim + self.enc_blocks = nn.ModuleList([ + Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope) + for i in range(enc_depth)]) + self.enc_norm = norm_layer(enc_embed_dim) + + # masked tokens + self._set_mask_token(dec_embed_dim) + + # decoder + self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec) + + # prediction head + self._set_prediction_head(dec_embed_dim, patch_size) + + # initializer weights + self.initialize_weights() + + @property + def device(self): + return self.patch_embed.device + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) + + def _set_mask_generator(self, num_patches, mask_ratio): + self.mask_generator = RandomMask(num_patches, mask_ratio) + + def _set_mask_token(self, dec_embed_dim): + self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) + + def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec): + self.dec_depth = dec_depth + self.dec_embed_dim = dec_embed_dim + # transfer from encoder to decoder + self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) + # transformer for the decoder + self.dec_blocks = nn.ModuleList([ + DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope) + for i in range(dec_depth)]) + # final norm layer + self.dec_norm = norm_layer(dec_embed_dim) + + def _set_prediction_head(self, dec_embed_dim, patch_size): + self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) + + + def initialize_weights(self): + # patch embed + self.patch_embed._init_weights() + # mask tokens + if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02) + # linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _encode_image(self, image, do_mask=False, return_all_blocks=False): + """ + image has B x 3 x img_size x img_size + do_mask: whether to perform masking or not + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + """ + # embed the image into patches (x has size B x Npatches x C) + # and get position if each return patch (pos has size B x Npatches x 2) + x, pos = self.patch_embed(image) + # add positional embedding without cls token + if self.enc_pos_embed is not None: + x = x + self.enc_pos_embed[None,...] + # apply masking + B,N,C = x.size() + if do_mask: + masks = self.mask_generator(x) + x = x[~masks].view(B, -1, C) + posvis = pos[~masks].view(B, -1, 2) + else: + B,N,C = x.size() + masks = torch.zeros((B,N), dtype=bool) + posvis = pos + # now apply the transformer encoder and normalization + if return_all_blocks: + out = [] + for blk in self.enc_blocks: + x = blk(x, posvis) + out.append(x) + out[-1] = self.enc_norm(out[-1]) + return out, pos, masks + else: + for blk in self.enc_blocks: + x = blk(x, posvis) + x = self.enc_norm(x) + return x, pos, masks + + def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): + """ + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + + masks1 can be None => assume image1 fully visible + """ + # encoder to decoder layer + visf1 = self.decoder_embed(feat1) + f2 = self.decoder_embed(feat2) + # append masked tokens to the sequence + B,Nenc,C = visf1.size() + if masks1 is None: # downstreams + f1_ = visf1 + else: # pretraining + Ntotal = masks1.size(1) + f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) + f1_[~masks1] = visf1.view(B * Nenc, C) + # add positional embedding + if self.dec_pos_embed is not None: + f1_ = f1_ + self.dec_pos_embed + f2 = f2 + self.dec_pos_embed + # apply Transformer blocks + out = f1_ + out2 = f2 + if return_all_blocks: + _out, out = out, [] + for blk in self.dec_blocks: + _out, out2 = blk(_out, out2, pos1, pos2) + out.append(_out) + out[-1] = self.dec_norm(out[-1]) + else: + for blk in self.dec_blocks: + out, out2 = blk(out, out2, pos1, pos2) + out = self.dec_norm(out) + return out + + def patchify(self, imgs): + """ + imgs: (B, 3, H, W) + x: (B, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x, channels=3): + """ + x: (N, L, patch_size**2 *channels) + imgs: (N, 3, H, W) + """ + patch_size = self.patch_embed.patch_size[0] + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) + return imgs + + def forward(self, img1, img2): + """ + img1: tensor of size B x 3 x img_size x img_size + img2: tensor of size B x 3 x img_size x img_size + + out will be B x N x (3*patch_size*patch_size) + masks are also returned as B x N just in case + """ + # encoder of the masked first image + feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) + # encoder of the second image + feat2, pos2, _ = self._encode_image(img2, do_mask=False) + # decoder + decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) + # prediction head + out = self.prediction_head(decfeat) + # get target + target = self.patchify(img1) + return out, mask1, target diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco_downstream.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..159dfff4d2c1461bc235e21441b57ce1e2088f76 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco_downstream.py @@ -0,0 +1,122 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# CroCo model for downstream tasks +# -------------------------------------------------------- + +import torch + +from .croco import CroCoNet + + +def croco_args_from_ckpt(ckpt): + if 'croco_kwargs' in ckpt: # CroCo v2 released models + return ckpt['croco_kwargs'] + elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release + s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" + assert s.startswith('CroCoNet(') + return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it + else: # CroCo v1 released models + return dict() + +class CroCoDownstreamMonocularEncoder(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for monocular downstream task, only using the encoder. + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + NOTE: It works by *calling super().__init__() but with redefined setters + + """ + super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_decoder(self, *args, **kwargs): + """ No decoder """ + return + + def _set_prediction_head(self, *args, **kwargs): + """ No 'prediction head' for downstream tasks.""" + return + + def forward(self, img): + """ + img if of size batch_size x 3 x h x w + """ + B, C, H, W = img.size() + img_info = {'height': H, 'width': W} + need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers) + return self.head(out, img_info) + + +class CroCoDownstreamBinocular(CroCoNet): + + def __init__(self, + head, + **kwargs): + """ Build network for binocular downstream task + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + """ + super(CroCoDownstreamBinocular, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """ No mask generator """ + return + + def _set_mask_token(self, *args, **kwargs): + """ No mask token """ + self.mask_token = None + return + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head for downstream tasks, define your own head """ + return + + def encode_image_pairs(self, img1, img2, return_all_blocks=False): + """ run encoder for a pair of images + it is actually ~5% faster to concatenate the images along the batch dimension + than to encode them separately + """ + ## the two commented lines below is the naive version with separate encoding + #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) + #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) + ## and now the faster version + out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks ) + if return_all_blocks: + out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) + out2 = out2[-1] + else: + out,out2 = out.chunk(2, dim=0) + pos,pos2 = pos.chunk(2, dim=0) + return out, out2, pos, pos2 + + def forward(self, img1, img2): + B, C, H, W = img1.size() + img_info = {'height': H, 'width': W} + return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks + out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks) + if return_all_blocks: + decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks) + decout = out+decout + else: + decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks) + return self.head(decout, img_info) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25e3d48a162760260826080f6366838e83e26878 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from .curope2d import cuRoPE2D diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope.cpp b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fe9058e05aa1bf3f37b0d970edc7312bc68455b --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope.cpp @@ -0,0 +1,69 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include + +// forward declaration +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); + +void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) +{ + const int B = tokens.size(0); + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3) / 4; + + auto tok = tokens.accessor(); + auto pos = positions.accessor(); + + for (int b = 0; b < B; b++) { + for (int x = 0; x < 2; x++) { // y and then x (2d) + for (int n = 0; n < N; n++) { + + // grab the token position + const int p = pos[b][n][x]; + + for (int h = 0; h < H; h++) { + for (int d = 0; d < D; d++) { + // grab the two values + float u = tok[b][n][h][d+0+x*2*D]; + float v = tok[b][n][h][d+D+x*2*D]; + + // grab the cos,sin + const float inv_freq = fwd * p / powf(base, d/float(D)); + float c = cosf(inv_freq); + float s = sinf(inv_freq); + + // write the result + tok[b][n][h][d+0+x*2*D] = u*c - v*s; + tok[b][n][h][d+D+x*2*D] = v*c + u*s; + } + } + } + } + } +} + +void rope_2d( torch::Tensor tokens, // B,N,H,D + const torch::Tensor positions, // B,N,2 + const float base, + const float fwd ) +{ + TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); + TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); + TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); + TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); + TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); + TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); + + if (tokens.is_cuda()) + rope_2d_cuda( tokens, positions, base, fwd ); + else + rope_2d_cpu( tokens, positions, base, fwd ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); +} diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope2d.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c12f8c529e9a889b5ac20c5767158f238e17d --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope2d.py @@ -0,0 +1,40 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func (torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d( tokens, positions, base, F0 ) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d( grad_res, positions, base, -F0 ) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) + return tokens \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/kernels.cu b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..7156cd1bb935cb1f0be45e58add53f9c21505c20 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/kernels.cu @@ -0,0 +1,108 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include +#include +#include +#include + +#define CHECK_CUDA(tensor) {\ + TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ + TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } +void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} + + +template < typename scalar_t > +__global__ void rope_2d_cuda_kernel( + //scalar_t* __restrict__ tokens, + torch::PackedTensorAccessor32 tokens, + const int64_t* __restrict__ pos, + const float base, + const float fwd ) + // const int N, const int H, const int D ) +{ + // tokens shape = (B, N, H, D) + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3); + + // each block update a single token, for all heads + // each thread takes care of a single output + extern __shared__ float shared[]; + float* shared_inv_freq = shared + D; + + const int b = blockIdx.x / N; + const int n = blockIdx.x % N; + + const int Q = D / 4; + // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] + // u_Y v_Y u_X v_X + + // shared memory: first, compute inv_freq + if (threadIdx.x < Q) + shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); + __syncthreads(); + + // start of X or Y part + const int X = threadIdx.x < D/2 ? 0 : 1; + const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X + + // grab the cos,sin appropriate for me + const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; + const float cos = cosf(freq); + const float sin = sinf(freq); + /* + float* shared_cos_sin = shared + D + D/4; + if ((threadIdx.x % (D/2)) < Q) + shared_cos_sin[m+0] = cosf(freq); + else + shared_cos_sin[m+Q] = sinf(freq); + __syncthreads(); + const float cos = shared_cos_sin[m+0]; + const float sin = shared_cos_sin[m+Q]; + */ + + for (int h = 0; h < H; h++) + { + // then, load all the token for this head in shared memory + shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; + __syncthreads(); + + const float u = shared[m]; + const float v = shared[m+Q]; + + // write output + if ((threadIdx.x % (D/2)) < Q) + tokens[b][n][h][threadIdx.x] = u*cos - v*sin; + else + tokens[b][n][h][threadIdx.x] = v*cos + u*sin; + } +} + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) +{ + const int B = tokens.size(0); // batch size + const int N = tokens.size(1); // sequence length + const int H = tokens.size(2); // number of heads + const int D = tokens.size(3); // dimension per head + + TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); + TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); + TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); + TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); + + // one block for each layer, one thread per local-max + const int THREADS_PER_BLOCK = D; + const int N_BLOCKS = B * N; // each block takes care of H*D values + const int SHARED_MEM = sizeof(float) * (D + D/4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { + rope_2d_cuda_kernel <<>> ( + //tokens.data_ptr(), + tokens.packed_accessor32(), + pos.data_ptr(), + base, fwd); //, N, H, D ); + })); +} diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/setup.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..230632ed05e309200e8f93a3a852072333975009 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/setup.py @@ -0,0 +1,34 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ + # '-gencode', 'arch=compute_70,code=sm_70', + # '-gencode', 'arch=compute_75,code=sm_75', + # '-gencode', 'arch=compute_80,code=sm_80', + # '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name = 'curope', + ext_modules = [ + CUDAExtension( + name='curope', + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args = dict( + nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, + cxx=['-O3']) + ) + ], + cmdclass = { + 'build_ext': BuildExtension + }) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/dpt_block.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/dpt_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ddfb74e2769ceca88720d4c730e00afd71c763 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/dpt_block.py @@ -0,0 +1,450 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# DPT head for ViTs +# -------------------------------------------------------- +# References: +# https://github.com/isl-org/DPT +# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from typing import Union, Tuple, Iterable, List, Optional, Dict + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList([ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ]) + + return scratch + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + width_ratio=1, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + self.width_ratio = width_ratio + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + if self.width_ratio != 1: + res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear') + + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if self.width_ratio != 1: + # and output.shape[3] < self.width_ratio * output.shape[2] + #size=(image.shape[]) + if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio: + shape = 3 * output.shape[3] + else: + shape = int(self.width_ratio * 2 * output.shape[2]) + output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear') + else: + output = nn.functional.interpolate(output, scale_factor=2, + mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + return output + +def make_fusion_block(features, use_bn, width_ratio=1): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + width_ratio=width_ratio, + ) + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_cahnnels: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__(self, + num_channels: int = 1, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ('rgb',), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + last_dim: int = 32, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = 'regression', + output_width_ratio=1, + **kwargs): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None + self.head_type = head_type + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + + if self.head_type == 'regression': + # The "DPTDepthModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0) + ) + elif self.head_type == 'semseg': + # The "DPTSegmentationModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc=768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + #print(dim_tokens_enc) + + # Set up activation postprocessing layers + if isinstance(dim_tokens_enc, int): + dim_tokens_enc = 4 * [dim_tokens_enc] + + self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc] + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[0], + out_channels=self.layer_dims[0], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, stride=4, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[1], + out_channels=self.layer_dims[1], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, stride=2, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[2], + out_channels=self.layer_dims[2], + kernel_size=1, stride=1, padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[3], + out_channels=self.layer_dims[3], + kernel_size=1, stride=1, padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, stride=2, padding=1, + ) + ) + + self.act_postprocess = nn.ModuleList([ + self.act_1_postprocess, + self.act_2_postprocess, + self.act_3_postprocess, + self.act_4_postprocess + ]) + + def adapt_tokens(self, encoder_tokens): + # Adapt tokens + x = [] + x.append(encoder_tokens[:, :]) + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], image_size): + #input_info: Dict): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + H, W = image_size + + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/head_downstream.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/head_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..bd40c91ba244d6c3522c6efd4ed4d724b7bdc650 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/head_downstream.py @@ -0,0 +1,58 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Heads for downstream tasks +# -------------------------------------------------------- + +""" +A head is a module where the __init__ defines only the head hyperparameters. +A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. +The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' +""" + +import torch +import torch.nn as nn +from .dpt_block import DPTOutputAdapter + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for CroCo. + by default, hooks_idx will be equal to: + * for encoder-only: 4 equally spread layers + * for encoder+decoder: last encoder + 3 equally spread layers of the decoder + """ + + def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768], + output_width_ratio=1, num_channels=1, postprocess=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_blocks = True # backbone needs to return all layers + self.postprocess = postprocess + self.output_width_ratio = output_width_ratio + self.num_channels = num_channels + self.hooks_idx = hooks_idx + self.layer_dims = layer_dims + + def setup(self, croconet): + dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels} + if self.hooks_idx is None: + if hasattr(croconet, 'dec_blocks'): # encoder + decoder + step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] + hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + else: # encoder only + step = croconet.enc_depth//4 + hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)] + self.hooks_idx = hooks_idx + print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}') + dpt_args['hooks'] = self.hooks_idx + dpt_args['layer_dims'] = self.layer_dims + self.dpt = DPTOutputAdapter(**dpt_args) + dim_tokens = [croconet.enc_embed_dim if hook0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +#---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +#---------------------------------------------------------- + +try: + from models.curope import cuRoPE2D + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D,seq_len,device,dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D,seq_len,device,dtype] = (cos,sin) + return self.cache[D,seq_len,device,dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim==2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:,:,0], cos, sin) + x = self.apply_rope1d(x, positions[:,:,1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc597c702861154bbe7a08f23b089474e926bb35 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/__init__.py @@ -0,0 +1,29 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# global alignment optimization wrapper function +# -------------------------------------------------------- +from enum import Enum + +from .optimizer import PointCloudOptimizer +from .pair_viewer import PairViewer + + +class GlobalAlignerMode(Enum): + PointCloudOptimizer = "PointCloudOptimizer" + PairViewer = "PairViewer" + + +def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): + # extract all inputs + view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] + # build the optimizer + if mode == GlobalAlignerMode.PointCloudOptimizer: + net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.PairViewer: + net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) + else: + raise NotImplementedError(f'Unknown mode {mode}') + + return net diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/base_opt.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/base_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..0869abd75b19ad441e5c32a27f34973edf1f9795 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/base_opt.py @@ -0,0 +1,375 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class for the global alignement procedure +# -------------------------------------------------------- +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn +import roma +from copy import deepcopy +import tqdm + +from dust3r.utils.geometry import inv, geotrf +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb +from dust3r.viz import SceneViz, segment_sky, auto_cam_size +from dust3r.optim_factory import adjust_learning_rate_by_lr + +from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, + cosine_schedule, linear_schedule, get_conf_trf) +import dust3r.cloud_opt.init_im_poses as init_fun + + +class BasePCOptimizer (nn.Module): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + other = deepcopy(args[0]) + attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes + min_conf_thr conf_thr conf_i conf_j im_conf + base_scale norm_pw_scale POSE_DIM pw_poses + pw_adaptors pw_adaptors has_im_poses rand_pose imgs'''.split() + self.__dict__.update({k: other[k] for k in attrs}) + else: + self._init_from_views(*args, **kwargs) + + def _init_from_views(self, view1, view2, pred1, pred2, + dist='l1', + conf='log', + min_conf_thr=3, + base_scale=0.5, + allow_pw_adaptors=False, + pw_break=20, + rand_pose=torch.randn, + iterationsCount=None): + super().__init__() + if not isinstance(view1['idx'], list): + view1['idx'] = view1['idx'].tolist() + if not isinstance(view2['idx'], list): + view2['idx'] = view2['idx'].tolist() + self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} + self.dist = ALL_DISTS[dist] + + self.n_imgs = self._check_edges() + + # input data + pred1_pts = pred1['pts3d'] + pred2_pts = pred2['pts3d_in_other_view'] + self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) + self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) + self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) + + # work in log-scale with conf + pred1_conf = pred1['conf'] + pred2_conf = pred2['conf'] + self.min_conf_thr = min_conf_thr + self.conf_trf = get_conf_trf(conf) + + self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}) + self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}) + self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) + + # pairwise pose parameters + self.base_scale = base_scale + self.norm_pw_scale = True + self.pw_break = pw_break + self.POSE_DIM = 7 + self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses + self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation + self.pw_adaptors.requires_grad_(allow_pw_adaptors) + self.has_im_poses = False + self.rand_pose = rand_pose + + # possibly store images for show_pointcloud + self.imgs = None + if 'img' in view1 and 'img' in view2: + imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] + for v in range(len(self.edges)): + idx = view1['idx'][v] + imgs[idx] = view1['img'][v] + idx = view2['idx'][v] + imgs[idx] = view2['img'][v] + self.imgs = rgb(imgs) + + @property + def n_edges(self): + return len(self.edges) + + @property + def str_edges(self): + return [edge_str(i, j) for i, j in self.edges] + + @property + def imsizes(self): + return [(w, h) for h, w in self.imshapes] + + @property + def device(self): + return next(iter(self.parameters())).device + + def state_dict(self, trainable=True): + all_params = super().state_dict() + return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} + + def load_state_dict(self, data): + return super().load_state_dict(self.state_dict(trainable=False) | data) + + def _check_edges(self): + indices = sorted({i for edge in self.edges for i in edge}) + assert indices == list(range(len(indices))), 'bad pair indices: missing values ' + return len(indices) + + @torch.no_grad() + def _compute_img_conf(self, pred1_conf, pred2_conf): + im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) + for e, (i, j) in enumerate(self.edges): + im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) + im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) + return im_conf + + def get_adaptors(self): + adapt = self.pw_adaptors + adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) + if self.norm_pw_scale: # normalize so that the product == 1 + adapt = adapt - adapt.mean(dim=1, keepdim=True) + return (adapt / self.pw_break).exp() + + def _get_poses(self, poses): + # normalize rotation + Q = poses[:, :4] + T = signed_expm1(poses[:, 4:7]) + RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() + return RT + + def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): + # all poses == cam-to-world + pose = poses[idx] + if not (pose.requires_grad or force): + return pose + + if R.shape == (4, 4): + assert T is None + T = R[:3, 3] + R = R[:3, :3] + + if R is not None: + pose.data[0:4] = roma.rotmat_to_unitquat(R) + if T is not None: + pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale + + if scale is not None: + assert poses.shape[-1] in (8, 13) + pose.data[-1] = np.log(float(scale)) + return pose + + def get_pw_norm_scale_factor(self): + if self.norm_pw_scale: + # normalize scales so that things cannot go south + # we want that exp(scale) ~= self.base_scale + return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() + else: + return 1 # don't norm scale for known poses + + def get_pw_scale(self): + scale = self.pw_poses[:, -1].exp() # (n_edges,) + scale = scale * self.get_pw_norm_scale_factor() + return scale + + def get_pw_poses(self): # cam to world + RT = self._get_poses(self.pw_poses) + scaled_RT = RT.clone() + scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation + return scaled_RT + + def get_masks(self): + return [(conf > self.min_conf_thr) for conf in self.im_conf] + + def depth_to_pts3d(self): + raise NotImplementedError() + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def _set_focal(self, idx, focal, force=False): + raise NotImplementedError() + + def get_focals(self): + raise NotImplementedError() + + def get_known_focal_mask(self): + raise NotImplementedError() + + def get_principal_points(self): + raise NotImplementedError() + + def get_conf(self, mode=None): + trf = self.conf_trf if mode is None else get_conf_trf(mode) + return [trf(c) for c in self.im_conf] + + def get_im_poses(self): + raise NotImplementedError() + + def _set_depthmap(self, idx, depth, force=False): + raise NotImplementedError() + + def get_depthmaps(self, raw=False): + raise NotImplementedError() + + @torch.no_grad() + def clean_pointcloud(self, tol=0.001, max_bad_conf=0): + """ Method: + 1) express all 3d points in each camera coordinate frame + 2) if they're in front of a depthmap --> then lower their confidence + """ + assert 0 <= tol < 1 + cams = inv(self.get_im_poses()) + K = self.get_intrinsics() + depthmaps = self.get_depthmaps() + res = deepcopy(self) + + for i, pts3d in enumerate(self.depth_to_pts3d()): + for j in range(self.n_imgs): + if i == j: + continue + + # project 3dpts in other view + Hi, Wi = self.imshapes[i] + Hj, Wj = self.imshapes[j] + proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3) + proj_depth = proj[:, :, 2] + u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) + + # check which points are actually in the visible cone + msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj) + msk_j = v[msk_i], u[msk_i] + + # find bad points = those in front but less confident + bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j] + ) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j]) + + bad_msk_i = msk_i.clone() + bad_msk_i[msk_i] = bad_points + res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf) + + return res + + def forward(self, ret_details=False): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors() + proj_pts3d = self.get_pts3d() + # pre-compute pixel weights + weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} + weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} + + loss = 0 + if ret_details: + details = -torch.ones((self.n_imgs, self.n_imgs)) + + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # distance in image i and j + aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) + aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) + li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() + lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() + loss = loss + li + lj + + if ret_details: + details[i, j] = li + lj + loss /= self.n_edges # average over all pairs + + if ret_details: + return loss, details + return loss + + def compute_global_alignment(self, init=None, niter_PnP=10, **kw): + if init is None: + pass + elif init == 'msp' or init == 'mst': + init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) + elif init == 'known_poses': + init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP) + else: + raise ValueError(f'bad value for {init=}') + + global_alignment_loop(self, **kw) + + @torch.no_grad() + def mask_sky(self): + res = deepcopy(self) + for i in range(self.n_imgs): + sky = segment_sky(self.imgs[i]) + res.im_conf[i][sky] = 0 + return res + + def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): + viz = SceneViz() + if self.imgs is None: + colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) + colors = list(map(tuple, colors.tolist())) + for n in range(self.n_imgs): + viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) + else: + viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) + colors = np.random.randint(256, size=(self.n_imgs, 3)) + + # camera poses + im_poses = to_numpy(self.get_im_poses()) + if cam_size is None: + cam_size = auto_cam_size(im_poses) + viz.add_cameras(im_poses, self.get_focals(), colors=colors, + images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) + if show_pw_cams: + pw_poses = self.get_pw_poses() + viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) + + if show_pw_pts3d: + pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] + viz.add_pointcloud(pts, (128, 0, 128)) + + viz.show(**kw) + return viz + + +def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6, verbose=False): + params = [p for p in net.parameters() if p.requires_grad] + if not params: + return net + + if verbose: + print([name for name, value in net.named_parameters() if value.requires_grad]) + + lr_base = lr + optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) + + with tqdm.tqdm(total=niter) as bar: + while bar.n < bar.total: + t = bar.n / bar.total + + if schedule == 'cosine': + lr = cosine_schedule(t, lr_base, lr_min) + elif schedule == 'linear': + lr = linear_schedule(t, lr_base, lr_min) + else: + raise ValueError(f'bad lr {schedule=}') + adjust_learning_rate_by_lr(optimizer, lr) + + optimizer.zero_grad() + loss = net() + loss.backward() + optimizer.step() + loss = float(loss) + bar.set_postfix_str(f'{lr=:g} loss={loss:g}') + bar.update() diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/commons.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..3be9f855a69ea18c82dcc8e5769e0149a59649bd --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/commons.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utility functions for global alignment +# -------------------------------------------------------- +import torch +import torch.nn as nn +import numpy as np + + +def edge_str(i, j): + return f'{i}_{j}' + + +def i_j_ij(ij): + return edge_str(*ij), ij + + +def edge_conf(conf_i, conf_j, edge): + return float(conf_i[edge].mean() * conf_j[edge].mean()) + + +def compute_edge_scores(edges, conf_i, conf_j): + return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} + + +def NoGradParamDict(x): + assert isinstance(x, dict) + return nn.ParameterDict(x).requires_grad_(False) + + +def get_imshapes(edges, pred_i, pred_j): + n_imgs = max(max(e) for e in edges) + 1 + imshapes = [None] * n_imgs + for e, (i, j) in enumerate(edges): + shape_i = tuple(pred_i[e].shape[0:2]) + shape_j = tuple(pred_j[e].shape[0:2]) + if imshapes[i]: + assert imshapes[i] == shape_i, f'incorrect shape for image {i}' + if imshapes[j]: + assert imshapes[j] == shape_j, f'incorrect shape for image {j}' + imshapes[i] = shape_i + imshapes[j] = shape_j + return imshapes + + +def get_conf_trf(mode): + if mode == 'log': + def conf_trf(x): return x.log() + elif mode == 'sqrt': + def conf_trf(x): return x.sqrt() + elif mode == 'm1': + def conf_trf(x): return x-1 + elif mode in ('id', 'none'): + def conf_trf(x): return x + else: + raise ValueError(f'bad mode for {mode=}') + return conf_trf + + +def l2_dist(a, b, weight): + return ((a - b).square().sum(dim=-1) * weight) + + +def l1_dist(a, b, weight): + return ((a - b).norm(dim=-1) * weight) + + +ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) + + +def signed_log1p(x): + sign = torch.sign(x) + return sign * torch.log1p(torch.abs(x)) + + +def signed_expm1(x): + sign = torch.sign(x) + return sign * torch.expm1(torch.abs(x)) + + +def cosine_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 + + +def linear_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_start + (lr_end - lr_start) * t diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/init_im_poses.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/init_im_poses.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed6116be3b81ed5b483fa87dfb013e1e5f1d29a --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/init_im_poses.py @@ -0,0 +1,312 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Initialization functions for global alignment +# -------------------------------------------------------- +from functools import cache + +import numpy as np +import scipy.sparse as sp +import torch +import cv2 +import roma +from tqdm import tqdm + +from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses +from dust3r.post_process import estimate_focal_knowing_depth +from dust3r.viz import to_numpy + +from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores + + +@torch.no_grad() +def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3): + device = self.device + + # indices of known poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + assert nkp == self.n_imgs, 'not all poses are known' + + # get all focals + nkf, _, im_focals = get_known_focals(self) + assert nkf == self.n_imgs + im_pp = self.get_principal_points() + + best_depthmaps = {} + # init all pairwise poses + for e, (i, j) in enumerate(tqdm(self.edges)): + i_j = edge_str(i, j) + + # find relative pose for this pair + P1 = torch.eye(4, device=device) + msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1) + _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()), + pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP) + + # align the two predicted camera with the two gt cameras + s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]]) + # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1 + # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # remember if this is a good depthmap + score = float(self.conf_i[i_j].mean()) + if score > best_depthmaps.get(i, (0,))[0]: + best_depthmaps[i] = score, i_j, s + + # init all image poses + for n in range(self.n_imgs): + assert known_poses_msk[n] + _, i_j, scale = best_depthmaps[n] + depth = self.pred_i[i_j][:, :, 2] + self._set_depthmap(n, depth * scale) + + +@torch.no_grad() +def init_minimum_spanning_tree(self, **kw): + """ Init all camera poses (image-wise and pairwise poses) given + an initial set of pairwise estimations. + """ + device = self.device + pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges, + self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr, + device, has_im_poses=self.has_im_poses, **kw) + + return init_from_pts3d(self, pts3d, im_focals, im_poses) + + +def init_from_pts3d(self, pts3d, im_focals, im_poses): + # init poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + if nkp == 1: + raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose") + elif nkp > 1: + # global rigid SE3 alignment + s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk]) + trf = sRT_to_4x4(s, R, T, device=known_poses.device) + + # rotate everything + im_poses = trf @ im_poses + im_poses[:, :3, :3] /= s # undo scaling on the rotation part + for img_pts3d in pts3d: + img_pts3d[:] = geotrf(trf, img_pts3d) + + # set all pairwise poses + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # compute transform that goes from cam to world + s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # take into account the scale normalization + s_factor = self.get_pw_norm_scale_factor() + im_poses[:, :3, 3] *= s_factor # apply downscaling factor + for img_pts3d in pts3d: + img_pts3d *= s_factor + + # init all image poses + if self.has_im_poses: + for i in range(self.n_imgs): + cam2world = im_poses[i] + depth = geotrf(inv(cam2world), pts3d[i])[..., 2] + self._set_depthmap(i, depth) + self._set_pose(self.im_poses, i, cam2world) + if im_focals[i] is not None: + self._set_focal(i, im_focals[i]) + + print(' init loss =', float(self())) + + +def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr, + device, has_im_poses=True, niter_PnP=10): + n_imgs = len(imshapes) + sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)) + msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() + + # temp variable to store 3d points + pts3d = [None] * len(imshapes) + + todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges + im_poses = [None] * n_imgs + im_focals = [None] * n_imgs + + # init with strongest edge + score, i, j = todo.pop() + print(f' init edge ({i}*,{j}*) {score=}') + i_j = edge_str(i, j) + pts3d[i] = pred_i[i_j].clone() + pts3d[j] = pred_j[i_j].clone() + done = {i, j} + if has_im_poses: + im_poses[i] = torch.eye(4, device=device) + im_focals[i] = estimate_focal(pred_i[i_j]) + + # set intial pointcloud based on pairwise graph + msp_edges = [(i, j)] + while todo: + # each time, predict the next one + score, i, j = todo.pop() + + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[i_j]) + + if i in done: + print(f' init edge ({i},{j}*) {score=}') + assert j not in done + # align pred[i] with pts3d[i], and then set j accordingly + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[j] = geotrf(trf, pred_j[i_j]) + done.add(j) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + + elif j in done: + print(f' init edge ({i}*,{j}) {score=}') + assert i not in done + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[i] = geotrf(trf, pred_i[i_j]) + done.add(i) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + else: + # let's try again later + todo.insert(0, (score, i, j)) + + if has_im_poses: + # complete all missing informations + pair_scores = list(sparse_graph.values()) # already negative scores: less is best + edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)] + for i, j in edges_from_best_to_worse.tolist(): + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[edge_str(i, j)]) + + for i in range(n_imgs): + if im_poses[i] is None: + msk = im_conf[i] > min_conf_thr + res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP) + if res: + im_focals[i], im_poses[i] = res + if im_poses[i] is None: + im_poses[i] = torch.eye(4, device=device) + im_poses = torch.stack(im_poses) + else: + im_poses = im_focals = None + + return pts3d, msp_edges, im_focals, im_poses + + +def dict_to_sparse_graph(dic): + n_imgs = max(max(e) for e in dic) + 1 + res = sp.dok_array((n_imgs, n_imgs)) + for edge, value in dic.items(): + res[edge] = value + return res + + +def rigid_points_registration(pts1, pts2, conf): + R, T, s = roma.rigid_points_registration( + pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True) + return s, R, T # return un-scaled (R, T) + + +def sRT_to_4x4(scale, R, T, device): + trf = torch.eye(4, device=device) + trf[:3, :3] = R * scale + trf[:3, 3] = T.ravel() # doesn't need scaling + return trf + + +def estimate_focal(pts3d_i, pp=None): + if pp is None: + H, W, THREE = pts3d_i.shape + assert THREE == 3 + pp = torch.tensor((W/2, H/2), device=pts3d_i.device) + focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze( + 0), focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5).ravel() + return float(focal) + + +@cache +def pixel_grid(H, W): + return np.mgrid[:W, :H].T.astype(np.float32) + + +def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): + # extract camera poses and focals with RANSAC-PnP + if msk.sum() < 4: + return None # we need at least 4 points for PnP + pts3d, msk = map(to_numpy, (pts3d, msk)) + + H, W, THREE = pts3d.shape + assert THREE == 3 + pixels = pixel_grid(H, W) + + if focal is None: + S = max(W, H) + tentative_focals = np.geomspace(S/2, S*3, 21) + else: + tentative_focals = [focal] + + if pp is None: + pp = (W/2, H/2) + else: + pp = to_numpy(pp) + + best = 0, + for focal in tentative_focals: + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + if not success: + continue + + score = len(inliers) + if success and score > best[0]: + best = score, R, T, focal + + if not best[0]: + return None + + _, R, T, best_focal = best + R = cv2.Rodrigues(R)[0] # world to cam + R, T = map(torch.from_numpy, (R, T)) + return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world + + +def get_known_poses(self): + if self.has_im_poses: + known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses]) + known_poses = self.get_im_poses() + return known_poses_msk.sum(), known_poses_msk, known_poses + else: + return 0, None, None + + +def get_known_focals(self): + if self.has_im_poses: + known_focal_msk = self.get_known_focal_mask() + known_focals = self.get_focals() + return known_focal_msk.sum(), known_focal_msk, known_focals + else: + return 0, None, None + + +def align_multiple_poses(src_poses, target_poses): + N = len(src_poses) + assert src_poses.shape == target_poses.shape == (N, 4, 4) + + def center_and_z(poses): + eps = get_med_dist_between_poses(poses) / 100 + return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2])) + R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True) + return s, R, T diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/optimizer.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e53636dac67739e6e92affae811855bd1e42ac96 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/optimizer.py @@ -0,0 +1,230 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Main class for the implementation of the global alignment +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import xy_grid, geotrf +from dust3r.utils.device import to_cpu, to_numpy + + +class PointCloudOptimizer(BasePCOptimizer): + """ Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs): + super().__init__(*args, **kwargs) + + self.has_im_poses = True # by definition of this class + self.focal_break = focal_break + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) + self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses + self.im_focals = nn.ParameterList(torch.FloatTensor( + [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics + self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + self.imshape = self.imshapes[0] + im_areas = [h*w for h, w in self.imshapes] + self.max_area = max(im_areas) + + # adding thing to optimize + self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area) + self.im_poses = ParameterStack(self.im_poses, is_param=True) + self.im_focals = ParameterStack(self.im_focals, is_param=True) + self.im_pp = ParameterStack(self.im_pp, is_param=True) + self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes])) + self.register_buffer('_grid', ParameterStack( + [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area)) + + # pre-compute pixel weights + self.register_buffer('_weight_i', ParameterStack( + [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area)) + self.register_buffer('_weight_j', ParameterStack( + [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area)) + + # precompute aa + self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area)) + self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area)) + self.register_buffer('_ei', torch.tensor([i for i, j in self.edges])) + self.register_buffer('_ej', torch.tensor([j for i, j in self.edges])) + self.total_area_i = sum([im_areas[i] for i, j in self.edges]) + self.total_area_j = sum([im_areas[j] for i, j in self.edges]) + + def _check_all_imgs_are_selected(self, msk): + assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!' + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + self._check_all_imgs_are_selected(pose_msk) + + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + print(f' (setting pose #{idx} = {pose[:3,3]})') + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = (n_known_poses <= 1) + + self.im_poses.requires_grad_(False) + self.norm_pw_scale = False + + def preset_focal(self, known_focals, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + print(f' (setting focal #{idx} = {focal})') + self._no_grad(self._set_focal(idx, focal)) + + self.im_focals.requires_grad_(False) + + def preset_principal_point(self, known_pp, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + print(f' (setting principal point #{idx} = {pp})') + self._no_grad(self._set_principal_point(idx, pp)) + + self.im_pp.requires_grad_(False) + + def _no_grad(self, tensor): + assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs' + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = self.focal_break * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_break).exp() + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.im_focals]) + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 + return param + + def get_principal_points(self): + return self._pp + 10 * self.im_pp + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().flatten() + K[:, 0, 0] = K[:, 1, 1] = focals + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(self.im_poses) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + depth = _ravel_hw(depth, self.max_area) + + param = self.im_depthmaps[idx] + if param.requires_grad or force: # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self, raw=False): + res = self.im_depthmaps.exp() + if not raw: + res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps(raw=True) + + # get pointmaps in camera frame + rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) + # project to world frame + return geotrf(im_poses, rel_ptmaps) + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def forward(self): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors().unsqueeze(1) + proj_pts3d = self.get_pts3d(raw=True) + + # rotate pairwise prediction according to pw_poses + aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) + aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) + + # compute the less + li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i + lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j + + return li + lj + + +def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): + pp = pp.unsqueeze(1) + focal = focal.unsqueeze(1) + assert focal.shape == (len(depth), 1, 1) + assert pp.shape == (len(depth), 1, 2) + assert pixel_grid.shape == depth.shape + (2,) + depth = depth.unsqueeze(-1) + return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) + + +def ParameterStack(params, keys=None, is_param=None, fill=0): + if keys is not None: + params = [params[k] for k in keys] + + if fill > 0: + params = [_ravel_hw(p, fill) for p in params] + + requires_grad = params[0].requires_grad + assert all(p.requires_grad == requires_grad for p in params) + + params = torch.stack(list(params)).float().detach() + if is_param or requires_grad: + params = nn.Parameter(params) + params.requires_grad_(requires_grad) + return params + + +def _ravel_hw(tensor, fill=0): + # ravel H,W + tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) + + if len(tensor) < fill: + tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:]))) + return tensor + + +def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + return minf*focal_base, maxf*focal_base + + +def apply_mask(img, msk): + img = img.copy() + img[msk] = 0 + return img diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/pair_viewer.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/pair_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..a49e9a17df9ddc489b8fe3dddc027636c0c5973d --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/pair_viewer.py @@ -0,0 +1,125 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dummy optimizer for visualizing pairs +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn +import cv2 + +from dust3r.cloud_opt.base_opt import BasePCOptimizer +from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates +from dust3r.cloud_opt.commons import edge_str +from dust3r.post_process import estimate_focal_knowing_depth + + +class PairViewer (BasePCOptimizer): + """ + This a Dummy Optimizer. + To use only when the goal is to visualize the results for a pair of images (with is_symmetrized) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.is_symmetrized and self.n_edges == 2 + self.has_im_poses = True + + # compute all parameters directly from raw input + self.focals = [] + self.pp = [] + rel_poses = [] + confs = [] + for i in range(self.n_imgs): + conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean()) + print(f' - {conf=:.3} for edge {i}-{1-i}') + confs.append(conf) + + H, W = self.imshapes[i] + pts3d = self.pred_i[edge_str(i, 1-i)] + pp = torch.tensor((W/2, H/2)) + focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld')) + self.focals.append(focal) + self.pp.append(pp) + + # estimate the pose of pts1 in image 2 + pixels = np.mgrid[:W, :H].T.astype(np.float32) + pts3d = self.pred_j[edge_str(1-i, i)].numpy() + assert pts3d.shape[:2] == (H, W) + msk = self.get_masks()[i].numpy() + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + try: + res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, + iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) + success, R, T, inliers = res + assert success + + R = cv2.Rodrigues(R)[0] # world to cam + pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world + except: + pose = np.eye(4) + rel_poses.append(torch.from_numpy(pose.astype(np.float32))) + + # let's use the pair with the most confidence + if confs[0] > confs[1]: + # ptcloud is expressed in camera1 + self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1 + self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]] + else: + # ptcloud is expressed in camera2 + self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2 + self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]] + + self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False) + self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False) + self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False) + self.depth = nn.ParameterList(self.depth) + for p in self.parameters(): + p.requires_grad = False + + def _set_depthmap(self, idx, depth, force=False): + print('_set_depthmap is ignored in PairViewer') + return + + def get_depthmaps(self, raw=False): + depth = [d.to(self.device) for d in self.depth] + return depth + + def _set_focal(self, idx, focal, force=False): + self.focals[idx] = focal + + def get_focals(self): + return self.focals + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.focals]) + + def get_principal_points(self): + return self.pp + + def get_intrinsics(self): + focals = self.get_focals() + pps = self.get_principal_points() + K = torch.zeros((len(focals), 3, 3), device=self.device) + for i in range(len(focals)): + K[i, 0, 0] = K[i, 1, 1] = focals[i] + K[i, :2, 2] = pps[i] + K[i, 2, 2] = 1 + return K + + def get_im_poses(self): + return self.im_poses + + def depth_to_pts3d(self): + pts3d = [] + for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()): + pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(), + intrinsics.cpu().numpy(), + im_pose.cpu().numpy()) + pts3d.append(torch.from_numpy(pts).to(device=self.device)) + return pts3d + + def forward(self): + return float('nan') diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5e79718e4a3eb2e31c60c8a390e61a19ec5432 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/__init__.py @@ -0,0 +1,42 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +from .utils.transforms import * +from .base.batched_sampler import BatchedRandomSampler # noqa: F401 +from .co3d import Co3d # noqa: F401 + + +def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): + import torch + from croco.utils.misc import get_world_size, get_rank + + # pytorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + try: + sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, + rank=rank, drop_last=drop_last) + except (AttributeError, NotImplementedError): + # not avail for this dataset + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + ) + + return data_loader diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/base_stereo_view_dataset.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/base_stereo_view_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17390ca29d4437fc41f3c946b235888af9e4c888 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/base_stereo_view_dataset.py @@ -0,0 +1,220 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# base class for implementing datasets +# -------------------------------------------------------- +import PIL +import numpy as np +import torch + +from dust3r.datasets.base.easy_dataset import EasyDataset +from dust3r.datasets.utils.transforms import ImgNorm +from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates +import dust3r.datasets.utils.cropping as cropping + + +class BaseStereoViewDataset (EasyDataset): + """ Define all basic options. + + Usage: + class MyDataset (BaseStereoViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__(self, *, # only keyword arguments + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + seed=None): + self.num_views = 2 + self.split = split + self._set_resolutions(resolution) + + self.transform = transform + if isinstance(transform, str): + transform = eval(transform) + + self.aug_crop = aug_crop + self.seed = seed + + def __len__(self): + return len(self.scenes) + + def get_stats(self): + return f"{len(self)} pairs" + + def __repr__(self): + resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']' + return f"""{type(self).__name__}({self.get_stats()}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '') + + def _get_views(self, idx, resolution, rng): + raise NotImplementedError() + + def __getitem__(self, idx): + if isinstance(idx, tuple): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, '_rng'): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # over-loaded code + resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng) + assert len(views) == self.num_views + + # check data-types + for v, view in enumerate(views): + assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view['idx'] = (idx, ar_idx, v) + + # encode the image + width, height = view['img'].size + view['true_shape'] = np.int32((height, width)) + view['img'] = self.transform(view['img']) + + assert 'camera_intrinsics' in view + if 'camera_pose' not in view: + view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}' + assert 'pts3d' not in view + assert 'valid_mask' not in view + assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}' + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view['pts3d'] = pts3d + view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view['camera_intrinsics'] + + # last thing done! + for view in views: + # transpose to make sure all views are the same size + transpose_to_landscape(view) + # this allows to check whether the RNG is is the same state each time + view['rng'] = int.from_bytes(self._rng.bytes(4), 'big') + return views + + def _set_resolutions(self, resolutions): + assert resolutions is not None, 'undefined resolution' + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int' + assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int' + assert width >= height + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None): + """ This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W-cx) + min_margin_y = min(cy, H-cy) + assert min_margin_x > W/5, f'Bad principal point in view={info}' + assert min_margin_y > H/5, f'Bad principal point in view={info}' + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + # transpose the resolution if necessary + W, H = image.size # new size + assert resolution[0] >= resolution[1] + if H > 1.1*W: + # image is portrait mode + resolution = resolution[::-1] + elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2): + resolution = resolution[::-1] + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + if self.aug_crop > 1: + target_resolution += rng.integers(0, self.aug_crop) + image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) + + # actual cropping (if necessary) with bilinear interpolation + intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5) + crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) + + return image, depthmap, intrinsics2 + + +def is_good_type(key, v): + """ returns (is_good, err_msg) + """ + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x + db = sel(view['dataset']) + label = sel(view['label']) + instance = sel(view['instance']) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view['true_shape'] + + if width < height: + # rectify portrait to landscape + assert view['img'].shape == (3, height, width) + view['img'] = view['img'].swapaxes(1, 2) + + assert view['valid_mask'].shape == (height, width) + view['valid_mask'] = view['valid_mask'].swapaxes(0, 1) + + assert view['depthmap'].shape == (height, width) + view['depthmap'] = view['depthmap'].swapaxes(0, 1) + + assert view['pts3d'].shape == (height, width, 3) + view['pts3d'] = view['pts3d'].swapaxes(0, 1) + + # transpose x and y pixels + view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]] diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/batched_sampler.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..85f58a65d41bb8101159e032d5b0aac26a7cf1a1 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/batched_sampler.py @@ -0,0 +1,74 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Random sampling under a constraint +# -------------------------------------------------------- +import numpy as np +import torch + + +class BatchedRandomSampler: + """ Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): + self.batch_size = batch_size + self.pool_size = pool_size + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size*world_size) if drop_last else N + assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' + + # distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + return self.total_size // self.world_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + # prepare RNG + if self.epoch is None: + assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # random feat_idxs (same across each batch) + n_batches = (self.total_size+self.batch_size-1) // self.batch_size + feat_idxs = rng.integers(self.pool_size, size=n_batches) + feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) + feat_idxs = feat_idxs.ravel()[:self.total_size] + + # put them together + idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) + + # Distributed sampler: we select a subset of batches + # make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ((self.total_size + self.world_size * + self.batch_size-1) // (self.world_size * self.batch_size)) + idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple-1 + return (total//multiple) * multiple diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/easy_dataset.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4939a88f02715a1f80be943ddb6d808e1be84db7 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/easy_dataset.py @@ -0,0 +1,157 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# A dataset base class that you can easily resize and combine. +# -------------------------------------------------------- +import numpy as np +from dust3r.datasets.base.batched_sampler import BatchedRandomSampler + + +class EasyDataset: + """ a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) + + +class MulDataset (EasyDataset): + """ Artifically augmenting the size of a dataset. + """ + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f'{self.multiplicator}*{repr(self.dataset)}' + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[idx // self.multiplicator, other] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class ResizedDataset (EasyDataset): + """ Artifically changing the size of a dataset. + """ + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str)-1) // 3): + sep = -4*i-3 + size_str = size_str[:sep] + '_' + size_str[sep:] + return f'{size_str} @ {repr(self.dataset)}' + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch+777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) + self._idxs_mapping = shuffled_idxs[:self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def __getitem__(self, idx): + assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[self._idxs_mapping[idx], other] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class CatDataset (EasyDataset): + """ Concatenation of several datasets + """ + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other = idx + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, 'right') + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None: + new_idx = (new_idx, other) + return dataset[new_idx] + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/co3d.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/co3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc94f9420d86372e643c00e7cddf85b3d1982c6 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/co3d.py @@ -0,0 +1,146 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed Co3d_v2 +# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International +# See datasets_preprocess/preprocess_co3d.py +# -------------------------------------------------------- +import os.path as osp +import json +import itertools +from collections import deque + +import cv2 +import numpy as np + +from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from dust3r.utils.image import imread_cv2 + + +class Co3d(BaseStereoViewDataset): + def __init__(self, mask_bg=True, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert mask_bg in (True, False, 'rand') + self.mask_bg = mask_bg + + # load all scenes + with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f: + self.scenes = json.load(f) + self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} + self.scenes = {(k, k2): v2 for k, v in self.scenes.items() + for k2, v2 in v.items()} + self.scene_list = list(self.scenes.keys()) + + # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) + # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees + self.combinations = [(i, j) + for i, j in itertools.combinations(range(100), 2) + if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0] + + self.invalidate = {scene: {} for scene in self.scene_list} + + def __len__(self): + return len(self.scene_list) * len(self.combinations) + + def _get_views(self, idx, resolution, rng): + # choose a scene + obj, instance = self.scene_list[idx // len(self.combinations)] + image_pool = self.scenes[obj, instance] + im1_idx, im2_idx = self.combinations[idx % len(self.combinations)] + + # add a bit of randomness + last = len(image_pool)-1 + + if resolution not in self.invalidate[obj, instance]: # flag invalid images + self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] + + # decide now if we mask the bg + mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) + + views = [] + imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]] + imgs_idxs = deque(imgs_idxs) + while len(imgs_idxs) > 0: # some images (few) have zero depth + im_idx = imgs_idxs.pop() + + if self.invalidate[obj, instance][resolution][im_idx]: + # search for a valid image + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(image_pool)): + tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) + if not self.invalidate[obj, instance][resolution][tentative_im_idx]: + im_idx = tentative_im_idx + break + + view_idx = image_pool[im_idx] + + impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') + + # load camera params + input_metadata = np.load(impath.replace('jpg', 'npz')) + camera_pose = input_metadata['camera_pose'].astype(np.float32) + intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED) + depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth']) + + if mask_bg: + # load object mask + maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') + maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) + maskmap = (maskmap / 255.0) > 0.1 + + # update the depthmap with mask + depthmap *= maskmap + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) + + num_valid = (depthmap > 0.0).sum() + if num_valid == 0: + # problem, invalidate image and retry + self.invalidate[obj, instance][resolution][im_idx] = True + imgs_idxs.append(im_idx) + continue + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset='Co3d_v2', + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + )) + return views + + +if __name__ == "__main__": + from dust3r.datasets.base.base_stereo_view_dataset import view_name + from dust3r.viz import SceneViz, auto_cam_size + from dust3r.utils.image import rgb + + dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16) + + for idx in np.random.permutation(len(dataset)): + views = dataset[idx] + assert len(views) == 2 + print(view_name(views[0]), view_name(views[1])) + viz = SceneViz() + poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] + cam_size = max(auto_cam_size(poses), 0.001) + for view_idx in [0, 1]: + pts3d = views[view_idx]['pts3d'] + valid_mask = views[view_idx]['valid_mask'] + colors = rgb(views[view_idx]['img']) + viz.add_pointcloud(pts3d, colors, valid_mask) + viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], + focal=views[view_idx]['camera_intrinsics'][0, 0], + color=(idx*255, (1 - idx)*255, 0), + image=colors, + cam_size=cam_size) + viz.show() diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/cropping.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..02b1915676f3deea24f57032f7588ff34cbfaeb9 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/cropping.py @@ -0,0 +1,119 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa +try: + lanczos = PIL.Image.Resampling.LANCZOS +except AttributeError: + lanczos = PIL.Image.LANCZOS + + +class ImageList: + """ Convenience class to aply the same operation to a whole set of images. + """ + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch('resize', *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch('crop', *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution): + """ Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + assert output_resolution.shape == (2,) + # define output resolution + scale_final = max(output_resolution / image.size) + 1e-8 + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize(output_resolution, resample=lanczos) + if depthmap is not None: + depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, + fy=scale_final, interpolation=cv2.INTER_NEAREST) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l+out_width, t+out_height) + return crop_bbox diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/transforms.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..eb34f2f01d3f8f829ba71a7e03e181bf18f72c25 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/transforms.py @@ -0,0 +1,11 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf +from dust3r.utils.image import ImgNorm + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53d0aa5610cae95f34f96bdb3ff9e835a2d6208e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/__init__.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# head factory +# -------------------------------------------------------- +from .linear_head import LinearPts3d +from .dpt_head import create_dpt_head + + +def head_factory(head_type, output_mode, net, has_conf=False): + """" build a prediction head for the decoder + """ + if head_type == 'linear' and output_mode == 'pts3d': + return LinearPts3d(net, has_conf) + elif head_type == 'dpt' and output_mode == 'pts3d': + return create_dpt_head(net, has_conf=has_conf) + else: + raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/dpt_head.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3470ac507a776e4af32f39c317c77e9351b96c4b --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/dpt_head.py @@ -0,0 +1,114 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dpt head implementation for DUST3R +# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; +# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True +# the forward function also takes as input a dictionnary img_info with key "height" and "width" +# for PixelwiseTask, the output will be of dimension B x num_channels x H x W +# -------------------------------------------------------- +from einops import rearrange +from typing import List +import torch +import torch.nn as nn +from .postprocess import postprocess +from ...croco.dpt_block import DPTOutputAdapter # noqa + + +class DPTOutputAdapter_fix(DPTOutputAdapter): + """ + Adapt croco's DPTOutputAdapter implementation for dust3r: + remove duplicated weigths, and fix forward for dust3r + """ + + def init(self, dim_tokens_enc=768): + super().init(dim_tokens_enc) + # these are duplicated weights + del self.act_1_postprocess + del self.act_2_postprocess + del self.act_3_postprocess + del self.act_4_postprocess + + def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + # H, W = input_info['image_size'] + image_size = self.image_size if image_size is None else image_size + H, W = image_size + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out + + +class PixelwiseTaskWithDPT(nn.Module): + """ DPT module for dust3r, can return 3D points + confidence for all pixels""" + + def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, + output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_layers = True # backbone needs to return all layers + self.postprocess = postprocess + self.depth_mode = depth_mode + self.conf_mode = conf_mode + + assert n_cls_token == 0, "Not implemented" + dpt_args = dict(output_width_ratio=output_width_ratio, + num_channels=num_channels, + **kwargs) + if hooks_idx is not None: + dpt_args.update(hooks=hooks_idx) + self.dpt = DPTOutputAdapter_fix(**dpt_args) + dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info[0], img_info[1])) + if self.postprocess: + out = self.postprocess(out, self.depth_mode, self.conf_mode) + return out + + +def create_dpt_head(net, has_conf=False): + """ + return PixelwiseTaskWithDPT for given net params + """ + assert net.dec_depth > 9 + l2 = net.dec_depth + feature_dim = 256 + last_dim = feature_dim//2 + out_nchan = 3 + ed = net.enc_embed_dim + dd = net.dec_embed_dim + return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=[0, l2*2//4, l2*3//4, l2], + dim_tokens=[ed, dd, dd, dd], + postprocess=postprocess, + depth_mode=net.depth_mode, + conf_mode=net.conf_mode, + head_type='regression') diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/linear_head.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..27c5678d551033cc576798626b7ba59b1e7b20cc --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/linear_head.py @@ -0,0 +1,41 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# linear head implementation for DUST3R +# -------------------------------------------------------- +import torch.nn as nn +import torch.nn.functional as F +from .postprocess import postprocess + + +class LinearPts3d (nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, net, has_conf=False): + super().__init__() + self.patch_size = net.patch_embed.patch_size[0] + self.depth_mode = net.depth_mode + self.conf_mode = net.conf_mode + self.has_conf = has_conf + + self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) + + def setup(self, croconet): + pass + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return postprocess(feat, self.depth_mode, self.conf_mode) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/postprocess.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..cd68a90d89b8dcd7d8a4b4ea06ef8b17eb5da093 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/postprocess.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# post process function for all heads: extract 3D points/confidence from output +# -------------------------------------------------------- +import torch + + +def postprocess(out, depth_mode, conf_mode): + """ + extract 3D points/confidence from prediction head output + """ + fmap = out.permute(0, 2, 3, 1) # B,H,W,3 + res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) + + if conf_mode is not None: + res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) + return res + + +def reg_dense_depth(xyz, mode): + """ + extract 3D points from prediction head output + """ + mode, vmin, vmax = mode + + no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) + assert no_bounds + + if mode == 'linear': + if no_bounds: + return xyz # [-inf, +inf] + return xyz.clip(min=vmin, max=vmax) + + # distance to origin + d = xyz.norm(dim=-1, keepdim=True) + xyz = xyz / d.clip(min=1e-8) + + if mode == 'square': + return xyz * d.square() + + if mode == 'exp': + return xyz * torch.expm1(d) + + raise ValueError(f'bad {mode=}') + + +def reg_dense_conf(x, mode): + """ + extract confidence from prediction head output + """ + mode, vmin, vmax = mode + if mode == 'exp': + return vmin + x.exp().clip(max=vmax-vmin) + if mode == 'sigmoid': + return (vmax - vmin) * torch.sigmoid(x) + vmin + raise ValueError(f'bad {mode=}') diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/image_pairs.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/image_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..9251dc822b6b4b11bb9149dfd256ee1e66947562 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/image_pairs.py @@ -0,0 +1,83 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed to load image pairs +# -------------------------------------------------------- +import numpy as np +import torch + + +def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True): + pairs = [] + + if scene_graph == 'complete': # complete graph + for i in range(len(imgs)): + for j in range(i): + pairs.append((imgs[i], imgs[j])) + + elif scene_graph.startswith('swin'): + winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3 + for i in range(len(imgs)): + for j in range(winsize): + idx = (i + j) % len(imgs) # explicit loop closure + pairs.append((imgs[i], imgs[idx])) + + elif scene_graph.startswith('oneref'): + refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 + for j in range(len(imgs)): + if j != refid: + pairs.append((imgs[refid], imgs[j])) + + elif scene_graph == 'pairs': + assert len(imgs) % 2 == 0 + for i in range(0, len(imgs), 2): + pairs.append((imgs[i], imgs[i+1])) + + if symmetrize: + pairs += [(img2, img1) for img1, img2 in pairs] + + # now, remove edges + if isinstance(prefilter, str) and prefilter.startswith('seq'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:])) + + if isinstance(prefilter, str) and prefilter.startswith('cyc'): + pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) + + return pairs + + +def sel(x, kept): + if isinstance(x, dict): + return {k: sel(v, kept) for k, v in x.items()} + if isinstance(x, (torch.Tensor, np.ndarray)): + return x[kept] + if isinstance(x, (tuple, list)): + return type(x)([x[k] for k in kept]) + + +def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): + # number of images + n = max(max(e) for e in edges)+1 + + kept = [] + for e, (i, j) in enumerate(edges): + dis = abs(i-j) + if cyclic: + dis = min(dis, abs(i+n-j), abs(i-n-j)) + if dis <= seq_dis_thr: + kept.append(e) + return kept + + +def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): + edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + return [pairs[i] for i in kept] + + +def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): + edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') + return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/inference.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..708bd46e7d67448bcc05cb7a6d717e3dbffe81a3 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/inference.py @@ -0,0 +1,165 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed for the inference +# -------------------------------------------------------- +import tqdm +import torch +from .utils.device import to_cpu, collate_with_cat +from .model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model +from .utils.misc import invalid_to_nans +from .utils.geometry import depthmap_to_pts3d, geotrf + + +def load_model(model_path, device): + print('... loading model from', model_path) + ckpt = torch.load(model_path, map_location='cpu') + args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if 'landscape_only' not in args: + args = args[:-1] + ', landscape_only=False)' + else: + args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') + assert "landscape_only=False" in args + print(f"instantiating : {args}") + net = eval(args) + print(net.load_state_dict(ckpt['model'], strict=False)) + return net.to(device) + + +def _interleave_imgs(img1, img2): + res = {} + for key, value1 in img1.items(): + value2 = img2[key] + if isinstance(value1, torch.Tensor): + value = torch.stack((value1, value2), dim=1).flatten(0, 1) + else: + value = [x for pair in zip(value1, value2) for x in pair] + res[key] = value + return res + + +def make_batch_symmetric(batch): + view1, view2 = batch + view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) + return view1, view2 + + +def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): + view1, view2 = batch + for view in batch: + for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal + if name not in view: + continue + view[name] = view[name].to(device, non_blocking=True) + + if symmetrize_batch: + view1, view2 = make_batch_symmetric(batch) + + with torch.cuda.amp.autocast(enabled=bool(use_amp)): + pred1, pred2 = model(view1, view2) + + # loss is supposed to be symmetric + with torch.cuda.amp.autocast(enabled=False): + loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None + + result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) + return result[ret] if ret else result + + +@torch.no_grad() +def inference(pairs, model, device, batch_size=8): + print(f'>> Inference with model on {len(pairs)} image pairs') + result = [] + + # first, check if all images have the same size + multiple_shapes = not (check_if_same_size(pairs)) + if multiple_shapes: # force bs=1 + batch_size = 1 + + for i in tqdm.trange(0, len(pairs), batch_size): + res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device) + result.append(to_cpu(res)) + + result = collate_with_cat(result, lists=multiple_shapes) + + torch.cuda.empty_cache() + return result + + +def check_if_same_size(pairs): + shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs] + shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs] + return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2) + + +def get_pred_pts3d(gt, pred, use_pose=False): + if 'depth' in pred and 'pseudo_focal' in pred: + try: + pp = gt['camera_intrinsics'][..., :2, 2] + except KeyError: + pp = None + pts3d = depthmap_to_pts3d(**pred, pp=pp) + + elif 'pts3d' in pred: + # pts3d from my camera + pts3d = pred['pts3d'] + + elif 'pts3d_in_other_view' in pred: + # pts3d from the other camera, already transformed + assert use_pose is True + return pred['pts3d_in_other_view'] # return! + + if use_pose: + camera_pose = pred.get('camera_pose') + assert camera_pose is not None + pts3d = geotrf(camera_pose, pts3d) + + return pts3d + + +def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None): + assert gt_pts1.ndim == pr_pts1.ndim == 4 + assert gt_pts1.shape == pr_pts1.shape + if gt_pts2 is not None: + assert gt_pts2.ndim == pr_pts2.ndim == 4 + assert gt_pts2.shape == pr_pts2.shape + + # concat the pointcloud + nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) + nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None + + pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) + pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None + + all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1 + all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 + + dot_gt_pr = (all_pr * all_gt).sum(dim=-1) + dot_gt_gt = all_gt.square().sum(dim=-1) + + if fit_mode.startswith('avg'): + # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + elif fit_mode.startswith('median'): + scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values + elif fit_mode.startswith('weiszfeld'): + # init scaling with l2 closed form + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip_(min=1e-8).reciprocal() + # update the scaling with the new weights + scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) + else: + raise ValueError(f'bad {fit_mode=}') + + if fit_mode.endswith('stop_grad'): + scaling = scaling.detach() + + scaling = scaling.clip(min=1e-3) + # assert scaling.isfinite().all(), bb() + return scaling diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/losses.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6e20fd3a30d6d498afdc13ec852ae984d05f7e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/losses.py @@ -0,0 +1,297 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Implementation of DUSt3R training losses +# -------------------------------------------------------- +from copy import copy, deepcopy +import torch +import torch.nn as nn + +from dust3r.inference import get_pred_pts3d, find_opt_scaling +from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud +from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale + + +def Sum(*losses_and_masks): + loss, mask = losses_and_masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + return losses_and_masks + else: + # we are returning the global loss + for loss2, mask2 in losses_and_masks[1:]: + loss = loss + loss2 + return loss + + +class LLoss (nn.Module): + """ L-norm loss + """ + + def __init__(self, reduction='mean'): + super().__init__() + self.reduction = reduction + + def forward(self, a, b): + assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' + dist = self.distance(a, b) + assert dist.ndim == a.ndim-1 # one dimension less + if self.reduction == 'none': + return dist + if self.reduction == 'sum': + return dist.sum() + if self.reduction == 'mean': + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f'bad {self.reduction=} mode') + + def distance(self, a, b): + raise NotImplementedError() + + +class L21Loss (LLoss): + """ Euclidean distance between 3d points """ + + def distance(self, a, b): + return torch.norm(a - b, dim=-1) # normalized L2 distance + + +L21 = L21Loss() + + +class Criterion (nn.Module): + def __init__(self, criterion=None): + super().__init__() + assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb() + self.criterion = copy(criterion) + + def get_name(self): + return f'{type(self).__name__}({self.criterion})' + + def with_reduction(self, mode): + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = 'none' # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss (nn.Module): + """ Easily combinable losses (also keep track of individual loss values): + loss = MyLoss1() + 0.1*MyLoss2() + Usage: + Inherit from this class and override get_name() and compute_loss() + """ + + def __init__(self): + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + def __mul__(self, alpha): + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + __rmul__ = __mul__ # same + + def __add__(self, loss2): + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + # find the end of the chain + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + name = self.get_name() + if self._alpha != 1: + name = f'{self._alpha:g}*{name}' + if self._loss2: + name = f'{name} + {self._loss2}' + return name + + def forward(self, *args, **kwargs): + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class Regr3D (Criterion, MultiLoss): + """ Ensure that all 3D points are correct. + Asymmetric loss: view1 is supposed to be the anchor. + + P1 = RT1 @ D1 + P2 = RT2 @ D2 + loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) + loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) + = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) + """ + + def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False): + super().__init__(criterion) + self.norm_mode = norm_mode + self.gt_scale = gt_scale + + def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): + # everything is normalized w.r.t. camera of view1 + in_camera1 = inv(gt1['camera_pose']) + gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3 + gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3 + + valid1 = gt1['valid_mask'].clone() + valid2 = gt2['valid_mask'].clone() + + if dist_clip is not None: + # points that are too far-away == invalid + dis1 = gt_pts1.norm(dim=-1) # (B, H, W) + dis2 = gt_pts2.norm(dim=-1) # (B, H, W) + valid1 = valid1 & (dis1 <= dist_clip) + valid2 = valid2 & (dis2 <= dist_clip) + + pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) + pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True) + + # normalize 3d points + if self.norm_mode: + pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) + if self.norm_mode and not self.gt_scale: + gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) + + return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {} + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) + # loss on img1 side + l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) + # loss on gt2 side + l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2]) + self_name = type(self).__name__ + details = {self_name+'_pts3d_1': float(l1.mean()), self_name+'_pts3d_2': float(l2.mean())} + return Sum((l1, mask1), (l2, mask2)), (details | monitoring) + + +class ConfLoss (MultiLoss): + """ Weighted regression by learned confidence. + Assuming the input pixel_loss is a pixel-level regression loss. + + Principle: + high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) + low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) + + alpha: hyperparameter + """ + + def __init__(self, pixel_loss, alpha=1): + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction('none') + + def get_name(self): + return f'ConfLoss({self.pixel_loss})' + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_loss(self, gt1, gt2, pred1, pred2, **kw): + # compute per-pixel loss + ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) + if loss1.numel() == 0: + print('NO VALID POINTS in img1', force=True) + if loss2.numel() == 0: + print('NO VALID POINTS in img2', force=True) + + # weight by confidence + conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) + conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) + conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 + conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 + + # average + nan protection (in case of no valid pixels at all) + conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 + conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 + + return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) + + +class Regr3D_ShiftInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute unnormalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ + super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # compute median depth + gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] + pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] + gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] + pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] + + # subtract the median depth + gt_z1 -= gt_shift_z + gt_z2 -= gt_shift_z + pred_z1 -= pred_shift_z + pred_z2 -= pred_shift_z + + # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach()) + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleInv (Regr3D): + """ Same than Regr3D but invariant to depth shift. + if gt_scale == True: enforce the prediction to take the same scale than GT + """ + + def get_all_pts3d(self, gt1, gt2, pred1, pred2): + # compute depth-normalized points + gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2) + + # measure scene scale + _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) + _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) + + # prevent predictions to be in a ridiculous range + pred_scale = pred_scale.clip(min=1e-3, max=1e3) + + # subtract the median depth + if self.gt_scale: + pred_pts1 *= gt_scale / pred_scale + pred_pts2 *= gt_scale / pred_scale + # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean()) + else: + gt_pts1 /= gt_scale + gt_pts2 /= gt_scale + pred_pts1 /= pred_scale + pred_pts2 /= pred_scale + # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()) + + return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring + + +class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): + # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv + pass diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/model.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/model.py new file mode 100644 index 0000000000000000000000000000000000000000..96ce519e30ccefa06afe55aa1d7b4e9188c74f55 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/model.py @@ -0,0 +1,167 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R model class +# -------------------------------------------------------- +from copy import deepcopy +import torch + +from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape +from .heads import head_factory +from .patch_embed import get_patch_embed + +from ..croco.croco import CroCoNet # noqa +inf = float('inf') + +class AsymmetricCroCo3DStereo(CroCoNet): + """ Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__(self, + output_mode='pts3d', + head_type='linear', + depth_mode=('exp', -inf, inf), + conf_mode=('exp', 1, inf), + freeze='none', + landscape_only=True, + patch_embed_cls='PatchEmbedDust3R', # PatchEmbedDust3R or ManyAR_PatchEmbed + **croco_kwargs): + self.patch_embed_cls = patch_embed_cls + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + + # dust3r specific initialization + self.dec_blocks2 = deepcopy(self.dec_blocks) + self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) + self.set_freeze(freeze) + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) + + def load_state_dict(self, ckpt, **kw): + # duplicate all weights for the second decoder if not present + new_ckpt = dict(ckpt) + if not any(k.startswith('dec_blocks2') for k in ckpt): + for key, value in ckpt.items(): + if key.startswith('dec_blocks'): + new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value + return super().load_state_dict(new_ckpt, **kw) + + def device(self): + return next(self.dec_blocks2.parameters()).device + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + 'none': [], + 'mask': [self.mask_token], + 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], + } + freeze_all_params(to_be_frozen[freeze]) + + def _set_prediction_head(self, *args, **kwargs): + """ No prediction head """ + return + + def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, + **kw): + assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ + f'{img_size=} must be multiple of {patch_size=}' + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + # allocate heads + self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) + # magic wrapper + self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) + self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) + + def _encode_image(self, image, true_shape=None): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + + # add positional embedding without cls token + assert self.enc_pos_embed is None + + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, pos) + + x = self.enc_norm(x) + return x, pos, None + + def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): + if img1.shape[-2:] == img2.shape[-2:]: + out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), + torch.cat((true_shape1, true_shape2), dim=0)) + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + else: + out, pos, _ = self._encode_image(img1, true_shape1) + out2, pos2, _ = self._encode_image(img2, true_shape2) + return out, out2, pos, pos2 + + def _encode_symmetrized(self, view1, view2): + img1 = view1['img'] + img2 = view2['img'] + B = img1.shape[0] + # Recover true_shape when available, otherwise assume that the img shape is the true one + shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) + shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) + # warning! maybe the images have different portrait/landscape orientations + + if is_symmetrized(view1, view2): + # computing half of forward pass!' + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2]) + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2): + final_output = [(f1, f2)] # before projection + + # project to decoder dim + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + final_output.append((f1, f2)) + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + # img1 side + f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) + # img2 side + f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) + # store the result + final_output.append((f1, f2)) + + # normalize last output + del final_output[1] # duplicate with final_output[0] + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + # img_shape = tuple(map(int, img_shape)) + head = getattr(self, f'head{head_num}') + return head(decout, img_shape) + + def forward(self, view1, view2): + # encode the two images --> B,S,D + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2) + + # combine all ref images into object-centric representation + dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) + + with torch.cuda.amp.autocast(enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2['pts3d_in_other_view'] = res2.pop('pts3d') # predict view2's pts3d in view1's frame + return res1, res2 diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/optim_factory.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9c16e0e0fda3fd03c3def61abc1f354f75c584 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/optim_factory.py @@ -0,0 +1,14 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# optimization functions +# -------------------------------------------------------- + + +def adjust_learning_rate_by_lr(optimizer, lr): + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/patch_embed.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..4ecd49ed62613f49df7ad82f40202b10843e7885 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/patch_embed.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# PatchEmbed implementation for DUST3R, +# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio +# -------------------------------------------------------- +import torch +# import dust3r.utils.path_to_croco # noqa: F401 +from ..croco.blocks import PatchEmbed + + +def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): + assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] + patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) + return patch_embed + + +class PatchEmbedDust3R(PatchEmbed): + def forward(self, x, **kw): + B, C, H, W = x.shape + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + +class ManyAR_PatchEmbed (PatchEmbed): + """ Handle images with non-square aspect ratio. + All images in the same batch have the same aspect ratio. + true_shape = [(height, width) ...] indicates the actual shape of each image. + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + self.embed_dim = embed_dim + super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) + + def forward(self, img, true_shape): + B, C, H, W = img.shape + assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' + assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" + + # size expressed in tokens + W //= self.patch_size[0] + H //= self.patch_size[1] + n_tokens = H * W + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # allocate result + x = img.new_zeros((B, n_tokens, self.embed_dim)) + pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) + + # linear projection, transposed if necessary + x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() + x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() + + pos[is_landscape] = self.position_getter(1, H, W, pos.device) + pos[is_portrait] = self.position_getter(1, W, H, pos.device) + + x = self.norm(x) + return x, pos diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/post_process.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..7d953bce2149eca6021f280d28945378ad8d77e0 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/post_process.py @@ -0,0 +1,60 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities for interpreting the DUST3R output +# -------------------------------------------------------- +import numpy as np +import torch +from dust3r.utils.geometry import xy_grid + + +def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0.5, max_focal=3.5): + """ Reprojection method, for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + # centered pixel grid + pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 + pts3d = pts3d.flatten(1, 2) # (B, HW, 3) + + if focal_mode == 'median': + with torch.no_grad(): + # direct estimation of focal + u, v = pixels.unbind(dim=-1) + x, y, z = pts3d.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + # assume square pixels, hence same focal for X and Y + f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) + focal = torch.nanmedian(f_votes, dim=-1).values + + elif focal_mode == 'weiszfeld': + # init focal with l2 closed form + # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| + xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) + + dot_xy_px = (xy_over_z * pixels).sum(dim=-1) + dot_xy_xy = xy_over_z.square().sum(dim=-1) + + focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) + + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip(min=1e-8).reciprocal() + # update the scaling with the new weights + focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) + else: + raise ValueError(f'bad {focal_mode=}') + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 + focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) + # print(focal) + return focal diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/device.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b6a74dac05a2e1ba3a2b2f0faa8cea08ece745 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/device.py @@ -0,0 +1,76 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + ''' + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == 'numpy': + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): return todevice(x, 'numpy') +def to_cpu(x): return todevice(x, 'cpu') +def to_cuda(x): return todevice(x, 'cuda') + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) + + # otherwise, we just chain lists + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/geometry.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a11fe4fbddb5e085a8053a373dd91d29d2c664 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/geometry.py @@ -0,0 +1,361 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# geometry utilitary functions +# -------------------------------------------------------- +import torch +import numpy as np +from scipy.spatial import cKDTree as KDTree + +from .misc import invalid_to_zeros, invalid_to_nans +from .device import to_numpy + + +def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): + """ Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing='xy') + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """ Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and + Trf.ndim == 3 and pts.ndim == 4): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d+1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') + else: + if Trf.ndim >= 3: + n = Trf.ndim-2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1]+1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """ Invert a torch or numpy matrix + """ + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f'bad matrix type = {type(mat)}') + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W-1)/2 + grid_y = grid_y - (H-1)/2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = (depthmap > 0.0) + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + return X_world, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None): + """ renorm pointmaps pts1, pts2 with norm_mode + """ + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split('_') + + if norm_mode == 'avg': + # gather all points together (joint normalization) + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == 'dis': + pass # do nothing + elif dis_mode == 'log1p': + all_dis = torch.log1p(all_dis) + elif dis_mode == 'warp-log1p': + # actually warp input points before normalizing them + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1) + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f'bad {dis_mode=}') + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + # gather all points together (joint normalization) + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + + if norm_mode == 'avg': + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == 'median': + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == 'sqrt': + norm_factor = all_dis.sqrt().nanmean(dim=1)**2 + else: + raise ValueError(f'bad {norm_mode=}') + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + # set invalid points to NaN + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True): + # set invalid points to NaN + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))) + reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/image.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..51a3e4391a3b620c13f3c514051ce000e9406a57 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/image.py @@ -0,0 +1,104 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions about images (loading/converting...) +# -------------------------------------------------------- +import os +import torch +import numpy as np +import PIL.Image +import torchvision.transforms as tvf +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """ Open an image or a depthmap with opencv-python. + """ + if path.endswith(('.exr', 'EXR')): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f'Could not load image={path} with {options=}') + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) + return img.resize(new_size, interp) + + +def load_images(folder_or_list, size, square_ok=False): + """ open and convert all images in a list or folder to proper input format for DUSt3R + """ + if isinstance(folder_or_list, str): + print(f'>> Loading images from {folder_or_list}') + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + print(f'>> Loading a list of {len(folder_or_list)} images') + root, folder_content = '', folder_or_list + + else: + raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') + + imgs = [] + for path in folder_content: + if not path.endswith(('.jpg', '.jpeg', '.png', '.JPG')): + continue + img = PIL.Image.open(os.path.join(root, path)).convert('RGB') + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1))) + else: + # resize long side to 512 + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W//2, H//2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx-half, cy-half, cx+half, cy+half)) + else: + halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 + if not (square_ok) and W == H: + halfh = 3*halfw/4 + img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) + + W2, H2 = img.size + print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') + imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( + [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)))) + + assert imgs, 'no images foud at '+root + print(f' (Found {len(imgs)} images)') + return imgs diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/misc.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9fd06a063c3eafbfafddc011064ebb8a3232a8 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/misc.py @@ -0,0 +1,121 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import torch + + +def fill_default_args(kwargs, func): + import inspect # a bit hacky but it works reliably + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + # module is directly a parameter + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1['instance'] + y = gt2['instance'] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i]) + return ok + + +def flip(tensor): + """ flip so that tensor[0::2] <=> tensor[1::2] """ + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """ Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + def wrapper_no(decout, true_shape): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W)) + return res + + def wrapper_yes(decout, true_shape): + B = len(true_shape) + # by definition, the batch is in landscape mode so W >= H + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = (width >= height) + is_portrait = ~is_landscape + + # true_shape = true_shape.cpu() + if is_landscape.all(): + return head(decout, (H, W)) + if is_portrait.all(): + return transposed(head(decout, (W, H))) + + # batch is a mix of both portraint & landscape + def selout(ar): return [d[ar] for d in decout] + l_result = head(selout(is_landscape), (H, W)) + p_result = transposed(head(selout(is_portrait), (W, H))) + + # allocate full result + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float('nan') + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/path_to_croco.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000000000000000000000000000000000000..39226ce6bc0e1993ba98a22096de32cb6fa916b4 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/path_to_croco.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# CroCo submodule import +# -------------------------------------------------------- + +import sys +import os.path as path +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco')) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models') +# check the presence of models directory in repo to be sure its cloned +if path.isdir(CROCO_MODELS_PATH): + # workaround for sibling import + sys.path.insert(0, CROCO_REPO_PATH) +else: + raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?") diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/viz.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..a21f399accf6710816cc4a858d60849ccaad31e1 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/viz.py @@ -0,0 +1,320 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Visualization utilities using trimesh +# -------------------------------------------------------- +import PIL.Image +import numpy as np +from scipy.spatial.transform import Rotation +import torch + +from dust3r.utils.geometry import geotrf, get_med_dist_between_poses +from dust3r.utils.device import to_numpy +from dust3r.utils.image import rgb + +try: + import trimesh +except ImportError: + print('/!\\ module trimesh is not installed, cannot visualize results /!\\') + + +def cat_3d(vecs): + if isinstance(vecs, (np.ndarray, torch.Tensor)): + vecs = [vecs] + return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)]) + + +def show_raw_pointcloud(pts3d, colors, point_size=2): + scene = trimesh.Scene() + + pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors)) + scene.add_geometry(pct) + + scene.show(line_settings={'point_size': point_size}) + + +def pts3d_to_trimesh(img, pts3d, valid=None): + H, W, THREE = img.shape + assert THREE == 3 + assert img.shape == pts3d.shape + + vertices = pts3d.reshape(-1, 3) + + # make squares: each pixel == 2 triangles + idx = np.arange(len(vertices)).reshape(H, W) + idx1 = idx[:-1, :-1].ravel() # top-left corner + idx2 = idx[:-1, +1:].ravel() # right-left corner + idx3 = idx[+1:, :-1].ravel() # bottom-left corner + idx4 = idx[+1:, +1:].ravel() # bottom-right corner + faces = np.concatenate(( + np.c_[idx1, idx2, idx3], + np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling) + np.c_[idx2, idx3, idx4], + np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling) + ), axis=0) + + # prepare triangle colors + face_colors = np.concatenate(( + img[:-1, :-1].reshape(-1, 3), + img[:-1, :-1].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3) + ), axis=0) + + # remove invalid faces + if valid is not None: + assert valid.shape == (H, W) + valid_idxs = valid.ravel() + valid_faces = valid_idxs[faces].all(axis=-1) + faces = faces[valid_faces] + face_colors = face_colors[valid_faces] + + assert len(faces) == len(face_colors) + return dict(vertices=vertices, face_colors=face_colors, faces=faces) + + +def cat_meshes(meshes): + vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes]) + n_vertices = np.cumsum([0]+[len(v) for v in vertices]) + for i in range(len(faces)): + faces[i][:] += n_vertices[i] + + vertices = np.concatenate(vertices) + colors = np.concatenate(colors) + faces = np.concatenate(faces) + return dict(vertices=vertices, face_colors=colors, faces=faces) + + +def show_duster_pairs(view1, view2, pred1, pred2): + import matplotlib.pyplot as pl + pl.ion() + + for e in range(len(view1['instance'])): + i = view1['idx'][e] + j = view2['idx'][e] + img1 = rgb(view1['img'][e]) + img2 = rgb(view2['img'][e]) + conf1 = pred1['conf'][e].squeeze() + conf2 = pred2['conf'][e].squeeze() + score = conf1.mean()*conf2.mean() + print(f">> Showing pair #{e} {i}-{j} {score=:g}") + pl.clf() + pl.subplot(221).imshow(img1) + pl.subplot(223).imshow(img2) + pl.subplot(222).imshow(conf1, vmin=1, vmax=30) + pl.subplot(224).imshow(conf2, vmin=1, vmax=30) + pts1 = pred1['pts3d'][e] + pts2 = pred2['pts3d_in_other_view'][e] + pl.subplots_adjust(0, 0, 1, 1, 0, 0) + if input('show pointcloud? (y/n) ') == 'y': + show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5) + + +def auto_cam_size(im_poses): + return 0.1 * get_med_dist_between_poses(im_poses) + + +class SceneViz: + def __init__(self): + self.scene = trimesh.Scene() + + def add_pointcloud(self, pts3d, color, mask=None): + pts3d = to_numpy(pts3d) + mask = to_numpy(mask) + if mask is None: + mask = [slice(None)] * len(pts3d) + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + pct = trimesh.PointCloud(pts.reshape(-1, 3)) + + if isinstance(color, (list, np.ndarray, torch.Tensor)): + color = to_numpy(color) + col = np.concatenate([p[m] for p, m in zip(color, mask)]) + assert col.shape == pts.shape + pct.visual.vertex_colors = uint8(col.reshape(-1, 3)) + else: + assert len(color) == 3 + pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape) + + self.scene.add_geometry(pct) + return self + + def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03): + pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image)) + add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size) + return self + + def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw): + def get(arr, idx): return None if arr is None else arr[idx] + for i, pose_c2w in enumerate(poses): + self.add_camera(pose_c2w, get(focals, i), image=get(images, i), + color=get(colors, i), imsize=get(imsizes, i), **kw) + return self + + def show(self, point_size=2): + self.scene.show(line_settings={'point_size': point_size}) + + +def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world, + point_size=2, cam_size=0.05, cam_color=None): + """ Visualization of a pointcloud with cameras + imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...] + pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...] + focals = (N,) or N-size list of [focal, ...] + cams2world = (N,4,4) or N-size list of [(4,4), ...] + """ + assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) + pts3d = to_numpy(pts3d) + imgs = to_numpy(imgs) + focals = to_numpy(focals) + cams2world = to_numpy(cams2world) + + scene = trimesh.Scene() + + # full pointcloud + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) + pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) + scene.add_geometry(pct) + + # add each camera + for i, pose_c2w in enumerate(cams2world): + if isinstance(cam_color, list): + camera_edge_color = cam_color[i] + else: + camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] + add_scene_cam(scene, pose_c2w, camera_edge_color, + imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size) + + scene.show(line_settings={'point_size': point_size}) + + +def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03): + + if image is not None: + H, W, THREE = image.shape + assert THREE == 3 + if image.dtype != np.uint8: + image = np.uint8(255*image) + elif imsize is not None: + W, H = imsize + elif focal is not None: + H = W = focal / 1.1 + else: + H = W = 1 + + if focal is None: + focal = min(H, W) * 1.1 # default value + elif isinstance(focal, np.ndarray): + focal = focal[0] + + # create fake camera + height = focal * screen_width / H + width = screen_width * 0.5**0.5 + rot45 = np.eye(4) + rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() + rot45[2, 3] = -height # set the tip of the cone = optical center + aspect_ratio = np.eye(4) + aspect_ratio[0, 0] = W/H + transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45 + cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform) + + # this is the image + if image is not None: + vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) + faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) + img = trimesh.Trimesh(vertices=vertices, faces=faces) + uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) + img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image)) + scene.add_geometry(img) + + # this is the camera mesh + rot2 = np.eye(4) + rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix() + vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)] + vertices = geotrf(transform, vertices) + faces = [] + for face in cam.faces: + if 0 in face: + continue + a, b, c = face + a2, b2, c2 = face + len(cam.vertices) + a3, b3, c3 = face + 2*len(cam.vertices) + + # add 3 pseudo-edges + faces.append((a, b, b2)) + faces.append((a, a2, c)) + faces.append((c2, b, c)) + + faces.append((a, b, b3)) + faces.append((a, a3, c)) + faces.append((c3, b, c)) + + # no culling + faces += [(c, b, a) for a, b, c in faces] + + cam = trimesh.Trimesh(vertices=vertices, faces=faces) + cam.visual.face_colors[:, :3] = edge_color + scene.add_geometry(cam) + + +def cat(a, b): + return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3))) + + +OPENGL = np.array([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + + +CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204), + (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)] + + +def uint8(colors): + if not isinstance(colors, np.ndarray): + colors = np.array(colors) + if np.issubdtype(colors.dtype, np.floating): + colors *= 255 + assert 0 <= colors.min() and colors.max() < 256 + return np.uint8(colors) + + +def segment_sky(image): + import cv2 + from scipy import ndimage + + # Convert to HSV + image = to_numpy(image) + if np.issubdtype(image.dtype, np.floating): + image = np.uint8(255*image.clip(min=0, max=1)) + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + # Define range for blue color and create mask + lower_blue = np.array([0, 0, 100]) + upper_blue = np.array([30, 255, 255]) + mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool) + + # add luminous gray + mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150) + mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) + mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) + + # Morphological operations + kernel = np.ones((5, 5), np.uint8) + mask2 = ndimage.binary_opening(mask, structure=kernel) + + # keep only largest CC + _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8) + cc_sizes = stats[1:, cv2.CC_STAT_AREA] + order = cc_sizes.argsort()[::-1] # bigger first + i = 0 + selection = [] + while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: + selection.append(1 + order[i]) + i += 1 + mask3 = np.in1d(labels, selection).reshape(labels.shape) + + # Apply mask + return torch.from_numpy(mask3) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/encoders.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..3de24c4224fc817cd3c91ee7e8ec87f175af4562 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/encoders.py @@ -0,0 +1,137 @@ +from typing import Optional, Union +import torch +from torch import device +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as tvm +import gc + + +class ResNet50(nn.Module): + def __init__(self, pretrained=False, high_res = False, weights = None, + dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + if dilation is None: + dilation = [False,False,False] + if anti_aliased: + pass + else: + if weights is not None: + self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) + else: + net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) + self.net = nn.Sequential(net.conv1, net.bn1, net.relu, net.maxpool, net.layer1, net.layer2, net.layer3) + + self.high_res = high_res + self.freeze_bn = freeze_bn + self.early_exit = early_exit + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + # net = self.net + # feats = {1:x} + # x = net.conv1(x) + # x = net.bn1(x) + # x = net.relu(x) + # feats[2] = x + # x = net.maxpool(x) + # x = net.layer1(x) + # feats[4] = x + # x = net.layer2(x) + # feats[8] = x + # if self.early_exit: + # return feats + # x = net.layer3(x) + # feats[16] = x + # x = net.layer4(x) + # feats[32] = x + return self.net(x) + + def train(self, mode=True): + super().train(mode) + if self.freeze_bn: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + pass + +class VGG19(nn.Module): + def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + feats = {} + scale = 1 + for layer in self.layers: + if isinstance(layer, nn.MaxPool2d): + feats[scale] = x + scale = scale*2 + x = layer(x) + return feats + +class CNNandDinov2(nn.Module): + def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, coarse_backbone='DINOv2_large', coarse_patch_size=14, coarse_feat_dim=1024, dinov2_weights = None, amp_dtype = torch.float16): + super().__init__() + self.amp = amp + self.amp_dtype = amp_dtype + self.coarse_backbone = coarse_backbone + self.coarse_patch_size = coarse_patch_size + self.coarse_feat_dim = coarse_feat_dim + if 'DINOv2' in coarse_backbone: + if 'large' in coarse_backbone: + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu") + from .transformer import vit_large as vit_model + vit_kwargs = dict(img_size= 518, + patch_size= coarse_patch_size, + init_values = 1.0, + ffn_layer = "mlp", + block_chunks = 0, + ) + else: + raise NotImplementedError + + dinov2_vitl14 = vit_model(**vit_kwargs).eval() + dinov2_vitl14.load_state_dict(dinov2_weights) + + if self.amp: + dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) + self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP + elif coarse_backbone == 'ResNet50': + self.backbone_model = ResNet50(pretrained=True, amp=self.amp) + else: + raise NotImplementedError + + cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} + if not use_vgg: + self.cnn = ResNet50(**cnn_kwargs) + else: + self.cnn = VGG19(**cnn_kwargs) + + def train(self, mode: bool = True): + return self.cnn.train(mode) + + def forward(self, x, upsample = False): + B,C,H,W = x.shape + feature_pyramid = self.cnn(x) + + if not upsample: + with torch.no_grad(): + if 'DINOv2' in self.coarse_backbone: + if self.dinov2_vitl14[0].device != x.device: + self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device) + dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype) if self.amp else x) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,self.coarse_feat_dim,H//self.coarse_patch_size, W//self.coarse_patch_size) + del dinov2_features_16 + else: + raise NotImplementedError + if self.coarse_backbone == 'ResNet50': + features_16 = self.backbone_model(x.to(self.amp_dtype) if self.amp else x) + feature_pyramid[16] = features_16 + return feature_pyramid \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/matcher.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..7a63ebc391954396056c4fbd91857fe61dfec35b --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/matcher.py @@ -0,0 +1,937 @@ +import os +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize +from einops import rearrange +import warnings +from warnings import warn +from PIL import Image + +import roma +from roma.utils import get_tuple_transform_ops, resize_by_longest_edge_and_padding, resize_by_longest_edge_and_stretch +from roma.utils.local_correlation import local_correlation +from roma.utils.utils import cls_to_flow_refine +from roma.utils.kde import kde + +class ConvRefiner(nn.Module): + def __init__( + self, + in_dim=6, + hidden_dim=16, + out_dim=2, + dw=False, + kernel_size=5, + hidden_blocks=3, + displacement_emb = None, + displacement_emb_dim = None, + local_corr_radius = None, + corr_in_other = None, + no_im_B_fm = False, + amp = False, + concat_logits = False, + use_bias_block_1 = True, + use_cosine_corr = False, + disable_local_corr_grad = False, + is_classifier = False, + sample_mode = "bilinear", + norm_type = nn.BatchNorm2d, + bn_momentum = 0.1, + amp_dtype = torch.float16, + ): + super().__init__() + self.bn_momentum = bn_momentum + self.block1 = self.create_block( + in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1, + ) + self.hidden_blocks = nn.Sequential( + *[ + self.create_block( + hidden_dim, + hidden_dim, + dw=dw, + kernel_size=kernel_size, + norm_type=norm_type, + ) + for hb in range(hidden_blocks) + ] + ) + self.hidden_blocks = self.hidden_blocks + self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) + if displacement_emb: + self.has_displacement_emb = True + self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0) + else: + self.has_displacement_emb = False + self.local_corr_radius = local_corr_radius + self.corr_in_other = corr_in_other + self.no_im_B_fm = no_im_B_fm + self.amp = amp + self.concat_logits = concat_logits + self.use_cosine_corr = use_cosine_corr + self.disable_local_corr_grad = disable_local_corr_grad + self.is_classifier = is_classifier + self.sample_mode = sample_mode + self.amp_dtype = amp_dtype + + def create_block( + self, + in_dim, + out_dim, + dw=False, + kernel_size=5, + bias = True, + norm_type = nn.BatchNorm2d, + ): + num_groups = 1 if not dw else in_dim + if dw: + assert ( + out_dim % in_dim == 0 + ), "outdim must be divisible by indim for depthwise" + conv1 = nn.Conv2d( + in_dim, + out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=num_groups, + bias=bias, + ) + norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) + return nn.Sequential(conv1, norm, relu, conv2) + + def forward(self, x, y, flow, scale_factor = 1, logits = None): + b,c,hs,ws = x.shape + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + with torch.no_grad(): + x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode) + if self.has_displacement_emb: + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device), + ) + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) + in_displacement = flow-im_A_coords + emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement) + if self.local_corr_radius: + if self.corr_in_other: + # Corr in other means take a kxk grid around the predicted coordinate in other image + local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow, + sample_mode = self.sample_mode) + else: + raise NotImplementedError("Local corr in own frame should not be used.") + if self.no_im_B_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) + else: + d = torch.cat((x, x_hat, emb_in_displacement), dim=1) + else: + if self.no_im_B_fm: + x_hat = torch.zeros_like(x) + d = torch.cat((x, x_hat), dim=1) + if self.concat_logits: + d = torch.cat((d, logits), dim=1) + d = self.block1(d) + d = self.hidden_blocks(d) + d = self.out_conv(d.float()) + displacement, certainty = d[:, :-1], d[:, -1:] + return displacement, certainty + +class CosKernel(nn.Module): # similar to softmax kernel + def __init__(self, T, learn_temperature=False): + super().__init__() + self.learn_temperature = learn_temperature + if self.learn_temperature: + self.T = nn.Parameter(torch.tensor(T)) + else: + self.T = T + + def __call__(self, x, y, eps=1e-6): + c = torch.einsum("bnd,bmd->bnm", x, y) / ( + x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps + ) + if self.learn_temperature: + T = self.T.abs() + 0.01 + else: + T = torch.tensor(self.T, device=c.device) + K = ((c - 1.0) / T).exp() + return K + +class GP(nn.Module): + def __init__( + self, + kernel, + T=1, + learn_temperature=False, + only_attention=False, + gp_dim=64, + basis="fourier", + covar_size=5, + only_nearest_neighbour=False, + sigma_noise=0.1, + no_cov=False, + predict_features = False, + ): + super().__init__() + self.K = kernel(T=T, learn_temperature=learn_temperature) + self.sigma_noise = sigma_noise + self.covar_size = covar_size + self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) + self.only_attention = only_attention + self.only_nearest_neighbour = only_nearest_neighbour + self.basis = basis + self.no_cov = no_cov + self.dim = gp_dim + self.predict_features = predict_features + + def get_local_cov(self, cov): + K = self.covar_size + b, h, w, h, w = cov.shape + hw = h * w + cov = F.pad(cov, 4 * (K // 2,)) # pad v_q + delta = torch.stack( + torch.meshgrid( + torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) + ), + dim=-1, + ) + positions = torch.stack( + torch.meshgrid( + torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) + ), + dim=-1, + ) + neighbours = positions[:, :, None, None, :] + delta[None, :, :] + points = torch.arange(hw)[:, None].expand(hw, K**2) + local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ + :, + points.flatten(), + neighbours[..., 0].flatten(), + neighbours[..., 1].flatten(), + ].reshape(b, h, w, K**2) + return local_cov + + def reshape(self, x): + return rearrange(x, "b d h w -> b (h w) d") + + def project_to_basis(self, x): + if self.basis == "fourier": + return torch.cos(8 * math.pi * self.pos_conv(x)) + elif self.basis == "linear": + return self.pos_conv(x) + else: + raise ValueError( + "No other bases other than fourier and linear currently im_Bed in public release" + ) + + def get_pos_enc(self, y): + b, c, h, w = y.shape + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), + ) + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.project_to_basis(coarse_coords) + return coarse_embedded_coords + + def forward(self, x, y, **kwargs): + b, c, h1, w1 = x.shape + b, c, h2, w2 = y.shape + f = self.get_pos_enc(y) + b, d, h2, w2 = f.shape + x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f) + K_xx = self.K(x, x) + K_yy = self.K(y, y) + K_xy = self.K(x, y) + K_yx = K_xy.permute(0, 2, 1) + sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] + with warnings.catch_warnings(): + K_yy = K_yy + sigma_noise # To increase stability in inverse + with torch.no_grad(): + K_yy_dig_zeromask = ((K_yy[torch.eye(h2 * w2, device=x.device, dtype=torch.bool).repeat(b, 1, 1)] == 0).reshape(b, -1)) + K_yy = K_yy + self.sigma_noise * K_yy_dig_zeromask[..., None] * torch.eye(h2 * w2, device=x.device)[None, :, :] + K_yy_inv = torch.linalg.inv(K_yy) + + mu_x = K_xy.matmul(K_yy_inv.matmul(f)) + mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) + if not self.no_cov: + cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) + cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1) + local_cov_x = self.get_local_cov(cov_x) + local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") + gp_feats = torch.cat((mu_x, local_cov_x), dim=1) + else: + gp_feats = mu_x + return gp_feats + +class Decoder(nn.Module): + def __init__( + self, embedding_decoder, gps, proj, conv_refiner, amp, detach=False, scales="all", pos_embeddings = None, + num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0, + flow_upsample_mode = "bilinear", amp_dtype = torch.float16, + ): + super().__init__() + self.embedding_decoder = embedding_decoder + self.num_refinement_steps_per_scale = num_refinement_steps_per_scale + self.gps = gps + self.proj = proj + self.amp = amp + self.conv_refiner = conv_refiner + self.detach = detach + if pos_embeddings is None: + self.pos_embeddings = {} + else: + self.pos_embeddings = pos_embeddings + if scales == "all": + self.scales = ["32", "16", "8", "4", "2", "1"] + else: + self.scales = scales + self.warp_noise_std = warp_noise_std + self.refine_init = 4 + self.displacement_dropout_p = displacement_dropout_p + self.gm_warp_dropout_p = gm_warp_dropout_p + self.flow_upsample_mode = flow_upsample_mode + self.amp_dtype = amp_dtype + + def get_placeholder_flow(self, b, h, w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + return coarse_coords + + def get_positional_embedding(self, b, h ,w, device): + coarse_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), + ) + ) + + coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ + None + ].expand(b, h, w, 2) + coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") + coarse_embedded_coords = self.pos_embedding(coarse_coords) + return coarse_embedded_coords + + def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1): + coarse_scales = self.embedding_decoder.scales() + all_scales = self.scales if not upsample else ["8", "4", "2", "1"] + sizes = {scale: f1[scale].shape[-2:] for scale in f1} + h, w = sizes[1] + b = f1[1].shape[0] + device = f1[1].device + coarsest_scale = int(all_scales[0]) + old_stuff = torch.zeros( + b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device + ) + corresps = {} + if not upsample: + flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) + certainty = 0.0 + else: + flow = F.interpolate( + flow, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + certainty = F.interpolate( + certainty, + size=sizes[coarsest_scale], + align_corners=False, + mode="bilinear", + ) + displacement = 0.0 + for new_scale in all_scales: + ins = int(new_scale) + corresps[ins] = {} + f1_s, f2_s = f1[ins], f2[ins] + if new_scale in self.proj: + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) + + if ins in coarse_scales: + old_stuff = F.interpolate( + old_stuff, size=sizes[ins], mode="bilinear", align_corners=False + ) + gp_posterior = self.gps[new_scale](f1_s, f2_s) + gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder( + gp_posterior, f1_s, old_stuff, new_scale + ) + + if self.embedding_decoder.is_classifier: + flow = cls_to_flow_refine( + gm_warp_or_cls, + ).permute(0,3,1,2) + corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) + else: + corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) + flow = gm_warp_or_cls.detach() + + if new_scale in self.conv_refiner: + corresps[ins].update({"flow_pre_delta": flow}) if self.training else None + delta_flow, delta_certainty = self.conv_refiner[new_scale]( + f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty, + ) + corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None + displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w), + delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,) + flow = flow + displacement + certainty = ( + certainty + delta_certainty + ) # predict both certainty and displacement + corresps[ins].update({ + "certainty": certainty, + "flow": flow, + }) + if new_scale != "1": + flow = F.interpolate( + flow, + size=sizes[ins // 2], + mode=self.flow_upsample_mode, + ) + certainty = F.interpolate( + certainty, + size=sizes[ins // 2], + mode=self.flow_upsample_mode, + ) + if self.detach: + flow = flow.detach() + certainty = certainty.detach() + #torch.cuda.empty_cache() + return corresps + + +class RegressionMatcher(nn.Module): + def __init__( + self, + encoder, + decoder, + h=448, + w=448, + sample_mode = "threshold", + upsample_preds = False, + symmetric = False, + name = None, + attenuate_cert = None, + recrop_upsample = False, + ): + super().__init__() + self.attenuate_cert = attenuate_cert + self.encoder = encoder + self.decoder = decoder + self.name = name + self.w_resized = w + self.h_resized = h + self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) + self.sample_mode = sample_mode + self.upsample_preds = upsample_preds + self.upsample_res = (14*16*6, 14*16*6) + self.symmetric = symmetric + self.sample_thresh = 0.05 + self.recrop_upsample = recrop_upsample + + def get_output_resolution(self): + if not self.upsample_preds: + return self.h_resized, self.w_resized + else: + return self.upsample_res + + def extract_backbone_features(self, batch, batched = True, upsample = False): + x_q = batch["im_A"] + x_s = batch["im_B"] + if batched: + X = torch.cat((x_q, x_s), dim = 0) + feature_pyramid = self.encoder(X, upsample = upsample) + else: + feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample) + return feature_pyramid + + def sample( + self, + matches, + certainty, + num=10000, + ): + if "threshold" in self.sample_mode: + upper_thresh = self.sample_thresh + certainty = certainty.clone() + certainty[certainty > upper_thresh] = 1 + matches, certainty = ( + matches.reshape(-1, 4), + certainty.reshape(-1), + ) + expansion_factor = 4 if "balanced" in self.sample_mode else 1 + + if certainty.sum() == 0: + certainty[0] = 1 # Corner case, to avoid following multinormal error + try: + good_samples = torch.multinomial(certainty, + num_samples = min(expansion_factor*num, len(certainty)), + replacement=False) + except: + return matches[[0]], certainty[[0]] + good_matches, good_certainty = matches[good_samples], certainty[good_samples] + if "balanced" not in self.sample_mode: + return good_matches, good_certainty + density = kde(good_matches, std=0.1) + p = 1 / (density+1) + p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones + balanced_samples = torch.multinomial(p, + num_samples = min(num,len(good_certainty)), + replacement=False) + return good_matches[balanced_samples], good_certainty[balanced_samples] + + def forward(self, batch, batched = True, upsample = False, scale_factor = 1): + feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample) + if batched: + f_q_pyramid = { + scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() + } + f_s_pyramid = { + scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() + } + else: + f_q_pyramid, f_s_pyramid = feature_pyramid + corresps = self.decoder(f_q_pyramid, + f_s_pyramid, + upsample = upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor) + + return corresps + + def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1): + feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample) + f_q_pyramid = feature_pyramid + f_s_pyramid = { + scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0) + for scale, f_scale in feature_pyramid.items() + } + corresps = self.decoder(f_q_pyramid, + f_s_pyramid, + upsample = upsample, + **(batch["corresps"] if "corresps" in batch else {}), + scale_factor=scale_factor) + return corresps + + def to_pixel_coordinates(self, coords, H_A, W_A, H_B, W_B): + if isinstance(coords, (list, tuple)): + kpts_A, kpts_B = coords[0], coords[1] + else: + kpts_A, kpts_B = coords[...,:2], coords[...,2:] + kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1) + kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1) + return kpts_A, kpts_B + + def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B): + if isinstance(coords, (list, tuple)): + kpts_A, kpts_B = coords[0], coords[1] + else: + kpts_A, kpts_B = coords[...,:2], coords[...,2:] + kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1) + kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1) + return kpts_A, kpts_B + + def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False): + x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT + cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0] + D = torch.cdist(x_A_to_B, x_B) + inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True) + + if return_tuple: + if return_inds: + return inds_A, inds_B + else: + return x_A[inds_A], x_B[inds_B] + else: + if return_inds: + return torch.cat((inds_A, inds_B),dim=-1) + else: + return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1) + + def warp_keypoints(self, x_A, warp, certainty, H_A, W_A, H_B, W_B): + H,W2,_ = warp.shape + W = W2//2 if self.symmetric else W2 + # To normalized coords: + x_A_norm = torch.stack((2/W_A * x_A[...,0] - 1, 2/H_A * x_A[...,1] - 1),axis=-1) + x_A_to_B = F.grid_sample(warp[:,:W, 2:].permute(2,0,1)[None], x_A_norm[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT + cert_A_to_B = F.grid_sample(certainty[None,None,:,:W], x_A_norm[None,None], align_corners = False, mode = "bilinear")[0,0,0] + + # To origin coords: + x_A_to_B = torch.stack((W_B/2 * (x_A_to_B[...,0]+1), H_B/2 * (x_A_to_B[...,1]+1)),axis=-1) + return x_A_to_B, cert_A_to_B + + def get_roi(self, certainty, W, H, thr = 0.025): + raise NotImplementedError("WIP, disable for now") + hs,ws = certainty.shape + certainty = certainty/certainty.sum(dim=(-1,-2)) + cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2) + cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1) + print(cum_certainty_w) + print(torch.min(torch.nonzero(cum_certainty_w > thr))) + print(torch.min(torch.nonzero(cum_certainty_w < thr))) + left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr))) + right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr))) + top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr))) + bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr))) + print(left, right, top, bottom) + return left, top, right, bottom + + def recrop(self, certainty, image_path): + roi = self.get_roi(certainty, *Image.open(image_path).size) + return Image.open(image_path).crop(roi) + + @torch.no_grad() + def self_train_time_match( + self, + data, + corresps, + finest_scale=1, + ): + B, C, hs, ws = data['image0'].shape + device = data['image0'].device + im_A_to_im_B = corresps[finest_scale]["flow"] + certainty = corresps[finest_scale]["certainty"] + if finest_scale != 1: + im_A_to_im_B = F.interpolate( + im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" + ) + certainty = F.interpolate( + certainty, size=(hs, ws), align_corners=False, mode="bilinear" + ) + im_A_to_im_B = im_A_to_im_B.permute( + 0, 2, 3, 1 + ) + # Create im_A meshgrid + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + ) + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(B, 2, hs, ws) + certainty = certainty.sigmoid() # logits -> probs + im_A_coords = im_A_coords.permute(0, 2, 3, 1) + if (im_A_to_im_B.abs() > 1).any() and True: + wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 + certainty[wrong[:,None]] = 0 + im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) + warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) + return ( + warp, + certainty[:, 0] + ) + + @torch.no_grad() + def self_inference_time_match( + self, + im_A_org, + im_B_org, + device = None, + resize_by_stretch=False, + norm_img=False, + ): + if isinstance(im_A_org, (str, os.PathLike)): + im_A_org = torch.from_numpy(np.array(Image.open(im_A_org).convert("RGB"))).permute(2,0,1) / 255. + im_B_org = torch.from_numpy(np.array(Image.open(im_B_org).convert("RGB"))).permute(2,0,1) / 255. + + symmetric = self.symmetric + self.train(False) + with torch.no_grad(): + b = 1 + # Get images in good format + assert self.w_resized == self.h_resized + hs, ws = self.h_resized, self.w_resized + if resize_by_stretch: + im_A = resize_by_longest_edge_and_stretch(im_A_org, hs) + im_B = resize_by_longest_edge_and_stretch(im_B_org, hs) + else: + im_A = resize_by_longest_edge_and_padding(im_A_org, hs) + im_B = resize_by_longest_edge_and_padding(im_B_org, hs) + + if norm_img: + im_A = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(im_A) # Input: 3*H*W + im_B = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(im_B) # Input: 3*H*W + + if device is None: + batch = {"im_A": im_A[None], "im_B": im_B[None]} + else: + batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} + + finest_scale = 1 + # Run matcher + if symmetric: + corresps = self.forward_symmetric(batch) + else: + corresps = self.forward(batch, batched = True) + + if self.upsample_preds: + hs, ws = self.upsample_res + + if self.attenuate_cert: + low_res_certainty = F.interpolate( + corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + cert_clamp = 0 + factor = 0.5 + low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + + if self.upsample_preds: + finest_corresps = corresps[finest_scale] + + assert hs == ws + if resize_by_stretch: + im_A = resize_by_longest_edge_and_stretch(im_A_org, hs) + im_B = resize_by_longest_edge_and_stretch(im_B_org, hs) + else: + im_A = resize_by_longest_edge_and_padding(im_A_org, hs) + im_B = resize_by_longest_edge_and_padding(im_B_org, hs) + + if norm_img: + im_A = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(im_A) # Input: 3*H*W + im_B = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(im_B) # Input: 3*H*W + + if device is None: + im_A, im_B = im_A[None], im_B[None] + else: + im_A, im_B = im_A[None].to(device), im_B[None].to(device) + scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized)) + batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} + if symmetric: + corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor) + else: + corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor) + + im_A_to_im_B = corresps[finest_scale]["flow"] + certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0) + if finest_scale != 1: + im_A_to_im_B = F.interpolate( + im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" + ) + certainty = F.interpolate( + certainty, size=(hs, ws), align_corners=False, mode="bilinear" + ) + im_A_to_im_B = im_A_to_im_B.permute( + 0, 2, 3, 1 + ) + # Create im_A meshgrid + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=im_A.device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=im_A.device), + ) + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) + certainty = certainty.sigmoid() # logits -> probs + im_A_coords = im_A_coords.permute(0, 2, 3, 1) + if (im_A_to_im_B.abs() > 1).any() and True: + wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 + certainty[wrong[:,None]] = 0 + im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) + if symmetric: + A_to_B, B_to_A = im_A_to_im_B.chunk(2) + q_warp = torch.cat((im_A_coords, A_to_B), dim=-1) + im_B_coords = im_A_coords + s_warp = torch.cat((B_to_A, im_B_coords), dim=-1) + warp = torch.cat((q_warp, s_warp),dim=2) + certainty = torch.cat(certainty.chunk(2), dim=3) + else: + warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) + return ( + warp[0], + certainty[0, 0], + ) + + @torch.inference_mode() + def match( + self, + im_A_path, + im_B_path, + *args, + batched=False, + device = None, + ): + if isinstance(im_A_path, (str, os.PathLike)): + im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") + else: + # Assume its not a path + im_A, im_B = im_A_path, im_B_path + symmetric = self.symmetric + self.train(False) + with torch.no_grad(): + if not batched: + b = 1 + if isinstance(im_A, torch.Tensor): + h, w = im_A.shape[-2:] + h2, w2 = im_B.shape[-2:] + else: + w, h = im_A.size + w2, h2 = im_B.size + # Get images in good format + ws = self.w_resized + hs = self.h_resized + + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True, clahe = False + ) + im_A, im_B = test_transform((im_A, im_B)) + if device is None: + batch = {"im_A": im_A[None], "im_B": im_B[None]} + else: + batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} + else: + b, c, h, w = im_A.shape + b, c, h2, w2 = im_B.shape + assert w == w2 and h == h2, "For batched images we assume same size" + if device is None: + batch = {"im_A": im_A, "im_B": im_B} + else: + batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)} + if h != self.h_resized or self.w_resized != w: + warn("Model resolution and batch resolution differ, may produce unexpected results") + hs, ws = h, w + finest_scale = 1 + # Run matcher + if symmetric: + corresps = self.forward_symmetric(batch) + else: + corresps = self.forward(batch, batched = True) + + if self.upsample_preds: + hs, ws = self.upsample_res + + if self.attenuate_cert: + low_res_certainty = F.interpolate( + corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + cert_clamp = 0 + factor = 0.5 + low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp) + + if self.upsample_preds: + finest_corresps = corresps[finest_scale] + test_transform = get_tuple_transform_ops( + resize=(hs, ws), normalize=True + ) + if self.recrop_upsample: + certainty = corresps[finest_scale]["certainty"] + print(certainty.shape) + im_A = self.recrop(certainty[0,0], im_A_path) + im_B = self.recrop(certainty[1,0], im_B_path) + #TODO: need to adjust corresps when doing this + else: + if isinstance(im_A_path, (str, os.PathLike)): + im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB") + else: + # Assume its not a path + im_A, im_B = im_A_path, im_B_path + + im_A, im_B = test_transform((im_A, im_B)) + if device is None: + im_A, im_B = im_A[None], im_B[None] + else: + im_A, im_B = im_A[None].to(device), im_B[None].to(device) + scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized)) + batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps} + if symmetric: + corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor) + else: + corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor) + + im_A_to_im_B = corresps[finest_scale]["flow"] + certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0) + if finest_scale != 1: + im_A_to_im_B = F.interpolate( + im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" + ) + certainty = F.interpolate( + certainty, size=(hs, ws), align_corners=False, mode="bilinear" + ) + im_A_to_im_B = im_A_to_im_B.permute( + 0, 2, 3, 1 + ) + # Create im_A meshgrid + im_A_coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=im_A.device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=im_A.device), + ) + ) + im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0])) + im_A_coords = im_A_coords[None].expand(b, 2, hs, ws) + certainty = certainty.sigmoid() # logits -> probs + im_A_coords = im_A_coords.permute(0, 2, 3, 1) + if (im_A_to_im_B.abs() > 1).any() and True: + wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0 + certainty[wrong[:,None]] = 0 + im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1) + if symmetric: + A_to_B, B_to_A = im_A_to_im_B.chunk(2) + q_warp = torch.cat((im_A_coords, A_to_B), dim=-1) + im_B_coords = im_A_coords + s_warp = torch.cat((B_to_A, im_B_coords), dim=-1) + warp = torch.cat((q_warp, s_warp),dim=2) + certainty = torch.cat(certainty.chunk(2), dim=3) + else: + warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) + if batched: + return ( + warp, + certainty[:, 0] + ) + else: + return ( + warp[0], + certainty[0, 0], + ) + + def visualize_warp(self, warp, certainty, im_A = None, im_B = None, im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None): + assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)" + H,W2,_ = warp.shape + W = W2//2 if symmetric else W2 + if im_A is None: + from PIL import Image + im_A, im_B = Image.open(im_A_path), Image.open(im_B_path) + im_A = im_A.resize((W,H)) + im_B = im_B.resize((W,H)) + + x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1) + x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1) + + im_A_transfer_rgb = F.grid_sample( + x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False + )[0] + im_B_transfer_rgb = F.grid_sample( + x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False + )[0] + warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2) + white_im = torch.ones((H,2*W),device=device) + vis_im = certainty * warp_im + (1 - certainty) * white_im + if save_path is not None: + from roma.utils import tensor_to_pil + tensor_to_pil(vis_im, unnormalize=False).save(save_path) + return vis_im \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20e8da481880da376c6d4653770bd4ca1e814034 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/__init__.py @@ -0,0 +1,53 @@ +from typing import Union +import torch +from .roma_models import roma_model + +weight_urls = { + "roma": { + "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", + "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", + }, + "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D +} + +def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res = 560, upsample_res = 864, upsample_preds = True, symmetric=True, attenuate_cert=True): + if isinstance(coarse_res, int): + coarse_res = (coarse_res, coarse_res) + if isinstance(upsample_res, int): + upsample_res = (upsample_res, upsample_res) + + assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone" + assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone" + + if weights is None: + weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"], + map_location=device) + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], + map_location=device) + model = roma_model(resolution=coarse_res, upsample_preds=upsample_preds, + weights=weights,dinov2_weights = dinov2_weights,device=device, symmetric=symmetric, attenuate_cert=attenuate_cert) + model.upsample_res = upsample_res + print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}") + return model + +def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res = 560, upsample_res = 864): + if isinstance(coarse_res, int): + coarse_res = (coarse_res, coarse_res) + if isinstance(upsample_res, int): + upsample_res = (upsample_res, upsample_res) + + assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone" + assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone" + + if weights is None: + weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"], + map_location=device) + if dinov2_weights is None: + dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"], + map_location=device) + model = roma_model(resolution=coarse_res, upsample_preds=True, + weights=weights,dinov2_weights = dinov2_weights,device=device) + model.upsample_res = upsample_res + print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}") + return model diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/roma_models.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/roma_models.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4e66952d58b78cba7c58ab84f91a3d0bd02c33 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/roma_models.py @@ -0,0 +1,162 @@ +import warnings +import torch.nn as nn + +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent.parent.parent.resolve())) +from roma.models.matcher import * +from roma.models.transformer import Block, TransformerDecoder, MemEffAttention +from roma.models.encoders import * + +def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, symmetric=True, attenuate_cert=True, **kwargs): + # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') + gp_dim = 512 + feat_dim = 512 + decoder_dim = gp_dim + feat_dim + cls_to_coord_res = 64 + coordinate_decoder = TransformerDecoder( + nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), + decoder_dim, + cls_to_coord_res**2 + 1, + is_classifier=True, + amp = True, + pos_enc = False,) + dw = True + hidden_blocks = 8 + kernel_size = 5 + displacement_emb = "linear" + disable_local_corr_grad = True + + conv_refiner = nn.ModuleDict( + { + "16": ConvRefiner( + 2 * 512+128+(2*7+1)**2, + 2 * 512+128+(2*7+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=128, + local_corr_radius = 7, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "8": ConvRefiner( + 2 * 512+64+(2*3+1)**2, + 2 * 512+64+(2*3+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=64, + local_corr_radius = 3, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "4": ConvRefiner( + 2 * 256+32+(2*2+1)**2, + 2 * 256+32+(2*2+1)**2, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=32, + local_corr_radius = 2, + corr_in_other = True, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "2": ConvRefiner( + 2 * 64+16, + 128+16, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks=hidden_blocks, + displacement_emb=displacement_emb, + displacement_emb_dim=16, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + "1": ConvRefiner( + 2 * 9 + 6, + 24, + 2 + 1, + kernel_size=kernel_size, + dw=dw, + hidden_blocks = hidden_blocks, + displacement_emb = displacement_emb, + displacement_emb_dim = 6, + amp = True, + disable_local_corr_grad = disable_local_corr_grad, + bn_momentum = 0.01, + ), + } + ) + kernel_temperature = 0.2 + learn_temperature = False + no_cov = True + kernel = CosKernel + only_attention = False + basis = "fourier" + gp16 = GP( + kernel, + T=kernel_temperature, + learn_temperature=learn_temperature, + only_attention=only_attention, + gp_dim=gp_dim, + basis=basis, + no_cov=no_cov, + ) + gps = nn.ModuleDict({"16": gp16}) + proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) + proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) + proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) + proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) + proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) + proj = nn.ModuleDict({ + "16": proj16, + "8": proj8, + "4": proj4, + "2": proj2, + "1": proj1, + }) + displacement_dropout_p = 0.0 + gm_warp_dropout_p = 0.0 + decoder = Decoder(coordinate_decoder, + gps, + proj, + conv_refiner, + amp=True, + detach=True, + scales=["16", "8", "4", "2", "1"], + displacement_dropout_p = displacement_dropout_p, + gm_warp_dropout_p = gm_warp_dropout_p) + + encoder = CNNandDinov2( + cnn_kwargs = dict( + pretrained=False, + amp = True), + amp = True, + use_vgg = True, + dinov2_weights = dinov2_weights + ) + h,w = resolution + symmetric = symmetric + attenuate_cert = attenuate_cert + matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, + symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device) + matcher.load_state_dict(weights) + return matcher diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7a3ccf92e70755817c98a3282f6d8769e32e63 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/__init__.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from roma.utils.utils import get_grid +from .layers.block import Block +from .layers.attention import MemEffAttention +from .dinov2 import vit_large, vit_base, vit_small + +class TransformerDecoder(nn.Module): + def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, + amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.blocks = blocks + self.to_out = nn.Linear(hidden_dim, out_dim) + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self._scales = [16] + self.is_classifier = is_classifier + self.amp = amp + self.amp_dtype = amp_dtype + self.pos_enc = pos_enc + self.learned_embeddings = learned_embeddings + if self.learned_embeddings: + self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim)))) + + def scales(self): + return self._scales.copy() + + def forward(self, gp_posterior, features, old_stuff, new_scale): + with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp): + B,C,H,W = gp_posterior.shape + x = torch.cat((gp_posterior, features), dim = 1) + B,C,H,W = x.shape + grid = get_grid(B, H, W, x.device).reshape(B,H*W,2) + if self.learned_embeddings: + pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C) + else: + pos_enc = 0 + tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc + z = self.blocks(tokens) + out = self.to_out(z) + out = out.permute(0,2,1).reshape(B, self.out_dim, H, W) + warp, certainty = out[:, :-1], out[:, -1:] + return warp, certainty, None + + diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/dinov2.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..b556c63096d17239c8603d5fe626c331963099fd --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/dinov2.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + for param in self.parameters(): + param.requires_grad = False + + @property + def device(self): + return self.cls_token.device + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/attention.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9b0c94b40967dfdff4f261c127cbd21328c905 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/attention.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/block.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/dino_head.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/dino_head.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/drop_path.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/layer_scale.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/mlp.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/patch_embed.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/swiglu_ffn.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/roma_adpat_model.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/roma_adpat_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbc9e1c16fa8b7f42b9df7b29b114c99799c9ab --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/roma_adpat_model.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +from PIL import Image +import numpy as np +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent.resolve())) + +from .models import roma_outdoor + +class ROMA_Model(nn.Module): + def __init__(self, MAX_MATCHES=5000, SAMPLE_THRESH=0.8, MATCH_THRESH=0.3) -> None: + super().__init__() + self.model = roma_outdoor(device=torch.device("cpu")) + self.MAX_MATCHES = MAX_MATCHES + self.MATCH_THRESH = MATCH_THRESH + self.model.sample_thresh = SAMPLE_THRESH # Inner matcher + + def forward(self, data): + img0, img1 = data['image0_rgb'][0], data['image1_rgb'][0] # unbatch, 3 * H * W + + H_A, W_A = img0.shape[-2:] + H_B, W_B = img1.shape[-2:] + warp, certainty = self.model.match(img0, img1) # 3 * H * W + # Sample matches for estimation + matches, certainty = self.model.sample(warp, certainty, num=self.MAX_MATCHES) + + mask = certainty > self.MATCH_THRESH + kpts0, kpts1 = self.model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) + kpts0, kpts1, certainty = map(lambda x:x[mask], [kpts0, kpts1, certainty]) + data.update({'m_bids': torch.zeros_like(kpts0[:, 0]), "mkpts0_f": kpts0, "mkpts1_f": kpts1, "mconf": certainty}) + return data \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/__init__.py @@ -0,0 +1 @@ +from .train import train_k_epochs diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/train.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d51909772ec464be2428796c5ef936d2bf4e1ef4 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/train.py @@ -0,0 +1,102 @@ +from tqdm import tqdm +from roma.utils.utils import to_cuda +import roma +import torch +# import wandb + +def log_param_statistics(named_parameters, norm_type = 2): + named_parameters = list(named_parameters) + grads = [p.grad for n, p in named_parameters if p.grad is not None] + weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None] + names = [n for n,p in named_parameters if p.grad is not None] + param_norm = torch.stack(weight_norms).norm(p=norm_type) + device = grads[0].device + grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]) + nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms) + nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf] + total_grad_norm = torch.norm(grad_norms, norm_type) + if torch.any(nans_or_infs): + print(f"These params have nan or inf grads: {nan_inf_names}") + wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP) + wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP) + +def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs): + optimizer.zero_grad() + out = model(train_batch) + l = objective(out, train_batch) + grad_scaler.scale(l).backward() + grad_scaler.unscale_(optimizer) + log_param_statistics(model.named_parameters()) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be? + grad_scaler.step(optimizer) + grad_scaler.update() + wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP) + if grad_scaler._scale < 1.: + grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale) + roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step + return {"train_out": out, "train_loss": l.item()} + + +def train_k_steps( + n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, +): + for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0): + batch = next(dataloader) + model.train(True) + batch = to_cuda(batch) + train_step( + train_batch=batch, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + grad_scaler=grad_scaler, + n=n, + grad_clip_norm = grad_clip_norm, + ) + if ema_model is not None: + ema_model.update() + if warmup is not None: + with warmup.dampening(): + lr_scheduler.step() + else: + lr_scheduler.step() + [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())] + + +def train_epoch( + dataloader=None, + model=None, + objective=None, + optimizer=None, + lr_scheduler=None, + epoch=None, +): + model.train(True) + print(f"At epoch {epoch}") + for batch in tqdm(dataloader, mininterval=5.0): + batch = to_cuda(batch) + train_step( + train_batch=batch, model=model, objective=objective, optimizer=optimizer + ) + lr_scheduler.step() + return { + "model": model, + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + "epoch": epoch, + } + + +def train_k_epochs( + start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler +): + for epoch in range(start_epoch, end_epoch + 1): + train_epoch( + dataloader=dataloader, + model=model, + objective=objective, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + ) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/__init__.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce65460cfe65694fcef49d8aec3130672d9d7b8f --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/__init__.py @@ -0,0 +1,18 @@ +from .utils import ( + pose_auc, + get_pose, + compute_relative_pose, + compute_pose_error, + estimate_pose, + estimate_pose_uncalibrated, + rotate_intrinsic, + get_tuple_transform_ops, + get_depth_tuple_transform_ops, + warp_kpts, + numpy_to_pil, + tensor_to_pil, + recover_pose, + signed_left_to_right_epipolar_distance, + resize_by_longest_edge_and_padding, + resize_by_longest_edge_and_stretch +) diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/kde.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..c858b0734854d4ac3186c0fc507fdcf824a8275e --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/kde.py @@ -0,0 +1,8 @@ +import torch + +def kde(x, std = 0.1): + # use a gaussian kernel to estimate density + # x = x.half() # Do it in half precision TODO: remove hardcoding + scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() + density = scores.sum(dim=-1) + return density \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/local_correlation.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/local_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..e180c189c2801d528cfe7de9943baecf11851135 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/local_correlation.py @@ -0,0 +1,44 @@ +import torch +import torch.nn.functional as F + +def local_correlation( + feature0, + feature1, + local_radius, + padding_mode="zeros", + flow = None, + sample_mode = "bilinear", +): + r = local_radius + K = (2*r+1)**2 + B, c, h, w = feature0.size() + corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype) + if flow is None: + # If flow is None, assume feature0 and feature1 are aligned + coords = torch.meshgrid( + ( + torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device), + torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device), + )) + coords = torch.stack((coords[1], coords[0]), dim=-1)[ + None + ].expand(B, h, w, 2) + else: + coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. + local_window = torch.meshgrid( + ( + torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device), + torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device), + )) + local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ + None + ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2) + for _ in range(B): + with torch.no_grad(): + local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2) + window_feature = F.grid_sample( + feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, # + ) + window_feature = window_feature.reshape(c,h,w,(2*r+1)**2) + corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1) + return corr \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/transforms.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6476bd816a31df36f7d1b5417853637b65474b --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/transforms.py @@ -0,0 +1,118 @@ +from typing import Dict +import numpy as np +import torch +import kornia.augmentation as K +from kornia.geometry.transform import warp_perspective + +# Adapted from Kornia +class GeometricSequential: + def __init__(self, *transforms, align_corners=True) -> None: + self.transforms = transforms + self.align_corners = align_corners + + def __call__(self, x, mode="bilinear"): + b, c, h, w = x.shape + M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) + for t in self.transforms: + if np.random.rand() < t.p: + M = M.matmul( + t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None) + ) + return ( + warp_perspective( + x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners + ), + M, + ) + + def apply_transform(self, x, M, mode="bilinear"): + b, c, h, w = x.shape + return warp_perspective( + x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode + ) + + +class RandomPerspective(K.RandomPerspective): + def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: + distortion_scale = torch.as_tensor( + self.distortion_scale, device=self._device, dtype=self._dtype + ) + return self.random_perspective_generator( + batch_shape[0], + batch_shape[-2], + batch_shape[-1], + distortion_scale, + self.same_on_batch, + self.device, + self.dtype, + ) + + def random_perspective_generator( + self, + batch_size: int, + height: int, + width: int, + distortion_scale: torch.Tensor, + same_on_batch: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Dict[str, torch.Tensor]: + r"""Get parameters for ``perspective`` for a random perspective transform. + + Args: + batch_size (int): the tensor batch size. + height (int) : height of the image. + width (int): width of the image. + distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. + same_on_batch (bool): apply the same transformation across the batch. Default: False. + device (torch.device): the device on which the random numbers will be generated. Default: cpu. + dtype (torch.dtype): the data type of the generated random numbers. Default: float32. + + Returns: + params Dict[str, torch.Tensor]: parameters to be passed for transformation. + - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). + - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). + + Note: + The generated random numbers are not reproducible across different devices and dtypes. + """ + if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): + raise AssertionError( + f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." + ) + if not ( + type(height) is int and height > 0 and type(width) is int and width > 0 + ): + raise AssertionError( + f"'height' and 'width' must be integers. Got {height}, {width}." + ) + + start_points: torch.Tensor = torch.tensor( + [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], + device=distortion_scale.device, + dtype=distortion_scale.dtype, + ).expand(batch_size, -1, -1) + + # generate random offset not larger than half of the image + fx = distortion_scale * width / 2 + fy = distortion_scale * height / 2 + + factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) + offset = (torch.rand_like(start_points) - 0.5) * 2 + end_points = start_points + factor * offset + + return dict(start_points=start_points, end_points=end_points) + + + +class RandomErasing: + def __init__(self, p = 0., scale = 0.) -> None: + self.p = p + self.scale = scale + self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p) + def __call__(self, image, depth): + if self.p > 0: + image = self.random_eraser(image) + depth = self.random_eraser(depth, params=self.random_eraser._params) + return image, depth + \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/utils.py b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a6691458ea801390c89ba4db47a2b53283c97691 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/utils.py @@ -0,0 +1,661 @@ +import warnings +import numpy as np +import cv2 +import math +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torch.nn.functional as F +from PIL import Image + +def resize_by_longest_edge_and_stretch(img, aim_long_edge_size): + """ + img: C * H * W, torch.tensor + aim_long_edge_size: int + """ + c, h, w = img.shape + + hs, ws = aim_long_edge_size, aim_long_edge_size + return resize_and_padding(img, (hs, ws), padding=False) + +def resize_by_longest_edge_and_padding(img, aim_long_edge_size): + """ + img: C * H * W, torch.tensor + aim_long_edge_size: int + """ + c, h, w = img.shape + + scale = aim_long_edge_size / max(h, w) + hs, ws = round(h * scale), round(w * scale) + return resize_and_padding(img, (hs, ws)) + +def resize_and_padding(img, resize, padding=True): + """ + img: C * H * W, torch.tensor + resize: aim (h, w) + """ + c, h_org, w_org = img.shape + # img_resized = transforms.Resize(resize, InterpolationMode.BILINEAR)(img) + img_resized = transforms.Resize(resize, InterpolationMode.BICUBIC)(img) + + if padding: + img_padded = torch.zeros((c, max(resize), max(resize)), device=img.device) + img_padded[:, :resize[0], :resize[1]] = img_resized + return img_padded + else: + return img_resized + +def recover_pose(E, kpts0, kpts1, K0, K1, mask): + best_num_inliers = 0 + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + + + +# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py +# --- GEOMETRY --- +def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T + E, mask = cv2.findEssentialMat( + kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf + ) + + ret = None + if E is not None: + best_num_inliers = 0 + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + +def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): + if len(kpts0) < 5: + return None + method = cv2.USAC_ACCURATE + F, mask = cv2.findFundamentalMat( + kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000 + ) + E = K1.T@F@K0 + ret = None + if E is not None: + best_num_inliers = 0 + K0inv = np.linalg.inv(K0[:2,:2]) + K1inv = np.linalg.inv(K1[:2,:2]) + + kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T + kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T + + for _E in np.split(E, len(E) / 3): + n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) + if n > best_num_inliers: + best_num_inliers = n + ret = (R, t, mask.ravel() > 0) + return ret + +def unnormalize_coords(x_n,h,w): + x = torch.stack( + (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + return x + + +def rotate_intrinsic(K, n): + base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + rot = np.linalg.matrix_power(base_rot, n) + return rot @ K + + +def rotate_pose_inplane(i_T_w, rot): + rotation_matrices = [ + np.array( + [ + [np.cos(r), -np.sin(r), 0.0, 0.0], + [np.sin(r), np.cos(r), 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] + ] + return np.dot(rotation_matrices[rot], i_T_w) + + +def scale_intrinsics(K, scales): + scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) + return np.dot(scales, K) + + +def to_homogeneous(points): + return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) + + +def angle_error_mat(R1, R2): + cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 + cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds + return np.rad2deg(np.abs(np.arccos(cos))) + + +def angle_error_vec(v1, v2): + n = np.linalg.norm(v1) * np.linalg.norm(v2) + return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) + + +def compute_pose_error(T_0to1, R, t): + R_gt = T_0to1[:3, :3] + t_gt = T_0to1[:3, 3] + error_t = angle_error_vec(t.squeeze(), t_gt) + error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation + error_R = angle_error_mat(R, R_gt) + return error_t, error_R + + +def pose_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.trapz(r, x=e) / t) + return aucs + + +# From Patch2Pix https://github.com/GrumpyZhou/patch2pix +def get_depth_tuple_transform_ops_nearest_exact(resize=None): + ops = [] + if resize: + ops.append(TupleResizeNearestExact(resize)) + return TupleCompose(ops) + +def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): + ops = [] + if resize: + ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) + return TupleCompose(ops) + + +def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None): + ops = [] + if resize: + ops.append(TupleResize(resize)) + ops.append(TupleToTensorScaled()) + if normalize: + ops.append( + TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) # Imagenet mean/std + return TupleCompose(ops) + +class ToTensorScaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" + + def __call__(self, im): + if not isinstance(im, torch.Tensor): + im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) + im /= 255.0 + return torch.from_numpy(im) + else: + return im + + def __repr__(self): + return "ToTensorScaled(./255)" + + +class TupleToTensorScaled(object): + def __init__(self): + self.to_tensor = ToTensorScaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorScaled(./255)" + + +class ToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __call__(self, im): + return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) + + def __repr__(self): + return "ToTensorUnscaled()" + + +class TupleToTensorUnscaled(object): + """Convert a RGB PIL Image to a CHW ordered Tensor""" + + def __init__(self): + self.to_tensor = ToTensorUnscaled() + + def __call__(self, im_tuple): + return [self.to_tensor(im) for im in im_tuple] + + def __repr__(self): + return "TupleToTensorUnscaled()" + +class TupleResizeNearestExact: + def __init__(self, size): + self.size = size + def __call__(self, im_tuple): + return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple] + + def __repr__(self): + return "TupleResizeNearestExact(size={})".format(self.size) + + +class TupleResize(object): + def __init__(self, size, mode=InterpolationMode.BICUBIC): + self.size = size + self.resize = transforms.Resize(size, mode) + def __call__(self, im_tuple): + return [self.resize(im) for im in im_tuple] + + def __repr__(self): + return "TupleResize(size={})".format(self.size) + +class Normalize: + def __call__(self,im): + mean = im.mean(dim=(1,2), keepdims=True) + std = im.std(dim=(1,2), keepdims=True) + return (im-mean)/std + + +class TupleNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + self.normalize = transforms.Normalize(mean=mean, std=std) + + def __call__(self, im_tuple): + c,h,w = im_tuple[0].shape + if c > 3: + warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb") + return [self.normalize(im[:3]) for im in im_tuple] + + def __repr__(self): + return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) + + +class TupleCompose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, im_tuple): + for t in self.transforms: + im_tuple = t(im_tuple) + return im_tuple + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + +@torch.no_grad() +def cls_to_flow(cls, deterministic_sampling = True): + B,C,H,W = cls.shape + device = cls.device + res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) + G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + if deterministic_sampling: + sampled_cls = cls.max(dim=1).indices + else: + sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W) + flow = G[sampled_cls] + return flow + +@torch.no_grad() +def cls_to_flow_refine(cls): + B,C,H,W = cls.shape + device = cls.device + res = round(math.sqrt(C)) + G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) + G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) + cls = cls.softmax(dim=1) + mode = cls.max(dim=1).indices + + index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long() + neighbours = torch.gather(cls, dim = 1, index = index)[...,None] + flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]] + tot_prob = neighbours.sum(dim=1) + flow = flow / tot_prob + return flow + + +def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): + + if H is None: + B,H,W = depth1.shape + else: + B = depth1.shape[0] + with torch.no_grad(): + x1_n = torch.meshgrid( + *[ + torch.linspace( + -1 + 1 / n, 1 - 1 / n, n, device=depth1.device + ) + for n in (B, H, W) + ] + ) + x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) + mask, x2 = warp_kpts( + x1_n.double(), + depth1.double(), + depth2.double(), + T_1to2.double(), + K1.double(), + K2.double(), + depth_interpolation_mode = depth_interpolation_mode, + relative_depth_error_threshold = relative_depth_error_threshold, + ) + prob = mask.float().reshape(B, H, W) + x2 = x2.reshape(B, H, W, 2) + return x2, prob + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): + """Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here + Args: + kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + ( + n, + h, + w, + ) = depth0.shape + if depth_interpolation_mode == "combined": + # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation + if smooth_mask: + raise NotImplementedError("Combined bilinear and NN warp not implemented") + valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "bilinear", + relative_depth_error_threshold = relative_depth_error_threshold) + valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, + smooth_mask = smooth_mask, + return_relative_depth_error = return_relative_depth_error, + depth_interpolation_mode = "nearest-exact", + relative_depth_error_threshold = relative_depth_error_threshold) + nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) + warp = warp_bilinear.clone() + warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] + valid = valid_bilinear | valid_nearest + return valid, warp + + + kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ + :, 0, :, 0 + ] + kpts0 = torch.stack( + (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 + ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] + # Sample depth, get calculable_mask on depth != 0 + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = ( + torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) + * kpts0_depth[..., None] + ) # (N, L, 3) + kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + kpts0_cam = kpts0_n + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / ( + w_kpts0_h[:, :, [2]] + 1e-4 + ) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = ( + (w_kpts0[:, :, 0] > 0) + * (w_kpts0[:, :, 0] < w - 1) + * (w_kpts0[:, :, 1] > 0) + * (w_kpts0[:, :, 1] < h - 1) + ) + w_kpts0 = torch.stack( + (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 + ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] + # w_kpts0[~covisible_mask, :] = -5 # xd + + w_kpts0_depth = F.grid_sample( + depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False + )[:, 0, :, 0] + + relative_depth_error = ( + (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth + ).abs() + if not smooth_mask: + consistent_mask = relative_depth_error < relative_depth_error_threshold + else: + consistent_mask = (-relative_depth_error/smooth_mask).exp() + valid_mask = nonzero_mask * covisible_mask * consistent_mask + if return_relative_depth_error: + return relative_depth_error, w_kpts0 + else: + return valid_mask, w_kpts0 + +imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) +imagenet_std = torch.tensor([0.229, 0.224, 0.225]) + + +def numpy_to_pil(x: np.ndarray): + """ + Args: + x: Assumed to be of shape (h,w,c) + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if x.max() <= 1.01: + x *= 255 + x = x.astype(np.uint8) + return Image.fromarray(x) + + +def tensor_to_pil(x, unnormalize=False): + if unnormalize: + x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) + x = x.detach().permute(1, 2, 0).cpu().numpy() + x = np.clip(x, 0.0, 1.0) + return numpy_to_pil(x) + + +def to_cuda(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cuda() + return batch + + +def to_cpu(batch): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.cpu() + return batch + + +def get_pose(calib): + w, h = np.array(calib["imsize"])[0] + return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w + + +def compute_relative_pose(R1, t1, R2, t2): + rots = R2 @ (R1.T) + trans = -rots @ t1 + t2 + return rots, trans + +@torch.no_grad() +def reset_opt(opt): + for group in opt.param_groups: + for p in group['params']: + if p.requires_grad: + state = opt.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + +def flow_to_pixel_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + w1 * (flow[..., 0] + 1) / 2, + h1 * (flow[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return flow + +to_pixel_coords = flow_to_pixel_coords # just an alias + +def flow_to_normalized_coords(flow, h1, w1): + flow = ( + torch.stack( + ( + 2 * (flow[..., 0]) / w1 - 1, + 2 * (flow[..., 1]) / h1 - 1, + ), + axis=-1, + ) + ) + return flow + +to_normalized_coords = flow_to_normalized_coords # just an alias + +def warp_to_pixel_coords(warp, h1, w1, h2, w2): + warp1 = warp[..., :2] + warp1 = ( + torch.stack( + ( + w1 * (warp1[..., 0] + 1) / 2, + h1 * (warp1[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + warp2 = warp[..., 2:] + warp2 = ( + torch.stack( + ( + w2 * (warp2[..., 0] + 1) / 2, + h2 * (warp2[..., 1] + 1) / 2, + ), + axis=-1, + ) + ) + return torch.cat((warp1,warp2), dim=-1) + + + +def signed_point_line_distance(point, line, eps: float = 1e-9): + r"""Return the distance from points to lines. + + Args: + point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`. + line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`. + eps: Small constant for safe sqrt. + + Returns: + the computed distance with shape :math:`(*, N)`. + """ + + if not point.shape[-1] in (2, 3): + raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}") + + if not line.shape[-1] == 3: + raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}") + + numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]) + denominator = line[..., :2].norm(dim=-1) + + return numerator / (denominator + eps) + + +def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): + r"""Return one-sided epipolar distance for correspondences given the fundamental matrix. + + This method measures the distance from points in the right images to the epilines + of the corresponding points in the left images as they reflect in the right images. + + Args: + pts1: correspondences from the left images with shape + :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. + pts2: correspondences from the right images with shape + :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. + Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to + avoid ambiguity with torch.nn.functional. + + Returns: + the computed Symmetrical distance with shape :math:`(*, N)`. + """ + import kornia + if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3): + raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}") + + if pts1.shape[-1] == 2: + pts1 = kornia.geometry.convert_points_to_homogeneous(pts1) + + F_t = Fm.transpose(dim0=-2, dim1=-1) + line1_in_2 = pts1 @ F_t + + return signed_point_line_distance(pts2, line1_in_2) + +def get_grid(b, h, w, device): + grid = torch.meshgrid( + *[ + torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) + for n in (b, h, w) + ] + ) + grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2) + return grid diff --git a/imcui/third_party/MatchAnything/third_party/ROMA/setup.py b/imcui/third_party/MatchAnything/third_party/ROMA/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..ae777c0e5a41f0e4b03a838d19bc9a2bb04d4617 --- /dev/null +++ b/imcui/third_party/MatchAnything/third_party/ROMA/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup( + name="roma", + packages=["roma"], + version="0.0.1", + author="Johan Edstedt", + install_requires=open("requirements.txt", "r").read().split("\n"), +) diff --git a/imcui/third_party/MatchAnything/tools/__init__.py b/imcui/third_party/MatchAnything/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imcui/third_party/MatchAnything/tools/evaluate_datasets.py b/imcui/third_party/MatchAnything/tools/evaluate_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..9ebbf2aa096367611491f3846734ff7cd9320928 --- /dev/null +++ b/imcui/third_party/MatchAnything/tools/evaluate_datasets.py @@ -0,0 +1,240 @@ +import argparse +import pytorch_lightning as pl +from tqdm import tqdm +import os.path as osp +import numpy as np +from loguru import logger +from PIL import Image +Image.MAX_IMAGE_PIXELS = None +import torch + +from torch.utils.data import ( + DataLoader, + ConcatDataset) + +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent.resolve())) + +from src.lightning.lightning_loftr import PL_LoFTR +from src.config.default import get_cfg_defaults +from src.utils.dataset import dict_to_cuda +from src.utils.metrics import estimate_homo, estimate_pose, relative_pose_error +from src.utils.homography_utils import warp_points + +from src.datasets.common_data_pair import CommonDataset +from src.utils.metrics import error_auc +from tools_utils.plot import plot_matches, warp_img_and_blend, epipolar_error + +from pairs_match_and_propogation.utils.data_io import save_h5 + +def parse_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + 'main_cfg_path', type=str, help='main config path') + parser.add_argument( + '--ckpt_path', type=str, default="", help='path to the checkpoint') + parser.add_argument( + '--thr', type=float, default=0.1, help='modify the coarse-level matching threshold.') + parser.add_argument( + '--method', type=str, default='loftr@-@ransac_affine', help='choose method') + parser.add_argument( + '--imgresize', type=int, default=None) + parser.add_argument( + '--npe', action='store_true', default=False, help='') + parser.add_argument( + '--npe2', action='store_true', default=False, help='') + parser.add_argument( + '--ckpt32', action='store_true', default=False, help='') + parser.add_argument( + '--fp32', action='store_true', default=False, help='') + + # Input: + parser.add_argument( + '--data_root', type=str, default="data/test_data") + + parser.add_argument( + '--npz_root', type=str, default="") + + parser.add_argument( + '--npz_list_path', type=str, default="") + + parser.add_argument( + '--plot_matches', action='store_true') + + parser.add_argument( + '--plot_matches_alpha', type=float, default=0.2) + + parser.add_argument( + '--plot_matches_color', type=str, default='error', choices=['green', 'error', 'conf']) + + parser.add_argument( + '--plot_align', action='store_true') + parser.add_argument( + '--plot_refinement', action='store_true') + parser.add_argument( + '--output_path', type=str, default="") + + parser.add_argument( + '--rigid_ransac_thr', type=float, default=3.0) + parser.add_argument( + '--elastix_ransac_thr', type=float, default=40.0) + parser.add_argument( + '--comment', type=str, default="") + + return parser.parse_args() + +def array_rgb2gray(img): + return (img * np.array([0.2989, 0.5870, 0.1140])[None, None]).sum(axis=-1) + +if __name__ == '__main__': + args = parse_args() + + # Load data: + datasets = [] + sub_dataset_name = Path(args.npz_list_path).parent.name + with open(args.npz_list_path, 'r') as f: + npz_names = [name.split()[0] for name in f.readlines()] + npz_names = [f'{n}.npz' for n in npz_names] + data_root = args.data_root + + vis_output_path = args.output_path + Path(vis_output_path).mkdir(parents=True, exist_ok=True) + + ########################## + config = get_cfg_defaults() + method, estimator = (args.method).split('@-@')[0], (args.method).split('@-@')[1] + if method != 'None': + config.merge_from_file(args.main_cfg_path) + + pl.seed_everything(config.TRAINER.SEED) + config.METHOD = method + # Config overwrite: + if config.LOFTR.COARSE.ROPE: + assert config.DATASET.NPE_NAME is not None + if config.DATASET.NPE_NAME is not None: + config.LOFTR.COARSE.NPE = [832, 832, args.imgresize, args.imgresize] + + if "visible_sar" in args.npz_list_path: + config.DATASET.RESIZE_BY_STRETCH = True + + if args.thr is not None: + config.LOFTR.MATCH_COARSE.THR = args.thr + + matcher = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, test_mode=True).matcher + matcher.eval().cuda() + else: + matcher = None + + for npz_name in tqdm(npz_names): + npz_path = osp.join(args.npz_root, npz_name) + try: + np.load(npz_path, allow_pickle=True) + except: + logger.info(f"{npz_path} cannot be opened!") + continue + + datasets.append( + CommonDataset(data_root, npz_path, mode='test', min_overlap_score=-1, img_resize=args.imgresize, df=None, img_padding=False, depth_padding=True, testNpairs=None, fp16=False, load_origin_rgb=True, read_gray=True, normalize_img=False, resize_by_stretch=config.DATASET.RESIZE_BY_STRETCH, gt_matches_padding_n=100)) + + concat_dataset = ConcatDataset(datasets) + + dataloader = DataLoader(concat_dataset, num_workers=4, pin_memory=True, batch_size=1, drop_last=False) + errors = [] # distance + result_dict = {} + pose_error = [] + + eval_mode = 'gt_homo' + for id, data in enumerate(tqdm(dataloader)): + img0, img1 = (data['image0_rgb_origin'] * 255.)[0].permute(1,2,0).numpy().round().squeeze(), (data['image1_rgb_origin'] * 255.)[0].permute(1,2,0).numpy().round().squeeze() + img_1_h, img_1_w = img1.shape[:2] + pair_name = '@-@'.join([data['pair_names'][0][0].split('/', 1)[1], data['pair_names'][1][0].split('/', 1)[1]]).replace('/', '_') + homography_gt = data['homography'][0].numpy() + if 'gt_2D_matches' in data and data["gt_2D_matches"].shape[-1] == 4: + gt_2D_matches = data["gt_2D_matches"][0].numpy() # N * 4 + eval_coord = gt_2D_matches[:, :2] + gt_points = gt_2D_matches[:, 2:] + eval_mode = 'gt_match' + ransac_mode = 'homo' if 'FIRE' in args.npz_list_path else 'affine' + elif homography_gt.sum() != 0: + h, w = img0.shape[0], img0.shape[1] + eval_coord = np.array([[0, 0], [0, h], [w, 0], [w, h]]) + ransac_mode = 'affine' + assert homography_gt.sum() != 0, f"Evaluation should either using gt match, or using gt homography warp." + else: + eval_mode = 'pose_error' + K0 = data['K0'].cpu().numpy()[0] + K1 = data['K1'].cpu().numpy()[0] + T_0to1 = data['T_0to1'].cpu().numpy()[0] + estimator = 'pose' + + # Perform matching + if matcher is not None: + if eval_mode in ['gt_match']: + data.update({'query_points': torch.from_numpy(eval_coord)[None]}) + batch = dict_to_cuda(data) + + with torch.no_grad(): + with torch.autocast(enabled=config.LOFTR.FP16, device_type='cuda'): + matcher(batch) + + mkpts0 = batch['mkpts0_f'].cpu().numpy() + mkpts1 = batch['mkpts1_f'].cpu().numpy() + mconf = batch['mconf'].cpu().numpy() + + # Get warpped points by homography: + if estimator == "ransac_affine": + H_est, _ = estimate_homo(mkpts0, mkpts1, thresh=args.rigid_ransac_thr, mode=ransac_mode) + # Warp points for eval: + eval_points_warpped = warp_points(eval_coord, H_est, inverse=False) + + # Warp images and blend: + if args.plot_align: + warp_img_and_blend(img0, img1, H_est, save_path=Path(vis_output_path)/'aligned'/f"{pair_name}_{args.method}.png", alpha=0.5, inverse=True) + elif estimator == 'pose': + pose = estimate_pose(mkpts0, mkpts1, K0, K1, args.rigid_ransac_thr, conf=0.99999) + else: + raise NotImplementedError + else: + raise NotImplementedError + + if eval_mode == 'pose_error': + if pose is None: + t_err, R_err = np.inf, np.inf + else: + R, t, inliers = pose + t_err, R_err = relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0) + error = max(t_err, R_err) + errors.append(error) + match_error = epipolar_error(mkpts0, mkpts1, T_0to1, K0, K1) + plot_text = f"R_err_{R_err:.2}_t_err_{t_err:.2}" + thr = 3e-3 + print(f"max_error:{error}") + else: + if eval_mode == 'gt_homo': + gt_points = warp_points(eval_coord, homography_gt, inverse=False) + match_error = np.linalg.norm(warp_points(mkpts0, homography_gt, inverse=False) - mkpts1, axis=-1) + else: + match_error = None + + thr = 5 # Pix + error = np.mean(np.linalg.norm(eval_points_warpped - gt_points, axis=1)) + print(f"error: {error}") + errors.append(error) + + result_dict['@-@'.join([data['pair_names'][0][0].split('/', 1)[1], data['pair_names'][1][0].split('/', 1)[1]])] = error + + if args.plot_matches and matcher is not None: + draw_match_type='corres' + color_type=args.plot_matches_color + plot_matches(img0, img1, mkpts0, mkpts1, mconf, vertical=False, draw_match_type=draw_match_type, alpha=args.plot_matches_alpha, save_path=Path(vis_output_path)/'demo_matches'/f"{pair_name}_{draw_match_type}.pdf", inverse=False, match_error=match_error if color_type == 'error' else None, error_thr=thr, color_type=color_type) + + # Success Rate Metric: + metric = error_auc(np.array(errors), thresholds=[5,10,20], method="success_rate") + print(metric) + + # AUC Metric: + metric = error_auc(np.array(errors), thresholds=[5,10,20], method='fire_paper' if 'FIRE' in args.npz_list_path else 'exact_auc') + print(metric) + + save_h5(result_dict, (Path(args.output_path) / f'eval_{sub_dataset_name}_{args.method}_{args.comment}_error.h5')) \ No newline at end of file diff --git a/imcui/third_party/MatchAnything/tools/tools_utils/plot.py b/imcui/third_party/MatchAnything/tools/tools_utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..a7eefeb312b3954a8d05a67a264efc0e5e7713ad --- /dev/null +++ b/imcui/third_party/MatchAnything/tools/tools_utils/plot.py @@ -0,0 +1,77 @@ +import matplotlib +matplotlib.use("agg") +import matplotlib.cm as cm +import numpy as np +from PIL import Image +import cv2 +from kornia.geometry.epipolar import numeric +import torch + +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent.parent.resolve())) + +from src.utils.plotting import error_colormap, dynamic_alpha +from src.utils.metrics import symmetric_epipolar_distance +from notebooks.notebooks_utils import make_matching_figure + +def plot_matches(img0_origin, img1_origin, mkpts0, mkpts1, mconf, vertical, draw_match_type, alpha, save_path, inverse=False, match_error=None, error_thr=5e-3, color_type='error'): + if inverse: + img0_origin, img1_origin, mkpts0, mkpts1 = img1_origin, img0_origin, mkpts1, mkpts0 + img0_origin = np.copy(img0_origin) / 255.0 + img1_origin = np.copy(img1_origin) / 255.0 + # Draw + alpha =dynamic_alpha(len(mkpts0), milestones=[0, 200, 500, 1000, 2000, 4000], alphas=[1.0, 0.5, 0.3, 0.2, 0.15, 0.09]) + if color_type == 'conf': + color = error_colormap(mconf, thr=None, alpha=alpha) + elif color_type == 'green': + mconf = np.ones_like(mconf) * 0.15 + color = error_colormap(mconf, thr=None, alpha=alpha) + else: + color = error_colormap(np.zeros((len(mconf),)) if match_error is None else match_error, error_thr, alpha=alpha) + + text = [ + '' + ] + + Path(save_path).parent.mkdir(parents=True, exist_ok=True) + fig = make_matching_figure(img0_origin, img1_origin, mkpts0, mkpts1, color, text=text, path=save_path, vertical=vertical, plot_size_factor=1, draw_match_type=draw_match_type, r_normalize_factor=0.4) + +def blend_img(img0, img1, alpha=0.4, save_path=None, blend_method='weighted_sum'): + img0, img1 = Image.fromarray(np.array(img0)), Image.fromarray(np.array(img1)) + Path(save_path).parent.mkdir(parents=True, exist_ok=True) + + # Blend: + if blend_method == 'weighted_sum': + blended_img = Image.blend(img0, img1, alpha=alpha) + else: + raise NotImplementedError + + blended_img.save(save_path) + +def warp_img(img0, img1, H, fill_white=False): + img0 = np.copy(img0).astype(np.uint8) + img1 = np.copy(img1).astype(np.uint8) + if fill_white: + img0_warped = cv2.warpAffine(np.array(img0), H[:2, :], [img1.shape[1], img1.shape[0]], flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=[255, 255, 255]) + else: + img0_warped = cv2.warpAffine(np.array(img0), H[:2, :], [img1.shape[1], img1.shape[0]], flags=cv2.INTER_LINEAR) + return img0_warped + +def warp_img_and_blend(img0_origin, img1_origin, H, save_path, alpha=0.4, inverse=False): + if inverse: + img0_origin, img1_origin = img1_origin, img0_origin + H = np.linalg.inv(H) + img0_origin = np.copy(img0_origin).astype(np.uint8) + img1_origin = np.copy(img1_origin).astype(np.uint8) + + # Warp + img0_warpped = Image.fromarray(warp_img(img0_origin, img1_origin, H, fill_white=False)) + + # Blend and save: + blend_img(img0_warpped, Image.fromarray(img1_origin), alpha=alpha, save_path=save_path) + +def epipolar_error(mkpts0, mkpts1, T_0to1, K0, K1): + Tx = numeric.cross_product_matrix(torch.from_numpy(T_0to1)[:3, 3]) + E_mat = Tx @ T_0to1[:3, :3] + return symmetric_epipolar_distance(torch.from_numpy(mkpts0), torch.from_numpy(mkpts1), E_mat, torch.from_numpy(K0), torch.from_numpy(K1)).numpy() \ No newline at end of file diff --git a/imcui/ui/app_class.py b/imcui/ui/app_class.py index f98f98575c4437b1d6cbff2d8eb165cfb7061f08..21c12f79bb545ca606015774aa68f2a661b8037a 100644 --- a/imcui/ui/app_class.py +++ b/imcui/ui/app_class.py @@ -20,11 +20,13 @@ from .utils import ( send_to_match, ) import os -GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN') -GOOGLE_TOKEN = os.environ.get('GOOGLE_TOKEN') +# GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN') +# GOOGLE_TOKEN = os.environ.get('GOOGLE_TOKEN') +GOOGLE_TOKEN = "12L3g9-w8rR9K2L4rYaGaDJ7NqX1D713d" if not (Path(__file__).parent / "../third_party/MatchAnything").exists(): print("**********************************") - os.system(f"cd {str(Path(__file__).parent / '../third_party')} && git clone https://{GITHUB_TOKEN}@github.com/hxy-123/MatchAnything_HF.git && mv MatchAnything_HF MatchAnything && cd MatchAnything && gdown {GOOGLE_TOKEN} && unzip weights.zip") + # os.system(f"cd {str(Path(__file__).parent / '../third_party')} && git clone https://{GITHUB_TOKEN}@github.com/hxy-123/MatchAnything_HF.git && mv MatchAnything_HF MatchAnything && cd MatchAnything && gdown {GOOGLE_TOKEN} && unzip weights.zip") + os.system(f"cd {str(Path(__file__).parent / '../third_party')} && cd MatchAnything && gdown {GOOGLE_TOKEN} && unzip weights.zip") DESCRIPTION = '''
MatchAnything