diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py
index 9fb76eddd29347be56be162afc346b0ab9bb934a..1bd6188faa8ac8bfa647e6d5bcb3a9dfc07a2f30 100644
--- a/hloc/extractors/sfd2.py
+++ b/hloc/extractors/sfd2.py
@@ -1,4 +1,3 @@
-# -*- coding: UTF-8 -*-
import sys
from pathlib import Path
@@ -7,10 +6,9 @@ import torchvision.transforms as tvf
from .. import logger
from ..utils.base_model import BaseModel
-pram_path = Path(__file__).parent / "../../third_party/pram"
-sys.path.append(str(pram_path))
-
-from nets.sfd2 import load_sfd2
+tp_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(tp_path))
+from pram.nets.sfd2 import load_sfd2
class SFD2(BaseModel):
@@ -26,8 +24,8 @@ class SFD2(BaseModel):
self.norm_rgb = tvf.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
- model_fn = pram_path / "weights" / self.conf["model_name"]
- self.net = load_sfd2(weight_path=model_fn).eval()
+ model_path = tp_path / "pram" / "weights" / self.conf["model_name"]
+ self.net = load_sfd2(weight_path=model_path).eval()
logger.info("Load SFD2 model done.")
diff --git a/hloc/matchers/eloftr.py b/hloc/matchers/eloftr.py
index 2c1e6245eb720c5b3545f9e2f5d2a6a5a93cb95b..d22906de8bf7cc912745c21b950458829dee5d19 100644
--- a/hloc/matchers/eloftr.py
+++ b/hloc/matchers/eloftr.py
@@ -5,18 +5,22 @@ from pathlib import Path
import torch
-eloftr_path = Path(__file__).parent / "../../third_party/EfficientLoFTR"
-sys.path.append(str(eloftr_path))
+tp_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(tp_path))
-from src.loftr import LoFTR as ELoFTR_
-from src.loftr import full_default_cfg, opt_default_cfg, reparameter
+from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_
+from EfficientLoFTR.src.loftr import (
+ full_default_cfg,
+ opt_default_cfg,
+ reparameter,
+)
from hloc import logger
from ..utils.base_model import BaseModel
-class LoFTR(BaseModel):
+class ELoFTR(BaseModel):
default_conf = {
"weights": "weights/eloftr_outdoor.ckpt",
"match_threshold": 0.2,
@@ -40,7 +44,7 @@ class LoFTR(BaseModel):
_default_cfg["mp"] = True
elif self.conf["precision"] == "fp16":
_default_cfg["half"] = True
- model_path = eloftr_path / self.conf["weights"]
+ model_path = tp_path / "EfficientLoFTR" / self.conf["weights"]
cfg = _default_cfg
cfg["match_coarse"]["thr"] = conf["match_threshold"]
# cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py
index ca64980ef70c52672806476fdc65bb4d39479f10..05c3cb96b05410985ca97f89d8fe55a4d71be501 100644
--- a/hloc/matchers/imp.py
+++ b/hloc/matchers/imp.py
@@ -1,4 +1,3 @@
-# -*- coding: UTF-8 -*-
import sys
from pathlib import Path
@@ -7,10 +6,9 @@ import torch
from .. import DEVICE, logger
from ..utils.base_model import BaseModel
-pram_path = Path(__file__).parent / "../../third_party/pram"
-sys.path.append(str(pram_path))
-
-from nets.gml import GML
+tp_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(tp_path))
+from pram.nets.gml import GML
class IMP(BaseModel):
@@ -33,7 +31,8 @@ class IMP(BaseModel):
def _init(self, conf):
self.conf = {**self.default_conf, **conf}
- weight_path = pram_path / "weights" / self.conf["model_name"]
+ weight_path = tp_path / "pram" / "weights" / self.conf["model_name"]
+ # self.net = nets.gml(self.conf).eval().to(DEVICE)
self.net = GML(self.conf).eval().to(DEVICE)
self.net.load_state_dict(
torch.load(weight_path, map_location="cpu")["model"], strict=True
diff --git a/third_party/pram/.gitignore b/third_party/pram/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e76db3ee25df1858b0cec129d3e7c0eb84637c09
--- /dev/null
+++ b/third_party/pram/.gitignore
@@ -0,0 +1,13 @@
+.idea
+__pycache__
+weights/12scenes*
+weights/7scenes*
+weights/aachen*
+weights/cambridgelandmarks*
+weights/imp_adagml.80.pth
+landmarks
+3D-models
+log_*
+*.log
+.nfs*
+Pangolin
diff --git a/third_party/pram/LICENSE b/third_party/pram/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0bde2a83689b0ae97269181bc848fd581d23e828
--- /dev/null
+++ b/third_party/pram/LICENSE
@@ -0,0 +1,2 @@
+This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License.
+To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/.
diff --git a/third_party/pram/README.md b/third_party/pram/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b8ceb745c82fd44f1ef2c1808ab3993fb4d3890c
--- /dev/null
+++ b/third_party/pram/README.md
@@ -0,0 +1,207 @@
+## PRAM: Place Recognition Anywhere Model for Efficient Visual Localization
+
+
+
+
+
+Humans localize themselves efficiently in known environments by first recognizing landmarks defined on certain objects
+and their spatial relationships, and then verifying the location by aligning detailed structures of recognized objects
+with those in the memory. Inspired by this, we propose the place recognition anywhere model (PRAM) to perform visual
+localization as efficiently as humans do. PRAM consists of two main components - recognition and registration. In
+detail, first of all, a self-supervised map-centric landmark definition strategy is adopted, making places in either
+indoor or outdoor scenes act as unique landmarks. Then, sparse keypoints extracted from images, are utilized as the
+input to a transformer-based deep neural network for landmark recognition; these keypoints enable PRAM to recognize
+hundreds of landmarks with high time and memory efficiency. Keypoints along with recognized landmark labels are further
+used for registration between query images and the 3D landmark map. Different from previous hierarchical methods, PRAM
+discards global and local descriptors, and reduces over 90% storage. Since PRAM utilizes recognition and landmark-wise
+verification to replace global reference search and exhaustive matching respectively, it runs 2.4 times faster than
+prior state-of-the-art approaches. Moreover, PRAM opens new directions for visual localization including multi-modality
+localization, map-centric feature learning, and hierarchical scene coordinate regression.
+
+* Full paper
+ PDF: [Place Recognition Anywhere Model for Efficient Visual Localization](https://arxiv.org/pdf/2404.07785.pdf).
+
+* Authors: *Fei Xue, Ignas Budvytis, Roberto Cipolla*
+
+* Website: [PRAM](https://feixue94.github.io/pram-project) for videos, slides, recent updates, and datasets.
+
+## Key Features
+
+### 1. Self-supervised landmark definition on 3D space
+
+- No need of segmentations on images
+- No inconsistent semantic results from multi-view images
+- No limitation to labels of only known objects
+- Work in any places with known or unknown objects
+- Landmark-wise 3D map sparsification
+
+
+
+
+
+### 2. Efficient landmark-wise coarse and fine localization
+
+- Recognize landmarks as opposed to do global retrieval
+- Local landmark-wise matching as opposed to exhaustive matching
+- No global descriptors (e.g. NetVLAD)
+- No reference images and their heavy repetative 2D keypoints and descriptors
+- Automatic inlier/outlier idetification
+
+
+
+
+
+### 4. Sparse recognition
+
+- Sparse SFD2 keypoints as tokens
+- No uncertainties of points at boundaries
+- Flexible to accept multi-modality inputs
+
+### 5. Relocalization and temporal localization
+
+- Per frame reclocalization from scratch
+- Tracking previous frames for higher efficiency
+
+### 6. One model one dataset
+
+- All 7 subscenes in 7Scenes dataset share a model
+- All 12 subscenes in 12Scenes dataset share a model
+- All 5 subscenes in CambridgeLandmarks share a model
+
+### 7. Robust to long-term changes
+
+
+
+
+
+## Open problems
+
+- Adaptive number landmarks determination
+- Using SAM + open vocabulary to generate semantic map
+- Multi-modality localization with other tokenized signals (e.g. text, language, GPS, Magonemeter)
+- More effective solutions to 3D sparsification
+
+## Preparation
+
+1. Download the 7Scenes, 12Scenes, CambridgeLandmarks, and Aachen datasets (remove redundant depth images otherwise they
+ will be found in the sfm process)
+2. Environments
+
+2.1 Create a virtual environment
+
+```
+conda env create -f environment.yml
+(do not activate pram before pangolin is installed)
+```
+
+2.2 Compile Pangolin for the installed python
+
+```
+git clone --recursive https://github.com/stevenlovegrove/Pangolin.git
+cd Pangolin
+git checkout v0.8
+
+# Install dependencies
+./scripts/install_prerequisites.sh recommended
+
+# Compile with your python
+cmake -DPython_EXECUTABLE=/your path to/anaconda3/envs/pram/bin/python3 -B build
+cmake --build build -t pypangolin_pip_install
+
+conda activate pram
+```
+
+## Run the localization with online visualization
+
+1. Download the [3D-models](https://drive.google.com/drive/folders/1DUB073KxAjsc8lxhMpFuxPRf0ZBQS6NS?usp=drive_link),
+ pretrained [models](https://drive.google.com/drive/folders/1E2QvujCevqnyg_CM9FGAa0AxKkt4KbLD?usp=drive_link) ,
+ and [landmarks](https://drive.google.com/drive/folders/1r9src9bz7k3WYGfaPmKJ9gqxuvdfxZU0?usp=sharing)
+2. Put pretrained models in ```weights``` directory
+3. Run the demo (e.g. 7Scenes)
+
+```
+python3 inference.py --config configs/config_train_7scenes_sfd2.yaml --rec_weight_path weights/7scenes_nc113_birch_segnetvit.199.pth --landmark_path /your path to/landmarks --online
+```
+
+## Train the recognition model (e.g. for 7Scenes)
+
+### 1. Do SfM with SFD2 including feature extraction (modify the dataset_dir, ref_sfm_dir, output_dir)
+
+```
+./sfm_scripts/reconstruct_7scenes.sh
+```
+
+This step will produce the SfM results together with the extracted keypoints
+
+### 2. Generate 3D landmarks
+
+```
+python3 -m recognition.recmap --dataset 7Scenes --dataset_dir /your path to/7Scenes --sfm_dir /sfm_path/7Scenes --save_dir /save_path/landmakrs
+```
+
+This step will generate 3D landmarks, create virtual reference frame, and sparsify the 3D points for each landmark for
+all scenes in 7Scenes
+
+### 3. Train the sparse recognition model (one model one dataset)
+
+```
+python3 train.py --config configs/config_train_7scenes_sfd2.yaml
+```
+
+Remember to modify the paths in 'config_train_7scenes_sfd2.yaml'
+
+## Your own dataset
+
+1. Run colmap or hloc to obtain the SfM results
+2. Do reconstruction with SFD2 keypoints with the sfm from step as refernece sfm
+3. Do 3D landmark generation, VRF, map sparsification etc (Add DatasetName.yaml to configs/datasets)
+4. Train the recognition model
+5. Do evaluation
+
+## Previous works can be found here
+
+1. [Efficient large-scale localization by landmark recognition, CVPR 2022](https://github.com/feixue94/lbr)
+2. [IMP: Iterative Matching and Pose Estimation with Adaptive Pooling, CVPR 2023](https://github.com/feixue94/imp-release)
+3. [SFD2: Semantic-guided Feature Detection and Description, CVPR 2023](https://github.com/feixue94/sfd2)
+4. [VRS-NeRF: Visual Relocalization with Sparse Neural Radiance Field, under review](https://github.com/feixue94/vrs-nerf)
+
+## BibTeX Citation
+
+If you use any ideas from the paper or code in this repo, please consider citing:
+
+```
+ @article{xue2024pram,
+ author = {Fei Xue and Ignas Budvytis and Roberto Cipolla},
+ title = {PRAM: Place Recognition Anywhere Model for Efficient Visual Localization},
+ journal = {arXiv preprint arXiv:2404.07785},
+ year = {2024}
+ }
+
+@inproceedings{xue2023sfd2,
+ author = {Fei Xue and Ignas Budvytis and Roberto Cipolla},
+ title = {SFD2: Semantic-guided Feature Detection and Description},
+ booktitle = {CVPR},
+ year = {2023}
+}
+
+@inproceedings{xue2022imp,
+ author = {Fei Xue and Ignas Budvytis and Roberto Cipolla},
+ title = {IMP: Iterative Matching and Pose Estimation with Adaptive Pooling},
+ booktitle = {CVPR},
+ year = {2023}
+}
+
+@inproceedings{xue2022efficient,
+ author = {Fei Xue and Ignas Budvytis and Daniel Olmeda Reino and Roberto Cipolla},
+ title = {Efficient Large-scale Localization by Global Instance Recognition},
+ booktitle = {CVPR},
+ year = {2022}
+}
+```
+
+## Acknowledgements
+
+Part of the code is from previous excellent works
+including , [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork)
+and [hloc](https://github.com/cvg/Hierarchical-Localization). You can find more details from their released
+repositories if you are interested in their works.
\ No newline at end of file
diff --git a/third_party/pram/assets/map_sparsification.gif b/third_party/pram/assets/map_sparsification.gif
new file mode 100644
index 0000000000000000000000000000000000000000..63133a4b49805d0311aec8572fc10482f21d97f1
--- /dev/null
+++ b/third_party/pram/assets/map_sparsification.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd7bbe3b0bad7c6ae330eaa702b2839533a6f27ad5a0b104c4a37597c0c37aad
+size 493481
diff --git a/third_party/pram/assets/multi_recognition.png b/third_party/pram/assets/multi_recognition.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b12f484fb23daccd0bc83509db99fdf200fe79b
--- /dev/null
+++ b/third_party/pram/assets/multi_recognition.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c84e81cb990adedc25ef612b31d1ec53f7cb9f2168ef2246f2f03ca479cca9cf
+size 2460085
diff --git a/third_party/pram/assets/overview.png b/third_party/pram/assets/overview.png
new file mode 100644
index 0000000000000000000000000000000000000000..e5cc9c60f72a7590dace5db4e29eb848f0676b40
--- /dev/null
+++ b/third_party/pram/assets/overview.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:466b1f2b6a38cb956a389c1fc69c213c1655579c0c944174b6e95e247209eedc
+size 662283
diff --git a/third_party/pram/assets/pipeline1.png b/third_party/pram/assets/pipeline1.png
new file mode 100644
index 0000000000000000000000000000000000000000..780d9639033cb33aa765b571b486be9b96a44b9b
--- /dev/null
+++ b/third_party/pram/assets/pipeline1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bd0545bc3f4814d4b9f18893965529a08a73263e80a3978755162935e05d2b3
+size 3990973
diff --git a/third_party/pram/assets/pram_demo.gif b/third_party/pram/assets/pram_demo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..5200c873d71e32a1013a9213e5406a194e0462c8
--- /dev/null
+++ b/third_party/pram/assets/pram_demo.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95e56e33824789b650f4760b4246eca89c9cd1a8c138afc2d2ab5e24ec665fac
+size 14654499
diff --git a/third_party/pram/assets/sam_openvoc.png b/third_party/pram/assets/sam_openvoc.png
new file mode 100644
index 0000000000000000000000000000000000000000..aabb6e166dce60f09acbb2578e526eb573f7a1e4
--- /dev/null
+++ b/third_party/pram/assets/sam_openvoc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b3e0b06b6917402ed010cd4054e2efcf75c04ede84be53f17d147e2dd388d15a
+size 1148808
diff --git a/third_party/pram/colmap_utils/camera_intrinsics.py b/third_party/pram/colmap_utils/camera_intrinsics.py
new file mode 100644
index 0000000000000000000000000000000000000000..41bdc5055dfb451fa1f4dac3f27931675b68333f
--- /dev/null
+++ b/third_party/pram/colmap_utils/camera_intrinsics.py
@@ -0,0 +1,30 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File localizer -> camera_intrinsics
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 15/08/2023 12:33
+=================================================='''
+import numpy as np
+
+
+def intrinsics_from_camera(camera_model, params):
+ if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
+ fx = fy = params[0]
+ cx = params[1]
+ cy = params[2]
+ elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
+ fx = params[0]
+ fy = params[1]
+ cx = params[2]
+ cy = params[3]
+ else:
+ raise Exception("Camera model not supported")
+
+ # intrinsics
+ K = np.identity(3)
+ K[0, 0] = fx
+ K[1, 1] = fy
+ K[0, 2] = cx
+ K[1, 2] = cy
+ return K
diff --git a/third_party/pram/colmap_utils/database.py b/third_party/pram/colmap_utils/database.py
new file mode 100644
index 0000000000000000000000000000000000000000..37638347834f4b0b1432846adf9a83693b509a7f
--- /dev/null
+++ b/third_party/pram/colmap_utils/database.py
@@ -0,0 +1,352 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+# This script is based on an original implementation by True Price.
+
+import sys
+import sqlite3
+import numpy as np
+
+
+IS_PYTHON3 = sys.version_info[0] >= 3
+
+MAX_IMAGE_ID = 2**31 - 1
+
+CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
+ camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ model INTEGER NOT NULL,
+ width INTEGER NOT NULL,
+ height INTEGER NOT NULL,
+ params BLOB,
+ prior_focal_length INTEGER NOT NULL)"""
+
+CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
+
+CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
+ image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ name TEXT NOT NULL UNIQUE,
+ camera_id INTEGER NOT NULL,
+ prior_qw REAL,
+ prior_qx REAL,
+ prior_qy REAL,
+ prior_qz REAL,
+ prior_tx REAL,
+ prior_ty REAL,
+ prior_tz REAL,
+ CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
+ FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
+""".format(MAX_IMAGE_ID)
+
+CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
+CREATE TABLE IF NOT EXISTS two_view_geometries (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ config INTEGER NOT NULL,
+ F BLOB,
+ E BLOB,
+ H BLOB)
+"""
+
+CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
+"""
+
+CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB)"""
+
+CREATE_NAME_INDEX = \
+ "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
+
+CREATE_ALL = "; ".join([
+ CREATE_CAMERAS_TABLE,
+ CREATE_IMAGES_TABLE,
+ CREATE_KEYPOINTS_TABLE,
+ CREATE_DESCRIPTORS_TABLE,
+ CREATE_MATCHES_TABLE,
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE,
+ CREATE_NAME_INDEX
+])
+
+
+def image_ids_to_pair_id(image_id1, image_id2):
+ if image_id1 > image_id2:
+ image_id1, image_id2 = image_id2, image_id1
+ return image_id1 * MAX_IMAGE_ID + image_id2
+
+
+def pair_id_to_image_ids(pair_id):
+ image_id2 = pair_id % MAX_IMAGE_ID
+ image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
+ return image_id1, image_id2
+
+
+def array_to_blob(array):
+ if IS_PYTHON3:
+ return array.tostring()
+ else:
+ return np.getbuffer(array)
+
+
+def blob_to_array(blob, dtype, shape=(-1,)):
+ if IS_PYTHON3:
+ return np.fromstring(blob, dtype=dtype).reshape(*shape)
+ else:
+ return np.frombuffer(blob, dtype=dtype).reshape(*shape)
+
+
+class COLMAPDatabase(sqlite3.Connection):
+
+ @staticmethod
+ def connect(database_path):
+ return sqlite3.connect(str(database_path), factory=COLMAPDatabase)
+
+
+ def __init__(self, *args, **kwargs):
+ super(COLMAPDatabase, self).__init__(*args, **kwargs)
+
+ self.create_tables = lambda: self.executescript(CREATE_ALL)
+ self.create_cameras_table = \
+ lambda: self.executescript(CREATE_CAMERAS_TABLE)
+ self.create_descriptors_table = \
+ lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
+ self.create_images_table = \
+ lambda: self.executescript(CREATE_IMAGES_TABLE)
+ self.create_two_view_geometries_table = \
+ lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
+ self.create_keypoints_table = \
+ lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
+ self.create_matches_table = \
+ lambda: self.executescript(CREATE_MATCHES_TABLE)
+ self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
+
+ def add_camera(self, model, width, height, params,
+ prior_focal_length=False, camera_id=None):
+ params = np.asarray(params, np.float64)
+ cursor = self.execute(
+ "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
+ (camera_id, model, width, height, array_to_blob(params),
+ prior_focal_length))
+ return cursor.lastrowid
+
+ def add_image(self, name, camera_id,
+ prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None):
+ cursor = self.execute(
+ "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
+ prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
+ return cursor.lastrowid
+
+ def add_keypoints(self, image_id, keypoints):
+ assert(len(keypoints.shape) == 2)
+ assert(keypoints.shape[1] in [2, 4, 6])
+
+ keypoints = np.asarray(keypoints, np.float32)
+ self.execute(
+ "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
+ (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
+
+ def add_descriptors(self, image_id, descriptors):
+ descriptors = np.ascontiguousarray(descriptors, np.uint8)
+ self.execute(
+ "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
+ (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
+
+ def add_matches(self, image_id1, image_id2, matches):
+ assert(len(matches.shape) == 2)
+ assert(matches.shape[1] == 2)
+
+ if image_id1 > image_id2:
+ matches = matches[:,::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ self.execute(
+ "INSERT INTO matches VALUES (?, ?, ?, ?)",
+ (pair_id,) + matches.shape + (array_to_blob(matches),))
+
+ def add_two_view_geometry(self, image_id1, image_id2, matches,
+ F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
+ assert(len(matches.shape) == 2)
+ assert(matches.shape[1] == 2)
+
+ if image_id1 > image_id2:
+ matches = matches[:,::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ F = np.asarray(F, dtype=np.float64)
+ E = np.asarray(E, dtype=np.float64)
+ H = np.asarray(H, dtype=np.float64)
+ self.execute(
+ "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
+ (pair_id,) + matches.shape + (array_to_blob(matches), config,
+ array_to_blob(F), array_to_blob(E), array_to_blob(H)))
+
+
+def example_usage():
+ import os
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--database_path", default="database.db")
+ args = parser.parse_args()
+
+ if os.path.exists(args.database_path):
+ print("ERROR: database path already exists -- will not modify it.")
+ return
+
+ # Open the database.
+
+ db = COLMAPDatabase.connect(args.database_path)
+
+ # For convenience, try creating all the tables upfront.
+
+ db.create_tables()
+
+ # Create dummy cameras.
+
+ model1, width1, height1, params1 = \
+ 0, 1024, 768, np.array((1024., 512., 384.))
+ model2, width2, height2, params2 = \
+ 2, 1024, 768, np.array((1024., 512., 384., 0.1))
+
+ camera_id1 = db.add_camera(model1, width1, height1, params1)
+ camera_id2 = db.add_camera(model2, width2, height2, params2)
+
+ # Create dummy images.
+
+ image_id1 = db.add_image("image1.png", camera_id1)
+ image_id2 = db.add_image("image2.png", camera_id1)
+ image_id3 = db.add_image("image3.png", camera_id2)
+ image_id4 = db.add_image("image4.png", camera_id2)
+
+ # Create dummy keypoints.
+ #
+ # Note that COLMAP supports:
+ # - 2D keypoints: (x, y)
+ # - 4D keypoints: (x, y, theta, scale)
+ # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
+
+ num_keypoints = 1000
+ keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
+ keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
+
+ db.add_keypoints(image_id1, keypoints1)
+ db.add_keypoints(image_id2, keypoints2)
+ db.add_keypoints(image_id3, keypoints3)
+ db.add_keypoints(image_id4, keypoints4)
+
+ # Create dummy matches.
+
+ M = 50
+ matches12 = np.random.randint(num_keypoints, size=(M, 2))
+ matches23 = np.random.randint(num_keypoints, size=(M, 2))
+ matches34 = np.random.randint(num_keypoints, size=(M, 2))
+
+ db.add_matches(image_id1, image_id2, matches12)
+ db.add_matches(image_id2, image_id3, matches23)
+ db.add_matches(image_id3, image_id4, matches34)
+
+ # Commit the data to the file.
+
+ db.commit()
+
+ # Read and check cameras.
+
+ rows = db.execute("SELECT * FROM cameras")
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id1
+ assert model == model1 and width == width1 and height == height1
+ assert np.allclose(params, params1)
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id2
+ assert model == model2 and width == width2 and height == height2
+ assert np.allclose(params, params2)
+
+ # Read and check keypoints.
+
+ keypoints = dict(
+ (image_id, blob_to_array(data, np.float32, (-1, 2)))
+ for image_id, data in db.execute(
+ "SELECT image_id, data FROM keypoints"))
+
+ assert np.allclose(keypoints[image_id1], keypoints1)
+ assert np.allclose(keypoints[image_id2], keypoints2)
+ assert np.allclose(keypoints[image_id3], keypoints3)
+ assert np.allclose(keypoints[image_id4], keypoints4)
+
+ # Read and check matches.
+
+ pair_ids = [image_ids_to_pair_id(*pair) for pair in
+ ((image_id1, image_id2),
+ (image_id2, image_id3),
+ (image_id3, image_id4))]
+
+ matches = dict(
+ (pair_id_to_image_ids(pair_id),
+ blob_to_array(data, np.uint32, (-1, 2)))
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
+ )
+
+ assert np.all(matches[(image_id1, image_id2)] == matches12)
+ assert np.all(matches[(image_id2, image_id3)] == matches23)
+ assert np.all(matches[(image_id3, image_id4)] == matches34)
+
+ # Clean up.
+
+ db.close()
+
+ if os.path.exists(args.database_path):
+ os.remove(args.database_path)
+
+
+if __name__ == "__main__":
+ example_usage()
\ No newline at end of file
diff --git a/third_party/pram/colmap_utils/geometry.py b/third_party/pram/colmap_utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d48f0a9545f04300f0f914515e650bb60957296
--- /dev/null
+++ b/third_party/pram/colmap_utils/geometry.py
@@ -0,0 +1,17 @@
+# -*- coding: UTF-8 -*-
+import numpy as np
+import pycolmap
+
+
+def to_homogeneous(p):
+ return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1)
+
+
+def compute_epipolar_errors(j_from_i: pycolmap.Rigid3d, p2d_i, p2d_j):
+ j_E_i = j_from_i.essential_matrix()
+ l2d_j = to_homogeneous(p2d_i) @ j_E_i.T
+ l2d_i = to_homogeneous(p2d_j) @ j_E_i
+ dist = np.abs(np.sum(to_homogeneous(p2d_i) * l2d_i, axis=1))
+ errors_i = dist / np.linalg.norm(l2d_i[:, :2], axis=1)
+ errors_j = dist / np.linalg.norm(l2d_j[:, :2], axis=1)
+ return errors_i, errors_j
diff --git a/third_party/pram/colmap_utils/io.py b/third_party/pram/colmap_utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad46c685ca2a2fbb166d22884948f3fd6547368
--- /dev/null
+++ b/third_party/pram/colmap_utils/io.py
@@ -0,0 +1,78 @@
+# -*- coding: UTF-8 -*-
+from pathlib import Path
+from typing import Tuple
+
+import cv2
+import h5py
+import numpy as np
+
+from .parsers import names_to_pair, names_to_pair_old
+
+
+def read_image(path, grayscale=False):
+ if grayscale:
+ mode = cv2.IMREAD_GRAYSCALE
+ else:
+ mode = cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise ValueError(f"Cannot read image {path}.")
+ if not grayscale and len(image.shape) == 3:
+ image = image[:, :, ::-1] # BGR to RGB
+ return image
+
+
+def list_h5_names(path):
+ names = []
+ with h5py.File(str(path), "r", libver="latest") as fd:
+ def visit_fn(_, obj):
+ if isinstance(obj, h5py.Dataset):
+ names.append(obj.parent.name.strip("/"))
+
+ fd.visititems(visit_fn)
+ return list(set(names))
+
+
+def get_keypoints(
+ path: Path, name: str, return_uncertainty: bool = False
+) -> np.ndarray:
+ with h5py.File(str(path), "r", libver="latest") as hfile:
+ dset = hfile[name]["keypoints"]
+ p = dset.__array__()
+ uncertainty = dset.attrs.get("uncertainty")
+ if return_uncertainty:
+ return p, uncertainty
+ return p
+
+
+def find_pair(hfile: h5py.File, name0: str, name1: str):
+ pair = names_to_pair(name0, name1)
+ if pair in hfile:
+ return pair, False
+ pair = names_to_pair(name1, name0)
+ if pair in hfile:
+ return pair, True
+ # older, less efficient format
+ pair = names_to_pair_old(name0, name1)
+ if pair in hfile:
+ return pair, False
+ pair = names_to_pair_old(name1, name0)
+ if pair in hfile:
+ return pair, True
+ raise ValueError(
+ f"Could not find pair {(name0, name1)}... "
+ "Maybe you matched with a different list of pairs? "
+ )
+
+
+def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]:
+ with h5py.File(str(path), "r", libver="latest") as hfile:
+ pair, reverse = find_pair(hfile, name0, name1)
+ matches = hfile[pair]["matches0"].__array__()
+ scores = hfile[pair]["matching_scores0"].__array__()
+ idx = np.where(matches != -1)[0]
+ matches = np.stack([idx, matches[idx]], -1)
+ if reverse:
+ matches = np.flip(matches, -1)
+ scores = scores[idx]
+ return matches, scores
diff --git a/third_party/pram/colmap_utils/parsers.py b/third_party/pram/colmap_utils/parsers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e9087d78cc8cf7f1e81ab8359862227c3882786
--- /dev/null
+++ b/third_party/pram/colmap_utils/parsers.py
@@ -0,0 +1,73 @@
+# -*- coding: UTF-8 -*-
+
+from pathlib import Path
+import logging
+import numpy as np
+from collections import defaultdict
+
+
+def parse_image_lists_with_intrinsics(paths):
+ results = []
+ files = list(Path(paths.parent).glob(paths.name))
+ assert len(files) > 0
+
+ for lfile in files:
+ with open(lfile, 'r') as f:
+ raw_data = f.readlines()
+
+ logging.info(f'Importing {len(raw_data)} queries in {lfile.name}')
+ for data in raw_data:
+ data = data.strip('\n').split(' ')
+ name, camera_model, width, height = data[:4]
+ params = np.array(data[4:], float)
+ info = (camera_model, int(width), int(height), params)
+ results.append((name, info))
+
+ assert len(results) > 0
+ return results
+
+
+def parse_img_lists_for_extended_cmu_seaons(paths):
+ Ks = {
+ "c0": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571",
+ "c1": "OPENCV 1024 768 868.993378 866.063001 525.942323 420.042529 -0.399431 0.188924 0.000153 0.000571"
+ }
+
+ results = []
+ files = list(Path(paths.parent).glob(paths.name))
+ assert len(files) > 0
+
+ for lfile in files:
+ with open(lfile, 'r') as f:
+ raw_data = f.readlines()
+
+ logging.info(f'Importing {len(raw_data)} queries in {lfile.name}')
+ for name in raw_data:
+ name = name.strip('\n')
+ camera = name.split('_')[2]
+ K = Ks[camera].split(' ')
+ camera_model, width, height = K[:3]
+ params = np.array(K[3:], float)
+ # print("camera: ", camera_model, width, height, params)
+ info = (camera_model, int(width), int(height), params)
+ results.append((name, info))
+
+ assert len(results) > 0
+ return results
+
+
+def parse_retrieval(path):
+ retrieval = defaultdict(list)
+ with open(path, 'r') as f:
+ for p in f.read().rstrip('\n').split('\n'):
+ q, r = p.split(' ')
+ retrieval[q].append(r)
+ return dict(retrieval)
+
+
+def names_to_pair_old(name0, name1):
+ return '_'.join((name0.replace('/', '-'), name1.replace('/', '-')))
+
+
+def names_to_pair(name0, name1, separator="/"):
+ return separator.join((name0.replace("/", "-"), name1.replace("/", "-")))
diff --git a/third_party/pram/colmap_utils/read_write_model.py b/third_party/pram/colmap_utils/read_write_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..eddbeb7edd364c27c54029fa81077ea4f75d2700
--- /dev/null
+++ b/third_party/pram/colmap_utils/read_write_model.py
@@ -0,0 +1,627 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+import os
+import sys
+import collections
+import numpy as np
+import struct
+import argparse
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"])
+Camera = collections.namedtuple(
+ "Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
+}
+CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
+ for camera_model in CAMERA_MODELS])
+CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
+ for camera_model in CAMERA_MODELS])
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
+ """pack and write to a binary file.
+ :param fid:
+ :param data: data to send, if multiple elements are sent at the same time,
+ they should be encapsuled either in a list or a tuple
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ should be the same length as the data list or tuple
+ :param endian_character: Any of {@, =, <, >, !}
+ """
+ if isinstance(data, (list, tuple)):
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
+ else:
+ bytes = struct.pack(endian_character + format_char_sequence, data)
+ fid.write(bytes)
+
+
+def read_cameras_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(id=camera_id, model=model,
+ width=width, height=height,
+ params=params)
+ return cameras
+
+
+def read_cameras_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for camera_line_index in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ")
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(fid, num_bytes=8 * num_params,
+ format_char_sequence="d" * num_params)
+ cameras[camera_id] = Camera(id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params))
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def write_cameras_text(cameras, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ HEADER = '# Camera list with one line of data per camera:\n'
+ '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n'
+ '# Number of cameras: {}\n'.format(len(cameras))
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, cam in cameras.items():
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
+ line = " ".join([str(elem) for elem in to_write])
+ fid.write(line + "\n")
+
+
+def write_cameras_binary(cameras, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(cameras), "Q")
+ for _, cam in cameras.items():
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
+ camera_properties = [cam.id,
+ model_id,
+ cam.width,
+ cam.height]
+ write_next_bytes(fid, camera_properties, "iiQQ")
+ for p in cam.params:
+ write_next_bytes(fid, float(p), "d")
+ return cameras
+
+
+def read_images_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
+ tuple(map(float, elems[1::3]))])
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def read_images_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for image_index in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi")
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8,
+ format_char_sequence="Q")[0]
+ x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D,
+ format_char_sequence="ddq" * num_points2D)
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
+ tuple(map(float, x_y_id_s[1::3]))])
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def write_images_text(images, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ if len(images) == 0:
+ mean_observations = 0
+ else:
+ mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images)
+ HEADER = '# Image list with two lines of data per image:\n'
+ '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
+ '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
+ '# Number of images: {}, mean observations per image: {}\n'.format(len(images), mean_observations)
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, img in images.items():
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
+ first_line = " ".join(map(str, image_header))
+ fid.write(first_line + "\n")
+
+ points_strings = []
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
+ fid.write(" ".join(points_strings) + "\n")
+
+
+def write_images_binary(images, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_points3D_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ points3D = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ point3D_id = int(elems[0])
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = float(elems[7])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
+ points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
+ error=error, image_ids=image_ids,
+ point2D_idxs=point2D_idxs)
+ return points3D
+
+
+def read_points3d_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for point_line_index in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(
+ fid, num_bytes=8, format_char_sequence="Q")[0]
+ track_elems = read_next_bytes(
+ fid, num_bytes=8 * track_length,
+ format_char_sequence="ii" * track_length)
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id, xyz=xyz, rgb=rgb,
+ error=error, image_ids=image_ids,
+ point2D_idxs=point2D_idxs)
+ return points3D
+
+
+def write_points3D_text(points3D, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ if len(points3D) == 0:
+ mean_track_length = 0
+ else:
+ mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D)
+ HEADER = '# 3D point list with one line of data per point:\n'
+ '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n'
+ '# Number of points: {}, mean track length: {}\n'.format(len(points3D), mean_track_length)
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, pt in points3D.items():
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
+ fid.write(" ".join(map(str, point_header)) + " ")
+ track_strings = []
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
+ fid.write(" ".join(track_strings) + "\n")
+
+
+def write_points3d_binary(points3D, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
+
+
+def read_model(path, ext):
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def write_model(cameras, images, points3D, path, ext):
+ if ext == ".txt":
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
+ write_images_text(images, os.path.join(path, "images" + ext))
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
+ else:
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
+ write_images_binary(images, os.path.join(path, "images" + ext))
+ write_points3d_binary(points3D, os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def read_compressed_images_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for image_index in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi")
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8,
+ format_char_sequence="Q")[0]
+ # x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D,
+ # format_char_sequence="ddq" * num_points2D)
+ # xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
+ # tuple(map(float, x_y_id_s[1::3]))])
+ x_y_id_s = read_next_bytes(fid, num_bytes=8 * num_points2D,
+ format_char_sequence="q" * num_points2D)
+ point3D_ids = np.array(x_y_id_s)
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=np.array([]), point3D_ids=point3D_ids)
+ return images
+
+
+def write_compressed_images_binary(images, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for p3d_id in img.point3D_ids:
+ write_next_bytes(fid, p3d_id, "q")
+ # for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ # write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_compressed_points3d_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for point_line_index in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(
+ fid, num_bytes=8, format_char_sequence="Q")[0]
+ track_elems = read_next_bytes(
+ fid, num_bytes=4 * track_length,
+ format_char_sequence="i" * track_length)
+ image_ids = np.array(track_elems)
+ # point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id, xyz=xyz, rgb=rgb,
+ error=error, image_ids=image_ids,
+ point2D_idxs=np.array([]))
+ return points3D
+
+
+def write_compressed_points3d_binary(points3D, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ # for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ # write_next_bytes(fid, [image_id, point2D_id], "ii")
+ for image_id in pt.image_ids:
+ write_next_bytes(fid, image_id, "i")
+
+
+def read_compressed_model(path, ext):
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_compressed_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_compressed_points3d_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def qvec2rotmat(qvec):
+ return np.array([
+ [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]])
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = np.array([
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+def intrinsics_from_camera(camera_model, params):
+ if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
+ fx = fy = params[0]
+ cx = params[1]
+ cy = params[2]
+ elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
+ fx = params[0]
+ fy = params[1]
+ cx = params[2]
+ cy = params[3]
+ else:
+ raise Exception("Camera model not supported")
+
+ # intrinsics
+ K = np.identity(3)
+ K[0, 0] = fx
+ K[1, 1] = fy
+ K[0, 2] = cx
+ K[1, 2] = cy
+ return K
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Read and write COLMAP binary and text models')
+ parser.add_argument('input_model', help='path to input model folder')
+ parser.add_argument('input_format', choices=['.bin', '.txt'],
+ help='input model format')
+ parser.add_argument('--output_model', metavar='PATH',
+ help='path to output model folder')
+ parser.add_argument('--output_format', choices=['.bin', '.txt'],
+ help='outut model format', default='.txt')
+ args = parser.parse_args()
+
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
+
+ print("num_cameras:", len(cameras))
+ print("num_images:", len(images))
+ print("num_points3D:", len(points3D))
+
+ if args.output_model is not None:
+ write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/third_party/pram/colmap_utils/utils.py b/third_party/pram/colmap_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d98fed2dfc5789b650144caa3a4bac8cfe6a2fb
--- /dev/null
+++ b/third_party/pram/colmap_utils/utils.py
@@ -0,0 +1 @@
+# -*- coding: UTF-8 -*-
diff --git a/third_party/pram/configs/config_train_12scenes_sfd2.yaml b/third_party/pram/configs/config_train_12scenes_sfd2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e6e7fb7c851edb8bd6e26e8d4806cadeb5977d5
--- /dev/null
+++ b/third_party/pram/configs/config_train_12scenes_sfd2.yaml
@@ -0,0 +1,102 @@
+dataset: [ '12Scenes' ]
+
+network_1: "segnet"
+network: "segnetvit"
+
+local_rank: 0
+gpu: [ 0 ]
+
+feature: "sfd2"
+save_path: '/scratches/flyer_2/fx221/exp/pram'
+landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
+dataset_path: "/scratches/flyer_3/fx221/dataset"
+config_path: 'configs/datasets'
+
+image_dim: 3
+feat_dim: 128
+min_inliers: 32
+max_inliers: 512
+random_inliers: true
+max_keypoints: 512
+ignore_index: -1
+output_dim: 1024
+output_dim_: 2048
+jitter_params:
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.25
+ hue: 0.15
+ blur: 0
+
+scale_params: [ 0.5, 1.0 ]
+pre_load: false
+train: true
+inlier_th: 0.5
+lr: 0.0001
+min_lr: 0.00001
+optimizer: "adamw"
+seg_loss: "cew"
+seg_loss_nx: "cei"
+cls_loss: "ce"
+cls_loss_: "bce"
+ac_fn: "relu"
+norm_fn: "bn"
+workers: 8
+layers: 15
+log_intervals: 50
+eval_n_epoch: 10
+do_eval: false
+
+use_mid_feature: true
+norm_desc: false
+with_score: false
+with_aug: true
+with_dist: true
+
+batch_size: 32
+its_per_epoch: 1000
+decay_rate: 0.999992
+decay_iter: 60000
+epochs: 500
+
+cluster_method: 'birch'
+
+weight_path: null
+weight_path_1: '20230719_220620_segnet_L15_T_resnet4x_B32_K1024_relu_bn_od1024_nc193_adamw_cew_md_A_birch/segnet.499.pth'
+weight_path_2: '20240202_145337_segnetvit_L15_T_resnet4x_B32_K512_relu_bn_od1024_nc193_adam_cew_md_A_birch/segnetvit.499.pth'
+
+resume_path: null
+
+n_class: 193
+
+eval_max_keypoints: 1024
+
+localization:
+ loc_scene_name: [ 'apt1/kitchen' ]
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
+ seg_k: 20
+ threshold: 8
+ min_kpts: 128
+ min_matches: 4
+ min_inliers: 64
+ matching_method_: "mnn"
+ matching_method_1: "spg"
+ matching_method_2: "gm"
+ matching_method: "gml"
+ matching_method_5: "adagml"
+ save: false
+ show: true
+ show_time: 1
+ max_vrf: 1
+ with_original: true
+ with_extra: false
+ with_compress: true
+ semantic_matching: true
+ do_refinement: true
+ refinement_method_: 'matching'
+ refinement_method: 'projection'
+ pre_filtering_th: 0.95
+ covisibility_frame: 20
+ refinement_radius: 20
+ refinement_nn_ratio: 0.9
+ refinement_max_matches: 0
diff --git a/third_party/pram/configs/config_train_7scenes_sfd2.yaml b/third_party/pram/configs/config_train_7scenes_sfd2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..19b0635c9ad4ebcf0a085a759640e4a149a75009
--- /dev/null
+++ b/third_party/pram/configs/config_train_7scenes_sfd2.yaml
@@ -0,0 +1,104 @@
+dataset: [ '7Scenes' ]
+
+network: "segnetvit"
+
+local_rank: 0
+gpu: [ 0 ]
+# when using ddp, set gpu: [0,1,2,3]
+with_dist: true
+
+feature: "sfd2"
+save_path_: '/scratches/flyer_2/fx221/exp/pram'
+save_path: '/scratches/flyer_2/fx221/publications/test_pram/exp'
+landmark_path_: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
+landmark_path: "/scratches/flyer_2/fx221/publications/test_pram/landmakrs/sfd2-gml"
+dataset_path: "/scratches/flyer_3/fx221/dataset"
+config_path: 'configs/datasets'
+
+image_dim: 3
+feat_dim: 128
+
+min_inliers: 32
+max_inliers: 256
+random_inliers: 1
+max_keypoints: 512
+ignore_index: -1
+output_dim: 1024
+output_dim_: 2048
+jitter_params:
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.25
+ hue: 0.15
+ blur: 0
+
+scale_params: [ 0.5, 1.0 ]
+pre_load: false
+train: true
+inlier_th: 0.5
+lr: 0.0001
+min_lr: 0.00001
+cls_loss: "ce"
+ac_fn: "relu"
+norm_fn: "bn"
+workers: 8
+layers: 15
+log_intervals: 50
+eval_n_epoch: 10
+do_eval: false
+
+use_mid_feature: true
+norm_desc: false
+with_cls: false
+with_score: false
+with_aug: true
+
+batch_size: 32
+its_per_epoch: 1000
+decay_rate: 0.999992
+decay_iter: 80000
+epochs: 200
+
+cluster_method: 'birch'
+
+weight_path: null
+weight_path_1: '20230724_203230_segnet_L15_S_resnet4x_B32_K1024_relu_bn_od1024_nc113_adam_cew_md_A_birch/segnet.180.pth'
+weight_path_2: '20240202_152519_segnetvit_L15_S_resnet4x_B32_K512_relu_bn_od1024_nc113_adamw_cew_md_A_birch/segnetvit.199.pth'
+
+# used for resuming training
+resume_path: null
+
+# used for localization
+n_class: 113
+
+eval_max_keypoints: 1024
+
+localization:
+ loc_scene_name: [ 'chess' ]
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
+
+ seg_k: 20
+ threshold: 8
+ min_kpts: 128
+ min_matches: 16
+ min_inliers: 32
+ matching_method_: "mnn"
+ matching_method_1: "spg"
+ matching_method_2: "gm"
+ matching_method: "gml"
+ matching_method_4: "adagml"
+ save: false
+ show: true
+ show_time: 1
+ with_original: true
+ max_vrf: 1
+ with_compress: true
+ semantic_matching: true
+ do_refinement: true
+ pre_filtering_th: 0.95
+ refinement_method_: 'matching'
+ refinement_method: 'projection'
+ covisibility_frame: 20
+ refinement_radius: 20
+ refinement_nn_ratio: 0.9
+ refinement_max_matches: 0
diff --git a/third_party/pram/configs/config_train_aachen_sfd2.yaml b/third_party/pram/configs/config_train_aachen_sfd2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9e2111377ed9d6cff38efd69bc397487ecfb33fb
--- /dev/null
+++ b/third_party/pram/configs/config_train_aachen_sfd2.yaml
@@ -0,0 +1,104 @@
+dataset: [ 'Aachen' ]
+
+network_: "segnet"
+network: "segnetvit"
+local_rank: 0
+gpu: [ 0 ]
+
+feature: "sfd2"
+save_path: '/scratches/flyer_2/fx221/exp/pram'
+landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
+dataset_path: "/scratches/flyer_3/fx221/dataset"
+
+config_path: 'configs/datasets'
+
+image_dim: 3
+feat_dim: 128
+
+min_inliers: 32
+max_inliers: 512
+random_inliers: true
+max_keypoints: 1024
+ignore_index: -1
+output_dim: 1024
+output_dim_: 2048
+jitter_params:
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.25
+ hue: 0.15
+ blur: 0
+
+scale_params: [ 0.5, 1.0 ]
+pre_load: false
+do_eval: true
+train: true
+inlier_th: 0.5
+lr: 0.0001
+min_lr: 0.00001
+optimizer: "adam"
+seg_loss: "cew"
+seg_loss_nx: "cei"
+cls_loss: "ce"
+cls_loss_: "bce"
+ac_fn: "relu"
+norm_fn: "bn"
+workers: 8
+layers: 15
+log_intervals: 50
+eval_n_epoch: 10
+
+use_mid_feature: true
+norm_desc: false
+with_sc: false
+with_cls: true
+with_score: false
+with_aug: true
+with_dist: true
+
+batch_size: 32
+its_per_epoch: 1000
+decay_rate: 0.999992
+decay_iter: 80000
+epochs: 800
+
+cluster_method: 'birch'
+
+weight_path: null
+weight_path_1: '20230719_221442_segnet_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adamw_cew_md_A_birch/segnet.899.pth'
+weight_path_2: '20240211_142623_segnetvit_L15_A_resnet4x_B32_K1024_relu_bn_od1024_nc513_adam_cew_md_A_birch/segnetvit.799.pth'
+resume_path: null
+
+n_class: 513
+
+eval_max_keypoints: 4096
+
+localization:
+ loc_scene_name: [ ]
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
+ seg_k: 10
+ threshold: 12
+ min_kpts: 256
+ min_matches: 8
+ min_inliers: 128
+ matching_method_: "mnn"
+ matching_method_1: "spg"
+ matching_method_2: "gm"
+ matching_method: "gml"
+ matching_method_4: "adagml"
+ save: false
+ show: true
+ show_time: 1
+ with_original: true
+ with_extra: false
+ max_vrf: 1
+ with_compress: true
+ semantic_matching: true
+ refinement_method_: 'matching'
+ refinement_method: 'projection'
+ pre_filtering_th: 0.95
+ do_refinement: true
+ covisibility_frame: 50
+ refinement_radius: 30
+ refinement_nn_ratio: 0.9
+ refinement_max_matches: 0
diff --git a/third_party/pram/configs/config_train_cambridge_sfd2.yaml b/third_party/pram/configs/config_train_cambridge_sfd2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8cc843ee963dc5c0041954790d7e622e24aefe16
--- /dev/null
+++ b/third_party/pram/configs/config_train_cambridge_sfd2.yaml
@@ -0,0 +1,103 @@
+dataset: [ 'CambridgeLandmarks' ]
+
+network_: "segnet"
+network: "segnetvit"
+
+local_rank: 0
+gpu: [ 0 ]
+
+feature: "sfd2"
+save_path: '/scratches/flyer_2/fx221/exp/pram'
+landmark_path: "/scratches/flyer_3/fx221/exp/pram/landmarks/sfd2-gml"
+dataset_path: "/scratches/flyer_3/fx221/dataset"
+config_path: 'configs/datasets'
+
+image_dim: 3
+feat_dim: 128
+
+min_inliers: 32
+max_inliers: 512
+random_inliers: 1
+max_keypoints: 1024
+ignore_index: -1
+output_dim: 1024
+output_dim_: 2048
+jitter_params:
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.25
+ hue: 0.15
+ blur: 0
+
+scale_params: [ 0.5, 1.0 ]
+pre_load: false
+do_eval: false
+train: true
+inlier_th: 0.5
+lr: 0.0001
+min_lr: 0.00001
+epochs: 300
+seg_loss: "cew"
+ac_fn: "relu"
+norm_fn: "bn"
+workers: 8
+layers: 15
+log_intervals: 50
+eval_n_epoch: 10
+
+use_mid_feature: true
+norm_desc: false
+with_score: false
+with_aug: true
+with_dist: true
+
+batch_size: 32
+its_per_epoch: 1000
+decay_rate: 0.999992
+decay_iter: 60000
+
+cluster_method: 'birch'
+
+weight_path: null
+weight_path_1: '20230725_144044_segnet_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adam_cew_md_A_birch/segnet.260.pth'
+weight_path_2: '20240204_130323_segnetvit_L15_C_resnet4x_B32_K1024_relu_bn_od1024_nc161_adamw_cew_md_A_birch/segnetvit.399.pth'
+
+resume_path: null
+
+n_class: 161
+
+eval_max_keypoints: 2048
+
+localization:
+ loc_scene_name_1: [ 'GreatCourt' ]
+ loc_scene_name_2: [ 'KingsCollege' ]
+ loc_scene_name: [ 'StMarysChurch' ]
+ loc_scene_name_4: [ 'OldHospital' ]
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
+ seg_k: 30
+ threshold: 12
+ min_kpts: 256
+ min_matches: 16
+ min_inliers_gm: 128
+ min_inliers: 128
+ matching_method_: "mnn"
+ matching_method_1: "spg"
+ matching_method_2: "gm"
+ matching_method: "gml"
+ matching_method_4: "adagml"
+ show: true
+ show_time: 1
+ save: false
+ with_original: true
+ max_vrf: 1
+ with_extra: false
+ with_compress: true
+ semantic_matching: true
+ do_refinement: true
+ pre_filtering_th: 0.95
+ refinement_method_: 'matching'
+ refinement_method: 'projection'
+ covisibility_frame: 20
+ refinement_radius: 20
+ refinement_nn_ratio: 0.9
+ refinement_max_matches: 0
diff --git a/third_party/pram/configs/config_train_multiset_sfd2.yaml b/third_party/pram/configs/config_train_multiset_sfd2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..90618e0812c2321ba05fbe3ab9a12d52ec447e99
--- /dev/null
+++ b/third_party/pram/configs/config_train_multiset_sfd2.yaml
@@ -0,0 +1,100 @@
+dataset: [ 'S', 'T', 'C', 'A' ]
+
+network: "segnet"
+network_: "gsegnet3"
+
+local_rank: 0
+gpu: [ 4 ]
+
+feature: "resnet4x"
+save_path: '/scratches/flyer_2/fx221/exp/localizer'
+landmark_path: "/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gm"
+dataset_path: "/scratches/flyer_3/fx221/dataset"
+config_path: 'configs/datasets'
+
+image_dim: 3
+min_inliers: 32
+max_inliers: 512
+random_inliers: 1
+max_keypoints: 1024
+ignore_index: -1
+output_dim: 1024
+output_dim_: 2048
+jitter_params:
+ brightness: 0.5
+ contrast: 0.5
+ saturation: 0.25
+ hue: 0.15
+ blur: 0
+
+scale_params: [ 0.5, 1.0 ]
+pre_load: false
+do_eval: true
+train: true
+inlier_th: 0.5
+lr: 0.0001
+min_lr: 0.00001
+optimizer: "adam"
+seg_loss: "cew"
+seg_loss_nx: "cei"
+cls_loss: "ce"
+cls_loss_: "bce"
+sc_loss: 'l1g'
+ac_fn: "relu"
+norm_fn: "bn"
+workers: 8
+layers: 15
+log_intervals: 50
+eval_n_epoch: 10
+
+use_mid_feature: true
+norm_desc: false
+with_sc: false
+with_cls: true
+with_score: false
+with_aug: true
+with_dist: true
+
+batch_size: 32
+its_per_epoch: 1000
+decay_rate: 0.999992
+decay_iter: 150000
+epochs: 1500
+
+cluster_method_: 'kmeans'
+cluster_method: 'birch'
+
+weight_path_: null
+weight_path: '20230805_132653_segnet_L15_STCA_resnet4x_B32_K1024_relu_bn_od1024_nc977_adam_cew_md_A_birch/segnet.485.pth'
+resume_path: null
+
+eval: false
+#loc: false
+loc: true
+#n_class: 977
+online: false
+
+eval_max_keypoints: 4096
+
+localization:
+ loc_scene_name: [ ]
+ save_path: '/scratches/flyer_2/fx221/exp/localizer/loc_results'
+ dataset: [ 'T' ]
+ seg_k: 50
+ threshold: 8 # 8 for indoor, 12 for outdoor
+ min_kpts: 256
+ min_matches: 4
+ min_inliers: 64
+ matching_method_: "mnn"
+ matching_method_1: "spg"
+ matching_method: "gm"
+ save: false
+ show: true
+ show_time: 1
+ do_refinement: true
+ with_original: true
+ with_extra: false
+ max_vrf: 1
+ with_compress: false
+ covisibility_frame: 20
+ observation_threshold: 3
diff --git a/third_party/pram/configs/datasets/12Scenes.yaml b/third_party/pram/configs/datasets/12Scenes.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e950aca2ff25526af622fec779e9bb6a07eaea6b
--- /dev/null
+++ b/third_party/pram/configs/datasets/12Scenes.yaml
@@ -0,0 +1,166 @@
+dataset: '12Scenes'
+scenes: [ 'apt1/kitchen',
+ 'apt1/living',
+ 'apt2/bed',
+ 'apt2/kitchen',
+ 'apt2/living',
+ 'apt2/luke',
+ 'office1/gates362',
+ 'office1/gates381',
+ 'office1/lounge',
+ 'office1/manolis',
+ 'office2/5a',
+ 'office2/5b'
+]
+
+apt1/kitchen:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+ image_path_prefix: ''
+
+
+apt1/living:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+apt2/bed:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+apt2/kitchen:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+apt2/living:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+apt2/luke:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+office1/gates362:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 3
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+office1/gates381:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 3
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+office1/lounge:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+office1/manolis:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+office2/5a:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+office2/5b:
+ n_cluster: 16
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 5
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
diff --git a/third_party/pram/configs/datasets/7Scenes.yaml b/third_party/pram/configs/datasets/7Scenes.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fd68181fbc0ed96ccb3e464d94a5346183c1dfe3
--- /dev/null
+++ b/third_party/pram/configs/datasets/7Scenes.yaml
@@ -0,0 +1,96 @@
+dataset: '7Scenes'
+scenes: [ 'chess', 'heads', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin' ]
+
+
+chess:
+ n_cluster: 16
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 2
+ eval_sample_ratio: 10
+ gt_pose_path: 'queries_poses.txt'
+ query_path: 'queries_with_intrinsics.txt'
+ image_path_prefix: ''
+
+
+
+heads:
+ n_cluster: 16
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 2
+ gt_pose_path: 'queries_poses.txt'
+ query_path: 'queries_with_intrinsics.txt'
+ image_path_prefix: ''
+
+
+office:
+ n_cluster: 16
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 3
+ eval_sample_ratio: 10
+ gt_pose_path: 'queries_poses.txt'
+ query_path: 'queries_with_intrinsics.txt'
+ image_path_prefix: ''
+
+fire:
+ n_cluster: 16
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 2
+ eval_sample_ratio: 5
+ gt_pose_path: 'queries_poses.txt'
+ query_path: 'queries_with_intrinsics.txt'
+ image_path_prefix: ''
+
+
+stairs:
+ n_cluster: 16
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 10
+ gt_pose_path: 'queries_poses.txt'
+ query_path: 'queries_with_intrinsics.txt'
+ image_path_prefix: ''
+
+
+redkitchen:
+ n_cluster: 16
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 3
+ eval_sample_ratio: 10
+ gt_pose_path: 'queries_poses.txt'
+ query_path: 'queries_with_intrinsics.txt'
+ image_path_prefix: ''
+
+
+
+
+pumpkin:
+ n_cluster: 16
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 2
+ eval_sample_ratio: 10
+ gt_pose_path: 'queries_poses.txt'
+ query_path: 'queries_with_intrinsics.txt'
+ image_path_prefix: ''
+
diff --git a/third_party/pram/configs/datasets/Aachen.yaml b/third_party/pram/configs/datasets/Aachen.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..49477afbe569cb0fc4317b6c1a98c30f261ee7e0
--- /dev/null
+++ b/third_party/pram/configs/datasets/Aachen.yaml
@@ -0,0 +1,15 @@
+dataset: 'Aachen'
+
+scenes: [ 'Aachenv11' ]
+
+Aachenv11:
+ n_cluster: 512
+ cluster_mode: 'xz'
+ cluster_method_: 'kmeans'
+ cluster_method: 'birch'
+ training_sample_ratio: 1
+ eval_sample_ratio: 1
+ image_path_prefix: 'images/images_upright'
+ query_path_: 'queries_with_intrinsics.txt'
+ query_path: 'queries_with_intrinsics_demo.txt'
+ gt_pose_path: 'queries_pose_spp_spg.txt'
diff --git a/third_party/pram/configs/datasets/CambridgeLandmarks.yaml b/third_party/pram/configs/datasets/CambridgeLandmarks.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3a757898db1e772b593059d2c21ef1eaaa825ea
--- /dev/null
+++ b/third_party/pram/configs/datasets/CambridgeLandmarks.yaml
@@ -0,0 +1,67 @@
+dataset: 'CambridgeLandmarks'
+scenes: [ 'GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch' ]
+
+GreatCourt:
+ n_cluster: 32
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 1
+ image_path_prefix: ''
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+KingsCollege:
+ n_cluster: 32
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 1
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+OldHospital:
+ n_cluster: 32
+ cluster_mode: 'xz'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 1
+ image_path_prefix: ''
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+ShopFacade:
+ n_cluster: 32
+ cluster_mode: 'xy'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 1
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+StMarysChurch:
+ n_cluster: 32
+ cluster_mode: 'xz'
+ cluster_method: 'birch'
+
+ training_sample_ratio: 1
+ eval_sample_ratio: 1
+ image_path_prefix: ''
+
+ query_path: 'queries_with_intrinsics.txt'
+ gt_pose_path: 'queries_poses.txt'
+
+
+
diff --git a/third_party/pram/dataset/aachen.py b/third_party/pram/dataset/aachen.py
new file mode 100644
index 0000000000000000000000000000000000000000..d57efd8e4460f943d66b2d8b92e57d7cd7f7f75a
--- /dev/null
+++ b/third_party/pram/dataset/aachen.py
@@ -0,0 +1,119 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> aachen
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:33
+=================================================='''
+import os.path as osp
+import numpy as np
+import cv2
+from colmap_utils.read_write_model import read_model
+import torchvision.transforms as tvt
+from dataset.basicdataset import BasicDataset
+
+
+class Aachen(BasicDataset):
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='Aachen',
+ nfeatures=1024,
+ query_p3d_fn=None,
+ train=True,
+ with_aug=False,
+ min_inliers=0,
+ max_inliers=4096,
+ random_inliers=False,
+ jitter_params=None,
+ scale_params=None,
+ image_dim=3,
+ query_info_path=None,
+ sample_ratio=1, ):
+ self.landmark_path = osp.join(landmark_path, scene)
+ self.dataset_path = osp.join(dataset_path, scene)
+ self.n_class = n_class
+ self.dataset = dataset + '/' + scene
+ self.nfeatures = nfeatures
+ self.with_aug = with_aug
+ self.jitter_params = jitter_params
+ self.scale_params = scale_params
+ self.image_dim = image_dim
+ self.train = train
+ self.min_inliers = min_inliers
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
+ self.random_inliers = random_inliers
+ self.image_prefix = 'images/images_upright'
+
+ train_transforms = []
+ if self.with_aug:
+ train_transforms.append(tvt.ColorJitter(
+ brightness=jitter_params['brightness'],
+ contrast=jitter_params['contrast'],
+ saturation=jitter_params['saturation'],
+ hue=jitter_params['hue']))
+ if jitter_params['blur'] > 0:
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
+ self.train_transforms = tvt.Compose(train_transforms)
+
+ if train:
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
+
+ # only for testing of query images
+ if not self.train:
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
+ self.img_p3d = data
+ else:
+ self.img_p3d = {}
+
+ self.img_fns = []
+ if train:
+ with open(osp.join(self.dataset_path, 'aachen_db_imglist.txt'), 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip()
+ if l not in self.name_to_id.keys():
+ continue
+ self.img_fns.append(l)
+ else:
+ with open(osp.join(self.dataset_path, 'queries', 'day_time_queries_with_intrinsics.txt'), 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip().split()[0]
+ if l not in self.img_p3d.keys():
+ continue
+ self.img_fns.append(l)
+ with open(osp.join(self.dataset_path, 'queries', 'night_time_queries_with_intrinsics.txt'), 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip().split()[0]
+ if l not in self.img_p3d.keys():
+ continue
+ self.img_fns.append(l)
+
+ print(
+ 'Load {} images from {} for {}...'.format(len(self.img_fns), self.dataset, 'training' if train else 'eval'))
+
+ data = np.load(osp.join(self.landmark_path,
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
+ allow_pickle=True)[()]
+ p3d_id = data['id']
+ seg_id = data['label']
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ xyzs = data['xyz']
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
+
+ with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip().split()
+ self.mean_xyz = np.array([float(v) for v in l[:3]])
+ self.scale_xyz = np.array([float(v) for v in l[3:]])
+
+ if not train:
+ self.query_info = self.read_query_info(path=query_info_path)
+
+ self.nfeatures = nfeatures
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
+ self.feats = {}
+
+ def read_image(self, image_name):
+ return cv2.imread(osp.join(self.dataset_path, 'images/images_upright/', image_name))
diff --git a/third_party/pram/dataset/basicdataset.py b/third_party/pram/dataset/basicdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c77c32ca010e99d14ddd8643c2ff07789bd75851
--- /dev/null
+++ b/third_party/pram/dataset/basicdataset.py
@@ -0,0 +1,477 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> basicdataset
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:27
+=================================================='''
+import torchvision.transforms.functional as tvf
+import torchvision.transforms as tvt
+import os.path as osp
+import numpy as np
+import cv2
+from colmap_utils.read_write_model import qvec2rotmat, read_model
+from dataset.utils import normalize_size
+
+
+class BasicDataset:
+ def __init__(self,
+ img_list_fn,
+ feature_dir,
+ sfm_path,
+ seg_fn,
+ dataset_path,
+ n_class,
+ dataset,
+ nfeatures=1024,
+ query_p3d_fn=None,
+ train=True,
+ with_aug=False,
+ min_inliers=0,
+ max_inliers=4096,
+ random_inliers=False,
+ jitter_params=None,
+ scale_params=None,
+ image_dim=1,
+ pre_load=False,
+ query_info_path=None,
+ sc_mean_scale_fn=None,
+ ):
+ self.n_class = n_class
+ self.train = train
+ self.min_inliers = min_inliers
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
+ self.random_inliers = random_inliers
+ self.dataset_path = dataset_path
+ self.with_aug = with_aug
+ self.dataset = dataset
+ self.jitter_params = jitter_params
+ self.scale_params = scale_params
+ self.image_dim = image_dim
+ self.image_prefix = ''
+
+ train_transforms = []
+ if self.with_aug:
+ train_transforms.append(tvt.ColorJitter(
+ brightness=jitter_params['brightness'],
+ contrast=jitter_params['contrast'],
+ saturation=jitter_params['saturation'],
+ hue=jitter_params['hue']))
+ if jitter_params['blur'] > 0:
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
+ self.train_transforms = tvt.Compose(train_transforms)
+
+ # only for testing of query images
+ if not self.train:
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
+ self.img_p3d = data
+ else:
+ self.img_p3d = {}
+
+ self.img_fns = []
+ with open(img_list_fn, 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip()
+ self.img_fns.append(l)
+ print('Load {} images from {} for {}...'.format(len(self.img_fns), dataset, 'training' if train else 'eval'))
+ self.feats = {}
+ if train:
+ self.cameras, self.images, point3Ds = read_model(path=sfm_path, ext='.bin')
+ self.name_to_id = {image.name: i for i, image in self.images.items()}
+
+ data = np.load(seg_fn, allow_pickle=True)[()]
+ p3d_id = data['id']
+ seg_id = data['label']
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ self.p3d_xyzs = {}
+
+ for pid in self.p3d_seg.keys():
+ p3d = point3Ds[pid]
+ self.p3d_xyzs[pid] = p3d.xyz
+
+ with open(sc_mean_scale_fn, 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip().split()
+ self.mean_xyz = np.array([float(v) for v in l[:3]])
+ self.scale_xyz = np.array([float(v) for v in l[3:]])
+
+ if not train:
+ self.query_info = self.read_query_info(path=query_info_path)
+
+ self.nfeatures = nfeatures
+ self.feature_dir = feature_dir
+ print('Pre loaded {} feats, mean xyz {}, scale xyz {}'.format(len(self.feats.keys()), self.mean_xyz,
+ self.scale_xyz))
+
+ def normalize_p3ds(self, p3ds):
+ mean_p3ds = np.ceil(np.mean(p3ds, axis=0))
+ p3ds_ = p3ds - mean_p3ds
+ dx = np.max(abs(p3ds_[:, 0]))
+ dy = np.max(abs(p3ds_[:, 1]))
+ dz = np.max(abs(p3ds_[:, 2]))
+ scale_p3ds = np.ceil(np.array([dx, dy, dz], dtype=float).reshape(3, ))
+ scale_p3ds[scale_p3ds < 1] = 1
+ scale_p3ds[scale_p3ds == 0] = 1
+ return mean_p3ds, scale_p3ds
+
+ def read_query_info(self, path):
+ query_info = {}
+ with open(path, 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip().split()
+ image_name = l[0]
+ cam_model = l[1]
+ h, w = int(l[2]), int(l[3])
+ params = np.array([float(v) for v in l[4:]])
+ query_info[image_name] = {
+ 'width': w,
+ 'height': h,
+ 'model': cam_model,
+ 'params': params,
+ }
+ return query_info
+
+ def extract_intrinsic_extrinsic_params(self, image_id):
+ cam = self.cameras[self.images[image_id].camera_id]
+ params = cam.params
+ model = cam.model
+ if model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
+ fx = fy = params[0]
+ cx = params[1]
+ cy = params[2]
+ elif model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
+ fx = params[0]
+ fy = params[1]
+ cx = params[2]
+ cy = params[3]
+ else:
+ raise Exception("Camera model not supported")
+ K = np.eye(3, dtype=float)
+ K[0, 0] = fx
+ K[1, 1] = fy
+ K[0, 2] = cx
+ K[1, 2] = cy
+
+ qvec = self.images[image_id].qvec
+ tvec = self.images[image_id].tvec
+ R = qvec2rotmat(qvec=qvec)
+ P = np.eye(4, dtype=float)
+ P[:3, :3] = R
+ P[:3, 3] = tvec.reshape(3, )
+
+ return {'K': K, 'P': P}
+
+ def get_item_train(self, idx):
+ img_name = self.img_fns[idx]
+ if img_name in self.feats.keys():
+ feat_data = self.feats[img_name]
+ else:
+ feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()]
+ # descs = feat_data['descriptors'] # [N, D]
+ scores = feat_data['scores'] # [N, 1]
+ kpts = feat_data['keypoints'] # [N, 2]
+ image_size = feat_data['image_size']
+
+ nfeat = kpts.shape[0]
+
+ # print(img_name, self.name_to_id[img_name])
+ p3d_ids = self.images[self.name_to_id[img_name]].point3D_ids
+ p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float)
+
+ seg_ids = np.zeros(shape=(nfeat,), dtype=int) # + self.n_class - 1
+ for i in range(nfeat):
+ p3d = p3d_ids[i]
+ if p3d in self.p3d_seg.keys():
+ seg_ids[i] = self.p3d_seg[p3d] + 1 # 0 for invalid
+ if seg_ids[i] == -1:
+ seg_ids[i] = 0
+
+ if p3d in self.p3d_xyzs.keys():
+ p3d_xyzs[i] = self.p3d_xyzs[p3d]
+
+ seg_ids = np.array(seg_ids).reshape(-1, )
+
+ n_inliers = np.sum(seg_ids > 0)
+ n_outliers = np.sum(seg_ids == 0)
+ inlier_ids = np.where(seg_ids > 0)[0]
+ outlier_ids = np.where(seg_ids == 0)[0]
+
+ if n_inliers <= self.min_inliers:
+ sel_inliers = n_inliers
+ sel_outliers = self.nfeatures - sel_inliers
+
+ out_ids = np.arange(n_outliers)
+ np.random.shuffle(out_ids)
+ sel_ids = np.hstack([inlier_ids, outlier_ids[out_ids[:self.nfeatures - n_inliers]]])
+ else:
+ sel_inliers = np.random.randint(self.min_inliers, self.max_inliers)
+ if sel_inliers > n_inliers:
+ sel_inliers = n_inliers
+
+ if sel_inliers + n_outliers < self.nfeatures:
+ sel_inliers = self.nfeatures - n_outliers
+
+ sel_outliers = self.nfeatures - sel_inliers
+
+ in_ids = np.arange(n_inliers)
+ np.random.shuffle(in_ids)
+ sel_inlier_ids = inlier_ids[in_ids[:sel_inliers]]
+
+ out_ids = np.arange(n_outliers)
+ np.random.shuffle(out_ids)
+ sel_outlier_ids = outlier_ids[out_ids[:sel_outliers]]
+
+ sel_ids = np.hstack([sel_inlier_ids, sel_outlier_ids])
+
+ # sel_descs = descs[sel_ids]
+ sel_scores = scores[sel_ids]
+ sel_kpts = kpts[sel_ids]
+ sel_seg_ids = seg_ids[sel_ids]
+ sel_xyzs = p3d_xyzs[sel_ids]
+
+ shuffle_ids = np.arange(sel_ids.shape[0])
+ np.random.shuffle(shuffle_ids)
+ # sel_descs = sel_descs[shuffle_ids]
+ sel_scores = sel_scores[shuffle_ids]
+ sel_kpts = sel_kpts[shuffle_ids]
+ sel_seg_ids = sel_seg_ids[shuffle_ids]
+ sel_xyzs = sel_xyzs[shuffle_ids]
+
+ if sel_kpts.shape[0] < self.nfeatures:
+ # print(sel_descs.shape, sel_kpts.shape, sel_scores.shape, sel_seg_ids.shape, sel_xyzs.shape)
+ valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0]) if sel_seg_ids[v] > 0], dtype=int)
+ # ref_sel_id = np.random.choice(valid_sel_ids, size=1)[0]
+ if valid_sel_ids.shape[0] == 0:
+ valid_sel_ids = np.array([v for v in range(sel_kpts.shape[0])], dtype=int)
+ random_n = self.nfeatures - sel_kpts.shape[0]
+ random_scores = np.random.random((random_n,))
+ random_kpts, random_seg_ids, random_xyzs = self.random_points_from_reference(
+ n=random_n,
+ ref_kpts=sel_kpts[valid_sel_ids],
+ ref_segs=sel_seg_ids[valid_sel_ids],
+ ref_xyzs=sel_xyzs[valid_sel_ids],
+ radius=5,
+ )
+ # sel_descs = np.vstack([sel_descs, random_descs])
+ sel_scores = np.hstack([sel_scores, random_scores])
+ sel_kpts = np.vstack([sel_kpts, random_kpts])
+ sel_seg_ids = np.hstack([sel_seg_ids, random_seg_ids])
+ sel_xyzs = np.vstack([sel_xyzs, random_xyzs])
+
+ gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
+ gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
+ gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
+ uids = np.unique(sel_seg_ids).tolist()
+ for uid in uids:
+ if uid == 0:
+ continue
+ gt_cls[uid] = 1
+ gt_n_seg[uid] = np.sum(sel_seg_ids == uid)
+ gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum(seg_ids > 0) # [valid_id / total_valid_id]
+
+ param_out = self.extract_intrinsic_extrinsic_params(image_id=self.name_to_id[img_name])
+
+ img = self.read_image(image_name=img_name)
+ image_size = img.shape[:2]
+ if self.image_dim == 1:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if self.with_aug:
+ nh = img.shape[0]
+ nw = img.shape[1]
+ if self.scale_params is not None:
+ do_scale = np.random.random()
+ if do_scale <= 0.25:
+ p = np.random.randint(0, 11)
+ s = self.scale_params[0] + (self.scale_params[1] - self.scale_params[0]) / 10 * p
+ nh = int(img.shape[0] * s)
+ nw = int(img.shape[1] * s)
+ sh = nh / img.shape[0]
+ sw = nw / img.shape[1]
+ sel_kpts[:, 0] = sel_kpts[:, 0] * sw
+ sel_kpts[:, 1] = sel_kpts[:, 1] * sh
+ img = cv2.resize(img, dsize=(nw, nh))
+
+ brightness = np.random.uniform(-self.jitter_params['brightness'], self.jitter_params['brightness']) * 255
+ contrast = 1 + np.random.uniform(-self.jitter_params['contrast'], self.jitter_params['contrast'])
+ img = cv2.addWeighted(img, contrast, img, 0, brightness)
+ img = np.clip(img, a_min=0, a_max=255)
+ if self.image_dim == 1:
+ img = img[..., None]
+ img = img.astype(float) / 255.
+ image_size = np.array([nh, nw], dtype=int)
+ else:
+ if self.image_dim == 1:
+ img = img[..., None].astype(float) / 255.
+
+ output = {
+ # 'descriptors': sel_descs, # may not be used
+ 'scores': sel_scores,
+ 'keypoints': sel_kpts,
+ 'norm_keypoints': normalize_size(x=sel_kpts, size=image_size),
+ 'image': [img],
+ 'gt_seg': sel_seg_ids,
+ 'gt_cls': gt_cls,
+ 'gt_cls_dist': gt_cls_dist,
+ 'gt_n_seg': gt_n_seg,
+ 'file_name': img_name,
+ 'prefix_name': self.image_prefix,
+ # 'mean_xyz': self.mean_xyz,
+ # 'scale_xyz': self.scale_xyz,
+ # 'gt_sc': sel_xyzs,
+ # 'gt_norm_sc': (sel_xyzs - self.mean_xyz) / self.scale_xyz,
+ 'K': param_out['K'],
+ 'gt_P': param_out['P']
+ }
+ return output
+
+ def get_item_test(self, idx):
+
+ # evaluation of recognition only
+ img_name = self.img_fns[idx]
+ feat_data = np.load(osp.join(self.feature_dir, img_name.replace('/', '+') + '.npy'), allow_pickle=True)[()]
+ descs = feat_data['descriptors'] # [N, D]
+ scores = feat_data['scores'] # [N, 1]
+ kpts = feat_data['keypoints'] # [N, 2]
+ image_size = feat_data['image_size']
+
+ nfeat = descs.shape[0]
+
+ if img_name in self.img_p3d.keys():
+ p3d_ids = self.img_p3d[img_name]
+ p3d_xyzs = np.zeros(shape=(nfeat, 3), dtype=float)
+ seg_ids = np.zeros(shape=(nfeat,), dtype=int) # attention! by default invalid!!!
+ for i in range(nfeat):
+ p3d = p3d_ids[i]
+ if p3d in self.p3d_seg.keys():
+ seg_ids[i] = self.p3d_seg[p3d] + 1
+ if seg_ids[i] == -1:
+ seg_ids[i] = 0 # 0 for in valid
+
+ if p3d in self.p3d_xyzs.keys():
+ p3d_xyzs[i] = self.p3d_xyzs[p3d]
+
+ seg_ids = np.array(seg_ids).reshape(-1, )
+
+ if self.nfeatures > 0:
+ sorted_ids = np.argsort(scores)[::-1][:self.nfeatures] # large to small
+ descs = descs[sorted_ids]
+ scores = scores[sorted_ids]
+ kpts = kpts[sorted_ids]
+ p3d_xyzs = p3d_xyzs[sorted_ids]
+
+ seg_ids = seg_ids[sorted_ids]
+
+ gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
+ gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
+ gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
+ uids = np.unique(seg_ids).tolist()
+ for uid in uids:
+ if uid == 0:
+ continue
+ gt_cls[uid] = 1
+ gt_n_seg[uid] = np.sum(seg_ids == uid)
+ gt_cls_dist[uid] = np.sum(seg_ids == uid) / np.sum(
+ seg_ids < self.n_class - 1) # [valid_id / total_valid_id]
+
+ gt_cls[0] = 0
+
+ img = self.read_image(image_name=img_name)
+ if self.image_dim == 1:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ img = img[..., None].astype(float) / 255.
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(float) / 255.
+ return {
+ 'descriptors': descs,
+ 'scores': scores,
+ 'keypoints': kpts,
+ 'image_size': image_size,
+ 'norm_keypoints': normalize_size(x=kpts, size=image_size),
+ 'gt_seg': seg_ids,
+ 'gt_cls': gt_cls,
+ 'gt_cls_dist': gt_cls_dist,
+ 'gt_n_seg': gt_n_seg,
+ 'file_name': img_name,
+ 'prefix_name': self.image_prefix,
+ 'image': [img],
+
+ 'mean_xyz': self.mean_xyz,
+ 'scale_xyz': self.scale_xyz,
+ 'gt_sc': p3d_xyzs,
+ 'gt_norm_sc': (p3d_xyzs - self.mean_xyz) / self.scale_xyz
+ }
+
+ def __getitem__(self, idx):
+ if self.train:
+ return self.get_item_train(idx=idx)
+ else:
+ return self.get_item_test(idx=idx)
+
+ def __len__(self):
+ return len(self.img_fns)
+
+ def read_image(self, image_name):
+ return cv2.imread(osp.join(self.dataset_path, image_name))
+
+ def jitter_augmentation(self, img, params):
+ brightness, contrast, saturation, hue = params
+ p = np.random.randint(0, 20) / 20
+ b = brightness[0] + (brightness[1] - brightness[0]) / 20 * p
+ img = tvf.adjust_brightness(img=img, brightness_factor=b)
+
+ p = np.random.randint(0, 20) / 20
+ c = contrast[0] + (contrast[1] - contrast[0]) / 20 * p
+ img = tvf.adjust_contrast(img=img, contrast_factor=c)
+
+ p = np.random.randint(0, 20) / 20
+ s = saturation[0] + (saturation[1] - saturation[0]) / 20 * p
+ img = tvf.adjust_saturation(img=img, saturation_factor=s)
+
+ p = np.random.randint(0, 20) / 20
+ h = hue[0] + (hue[1] - hue[0]) / 20 * p
+ img = tvf.adjust_hue(img=img, hue_factor=h)
+
+ return img
+
+ def random_points(self, n, d, h, w):
+ desc = np.random.random((n, d))
+ desc = desc / np.linalg.norm(desc, ord=2, axis=1)[..., None]
+ xs = np.random.randint(0, w - 1, size=(n, 1))
+ ys = np.random.randint(0, h - 1, size=(n, 1))
+ kpts = np.hstack([xs, ys])
+ return desc, kpts
+
+ def random_points_from_reference(self, n, ref_kpts, ref_segs, ref_xyzs, radius=5):
+ n_ref = ref_kpts.shape[0]
+ if n_ref < n:
+ ref_ids = np.random.choice([i for i in range(n_ref)], size=n).tolist()
+ else:
+ ref_ids = [i for i in range(n)]
+
+ new_xs = []
+ new_ys = []
+ # new_descs = []
+ new_segs = []
+ new_xyzs = []
+ for i in ref_ids:
+ nx = np.random.randint(-radius, radius) + ref_kpts[i, 0]
+ ny = np.random.randint(-radius, radius) + ref_kpts[i, 1]
+
+ new_xs.append(nx)
+ new_ys.append(ny)
+ # new_descs.append(ref_descs[i])
+ new_segs.append(ref_segs[i])
+ new_xyzs.append(ref_xyzs[i])
+
+ new_xs = np.array(new_xs).reshape(n, 1)
+ new_ys = np.array(new_ys).reshape(n, 1)
+ new_segs = np.array(new_segs).reshape(n, )
+ new_kpts = np.hstack([new_xs, new_ys])
+ # new_descs = np.array(new_descs).reshape(n, -1)
+ new_xyzs = np.array(new_xyzs)
+ return new_kpts, new_segs, new_xyzs
diff --git a/third_party/pram/dataset/cambridge_landmarks.py b/third_party/pram/dataset/cambridge_landmarks.py
new file mode 100644
index 0000000000000000000000000000000000000000..03f30f367f4ded9ce1d7c2efbaa407ed26725a69
--- /dev/null
+++ b/third_party/pram/dataset/cambridge_landmarks.py
@@ -0,0 +1,101 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> cambridge_landmarks
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:41
+=================================================='''
+import os.path as osp
+import numpy as np
+from colmap_utils.read_write_model import read_model
+import torchvision.transforms as tvt
+from dataset.basicdataset import BasicDataset
+
+
+class CambridgeLandmarks(BasicDataset):
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='CambridgeLandmarks',
+ nfeatures=1024,
+ query_p3d_fn=None,
+ train=True,
+ with_aug=False,
+ min_inliers=0,
+ max_inliers=4096,
+ random_inliers=False,
+ jitter_params=None,
+ scale_params=None,
+ image_dim=3,
+ query_info_path=None,
+ sample_ratio=1,
+ ):
+ self.landmark_path = osp.join(landmark_path, scene)
+ self.dataset_path = osp.join(dataset_path, scene)
+ self.n_class = n_class
+ self.dataset = dataset + '/' + scene
+ self.nfeatures = nfeatures
+ self.with_aug = with_aug
+ self.jitter_params = jitter_params
+ self.scale_params = scale_params
+ self.image_dim = image_dim
+ self.train = train
+ self.min_inliers = min_inliers
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
+ self.random_inliers = random_inliers
+ self.image_prefix = ''
+ train_transforms = []
+ if self.with_aug:
+ train_transforms.append(tvt.ColorJitter(
+ brightness=jitter_params['brightness'],
+ contrast=jitter_params['contrast'],
+ saturation=jitter_params['saturation'],
+ hue=jitter_params['hue']))
+ if jitter_params['blur'] > 0:
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
+ self.train_transforms = tvt.Compose(train_transforms)
+
+ if train:
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
+
+ # only for testing of query images
+ if not self.train:
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
+ self.img_p3d = data
+ else:
+ self.img_p3d = {}
+
+ self.img_fns = []
+ with open(osp.join(self.dataset_path, 'dataset_train.txt' if train else 'dataset_test.txt'), 'r') as f:
+ lines = f.readlines()[3:] # ignore the first 3 lines
+ for l in lines:
+ l = l.strip().split()[0]
+ if train and l not in self.name_to_id.keys():
+ continue
+ if not train and l not in self.img_p3d.keys():
+ continue
+ self.img_fns.append(l)
+
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
+ self.dataset, 'training' if train else 'eval'))
+
+ data = np.load(osp.join(self.landmark_path,
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
+ allow_pickle=True)[()]
+ p3d_id = data['id']
+ seg_id = data['label']
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ xyzs = data['xyz']
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
+
+ # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
+ # lines = f.readlines()
+ # for l in lines:
+ # l = l.strip().split()
+ # self.mean_xyz = np.array([float(v) for v in l[:3]])
+ # self.scale_xyz = np.array([float(v) for v in l[3:]])
+
+ if not train:
+ self.query_info = self.read_query_info(path=query_info_path)
+
+ self.nfeatures = nfeatures
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
+ self.feats = {}
diff --git a/third_party/pram/dataset/customdataset.py b/third_party/pram/dataset/customdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..41ec99ec1540868f3dfbafe00b5585398062e3f8
--- /dev/null
+++ b/third_party/pram/dataset/customdataset.py
@@ -0,0 +1,93 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> customdataset.py
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:38
+=================================================='''
+import os.path as osp
+import numpy as np
+from colmap_utils.read_write_model import read_model
+import torchvision.transforms as tvt
+from dataset.basicdataset import BasicDataset
+
+
+class CustomDataset(BasicDataset):
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset,
+ nfeatures=1024,
+ query_p3d_fn=None,
+ train=True,
+ with_aug=False,
+ min_inliers=0,
+ max_inliers=4096,
+ random_inliers=False,
+ jitter_params=None,
+ scale_params=None,
+ image_dim=3,
+ query_info_path=None,
+ sample_ratio=1,
+ ):
+ self.landmark_path = osp.join(landmark_path, scene)
+ self.dataset_path = osp.join(dataset_path, scene)
+ self.n_class = n_class
+ self.dataset = dataset + '/' + scene
+ self.nfeatures = nfeatures
+ self.with_aug = with_aug
+ self.jitter_params = jitter_params
+ self.scale_params = scale_params
+ self.image_dim = image_dim
+ self.train = train
+ self.min_inliers = min_inliers
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
+ self.random_inliers = random_inliers
+ self.image_prefix = ''
+
+ train_transforms = []
+ if self.with_aug:
+ train_transforms.append(tvt.ColorJitter(
+ brightness=jitter_params['brightness'],
+ contrast=jitter_params['contrast'],
+ saturation=jitter_params['saturation'],
+ hue=jitter_params['hue']))
+ if jitter_params['blur'] > 0:
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
+ self.train_transforms = tvt.Compose(train_transforms)
+
+ if train:
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
+
+ # only for testing of query images
+ if not self.train:
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
+ self.img_p3d = data
+ else:
+ self.img_p3d = {}
+
+ if train:
+ self.img_fns = [self.images[v].name for v in self.images.keys() if
+ self.images[v].name in self.name_to_id.keys()]
+ else:
+ self.img_fns = []
+ with open(osp.join(self.dataset_path, 'queries_with_intrinsics.txt'), 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ self.img_fns.append(l.strip().split()[0])
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
+ self.dataset, 'training' if train else 'eval'))
+
+ data = np.load(osp.join(self.landmark_path,
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
+ allow_pickle=True)[()]
+ p3d_id = data['id']
+ seg_id = data['label']
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ xyzs = data['xyz']
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
+
+ if not train:
+ self.query_info = self.read_query_info(path=query_info_path)
+
+ self.nfeatures = nfeatures
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
+ self.feats = {}
diff --git a/third_party/pram/dataset/get_dataset.py b/third_party/pram/dataset/get_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fe28eaa6238b480aae4c64cd08ffe6cd2379c90
--- /dev/null
+++ b/third_party/pram/dataset/get_dataset.py
@@ -0,0 +1,89 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> get_dataset
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:40
+=================================================='''
+import os.path as osp
+import yaml
+from dataset.aachen import Aachen
+from dataset.twelve_scenes import TwelveScenes
+from dataset.seven_scenes import SevenScenes
+from dataset.cambridge_landmarks import CambridgeLandmarks
+from dataset.customdataset import CustomDataset
+from dataset.recdataset import RecDataset
+
+
+def get_dataset(dataset):
+ if dataset in ['7Scenes', 'S']:
+ return SevenScenes
+ elif dataset in ['12Scenes', 'T']:
+ return TwelveScenes
+ elif dataset in ['Aachen', 'A']:
+ return Aachen
+ elif dataset in ['CambridgeLandmarks', 'C']:
+ return CambridgeLandmarks
+ else:
+ return CustomDataset
+
+
+def compose_datasets(datasets, config, train=True, sample_ratio=None):
+ sub_sets = []
+ for name in datasets:
+ if name == 'S':
+ ds_name = '7Scenes'
+ elif name == 'T':
+ ds_name = '12Scenes'
+ elif name == 'A':
+ ds_name = 'Aachen'
+ elif name == 'R':
+ ds_name = 'RobotCar-Seasons'
+ elif name == 'C':
+ ds_name = 'CambridgeLandmarks'
+ else:
+ ds_name = name
+ # raise '{} dataset does not exist'.format(name)
+ landmark_path = osp.join(config['landmark_path'], ds_name)
+ dataset_path = osp.join(config['dataset_path'], ds_name)
+ scene_config_path = 'configs/datasets/{:s}.yaml'.format(ds_name)
+
+ with open(scene_config_path, 'r') as f:
+ scene_config = yaml.load(f, Loader=yaml.Loader)
+ DSet = get_dataset(dataset=ds_name)
+
+ for scene in scene_config['scenes']:
+ if sample_ratio is None:
+ scene_sample_ratio = scene_config[scene]['training_sample_ratio'] if train else scene_config[scene][
+ 'eval_sample_ratio']
+ else:
+ scene_sample_ratio = sample_ratio
+ scene_set = DSet(landmark_path=landmark_path,
+ dataset_path=dataset_path,
+ scene=scene,
+ seg_mode=scene_config[scene]['cluster_mode'],
+ seg_method=scene_config[scene]['cluster_method'],
+ n_class=scene_config[scene]['n_cluster'] + 1, # including invalid - 0
+ dataset=ds_name,
+ train=train,
+ nfeatures=config['max_keypoints'] if train else config['eval_max_keypoints'],
+ min_inliers=config['min_inliers'],
+ max_inliers=config['max_inliers'],
+ random_inliers=config['random_inliers'],
+ with_aug=config['with_aug'],
+ jitter_params=config['jitter_params'],
+ scale_params=config['scale_params'],
+ image_dim=config['image_dim'],
+ query_p3d_fn=osp.join(config['landmark_path'], ds_name, scene,
+ 'point3D_query_n{:d}_{:s}_{:s}.npy'.format(
+ scene_config[scene]['n_cluster'],
+ scene_config[scene]['cluster_mode'],
+ scene_config[scene]['cluster_method'])),
+ query_info_path=osp.join(config['dataset_path'], ds_name, scene,
+ 'queries_with_intrinsics.txt'),
+ sample_ratio=scene_sample_ratio,
+ )
+
+ sub_sets.append(scene_set)
+
+ return RecDataset(sub_sets=sub_sets)
diff --git a/third_party/pram/dataset/recdataset.py b/third_party/pram/dataset/recdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eebd473018ad269eaa6cd8f1ffaab3f5f316ec6
--- /dev/null
+++ b/third_party/pram/dataset/recdataset.py
@@ -0,0 +1,95 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> recdataset
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:42
+=================================================='''
+import numpy as np
+from torch.utils.data import Dataset
+
+
+class RecDataset(Dataset):
+ def __init__(self, sub_sets=[]):
+ assert len(sub_sets) >= 1
+
+ self.sub_sets = sub_sets
+ self.names = []
+
+ self.sub_set_index = []
+ self.seg_offsets = []
+ self.sub_set_item_index = []
+ self.dataset_names = []
+ self.scene_names = []
+ start_index_valid_seg = 1 # start from 1, 0 is for invalid
+
+ total_subset = 0
+ for scene_set in sub_sets: # [0, n_class]
+ name = scene_set.dataset
+ self.names.append(name)
+ n_samples = len(scene_set)
+
+ n_class = scene_set.n_class
+ self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))]
+ start_index_valid_seg = start_index_valid_seg + n_class - 1
+
+ self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)]
+ self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)]
+
+ # self.dataset_names = self.dataset_names + [name for k in range(n_samples)]
+ self.scene_names = self.scene_names + [name for k in range(n_samples)]
+ total_subset += 1
+
+ self.n_class = start_index_valid_seg
+
+ print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class,
+ len(sub_sets), self.names))
+
+ def __len__(self):
+ return len(self.sub_set_item_index)
+
+ def __getitem__(self, idx):
+ subset_idx = self.sub_set_index[idx]
+ item_idx = self.sub_set_item_index[idx]
+ scene_name = self.scene_names[idx]
+
+ out = self.sub_sets[subset_idx][item_idx]
+
+ org_gt_seg = out['gt_seg']
+ org_gt_cls = out['gt_cls']
+ org_gt_cls_dist = out['gt_cls_dist']
+ org_gt_n_seg = out['gt_n_seg']
+ offset = self.seg_offsets[idx]
+ org_n_class = self.sub_sets[subset_idx].n_class
+
+ gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int) # [0, ..., n_features]
+ gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
+ gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
+ gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
+
+ # copy invalid segments
+ gt_n_seg[0] = org_gt_n_seg[0]
+ gt_cls[0] = org_gt_cls[0]
+ gt_cls_dist[0] = org_gt_cls_dist[0]
+ # print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg)
+
+ # copy valid segments
+ gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1 # [0, ..., 1023]
+ gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:] # [0...,n_seg]
+ gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:] # [0, ..., n_seg]
+ gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:] # [0, ..., n_seg]
+
+ out['gt_seg'] = gt_seg
+ out['gt_cls'] = gt_cls
+ out['gt_cls_dist'] = gt_cls_dist
+ out['gt_n_seg'] = gt_n_seg
+
+ # print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg)
+ out['scene_name'] = scene_name
+
+ # out['org_gt_seg'] = org_gt_seg
+ # out['org_gt_n_seg'] = org_gt_n_seg
+ # out['org_gt_cls'] = org_gt_cls
+ # out['org_gt_cls_dist'] = org_gt_cls_dist
+
+ return out
diff --git a/third_party/pram/dataset/seven_scenes.py b/third_party/pram/dataset/seven_scenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbc29b29d3b935e45129a35b502117067816433a
--- /dev/null
+++ b/third_party/pram/dataset/seven_scenes.py
@@ -0,0 +1,115 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> seven_scenes
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:36
+=================================================='''
+import os
+import os.path as osp
+import numpy as np
+from colmap_utils.read_write_model import read_model
+import torchvision.transforms as tvt
+from dataset.basicdataset import BasicDataset
+
+
+class SevenScenes(BasicDataset):
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='7Scenes',
+ nfeatures=1024,
+ query_p3d_fn=None,
+ train=True,
+ with_aug=False,
+ min_inliers=0,
+ max_inliers=4096,
+ random_inliers=False,
+ jitter_params=None,
+ scale_params=None,
+ image_dim=3,
+ query_info_path=None,
+ sample_ratio=1,
+ ):
+ self.landmark_path = osp.join(landmark_path, scene)
+ self.dataset_path = osp.join(dataset_path, scene)
+ self.n_class = n_class
+ self.dataset = dataset + '/' + scene
+ self.nfeatures = nfeatures
+ self.with_aug = with_aug
+ self.jitter_params = jitter_params
+ self.scale_params = scale_params
+ self.image_dim = image_dim
+ self.train = train
+ self.min_inliers = min_inliers
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
+ self.random_inliers = random_inliers
+ self.image_prefix = ''
+
+ train_transforms = []
+ if self.with_aug:
+ train_transforms.append(tvt.ColorJitter(
+ brightness=jitter_params['brightness'],
+ contrast=jitter_params['contrast'],
+ saturation=jitter_params['saturation'],
+ hue=jitter_params['hue']))
+ if jitter_params['blur'] > 0:
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
+ self.train_transforms = tvt.Compose(train_transforms)
+
+ if train:
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
+
+ # only for testing of query images
+ if not self.train:
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
+ self.img_p3d = data
+ else:
+ self.img_p3d = {}
+
+ if self.train:
+ split_fn = osp.join(self.dataset_path, 'TrainSplit.txt')
+ else:
+ split_fn = osp.join(self.dataset_path, 'TestSplit.txt')
+
+ self.img_fns = []
+ with open(split_fn, 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ seq = int(l.strip()[8:])
+ fns = os.listdir(osp.join(self.dataset_path, osp.join('seq-{:02d}'.format(seq))))
+ fns = sorted(fns)
+ nf = 0
+ for fn in fns:
+ if fn.find('png') >= 0:
+ if train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.name_to_id.keys():
+ continue
+ if not train and 'seq-{:02d}'.format(seq) + '/' + fn not in self.img_p3d.keys():
+ continue
+ if nf % sample_ratio == 0:
+ self.img_fns.append('seq-{:02d}'.format(seq) + '/' + fn)
+ nf += 1
+
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
+ self.dataset, 'training' if train else 'eval'))
+
+ data = np.load(osp.join(self.landmark_path,
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
+ allow_pickle=True)[()]
+ p3d_id = data['id']
+ seg_id = data['label']
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ xyzs = data['xyz']
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
+
+ # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
+ # lines = f.readlines()
+ # for l in lines:
+ # l = l.strip().split()
+ # self.mean_xyz = np.array([float(v) for v in l[:3]])
+ # self.scale_xyz = np.array([float(v) for v in l[3:]])
+
+ if not train:
+ self.query_info = self.read_query_info(path=query_info_path)
+
+ self.nfeatures = nfeatures
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
+ self.feats = {}
diff --git a/third_party/pram/dataset/twelve_scenes.py b/third_party/pram/dataset/twelve_scenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..34fcc7f46b6d4315d9ebca69043a262310adc453
--- /dev/null
+++ b/third_party/pram/dataset/twelve_scenes.py
@@ -0,0 +1,121 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> twelve_scenes
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:37
+=================================================='''
+import os
+import os.path as osp
+import numpy as np
+from colmap_utils.read_write_model import read_model
+import torchvision.transforms as tvt
+from dataset.basicdataset import BasicDataset
+
+
+class TwelveScenes(BasicDataset):
+ def __init__(self, landmark_path, scene, dataset_path, n_class, seg_mode, seg_method, dataset='12Scenes',
+ nfeatures=1024,
+ query_p3d_fn=None,
+ train=True,
+ with_aug=False,
+ min_inliers=0,
+ max_inliers=4096,
+ random_inliers=False,
+ jitter_params=None,
+ scale_params=None,
+ image_dim=3,
+ query_info_path=None,
+ sample_ratio=1,
+ ):
+ self.landmark_path = osp.join(landmark_path, scene)
+ self.dataset_path = osp.join(dataset_path, scene)
+ self.n_class = n_class
+ self.dataset = dataset + '/' + scene
+ self.nfeatures = nfeatures
+ self.with_aug = with_aug
+ self.jitter_params = jitter_params
+ self.scale_params = scale_params
+ self.image_dim = image_dim
+ self.train = train
+ self.min_inliers = min_inliers
+ self.max_inliers = max_inliers if max_inliers < nfeatures else nfeatures
+ self.random_inliers = random_inliers
+ self.image_prefix = ''
+
+ train_transforms = []
+ if self.with_aug:
+ train_transforms.append(tvt.ColorJitter(
+ brightness=jitter_params['brightness'],
+ contrast=jitter_params['contrast'],
+ saturation=jitter_params['saturation'],
+ hue=jitter_params['hue']))
+ if jitter_params['blur'] > 0:
+ train_transforms.append(tvt.GaussianBlur(kernel_size=int(jitter_params['blur'])))
+ self.train_transforms = tvt.Compose(train_transforms)
+
+ if train:
+ self.cameras, self.images, point3Ds = read_model(path=osp.join(self.landmark_path, '3D-models'), ext='.bin')
+ self.name_to_id = {image.name: i for i, image in self.images.items() if len(self.images[i].point3D_ids) > 0}
+
+ # only for testing of query images
+ if not self.train:
+ data = np.load(query_p3d_fn, allow_pickle=True)[()]
+ self.img_p3d = data
+ else:
+ self.img_p3d = {}
+
+ with open(osp.join(self.dataset_path, 'split.txt'), 'r') as f:
+ l = f.readline()
+ l = l.strip().split(' ') # sequence0 [frames=357] [start=0 ; end=356], first sequence for testing
+ start_img_id = l[-3].split('=')[-1]
+ end_img_id = l[-1].split('=')[-1][:-1]
+ test_start_img_id = int(start_img_id)
+ test_end_img_id = int(end_img_id)
+
+ self.img_fns = []
+ fns = os.listdir(osp.join(self.dataset_path, 'data'))
+ fns = sorted(fns)
+ nf = 0
+ for fn in fns:
+ if fn.find('jpg') >= 0: # frame-001098.color.jpg
+ frame_id = int(fn.split('.')[0].split('-')[-1])
+ if not train and frame_id > test_end_img_id:
+ continue
+ if train and frame_id <= test_end_img_id:
+ continue
+
+ if train and 'data' + '/' + fn not in self.name_to_id.keys():
+ continue
+
+ if not train and 'data' + '/' + fn not in self.img_p3d.keys():
+ continue
+ if nf % sample_ratio == 0:
+ self.img_fns.append('data' + '/' + fn)
+ nf += 1
+
+ print('Load {} images from {} for {}...'.format(len(self.img_fns),
+ self.dataset, 'training' if train else 'eval'))
+
+ data = np.load(osp.join(self.landmark_path,
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_class - 1, seg_mode, seg_method)),
+ allow_pickle=True)[()]
+ p3d_id = data['id']
+ seg_id = data['label']
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ xyzs = data['xyz']
+ self.p3d_xyzs = {p3d_id[i]: xyzs[i] for i in range(p3d_id.shape[0])}
+
+ # with open(osp.join(self.landmark_path, 'sc_mean_scale.txt'), 'r') as f:
+ # lines = f.readlines()
+ # for l in lines:
+ # l = l.strip().split()
+ # self.mean_xyz = np.array([float(v) for v in l[:3]])
+ # self.scale_xyz = np.array([float(v) for v in l[3:]])
+
+ if not train:
+ self.query_info = self.read_query_info(path=query_info_path)
+
+ self.nfeatures = nfeatures
+ self.feature_dir = osp.join(self.landmark_path, 'feats')
+ self.feats = {}
diff --git a/third_party/pram/dataset/utils.py b/third_party/pram/dataset/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb8132662c540ae28de32494a5abff6e679064f5
--- /dev/null
+++ b/third_party/pram/dataset/utils.py
@@ -0,0 +1,31 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> utils
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:31
+=================================================='''
+import torch
+
+
+def normalize_size(x, size, scale=0.7):
+ size = size.reshape([1, 2])
+ norm_fac = size.max() + 0.5
+ return (x - size / 2) / (norm_fac * scale)
+
+
+def collect_batch(batch):
+ out = {}
+ # if len(batch) == 0:
+ # return batch
+ # else:
+ for k in batch[0].keys():
+ tmp = []
+ for v in batch:
+ tmp.append(v[k])
+ if isinstance(batch[0][k], str) or isinstance(batch[0][k], list):
+ out[k] = tmp
+ else:
+ out[k] = torch.cat([torch.from_numpy(i)[None] for i in tmp], dim=0)
+
+ return out
diff --git a/third_party/pram/environment.yml b/third_party/pram/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bf1c2111660046500e25c9ff28e66d470c7f68a9
--- /dev/null
+++ b/third_party/pram/environment.yml
@@ -0,0 +1,173 @@
+name: pram
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=2_gnu
+ - binutils_impl_linux-64=2.38=h2a08ee3_1
+ - bzip2=1.0.8=h5eee18b_5
+ - ca-certificates=2024.3.11=h06a4308_0
+ - gcc=12.1.0=h9ea6d83_10
+ - gcc_impl_linux-64=12.1.0=hea43390_17
+ - kernel-headers_linux-64=2.6.32=he073ed8_17
+ - ld_impl_linux-64=2.38=h1181459_1
+ - libffi=3.4.4=h6a678d5_0
+ - libgcc-devel_linux-64=12.1.0=h1ec3361_17
+ - libgcc-ng=13.2.0=h807b86a_5
+ - libgomp=13.2.0=h807b86a_5
+ - libsanitizer=12.1.0=ha89aaad_17
+ - libstdcxx-ng=13.2.0=h7e041cc_5
+ - libuuid=1.41.5=h5eee18b_0
+ - ncurses=6.4=h6a678d5_0
+ - openssl=3.2.1=hd590300_1
+ - pip=23.3.1=py310h06a4308_0
+ - python=3.10.14=h955ad1f_0
+ - readline=8.2=h5eee18b_0
+ - setuptools=68.2.2=py310h06a4308_0
+ - sqlite=3.41.2=h5eee18b_0
+ - sysroot_linux-64=2.12=he073ed8_17
+ - tk=8.6.12=h1ccaba5_0
+ - wheel=0.41.2=py310h06a4308_0
+ - xz=5.4.6=h5eee18b_0
+ - zlib=1.2.13=h5eee18b_0
+ - pip:
+ - addict==2.4.0
+ - aiofiles==23.2.1
+ - aiohttp==3.9.3
+ - aioopenssl==0.6.0
+ - aiosasl==0.5.0
+ - aiosignal==1.3.1
+ - aioxmpp==0.13.3
+ - asttokens==2.4.1
+ - async-timeout==4.0.3
+ - attrs==23.2.0
+ - babel==2.14.0
+ - benbotasync==3.0.2
+ - blinker==1.7.0
+ - certifi==2024.2.2
+ - cffi==1.16.0
+ - charset-normalizer==3.3.2
+ - click==8.1.7
+ - colorama==0.4.6
+ - comm==0.2.2
+ - configargparse==1.7
+ - contourpy==1.2.1
+ - crayons==0.4.0
+ - cryptography==42.0.5
+ - cycler==0.12.1
+ - dash==2.16.1
+ - dash-core-components==2.0.0
+ - dash-html-components==2.0.0
+ - dash-table==5.0.0
+ - decorator==5.1.1
+ - dnspython==2.6.1
+ - einops==0.7.0
+ - exceptiongroup==1.2.0
+ - executing==2.0.1
+ - fastjsonschema==2.19.1
+ - filelock==3.13.3
+ - flask==3.0.2
+ - fonttools==4.50.0
+ - fortniteapiasync==0.1.7
+ - fortnitepy==3.6.9
+ - frozenlist==1.4.1
+ - fsspec==2024.3.1
+ - h5py==3.10.0
+ - html5tagger==1.3.0
+ - httptools==0.6.1
+ - idna==3.6
+ - importlib-metadata==7.1.0
+ - ipython==8.23.0
+ - ipywidgets==8.1.2
+ - itsdangerous==2.1.2
+ - jedi==0.19.1
+ - jinja2==3.1.3
+ - joblib==1.3.2
+ - jsonschema==4.21.1
+ - jsonschema-specifications==2023.12.1
+ - jupyter-core==5.7.2
+ - jupyterlab-widgets==3.0.10
+ - kiwisolver==1.4.5
+ - lxml==4.9.4
+ - markupsafe==2.1.5
+ - matplotlib==3.8.4
+ - matplotlib-inline==0.1.6
+ - mpmath==1.3.0
+ - multidict==6.0.5
+ - nbformat==5.10.4
+ - nest-asyncio==1.6.0
+ - networkx==3.2.1
+ - numpy==1.26.4
+ - nvidia-cublas-cu12==12.1.3.1
+ - nvidia-cuda-cupti-cu12==12.1.105
+ - nvidia-cuda-nvrtc-cu12==12.1.105
+ - nvidia-cuda-runtime-cu12==12.1.105
+ - nvidia-cudnn-cu12==8.9.2.26
+ - nvidia-cufft-cu12==11.0.2.54
+ - nvidia-curand-cu12==10.3.2.106
+ - nvidia-cusolver-cu12==11.4.5.107
+ - nvidia-cusparse-cu12==12.1.0.106
+ - nvidia-nccl-cu12==2.19.3
+ - nvidia-nvjitlink-cu12==12.4.127
+ - nvidia-nvtx-cu12==12.1.105
+ - open3d==0.18.0
+ - opencv-contrib-python==4.5.5.64
+ - packaging==24.0
+ - pandas==2.2.1
+ - parso==0.8.3
+ - pexpect==4.9.0
+ - pillow==10.3.0
+ - platformdirs==4.2.0
+ - plotly==5.20.0
+ - prompt-toolkit==3.0.43
+ - ptyprocess==0.7.0
+ - pure-eval==0.2.2
+ - pyasn1==0.6.0
+ - pyasn1-modules==0.4.0
+ - pybind11==2.12.0
+ - pycolmap==0.6.1
+ - pycparser==2.22
+ - pygments==2.17.2
+ - pyopengl==3.1.7
+ - pyopengl-accelerate==3.1.7
+ - pyopenssl==24.1.0
+ - pyparsing==3.1.2
+ - pyquaternion==0.9.9
+ - python-dateutil==2.9.0.post0
+ - pytz==2024.1
+ - pyyaml==6.0.1
+ - referencing==0.34.0
+ - requests==2.31.0
+ - retrying==1.3.4
+ - rpds-py==0.18.0
+ - sanic==23.12.1
+ - sanic-routing==23.12.0
+ - scikit-learn==1.4.1.post1
+ - scipy==1.13.0
+ - six==1.16.0
+ - sortedcollections==2.1.0
+ - sortedcontainers==2.4.0
+ - stack-data==0.6.3
+ - sympy==1.12
+ - tenacity==8.2.3
+ - threadpoolctl==3.4.0
+ - torch==2.2.2
+ - torchvision==0.17.2
+ - tqdm==4.66.2
+ - tracerite==1.1.1
+ - traitlets==5.14.2
+ - triton==2.2.0
+ - typing-extensions==4.10.0
+ - tzdata==2024.1
+ - tzlocal==5.2
+ - ujson==5.9.0
+ - urllib3==2.2.1
+ - uvloop==0.15.2
+ - wcwidth==0.2.13
+ - websockets==12.0
+ - werkzeug==3.0.2
+ - widgetsnbextension==4.0.10
+ - yaml2==0.0.1
+ - yarl==1.9.4
+ - zipp==3.18.1
diff --git a/third_party/pram/inference.py b/third_party/pram/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..29ccd76911f0b2ff8dc82fc28c712cf1d19d40be
--- /dev/null
+++ b/third_party/pram/inference.py
@@ -0,0 +1,62 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> inference
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 03/04/2024 16:06
+=================================================='''
+import argparse
+import torch
+import torchvision.transforms.transforms as tvt
+import yaml
+from nets.load_segnet import load_segnet
+from nets.sfd2 import load_sfd2
+from dataset.get_dataset import compose_datasets
+
+parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--config', type=str, required=True, help='config of specifications')
+parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks')
+parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth')
+parser.add_argument('--rec_weight_path', type=str, required=True, help='recognition weight')
+parser.add_argument('--online', action='store_true', help='online visualization with pangolin')
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ with open(args.config, 'rt') as f:
+ config = yaml.load(f, Loader=yaml.Loader)
+ config['landmark_path'] = args.landmark_path
+
+ feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval()
+ print('Load SFD2 weight from {:s}'.format(args.feat_weight_path))
+
+ # rec_model = get_model(config=config)
+ rec_model = load_segnet(network=config['network'],
+ n_class=config['n_class'],
+ desc_dim=256 if config['use_mid_feature'] else 128,
+ n_layers=config['layers'],
+ output_dim=config['output_dim'])
+ state_dict = torch.load(args.rec_weight_path, map_location='cpu')['model']
+ rec_model.load_state_dict(state_dict, strict=True)
+ print('Load recognition weight from {:s}'.format(args.rec_weight_path))
+
+ img_transforms = []
+ img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
+ img_transforms = tvt.Compose(img_transforms)
+
+ dataset = config['dataset']
+ if not args.online:
+ from localization.loc_by_rec_eval import loc_by_rec_eval
+
+ test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1)
+ config['n_class'] = test_set.n_class
+
+ loc_by_rec_eval(rec_model=rec_model.cuda().eval(),
+ loader=test_set,
+ local_feat=feat_model.cuda().eval(),
+ config=config, img_transforms=img_transforms)
+ else:
+ from localization.loc_by_rec_online import loc_by_rec_online
+
+ loc_by_rec_online(rec_model=rec_model.cuda().eval(),
+ local_feat=feat_model.cuda().eval(),
+ config=config, img_transforms=img_transforms)
diff --git a/third_party/pram/localization/base_model.py b/third_party/pram/localization/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..432f49c325d39aa44efb0c3106abf7e376c8244e
--- /dev/null
+++ b/third_party/pram/localization/base_model.py
@@ -0,0 +1,45 @@
+from abc import ABCMeta, abstractmethod
+from torch import nn
+from copy import copy
+import inspect
+
+
+class BaseModel(nn.Module, metaclass=ABCMeta):
+ default_conf = {}
+ required_data_keys = []
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ super().__init__()
+ self.conf = conf = {**self.default_conf, **conf}
+ self.required_data_keys = copy(self.required_data_keys)
+ self._init(conf)
+
+ def forward(self, data):
+ """Check the data and call the _forward method of the child model."""
+ for key in self.required_data_keys:
+ assert key in data, 'Missing key {} in data'.format(key)
+ return self._forward(data)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _forward(self, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+
+def dynamic_load(root, model):
+ module_path = f'{root.__name__}.{model}'
+ module = __import__(module_path, fromlist=[''])
+ classes = inspect.getmembers(module, inspect.isclass)
+ # Filter classes defined in the module
+ classes = [c for c in classes if c[1].__module__ == module_path]
+ # Filter classes inherited from BaseModel
+ classes = [c for c in classes if issubclass(c[1], BaseModel)]
+ assert len(classes) == 1, classes
+ return classes[0][1]
+ # return getattr(module, 'Model')
diff --git a/third_party/pram/localization/camera.py b/third_party/pram/localization/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d77af63bcac68b87acd6f5ddc19d92c7d99d07
--- /dev/null
+++ b/third_party/pram/localization/camera.py
@@ -0,0 +1,11 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> camera
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 04/03/2024 11:27
+=================================================='''
+import collections
+
+Camera = collections.namedtuple(
+ "Camera", ["id", "model", "width", "height", "params"])
diff --git a/third_party/pram/localization/extract_features.py b/third_party/pram/localization/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd3f85c53dafd33fe737fdb9e79eeee1bd1c600b
--- /dev/null
+++ b/third_party/pram/localization/extract_features.py
@@ -0,0 +1,256 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> extract_features.py
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 14:49
+=================================================='''
+import os
+import os.path as osp
+import h5py
+import numpy as np
+import progressbar
+import yaml
+import torch
+import cv2
+import torch.utils.data as Data
+from tqdm import tqdm
+from types import SimpleNamespace
+import logging
+import pprint
+from pathlib import Path
+import argparse
+from nets.sfd2 import ResNet4x, extract_sfd2_return
+from nets.superpoint import SuperPoint, extract_sp_return
+
+confs = {
+ 'superpoint-n4096': {
+ 'output': 'feats-superpoint-n4096',
+ 'model': {
+ 'name': 'superpoint',
+ 'outdim': 256,
+ 'use_stability': False,
+ 'nms_radius': 3,
+ 'max_keypoints': 4096,
+ 'conf_th': 0.005,
+ 'multiscale': False,
+ 'scales': [1.0],
+ 'model_fn': osp.join(os.getcwd(),
+ "weights/superpoint_v1.pth"),
+ },
+ 'preprocessing': {
+ 'grayscale': True,
+ 'resize_max': False,
+ },
+ },
+
+ 'resnet4x-20230511-210205-pho-0005': {
+ 'output': 'feats-resnet4x-20230511-210205-pho-0005',
+ 'model': {
+ 'outdim': 128,
+ 'name': 'resnet4x',
+ 'use_stability': False,
+ 'max_keypoints': 4096,
+ 'conf_th': 0.005,
+ 'multiscale': False,
+ 'scales': [1.0],
+ 'model_fn': osp.join(os.getcwd(),
+ "weights/sfd2_20230511_210205_resnet4x.79.pth"),
+ },
+ 'preprocessing': {
+ 'grayscale': False,
+ 'resize_max': False,
+ },
+ 'mask': False,
+ },
+
+ 'sfd2': {
+ 'output': 'feats-sfd2',
+ 'model': {
+ 'outdim': 128,
+ 'name': 'resnet4x',
+ 'use_stability': False,
+ 'max_keypoints': 4096,
+ 'conf_th': 0.005,
+ 'multiscale': False,
+ 'scales': [1.0],
+ 'model_fn': osp.join(os.getcwd(),
+ "weights/sfd2_20230511_210205_resnet4x.79.pth"),
+ },
+ 'preprocessing': {
+ 'grayscale': False,
+ 'resize_max': False,
+ },
+ 'mask': False,
+ },
+}
+
+
+class ImageDataset(Data.Dataset):
+ default_conf = {
+ 'globs': ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG'],
+ 'grayscale': False,
+ 'resize_max': None,
+ 'resize_force': False,
+ }
+
+ def __init__(self, root, conf, image_list=None,
+ mask_root=None):
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
+ self.root = root
+
+ self.paths = []
+ if image_list is None:
+ for g in conf.globs:
+ self.paths += list(Path(root).glob('**/' + g))
+ if len(self.paths) == 0:
+ raise ValueError(f'Could not find any image in root: {root}.')
+ self.paths = [i.relative_to(root) for i in self.paths]
+ else:
+ with open(image_list, "r") as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip()
+ self.paths.append(Path(l))
+
+ logging.info(f'Found {len(self.paths)} images in root {root}.')
+
+ if mask_root is not None:
+ self.mask_root = mask_root
+ else:
+ self.mask_root = None
+
+ def __getitem__(self, idx):
+ path = self.paths[idx]
+ if self.conf.grayscale:
+ mode = cv2.IMREAD_GRAYSCALE
+ else:
+ mode = cv2.IMREAD_COLOR
+ image = cv2.imread(str(self.root / path), mode)
+ if not self.conf.grayscale:
+ image = image[:, :, ::-1] # BGR to RGB
+ if image is None:
+ raise ValueError(f'Cannot read image {str(path)}.')
+ image = image.astype(np.float32)
+ size = image.shape[:2][::-1]
+ w, h = size
+
+ if self.conf.resize_max and (self.conf.resize_force
+ or max(w, h) > self.conf.resize_max):
+ scale = self.conf.resize_max / max(h, w)
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
+ image = cv2.resize(
+ image, (w_new, h_new), interpolation=cv2.INTER_CUBIC)
+
+ if self.conf.grayscale:
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = image / 255.
+
+ data = {
+ 'name': str(path),
+ 'image': image,
+ 'original_size': np.array(size),
+ }
+
+ if self.mask_root is not None:
+ mask_path = Path(str(path).replace("jpg", "png"))
+ if osp.exists(mask_path):
+ mask = cv2.imread(str(self.mask_root / mask_path))
+ mask = cv2.resize(mask, dsize=(image.shape[2], image.shape[1]), interpolation=cv2.INTER_NEAREST)
+ else:
+ mask = np.zeros(shape=(image.shape[1], image.shape[2], 3), dtype=np.uint8)
+
+ data['mask'] = mask
+
+ return data
+
+ def __len__(self):
+ return len(self.paths)
+
+
+def get_model(model_name, weight_path, outdim=128, **kwargs):
+ if model_name == 'superpoint':
+ model = SuperPoint(config={
+ 'descriptor_dim': 256,
+ 'nms_radius': 4,
+ 'keypoint_threshold': 0.005,
+ 'max_keypoints': -1,
+ 'remove_borders': 4,
+ 'weight_path': weight_path,
+ }).eval()
+
+ extractor = extract_sp_return
+
+ if model_name == 'resnet4x':
+ model = ResNet4x(outdim=outdim).eval()
+ model.load_state_dict(torch.load(weight_path)['state_dict'], strict=True)
+ extractor = extract_sfd2_return
+
+ return model, extractor
+
+
+@torch.no_grad()
+def main(conf, image_dir, export_dir):
+ logging.info('Extracting local features with configuration:'
+ f'\n{pprint.pformat(conf)}')
+ model, extractor = get_model(model_name=conf['model']['name'], weight_path=conf["model"]["model_fn"],
+ use_stability=conf['model']['use_stability'], outdim=conf['model']['outdim'])
+ model = model.cuda()
+ loader = ImageDataset(image_dir,
+ conf['preprocessing'],
+ image_list=args.image_list,
+ mask_root=None)
+ loader = torch.utils.data.DataLoader(loader, num_workers=4)
+
+ os.makedirs(export_dir, exist_ok=True)
+ feature_path = Path(export_dir, conf['output'] + '.h5')
+ feature_path.parent.mkdir(exist_ok=True, parents=True)
+ feature_file = h5py.File(str(feature_path), 'a')
+
+ with tqdm(total=len(loader)) as t:
+ for idx, data in enumerate(loader):
+ t.update()
+ pred = extractor(model, img=data["image"],
+ topK=conf["model"]["max_keypoints"],
+ mask=None,
+ conf_th=conf["model"]["conf_th"],
+ scales=conf["model"]["scales"],
+ )
+
+ # pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
+ pred['descriptors'] = pred['descriptors'].transpose()
+
+ t.set_postfix(npoints=pred['keypoints'].shape[0])
+ # print(pred['keypoints'].shape)
+
+ pred['image_size'] = original_size = data['original_size'][0].numpy()
+ # pred['descriptors'] = pred['descriptors'].T
+ if 'keypoints' in pred.keys():
+ size = np.array(data['image'].shape[-2:][::-1])
+ scales = (original_size / size).astype(np.float32)
+ pred['keypoints'] = (pred['keypoints'] + .5) * scales[None] - .5
+
+ grp = feature_file.create_group(data['name'][0])
+ for k, v in pred.items():
+ # print(k, v.shape)
+ grp.create_dataset(k, data=v)
+
+ del pred
+
+ feature_file.close()
+ logging.info('Finished exporting features.')
+
+ return feature_path
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--image_dir', type=Path, required=True)
+ parser.add_argument('--image_list', type=str, default=None)
+ parser.add_argument('--mask_dir', type=Path, default=None)
+ parser.add_argument('--export_dir', type=Path, required=True)
+ parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
+ args = parser.parse_args()
+ main(confs[args.conf], args.image_dir, args.export_dir)
diff --git a/third_party/pram/localization/frame.py b/third_party/pram/localization/frame.py
new file mode 100644
index 0000000000000000000000000000000000000000..467a0f31a9c62a19b4435c71add6d08e34b051f3
--- /dev/null
+++ b/third_party/pram/localization/frame.py
@@ -0,0 +1,195 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> frame
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 01/03/2024 10:08
+=================================================='''
+from collections import defaultdict
+
+import numpy as np
+import torch
+import pycolmap
+
+from localization.camera import Camera
+from localization.utils import compute_pose_error
+
+
+class Frame:
+ def __init__(self, image: np.ndarray, camera: pycolmap.Camera, id: int, name: str = None, qvec=None, tvec=None,
+ scene_name=None,
+ reference_frame_id=None):
+ self.image = image
+ self.camera = camera
+ self.id = id
+ self.name = name
+ self.image_size = np.array([camera.height, camera.width])
+ self.qvec = qvec
+ self.tvec = tvec
+ self.scene_name = scene_name
+ self.reference_frame_id = reference_frame_id
+
+ self.keypoints = None # [N, 3]
+ self.descriptors = None # [N, D]
+ self.segmentations = None # [N C]
+ self.seg_scores = None # [N C]
+ self.seg_ids = None # [N, 1]
+ self.point3D_ids = None # [N, 1]
+ self.xyzs = None
+
+ self.gt_qvec = None
+ self.gt_tvec = None
+
+ self.matched_scene_name = None
+ self.matched_keypoints = None
+ self.matched_keypoint_ids = None
+ self.matched_xyzs = None
+ self.matched_point3D_ids = None
+ self.matched_inliers = None
+ self.matched_sids = None
+ self.matched_order = None
+
+ self.refinement_reference_frame_ids = None
+ self.image_rec = None
+ self.image_matching = None
+ self.image_inlier = None
+ self.reference_frame_name = None
+ self.image_matching_tmp = None
+ self.image_inlier_tmp = None
+ self.reference_frame_name_tmp = None
+
+ self.tracking_status = None
+
+ self.time_feat = 0
+ self.time_rec = 0
+ self.time_loc = 0
+ self.time_ref = 0
+
+ def update_point3ds_old(self):
+ pt = torch.from_numpy(self.keypoints[:, :2]).unsqueeze(-1) # [M 2 1]
+ mpt = torch.from_numpy(self.matched_keypoints[:, :2].transpose()).unsqueeze(0) # [1 2 N]
+ dist = torch.sqrt(torch.sum((pt - mpt) ** 2, dim=1))
+ values, ids = torch.topk(dist, dim=1, k=1, largest=False)
+ values = values[:, 0].numpy()
+ ids = ids[:, 0].numpy()
+ mask = (values < 1) # 1 pixel error
+ self.point3D_ids = np.zeros(shape=(self.keypoints.shape[0],), dtype=int) - 1
+ self.point3D_ids[mask] = self.matched_point3D_ids[ids[mask]]
+
+ # self.xyzs = np.zeros(shape=(self.keypoints.shape[0], 3), dtype=float)
+ inlier_mask = self.matched_inliers
+ self.xyzs[mask] = self.matched_xyzs[ids[mask]]
+ self.seg_ids[mask] = self.matched_sids[ids[mask]]
+
+ def update_point3ds(self):
+ # print('Frame: update_point3ds: ', self.matched_keypoint_ids.shape, self.matched_xyzs.shape,
+ # self.matched_sids.shape, self.matched_point3D_ids.shape)
+ self.xyzs[self.matched_keypoint_ids] = self.matched_xyzs
+ self.seg_ids[self.matched_keypoint_ids] = self.matched_sids
+ self.point3D_ids[self.matched_keypoint_ids] = self.matched_point3D_ids
+
+ def add_keypoints(self, keypoints: np.ndarray, descriptors: np.ndarray):
+ self.keypoints = keypoints
+ self.descriptors = descriptors
+ self.initialize_localization_variables()
+
+ def add_segmentations(self, segmentations: torch.Tensor, filtering_threshold: float):
+ '''
+ :param segmentations: [number_points number_labels]
+ :return:
+ '''
+ seg_scores = torch.softmax(segmentations, dim=-1)
+ if filtering_threshold > 0:
+ scores_background = seg_scores[:, 0]
+ non_bg_mask = (scores_background < filtering_threshold)
+ print('pre filtering before: ', self.keypoints.shape)
+ if torch.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]:
+ self.keypoints = self.keypoints[non_bg_mask.cpu().numpy()]
+ self.descriptors = self.descriptors[non_bg_mask.cpu().numpy()]
+ # print('pre filtering after: ', self.keypoints.shape)
+
+ # update localization variables
+ self.initialize_localization_variables()
+
+ segmentations = segmentations[non_bg_mask]
+ seg_scores = seg_scores[non_bg_mask]
+ print('pre filtering after: ', self.keypoints.shape)
+
+ # extract initial segmentation info
+ self.segmentations = segmentations.cpu().numpy()
+ self.seg_scores = seg_scores.cpu().numpy()
+ self.seg_ids = segmentations.max(dim=-1)[1].cpu().numpy() - 1 # should start from 0
+
+ def filter_keypoints(self, seg_scores: np.ndarray, filtering_threshold: float):
+ scores_background = seg_scores[:, 0]
+ non_bg_mask = (scores_background < filtering_threshold)
+ print('pre filtering before: ', self.keypoints.shape)
+ if np.sum(non_bg_mask) >= 0.4 * seg_scores.shape[0]:
+ self.keypoints = self.keypoints[non_bg_mask]
+ self.descriptors = self.descriptors[non_bg_mask]
+ print('pre filtering after: ', self.keypoints.shape)
+
+ # update localization variables
+ self.initialize_localization_variables()
+ return non_bg_mask
+ else:
+ print('pre filtering after: ', self.keypoints.shape)
+ return None
+
+ def compute_pose_error(self, pred_qvec=None, pred_tvec=None):
+ if pred_qvec is not None and pred_tvec is not None:
+ if self.gt_qvec is not None and self.gt_tvec is not None:
+ return compute_pose_error(pred_qcw=pred_qvec, pred_tcw=pred_tvec,
+ gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec)
+ else:
+ return 100, 100
+
+ if self.qvec is None or self.tvec is None or self.gt_qvec is None or self.gt_tvec is None:
+ return 100, 100
+ else:
+ err_q, err_t = compute_pose_error(pred_qcw=self.qvec, pred_tcw=self.tvec,
+ gt_qcw=self.gt_qvec, gt_tcw=self.gt_tvec)
+ return err_q, err_t
+
+ def get_intrinsics(self) -> np.ndarray:
+ camera_model = self.camera.model.name
+ params = self.camera.params
+ if camera_model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
+ fx = fy = params[0]
+ cx = params[1]
+ cy = params[2]
+ elif camera_model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
+ fx = params[0]
+ fy = params[1]
+ cx = params[2]
+ cy = params[3]
+ else:
+ raise Exception("Camera model not supported")
+
+ # intrinsics
+ K = np.identity(3)
+ K[0, 0] = fx
+ K[1, 1] = fy
+ K[0, 2] = cx
+ K[1, 2] = cy
+ return K
+
+ def get_dominate_seg_id(self):
+ counts = np.bincount(self.seg_ids[self.seg_ids > 0])
+ return np.argmax(counts)
+
+ def clear_localization_track(self):
+ self.matched_scene_name = None
+ self.matched_keypoints = None
+ self.matched_xyzs = None
+ self.matched_point3D_ids = None
+ self.matched_inliers = None
+ self.matched_sids = None
+
+ self.refinement_reference_frame_ids = None
+
+ def initialize_localization_variables(self):
+ nkpt = self.keypoints.shape[0]
+ self.seg_ids = np.zeros(shape=(nkpt,), dtype=int) - 1
+ self.point3D_ids = np.zeros(shape=(nkpt,), dtype=int) - 1
+ self.xyzs = np.zeros(shape=(nkpt, 3), dtype=float)
diff --git a/third_party/pram/localization/loc_by_rec_eval.py b/third_party/pram/localization/loc_by_rec_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69b4ac3fde0547947abe983b1f5a4a4af55f974
--- /dev/null
+++ b/third_party/pram/localization/loc_by_rec_eval.py
@@ -0,0 +1,299 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> loc_by_rec
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 08/02/2024 15:26
+=================================================='''
+import torch
+from torch.autograd import Variable
+from localization.multimap3d import MultiMap3D
+from localization.frame import Frame
+import yaml, cv2, time
+import numpy as np
+import os.path as osp
+import threading
+import os
+from tqdm import tqdm
+from recognition.vis_seg import vis_seg_point, generate_color_dic
+from tools.metrics import compute_iou, compute_precision
+from localization.tracker import Tracker
+from localization.utils import read_query_info
+from localization.camera import Camera
+
+
+def loc_by_rec_eval(rec_model, loader, config, local_feat, img_transforms=None):
+ n_epoch = int(config['weight_path'].split('.')[1])
+ save_fn = osp.join(config['localization']['save_path'],
+ config['weight_path'].split('/')[0] + '_{:d}'.format(n_epoch) + '_{:d}'.format(
+ config['feat_dim']))
+ tag = 'k{:d}_th{:d}_mm{:d}_mi{:d}'.format(config['localization']['seg_k'], config['localization']['threshold'],
+ config['localization']['min_matches'],
+ config['localization']['min_inliers'])
+ if config['localization']['do_refinement']:
+ tag += '_op{:d}'.format(config['localization']['covisibility_frame'])
+ if config['localization']['with_compress']:
+ tag += '_comp'
+
+ save_fn = save_fn + '_' + tag
+
+ save = config['localization']['save']
+ save = config['localization']['save']
+ if save:
+ save_dir = save_fn
+ os.makedirs(save_dir, exist_ok=True)
+ else:
+ save_dir = None
+
+ seg_color = generate_color_dic(n_seg=2000)
+ dataset_path = config['dataset_path']
+ show = config['localization']['show']
+ if show:
+ cv2.namedWindow('img', cv2.WINDOW_NORMAL)
+
+ locMap = MultiMap3D(config=config, save_dir=None)
+ # start tracker
+ mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config)
+
+ dataset_name = config['dataset'][0]
+ all_scene_query_info = {}
+ with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f:
+ scene_config = yaml.load(f, Loader=yaml.Loader)
+ scenes = scene_config['scenes']
+ for scene in scenes:
+ query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path'])
+ query_info = read_query_info(query_fn=query_path)
+ all_scene_query_info[dataset_name + '/' + scene] = query_info
+ # print(scene, query_info.keys())
+
+ tracking = False
+
+ full_log = ''
+ failed_cases = []
+ success_cases = []
+ poses = {}
+ err_ths_cnt = [0, 0, 0, 0]
+
+ seg_results = {}
+ time_results = {
+ 'feat': [],
+ 'rec': [],
+ 'loc': [],
+ 'ref': [],
+ 'total': [],
+ }
+ n_total = 0
+
+ loc_scene_names = config['localization']['loc_scene_name']
+ # loader = loader[8990:]
+ for bid, pred in tqdm(enumerate(loader), total=len(loader)):
+ pred = loader[bid]
+ image_name = pred['file_name'] # [0]
+ scene_name = pred['scene_name'] # [0] # dataset_scene
+ if len(loc_scene_names) > 0:
+ skip = True
+ for loc_scene in loc_scene_names:
+ if scene_name.find(loc_scene) > 0:
+ skip = False
+ break
+ if skip:
+ continue
+ with torch.no_grad():
+ for k in pred:
+ if k.find('name') >= 0:
+ continue
+ if k != 'image0' and k != 'image1' and k != 'depth0' and k != 'depth1':
+ if type(pred[k]) == np.ndarray:
+ pred[k] = Variable(torch.from_numpy(pred[k]).float().cuda())[None]
+ elif type(pred[k]) == torch.Tensor:
+ pred[k] = Variable(pred[k].float().cuda())
+ elif type(pred[k]) == list:
+ continue
+ else:
+ pred[k] = Variable(torch.stack(pred[k]).float().cuda())
+ print('scene: ', scene_name, image_name)
+
+ n_total += 1
+ with torch.no_grad():
+ img = pred['image']
+ while isinstance(img, list):
+ img = img[0]
+
+ new_im = torch.from_numpy(img).permute(2, 0, 1).cuda().float()
+ if img_transforms is not None:
+ new_im = img_transforms(new_im)[None]
+ else:
+ new_im = new_im[None]
+ img = (img * 255).astype(np.uint8)
+
+ fn = image_name
+ camera_model, width, height, params = all_scene_query_info[scene_name][fn]
+ camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params)
+ curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=scene_name)
+ gt_sub_map = locMap.sub_maps[curr_frame.scene_name]
+ if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys():
+ curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec']
+ curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec']
+
+ t_start = time.time()
+ encoder_out = local_feat.extract_local_global(data={'image': new_im},
+ config=
+ {
+ # 'min_keypoints': 128,
+ 'max_keypoints': config['eval_max_keypoints'],
+ }
+ )
+ t_feat = time.time() - t_start
+ # global_descriptors_cuda = encoder_out['global_descriptors']
+ # scores_cuda = encoder_out['scores'][0][None]
+ # kpts_cuda = encoder_out['keypoints'][0][None]
+ # descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1)
+
+ sparse_scores = pred['scores']
+ sparse_descs = pred['descriptors']
+ sparse_kpts = pred['keypoints']
+ gt_seg = pred['gt_seg']
+
+ curr_frame.add_keypoints(keypoints=np.hstack([sparse_kpts[0].cpu().numpy(),
+ sparse_scores[0].cpu().numpy().reshape(-1, 1)]),
+ descriptors=sparse_descs[0].cpu().numpy())
+ curr_frame.time_feat = t_feat
+
+ t_start = time.time()
+ _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'],
+ semi_descs=encoder_out['mid_features'],
+ # kpts=kpts_cuda[0],
+ kpts=sparse_kpts[0],
+ norm_desc=config['norm_desc'])
+ rec_out = rec_model({'scores': sparse_scores,
+ 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1),
+ 'keypoints': sparse_kpts,
+ 'image': new_im})
+ t_rec = time.time() - t_start
+ curr_frame.time_rec = t_rec
+
+ pred = {
+ # 'scores': scores_cuda,
+ # 'keypoints': kpts_cuda,
+ # 'descriptors': descriptors_cuda,
+ # 'global_descriptors': global_descriptors_cuda,
+ 'image_size': np.array([img.shape[1], img.shape[0]])[None],
+ }
+
+ pred = {**pred, **rec_out}
+ pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C]
+
+ pred_seg = pred_seg[0].cpu().numpy()
+ kpts = sparse_kpts[0].cpu().numpy()
+ img_pred_seg = vis_seg_point(img=img, kpts=kpts, segs=pred_seg, seg_color=seg_color, radius=9)
+ show_text = 'kpts: {:d}'.format(kpts.shape[0])
+ img_pred_seg = cv2.putText(img=img_pred_seg, text=show_text,
+ org=(50, 30),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+ fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ curr_frame.image_rec = img_pred_seg
+
+ if show:
+ cv2.imshow('img', img)
+ key = cv2.waitKey(1)
+ if key == ord('q'):
+ exit(0)
+ elif key == ord('s'):
+ show_time = -1
+ elif key == ord('c'):
+ show_time = 1
+
+ segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C]
+ curr_frame.add_segmentations(segmentations=segmentations,
+ filtering_threshold=config['localization']['pre_filtering_th'])
+
+ # Step1: do tracker first
+ success = not mTracker.lost and tracking
+ if success:
+ success = mTracker.run(frame=curr_frame)
+ if not success:
+ success = locMap.run(q_frame=curr_frame)
+ if success:
+ curr_frame.update_point3ds()
+ if tracking:
+ mTracker.lost = False
+ mTracker.last_frame = curr_frame
+ # '''
+ pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C]
+ pred_seg = pred_seg[0].cpu().numpy()
+ gt_seg = gt_seg[0].cpu().numpy()
+ iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=pred_seg.shape[0],
+ ignored_ids=[0]) # 0 - background
+ prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0])
+
+ kpts = sparse_kpts[0].cpu().numpy()
+ if scene not in seg_results.keys():
+ seg_results[scene] = {
+ 'day': {
+ 'prec': [],
+ 'iou': [],
+ 'kpts': [],
+ },
+ 'night': {
+ 'prec': [],
+ 'iou': [],
+ 'kpts': [],
+
+ }
+ }
+ if fn.find('night') >= 0:
+ seg_results[scene]['night']['prec'].append(prec)
+ seg_results[scene]['night']['iou'].append(iou)
+ seg_results[scene]['night']['kpts'].append(kpts.shape[0])
+ else:
+ seg_results[scene]['day']['prec'].append(prec)
+ seg_results[scene]['day']['iou'].append(iou)
+ seg_results[scene]['day']['kpts'].append(kpts.shape[0])
+
+ print_text = 'name: {:s}, kpts: {:d}, iou: {:.3f}, prec: {:.3f}'.format(fn, kpts.shape[0], iou,
+ prec)
+ print(print_text)
+ # '''
+
+ t_feat = curr_frame.time_feat
+ t_rec = curr_frame.time_rec
+ t_loc = curr_frame.time_loc
+ t_ref = curr_frame.time_ref
+ t_total = t_feat + t_rec + t_loc + t_ref
+ time_results['feat'].append(t_feat)
+ time_results['rec'].append(t_rec)
+ time_results['loc'].append(t_loc)
+ time_results['ref'].append(t_ref)
+ time_results['total'].append(t_total)
+
+ poses[scene + '/' + fn] = (curr_frame.qvec, curr_frame.tvec)
+ q_err, t_err = curr_frame.compute_pose_error()
+ if q_err <= 5 and t_err <= 0.05:
+ err_ths_cnt[0] = err_ths_cnt[0] + 1
+ if q_err <= 2 and t_err <= 0.25:
+ err_ths_cnt[1] = err_ths_cnt[1] + 1
+ if q_err <= 5 and t_err <= 0.5:
+ err_ths_cnt[2] = err_ths_cnt[2] + 1
+ if q_err <= 10 and t_err <= 5:
+ err_ths_cnt[3] = err_ths_cnt[3] + 1
+
+ if success:
+ success_cases.append(scene + '/' + fn)
+ print_text = 'qname: {:s} localization success {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format(
+ scene + '/' + fn, len(success_cases), n_total, q_err, t_err, err_ths_cnt[0],
+ err_ths_cnt[1],
+ err_ths_cnt[2],
+ err_ths_cnt[3],
+ n_total,
+ t_feat, t_rec, t_loc, t_ref, t_total
+ )
+ else:
+ failed_cases.append(scene + '/' + fn)
+ print_text = 'qname: {:s} localization fail {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format(
+ scene + '/' + fn, len(failed_cases), n_total, q_err, t_err, err_ths_cnt[0],
+ err_ths_cnt[1],
+ err_ths_cnt[2],
+ err_ths_cnt[3],
+ n_total, t_feat, t_rec, t_loc, t_ref, t_total)
+ print(print_text)
diff --git a/third_party/pram/localization/loc_by_rec_online.py b/third_party/pram/localization/loc_by_rec_online.py
new file mode 100644
index 0000000000000000000000000000000000000000..58afed6eb439b23b4a0bc7daf45d50098bcc4fc2
--- /dev/null
+++ b/third_party/pram/localization/loc_by_rec_online.py
@@ -0,0 +1,225 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> loc_by_rec
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 08/02/2024 15:26
+=================================================='''
+import torch
+import pycolmap
+from localization.multimap3d import MultiMap3D
+from localization.frame import Frame
+import yaml, cv2, time
+import numpy as np
+import os.path as osp
+import threading
+from recognition.vis_seg import vis_seg_point, generate_color_dic
+from tools.common import resize_img
+from localization.viewer import Viewer
+from localization.tracker import Tracker
+from localization.utils import read_query_info
+from tools.common import puttext_with_background
+
+
+def loc_by_rec_online(rec_model, config, local_feat, img_transforms=None):
+ seg_color = generate_color_dic(n_seg=2000)
+ dataset_path = config['dataset_path']
+ show = config['localization']['show']
+ if show:
+ cv2.namedWindow('img', cv2.WINDOW_NORMAL)
+
+ locMap = MultiMap3D(config=config, save_dir=None)
+ if config['dataset'][0] in ['Aachen']:
+ viewer_config = {'scene': 'outdoor',
+ 'image_size_indoor': 4,
+ 'image_line_width_indoor': 8, }
+ elif config['dataset'][0] in ['C']:
+ viewer_config = {'scene': 'outdoor'}
+ elif config['dataset'][0] in ['12Scenes', '7Scenes']:
+ viewer_config = {'scene': 'indoor', }
+ else:
+ viewer_config = {'scene': 'outdoor',
+ 'image_size_indoor': 0.4,
+ 'image_line_width_indoor': 2, }
+ # start viewer
+ mViewer = Viewer(locMap=locMap, seg_color=seg_color, config=viewer_config)
+ mViewer.refinement = locMap.do_refinement
+ # locMap.viewer = mViewer
+ viewer_thread = threading.Thread(target=mViewer.run)
+ viewer_thread.start()
+
+ # start tracker
+ mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config)
+
+ dataset_name = config['dataset'][0]
+ all_scene_query_info = {}
+ with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f:
+ scene_config = yaml.load(f, Loader=yaml.Loader)
+
+ # multiple scenes in a single dataset
+ err_ths_cnt = [0, 0, 0, 0]
+
+ show_time = -1
+ scenes = scene_config['scenes']
+ n_total = 0
+ for scene in scenes:
+ if len(config['localization']['loc_scene_name']) > 0:
+ if scene not in config['localization']['loc_scene_name']:
+ continue
+
+ query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path'])
+ query_info = read_query_info(query_fn=query_path)
+ all_scene_query_info[dataset_name + '/' + scene] = query_info
+ image_path = osp.join(dataset_path, dataset_name, scene)
+ for fn in sorted(query_info.keys()):
+ # for fn in sorted(query_info.keys())[880:][::5]: # darwinRGB-loc-outdoor-aligned
+ # for fn in sorted(query_info.keys())[3161:][::5]: # darwinRGB-loc-indoor-aligned
+ # for fn in sorted(query_info.keys())[2840:][::5]: # darwinRGB-loc-indoor-aligned
+
+ # for fn in sorted(query_info.keys())[2100:][::5]: # darwinRGB-loc-outdoor
+ # for fn in sorted(query_info.keys())[4360:][::5]: # darwinRGB-loc-indoor
+ # for fn in sorted(query_info.keys())[1380:]: # Cam-Church
+ # for fn in sorted(query_info.keys())[::5]: #ACUED-test2
+ # for fn in sorted(query_info.keys())[1260:]: # jesus aligned
+ # for fn in sorted(query_info.keys())[1260:]: # jesus aligned
+ # for fn in sorted(query_info.keys())[4850:]:
+ img = cv2.imread(osp.join(image_path, fn)) # BGR
+
+ camera_model, width, height, params = all_scene_query_info[dataset_name + '/' + scene][fn]
+ # camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params)
+ camera = pycolmap.Camera(model=camera_model, width=int(width), height=int(height), params=params)
+ curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=dataset_name + '/' + scene)
+ gt_sub_map = locMap.sub_maps[curr_frame.scene_name]
+ if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys():
+ curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec']
+ curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec']
+
+ with torch.no_grad():
+ if config['image_dim'] == 1:
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ img_cuda = torch.from_numpy(img_gray / 255)[None].cuda().float()
+ else:
+ img_cuda = torch.from_numpy(img / 255).permute(2, 0, 1).cuda().float()
+ if img_transforms is not None:
+ img_cuda = img_transforms(img_cuda)[None]
+ else:
+ img_cuda = img_cuda[None]
+
+ t_start = time.time()
+ encoder_out = local_feat.extract_local_global(data={'image': img_cuda},
+ config={'min_keypoints': 128,
+ 'max_keypoints': config['eval_max_keypoints'],
+ }
+ )
+ t_feat = time.time() - t_start
+ # global_descriptors_cuda = encoder_out['global_descriptors']
+ scores_cuda = encoder_out['scores'][0][None]
+ kpts_cuda = encoder_out['keypoints'][0][None]
+ descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1)
+
+ curr_frame.add_keypoints(keypoints=np.hstack([kpts_cuda[0].cpu().numpy(),
+ scores_cuda[0].cpu().numpy().reshape(-1, 1)]),
+ descriptors=descriptors_cuda[0].cpu().numpy())
+ curr_frame.time_feat = t_feat
+
+ t_start = time.time()
+ _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'],
+ semi_descs=encoder_out['mid_features'],
+ kpts=kpts_cuda[0],
+ norm_desc=config['norm_desc'])
+ rec_out = rec_model({'scores': scores_cuda,
+ 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1),
+ 'keypoints': kpts_cuda,
+ 'image': img_cuda})
+ t_rec = time.time() - t_start
+ curr_frame.time_rec = t_rec
+
+ pred = {
+ 'scores': scores_cuda,
+ 'keypoints': kpts_cuda,
+ 'descriptors': descriptors_cuda,
+ # 'global_descriptors': global_descriptors_cuda,
+ 'image_size': np.array([img.shape[1], img.shape[0]])[None],
+ }
+
+ pred = {**pred, **rec_out}
+ pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C]
+
+ pred_seg = pred_seg[0].cpu().numpy()
+ kpts = kpts_cuda[0].cpu().numpy()
+ segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C]
+ curr_frame.add_segmentations(segmentations=segmentations,
+ filtering_threshold=config['localization']['pre_filtering_th'])
+
+ img_pred_seg = vis_seg_point(img=img, kpts=curr_frame.keypoints,
+ segs=curr_frame.seg_ids + 1, seg_color=seg_color, radius=9)
+ show_text = 'kpts: {:d}'.format(kpts.shape[0])
+ img_pred_seg = cv2.putText(img=img_pred_seg,
+ text=show_text,
+ org=(50, 30),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+ fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ curr_frame.image_rec = img_pred_seg
+
+ if show:
+ img_text = puttext_with_background(image=img, text='Press C - continue | S - pause | Q - exit',
+ org=(30, 50),
+ bg_color=(255, 255, 255),
+ text_color=(0, 0, 255),
+ fontScale=1, thickness=2)
+ cv2.imshow('img', img_text)
+ key = cv2.waitKey(show_time)
+ if key == ord('q'):
+ exit(0)
+ elif key == ord('s'):
+ show_time = -1
+ elif key == ord('c'):
+ show_time = 1
+
+ # Step1: do tracker first
+ success = not mTracker.lost and mViewer.tracking
+ if success:
+ success = mTracker.run(frame=curr_frame)
+ if success:
+ mViewer.update(curr_frame=curr_frame)
+
+ if not success:
+ # success = locMap.run(q_frame=curr_frame, q_segs=segmentations)
+ success = locMap.run(q_frame=curr_frame)
+ if success:
+ mViewer.update(curr_frame=curr_frame)
+
+ if success:
+ curr_frame.update_point3ds()
+ if mViewer.tracking:
+ mTracker.lost = False
+ mTracker.last_frame = curr_frame
+
+ time.sleep(50 / 1000)
+ locMap.do_refinement = mViewer.refinement
+
+ n_total = n_total + 1
+ q_err, t_err = curr_frame.compute_pose_error()
+ if q_err <= 5 and t_err <= 0.05:
+ err_ths_cnt[0] = err_ths_cnt[0] + 1
+ if q_err <= 2 and t_err <= 0.25:
+ err_ths_cnt[1] = err_ths_cnt[1] + 1
+ if q_err <= 5 and t_err <= 0.5:
+ err_ths_cnt[2] = err_ths_cnt[2] + 1
+ if q_err <= 10 and t_err <= 5:
+ err_ths_cnt[3] = err_ths_cnt[3] + 1
+ time_total = curr_frame.time_feat + curr_frame.time_rec + curr_frame.time_loc + curr_frame.time_ref
+ print_text = 'qname: {:s} localization {:b}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format(
+ scene + '/' + fn, success, q_err, t_err,
+ err_ths_cnt[0],
+ err_ths_cnt[1],
+ err_ths_cnt[2],
+ err_ths_cnt[3],
+ n_total,
+ curr_frame.time_feat, curr_frame.time_rec, curr_frame.time_loc, curr_frame.time_ref, time_total
+ )
+ print(print_text)
+
+ mViewer.terminate()
+ viewer_thread.join()
diff --git a/third_party/pram/localization/localizer.py b/third_party/pram/localization/localizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0777b9cc6d7f70aa8c3699f360684cd24054a488
--- /dev/null
+++ b/third_party/pram/localization/localizer.py
@@ -0,0 +1,217 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> hloc
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 16:45
+=================================================='''
+
+import os
+import os.path as osp
+from tqdm import tqdm
+import argparse
+import time
+import logging
+import h5py
+import numpy as np
+from pathlib import Path
+from colmap_utils.read_write_model import read_model
+from colmap_utils.parsers import parse_image_lists_with_intrinsics
+# localization
+from localization.match_features_batch import confs
+from localization.base_model import dynamic_load
+from localization import matchers
+from localization.utils import compute_pose_error, read_gt_pose, read_retrieval_results
+from localization.pose_estimator import pose_estimator_hloc, pose_estimator_iterative
+
+
+def run(args):
+ if args.gt_pose_fn is not None:
+ gt_poses = read_gt_pose(path=args.gt_pose_fn)
+ else:
+ gt_poses = {}
+ retrievals = read_retrieval_results(args.retrieval)
+
+ save_root = args.save_root # path to save
+ os.makedirs(save_root, exist_ok=True)
+ matcher_name = args.matcher_method # matching method
+ print('matcher: ', confs[args.matcher_method]['model']['name'])
+ Model = dynamic_load(matchers, confs[args.matcher_method]['model']['name'])
+ matcher = Model(confs[args.matcher_method]['model']).eval().cuda()
+
+ local_feat_name = args.features.as_posix().split("/")[-1].split(".")[0] # name of local features
+ save_fn = '{:s}_{:s}'.format(local_feat_name, matcher_name)
+ if args.use_hloc:
+ save_fn = 'hloc_' + save_fn
+ save_fn = osp.join(save_root, save_fn)
+
+ queries = parse_image_lists_with_intrinsics(args.queries)
+ _, db_images, points3D = read_model(str(args.reference_sfm), '.bin')
+ db_name_to_id = {image.name: i for i, image in db_images.items()}
+ feature_file = h5py.File(args.features, 'r')
+
+ tag = ''
+ if args.do_covisible_opt:
+ tag = tag + "_o" + str(int(args.obs_thresh)) + 'op' + str(int(args.covisibility_frame))
+ tag = tag + "th" + str(int(args.opt_thresh))
+ if args.iters > 0:
+ tag = tag + "i" + str(int(args.iters))
+
+ log_fn = save_fn + tag
+ vis_dir = save_fn + tag
+ results = save_fn + tag
+
+ full_log_fn = log_fn + '_full.log'
+ loc_log_fn = log_fn + '_loc.npy'
+ results = Path(results + '.txt')
+ vis_dir = Path(vis_dir)
+ if vis_dir is not None:
+ Path(vis_dir).mkdir(exist_ok=True)
+ print("save_fn: ", log_fn)
+
+ logging.info('Starting localization...')
+ poses = {}
+ failed_cases = []
+ n_total = 0
+ n_failed = 0
+ full_log_info = ''
+ loc_results = {}
+
+ error_ths = ((0.25, 2), (0.5, 5), (5, 10))
+ success = [0, 0, 0]
+ total_loc_time = []
+
+ for qname, qinfo in tqdm(queries):
+ kpq = feature_file[qname]['keypoints'].__array__()
+ n_total += 1
+ time_start = time.time()
+
+ if qname in retrievals.keys():
+ cans = retrievals[qname]
+ db_ids = [db_name_to_id[v] for v in cans]
+ else:
+ cans = []
+ db_ids = []
+ time_coarse = time.time()
+
+ if args.use_hloc:
+ output = pose_estimator_hloc(qname=qname, qinfo=qinfo, db_ids=db_ids, db_images=db_images,
+ points3D=points3D,
+ feature_file=feature_file,
+ thresh=args.ransac_thresh,
+ image_dir=args.image_dir,
+ matcher=matcher,
+ log_info='',
+ query_img_prefix='',
+ db_img_prefix='')
+ else: # should be faster and more accurate than hloc
+ t_start = time.time()
+ output = pose_estimator_iterative(qname=qname,
+ qinfo=qinfo,
+ matcher=matcher,
+ db_ids=db_ids,
+ db_images=db_images,
+ points3D=points3D,
+ feature_file=feature_file,
+ thresh=args.ransac_thresh,
+ image_dir=args.image_dir,
+ do_covisibility_opt=args.do_covisible_opt,
+ covisibility_frame=args.covisibility_frame,
+ log_info='',
+ inlier_th=args.inlier_thresh,
+ obs_th=args.obs_thresh,
+ opt_th=args.opt_thresh,
+ gt_qvec=gt_poses[qname]['qvec'] if qname in gt_poses.keys() else None,
+ gt_tvec=gt_poses[qname]['tvec'] if qname in gt_poses.keys() else None,
+ query_img_prefix='',
+ db_img_prefix='database',
+ )
+ time_full = time.time()
+
+ qvec = output['qvec']
+ tvec = output['tvec']
+ loc_time = time_full - time_start
+ total_loc_time.append(loc_time)
+
+ poses[qname] = (qvec, tvec)
+ print_text = "All {:d}/{:d} failed cases, time[cs/fn]: {:.2f}/{:.2f}".format(
+ n_failed, n_total,
+ time_coarse - time_start,
+ time_full - time_coarse,
+ )
+
+ if qname in gt_poses.keys():
+ gt_qvec = gt_poses[qname]['qvec']
+ gt_tvec = gt_poses[qname]['tvec']
+
+ q_error, t_error = compute_pose_error(pred_qcw=qvec, pred_tcw=tvec, gt_qcw=gt_qvec, gt_tcw=gt_tvec)
+
+ for error_idx, th in enumerate(error_ths):
+ if t_error <= th[0] and q_error <= th[1]:
+ success[error_idx] += 1
+ print_text += (
+ ', q_error:{:.2f} t_error:{:.2f} {:d}/{:d}/{:d}/{:d}, time: {:.2f}, {:d}pts'.format(q_error, t_error,
+ success[0],
+ success[1],
+ success[2], n_total,
+ loc_time,
+ kpq.shape[0]))
+ if output['num_inliers'] == 0:
+ failed_cases.append(qname)
+
+ loc_results[qname] = {
+ 'keypoints_query': output['keypoints_query'],
+ 'points3D_ids': output['points3D_ids'],
+ }
+ full_log_info = full_log_info + output['log_info']
+ full_log_info += (print_text + "\n")
+ print(print_text)
+
+ logs_path = f'{results}.failed'
+ with open(logs_path, 'w') as f:
+ for v in failed_cases:
+ print(v)
+ f.write(v + "\n")
+
+ logging.info(f'Localized {len(poses)} / {len(queries)} images.')
+ logging.info(f'Writing poses to {results}...')
+ # logging.info(f'Mean loc time: {np.mean(total_loc_time)}...')
+ print('Mean loc time: {:.2f}...'.format(np.mean(total_loc_time)))
+ with open(results, 'w') as f:
+ for q in poses:
+ qvec, tvec = poses[q]
+ qvec = ' '.join(map(str, qvec))
+ tvec = ' '.join(map(str, tvec))
+ name = q
+ f.write(f'{name} {qvec} {tvec}\n')
+
+ with open(full_log_fn, 'w') as f:
+ f.write(full_log_info)
+
+ np.save(loc_log_fn, loc_results)
+ print('Save logs to ', loc_log_fn)
+ logging.info('Done!')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--image_dir', type=str, required=True)
+ parser.add_argument('--dataset', type=str, required=True)
+ parser.add_argument('--reference_sfm', type=Path, required=True)
+ parser.add_argument('--queries', type=Path, required=True)
+ parser.add_argument('--features', type=Path, required=True)
+ parser.add_argument('--ransac_thresh', type=float, default=12)
+ parser.add_argument('--covisibility_frame', type=int, default=50)
+ parser.add_argument('--do_covisible_opt', action='store_true')
+ parser.add_argument('--use_hloc', action='store_true')
+ parser.add_argument('--matcher_method', type=str, default="NNM")
+ parser.add_argument('--inlier_thresh', type=int, default=50)
+ parser.add_argument('--obs_thresh', type=float, default=3)
+ parser.add_argument('--opt_thresh', type=float, default=12)
+ parser.add_argument('--save_root', type=str, required=True)
+ parser.add_argument('--retrieval', type=Path, default=None)
+ parser.add_argument('--gt_pose_fn', type=str, default=None)
+
+ args = parser.parse_args()
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+ run(args=args)
diff --git a/third_party/pram/localization/match_features.py b/third_party/pram/localization/match_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1b4edccff67db24d97fadb47024eb09c026ce8
--- /dev/null
+++ b/third_party/pram/localization/match_features.py
@@ -0,0 +1,156 @@
+import argparse
+import torch
+from pathlib import Path
+import h5py
+import logging
+from tqdm import tqdm
+import pprint
+
+import localization.matchers as matchers
+from localization.base_model import dynamic_load
+from colmap_utils.parsers import names_to_pair
+
+confs = {
+ 'gm': {
+ 'output': 'gm',
+ 'model': {
+ 'name': 'gm',
+ 'weight_path': 'weights/imp_gm.900.pth',
+ 'sinkhorn_iterations': 20,
+ },
+ },
+ 'gml': {
+ 'output': 'gml',
+ 'model': {
+ 'name': 'gml',
+ 'weight_path': 'weights/imp_gml.920.pth',
+ 'sinkhorn_iterations': 20,
+ },
+ },
+
+ 'adagml': {
+ 'output': 'adagml',
+ 'model': {
+ 'name': 'adagml',
+ 'weight_path': 'weights/imp_adagml.80.pth',
+ 'sinkhorn_iterations': 20,
+ },
+ },
+
+ 'superglue': {
+ 'output': 'superglue',
+ 'model': {
+ 'name': 'superglue',
+ 'weights': 'outdoor',
+ 'sinkhorn_iterations': 20,
+ 'weight_path': 'weights/superglue_outdoor.pth',
+ },
+ },
+ 'NNM': {
+ 'output': 'NNM',
+ 'model': {
+ 'name': 'nearest_neighbor',
+ 'do_mutual_check': True,
+ 'distance_threshold': None,
+ },
+ },
+}
+
+
+@torch.no_grad()
+def main(conf, pairs, features, export_dir, exhaustive=False):
+ logging.info('Matching local features with configuration:'
+ f'\n{pprint.pformat(conf)}')
+
+ feature_path = Path(export_dir, features + '.h5')
+ assert feature_path.exists(), feature_path
+ feature_file = h5py.File(str(feature_path), 'r')
+ pairs_name = pairs.stem
+ if not exhaustive:
+ assert pairs.exists(), pairs
+ with open(pairs, 'r') as f:
+ pair_list = f.read().rstrip('\n').split('\n')
+ elif exhaustive:
+ logging.info(f'Writing exhaustive match pairs to {pairs}.')
+ assert not pairs.exists(), pairs
+
+ # get the list of images from the feature file
+ images = []
+ feature_file.visititems(
+ lambda name, obj: images.append(obj.parent.name.strip('/'))
+ if isinstance(obj, h5py.Dataset) else None)
+ images = list(set(images))
+
+ pair_list = [' '.join((images[i], images[j]))
+ for i in range(len(images)) for j in range(i)]
+ with open(str(pairs), 'w') as f:
+ f.write('\n'.join(pair_list))
+
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ Model = dynamic_load(matchers, conf['model']['name'])
+ model = Model(conf['model']).eval().to(device)
+
+ match_name = f'{features}-{conf["output"]}-{pairs_name}'
+ match_path = Path(export_dir, match_name + '.h5')
+
+ match_file = h5py.File(str(match_path), 'a')
+
+ matched = set()
+ for pair in tqdm(pair_list, smoothing=.1):
+ name0, name1 = pair.split(' ')
+ pair = names_to_pair(name0, name1)
+
+ # Avoid to recompute duplicates to save time
+ if len({(name0, name1), (name1, name0)} & matched) \
+ or pair in match_file:
+ continue
+
+ data = {}
+ feats0, feats1 = feature_file[name0], feature_file[name1]
+ for k in feats1.keys():
+ # data[k + '0'] = feats0[k].__array__()
+ if k == 'descriptors':
+ data[k + '0'] = feats0[k][()].transpose() # [N D]
+ else:
+ data[k + '0'] = feats0[k][()]
+ for k in feats1.keys():
+ # data[k + '1'] = feats1[k].__array__()
+ # data[k + '1'] = feats1[k][()].transpose() # [N D]
+ if k == 'descriptors':
+ data[k + '1'] = feats1[k][()].transpose() # [N D]
+ else:
+ data[k + '1'] = feats1[k][()]
+ data = {k: torch.from_numpy(v)[None].float().to(device)
+ for k, v in data.items()}
+
+ # some matchers might expect an image but only use its size
+ data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1])
+ data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1])
+
+ pred = model(data)
+ grp = match_file.create_group(pair)
+ matches = pred['matches0'][0].cpu().short().numpy()
+ grp.create_dataset('matches0', data=matches)
+
+ if 'matching_scores0' in pred:
+ scores = pred['matching_scores0'][0].cpu().half().numpy()
+ grp.create_dataset('matching_scores0', data=scores)
+
+ matched |= {(name0, name1), (name1, name0)}
+
+ match_file.close()
+ logging.info('Finished exporting matches.')
+
+ return match_path
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--export_dir', type=Path, required=True)
+ parser.add_argument('--features', type=str, required=True)
+ parser.add_argument('--pairs', type=Path, required=True)
+ parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
+ parser.add_argument('--exhaustive', action='store_true')
+ args = parser.parse_args()
+ main(confs[args.conf], args.pairs, args.features, args.export_dir,
+ exhaustive=args.exhaustive)
diff --git a/third_party/pram/localization/match_features_batch.py b/third_party/pram/localization/match_features_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c0dc9d4a1e4288892c365616e45304a19e93c3e
--- /dev/null
+++ b/third_party/pram/localization/match_features_batch.py
@@ -0,0 +1,242 @@
+import argparse
+import torch
+from pathlib import Path
+import h5py
+import logging
+from tqdm import tqdm
+import pprint
+from queue import Queue
+from threading import Thread
+from functools import partial
+from typing import Dict, List, Optional, Tuple, Union
+
+import localization.matchers as matchers
+from localization.base_model import dynamic_load
+from colmap_utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
+
+confs = {
+ 'gm': {
+ 'output': 'gm',
+ 'model': {
+ 'name': 'gm',
+ 'weight_path': 'weights/imp_gm.900.pth',
+ 'sinkhorn_iterations': 20,
+ },
+ },
+ 'gml': {
+ 'output': 'gml',
+ 'model': {
+ 'name': 'gml',
+ 'weight_path': 'weights/imp_gml.920.pth',
+ 'sinkhorn_iterations': 20,
+ },
+ },
+
+ 'adagml': {
+ 'output': 'adagml',
+ 'model': {
+ 'name': 'adagml',
+ 'weight_path': 'weights/imp_adagml.80.pth',
+ 'sinkhorn_iterations': 20,
+ },
+ },
+
+ 'superglue': {
+ 'output': 'superglue',
+ 'model': {
+ 'name': 'superglue',
+ 'weights': 'outdoor',
+ 'sinkhorn_iterations': 20,
+ 'weight_path': 'weights/superglue_outdoor.pth',
+ },
+ },
+ 'NNM': {
+ 'output': 'NNM',
+ 'model': {
+ 'name': 'nearest_neighbor',
+ 'do_mutual_check': True,
+ 'distance_threshold': None,
+ },
+ },
+}
+
+
+class WorkQueue:
+ def __init__(self, work_fn, num_threads=1):
+ self.queue = Queue(num_threads)
+ self.threads = [
+ Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads)
+ ]
+ for thread in self.threads:
+ thread.start()
+
+ def join(self):
+ for thread in self.threads:
+ self.queue.put(None)
+ for thread in self.threads:
+ thread.join()
+
+ def thread_fn(self, work_fn):
+ item = self.queue.get()
+ while item is not None:
+ work_fn(item)
+ item = self.queue.get()
+
+ def put(self, data):
+ self.queue.put(data)
+
+
+class FeaturePairsDataset(torch.utils.data.Dataset):
+ def __init__(self, pairs, feature_path_q, feature_path_r):
+ self.pairs = pairs
+ self.feature_path_q = feature_path_q
+ self.feature_path_r = feature_path_r
+
+ def __getitem__(self, idx):
+ name0, name1 = self.pairs[idx]
+ data = {}
+ with h5py.File(self.feature_path_q, "r") as fd:
+ grp = fd[name0]
+ for k, v in grp.items():
+ data[k + "0"] = torch.from_numpy(v.__array__()).float()
+ if k == 'descriptors':
+ data[k + '0'] = data[k + '0'].t()
+ # some matchers might expect an image but only use its size
+ data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
+ with h5py.File(self.feature_path_r, "r") as fd:
+ grp = fd[name1]
+ for k, v in grp.items():
+ data[k + "1"] = torch.from_numpy(v.__array__()).float()
+ if k == 'descriptors':
+ data[k + '1'] = data[k + '1'].t()
+ data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
+ return data
+
+ def __len__(self):
+ return len(self.pairs)
+
+
+def writer_fn(inp, match_path):
+ pair, pred = inp
+ with h5py.File(str(match_path), "a", libver="latest") as fd:
+ if pair in fd:
+ del fd[pair]
+ grp = fd.create_group(pair)
+ matches = pred["matches0"][0].cpu().short().numpy()
+ grp.create_dataset("matches0", data=matches)
+ if "matching_scores0" in pred:
+ scores = pred["matching_scores0"][0].cpu().half().numpy()
+ grp.create_dataset("matching_scores0", data=scores)
+
+
+def main(
+ conf: Dict,
+ pairs: Path,
+ features: Union[Path, str],
+ export_dir: Optional[Path] = None,
+ matches: Optional[Path] = None,
+ features_ref: Optional[Path] = None,
+ overwrite: bool = False,
+) -> Path:
+ if isinstance(features, Path) or Path(features).exists():
+ features_q = features
+ if matches is None:
+ raise ValueError(
+ "Either provide both features and matches as Path" " or both as names."
+ )
+ else:
+ if export_dir is None:
+ raise ValueError(
+ "Provide an export_dir if features is not" f" a file path: {features}."
+ )
+ features_q = Path(export_dir, features + ".h5")
+ if matches is None:
+ matches = Path(export_dir, f'{features}-{conf["output"]}-{pairs.stem}.h5')
+
+ if features_ref is None:
+ features_ref = features_q
+ match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite)
+
+ return matches
+
+
+def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None):
+ """Avoid to recompute duplicates to save time."""
+ pairs = set()
+ for i, j in pairs_all:
+ if (j, i) not in pairs:
+ pairs.add((i, j))
+ pairs = list(pairs)
+ if match_path is not None and match_path.exists():
+ with h5py.File(str(match_path), "r", libver="latest") as fd:
+ pairs_filtered = []
+ for i, j in pairs:
+ if (
+ names_to_pair(i, j) in fd
+ or names_to_pair(j, i) in fd
+ or names_to_pair_old(i, j) in fd
+ or names_to_pair_old(j, i) in fd
+ ):
+ continue
+ pairs_filtered.append((i, j))
+ return pairs_filtered
+ return pairs
+
+
+@torch.no_grad()
+def match_from_paths(
+ conf: Dict,
+ pairs_path: Path,
+ match_path: Path,
+ feature_path_q: Path,
+ feature_path_ref: Path,
+ overwrite: bool = False,
+) -> Path:
+ logging.info(
+ "Matching local features with configuration:" f"\n{pprint.pformat(conf)}"
+ )
+
+ if not feature_path_q.exists():
+ raise FileNotFoundError(f"Query feature file {feature_path_q}.")
+ if not feature_path_ref.exists():
+ raise FileNotFoundError(f"Reference feature file {feature_path_ref}.")
+ match_path.parent.mkdir(exist_ok=True, parents=True)
+
+ assert pairs_path.exists(), pairs_path
+ pairs = parse_retrieval(pairs_path)
+ pairs = [(q, r) for q, rs in pairs.items() for r in rs]
+ pairs = find_unique_new_pairs(pairs, None if overwrite else match_path)
+ if len(pairs) == 0:
+ logging.info("Skipping the matching.")
+ return
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ Model = dynamic_load(matchers, conf["model"]["name"])
+ model = Model(conf["model"]).eval().to(device)
+
+ dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref)
+ loader = torch.utils.data.DataLoader(
+ dataset, num_workers=4, batch_size=1, shuffle=False, pin_memory=True
+ )
+ writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5)
+
+ for idx, data in enumerate(tqdm(loader, smoothing=0.1)):
+ data = {
+ k: v if k.startswith("image") else v.to(device, non_blocking=True)
+ for k, v in data.items()
+ }
+ pred = model(data)
+ pair = names_to_pair(*pairs[idx])
+ writer_queue.put((pair, pred))
+ writer_queue.join()
+ logging.info("Finished exporting matches.")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--export_dir', type=Path, required=True)
+ parser.add_argument('--features', type=str, required=True)
+ parser.add_argument('--pairs', type=Path, required=True)
+ parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
+ args = parser.parse_args()
+ main(confs[args.conf], args.pairs, args.features, args.export_dir)
diff --git a/third_party/pram/localization/matchers/__init__.py b/third_party/pram/localization/matchers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7edac76f912b1e5ebb0401b6cc7a5d3c64ce963a
--- /dev/null
+++ b/third_party/pram/localization/matchers/__init__.py
@@ -0,0 +1,3 @@
+def get_matcher(matcher):
+ mod = __import__(f'{__name__}.{matcher}', fromlist=[''])
+ return getattr(mod, 'Model')
diff --git a/third_party/pram/localization/matchers/adagml.py b/third_party/pram/localization/matchers/adagml.py
new file mode 100644
index 0000000000000000000000000000000000000000..31a4bd2aa74bef934543b79567f148f5b8b7b092
--- /dev/null
+++ b/third_party/pram/localization/matchers/adagml.py
@@ -0,0 +1,41 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> adagml
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 11/02/2024 14:34
+=================================================='''
+import torch
+from localization.base_model import BaseModel
+from nets.adagml import AdaGML as GMatcher
+
+
+class AdaGML(BaseModel):
+ default_config = {
+ 'descriptor_dim': 128,
+ 'hidden_dim': 256,
+ 'weights': 'indoor',
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
+ 'sinkhorn_iterations': 20,
+ 'match_threshold': 0.2,
+ 'with_pose': False,
+ 'n_layers': 9,
+ 'n_min_tokens': 256,
+ 'with_sinkhorn': True,
+ 'weight_path': None,
+ }
+
+ required_inputs = [
+ 'image0', 'keypoints0', 'scores0', 'descriptors0',
+ 'image1', 'keypoints1', 'scores1', 'descriptors1',
+ ]
+
+ def _init(self, conf):
+ self.net = GMatcher(config=conf).eval()
+ state_dict = torch.load(conf['weight_path'], map_location='cpu')['model']
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def _forward(self, data):
+ with torch.no_grad():
+ return self.net(data)
diff --git a/third_party/pram/localization/matchers/gm.py b/third_party/pram/localization/matchers/gm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2484cdb521d28a8cc0b5be7148919cd46bc67b32
--- /dev/null
+++ b/third_party/pram/localization/matchers/gm.py
@@ -0,0 +1,44 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File r2d2 -> gm
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 25/05/2023 10:09
+=================================================='''
+import torch
+from localization.base_model import BaseModel
+from nets.gm import GM as GMatcher
+
+
+class GM(BaseModel):
+ default_config = {
+ 'descriptor_dim': 128,
+ 'hidden_dim': 256,
+ 'weights': 'indoor',
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
+ 'sinkhorn_iterations': 20,
+ 'match_threshold': 0.2,
+ 'with_pose': False,
+ 'n_layers': 9,
+ 'n_min_tokens': 256,
+ 'with_sinkhorn': True,
+
+ 'ac_fn': 'relu',
+ 'norm_fn': 'bn',
+ 'weight_path': None,
+ }
+
+ required_inputs = [
+ 'image0', 'keypoints0', 'scores0', 'descriptors0',
+ 'image1', 'keypoints1', 'scores1', 'descriptors1',
+ ]
+
+ def _init(self, conf):
+ self.net = GMatcher(config=conf).eval()
+ state_dict = torch.load(conf['weight_path'], map_location='cpu')['model']
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def _forward(self, data):
+ with torch.no_grad():
+ return self.net(data)
diff --git a/third_party/pram/localization/matchers/gml.py b/third_party/pram/localization/matchers/gml.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f9acdeaf3c7bd9670c1f7c49e2bbf709f1e8b4a
--- /dev/null
+++ b/third_party/pram/localization/matchers/gml.py
@@ -0,0 +1,45 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File localizer -> gml
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 15/01/2024 11:01
+=================================================='''
+import torch
+from localization.base_model import BaseModel
+from nets.gml import GML as GMatcher
+
+
+class GML(BaseModel):
+ default_config = {
+ 'descriptor_dim': 128,
+ 'hidden_dim': 256,
+ 'weights': 'indoor',
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
+ 'sinkhorn_iterations': 20,
+ 'match_threshold': 0.2,
+ 'with_pose': False,
+ 'n_layers': 9,
+ 'n_min_tokens': 256,
+ 'with_sinkhorn': True,
+
+ 'ac_fn': 'relu',
+ 'norm_fn': 'bn',
+ 'weight_path': None,
+ }
+
+ required_inputs = [
+ 'image0', 'keypoints0', 'scores0', 'descriptors0',
+ 'image1', 'keypoints1', 'scores1', 'descriptors1',
+ ]
+
+ def _init(self, conf):
+ self.net = GMatcher(config=conf).eval()
+ state_dict = torch.load(conf['weight_path'], map_location='cpu')['model']
+ self.net.load_state_dict(state_dict, strict=True)
+
+ def _forward(self, data):
+ with torch.no_grad():
+ # print(data['keypoints0'].shape, data['descriptors0'].shape, data['image0'].shape)
+ return self.net(data)
diff --git a/third_party/pram/localization/matchers/nearest_neighbor.py b/third_party/pram/localization/matchers/nearest_neighbor.py
new file mode 100644
index 0000000000000000000000000000000000000000..42b8078747535a269dab6131b4f20c0857c36c03
--- /dev/null
+++ b/third_party/pram/localization/matchers/nearest_neighbor.py
@@ -0,0 +1,56 @@
+import torch
+from localization.base_model import BaseModel
+
+
+def find_nn(sim, ratio_thresh, distance_thresh):
+ sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
+ dist_nn = 2 * (1 - sim_nn)
+ mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
+ if ratio_thresh:
+ mask = mask & (dist_nn[..., 0] <= (ratio_thresh ** 2) * dist_nn[..., 1])
+ if distance_thresh:
+ mask = mask & (dist_nn[..., 0] <= distance_thresh ** 2)
+ matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
+ scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0))
+ return matches, scores
+
+
+def mutual_check(m0, m1):
+ inds0 = torch.arange(m0.shape[-1], device=m0.device)
+ loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
+ ok = (m0 > -1) & (inds0 == loop)
+ m0_new = torch.where(ok, m0, m0.new_tensor(-1))
+ return m0_new
+
+
+class NearestNeighbor(BaseModel):
+ default_conf = {
+ 'ratio_threshold': None,
+ 'distance_threshold': None,
+ 'do_mutual_check': True,
+ }
+ required_inputs = ['descriptors0', 'descriptors1']
+
+ def _init(self, conf):
+ pass
+
+ def _forward(self, data):
+ sim = torch.einsum(
+ 'bdn,bdm->bnm', data['descriptors0'], data['descriptors1'])
+ matches0, scores0 = find_nn(
+ sim, self.conf['ratio_threshold'], self.conf['distance_threshold'])
+ # matches1, scores1 = find_nn(
+ # sim.transpose(1, 2), self.conf['ratio_threshold'],
+ # self.conf['distance_threshold'])
+ if self.conf['do_mutual_check']:
+ # print("with mutual check")
+ matches1, scores1 = find_nn(
+ sim.transpose(1, 2), self.conf['ratio_threshold'],
+ self.conf['distance_threshold'])
+ matches0 = mutual_check(matches0, matches1)
+ # else:
+ # print("no mutual check")
+ return {
+ 'matches0': matches0,
+ 'matching_scores0': scores0,
+ }
diff --git a/third_party/pram/localization/multimap3d.py b/third_party/pram/localization/multimap3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6100b4f4bfeb1d3f8bc94598723979e830bf4172
--- /dev/null
+++ b/third_party/pram/localization/multimap3d.py
@@ -0,0 +1,379 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> multimap3d
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 04/03/2024 13:47
+=================================================='''
+import numpy as np
+import os
+import os.path as osp
+import time
+import cv2
+import torch
+import yaml
+from copy import deepcopy
+from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches
+from localization.base_model import dynamic_load
+import localization.matchers as matchers
+from localization.match_features_batch import confs as matcher_confs
+from nets.gm import GM
+from tools.common import resize_img
+from localization.singlemap3d import SingleMap3D
+from localization.frame import Frame
+
+
+class MultiMap3D:
+ def __init__(self, config, viewer=None, save_dir=None):
+ self.config = config
+ self.save_dir = save_dir
+
+ self.scenes = []
+ self.sid_scene_name = []
+ self.sub_maps = {}
+ self.scene_name_start_sid = {}
+
+ self.loc_config = config['localization']
+ self.save_dir = save_dir
+ if self.save_dir is not None:
+ os.makedirs(self.save_dir, exist_ok=True)
+
+ self.matching_method = config['localization']['matching_method']
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ Model = dynamic_load(matchers, self.matching_method)
+ self.matcher = Model(matcher_confs[self.matching_method]['model']).eval().to(device)
+
+ self.initialize_map(config=config)
+ self.loc_config = config['localization']
+
+ self.viewer = viewer
+
+ # options
+ self.do_refinement = self.loc_config['do_refinement']
+ self.refinement_method = self.loc_config['refinement_method']
+ self.semantic_matching = self.loc_config['semantic_matching']
+ self.do_pre_filtering = self.loc_config['pre_filtering_th'] > 0
+ self.pre_filtering_th = self.loc_config['pre_filtering_th']
+
+ def initialize_map(self, config):
+ n_class = 0
+ datasets = config['dataset']
+
+ for name in datasets:
+ config_path = osp.join(config['config_path'], '{:s}.yaml'.format(name))
+ dataset_name = name
+
+ with open(config_path, 'r') as f:
+ scene_config = yaml.load(f, Loader=yaml.Loader)
+
+ scenes = scene_config['scenes']
+ for sid, scene in enumerate(scenes):
+ self.scenes.append(name + '/' + scene)
+
+ new_config = deepcopy(config)
+ new_config['dataset_path'] = osp.join(config['dataset_path'], dataset_name, scene)
+ new_config['landmark_path'] = osp.join(config['landmark_path'], dataset_name, scene)
+ new_config['n_cluster'] = scene_config[scene]['n_cluster']
+ new_config['cluster_mode'] = scene_config[scene]['cluster_mode']
+ new_config['cluster_method'] = scene_config[scene]['cluster_method']
+ new_config['gt_pose_path'] = scene_config[scene]['gt_pose_path']
+ new_config['image_path_prefix'] = scene_config[scene]['image_path_prefix']
+ sub_map = SingleMap3D(config=new_config,
+ matcher=self.matcher,
+ with_compress=config['localization']['with_compress'],
+ start_sid=n_class)
+ self.sub_maps[dataset_name + '/' + scene] = sub_map
+
+ n_scene_class = scene_config[scene]['n_cluster']
+ self.sid_scene_name = self.sid_scene_name + [dataset_name + '/' + scene for ni in range(n_scene_class)]
+ self.scene_name_start_sid[dataset_name + '/' + scene] = n_class
+ n_class = n_class + n_scene_class
+
+ # break
+ print('Load {} sub_maps from {} datasets'.format(len(self.sub_maps), len(datasets)))
+
+ def run(self, q_frame: Frame):
+ show = self.loc_config['show']
+ seg_color = generate_color_dic(n_seg=2000)
+ if show:
+ cv2.namedWindow('loc', cv2.WINDOW_NORMAL)
+
+ q_loc_segs = self.process_segmentations(segs=torch.from_numpy(q_frame.segmentations),
+ topk=self.loc_config['seg_k'])
+ q_pred_segs_top1 = q_frame.seg_ids # initial results
+
+ q_scene_name = q_frame.scene_name
+ q_name = q_frame.name
+ q_full_name = osp.join(q_scene_name, q_name)
+
+ q_loc_sids = {}
+ for v in q_loc_segs:
+ q_loc_sids[v[0]] = (v[1], v[2])
+ query_sids = list(q_loc_sids.keys())
+
+ for i, sid in enumerate(query_sids):
+ t_start = time.time()
+ q_kpt_ids = q_loc_sids[sid][0]
+ print(q_scene_name, q_name, sid)
+
+ sid = sid - 1 # start from 0, confused!
+
+ pred_scene_name = self.sid_scene_name[sid]
+ start_seg_id = self.scene_name_start_sid[pred_scene_name]
+ pred_sid_in_sub_scene = sid - self.scene_name_start_sid[pred_scene_name]
+ pred_sub_map = self.sub_maps[pred_scene_name]
+ pred_image_path_prefix = pred_sub_map.image_path_prefix
+
+ print('pred/gt scene: {:s}, {:s}, sid: {:d}'.format(pred_scene_name, q_scene_name, pred_sid_in_sub_scene))
+ print('{:s}/{:s}, pred: {:s}, sid: {:d}, order: {:d}'.format(q_scene_name, q_name, pred_scene_name, sid,
+ i))
+
+ if (q_kpt_ids.shape[0] >= self.loc_config['min_kpts']
+ and self.semantic_matching
+ and pred_sub_map.check_semantic_consistency(q_frame=q_frame,
+ sid=pred_sid_in_sub_scene,
+ overlap_ratio=0.5)):
+ semantic_matching = True
+ else:
+ q_kpt_ids = np.arange(q_frame.keypoints.shape[0])
+ semantic_matching = False
+ print_text = f'Semantic matching - {semantic_matching}! Query kpts {q_kpt_ids.shape[0]} for {i}th seg {sid}'
+ print(print_text)
+ ret = pred_sub_map.localize_with_ref_frame(q_frame=q_frame,
+ q_kpt_ids=q_kpt_ids,
+ sid=pred_sid_in_sub_scene,
+ semantic_matching=semantic_matching)
+
+ q_frame.time_loc = q_frame.time_loc + time.time() - t_start # accumulate tracking time
+
+ if show:
+ reference_frame = pred_sub_map.reference_frames[ret['reference_frame_id']]
+ ref_img = cv2.imread(osp.join(self.config['dataset_path'], pred_scene_name, pred_image_path_prefix,
+ reference_frame.name))
+ q_img_seg = vis_seg_point(img=q_frame.image, kpts=q_frame.keypoints[q_kpt_ids, :2],
+ segs=q_frame.seg_ids[q_kpt_ids] + 1,
+ seg_color=seg_color)
+ matched_points3D_ids = ret['matched_point3D_ids']
+ ref_sids = np.array([pred_sub_map.point3Ds[v].seg_id for v in matched_points3D_ids]) + \
+ self.scene_name_start_sid[pred_scene_name] + 1 # start from 1 as bg is 0
+ ref_img_seg = vis_seg_point(img=ref_img, kpts=ret['matched_ref_keypoints'], segs=ref_sids,
+ seg_color=seg_color)
+ q_matched_kpts = ret['matched_keypoints']
+ ref_matched_kpts = ret['matched_ref_keypoints']
+ img_loc_matching = plot_matches(img1=q_img_seg, img2=ref_img_seg,
+ pts1=q_matched_kpts, pts2=ref_matched_kpts,
+ inliers=np.array([True for i in range(q_matched_kpts.shape[0])]),
+ radius=9, line_thickness=3
+ )
+
+ q_frame.image_matching_tmp = img_loc_matching
+ q_frame.reference_frame_name_tmp = osp.join(self.config['dataset_path'],
+ pred_scene_name,
+ pred_image_path_prefix,
+ reference_frame.name)
+ # ret['image_matching'] = img_loc_matching
+ # ret['reference_frame_name'] = osp.join(self.config['dataset_path'],
+ # pred_scene_name,
+ # pred_image_path_prefix,
+ # reference_frame.name)
+ q_ref_img_matching = np.hstack([resize_img(q_img_seg, nh=512),
+ resize_img(ref_img_seg, nh=512),
+ resize_img(img_loc_matching, nh=512)])
+
+ ret['order'] = i
+ ret['matched_scene_name'] = pred_scene_name
+ if not ret['success']:
+ num_matches = ret['matched_keypoints'].shape[0]
+ num_inliers = ret['num_inliers']
+ print_text = f'Localization failed with {num_matches}/{q_kpt_ids.shape[0]} matches and {num_inliers} inliers, order {i}'
+ print(print_text)
+
+ if show:
+ show_text = 'FAIL! order: {:d}/{:d}-{:d}/{:d}'.format(i, len(q_loc_segs),
+ num_matches,
+ q_kpt_ids.shape[0])
+ q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'],
+ radius=9 + 2, thickness=2)
+ q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ q_frame.image_inlier_tmp = q_img_inlier
+ q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
+ cv2.imshow('loc', q_img_loc)
+ key = cv2.waitKey(self.loc_config['show_time'])
+ if key == ord('q'):
+ cv2.destroyAllWindows()
+ exit(0)
+ continue
+
+ if show:
+ q_err, t_err = q_frame.compute_pose_error()
+ num_matches = ret['matched_keypoints'].shape[0]
+ num_inliers = ret['num_inliers']
+ show_text = 'order: {:d}/{:d}, k/m/i: {:d}/{:d}/{:d}'.format(
+ i, len(q_loc_segs), q_kpt_ids.shape[0], num_matches, num_inliers)
+ q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'],
+ radius=9 + 2, thickness=2)
+ q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err)
+ q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ q_frame.image_inlier_tmp = q_img_inlier
+
+ q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
+
+ cv2.imshow('loc', q_img_loc)
+ key = cv2.waitKey(self.loc_config['show_time'])
+ if key == ord('q'):
+ cv2.destroyAllWindows()
+ exit(0)
+
+ success = self.verify_and_update(q_frame=q_frame, ret=ret)
+
+ if not success:
+ continue
+ else:
+ break
+
+ if q_frame.tracking_status is None:
+ print('Failed to find a proper reference frame.')
+ return False
+
+ # do refinement
+ if not self.do_refinement:
+ return True
+ else:
+ t_start = time.time()
+ pred_sub_map = self.sub_maps[q_frame.matched_scene_name]
+ if q_frame.tracking_status is True and np.sum(q_frame.matched_inliers) >= 64:
+ ret = pred_sub_map.refine_pose(q_frame=q_frame, refinement_method=self.loc_config['refinement_method'])
+ else:
+ ret = pred_sub_map.refine_pose(q_frame=q_frame,
+ refinement_method='matching') # do not trust the pose for projection
+
+ q_frame.time_ref = time.time() - t_start
+
+ inlier_mask = np.array(ret['inliers'])
+
+ q_frame.qvec = ret['qvec']
+ q_frame.tvec = ret['tvec']
+ q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask]
+ q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask]
+ q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask]
+ q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask]
+ q_frame.matched_sids = ret['matched_sids'][inlier_mask]
+ q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask]
+
+ q_frame.refinement_reference_frame_ids = ret['refinement_reference_frame_ids']
+ q_frame.reference_frame_id = ret['reference_frame_id']
+
+ q_err, t_err = q_frame.compute_pose_error()
+ ref_full_name = q_frame.matched_scene_name + '/' + pred_sub_map.reference_frames[
+ q_frame.reference_frame_id].name
+ print_text = 'Localization of {:s} success with inliers {:d}/{:d} with ref_name: {:s}, order: {:d}, q_err: {:.2f}, t_err: {:.2f}'.format(
+ q_full_name, ret['num_inliers'], len(ret['inliers']), ref_full_name, q_frame.matched_order, q_err,
+ t_err)
+ print(print_text)
+
+ if show:
+ q_err, t_err = q_frame.compute_pose_error()
+ num_matches = ret['matched_keypoints'].shape[0]
+ num_inliers = ret['num_inliers']
+ show_text = 'Ref:{:d}/{:d},r_err:{:.2f}/t_err:{:.2f}'.format(num_matches, num_inliers, q_err,
+ t_err)
+ q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 130),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ q_frame.image_inlier = q_img_inlier
+
+ return True
+
+ def verify_and_update(self, q_frame: Frame, ret: dict):
+ num_matches = ret['matched_keypoints'].shape[0]
+ num_inliers = ret['num_inliers']
+ if q_frame.matched_keypoints is None or np.sum(q_frame.matched_inliers) < num_inliers:
+ self.update_query_frame(q_frame=q_frame, ret=ret)
+
+ q_err, t_err = q_frame.compute_pose_error(pred_qvec=ret['qvec'], pred_tvec=ret['tvec'])
+
+ if num_inliers < self.loc_config['min_inliers']:
+ print_text = 'Failed due to insufficient {:d} inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format(
+ ret['num_inliers'], ret['order'], q_err, t_err)
+ print(print_text)
+ q_frame.tracking_status = False
+ return False
+ else:
+ print_text = 'Succeed! Find {}/{} 2D-3D inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format(
+ num_inliers, num_matches, ret['order'], q_err, t_err)
+ print(print_text)
+ q_frame.tracking_status = True
+ return True
+
+ def update_query_frame(self, q_frame, ret):
+ q_frame.matched_scene_name = ret['matched_scene_name']
+ q_frame.reference_frame_id = ret['reference_frame_id']
+ q_frame.qvec = ret['qvec']
+ q_frame.tvec = ret['tvec']
+
+ inlier_mask = np.array(ret['inliers'])
+ q_frame.matched_keypoints = ret['matched_keypoints']
+ q_frame.matched_keypoint_ids = ret['matched_keypoint_ids']
+ q_frame.matched_xyzs = ret['matched_xyzs']
+ q_frame.matched_point3D_ids = ret['matched_point3D_ids']
+ q_frame.matched_sids = ret['matched_sids']
+ q_frame.matched_inliers = np.array(ret['inliers'])
+ q_frame.matched_order = ret['order']
+
+ if q_frame.image_inlier_tmp is not None:
+ q_frame.image_inlier = deepcopy(q_frame.image_inlier_tmp)
+ if q_frame.image_matching_tmp is not None:
+ q_frame.image_matching = deepcopy(q_frame.image_matching_tmp)
+ if q_frame.reference_frame_name_tmp is not None:
+ q_frame.reference_frame_name = q_frame.reference_frame_name_tmp
+
+ # inlier_mask = np.array(ret['inliers'])
+ # q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask]
+ # q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask]
+ # q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask]
+ # q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask]
+ # q_frame.matched_sids = ret['matched_sids'][inlier_mask]
+ # q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask]
+
+ # print('update_query_frame: ', q_frame.matched_keypoint_ids.shape, q_frame.matched_keypoints.shape,
+ # q_frame.matched_xyzs.shape, q_frame.matched_xyzs.shape, np.sum(q_frame.matched_inliers))
+
+ def process_segmentations(self, segs, topk=10):
+ pred_values, pred_ids = torch.topk(segs, k=segs.shape[-1], largest=True, dim=-1) # [N, C]
+ pred_values = pred_values.numpy()
+ pred_ids = pred_ids.numpy()
+
+ out = []
+ used_sids = []
+ for k in range(segs.shape[-1]):
+ values_k = pred_values[:, k]
+ ids_k = pred_ids[:, k]
+ uids = np.unique(ids_k)
+
+ out_k = []
+ for sid in uids:
+ if sid == 0:
+ continue
+ if sid in used_sids:
+ continue
+ used_sids.append(sid)
+ ids = np.where(ids_k == sid)[0]
+ score = np.mean(values_k[ids])
+ # score = np.median(values_k[ids])
+ # score = 100 - k
+ # out_k.append((ids.shape[0], sid - 1, ids, score))
+ out_k.append((ids.shape[0], sid, ids, score))
+
+ out_k = sorted(out_k, key=lambda item: item[0], reverse=True)
+ for v in out_k:
+ out.append((v[1], v[2], v[3])) # [sid, ids, score]
+ if len(out) >= topk:
+ return out
+ return out
diff --git a/third_party/pram/localization/point3d.py b/third_party/pram/localization/point3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e1babf427759c5f588f44023e9e1bf2648a073b
--- /dev/null
+++ b/third_party/pram/localization/point3d.py
@@ -0,0 +1,21 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> point3d
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 04/03/2024 10:13
+=================================================='''
+import numpy as np
+
+
+class Point3D:
+ def __init__(self, id: int, xyz: np.ndarray, error: float, refframe_id: int, seg_id: int = None,
+ descriptor: np.ndarray = None, rgb: np.ndarray = None, frame_ids: np.ndarray = None):
+ self.id = id
+ self.xyz = xyz
+ self.rgb = rgb
+ self.error = error
+ self.seg_id = seg_id
+ self.refframe_id = refframe_id
+ self.frame_ids = frame_ids
+ self.descriptor = descriptor
diff --git a/third_party/pram/localization/pose_estimator.py b/third_party/pram/localization/pose_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d28d6001d38cfd5f6f6135c611293ab5e83cf0a
--- /dev/null
+++ b/third_party/pram/localization/pose_estimator.py
@@ -0,0 +1,612 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> pose_estimation
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 08/02/2024 11:01
+=================================================='''
+import torch
+import numpy as np
+import pycolmap
+import cv2
+import os
+import time
+import os.path as osp
+from collections import defaultdict
+
+
+def get_covisibility_frames(frame_id, all_images, points3D, covisibility_frame=50):
+ observed = all_images[frame_id].point3D_ids
+ covis = defaultdict(int)
+ for pid in observed:
+ if pid == -1:
+ continue
+ for img_id in points3D[pid].image_ids:
+ if img_id != frame_id:
+ covis[img_id] += 1
+
+ print('Find {:d} connected frames'.format(len(covis.keys())))
+
+ covis_ids = np.array(list(covis.keys()))
+ covis_num = np.array([covis[i] for i in covis_ids])
+
+ if len(covis_ids) <= covisibility_frame:
+ sel_covis_ids = covis_ids[np.argsort(-covis_num)]
+ else:
+ ind_top = np.argpartition(covis_num, -covisibility_frame)
+ ind_top = ind_top[-covisibility_frame:] # unsorted top k
+ ind_top = ind_top[np.argsort(-covis_num[ind_top])]
+ sel_covis_ids = [covis_ids[i] for i in ind_top]
+
+ print('Retain {:d} valid connected frames'.format(len(sel_covis_ids)))
+ return sel_covis_ids
+
+
+def feature_matching(query_data, db_data, matcher):
+ db_3D_ids = db_data['db_3D_ids']
+ if db_3D_ids is None:
+ with torch.no_grad():
+ match_data = {
+ 'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(),
+ 'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(),
+ 'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(),
+ 'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]),
+
+ 'keypoints1': torch.from_numpy(db_data['keypoints'])[None].float().cuda(),
+ 'scores1': torch.from_numpy(db_data['scores'])[None].float().cuda(),
+ 'descriptors1': torch.from_numpy(db_data['descriptors'])[None].float().cuda(), # [B, N, D]
+ 'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]),
+ }
+ matches = matcher(match_data)['matches0'][0].cpu().numpy()
+ del match_data
+ else:
+ masks = (db_3D_ids != -1)
+ valid_ids = [i for i in range(masks.shape[0]) if masks[i]]
+ if len(valid_ids) == 0:
+ return np.zeros(shape=(query_data['keypoints'].shape[0],), dtype=int) - 1
+ with torch.no_grad():
+ match_data = {
+ 'keypoints0': torch.from_numpy(query_data['keypoints'])[None].float().cuda(),
+ 'scores0': torch.from_numpy(query_data['scores'])[None].float().cuda(),
+ 'descriptors0': torch.from_numpy(query_data['descriptors'])[None].float().cuda(),
+ 'image0': torch.empty((1, 1,) + tuple(query_data['image_size'])[::-1]),
+
+ 'keypoints1': torch.from_numpy(db_data['keypoints'])[masks][None].float().cuda(),
+ 'scores1': torch.from_numpy(db_data['scores'])[masks][None].float().cuda(),
+ 'descriptors1': torch.from_numpy(db_data['descriptors'][masks])[None].float().cuda(),
+ 'image1': torch.empty((1, 1,) + tuple(db_data['image_size'])[::-1]),
+ }
+ matches = matcher(match_data)['matches0'][0].cpu().numpy()
+ del match_data
+
+ for i in range(matches.shape[0]):
+ if matches[i] >= 0:
+ matches[i] = valid_ids[matches[i]]
+
+ return matches
+
+
+def find_2D_3D_matches(query_data, db_id, points3D, feature_file, db_images, matcher, obs_th=0):
+ kpq = query_data['keypoints']
+ db_name = db_images[db_id].name
+ kpdb = feature_file[db_name]['keypoints'][()]
+ desc_db = feature_file[db_name]["descriptors"][()]
+ desc_db = desc_db.transpose()
+
+ # print('db_desc: ', desc_db.shape, query_data['descriptors'].shape)
+
+ points3D_ids = db_images[db_id].point3D_ids
+ matches = feature_matching(query_data=query_data,
+ db_data={
+ 'keypoints': kpdb,
+ 'scores': feature_file[db_name]['scores'][()],
+ 'descriptors': desc_db,
+ 'db_3D_ids': points3D_ids,
+ 'image_size': feature_file[db_name]['image_size'][()]
+ },
+ matcher=matcher)
+ mkpdb = []
+ mp3d_ids = []
+ q_ids = []
+ mkpq = []
+ mp3d = []
+ valid_matches = []
+ for idx in range(matches.shape[0]):
+ if matches[idx] == -1:
+ continue
+ if points3D_ids[matches[idx]] == -1:
+ continue
+ id_3D = points3D_ids[matches[idx]]
+
+ # reject 3d points without enough observations
+ if len(points3D[id_3D].image_ids) < obs_th:
+ continue
+ mp3d.append(points3D[id_3D].xyz)
+ mp3d_ids.append(id_3D)
+
+ mkpq.append(kpq[idx])
+ mkpdb.append(kpdb[matches[idx]])
+ q_ids.append(idx)
+ valid_matches.append(matches[idx])
+
+ mp3d = np.array(mp3d, float).reshape(-1, 3)
+ mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5
+ return mp3d, mkpq, mp3d_ids, q_ids
+
+
+# hfnet, cvpr 2019
+def pose_estimator_hloc(qname, qinfo, db_ids, db_images, points3D,
+ feature_file,
+ thresh,
+ image_dir,
+ matcher,
+ log_info=None,
+ query_img_prefix='',
+ db_img_prefix=''):
+ kpq = feature_file[qname]['keypoints'][()]
+ score_q = feature_file[qname]['scores'][()]
+ desc_q = feature_file[qname]['descriptors'][()]
+ desc_q = desc_q.transpose()
+ imgsize_q = feature_file[qname]['image_size'][()]
+ query_data = {
+ 'keypoints': kpq,
+ 'scores': score_q,
+ 'descriptors': desc_q,
+ 'image_size': imgsize_q,
+ }
+
+ camera_model, width, height, params = qinfo
+ cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params)
+ cfg = {
+ 'model': camera_model,
+ 'width': width,
+ 'height': height,
+ 'params': params,
+ }
+ all_mkpts = []
+ all_mp3ds = []
+ all_points3D_ids = []
+ best_db_id = db_ids[0]
+ best_db_name = db_images[best_db_id].name
+
+ t_start = time.time()
+
+ for cluster_idx, db_id in enumerate(db_ids):
+ mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches(
+ query_data=query_data,
+ db_id=db_id,
+ points3D=points3D,
+ feature_file=feature_file,
+ db_images=db_images,
+ matcher=matcher,
+ obs_th=3)
+ if mp3d.shape[0] > 0:
+ all_mkpts.append(mkpq)
+ all_mp3ds.append(mp3d)
+ all_points3D_ids = all_points3D_ids + mp3d_ids
+
+ if len(all_mkpts) == 0:
+ print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name)
+ print(print_text)
+ if log_info is not None:
+ log_info = log_info + print_text + '\n'
+
+ qvec = db_images[best_db_id].qvec
+ tvec = db_images[best_db_id].tvec
+
+ return {
+ 'qvec': qvec,
+ 'tvec': tvec,
+ 'log_info': log_info,
+ 'qname': qname,
+ 'dbname': best_db_name,
+ 'num_inliers': 0,
+ 'order': -1,
+ 'keypoints_query': np.array([]),
+ 'points3D_ids': [],
+ 'time': time.time() - t_start,
+ }
+
+ all_mkpts = np.vstack(all_mkpts)
+ all_mp3ds = np.vstack(all_mp3ds)
+
+ ret = pycolmap.absolute_pose_estimation(all_mkpts, all_mp3ds, cam,
+ estimation_options={
+ "ransac": {"max_error": thresh}},
+ refinement_options={},
+ )
+ if ret is None:
+ ret = {'success': False, }
+ else:
+ ret['success'] = True
+ ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
+ ret['tvec'] = ret['cam_from_world'].translation
+ success = ret['success']
+
+ if success:
+ print_text = 'qname: {:s} localization success with {:d}/{:d} inliers'.format(qname, ret['num_inliers'],
+ all_mp3ds.shape[0])
+ print(print_text)
+ if log_info is not None:
+ log_info = log_info + print_text + '\n'
+
+ qvec = ret['qvec']
+ tvec = ret['tvec']
+ ret['cfg'] = cfg
+ num_inliers = ret['num_inliers']
+ inliers = ret['inliers']
+ return {
+ 'qvec': qvec,
+ 'tvec': tvec,
+ 'log_info': log_info,
+ 'qname': qname,
+ 'dbname': best_db_name,
+ 'num_inliers': num_inliers,
+ 'order': -1,
+ 'keypoints_query': np.array([all_mkpts[i] for i in range(len(inliers)) if inliers[i]]),
+ 'points3D_ids': [all_points3D_ids[i] for i in range(len(inliers)) if inliers[i]],
+ 'time': time.time() - t_start,
+ }
+ else:
+ print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, best_db_name)
+ print(print_text)
+ if log_info is not None:
+ log_info = log_info + print_text + '\n'
+
+ qvec = db_images[best_db_id].qvec
+ tvec = db_images[best_db_id].tvec
+
+ return {
+ 'qvec': qvec,
+ 'tvec': tvec,
+ 'log_info': log_info,
+ 'qname': qname,
+ 'dbname': best_db_name,
+ 'num_inliers': 0,
+ 'order': -1,
+ 'keypoints_query': np.array([]),
+ 'points3D_ids': [],
+ 'time': time.time() - t_start,
+ }
+
+
+def pose_refinement(query_data,
+ query_cam, feature_file, db_frame_id, db_images, points3D, matcher,
+ covisibility_frame=50,
+ obs_th=3,
+ opt_th=12,
+ qvec=None,
+ tvec=None,
+ log_info='',
+ **kwargs,
+ ):
+ db_ids = get_covisibility_frames(frame_id=db_frame_id, all_images=db_images, points3D=points3D,
+ covisibility_frame=covisibility_frame)
+
+ mp3d = []
+ mkpq = []
+ mkpdb = []
+ all_3D_ids = []
+ all_score_q = []
+ kpq = query_data['keypoints']
+ for i, db_id in enumerate(db_ids):
+ db_name = db_images[db_id].name
+ kpdb = feature_file[db_name]['keypoints'][()]
+ scores_db = feature_file[db_name]['scores'][()]
+ imgsize_db = feature_file[db_name]['image_size'][()]
+ desc_db = feature_file[db_name]["descriptors"][()]
+ desc_db = desc_db.transpose()
+
+ points3D_ids = db_images[db_id].point3D_ids
+ if points3D_ids.size == 0:
+ print("No 3D points in this db image: ", db_name)
+ continue
+
+ matches = feature_matching(query_data=query_data,
+ db_data={'keypoints': kpdb,
+ 'scores': scores_db,
+ 'descriptors': desc_db,
+ 'image_size': imgsize_db,
+ 'db_3D_ids': points3D_ids,
+ },
+ matcher=matcher,
+ )
+ valid = np.where(matches > -1)[0]
+ valid = valid[points3D_ids[matches[valid]] != -1]
+ inliers = []
+ for idx in valid:
+ id_3D = points3D_ids[matches[idx]]
+ if len(points3D[id_3D].image_ids) < obs_th:
+ continue
+
+ inliers.append(True)
+
+ mp3d.append(points3D[id_3D].xyz)
+ mkpq.append(kpq[idx])
+ mkpdb.append(kpdb[matches[idx]])
+ all_3D_ids.append(id_3D)
+
+ mp3d = np.array(mp3d, float).reshape(-1, 3)
+ mkpq = np.array(mkpq, float).reshape(-1, 2) + 0.5
+ print_text = 'Get {:d} covisible frames with {:d} matches from cluster optimization'.format(len(db_ids),
+ mp3d.shape[0])
+ print(print_text)
+ if log_info is not None:
+ log_info += (print_text + '\n')
+
+ # cam = pycolmap.Camera(model=cfg['model'], params=cfg['params'])
+ ret = pycolmap.absolute_pose_estimation(mkpq, mp3d,
+ query_cam,
+ estimation_options={
+ "ransac": {"max_error": opt_th}},
+ refinement_options={},
+ )
+ if ret is None:
+ ret = {'success': False, }
+ else:
+ ret['success'] = True
+ ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
+ ret['tvec'] = ret['cam_from_world'].translation
+
+ if not ret['success']:
+ ret['mkpq'] = mkpq
+ ret['3D_ids'] = all_3D_ids
+ ret['db_ids'] = db_ids
+ ret['score_q'] = all_score_q
+ ret['log_info'] = log_info
+ ret['qvec'] = qvec
+ ret['tvec'] = tvec
+ ret['inliers'] = [False for i in range(mkpq.shape[0])]
+ ret['num_inliers'] = 0
+ ret['keypoints_query'] = np.array([])
+ ret['points3D_ids'] = []
+ return ret
+
+ ret_inliers = ret['inliers']
+ loc_keypoints_query = np.array([mkpq[i] for i in range(len(ret_inliers)) if ret_inliers[i]])
+ loc_points3D_ids = [all_3D_ids[i] for i in range(len(ret_inliers)) if ret_inliers[i]]
+
+ ret['mkpq'] = mkpq
+ ret['3D_ids'] = all_3D_ids
+ ret['db_ids'] = db_ids
+ ret['log_info'] = log_info
+ ret['keypoints_query'] = loc_keypoints_query
+ ret['points3D_ids'] = loc_points3D_ids
+
+ return ret
+
+
+# proposed in efficient large-scale localization by global instance recognition, cvpr 2022
+def pose_estimator_iterative(qname, qinfo, db_ids, db_images, points3D, feature_file, thresh, image_dir,
+ matcher,
+ inlier_th=50,
+ log_info=None,
+ do_covisibility_opt=False,
+ covisibility_frame=50,
+ vis_dir=None,
+ obs_th=0,
+ opt_th=12,
+ gt_qvec=None,
+ gt_tvec=None,
+ query_img_prefix='',
+ db_img_prefix='',
+ ):
+ print("qname: ", qname)
+ db_name_to_id = {image.name: i for i, image in db_images.items()}
+ # q_img = cv2.imread(osp.join(image_dir, query_img_prefix, qname))
+
+ kpq = feature_file[qname]['keypoints'][()]
+ score_q = feature_file[qname]['scores'][()]
+ imgsize_q = feature_file[qname]['image_size'][()]
+ desc_q = feature_file[qname]['descriptors'][()]
+ desc_q = desc_q.transpose() # [N D]
+ query_data = {
+ 'keypoints': kpq,
+ 'scores': score_q,
+ 'descriptors': desc_q,
+ 'image_size': imgsize_q,
+ }
+ camera_model, width, height, params = qinfo
+
+ best_results = {
+ 'tvec': None,
+ 'qvec': None,
+ 'num_inliers': 0,
+ 'single_num_inliers': 0,
+ 'db_id': -1,
+ 'order': -1,
+ 'qname': qname,
+ 'optimize': False,
+ 'dbname': db_images[db_ids[0]].name,
+ "ret_source": "",
+ "inliers": [],
+ 'keypoints_query': np.array([]),
+ 'points3D_ids': [],
+ }
+
+ cam = pycolmap.Camera(model=camera_model, width=width, height=height, params=params)
+
+ for cluster_idx, db_id in enumerate(db_ids):
+ db_name = db_images[db_id].name
+ mp3d, mkpq, mp3d_ids, q_ids = find_2D_3D_matches(
+ query_data=query_data,
+ db_id=db_id,
+ points3D=points3D,
+ feature_file=feature_file,
+ db_images=db_images,
+ matcher=matcher,
+ obs_th=obs_th)
+
+ if mp3d.shape[0] < 8:
+ print_text = "qname: {:s} dbname: {:s}({:d}/{:d}) failed because of insufficient 3d points {:d}".format(
+ qname,
+ db_name,
+ cluster_idx + 1,
+ len(db_ids),
+ mp3d.shape[0])
+ print(print_text)
+ if log_info is not None:
+ log_info += (print_text + '\n')
+ continue
+
+ ret = pycolmap.absolute_pose_estimation(mkpq, mp3d, cam,
+ estimation_options={
+ "ransac": {"max_error": thresh}},
+ refinement_options={},
+ )
+
+ if ret is None:
+ ret = {'success': False, }
+ else:
+ ret['success'] = True
+ ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
+ ret['tvec'] = ret['cam_from_world'].translation
+
+ if not ret["success"]:
+ print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed after matching".format(qname, db_name,
+ cluster_idx + 1,
+ len(db_ids))
+ print(print_text)
+ if log_info is not None:
+ log_info += (print_text + '\n')
+ continue
+
+ inliers = ret['inliers']
+ num_inliers = ret['num_inliers']
+ inlier_p3d_ids = [mp3d_ids[i] for i in range(len(inliers)) if inliers[i]]
+ inlier_mkpq = [mkpq[i] for i in range(len(inliers)) if inliers[i]]
+ loc_keypoints_query = np.array(inlier_mkpq)
+ loc_points3D_ids = inlier_p3d_ids
+
+ if ret['num_inliers'] > best_results['num_inliers']:
+ best_results['qvec'] = ret['qvec']
+ best_results['tvec'] = ret['tvec']
+ best_results['inlier'] = ret['inliers']
+ best_results['num_inliers'] = ret['num_inliers']
+ best_results['dbname'] = db_name
+ best_results['order'] = cluster_idx + 1
+ best_results['keypoints_query'] = loc_keypoints_query
+ best_results['points3D_ids'] = loc_points3D_ids
+
+ if ret['num_inliers'] < inlier_th:
+ print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) failed insufficient {:d} inliers".format(qname,
+ db_name,
+ cluster_idx + 1,
+ len(db_ids),
+ num_inliers,
+ )
+ print(print_text)
+ if log_info is not None:
+ log_info += (print_text + '\n')
+ continue
+
+ print_text = "qname: {:s} dbname: {:s} ({:d}/{:d}) initialization succeed with {:d} inliers".format(
+ qname,
+ db_name,
+ cluster_idx + 1,
+ len(db_ids),
+ ret["num_inliers"]
+ )
+ print(print_text)
+ if log_info is not None:
+ log_info += (print_text + '\n')
+
+ if do_covisibility_opt:
+ ret = pose_refinement(qname=qname,
+ query_cam=cam,
+ feature_file=feature_file,
+ db_frame_id=db_id,
+ db_images=db_images,
+ points3D=points3D,
+ thresh=thresh,
+ covisibility_frame=covisibility_frame,
+ matcher=matcher,
+ obs_th=obs_th,
+ opt_th=opt_th,
+ qvec=ret['qvec'],
+ tvec=ret['tvec'],
+ log_info='',
+ image_dir=image_dir,
+ vis_dir=vis_dir,
+ gt_qvec=gt_qvec,
+ gt_tvec=gt_tvec,
+ )
+
+ loc_keypoints_query = ret['keypoints_query']
+ loc_points3D_ids = ret['points3D_ids']
+
+ log_info = log_info + ret['log_info']
+ print_text = 'Find {:d} inliers after optimization'.format(ret['num_inliers'])
+ print(print_text)
+ if log_info is not None:
+ log_info += (print_text + "\n")
+
+ # localization succeed
+ qvec = ret['qvec']
+ tvec = ret['tvec']
+ num_inliers = ret['num_inliers']
+ best_results['keypoints_query'] = loc_keypoints_query
+ best_results['points3D_ids'] = loc_points3D_ids
+
+ best_results['qvec'] = qvec
+ best_results['tvec'] = tvec
+ best_results['num_inliers'] = num_inliers
+ best_results['log_info'] = log_info
+
+ return best_results
+
+ if best_results['num_inliers'] >= 10: # 20 for aachen
+ qvec = best_results['qvec']
+ tvec = best_results['tvec']
+ best_dbname = best_results['dbname']
+
+ best_results['keypoints_query'] = loc_keypoints_query
+ best_results['points3D_ids'] = loc_points3D_ids
+
+ if do_covisibility_opt:
+ ret = pose_refinement(qname=qname,
+ query_cam=cam,
+ feature_file=feature_file,
+ db_frame_id=db_name_to_id[best_dbname],
+ db_images=db_images,
+ points3D=points3D,
+ thresh=thresh,
+ covisibility_frame=covisibility_frame,
+ matcher=matcher,
+ obs_th=obs_th,
+ opt_th=opt_th,
+ qvec=qvec,
+ tvec=tvec,
+ log_info='',
+ image_dir=image_dir,
+ vis_dir=vis_dir,
+ gt_qvec=gt_qvec,
+ gt_tvec=gt_tvec,
+ )
+
+ # localization succeed
+ qvec = ret['qvec']
+ tvec = ret['tvec']
+ num_inliers = ret['num_inliers']
+ best_results['keypoints_query'] = loc_keypoints_query
+ best_results['points3D_ids'] = loc_points3D_ids
+
+ best_results['qvec'] = qvec
+ best_results['tvec'] = tvec
+ best_results['num_inliers'] = num_inliers
+ best_results['log_info'] = log_info
+
+ return best_results
+
+ closest = db_images[db_ids[0][0]]
+ print_text = 'Localize {:s} failed, but use the pose of {:s} as approximation'.format(qname, closest.name)
+ print(print_text)
+ if log_info is not None:
+ log_info += (print_text + '\n')
+
+ best_results['qvec'] = closest.qvec
+ best_results['tvec'] = closest.tvec
+ best_results['num_inliers'] = -1
+ best_results['log_info'] = log_info
+
+ return best_results
diff --git a/third_party/pram/localization/refframe.py b/third_party/pram/localization/refframe.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7eeafd44557ffdfda5829dab00dd5df125148b4
--- /dev/null
+++ b/third_party/pram/localization/refframe.py
@@ -0,0 +1,147 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> refframe
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 04/03/2024 10:06
+=================================================='''
+import numpy as np
+from localization.camera import Camera
+from colmap_utils.camera_intrinsics import intrinsics_from_camera
+from colmap_utils.read_write_model import qvec2rotmat
+
+
+class RefFrame:
+ def __init__(self, camera: Camera, id: int, qvec: np.ndarray, tvec: np.ndarray,
+ point3D_ids: np.ndarray = None, keypoints: np.ndarray = None,
+ name: str = None, scene_name: str = None):
+ self.camera = camera
+ self.id = id
+ self.qvec = qvec
+ self.tvec = tvec
+ self.name = name
+ self.scene_name = scene_name
+ self.width = camera.width
+ self.height = camera.height
+ self.image_size = np.array([self.height, self.width])
+
+ self.point3D_ids = point3D_ids
+ self.keypoints = keypoints
+ self.descriptors = None
+ self.keypoint_segs = None
+ self.xyzs = None
+
+ def get_keypoints_by_sid(self, sid: int):
+ mask = (self.keypoint_segs == sid)
+ return {
+ 'point3D_ids': self.point3D_ids[mask],
+ 'keypoints': self.keypoints[mask][:, :2],
+ 'descriptors': self.descriptors[mask],
+ 'scores': self.keypoints[mask][:, 2],
+ 'xyzs': self.xyzs[mask],
+ 'camera': self.camera,
+ }
+
+ valid_p3d_ids = []
+ valid_kpts = []
+ valid_descs = []
+ valid_scores = []
+ valid_xyzs = []
+ for i, v in enumerate(self.point3D_ids):
+ if v in point3Ds.keys():
+ p3d = point3Ds[v]
+ if p3d.seg_id == sid:
+ valid_kpts.append(self.keypoints[i])
+ valid_p3d_ids.append(v)
+ valid_xyzs.append(p3d.xyz)
+ valid_descs.append(p3d.descriptor)
+ valid_scores.append(p3d.error)
+ return {
+ 'point3D_ids': np.array(valid_p3d_ids),
+ 'keypoints': np.array(valid_kpts),
+ 'descriptors': np.array(valid_descs),
+ 'scores': np.array(valid_scores),
+ 'xyzs': np.array(valid_xyzs),
+ }
+
+ def get_keypoints(self):
+ return {
+ 'point3D_ids': self.point3D_ids,
+ 'keypoints': self.keypoints[:, :2],
+ 'descriptors': self.descriptors,
+ 'scores': self.keypoints[:, 2],
+ 'xyzs': self.xyzs,
+ 'camera': self.camera,
+ }
+
+ valid_p3d_ids = []
+ valid_kpts = []
+ valid_descs = []
+ valid_scores = []
+ valid_xyzs = []
+ for i, v in enumerate(self.point3D_ids):
+ if v in point3Ds.keys():
+ p3d = point3Ds[v]
+ valid_kpts.append(self.keypoints[i])
+ valid_p3d_ids.append(v)
+ valid_xyzs.append(p3d.xyz)
+ valid_descs.append(p3d.descriptor)
+ valid_scores.append(p3d.error)
+ return {
+ 'points3D_ids': np.array(valid_p3d_ids),
+ 'keypoints': np.array(valid_kpts),
+ 'descriptors': np.array(valid_descs),
+ 'scores': 1 / np.clip(np.array(valid_scores) * 5, a_min=1., a_max=20.),
+ 'xyzs': np.array(valid_xyzs),
+ 'camera': self.camera,
+ }
+
+ def associate_keypoints_with_point3Ds(self, point3Ds: dict):
+ xyzs = []
+ descs = []
+ scores = []
+ p3d_ids = []
+ kpt_sids = []
+ for i, v in enumerate(self.point3D_ids):
+ if v in point3Ds.keys():
+ p3d = point3Ds[v]
+ p3d_ids.append(v)
+ xyzs.append(p3d.xyz)
+ descs.append(p3d.descriptor)
+ scores.append(p3d.error)
+
+ kpt_sids.append(p3d.seg_id)
+
+ xyzs = np.array(xyzs)
+ if xyzs.shape[0] == 0:
+ return False
+
+ descs = np.array(descs)
+ scores = 1 / np.clip(np.array(scores) * 5, a_min=1., a_max=20.)
+ p3d_ids = np.array(p3d_ids)
+ uvs = self.project(xyzs=xyzs)
+ self.keypoints = np.hstack([uvs, scores.reshape(-1, 1)])
+ self.descriptors = descs
+ self.point3D_ids = p3d_ids
+ self.xyzs = xyzs
+ self.keypoint_segs = np.array(kpt_sids)
+
+ return True
+
+ def project(self, xyzs):
+ '''
+ :param xyzs: [N, 3]
+ :return:
+ '''
+ K = intrinsics_from_camera(camera_model=self.camera.model, params=self.camera.params) # [3, 3]
+ Rcw = qvec2rotmat(self.qvec)
+ tcw = self.tvec.reshape(3, 1)
+ Tcw = np.eye(4, dtype=float)
+ Tcw[:3, :3] = Rcw
+ Tcw[:3, 3:] = tcw
+ xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1))]) # [N 4]
+
+ xyzs_cam = Tcw @ xyzs_homo.transpose() # [4, N]
+ uvs = K @ xyzs_cam[:3, :] # [3, N]
+ uvs[:2, :] = uvs[:2, :] / uvs[2, :]
+ return uvs[:2, :].transpose()
diff --git a/third_party/pram/localization/singlemap3d.py b/third_party/pram/localization/singlemap3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..77fc0ef2c78321044bb8f8f2952ccb278ea28d8f
--- /dev/null
+++ b/third_party/pram/localization/singlemap3d.py
@@ -0,0 +1,532 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> map3d
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 04/03/2024 10:25
+=================================================='''
+import numpy as np
+from collections import defaultdict
+import os.path as osp
+import pycolmap
+import logging
+import time
+
+import torch
+
+from localization.refframe import RefFrame
+from localization.frame import Frame
+from localization.point3d import Point3D
+from colmap_utils.read_write_model import qvec2rotmat, read_model, read_compressed_model
+from localization.utils import read_gt_pose
+
+
+class SingleMap3D:
+ def __init__(self, config, matcher, with_compress=False, start_sid: int = 0):
+ self.config = config
+ self.matcher = matcher
+ self.image_path_prefix = self.config['image_path_prefix']
+ self.start_sid = start_sid # for a dataset with multiple scenes
+ if not with_compress:
+ cameras, images, p3ds = read_model(
+ path=osp.join(config['landmark_path'], 'model'), ext='.bin')
+ p3d_descs = np.load(osp.join(config['landmark_path'], 'point3D_desc.npy'),
+ allow_pickle=True)[()]
+ else:
+ cameras, images, p3ds = read_compressed_model(
+ path=osp.join(config['landmark_path'], 'compress_model_{:s}'.format(config['cluster_method'])),
+ ext='.bin')
+ p3d_descs = np.load(osp.join(config['landmark_path'], 'compress_model_{:s}/point3D_desc.npy'.format(
+ config['cluster_method'])), allow_pickle=True)[()]
+
+ print('Load {} cameras {} images {} 3D points'.format(len(cameras), len(images), len(p3d_descs)))
+
+ seg_data = np.load(
+ osp.join(config['landmark_path'], 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'],
+ config['cluster_mode'],
+ config['cluster_method'])),
+ allow_pickle=True)[()]
+
+ p3d_id = seg_data['id']
+ seg_id = seg_data['label']
+ p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ seg_p3d = {}
+ for k in p3d_seg.keys():
+ sid = p3d_seg[k]
+ if sid in seg_p3d.keys():
+ seg_p3d[sid].append(k)
+ else:
+ seg_p3d[sid] = [k]
+
+ print('Load {} segments and {} 3d points'.format(len(seg_p3d.keys()), len(p3d_seg.keys())))
+ seg_vrf = np.load(
+ osp.join(config['landmark_path'], 'point3D_vrf_n{:d}_{:s}_{:s}.npy'.format(config['n_cluster'],
+ config['cluster_mode'],
+ config['cluster_method'])),
+ allow_pickle=True)[()]
+
+ # construct 3D map
+ self.initialize_point3Ds(p3ds=p3ds, p3d_descs=p3d_descs, p3d_seg=p3d_seg)
+ self.initialize_ref_frames(cameras=cameras, images=images)
+
+ all_vrf_frame_ids = []
+ self.seg_ref_frame_ids = {}
+ for sid in seg_vrf.keys():
+ self.seg_ref_frame_ids[sid] = []
+ for vi in seg_vrf[sid].keys():
+ vrf_frame_id = seg_vrf[sid][vi]['image_id']
+ self.seg_ref_frame_ids[sid].append(vrf_frame_id)
+ if with_compress and vrf_frame_id in self.reference_frames.keys():
+ self.reference_frames[vrf_frame_id].point3D_ids = seg_vrf[sid][vi]['original_points3d']
+
+ all_vrf_frame_ids.extend(self.seg_ref_frame_ids[sid])
+
+ if with_compress:
+ all_ref_ids = list(self.reference_frames.keys())
+ for fid in all_ref_ids:
+ valid = self.reference_frames[fid].associate_keypoints_with_point3Ds(point3Ds=self.point3Ds)
+ if not valid:
+ del self.reference_frames[fid]
+
+ all_vrf_frame_ids = np.unique(all_vrf_frame_ids)
+ all_vrf_frame_ids = [v for v in all_vrf_frame_ids if v in self.reference_frames.keys()]
+ self.build_covisibility_graph(frame_ids=all_vrf_frame_ids, n_frame=config['localization'][
+ 'covisibility_frame']) # build covisible frames for vrf frames only
+
+ logging.info(
+ f'Construct {len(self.reference_frames.keys())} ref frames and {len(self.point3Ds.keys())} 3d points')
+
+ self.gt_poses = {}
+ if config['gt_pose_path'] is not None:
+ gt_pose_path = osp.join(config['dataset_path'], config['gt_pose_path'])
+ self.read_gt_pose(path=gt_pose_path)
+
+ def read_gt_pose(self, path, prefix=''):
+ self.gt_poses = read_gt_pose(path=path)
+ print('Load {} gt poses'.format(len(self.gt_poses.keys())))
+
+ def initialize_point3Ds(self, p3ds, p3d_descs, p3d_seg):
+ self.point3Ds = {}
+ for id in p3ds.keys():
+ if id not in p3d_seg.keys():
+ continue
+ self.point3Ds[id] = Point3D(id=id, xyz=p3ds[id].xyz, error=p3ds[id].error,
+ refframe_id=-1, rgb=p3ds[id].rgb,
+ descriptor=p3d_descs[id], seg_id=p3d_seg[id],
+ frame_ids=p3ds[id].image_ids)
+
+ def initialize_ref_frames(self, cameras, images):
+ self.reference_frames = {}
+ for id in images.keys():
+ im = images[id]
+ cam = cameras[im.camera_id]
+ self.reference_frames[id] = RefFrame(camera=cam, id=id, qvec=im.qvec, tvec=im.tvec,
+ point3D_ids=im.point3D_ids,
+ keypoints=im.xys, name=im.name)
+
+ def localize_with_ref_frame(self, q_frame: Frame, q_kpt_ids: np.ndarray, sid, semantic_matching=False):
+ ref_frame_id = self.seg_ref_frame_ids[sid][0]
+ ref_frame = self.reference_frames[ref_frame_id]
+ if semantic_matching and sid > 0:
+ ref_data = ref_frame.get_keypoints_by_sid(sid=sid)
+ else:
+ ref_data = ref_frame.get_keypoints()
+
+ q_descs = q_frame.descriptors[q_kpt_ids]
+ q_kpts = q_frame.keypoints[q_kpt_ids, :2]
+ q_scores = q_frame.keypoints[q_kpt_ids, 2]
+
+ xyzs = ref_data['xyzs']
+ point3D_ids = ref_data['point3D_ids']
+ ref_sids = np.array([self.point3Ds[v].seg_id for v in point3D_ids])
+ with torch.no_grad():
+ indices0 = self.matcher({
+ 'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(),
+ 'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(),
+ 'scores0': torch.from_numpy(q_scores)[None].cuda().float(),
+ 'image_shape0': (1, 3, q_frame.camera.width, q_frame.camera.height),
+
+ 'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(),
+ 'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(),
+ 'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(),
+ 'image_shape1': (1, 3, ref_frame.camera.width, ref_frame.camera.height),
+ }
+ )['matches0'][0].cpu().numpy()
+
+ valid = indices0 >= 0
+ mkpts = q_kpts[valid]
+ mkpt_ids = q_kpt_ids[valid]
+ mxyzs = xyzs[indices0[valid]]
+ mpoint3D_ids = point3D_ids[indices0[valid]]
+ matched_sids = ref_sids[indices0[valid]]
+ matched_ref_keypoints = ref_data['keypoints'][indices0[valid]]
+
+ # print('mkpts: ', mkpts.shape, mxyzs.shape, np.sum(indices0 >= 0))
+ # cfg = q_frame.camera._asdict()
+ # q_cam = pycolmap.Camera(model=q_frame.camera.model, )
+ # config = {"estimation": {"ransac": {"max_error": ransac_thresh}}, **(config or {})}
+ ret = pycolmap.absolute_pose_estimation(mkpts + 0.5,
+ mxyzs,
+ q_frame.camera,
+ estimation_options={
+ "ransac": {"max_error": self.config['localization']['threshold']}},
+ refinement_options={},
+ # max_error_px=self.config['localization']['threshold']
+ )
+ if ret is None:
+ ret = {'success': False, }
+ else:
+ ret['success'] = True
+ ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
+ ret['tvec'] = ret['cam_from_world'].translation
+ ret['matched_keypoints'] = mkpts
+ ret['matched_keypoint_ids'] = mkpt_ids
+ ret['matched_xyzs'] = mxyzs
+ ret['reference_frame_id'] = ref_frame_id
+ ret['matched_point3D_ids'] = mpoint3D_ids
+ ret['matched_sids'] = matched_sids
+ ret['matched_ref_keypoints'] = matched_ref_keypoints
+
+ if not ret['success']:
+ ret['num_inliers'] = 0
+ ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool)
+ return ret
+
+ def match(self, query_data, ref_data):
+ q_descs = query_data['descriptors']
+ q_kpts = query_data['keypoints']
+ q_scores = query_data['scores']
+ xyzs = ref_data['xyzs']
+ points3D_ids = ref_data['point3D_ids']
+ with torch.no_grad():
+ indices0 = self.matcher({
+ 'descriptors0': torch.from_numpy(q_descs)[None].cuda().float(),
+ 'keypoints0': torch.from_numpy(q_kpts)[None].cuda().float(),
+ 'scores0': torch.from_numpy(q_scores)[None].cuda().float(),
+ 'image_shape0': (1, 3, query_data['camera'].width, query_data['camera'].height),
+
+ 'descriptors1': torch.from_numpy(ref_data['descriptors'])[None].cuda().float(),
+ 'keypoints1': torch.from_numpy(ref_data['keypoints'])[None].cuda().float(),
+ 'scores1': torch.from_numpy(ref_data['scores'])[None].cuda().float(),
+ 'image_shape1': (1, 3, ref_data['camera'].width, ref_data['camera'].height),
+ }
+ )['matches0'][0].cpu().numpy()
+
+ valid = indices0 >= 0
+ mkpts = q_kpts[valid]
+ mkpt_ids = np.where(valid)[0]
+ mxyzs = xyzs[indices0[valid]]
+ mpoints3D_ids = points3D_ids[indices0[valid]]
+
+ return {
+ 'matched_keypoints': mkpts,
+ 'matched_xyzs': mxyzs,
+ 'matched_point3D_ids': mpoints3D_ids,
+ 'matched_keypoint_ids': mkpt_ids,
+ }
+
+ def build_covisibility_graph(self, frame_ids: list = None, n_frame: int = 20):
+ def find_covisible_frames(frame_id):
+ observed = self.reference_frames[frame_id].point3D_ids
+ covis = defaultdict(int)
+ for pid in observed:
+ if pid == -1:
+ continue
+ if pid not in self.point3Ds.keys():
+ continue
+ for img_id in self.point3Ds[pid].frame_ids:
+ covis[img_id] += 1
+
+ covis_ids = np.array(list(covis.keys()))
+ covis_num = np.array([covis[i] for i in covis_ids])
+
+ if len(covis_ids) <= n_frame:
+ sel_covis_ids = covis_ids[np.argsort(-covis_num)]
+ else:
+ ind_top = np.argpartition(covis_num, -n_frame)
+ ind_top = ind_top[-n_frame:] # unsorted top k
+ ind_top = ind_top[np.argsort(-covis_num[ind_top])]
+ sel_covis_ids = [covis_ids[i] for i in ind_top]
+
+ return sel_covis_ids
+
+ if frame_ids is None:
+ frame_ids = list(self.referece_frames.keys())
+
+ self.covisible_graph = defaultdict()
+ for frame_id in frame_ids:
+ self.covisible_graph[frame_id] = find_covisible_frames(frame_id=frame_id)
+
+ def refine_pose(self, q_frame: Frame, refinement_method='matching'):
+ if refinement_method == 'matching':
+ return self.refine_pose_by_matching(q_frame=q_frame)
+ elif refinement_method == 'projection':
+ return self.refine_pose_by_projection(q_frame=q_frame)
+ else:
+ raise NotImplementedError
+
+ def refine_pose_by_matching(self, q_frame):
+ ref_frame_id = q_frame.reference_frame_id
+ db_ids = self.covisible_graph[ref_frame_id]
+ print('Find {} covisible frames'.format(len(db_ids)))
+ loc_success = q_frame.tracking_status
+ if loc_success and ref_frame_id in db_ids:
+ init_kpts = q_frame.matched_keypoints
+ init_kpt_ids = q_frame.matched_keypoint_ids
+ init_point3D_ids = q_frame.matched_point3D_ids
+ init_xyzs = np.array([self.point3Ds[v].xyz for v in init_point3D_ids]).reshape(-1, 3)
+ list(db_ids).remove(ref_frame_id)
+ else:
+ init_kpts = None
+ init_xyzs = None
+ init_point3D_ids = None
+
+ matched_xyzs = []
+ matched_kpts = []
+ matched_point3D_ids = []
+ matched_kpt_ids = []
+ for idx, frame_id in enumerate(db_ids):
+ ref_data = self.reference_frames[frame_id].get_keypoints()
+ match_out = self.match(query_data={
+ 'keypoints': q_frame.keypoints[:, :2],
+ 'scores': q_frame.keypoints[:, 2],
+ 'descriptors': q_frame.descriptors,
+ 'camera': q_frame.camera, },
+ ref_data=ref_data)
+ if match_out['matched_keypoints'].shape[0] > 0:
+ matched_kpts.append(match_out['matched_keypoints'])
+ matched_xyzs.append(match_out['matched_xyzs'])
+ matched_point3D_ids.append(match_out['matched_point3D_ids'])
+ matched_kpt_ids.append(match_out['matched_keypoint_ids'])
+ if len(matched_kpts) > 1:
+ matched_kpts = np.vstack(matched_kpts)
+ matched_xyzs = np.vstack(matched_xyzs).reshape(-1, 3)
+ matched_point3D_ids = np.hstack(matched_point3D_ids)
+ matched_kpt_ids = np.hstack(matched_kpt_ids)
+ else:
+ matched_kpts = matched_kpts[0]
+ matched_xyzs = matched_xyzs[0]
+ matched_point3D_ids = matched_point3D_ids[0]
+ matched_kpt_ids = matched_kpt_ids[0]
+ if init_kpts is not None and init_kpts.shape[0] > 0:
+ matched_kpts = np.vstack([matched_kpts, init_kpts])
+ matched_xyzs = np.vstack([matched_xyzs, init_xyzs])
+ matched_point3D_ids = np.hstack([matched_point3D_ids, init_point3D_ids])
+ matched_kpt_ids = np.hstack([matched_kpt_ids, init_kpt_ids])
+
+ matched_sids = np.array([self.point3Ds[v].seg_id for v in matched_point3D_ids])
+
+ print_text = 'Refinement by matching. Get {:d} covisible frames with {:d} matches for optimization'.format(
+ len(db_ids), matched_xyzs.shape[0])
+ print(print_text)
+
+ t_start = time.time()
+ ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5,
+ matched_xyzs,
+ q_frame.camera,
+ estimation_options={
+ 'ransac': {
+ 'max_error': self.config['localization']['threshold'],
+ 'min_num_trials': 1000,
+ 'max_num_trials': 10000,
+ 'confidence': 0.995,
+ }},
+ refinement_options={},
+ # max_error_px=self.config['localization']['threshold'],
+ # min_num_trials=1000, max_num_trials=10000, confidence=0.995)
+ )
+ print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start))
+
+ if ret is None:
+ ret = {'success': False, }
+ else:
+ ret['success'] = True
+ ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
+ ret['tvec'] = ret['cam_from_world'].translation
+
+ ret['matched_keypoints'] = matched_kpts
+ ret['matched_keypoint_ids'] = matched_kpt_ids
+ ret['matched_xyzs'] = matched_xyzs
+ ret['matched_point3D_ids'] = matched_point3D_ids
+ ret['matched_sids'] = matched_sids
+
+ if ret['success']:
+ inlier_mask = np.array(ret['inliers'])
+ best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids[inlier_mask],
+ candidate_frame_ids=self.covisible_graph.keys())
+ else:
+ best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=matched_point3D_ids,
+ candidate_frame_ids=self.covisible_graph.keys())
+
+ ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][
+ 'covisibility_frame']]
+ ret['reference_frame_id'] = best_reference_frame_ids[0]
+
+ return ret
+
+ @torch.no_grad()
+ def refine_pose_by_projection(self, q_frame):
+ q_Rcw = qvec2rotmat(q_frame.qvec)
+ q_tcw = q_frame.tvec
+ q_Tcw = np.eye(4, dtype=float) # [4 4]
+ q_Tcw[:3, :3] = q_Rcw
+ q_Tcw[:3, 3] = q_tcw
+ cam = q_frame.camera
+ imw = cam.width
+ imh = cam.height
+ K = q_frame.get_intrinsics() # [3, 3]
+ reference_frame_id = q_frame.reference_frame_id
+ covis_frame_ids = self.covisible_graph[reference_frame_id]
+ if reference_frame_id not in covis_frame_ids:
+ covis_frame_ids.append(reference_frame_id)
+ all_point3D_ids = []
+
+ for frame_id in covis_frame_ids:
+ all_point3D_ids.extend(list(self.reference_frames[frame_id].point3D_ids))
+
+ all_point3D_ids = np.unique(all_point3D_ids)
+ all_xyzs = []
+ all_descs = []
+ all_sids = []
+ for pid in all_point3D_ids:
+ all_xyzs.append(self.point3Ds[pid].xyz)
+ all_descs.append(self.point3Ds[pid].descriptor)
+ all_sids.append(self.point3Ds[pid].seg_id)
+
+ all_xyzs = np.array(all_xyzs) # [N 3]
+ all_descs = np.array(all_descs) # [N 3]
+ all_point3D_ids = np.array(all_point3D_ids)
+ all_sids = np.array(all_sids)
+
+ # move to gpu (distortion is not included)
+ # proj_uv = pycolmap.camera.img_from_cam(
+ # np.array([1, 1, 1]).reshape(1, 3),
+ # )
+ all_xyzs_cuda = torch.from_numpy(all_xyzs).cuda()
+ ones = torch.ones(size=(all_xyzs_cuda.shape[0], 1), dtype=all_xyzs_cuda.dtype).cuda()
+ all_xyzs_cuda_homo = torch.cat([all_xyzs_cuda, ones], dim=1) # [N 4]
+ K_cuda = torch.from_numpy(K).cuda()
+ proj_uvs = K_cuda @ (torch.from_numpy(q_Tcw).cuda() @ all_xyzs_cuda_homo.t())[:3, :] # [3, N]
+ proj_uvs[0] /= proj_uvs[2]
+ proj_uvs[1] /= proj_uvs[2]
+ mask = (proj_uvs[2] > 0) * (proj_uvs[2] < 100) * (proj_uvs[0] >= 0) * (proj_uvs[0] < imw) * (
+ proj_uvs[1] >= 0) * (proj_uvs[1] < imh)
+
+ proj_uvs = proj_uvs[:, mask]
+
+ print('Projection: out of range {:d}/{:d}'.format(all_xyzs_cuda.shape[0], proj_uvs.shape[1]))
+
+ mxyzs = all_xyzs[mask.cpu().numpy()]
+ mpoint3D_ids = all_point3D_ids[mask.cpu().numpy()]
+ msids = all_sids[mask.cpu().numpy()]
+
+ q_kpts_cuda = torch.from_numpy(q_frame.keypoints[:, :2]).cuda()
+ proj_error = q_kpts_cuda[..., None] - proj_uvs[:2][None]
+ proj_error = torch.sqrt(torch.sum(proj_error ** 2, dim=1)) # [M N]
+ out_of_range_mask = (proj_error >= 2 * self.config['localization']['threshold'])
+
+ q_descs_cuda = torch.from_numpy(q_frame.descriptors).cuda().float() # [M D]
+ all_descs_cuda = torch.from_numpy(all_descs).cuda().float()[mask] # [N D]
+ desc_dist = torch.sqrt(2 - 2 * q_descs_cuda @ all_descs_cuda.t() + 1e-6)
+ desc_dist[out_of_range_mask] = desc_dist[out_of_range_mask] + 100
+ dists, ids = torch.topk(desc_dist, k=2, largest=False, dim=1)
+ # apply nn ratio
+ ratios = dists[:, 0] / dists[:, 1] # smaller, better
+ ratio_mask = (ratios <= 0.995) * (dists[:, 0] < 100)
+ ratio_mask = ratio_mask.cpu().numpy()
+ ids = ids.cpu().numpy()[ratio_mask, 0]
+
+ ratio_num = torch.sum(ratios <= 0.995)
+ proj_num = torch.sum(dists[:, 0] < 100)
+
+ print('Projection: after ratio {:d}/{:d}, ratio {:d}, proj {:d}'.format(q_kpts_cuda.shape[0],
+ np.sum(ratio_mask),
+ ratio_num, proj_num))
+
+ mkpts = q_frame.keypoints[ratio_mask]
+ mkpt_ids = np.where(ratio_mask)[0]
+ mxyzs = mxyzs[ids]
+ mpoint3D_ids = mpoint3D_ids[ids]
+ msids = msids[ids]
+ print('projection: ', mkpts.shape, mkpt_ids.shape, mxyzs.shape, mpoint3D_ids.shape, msids.shape)
+
+ t_start = time.time()
+ ret = pycolmap.absolute_pose_estimation(mkpts[:, :2] + 0.5, mxyzs, q_frame.camera,
+ estimation_options={
+ "ransac": {"max_error": self.config['localization']['threshold']}},
+ refinement_options={},
+ # max_error_px=self.config['localization']['threshold']
+ )
+ if ret is None:
+ ret = {'success': False, }
+ else:
+ ret['success'] = True
+ ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
+ ret['tvec'] = ret['cam_from_world'].translation
+ # inlier_mask = np.ones(shape=(mkpts.shape[0],), dtype=bool).tolist()
+ # ret = pycolmap.pose_refinement(q_frame.tvec, q_frame.qvec, mkpts[:, :2] + 0.5, mxyzs, inlier_mask, cfg)
+ # ret['num_inliers'] = np.sum(inlier_mask).astype(int)
+ # ret['inliers'] = np.array(inlier_mask)
+
+ print_text = 'Refinement by projection. Get {:d} inliers of {:d} matches for optimization'.format(
+ ret['num_inliers'], mxyzs.shape[0])
+ print(print_text)
+ print('Time of RANSAC: {:.2f}s'.format(time.time() - t_start))
+
+ ret['matched_keypoints'] = mkpts
+ ret['matched_xyzs'] = mxyzs
+ ret['matched_point3D_ids'] = mpoint3D_ids
+ ret['matched_sids'] = msids
+ ret['matched_keypoint_ids'] = mkpt_ids
+
+ if ret['success']:
+ inlier_mask = np.array(ret['inliers'])
+ best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids[inlier_mask],
+ candidate_frame_ids=self.covisible_graph.keys())
+ else:
+ best_reference_frame_ids = self.find_reference_frames(matched_point3D_ids=mpoint3D_ids,
+ candidate_frame_ids=self.covisible_graph.keys())
+
+ ret['refinement_reference_frame_ids'] = best_reference_frame_ids[:self.config['localization'][
+ 'covisibility_frame']]
+ ret['reference_frame_id'] = best_reference_frame_ids[0]
+
+ if not ret['success']:
+ ret['num_inliers'] = 0
+ ret['inliers'] = np.zeros(shape=(mkpts.shape[0],), dtype=bool)
+
+ return ret
+
+ def find_reference_frames(self, matched_point3D_ids, candidate_frame_ids=None):
+ covis_frames = defaultdict(int)
+ for pid in matched_point3D_ids:
+ for im_id in self.point3Ds[pid].frame_ids:
+ if candidate_frame_ids is not None and im_id in candidate_frame_ids:
+ covis_frames[im_id] += 1
+
+ covis_ids = np.array(list(covis_frames.keys()))
+ covis_num = np.array([covis_frames[i] for i in covis_ids])
+ sorted_idxes = np.argsort(covis_num)[::-1] # larger to small
+ sorted_frame_ids = covis_ids[sorted_idxes]
+ return sorted_frame_ids
+
+ def check_semantic_consistency(self, q_frame: Frame, sid, overlap_ratio=0.5):
+ ref_frame_id = self.seg_ref_frame_ids[sid][0]
+ ref_frame = self.reference_frames[ref_frame_id]
+
+ q_sids = q_frame.seg_ids
+ ref_sids = np.array([self.point3Ds[v].seg_id for v in ref_frame.point3D_ids]) + self.start_sid
+ overlap_sids = np.intersect1d(q_sids, ref_sids)
+
+ overlap_num1 = 0
+ overlap_num2 = 0
+ for sid in overlap_sids:
+ overlap_num1 += np.sum(q_sids == sid)
+ overlap_num2 += np.sum(ref_sids == sid)
+
+ ratio1 = overlap_num1 / q_sids.shape[0]
+ ratio2 = overlap_num2 / ref_sids.shape[0]
+
+ # print('semantic_check: ', overlap_sids, overlap_num1, ratio1, overlap_num2, ratio2)
+
+ return min(ratio1, ratio2) >= overlap_ratio
diff --git a/third_party/pram/localization/tracker.py b/third_party/pram/localization/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a401fea82c2372cfdf301ab2d2fb34981facf4fe
--- /dev/null
+++ b/third_party/pram/localization/tracker.py
@@ -0,0 +1,338 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> tracker
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/02/2024 16:58
+=================================================='''
+import time
+import cv2
+import numpy as np
+import torch
+import pycolmap
+from localization.frame import Frame
+from localization.base_model import dynamic_load
+import localization.matchers as matchers
+from localization.match_features_batch import confs as matcher_confs
+from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches
+from tools.common import resize_img
+
+
+class Tracker:
+ def __init__(self, locMap, matcher, config):
+ self.locMap = locMap
+ self.matcher = matcher
+ self.config = config
+ self.loc_config = config['localization']
+
+ self.lost = True
+
+ self.curr_frame = None
+ self.last_frame = None
+
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ Model = dynamic_load(matchers, 'nearest_neighbor')
+ self.nn_matcher = Model(matcher_confs['NNM']['model']).eval().to(device)
+
+ def run(self, frame: Frame):
+ print('Start tracking...')
+ show = self.config['localization']['show']
+ self.curr_frame = frame
+ ref_img = self.last_frame.image
+ curr_img = self.curr_frame.image
+ q_kpts = frame.keypoints
+
+ t_start = time.time()
+ ret = self.track_last_frame(curr_frame=self.curr_frame, last_frame=self.last_frame)
+ self.curr_frame.time_loc = self.curr_frame.time_loc + time.time() - t_start
+
+ if show:
+ curr_matched_kpts = ret['matched_keypoints']
+ ref_matched_kpts = ret['matched_ref_keypoints']
+ img_loc_matching = plot_matches(img1=curr_img, img2=ref_img,
+ pts1=curr_matched_kpts,
+ pts2=ref_matched_kpts,
+ inliers=np.array([True for i in range(curr_matched_kpts.shape[0])]),
+ radius=9, line_thickness=3)
+ self.curr_frame.image_matching = img_loc_matching
+
+ q_ref_img_matching = resize_img(img_loc_matching, nh=512)
+
+ if not ret['success']:
+ show_text = 'Tracking FAILED!'
+ img_inlier = vis_inlier(img=curr_img, kpts=curr_matched_kpts,
+ inliers=[False for i in range(curr_matched_kpts.shape[0])], radius=9 + 2,
+ thickness=2)
+ q_img_inlier = cv2.putText(img=img_inlier, text=show_text, org=(30, 30),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+
+ q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
+
+ cv2.imshow('loc', q_img_loc)
+ key = cv2.waitKey(self.loc_config['show_time'])
+ if key == ord('q'):
+ cv2.destroyAllWindows()
+ exit(0)
+ return False
+
+ ret['matched_scene_name'] = self.last_frame.scene_name
+ success = self.verify_and_update(q_frame=self.curr_frame, ret=ret)
+
+ if not success:
+ return False
+
+ if ret['num_inliers'] < 256:
+ # refinement is necessary for tracking last frame
+ t_start = time.time()
+ ret = self.locMap.sub_maps[self.last_frame.matched_scene_name].refine_pose(self.curr_frame,
+ refinement_method=
+ self.loc_config[
+ 'refinement_method'])
+ self.curr_frame.time_ref = self.curr_frame.time_ref + time.time() - t_start
+ ret['matched_scene_name'] = self.last_frame.scene_name
+ success = self.verify_and_update(q_frame=self.curr_frame, ret=ret)
+
+ if show:
+ q_err, t_err = self.curr_frame.compute_pose_error()
+ num_matches = ret['matched_keypoints'].shape[0]
+ num_inliers = ret['num_inliers']
+ show_text = 'Tracking, k/m/i: {:d}/{:d}/{:d}'.format(q_kpts.shape[0], num_matches, num_inliers)
+ q_img_inlier = vis_inlier(img=curr_img, kpts=ret['matched_keypoints'], inliers=ret['inliers'],
+ radius=9 + 2, thickness=2)
+ q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err)
+ q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA)
+ self.curr_frame.image_inlier = q_img_inlier
+
+ q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
+
+ cv2.imshow('loc', q_img_loc)
+ key = cv2.waitKey(self.loc_config['show_time'])
+ if key == ord('q'):
+ cv2.destroyAllWindows()
+ exit(0)
+
+ self.lost = success
+ return success
+
+ def verify_and_update(self, q_frame: Frame, ret: dict):
+ num_matches = ret['matched_keypoints'].shape[0]
+ num_inliers = ret['num_inliers']
+
+ q_frame.qvec = ret['qvec']
+ q_frame.tvec = ret['tvec']
+
+ q_err, t_err = q_frame.compute_pose_error()
+
+ if num_inliers < self.loc_config['min_inliers']:
+ print_text = 'Failed due to insufficient {:d} inliers, q_err: {:.2f}, t_err: {:.2f}'.format(
+ ret['num_inliers'], q_err, t_err)
+ print(print_text)
+ q_frame.tracking_status = False
+ q_frame.clear_localization_track()
+ return False
+ else:
+ print_text = 'Succeed! Find {}/{} 2D-3D inliers,q_err: {:.2f}, t_err: {:.2f}'.format(
+ num_inliers, num_matches, q_err, t_err)
+ print(print_text)
+ q_frame.tracking_status = True
+
+ self.update_current_frame(curr_frame=q_frame, ret=ret)
+ return True
+
+ def update_current_frame(self, curr_frame: Frame, ret: dict):
+ curr_frame.qvec = ret['qvec']
+ curr_frame.tvec = ret['tvec']
+
+ curr_frame.matched_scene_name = ret['matched_scene_name']
+ curr_frame.reference_frame_id = ret['reference_frame_id']
+ inliers = np.array(ret['inliers'])
+
+ curr_frame.matched_keypoints = ret['matched_keypoints'][inliers]
+ curr_frame.matched_xyzs = ret['matched_xyzs'][inliers]
+ curr_frame.matched_point3D_ids = ret['matched_point3D_ids'][inliers]
+ curr_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inliers]
+ curr_frame.matched_sids = ret['matched_sids'][inliers]
+
+ def track_last_frame(self, curr_frame: Frame, last_frame: Frame):
+ curr_kpts = curr_frame.keypoints[:, :2]
+ curr_scores = curr_frame.keypoints[:, 2]
+ curr_descs = curr_frame.descriptors
+ curr_kpt_ids = np.arange(curr_kpts.shape[0])
+
+ last_kpts = last_frame.keypoints[:, :2]
+ last_scores = last_frame.keypoints[:, 2]
+ last_descs = last_frame.descriptors
+ last_xyzs = last_frame.xyzs
+ last_point3D_ids = last_frame.point3D_ids
+ last_sids = last_frame.seg_ids
+
+ # '''
+ indices = self.matcher({
+ 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(),
+ 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(),
+ 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(),
+ 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height),
+
+ 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(),
+ 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(),
+ 'scores1': torch.from_numpy(last_scores)[None].cuda().float(),
+ 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height),
+ })['matches0'][0].cpu().numpy()
+ '''
+
+ indices = self.nn_matcher({
+ 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None],
+ 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None],
+ })['matches0'][0].cpu().numpy()
+ '''
+
+ valid = (indices >= 0)
+
+ matched_point3D_ids = last_point3D_ids[indices[valid]]
+ point3D_mask = (matched_point3D_ids >= 0)
+ matched_point3D_ids = matched_point3D_ids[point3D_mask]
+ matched_sids = last_sids[indices[valid]][point3D_mask]
+
+ matched_kpts = curr_kpts[valid][point3D_mask]
+ matched_kpt_ids = curr_kpt_ids[valid][point3D_mask]
+ matched_xyzs = last_xyzs[indices[valid]][point3D_mask]
+ matched_last_kpts = last_kpts[indices[valid]][point3D_mask]
+
+ print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0],
+ last_kpts.shape[0]))
+
+ # print('tracking: ', matched_kpts.shape, matched_xyzs.shape)
+ ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs,
+ curr_frame.camera,
+ estimation_options={
+ "ransac": {"max_error": self.config['localization']['threshold']}},
+ refinement_options={},
+ # max_error_px=self.config['localization']['threshold']
+ )
+ if ret is None:
+ ret = {'success': False, }
+ else:
+ ret['success'] = True
+ ret['qvec'] = ret['cam_from_world'].rotation.quat[[3, 0, 1, 2]]
+ ret['tvec'] = ret['cam_from_world'].translation
+
+ ret['matched_keypoints'] = matched_kpts
+ ret['matched_keypoint_ids'] = matched_kpt_ids
+ ret['matched_ref_keypoints'] = matched_last_kpts
+ ret['matched_xyzs'] = matched_xyzs
+ ret['matched_point3D_ids'] = matched_point3D_ids
+ ret['matched_sids'] = matched_sids
+ ret['reference_frame_id'] = last_frame.reference_frame_id
+ ret['matched_scene_name'] = last_frame.matched_scene_name
+ return ret
+
+ def track_last_frame_fast(self, curr_frame: Frame, last_frame: Frame):
+ curr_kpts = curr_frame.keypoints[:, :2]
+ curr_scores = curr_frame.keypoints[:, 2]
+ curr_descs = curr_frame.descriptors
+ curr_kpt_ids = np.arange(curr_kpts.shape[0])
+
+ last_point3D_ids = last_frame.point3D_ids
+ point3D_mask = (last_point3D_ids >= 0)
+ last_kpts = last_frame.keypoints[:, :2][point3D_mask]
+ last_scores = last_frame.keypoints[:, 2][point3D_mask]
+ last_descs = last_frame.descriptors[point3D_mask]
+ last_xyzs = last_frame.xyzs[point3D_mask]
+ last_sids = last_frame.seg_ids[point3D_mask]
+
+ minx = np.min(last_kpts[:, 0])
+ maxx = np.max(last_kpts[:, 0])
+ miny = np.min(last_kpts[:, 1])
+ maxy = np.max(last_kpts[:, 1])
+ curr_mask = (curr_kpts[:, 0] >= minx) * (curr_kpts[:, 0] <= maxx) * (curr_kpts[:, 1] >= miny) * (
+ curr_kpts[:, 1] <= maxy)
+
+ curr_kpts = curr_kpts[curr_mask]
+ curr_scores = curr_scores[curr_mask]
+ curr_descs = curr_descs[curr_mask]
+ curr_kpt_ids = curr_kpt_ids[curr_mask]
+ # '''
+ indices = self.matcher({
+ 'descriptors0': torch.from_numpy(curr_descs)[None].cuda().float(),
+ 'keypoints0': torch.from_numpy(curr_kpts)[None].cuda().float(),
+ 'scores0': torch.from_numpy(curr_scores)[None].cuda().float(),
+ 'image_shape0': (1, 3, curr_frame.camera.width, curr_frame.camera.height),
+
+ 'descriptors1': torch.from_numpy(last_descs)[None].cuda().float(),
+ 'keypoints1': torch.from_numpy(last_kpts)[None].cuda().float(),
+ 'scores1': torch.from_numpy(last_scores)[None].cuda().float(),
+ 'image_shape1': (1, 3, last_frame.camera.width, last_frame.camera.height),
+ })['matches0'][0].cpu().numpy()
+ '''
+
+ indices = self.nn_matcher({
+ 'descriptors0': torch.from_numpy(curr_descs.transpose()).float().cuda()[None],
+ 'descriptors1': torch.from_numpy(last_descs.transpose()).float().cuda()[None],
+ })['matches0'][0].cpu().numpy()
+ '''
+
+ valid = (indices >= 0)
+
+ matched_point3D_ids = last_point3D_ids[indices[valid]]
+ matched_sids = last_sids[indices[valid]]
+
+ matched_kpts = curr_kpts[valid]
+ matched_kpt_ids = curr_kpt_ids[valid]
+ matched_xyzs = last_xyzs[indices[valid]]
+ matched_last_kpts = last_kpts[indices[valid]]
+
+ print('Tracking: {:d} matches from {:d}-{:d} kpts'.format(matched_kpts.shape[0], curr_kpts.shape[0],
+ last_kpts.shape[0]))
+
+ # print('tracking: ', matched_kpts.shape, matched_xyzs.shape)
+ ret = pycolmap.absolute_pose_estimation(matched_kpts + 0.5, matched_xyzs,
+ curr_frame.camera._asdict(),
+ max_error_px=self.config['localization']['threshold'])
+
+ ret['matched_keypoints'] = matched_kpts
+ ret['matched_keypoint_ids'] = matched_kpt_ids
+ ret['matched_ref_keypoints'] = matched_last_kpts
+ ret['matched_xyzs'] = matched_xyzs
+ ret['matched_point3D_ids'] = matched_point3D_ids
+ ret['matched_sids'] = matched_sids
+ ret['reference_frame_id'] = last_frame.reference_frame_id
+ ret['matched_scene_name'] = last_frame.matched_scene_name
+ return ret
+
+ @torch.no_grad()
+ def match_frame(self, frame: Frame, reference_frame: Frame):
+ print('match: ', frame.keypoints.shape, reference_frame.keypoints.shape)
+ matches = self.matcher({
+ 'descriptors0': torch.from_numpy(frame.descriptors)[None].cuda().float(),
+ 'keypoints0': torch.from_numpy(frame.keypoints[:, :2])[None].cuda().float(),
+ 'scores0': torch.from_numpy(frame.keypoints[:, 2])[None].cuda().float(),
+ 'image_shape0': (1, 3, frame.image_size[0], frame.image_size[1]),
+
+ # 'descriptors0': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(),
+ # 'keypoints0': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(),
+ # 'scores0': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(),
+ # 'image_shape0': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]),
+
+ 'descriptors1': torch.from_numpy(reference_frame.descriptors)[None].cuda().float(),
+ 'keypoints1': torch.from_numpy(reference_frame.keypoints[:, :2])[None].cuda().float(),
+ 'scores1': torch.from_numpy(reference_frame.keypoints[:, 2])[None].cuda().float(),
+ 'image_shape1': (1, 3, reference_frame.image_size[0], reference_frame.image_size[1]),
+
+ })['matches0'][0].cpu().numpy()
+
+ ids1 = np.arange(matches.shape[0])
+ ids2 = matches
+ ids1 = ids1[matches >= 0]
+ ids2 = ids2[matches >= 0]
+
+ mask_p3ds = reference_frame.points3d_mask[ids2]
+ ids1 = ids1[mask_p3ds]
+ ids2 = ids2[mask_p3ds]
+
+ return ids1, ids2
diff --git a/third_party/pram/localization/triangulation.py b/third_party/pram/localization/triangulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5b885ec4be9c328353af9c0b0aaf136d694556a
--- /dev/null
+++ b/third_party/pram/localization/triangulation.py
@@ -0,0 +1,317 @@
+# code is from hloc https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/triangulation.py
+import argparse
+import contextlib
+import io
+import sys
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import pycolmap
+from tqdm import tqdm
+
+from colmap_utils.database import COLMAPDatabase
+from colmap_utils.geometry import compute_epipolar_errors
+from colmap_utils.io import get_keypoints, get_matches
+from colmap_utils.parsers import parse_retrieval
+import logging
+
+
+class OutputCapture:
+ def __init__(self, verbose: bool):
+ self.verbose = verbose
+
+ def __enter__(self):
+ if not self.verbose:
+ self.capture = contextlib.redirect_stdout(io.StringIO())
+ self.out = self.capture.__enter__()
+
+ def __exit__(self, exc_type, *args):
+ if not self.verbose:
+ self.capture.__exit__(exc_type, *args)
+ if exc_type is not None:
+ # logger.error("Failed with output:\n%s", self.out.getvalue())
+ logging.error("Failed with output:\n%s", self.out.getvalue())
+ sys.stdout.flush()
+
+
+def create_db_from_model(
+ reconstruction: pycolmap.Reconstruction, database_path: Path
+) -> Dict[str, int]:
+ if database_path.exists():
+ # logger.warning("The database already exists, deleting it.")
+ logging.warning("The database already exists, deleting it.")
+ database_path.unlink()
+
+ db = COLMAPDatabase.connect(database_path)
+ db.create_tables()
+
+ for i, camera in reconstruction.cameras.items():
+ db.add_camera(
+ camera.model.value,
+ camera.width,
+ camera.height,
+ camera.params,
+ camera_id=i,
+ prior_focal_length=True,
+ )
+
+ for i, image in reconstruction.images.items():
+ db.add_image(image.name, image.camera_id, image_id=i)
+
+ db.commit()
+ db.close()
+ return {image.name: i for i, image in reconstruction.images.items()}
+
+
+def import_features(
+ image_ids: Dict[str, int], database_path: Path, features_path: Path
+):
+ # logger.info("Importing features into the database...")
+ logging.info("Importing features into the database...")
+ db = COLMAPDatabase.connect(database_path)
+
+ for image_name, image_id in tqdm(image_ids.items()):
+ keypoints = get_keypoints(features_path, image_name)
+ keypoints += 0.5 # COLMAP origin
+ db.add_keypoints(image_id, keypoints)
+
+ db.commit()
+ db.close()
+
+
+def import_matches(
+ image_ids: Dict[str, int],
+ database_path: Path,
+ pairs_path: Path,
+ matches_path: Path,
+ min_match_score: Optional[float] = None,
+ skip_geometric_verification: bool = False,
+):
+ # logger.info("Importing matches into the database...")
+ logging.info("Importing matches into the database...")
+
+ with open(str(pairs_path), "r") as f:
+ pairs = [p.split() for p in f.readlines()]
+
+ db = COLMAPDatabase.connect(database_path)
+
+ matched = set()
+ for name0, name1 in tqdm(pairs):
+ id0, id1 = image_ids[name0], image_ids[name1]
+ if len({(id0, id1), (id1, id0)} & matched) > 0:
+ continue
+ matches, scores = get_matches(matches_path, name0, name1)
+ if min_match_score:
+ matches = matches[scores > min_match_score]
+ db.add_matches(id0, id1, matches)
+ matched |= {(id0, id1), (id1, id0)}
+
+ if skip_geometric_verification:
+ db.add_two_view_geometry(id0, id1, matches)
+
+ db.commit()
+ db.close()
+
+
+def estimation_and_geometric_verification(
+ database_path: Path, pairs_path: Path, verbose: bool = False
+):
+ # logger.info("Performing geometric verification of the matches...")
+ logging.info("Performing geometric verification of the matches...")
+ with OutputCapture(verbose):
+ with pycolmap.ostream():
+ pycolmap.verify_matches(
+ database_path,
+ pairs_path,
+ options=dict(ransac=dict(max_num_trials=20000, min_inlier_ratio=0.1)),
+ )
+
+
+def geometric_verification(
+ image_ids: Dict[str, int],
+ reference: pycolmap.Reconstruction,
+ database_path: Path,
+ features_path: Path,
+ pairs_path: Path,
+ matches_path: Path,
+ max_error: float = 4.0,
+):
+ # logger.info("Performing geometric verification of the matches...")
+ logging.info("Performing geometric verification of the matches...")
+
+ pairs = parse_retrieval(pairs_path)
+ db = COLMAPDatabase.connect(database_path)
+
+ inlier_ratios = []
+ matched = set()
+ for name0 in tqdm(pairs):
+ id0 = image_ids[name0]
+ image0 = reference.images[id0]
+ cam0 = reference.cameras[image0.camera_id]
+ kps0, noise0 = get_keypoints(features_path, name0, return_uncertainty=True)
+ noise0 = 1.0 if noise0 is None else noise0
+ if len(kps0) > 0:
+ kps0 = np.stack(cam0.cam_from_img(kps0))
+ else:
+ kps0 = np.zeros((0, 2))
+
+ for name1 in pairs[name0]:
+ id1 = image_ids[name1]
+ image1 = reference.images[id1]
+ cam1 = reference.cameras[image1.camera_id]
+ kps1, noise1 = get_keypoints(features_path, name1, return_uncertainty=True)
+ noise1 = 1.0 if noise1 is None else noise1
+ if len(kps1) > 0:
+ kps1 = np.stack(cam1.cam_from_img(kps1))
+ else:
+ kps1 = np.zeros((0, 2))
+
+ matches = get_matches(matches_path, name0, name1)[0]
+
+ if len({(id0, id1), (id1, id0)} & matched) > 0:
+ continue
+ matched |= {(id0, id1), (id1, id0)}
+
+ if matches.shape[0] == 0:
+ db.add_two_view_geometry(id0, id1, matches)
+ continue
+
+ cam1_from_cam0 = image1.cam_from_world * image0.cam_from_world.inverse()
+ errors0, errors1 = compute_epipolar_errors(
+ cam1_from_cam0, kps0[matches[:, 0]], kps1[matches[:, 1]]
+ )
+ valid_matches = np.logical_and(
+ errors0 <= cam0.cam_from_img_threshold(noise0 * max_error),
+ errors1 <= cam1.cam_from_img_threshold(noise1 * max_error),
+ )
+ # TODO: We could also add E to the database, but we need
+ # to reverse the transformations if id0 > id1 in utils/database.py.
+ db.add_two_view_geometry(id0, id1, matches[valid_matches, :])
+ inlier_ratios.append(np.mean(valid_matches))
+ # logger.info(
+ logging.info(
+ "mean/med/min/max valid matches %.2f/%.2f/%.2f/%.2f%%.",
+ np.mean(inlier_ratios) * 100,
+ np.median(inlier_ratios) * 100,
+ np.min(inlier_ratios) * 100,
+ np.max(inlier_ratios) * 100,
+ )
+
+ db.commit()
+ db.close()
+
+
+def run_triangulation(
+ model_path: Path,
+ database_path: Path,
+ image_dir: Path,
+ reference_model: pycolmap.Reconstruction,
+ verbose: bool = False,
+ options: Optional[Dict[str, Any]] = None,
+) -> pycolmap.Reconstruction:
+ model_path.mkdir(parents=True, exist_ok=True)
+ # logger.info("Running 3D triangulation...")
+ logging.info("Running 3D triangulation...")
+ if options is None:
+ options = {}
+ with OutputCapture(verbose):
+ with pycolmap.ostream():
+ reconstruction = pycolmap.triangulate_points(
+ reference_model, database_path, image_dir, model_path, options=options
+ )
+ return reconstruction
+
+
+def main(
+ sfm_dir: Path,
+ reference_sfm_model: Path,
+ image_dir: Path,
+ pairs: Path,
+ features: Path,
+ matches: Path,
+ skip_geometric_verification: bool = False,
+ estimate_two_view_geometries: bool = False,
+ min_match_score: Optional[float] = None,
+ verbose: bool = False,
+ mapper_options: Optional[Dict[str, Any]] = None,
+) -> pycolmap.Reconstruction:
+ assert reference_sfm_model.exists(), reference_sfm_model
+ assert features.exists(), features
+ assert pairs.exists(), pairs
+ assert matches.exists(), matches
+
+ sfm_dir.mkdir(parents=True, exist_ok=True)
+ database = sfm_dir / "database.db"
+ reference = pycolmap.Reconstruction(reference_sfm_model)
+
+ image_ids = create_db_from_model(reference, database)
+ import_features(image_ids, database, features)
+ import_matches(
+ image_ids,
+ database,
+ pairs,
+ matches,
+ min_match_score,
+ skip_geometric_verification,
+ )
+ if not skip_geometric_verification:
+ if estimate_two_view_geometries:
+ estimation_and_geometric_verification(database, pairs, verbose)
+ else:
+ geometric_verification(
+ image_ids, reference, database, features, pairs, matches
+ )
+ reconstruction = run_triangulation(
+ sfm_dir, database, image_dir, reference, verbose, mapper_options
+ )
+ # logger.info(
+ logging.info(
+ "Finished the triangulation with statistics:\n%s", reconstruction.summary()
+ )
+ stats = reconstruction.summary()
+ with open(sfm_dir / 'statics.txt', 'w') as f:
+ f.write(stats + '\n')
+
+ # logging.info(f'Statistics:\n{pprint.pformat(stats)}')
+ return reconstruction
+
+
+def parse_option_args(args: List[str], default_options) -> Dict[str, Any]:
+ options = {}
+ for arg in args:
+ idx = arg.find("=")
+ if idx == -1:
+ raise ValueError("Options format: key1=value1 key2=value2 etc.")
+ key, value = arg[:idx], arg[idx + 1:]
+ if not hasattr(default_options, key):
+ raise ValueError(
+ f'Unknown option "{key}", allowed options and default values'
+ f" for {default_options.summary()}"
+ )
+ value = eval(value)
+ target_type = type(getattr(default_options, key))
+ if not isinstance(value, target_type):
+ raise ValueError(
+ f'Incorrect type for option "{key}":' f" {type(value)} vs {target_type}"
+ )
+ options[key] = value
+ return options
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--sfm_dir", type=Path, required=True)
+ parser.add_argument("--reference_sfm_model", type=Path, required=True)
+ parser.add_argument("--image_dir", type=Path, required=True)
+
+ parser.add_argument("--pairs", type=Path, required=True)
+ parser.add_argument("--features", type=Path, required=True)
+ parser.add_argument("--matches", type=Path, required=True)
+
+ parser.add_argument("--skip_geometric_verification", action="store_true")
+ parser.add_argument("--min_match_score", type=float)
+ parser.add_argument("--verbose", action="store_true")
+ args = parser.parse_args().__dict__
+
+ main(**args)
diff --git a/third_party/pram/localization/utils.py b/third_party/pram/localization/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e5861afceba6bed7518921145505b01caf66954
--- /dev/null
+++ b/third_party/pram/localization/utils.py
@@ -0,0 +1,83 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> utils
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 15:27
+=================================================='''
+import numpy as np
+from colmap_utils.read_write_model import qvec2rotmat
+
+
+def read_query_info(query_fn: str, name_prefix='') -> dict:
+ results = {}
+ with open(query_fn, 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip().split()
+ name, camera_model, width, height = l[:4]
+ params = np.array(l[4:], float)
+ info = (camera_model, int(width), int(height), params)
+ results[name_prefix + name] = info
+ print('Load {} query images'.format(len(results.keys())))
+ return results
+
+
+def quaternion_angular_error(q1, q2):
+ """
+ angular error between two quaternions
+ :param q1: (4, )
+ :param q2: (4, )
+ :return:
+ """
+ d = abs(np.dot(q1, q2))
+ d = min(1.0, max(-1.0, d))
+ theta = 2 * np.arccos(d) * 180 / np.pi
+ return theta
+
+
+def compute_pose_error(pred_qcw, pred_tcw, gt_qcw, gt_tcw):
+ pred_Rcw = qvec2rotmat(qvec=pred_qcw)
+ pred_tcw = np.array(pred_tcw, float).reshape(3, 1)
+ pred_twc = -pred_Rcw.transpose() @ pred_tcw
+
+ gt_Rcw = qvec2rotmat(gt_qcw)
+ gt_tcw = np.array(gt_tcw, float).reshape(3, 1)
+ gt_twc = -gt_Rcw.transpose() @ gt_tcw
+
+ t_error_xyz = pred_twc - gt_twc
+ t_error = np.sqrt(np.sum(t_error_xyz ** 2))
+
+ q_error = quaternion_angular_error(q1=pred_qcw, q2=gt_qcw)
+
+ return q_error, t_error
+
+
+def read_retrieval_results(path):
+ output = {}
+ with open(path, "r") as f:
+ lines = f.readlines()
+ for p in lines:
+ p = p.strip("\n").split(" ")
+
+ if p[1] == "no_match":
+ continue
+ if p[0] in output.keys():
+ output[p[0]].append(p[1])
+ else:
+ output[p[0]] = [p[1]]
+ return output
+
+
+def read_gt_pose(path):
+ gt_poses = {}
+ with open(path, 'r') as f:
+ lines = f.readlines()
+ for l in lines:
+ l = l.strip().split(' ')
+ gt_poses[l[0]] = {
+ 'qvec': np.array([float(v) for v in l[1:5]], float),
+ 'tvec': np.array([float(v) for v in l[5:]], float),
+ }
+
+ return gt_poses
diff --git a/third_party/pram/localization/viewer.py b/third_party/pram/localization/viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..33899f60ab362e240b7b0e6736a157a7aa041d31
--- /dev/null
+++ b/third_party/pram/localization/viewer.py
@@ -0,0 +1,548 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> viewer
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 05/03/2024 16:50
+=================================================='''
+import cv2
+import numpy as np
+import pypangolin as pangolin
+from OpenGL.GL import *
+import time
+import threading
+from colmap_utils.read_write_model import qvec2rotmat
+from tools.common import resize_image_with_padding
+from localization.frame import Frame
+
+
+class Viewer:
+ default_config = {
+ 'image_size_indoor': 0.1,
+ 'image_line_width_indoor': 1,
+
+ 'image_size_outdoor': 1,
+ 'image_line_width_outdoor': 3,
+
+ 'point_size_indoor': 1,
+ 'point_size_outdoor': 1,
+
+ 'image_width': 640,
+ 'image_height': 480,
+
+ 'viewpoint_x': 0,
+ 'viewpoint_y': -1,
+ 'viewpoint_z': -5,
+ 'viewpoint_F': 512,
+
+ 'scene': 'indoor',
+ }
+
+ def __init__(self, locMap, seg_color, config={}):
+ self.config = {**self.default_config, **config}
+ self.viewpoint_x = self.config['viewpoint_x']
+ self.viewpoint_y = self.config['viewpoint_y']
+ self.viewpoint_z = self.config['viewpoint_z']
+ self.viewpoint_F = self.config['viewpoint_F']
+ self.img_width = self.config['image_width']
+ self.img_height = self.config['image_height']
+
+ if self.config['scene'] == 'indoor':
+ self.image_size = self.config['image_size_indoor']
+ self.image_line_width = self.config['image_line_width_indoor']
+ self.point_size = self.config['point_size_indoor']
+
+ else:
+ self.image_size = self.config['image_size_outdoor']
+ self.image_line_width = self.config['image_line_width_outdoor']
+ self.point_size = self.config['point_size_outdoor']
+ self.viewpoint_z = -150
+
+ self.locMap = locMap
+ self.seg_colors = seg_color
+
+ # current camera pose
+ self.frame = None
+ self.Tcw = np.eye(4, dtype=float)
+ self.Twc = np.linalg.inv(self.Tcw)
+ self.gt_Tcw = None
+ self.gt_Twc = None
+
+ self.scene = None
+ self.current_vrf_id = None
+ self.reference_frame_ids = None
+ self.subMap = None
+ self.seg_point_clouds = None
+ self.point_clouds = None
+
+ self.start_seg_id = 1
+ self.stop = False
+
+ self.refinement = False
+ self.tracking = False
+
+ # time
+ self.time_feat = np.NAN
+ self.time_rec = np.NAN
+ self.time_loc = np.NAN
+ self.time_ref = np.NAN
+
+ # image
+ self.image_rec = None
+
+ def draw_3d_points_white(self):
+ if self.point_clouds is None:
+ return
+
+ point_size = self.point_size * 0.5
+ glColor4f(0.9, 0.95, 1.0, 0.6)
+ glPointSize(point_size)
+ pangolin.glDrawPoints(self.point_clouds)
+
+ def draw_seg_3d_points(self):
+ if self.seg_point_clouds is None:
+ return
+ for sid in self.seg_point_clouds.keys():
+ xyzs = self.seg_point_clouds[sid]
+ point_size = self.point_size * 0.5
+ bgr = self.seg_colors[sid + self.start_seg_id + 1]
+ glColor3f(bgr[2] / 255, bgr[1] / 255, bgr[0] / 255)
+ glPointSize(point_size)
+ pangolin.glDrawPoints(xyzs)
+
+ def draw_ref_3d_points(self, use_seg_color=False):
+ if self.reference_frame_ids is None:
+ return
+
+ ref_point3D_ids = []
+ for fid in self.reference_frame_ids:
+ pids = self.subMap.reference_frames[fid].point3D_ids
+ ref_point3D_ids.extend(list(pids))
+
+ ref_point3D_ids = np.unique(ref_point3D_ids).tolist()
+
+ point_size = self.point_size * 5
+ glPointSize(point_size)
+ glBegin(GL_POINTS)
+
+ for pid in ref_point3D_ids:
+ if pid not in self.subMap.point3Ds.keys():
+ continue
+ xyz = self.subMap.point3Ds[pid].xyz
+ rgb = self.subMap.point3Ds[pid].rgb
+ sid = self.subMap.point3Ds[pid].seg_id
+ if use_seg_color:
+ bgr = self.seg_colors[sid + self.start_seg_id + 1]
+ glColor3f(bgr[2] / 255, bgr[1] / 255, bgr[0] / 255)
+ else:
+ glColor3f(rgb[0] / 255, rgb[1] / 255, rgb[2] / 255)
+
+ glVertex3f(xyz[0], xyz[1], xyz[2])
+
+ glEnd()
+
+ def draw_vrf_frames(self):
+ if self.subMap is None:
+ return
+ w = self.image_size * 1.0
+ image_line_width = self.image_line_width * 1.0
+ h = w * 0.75
+ z = w * 0.6
+ for sid in self.subMap.seg_ref_frame_ids.keys():
+ frame_id = self.subMap.seg_ref_frame_ids[sid][0]
+ qvec = self.subMap.reference_frames[frame_id].qvec
+ tcw = self.subMap.reference_frames[frame_id].tvec
+
+ Rcw = qvec2rotmat(qvec)
+
+ twc = -Rcw.T @ tcw
+ Rwc = Rcw.T
+
+ Twc = np.column_stack((Rwc, twc))
+ Twc = np.vstack((Twc, (0, 0, 0, 1)))
+
+ glPushMatrix()
+
+ glMultMatrixf(Twc.T)
+
+ glLineWidth(image_line_width)
+ glColor3f(1, 0, 0)
+ glBegin(GL_LINES)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, h, z)
+
+ glVertex3f(w, h, z)
+ glVertex3f(w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(-w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(w, h, z)
+
+ glVertex3f(-w, -h, z)
+ glVertex3f(w, -h, z)
+ glEnd()
+
+ glPopMatrix()
+
+ def draw_current_vrf_frame(self):
+ if self.current_vrf_id is None:
+ return
+ qvec = self.subMap.reference_frames[self.current_vrf_id].qvec
+ tcw = self.subMap.reference_frames[self.current_vrf_id].tvec
+ Rcw = qvec2rotmat(qvec)
+ twc = -Rcw.T @ tcw
+ Rwc = Rcw.T
+ Twc = np.column_stack((Rwc, twc))
+ Twc = np.vstack((Twc, (0, 0, 0, 1)))
+
+ camera_line_width = self.image_line_width * 2
+ w = self.image_size * 2
+ h = w * 0.75
+ z = w * 0.6
+
+ glPushMatrix()
+
+ glMultMatrixf(Twc.T) # note the .T
+
+ glLineWidth(camera_line_width)
+ glColor3f(1, 0, 0)
+ glBegin(GL_LINES)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, h, z)
+
+ glVertex3f(w, h, z)
+ glVertex3f(w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(-w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(w, h, z)
+
+ glVertex3f(-w, -h, z)
+ glVertex3f(w, -h, z)
+ glEnd()
+
+ glPopMatrix()
+
+ def draw_current_frame(self, Tcw, color=(0, 1.0, 0)):
+ Twc = np.linalg.inv(Tcw)
+
+ camera_line_width = self.image_line_width * 2
+ w = self.image_size * 2
+ h = w * 0.75
+ z = w * 0.6
+
+ glPushMatrix()
+
+ glMultMatrixf(Twc.T) # not the .T
+
+ glLineWidth(camera_line_width)
+ glColor3f(color[0], color[1], color[2])
+ glBegin(GL_LINES)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, h, z)
+
+ glVertex3f(w, h, z)
+ glVertex3f(w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(-w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(w, h, z)
+
+ glVertex3f(-w, -h, z)
+ glVertex3f(w, -h, z)
+ glEnd()
+
+ glPopMatrix()
+
+ def draw_ref_frames(self):
+ if self.reference_frame_ids is None:
+ return
+ w = self.image_size * 1.5
+ image_line_width = self.image_line_width * 1.5
+ h = w * 0.75
+ z = w * 0.6
+ for fid in self.reference_frame_ids:
+ qvec = self.subMap.reference_frames[fid].qvec
+ tcw = self.subMap.reference_frames[fid].tvec
+ Rcw = qvec2rotmat(qvec)
+
+ twc = -Rcw.T @ tcw
+ Rwc = Rcw.T
+
+ Twc = np.column_stack((Rwc, twc))
+ Twc = np.vstack((Twc, (0, 0, 0, 1)))
+
+ glPushMatrix()
+
+ glMultMatrixf(Twc.T)
+
+ glLineWidth(image_line_width)
+ glColor3f(100 / 255, 140 / 255, 17 / 255)
+ glBegin(GL_LINES)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, -h, z)
+ glVertex3f(0, 0, 0)
+ glVertex3f(-w, h, z)
+
+ glVertex3f(w, h, z)
+ glVertex3f(w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(-w, -h, z)
+
+ glVertex3f(-w, h, z)
+ glVertex3f(w, h, z)
+
+ glVertex3f(-w, -h, z)
+ glVertex3f(w, -h, z)
+ glEnd()
+
+ glPopMatrix()
+
+ def terminate(self):
+ lock = threading.Lock()
+ lock.acquire()
+ self.stop = True
+ lock.release()
+
+ def update_point_clouds(self):
+ # for fast drawing
+ seg_point_clouds = {}
+ point_clouds = []
+ for pid in self.subMap.point3Ds.keys():
+ sid = self.subMap.point3Ds[pid].seg_id
+ xyz = self.subMap.point3Ds[pid].xyz
+ if sid in seg_point_clouds.keys():
+ seg_point_clouds[sid].append(xyz.reshape(3, 1))
+ else:
+ seg_point_clouds[sid] = [xyz.reshape(3, 1)]
+
+ point_clouds.append(xyz.reshape(3, 1))
+
+ self.seg_point_clouds = seg_point_clouds
+ self.point_clouds = point_clouds
+
+ def update(self, curr_frame: Frame):
+ lock = threading.Lock()
+ lock.acquire()
+
+ # self.frame = curr_frame
+ self.current_vrf_id = curr_frame.reference_frame_id
+ self.reference_frame_ids = [self.current_vrf_id]
+
+ # self.reference_frame_ids = curr_frame.refinement_reference_frame_ids
+ # if self.reference_frame_ids is None:
+ # self.reference_frame_ids = [self.current_vrf_id]
+ self.subMap = self.locMap.sub_maps[curr_frame.matched_scene_name]
+ self.start_seg_id = self.locMap.scene_name_start_sid[curr_frame.matched_scene_name]
+
+ if self.scene is None or self.scene != curr_frame.matched_scene_name:
+ self.scene = curr_frame.matched_scene_name
+ self.update_point_clouds()
+
+ if curr_frame.qvec is not None:
+ Rcw = qvec2rotmat(curr_frame.qvec)
+ Tcw = np.column_stack((Rcw, curr_frame.tvec))
+ self.Tcw = np.vstack((Tcw, (0, 0, 0, 1)))
+ Rwc = Rcw.T
+ twc = -Rcw.T @ curr_frame.tvec
+ Twc = np.column_stack((Rwc, twc))
+ self.Twc = np.vstack((Twc, (0, 0, 0, 1)))
+
+ if curr_frame.gt_qvec is not None:
+ gt_Rcw = qvec2rotmat(curr_frame.gt_qvec)
+ gt_Tcw = np.column_stack((gt_Rcw, curr_frame.gt_tvec))
+ self.gt_Tcw = np.vstack((gt_Tcw, (0, 0, 0, 1)))
+ gt_Rwc = gt_Rcw.T
+ gt_twc = -gt_Rcw.T @ curr_frame.gt_tvec
+ gt_Twc = np.column_stack((gt_Rwc, gt_twc))
+ self.gt_Twc = np.vstack((gt_Twc, (0, 0, 0, 1)))
+ else:
+ self.gt_Tcw = None
+ self.gt_Twc = None
+
+ # update time
+ self.time_feat = curr_frame.time_feat
+ self.time_rec = curr_frame.time_rec
+ self.time_loc = curr_frame.time_loc
+ self.time_ref = curr_frame.time_ref
+
+ # update image
+ image_rec_inlier = np.hstack([curr_frame.image_rec, curr_frame.image_inlier])
+ image_rec_inlier = resize_image_with_padding(image=image_rec_inlier, nw=self.img_width * 2, nh=self.img_height)
+ image_matching = resize_image_with_padding(image=curr_frame.image_matching, nw=self.img_width * 2,
+ nh=self.img_height)
+ image_rec_matching_inliers = resize_image_with_padding(image=np.vstack([image_rec_inlier, image_matching]),
+ nw=self.img_width * 2, nh=self.img_height * 2)
+
+ self.image_rec = cv2.cvtColor(image_rec_matching_inliers, cv2.COLOR_BGR2RGB)
+ lock.release()
+
+ def run(self):
+ pangolin.CreateWindowAndBind("Map reviewer", 640, 480)
+ glEnable(GL_DEPTH_TEST)
+ glEnable(GL_BLEND)
+ glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
+
+ pangolin.CreatePanel("menu").SetBounds(pangolin.Attach(0),
+ pangolin.Attach(1),
+ pangolin.Attach(0),
+ # pangolin.Attach.Pix(-175),
+ pangolin.Attach.Pix(175),
+ # pangolin.Attach(1)
+ )
+
+ menu = pangolin.Var("menu")
+ menu.Tracking = (False, pangolin.VarMeta(toggle=True))
+ menu.FollowCamera = (True, pangolin.VarMeta(toggle=True))
+ menu.ShowPoints = (True, pangolin.VarMeta(toggle=True))
+ menu.ShowSegs = (False, pangolin.VarMeta(toggle=True))
+ menu.ShowRefSegs = (True, pangolin.VarMeta(toggle=True))
+ menu.ShowRefPoints = (False, pangolin.VarMeta(toggle=True))
+ menu.ShowVRFFrame = (True, pangolin.VarMeta(toggle=True))
+ menu.ShowAllVRFs = (False, pangolin.VarMeta(toggle=True))
+ menu.ShowRefFrames = (False, pangolin.VarMeta(toggle=True))
+
+ menu.Refinement = (self.refinement, pangolin.VarMeta(toggle=True))
+
+ menu.featTime = 'NaN'
+ menu.recTime = 'NaN'
+ menu.locTime = 'NaN'
+ menu.refTime = 'NaN'
+ menu.totalTime = 'NaN'
+
+ pm = pangolin.ProjectionMatrix(640, 480, self.viewpoint_F, self.viewpoint_F, 320, 240, 0.1,
+ 10000)
+
+ # /camera position,viewpoint position,axis direction
+ mv = pangolin.ModelViewLookAt(self.viewpoint_x,
+ self.viewpoint_y,
+ self.viewpoint_z,
+ 0, 0, 0,
+ # 0.0, -1.0, 0.0,
+ pangolin.AxisZ,
+ )
+
+ s_cam = pangolin.OpenGlRenderState(pm, mv)
+ # Attach bottom, Attach top, Attach left, Attach right,
+ scale = 0.42
+ d_img_rec = pangolin.Display('image_rec').SetBounds(pangolin.Attach(1 - scale),
+ pangolin.Attach(1),
+ pangolin.Attach(
+ 1 - 0.3),
+ pangolin.Attach(1),
+ self.img_width / self.img_height
+ ) # .SetLock(0, 1)
+
+ handler = pangolin.Handler3D(s_cam)
+
+ d_cam = pangolin.Display('3D').SetBounds(
+ pangolin.Attach(0), # bottom
+ pangolin.Attach(1), # top
+ pangolin.Attach.Pix(175), # left
+ # pangolin.Attach.Pix(0), # left
+ pangolin.Attach(1), # right
+ -640 / 480, # aspect
+ ).SetHandler(handler)
+
+ d_img_rec_texture = pangolin.GlTexture(self.img_width * 2, self.img_height * 2, GL_RGB, False, 0, GL_RGB,
+ GL_UNSIGNED_BYTE)
+ while not pangolin.ShouldQuit() and not self.stop:
+ glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
+
+ # glClearColor(1.0, 1.0, 1.0, 1.0)
+ glClearColor(0.0, 0.0, 0.0, 1.0)
+
+ d_cam.Activate(s_cam)
+ if menu.FollowCamera:
+ s_cam.Follow(pangolin.OpenGlMatrix(self.Twc.astype(np.float32)), follow=True)
+
+ # pangolin.glDrawColouredCube()
+ if menu.ShowPoints:
+ self.draw_3d_points_white()
+
+ if menu.ShowRefPoints:
+ self.draw_ref_3d_points(use_seg_color=False)
+ if menu.ShowRefSegs:
+ self.draw_ref_3d_points(use_seg_color=True)
+
+ if menu.ShowSegs:
+ self.draw_seg_3d_points()
+
+ if menu.ShowAllVRFs:
+ self.draw_vrf_frames()
+
+ if menu.ShowRefFrames:
+ self.draw_ref_frames()
+
+ if menu.ShowVRFFrame:
+ self.draw_current_vrf_frame()
+
+ if menu.Refinement:
+ self.refinement = True
+ else:
+ self.refinement = False
+
+ if menu.Tracking:
+ self.tracking = True
+ else:
+ self.tracking = False
+
+ self.draw_current_frame(Tcw=self.Tcw)
+
+ if self.gt_Tcw is not None: # draw gt pose with color (0, 0, 1.0)
+ self.draw_current_frame(Tcw=self.gt_Tcw, color=(0., 0., 1.0))
+
+ d_img_rec.Activate()
+ glColor4f(1, 1, 1, 1)
+
+ if self.image_rec is not None:
+ d_img_rec_texture.Upload(self.image_rec, GL_RGB, GL_UNSIGNED_BYTE)
+ d_img_rec_texture.RenderToViewportFlipY()
+
+ time_total = 0
+ if self.time_feat != np.NAN:
+ menu.featTime = '{:.2f}s'.format(self.time_feat)
+ time_total = time_total + self.time_feat
+ if self.time_rec != np.NAN:
+ menu.recTime = '{:.2f}s'.format(self.time_rec)
+ time_total = time_total + self.time_rec
+ if self.time_loc != np.NAN:
+ menu.locTime = '{:.2f}s'.format(self.time_loc)
+ time_total = time_total + self.time_loc
+ if self.time_ref != np.NAN:
+ menu.refTime = '{:.2f}s'.format(self.time_ref)
+ time_total = time_total + self.time_ref
+ menu.totalTime = '{:.2f}s'.format(time_total)
+
+ time.sleep(50 / 1000)
+
+ pangolin.FinishFrame()
diff --git a/third_party/pram/main.py b/third_party/pram/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f32b1e9087dcf7edd152911cf09bef93f0555d5
--- /dev/null
+++ b/third_party/pram/main.py
@@ -0,0 +1,228 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> train
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:26
+=================================================='''
+import argparse
+import os
+import os.path as osp
+import torch
+import torchvision.transforms.transforms as tvt
+import yaml
+import torch.utils.data as Data
+import torch.multiprocessing as mp
+import torch.distributed as dist
+
+from nets.segnet import SegNet
+from nets.segnetvit import SegNetViT
+from dataset.utils import collect_batch
+from dataset.get_dataset import compose_datasets
+from tools.common import torch_set_gpu
+from trainer import Trainer
+
+from nets.sfd2 import ResNet4x, DescriptorCompressor
+from nets.superpoint import SuperPoint
+
+torch.set_grad_enabled(True)
+
+parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--config', type=str, required=True, help='config of specifications')
+parser.add_argument('--landmark_path', type=str, default=None, help='path of landmarks')
+
+
+def load_feat_network(config):
+ if config['feature'] == 'spp':
+ net = SuperPoint(config={
+ 'weight_path': '/scratches/flyer_2/fx221/Research/Code/third_weights/superpoint_v1.pth',
+ }).eval()
+ elif config['feature'] == 'resnet4x':
+ net = ResNet4x(inputdim=3, outdim=128)
+ net.load_state_dict(
+ torch.load('weights/sfd2_20230511_210205_resnet4x.79.pth', map_location='cpu')['state_dict'],
+ strict=True)
+ net.eval()
+ else:
+ print('Please input correct feature {:s}'.format(config['feature']))
+ net = None
+
+ if config['feat_dim'] != 128:
+ desc_compressor = DescriptorCompressor(inputdim=128, outdim=config['feat_dim']).eval()
+ if config['feat_dim'] == 64:
+ desc_compressor.load_state_dict(
+ torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O64.pth',
+ map_location='cpu'),
+ strict=True)
+ elif config['feat_dim'] == 32:
+ desc_compressor.load_state_dict(
+ torch.load('weights/20230511_210205_resnet4x_B6_R512_I3_O128_pho_resnet4x_e79_to_O32.pth',
+ map_location='cpu'),
+ strict=True)
+ else:
+ desc_compressor = None
+ else:
+ desc_compressor = None
+ return net, desc_compressor
+
+
+def get_model(config):
+ desc_dim = 256 if config['feature'] == 'spp' else 128
+ if config['use_mid_feature']:
+ desc_dim = 256
+ model_config = {
+ 'network': {
+ 'descriptor_dim': desc_dim,
+ 'n_layers': config['layers'],
+ 'ac_fn': config['ac_fn'],
+ 'norm_fn': config['norm_fn'],
+ 'n_class': config['n_class'],
+ 'output_dim': config['output_dim'],
+ 'with_cls': config['with_cls'],
+ 'with_sc': config['with_sc'],
+ 'with_score': config['with_score'],
+ }
+ }
+
+ if config['network'] == 'segnet':
+ model = SegNet(model_config.get('network', {}))
+ config['with_cls'] = False
+ elif config['network'] == 'segnetvit':
+ model = SegNetViT(model_config.get('network', {}))
+ config['with_cls'] = False
+ else:
+ raise 'ERROR! {:s} model does not exist'.format(config['network'])
+
+ if config['local_rank'] == 0:
+ if config['weight_path'] is not None:
+ state_dict = torch.load(osp.join(config['save_path'], config['weight_path']), map_location='cpu')['model']
+ model.load_state_dict(state_dict, strict=True)
+ print('Load weight from {:s}'.format(osp.join(config['save_path'], config['weight_path'])))
+
+ if config['resume_path'] is not None and not config['eval']: # only for training
+ model.load_state_dict(
+ torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'],
+ strict=True)
+ print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path'])))
+
+ return model
+
+
+def setup(rank, world_size):
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '12355'
+ # initialize the process group
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+
+def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms):
+ print('In train_DDP..., rank: ', rank)
+ torch.cuda.set_device(rank)
+
+ device = torch.device(f'cuda:{rank}')
+ if feat_model is not None:
+ feat_model.to(device)
+ model.to(device)
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ setup(rank=rank, world_size=world_size)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set,
+ shuffle=True,
+ rank=rank,
+ num_replicas=world_size,
+ drop_last=True, # important?
+ )
+ train_loader = torch.utils.data.DataLoader(train_set,
+ batch_size=config['batch_size'] // world_size,
+ num_workers=config['workers'] // world_size,
+ # num_workers=1,
+ pin_memory=True,
+ # persistent_workers=True,
+ shuffle=False, # must be False
+ drop_last=True,
+ collate_fn=collect_batch,
+ prefetch_factor=4,
+ sampler=train_sampler)
+ config['local_rank'] = rank
+
+ if rank == 0:
+ test_set = test_set
+ else:
+ test_set = None
+
+ trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set,
+ config=config, img_transforms=img_transforms)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ with open(args.config, 'rt') as f:
+ config = yaml.load(f, Loader=yaml.Loader)
+ torch_set_gpu(gpus=config['gpu'])
+ if config['local_rank'] == 0:
+ print(config)
+
+ if config['feature'] == 'spp':
+ img_transforms = None
+ else:
+ img_transforms = []
+ img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
+ img_transforms = tvt.Compose(img_transforms)
+ feat_model, desc_compressor = load_feat_network(config=config)
+
+ dataset = config['dataset']
+ if config['eval'] or config['loc']:
+ if not config['online']:
+ from localization.loc_by_rec_eval import loc_by_rec_eval
+
+ test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=1)
+ config['n_class'] = test_set.n_class
+
+ model = get_model(config=config)
+ loc_by_rec_eval(rec_model=model.cuda().eval(),
+ loader=test_set,
+ local_feat=feat_model.cuda().eval(),
+ config=config, img_transforms=img_transforms)
+ else:
+ from localization.loc_by_rec_online import loc_by_rec_online
+
+ model = get_model(config=config)
+ loc_by_rec_online(rec_model=model.cuda().eval(),
+ local_feat=feat_model.cuda().eval(),
+ config=config, img_transforms=img_transforms)
+ exit(0)
+
+ train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None)
+ if config['do_eval']:
+ test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None)
+ else:
+ test_set = None
+ config['n_class'] = train_set.n_class
+ model = get_model(config=config)
+
+ if not config['with_dist'] or len(config['gpu']) == 1:
+ config['with_dist'] = False
+ model = model.cuda()
+ train_loader = Data.DataLoader(dataset=train_set,
+ shuffle=True,
+ batch_size=config['batch_size'],
+ drop_last=True,
+ collate_fn=collect_batch,
+ num_workers=config['workers'])
+ if test_set is not None:
+ test_loader = Data.DataLoader(dataset=test_set,
+ shuffle=False,
+ batch_size=1,
+ drop_last=False,
+ collate_fn=collect_batch,
+ num_workers=4)
+ else:
+ test_loader = None
+ trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader,
+ config=config, img_transforms=img_transforms)
+ trainer.train()
+ else:
+ mp.spawn(train_DDP, nprocs=len(config['gpu']),
+ args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms),
+ join=True)
diff --git a/third_party/pram/nets/adagml.py b/third_party/pram/nets/adagml.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6980334a8980a105dc91d4586b3a342fb4e648e
--- /dev/null
+++ b/third_party/pram/nets/adagml.py
@@ -0,0 +1,536 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> adagml
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 11/02/2024 14:29
+=================================================='''
+import torch
+from torch import nn
+import torch.nn.functional as F
+from typing import Callable
+import time
+import numpy as np
+
+torch.backends.cudnn.deterministic = True
+
+eps = 1e-8
+
+
+def arange_like(x, dim: int):
+ return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
+
+
+def dual_softmax(M, dustbin):
+ M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
+ M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
+ score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1)
+ return torch.exp(score)
+
+
+def sinkhorn(M, r, c, iteration):
+ p = torch.softmax(M, dim=-1)
+ u = torch.ones_like(r)
+ v = torch.ones_like(c)
+ for _ in range(iteration):
+ u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps)
+ v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps)
+ p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
+ return p
+
+
+def sink_algorithm(M, dustbin, iteration):
+ M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
+ M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
+ r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda')
+ r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1)
+ c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda')
+ c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1)
+ p = sinkhorn(M, r, c, iteration)
+ return p
+
+
+def normalize_keypoints(kpts, image_shape):
+ """ Normalize keypoints locations based on image image_shape"""
+ _, _, height, width = image_shape
+ one = kpts.new_tensor(1)
+ size = torch.stack([one * width, one * height])[None]
+ center = size / 2
+ scaling = size.max(1, keepdim=True).values * 0.7
+ return (kpts - center[:, None, :]) / scaling[:, None, :]
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ x = x.unflatten(-1, (-1, 2))
+ x1, x2 = x.unbind(dim=-1)
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
+
+
+def apply_cached_rotary_emb(
+ freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ return (t * freqs[0]) + (rotate_half(t) * freqs[1])
+
+
+class LearnableFourierPositionalEncoding(nn.Module):
+ def __init__(self, M: int, dim: int, F_dim: int = None,
+ gamma: float = 1.0) -> None:
+ super().__init__()
+ F_dim = F_dim if F_dim is not None else dim
+ self.gamma = gamma
+ self.Wr = nn.Linear(M, F_dim // 2, bias=False)
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ encode position vector """
+ projected = self.Wr(x)
+ cosines, sines = torch.cos(projected), torch.sin(projected)
+ emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
+ return emb.repeat_interleave(2, dim=-1)
+
+
+class KeypointEncoder(nn.Module):
+ """ Joint encoding of visual appearance and location using MLPs"""
+
+ def __init__(self):
+ super().__init__()
+ self.encoder = nn.Sequential(
+ nn.Linear(3, 32),
+ nn.LayerNorm(32, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(32, 64),
+ nn.LayerNorm(64, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(64, 128),
+ nn.LayerNorm(128, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(128, 256),
+ )
+
+ def forward(self, kpts, scores):
+ inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1]
+ return self.encoder(torch.cat(inputs, dim=-1))
+
+
+class PoolingLayer(nn.Module):
+ def __init__(self, hidden_dim: int, score_dim: int = 2):
+ super().__init__()
+
+ self.score_enc = nn.Sequential(
+ nn.Linear(score_dim, hidden_dim),
+ nn.LayerNorm(hidden_dim, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ )
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
+ self.predict = nn.Sequential(
+ nn.Linear(hidden_dim * 2, hidden_dim),
+ nn.LayerNorm(hidden_dim, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(hidden_dim, 1),
+ )
+
+ def forward(self, x, score):
+ score_ = self.score_enc(score)
+ x_ = self.proj(x)
+ confidence = self.predict(torch.cat([x_, score_], -1))
+ confidence = torch.sigmoid(confidence)
+
+ return confidence
+
+
+class Attention(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, q, k, v):
+ s = q.shape[-1] ** -0.5
+ attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1)
+ return torch.einsum('...ij,...jd->...id', attn, v), torch.mean(torch.mean(attn, dim=1), dim=1)
+
+
+class SelfMultiHeadAttention(nn.Module):
+ def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int):
+ super().__init__()
+ self.feat_dim = feat_dim
+ self.num_heads = num_heads
+
+ assert feat_dim % num_heads == 0
+ self.head_dim = feat_dim // num_heads
+ self.qkv = nn.Linear(feat_dim, hidden_dim * 3)
+ self.attn = Attention()
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(feat_dim + hidden_dim, feat_dim * 2),
+ nn.LayerNorm(feat_dim * 2, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(feat_dim * 2, feat_dim)
+ )
+
+ def forward_(self, x, encoding=None):
+ qkv = self.qkv(x)
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
+ if encoding is not None:
+ q = apply_cached_rotary_emb(encoding, q)
+ k = apply_cached_rotary_emb(encoding, k)
+ attn, attn_score = self.attn(q, k, v)
+ message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2))
+ return x + self.mlp(torch.cat([x, message], -1)), attn_score
+
+ def forward(self, x0, x1, encoding0=None, encoding1=None):
+ x0_, att_score00 = self.forward_(x=x0, encoding=encoding0)
+ x1_, att_score11 = self.forward_(x=x1, encoding=encoding1)
+ return x0_, x1_, att_score00, att_score11
+
+
+class CrossMultiHeadAttention(nn.Module):
+ def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int):
+ super().__init__()
+ self.feat_dim = feat_dim
+ self.num_heads = num_heads
+ assert hidden_dim % num_heads == 0
+ dim_head = hidden_dim // num_heads
+ self.scale = dim_head ** -0.5
+ self.to_qk = nn.Linear(feat_dim, hidden_dim)
+ self.to_v = nn.Linear(feat_dim, hidden_dim)
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(feat_dim + hidden_dim, feat_dim * 2),
+ nn.LayerNorm(feat_dim * 2, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(feat_dim * 2, feat_dim),
+ )
+
+ def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
+ return func(x0), func(x1)
+
+ def forward(self, x0, x1):
+ qk0 = self.to_qk(x0)
+ qk1 = self.to_qk(x1)
+ v0 = self.to_v(x0)
+ v1 = self.to_v(x1)
+
+ qk0, qk1, v0, v1 = map(
+ lambda t: t.unflatten(-1, (self.num_heads, -1)).transpose(1, 2),
+ (qk0, qk1, v0, v1))
+
+ qk0, qk1 = qk0 * self.scale ** 0.5, qk1 * self.scale ** 0.5
+ sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1)
+ attn01 = F.softmax(sim, dim=-1)
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
+ m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1)
+ m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0)
+
+ m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
+ m0, m1)
+ m0, m1 = self.map_(self.proj, m0, m1)
+ x0 = x0 + self.mlp(torch.cat([x0, m0], -1))
+ x1 = x1 + self.mlp(torch.cat([x1, m1], -1))
+ return x0, x1, torch.mean(torch.mean(attn10, dim=1), dim=1), torch.mean(torch.mean(attn01, dim=1), dim=1)
+
+
+class AdaGML(nn.Module):
+ default_config = {
+ 'descriptor_dim': 128,
+ 'hidden_dim': 256,
+ 'weights': 'indoor',
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
+ 'sinkhorn_iterations': 20,
+ 'match_threshold': 0.2,
+ 'with_pose': True,
+ 'n_layers': 9,
+ 'n_min_tokens': 256,
+ 'with_sinkhorn': True,
+ 'min_confidence': 0.9,
+
+ 'classification_background_weight': 0.05,
+ 'pretrained': True,
+ }
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = {**self.default_config, **config}
+ self.n_layers = self.config['n_layers']
+ self.first_layer_pooling = 0
+ self.n_min_tokens = self.config['n_min_tokens']
+ self.min_confidence = self.config['min_confidence']
+ self.classification_background_weight = self.config['classification_background_weight']
+
+ self.with_sinkhorn = self.config['with_sinkhorn']
+ self.match_threshold = self.config['match_threshold']
+ self.sinkhorn_iterations = self.config['sinkhorn_iterations']
+
+ self.input_proj = nn.Linear(self.config['descriptor_dim'], self.config['hidden_dim'])
+
+ self.self_attn = nn.ModuleList(
+ [SelfMultiHeadAttention(feat_dim=self.config['hidden_dim'],
+ hidden_dim=self.config['hidden_dim'],
+ num_heads=4) for _ in range(self.n_layers)]
+ )
+ self.cross_attn = nn.ModuleList(
+ [CrossMultiHeadAttention(feat_dim=self.config['hidden_dim'],
+ hidden_dim=self.config['hidden_dim'],
+ num_heads=4) for _ in range(self.n_layers)]
+ )
+
+ head_dim = self.config['hidden_dim'] // 4
+ self.poseenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim)
+ self.out_proj = nn.ModuleList(
+ [nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) for _ in range(self.n_layers)]
+ )
+
+ bin_score = torch.nn.Parameter(torch.tensor(1.))
+ self.register_parameter('bin_score', bin_score)
+
+ self.pooling = nn.ModuleList(
+ [PoolingLayer(score_dim=2, hidden_dim=self.config['hidden_dim']) for _ in range(self.n_layers)]
+ )
+ # self.pretrained = config['pretrained']
+ # if self.pretrained:
+ # bin_score.requires_grad = False
+ # for m in [self.input_proj, self.out_proj, self.poseenc, self.self_attn, self.cross_attn]:
+ # for p in m.parameters():
+ # p.requires_grad = False
+
+ def forward(self, data, mode=0):
+ if not self.training:
+ if mode == 0:
+ return self.produce_matches(data=data)
+ else:
+ return self.run(data=data)
+ return self.forward_train(data=data)
+
+ def forward_train(self, data: dict, p=0.2, **kwargs):
+ pass
+
+ def produce_matches(self, data: dict, p: float = 0.2, **kwargs):
+ desc0, desc1 = data['descriptors0'], data['descriptors1']
+ kpts0, kpts1 = data['keypoints0'], data['keypoints1']
+ scores0, scores1 = data['scores0'], data['scores1']
+
+ # Keypoint normalization.
+ if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys():
+ norm_kpts0 = data['norm_keypoints0']
+ norm_kpts1 = data['norm_keypoints1']
+ elif 'image0' in data.keys() and 'image1' in data.keys():
+ norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
+ norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
+ elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys():
+ norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0'])
+ norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1'])
+ else:
+ raise ValueError('Require image shape for keypoint coordinate normalization')
+
+ desc0 = desc0.detach() # [B, N, D]
+ desc1 = desc1.detach()
+
+ desc0 = self.input_proj(desc0)
+ desc1 = self.input_proj(desc1)
+ enc0 = self.poseenc(norm_kpts0)
+ enc1 = self.poseenc(norm_kpts1)
+
+ nI = self.config['n_layers']
+ nB = desc0.shape[0]
+ m = desc0.shape[1]
+ n = desc1.shape[1]
+ dev = desc0.device
+
+ ind0 = torch.arange(0, m, device=dev)[None]
+ ind1 = torch.arange(0, n, device=dev)[None]
+
+ do_pooling = True
+
+ for ni in range(nI):
+ desc0, desc1, att_score00, att_score11 = self.self_attn[ni](desc0, desc1, enc0, enc1)
+ desc0, desc1, att_score01, att_score10 = self.cross_attn[ni](desc0, desc1)
+
+ att_score0 = torch.cat([att_score00.unsqueeze(-1), att_score01.unsqueeze(-1)], dim=-1)
+ att_score1 = torch.cat([att_score11.unsqueeze(-1), att_score10.unsqueeze(-1)], dim=-1)
+
+ conf0 = self.pooling[ni](desc0, att_score0).squeeze(-1)
+ conf1 = self.pooling[ni](desc1, att_score1).squeeze(-1)
+
+ if do_pooling and ni >= 1:
+ if desc0.shape[1] >= self.n_min_tokens:
+ mask0 = conf0 > self.confidence_threshold(layer_index=ni)
+ ind0 = ind0[mask0][None]
+ desc0 = desc0[mask0][None]
+ enc0 = enc0[:, :, mask0][:, None]
+
+ if desc1.shape[1] >= self.n_min_tokens:
+ mask1 = conf1 > self.confidence_threshold(layer_index=ni)
+ ind1 = ind1[mask1][None]
+ desc1 = desc1[mask1][None]
+ enc1 = enc1[:, :, mask1][:, None]
+
+ # print('pooling: ', ni, desc0.shape, desc1.shape)
+ # print('ni: {:d}: pooling: {:.4f}'.format(ni, time.time() - t_start))
+ # t_start = time.time()
+ if self.check_if_stop(confidences0=conf0, confidences1=conf1, layer_index=ni, num_points=m + n):
+ # print('ni:{:d}: checking: {:.4f}'.format(ni, time.time() - t_start))
+ break
+
+ if ni == nI: ni = nI - 1
+ d = desc0.shape[-1]
+ mdesc0 = self.out_proj[ni](desc0) / d ** .25
+ mdesc1 = self.out_proj[ni](desc1) / d ** .25
+
+ dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
+ score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations)
+ indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p)
+ valid = indices0 > -1
+ m_indices0 = torch.where(valid)[1]
+ m_indices1 = indices0[valid]
+
+ mind0 = ind0[0, m_indices0]
+ mind1 = ind1[0, m_indices1]
+
+ indices0_full = torch.full((nB, m), -1, device=dev, dtype=indices0.dtype)
+ indices0_full[:, mind0] = mind1
+
+ mscores0_full = torch.zeros((nB, m), device=dev)
+ mscores0_full[:, ind0] = mscores0
+
+ indices0 = indices0_full
+ mscores0 = mscores0_full
+
+ output = {
+ 'matches0': indices0, # use -1 for invalid match
+ # 'matches1': indices1, # use -1 for invalid match
+ 'matching_scores0': mscores0,
+ }
+
+ return output
+
+ def run(self, data, p=0.2):
+ desc0 = data['desc1']
+ # print('desc0: ', torch.sum(desc0 ** 2, dim=-1))
+ # desc0 = torch.nn.functional.normalize(desc0, dim=-1)
+ desc0 = desc0.detach()
+
+ desc1 = data['desc2']
+ # desc1 = torch.nn.functional.normalize(desc1, dim=-1)
+ desc1 = desc1.detach()
+
+ kpts0 = data['x1'][:, :, :2]
+ kpts1 = data['x2'][:, :, :2]
+ # kpts0 = normalize_keypoints(kpts=kpts0, image_shape=data['image_shape1'])
+ # kpts1 = normalize_keypoints(kpts=kpts1, image_shape=data['image_shape2'])
+ scores0 = data['x1'][:, :, -1]
+ scores1 = data['x2'][:, :, -1]
+
+ desc0 = self.input_proj(desc0)
+ desc1 = self.input_proj(desc1)
+ enc0 = self.poseenc(kpts0)
+ enc1 = self.poseenc(kpts1)
+
+ nB = desc0.shape[0]
+ nI = self.n_layers
+ m, n = desc0.shape[1], desc1.shape[1]
+ dev = desc0.device
+ ind0 = torch.arange(0, m, device=dev)[None]
+ ind1 = torch.arange(0, n, device=dev)[None]
+ do_pooling = True
+
+ for ni in range(nI):
+ desc0, desc1, att_score00, att_score11 = self.self_attn[ni](desc0, desc1, enc0, enc1)
+ desc0, desc1, att_score01, att_score10 = self.cross_attn[ni](desc0, desc1)
+
+ att_score0 = torch.cat([att_score00.unsqueeze(-1), att_score01.unsqueeze(-1)], dim=-1)
+ att_score1 = torch.cat([att_score11.unsqueeze(-1), att_score10.unsqueeze(-1)], dim=-1)
+
+ conf0 = self.pooling[ni](desc0, att_score0).squeeze(-1)
+ conf1 = self.pooling[ni](desc1, att_score1).squeeze(-1)
+
+ if do_pooling and ni >= 1:
+ if desc0.shape[1] >= self.n_min_tokens:
+ mask0 = conf0 > self.confidence_threshold(layer_index=ni)
+ ind0 = ind0[mask0][None]
+ desc0 = desc0[mask0][None]
+ enc0 = enc0[:, :, mask0][:, None]
+
+ if desc1.shape[1] >= self.n_min_tokens:
+ mask1 = conf1 > self.confidence_threshold(layer_index=ni)
+ ind1 = ind1[mask1][None]
+ desc1 = desc1[mask1][None]
+ enc1 = enc1[:, :, mask1][:, None]
+ if desc0.shape[1] <= 5 or desc1.shape[1] <= 5:
+ return {
+ 'index0': torch.zeros(size=(1,), device=desc0.device).long(),
+ 'index1': torch.zeros(size=(1,), device=desc1.device).long(),
+ }
+
+ if self.check_if_stop(confidences0=conf0, confidences1=conf1, layer_index=ni,
+ num_points=m + n):
+ break
+
+ if ni == nI: ni = -1
+ d = desc0.shape[-1]
+ mdesc0 = self.out_proj[ni](desc0) / d ** .25
+ mdesc1 = self.out_proj[ni](desc1) / d ** .25
+
+ dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
+ score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations)
+ indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p)
+ valid = indices0 > -1
+ m_indices0 = torch.where(valid)[1]
+ m_indices1 = indices0[valid]
+
+ mind0 = ind0[0, m_indices0]
+ mind1 = ind1[0, m_indices1]
+
+ output = {
+ # 'p': score,
+ 'index0': mind0,
+ 'index1': mind1,
+ }
+
+ return output
+
+ def compute_score(self, dist, dustbin, iteration):
+ if self.with_sinkhorn:
+ score = sink_algorithm(M=dist, dustbin=dustbin,
+ iteration=iteration) # [nI * nB, N, M]
+ else:
+ score = dual_softmax(M=dist, dustbin=dustbin)
+ return score
+
+ def compute_matches(self, scores, p=0.2):
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
+ indices0, indices1 = max0.indices, max1.indices
+ mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
+ mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
+ zero = scores.new_tensor(0)
+ # mscores0 = torch.where(mutual0, max0.values.exp(), zero)
+ mscores0 = torch.where(mutual0, max0.values, zero)
+ mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
+ # valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
+ valid0 = mutual0 & (mscores0 > p)
+ valid1 = mutual1 & valid0.gather(1, indices1)
+ indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
+ indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
+
+ return indices0, indices1, mscores0, mscores1
+
+ def confidence_threshold(self, layer_index: int):
+ """scaled confidence threshold"""
+ # threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
+ threshold = 0.5 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
+ return np.clip(threshold, 0, 1)
+
+ def check_if_stop(self,
+ confidences0: torch.Tensor,
+ confidences1: torch.Tensor,
+ layer_index: int, num_points: int) -> torch.Tensor:
+ """ evaluate stopping condition"""
+ confidences = torch.cat([confidences0, confidences1], -1)
+ threshold = self.confidence_threshold(layer_index)
+ pos = 1.0 - (confidences < threshold).float().sum() / num_points
+ # print('check_stop: ', pos)
+ return pos > 0.95
+
+ def stop_iteration(self, m_last, n_last, m_current, n_current, confidence=0.975):
+ prob = (m_current + n_current) / (m_last + n_last)
+ # print('prob: ', prob)
+ return prob > confidence
diff --git a/third_party/pram/nets/gm.py b/third_party/pram/nets/gm.py
new file mode 100644
index 0000000000000000000000000000000000000000..232a364ce60acb49cb6af26b72a881cbec18c1a9
--- /dev/null
+++ b/third_party/pram/nets/gm.py
@@ -0,0 +1,264 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> gm
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 10:47
+=================================================='''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from nets.layers import KeypointEncoder, AttentionalPropagation
+from nets.utils import normalize_keypoints, arange_like
+
+eps = 1e-8
+
+
+def dual_softmax(M, dustbin):
+ M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
+ M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
+ score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1)
+ return torch.exp(score)
+
+
+def sinkhorn(M, r, c, iteration):
+ p = torch.softmax(M, dim=-1)
+ u = torch.ones_like(r)
+ v = torch.ones_like(c)
+ for _ in range(iteration):
+ u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps)
+ v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps)
+ p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
+ return p
+
+
+def sink_algorithm(M, dustbin, iteration):
+ M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
+ M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
+ r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda')
+ r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1)
+ c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda')
+ c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1)
+ p = sinkhorn(M, r, c, iteration)
+ return p
+
+
+class AttentionalGNN(nn.Module):
+ def __init__(self, feature_dim: int, layer_names: list, hidden_dim: int = 256, ac_fn: str = 'relu',
+ norm_fn: str = 'bn'):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ AttentionalPropagation(feature_dim=feature_dim, num_heads=4, hidden_dim=hidden_dim, ac_fn=ac_fn,
+ norm_fn=norm_fn)
+ for _ in range(len(layer_names))])
+ self.names = layer_names
+
+ def forward(self, desc0, desc1):
+ # desc0s = []
+ # desc1s = []
+
+ for i, (layer, name) in enumerate(zip(self.layers, self.names)):
+ if name == 'cross':
+ src0, src1 = desc1, desc0
+ else:
+ src0, src1 = desc0, desc1
+ delta0 = layer(desc0, src0)
+ # prob0 = layer.attn.prob
+ delta1 = layer(desc1, src1)
+ # prob1 = layer.attn.prob
+ desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
+
+ # if name == 'cross':
+ # desc0s.append(desc0)
+ # desc1s.append(desc1)
+ return [desc0], [desc1]
+
+ def predict(self, desc0, desc1, n_it=-1):
+ for i, (layer, name) in enumerate(zip(self.layers, self.names)):
+ if name == 'cross':
+ src0, src1 = desc1, desc0
+ else:
+ src0, src1 = desc0, desc1
+ delta0 = layer(desc0, src0)
+ # prob0 = layer.attn.prob
+ delta1 = layer(desc1, src1)
+ # prob1 = layer.attn.prob
+ desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
+
+ if name == 'cross' and i == n_it:
+ break
+ return [desc0], [desc1]
+
+
+class GM(nn.Module):
+ default_config = {
+ 'descriptor_dim': 128,
+ 'hidden_dim': 256,
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
+ 'sinkhorn_iterations': 20,
+ 'match_threshold': 0.2,
+ 'with_pose': False,
+ 'n_layers': 9,
+ 'n_min_tokens': 256,
+ 'with_sinkhorn': True,
+
+ 'ac_fn': 'relu',
+ 'norm_fn': 'bn',
+ 'weight_path': None,
+ }
+
+ required_inputs = [
+ 'image0', 'keypoints0', 'scores0', 'descriptors0',
+ 'image1', 'keypoints1', 'scores1', 'descriptors1',
+ ]
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = {**self.default_config, **config}
+ print('gm: ', self.config)
+
+ self.n_layers = self.config['n_layers']
+
+ self.with_sinkhorn = self.config['with_sinkhorn']
+ self.match_threshold = self.config['match_threshold']
+
+ self.sinkhorn_iterations = self.config['sinkhorn_iterations']
+ self.kenc = KeypointEncoder(
+ self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
+ self.config['keypoint_encoder'],
+ ac_fn=self.config['ac_fn'],
+ norm_fn=self.config['norm_fn'])
+ self.gnn = AttentionalGNN(
+ feature_dim=self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
+ hidden_dim=self.config['hidden_dim'],
+ layer_names=self.config['GNN_layers'],
+ ac_fn=self.config['ac_fn'],
+ norm_fn=self.config['norm_fn'],
+ )
+
+ self.final_proj = nn.ModuleList([nn.Conv1d(
+ self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
+ self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
+ kernel_size=1, bias=True) for _ in range(self.n_layers)])
+
+ bin_score = torch.nn.Parameter(torch.tensor(1.))
+ self.register_parameter('bin_score', bin_score)
+
+ self.match_net = None # GraphLoss(config=self.config)
+
+ self.self_prob0 = None
+ self.self_prob1 = None
+ self.cross_prob0 = None
+ self.cross_prob1 = None
+
+ self.desc_compressor = None
+
+ def forward_train(self, data):
+ pass
+
+ def produce_matches(self, data, p=0.2, n_it=-1, **kwargs):
+ kpts0, kpts1 = data['keypoints0'], data['keypoints1']
+ scores0, scores1 = data['scores0'], data['scores1']
+ if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints
+ shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
+ return {
+ 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
+ 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
+ 'matching_scores0': kpts0.new_zeros(shape0)[0],
+ 'matching_scores1': kpts1.new_zeros(shape1)[0],
+ 'skip_train': True
+ }
+
+ if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys():
+ norm_kpts0 = data['norm_keypoints0']
+ norm_kpts1 = data['norm_keypoints1']
+ elif 'image0' in data.keys() and 'image1' in data.keys():
+ norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
+ norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
+ elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys():
+ norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0'])
+ norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1'])
+ else:
+ raise ValueError('Require image shape for keypoint coordinate normalization')
+
+ # Keypoint MLP encoder.
+ enc0, enc1 = self.encode_keypoint(norm_kpts0=norm_kpts0, norm_kpts1=norm_kpts1, scores0=scores0,
+ scores1=scores1)
+
+ if self.config['descriptor_dim'] > 0:
+ desc0, desc1 = data['descriptors0'], data['descriptors1']
+ desc0 = desc0.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N]
+ desc1 = desc1.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N]
+ with torch.no_grad():
+ if desc0.shape[1] != self.config['descriptor_dim']:
+ desc0 = self.desc_compressor(desc0)
+ if desc1.shape[1] != self.config['descriptor_dim']:
+ desc1 = self.desc_compressor(desc1)
+ desc0 = desc0 + enc0
+ desc1 = desc1 + enc1
+ else:
+ desc0 = enc0
+ desc1 = enc1
+
+ desc0s, desc1s = self.gnn.predict(desc0, desc1, n_it=n_it)
+
+ mdescs0 = self.final_proj[n_it](desc0s[-1])
+ mdescs1 = self.final_proj[n_it](desc1s[-1])
+ dist = torch.einsum('bdn,bdm->bnm', mdescs0, mdescs1)
+ if self.config['descriptor_dim'] > 0:
+ dist = dist / self.config['descriptor_dim'] ** .5
+ else:
+ dist = dist / 128 ** .5
+ score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations)
+
+ indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p)
+
+ output = {
+ 'matches0': indices0, # use -1 for invalid match
+ 'matches1': indices1, # use -1 for invalid match
+ 'matching_scores0': mscores0,
+ 'matching_scores1': mscores1,
+ }
+
+ return output
+
+ def forward(self, data, mode=0):
+ if not self.training:
+ return self.produce_matches(data=data, n_it=-1)
+ return self.forward_train(data=data)
+
+ def encode_keypoint(self, norm_kpts0, norm_kpts1, scores0, scores1):
+ return self.kenc(norm_kpts0, scores0), self.kenc(norm_kpts1, scores1)
+
+ def compute_distance(self, desc0, desc1, layer_id=-1):
+ mdesc0 = self.final_proj[layer_id](desc0)
+ mdesc1 = self.final_proj[layer_id](desc1)
+ dist = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
+ dist = dist / self.config['descriptor_dim'] ** .5
+ return dist
+
+ def compute_score(self, dist, dustbin, iteration):
+ if self.with_sinkhorn:
+ score = sink_algorithm(M=dist, dustbin=dustbin,
+ iteration=iteration) # [nI * nB, N, M]
+ else:
+ score = dual_softmax(M=dist, dustbin=dustbin)
+ return score
+
+ def compute_matches(self, scores, p=0.2):
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
+ indices0, indices1 = max0.indices, max1.indices
+ mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
+ mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
+ zero = scores.new_tensor(0)
+ # mscores0 = torch.where(mutual0, max0.values.exp(), zero)
+ mscores0 = torch.where(mutual0, max0.values, zero)
+ mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
+ # valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
+ valid0 = mutual0 & (mscores0 > p)
+ valid1 = mutual1 & valid0.gather(1, indices1)
+ indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
+ indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
+
+ return indices0, indices1, mscores0, mscores1
diff --git a/third_party/pram/nets/gml.py b/third_party/pram/nets/gml.py
new file mode 100644
index 0000000000000000000000000000000000000000..996de5f01211e0a315f7f9b4ce35d561dfc74b2f
--- /dev/null
+++ b/third_party/pram/nets/gml.py
@@ -0,0 +1,319 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> gml
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 10:56
+=================================================='''
+import torch
+from torch import nn
+import torch.nn.functional as F
+from typing import Callable
+from .utils import arange_like, normalize_keypoints
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+torch.backends.cudnn.deterministic = True
+
+eps = 1e-8
+
+
+def dual_softmax(M, dustbin):
+ M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
+ M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
+ score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1)
+ return torch.exp(score)
+
+
+def sinkhorn(M, r, c, iteration):
+ p = torch.softmax(M, dim=-1)
+ u = torch.ones_like(r)
+ v = torch.ones_like(c)
+ for _ in range(iteration):
+ u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps)
+ v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps)
+ p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
+ return p
+
+
+def sink_algorithm(M, dustbin, iteration):
+ M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
+ M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
+ r = torch.ones([M.shape[0], M.shape[1] - 1], device=device)
+ r = torch.cat([r, torch.ones([M.shape[0], 1], device=device) * M.shape[1]], dim=-1)
+ c = torch.ones([M.shape[0], M.shape[2] - 1], device=device)
+ c = torch.cat([c, torch.ones([M.shape[0], 1], device=device) * M.shape[2]], dim=-1)
+ p = sinkhorn(M, r, c, iteration)
+ return p
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ x = x.unflatten(-1, (-1, 2))
+ x1, x2 = x.unbind(dim=-1)
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
+
+
+def apply_cached_rotary_emb(
+ freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ return (t * freqs[0]) + (rotate_half(t) * freqs[1])
+
+
+class LearnableFourierPositionalEncoding(nn.Module):
+ def __init__(self, M: int, dim: int, F_dim: int = None,
+ gamma: float = 1.0) -> None:
+ super().__init__()
+ F_dim = F_dim if F_dim is not None else dim
+ self.gamma = gamma
+ self.Wr = nn.Linear(M, F_dim // 2, bias=False)
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ encode position vector """
+ projected = self.Wr(x)
+ cosines, sines = torch.cos(projected), torch.sin(projected)
+ emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
+ return emb.repeat_interleave(2, dim=-1)
+
+
+class KeypointEncoder(nn.Module):
+ """ Joint encoding of visual appearance and location using MLPs"""
+
+ def __init__(self):
+ super().__init__()
+ self.encoder = nn.Sequential(
+ nn.Linear(3, 32),
+ nn.LayerNorm(32, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(32, 64),
+ nn.LayerNorm(64, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(64, 128),
+ nn.LayerNorm(128, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(128, 256),
+ )
+
+ def forward(self, kpts, scores):
+ inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1]
+ return self.encoder(torch.cat(inputs, dim=-1))
+
+
+class Attention(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, q, k, v):
+ s = q.shape[-1] ** -0.5
+ attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1)
+ return torch.einsum('...ij,...jd->...id', attn, v)
+
+
+class SelfMultiHeadAttention(nn.Module):
+ def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int):
+ super().__init__()
+ self.feat_dim = feat_dim
+ self.num_heads = num_heads
+
+ assert feat_dim % num_heads == 0
+ self.head_dim = feat_dim // num_heads
+ self.qkv = nn.Linear(feat_dim, hidden_dim * 3)
+ self.attn = Attention()
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(feat_dim + hidden_dim, feat_dim * 2),
+ nn.LayerNorm(feat_dim * 2, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(feat_dim * 2, feat_dim)
+ )
+
+ def forward_(self, x, encoding=None):
+ qkv = self.qkv(x)
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
+ if encoding is not None:
+ q = apply_cached_rotary_emb(encoding, q)
+ k = apply_cached_rotary_emb(encoding, k)
+ attn = self.attn(q, k, v)
+ message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2))
+ return x + self.mlp(torch.cat([x, message], -1))
+
+ def forward(self, x0, x1, encoding0=None, encoding1=None):
+ return self.forward_(x0, encoding0), self.forward_(x1, encoding1)
+
+
+class CrossMultiHeadAttention(nn.Module):
+ def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int):
+ super().__init__()
+ self.feat_dim = feat_dim
+ self.num_heads = num_heads
+ assert hidden_dim % num_heads == 0
+ dim_head = hidden_dim // num_heads
+ self.scale = dim_head ** -0.5
+ self.to_qk = nn.Linear(feat_dim, hidden_dim)
+ self.to_v = nn.Linear(feat_dim, hidden_dim)
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(feat_dim + hidden_dim, feat_dim * 2),
+ nn.LayerNorm(feat_dim * 2, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(feat_dim * 2, feat_dim),
+ )
+
+ def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
+ return func(x0), func(x1)
+
+ def forward(self, x0, x1):
+ qk0 = self.to_qk(x0)
+ qk1 = self.to_qk(x1)
+ v0 = self.to_v(x0)
+ v1 = self.to_v(x1)
+
+ qk0, qk1, v0, v1 = map(
+ lambda t: t.unflatten(-1, (self.num_heads, -1)).transpose(1, 2),
+ (qk0, qk1, v0, v1))
+
+ qk0, qk1 = qk0 * self.scale ** 0.5, qk1 * self.scale ** 0.5
+ sim = torch.einsum('b h i d, b h j d -> b h i j', qk0, qk1)
+ attn01 = F.softmax(sim, dim=-1)
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
+ m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1)
+ m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0)
+
+ m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
+ m0, m1)
+ m0, m1 = self.map_(self.proj, m0, m1)
+ x0 = x0 + self.mlp(torch.cat([x0, m0], -1))
+ x1 = x1 + self.mlp(torch.cat([x1, m1], -1))
+ return x0, x1
+
+
+class GML(nn.Module):
+ '''
+ the architecture of lightglue, but trained with imp
+ '''
+ default_config = {
+ 'descriptor_dim': 128,
+ 'hidden_dim': 256,
+ 'weights': 'indoor',
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
+ 'sinkhorn_iterations': 20,
+ 'match_threshold': 0.2,
+ 'with_pose': False,
+ 'n_layers': 9,
+ 'n_min_tokens': 256,
+ 'with_sinkhorn': True,
+
+ 'ac_fn': 'relu',
+ 'norm_fn': 'bn',
+
+ }
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = {**self.default_config, **config}
+ self.n_layers = self.config['n_layers']
+
+ self.with_sinkhorn = self.config['with_sinkhorn']
+ self.match_threshold = self.config['match_threshold']
+ self.sinkhorn_iterations = self.config['sinkhorn_iterations']
+
+ self.input_proj = nn.Linear(self.config['descriptor_dim'], self.config['hidden_dim'])
+
+ self.self_attn = nn.ModuleList(
+ [SelfMultiHeadAttention(feat_dim=self.config['hidden_dim'],
+ hidden_dim=self.config['hidden_dim'],
+ num_heads=4) for _ in range(self.n_layers)]
+ )
+ self.cross_attn = nn.ModuleList(
+ [CrossMultiHeadAttention(feat_dim=self.config['hidden_dim'],
+ hidden_dim=self.config['hidden_dim'],
+ num_heads=4) for _ in range(self.n_layers)]
+ )
+
+ head_dim = self.config['hidden_dim'] // 4
+ self.poseenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim)
+ self.out_proj = nn.ModuleList(
+ [nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) for _ in range(self.n_layers)]
+ )
+
+ bin_score = torch.nn.Parameter(torch.tensor(1.))
+ self.register_parameter('bin_score', bin_score)
+
+ def forward(self, data, mode=0):
+ if not self.training:
+ return self.produce_matches(data=data)
+ return self.forward_train(data=data)
+
+ def forward_train(self, data: dict, p=0.2, **kwargs):
+ pass
+
+ def produce_matches(self, data: dict, p=0.2, **kwargs):
+ desc0, desc1 = data['descriptors0'], data['descriptors1']
+ kpts0, kpts1 = data['keypoints0'], data['keypoints1']
+ # Keypoint normalization.
+ if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys():
+ norm_kpts0 = data['norm_keypoints0']
+ norm_kpts1 = data['norm_keypoints1']
+ elif 'image0' in data.keys() and 'image1' in data.keys():
+ norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape).float()
+ norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape).float()
+ elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys():
+ norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']).float()
+ norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']).float()
+ else:
+ raise ValueError('Require image shape for keypoint coordinate normalization')
+
+ desc0 = self.input_proj(desc0)
+ desc1 = self.input_proj(desc1)
+ enc0 = self.poseenc(norm_kpts0)
+ enc1 = self.poseenc(norm_kpts1)
+
+ nI = self.n_layers
+ # nI = 5
+
+ for i in range(nI):
+ desc0, desc1 = self.self_attn[i](desc0, desc1, enc0, enc1)
+ desc0, desc1 = self.cross_attn[i](desc0, desc1)
+
+ d = desc0.shape[-1]
+ mdesc0 = self.out_proj[nI - 1](desc0) / d ** .25
+ mdesc1 = self.out_proj[nI - 1](desc1) / d ** .25
+
+ dist = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
+
+ score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations)
+ indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p)
+
+ output = {
+ 'matches0': indices0, # use -1 for invalid match
+ 'matches1': indices1, # use -1 for invalid match
+ 'matching_scores0': mscores0,
+ 'matching_scores1': mscores1,
+ }
+
+ return output
+
+ def compute_score(self, dist, dustbin, iteration):
+ if self.with_sinkhorn:
+ score = sink_algorithm(M=dist, dustbin=dustbin,
+ iteration=iteration) # [nI * nB, N, M]
+ else:
+ score = dual_softmax(M=dist, dustbin=dustbin)
+ return score
+
+ def compute_matches(self, scores, p=0.2):
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
+ indices0, indices1 = max0.indices, max1.indices
+ mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
+ mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
+ zero = scores.new_tensor(0)
+ # mscores0 = torch.where(mutual0, max0.values.exp(), zero)
+ mscores0 = torch.where(mutual0, max0.values, zero)
+ mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
+ # valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
+ valid0 = mutual0 & (mscores0 > p)
+ valid1 = mutual1 & valid0.gather(1, indices1)
+ indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
+ indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
+
+ return indices0, indices1, mscores0, mscores1
diff --git a/third_party/pram/nets/layers.py b/third_party/pram/nets/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..417488e6a163327895eb435567c4255c7827bca2
--- /dev/null
+++ b/third_party/pram/nets/layers.py
@@ -0,0 +1,109 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> layers
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:46
+=================================================='''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+from einops import rearrange
+
+
+def MLP(channels: list, do_bn=True, ac_fn='relu', norm_fn='bn'):
+ """ Multi-layer perceptron """
+ n = len(channels)
+ layers = []
+ for i in range(1, n):
+ layers.append(
+ nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
+ if i < (n - 1):
+ if norm_fn == 'in':
+ layers.append(nn.InstanceNorm1d(channels[i], eps=1e-3))
+ elif norm_fn == 'bn':
+ layers.append(nn.BatchNorm1d(channels[i], eps=1e-3))
+ if ac_fn == 'relu':
+ layers.append(nn.ReLU())
+ elif ac_fn == 'gelu':
+ layers.append(nn.GELU())
+ elif ac_fn == 'lrelu':
+ layers.append(nn.LeakyReLU(negative_slope=0.1))
+ # if norm_fn == 'ln':
+ # layers.append(nn.LayerNorm(channels[i]))
+ return nn.Sequential(*layers)
+
+
+class MultiHeadedAttention(nn.Module):
+ def __init__(self, num_heads: int, d_model: int):
+ super().__init__()
+ assert d_model % num_heads == 0
+ self.dim = d_model // num_heads
+ self.num_heads = num_heads
+ self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
+ self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
+
+ def forward(self, query, key, value, M=None):
+ '''
+ :param query: [B, D, N]
+ :param key: [B, D, M]
+ :param value: [B, D, M]
+ :param M: [B, N, M]
+ :return:
+ '''
+
+ batch_dim = query.size(0)
+ query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
+ for l, x in zip(self.proj, (query, key, value))] # [B, D, NH, N]
+ dim = query.shape[1]
+ scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5
+
+ if M is not None:
+ # print('M: ', scores.shape, M.shape, torch.sum(M, dim=2))
+ # scores = scores * M[:, None, :, :].expand_as(scores)
+ # with torch.no_grad():
+ mask = (1 - M[:, None, :, :]).repeat(1, scores.shape[1], 1, 1).bool() # [B, H, N, M]
+ scores = scores.masked_fill(mask, -torch.finfo(scores.dtype).max)
+ prob = F.softmax(scores, dim=-1) # * (~mask).float() # * mask.float()
+ else:
+ prob = F.softmax(scores, dim=-1)
+
+ x = torch.einsum('bhnm,bdhm->bdhn', prob, value)
+ self.prob = prob
+
+ out = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1))
+
+ return out
+
+
+class AttentionalPropagation(nn.Module):
+ def __init__(self, feature_dim: int, num_heads: int, ac_fn='relu', norm_fn='bn'):
+ super().__init__()
+ self.attn = MultiHeadedAttention(num_heads, feature_dim)
+ self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim], ac_fn=ac_fn, norm_fn=norm_fn)
+ nn.init.constant_(self.mlp[-1].bias, 0.0)
+
+ def forward(self, x, source, M=None):
+ message = self.attn(x, source, source, M=M)
+ self.prob = self.attn.prob
+
+ out = self.mlp(torch.cat([x, message], dim=1))
+ return out
+
+
+class KeypointEncoder(nn.Module):
+ """ Joint encoding of visual appearance and location using MLPs"""
+
+ def __init__(self, input_dim, feature_dim, layers, ac_fn='relu', norm_fn='bn'):
+ super().__init__()
+ self.input_dim = input_dim
+ self.encoder = MLP([input_dim] + layers + [feature_dim], ac_fn=ac_fn, norm_fn=norm_fn)
+ nn.init.constant_(self.encoder[-1].bias, 0.0)
+
+ def forward(self, kpts, scores=None):
+ if self.input_dim == 2:
+ return self.encoder(kpts.transpose(1, 2))
+ else:
+ inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] # [B, 2, N] + [B, 1, N]
+ return self.encoder(torch.cat(inputs, dim=1))
diff --git a/third_party/pram/nets/load_segnet.py b/third_party/pram/nets/load_segnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..51b8c5bc3fc1c25a8e52dd21cc6f3f4e79b418aa
--- /dev/null
+++ b/third_party/pram/nets/load_segnet.py
@@ -0,0 +1,31 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> load_segnet
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 09/04/2024 15:39
+=================================================='''
+from nets.segnet import SegNet
+from nets.segnetvit import SegNetViT
+
+
+def load_segnet(network, n_class, desc_dim, n_layers, output_dim):
+ model_config = {
+ 'network': {
+ 'descriptor_dim': desc_dim,
+ 'n_layers': n_layers,
+ 'n_class': n_class,
+ 'output_dim': output_dim,
+ 'with_score': False,
+ }
+ }
+
+ if network == 'segnet':
+ model = SegNet(model_config.get('network', {}))
+ # config['with_cls'] = False
+ elif network == 'segnetvit':
+ model = SegNetViT(model_config.get('network', {}))
+ else:
+ raise 'ERROR! {:s} model does not exist'.format(config['network'])
+
+ return model
diff --git a/third_party/pram/nets/retnet.py b/third_party/pram/nets/retnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4f3346fcd82193683ec72d0e55a2429d18a974b
--- /dev/null
+++ b/third_party/pram/nets/retnet.py
@@ -0,0 +1,174 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> retnet
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 22/02/2024 15:23
+=================================================='''
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File glretrieve -> retnet
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 15/02/2024 10:55
+=================================================='''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, inplanes, outplanes, stride=1, groups=32, dilation=1, norm_layer=None, ac_fn=None):
+ super(ResBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self.conv1 = conv1x1(inplanes, outplanes)
+ self.bn1 = norm_layer(outplanes)
+ self.conv2 = conv3x3(outplanes, outplanes, stride, groups, dilation)
+ self.bn2 = norm_layer(outplanes)
+ self.conv3 = conv1x1(outplanes, outplanes)
+ self.bn3 = norm_layer(outplanes)
+ if ac_fn is None:
+ self.ac_fn = nn.ReLU(inplace=True)
+ else:
+ self.ac_fn = ac_fn
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.ac_fn(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.ac_fn(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += identity
+ out = self.ac_fn(out)
+
+ return out
+
+
+class GeneralizedMeanPooling(nn.Module):
+ r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
+ The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
+ - At p = infinity, one gets Max Pooling
+ - At p = 1, one gets Average Pooling
+ The output is of size H x W, for any input size.
+ The number of output features is equal to the number of input planes.
+ Args:
+ output_size: the target output size of the image of the form H x W.
+ Can be a tuple (H, W) or a single H for a square image H x H
+ H and W can be either a ``int``, or ``None`` which means the size will
+ be the same as that of the input.
+ """
+
+ def __init__(self, norm, output_size=1, eps=1e-6):
+ super(GeneralizedMeanPooling, self).__init__()
+ assert norm > 0
+ self.p = float(norm)
+ self.output_size = output_size
+ self.eps = eps
+
+ def forward(self, x):
+ x = x.clamp(min=self.eps).pow(self.p)
+ return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(' \
+ + str(self.p) + ', ' \
+ + 'output_size=' + str(self.output_size) + ')'
+
+
+class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
+ """ Same, but norm is trainable
+ """
+
+ def __init__(self, norm=3, output_size=1, eps=1e-6):
+ super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
+ self.p = nn.Parameter(torch.ones(1) * norm)
+
+
+class Flatten(nn.Module):
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+class L2Norm(nn.Module):
+ def __init__(self, dim=1):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, input):
+ return F.normalize(input, p=2, dim=self.dim)
+
+
+class RetNet(nn.Module):
+ def __init__(self, indim=256, outdim=1024):
+ super().__init__()
+
+ ac_fn = nn.GELU()
+
+ self.convs = nn.Sequential(
+ # no batch normalization
+
+ nn.Conv2d(in_channels=indim, out_channels=512, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(512),
+ # nn.ReLU(),
+
+ ResBlock(512, 512, groups=32, stride=1, ac_fn=ac_fn),
+ ResBlock(512, 512, groups=32, stride=1, ac_fn=ac_fn),
+
+ nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(1024),
+ # nn.ReLU(),
+ ResBlock(inplanes=1024, outplanes=1024, groups=32, stride=1, ac_fn=ac_fn),
+ ResBlock(inplanes=1024, outplanes=1024, groups=32, stride=1, ac_fn=ac_fn),
+ )
+
+ self.pool = GeneralizedMeanPoolingP()
+ self.fc = nn.Linear(1024, out_features=outdim)
+
+ def initialize(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ out = self.convs(x)
+ out = self.pool(out).reshape(x.shape[0], -1)
+ out = self.fc(out)
+ out = F.normalize(out, p=2, dim=1)
+ return out
+
+
+if __name__ == '__main__':
+ mode = RetNet(indim=256, outdim=1024)
+ state_dict = mode.state_dict()
+ keys = state_dict.keys()
+ print(keys)
+ shapes = [state_dict[v].shape for v in keys]
+ print(shapes)
diff --git a/third_party/pram/nets/segnet.py b/third_party/pram/nets/segnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..632a38cb83ca77a23b5c1e1276996bd5574c3a0b
--- /dev/null
+++ b/third_party/pram/nets/segnet.py
@@ -0,0 +1,120 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> segnet
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:46
+=================================================='''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from nets.layers import MLP, KeypointEncoder
+from nets.layers import AttentionalPropagation
+from nets.utils import normalize_keypoints
+
+
+class SegGNN(nn.Module):
+ def __init__(self, feature_dim: int, n_layers: int, ac_fn: str = 'relu', norm_fn: str = 'bn', **kwargs):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ AttentionalPropagation(feature_dim, 4, ac_fn=ac_fn, norm_fn=norm_fn)
+ for _ in range(n_layers)
+ ])
+
+ def forward(self, desc):
+ for i, layer in enumerate(self.layers):
+ delta = layer(desc, desc)
+ desc = desc + delta
+
+ return desc
+
+
+class SegNet(nn.Module):
+ default_config = {
+ 'descriptor_dim': 256,
+ 'output_dim': 1024,
+ 'n_class': 512,
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'n_layers': 9,
+ 'ac_fn': 'relu',
+ 'norm_fn': 'in',
+ 'with_score': False,
+ # 'with_global': False,
+ 'with_cls': False,
+ 'with_sc': False,
+ }
+
+ def __init__(self, config={}):
+ super().__init__()
+ self.config = {**self.default_config, **config}
+ self.with_cls = self.config['with_cls']
+ self.with_sc = self.config['with_sc']
+
+ self.n_layers = self.config['n_layers']
+ self.gnn = SegGNN(
+ feature_dim=self.config['descriptor_dim'],
+ n_layers=self.config['n_layers'],
+ ac_fn=self.config['ac_fn'],
+ norm_fn=self.config['norm_fn'],
+ )
+
+ self.with_score = self.config['with_score']
+ self.kenc = KeypointEncoder(
+ input_dim=3 if self.with_score else 2,
+ feature_dim=self.config['descriptor_dim'],
+ layers=self.config['keypoint_encoder'],
+ ac_fn=self.config['ac_fn'],
+ norm_fn=self.config['norm_fn']
+ )
+
+ self.seg = MLP(channels=[self.config['descriptor_dim'],
+ self.config['output_dim'],
+ self.config['n_class']],
+ ac_fn=self.config['ac_fn'],
+ norm_fn=self.config['norm_fn']
+ )
+
+ if self.with_sc:
+ self.sc = MLP(channels=[self.config['descriptor_dim'],
+ self.config['output_dim'],
+ 3],
+ ac_fn=self.config['ac_fn'],
+ norm_fn=self.config['norm_fn']
+ )
+
+ def preprocess(self, data):
+ desc0 = data['seg_descriptors']
+ desc0 = desc0.transpose(1, 2) # [B, N, D] - > [B, D, N]
+
+ if 'norm_keypoints' in data.keys():
+ norm_kpts0 = data['norm_keypoints']
+ elif 'image' in data.keys():
+ kpts0 = data['keypoints']
+ norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape)
+ else:
+ raise ValueError('Require image shape for keypoint coordinate normalization')
+
+ # Keypoint MLP encoder.
+ if self.with_score:
+ scores0 = data['scores']
+ else:
+ scores0 = None
+ enc0 = self.kenc(norm_kpts0, scores0)
+
+ return desc0, enc0
+
+ def forward(self, data):
+ desc, enc = self.preprocess(data=data)
+ desc = desc + enc
+
+ desc = self.gnn(desc)
+ cls_output = self.seg(desc) # [B, C, N]
+ output = {
+ 'prediction': cls_output.transpose(-1, -2).contiguous(),
+ }
+
+ if self.with_sc:
+ sc_output = self.sc(desc)
+ output['sc'] = sc_output
+
+ return output
diff --git a/third_party/pram/nets/segnetvit.py b/third_party/pram/nets/segnetvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7919b545c26d3098df84d2e8e909d7ed69809dcd
--- /dev/null
+++ b/third_party/pram/nets/segnetvit.py
@@ -0,0 +1,203 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> segnetvit
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 14:52
+=================================================='''
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from nets.utils import normalize_keypoints
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ x = x.unflatten(-1, (-1, 2))
+ x1, x2 = x.unbind(dim=-1)
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
+
+
+def apply_cached_rotary_emb(
+ freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ return (t * freqs[0]) + (rotate_half(t) * freqs[1])
+
+
+class LearnableFourierPositionalEncoding(nn.Module):
+ def __init__(self, M: int, dim: int, F_dim: int = None,
+ gamma: float = 1.0) -> None:
+ super().__init__()
+ F_dim = F_dim if F_dim is not None else dim
+ self.gamma = gamma
+ self.Wr = nn.Linear(M, F_dim // 2, bias=False)
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ encode position vector """
+ projected = self.Wr(x)
+ cosines, sines = torch.cos(projected), torch.sin(projected)
+ emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
+ return emb.repeat_interleave(2, dim=-1)
+
+
+class KeypointEncoder(nn.Module):
+ """ Joint encoding of visual appearance and location using MLPs"""
+
+ def __init__(self):
+ super().__init__()
+ self.encoder = nn.Sequential(
+ nn.Linear(2, 32),
+ nn.LayerNorm(32, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(32, 64),
+ nn.LayerNorm(64, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(64, 128),
+ nn.LayerNorm(128, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(128, 256),
+ )
+
+ def forward(self, kpts, scores=None):
+ if scores is not None:
+ inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1]
+ return self.encoder(torch.cat(inputs, dim=-1))
+ else:
+ return self.encoder(kpts)
+
+
+class Attention(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, q, k, v):
+ s = q.shape[-1] ** -0.5
+ attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1)
+ return torch.einsum('...ij,...jd->...id', attn, v)
+
+
+class SelfMultiHeadAttention(nn.Module):
+ def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int):
+ super().__init__()
+ self.feat_dim = feat_dim
+ self.num_heads = num_heads
+
+ assert feat_dim % num_heads == 0
+ self.head_dim = feat_dim // num_heads
+ self.qkv = nn.Linear(feat_dim, hidden_dim * 3)
+ self.attn = Attention()
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(feat_dim + hidden_dim, feat_dim * 2),
+ nn.LayerNorm(feat_dim * 2, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(feat_dim * 2, feat_dim)
+ )
+
+ def forward(self, x, encoding=None):
+ qkv = self.qkv(x)
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
+ if encoding is not None:
+ q = apply_cached_rotary_emb(encoding, q)
+ k = apply_cached_rotary_emb(encoding, k)
+ attn = self.attn(q, k, v)
+ message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2))
+ return x + self.mlp(torch.cat([x, message], -1))
+
+
+class SegGNNViT(nn.Module):
+ def __init__(self, feature_dim: int, n_layers: int, hidden_dim: int = 256, num_heads: int = 4, **kwargs):
+ super(SegGNNViT, self).__init__()
+ self.layers = nn.ModuleList([
+ SelfMultiHeadAttention(feat_dim=feature_dim, hidden_dim=hidden_dim, num_heads=num_heads)
+ for _ in range(n_layers)
+ ])
+
+ def forward(self, desc, encoding=None):
+ for i, layer in enumerate(self.layers):
+ desc = layer(desc, encoding)
+ # desc = desc + delta // should be removed as this is already done in self-attention
+ return desc
+
+
+class SegNetViT(nn.Module):
+ default_config = {
+ 'descriptor_dim': 256,
+ 'output_dim': 1024,
+ 'n_class': 512,
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'n_layers': 15,
+ 'num_heads': 4,
+ 'hidden_dim': 256,
+ 'with_score': False,
+ 'with_global': False,
+ 'with_cls': False,
+ 'with_sc': False,
+ }
+
+ def __init__(self, config={}):
+ super(SegNetViT, self).__init__()
+ self.config = {**self.default_config, **config}
+ self.with_cls = self.config['with_cls']
+ self.with_sc = self.config['with_sc']
+
+ self.n_layers = self.config['n_layers']
+ self.gnn = SegGNNViT(
+ feature_dim=self.config['hidden_dim'],
+ n_layers=self.config['n_layers'],
+ hidden_dim=self.config['hidden_dim'],
+ num_heads=self.config['num_heads'],
+ )
+
+ self.with_score = self.config['with_score']
+ self.kenc = LearnableFourierPositionalEncoding(2, self.config['hidden_dim'] // self.config['num_heads'],
+ self.config['hidden_dim'] // self.config['num_heads'])
+
+ self.input_proj = nn.Linear(in_features=self.config['descriptor_dim'],
+ out_features=self.config['hidden_dim'])
+ self.seg = nn.Sequential(
+ nn.Linear(in_features=self.config['hidden_dim'], out_features=self.config['output_dim']),
+ nn.LayerNorm(self.config['output_dim'], elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(self.config['output_dim'], self.config['n_class'])
+ )
+
+ if self.with_sc:
+ self.sc = nn.Sequential(
+ nn.Linear(in_features=config['hidden_dim'], out_features=self.config['output_dim']),
+ nn.LayerNorm(self.config['output_dim'], elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(self.config['output_dim'], 3)
+ )
+
+ def preprocess(self, data):
+ desc0 = data['seg_descriptors']
+ if 'norm_keypoints' in data.keys():
+ norm_kpts0 = data['norm_keypoints']
+ elif 'image' in data.keys():
+ kpts0 = data['keypoints']
+ norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape)
+ else:
+ raise ValueError('Require image shape for keypoint coordinate normalization')
+
+ enc0 = self.kenc(norm_kpts0)
+
+ return desc0, enc0
+
+ def forward(self, data):
+ desc, enc = self.preprocess(data=data)
+ desc = self.input_proj(desc)
+
+ desc = self.gnn(desc, enc)
+ seg_output = self.seg(desc) # [B, N, C]
+
+ output = {
+ 'prediction': seg_output,
+ }
+
+ if self.with_sc:
+ sc_output = self.sc(desc)
+ output['sc'] = sc_output
+
+ return output
diff --git a/third_party/pram/nets/sfd2.py b/third_party/pram/nets/sfd2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c5a099b001ed9cf9e8a82b1b77dc9f7d9e31c8
--- /dev/null
+++ b/third_party/pram/nets/sfd2.py
@@ -0,0 +1,596 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> sfd2
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 14:53
+=================================================='''
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+import torchvision.transforms as tvf
+
+RGB_mean = [0.485, 0.456, 0.406]
+RGB_std = [0.229, 0.224, 0.225]
+
+norm_RGB = tvf.Compose([tvf.Normalize(mean=RGB_mean, std=RGB_std)])
+
+
+def simple_nms(scores, nms_radius: int):
+ """ Fast Non-maximum suppression to remove nearby points """
+ assert (nms_radius >= 0)
+
+ def max_pool(x):
+ return torch.nn.functional.max_pool2d(
+ x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == max_pool(scores)
+ for _ in range(2):
+ supp_mask = max_pool(max_mask.float()) > 0
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == max_pool(supp_scores)
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+def remove_borders(keypoints, scores, border: int, height: int, width: int):
+ """ Removes keypoints too close to the border """
+ mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
+ mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
+ mask = mask_h & mask_w
+ return keypoints[mask], scores[mask]
+
+
+def top_k_keypoints(keypoints, scores, k: int):
+ if k >= len(keypoints):
+ return keypoints, scores
+ scores, indices = torch.topk(scores, k, dim=0)
+ return keypoints[indices], scores
+
+
+def sample_descriptors(keypoints, descriptors, s: int = 8):
+ """ Interpolate descriptors at keypoint locations """
+ b, c, h, w = descriptors.shape
+ keypoints = keypoints - s / 2 + 0.5
+ keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
+ ).to(keypoints)[None]
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', align_corners=True)
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1)
+ return descriptors
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=False, groups=1, dilation=1):
+ if not use_bn:
+ return nn.Sequential(
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation),
+ nn.ReLU(inplace=True),
+ )
+ else:
+ return nn.Sequential(
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+
+class ResBlock(nn.Module):
+ def __init__(self, inplanes, outplanes, stride=1, groups=32, dilation=1, norm_layer=None):
+ super(ResBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self.conv1 = conv1x1(inplanes, outplanes)
+ self.bn1 = norm_layer(outplanes)
+ self.conv2 = conv3x3(outplanes, outplanes, stride, groups, dilation)
+ self.bn2 = norm_layer(outplanes)
+ self.conv3 = conv1x1(outplanes, outplanes)
+ self.bn3 = norm_layer(outplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet4x(nn.Module):
+ default_config = {
+ 'conf_th': 0.005,
+ 'remove_borders': 4,
+ 'min_keypoints': 128,
+ 'max_keypoints': 4096,
+ }
+
+ def __init__(self, inputdim=3, outdim=128, desc_compressor=None):
+ super().__init__()
+ self.outdim = outdim
+ self.desc_compressor = desc_compressor
+
+ d1, d2, d3, d4, d5, d6 = 64, 128, 256, 256, 256, 256
+ self.conv1a = conv(in_channels=inputdim, out_channels=d1, kernel_size=3, use_bn=True)
+ self.conv1b = conv(in_channels=d1, out_channels=d1, kernel_size=3, stride=2, use_bn=True)
+
+ self.conv2a = conv(in_channels=d1, out_channels=d2, kernel_size=3, use_bn=True)
+ self.conv2b = conv(in_channels=d2, out_channels=d2, kernel_size=3, stride=2, use_bn=True)
+
+ self.conv3a = conv(in_channels=d2, out_channels=d3, kernel_size=3, use_bn=True)
+ self.conv3b = conv(in_channels=d3, out_channels=d3, kernel_size=3, use_bn=True)
+
+ self.conv4 = nn.Sequential(
+ ResBlock(inplanes=256, outplanes=256, groups=32),
+ ResBlock(inplanes=256, outplanes=256, groups=32),
+ ResBlock(inplanes=256, outplanes=256, groups=32),
+ )
+
+ self.convPa = nn.Sequential(
+ torch.nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+ torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
+ )
+ self.convDa = nn.Sequential(
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
+ self.convDb = torch.nn.Conv2d(256, outdim, kernel_size=1, stride=1, padding=0)
+
+ def det(self, x):
+ out1a = self.conv1a(x)
+ out1b = self.conv1b(out1a)
+
+ out2a = self.conv2a(out1b)
+ out2b = self.conv2b(out2a)
+
+ out3a = self.conv3a(out2b)
+ out3b = self.conv3b(out3a)
+
+ out4 = self.conv4(out3b)
+
+ cPa = self.convPa(out4)
+ logits = self.convPb(cPa)
+ full_semi = torch.softmax(logits, dim=1)
+ semi = full_semi[:, :-1, :, :]
+ Hc, Wc = semi.size(2), semi.size(3)
+ score = semi.permute([0, 2, 3, 1])
+ score = score.view(score.size(0), Hc, Wc, 8, 8)
+ score = score.permute([0, 1, 3, 2, 4])
+ score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)
+
+ # Descriptor Head
+ cDa = self.convDa(out4)
+ desc = self.convDb(cDa)
+ desc = F.normalize(desc, dim=1)
+
+ return score, desc
+
+ def forward(self, batch):
+ out1a = self.conv1a(batch['image'])
+ out1b = self.conv1b(out1a)
+
+ out2a = self.conv2a(out1b)
+ out2b = self.conv2b(out2a)
+
+ out3a = self.conv3a(out2b)
+ out3b = self.conv3b(out3a)
+
+ out4 = self.conv4(out3b)
+
+ cPa = self.convPa(out4)
+ logits = self.convPb(cPa)
+ full_semi = torch.softmax(logits, dim=1)
+ semi = full_semi[:, :-1, :, :]
+ Hc, Wc = semi.size(2), semi.size(3)
+ score = semi.permute([0, 2, 3, 1])
+ score = score.view(score.size(0), Hc, Wc, 8, 8)
+ score = score.permute([0, 1, 3, 2, 4])
+ score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)
+
+ # Descriptor Head
+ cDa = self.convDa(out4)
+ desc = self.convDb(cDa)
+ desc = F.normalize(desc, dim=1)
+
+ return {
+ 'dense_features': desc,
+ 'scores': score,
+ 'logits': logits,
+ 'semi_map': semi,
+ }
+
+ def extract_patches(self, batch):
+ out1a = self.conv1a(batch['image'])
+ out1b = self.conv1b(out1a)
+
+ out2a = self.conv2a(out1b)
+ out2b = self.conv2b(out2a)
+
+ out3a = self.conv3a(out2b)
+ out3b = self.conv3b(out3a)
+
+ out4 = self.conv4(out3b)
+
+ cPa = self.convPa(out4)
+ logits = self.convPb(cPa)
+ full_semi = torch.softmax(logits, dim=1)
+ semi = full_semi[:, :-1, :, :]
+ Hc, Wc = semi.size(2), semi.size(3)
+ score = semi.permute([0, 2, 3, 1])
+ score = score.view(score.size(0), Hc, Wc, 8, 8)
+ score = score.permute([0, 1, 3, 2, 4])
+ score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)
+
+ # Descriptor Head
+ cDa = self.convDa(out4)
+ desc = self.convDb(cDa)
+ desc = F.normalize(desc, dim=1)
+
+ return {
+ 'dense_features': desc,
+ 'scores': score,
+ 'logits': logits,
+ 'semi_map': semi,
+ }
+
+ def extract_local_global(self, data,
+ config={
+ 'conf_th': 0.005,
+ 'remove_borders': 4,
+ 'min_keypoints': 128,
+ 'max_keypoints': 4096,
+ }
+ ):
+
+ config = {**self.default_config, **config}
+
+ b, ic, ih, iw = data['image'].shape
+ out1a = self.conv1a(data['image'])
+ out1b = self.conv1b(out1a) # 64
+
+ out2a = self.conv2a(out1b)
+ out2b = self.conv2b(out2a) # 128
+
+ out3a = self.conv3a(out2b)
+ out3b = self.conv3b(out3a) # 256
+
+ out4 = self.conv4(out3b) # 256
+
+ cPa = self.convPa(out4)
+ logits = self.convPb(cPa)
+ full_semi = torch.softmax(logits, dim=1)
+ semi = full_semi[:, :-1, :, :]
+ Hc, Wc = semi.size(2), semi.size(3)
+ score = semi.permute([0, 2, 3, 1])
+ score = score.view(score.size(0), Hc, Wc, 8, 8)
+ score = score.permute([0, 1, 3, 2, 4])
+ score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)
+ if Hc * 8 != ih or Wc * 8 != iw:
+ score = F.interpolate(score.unsqueeze(1), size=[ih, iw], align_corners=True, mode='bilinear')
+ score = score.squeeze(1)
+ # extract keypoints
+ nms_scores = simple_nms(scores=score, nms_radius=4)
+ keypoints = [
+ torch.nonzero(s >= config['conf_th'])
+ for s in nms_scores]
+ scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]
+
+ if len(scores[0]) <= config['min_keypoints']:
+ keypoints = [
+ torch.nonzero(s >= config['conf_th'] * 0.5)
+ for s in nms_scores]
+ scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]
+
+ # Discard keypoints near the image borders
+ keypoints, scores = list(zip(*[
+ remove_borders(k, s, config['remove_borders'], ih, iw)
+ for k, s in zip(keypoints, scores)]))
+
+ # Keep the k keypoints with highest score
+ if config['max_keypoints'] >= 0:
+ keypoints, scores = list(zip(*[
+ top_k_keypoints(k, s, config['max_keypoints'])
+ for k, s in zip(keypoints, scores)]))
+
+ # Convert (h, w) to (x, y)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+ # Descriptor Head
+ cDa = self.convDa(out4)
+ desc_map = self.convDb(cDa)
+ desc_map = F.normalize(desc_map, dim=1)
+
+ descriptors = [sample_descriptors(k[None], d[None], 4)[0]
+ for k, d in zip(keypoints, desc_map)]
+
+ return {
+ 'score_map': score,
+ 'desc_map': desc_map,
+ 'mid_features': out4,
+ 'global_descriptors': [out1b, out2b, out3b, out4],
+ 'keypoints': keypoints,
+ 'scores': scores,
+ 'descriptors': descriptors,
+ }
+
+ def sample(self, score_map, semi_descs, kpts, s=4, norm_desc=True):
+ # print('sample: ', score_map.shape, semi_descs.shape, kpts.shape)
+ b, c, h, w = semi_descs.shape
+ norm_kpts = kpts - s / 2 + 0.5
+ norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
+ ).to(norm_kpts)[None]
+ norm_kpts = norm_kpts * 2 - 1
+ # args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
+ descriptors = torch.nn.functional.grid_sample(
+ semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True)
+
+ if norm_desc:
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1)
+ else:
+ descriptors = descriptors.reshape(b, c, -1)
+
+ # print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()),
+ # torch.max(kpts[:, 0].long()))
+ scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()]
+
+ return scores, descriptors.squeeze(0)
+
+
+class DescriptorCompressor(nn.Module):
+ def __init__(self, inputdim: int, outdim: int):
+ super().__init__()
+ self.inputdim = inputdim
+ self.outdim = outdim
+ self.conv = nn.Conv1d(in_channels=inputdim, out_channels=outdim, kernel_size=1, padding=0, bias=True)
+
+ def forward(self, x):
+ # b, c, n = x.shape
+ out = self.conv(x)
+ out = F.normalize(out, p=2, dim=1)
+ return out
+
+
+def extract_sfd2_return(model, img, conf_th=0.001,
+ mask=None,
+ topK=-1,
+ min_keypoints=0,
+ **kwargs):
+ old_bm = torch.backends.cudnn.benchmark
+ torch.backends.cudnn.benchmark = False # speedup
+
+ img = norm_RGB(img.squeeze())
+ img = img[None]
+ img = img.cuda()
+
+ B, one, H, W = img.shape
+
+ all_pts = []
+ all_descs = []
+
+ if 'scales' in kwargs.keys():
+ scales = kwargs.get('scales')
+ else:
+ scales = [1.0]
+
+ for s in scales:
+ if s == 1.0:
+ new_img = img
+ else:
+ nh = int(H * s)
+ nw = int(W * s)
+ new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True)
+ nh, nw = new_img.shape[2:]
+
+ with torch.no_grad():
+ heatmap, coarse_desc = model.det(new_img)
+
+ # print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape)
+ if len(heatmap.size()) == 3:
+ heatmap = heatmap.unsqueeze(1)
+ if len(heatmap.size()) == 2:
+ heatmap = heatmap.unsqueeze(0)
+ heatmap = heatmap.unsqueeze(1)
+ # print(heatmap.shape)
+ if heatmap.size(2) != nh or heatmap.size(3) != nw:
+ heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True)
+
+ conf_thresh = conf_th
+ nms_dist = 3
+ border_remove = 4
+ scores = simple_nms(heatmap, nms_radius=nms_dist)
+ keypoints = [
+ torch.nonzero(s > conf_thresh)
+ for s in scores]
+ scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
+ # print('scores in return: ', len(scores[0]))
+
+ # print(keypoints[0].shape)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+ scores = scores[0].data.cpu().numpy().squeeze()
+ keypoints = keypoints[0].data.cpu().numpy().squeeze()
+ pts = keypoints.transpose()
+ pts[2, :] = scores
+
+ inds = np.argsort(pts[2, :])
+ pts = pts[:, inds[::-1]] # Sort by confidence.
+ # Remove points along border.
+ bord = border_remove
+ toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord))
+ toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord))
+ toremove = np.logical_or(toremoveW, toremoveH)
+ pts = pts[:, ~toremove]
+
+ # valid_idex = heatmap > conf_thresh
+ # valid_score = heatmap[valid_idex]
+ # """
+ # --- Process descriptor.
+ # coarse_desc = coarse_desc.data.cpu().numpy().squeeze()
+ D = coarse_desc.size(1)
+ if pts.shape[1] == 0:
+ desc = np.zeros((D, 0))
+ else:
+ if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw:
+ desc = coarse_desc[:, :, pts[1, :], pts[0, :]]
+ desc = desc.data.cpu().numpy().reshape(D, -1)
+ else:
+ # Interpolate into descriptor map using 2D point locations.
+ samp_pts = torch.from_numpy(pts[:2, :].copy())
+ samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1.
+ samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1.
+ samp_pts = samp_pts.transpose(0, 1).contiguous()
+ samp_pts = samp_pts.view(1, 1, -1, 2)
+ samp_pts = samp_pts.float()
+ samp_pts = samp_pts.cuda()
+ desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True)
+ desc = desc.data.cpu().numpy().reshape(D, -1)
+ desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :]
+
+ if pts.shape[1] == 0:
+ continue
+
+ # print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H)
+ pts[0, :] = pts[0, :] * W / nw
+ pts[1, :] = pts[1, :] * H / nh
+ all_pts.append(np.transpose(pts, [1, 0]))
+ all_descs.append(np.transpose(desc, [1, 0]))
+
+ all_pts = np.vstack(all_pts)
+ all_descs = np.vstack(all_descs)
+
+ torch.backends.cudnn.benchmark = old_bm
+
+ if all_pts.shape[0] == 0:
+ return None, None, None
+
+ keypoints = all_pts[:, 0:2]
+ scores = all_pts[:, 2]
+ descriptors = all_descs
+
+ if mask is not None:
+ # cv2.imshow("mask", mask)
+ # cv2.waitKey(0)
+ labels = []
+ others = []
+ keypoints_with_labels = []
+ scores_with_labels = []
+ descriptors_with_labels = []
+ keypoints_without_labels = []
+ scores_without_labels = []
+ descriptors_without_labels = []
+
+ id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0])
+ # print(img.shape, id_img.shape)
+
+ for i in range(keypoints.shape[0]):
+ x = keypoints[i, 0]
+ y = keypoints[i, 1]
+ # print("x-y", x, y, int(x), int(y))
+ gid = id_img[int(y), int(x)]
+ if gid == 0:
+ keypoints_without_labels.append(keypoints[i])
+ scores_without_labels.append(scores[i])
+ descriptors_without_labels.append(descriptors[i])
+ others.append(0)
+ else:
+ keypoints_with_labels.append(keypoints[i])
+ scores_with_labels.append(scores[i])
+ descriptors_with_labels.append(descriptors[i])
+ labels.append(gid)
+
+ if topK > 0:
+ if topK <= len(keypoints_with_labels):
+ idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK]
+ keypoints = np.array(keypoints_with_labels, float)[idxes]
+ scores = np.array(scores_with_labels, float)[idxes]
+ labels = np.array(labels, np.int32)[idxes]
+ descriptors = np.array(descriptors_with_labels, float)[idxes]
+ elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels):
+ # keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels])
+ # scores = np.vstack([scorescc_with_labels, scores_without_labels])
+ # descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels])
+ # labels = np.vstack([labels, others])
+ keypoints = keypoints_with_labels
+ scores = scores_with_labels
+ descriptors = descriptors_with_labels
+ for i in range(len(others)):
+ keypoints.append(keypoints_without_labels[i])
+ scores.append(scores_without_labels[i])
+ descriptors.append(descriptors_without_labels[i])
+ labels.append(others[i])
+ else:
+ n = topK - len(keypoints_with_labels)
+ idxes = np.array(scores_without_labels, float).argsort()[::-1][:n]
+ keypoints = keypoints_with_labels
+ scores = scores_with_labels
+ descriptors = descriptors_with_labels
+ for i in idxes:
+ keypoints.append(keypoints_without_labels[i])
+ scores.append(scores_without_labels[i])
+ descriptors.append(descriptors_without_labels[i])
+ labels.append(others[i])
+ keypoints = np.array(keypoints, float)
+ descriptors = np.array(descriptors, float)
+ # print(keypoints.shape, descriptors.shape)
+ return {"keypoints": np.array(keypoints, float),
+ "descriptors": np.array(descriptors, float),
+ "scores": np.array(scores, np.float),
+ "labels": np.array(labels, np.int32),
+ }
+ else:
+ # print(topK)
+ if topK > 0:
+ idxes = np.array(scores, dtype=float).argsort()[::-1][:topK]
+ keypoints = np.array(keypoints[idxes], dtype=float)
+ scores = np.array(scores[idxes], dtype=float)
+ descriptors = np.array(descriptors[idxes], dtype=float)
+
+ keypoints = np.array(keypoints, dtype=float)
+ scores = np.array(scores, dtype=float)
+ descriptors = np.array(descriptors, dtype=float)
+
+ # print(keypoints.shape, descriptors.shape)
+
+ return {"keypoints": np.array(keypoints, dtype=float),
+ "descriptors": descriptors,
+ "scores": scores,
+ }
+
+
+def load_sfd2(weight_path):
+ net = ResNet4x(inputdim=3, outdim=128)
+ net.load_state_dict(torch.load(weight_path, map_location='cpu')['state_dict'], strict=True)
+ # print('Load sfd2 from {:s}'.format(weight_path))
+ return net
diff --git a/third_party/pram/nets/superpoint.py b/third_party/pram/nets/superpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6751016bd71cbbbb072243b3c1aebc100f632693
--- /dev/null
+++ b/third_party/pram/nets/superpoint.py
@@ -0,0 +1,607 @@
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+from pathlib import Path
+import torch
+from torch import nn
+import numpy as np
+import cv2
+import torch.nn.functional as F
+
+
+def simple_nms(scores, nms_radius: int):
+ """ Fast Non-maximum suppression to remove nearby points """
+ assert (nms_radius >= 0)
+
+ def max_pool(x):
+ return torch.nn.functional.max_pool2d(
+ x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == max_pool(scores)
+ for _ in range(2):
+ supp_mask = max_pool(max_mask.float()) > 0
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == max_pool(supp_scores)
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+def remove_borders(keypoints, scores, border: int, height: int, width: int):
+ """ Removes keypoints too close to the border """
+ mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
+ mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
+ mask = mask_h & mask_w
+ return keypoints[mask], scores[mask]
+
+
+def top_k_keypoints(keypoints, scores, k: int):
+ if k >= len(keypoints):
+ return keypoints, scores
+ scores, indices = torch.topk(scores, k, dim=0)
+ return keypoints[indices], scores
+
+
+def sample_descriptors(keypoints, descriptors, s: int = 8):
+ """ Interpolate descriptors at keypoint locations """
+ b, c, h, w = descriptors.shape
+ keypoints = keypoints - s / 2 + 0.5
+ keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
+ ).to(keypoints)[None]
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
+ args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1)
+ return descriptors
+
+
+class SuperPoint(nn.Module):
+ """SuperPoint Convolutional Detector and Descriptor
+
+ SuperPoint: Self-Supervised Interest Point Detection and
+ Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
+ Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
+
+ """
+ default_config = {
+ 'descriptor_dim': 256,
+ 'nms_radius': 3,
+ 'keypoint_threshold': 0.001,
+ 'max_keypoints': -1,
+ 'min_keypoints': 32,
+ 'remove_borders': 4,
+ }
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = {**self.default_config, **config}
+
+ self.relu = nn.ReLU(inplace=True)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
+
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) # 64
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) # 64
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) # 128
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) # 128
+
+ self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256
+ self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
+
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256
+ self.convDb = nn.Conv2d(
+ c5, self.config['descriptor_dim'],
+ kernel_size=1, stride=1, padding=0)
+
+ # path = Path(__file__).parent / 'weights/superpoint_v1.pth'
+ path = config['weight_path']
+ self.load_state_dict(torch.load(str(path), map_location='cpu'), strict=True)
+
+ mk = self.config['max_keypoints']
+ if mk == 0 or mk < -1:
+ raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
+
+ print('Loaded SuperPoint model')
+
+ def extract_global(self, data):
+ # Shared Encoder
+ x0 = self.relu(self.conv1a(data['image']))
+ x0 = self.relu(self.conv1b(x0))
+ x0 = self.pool(x0)
+ x1 = self.relu(self.conv2a(x0))
+ x1 = self.relu(self.conv2b(x1))
+ x1 = self.pool(x1)
+ x2 = self.relu(self.conv3a(x1))
+ x2 = self.relu(self.conv3b(x2))
+ x2 = self.pool(x2)
+ x3 = self.relu(self.conv4a(x2))
+ x3 = self.relu(self.conv4b(x3))
+
+ x4 = self.relu(self.convDa(x3))
+
+ # print('ex_g: ', x0.shape, x1.shape, x2.shape, x3.shape, x4.shape)
+
+ return [x0, x1, x2, x3, x4]
+
+ def extract_local_global(self, data):
+ # Shared Encoder
+ b, ic, ih, iw = data['image'].shape
+ x0 = self.relu(self.conv1a(data['image']))
+ x0 = self.relu(self.conv1b(x0))
+ x0 = self.pool(x0)
+ x1 = self.relu(self.conv2a(x0))
+ x1 = self.relu(self.conv2b(x1))
+ x1 = self.pool(x1)
+ x2 = self.relu(self.conv3a(x1))
+ x2 = self.relu(self.conv3b(x2))
+ x2 = self.pool(x2)
+ x3 = self.relu(self.conv4a(x2))
+ x3 = self.relu(self.conv4b(x3))
+
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x3))
+ score = self.convPb(cPa)
+ score = torch.nn.functional.softmax(score, 1)[:, :-1]
+ # print(scores.shape)
+ b, _, h, w = score.shape
+ score = score.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ score = score.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+ score = torch.nn.functional.interpolate(score.unsqueeze(1), size=(ih, iw), align_corners=True,
+ mode='bilinear')
+ score = score.squeeze(1)
+
+ # extract kpts
+ nms_scores = simple_nms(scores=score, nms_radius=self.config['nms_radius'])
+ keypoints = [
+ torch.nonzero(s >= self.config['keypoint_threshold'])
+ for s in nms_scores]
+ scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]
+
+ if len(scores[0]) <= self.config['min_keypoints']:
+ keypoints = [
+ torch.nonzero(s >= self.config['keypoint_threshold'] * 0.5)
+ for s in nms_scores]
+ scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]
+
+ # Discard keypoints near the image borders
+ keypoints, scores = list(zip(*[
+ remove_borders(k, s, self.config['remove_borders'], ih, iw)
+ for k, s in zip(keypoints, scores)]))
+
+ # Keep the k keypoints with the highest score
+ if self.config['max_keypoints'] >= 0:
+ keypoints, scores = list(zip(*[
+ top_k_keypoints(k, s, self.config['max_keypoints'])
+ for k, s in zip(keypoints, scores)]))
+
+ # Convert (h, w) to (x, y)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x3))
+ desc_map = self.convDb(cDa)
+ desc_map = torch.nn.functional.normalize(desc_map, p=2, dim=1)
+ descriptors = [sample_descriptors(k[None], d[None], 8)[0]
+ for k, d in zip(keypoints, desc_map)]
+
+ return {
+ 'score_map': score,
+ 'desc_map': desc_map,
+ 'mid_features': cDa, # 256
+ 'global_descriptors': [x0, x1, x2, x3, cDa],
+ 'keypoints': keypoints,
+ 'scores': scores,
+ 'descriptors': descriptors,
+ }
+
+ def sample(self, score_map, semi_descs, kpts, s=8, norm_desc=True):
+ # print('sample: ', score_map.shape, semi_descs.shape, kpts.shape)
+ b, c, h, w = semi_descs.shape
+ norm_kpts = kpts - s / 2 + 0.5
+ norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
+ ).to(norm_kpts)[None]
+ norm_kpts = norm_kpts * 2 - 1
+ # args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
+ descriptors = torch.nn.functional.grid_sample(
+ semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True)
+ if norm_desc:
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1)
+ else:
+ descriptors = descriptors.reshape(b, c, -1)
+
+ # print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()),
+ # torch.max(kpts[:, 0].long()))
+ scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()]
+
+ return scores, descriptors.squeeze(0)
+
+ def extract(self, data):
+ """ Compute keypoints, scores, descriptors for image """
+ # Shared Encoder
+ x = self.relu(self.conv1a(data['image']))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x))
+ scores = self.convPb(cPa)
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+ b, _, h, w = scores.shape
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x))
+ descriptors = self.convDb(cDa)
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
+
+ return scores, descriptors
+
+ def det(self, image):
+ """ Compute keypoints, scores, descriptors for image """
+ # Shared Encoder
+ x = self.relu(self.conv1a(image))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x))
+ scores = self.convPb(cPa)
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+ # print(scores.shape)
+ b, _, h, w = scores.shape
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x))
+ descriptors = self.convDb(cDa)
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
+
+ return scores, descriptors
+
+ def forward(self, data):
+ """ Compute keypoints, scores, descriptors for image """
+ # Shared Encoder
+ x = self.relu(self.conv1a(data['image']))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x))
+ scores = self.convPb(cPa)
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+ # print(scores.shape)
+ b, _, h, w = scores.shape
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+ scores = simple_nms(scores, self.config['nms_radius'])
+
+ # Extract keypoints
+ keypoints = [
+ torch.nonzero(s > self.config['keypoint_threshold'])
+ for s in scores]
+ scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
+
+ # Discard keypoints near the image borders
+ keypoints, scores = list(zip(*[
+ remove_borders(k, s, self.config['remove_borders'], h * 8, w * 8)
+ for k, s in zip(keypoints, scores)]))
+
+ # Keep the k keypoints with highest score
+ if self.config['max_keypoints'] >= 0:
+ keypoints, scores = list(zip(*[
+ top_k_keypoints(k, s, self.config['max_keypoints'])
+ for k, s in zip(keypoints, scores)]))
+
+ # Convert (h, w) to (x, y)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x))
+ descriptors = self.convDb(cDa)
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
+
+ # Extract descriptors
+ # print(keypoints[0].shape)
+ descriptors = [sample_descriptors(k[None], d[None], 8)[0]
+ for k, d in zip(keypoints, descriptors)]
+
+ return {
+ 'keypoints': keypoints,
+ 'scores': scores,
+ 'descriptors': descriptors,
+ 'global_descriptor': x,
+ }
+
+
+def extract_descriptor(sample_pts, coarse_desc, H, W):
+ '''
+ :param samplt_pts:
+ :param coarse_desc:
+ :return:
+ '''
+ with torch.no_grad():
+ norm_sample_pts = torch.zeros_like(sample_pts)
+ norm_sample_pts[0, :] = (sample_pts[0, :] / (float(W) / 2.)) - 1. # x
+ norm_sample_pts[1, :] = (sample_pts[1, :] / (float(H) / 2.)) - 1. # y
+ norm_sample_pts = norm_sample_pts.transpose(0, 1).contiguous()
+ norm_sample_pts = norm_sample_pts.view(1, 1, -1, 2).float()
+ sample_desc = torch.nn.functional.grid_sample(coarse_desc[None], norm_sample_pts, mode='bilinear',
+ align_corners=False)
+ sample_desc = torch.nn.functional.normalize(sample_desc, dim=1).squeeze(2).squeeze(0)
+ return sample_desc
+
+
+def extract_sp_return(model, img, conf_th=0.005,
+ mask=None,
+ topK=-1,
+ **kwargs):
+ old_bm = torch.backends.cudnn.benchmark
+ torch.backends.cudnn.benchmark = False # speedup
+
+ # print(img.shape)
+ img = img.cuda()
+ # if len(img.shape) == 3: # gray image
+ # img = img[None]
+
+ B, one, H, W = img.shape
+
+ all_pts = []
+ all_descs = []
+
+ if 'scales' in kwargs.keys():
+ scales = kwargs.get('scales')
+ else:
+ scales = [1.0]
+
+ for s in scales:
+ if s == 1.0:
+ new_img = img
+ else:
+ nh = int(H * s)
+ nw = int(W * s)
+ new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True)
+ nh, nw = new_img.shape[2:]
+
+ with torch.no_grad():
+ heatmap, coarse_desc = model.det(new_img)
+
+ # print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape)
+ if len(heatmap.size()) == 3:
+ heatmap = heatmap.unsqueeze(1)
+ if len(heatmap.size()) == 2:
+ heatmap = heatmap.unsqueeze(0)
+ heatmap = heatmap.unsqueeze(1)
+ # print(heatmap.shape)
+ if heatmap.size(2) != nh or heatmap.size(3) != nw:
+ heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True)
+
+ conf_thresh = conf_th
+ nms_dist = 4
+ border_remove = 4
+ scores = simple_nms(heatmap, nms_radius=nms_dist)
+ keypoints = [
+ torch.nonzero(s > conf_thresh)
+ for s in scores]
+ scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
+ # print(keypoints[0].shape)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+ scores = scores[0].data.cpu().numpy().squeeze()
+ keypoints = keypoints[0].data.cpu().numpy().squeeze()
+ pts = keypoints.transpose()
+ pts[2, :] = scores
+
+ inds = np.argsort(pts[2, :])
+ pts = pts[:, inds[::-1]] # Sort by confidence.
+ # Remove points along border.
+ bord = border_remove
+ toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord))
+ toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord))
+ toremove = np.logical_or(toremoveW, toremoveH)
+ pts = pts[:, ~toremove]
+
+ # valid_idex = heatmap > conf_thresh
+ # valid_score = heatmap[valid_idex]
+ # """
+ # --- Process descriptor.
+ # coarse_desc = coarse_desc.data.cpu().numpy().squeeze()
+ D = coarse_desc.size(1)
+ if pts.shape[1] == 0:
+ desc = np.zeros((D, 0))
+ else:
+ if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw:
+ desc = coarse_desc[:, :, pts[1, :], pts[0, :]]
+ desc = desc.data.cpu().numpy().reshape(D, -1)
+ else:
+ # Interpolate into descriptor map using 2D point locations.
+ samp_pts = torch.from_numpy(pts[:2, :].copy())
+ samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1.
+ samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1.
+ samp_pts = samp_pts.transpose(0, 1).contiguous()
+ samp_pts = samp_pts.view(1, 1, -1, 2)
+ samp_pts = samp_pts.float()
+ samp_pts = samp_pts.cuda()
+ desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True)
+ desc = desc.data.cpu().numpy().reshape(D, -1)
+ desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :]
+
+ if pts.shape[1] == 0:
+ continue
+
+ # print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H)
+ pts[0, :] = pts[0, :] * W / nw
+ pts[1, :] = pts[1, :] * H / nh
+ all_pts.append(np.transpose(pts, [1, 0]))
+ all_descs.append(np.transpose(desc, [1, 0]))
+
+ all_pts = np.vstack(all_pts)
+ all_descs = np.vstack(all_descs)
+
+ torch.backends.cudnn.benchmark = old_bm
+
+ if all_pts.shape[0] == 0:
+ return None, None, None
+
+ keypoints = all_pts[:, 0:2]
+ scores = all_pts[:, 2]
+ descriptors = all_descs
+
+ if mask is not None:
+ # cv2.imshow("mask", mask)
+ # cv2.waitKey(0)
+ labels = []
+ others = []
+ keypoints_with_labels = []
+ scores_with_labels = []
+ descriptors_with_labels = []
+ keypoints_without_labels = []
+ scores_without_labels = []
+ descriptors_without_labels = []
+
+ id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0])
+ # print(img.shape, id_img.shape)
+
+ for i in range(keypoints.shape[0]):
+ x = keypoints[i, 0]
+ y = keypoints[i, 1]
+ # print("x-y", x, y, int(x), int(y))
+ gid = id_img[int(y), int(x)]
+ if gid == 0:
+ keypoints_without_labels.append(keypoints[i])
+ scores_without_labels.append(scores[i])
+ descriptors_without_labels.append(descriptors[i])
+ others.append(0)
+ else:
+ keypoints_with_labels.append(keypoints[i])
+ scores_with_labels.append(scores[i])
+ descriptors_with_labels.append(descriptors[i])
+ labels.append(gid)
+
+ if topK > 0:
+ if topK <= len(keypoints_with_labels):
+ idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK]
+ keypoints = np.array(keypoints_with_labels, float)[idxes]
+ scores = np.array(scores_with_labels, float)[idxes]
+ labels = np.array(labels, np.int32)[idxes]
+ descriptors = np.array(descriptors_with_labels, float)[idxes]
+ elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels):
+ # keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels])
+ # scores = np.vstack([scorescc_with_labels, scores_without_labels])
+ # descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels])
+ # labels = np.vstack([labels, others])
+ keypoints = keypoints_with_labels
+ scores = scores_with_labels
+ descriptors = descriptors_with_labels
+ for i in range(len(others)):
+ keypoints.append(keypoints_without_labels[i])
+ scores.append(scores_without_labels[i])
+ descriptors.append(descriptors_without_labels[i])
+ labels.append(others[i])
+ else:
+ n = topK - len(keypoints_with_labels)
+ idxes = np.array(scores_without_labels, float).argsort()[::-1][:n]
+ keypoints = keypoints_with_labels
+ scores = scores_with_labels
+ descriptors = descriptors_with_labels
+ for i in idxes:
+ keypoints.append(keypoints_without_labels[i])
+ scores.append(scores_without_labels[i])
+ descriptors.append(descriptors_without_labels[i])
+ labels.append(others[i])
+ keypoints = np.array(keypoints, float)
+ descriptors = np.array(descriptors, float)
+ # print(keypoints.shape, descriptors.shape)
+ return {"keypoints": np.array(keypoints, float),
+ "descriptors": np.array(descriptors, float),
+ "scores": np.array(scores, float),
+ "labels": np.array(labels, np.int32),
+ }
+ else:
+ # print(topK)
+ if topK > 0:
+ idxes = np.array(scores, dtype=float).argsort()[::-1][:topK]
+ keypoints = np.array(keypoints[idxes], dtype=float)
+ scores = np.array(scores[idxes], dtype=float)
+ descriptors = np.array(descriptors[idxes], dtype=float)
+
+ keypoints = np.array(keypoints, dtype=float)
+ scores = np.array(scores, dtype=float)
+ descriptors = np.array(descriptors, dtype=float)
+
+ # print(keypoints.shape, descriptors.shape)
+
+ return {"keypoints": np.array(keypoints, dtype=float),
+ "descriptors": descriptors,
+ "scores": scores,
+ }
diff --git a/third_party/pram/nets/utils.py b/third_party/pram/nets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..066a00510c19e0c87cf5d07a36cea2a90dd0e3eb
--- /dev/null
+++ b/third_party/pram/nets/utils.py
@@ -0,0 +1,24 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> utils
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 10:48
+=================================================='''
+import torch
+
+eps = 1e-8
+
+
+def arange_like(x, dim: int):
+ return x.new_ones(x.shape[dim]).cumsum(0) - 1
+
+
+def normalize_keypoints(kpts, image_shape):
+ """ Normalize keypoints locations based on image image_shape"""
+ _, _, height, width = image_shape
+ one = kpts.new_tensor(1)
+ size = torch.stack([one * width, one * height])[None]
+ center = size / 2
+ scaling = size.max(1, keepdim=True).values * 0.7
+ return (kpts - center[:, None, :]) / scaling[:, None, :]
diff --git a/third_party/pram/recognition/recmap.py b/third_party/pram/recognition/recmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..c159de286e96fdb594428e88e370e1a7edbecb79
--- /dev/null
+++ b/third_party/pram/recognition/recmap.py
@@ -0,0 +1,1118 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> recmap
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 11:02
+=================================================='''
+import argparse
+import torch
+import os
+import os.path as osp
+import numpy as np
+import cv2
+import yaml
+import multiprocessing as mp
+from copy import deepcopy
+import logging
+import h5py
+from tqdm import tqdm
+import open3d as o3d
+from sklearn.cluster import KMeans, Birch
+from collections import defaultdict
+from colmap_utils.read_write_model import read_model, qvec2rotmat, write_cameras_binary, write_images_binary
+from colmap_utils.read_write_model import write_points3d_binary, Image, Point3D, Camera
+from colmap_utils.read_write_model import write_compressed_points3d_binary, write_compressed_images_binary
+from recognition.vis_seg import generate_color_dic, vis_seg_point, plot_kpts
+
+
+class RecMap:
+ def __init__(self):
+ self.cameras = None
+ self.images = None
+ self.points3D = None
+ self.pcd = o3d.geometry.PointCloud()
+ self.seg_color_dict = generate_color_dic(n_seg=1000)
+
+ def load_sfm_model(self, path: str, ext='.bin'):
+ self.cameras, self.images, self.points3D = read_model(path, ext)
+ self.name_to_id = {image.name: i for i, image in self.images.items()}
+ print('Load {:d} cameras, {:d} images, {:d} points'.format(len(self.cameras), len(self.images),
+ len(self.points3D)))
+
+ def remove_statics_outlier(self, nb_neighbors: int = 20, std_ratio: float = 2.0):
+ xyzs = []
+ p3d_ids = []
+ for p3d_id in self.points3D.keys():
+ xyzs.append(self.points3D[p3d_id].xyz)
+ p3d_ids.append(p3d_id)
+
+ xyzs = np.array(xyzs)
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(xyzs)
+ new_pcd, inlier_ids = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
+
+ new_point3Ds = {}
+ for i in inlier_ids:
+ new_point3Ds[p3d_ids[i]] = self.points3D[p3d_ids[i]]
+ self.points3D = new_point3Ds
+ n_outlier = xyzs.shape[0] - len(inlier_ids)
+ ratio = n_outlier / xyzs.shape[0]
+ print('Remove {:d} - {:d} = {:d}/{:.2f}% points'.format(xyzs.shape[0], len(inlier_ids), n_outlier, ratio * 100))
+
+ def load_segmentation(self, path: str):
+ data = np.load(path, allow_pickle=True)[()]
+ p3d_id = data['id']
+ seg_id = data['label']
+ self.p3d_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+ self.seg_p3d = {}
+ for pid in self.p3d_seg.keys():
+ sid = self.p3d_seg[pid]
+ if sid not in self.seg_p3d.keys():
+ self.seg_p3d[sid] = [pid]
+ else:
+ self.seg_p3d[sid].append(pid)
+
+ if 'xyz' not in data.keys():
+ all_xyz = []
+ for pid in p3d_id:
+ xyz = self.points3D[pid].xyz
+ all_xyz.append(xyz)
+ data['xyz'] = np.array(all_xyz)
+ np.save(path, data)
+ print('Add xyz to ', path)
+
+ def cluster(self, k=512, mode='xyz', min_obs=3, save_fn=None, method='kmeans', **kwargs):
+ if save_fn is not None:
+ if osp.isfile(save_fn):
+ print('{:s} exists.'.format(save_fn))
+ return
+ all_xyz = []
+ point3D_ids = []
+ for p3d in self.points3D.values():
+ track_len = len(p3d.point2D_idxs)
+ if track_len < min_obs:
+ continue
+ all_xyz.append(p3d.xyz)
+ point3D_ids.append(p3d.id)
+
+ xyz = np.array(all_xyz)
+ point3D_ids = np.array(point3D_ids)
+
+ if mode.find('x') < 0:
+ xyz[:, 0] = 0
+ if mode.find('y') < 0:
+ xyz[:, 1] = 0
+ if mode.find('z') < 0:
+ xyz[:, 2] = 0
+
+ if method == 'kmeans':
+ model = KMeans(n_clusters=k, random_state=0, verbose=True).fit(xyz)
+ elif method == 'birch':
+ model = Birch(threshold=kwargs.get('threshold'), n_clusters=k).fit(xyz) # 0.01 for indoor
+ else:
+ print('Method {:s} for clustering does not exist'.format(method))
+ exit(0)
+ labels = np.array(model.labels_).reshape(-1)
+ if save_fn is not None:
+ np.save(save_fn, {
+ 'id': np.array(point3D_ids), # should be assigned to self.points3D_ids
+ 'label': np.array(labels),
+ 'xyz': np.array(all_xyz),
+ })
+
+ def assign_point3D_descriptor(self, feature_fn: str, save_fn=None, n_process=1):
+ '''
+ assign each 3d point a descriptor for localization
+ :param feature_fn: file name of features [h5py]
+ :param save_fn:
+ :param n_process:
+ :return:
+ '''
+
+ def run(start_id, end_id, points3D_desc):
+ for pi in tqdm(range(start_id, end_id), total=end_id - start_id):
+ p3d_id = all_p3d_ids[pi]
+ img_list = self.points3D[p3d_id].image_ids
+ kpt_ids = self.points3D[p3d_id].point2D_idxs
+ all_descs = []
+ for img_id, p2d_id in zip(img_list, kpt_ids):
+ if img_id not in self.images.keys():
+ continue
+ img_fn = self.images[img_id].name
+ desc = feat_file[img_fn]['descriptors'][()].transpose()[p2d_id]
+ all_descs.append(desc)
+
+ if len(all_descs) == 1:
+ points3D_desc[p3d_id] = all_descs[0]
+ else:
+ all_descs = np.array(all_descs) # [n, d]
+ dist = all_descs @ all_descs.transpose() # [n, n]
+ dist = 2 - 2 * dist
+ md_dist = np.median(dist, axis=-1) # [n]
+ min_id = np.argmin(md_dist)
+ points3D_desc[p3d_id] = all_descs[min_id]
+
+ if osp.isfile(save_fn):
+ print('{:s} exists.'.format(save_fn))
+ return
+ p3D_desc = {}
+ feat_file = h5py.File(feature_fn, 'r')
+ all_p3d_ids = sorted(self.points3D.keys())
+
+ if n_process > 1:
+ if len(all_p3d_ids) <= n_process:
+ run(start_id=0, end_id=len(all_p3d_ids), points3D_desc=p3D_desc)
+ else:
+ manager = mp.Manager()
+ output = manager.dict() # necessary otherwise empty
+ n_sample_per_process = len(all_p3d_ids) // n_process
+ jobs = []
+ for i in range(n_process):
+ start_id = i * n_sample_per_process
+ if i == n_process - 1:
+ end_id = len(all_p3d_ids)
+ else:
+ end_id = (i + 1) * n_sample_per_process
+ p = mp.Process(
+ target=run,
+ args=(start_id, end_id, output),
+ )
+ jobs.append(p)
+ p.start()
+
+ for p in jobs:
+ p.join()
+
+ p3D_desc = {}
+ for k in output.keys():
+ p3D_desc[k] = output[k]
+ else:
+ run(start_id=0, end_id=len(all_p3d_ids), points3D_desc=p3D_desc)
+
+ if save_fn is not None:
+ np.save(save_fn, p3D_desc)
+
+ def reproject(self, img_id, xyzs):
+ qvec = self.images[img_id].qvec
+ Rcw = qvec2rotmat(qvec=qvec)
+ tvec = self.images[img_id].tvec
+ tcw = tvec.reshape(3, )
+ Tcw = np.eye(4, dtype=float)
+ Tcw[:3, :3] = Rcw
+ Tcw[:3, 3] = tcw
+ # intrinsics
+ cam = self.cameras[self.images[img_id].camera_id]
+ K = self.get_intrinsics_from_camera(camera=cam)
+
+ xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1), dtype=float)])
+ kpts = K @ ((Tcw @ xyzs_homo.transpose())[:3, :]) # [3, N]
+ kpts = kpts.transpose() # [N, 3]
+ kpts[:, 0] = kpts[:, 0] / kpts[:, 2]
+ kpts[:, 1] = kpts[:, 1] / kpts[:, 2]
+
+ return kpts
+
+ def find_covisible_frame_ids(self, image_id, images, points3D):
+ covis = defaultdict(int)
+ p3d_ids = images[image_id].point3D_ids
+
+ for pid in p3d_ids:
+ if pid == -1:
+ continue
+ if pid not in points3D.keys():
+ continue
+ for im in points3D[pid].image_ids:
+ covis[im] += 1
+
+ covis_ids = np.array(list(covis.keys()))
+ covis_num = np.array([covis[i] for i in covis_ids])
+ ind_top = np.argsort(covis_num)[::-1]
+ sorted_covis_ids = [covis_ids[i] for i in ind_top]
+ return sorted_covis_ids
+
+ def create_virtual_frame_3(self, save_fn=None, save_vrf_dir=None, show_time=-1, ignored_cameras=[],
+ min_cover_ratio=0.9,
+ depth_scale=1.2,
+ radius=15,
+ min_obs=120,
+ topk_imgs=500,
+ n_vrf=10,
+ covisible_frame=20,
+ **kwargs):
+ def reproject(img_id, xyzs):
+ qvec = self.images[img_id].qvec
+ Rcw = qvec2rotmat(qvec=qvec)
+ tvec = self.images[img_id].tvec
+ tcw = tvec.reshape(3, )
+ Tcw = np.eye(4, dtype=float)
+ Tcw[:3, :3] = Rcw
+ Tcw[:3, 3] = tcw
+ # intrinsics
+ cam = self.cameras[self.images[img_id].camera_id]
+ K = self.get_intrinsics_from_camera(camera=cam)
+
+ xyzs_homo = np.hstack([xyzs, np.ones(shape=(xyzs.shape[0], 1), dtype=float)])
+ kpts = K @ ((Tcw @ xyzs_homo.transpose())[:3, :]) # [3, N]
+ kpts = kpts.transpose() # [N, 3]
+ kpts[:, 0] = kpts[:, 0] / kpts[:, 2]
+ kpts[:, 1] = kpts[:, 1] / kpts[:, 2]
+
+ return kpts
+
+ def find_best_vrf_by_covisibility(p3d_id_list):
+ all_img_ids = []
+ all_xyzs = []
+
+ img_ids_full = []
+ img_id_obs = {}
+ for pid in p3d_id_list:
+ if pid not in self.points3D.keys():
+ continue
+ all_xyzs.append(self.points3D[pid].xyz)
+
+ img_ids = self.points3D[pid].image_ids
+ for iid in img_ids:
+ if iid in all_img_ids:
+ continue
+ # valid_p3ds = [v for v in self.images[iid].point3D_ids if v > 0 and v in p3d_id_list]
+ if len(ignored_cameras) > 0:
+ ignore = False
+ img_name = self.images[iid].name
+ for c in ignored_cameras:
+ if img_name.find(c) >= 0:
+ ignore = True
+ break
+ if ignore:
+ continue
+ # valid_p3ds = np.intersect1d(np.array(self.images[iid].point3D_ids), np.array(p3d_id_list)).tolist()
+ valid_p3ds = [v for v in self.images[iid].point3D_ids if v > 0]
+ img_ids_full.append(iid)
+ if len(valid_p3ds) < min_obs:
+ continue
+
+ all_img_ids.append(iid)
+ img_id_obs[iid] = len(valid_p3ds)
+ all_xyzs = np.array(all_xyzs)
+
+ print('Find {} 3D points and {} images'.format(len(p3d_id_list), len(img_id_obs.keys())))
+ top_img_ids_by_obs = sorted(img_id_obs.items(), key=lambda item: item[1], reverse=True) # [(key, value), ]
+ all_img_ids = []
+ for item in top_img_ids_by_obs:
+ all_img_ids.append(item[0])
+ if len(all_img_ids) >= topk_imgs:
+ break
+
+ # all_img_ids = all_img_ids[:200]
+ if len(all_img_ids) == 0:
+ print('no valid img ids with obs over {:d}'.format(min_obs))
+ all_img_ids = img_ids_full
+
+ img_observations = {}
+ p3d_id_array = np.array(p3d_id_list)
+ for idx, img_id in enumerate(all_img_ids):
+ valid_p3ds = [v for v in self.images[img_id].point3D_ids if v > 0]
+ mask = np.array([False for i in range(len(p3d_id_list))])
+ for pid in valid_p3ds:
+ found_idx = np.where(p3d_id_array == pid)[0]
+ if found_idx.shape[0] == 0:
+ continue
+ mask[found_idx[0]] = True
+
+ img_observations[img_id] = mask
+
+ unobserved_p3d_ids = np.array([True for i in range(len(p3d_id_list))])
+
+ candidate_img_ids = []
+ total_cover_ratio = 0
+ while total_cover_ratio < min_cover_ratio:
+ best_img_id = -1
+ best_img_obs = -1
+ for idx, im_id in enumerate(all_img_ids):
+ if im_id in candidate_img_ids:
+ continue
+ obs_i = np.sum(img_observations[im_id] * unobserved_p3d_ids)
+ if obs_i > best_img_obs:
+ best_img_id = im_id
+ best_img_obs = obs_i
+
+ if best_img_id >= 0:
+ # keep the valid img_id
+ candidate_img_ids.append(best_img_id)
+ # update the unobserved mask
+ unobserved_p3d_ids[img_observations[best_img_id]] = False
+ total_cover_ratio = 1 - np.sum(unobserved_p3d_ids) / len(p3d_id_list)
+ print(len(candidate_img_ids), best_img_obs, best_img_obs / len(p3d_id_list), total_cover_ratio)
+
+ if best_img_obs / len(p3d_id_list) < 0.01:
+ break
+
+ if len(candidate_img_ids) >= n_vrf:
+ break
+ else:
+ break
+
+ return candidate_img_ids
+ # return [(v, img_observations[v]) for v in candidate_img_ids]
+
+ if save_vrf_dir is not None:
+ os.makedirs(save_vrf_dir, exist_ok=True)
+
+ seg_ref = {}
+ for sid in self.seg_p3d.keys():
+ if sid == -1: # ignore invalid segment
+ continue
+ all_p3d_ids = self.seg_p3d[sid]
+ candidate_img_ids = find_best_vrf_by_covisibility(p3d_id_list=all_p3d_ids)
+
+ seg_ref[sid] = {}
+ for can_idx, img_id in enumerate(candidate_img_ids):
+ cam = self.cameras[self.images[img_id].camera_id]
+ width = cam.width
+ height = cam.height
+ qvec = self.images[img_id].qvec
+ tvec = self.images[img_id].tvec
+
+ img_name = self.images[img_id].name
+ orig_p3d_ids = [p for p in self.images[img_id].point3D_ids if p in self.points3D.keys() and p >= 0]
+ orig_xyzs = []
+ new_xyzs = []
+ for pid in all_p3d_ids:
+ if pid in orig_p3d_ids:
+ orig_xyzs.append(self.points3D[pid].xyz)
+ else:
+ if pid in self.points3D.keys():
+ new_xyzs.append(self.points3D[pid].xyz)
+
+ if len(orig_xyzs) == 0:
+ continue
+
+ orig_xyzs = np.array(orig_xyzs)
+ new_xyzs = np.array(new_xyzs)
+
+ print('img: ', osp.join(kwargs.get('image_root'), img_name))
+ img = cv2.imread(osp.join(kwargs.get('image_root'), img_name))
+ orig_kpts = reproject(img_id=img_id, xyzs=orig_xyzs)
+ max_depth = depth_scale * np.max(orig_kpts[:, 2])
+ orig_kpts = orig_kpts[:, :2]
+ mask_ori = (orig_kpts[:, 0] >= 0) & (orig_kpts[:, 0] < width) & (orig_kpts[:, 1] >= 0) & (
+ orig_kpts[:, 1] < height)
+ orig_kpts = orig_kpts[mask_ori]
+
+ if orig_kpts.shape[0] == 0:
+ continue
+
+ img_kpt = plot_kpts(img=img, kpts=orig_kpts, radius=[3 for i in range(orig_kpts.shape[0])],
+ colors=[(0, 0, 255) for i in range(orig_kpts.shape[0])], thickness=-1)
+ if new_xyzs.shape[0] == 0:
+ img_all = img_kpt
+ else:
+ new_kpts = reproject(img_id=img_id, xyzs=new_xyzs)
+ mask_depth = (new_kpts[:, 2] > 0) & (new_kpts[:, 2] <= max_depth)
+ mask_in_img = (new_kpts[:, 0] >= 0) & (new_kpts[:, 0] < width) & (new_kpts[:, 1] >= 0) & (
+ new_kpts[:, 1] < height)
+ dist_all_orig = torch.from_numpy(new_kpts[:, :2])[..., None] - \
+ torch.from_numpy(orig_kpts[:, :2].transpose())[None]
+ dist_all_orig = torch.sqrt(torch.sum(dist_all_orig ** 2, dim=1)) # [N, M]
+ min_dist = torch.min(dist_all_orig, dim=1)[0].numpy()
+ mask_close_to_img = (min_dist <= radius)
+
+ mask_new = (mask_depth & mask_in_img & mask_close_to_img)
+
+ cover_ratio = np.sum(mask_ori) + np.sum(mask_new)
+ cover_ratio = cover_ratio / len(all_p3d_ids)
+
+ print('idx: {:d}, img: ori {:d}/{:d}/{:.2f}, new {:d}/{:d}'.format(can_idx,
+ orig_kpts.shape[0],
+ np.sum(mask_ori),
+ cover_ratio * 100,
+ new_kpts.shape[0],
+ np.sum(mask_new)))
+
+ new_kpts = new_kpts[mask_new]
+
+ # img_all = img_kpt
+ img_all = plot_kpts(img=img_kpt, kpts=new_kpts, radius=[3 for i in range(new_kpts.shape[0])],
+ colors=[(0, 255, 0) for i in range(new_kpts.shape[0])], thickness=-1)
+
+ cv2.namedWindow('img', cv2.WINDOW_NORMAL)
+ cv2.imshow('img', img_all)
+
+ if save_vrf_dir is not None:
+ cv2.imwrite(osp.join(save_vrf_dir,
+ 'seg-{:05d}_can-{:05d}_'.format(sid, can_idx) + img_name.replace('/', '+')),
+ img_all)
+
+ key = cv2.waitKey(show_time)
+ if key == ord('q'):
+ cv2.destroyAllWindows()
+ exit(0)
+
+ covisile_frame_ids = self.find_covisible_frame_ids(image_id=img_id, images=self.images,
+ points3D=self.points3D)
+ seg_ref[sid][can_idx] = {
+ 'image_name': img_name,
+ 'image_id': img_id,
+ 'qvec': deepcopy(qvec),
+ 'tvec': deepcopy(tvec),
+ 'camera': {
+ 'model': cam.model,
+ 'params': cam.params,
+ 'width': cam.width,
+ 'height': cam.height,
+ },
+ 'original_points3d': np.array(
+ [v for v in self.images[img_id].point3D_ids if v >= 0 and v in self.points3D.keys()]),
+ 'covisible_frame_ids': np.array(covisile_frame_ids[:covisible_frame]),
+ }
+ # save vrf info
+ if save_fn is not None:
+ print('Save {} segments with virtual reference image information to {}'.format(len(seg_ref.keys()),
+ save_fn))
+ np.save(save_fn, seg_ref)
+
+ def visualize_3Dpoints(self):
+ xyz = []
+ rgb = []
+ for point3D in self.points3D.values():
+ xyz.append(point3D.xyz)
+ rgb.append(point3D.rgb / 255)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(xyz)
+ pcd.colors = o3d.utility.Vector3dVector(rgb)
+ o3d.visualization.draw_geometries([pcd])
+
+ def visualize_segmentation(self, p3d_segs, points3D):
+ p3d_ids = p3d_segs.keys()
+ xyzs = []
+ rgbs = []
+ for pid in p3d_ids:
+ xyzs.append(points3D[pid].xyz)
+ seg_color = self.seg_color_dict[p3d_segs[pid]]
+ rgbs.append(np.array([seg_color[2], seg_color[1], seg_color[0]]) / 255)
+ xyzs = np.array(xyzs)
+ rgbs = np.array(rgbs)
+
+ self.pcd.points = o3d.utility.Vector3dVector(xyzs)
+ self.pcd.colors = o3d.utility.Vector3dVector(rgbs)
+
+ o3d.visualization.draw_geometries([self.pcd])
+
+ def visualize_segmentation_on_image(self, p3d_segs, image_path, feat_path):
+ vis_color = generate_color_dic(n_seg=1024)
+ feat_file = h5py.File(feat_path, 'r')
+
+ cv2.namedWindow('img', cv2.WINDOW_NORMAL)
+ for mi in sorted(self.images.keys()):
+ im = self.images[mi]
+ im_name = im.name
+ p3d_ids = im.point3D_ids
+ p2ds = feat_file[im_name]['keypoints'][()]
+ image = cv2.imread(osp.join(image_path, im_name))
+ print('img_name: ', im_name)
+
+ sems = []
+ for pid in p3d_ids:
+ if pid in p3d_segs.keys():
+ sems.append(p3d_segs[pid] + 1)
+ else:
+ sems.append(0)
+ sems = np.array(sems)
+
+ sems = np.array(sems)
+ mask = sems > 0
+ img_seg = vis_seg_point(img=image, kpts=p2ds[mask], segs=sems[mask], seg_color=vis_color)
+
+ cv2.imshow('img', img_seg)
+ key = cv2.waitKey(0)
+ if key == ord('q'):
+ exit(0)
+ elif key == ord('r'):
+ # cv2.destroyAllWindows()
+ return
+
+ def extract_query_p3ds(self, log_fn, feat_fn, save_fn=None):
+ if save_fn is not None:
+ if osp.isfile(save_fn):
+ print('{:s} exists'.format(save_fn))
+ return
+
+ loc_log = np.load(log_fn, allow_pickle=True)[()]
+ fns = loc_log.keys()
+ feat_file = h5py.File(feat_fn, 'r')
+
+ out = {}
+ for fn in tqdm(fns, total=len(fns)):
+ matched_kpts = loc_log[fn]['keypoints_query']
+ matched_p3ds = loc_log[fn]['points3D_ids']
+
+ query_kpts = feat_file[fn]['keypoints'][()].astype(float)
+ query_p3d_ids = np.zeros(shape=(query_kpts.shape[0],), dtype=int) - 1
+ print('matched kpts: {}, query kpts: {}'.format(matched_kpts.shape[0], query_kpts.shape[0]))
+
+ if matched_kpts.shape[0] > 0:
+ # [M, 2, 1] - [1, 2, N] = [M, 2, N]
+ dist = torch.from_numpy(matched_kpts).unsqueeze(-1) - torch.from_numpy(
+ query_kpts.transpose()).unsqueeze(0)
+ dist = torch.sum(dist ** 2, dim=1) # [M, N]
+ values, idxes = torch.topk(dist, dim=1, largest=False, k=1) # find the matches kpts with dist of 0
+ values = values.numpy()
+ idxes = idxes.numpy()
+ for i in range(values.shape[0]):
+ if values[i, 0] < 1:
+ query_p3d_ids[idxes[i, 0]] = matched_p3ds[i]
+
+ out[fn] = query_p3d_ids
+ np.save(save_fn, out)
+ feat_file.close()
+
+ def compute_mean_scale_p3ds(self, min_obs=5, save_fn=None):
+ if save_fn is not None:
+ if osp.isfile(save_fn):
+ with open(save_fn, 'r') as f:
+ lines = f.readlines()
+ l = lines[0].strip().split()
+ self.mean_xyz = np.array([float(v) for v in l[:3]])
+ self.scale_xyz = np.array([float(v) for v in l[3:]])
+ print('{} exists'.format(save_fn))
+ return
+
+ all_xyzs = []
+ for pid in self.points3D.keys():
+ p3d = self.points3D[pid]
+ obs = len(p3d.point2D_idxs)
+ if obs < min_obs:
+ continue
+ all_xyzs.append(p3d.xyz)
+
+ all_xyzs = np.array(all_xyzs)
+ mean_xyz = np.ceil(np.mean(all_xyzs, axis=0))
+ all_xyz_ = all_xyzs - mean_xyz
+
+ dx = np.max(abs(all_xyz_[:, 0]))
+ dy = np.max(abs(all_xyz_[:, 1]))
+ dz = np.max(abs(all_xyz_[:, 2]))
+ scale_xyz = np.ceil(np.array([dx, dy, dz], dtype=float).reshape(3, ))
+ scale_xyz[scale_xyz < 1] = 1
+ scale_xyz[scale_xyz == 0] = 1
+
+ # self.mean_xyz = mean_xyz
+ # self.scale_xyz = scale_xyz
+ #
+ # if save_fn is not None:
+ # with open(save_fn, 'w') as f:
+ # text = '{:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'.format(mean_xyz[0], mean_xyz[1], mean_xyz[2],
+ # scale_xyz[0], scale_xyz[1], scale_xyz[2])
+ # f.write(text + '\n')
+
+ def compute_statics_inlier(self, xyz, nb_neighbors=20, std_ratio=2.0):
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(xyz)
+
+ new_pcd, inlier_ids = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
+ return inlier_ids
+
+ def export_features_to_directory(self, feat_fn, save_dir, with_descriptors=True):
+ def print_grp_name(grp_name, object):
+ try:
+ n_subgroups = len(object.keys())
+ except:
+ n_subgroups = 0
+ dataset_list.append(object.name)
+
+ dataset_list = []
+ feat_file = h5py.File(feat_fn, 'r')
+ feat_file.visititems(print_grp_name)
+ all_keys = []
+ os.makedirs(save_dir, exist_ok=True)
+ for fn in dataset_list:
+ subs = fn[1:].split('/')[:-1] # remove the first '/'
+ subs = '/'.join(map(str, subs))
+ if subs in all_keys:
+ continue
+ all_keys.append(subs)
+
+ for fn in tqdm(all_keys, total=len(all_keys)):
+ feat = feat_file[fn]
+ data = {
+ # 'descriptors': feat['descriptors'][()].transpose(),
+ 'scores': feat['scores'][()],
+ 'keypoints': feat['keypoints'][()],
+ 'image_size': feat['image_size'][()]
+ }
+ np.save(osp.join(save_dir, fn.replace('/', '+')), data)
+ feat_file.close()
+
+ def get_intrinsics_from_camera(self, camera):
+ if camera.model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
+ fx = fy = camera.params[0]
+ cx = camera.params[1]
+ cy = camera.params[2]
+ elif camera.model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"):
+ fx = camera.params[0]
+ fy = camera.params[1]
+ cx = camera.params[2]
+ cy = camera.params[3]
+ else:
+ raise Exception("Camera model not supported")
+
+ # intrinsics
+ K = np.identity(3)
+ K[0, 0] = fx
+ K[1, 1] = fy
+ K[0, 2] = cx
+ K[1, 2] = cy
+ return K
+
+ def compress_map_by_projection_v2(self, vrf_path, point3d_desc_path, vrf_frames=1, covisible_frames=20, radius=20,
+ nkpts=-1, save_dir=None):
+ def sparsify_by_grid(h, w, uvs, scores):
+ nh = np.ceil(h / radius).astype(int)
+ nw = np.ceil(w / radius).astype(int)
+ grid = {}
+ for ip in range(uvs.shape[0]):
+ p = uvs[ip]
+ iw = np.rint(p[0] // radius).astype(int)
+ ih = np.rint(p[1] // radius).astype(int)
+ idx = ih * nw + iw
+ if idx in grid.keys():
+ if scores[ip] <= grid[idx]['score']:
+ continue
+ else:
+ grid[idx]['score'] = scores[ip]
+ grid[idx]['ip'] = ip
+ else:
+ grid[idx] = {
+ 'score': scores[ip],
+ 'ip': ip
+ }
+
+ retained_ips = [grid[v]['ip'] for v in grid.keys()]
+ retained_ips = np.array(retained_ips)
+ return retained_ips
+
+ def choose_valid_p3ds(current_frame_id, covisible_frame_ids, reserved_images):
+ curr_p3d_ids = []
+ curr_xyzs = []
+ for pid in self.images[current_frame_id].point3D_ids:
+ if pid == -1:
+ continue
+ if pid not in self.points3D.keys():
+ continue
+ curr_p3d_ids.append(pid)
+ curr_xyzs.append(self.points3D[pid].xyz)
+ curr_xyzs = np.array(curr_xyzs) # [N, 3]
+ curr_xyzs_homo = np.hstack([curr_xyzs, np.ones((curr_xyzs.shape[0], 1), dtype=curr_xyzs.dtype)]) # [N, 4]
+
+ curr_mask = np.array([True for mi in range(curr_xyzs.shape[0])]) # keep all at first
+ for iim in covisible_frame_ids:
+ cam_id = self.images[iim].camera_id
+ width = self.cameras[cam_id].width
+ height = self.cameras[cam_id].height
+ qvec = self.images[iim].qvec
+ tcw = self.images[iim].tvec
+ Rcw = qvec2rotmat(qvec=qvec)
+ Tcw = np.eye(4, dtype=float)
+ Tcw[:3, :3] = Rcw
+ Tcw[:3, 3] = tcw.reshape(3, )
+
+ uvs = reserved_images[iim]['xys']
+ K = self.get_intrinsics_from_camera(camera=self.cameras[cam_id])
+ proj_xys = K @ (Tcw @ curr_xyzs_homo.transpose())[:3, :] # [3, ]
+ proj_xys = proj_xys.transpose()
+ depth = proj_xys[:, 2]
+ proj_xys[:, 0] = proj_xys[:, 0] / depth
+ proj_xys[:, 1] = proj_xys[:, 1] / depth
+
+ mask_in_image = (proj_xys[:, 0] >= 0) * (proj_xys[:, 0] < width) * (proj_xys[:, 1] >= 0) * (
+ proj_xys[:, 1] < height)
+ mask_depth = proj_xys[:, 2] > 0
+
+ dist_proj_uv = torch.from_numpy(proj_xys[:, :2])[..., None] - \
+ torch.from_numpy(uvs[:, :2].transpose())[None]
+ dist_proj_uv = torch.sqrt(torch.sum(dist_proj_uv ** 2, dim=1)) # [N, M]
+ min_dist = torch.min(dist_proj_uv, dim=1)[0].numpy()
+ mask_close_to_img = (min_dist <= radius)
+
+ mask = mask_in_image * mask_depth * mask_close_to_img # p3ds to be discarded
+
+ curr_mask = curr_mask * (1 - mask)
+
+ chosen_p3d_ids = []
+ for mi in range(curr_mask.shape[0]):
+ if curr_mask[mi]:
+ chosen_p3d_ids.append(curr_p3d_ids[mi])
+
+ return chosen_p3d_ids
+
+ vrf_data = np.load(vrf_path, allow_pickle=True)[()]
+ p3d_ids_in_vrf = []
+ image_ids_in_vrf = []
+ for sid in vrf_data.keys():
+ svrf = vrf_data[sid]
+ svrf_keys = [vi for vi in range(vrf_frames)]
+ for vi in svrf_keys:
+ if vi not in svrf.keys():
+ continue
+ image_id = svrf[vi]['image_id']
+ if image_id in image_ids_in_vrf:
+ continue
+ image_ids_in_vrf.append(image_id)
+ for pid in svrf[vi]['original_points3d']:
+ if pid in p3d_ids_in_vrf:
+ continue
+ p3d_ids_in_vrf.append(pid)
+
+ print('Find {:d} images and {:d} 3D points in vrf'.format(len(image_ids_in_vrf), len(p3d_ids_in_vrf)))
+
+ # first_vrf_images_covis = {}
+ retained_image_ids = {}
+ for frame_id in image_ids_in_vrf:
+ observed = self.images[frame_id].point3D_ids
+ xys = self.images[frame_id].xys
+ covis = defaultdict(int)
+ valid_xys = []
+ valid_p3d_ids = []
+ for xy, pid in zip(xys, observed):
+ if pid == -1:
+ continue
+ if pid not in self.points3D.keys():
+ continue
+ valid_xys.append(xy)
+ valid_p3d_ids.append(pid)
+ for img_id in self.points3D[pid].image_ids:
+ covis[img_id] += 1
+
+ retained_image_ids[frame_id] = {
+ 'xys': np.array(valid_xys),
+ 'p3d_ids': valid_p3d_ids,
+ }
+
+ print('Find {:d} valid connected frames'.format(len(covis.keys())))
+
+ covis_ids = np.array(list(covis.keys()))
+ covis_num = np.array([covis[i] for i in covis_ids])
+
+ if len(covis_ids) <= covisible_frames:
+ sel_covis_ids = covis_ids[np.argsort(-covis_num)]
+ else:
+ ind_top = np.argpartition(covis_num, -covisible_frames)
+ ind_top = ind_top[-covisible_frames:] # unsorted top k
+ ind_top = ind_top[np.argsort(-covis_num[ind_top])]
+ sel_covis_ids = [covis_ids[i] for i in ind_top]
+
+ covis_frame_ids = [frame_id]
+ for iim in sel_covis_ids:
+ if iim == frame_id:
+ continue
+ if iim in retained_image_ids.keys():
+ covis_frame_ids.append(iim)
+ continue
+
+ chosen_p3d_ids = choose_valid_p3ds(current_frame_id=iim, covisible_frame_ids=covis_frame_ids,
+ reserved_images=retained_image_ids)
+ if len(chosen_p3d_ids) == 0:
+ continue
+
+ xys = []
+ for xy, pid in zip(self.images[iim].xys, self.images[iim].point3D_ids):
+ if pid in chosen_p3d_ids:
+ xys.append(xy)
+ xys = np.array(xys)
+
+ covis_frame_ids.append(iim)
+ retained_image_ids[iim] = {
+ 'xys': xys,
+ 'p3d_ids': chosen_p3d_ids,
+ }
+
+ new_images = {}
+ new_point3Ds = {}
+ new_cameras = {}
+ for iim in retained_image_ids.keys():
+ p3d_ids = retained_image_ids[iim]['p3d_ids']
+ ''' this step reduces the performance
+ for v in self.images[iim].point3D_ids:
+ if v == -1 or v not in self.points3D:
+ continue
+ if v in p3d_ids:
+ continue
+ p3d_ids.append(v)
+ '''
+
+ xyzs = np.array([self.points3D[pid].xyz for pid in p3d_ids])
+ obs = np.array([len(self.points3D[pid].point2D_idxs) for pid in p3d_ids])
+ xys = self.images[iim].xys
+ cam_id = self.images[iim].camera_id
+ name = self.images[iim].name
+ qvec = self.images[iim].qvec
+ tvec = self.images[iim].tvec
+
+ if nkpts > 0 and len(p3d_ids) > nkpts:
+ proj_uvs = self.reproject(img_id=iim, xyzs=xyzs)
+ width = self.cameras[cam_id].width
+ height = self.cameras[cam_id].height
+ sparsified_idxs = sparsify_by_grid(h=height, w=width, uvs=proj_uvs[:, :2], scores=obs)
+
+ print('org / new kpts: ', len(p3d_ids), sparsified_idxs.shape)
+
+ p3d_ids = [p3d_ids[k] for k in sparsified_idxs]
+
+ new_images[iim] = Image(id=iim, qvec=qvec, tvec=tvec,
+ camera_id=cam_id,
+ name=name,
+ xys=np.array([]),
+ point3D_ids=np.array(p3d_ids))
+
+ if cam_id not in new_cameras.keys():
+ new_cameras[cam_id] = self.cameras[cam_id]
+
+ for pid in p3d_ids:
+ if pid in new_point3Ds.keys():
+ new_point3Ds[pid]['image_ids'].append(iim)
+ else:
+ xyz = self.points3D[pid].xyz
+ rgb = self.points3D[pid].rgb
+ error = self.points3D[pid].error
+
+ new_point3Ds[pid] = {
+ 'image_ids': [iim],
+ 'rgb': rgb,
+ 'xyz': xyz,
+ 'error': error
+ }
+
+ new_point3Ds_to_save = {}
+ for pid in new_point3Ds.keys():
+ image_ids = new_point3Ds[pid]['image_ids']
+ if len(image_ids) == 0:
+ continue
+ xyz = new_point3Ds[pid]['xyz']
+ rgb = new_point3Ds[pid]['rgb']
+ error = new_point3Ds[pid]['error']
+
+ new_point3Ds_to_save[pid] = Point3D(id=pid, xyz=xyz, rgb=rgb, error=error, image_ids=np.array(image_ids),
+ point2D_idxs=np.array([]))
+
+ print('Retain {:d}/{:d} images and {:d}/{:d} 3D points'.format(len(new_images), len(self.images),
+ len(new_point3Ds), len(self.points3D)))
+
+ if save_dir is not None:
+ os.makedirs(save_dir, exist_ok=True)
+ # write_images_binary(images=new_image_ids,
+ # path_to_model_file=osp.join(save_dir, 'images.bin'))
+ # write_points3d_binary(points3D=new_point3Ds,
+ # path_to_model_file=osp.join(save_dir, 'points3D.bin'))
+ write_compressed_images_binary(images=new_images,
+ path_to_model_file=osp.join(save_dir, 'images.bin'))
+ write_cameras_binary(cameras=new_cameras,
+ path_to_model_file=osp.join(save_dir, 'cameras.bin'))
+ write_compressed_points3d_binary(points3D=new_point3Ds_to_save,
+ path_to_model_file=osp.join(save_dir, 'points3D.bin'))
+
+ # Save 3d descriptors
+ p3d_desc = np.load(point3d_desc_path, allow_pickle=True)[()]
+ comp_p3d_desc = {}
+ for k in new_point3Ds_to_save.keys():
+ if k not in p3d_desc.keys():
+ print(k)
+ continue
+ comp_p3d_desc[k] = deepcopy(p3d_desc[k])
+ np.save(osp.join(save_dir, point3d_desc_path.split('/')[-1]), comp_p3d_desc)
+ print('Save data to {:s}'.format(save_dir))
+
+
+def process_dataset(dataset, dataset_dir, sfm_dir, save_dir, feature='sfd2', matcher='gml'):
+ # dataset_dir = '/scratches/flyer_3/fx221/dataset'
+ # sfm_dir = '/scratches/flyer_2/fx221/localization/outputs' # your sfm results (cameras, images, points3D) and features
+ # save_dir = '/scratches/flyer_3/fx221/exp/localizer'
+ # local_feat = 'sfd2'
+ # matcher = 'gml'
+ # hloc_results_dir = '/scratches/flyer_2/fx221/exp/sgd2'
+
+ # config_path = 'configs/datasets/CUED.yaml'
+ # config_path = 'configs/datasets/7Scenes.yaml'
+ # config_path = 'configs/datasets/12Scenes.yaml'
+ # config_path = 'configs/datasets/CambridgeLandmarks.yaml'
+ # config_path = 'configs/datasets/Aachen.yaml'
+
+ # config_path = 'configs/datasets/Aria.yaml'
+ # config_path = 'configs/datasets/DarwinRGB.yaml'
+ # config_path = 'configs/datasets/ACUED.yaml'
+ # config_path = 'configs/datasets/JesusCollege.yaml'
+ # config_path = 'configs/datasets/CUED2Kings.yaml'
+
+ config_path = 'configs/datasets/{:s}.yaml'.format(dataset)
+ with open(config_path, 'rt') as f:
+ configs = yaml.load(f, Loader=yaml.Loader)
+ print(configs)
+
+ dataset = configs['dataset']
+ all_scenes = configs['scenes']
+ for scene in all_scenes:
+ n_cluster = configs[scene]['n_cluster']
+ cluster_mode = configs[scene]['cluster_mode']
+ cluster_method = configs[scene]['cluster_method']
+ # if scene not in ['heads']:
+ # continue
+
+ print('scene: ', scene, cluster_mode, cluster_method)
+ # hloc_path = osp.join(hloc_root, dataset, scene)
+ sfm_path = osp.join(sfm_dir, scene)
+ save_path = osp.join(save_dir, feature + '-' + matcher, dataset, scene)
+
+ n_vrf = 1
+ n_cov = 30
+ radius = 20
+ n_kpts = 0
+
+ if dataset in ['Aachen']:
+ image_path = osp.join(dataset_dir, scene, 'images/images_upright')
+ min_obs = 250
+ filtering_outliers = True
+ threshold = 0.2
+ radius = 32
+
+ elif dataset in ['CambridgeLandmarks', ]:
+ image_path = osp.join(dataset_dir, scene)
+ min_obs = 250
+ filtering_outliers = True
+ threshold = 0.2
+ radius = 64
+ elif dataset in ['Aria']:
+ image_path = osp.join(dataset_dir, scene)
+ min_obs = 150
+ filtering_outliers = False
+ threshold = 0.01
+ radius = 15
+ elif dataset in ['DarwinRGB']:
+ image_path = osp.join(dataset_dir, scene)
+ min_obs = 150
+ filtering_outliers = True
+ threshold = 0.2
+ radius = 16
+ elif dataset in ['ACUED']:
+ image_path = osp.join(dataset_dir, scene)
+ min_obs = 250
+ filtering_outliers = True
+ threshold = 0.2
+ radius = 32
+ elif dataset in ['7Scenes', '12Scenes']:
+ image_path = osp.join(dataset_dir, scene)
+ min_obs = 150
+ filtering_outliers = False
+ threshold = 0.01
+ radius = 15
+ else:
+ image_path = osp.join(dataset_dir, scene)
+ min_obs = 250
+ filtering_outliers = True
+ threshold = 0.2
+ radius = 32
+
+ # comp_map_sub_path = 'comp_model_n{:d}_{:s}_{:s}_vrf{:d}_cov{:d}_r{:d}_np{:d}_projection_v2'.format(n_cluster,
+ # cluster_mode,
+ # cluster_method,
+ # n_vrf,
+ # n_cov,
+ # radius,
+ # n_kpts)
+ comp_map_sub_path = 'compress_model_{:s}'.format(cluster_method)
+ seg_fn = osp.join(save_path,
+ 'point3D_cluster_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method))
+ vrf_fn = osp.join(save_path,
+ 'point3D_vrf_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method))
+ vrf_img_dir = osp.join(save_path,
+ 'point3D_vrf_n{:d}_{:s}_{:s}'.format(n_cluster, cluster_mode, cluster_method))
+ # p3d_query_fn = osp.join(save_path,
+ # 'point3D_query_n{:d}_{:s}_{:s}.npy'.format(n_cluster, cluster_mode, cluster_method))
+ comp_map_path = osp.join(save_path, comp_map_sub_path)
+
+ os.makedirs(save_path, exist_ok=True)
+
+ rmap = RecMap()
+ rmap.load_sfm_model(path=osp.join(sfm_path, 'sfm_{:s}-{:s}'.format(feature, matcher)))
+ if filtering_outliers:
+ rmap.remove_statics_outlier(nb_neighbors=20, std_ratio=2.0)
+
+ # extract keypoints to train the recognition model (descriptors are recomputed from augmented db images)
+ # we do this for ddp training (reading h5py file is not supported)
+ rmap.export_features_to_directory(feat_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(feature)),
+ save_dir=osp.join(save_path, 'feats')) # only once for training
+
+ rmap.cluster(k=n_cluster, mode=cluster_mode, save_fn=seg_fn, method=cluster_method, threshold=threshold)
+ # rmap.visualize_3Dpoints()
+ rmap.load_segmentation(path=seg_fn)
+ # rmap.visualize_segmentation(p3d_segs=rmap.p3d_seg, points3D=rmap.points3D)
+
+ # Assign each 3D point a desciptor and discard all 2D images and descriptors - for localization
+ rmap.assign_point3D_descriptor(
+ feature_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(feature)),
+ save_fn=osp.join(save_path, 'point3D_desc.npy'.format(n_cluster, cluster_mode)),
+ n_process=32) # only once
+
+ # exit(0)
+ # rmap.visualize_segmentation_on_image(p3d_segs=rmap.p3d_seg, image_path=image_path, feat_path=feat_path)
+
+ # for query images only - for evaluation
+ # rmap.extract_query_p3ds(
+ # log_fn=osp.join(hloc_path, 'hloc_feats-{:s}_{:s}_loc.npy'.format(local_feat, matcher)),
+ # feat_fn=osp.join(sfm_path, 'feats-{:s}.h5'.format(local_feat)),
+ # save_fn=p3d_query_fn)
+ # continue
+
+ # up-to-date
+ rmap.create_virtual_frame_3(
+ save_fn=vrf_fn,
+ save_vrf_dir=vrf_img_dir,
+ image_root=image_path,
+ show_time=5,
+ min_cover_ratio=0.9,
+ radius=radius,
+ depth_scale=2.5, # 1.2 by default
+ min_obs=min_obs,
+ n_vrf=10,
+ covisible_frame=n_cov,
+ ignored_cameras=[])
+
+ # up-to-date
+ rmap.compress_map_by_projection_v2(
+ vrf_frames=n_vrf,
+ vrf_path=vrf_fn,
+ point3d_desc_path=osp.join(save_path, 'point3D_desc.npy'),
+ save_dir=comp_map_path,
+ covisible_frames=n_cov,
+ radius=radius,
+ nkpts=n_kpts,
+ )
+
+ # exit(0)
+ # soft_link_compress_path = osp.join(save_path, 'compress_model_{:s}'.format(cluster_method))
+ os.chdir(save_path)
+ # if osp.isdir(soft_link_compress_path):
+ # os.unlink(soft_link_compress_path)
+ # os.symlink(comp_map_sub_path, 'compress_model_{:s}'.format(cluster_method))
+ # create a soft link of the full model for training
+ if not osp.isdir('model'):
+ os.symlink(osp.join(sfm_path, 'sfm_{:s}-{:s}'.format(feature, matcher)), '3D-models')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset', type=str, required=True, help='dataset name')
+ parser.add_argument('--dataset_dir', type=str, required=True, help='dataset dir')
+ parser.add_argument('--sfm_dir', type=str, required=True, help='sfm dir')
+ parser.add_argument('--save_dir', type=str, required=True, help='dir to save the landmarks data')
+ parser.add_argument('--feature', type=str, default='sfd2', help='feature name e.g., SP, SFD2')
+ parser.add_argument('--matcher', type=str, default='gml', help='matcher name e.g., SG, LSG, gml')
+
+ args = parser.parse_args()
+
+ process_dataset(
+ dataset=args.dataset,
+ dataset_dir=args.dataset_dir,
+ sfm_dir=args.sfm_dir,
+ save_dir=args.save_dir,
+ feature=args.feature,
+ matcher=args.matcher)
diff --git a/third_party/pram/recognition/vis_seg.py b/third_party/pram/recognition/vis_seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ef9b2365787e5921a66c74ff6c0b5ec3e49a31a
--- /dev/null
+++ b/third_party/pram/recognition/vis_seg.py
@@ -0,0 +1,225 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> vis_seg
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 11:06
+=================================================='''
+import cv2
+import numpy as np
+from copy import deepcopy
+
+
+def myHash(text: str):
+ hash = 0
+ for ch in text:
+ hash = (hash * 7879 ^ ord(ch) * 5737) & 0xFFFFFFFF
+ return hash
+
+
+def generate_color_dic(n_seg=1000):
+ out = {}
+ for i in range(n_seg + 1):
+ sid = i
+ if sid == 0:
+ color = (0, 0, 255) # [b, g, r]
+ else:
+ # rgb_new = hash(str(sid * 319993))
+ rgb_new = myHash(str(sid * 319993))
+ r = (rgb_new & 0xFF0000) >> 16
+ g = (rgb_new & 0x00FF00) >> 8
+ b = rgb_new & 0x0000FF
+ color = (b, g, r)
+ out[i] = color
+ return out
+
+
+def vis_seg_point(img, kpts, segs=None, seg_color=None, radius=7, thickness=-1):
+ outimg = deepcopy(img)
+ for i in range(kpts.shape[0]):
+ # print(kpts[i])
+ if segs is not None and seg_color is not None:
+ color = seg_color[segs[i]]
+ else:
+ color = (0, 255, 0)
+ outimg = cv2.circle(outimg,
+ center=(int(kpts[i, 0]), int(kpts[i, 1])),
+ color=color,
+ radius=radius,
+ thickness=thickness, )
+
+ return outimg
+
+
+def vis_corr_incorr_point(img, kpts, pred_segs, gt_segs, radius=7, thickness=-1):
+ outimg = deepcopy(img)
+ for i in range(kpts.shape[0]):
+ # print(kpts[i])
+ p_seg = pred_segs[i]
+ g_seg = gt_segs[i]
+ if p_seg == g_seg:
+ if g_seg != 0:
+ color = (0, 255, 0)
+ else:
+ color = (255, 0, 0)
+ else:
+ color = (0, 0, 255)
+ outimg = cv2.circle(outimg,
+ center=(int(kpts[i, 0]), int(kpts[i, 1])),
+ color=color,
+ radius=radius,
+ thickness=thickness, )
+ return outimg
+
+
+def vis_inlier(img, kpts, inliers, radius=7, thickness=1, with_outlier=True):
+ outimg = deepcopy(img)
+ for i in range(kpts.shape[0]):
+ if not with_outlier:
+ if not inliers[i]:
+ continue
+ if inliers[i]:
+ color = (0, 255, 0)
+ else:
+ color = (0, 0, 255)
+ outimg = cv2.rectangle(outimg,
+ pt1=(int(kpts[i, 0] - radius), int(kpts[i, 1] - radius)),
+ pt2=(int(kpts[i, 0] + radius), int(kpts[i, 1] + radius)),
+ color=color,
+ thickness=thickness, )
+
+ return outimg
+
+
+def vis_global_seg(cls, seg_color, radius=7, thickness=-1):
+ all_patches = []
+ for i in range(cls.shape[0]):
+ if cls[i] == 0:
+ continue
+ color = seg_color[i]
+ patch = np.zeros(shape=(radius, radius, 3), dtype=np.uint8)
+ patch[..., 0] = color[0]
+ patch[..., 1] = color[1]
+ patch[..., 2] = color[2]
+
+ all_patches.append(patch)
+ if len(all_patches) == 0:
+ color = seg_color[0]
+ patch = np.zeros(shape=(radius, radius, 3), dtype=np.uint8)
+ patch[..., 0] = color[0]
+ patch[..., 1] = color[1]
+ patch[..., 2] = color[2]
+ all_patches.append(patch)
+ return np.vstack(all_patches)
+
+
+def plot_matches(img1, img2, pts1, pts2, inliers, radius=3, line_thickness=2, horizon=True, plot_outlier=False,
+ confs=None):
+ rows1 = img1.shape[0]
+ cols1 = img1.shape[1]
+ rows2 = img2.shape[0]
+ cols2 = img2.shape[1]
+ # r = 3
+ if horizon:
+ img_out = np.zeros((max([rows1, rows2]), cols1 + cols2, 3), dtype='uint8')
+ # Place the first image to the left
+ img_out[:rows1, :cols1] = img1
+ # Place the next image to the right of it
+ img_out[:rows2, cols1:] = img2 # np.dstack([img2, img2, img2])
+ for idx in range(inliers.shape[0]):
+ # if idx % 10 > 0:
+ # continue
+ if inliers[idx]:
+ color = (0, 255, 0)
+ else:
+ if not plot_outlier:
+ continue
+ color = (0, 0, 255)
+ pt1 = pts1[idx]
+ pt2 = pts2[idx]
+
+ if confs is not None:
+ nr = int(radius * confs[idx])
+ else:
+ nr = radius
+ img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2)
+
+ img_out = cv2.circle(img_out, (int(pt2[0]) + cols1, int(pt2[1])), nr, color, 2)
+
+ img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]) + cols1, int(pt2[1])), color,
+ line_thickness)
+ else:
+ img_out = np.zeros((rows1 + rows2, max([cols1, cols2]), 3), dtype='uint8')
+ # Place the first image to the left
+ img_out[:rows1, :cols1] = img1
+ # Place the next image to the right of it
+ img_out[rows1:, :cols2] = img2 # np.dstack([img2, img2, img2])
+
+ for idx in range(inliers.shape[0]):
+ # print("idx: ", inliers[idx])
+ # if idx % 10 > 0:
+ # continue
+ if inliers[idx]:
+ color = (0, 255, 0)
+ else:
+ if not plot_outlier:
+ continue
+ color = (0, 0, 255)
+
+ if confs is not None:
+ nr = int(radius * confs[idx])
+ else:
+ nr = radius
+
+ pt1 = pts1[idx]
+ pt2 = pts2[idx]
+ img_out = cv2.circle(img_out, (int(pt1[0]), int(pt1[1])), nr, color, 2)
+
+ img_out = cv2.circle(img_out, (int(pt2[0]), int(pt2[1]) + rows1), nr, color, 2)
+
+ img_out = cv2.line(img_out, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1]) + rows1), color,
+ line_thickness)
+
+ return img_out
+
+
+def plot_kpts(img, kpts, radius=None, colors=None, r=3, color=(0, 0, 255), nh=-1, nw=-1, shape='o', show_text=None,
+ thickness=5):
+ img_out = deepcopy(img)
+ for i in range(kpts.shape[0]):
+ pt = kpts[i]
+ if radius is not None:
+ if shape == 'o':
+ img_out = cv2.circle(img_out, center=(int(pt[0]), int(pt[1])), radius=radius[i],
+ color=color if colors is None else colors[i],
+ thickness=thickness)
+ elif shape == '+':
+ img_out = cv2.line(img_out, pt1=(int(pt[0] - radius[i]), int(pt[1])),
+ pt2=(int(pt[0] + radius[i]), int(pt[1])),
+ color=color if colors is None else colors[i],
+ thickness=5)
+ img_out = cv2.line(img_out, pt1=(int(pt[0]), int(pt[1] - radius[i])),
+ pt2=(int(pt[0]), int(pt[1] + radius[i])), color=color,
+ thickness=thickness)
+ else:
+ if shape == 'o':
+ img_out = cv2.circle(img_out, center=(int(pt[0]), int(pt[1])), radius=r,
+ color=color if colors is None else colors[i],
+ thickness=thickness)
+ elif shape == '+':
+ img_out = cv2.line(img_out, pt1=(int(pt[0] - r), int(pt[1])),
+ pt2=(int(pt[0] + r), int(pt[1])), color=color if colors is None else colors[i],
+ thickness=thickness)
+ img_out = cv2.line(img_out, pt1=(int(pt[0]), int(pt[1] - r)),
+ pt2=(int(pt[0]), int(pt[1] + r)), color=color if colors is None else colors[i],
+ thickness=thickness)
+
+ if show_text is not None:
+ img_out = cv2.putText(img_out, show_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2,
+ (0, 0, 255), 3)
+ if nh == -1 and nw == -1:
+ return img_out
+ if nh > 0:
+ return cv2.resize(img_out, dsize=(int(img.shape[1] / img.shape[0] * nh), nh))
+ if nw > 0:
+ return cv2.resize(img_out, dsize=(nw, int(img.shape[0] / img.shape[1] * nw)))
diff --git a/third_party/pram/sfm_scripts/reconstruct_12scenes.sh b/third_party/pram/sfm_scripts/reconstruct_12scenes.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4f79e356a73f897f9e5a3db5cdf4cbf4b689275c
--- /dev/null
+++ b/third_party/pram/sfm_scripts/reconstruct_12scenes.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+# you need to use your own path
+
+dataset_dir=/scratches/flyer_3/fx221/dataset/12Scenes
+ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/12Scenes
+output_dir=/scratches/flyer_2/fx221/localization/outputs/12Scenes
+
+feat=sfd2
+matcher=gm
+
+#feat=superpoint-n4096
+#matcher=superglue
+
+extract_feat_db=1
+match_db=1
+triangulation=1
+localize=1
+
+ransac_thresh=8
+opt_thresh=8
+covisibility_frame=20
+inlier_thresh=30
+obs_thresh=3
+
+
+#for scene in apt1 apt2 office1 office2
+for scene in apt2 office1 office2
+do
+ echo $scene
+
+ if [ "$scene" = "apt1" ]; then
+ all_subscenes='kitchen living'
+ elif [ "$scene" = "apt2" ]; then
+ all_subscenes='bed kitchen living luke'
+ elif [ "$scene" = "office1" ]; then
+ all_subscenes='gates362 gates381 lounge manolis'
+ elif [ "$scene" = "office2" ]; then
+ all_subscenes='5a 5b'
+ fi
+
+ for subscene in $all_subscenes
+ do
+ echo $subscene
+
+ image_dir=$dataset_dir/$scene/$subscene
+ ref_sfm=$ref_sfm_dir/$scene/$subscene/3D-models
+ db_pair=$ref_sfm_dir/$scene/$subscene/pairs-db-covis20.txt
+ outputs=$output_dir/$scene/$subscene
+ query_pair=$ref_sfm_dir/$scene/$subscene/pairs-query-netvlad20.txt
+ gt_pose_fn=$ref_sfm_dir/$scene/$subscene/queries_poses.txt
+ query_fn=$ref_sfm_dir/$scene/$subscene/queries_with_intrinsics.txt
+
+ if [ "$extract_feat_db" -gt "0" ]; then
+ python3 -m loc.extract_features --image_dir $image_dir --export_dir $outputs/ --conf $feat
+ fi
+
+ if [ "$match_db" -gt "0" ]; then
+ python3 -m loc.match_features --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat
+ fi
+
+ if [ "$triangulation" -gt "0" ]; then
+ python3 -m loc.triangulation \
+ --sfm_dir $outputs/sfm_$feat-$matcher \
+ --reference_sfm_model $ref_sfm \
+ --image_dir $image_dir \
+ --pairs $db_pair \
+ --features $outputs/feats-$feat.h5 \
+ --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5
+ fi
+
+ if [ "$localize" -gt "0" ]; then
+ python3 -m loc.localizer \
+ --dataset 12Scenes \
+ --image_dir $image_dir \
+ --save_root $outputs \
+ --gt_pose_fn $gt_pose_fn \
+ --retrieval $query_pair \
+ --reference_sfm $outputs/sfm_$feat-$matcher \
+ --queries $query_fn \
+ --features $outputs/feats-$feat.h5 \
+ --matcher_method $matcher \
+ --ransac_thresh $ransac_thresh \
+ --covisibility_frame $covisibility_frame \
+ --obs_thresh $obs_thresh \
+ --opt_thresh $opt_thresh \
+ --inlier_thresh $inlier_thresh \
+ --use_hloc
+ fi
+ done
+
+done
diff --git a/third_party/pram/sfm_scripts/reconstruct_7scenes.sh b/third_party/pram/sfm_scripts/reconstruct_7scenes.sh
new file mode 100644
index 0000000000000000000000000000000000000000..91fb16dabc2a294476c0865fc4a5e12e2b4cf0b7
--- /dev/null
+++ b/third_party/pram/sfm_scripts/reconstruct_7scenes.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+# you need to use your own path
+dataset_dir=/scratches/flyer_3/fx221/dataset/7Scenes
+ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/7Scenes
+output_dir=/scratches/flyer_2/fx221/publications/test_pram/7Scenes
+
+# keypoints and matcher used for sfm
+feat=sfd2
+matcher=gml
+
+
+extract_feat_db=1
+match_db=1
+triangulation=1
+localize=0
+
+
+ransac_thresh=12
+opt_thresh=12
+covisibility_frame=20
+inlier_thresh=30
+obs_thresh=3
+
+
+for scene in heads fire office stairs pumpkin redkitchen chess
+#for scene in fire office pumpkin redkitchen chess
+#for scene in chess
+do
+ echo $scene
+ image_dir=$dataset_dir/$scene
+ ref_sfm=$ref_sfm_dir/$scene/3D-models
+ db_pair=$ref_sfm_dir/$scene/pairs-db-covis20.txt
+ outputs=$output_dir/$scene
+ query_pair=$ref_sfm_dir/$scene/pairs-query-netvlad20.txt
+ gt_pose_fn=$ref_sfm_dir/$scene/queries_poses.txt
+ query_fn=$ref_sfm_dir/$scene/queries_with_intrinsics.txt
+
+ if [ "$extract_feat_db" -gt "0" ]; then
+ python3 -m localization.extract_features --image_dir $image_dir --export_dir $outputs/ --conf $feat
+ fi
+
+ if [ "$match_db" -gt "0" ]; then
+ python3 -m localization.match_features --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat
+ fi
+
+ if [ "$triangulation" -gt "0" ]; then
+ python3 -m localization.triangulation \
+ --sfm_dir $outputs/sfm_$feat-$matcher \
+ --reference_sfm_model $ref_sfm \
+ --image_dir $image_dir \
+ --pairs $db_pair \
+ --features $outputs/feats-$feat.h5 \
+ --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5
+ fi
+
+ if [ "$localize" -gt "0" ]; then
+ python3 -m localization.localizer \
+ --dataset 7Scenes \
+ --image_dir $image_dir \
+ --save_root $outputs \
+ --gt_pose_fn $gt_pose_fn \
+ --retrieval $query_pair \
+ --reference_sfm $outputs/sfm_$feat-$matcher \
+ --queries $query_fn \
+ --features $outputs/feats-$feat.h5 \
+ --matcher_method $matcher \
+ --ransac_thresh $ransac_thresh \
+ --covisibility_frame $covisibility_frame \
+ --obs_thresh $obs_thresh \
+ --opt_thresh $opt_thresh \
+ --inlier_thresh $inlier_thresh \
+ --use_hloc
+ fi
+done
\ No newline at end of file
diff --git a/third_party/pram/sfm_scripts/reconstruct_aachen.sh b/third_party/pram/sfm_scripts/reconstruct_aachen.sh
new file mode 100644
index 0000000000000000000000000000000000000000..510485e521511f1948060c5d0de5f56984586c8d
--- /dev/null
+++ b/third_party/pram/sfm_scripts/reconstruct_aachen.sh
@@ -0,0 +1,69 @@
+#!/bin/bash
+# you need to use your own path
+dataset_dir=/scratches/flyer_3/fx221/dataset/Aachen/Aachenv11
+ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/Aachen/Aachenv11
+output_dir=/scratches/flyer_2/fx221/localization/outputs/Aachen/Aachenv11
+
+# fixed
+output=$output_dir
+ref_sfm=$ref_sfm_dir/3D-models
+db_pair=$ref_sfm_dir/pairs-db-covis20.txt
+query_pair=$ref_sfm_dir/pairs-query-netvlad50.txt
+gt_pose_fn=$ref_sfm_dir/queries_pose_spp_spg.txt
+query_fn=$ref_sfm_dir/queries_with_intrinsics.txt
+
+
+
+feat=sfd2
+matcher=gm
+
+#feat=superpoint-n4096
+#matcher=superglue
+
+extract_feat_db=1
+match_db=1
+triangulation=1
+localize=1
+
+if [ "$extract_feat_db" -gt "0" ]; then
+ python3 -m loc.extract_features --image_dir $dataset/images/images_upright --export_dir $outputs/ --conf $feat
+fi
+
+if [ "$match_db" -gt "0" ]; then
+ python3 -m loc.match_features --pairs $ref_sfm_dir/pairs-db-covis20.txt --export_dir $outputs/ --conf $matcher --features feats-$feat
+fi
+
+if [ "$triangulation" -gt "0" ]; then
+ python3 -m loc.triangulation \
+ --sfm_dir $outputs/sfm_$feat-$matcher \
+ --reference_sfm_model $ref_sfm \
+ --image_dir $dataset/images/images_upright \
+ --pairs $db_pair \
+ --features $outputs/feats-$feat.h5 \
+ --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5
+fi
+
+ransac_thresh=15
+opt_thresh=15
+covisibility_frame=30
+inlier_thresh=80
+obs_thresh=3
+
+if [ "$localize" -gt "0" ]; then
+ python3 -m loc.localizer \
+ --dataset aachen_v1.1 \
+ --image_dir $image_dir \
+ --save_root $outputs \
+ --gt_pose_fn $gt_pose_fn \
+ --retrieval $query_pair \
+ --reference_sfm $outputs/sfm_$feat-$matcher \
+ --queries $query_fn \
+ --features $outputs/feats-$feat.h5 \
+ --matcher_method $matcher \
+ --ransac_thresh $ransac_thresh \
+ --covisibility_frame $covisibility_frame \
+ --obs_thresh $obs_thresh \
+ --opt_thresh $opt_thresh \
+ --inlier_thresh $inlier_thresh \
+ --use_hloc
+fi
\ No newline at end of file
diff --git a/third_party/pram/sfm_scripts/reconstruct_cambridge.sh b/third_party/pram/sfm_scripts/reconstruct_cambridge.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f1ee967cf94e16e4a2f1848436d236df9a273858
--- /dev/null
+++ b/third_party/pram/sfm_scripts/reconstruct_cambridge.sh
@@ -0,0 +1,76 @@
+#!/bin/bash
+
+# you need to use your own path
+dataset_dir=/scratches/flyer_3/fx221/dataset/CambridgeLandmarks
+ref_sfm_dir=/scratches/flyer_2/fx221/publications/pram_data/3D-models/CambridgeLandmarks
+output_dir=/scratches/flyer_2/fx221/localization/outputs/CambridgeLandmarks
+
+
+feat=sfd2
+matcher=gm
+
+extract_feat_db=0
+match_db=0
+triangulation=0
+localize=1
+
+ransac_thresh=12
+opt_thresh=12
+covisibility_frame=20
+inlier_thresh=30
+radius=30
+obs_thresh=3
+
+
+#for scene in GreatCourt ShopFacade KingsCollege OldHospital StMarysChurch
+for scene in StMarysChurch
+#for scene in GreatCourt ShopFacade
+do
+ echo $scene
+
+ image_dir=$dataset_dir/$scene
+ ref_sfm=$ref_sfm_dir/$scene/3D-models
+ db_pair=$ref_sfm_dir/$scene/pairs-db-covis20.txt
+ outputs=$output_dir/$scene
+ query_pair=$ref_sfm_dir/$scene/pairs-query-netvlad20.txt
+ gt_pose_fn=$ref_sfm_dir/$scene/queries_poses.txt
+ query_fn=$ref_sfm_dir/$scene/queries_with_intrinsics.txt
+
+ if [ "$extract_feat_db" -gt "0" ]; then
+ python3 -m loc.extract_features --image_dir $image_dir --export_dir $outputs/ --conf $feat
+ fi
+
+ if [ "$match_db" -gt "0" ]; then
+ python3 -m loc.match_features --pairs $db_pair --export_dir $outputs/ --conf $matcher --features feats-$feat
+ fi
+
+ if [ "$triangulation" -gt "0" ]; then
+ python3 -m loc.triangulation \
+ --sfm_dir $outputs/sfm_$feat-$matcher \
+ --reference_sfm_model $ref_sfm \
+ --image_dir $image_dir\
+ --pairs $db_pair \
+ --features $outputs/feats-$feat.h5 \
+ --matches $outputs/feats-$feat-$matcher-pairs-db-covis20.h5
+ fi
+
+ if [ "$localize" -gt "0" ]; then
+ python3 -m loc.localizer \
+ --dataset cambridge \
+ --image_dir $image_dir \
+ --save_root $outputs\
+ --gt_pose_fn $gt_pose_fn \
+ --retrieval $query_pair \
+ --reference_sfm $outputs/sfm_$feat-$matcher \
+ --queries $query_fn \
+ --features $outputs/feats-$feat.h5 \
+ --matcher_method adagm2 \
+ --ransac_thresh $ransac_thresh \
+ --covisibility_frame $covisibility_frame \
+ --obs_thresh $obs_thresh \
+ --opt_thresh $opt_thresh \
+ --inlier_thresh $inlier_thresh \
+ --use_hloc
+ fi
+
+done
\ No newline at end of file
diff --git a/third_party/pram/tools/common.py b/third_party/pram/tools/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..8990012575324ed593ebc07bec88d47602005d5f
--- /dev/null
+++ b/third_party/pram/tools/common.py
@@ -0,0 +1,125 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> common
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 15:05
+=================================================='''
+import os
+import torch
+import json
+import yaml
+import cv2
+import numpy as np
+from typing import Tuple
+from copy import deepcopy
+
+
+def load_args(args, save_path):
+ with open(save_path, "r") as f:
+ args.__dict__ = json.load(f)
+
+
+def save_args_yaml(args, save_path):
+ with open(save_path, 'w') as f:
+ yaml.dump(args, f)
+
+
+def merge_tags(tags: list, connection='_'):
+ out = ''
+ for i, t in enumerate(tags):
+ if i == 0:
+ out = out + t
+ else:
+ out = out + connection + t
+ return out
+
+
+def torch_set_gpu(gpus):
+ if type(gpus) is int:
+ gpus = [gpus]
+
+ cuda = all(gpu >= 0 for gpu in gpus)
+
+ if cuda:
+ os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in gpus])
+ # print(os.environ['CUDA_VISIBLE_DEVICES'])
+ assert cuda and torch.cuda.is_available(), "%s has GPUs %s unavailable" % (
+ os.environ['HOSTNAME'], os.environ['CUDA_VISIBLE_DEVICES'])
+ torch.backends.cudnn.benchmark = True # speed-up cudnn
+ torch.backends.cudnn.fastest = True # even more speed-up?
+ print('Launching on GPUs ' + os.environ['CUDA_VISIBLE_DEVICES'])
+
+ else:
+ print('Launching on CPU')
+
+ return cuda
+
+
+def resize_img(img, nh=-1, nw=-1, rmax=-1, mode=cv2.INTER_NEAREST):
+ assert nh > 0 or nw > 0 or rmax > 0
+ if nh > 0:
+ return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * nh), nh), interpolation=mode)
+ if nw > 0:
+ return cv2.resize(img, dsize=(nw, int(img.shape[0] / img.shape[1] * nw)), interpolation=mode)
+ if rmax > 0:
+ oh, ow = img.shape[0], img.shape[1]
+ if oh > ow:
+ return cv2.resize(img, dsize=(int(img.shape[1] / img.shape[0] * rmax), rmax), interpolation=mode)
+ else:
+ return cv2.resize(img, dsize=(rmax, int(img.shape[0] / img.shape[1] * rmax)), interpolation=mode)
+
+ return cv2.resize(img, dsize=(nw, nh), interpolation=mode)
+
+
+def resize_image_with_padding(image: np.array, nw: int, nh: int, padding_color: Tuple[int] = (0, 0, 0)) -> np.array:
+ """Maintains aspect ratio and resizes with padding.
+ Params:
+ image: Image to be resized.
+ new_shape: Expected (width, height) of new image.
+ padding_color: Tuple in BGR of padding color
+ Returns:
+ image: Resized image with padding
+ """
+ original_shape = (image.shape[1], image.shape[0]) # (w, h)
+ ratio_w = nw / original_shape[0]
+ ratio_h = nh / original_shape[1]
+
+ if ratio_w == ratio_h:
+ image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_NEAREST)
+
+ ratio = ratio_w if ratio_w < ratio_h else ratio_h
+
+ new_size = tuple([int(x * ratio) for x in original_shape])
+ image = cv2.resize(image, new_size, interpolation=cv2.INTER_NEAREST)
+ delta_w = nw - new_size[0] if nw > new_size[0] else new_size[0] - nw
+ delta_h = nh - new_size[1] if nh > new_size[1] else new_size[1] - nh
+
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
+
+ # print('top, bottom, left, right: ', top, bottom, left, right)
+ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color)
+ return image
+
+
+def puttext_with_background(image, text, org=(0, 0), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+ fontScale=1, text_color=(0, 0, 255),
+ thickness=2, lineType=cv2.LINE_AA, bg_color=None):
+ out_img = deepcopy(image)
+ if bg_color is not None:
+ (text_width, text_height), baseline = cv2.getTextSize(text,
+ fontFace,
+ fontScale=fontScale,
+ thickness=thickness)
+ box_coords = (
+ (org[0], org[1] + baseline),
+ (org[0] + text_width + 2, org[1] - text_height - 2))
+
+ cv2.rectangle(out_img, box_coords[0], box_coords[1], bg_color, cv2.FILLED)
+ out_img = cv2.putText(img=out_img, text=text,
+ org=org,
+ fontFace=fontFace,
+ fontScale=fontScale, color=text_color,
+ thickness=thickness, lineType=lineType)
+ return out_img
diff --git a/third_party/pram/tools/geometry.py b/third_party/pram/tools/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..d781a4172dd7f6ad8a4a26e252f614483ebd01e3
--- /dev/null
+++ b/third_party/pram/tools/geometry.py
@@ -0,0 +1,74 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> geometry
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/02/2024 11:08
+=================================================='''
+import numpy as np
+
+
+def nms_fast(in_corners, H, W, dist_thresh):
+ """
+ Run a faster approximate Non-Max-Suppression on numpy corners shaped:
+ 3xN [x_i,y_i,conf_i]^T
+
+ Algo summary: Create a grid sized HxW. Assign each corner location a 1, rest
+ are zeros. Iterate through all the 1's and convert them either to -1 or 0.
+ Suppress points by setting nearby values to 0.
+
+ Grid Value Legend:
+ -1 : Kept.
+ 0 : Empty or suppressed.
+ 1 : To be processed (converted to either kept or supressed).
+
+ NOTE: The NMS first rounds points to integers, so NMS distance might not
+ be exactly dist_thresh. It also assumes points are within image boundaries.
+
+ Inputs
+ in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
+ H - Image height.
+ W - Image width.
+ dist_thresh - Distance to suppress, measured as an infinty norm distance.
+ Returns
+ nmsed_corners - 3xN numpy matrix with surviving corners.
+ nmsed_inds - N length numpy vector with surviving corner indices.
+ """
+ grid = np.zeros((H, W)).astype(int) # Track NMS data.
+ inds = np.zeros((H, W)).astype(int) # Store indices of points.
+ # Sort by confidence and round to nearest int.
+ inds1 = np.argsort(-in_corners[2, :])
+ corners = in_corners[:, inds1]
+ rcorners = corners[:2, :].round().astype(int) # Rounded corners.
+ # Check for edge case of 0 or 1 corners.
+ if rcorners.shape[1] == 0:
+ return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int)
+ if rcorners.shape[1] == 1:
+ out = np.vstack((rcorners, in_corners[2])).reshape(3, 1)
+ return out, np.zeros((1)).astype(int)
+ # Initialize the grid.
+ for i, rc in enumerate(rcorners.T):
+ grid[rcorners[1, i], rcorners[0, i]] = 1
+ inds[rcorners[1, i], rcorners[0, i]] = i
+ # Pad the border of the grid, so that we can NMS points near the border.
+ pad = dist_thresh
+ grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant')
+ # Iterate through points, highest to lowest conf, suppress neighborhood.
+ count = 0
+ for i, rc in enumerate(rcorners.T):
+ # Account for top and left padding.
+ pt = (rc[0] + pad, rc[1] + pad)
+ if grid[pt[1], pt[0]] == 1: # If not yet suppressed.
+ grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0
+ grid[pt[1], pt[0]] = -1
+ count += 1
+ # Get all surviving -1's and return sorted array of remaining corners.
+ keepy, keepx = np.where(grid == -1)
+ keepy, keepx = keepy - pad, keepx - pad
+ inds_keep = inds[keepy, keepx]
+ out = corners[:, inds_keep]
+ values = out[-1, :]
+ inds2 = np.argsort(-values)
+ out = out[:, inds2]
+ out_inds = inds1[inds_keep[inds2]]
+ return out_inds
diff --git a/third_party/pram/tools/image_to_video.py b/third_party/pram/tools/image_to_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8f281fd2cf0ef5eb2752117610c042b8764f5f1
--- /dev/null
+++ b/third_party/pram/tools/image_to_video.py
@@ -0,0 +1,66 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File localizer -> image_to_video
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 07/09/2023 20:15
+=================================================='''
+import cv2
+import os
+import os.path as osp
+
+import numpy as np
+from tqdm import tqdm
+import argparse
+
+from tools.common import resize_img
+
+parser = argparse.ArgumentParser(description='Image2Video', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--image_dir', type=str, required=True)
+parser.add_argument('--video_path', type=str, required=True)
+parser.add_argument('--height', type=int, default=-1)
+parser.add_argument('--fps', type=int, default=30)
+
+
+def imgs2video(img_dir, video_path, fps=30, height=1024):
+ img_fns = os.listdir(img_dir)
+ # print(img_fns)
+ img_fns = [v for v in img_fns if v.split('.')[-1] in ['jpg', 'png']]
+ img_fns = sorted(img_fns)
+ # print(img_fns)
+ # 输出视频路径
+ # fps = 1
+
+ img = cv2.imread(osp.join(img_dir, img_fns[0]))
+ if height == -1:
+ height = img.shape[1]
+ new_img = resize_img(img=img, nh=height)
+ img_size = (new_img.shape[1], height)
+
+ # fourcc = cv2.cv.CV_FOURCC('M','J','P','G')#opencv2.4
+ # fourcc = cv2.VideoWriter_fourcc('I','4','2','0')
+
+ fourcc = cv2.VideoWriter_fourcc(*'MP4V') # 设置输出视频为mp4格式
+ # fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') # 设置输出视频为mp4格式
+ videoWriter = cv2.VideoWriter(video_path, fourcc, fps, img_size)
+
+ for i in tqdm(range(3700, len(img_fns)), total=len(img_fns)):
+ # fn = img_fns[i].split('-')
+ im_name = os.path.join(img_dir, img_fns[i])
+ print(im_name)
+ frame = cv2.imread(im_name, 1)
+ frame = np.flip(frame, 0)
+
+ frame = cv2.resize(frame, dsize=img_size)
+ # print(frame.shape)
+ # exit(0)
+ cv2.imshow("frame", frame)
+ cv2.waitKey(1)
+ videoWriter.write(frame)
+
+ videoWriter.release()
+
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ imgs2video(img_dir=args.image_dir, video_path=args.video_path, fps=args.fps, height=args.height)
diff --git a/third_party/pram/tools/metrics.py b/third_party/pram/tools/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..22e14374931fa9ba4151632b65b41c65d6ba55f7
--- /dev/null
+++ b/third_party/pram/tools/metrics.py
@@ -0,0 +1,216 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> metrics
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 16:32
+=================================================='''
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+class SeqIOU:
+ def __init__(self, n_class, ignored_sids=[]):
+ self.n_class = n_class
+ self.ignored_sids = ignored_sids
+ self.class_iou = np.zeros(n_class)
+ self.precisions = []
+
+ def add(self, pred, target):
+ for i in range(self.n_class):
+ inter = np.sum((pred == target) * (target == i))
+ union = np.sum(target == i) + np.sum(pred == i) - inter
+ if union > 0:
+ self.class_iou[i] = inter / union
+
+ acc = (pred == target)
+ if len(self.ignored_sids) == 0:
+ acc_ratio = np.sum(acc) / pred.shape[0]
+ else:
+ pred_mask = (pred >= 0)
+ target_mask = (target >= 0)
+ for i in self.ignored_sids:
+ pred_mask = pred_mask & (pred == i)
+ target_mask = target_mask & (target == i)
+
+ acc = acc & (1 - pred_mask)
+ tgt = (1 - target_mask)
+ if np.sum(tgt) == 0:
+ acc_ratio = 0
+ else:
+ acc_ratio = np.sum(acc) / np.sum(tgt)
+
+ self.precisions.append(acc_ratio)
+
+ def get_mean_iou(self):
+ return np.mean(self.class_iou)
+
+ def get_mean_precision(self):
+ return np.mean(self.precisions)
+
+ def clear(self):
+ self.precisions = []
+ self.class_iou = np.zeros(self.n_class)
+
+
+def compute_iou(pred: np.ndarray, target: np.ndarray, n_class: int, ignored_ids=[]) -> float:
+ class_iou = np.zeros(n_class)
+ for i in range(n_class):
+ if i in ignored_ids:
+ continue
+ inter = np.sum((pred == target) * (target == i))
+ union = np.sum(target == i) + np.sum(pred == i) - inter
+ if union > 0:
+ class_iou[i] = inter / union
+
+ return np.mean(class_iou)
+ # return class_iou
+
+
+def compute_precision(pred: np.ndarray, target: np.ndarray, ignored_ids: list = []) -> float:
+ acc = (pred == target)
+ if len(ignored_ids) == 0:
+ return np.sum(acc) / pred.shape[0]
+ else:
+ pred_mask = (pred >= 0)
+ target_mask = (target >= 0)
+ for i in ignored_ids:
+ pred_mask = pred_mask & (pred == i)
+ target_mask = target_mask & (target == i)
+
+ acc = acc & (1 - pred_mask)
+ tgt = (1 - target_mask)
+ if np.sum(tgt) == 0:
+ return 0
+ return np.sum(acc) / np.sum(tgt)
+
+
+def compute_cls_corr(pred: torch.Tensor, target: torch.Tensor, k: int = 20) -> torch.Tensor:
+ bs = pred.shape[0]
+ _, target_ids = torch.topk(target, k=k, dim=1)
+ target_ids = target_ids.cpu().numpy()
+ _, top_ids = torch.topk(pred, k=k, dim=1) # [B, k, 1]
+ top_ids = top_ids.cpu().numpy()
+ acc = 0
+ for i in range(bs):
+ # print('top_ids: ', i, top_ids[i], target_ids[i])
+ overlap = [v for v in top_ids[i] if v in target_ids[i] and v >= 0]
+ acc = acc + len(overlap) / k
+ acc = acc / bs
+ return torch.from_numpy(np.array([acc])).to(pred.device)
+
+
+def compute_corr_incorr(pred: torch.Tensor, target: torch.Tensor, ignored_ids: list = []) -> tuple:
+ '''
+ :param pred: [B, N, C]
+ :param target: [B, N]
+ :param ignored_ids: []
+ :return:
+ '''
+ pred_ids = torch.max(pred, dim=-1)[1]
+ if len(ignored_ids) == 0:
+ acc = (pred_ids == target)
+ inacc = torch.logical_not(acc)
+ acc_ratio = torch.sum(acc) / torch.numel(target)
+ inacc_ratio = torch.sum(inacc) / torch.numel(target)
+ else:
+ acc = (pred_ids == target)
+ inacc = torch.logical_not(acc)
+
+ mask = torch.zeros_like(acc)
+ for i in ignored_ids:
+ mask = torch.logical_and(mask, (target == i))
+
+ acc = torch.logical_and(acc, torch.logical_not(mask))
+ acc_ratio = torch.sum(acc) / torch.numel(target)
+ inacc_ratio = torch.sum(inacc) / torch.numel(target)
+
+ return acc_ratio, inacc_ratio
+
+
+def compute_seg_loss_weight(pred: torch.Tensor,
+ target: torch.Tensor,
+ background_id: int = 0,
+ weight_background: float = 0.1) -> torch.Tensor:
+ '''
+ :param pred: [B, C, N]
+ :param target: [B, N]
+ :param background_id:
+ :param weight_background:
+ :return:
+ '''
+ pred = pred.transpose(-2, -1).contiguous() # [B, N, C] -> [B, C, N]
+ weight = torch.ones(size=(pred.shape[1],), device=pred.device).float()
+ pred = torch.log_softmax(pred, dim=1)
+ weight[background_id] = weight_background
+ seg_loss = F.cross_entropy(pred, target.long(), weight=weight)
+ return seg_loss
+
+
+def compute_cls_loss_ce(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ cls_loss = torch.zeros(size=[], device=pred.device)
+ if len(pred.shape) == 2:
+ n_valid = torch.sum(target > 0)
+ cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred, target, reduction='sum')
+ cls_loss = cls_loss / n_valid
+ else:
+ for i in range(pred.shape[-1]):
+ cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred[..., i], target[..., i], reduction='sum')
+ n_valid = torch.sum(target > 0)
+ cls_loss = cls_loss / n_valid
+
+ return cls_loss
+
+
+def compute_cls_loss_kl(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ cls_loss = torch.zeros(size=[], device=pred.device)
+ if len(pred.shape) == 2:
+ cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred, dim=-1),
+ torch.softmax(target, dim=-1),
+ reduction='sum')
+ else:
+ for i in range(pred.shape[-1]):
+ cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred[..., i], dim=-1),
+ torch.softmax(target[..., i], dim=-1),
+ reduction='sum')
+
+ cls_loss = cls_loss / pred.shape[-1]
+
+ return cls_loss
+
+
+def compute_sc_loss_l1(pred: torch.Tensor, target: torch.Tensor, mean_xyz=None, scale_xyz=None, mask=None):
+ '''
+ :param pred: [B, N, C]
+ :param target: [B, N, C]
+ :param mean_xyz:
+ :param scale_xyz:
+ :param mask:
+ :return:
+ '''
+ loss = (pred - target)
+ loss = torch.abs(loss).mean(dim=1)
+ if mask is not None:
+ return torch.mean(loss[mask])
+ else:
+ return torch.mean(loss)
+
+
+def compute_sc_loss_geo(pred: torch.Tensor, P, K, p2ds, mean_xyz, scale_xyz, max_value=20, mask=None):
+ b, c, n = pred.shape
+ p3ds = (pred * scale_xyz[..., None].repeat(1, 1, n) + mean_xyz[..., None].repeat(1, 1, n))
+ p3ds_homo = torch.cat(
+ [pred, torch.ones(size=(p3ds.shape[0], 1, p3ds.shape[2]), dtype=p3ds.dtype, device=p3ds.device)],
+ dim=1) # [B, 4, N]
+ p3ds = torch.matmul(K, torch.matmul(P, p3ds_homo)[:, :3, :]) # [B, 3, N]
+ # print('p3ds: ', p3ds.shape, P.shape, K.shape, p2ds.shape)
+
+ p2ds_ = p3ds[:, :2, :] / p3ds[:, 2:, :]
+
+ loss = ((p2ds_ - p2ds.permute(0, 2, 1)) ** 2).sum(1)
+ loss = torch.clamp_max(loss, max=max_value)
+ if mask is not None:
+ return torch.mean(loss[mask])
+ else:
+ return torch.mean(loss)
diff --git a/third_party/pram/tools/video_to_image.py b/third_party/pram/tools/video_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..7283f3ba24d432410ea326a7d9aedbe011b60ed2
--- /dev/null
+++ b/third_party/pram/tools/video_to_image.py
@@ -0,0 +1,38 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File localizer -> video_to_image
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 13/01/2024 15:29
+=================================================='''
+import argparse
+import os
+import os.path as osp
+import cv2
+
+parser = argparse.ArgumentParser(description='Image2Video', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--image_path', type=str, required=True)
+parser.add_argument('--video_path', type=str, required=True)
+parser.add_argument('--height', type=int, default=-1)
+parser.add_argument('--sample_ratio', type=int, default=-1)
+
+
+def main(args):
+ video = cv2.VideoCapture(args.video_path)
+ nframe = 0
+ while True:
+ ret, frame = video.read()
+ if ret:
+ if args.sample_ratio > 0:
+ if nframe % args.sample_ratio != 0:
+ nframe += 1
+ continue
+ cv2.imwrite(osp.join(args.image_path, '{:06d}.png'.format(nframe)), frame)
+ nframe += 1
+ else:
+ break
+
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ main(args=args)
diff --git a/third_party/pram/tools/visualize_landmarks.py b/third_party/pram/tools/visualize_landmarks.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f8bcba35c14b929de1159c3a9491a98e1f0aebb
--- /dev/null
+++ b/third_party/pram/tools/visualize_landmarks.py
@@ -0,0 +1,171 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> visualize_landmarks
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 22/03/2024 10:39
+=================================================='''
+import os
+import os.path as osp
+import numpy as np
+from tqdm import tqdm
+from colmap_utils.read_write_model import read_model, write_model, Point3D, Image, read_compressed_model
+from recognition.vis_seg import generate_color_dic
+
+
+def reconstruct_map(valid_image_ids, valid_p3d_ids, cameras, images, point3Ds, p3d_seg: dict):
+ new_point3Ds = {}
+ new_images = {}
+
+ valid_p3d_ids_ = []
+ for pid in tqdm(valid_p3d_ids, total=len(valid_p3d_ids)):
+
+ if pid == -1:
+ continue
+ if pid not in point3Ds.keys():
+ continue
+
+ if pid not in p3d_seg.keys():
+ continue
+
+ sid = map_seg[pid]
+ if sid == -1:
+ continue
+ valid_p3d_ids_.append(pid)
+
+ valid_p3d_ids = valid_p3d_ids_
+ print('valid_p3ds: ', len(valid_p3d_ids))
+
+ # for im_id in tqdm(images.keys(), total=len(images.keys())):
+ for im_id in tqdm(valid_image_ids, total=len(valid_image_ids)):
+ im = images[im_id]
+ # print('im: ', im)
+ # exit(0)
+ pids = im.point3D_ids
+ valid_pids = []
+ # for v in pids:
+ # if v not in valid_p3d_ids:
+ # valid_pids.append(-1)
+ # else:
+ # valid_pids.append(v)
+
+ new_im = Image(id=im_id, qvec=im.qvec, tvec=im.tvec, camera_id=im.camera_id, name=im.name, xys=im.xys,
+ point3D_ids=pids)
+ new_images[im_id] = new_im
+
+ for pid in tqdm(valid_p3d_ids, total=len(valid_p3d_ids)):
+ sid = map_seg[pid]
+
+ xyz = points3D[pid].xyz
+ if show_2D:
+ xyz[1] = 0
+ rgb = points3D[pid].rgb
+ else:
+ bgr = seg_color[sid + sid_start]
+ rgb = np.array([bgr[2], bgr[1], bgr[0]])
+
+ error = points3D[pid].error
+
+ p3d = Point3D(id=pid, xyz=xyz, rgb=rgb, error=error,
+ image_ids=points3D[pid].image_ids,
+ point2D_idxs=points3D[pid].point2D_idxs)
+ new_point3Ds[pid] = p3d
+
+ return cameras, new_images, new_point3Ds
+
+
+if __name__ == '__main__':
+ save_root = '/scratches/flyer_3/fx221/exp/localizer/vis_clustering/'
+ seg_color = generate_color_dic(n_seg=2000)
+ data_root = '/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gm'
+ show_2D = False
+
+ compress_map = False
+ # compress_map = True
+
+ # scene = 'Aachen/Aachenv11'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n512_xz_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 1
+ # vrf_file_name = 'point3D_vrf_n512_xz_birch.npy'
+
+ #
+ # scene = 'CambridgeLandmarks/GreatCourt'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xy_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 1
+
+ # scene = 'CambridgeLandmarks/KingsCollege'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xy_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 33
+ # vrf_file_name = 'point3D_vrf_n32_xy_birch.npy'
+
+ # scene = 'CambridgeLandmarks/StMarysChurch'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n32_xz_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 32 * 4 + 1
+ # vrf_file_name = 'point3D_vrf_n32_xz_birch.npy'
+
+ # scene = '7Scenes/office'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 33
+
+ # scene = '7Scenes/chess'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 1
+ # vrf_file_name = 'point3D_vrf_n16_xz_birch.npy'
+
+ # scene = '7Scenes/redkitchen'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xz_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 16 * 5 + 1
+ # vrf_file_name = 'point3D_vrf_n16_xz_birch.npy'
+
+ # scene = '12Scenes/apt1/kitchen'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n16_xy_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 1
+ # vrf_file_name = 'point3D_vrf_n16_xy_birch.npy'
+
+ # data_root = '/scratches/flyer_3/fx221/exp/localizer/resnet4x-20230511-210205-pho-0005-gml2'
+ # scene = 'JesusCollege/jesuscollege'
+ # seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n256_xy_birch.npy'), allow_pickle=True)[()]
+ # sid_start = 1
+ # vrf_file_name = 'point3D_vrf_n256_xy_birch.npy'
+
+ scene = 'DarwinRGB/darwin'
+ seg_data = np.load(osp.join(data_root, scene, 'point3D_cluster_n128_xy_birch.npy'), allow_pickle=True)[()]
+ sid_start = 1
+ vrf_file_name = 'point3D_vrf_n128_xy_birch.npy'
+
+ cameras, images, points3D = read_model(osp.join(data_root, scene, 'model'), ext='.bin')
+ print('Load {:d} 3D points from map'.format(len(points3D.keys())))
+
+ if compress_map:
+ vrf_data = np.load(osp.join(data_root, scene, vrf_file_name), allow_pickle=True)[()]
+ valid_image_ids = [vrf_data[v][0]['image_id'] for v in vrf_data.keys()]
+ else:
+ valid_image_ids = list(images.keys())
+
+ if compress_map:
+ _, _, compress_points3D = read_compressed_model(osp.join(data_root, scene, 'compress_model_birch'),
+ ext='.bin')
+ print('Load {:d} 3D points from compressed map'.format(len(compress_points3D.keys())))
+ valid_p3d_ids = list(compress_points3D.keys())
+ else:
+ valid_p3d_ids = list(points3D.keys())
+
+ save_path = osp.join(save_root, scene)
+
+ if compress_map:
+ save_path = save_path + '_comp'
+ if show_2D:
+ save_path = save_path + '_2D'
+
+ os.makedirs(save_path, exist_ok=True)
+ p3d_id = seg_data['id']
+ seg_id = seg_data['label']
+ map_seg = {p3d_id[i]: seg_id[i] for i in range(p3d_id.shape[0])}
+
+ new_cameras, new_images, new_point3Ds = reconstruct_map(valid_image_ids=valid_image_ids,
+ valid_p3d_ids=valid_p3d_ids, cameras=cameras, images=images,
+ point3Ds=points3D, p3d_seg=map_seg)
+
+ # write_model(cameras=cameras, images=images, points3D=new_point3Ds,
+ # path=save_path, ext='.bin')
+ write_model(cameras=new_cameras, images=new_images, points3D=new_point3Ds, path=save_path, ext='.bin')
diff --git a/third_party/pram/train.py b/third_party/pram/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a2657f455d29c7c7c5417d8efa7aacaef4207ed
--- /dev/null
+++ b/third_party/pram/train.py
@@ -0,0 +1,170 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> train
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 03/04/2024 16:33
+=================================================='''
+import argparse
+import os
+import os.path as osp
+import torch
+import torchvision.transforms.transforms as tvt
+import yaml
+import torch.utils.data as Data
+import torch.multiprocessing as mp
+import torch.distributed as dist
+
+from nets.sfd2 import load_sfd2
+from nets.segnet import SegNet
+from nets.segnetvit import SegNetViT
+from nets.load_segnet import load_segnet
+from dataset.utils import collect_batch
+from dataset.get_dataset import compose_datasets
+from tools.common import torch_set_gpu
+from trainer import Trainer
+
+
+def get_model(config):
+ desc_dim = 256 if config['feature'] == 'spp' else 128
+ if config['use_mid_feature']:
+ desc_dim = 256
+ model_config = {
+ 'network': {
+ 'descriptor_dim': desc_dim,
+ 'n_layers': config['layers'],
+ 'ac_fn': config['ac_fn'],
+ 'norm_fn': config['norm_fn'],
+ 'n_class': config['n_class'],
+ 'output_dim': config['output_dim'],
+ # 'with_cls': config['with_cls'],
+ # 'with_sc': config['with_sc'],
+ 'with_score': config['with_score'],
+ }
+ }
+
+ if config['network'] == 'segnet':
+ model = SegNet(model_config.get('network', {}))
+ config['with_cls'] = False
+ elif config['network'] == 'segnetvit':
+ model = SegNetViT(model_config.get('network', {}))
+ config['with_cls'] = False
+ else:
+ raise 'ERROR! {:s} model does not exist'.format(config['network'])
+
+ return model
+
+
+parser = argparse.ArgumentParser(description='PRAM', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--config', type=str, required=True, help='config of specifications')
+# parser.add_argument('--landmark_path', type=str, required=True, help='path of landmarks')
+parser.add_argument('--feat_weight_path', type=str, default='weights/sfd2_20230511_210205_resnet4x.79.pth')
+
+
+def setup(rank, world_size):
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '12355'
+ # initialize the process group
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+
+def train_DDP(rank, world_size, model, config, train_set, test_set, feat_model, img_transforms):
+ print('In train_DDP..., rank: ', rank)
+ torch.cuda.set_device(rank)
+
+ device = torch.device(f'cuda:{rank}')
+ if feat_model is not None:
+ feat_model.to(device)
+ model.to(device)
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ setup(rank=rank, world_size=world_size)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set,
+ shuffle=True,
+ rank=rank,
+ num_replicas=world_size,
+ drop_last=True, # important?
+ )
+ train_loader = torch.utils.data.DataLoader(train_set,
+ batch_size=config['batch_size'] // world_size,
+ num_workers=config['workers'] // world_size,
+ # num_workers=1,
+ pin_memory=True,
+ # persistent_workers=True,
+ shuffle=False, # must be False
+ drop_last=True,
+ collate_fn=collect_batch,
+ prefetch_factor=4,
+ sampler=train_sampler)
+ config['local_rank'] = rank
+
+ if rank == 0:
+ test_set = test_set
+ else:
+ test_set = None
+
+ trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_set,
+ config=config, img_transforms=img_transforms)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ with open(args.config, 'rt') as f:
+ config = yaml.load(f, Loader=yaml.Loader)
+ torch_set_gpu(gpus=config['gpu'])
+ if config['local_rank'] == 0:
+ print(config)
+
+ img_transforms = []
+ img_transforms.append(tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
+ img_transforms = tvt.Compose(img_transforms)
+
+ feat_model = load_sfd2(weight_path=args.feat_weight_path).cuda().eval()
+ print('Load SFD2 weight from {:s}'.format(args.feat_weight_path))
+
+ dataset = config['dataset']
+ train_set = compose_datasets(datasets=dataset, config=config, train=True, sample_ratio=None)
+ if config['do_eval']:
+ test_set = compose_datasets(datasets=dataset, config=config, train=False, sample_ratio=None)
+ else:
+ test_set = None
+ config['n_class'] = train_set.n_class
+ # model = get_model(config=config)
+ model = load_segnet(network=config['network'],
+ n_class=config['n_class'],
+ desc_dim=256 if config['use_mid_feature'] else 128,
+ n_layers=config['layers'],
+ output_dim=config['output_dim'])
+ if config['local_rank'] == 0:
+ if config['resume_path'] is not None: # only for training
+ model.load_state_dict(
+ torch.load(osp.join(config['save_path'], config['resume_path']), map_location='cpu')['model'],
+ strict=True)
+ print('Load resume weight from {:s}'.format(osp.join(config['save_path'], config['resume_path'])))
+
+ if not config['with_dist'] or len(config['gpu']) == 1:
+ config['with_dist'] = False
+ model = model.cuda()
+ train_loader = Data.DataLoader(dataset=train_set,
+ shuffle=True,
+ batch_size=config['batch_size'],
+ drop_last=True,
+ collate_fn=collect_batch,
+ num_workers=config['workers'])
+ if test_set is not None:
+ test_loader = Data.DataLoader(dataset=test_set,
+ shuffle=False,
+ batch_size=1,
+ drop_last=False,
+ collate_fn=collect_batch,
+ num_workers=4)
+ else:
+ test_loader = None
+ trainer = Trainer(model=model, train_loader=train_loader, feat_model=feat_model, eval_loader=test_loader,
+ config=config, img_transforms=img_transforms)
+ trainer.train()
+ else:
+ mp.spawn(train_DDP, nprocs=len(config['gpu']),
+ args=(len(config['gpu']), model, config, train_set, test_set, feat_model, img_transforms),
+ join=True)
diff --git a/third_party/pram/trainer.py b/third_party/pram/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..002e349323ec587843ea4119a0bc32b343bd34dd
--- /dev/null
+++ b/third_party/pram/trainer.py
@@ -0,0 +1,404 @@
+# -*- coding: UTF-8 -*-
+'''=================================================
+@Project -> File pram -> trainer
+@IDE PyCharm
+@Author fx221@cam.ac.uk
+@Date 29/01/2024 15:04
+=================================================='''
+import datetime
+import os
+import os.path as osp
+import numpy as np
+from pathlib import Path
+from tensorboardX import SummaryWriter
+from tqdm import tqdm
+import torch.optim as optim
+import torch.nn.functional as F
+
+import shutil
+import torch
+from torch.autograd import Variable
+from tools.common import save_args_yaml, merge_tags
+from tools.metrics import compute_iou, compute_precision, SeqIOU, compute_corr_incorr, compute_seg_loss_weight
+from tools.metrics import compute_cls_loss_ce, compute_cls_corr
+
+
+class Trainer:
+ def __init__(self, model, train_loader, feat_model=None, eval_loader=None, config=None, img_transforms=None):
+ self.model = model
+ self.train_loader = train_loader
+ self.eval_loader = eval_loader
+ self.config = config
+ self.with_aug = self.config['with_aug']
+ self.with_cls = False # self.config['with_cls']
+ self.with_sc = False # self.config['with_sc']
+ self.img_transforms = img_transforms
+ self.feat_model = feat_model.cuda().eval() if feat_model is not None else None
+
+ self.init_lr = self.config['lr']
+ self.min_lr = self.config['min_lr']
+
+ params = [p for p in self.model.parameters() if p.requires_grad]
+ self.optimizer = optim.AdamW(params=params, lr=self.init_lr)
+ self.num_epochs = self.config['epochs']
+
+ if config['resume_path'] is not None:
+ log_dir = config['resume_path'].split('/')[-2]
+ resume_log = torch.load(osp.join(osp.join(config['save_path'], config['resume_path'])), map_location='cpu')
+ self.epoch = resume_log['epoch'] + 1
+ if 'iteration' in resume_log.keys():
+ self.iteration = resume_log['iteration']
+ else:
+ self.iteration = len(self.train_loader) * self.epoch
+ self.min_loss = resume_log['min_loss']
+ else:
+ self.iteration = 0
+ self.epoch = 0
+ self.min_loss = 1e10
+
+ now = datetime.datetime.now()
+ all_tags = [now.strftime("%Y%m%d_%H%M%S")]
+ dataset_name = merge_tags(self.config['dataset'], '')
+ all_tags = all_tags + [self.config['network'], 'L' + str(self.config['layers']),
+ dataset_name,
+ str(self.config['feature']), 'B' + str(self.config['batch_size']),
+ 'K' + str(self.config['max_keypoints']), 'od' + str(self.config['output_dim']),
+ 'nc' + str(self.config['n_class'])]
+ if self.config['use_mid_feature']:
+ all_tags.append('md')
+ # if self.with_cls:
+ # all_tags.append(self.config['cls_loss'])
+ # if self.with_sc:
+ # all_tags.append(self.config['sc_loss'])
+ if self.with_aug:
+ all_tags.append('A')
+
+ all_tags.append(self.config['cluster_method'])
+ log_dir = merge_tags(tags=all_tags, connection='_')
+
+ if config['local_rank'] == 0:
+ self.save_dir = osp.join(self.config['save_path'], log_dir)
+ os.makedirs(self.save_dir, exist_ok=True)
+
+ print("save_dir: ", self.save_dir)
+
+ self.log_file = open(osp.join(self.save_dir, "log.txt"), "a+")
+ save_args_yaml(args=config, save_path=Path(self.save_dir, "args.yaml"))
+ self.writer = SummaryWriter(self.save_dir)
+
+ self.tag = log_dir
+
+ self.do_eval = self.config['do_eval']
+ if self.do_eval:
+ self.eval_fun = None
+ self.seq_metric = SeqIOU(n_class=self.config['n_class'], ignored_sids=[0])
+
+ def preprocess_input(self, pred):
+ for k in pred.keys():
+ if k.find('name') >= 0:
+ continue
+ if k != 'image' and k != 'depth':
+ if type(pred[k]) == torch.Tensor:
+ pred[k] = Variable(pred[k].float().cuda())
+ else:
+ pred[k] = Variable(torch.stack(pred[k]).float().cuda())
+
+ if self.with_aug:
+ new_scores = []
+ new_descs = []
+ global_descs = []
+ with torch.no_grad():
+ for i, im in enumerate(pred['image']):
+ img = torch.from_numpy(im[0]).cuda().float().permute(2, 0, 1)
+ # img = self.img_transforms(img)[None]
+ if self.img_transforms is not None:
+ img = self.img_transforms(img)[None]
+ else:
+ img = img[None]
+ out = self.feat_model.extract_local_global(data={'image': img})
+ global_descs.append(out['global_descriptors'])
+
+ seg_scores, seg_descs = self.feat_model.sample(score_map=out['score_map'],
+ semi_descs=out['mid_features'] if self.config[
+ 'use_mid_feature'] else out['desc_map'],
+ kpts=pred['keypoints'][i],
+ norm_desc=self.config['norm_desc']) # [D, N]
+ new_scores.append(seg_scores[None])
+ new_descs.append(seg_descs[None])
+ pred['global_descriptors'] = global_descs
+ pred['scores'] = torch.cat(new_scores, dim=0)
+ pred['seg_descriptors'] = torch.cat(new_descs, dim=0).permute(0, 2, 1) # -> [B, N, D]
+
+ def process_epoch(self):
+ self.model.train()
+
+ epoch_cls_losses = []
+ epoch_seg_losses = []
+ epoch_losses = []
+ epoch_acc_corr = []
+ epoch_acc_incorr = []
+ epoch_cls_acc = []
+
+ epoch_sc_losses = []
+
+ for bidx, pred in tqdm(enumerate(self.train_loader), total=len(self.train_loader)):
+ self.preprocess_input(pred)
+ if 0 <= self.config['its_per_epoch'] <= bidx:
+ break
+
+ data = self.model(pred)
+ for k, v in pred.items():
+ pred[k] = v
+ pred = {**pred, **data}
+
+ seg_loss = compute_seg_loss_weight(pred=pred['prediction'],
+ target=pred['gt_seg'],
+ background_id=0,
+ weight_background=0.1)
+ acc_corr, acc_incorr = compute_corr_incorr(pred=pred['prediction'],
+ target=pred['gt_seg'],
+ ignored_ids=[0])
+
+ if self.with_cls:
+ pred_cls_dist = pred['classification']
+ gt_cls_dist = pred['gt_cls_dist']
+ if len(pred_cls_dist.shape) > 2:
+ gt_cls_dist_full = gt_cls_dist.unsqueeze(-1).repeat(1, 1, pred_cls_dist.shape[-1])
+ else:
+ gt_cls_dist_full = gt_cls_dist.unsqueeze(-1)
+ cls_loss = compute_cls_loss_ce(pred=pred_cls_dist, target=gt_cls_dist_full)
+ loss = seg_loss + cls_loss
+
+ # gt_n_seg = pred['gt_n_seg']
+ cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist)
+ else:
+ loss = seg_loss
+ cls_loss = torch.zeros_like(seg_loss)
+ cls_acc = torch.zeros_like(seg_loss)
+
+ if self.with_sc:
+ pass
+ else:
+ sc_loss = torch.zeros_like(seg_loss)
+
+ epoch_losses.append(loss.item())
+ epoch_seg_losses.append(seg_loss.item())
+ epoch_cls_losses.append(cls_loss.item())
+ epoch_sc_losses.append(sc_loss.item())
+
+ epoch_acc_corr.append(acc_corr.item())
+ epoch_acc_incorr.append(acc_incorr.item())
+ epoch_cls_acc.append(cls_acc.item())
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ self.iteration += 1
+
+ lr = min(self.config['lr'] * self.config['decay_rate'] ** (self.iteration - self.config['decay_iter']),
+ self.config['lr'])
+ if lr < self.min_lr:
+ lr = self.min_lr
+
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = lr
+
+ if self.config['local_rank'] == 0 and bidx % self.config['log_intervals'] == 0:
+ print_text = 'Epoch [{:d}/{:d}], Step [{:d}/{:d}/{:d}], Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]'.format(
+ self.epoch,
+ self.num_epochs, bidx,
+ len(self.train_loader),
+ self.iteration,
+ seg_loss.item(),
+ cls_loss.item(),
+ sc_loss.item(),
+ loss.item(),
+
+ np.mean(epoch_acc_corr),
+ np.mean(epoch_acc_incorr),
+ np.mean(epoch_cls_acc)
+ )
+
+ print(print_text)
+ self.log_file.write(print_text + '\n')
+
+ info = {
+ 'lr': lr,
+ 'loss': loss.item(),
+ 'cls_loss': cls_loss.item(),
+ 'sc_loss': sc_loss.item(),
+ 'acc_corr': acc_corr.item(),
+ 'acc_incorr': acc_incorr.item(),
+ 'acc_cls': cls_acc.item(),
+ }
+
+ for k, v in info.items():
+ self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.iteration)
+
+ if self.config['local_rank'] == 0:
+ print_text = 'Epoch [{:d}/{:d}], AVG Loss [s{:.2f}/c{:.2f}/sc{:.2f}/t{:.2f}], Acc [c{:.2f}/{:.2f}/{:.2f}]\n'.format(
+ self.epoch,
+ self.num_epochs,
+ np.mean(epoch_seg_losses),
+ np.mean(epoch_cls_losses),
+ np.mean(epoch_sc_losses),
+ np.mean(epoch_losses),
+ np.mean(epoch_acc_corr),
+ np.mean(epoch_acc_incorr),
+ np.mean(epoch_cls_acc),
+ )
+ print(print_text)
+ self.log_file.write(print_text + '\n')
+ self.log_file.flush()
+ return np.mean(epoch_losses)
+
+ def eval_seg(self, loader):
+ print('Start to do evaluation...')
+
+ self.model.eval()
+ self.seq_metric.clear()
+ mean_iou_day = []
+ mean_iou_night = []
+ mean_prec_day = []
+ mean_prec_night = []
+ mean_cls_day = []
+ mean_cls_night = []
+
+ for bid, pred in tqdm(enumerate(loader), total=len(loader)):
+ for k in pred.keys():
+ if k.find('name') >= 0:
+ continue
+ if k != 'image' and k != 'depth':
+ if type(pred[k]) == torch.Tensor:
+ pred[k] = Variable(pred[k].float().cuda())
+ elif type(pred[k]) == np.ndarray:
+ pred[k] = Variable(torch.from_numpy(pred[k]).float()[None].cuda())
+ else:
+ pred[k] = Variable(torch.stack(pred[k]).float().cuda())
+
+ if self.with_aug:
+ with torch.no_grad():
+ if isinstance(pred['image'][0], list):
+ img = pred['image'][0][0]
+ else:
+ img = pred['image'][0]
+
+ img = torch.from_numpy(img).cuda().float().permute(2, 0, 1)
+ if self.img_transforms is not None:
+ img = self.img_transforms(img)[None]
+ else:
+ img = img[None]
+
+ encoder_out = self.feat_model.extract_local_global(data={'image': img})
+ global_descriptors = [encoder_out['global_descriptors']]
+ pred['global_descriptors'] = global_descriptors
+ if self.config['use_mid_feature']:
+ scores, descs = self.feat_model.sample(score_map=encoder_out['score_map'],
+ semi_descs=encoder_out['mid_features'],
+ kpts=pred['keypoints'][0],
+ norm_desc=self.config['norm_desc'])
+ # print('eval: ', scores.shape, descs.shape)
+ pred['scores'] = scores[None]
+ pred['seg_descriptors'] = descs[None].permute(0, 2, 1) # -> [B, N, D]
+ else:
+ pred['seg_descriptors'] = pred['descriptors']
+
+ image_name = pred['file_name'][0]
+ with torch.no_grad():
+ out = self.model(pred)
+ pred = {**pred, **out}
+
+ pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C]
+ pred_seg = pred_seg[0].cpu().numpy()
+ gt_seg = pred['gt_seg'][0].cpu().numpy()
+ iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=self.config['n_class'], ignored_ids=[0])
+ prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0])
+
+ if self.with_cls:
+ pred_cls_dist = pred['classification']
+ gt_cls_dist = pred['gt_cls_dist']
+ cls_acc = compute_cls_corr(pred=pred_cls_dist.squeeze(-1), target=gt_cls_dist).item()
+ else:
+ cls_acc = 0.
+
+ if image_name.find('night') >= 0:
+ mean_iou_night.append(iou)
+ mean_prec_night.append(prec)
+ mean_cls_night.append(cls_acc)
+ else:
+ mean_iou_day.append(iou)
+ mean_prec_day.append(prec)
+ mean_cls_day.append(cls_acc)
+
+ print_txt = 'Eval Epoch {:d}, iou day/night {:.3f}/{:.3f}, prec day/night {:.3f}/{:.3f}, cls day/night {:.3f}/{:.3f}'.format(
+ self.epoch, np.mean(mean_iou_day), np.mean(mean_iou_night),
+ np.mean(mean_prec_day), np.mean(mean_prec_night),
+ np.mean(mean_cls_day), np.mean(mean_cls_night))
+ self.log_file.write(print_txt + '\n')
+ print(print_txt)
+
+ info = {
+ 'mean_iou_day': np.mean(mean_iou_day),
+ 'mean_iou_night': np.mean(mean_iou_night),
+ 'mean_prec_day': np.mean(mean_prec_day),
+ 'mean_prec_night': np.mean(mean_prec_night),
+ }
+
+ for k, v in info.items():
+ self.writer.add_scalar(tag=k, scalar_value=v, global_step=self.epoch)
+
+ return np.mean(mean_prec_night)
+
+ def train(self):
+ if self.config['local_rank'] == 0:
+ print('Start to train the model from epoch: {:d}'.format(self.epoch))
+ hist_values = []
+ min_value = self.min_loss
+
+ epoch = self.epoch
+ while epoch < self.num_epochs:
+ if self.config['with_dist']:
+ self.train_loader.sampler.set_epoch(epoch=epoch)
+ self.epoch = epoch
+
+ train_loss = self.process_epoch()
+
+ # return with loss INF/NAN
+ if train_loss is None:
+ continue
+
+ if self.config['local_rank'] == 0:
+ if self.do_eval and self.epoch % self.config['eval_n_epoch'] == 0: # and self.epoch >= 50:
+ eval_ratio = self.eval_seg(loader=self.eval_loader)
+
+ hist_values.append(eval_ratio) # higher better
+ else:
+ hist_values.append(-train_loss) # lower better
+
+ checkpoint_path = os.path.join(self.save_dir,
+ '%s.%02d.pth' % (self.config['network'], self.epoch))
+ checkpoint = {
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'model': self.model.state_dict(),
+ 'min_loss': min_value,
+ }
+ # for multi-gpu training
+ if len(self.config['gpu']) > 1:
+ checkpoint['model'] = self.model.module.state_dict()
+
+ torch.save(checkpoint, checkpoint_path)
+
+ if hist_values[-1] < min_value:
+ min_value = hist_values[-1]
+ best_checkpoint_path = os.path.join(
+ self.save_dir,
+ '%s.best.pth' % (self.tag)
+ )
+ shutil.copy(checkpoint_path, best_checkpoint_path)
+ # important!!!
+ epoch += 1
+
+ if self.config['local_rank'] == 0:
+ self.log_file.close()
diff --git a/third_party/pram/weights/imp_gml.920.pth b/third_party/pram/weights/imp_gml.920.pth
new file mode 100644
index 0000000000000000000000000000000000000000..dd9af051ef4af22329dbad4f168e30d948a97655
--- /dev/null
+++ b/third_party/pram/weights/imp_gml.920.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89ac37d35a667bdcae8566f5a236fcc2f0e3f407c30360bb378084ace4a29531
+size 47597159
diff --git a/third_party/pram/weights/sfd2_20230511_210205_resnet4x.79.pth b/third_party/pram/weights/sfd2_20230511_210205_resnet4x.79.pth
new file mode 100644
index 0000000000000000000000000000000000000000..39bb1f3d11dd93cb7c5dd11d3b6eb47b0b20f07d
--- /dev/null
+++ b/third_party/pram/weights/sfd2_20230511_210205_resnet4x.79.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06bbddca9f1acfaff09c29d0e3311d2c4ef6b8faaadd312683929c4af8a8898a
+size 16095284
diff --git a/ui/config.yaml b/ui/config.yaml
index d94cc3f67789b454c248b10468b9b2354ba358a9..28d0a5106718e25e6e3fd31cfe95d270bb0d3b17 100644
--- a/ui/config.yaml
+++ b/ui/config.yaml
@@ -389,9 +389,8 @@ matcher_zoo:
sfd2+imp:
matcher: imp
feature: sfd2
- enable: false
+ enable: true
dense: false
- skip_ci: true
info:
name: SFD2+IMP #dispaly name
source: "CVPR 2023"
@@ -403,9 +402,8 @@ matcher_zoo:
sfd2+mnn:
matcher: NN-mutual
feature: sfd2
- enable: false
+ enable: true
dense: false
- skip_ci: true
info:
name: SFD2+MNN #dispaly name
source: "CVPR 2023"