diff --git a/.clang-format b/.clang-format
new file mode 100644
index 0000000000000000000000000000000000000000..27e8fb94966d415ef0aea4a886d481933ca392b8
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,40 @@
+BasedOnStyle: Chromium
+IncludeBlocks: Preserve
+IncludeCategories:
+ - Regex: '^<.*>'
+ Priority: 1
+ - Regex: '^".*"'
+ Priority: 2
+SortIncludes: true
+Language: Cpp
+AccessModifierOffset: 2
+AlignAfterOpenBracket: true
+AlignConsecutiveAssignments: false
+AlignConsecutiveDeclarations: false
+AlignEscapedNewlines: Right
+AlignOperands: true
+AlignTrailingComments: false
+AllowAllParametersOfDeclarationOnNextLine: true
+AllowShortBlocksOnASingleLine: false
+AllowShortCaseLabelsOnASingleLine: true
+AllowShortFunctionsOnASingleLine: None
+AllowShortIfStatementsOnASingleLine: true
+AllowShortLoopsOnASingleLine: true
+AlwaysBreakAfterReturnType: None
+AlwaysBreakBeforeMultilineStrings: true
+AlwaysBreakTemplateDeclarations: false
+BinPackArguments: false
+BinPackParameters: false
+BreakBeforeBraces: Attach
+BreakBeforeInheritanceComma: false
+BreakBeforeTernaryOperators: true
+BreakStringLiterals: false
+ColumnLimit: 88
+CompactNamespaces: false
+ConstructorInitializerAllOnOneLineOrOnePerLine: true
+ConstructorInitializerIndentWidth: 4
+ContinuationIndentWidth: 4
+IndentCaseLabels: true
+IndentWidth: 4
+TabWidth: 4
+UseTab: Never
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..dc41813948358b1c4961863e937f642b2c5bd1c3
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,41 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.JPG filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.bmp filter=lfs diff=lfs merge=lfs -text
+*.pdf filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/.stale.yml b/.github/.stale.yml
new file mode 100644
index 0000000000000000000000000000000000000000..dc90e5a1c3aad4818a813606b52fdecd2fdf6782
--- /dev/null
+++ b/.github/.stale.yml
@@ -0,0 +1,17 @@
+# Number of days of inactivity before an issue becomes stale
+daysUntilStale: 60
+# Number of days of inactivity before a stale issue is closed
+daysUntilClose: 7
+# Issues with these labels will never be considered stale
+exemptLabels:
+ - pinned
+ - security
+# Label to use when marking an issue as stale
+staleLabel: wontfix
+# Comment to post when marking an issue as stale. Set to `false` to disable
+markComment: >
+ This issue has been automatically marked as stale because it has not had
+ recent activity. It will be closed if no further activity occurs. Thank you
+ for your contributions.
+# Comment to post when closing a stale issue. Set to `false` to disable
+closeComment: false
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000000000000000000000000000000000000..036bffc7cbe7f88ef6c4657752a24a73e4bf9a76
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,30 @@
+---
+name: 🐛 Bug report
+about: If something isn't working 🔧
+title: ""
+labels: bug
+assignees:
+---
+
+## 🐛 Bug Report
+
+
+
+## 🔬 How To Reproduce
+
+Steps to reproduce the behavior:
+
+1. ...
+
+### Environment
+
+- OS: [e.g. Linux / Windows / macOS]
+- Python version, get it with:
+
+```bash
+python --version
+```
+
+## 📎 Additional context
+
+
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8f2da5489e290e6e55426eaeac2234c61f21f638
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,3 @@
+# Configuration: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository
+
+blank_issues_enabled: false
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..7ce8c1277148183d41908db966ee2f0978c7a02a
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,15 @@
+---
+name: 🚀 Feature request
+about: Suggest an idea for this project 🏖
+title: ""
+labels: enhancement
+assignees:
+---
+
+## 🚀 Feature Request
+
+
+
+## 📎 Additional context
+
+
diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md
new file mode 100644
index 0000000000000000000000000000000000000000..0b624eefe6041c776f1ab4a6aba44d3d1c8cdd83
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/question.md
@@ -0,0 +1,25 @@
+---
+name: ❓ Question
+about: Ask a question about this project 🎓
+title: ""
+labels: question
+assignees:
+---
+
+## Checklist
+
+
+
+- [ ] I've searched the project's [`issues`]
+
+## ❓ Question
+
+
+
+How can I [...]?
+
+Is it possible to [...]?
+
+## 📎 Additional context
+
+
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000000000000000000000000000000000000..4dab74cab6b6b173c9a0a15e1d7652f75bc29818
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,7 @@
+## Description
+
+
+
+## Related Issue
+
+
diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fc4b3c3f511d20f0dd255f76f3fea1fe0dbd1ce5
--- /dev/null
+++ b/.github/release-drafter.yml
@@ -0,0 +1,24 @@
+# Release drafter configuration https://github.com/release-drafter/release-drafter#configuration
+# Emojis were chosen to match the https://gitmoji.carloscuesta.me/
+
+name-template: "v$RESOLVED_VERSION"
+tag-template: "v$RESOLVED_VERSION"
+
+categories:
+ - title: ":rocket: Features"
+ labels: [enhancement, feature]
+ - title: ":wrench: Fixes"
+ labels: [bug, bugfix, fix]
+ - title: ":toolbox: Maintenance & Refactor"
+ labels: [refactor, refactoring, chore]
+ - title: ":package: Build System & CI/CD & Test"
+ labels: [build, ci, testing, test]
+ - title: ":pencil: Documentation"
+ labels: [documentation]
+ - title: ":arrow_up: Dependencies updates"
+ labels: [dependencies]
+
+template: |
+ ## What’s Changed
+
+ $CHANGES
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8d25bdb208d7554ace8acae236a943a311aed12c
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,37 @@
+name: CI CPU
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ # runs-on: self-hosted
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ submodules: recursive
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+
+ - name: Install dependencies
+ run: |
+ pip install -r requirements.txt
+ sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
+
+ - name: Build and install
+ run: pip install .
+
+ - name: Run tests
+ # run: python -m pytest
+ run: python tests/test_basic.py
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
new file mode 100644
index 0000000000000000000000000000000000000000..39eca7f82193545f5afe00a3ff841ee93db4967e
--- /dev/null
+++ b/.github/workflows/format.yml
@@ -0,0 +1,23 @@
+# This is a format job. Pre-commit has a first-party GitHub action, so we use
+# that: https://github.com/pre-commit/action
+
+name: Format
+
+on:
+ workflow_dispatch:
+ pull_request:
+ push:
+ branches:
+ - main
+
+jobs:
+ pre-commit:
+ name: Format
+ runs-on: ubuntu-latest
+ # runs-on: self-hosted
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.x"
+ - uses: pre-commit/action@v3.0.1
diff --git a/.github/workflows/pip.yml b/.github/workflows/pip.yml
new file mode 100644
index 0000000000000000000000000000000000000000..87fec4fef633d9f13732c4a762eb6f835c40447b
--- /dev/null
+++ b/.github/workflows/pip.yml
@@ -0,0 +1,62 @@
+name: Pip
+on:
+ workflow_dispatch:
+ pull_request:
+ push:
+ branches:
+ - main
+
+jobs:
+ build:
+ strategy:
+ fail-fast: false
+ matrix:
+ platform: [ubuntu-latest]
+ python-version: ["3.9", "3.10"]
+
+ runs-on: ${{ matrix.platform }}
+ # runs-on: self-hosted
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: recursive
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Upgrade setuptools and wheel
+ run: |
+ pip install --upgrade setuptools wheel
+
+ - name: Install dependencies on Ubuntu
+ if: runner.os == 'Linux'
+ run: |
+ sudo apt-get update
+ sudo apt-get install libopencv-dev -y
+
+ - name: Install dependencies on macOS
+ if: runner.os == 'macOS'
+ run: |
+ brew update
+ brew install opencv
+
+ - name: Install dependencies on Windows
+ if: runner.os == 'Windows'
+ run: |
+ choco install opencv -y
+
+ - name: Add requirements
+ run: python -m pip install --upgrade wheel setuptools
+
+ - name: Install Python dependencies
+ run: |
+ pip install pytest
+ pip install -r requirements.txt
+ sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
+
+ - name: Build and install
+ run: pip install .
+
+ - name: Test
+ run: python -m pytest
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c272ab19fd460c0d2dca3f0d52567a003e86212c
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,95 @@
+name: PyPI Release
+on:
+ release:
+ types: [published]
+
+jobs:
+ build:
+ strategy:
+ fail-fast: false
+ matrix:
+ platform: [ubuntu-latest]
+ python-version: ["3.9", "3.10", "3.11"]
+
+ runs-on: ${{ matrix.platform }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: recursive
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Upgrade setuptools and wheel
+ run: |
+ pip install --upgrade setuptools wheel
+
+ - name: Install dependencies on Ubuntu
+ if: runner.os == 'Linux'
+ run: |
+ sudo apt-get update
+ sudo apt-get install libopencv-dev -y
+
+ - name: Install dependencies on macOS
+ if: runner.os == 'macOS'
+ run: |
+ brew update
+ brew install opencv
+
+ - name: Install dependencies on Windows
+ if: runner.os == 'Windows'
+ run: |
+ choco install opencv -y
+
+ - name: Add requirements
+ run: python -m pip install --upgrade setuptools wheel build
+
+ - name: Install Python dependencies
+ run: |
+ pip install pytest
+ pip install -r requirements.txt
+ sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
+
+ - name: Build source distribution
+ run: |
+ python -m build --outdir dist/
+ ls -lh dist/
+
+ - name: Upload to GitHub Release
+ if: matrix.python-version == '3.10' && github.event_name == 'release'
+ uses: softprops/action-gh-release@v2
+ with:
+ files: dist/*.whl
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Archive wheels
+ if: matrix.python-version == '3.10' && github.event_name == 'release'
+ uses: actions/upload-artifact@v4
+ with:
+ name: dist
+ path: dist/*.whl
+
+
+ pypi-publish:
+ name: upload release to PyPI
+ needs: build
+ runs-on: ubuntu-latest
+ environment: pypi
+ permissions:
+ # IMPORTANT: this permission is mandatory for Trusted Publishing
+ id-token: write
+ steps:
+ # retrieve your distributions here
+ - name: Download artifacts
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist
+
+ - name: List dist directory
+ run: ls -lh dist/
+
+ - name: Publish package distributions to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..6615959c6a79433c573a0855f2f8f1b5fb620199
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,36 @@
+build/
+
+lib/
+bin/
+
+cmake_modules/
+cmake-build-debug/
+.idea/
+.vscode/
+*.pyc
+flagged
+.ipynb_checkpoints
+__pycache__
+Untitled*
+experiments
+third_party/REKD
+hloc/matchers/dedode.py
+gradio_cached_examples
+*.mp4
+hloc/matchers/quadtree.py
+third_party/QuadTreeAttention
+desktop.ini
+*.egg-info
+output.pkl
+log.txt
+experiments*
+gen_example.py
+datasets/lines/terrace0.JPG
+datasets/lines/terrace1.JPG
+datasets/South-Building*
+*.pkl
+oryx-build-commands.txt
+.ruff_cache*
+dist
+tmp
+backup*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a4445ab2065b11dc18e3c820422008c4287b1d1
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,88 @@
+# To use:
+#
+# pre-commit run -a
+#
+# Or:
+#
+# pre-commit run --all-files
+#
+# Or:
+#
+# pre-commit install # (runs every time you commit in git)
+#
+# To update this file:
+#
+# pre-commit autoupdate
+#
+# See https://github.com/pre-commit/pre-commit
+
+ci:
+ autoupdate_commit_msg: "chore: update pre-commit hooks"
+ autofix_commit_msg: "style: pre-commit fixes"
+
+repos:
+# Standard hooks
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v5.0.0
+ hooks:
+ - id: check-added-large-files
+ exclude: ^imcui/third_party/
+ - id: check-case-conflict
+ exclude: ^imcui/third_party/
+ - id: check-merge-conflict
+ exclude: ^imcui/third_party/
+ - id: check-symlinks
+ exclude: ^imcui/third_party/
+ - id: check-yaml
+ exclude: ^imcui/third_party/
+ - id: debug-statements
+ exclude: ^imcui/third_party/
+ - id: end-of-file-fixer
+ exclude: ^imcui/third_party/
+ - id: mixed-line-ending
+ exclude: ^imcui/third_party/
+ - id: requirements-txt-fixer
+ exclude: ^imcui/third_party/
+ - id: trailing-whitespace
+ exclude: ^imcui/third_party/
+
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: "v0.8.4"
+ hooks:
+ - id: ruff
+ args: ["--fix", "--show-fixes", "--extend-ignore=E402"]
+ - id: ruff-format
+ exclude: ^(docs|imcui/third_party/)
+
+# Checking static types
+- repo: https://github.com/pre-commit/mirrors-mypy
+ rev: "v1.14.0"
+ hooks:
+ - id: mypy
+ files: "setup.py"
+ args: []
+ additional_dependencies: [types-setuptools]
+ exclude: ^imcui/third_party/
+# Changes tabs to spaces
+- repo: https://github.com/Lucas-C/pre-commit-hooks
+ rev: v1.5.5
+ hooks:
+ - id: remove-tabs
+ exclude: ^(docs|imcui/third_party/)
+
+# CMake formatting
+- repo: https://github.com/cheshirekow/cmake-format-precommit
+ rev: v0.6.13
+ hooks:
+ - id: cmake-format
+ additional_dependencies: [pyyaml]
+ types: [file]
+ files: (\.cmake|CMakeLists.txt)(.in)?$
+ exclude: ^imcui/third_party/
+
+# Suggested hook if you add a .clang-format file
+- repo: https://github.com/pre-commit/mirrors-clang-format
+ rev: v13.0.0
+ hooks:
+ - id: clang-format
+ exclude: ^imcui/third_party/
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..6c419c020eb769944237d2b27260de40ac8a7626
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,128 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+alpharealcat@gmail.com.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..09fd374039435f2c7a313d252fea4148e413d3f6
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,27 @@
+# Use an official conda-based Python image as a parent image
+FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime
+LABEL maintainer vincentqyw
+ARG PYTHON_VERSION=3.10.10
+
+# Set the working directory to /code
+WORKDIR /code
+
+# Install Git and Git LFS
+RUN apt-get update && apt-get install -y git-lfs
+RUN git lfs install
+
+# Clone the Git repository
+RUN git clone --recursive https://github.com/Vincentqyw/image-matching-webui.git /code
+
+RUN conda create -n imw python=${PYTHON_VERSION}
+RUN echo "source activate imw" > ~/.bashrc
+ENV PATH /opt/conda/envs/imw/bin:$PATH
+
+# Make RUN commands use the new environment
+SHELL ["conda", "run", "-n", "imw", "/bin/bash", "-c"]
+RUN pip install --upgrade pip
+RUN pip install -r requirements.txt
+RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
+
+# Export port
+EXPOSE 7860
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..2f2db59983f1aca800b0a43c2ab260cfc68fa311
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,12 @@
+# logo
+include imcui/assets/logo.webp
+
+recursive-include imcui/ui *.yaml
+recursive-include imcui/api *.yaml
+recursive-include imcui/third_party *.yaml *.cfg *.yml
+
+# ui examples
+# recursive-include imcui/datasets *.JPG *.jpg *.png
+
+# model
+recursive-include imcui/third_party/SuperGluePretrainedNetwork *.pth
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7abaf502e4022eb38d09a3530c7a7b6c2b525de2
--- /dev/null
+++ b/README.md
@@ -0,0 +1,194 @@
+[![Contributors][contributors-shield]][contributors-url]
+[![Forks][forks-shield]][forks-url]
+[![Stargazers][stars-shield]][stars-url]
+[![Issues][issues-shield]][issues-url]
+
+
+
$\color{red}{\textnormal{Image\ Matching\ WebUI}}$
+
Matching Keypoints between two images
+
+
+## Description
+
+This simple tool efficiently matches image pairs using multiple famous image matching algorithms. The tool features a Graphical User Interface (GUI) designed using [gradio](https://gradio.app/). You can effortlessly select two images and a matching algorithm and obtain a precise matching result.
+**Note**: the images source can be either local images or webcam images.
+
+Try it on
+
+
+
+
+Here is a demo of the tool:
+
+https://github.com/Vincentqyw/image-matching-webui/assets/18531182/263534692-c3484d1b-cc00-4fdc-9b31-e5b7af07ecd9
+
+The tool currently supports various popular image matching algorithms, namely:
+- [x] [MINIMA](https://github.com/LSXI7/MINIMA), ARXIV 2024
+- [x] [XoFTR](https://github.com/OnderT/XoFTR), CVPR 2024
+- [x] [EfficientLoFTR](https://github.com/zju3dv/EfficientLoFTR), CVPR 2024
+- [x] [MASt3R](https://github.com/naver/mast3r), CVPR 2024
+- [x] [DUSt3R](https://github.com/naver/dust3r), CVPR 2024
+- [x] [OmniGlue](https://github.com/Vincentqyw/omniglue-onnx), CVPR 2024
+- [x] [XFeat](https://github.com/verlab/accelerated_features), CVPR 2024
+- [x] [RoMa](https://github.com/Vincentqyw/RoMa), CVPR 2024
+- [x] [DeDoDe](https://github.com/Parskatt/DeDoDe), 3DV 2024
+- [ ] [Mickey](https://github.com/nianticlabs/mickey), CVPR 2024
+- [x] [GIM](https://github.com/xuelunshen/gim), ICLR 2024
+- [x] [ALIKED](https://github.com/Shiaoming/ALIKED), ICCV 2023
+- [x] [LightGlue](https://github.com/cvg/LightGlue), ICCV 2023
+- [x] [DarkFeat](https://github.com/THU-LYJ-Lab/DarkFeat), AAAI 2023
+- [x] [SFD2](https://github.com/feixue94/sfd2), CVPR 2023
+- [x] [IMP](https://github.com/feixue94/imp-release), CVPR 2023
+- [ ] [ASTR](https://github.com/ASTR2023/ASTR), CVPR 2023
+- [ ] [SEM](https://github.com/SEM2023/SEM), CVPR 2023
+- [ ] [DeepLSD](https://github.com/cvg/DeepLSD), CVPR 2023
+- [x] [GlueStick](https://github.com/cvg/GlueStick), ICCV 2023
+- [ ] [ConvMatch](https://github.com/SuhZhang/ConvMatch), AAAI 2023
+- [x] [LoFTR](https://github.com/zju3dv/LoFTR), CVPR 2021
+- [x] [SOLD2](https://github.com/cvg/SOLD2), CVPR 2021
+- [ ] [LineTR](https://github.com/yosungho/LineTR), RA-L 2021
+- [x] [DKM](https://github.com/Parskatt/DKM), CVPR 2023
+- [ ] [NCMNet](https://github.com/xinliu29/NCMNet), CVPR 2023
+- [x] [TopicFM](https://github.com/Vincentqyw/TopicFM), AAAI 2023
+- [x] [AspanFormer](https://github.com/Vincentqyw/ml-aspanformer), ECCV 2022
+- [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022
+- [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022
+- [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022
+- [x] [CoTR](https://github.com/ubc-vision/COTR), ICCV 2021
+- [x] [ALIKE](https://github.com/Shiaoming/ALIKE), TMM 2022
+- [x] [RoRD](https://github.com/UditSinghParihar/RoRD), IROS 2021
+- [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021
+- [x] [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork), CVPRW 2018
+- [x] [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork), CVPR 2020
+- [x] [D2Net](https://github.com/Vincentqyw/d2-net), CVPR 2019
+- [x] [R2D2](https://github.com/naver/r2d2), NeurIPS 2019
+- [x] [DISK](https://github.com/cvlab-epfl/disk), NeurIPS 2020
+- [ ] [Key.Net](https://github.com/axelBarroso/Key.Net), ICCV 2019
+- [ ] [OANet](https://github.com/zjhthu/OANet), ICCV 2019
+- [x] [SOSNet](https://github.com/scape-research/SOSNet), CVPR 2019
+- [x] [HardNet](https://github.com/DagnyT/hardnet), NeurIPS 2017
+- [x] [SIFT](https://docs.opencv.org/4.x/da/df5/tutorial_py_sift_intro.html), IJCV 2004
+
+## How to use
+
+### HuggingFace / Lightning AI
+
+Just try it on
+
+
+
+
+or deploy it locally following the instructions below.
+
+### Requirements
+
+- [Python 3.9+](https://www.python.org/downloads/)
+
+#### Install from pip [NEW]
+
+Update: now support install from [pip](https://pypi.org/project/imcui), just run:
+
+```bash
+pip install imcui
+```
+
+#### Install from source
+
+``` bash
+git clone --recursive https://github.com/Vincentqyw/image-matching-webui.git
+cd image-matching-webui
+conda env create -f environment.yaml
+conda activate imw
+pip install -e .
+```
+
+or using [docker](https://hub.docker.com/r/vincentqin/image-matching-webui):
+
+``` bash
+docker pull vincentqin/image-matching-webui:latest
+docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860
+```
+
+### Deploy to Railway
+
+Deploy to [Railway](https://railway.app/), setting up a `Custom Start Command` in `Deploy` section:
+
+``` bash
+python -m imcui.api.server
+```
+
+### Run demo
+``` bash
+python app.py --config ./config/config.yaml
+```
+then open http://localhost:7860 in your browser.
+
+
+
+### Add your own feature / matcher
+
+I provide an example to add local feature in [imcui/hloc/extractors/example.py](imcui/hloc/extractors/example.py). Then add feature settings in `confs` in file [imcui/hloc/extract_features.py](imcui/hloc/extract_features.py). Last step is adding some settings to `matcher_zoo` in file [imcui/ui/config.yaml](imcui/ui/config.yaml).
+
+### Upload models
+
+IMCUI hosts all models on [Huggingface](https://huggingface.co/Realcat/imcui_checkpoints). You can upload your model to Huggingface and add it to the [Realcat/imcui_checkpoints](https://huggingface.co/Realcat/imcui_checkpoints) repository.
+
+
+## Contributions welcome!
+
+External contributions are very much welcome. Please follow the [PEP8 style guidelines](https://www.python.org/dev/peps/pep-0008/) using a linter like flake8. This is a non-exhaustive list of features that might be valuable additions:
+
+- [x] support pip install command
+- [x] add [CPU CI](.github/workflows/ci.yml)
+- [x] add webcam support
+- [x] add [line feature matching](https://github.com/Vincentqyw/LineSegmentsDetection) algorithms
+- [x] example to add a new feature extractor / matcher
+- [x] ransac to filter outliers
+- [ ] add [rotation images](https://github.com/pidahbus/deep-image-orientation-angle-detection) options before matching
+- [ ] support export matches to colmap ([#issue 6](https://github.com/Vincentqyw/image-matching-webui/issues/6))
+- [x] add config file to set default parameters
+- [x] dynamically load models and reduce GPU overload
+
+Adding local features / matchers as submodules is very easy. For example, to add the [GlueStick](https://github.com/cvg/GlueStick):
+
+``` bash
+git submodule add https://github.com/cvg/GlueStick.git imcui/third_party/GlueStick
+```
+
+If remote submodule repositories are updated, don't forget to pull submodules with:
+
+``` bash
+git submodule update --init --recursive # init and download
+git submodule update --remote # update
+```
+
+if you only want to update one submodule, use `git submodule update --remote imcui/third_party/GlueStick`.
+
+To format code before committing, run:
+
+```bash
+pre-commit run -a # Auto-checks and fixes
+```
+
+## Contributors
+
+
+
+
+
+## Resources
+- [Image Matching: Local Features & Beyond](https://image-matching-workshop.github.io)
+- [Long-term Visual Localization](https://www.visuallocalization.net)
+
+## Acknowledgement
+
+This code is built based on [Hierarchical-Localization](https://github.com/cvg/Hierarchical-Localization). We express our gratitude to the authors for their valuable source code.
+
+[contributors-shield]: https://img.shields.io/github/contributors/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[contributors-url]: https://github.com/Vincentqyw/image-matching-webui/graphs/contributors
+[forks-shield]: https://img.shields.io/github/forks/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[forks-url]: https://github.com/Vincentqyw/image-matching-webui/network/members
+[stars-shield]: https://img.shields.io/github/stars/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[stars-url]: https://github.com/Vincentqyw/image-matching-webui/stargazers
+[issues-shield]: https://img.shields.io/github/issues/Vincentqyw/image-matching-webui.svg?style=for-the-badge
+[issues-url]: https://github.com/Vincentqyw/image-matching-webui/issues
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..9beb1e9bc714791d2bca91019443e72eb9cfd805
--- /dev/null
+++ b/app.py
@@ -0,0 +1,31 @@
+import argparse
+from pathlib import Path
+from imcui.ui.app_class import ImageMatchingApp
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--server_name",
+ type=str,
+ default="0.0.0.0",
+ help="server name",
+ )
+ parser.add_argument(
+ "--server_port",
+ type=int,
+ default=7860,
+ help="server port",
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default=Path(__file__).parent / "config/config.yaml",
+ help="config file",
+ )
+ args = parser.parse_args()
+ ImageMatchingApp(
+ args.server_name,
+ args.server_port,
+ config=args.config,
+ example_data_root=Path("imcui/datasets"),
+ ).run()
diff --git a/assets/demo.gif b/assets/demo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..9af81d1ecded321bbd99ac7b84191518d6daf17d
--- /dev/null
+++ b/assets/demo.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f163c0e2699181897c81c68e01c60fa4289e886a2a40932d53dd529262d3735
+size 8907062
diff --git a/assets/gui.jpg b/assets/gui.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cc91f306d1cbfb261ce4725891e4843618f5f4f6
--- /dev/null
+++ b/assets/gui.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a783162639d05631f34e8e3e9a7df682197a76f675265ebbaa639927e08473f7
+size 1669098
diff --git a/assets/logo.webp b/assets/logo.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0a799debc1a06cd6e500a8bccd0ddcef7eca0508
Binary files /dev/null and b/assets/logo.webp differ
diff --git a/build_docker.sh b/build_docker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a5aea45e6ff5024b71818dea6f4e7cfb0d0ae6c0
--- /dev/null
+++ b/build_docker.sh
@@ -0,0 +1,3 @@
+docker build -t image-matching-webui:latest . --no-cache
+docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
+docker push vincentqin/image-matching-webui:latest
diff --git a/config/config.yaml b/config/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2e1449d99272db6bc5dcf3452fecae87789bdd4b
--- /dev/null
+++ b/config/config.yaml
@@ -0,0 +1,474 @@
+server:
+ name: "0.0.0.0"
+ port: 7861
+
+defaults:
+ setting_threshold: 0.1
+ max_keypoints: 2000
+ keypoint_threshold: 0.05
+ enable_ransac: true
+ ransac_method: CV2_USAC_MAGSAC
+ ransac_reproj_threshold: 8
+ ransac_confidence: 0.999
+ ransac_max_iter: 10000
+ ransac_num_samples: 4
+ match_threshold: 0.2
+ setting_geometry: Homography
+
+matcher_zoo:
+ minima(loftr):
+ matcher: minima_loftr
+ dense: true
+ info:
+ name: MINIMA(LoFTR) #dispaly name
+ source: "ARXIV 2024"
+ paper: https://arxiv.org/abs/2412.19412
+ display: false
+ minima(RoMa):
+ matcher: minima_roma
+ skip_ci: true
+ dense: true
+ info:
+ name: MINIMA(RoMa) #dispaly name
+ source: "ARXIV 2024"
+ paper: https://arxiv.org/abs/2412.19412
+ display: false
+ omniglue:
+ enable: true
+ matcher: omniglue
+ dense: true
+ info:
+ name: OmniGlue
+ source: "CVPR 2024"
+ github: https://github.com/Vincentqyw/omniglue-onnx
+ paper: https://arxiv.org/abs/2405.12979
+ project: https://hwjiang1510.github.io/OmniGlue
+ display: true
+ Mast3R:
+ enable: false
+ matcher: mast3r
+ dense: true
+ info:
+ name: Mast3R #dispaly name
+ source: "CVPR 2024"
+ github: https://github.com/naver/mast3r
+ paper: https://arxiv.org/abs/2406.09756
+ project: https://dust3r.europe.naverlabs.com
+ display: true
+ DUSt3R:
+ # TODO: duster is under development
+ enable: true
+ # skip_ci: true
+ matcher: duster
+ dense: true
+ info:
+ name: DUSt3R #dispaly name
+ source: "CVPR 2024"
+ github: https://github.com/naver/dust3r
+ paper: https://arxiv.org/abs/2312.14132
+ project: https://dust3r.europe.naverlabs.com
+ display: true
+ GIM(dkm):
+ enable: true
+ # skip_ci: true
+ matcher: gim(dkm)
+ dense: true
+ info:
+ name: GIM(DKM) #dispaly name
+ source: "ICLR 2024"
+ github: https://github.com/xuelunshen/gim
+ paper: https://arxiv.org/abs/2402.11095
+ project: https://xuelunshen.com/gim
+ display: true
+ RoMa:
+ matcher: roma
+ skip_ci: true
+ dense: true
+ info:
+ name: RoMa #dispaly name
+ source: "CVPR 2024"
+ github: https://github.com/Parskatt/RoMa
+ paper: https://arxiv.org/abs/2305.15404
+ project: https://parskatt.github.io/RoMa
+ display: true
+ dkm:
+ matcher: dkm
+ skip_ci: true
+ dense: true
+ info:
+ name: DKM #dispaly name
+ source: "CVPR 2023"
+ github: https://github.com/Parskatt/DKM
+ paper: https://arxiv.org/abs/2202.00667
+ project: https://parskatt.github.io/DKM
+ display: true
+ loftr:
+ matcher: loftr
+ dense: true
+ info:
+ name: LoFTR #dispaly name
+ source: "CVPR 2021"
+ github: https://github.com/zju3dv/LoFTR
+ paper: https://arxiv.org/pdf/2104.00680
+ project: https://zju3dv.github.io/loftr
+ display: true
+ eloftr:
+ matcher: eloftr
+ dense: true
+ info:
+ name: Efficient LoFTR #dispaly name
+ source: "CVPR 2024"
+ github: https://github.com/zju3dv/efficientloftr
+ paper: https://zju3dv.github.io/efficientloftr/files/EfficientLoFTR.pdf
+ project: https://zju3dv.github.io/efficientloftr
+ display: true
+ xoftr:
+ matcher: xoftr
+ dense: true
+ info:
+ name: XoFTR #dispaly name
+ source: "CVPR 2024"
+ github: https://github.com/OnderT/XoFTR
+ paper: https://arxiv.org/pdf/2404.09692
+ project: null
+ display: true
+ cotr:
+ enable: false
+ skip_ci: true
+ matcher: cotr
+ dense: true
+ info:
+ name: CoTR #dispaly name
+ source: "ICCV 2021"
+ github: https://github.com/ubc-vision/COTR
+ paper: https://arxiv.org/abs/2103.14167
+ project: null
+ display: true
+ topicfm:
+ matcher: topicfm
+ dense: true
+ info:
+ name: TopicFM #dispaly name
+ source: "AAAI 2023"
+ github: https://github.com/TruongKhang/TopicFM
+ paper: https://arxiv.org/abs/2307.00485
+ project: null
+ display: true
+ aspanformer:
+ matcher: aspanformer
+ dense: true
+ info:
+ name: ASpanformer #dispaly name
+ source: "ECCV 2022"
+ github: https://github.com/Vincentqyw/ml-aspanformer
+ paper: https://arxiv.org/abs/2208.14201
+ project: null
+ display: true
+ xfeat+lightglue:
+ enable: true
+ matcher: xfeat_lightglue
+ dense: true
+ info:
+ name: xfeat+lightglue
+ source: "CVPR 2024"
+ github: https://github.com/Vincentqyw/omniglue-onnx
+ paper: https://arxiv.org/abs/2405.12979
+ project: https://hwjiang1510.github.io/OmniGlue
+ display: true
+ xfeat(sparse):
+ matcher: NN-mutual
+ feature: xfeat
+ dense: false
+ info:
+ name: XFeat #dispaly name
+ source: "CVPR 2024"
+ github: https://github.com/verlab/accelerated_features
+ paper: https://arxiv.org/abs/2404.19174
+ project: null
+ display: true
+ xfeat(dense):
+ matcher: xfeat_dense
+ dense: true
+ info:
+ name: XFeat #dispaly name
+ source: "CVPR 2024"
+ github: https://github.com/verlab/accelerated_features
+ paper: https://arxiv.org/abs/2404.19174
+ project: null
+ display: false
+ dedode:
+ matcher: Dual-Softmax
+ feature: dedode
+ dense: false
+ info:
+ name: DeDoDe #dispaly name
+ source: "3DV 2024"
+ github: https://github.com/Parskatt/DeDoDe
+ paper: https://arxiv.org/abs/2308.08479
+ project: null
+ display: true
+ superpoint+superglue:
+ matcher: superglue
+ feature: superpoint_max
+ dense: false
+ info:
+ name: SuperGlue #dispaly name
+ source: "CVPR 2020"
+ github: https://github.com/magicleap/SuperGluePretrainedNetwork
+ paper: https://arxiv.org/abs/1911.11763
+ project: null
+ display: true
+ superpoint+lightglue:
+ matcher: superpoint-lightglue
+ feature: superpoint_max
+ dense: false
+ info:
+ name: LightGlue #dispaly name
+ source: "ICCV 2023"
+ github: https://github.com/cvg/LightGlue
+ paper: https://arxiv.org/pdf/2306.13643
+ project: null
+ display: true
+ disk:
+ matcher: NN-mutual
+ feature: disk
+ dense: false
+ info:
+ name: DISK
+ source: "NeurIPS 2020"
+ github: https://github.com/cvlab-epfl/disk
+ paper: https://arxiv.org/abs/2006.13566
+ project: null
+ display: true
+ disk+dualsoftmax:
+ matcher: Dual-Softmax
+ feature: disk
+ dense: false
+ info:
+ name: DISK
+ source: "NeurIPS 2020"
+ github: https://github.com/cvlab-epfl/disk
+ paper: https://arxiv.org/abs/2006.13566
+ project: null
+ display: false
+ superpoint+dualsoftmax:
+ matcher: Dual-Softmax
+ feature: superpoint_max
+ dense: false
+ info:
+ name: SuperPoint
+ source: "CVPRW 2018"
+ github: https://github.com/magicleap/SuperPointPretrainedNetwork
+ paper: https://arxiv.org/abs/1712.07629
+ project: null
+ display: false
+ sift+lightglue:
+ matcher: sift-lightglue
+ feature: sift
+ dense: false
+ info:
+ name: LightGlue #dispaly name
+ source: "ICCV 2023"
+ github: https://github.com/cvg/LightGlue
+ paper: https://arxiv.org/pdf/2306.13643
+ project: null
+ display: true
+ disk+lightglue:
+ matcher: disk-lightglue
+ feature: disk
+ dense: false
+ info:
+ name: LightGlue
+ source: "ICCV 2023"
+ github: https://github.com/cvg/LightGlue
+ paper: https://arxiv.org/pdf/2306.13643
+ project: null
+ display: true
+ aliked+lightglue:
+ matcher: aliked-lightglue
+ feature: aliked-n16
+ dense: false
+ info:
+ name: ALIKED
+ source: "ICCV 2023"
+ github: https://github.com/Shiaoming/ALIKED
+ paper: https://arxiv.org/pdf/2304.03608.pdf
+ project: null
+ display: true
+ superpoint+mnn:
+ matcher: NN-mutual
+ feature: superpoint_max
+ dense: false
+ info:
+ name: SuperPoint #dispaly name
+ source: "CVPRW 2018"
+ github: https://github.com/magicleap/SuperPointPretrainedNetwork
+ paper: https://arxiv.org/abs/1712.07629
+ project: null
+ display: true
+ sift+sgmnet:
+ matcher: sgmnet
+ feature: sift
+ dense: false
+ info:
+ name: SGMNet #dispaly name
+ source: "ICCV 2021"
+ github: https://github.com/vdvchen/SGMNet
+ paper: https://arxiv.org/abs/2108.08771
+ project: null
+ display: true
+ sosnet:
+ matcher: NN-mutual
+ feature: sosnet
+ dense: false
+ info:
+ name: SOSNet #dispaly name
+ source: "CVPR 2019"
+ github: https://github.com/scape-research/SOSNet
+ paper: https://arxiv.org/abs/1904.05019
+ project: https://research.scape.io/sosnet
+ display: true
+ hardnet:
+ matcher: NN-mutual
+ feature: hardnet
+ dense: false
+ info:
+ name: HardNet #dispaly name
+ source: "NeurIPS 2017"
+ github: https://github.com/DagnyT/hardnet
+ paper: https://arxiv.org/abs/1705.10872
+ project: null
+ display: true
+ d2net:
+ matcher: NN-mutual
+ feature: d2net-ss
+ dense: false
+ info:
+ name: D2Net #dispaly name
+ source: "CVPR 2019"
+ github: https://github.com/Vincentqyw/d2-net
+ paper: https://arxiv.org/abs/1905.03561
+ project: https://dusmanu.com/publications/d2-net.html
+ display: true
+ rord:
+ matcher: NN-mutual
+ feature: rord
+ dense: false
+ info:
+ name: RoRD #dispaly name
+ source: "IROS 2021"
+ github: https://github.com/UditSinghParihar/RoRD
+ paper: https://arxiv.org/abs/2103.08573
+ project: https://uditsinghparihar.github.io/RoRD
+ display: true
+ alike:
+ matcher: NN-mutual
+ feature: alike
+ dense: false
+ info:
+ name: ALIKE #dispaly name
+ source: "TMM 2022"
+ github: https://github.com/Shiaoming/ALIKE
+ paper: https://arxiv.org/abs/2112.02906
+ project: null
+ display: true
+ lanet:
+ matcher: NN-mutual
+ feature: lanet
+ dense: false
+ info:
+ name: LANet #dispaly name
+ source: "ACCV 2022"
+ github: https://github.com/wangch-g/lanet
+ paper: https://openaccess.thecvf.com/content/ACCV2022/papers/Wang_Rethinking_Low-level_Features_for_Interest_Point_Detection_and_Description_ACCV_2022_paper.pdf
+ project: null
+ display: true
+ r2d2:
+ matcher: NN-mutual
+ feature: r2d2
+ dense: false
+ info:
+ name: R2D2 #dispaly name
+ source: "NeurIPS 2019"
+ github: https://github.com/naver/r2d2
+ paper: https://arxiv.org/abs/1906.06195
+ project: null
+ display: true
+ darkfeat:
+ matcher: NN-mutual
+ feature: darkfeat
+ dense: false
+ info:
+ name: DarkFeat #dispaly name
+ source: "AAAI 2023"
+ github: https://github.com/THU-LYJ-Lab/DarkFeat
+ paper: null
+ project: null
+ display: true
+ sift:
+ matcher: NN-mutual
+ feature: sift
+ dense: false
+ info:
+ name: SIFT #dispaly name
+ source: "IJCV 2004"
+ github: null
+ paper: https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf
+ project: null
+ display: true
+ gluestick:
+ enable: false
+ matcher: gluestick
+ dense: true
+ info:
+ name: GlueStick #dispaly name
+ source: "ICCV 2023"
+ github: https://github.com/cvg/GlueStick
+ paper: https://arxiv.org/abs/2304.02008
+ project: https://iago-suarez.com/gluestick
+ display: true
+ sold2:
+ enable: false
+ matcher: sold2
+ dense: true
+ info:
+ name: SOLD2 #dispaly name
+ source: "CVPR 2021"
+ github: https://github.com/cvg/SOLD2
+ paper: https://arxiv.org/abs/2104.03362
+ project: null
+ display: true
+
+ sfd2+imp:
+ enable: true
+ matcher: imp
+ feature: sfd2
+ dense: false
+ info:
+ name: SFD2+IMP #dispaly name
+ source: "CVPR 2023"
+ github: https://github.com/feixue94/imp-release
+ paper: https://arxiv.org/pdf/2304.14837
+ project: https://feixue94.github.io/
+ display: true
+
+ sfd2+mnn:
+ enable: true
+ matcher: NN-mutual
+ feature: sfd2
+ dense: false
+ info:
+ name: SFD2+MNN #dispaly name
+ source: "CVPR 2023"
+ github: https://github.com/feixue94/sfd2
+ paper: https://arxiv.org/abs/2304.14845
+ project: https://feixue94.github.io/
+ display: true
+
+retrieval_zoo:
+ netvlad:
+ enable: true
+ openibl:
+ enable: true
+ cosplace:
+ enable: true
diff --git a/docker/build_docker.bat b/docker/build_docker.bat
new file mode 100644
index 0000000000000000000000000000000000000000..9f3fc687e1185de2866a1dbe221599549abdbce8
--- /dev/null
+++ b/docker/build_docker.bat
@@ -0,0 +1,3 @@
+docker build -t image-matching-webui:latest . --no-cache
+# docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
+# docker push vincentqin/image-matching-webui:latest
diff --git a/docker/run_docker.bat b/docker/run_docker.bat
new file mode 100644
index 0000000000000000000000000000000000000000..da7686293c14465f0899c4b022f89fcc03db93b3
--- /dev/null
+++ b/docker/run_docker.bat
@@ -0,0 +1 @@
+docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860
diff --git a/docker/run_docker.sh b/docker/run_docker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..da7686293c14465f0899c4b022f89fcc03db93b3
--- /dev/null
+++ b/docker/run_docker.sh
@@ -0,0 +1 @@
+docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py --server_name "0.0.0.0" --server_port=7860
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aab94e3a4a1e8e4b5292e2a7767c7e916e3b8e2f
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,13 @@
+name: imw
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.10.10
+ - pytorch-cuda=12.1
+ - pytorch=2.4.0
+ - pip
+ - pip:
+ - -r requirements.txt
diff --git a/imcui/__init__.py b/imcui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/api/__init__.py b/imcui/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..251563d95841af0b77ef47f92f0c7d7c78a90bb8
--- /dev/null
+++ b/imcui/api/__init__.py
@@ -0,0 +1,47 @@
+import base64
+import io
+from typing import List
+
+import numpy as np
+from fastapi.exceptions import HTTPException
+from PIL import Image
+from pydantic import BaseModel
+
+from ..hloc import logger
+from .core import ImageMatchingAPI
+
+
+class ImagesInput(BaseModel):
+ data: List[str] = []
+ max_keypoints: List[int] = []
+ timestamps: List[str] = []
+ grayscale: bool = False
+ image_hw: List[List[int]] = [[], []]
+ feature_type: int = 0
+ rotates: List[float] = []
+ scales: List[float] = []
+ reference_points: List[List[float]] = []
+ binarize: bool = False
+
+
+def decode_base64_to_image(encoding):
+ if encoding.startswith("data:image/"):
+ encoding = encoding.split(";")[1].split(",")[1]
+ try:
+ image = Image.open(io.BytesIO(base64.b64decode(encoding)))
+ return image
+ except Exception as e:
+ logger.warning(f"API cannot decode image: {e}")
+ raise HTTPException(status_code=500, detail="Invalid encoded image") from e
+
+
+def to_base64_nparray(encoding: str) -> np.ndarray:
+ return np.array(decode_base64_to_image(encoding)).astype("uint8")
+
+
+__all__ = [
+ "ImageMatchingAPI",
+ "ImagesInput",
+ "decode_base64_to_image",
+ "to_base64_nparray",
+]
diff --git a/imcui/api/client.py b/imcui/api/client.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e5130bf8452e904d6c999de5dcc932fb216dffa
--- /dev/null
+++ b/imcui/api/client.py
@@ -0,0 +1,232 @@
+import argparse
+import base64
+import os
+import pickle
+import time
+from typing import Dict, List
+
+import cv2
+import numpy as np
+import requests
+
+ENDPOINT = "http://127.0.0.1:8001"
+if "REMOTE_URL_RAILWAY" in os.environ:
+ ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
+
+print(f"API ENDPOINT: {ENDPOINT}")
+
+API_VERSION = f"{ENDPOINT}/version"
+API_URL_MATCH = f"{ENDPOINT}/v1/match"
+API_URL_EXTRACT = f"{ENDPOINT}/v1/extract"
+
+
+def read_image(path: str) -> str:
+ """
+ Read an image from a file, encode it as a JPEG and then as a base64 string.
+
+ Args:
+ path (str): The path to the image to read.
+
+ Returns:
+ str: The base64 encoded image.
+ """
+ # Read the image from the file
+ img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
+
+ # Encode the image as a png, NO COMPRESSION!!!
+ retval, buffer = cv2.imencode(".png", img)
+
+ # Encode the JPEG as a base64 string
+ b64img = base64.b64encode(buffer).decode("utf-8")
+
+ return b64img
+
+
+def do_api_requests(url=API_URL_EXTRACT, **kwargs):
+ """
+ Helper function to send an API request to the image matching service.
+
+ Args:
+ url (str): The URL of the API endpoint to use. Defaults to the
+ feature extraction endpoint.
+ **kwargs: Additional keyword arguments to pass to the API.
+
+ Returns:
+ List[Dict[str, np.ndarray]]: A list of dictionaries containing the
+ extracted features. The keys are "keypoints", "descriptors", and
+ "scores", and the values are ndarrays of shape (N, 2), (N, ?),
+ and (N,), respectively.
+ """
+ # Set up the request body
+ reqbody = {
+ # List of image data base64 encoded
+ "data": [],
+ # List of maximum number of keypoints to extract from each image
+ "max_keypoints": [100, 100],
+ # List of timestamps for each image (not used?)
+ "timestamps": ["0", "1"],
+ # Whether to convert the images to grayscale
+ "grayscale": 0,
+ # List of image height and width
+ "image_hw": [[640, 480], [320, 240]],
+ # Type of feature to extract
+ "feature_type": 0,
+ # List of rotation angles for each image
+ "rotates": [0.0, 0.0],
+ # List of scale factors for each image
+ "scales": [1.0, 1.0],
+ # List of reference points for each image (not used)
+ "reference_points": [[640, 480], [320, 240]],
+ # Whether to binarize the descriptors
+ "binarize": True,
+ }
+ # Update the request body with the additional keyword arguments
+ reqbody.update(kwargs)
+ try:
+ # Send the request
+ r = requests.post(url, json=reqbody)
+ if r.status_code == 200:
+ # Return the response
+ return r.json()
+ else:
+ # Print an error message if the response code is not 200
+ print(f"Error: Response code {r.status_code} - {r.text}")
+ except Exception as e:
+ # Print an error message if an exception occurs
+ print(f"An error occurred: {e}")
+
+
+def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
+ """
+ Send a request to the API to generate a match between two images.
+
+ Args:
+ path0 (str): The path to the first image.
+ path1 (str): The path to the second image.
+
+ Returns:
+ Dict[str, np.ndarray]: A dictionary containing the generated matches.
+ The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
+ and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and
+ (N, 2), respectively.
+ """
+ files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
+ try:
+ # TODO: replace files with post json
+ response = requests.post(API_URL_MATCH, files=files)
+ pred = {}
+ if response.status_code == 200:
+ pred = response.json()
+ for key in list(pred.keys()):
+ pred[key] = np.array(pred[key])
+ else:
+ print(f"Error: Response code {response.status_code} - {response.text}")
+ finally:
+ files["image0"].close()
+ files["image1"].close()
+ return pred
+
+
+def send_request_extract(
+ input_images: str, viz: bool = False
+) -> List[Dict[str, np.ndarray]]:
+ """
+ Send a request to the API to extract features from an image.
+
+ Args:
+ input_images (str): The path to the image.
+
+ Returns:
+ List[Dict[str, np.ndarray]]: A list of dictionaries containing the
+ extracted features. The keys are "keypoints", "descriptors", and
+ "scores", and the values are ndarrays of shape (N, 2), (N, 128),
+ and (N,), respectively.
+ """
+ image_data = read_image(input_images)
+ inputs = {
+ "data": [image_data],
+ }
+ response = do_api_requests(
+ url=API_URL_EXTRACT,
+ **inputs,
+ )
+ # breakpoint()
+ # print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
+
+ # draw matching, debug only
+ if viz:
+ from hloc.utils.viz import plot_keypoints
+ from ui.viz import fig2im, plot_images
+
+ kpts = np.array(response[0]["keypoints_orig"])
+ if "image_orig" in response[0].keys():
+ img_orig = np.array(["image_orig"])
+
+ output_keypoints = plot_images([img_orig], titles="titles", dpi=300)
+ plot_keypoints([kpts])
+ output_keypoints = fig2im(output_keypoints)
+ cv2.imwrite(
+ "demo_match.jpg",
+ output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
+ )
+ return response
+
+
+def get_api_version():
+ try:
+ response = requests.get(API_VERSION).json()
+ print("API VERSION: {}".format(response["version"]))
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+
+if __name__ == "__main__":
+ from pathlib import Path
+
+ parser = argparse.ArgumentParser(
+ description="Send text to stable audio server and receive generated audio."
+ )
+ parser.add_argument(
+ "--image0",
+ required=False,
+ help="Path for the file's melody",
+ default=str(
+ Path(__file__).parents[1]
+ / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg"
+ ),
+ )
+ parser.add_argument(
+ "--image1",
+ required=False,
+ help="Path for the file's melody",
+ default=str(
+ Path(__file__).parents[1]
+ / "datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg"
+ ),
+ )
+ args = parser.parse_args()
+
+ # get api version
+ get_api_version()
+
+ # request match
+ # for i in range(10):
+ # t1 = time.time()
+ # preds = send_request_match(args.image0, args.image1)
+ # t2 = time.time()
+ # print(
+ # "Time cost1: {} seconds, matched: {}".format(
+ # (t2 - t1), len(preds["mmkeypoints0_orig"])
+ # )
+ # )
+
+ # request extract
+ for i in range(1000):
+ t1 = time.time()
+ preds = send_request_extract(args.image0)
+ t2 = time.time()
+ print(f"Time cost2: {(t2 - t1)} seconds")
+
+ # dump preds
+ with open("preds.pkl", "wb") as f:
+ pickle.dump(preds, f)
diff --git a/imcui/api/config/api.yaml b/imcui/api/config/api.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4f4be347bc713ed03ff2252be5b95096bbe9e43
--- /dev/null
+++ b/imcui/api/config/api.yaml
@@ -0,0 +1,51 @@
+# This file was generated using the `serve build` command on Ray v2.38.0.
+
+proxy_location: EveryNode
+http_options:
+ host: 0.0.0.0
+ port: 8001
+
+grpc_options:
+ port: 9000
+ grpc_servicer_functions: []
+
+logging_config:
+ encoding: TEXT
+ log_level: INFO
+ logs_dir: null
+ enable_access_log: true
+
+applications:
+- name: app1
+ route_prefix: /
+ import_path: api.server:service
+ runtime_env: {}
+ deployments:
+ - name: ImageMatchingService
+ num_replicas: 4
+ ray_actor_options:
+ num_cpus: 2.0
+ num_gpus: 1.0
+
+api:
+ feature:
+ output: feats-superpoint-n4096-rmax1600
+ model:
+ name: superpoint
+ nms_radius: 3
+ max_keypoints: 4096
+ keypoint_threshold: 0.005
+ preprocessing:
+ grayscale: True
+ force_resize: True
+ resize_max: 1600
+ width: 640
+ height: 480
+ dfactor: 8
+ matcher:
+ output: matches-NN-mutual
+ model:
+ name: nearest_neighbor
+ do_mutual_check: True
+ match_threshold: 0.2
+ dense: False
diff --git a/imcui/api/core.py b/imcui/api/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4d58d08662d853d5fdd29baee5f8a349b61d369
--- /dev/null
+++ b/imcui/api/core.py
@@ -0,0 +1,308 @@
+# api.py
+import warnings
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from ..hloc import extract_features, logger, match_dense, match_features
+from ..hloc.utils.viz import add_text, plot_keypoints
+from ..ui.utils import filter_matches, get_feature_model, get_model
+from ..ui.viz import display_matches, fig2im, plot_images
+
+warnings.simplefilter("ignore")
+
+
+class ImageMatchingAPI(torch.nn.Module):
+ default_conf = {
+ "ransac": {
+ "enable": True,
+ "estimator": "poselib",
+ "geometry": "homography",
+ "method": "RANSAC",
+ "reproj_threshold": 3,
+ "confidence": 0.9999,
+ "max_iter": 10000,
+ },
+ }
+
+ def __init__(
+ self,
+ conf: dict = {},
+ device: str = "cpu",
+ detect_threshold: float = 0.015,
+ max_keypoints: int = 1024,
+ match_threshold: float = 0.2,
+ ) -> None:
+ """
+ Initializes an instance of the ImageMatchingAPI class.
+
+ Args:
+ conf (dict): A dictionary containing the configuration parameters.
+ device (str, optional): The device to use for computation. Defaults to "cpu".
+ detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
+ max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
+ match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
+
+ Returns:
+ None
+ """
+ super().__init__()
+ self.device = device
+ self.conf = {**self.default_conf, **conf}
+ self._updata_config(detect_threshold, max_keypoints, match_threshold)
+ self._init_models()
+ if device == "cuda":
+ memory_allocated = torch.cuda.memory_allocated(device)
+ memory_reserved = torch.cuda.memory_reserved(device)
+ logger.info(f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB")
+ logger.info(f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB")
+ self.pred = None
+
+ def parse_match_config(self, conf):
+ if conf["dense"]:
+ return {
+ **conf,
+ "matcher": match_dense.confs.get(conf["matcher"]["model"]["name"]),
+ "dense": True,
+ }
+ else:
+ return {
+ **conf,
+ "feature": extract_features.confs.get(conf["feature"]["model"]["name"]),
+ "matcher": match_features.confs.get(conf["matcher"]["model"]["name"]),
+ "dense": False,
+ }
+
+ def _updata_config(
+ self,
+ detect_threshold: float = 0.015,
+ max_keypoints: int = 1024,
+ match_threshold: float = 0.2,
+ ):
+ self.dense = self.conf["dense"]
+ if self.conf["dense"]:
+ try:
+ self.conf["matcher"]["model"]["match_threshold"] = match_threshold
+ except TypeError as e:
+ logger.error(e)
+ else:
+ self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
+ self.conf["feature"]["model"]["keypoint_threshold"] = detect_threshold
+ self.extract_conf = self.conf["feature"]
+
+ self.match_conf = self.conf["matcher"]
+
+ def _init_models(self):
+ # initialize matcher
+ self.matcher = get_model(self.match_conf)
+ # initialize extractor
+ if self.dense:
+ self.extractor = None
+ else:
+ self.extractor = get_feature_model(self.conf["feature"])
+
+ def _forward(self, img0, img1):
+ if self.dense:
+ pred = match_dense.match_images(
+ self.matcher,
+ img0,
+ img1,
+ self.match_conf["preprocessing"],
+ device=self.device,
+ )
+ last_fixed = "{}".format( # noqa: F841
+ self.match_conf["model"]["name"]
+ )
+ else:
+ pred0 = extract_features.extract(
+ self.extractor, img0, self.extract_conf["preprocessing"]
+ )
+ pred1 = extract_features.extract(
+ self.extractor, img1, self.extract_conf["preprocessing"]
+ )
+ pred = match_features.match_images(self.matcher, pred0, pred1)
+ return pred
+
+ def _convert_pred(self, pred):
+ ret = {
+ k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
+ for k, v in pred.items()
+ }
+ ret = {
+ k: v[0].cpu().detach().numpy() if isinstance(v, list) else v
+ for k, v in ret.items()
+ }
+ return ret
+
+ @torch.inference_mode()
+ def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
+ """Extract features from a single image.
+
+ Args:
+ img0 (np.ndarray): image
+
+ Returns:
+ Dict[str, np.ndarray]: feature dict
+ """
+
+ # setting prams
+ self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512)
+ self.extractor.conf["keypoint_threshold"] = kwargs.get(
+ "keypoint_threshold", 0.0
+ )
+
+ pred = extract_features.extract(
+ self.extractor, img0, self.extract_conf["preprocessing"]
+ )
+ pred = self._convert_pred(pred)
+ # back to origin scale
+ s0 = pred["original_size"] / pred["size"]
+ pred["keypoints_orig"] = (
+ match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
+ )
+ # TODO: rotate back
+ binarize = kwargs.get("binarize", False)
+ if binarize:
+ assert "descriptors" in pred
+ pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8)
+ pred["descriptors"] = pred["descriptors"].T # N x DIM
+ return pred
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ img0: np.ndarray,
+ img1: np.ndarray,
+ ) -> Dict[str, np.ndarray]:
+ """
+ Forward pass of the image matching API.
+
+ Args:
+ img0: A 3D NumPy array of shape (H, W, C) representing the first image.
+ Values are in the range [0, 1] and are in RGB mode.
+ img1: A 3D NumPy array of shape (H, W, C) representing the second image.
+ Values are in the range [0, 1] and are in RGB mode.
+
+ Returns:
+ A dictionary containing the following keys:
+ - image0_orig: The original image 0.
+ - image1_orig: The original image 1.
+ - keypoints0_orig: The keypoints detected in image 0.
+ - keypoints1_orig: The keypoints detected in image 1.
+ - mkeypoints0_orig: The raw matches between image 0 and image 1.
+ - mkeypoints1_orig: The raw matches between image 1 and image 0.
+ - mmkeypoints0_orig: The RANSAC inliers in image 0.
+ - mmkeypoints1_orig: The RANSAC inliers in image 1.
+ - mconf: The confidence scores for the raw matches.
+ - mmconf: The confidence scores for the RANSAC inliers.
+ """
+ # Take as input a pair of images (not a batch)
+ assert isinstance(img0, np.ndarray)
+ assert isinstance(img1, np.ndarray)
+ self.pred = self._forward(img0, img1)
+ if self.conf["ransac"]["enable"]:
+ self.pred = self._geometry_check(self.pred)
+ return self.pred
+
+ def _geometry_check(
+ self,
+ pred: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """
+ Filter matches using RANSAC. If keypoints are available, filter by keypoints.
+ If lines are available, filter by lines. If both keypoints and lines are
+ available, filter by keypoints.
+
+ Args:
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
+ See :func:`filter_matches` for the expected keys.
+
+ Returns:
+ Dict[str, Any]: filtered matches
+ """
+ pred = filter_matches(
+ pred,
+ ransac_method=self.conf["ransac"]["method"],
+ ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"],
+ ransac_confidence=self.conf["ransac"]["confidence"],
+ ransac_max_iter=self.conf["ransac"]["max_iter"],
+ )
+ return pred
+
+ def visualize(
+ self,
+ log_path: Optional[Path] = None,
+ ) -> None:
+ """
+ Visualize the matches.
+
+ Args:
+ log_path (Path, optional): The directory to save the images. Defaults to None.
+
+ Returns:
+ None
+ """
+ if self.conf["dense"]:
+ postfix = str(self.conf["matcher"]["model"]["name"])
+ else:
+ postfix = "{}_{}".format(
+ str(self.conf["feature"]["model"]["name"]),
+ str(self.conf["matcher"]["model"]["name"]),
+ )
+ titles = [
+ "Image 0 - Keypoints",
+ "Image 1 - Keypoints",
+ ]
+ pred: Dict[str, Any] = self.pred
+ image0: np.ndarray = pred["image0_orig"]
+ image1: np.ndarray = pred["image1_orig"]
+ output_keypoints: np.ndarray = plot_images(
+ [image0, image1], titles=titles, dpi=300
+ )
+ if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
+ plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
+ text: str = (
+ f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
+ + f"# keypoints1: {len(pred['keypoints1_orig'])}"
+ )
+ add_text(0, text, fs=15)
+ output_keypoints = fig2im(output_keypoints)
+ # plot images with raw matches
+ titles = [
+ "Image 0 - Raw matched keypoints",
+ "Image 1 - Raw matched keypoints",
+ ]
+ output_matches_raw, num_matches_raw = display_matches(
+ pred, titles=titles, tag="KPTS_RAW"
+ )
+ # plot images with ransac matches
+ titles = [
+ "Image 0 - Ransac matched keypoints",
+ "Image 1 - Ransac matched keypoints",
+ ]
+ output_matches_ransac, num_matches_ransac = display_matches(
+ pred, titles=titles, tag="KPTS_RANSAC"
+ )
+ if log_path is not None:
+ img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png"
+ img_matches_raw_path: Path = log_path / f"img_matches_raw_{postfix}.png"
+ img_matches_ransac_path: Path = (
+ log_path / f"img_matches_ransac_{postfix}.png"
+ )
+ cv2.imwrite(
+ str(img_keypoints_path),
+ output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
+ )
+ cv2.imwrite(
+ str(img_matches_raw_path),
+ output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR
+ )
+ cv2.imwrite(
+ str(img_matches_ransac_path),
+ output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR
+ )
+ plt.close("all")
diff --git a/imcui/api/server.py b/imcui/api/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d1932a639374fdf0533caa0e68a2c109ce1a07b
--- /dev/null
+++ b/imcui/api/server.py
@@ -0,0 +1,170 @@
+# server.py
+import warnings
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+import ray
+import torch
+import yaml
+from fastapi import FastAPI, File, UploadFile
+from fastapi.responses import JSONResponse
+from PIL import Image
+from ray import serve
+
+from . import ImagesInput, to_base64_nparray
+from .core import ImageMatchingAPI
+from ..hloc import DEVICE
+from ..ui import get_version
+
+warnings.simplefilter("ignore")
+app = FastAPI()
+if ray.is_initialized():
+ ray.shutdown()
+ray.init(
+ dashboard_port=8265,
+ ignore_reinit_error=True,
+)
+serve.start(
+ http_options={"host": "0.0.0.0", "port": 8001},
+)
+
+num_gpus = 1 if torch.cuda.is_available() else 0
+
+
+@serve.deployment(
+ num_replicas=4, ray_actor_options={"num_cpus": 2, "num_gpus": num_gpus}
+)
+@serve.ingress(app)
+class ImageMatchingService:
+ def __init__(self, conf: dict, device: str):
+ self.conf = conf
+ self.api = ImageMatchingAPI(conf=conf, device=device)
+
+ @app.get("/")
+ def root(self):
+ return "Hello, world!"
+
+ @app.get("/version")
+ async def version(self):
+ return {"version": get_version()}
+
+ @app.post("/v1/match")
+ async def match(
+ self, image0: UploadFile = File(...), image1: UploadFile = File(...)
+ ):
+ """
+ Handle the image matching request and return the processed result.
+
+ Args:
+ image0 (UploadFile): The first image file for matching.
+ image1 (UploadFile): The second image file for matching.
+
+ Returns:
+ JSONResponse: A JSON response containing the filtered match results
+ or an error message in case of failure.
+ """
+ try:
+ # Load the images from the uploaded files
+ image0_array = self.load_image(image0)
+ image1_array = self.load_image(image1)
+
+ # Perform image matching using the API
+ output = self.api(image0_array, image1_array)
+
+ # Keys to skip in the output
+ skip_keys = ["image0_orig", "image1_orig"]
+
+ # Postprocess the output to filter unwanted data
+ pred = self.postprocess(output, skip_keys)
+
+ # Return the filtered prediction as a JSON response
+ return JSONResponse(content=pred)
+ except Exception as e:
+ # Return an error message with status code 500 in case of exception
+ return JSONResponse(content={"error": str(e)}, status_code=500)
+
+ @app.post("/v1/extract")
+ async def extract(self, input_info: ImagesInput):
+ """
+ Extract keypoints and descriptors from images.
+
+ Args:
+ input_info: An object containing the image data and options.
+
+ Returns:
+ A list of dictionaries containing the keypoints and descriptors.
+ """
+ try:
+ preds = []
+ for i, input_image in enumerate(input_info.data):
+ # Load the image from the input data
+ image_array = to_base64_nparray(input_image)
+ # Extract keypoints and descriptors
+ output = self.api.extract(
+ image_array,
+ max_keypoints=input_info.max_keypoints[i],
+ binarize=input_info.binarize,
+ )
+ # Do not return the original image and image_orig
+ # skip_keys = ["image", "image_orig"]
+ skip_keys = []
+
+ # Postprocess the output
+ pred = self.postprocess(output, skip_keys)
+ preds.append(pred)
+ # Return the list of extracted features
+ return JSONResponse(content=preds)
+ except Exception as e:
+ # Return an error message if an exception occurs
+ return JSONResponse(content={"error": str(e)}, status_code=500)
+
+ def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
+ """
+ Reads an image from a file path or an UploadFile object.
+
+ Args:
+ file_path: A file path or an UploadFile object.
+
+ Returns:
+ A numpy array representing the image.
+ """
+ if isinstance(file_path, str):
+ file_path = Path(file_path).resolve(strict=False)
+ else:
+ file_path = file_path.file
+ with Image.open(file_path) as img:
+ image_array = np.array(img)
+ return image_array
+
+ def postprocess(self, output: dict, skip_keys: list, binarize: bool = True) -> dict:
+ pred = {}
+ for key, value in output.items():
+ if key in skip_keys:
+ continue
+ if isinstance(value, np.ndarray):
+ pred[key] = value.tolist()
+ return pred
+
+ def run(self, host: str = "0.0.0.0", port: int = 8001):
+ import uvicorn
+
+ uvicorn.run(app, host=host, port=port)
+
+
+def read_config(config_path: Path) -> dict:
+ with open(config_path, "r") as f:
+ conf = yaml.safe_load(f)
+ return conf
+
+
+# api server
+conf = read_config(Path(__file__).parent / "config/api.yaml")
+service = ImageMatchingService.bind(conf=conf["api"], device=DEVICE)
+handle = serve.run(service, route_prefix="/")
+
+# serve run api.server_ray:service
+
+# build to generate config file
+# serve build api.server_ray:service -o api/config/ray.yaml
+# serve run api/config/ray.yaml
diff --git a/imcui/api/test/CMakeLists.txt b/imcui/api/test/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1da6c924042e615ebfa51e4e55de1dcaaddeff8b
--- /dev/null
+++ b/imcui/api/test/CMakeLists.txt
@@ -0,0 +1,17 @@
+cmake_minimum_required(VERSION 3.10)
+project(imatchui)
+
+set(OpenCV_DIR /usr/include/opencv4)
+find_package(OpenCV REQUIRED)
+
+find_package(Boost REQUIRED COMPONENTS system)
+if(Boost_FOUND)
+ include_directories(${Boost_INCLUDE_DIRS})
+endif()
+
+add_executable(client client.cpp)
+
+target_include_directories(client PRIVATE ${Boost_LIBRARIES}
+ ${OpenCV_INCLUDE_DIRS})
+
+target_link_libraries(client PRIVATE curl jsoncpp b64 ${OpenCV_LIBS})
diff --git a/imcui/api/test/build_and_run.sh b/imcui/api/test/build_and_run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..40921bb9b925c67722247df7ab901668d713e888
--- /dev/null
+++ b/imcui/api/test/build_and_run.sh
@@ -0,0 +1,16 @@
+# g++ main.cpp -I/usr/include/opencv4 -lcurl -ljsoncpp -lb64 -lopencv_core -lopencv_imgcodecs -o main
+# sudo apt-get update
+# sudo apt-get install libboost-all-dev -y
+# sudo apt-get install libcurl4-openssl-dev libjsoncpp-dev libb64-dev libopencv-dev -y
+
+cd build
+cmake ..
+make -j12
+
+echo " ======== RUN DEMO ========"
+
+./client
+
+echo " ======== END DEMO ========"
+
+cd ..
diff --git a/imcui/api/test/client.cpp b/imcui/api/test/client.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b970457fbbef8a81e16856d4e5eeadd1e70715c8
--- /dev/null
+++ b/imcui/api/test/client.cpp
@@ -0,0 +1,81 @@
+#include
+#include
+#include "helper.h"
+
+int main() {
+ std::string img_path =
+ "../../../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg";
+ cv::Mat original_img = cv::imread(img_path, cv::IMREAD_GRAYSCALE);
+
+ if (original_img.empty()) {
+ throw std::runtime_error("Failed to decode image");
+ }
+
+ // Convert the image to Base64
+ std::string base64_img = image_to_base64(original_img);
+
+ // Convert the Base64 back to an image
+ cv::Mat decoded_img = base64_to_image(base64_img);
+ cv::imwrite("decoded_image.jpg", decoded_img);
+ cv::imwrite("original_img.jpg", original_img);
+
+ // The images should be identical
+ if (cv::countNonZero(original_img != decoded_img) != 0) {
+ std::cerr << "The images are not identical" << std::endl;
+ return -1;
+ } else {
+ std::cout << "The images are identical!" << std::endl;
+ }
+
+ // construct params
+ APIParams params{.data = {base64_img},
+ .max_keypoints = {100, 100},
+ .timestamps = {"0", "1"},
+ .grayscale = {0},
+ .image_hw = {{480, 640}, {240, 320}},
+ .feature_type = 0,
+ .rotates = {0.0f, 0.0f},
+ .scales = {1.0f, 1.0f},
+ .reference_points = {{1.23e+2f, 1.2e+1f},
+ {5.0e-1f, 3.0e-1f},
+ {2.3e+2f, 2.2e+1f},
+ {6.0e-1f, 4.0e-1f}},
+ .binarize = {1}};
+
+ KeyPointResults kpts_results;
+
+ // Convert the parameters to JSON
+ Json::Value jsonData = paramsToJson(params);
+ std::string url = "http://127.0.0.1:8001/v1/extract";
+ Json::StreamWriterBuilder writer;
+ std::string output = Json::writeString(writer, jsonData);
+
+ CURL* curl;
+ CURLcode res;
+ std::string readBuffer;
+
+ curl_global_init(CURL_GLOBAL_DEFAULT);
+ curl = curl_easy_init();
+ if (curl) {
+ struct curl_slist* hs = NULL;
+ hs = curl_slist_append(hs, "Content-Type: application/json");
+ curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs);
+ curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
+ curl_easy_setopt(curl, CURLOPT_POSTFIELDS, output.c_str());
+ curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
+ curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
+ res = curl_easy_perform(curl);
+
+ if (res != CURLE_OK)
+ fprintf(
+ stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
+ else {
+ // std::cout << "Response from server: " << readBuffer << std::endl;
+ kpts_results = decode_response(readBuffer);
+ }
+ curl_easy_cleanup(curl);
+ }
+ curl_global_cleanup();
+
+ return 0;
+}
diff --git a/imcui/api/test/helper.h b/imcui/api/test/helper.h
new file mode 100644
index 0000000000000000000000000000000000000000..9cad50bc0e8de2ee1beb2a2d78ff9d791b3d0f21
--- /dev/null
+++ b/imcui/api/test/helper.h
@@ -0,0 +1,405 @@
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+// base64 to image
+#include
+#include
+#include
+
+/// Parameters used in the API
+struct APIParams {
+ /// A list of images, base64 encoded
+ std::vector data;
+
+ /// The maximum number of keypoints to detect for each image
+ std::vector max_keypoints;
+
+ /// The timestamps of the images
+ std::vector timestamps;
+
+ /// Whether to convert the images to grayscale
+ bool grayscale;
+
+ /// The height and width of each image
+ std::vector> image_hw;
+
+ /// The type of feature detector to use
+ int feature_type;
+
+ /// The rotations of the images
+ std::vector rotates;
+
+ /// The scales of the images
+ std::vector scales;
+
+ /// The reference points of the images
+ std::vector> reference_points;
+
+ /// Whether to binarize the descriptors
+ bool binarize;
+};
+
+/**
+ * @brief Contains the results of a keypoint detector.
+ *
+ * @details Stores the keypoints and descriptors for each image.
+ */
+class KeyPointResults {
+ public:
+ KeyPointResults() {
+ }
+
+ /**
+ * @brief Constructor.
+ *
+ * @param kp The keypoints for each image.
+ */
+ KeyPointResults(const std::vector>& kp,
+ const std::vector& desc)
+ : keypoints(kp), descriptors(desc) {
+ }
+
+ /**
+ * @brief Append keypoints to the result.
+ *
+ * @param kpts The keypoints to append.
+ */
+ inline void append_keypoints(std::vector& kpts) {
+ keypoints.emplace_back(kpts);
+ }
+
+ /**
+ * @brief Append descriptors to the result.
+ *
+ * @param desc The descriptors to append.
+ */
+ inline void append_descriptors(cv::Mat& desc) {
+ descriptors.emplace_back(desc);
+ }
+
+ /**
+ * @brief Get the keypoints.
+ *
+ * @return The keypoints.
+ */
+ inline std::vector> get_keypoints() {
+ return keypoints;
+ }
+
+ /**
+ * @brief Get the descriptors.
+ *
+ * @return The descriptors.
+ */
+ inline std::vector get_descriptors() {
+ return descriptors;
+ }
+
+ private:
+ std::vector> keypoints;
+ std::vector descriptors;
+ std::vector> scores;
+};
+
+/**
+ * @brief Decodes a base64 encoded string.
+ *
+ * @param base64 The base64 encoded string to decode.
+ * @return The decoded string.
+ */
+std::string base64_decode(const std::string& base64) {
+ using namespace boost::archive::iterators;
+ using It = transform_width, 8, 6>;
+
+ // Find the position of the last non-whitespace character
+ auto end = base64.find_last_not_of(" \t\n\r");
+ if (end != std::string::npos) {
+ // Move one past the last non-whitespace character
+ end += 1;
+ }
+
+ // Decode the base64 string and return the result
+ return std::string(It(base64.begin()), It(base64.begin() + end));
+}
+
+/**
+ * @brief Decodes a base64 string into an OpenCV image
+ *
+ * @param base64 The base64 encoded string
+ * @return The decoded OpenCV image
+ */
+cv::Mat base64_to_image(const std::string& base64) {
+ // Decode the base64 string
+ std::string decodedStr = base64_decode(base64);
+
+ // Decode the image
+ std::vector data(decodedStr.begin(), decodedStr.end());
+ cv::Mat img = cv::imdecode(data, cv::IMREAD_GRAYSCALE);
+
+ // Check for errors
+ if (img.empty()) {
+ throw std::runtime_error("Failed to decode image");
+ }
+
+ return img;
+}
+
+/**
+ * @brief Encodes an OpenCV image into a base64 string
+ *
+ * This function takes an OpenCV image and encodes it into a base64 string.
+ * The image is first encoded as a PNG image, and then the resulting
+ * bytes are encoded as a base64 string.
+ *
+ * @param img The OpenCV image
+ * @return The base64 encoded string
+ *
+ * @throws std::runtime_error if the image is empty or encoding fails
+ */
+std::string image_to_base64(cv::Mat& img) {
+ if (img.empty()) {
+ throw std::runtime_error("Failed to read image");
+ }
+
+ // Encode the image as a PNG
+ std::vector buf;
+ if (!cv::imencode(".png", img, buf)) {
+ throw std::runtime_error("Failed to encode image");
+ }
+
+ // Encode the bytes as a base64 string
+ using namespace boost::archive::iterators;
+ using It =
+ base64_from_binary::const_iterator, 6, 8>>;
+ std::string base64(It(buf.begin()), It(buf.end()));
+
+ // Pad the string with '=' characters to a multiple of 4 bytes
+ base64.append((3 - buf.size() % 3) % 3, '=');
+
+ return base64;
+}
+
+/**
+ * @brief Callback function for libcurl to write data to a string
+ *
+ * This function is used as a callback for libcurl to write data to a string.
+ * It takes the contents, size, and nmemb as parameters, and writes the data to
+ * the string.
+ *
+ * @param contents The data to write
+ * @param size The size of the data
+ * @param nmemb The number of members in the data
+ * @param s The string to write the data to
+ * @return The number of bytes written
+ */
+size_t WriteCallback(void* contents, size_t size, size_t nmemb, std::string* s) {
+ size_t newLength = size * nmemb;
+ try {
+ // Resize the string to fit the new data
+ s->resize(s->size() + newLength);
+ } catch (std::bad_alloc& e) {
+ // If there's an error allocating memory, return 0
+ return 0;
+ }
+
+ // Copy the data to the string
+ std::copy(static_cast(contents),
+ static_cast(contents) + newLength,
+ s->begin() + s->size() - newLength);
+ return newLength;
+}
+
+// Helper functions
+
+/**
+ * @brief Helper function to convert a type to a Json::Value
+ *
+ * This function takes a value of type T and converts it to a Json::Value.
+ * It is used to simplify the process of converting a type to a Json::Value.
+ *
+ * @param val The value to convert
+ * @return The converted Json::Value
+ */
+template Json::Value toJson(const T& val) {
+ return Json::Value(val);
+}
+
+/**
+ * @brief Converts a vector to a Json::Value
+ *
+ * This function takes a vector of type T and converts it to a Json::Value.
+ * Each element in the vector is appended to the Json::Value array.
+ *
+ * @param vec The vector to convert to Json::Value
+ * @return The Json::Value representing the vector
+ */
+template Json::Value vectorToJson(const std::vector& vec) {
+ Json::Value json(Json::arrayValue);
+ for (const auto& item : vec) {
+ json.append(item);
+ }
+ return json;
+}
+
+/**
+ * @brief Converts a nested vector to a Json::Value
+ *
+ * This function takes a nested vector of type T and converts it to a
+ * Json::Value. Each sub-vector is converted to a Json::Value array and appended
+ * to the main Json::Value array.
+ *
+ * @param vec The nested vector to convert to Json::Value
+ * @return The Json::Value representing the nested vector
+ */
+template
+Json::Value nestedVectorToJson(const std::vector>& vec) {
+ Json::Value json(Json::arrayValue);
+ for (const auto& subVec : vec) {
+ json.append(vectorToJson(subVec));
+ }
+ return json;
+}
+
+/**
+ * @brief Converts the APIParams struct to a Json::Value
+ *
+ * This function takes an APIParams struct and converts it to a Json::Value.
+ * The Json::Value is a JSON object with the following fields:
+ * - data: a JSON array of base64 encoded images
+ * - max_keypoints: a JSON array of integers, max number of keypoints for each
+ * image
+ * - timestamps: a JSON array of timestamps, one for each image
+ * - grayscale: a JSON boolean, whether to convert images to grayscale
+ * - image_hw: a nested JSON array, each sub-array contains the height and width
+ * of an image
+ * - feature_type: a JSON integer, the type of feature detector to use
+ * - rotates: a JSON array of doubles, the rotation of each image
+ * - scales: a JSON array of doubles, the scale of each image
+ * - reference_points: a nested JSON array, each sub-array contains the
+ * reference points of an image
+ * - binarize: a JSON boolean, whether to binarize the descriptors
+ *
+ * @param params The APIParams struct to convert
+ * @return The Json::Value representing the APIParams struct
+ */
+Json::Value paramsToJson(const APIParams& params) {
+ Json::Value json;
+ json["data"] = vectorToJson(params.data);
+ json["max_keypoints"] = vectorToJson(params.max_keypoints);
+ json["timestamps"] = vectorToJson(params.timestamps);
+ json["grayscale"] = toJson(params.grayscale);
+ json["image_hw"] = nestedVectorToJson(params.image_hw);
+ json["feature_type"] = toJson(params.feature_type);
+ json["rotates"] = vectorToJson(params.rotates);
+ json["scales"] = vectorToJson(params.scales);
+ json["reference_points"] = nestedVectorToJson(params.reference_points);
+ json["binarize"] = toJson(params.binarize);
+ return json;
+}
+
+template cv::Mat jsonToMat(Json::Value json) {
+ int rows = json.size();
+ int cols = json[0].size();
+
+ // Create a single array to hold all the data.
+ std::vector data;
+ data.reserve(rows * cols);
+
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < cols; j++) {
+ data.push_back(static_cast(json[i][j].asInt()));
+ }
+ }
+
+ // Create a cv::Mat object that points to the data.
+ cv::Mat mat(rows, cols, CV_8UC1,
+ data.data()); // Change the type if necessary.
+ // cv::Mat mat(cols, rows,CV_8UC1, data.data()); // Change the type if
+ // necessary.
+
+ return mat;
+}
+
+/**
+ * @brief Decodes the response of the server and prints the keypoints
+ *
+ * This function takes the response of the server, a JSON string, and decodes
+ * it. It then prints the keypoints and draws them on the original image.
+ *
+ * @param response The response of the server
+ * @return The keypoints and descriptors
+ */
+KeyPointResults decode_response(const std::string& response, bool viz = true) {
+ Json::CharReaderBuilder builder;
+ Json::CharReader* reader = builder.newCharReader();
+
+ Json::Value jsonData;
+ std::string errors;
+
+ // Parse the JSON response
+ bool parsingSuccessful = reader->parse(
+ response.c_str(), response.c_str() + response.size(), &jsonData, &errors);
+ delete reader;
+
+ if (!parsingSuccessful) {
+ // Handle error
+ std::cout << "Failed to parse the JSON, errors:" << std::endl;
+ std::cout << errors << std::endl;
+ return KeyPointResults();
+ }
+
+ KeyPointResults kpts_results;
+
+ // Iterate over the images
+ for (const auto& jsonItem : jsonData) {
+ auto jkeypoints = jsonItem["keypoints"];
+ auto jkeypoints_orig = jsonItem["keypoints_orig"];
+ auto jdescriptors = jsonItem["descriptors"];
+ auto jscores = jsonItem["scores"];
+ auto jimageSize = jsonItem["image_size"];
+ auto joriginalSize = jsonItem["original_size"];
+ auto jsize = jsonItem["size"];
+
+ std::vector vkeypoints;
+ std::vector vscores;
+
+ // Iterate over the keypoints
+ int counter = 0;
+ for (const auto& keypoint : jkeypoints_orig) {
+ if (counter < 10) {
+ // Print the first 10 keypoints
+ std::cout << keypoint[0].asFloat() << ", " << keypoint[1].asFloat()
+ << std::endl;
+ }
+ counter++;
+ // Convert the Json::Value to a cv::KeyPoint
+ vkeypoints.emplace_back(
+ cv::KeyPoint(keypoint[0].asFloat(), keypoint[1].asFloat(), 0.0));
+ }
+
+ if (viz && jsonItem.isMember("image_orig")) {
+ auto jimg_orig = jsonItem["image_orig"];
+ cv::Mat img = jsonToMat(jimg_orig);
+ cv::imwrite("viz_image_orig.jpg", img);
+
+ // Draw keypoints on the image
+ cv::Mat imgWithKeypoints;
+ cv::drawKeypoints(img, vkeypoints, imgWithKeypoints, cv::Scalar(0, 0, 255));
+
+ // Write the image with keypoints
+ std::string filename = "viz_image_orig_keypoints.jpg";
+ cv::imwrite(filename, imgWithKeypoints);
+ }
+
+ // Iterate over the descriptors
+ cv::Mat descriptors = jsonToMat(jdescriptors);
+ kpts_results.append_keypoints(vkeypoints);
+ kpts_results.append_descriptors(descriptors);
+ }
+ return kpts_results;
+}
diff --git a/imcui/assets/logo.webp b/imcui/assets/logo.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0a799debc1a06cd6e500a8bccd0ddcef7eca0508
Binary files /dev/null and b/imcui/assets/logo.webp differ
diff --git a/imcui/datasets/.gitignore b/imcui/datasets/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/datasets/lines/terrace0.JPG b/imcui/datasets/lines/terrace0.JPG
new file mode 100644
index 0000000000000000000000000000000000000000..e3f688c4d14b490da30b57cd1312b144588efe32
--- /dev/null
+++ b/imcui/datasets/lines/terrace0.JPG
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4198d3c47d8b397f3a40d58e32e516b8e4f9db4e989992dd069b374880412f5
+size 66986
diff --git a/imcui/datasets/lines/terrace1.JPG b/imcui/datasets/lines/terrace1.JPG
new file mode 100644
index 0000000000000000000000000000000000000000..4605fcf9bec3ed31c92b0a0f067d5cc16411fc9d
--- /dev/null
+++ b/imcui/datasets/lines/terrace1.JPG
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d94851889de709b8c8a11b2057e93627a21f623534e6ba2b3a1442b233fd7f20
+size 67363
diff --git a/imcui/datasets/sacre_coeur/README.md b/imcui/datasets/sacre_coeur/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d69115f7f262f6d97aa52bed9083bf3374249645
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/README.md
@@ -0,0 +1,3 @@
+# Sacre Coeur demo
+
+We provide here a subset of images depicting the Sacre Coeur. These images were obtained from the [Image Matching Challenge 2021](https://www.cs.ubc.ca/research/image-matching-challenge/2021/data/) and were originally collected by the [Yahoo Flickr Creative Commons 100M (YFCC) dataset](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).
diff --git a/imcui/datasets/sacre_coeur/mapping/02928139_3448003521.jpg b/imcui/datasets/sacre_coeur/mapping/02928139_3448003521.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..102589fa1a501f365fef0051f5ae97c42eb560ff
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/02928139_3448003521.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f52d9dcdb3ba9d8cf025025fb1be3f8f8d1ba0e0d84ab7eeb271215589ca608
+size 518060
diff --git a/imcui/datasets/sacre_coeur/mapping/03903474_1471484089.jpg b/imcui/datasets/sacre_coeur/mapping/03903474_1471484089.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7b44afda94ce4044c6d13df6eb5b5b7218e7d38e
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/03903474_1471484089.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:70969e7b24bbc8fe8de0a114d245ab9322df2bf10cd8533cc0d3ec6ec5f59c15
+size 348789
diff --git a/imcui/datasets/sacre_coeur/mapping/10265353_3838484249.jpg b/imcui/datasets/sacre_coeur/mapping/10265353_3838484249.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0c91c5a55b830f456163fd09092a3887a1c37ef2
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/10265353_3838484249.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d9c8ad3c8694703c08da1a4ec3980e1a35011690368f05e167092b33be8ff25c
+size 454554
diff --git a/imcui/datasets/sacre_coeur/mapping/17295357_9106075285.jpg b/imcui/datasets/sacre_coeur/mapping/17295357_9106075285.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3d38e80b2a28c7d06b28cc9a36b97d656b60b912
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/17295357_9106075285.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54dff1885bf44b5c0e0c0ce702220832e99e5b30f38462d1ef5b9d4a0d794f98
+size 535133
diff --git a/imcui/datasets/sacre_coeur/mapping/32809961_8274055477.jpg b/imcui/datasets/sacre_coeur/mapping/32809961_8274055477.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a8ca6a7b7f3dfc851bf3c06c1e469d12503418d1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/32809961_8274055477.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c2bf90ab409dfed11b7c112c120917a51141010335530c8607b71010d3919fa
+size 458230
diff --git a/imcui/datasets/sacre_coeur/mapping/44120379_8371960244.jpg b/imcui/datasets/sacre_coeur/mapping/44120379_8371960244.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c2d24e63b8141978e577f9c393d5214afe25e241
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/44120379_8371960244.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc7bbe1e4b47eeebf986e658d44704dbcce902a3af0d4853ef2e540c95a77659
+size 357768
diff --git a/imcui/datasets/sacre_coeur/mapping/51091044_3486849416.jpg b/imcui/datasets/sacre_coeur/mapping/51091044_3486849416.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..38d612bf9b595f66c788a3dd7f16ada85f34f99e
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/51091044_3486849416.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe4d46cb3feab25b4d184dca745921d7de7dfd18a0f3343660d8d3de07a0c054
+size 491952
diff --git a/imcui/datasets/sacre_coeur/mapping/60584745_2207571072.jpg b/imcui/datasets/sacre_coeur/mapping/60584745_2207571072.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bfda213d8d27cb0d8840682497d48c88828604b6
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/60584745_2207571072.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:186e6b8a426403e56a608429f7f63d5a56d32a895487f3da7cb9ceaff97f563f
+size 470664
diff --git a/imcui/datasets/sacre_coeur/mapping/71295362_4051449754.jpg b/imcui/datasets/sacre_coeur/mapping/71295362_4051449754.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a8c27b023e9ac5a205e3e5e9ae83df437fcf1e44
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/71295362_4051449754.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f7e7a59fab6f497bd94b066589a62cb74ceffe63ff3c8cd77b9b6b39bfc66bae
+size 368510
diff --git a/imcui/datasets/sacre_coeur/mapping/93341989_396310999.jpg b/imcui/datasets/sacre_coeur/mapping/93341989_396310999.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2bab9cbcccfb317965c25ff0a5892560ac89c0ab
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping/93341989_396310999.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2932ea734fb26d5417aba1b1891760299ddf16facfd6a76a3712a7f11652e1f6
+size 364593
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..065d5ac91a7e84130a452f7455235f6e47939ad1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b73c04d53516237028bd6c74011d2b94eb09a99e2741ee2c491f070a4b9dd28
+size 134199
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5ab52d7abbd300abcc377758aee8fed52be247b0
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85a3bac5d7072d1bb06022b052d4c7b27e7216a8e02688ab5d9d954799254a06
+size 127812
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..25e5aedf11797e410e45e55a79fbae55d03acb1d
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3d1ccde193e18620aa6da0aec5ddbbe612f30f2b398cd596e6585b9e114e45f
+size 133556
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..43890e27b5281a7c404fe7ff57460cb2825b3517
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4d238d8a052162da641b0d54506c85641c91f6f95cdf471e79147f4c373162d
+size 115076
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9fef9b14bb670469d8b95d20e51641fff886c1ba
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2110adbf1d114c498c5d011adfbe779c813e894500ff429a5b365447a3d9d106
+size 134430
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dbcdda46077574aa07505e0af4404b0121efdc53
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbf08e8eadcadeed24a6843cf79ee0edf1771d09c71b7e1d387a997ea1922cfb
+size 133104
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..59a633f7b02c3837eef980845ef309dbb882b611
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ede5a1cf1b99a230b407e24e7bf1a7926cf684d889674d085b299f8937ee3ae3
+size 114747
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..40e6ec78de130ca828d21dccc8ec18c5208a7047
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5489c853477a1e0eab0dc7862261e5ff3bca8b18e0dc742fe8be04473e993bb2
+size 82274
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c29068ed9422942a412c604d693b987ada0c148d
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2991086430d5880b01250023617c255dc165e14a00f199706445132ad7f3501e
+size 79432
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87e1fd29e80b566665e3c0b0f1a9e5f512a746f1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dcca8fbd0b68c41fa987982e56fee055997c09e6e88f7de8d46034a8683c931e
+size 81912
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9fa21bc56b688822d8f72529f34be3fde624bb47
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbd874af9c4d1406a5adfa10f9454b1444d550dbe21bd57619df906a66e79571
+size 66472
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..04514db419ca03a60e08b8edb9f74e0eb724b7b9
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6fc41beb78ec8e5adef2e601a344cddfe5fe353b4893e186b6036452f8c8198
+size 82027
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..af0f8cbad53e36cbb2efd00b75a78e54d3df07d1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc55ac4f079176709f577564d390bf3e2c4e08f53378270465062d198699c100
+size 81684
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ceddbfe43526b5b904d91d43b1261f5f8950158d
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/03903474_1471484089_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f5ef044c01f3e94868480cab60841caa92a410c33643d9af6b60be14ee37d60f
+size 66563
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e0fb1c291af7ed2e45cf25909a17d6196a7927cc
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b5cd34c3b6ff6fed9c32fe94e9311d9fcd94e4b6ed23ff205fca44c570e1827
+size 96685
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e27c62e234992f2e424c826e9fdf93772fe5e411
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:225c06d54c04d2ed14ec6d885a91760e9faaf298e0a42e776603d44759736b30
+size 104189
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a097e96b8d913f91bd03ff69cbf8ba8c168532b5
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9bda2da8868f114170621d93595d1b4600dff8ddda6e8211842b5598ac866ed3
+size 101098
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..20c7db4102c7c878ada35792e34fa49b1db177d2
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de7d6445cadee22083af4a99b970674ee2eb4286162ec65e9e134cb29d1b2748
+size 83143
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..72bcfa0bc7488c9e6f2f7fa48414751fbab323aa
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f14f08b0d72e7c3d946a50df958db3aa4058d1f9e5acb3ebb3d39a53503b1126
+size 96754
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6cc5f76d9d7435ff36d504fda63719e84cc801d2
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b5ec4c6fa41d84c07c8fffad7f19630fe2ddb88b042e1a80b470a3566320cb77
+size 101953
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..92e85895809de28aee13c9e588c4c4b642234a6c
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/10265353_3838484249_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:095ca3d56fc0ce8ef5879a4dfcd4f2af000e7df41a602d546f164896e5add1b5
+size 82961
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f5675e6886adc9ab72f023ebaaaaea34addbd50
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fed7e7f7adc94c88b9755fce914bd745855392fdce1057360731cad09adc2c3a
+size 119729
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..708cf8448bf95d8d6f5b7bfc2cf8887b90423038
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0a8f76fdbfe1223bc2342ec915b129f480425d2e1fd19a794f6e991146199928
+size 125780
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..85c17020b2e6b48021409dce19ec33bb3a63f3b7
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:335726de47840b164cd71f68ad9caa30b496a52783d4669dc5fab45da8e0427f
+size 111548
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..22831f09d5d8d80dce6056e67d29c0cf0b5633ca
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffc1822c65765801d9320e0dff5fecc42a013c1d4ba50855314aed0789f1d8f2
+size 87725
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1f38d0ce7d34fb720efe81d0b7221278f3a2b2a8
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af42b452bca3f3e5305700f2a977a41286ffcc1ef8d4a992450c3e98cd1c1d05
+size 119644
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..684dd21fbc6acb244f36bdb9b576681214de06e6
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2452d8839721d7c54755be96e75213277de99b9b9edc6c47b3d5bc94583b42c1
+size 111275
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..22bf9bdad97bfa626738eb88110a0ba12a967ed9
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/17295357_9106075285_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7d580e865a744f52057137ca79848cc173be19aaa6db9dcaa60f9be46f4c465
+size 87490
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4702aed5bddf1b3db6fcf7479c39ec33514ec44a
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:55c35e80d3783c5a93a724997ae62613460dcda6a6672593a8e9af6cc40f43c0
+size 98363
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5deeb08577e2b4517b23e1179f81f59b6ac3b1f5
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:23d8d3c4dcfea64e6779f8e216377a080abfdd4a3bc0bf554783c8b7f540d27f
+size 102149
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3d6d066deabecff0cf4c1de14e199ab947244460
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:516d60da5e943828a525407524cf9c2ee96f580afb713be63b5714e98be259b7
+size 92591
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7f21448e6c3b3b7a9c40879d9743e91037738266
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b0cf682dbbdd398ff1aa94e0b3ca7e4dafac1f373bbc889645ee6e442b78284
+size 79136
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87d31072cf1bea18293fba156dcdf676c12da898
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:88dd8ab393c2b7fd7f1dbbff1a212698dccee8e05f3f8b225fdd8f7ff110f5f1
+size 98588
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..992d761ff6b17d26d8e92a4e63df4206e611324c
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:693ef906b4a3cdedcf12f36ce309cc8b6de911cc2d06ec7752547119ffeee530
+size 93063
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8768bc44b96253a42b3736ac6b2ebfe4b9a6f8dc
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/32809961_8274055477_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06d0296ad6690beb85c52bf0df5a68b7fd6ffcf85072376a2090989d993dfbf8
+size 79729
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..05fb0e416b83a4189619c894a9b0d0e06d0641b4
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfc70746d378f4df6befc38e3fc12a946b74f020ed84f2d14148132bbf6f90c7
+size 73581
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e59b0eb3d7cda3065c3e85e29ed0b75bcb8c3e08
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56d78dbcbb813c47bd13a9dfbf00bb298c3fc63a5d3b807d53aae6b207360d70
+size 79424
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..11c63767e7494af427722008d7eb2c93fbcd8399
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b97eed0c96feac3e037d1bc072dcb490c81cad37e9d90c459cb72a8ea98c6406
+size 78572
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e9d45d0f296d87d726dfcdd742dadad727bda193
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0da47e47bf521578a1c1a72dee34cfec13a597925ade92f5552c5278e79a7a14
+size 62148
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c3282a63d04c032c509ee4fe15ec468d0b4b8f15
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29e6c31d2c3f255cd95fa119512305aafd60b7daf8472a73b263ed4dae7184ac
+size 75286
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..78dee48f38cb0595157636b72531eb8a3967ea06
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af952a441862a0cadfc77c5f00b9195d0684e3888f63c5a3e8375eda36802957
+size 78315
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d5765a113572929f7e99cd267a53c5457fb46272
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/44120379_8371960244_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c68e15b9a7820af38f0a0d914454bb6b7add78a0946fd83d909fde2ca56f721
+size 62828
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..41b2570b6f773ad1a8c2b06587bf3921a11827ee
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ccabe0811539e04ddb47879a1a864b0ccb0637c07de89e33b71df655fd08529
+size 103833
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8b75de67cf95e5598bdbde4dc6bb6371e99c158b
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f629597738b0cfbb567acdecfbe37b18b9cbbdaf435ebd59cd219c28db199b38
+size 109762
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4d592f7c0490225cfea3546cc1d4ca5b5f384715
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e44cb9873856cd2c785905bc53ef5347ad375ed539b1b16434ceb1f38548db95
+size 109015
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6b380dddec3ccaa824f39d0d2f86a466904921f3
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4c260b40e202d4c2144e4885eb61a52009ee2743c8d04cd01f1beb33105b4c0
+size 95253
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf31081cae7339d043d9b37174d32c2f39bd46e7
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1c256473100905130d2f1d7e82cd2d951792fc5ff4180d2ba26edcc0e8d17f0
+size 103940
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..52a36e0065b37f73d28c6d8e3b332d62d36605b7
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65d4c53b714fd5d1ba73604abcdb5a12e9844d2f4d0c8436d192c06fe23a949c
+size 108605
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b5d80a6ac1c5f37b1608a45b4588f15cc7e6f9a4
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/51091044_3486849416_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4434b4fe8eba5ca92044385410c654c85ac1f67956eb8e757c509756435b33c
+size 95080
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7b7f2992ee6dfe32f9476b48a7c29c93c40aa5b1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8e7e6ec0df2dbd9adbfa8b312cd6c5f04951df86eb4e4bd3f16acdb945b2d7b
+size 106398
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2356d37020783c7f06a3c75b0ac93f53244b2a21
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a85ee7b1041a18f0adc7673283d25c7c19ab4fdbe02aa8b4aaf37d49e66c2fcc
+size 109233
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..580ea14ffadcbef569d9947566cd6601ea5e3a31
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ba97ee2207a37249b3041a22318b8b85ac4ac1fcec3e2be6eabab9efcb526d7
+size 111988
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1bf37416ec0e9490efa4118ca15e3a5d3bc8f13b
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea908aa57fff053cb7689ebcdcb58f84cf2948dc34932f6cab6d7e4280f5929f
+size 93144
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5083e2f231c4489ad251789cf17fd279e94de123
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c83f58489127bbf5aefd0d2268a3a0ad5c4f4455f1b80bcb26f4e4f98547a52
+size 106249
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..61d8dd667a41288b5466a590a4f7917113b4f8eb
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:daaf1a21a08272e0870114927af705fbac76ec146ebceefe2e46085240e729af
+size 112103
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1f2337cd550356b50310d8f66900bb49a5cc1f78
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/60584745_2207571072_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57e4c53f861c5bab1f7d67191b95e75b75bdb236d5690afb44c4494a77472c29
+size 92118
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3e3139d8c7e911e5c1597f8d4cc211ba0bc62c8e
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2367c7a1fc593fe02cefa52da073ce65090b11e7916b1b6b6af0a703d574070
+size 79924
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7b96616f4dfcefeedc876ded2f7a05bbf5989198
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d93b4acc0d8fb7c150a6b5bcb50b346890af06557b3bfb19f4b86aa5d58c43ed
+size 81568
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8417a80887bd923f92e37187b193ac38626f7815
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da12abfef15e8a462c01ad88947d0aca45bae8d972366dd55eecbeb5febb25cb
+size 80924
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6444b8d7c844fc82804fba969839318da0c922cd
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42280a19b1792682a6743dfdbca8e4fb84c1354c999d68b103b1a9a460e47ca5
+size 63425
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3ff18e6d4b0eabbac042eefdcfa4fadca6921c73
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9ff7cbd7f10f3f6e8734aab05b20a1fd1e0d4b00a0b44c14ed4f34a3f64642c
+size 80202
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..959468fa7db4f3dde517e0ee4c59f7dc87ece238
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9cc498c25f22f75f099a9b2cf54477e5204e3f91bd0a3d0d4498f401847f4ac8
+size 80296
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..16054efff56ce617e17ff93d2c8029a7264f5951
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/71295362_4051449754_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63eb14c5563ee17dadd4bed86040f229fdee7162c31a22c1be0cb543922442f7
+size 62355
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot135.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot135.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d73dfa10ebdeb97740b773790fd1fa313bb0556a
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a21943a12405413ce1e3d5d0978cd5fc2497051aadf1b32ee5c1516980b063a
+size 79615
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot180.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot180.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0d6a78bb7cc2274b2714c97706cac35bd601e933
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot180.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7e5a61e764ba79cb839b4bbe3bd7ecd279edf9ccc868914d44d562f2b5f6bb7
+size 81016
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot225.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot225.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..83b348efda25d253fa197b29138ee2034124c4b6
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot225.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58a09610d7ae581682b7d0c8ce27795cfd00321874fd61b3f2bbe735331048c6
+size 81638
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot270.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot270.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8d8f611ec766ae6c4c0ea023cff1d17b0851def6
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot270.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65b3cbda1b6975f9a7da6d4b41fe2132816e64c739f1de6d76cd77495396bfc0
+size 68811
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot315.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot315.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..93754c8836ed911f93106c77c94047ef9c265343
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot315.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:228fc8095ed6614f1b4047dc66d2467c827229124c45a645c1916b80c6045c72
+size 78909
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot45.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot45.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bf980bf62804f81ed9384c09897091609b63d933
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot45.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b6a6c9006f4c7e38d5b4b232c787b4bf1136d3776c241ba6aadb3c60a5edf5e
+size 81485
diff --git a/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot90.jpg b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot90.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..522bc81a65f743be92e2ecc73d4c1ef41b8fdfbd
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_rot/93341989_396310999_rot90.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2650e275a279ccb1024c2222e2955f2f7c8cef15200712e89bf844a64d547513
+size 68109
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d27fe8d6f813e1525d581e2665feff4704710e98
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65632a80a46b26408c91314f2c78431be3df41d8c740e1b1b2614c98f17091a0
+size 21890
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..26c82557bab1bb66463603bcbb23fb179ab1456a
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b95394b52baf736f467b28633536223abaf34b4219fe9c9541cd902c47539de7
+size 40252
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4e76cf942182757709234e6b44fdef9ea255035f
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de85408453e630cc7d28cf91df73499e8c98caffef40b1b196ef019d59f09486
+size 52742
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3e87b7854fded8fa9bd423e47b93da3e4e76414d
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35bc6e9a0c6b570e9d19c91d124422f17e26012908ee089eede78fd33ce7c051
+size 65170
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4cdcbbf48d27668ecc6e2c4ddf109a6cf1405e69
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6120fd5688d41a61a6d00605b2355a1775e0dba3b3ade8b5ad32bae35de897f8
+size 80670
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e60569c28a50e261f661a20816fe5ab4dd265754
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea90074a3c2680086fad36ea3a875168af69d56b055a1166940fbd299abb8e9f
+size 96216
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..022cf34ec13b5accbfcaf5c6ac9adc9ed63f4387
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/02928139_3448003521_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4371f89589919aa8bc7f4f093bf88a87994c45de67a471b5b1a6412eaa8e54e4
+size 112868
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ae239cdbed0fa5094181cfa85977a4cc5dd401b1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:394013dede3a72016cfefb51dd45b84439523d427de47ef43f45bbd2ce07d1b1
+size 18346
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7816ec7253dac16d13da7ab2bb05eb918cf90ec9
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd82ebb523f88dddb3cfcdeb20f149cb18a70062cd67eb394b6442a8f39ec9b1
+size 29287
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..75421d467f758d274c43191085d6b99ec6987741
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8927ec9505d0043ae8030752132cf47253ebc034d93394e4b972961a5307f151
+size 37883
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4b77784210fb86c4db0f587c70061dff44d72e06
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27f0c0465ba88b9296383ed21e9b4fa4116d145fae52913706a134a58247488b
+size 44461
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..13b6b9d3b1d66e9356a0b97567f221882ea9bc7e
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1344408ec4e79932b9390a3340d2241eedc535c67a2811bda7390a64ddadaa75
+size 52812
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c1a7d765ecf17c513f6c716b8ba1b9421d8ad5c1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc4b253faf50d53822fdca31acfd269c5fd3df28ef6bd8182a753ca0ff626172
+size 63930
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..735e7c4475c84772447f96b0ca40703482044e94
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/03903474_1471484089_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0822b92d430c4d42076b0469a74e8cc6b379657b9b310b5f7674c6a47af6e8f
+size 73652
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..06204bf54d2df8157c21f061496e644b6fa47b23
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7bcef8bb651e67a969b5b17f55dd9b0827aaa69f42efe0da096787cae0ebc51e
+size 19404
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d16aaf7bdbc453b42fc2d8c72fd9cf90bd5e3317
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bdebf79043e7d889ecebf7f975d57d1202f49977dff193b158497ebcfc809a8f
+size 33947
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c24e10ded3cca68b8851073f8eb19f36ee07acc2
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89bf1005e2bb111bf8bd9914bff13eb3252bf1427426bc2925c4f618bd8428b6
+size 44233
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cd4f55bd2b3587f56a8c60ab0f4c6c0298d9b197
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b69b20d8695366ccbf25d8f40c7d722e4d3a48ecb99163f46c7d9684d7e7b2e5
+size 55447
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..da7c59860bdb2812ea138b55965ebc0711c80f5f
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f92f6225593780ad721b1c14dcb8772146433b5dbf7f0b927a1e2ac86bb8c4bd
+size 65029
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..807d37f1572761bf8d876384e33577f26ca257b4
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f91146bc682fc502cbefd103e44e68a6865fbf69f28946de2f63e14aed277eb5
+size 78525
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2c8384e6688273d4f1f31a80a3ef47a4e89a21b8
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/10265353_3838484249_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:406f548692e16ac5449d1353d1b73b396285d5af82928e271186ac4692fa6b61
+size 92311
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..52b5e6a247240106fd5e5c886b7a22cf7de939c3
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f1667f20a37d2ebc6ecf94c418078492ab465277a97c313317ff01bbafad6ae
+size 19279
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1ecf2d221eb38ae1bebfce5e722f0999325ac519
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:107c4cc346f8b1f992e369aedf3abdafea4698539194f40eade5b5776eef5dd5
+size 36526
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6976776119ba2f20a89e7cee4e04fdea6894677a
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:407f9290be844f5831a52b926b618eff02a4db7bc6597f063051eeeaf39d151b
+size 47460
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bcd235542d03cdbb3480a44ee181c01ca48a6558
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:967e490e89b602ecaf87863d49d590da4536c10059537aaec2d43f4038e3d771
+size 62509
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b987b854cdb4d833cf6aa74df88927a975c44362
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a5afc12af54be3cb0dba0922a7f566ee71135cea8369b2e4842e7ae42fe64b53
+size 76306
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..877bb2c923b487518c640125c5197b0edf0e0bfb
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:61f266200d9a1f2e9aefb1e3713aa70ccd0eed2d5c5af996aef120dcf18b0780
+size 91762
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4d0c041790cf9f3de32e608c04e589917d546556
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/17295357_9106075285_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12aaae2e7a12b49ec263d9d2a29ff730ff90fbed1d0c5ce4789550695a891cbe
+size 107595
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5f9313cfc25615fe53efa05ffcdfe54b62ccdcdb
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b7e8dc5991d5af4b134fe7534c596aad5b88265981bba3d90156fff2fa3b7bd8
+size 19569
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e50f64360fe0805698b11202f12caaf6f939049c
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad8c1fc561fc5967d406c8ca8d0cce65344838edc46cc72e85515732a1d1a266
+size 34434
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8ad560b70e0923b134bdf4fd35c75036b9c1feb6
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:132a899e4725b23eb0d0740bfda5528e68fdfcafef932b6f48b1a47e1d906bdb
+size 44062
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7713f99755d6fb12f24d59ace7c6bb30af8d0140
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6ca571a824efff8b35a3aa234756d652aa9a54c33dfd98c2c5854eeb8bc2987
+size 56056
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cd312a40e414dabc167d982e7283703e8d533bae
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66177f53a20ae430a858f7aeb2049049da83b910c277b5344113508a41859179
+size 65465
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f40be8bd5671519bfa985a592d4b4a2984582d4d
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a6db2f05714a37dfd8ab21d52d05afd990b47271c1469e0c965d35dcd81e2c39
+size 77708
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9ecb09015abbcbf32907ff8584f9427161e3ed64
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/32809961_8274055477_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5a3a58d711cfe35f937f5b9e9914c7dcb427065d685a816c4dc5da610403d981
+size 92547
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e3cd992394f9ace830fe1df7911f04c80e1abe0a
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bea2d09de8c92e49916c72f57ef8d0bc39fb63b51cdd2568706cbe01028a4480
+size 18119
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf490019463c6e01cc23fc2962c468c633a53aaa
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0312626c40b38df7b64471873e2efaebc90a5a34cfe9778da7b7a9a6793975be
+size 28961
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f7ac21a0279982bb6f0a8ddf715f794b1577fc71
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65ab9561483b9b02fd62304ce34c337107494050635e8a16bbf4b6426f602dbb
+size 36523
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bfe516659198de7a9fd0d8bfe1f961804d9c0e00
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36fe588a2cb429b6655f5fff6253fec5d3272fa98125ffd1f84bdf0bab0f63e5
+size 44134
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..44574a69046cb62cedeffee56c34ae91cf160c7f
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04b725bd69e0078cc6becbe04958d8492d7845a290d72ade34db03eeaebb9464
+size 52784
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7ccebe81457b6d7a75583ad0bd3e1a080d3be65b
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2febc81d0fede4f6fe42ca4232395798b183358193e1d63a71264b1f57a126af
+size 62516
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e78525e313a8a97bcbc16e035b73b4e69949b1f6
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/44120379_8371960244_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0cd73a22d565e467441ed367a794237e5aeb51c64e12d31eba36e1109f009f5
+size 72003
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..508aa2e17c3eafd0cbfdbeb5f7e8383afd1eff49
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8fd1041d9aa8b9236dc9a0fe637d96972897295236961984402858a8688dea4
+size 19103
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f749213b5d42af17b5e8fa85be5f2a0099900482
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27d421ef6ed9d78b08d0b4796cfc604d3511a274e2c4be4bd3e910efcc236f07
+size 34317
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0af327e6da018baaf5e862a25b002fec00ee1481
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d5480d172fc730905eb0de44d779ec45ecd11bd20907310833490eb8e816674e
+size 45635
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87a0920fc65ff0f224724e182baffa5efabb2727
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c139a0636fd9bb1c78c6999e09bd5f7229359a618754244553f1f3ff225e503
+size 55257
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c9429d3921a2f46d2bc2fa1b845c7b5ca8c34b4d
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c7f19e2d7bf07bd24785cd33a4b39a81551b0af81d6478a8e6caffd9c8702a4
+size 68433
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..816eb35ce7690a8f338698a32bf882f77c226515
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a54a79a52a214e511e7677570a548d84949989279e9d45f1ef2a3b64ec706d4
+size 83428
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..37c041c8c84ee938aa69efdba406855e71853961
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/51091044_3486849416_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44b0fccc48763c01af6114d224630f0e9e9efb3545a4a44c0ea5d0bfc9388bc6
+size 97890
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a2fac4209480c24b542cdfae90539ca30cb34730
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0a0621b195be135487a03013a9bbefd1c8d06219e89255cff3e9a7f353ddf512
+size 20967
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b9bb9631981f5525ccab16e8c57e9a5fc1cd8e23
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f463bf0ca416815588402256bc118649a4021a3613fda9fee7a498fe3c088211
+size 36897
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ef911222845da2a6392efd777b08888b386e36ff
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98f2e11e0dfdc72f62b23da6b67f233a57625c9bf1b3c71d4cb832749754cc8b
+size 46863
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3b3ed4111f137eba6204f7678b434da1a38b8b5f
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3fa19ae2f619f40298b7d5afe1dd9828a870f48fd288d63d5ea348b2a86da89c
+size 58545
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fd502a98dd92d22c8246e9a2febad4e63fd49573
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44a1ff7cfbd900bbac8ad369232ff71b559a154e05f1f6fb43c5da24f248952d
+size 70581
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e746d3e8df82d3022cb1e3192f51b067f8d22020
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:351b645bd3243f8676059d5096249bde326df7794c43c0910dc63c704d77dc28
+size 83757
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf18599607edcacf5f706fdb677fe63d2c6d2b37
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/60584745_2207571072_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:592c76689f52ed651ee40640a85e826597c6a24b548b67183cc0cb556b4727e4
+size 97236
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf7917d922cdbd30d8a415c3f71c4f1970aaace5
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5240431398627407ddf7f59f7d213ccf910ddf749bfe1d54f421e917b072e53b
+size 17605
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7f0604df7dde207aa3571ad68eeb37ea9c9646f6
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3436f60a6f7f69affcd67ac82c29845f31637420be8948baa938907c7e5f0a7
+size 28290
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fee6de7cbea4e841d6811d89a80baf7c8592b79c
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb8aa81f49cda021a0d92c97ff05f960274ab442f652d6041b163f86e2b35778
+size 37127
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b6fcf0e854278c71762b85c0eff8b74a72a8cf78
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8e931c5c4dfe453042cca041f37fa9cdc168fbe83f7a2ccd3f7a8aced93a44e
+size 45866
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d2b263a84eb2fa2ecc1ff8e8e516ee47d8676904
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:60fc4db91f0f34837601f066cd334840ce391b6d78b524ce28b7cc6a328227e2
+size 53956
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7ec774e14c9d0a09ba1d51f8834f5719a132b89e
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff223aa66f53e55ecd076c15ccb9395bee4d1baa8286d2acf236979787417626
+size 63961
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7c9d39b86fe5e72e9fe4061af79fe9bda7096a86
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/71295362_4051449754_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c54568031f5558c7dfff0bb2ed2dcc679a97f89045512f44cb24869d7a2783d5
+size 74563
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.2.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..554e7360ffea8643309bc60a4397b6f0a33600b1
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d5e5f2d3cf4b6ebb1a33cf667622d5f10f4a6f6905b6d317c8793368768747b
+size 18058
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.4.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b9c001354e3160e3849485eb79cfd9bd91bd559b
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.4.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f7613d4ae7a19dd03b69cce7ae7d9084568b96de51a6008a0e7ba1684b2e1d18
+size 28921
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.5.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..654b114c83a846ddc7d3004d4ecb33bb48f9dd95
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.5.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:120f479a4a45f31a7d01a4401c7fa96a2028a2521049bd4c9effd141ac3c5498
+size 36174
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.6.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..548a5f7d5ed22f5a93fd58723d47a979e3d0a9fe
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d490f0baec9460ad46f4fbfab2fe8fa8f8db37019e1c055300008f3de9421ec
+size 43542
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.7.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a6ff24a52f698f00033277c0fb149c1e3ea92fea
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.7.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68697e9c9cda399c66c43efad0c89ea2500ef50b86be536cea9cdaeebd3742db
+size 53236
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.8.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cad1a64516428929de5303d83eb9cd0cc6592f42
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.8.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:161e5e0aa1681fb3de437895f262811260bc7e7632245094663c8d527ed7634d
+size 61754
diff --git a/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.9.jpg b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f7e9b31632c13cf41edcfe060442ec44acdae202
--- /dev/null
+++ b/imcui/datasets/sacre_coeur/mapping_scale/93341989_396310999_scale0.9.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf1f3fd9123cd4d9246dadefdd6b25e3ee8efddd070aa9642b79545b5845898d
+size 72921
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/adam.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/adam.png
new file mode 100644
index 0000000000000000000000000000000000000000..d203797ac4dbe32b56bf858bf574ccba301b2890
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/adam.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2ab84505b1b857968e76515f3a70f8a137de06f1005ba67f7ba13560f26591a1
+size 204675
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cafe.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cafe.png
new file mode 100644
index 0000000000000000000000000000000000000000..721255df1cc096fd276591904a6635bbc9335934
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cafe.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8b75eb268829f64dc13acb5e14becd4fc151c867a0b7340c8149e6d0aa96b511
+size 677870
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cat.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cat.png
new file mode 100644
index 0000000000000000000000000000000000000000..44a96af0b33b709f46d89d15ad9758f0207eefee
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/cat.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b9e70a57eb37ba42064edfd72d85103a3824af4e14fb6e8196453834ceb8eaf
+size 858661
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/dum.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/dum.png
new file mode 100644
index 0000000000000000000000000000000000000000..d4bb86ef5af0caa7ed1bd233f45f8e79a44a1c45
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/dum.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c025b6776edf6e75dee24bba0e2eb2edd9541129a5c615dcb26f63da9d3e2c5
+size 1227154
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/face.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/face.png
new file mode 100644
index 0000000000000000000000000000000000000000..af9eb4b47e1e1adb17f29a6e3d06cd03b51b00dd
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/face.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8819f9fb3a5d3c2742685ad52429319268271f8d68e25a22832373c6791ef629
+size 1653453
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/fox.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/fox.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ee9996c5a6778e44ec12d68fabf4a2d8dc45cef
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/fox.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4ea0030c9270f97c5d31fa303f80cb36ccfbdbc02511c1e7c7ebee2b237360f
+size 847938
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/girl.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/girl.png
new file mode 100644
index 0000000000000000000000000000000000000000..33c2c22d3a73ebc78da6c143a98c9b18468ff9b5
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/girl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:699cbd2b94405170aac86aafb745a2fe158d5b3451f81aa9a2ece9c7b76d1bca
+size 1089404
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/graf.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/graf.png
new file mode 100644
index 0000000000000000000000000000000000000000..70eb91f5aa6f35270889b83edc84c6a71de08846
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/graf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33d874bf68a119429df9ed41d65baab5dbc864f5a7c92b939ab37b8a6f04eaa5
+size 947738
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/grand.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/grand.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc07235113020505da8521231982426ab737a8c2
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/grand.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e86a53256e15f7fa9ab27c266a7569a155c92f923b177106ed375fac21b64be
+size 1106556
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/index.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/index.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1400bb6e79c8ac25fd5555648aca002d6af46c4
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/index.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44da42873d5ef49d8ddde2bc19d873a49106961b1421ea71761340674228b840
+size 764609
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/mag.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/mag.png
new file mode 100644
index 0000000000000000000000000000000000000000..64acb056fcd7fda7a95fc3844772f1f628653489
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/mag.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba4102c8d465a122014cf717bfdf3a8ec23fea66f23c98620442186a1da205ff
+size 109706
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/pkk.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/pkk.png
new file mode 100644
index 0000000000000000000000000000000000000000..44d7844ecea3980342b5564935f9406bc3bac23a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/pkk.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:965fced32bbcd19d6f16bbe5b8fa1b4705ce2e1354674a6778bf26ddaf73a793
+size 1030148
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/shop.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/shop.png
new file mode 100644
index 0000000000000000000000000000000000000000..f58917a71cb265bcd4e404147c28444b811cebe3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/shop.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20a9b74db30bb47f740cdc9ac80d55c57169dcaaeaa9e6922615680ac8651562
+size 950858
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/there.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/there.png
new file mode 100644
index 0000000000000000000000000000000000000000..b39c566fa6861e9e3d9c4521513fdc74bc6241b0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/there.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0aa0ec27f2a82680c1f69d507b6cf1832fa4543c0a3cee4da0c9e177bf8c5cf4
+size 2194678
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/vin.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/vin.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c96c24a8181a0313caa0e05c3c521194c7155c2
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/1/vin.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a5420934a1fc2061b30b69cecb47be520415cf736c1ebe0ecb433afe9f4d3ea
+size 971107
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/adam.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/adam.png
new file mode 100644
index 0000000000000000000000000000000000000000..df6cee2629c7fe032fbef3ac81521cb28551f4c8
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/adam.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17f571379d9764726f9ba43cac717d4f29e7a21240bed40a95a2251407ae2c5d
+size 170216
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cafe.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cafe.png
new file mode 100644
index 0000000000000000000000000000000000000000..68ab7b57cd637ba9d704bca9b9867c1ee6a9a4e0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cafe.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:69aaaa6832d73e3b482e7c065bad3fce413dd3d9daa21f62f5890554b46d8d62
+size 668573
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cat.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cat.png
new file mode 100644
index 0000000000000000000000000000000000000000..16844477316e600d83398302f7b1b938a24d98f6
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/cat.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6f97f05c6f554fc02a4f11af84026e02fd8255453a4130a2c12619923c669ae
+size 855468
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/dum.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/dum.png
new file mode 100644
index 0000000000000000000000000000000000000000..6075f218a49b1ab906ac6b8a229c6d374c1b5fd1
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/dum.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e73d572607d2bd6283f178f681067dae54b01ceea3afcd9d932298d026e7202
+size 1124916
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/face.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/face.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d8e187deae660c6930b28440f44349334c935c6
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/face.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:abc729a77234d4f32fa7faf746a729596fc1dd8c3bcf9c3e301305aedbd3ba1f
+size 1052127
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/fox.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/fox.png
new file mode 100644
index 0000000000000000000000000000000000000000..faf3b8f755a711d53f0b3f7c59d5637b27f3d879
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/fox.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ceec2ea119120128d44f09aef9354cc39098cf639edb451894f6d4c099950161
+size 737697
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/girl.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/girl.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e196cb63dd8977f0ab483e7f8e83da4054f87e1
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/girl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b458d84eaf32b19b52f8871d03913cb872ce87d55d758a1dfc962eff148f929
+size 1044256
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/graf.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/graf.png
new file mode 100644
index 0000000000000000000000000000000000000000..4614b40313c575a31bccef6fca21e955ba98dc85
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/graf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3f3bce412ad9c05e8048a5fdb8a2c32dbba4a6a37db02631b5fff035712d556
+size 1037985
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/grand.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/grand.png
new file mode 100644
index 0000000000000000000000000000000000000000..855963559608e2599620426bc341bbccee82d0d8
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/grand.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8af49e086e21804e9ec6d3a117cad972187f6ea8bc05758893e16d56df720631
+size 1156742
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/index.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/index.png
new file mode 100644
index 0000000000000000000000000000000000000000..094bba34dd52c273985b48650b403d06b4105354
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/index.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74f4f0cf2ff09fe501e92ebbe89cc5c5a1e0df00aec1ce59f5aebed8ecb6ac5e
+size 764123
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/mag.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/mag.png
new file mode 100644
index 0000000000000000000000000000000000000000..0583f0bf356cc2b2eaa61394561819b9fcdc2ef6
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/mag.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fcb58c53ce102e0302f321027995dfba64fe1195f481ad48be12e70f80fe66b
+size 114421
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/pkk.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/pkk.png
new file mode 100644
index 0000000000000000000000000000000000000000..ccdc42a0aeee171029441e59eda873d154046ae0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/pkk.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f608cf8f1498bf65452f54652c755f7364dd0460ee6a793224b179c7426b2c6
+size 1030783
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/shop.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/shop.png
new file mode 100644
index 0000000000000000000000000000000000000000..8a8a62b6362e04ae8650c7fcb7efc85d84c2deca
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/shop.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:863f7ecf093d3f7dac75399f84cf2ad78d466a829f8cd2b79d8b85d297fe806c
+size 787662
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/there.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/there.png
new file mode 100644
index 0000000000000000000000000000000000000000..b6cdefd446bcecfb197ab24860512de6c78fc380
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/there.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:597a5750d05402480df0ddc361585b842ceb692438e3cc69671a7a81a9199997
+size 2208079
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/vin.png b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/vin.png
new file mode 100644
index 0000000000000000000000000000000000000000..cf45bf5859e3392daa77ee273adafb89ed7eb29a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/2/vin.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:edfc46c5e1ad9b990a9d771cff268e309f5554360d9017d5153c166c73def86e
+size 1173916
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/adam.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/adam.txt
new file mode 100644
index 0000000000000000000000000000000000000000..62fc9cce248fa1a8152402c90cd5ea11e8df3571
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/adam.txt
@@ -0,0 +1,3 @@
+-0.475719 0.28239 -808.118
+-0.0337691 -2.62434 -12.0121
+0.000234167 8.96558e-05 -2.96958
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cafe.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cafe.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a7cb081273eb3092eb5bfa927dc167b548c21f8b
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cafe.txt
@@ -0,0 +1,3 @@
+-15.8525 20.632 -111.81
+-0.866106 -0.309191 -335.089
+0.00197221 0.00575211 -5.53643
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cat.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cat.txt
new file mode 100644
index 0000000000000000000000000000000000000000..68070704b87b0e675be8130a9f9233f5206fd22a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/cat.txt
@@ -0,0 +1,3 @@
+-39.2129 0.823102 5118.35
+1.39233 -0.138102 -1711.78
+0.00592927 0.00421801 -11.1641
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/dum.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/dum.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9fcc4dc9b2e9f94c6b340c1fe2132e0c841ab0d3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/dum.txt
@@ -0,0 +1,3 @@
+-0.778836 0.981896 -748.395
+3.80065 -6.24622 -695.732
+0.00344236 0.000197217 -4.13445
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/face.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/face.txt
new file mode 100644
index 0000000000000000000000000000000000000000..44fc3af247ae00c5d20fdd07b712844c4150fe95
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/face.txt
@@ -0,0 +1,3 @@
+-0.260478 -0.124431 -3232.91
+-0.0968348 -2.12523 -254.238
+4.85011e-05 -0.000178626 -4.2015
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/fox.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/fox.txt
new file mode 100644
index 0000000000000000000000000000000000000000..95663e958f1fc7e10ec61cbda63a506a7bf5188f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/fox.txt
@@ -0,0 +1,3 @@
+-1.39944 -4.86818 -1653
+3.74922 -40.7447 2368.02
+6.58821e-05 -0.016356 -6.58091
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/girl.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/girl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b4e1dabd05273f50281a0d8c06332c90d0e6b5d5
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/girl.txt
@@ -0,0 +1,3 @@
+1.80543 -28.3029 -3740.88
+1.55303 -18.1493 -474.282
+0.00279485 -0.0334053 -4.76287
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/graf.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/graf.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c088d86285144f45038af4018183d5fc61e332dc
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/graf.txt
@@ -0,0 +1,3 @@
+ 4.2714590e-01 -6.7181765e-01 4.5361534e+02
+ 4.4106579e-01 1.0133230e+00 -4.6534569e+01
+ 5.1887712e-04 -7.8853731e-05 1.0000000e+00
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/grand.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/grand.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e8cad131c7f7c99e21ce61663588aa6a97c60553
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/grand.txt
@@ -0,0 +1,3 @@
+-0.964008 0.742193 -686.412
+3.65869 -3.74358 -1880.34
+0.00239552 0.000705499 -3.33892
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/index.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/index.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0a4f3aa9a93608b06eeadb1c1006f27a03908680
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/index.txt
@@ -0,0 +1,3 @@
+0.420191 -1.04493 1337.46
+-2.95716 8.01143 1193.82
+-0.00241706 -0.000152038 5.08127
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/mag.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/mag.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f88872a57071cad157cf98018bb3cfdfb75adb7e
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/mag.txt
@@ -0,0 +1,3 @@
+-0.296887 -16.5762 4405.55
+0.672497 0.906427 398.264
+-0.000777876 0.00319639 3.31196
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/pkk.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/pkk.txt
new file mode 100644
index 0000000000000000000000000000000000000000..840ea53a262a6d312b32178b06a6a3e5932fca88
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/pkk.txt
@@ -0,0 +1,3 @@
+-2.63017 -5.53715 -1101.86
+-2.06413 -10.0132 1466.65
+-0.00338896 -0.00999799 -2.34239
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/shop.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/shop.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2cb0af6b4b1973004cc5f3b2ef9d5221ccc0f52a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/shop.txt
@@ -0,0 +1,3 @@
+-0.0379327 -0.792592 -2778.33
+2.80747 -11.5896 919.535
+0.00281406 0.000153738 -8.54098
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/there.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/there.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5c2031cce2da4817c9a8d122cd4391452d3f6c13
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/there.txt
@@ -0,0 +1,3 @@
+0.314825 0.115834 690.506
+ 0.175462 0.706365 14.4974
+0.000267118 0.000126909 1.0
diff --git a/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/vin.txt b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/vin.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ea4b033024b230a7aadff1bd8034cf1194b10122
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.EVD/EVD/h/vin.txt
@@ -0,0 +1,3 @@
+4.90207 -19.4931 2521.64
+7.0392 -28.4826 5753.82
+0.00953653 -0.0350753 4.24836
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/README.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/README.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fccf33fdccf7d4ec7c55b62f82e556ec47b0ead4
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/README.txt
@@ -0,0 +1,47 @@
+Welcome to WxBS version 1.1 -- Wide (multiple) Baseline Dataset.
+
+It contains 34 very challenging image pairs with manually annotated ground truth correspondences.
+The images are organized into several categories:
+
+- WGALBS: with Geometric, Appearance and iLlumination changes
+- WGBS: with Geometric (viewpoint) changes
+- WLABS: with iLlumination and Appearance changes. The viewpoint change is present, but not significant
+- WGSBS: with Geometric and Sensor (thermal camera vs visible) changes
+- WGABS: with Geometric and Appearance changes.
+
+Compared to the original dataset from 2015, v.1.1 contains more correspondences, which are also cleaned, and 3 additional image pairs: WGALBS/kyiv_dolltheater, WGALBS/kyiv_dolltheater2, WGBS/kn-church.
+We also provide cross-validation errors for each of the GT correspondences.
+They are estimated in the following way:
+
+- the fundamental matrix F is estimated with OpenCV 8pt algorithm (no RANSAC), using all points, except one.
+F, _ = cv2.findFundamentalMat(corrs_cur[:,:2], corrs_cur[:,2:], cv2.FM_8POINT)
+Then the symmetrical epipolar distance is calculatd on that held-out point. We have used kornia implementation of the symmetrical epipolar distance:
+
+
+From Hartley and Zisserman, symmetric epipolar distance (11.10)
+sed = (x'^T F x) ** 2 / (((Fx)_1**2) + (Fx)_2**2)) + 1/ (((F^Tx')_1**2) + (F^Tx')_2**2))
+
+https://kornia.readthedocs.io/en/latest/geometry.epipolar.html#kornia.geometry.epipolar.symmetrical_epipolar_distance
+
+
+The labeling is done using [pixelstitch](https://pypi.org/project/pixelstitch/)
+
+There are main intended ways of using the dataset.
+a) First, is evaluation of the image matchers, which are estimating fundamental matrix. One calculates reprojection error on the GT correspondences and report mean error, or the percentage of the GT correspondences, which are in agreement with the estimated F. For more details see the paper[1]
+
+b) For the methods like [CoTR](https://arxiv.org/abs/2103.14167), which look for the correspondences in the image 2, given the query point in image 1, one can directly calculate error between returned point and GT correspondence.
+
+
+***
+If you are using this dataset, please cite us:
+
+[1] WxBS: Wide Baseline Stereo Generalizations. D. Mishkin and M. Perdoch and J.Matas and K. Lenc. In Proc BMVC, 2015
+
+@InProceedings{Mishkin2015WXBS,
+ author = {{Mishkin}, D. and {Matas}, J. and {Perdoch}, M. and {Lenc}, K. },
+ booktitle = {Proceedings of the British Machine Vision Conference},
+ publisher = {BMVA},
+ title = "{WxBS: Wide Baseline Stereo Generalizations}",
+ year = 2015,
+ month = sep
+}
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..0066f66d4bd858d6c5b47955fe4763a19eb60dd7
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..69dd5333685a7d65278630e65d7b07f02fc701d5
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:850e7ce7de02bf28d5d3301d1f5caf95c6c87df0f8c0c20c7af679a1e6a1dd8e
+size 845171
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..350c4b1d0559fc2fb78734d4f782073b7e685c15
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8b31024c1d52187acb022d24118bc7b37363801d7f8c6132027ecbb33259bc9e
+size 615221
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a52583361953fd2c94cb5e3477ce5fe806a56a08
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/corrs.txt
@@ -0,0 +1,49 @@
+4.907274400000000014e+02 3.917342500000000172e+01 4.622376800000000117e+02 1.865405900000000017e+01
+4.199036600000000021e+02 1.615936800000000062e+02 3.848875699999999824e+02 1.613663899999999956e+02
+4.948897400000000175e+02 1.302602300000000071e+02 4.316474000000000046e+02 1.558661099999999919e+02
+5.940666800000000194e+02 1.390053499999999929e+02 5.624503499999999576e+02 1.447967299999999966e+02
+4.473480099999999879e+02 2.174977000000000089e+02 3.973540699999999788e+02 2.388423799999999915e+02
+5.604767500000000382e+02 2.211028900000000021e+02 5.063893199999999979e+02 2.504398300000000006e+02
+1.570469300000000032e+02 2.428232000000000141e+02 1.659867399999999975e+02 1.769992300000000114e+02
+1.283247700000000009e+02 2.466919000000000040e+02 1.308167500000000132e+02 1.806334600000000137e+02
+1.856345599999999934e+02 2.481083900000000142e+02 1.985409400000000062e+02 1.834373699999999872e+02
+3.945598299999999767e+02 3.011998500000000263e+02 3.165450599999999781e+02 3.465213200000000029e+02
+4.044084700000000225e+02 2.958942599999999743e+02 3.425922899999999913e+02 3.308400700000000256e+02
+4.486278899999999794e+02 4.079485399999999800e+02 3.311905699999999797e+02 4.730156900000000064e+02
+4.880407599999999775e+02 4.070106700000000046e+02 3.615463899999999740e+02 4.750086600000000203e+02
+8.308690199999999493e+01 2.656556800000000180e+02 6.541268700000000536e+01 2.314867700000000070e+02
+1.554153399999999863e+02 3.250390100000000189e+02 1.665729100000000074e+02 2.826928100000000086e+02
+4.981625900000000229e+02 2.744976300000000151e+02 4.228344999999999914e+02 3.175723100000000159e+02
+4.522427599999999757e+02 2.984249800000000050e+02 3.946770700000000147e+02 3.311724100000000135e+02
+6.605721899999999778e+02 1.820804599999999880e+02 5.863873800000000074e+02 2.251841799999999978e+02
+4.103960099999999898e+02 3.502666500000000269e+02 3.240082199999999943e+02 4.019584899999999834e+02
+5.709502800000000207e+02 2.934323600000000170e+02 5.069851699999999823e+02 3.328039899999999989e+02
+1.218060499999999990e+02 2.819134199999999737e+02 1.248734800000000007e+02 2.297389900000000011e+02
+4.939615600000000200e+02 3.342423400000000129e+02 4.141000999999999976e+02 3.836088899999999740e+02
+3.687382882065904823e+02 3.495395049102415328e+02 2.896320927956163587e+02 4.006195633380069694e+02
+3.910059262681439236e+02 3.707051060659499058e+02 3.076935963192913732e+02 4.208602672369242441e+02
+3.975209120765460398e+02 3.629260185335294295e+02 3.133447415935700633e+02 4.132146001011354315e+02
+3.849771334305181085e+02 3.630232571276847011e+02 3.022640645851805061e+02 4.129929865609676654e+02
+5.620418645411094758e+02 3.036978250872198828e+02 4.978339578685253741e+02 3.451610927493376835e+02
+5.528727291807808797e+02 2.946790034213228182e+02 4.903915169107766019e+02 3.335120547285135331e+02
+5.608393549856565414e+02 2.664200288681786333e+02 5.007462173737313833e+02 3.035804987027848370e+02
+5.969146416492446861e+02 2.474805025813681141e+02 5.547848104147767572e+02 2.742961106073569795e+02
+6.084887961204792646e+02 2.737853991069012523e+02 5.619036669830582014e+02 3.040658744383520116e+02
+5.981171512046976204e+02 2.816017112173453825e+02 5.513871743253697559e+02 3.152295358749751131e+02
+4.966124691639038247e+02 2.254211022753156897e+02 4.251663048820389577e+02 2.657820691066586960e+02
+4.968032309070188717e+02 2.326700485136870213e+02 4.251663048820389577e+02 2.722625024200451662e+02
+5.078674120076908594e+02 2.578505986048716068e+02 4.311067020859765080e+02 3.022345064944572641e+02
+5.078674120076908594e+02 2.626196421827474978e+02 4.313767201407009111e+02 3.076348675889458946e+02
+4.861205732925768643e+02 2.631919274120925820e+02 4.100452938174706787e+02 3.065547953700481685e+02
+5.143533112736021167e+02 2.296178606238464681e+02 4.443375867674737378e+02 2.692923038180763342e+02
+5.048152241178503346e+02 2.403005182382884186e+02 4.316467381954253710e+02 2.809030801712269749e+02
+4.905080933842227182e+02 1.445381231945409013e+02 4.521681103544822804e+02 1.521044680676722862e+02
+5.090119824663810846e+02 1.451104084238859855e+02 4.699893019662949314e+02 1.521044680676722862e+02
+5.154978817322922851e+02 1.496886902586468295e+02 4.845702769214142904e+02 1.526445041771211493e+02
+4.811607679715859831e+02 2.319070015412268617e+02 4.127454743647149940e+02 2.682122315991786081e+02
+1.659310260943734363e+02 1.904096720602631763e+02 1.750965926567237716e+02 1.086081899286738235e+02
+1.514724884758092287e+02 1.897717954006206185e+02 1.568970252261333371e+02 1.083940773706668779e+02
+1.642300216686599867e+02 1.191801117335130584e+02 1.731695796346612610e+02 9.259963013450544622e+00
+9.197711414192471580e+01 2.817466503197928205e+02 7.713222472068414959e+01 2.509560380918086082e+02
+8.139120388717148558e+01 2.891666808721899997e+02 6.421399252360467358e+01 2.628805601198819772e+02
+6.942022126263745463e+01 2.819445178011901021e+02 5.052287463952045243e+01 2.516185115378127080e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..897f08d9475b367a75f4c0366d8dc21d69560ab4
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kremlin/crossval_errors.txt
@@ -0,0 +1,49 @@
+5.946305398998352754e-01
+6.151500492287013122e-01
+4.435131637366677704e-01
+8.607236886172928703e-01
+1.119487091546507093e-02
+1.842455765244689070e-01
+1.139326684695094460e+00
+1.741695769999793031e+00
+1.199652052219880360e+00
+1.070740677484232517e+00
+4.284088135921398921e-02
+2.060204916077549853e-01
+4.055000220707856706e-01
+6.973383799953524198e-01
+2.092896295374734983e+00
+1.402160800861952161e+00
+1.016634654724758224e+00
+1.492331175763036866e+00
+2.803261140423741193e-01
+1.549618715189659435e+00
+2.228131674171225374e+00
+1.047571352924588028e+00
+4.280705770200255778e+00
+1.020399628357264721e-01
+5.543407989816054871e-01
+4.932068148368541627e-02
+7.650761824905847330e-01
+5.087262963715805109e-01
+1.022079507471783000e+00
+8.715133918375695954e-01
+2.017157710400199200e+00
+2.755255421845337782e+00
+1.540080127238647512e+00
+1.175615709542446596e-01
+1.631398780409486493e-01
+5.850237557411903655e-01
+8.175730992894985061e-01
+1.138357037213331369e+00
+8.535338399374122753e-01
+2.322429340540620224e+00
+3.717534098826499322e-01
+1.258508672327652844e+00
+9.834160308617387880e-01
+1.191211366876187006e+00
+3.598025256102476699e-01
+1.039241620045164405e+00
+2.476248683496967473e+00
+2.857957423093941074e+00
+3.709824113185862249e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..102b10b6b368dc04440b2bf5da2a12be205b00e9
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e79c3270fd5365f355307e7e944f05dfdc43691f87b5aff4506607db42c2a56
+size 616896
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..a4f26819f208e9693f6dd59ec478fe1d10143837
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d7e20317377179b99d143d6197f0647d5882edfba14c326f354fe33ee912362
+size 523819
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1717b77ae6409871a49fb834330096a9d4a1e7e8
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/corrs.txt
@@ -0,0 +1,31 @@
+8.218366899999999475e+01 1.530866499999999917e+02 2.595849000000000046e+02 2.658549699999999802e+02
+2.744326800000000048e+02 1.119666299999999950e+02 3.892155200000000264e+02 2.346591600000000142e+02
+2.210166600000000017e+02 1.117554300000000040e+02 3.450251400000000217e+02 2.343074599999999919e+02
+2.451667199999999980e+02 1.114037299999999959e+02 3.629618300000000204e+02 2.341902200000000107e+02
+3.003384300000000167e+02 1.625225299999999891e+02 4.159758299999999736e+02 2.712874699999999848e+02
+2.449410799999999995e+02 1.619131099999999890e+02 3.759574400000000196e+02 2.722447000000000230e+02
+2.866947200000000180e+02 3.839937800000000152e+02 5.190999699999999848e+02 4.356803499999999758e+02
+3.922720800000000168e+02 1.174523999999999972e+02 4.723207800000000134e+02 2.384106199999999944e+02
+3.775006799999999885e+02 1.309342300000000137e+02 4.621214800000000196e+02 2.484926900000000103e+02
+3.396873899999999935e+02 1.175705700000000036e+02 4.378811600000000226e+02 2.388010600000000068e+02
+5.545821399999999812e+01 1.701533900000000017e+02 2.369180700000000002e+02 2.824644000000000119e+02
+9.402142899999999770e+01 1.694848800000000040e+02 2.677825799999999958e+02 2.776578400000000215e+02
+1.542611699999999928e+02 1.653593899999999906e+02 3.074880099999999743e+02 2.772470000000000141e+02
+2.136086400000000083e+02 1.141107899999999944e+02 3.302527099999999791e+02 2.362122400000000084e+02
+2.361495099999999923e+02 3.121274900000000230e+02 4.639961799999999812e+02 3.821199599999999919e+02
+2.772462800000000129e+02 2.915795699999999897e+02 4.827351699999999823e+02 3.661461400000000026e+02
+8.007346900000000289e+01 1.227232300000000009e+02 2.575919299999999907e+02 2.435806399999999883e+02
+3.099508300000000105e+01 9.362195800000000645e+01 2.155214400000000126e+02 2.217527800000000013e+02
+2.801296500000000265e+02 3.661743200000000229e+02 5.140589400000000069e+02 4.218468199999999797e+02
+2.433183999999999969e+02 3.944275499999999965e+02 5.046802700000000073e+02 4.427143500000000245e+02
+6.178881200000000007e+01 1.362729699999999866e+02 2.441865300000000047e+02 2.544456500000000005e+02
+2.756259297043956735e+02 3.740927471033548954e+02 5.114641728026981013e+02 4.273074362362154943e+02
+8.221162108383408906e+01 4.056540777987604542e+02 4.373557405352174214e+02 4.522597703330104650e+02
+7.411897218757627570e+01 4.113189320261409421e+02 4.341119371026341014e+02 4.560026204475296936e+02
+1.251026602340004672e+02 3.761159093274194447e+02 4.495823842426468673e+02 4.300521929868630195e+02
+1.360277362439485103e+02 3.684278928759745213e+02 4.538242810391020612e+02 4.243131561446001569e+02
+3.915916690636100839e+02 1.310116728485639044e+02 4.712393946580673401e+02 2.479437873807888479e+02
+3.764727575077063193e+02 1.158927612926601398e+02 4.608862906124819574e+02 2.377897814899262414e+02
+3.920793758879941038e+02 1.617372027847554818e+02 4.792033208469791816e+02 2.708400751739104635e+02
+4.655489607273053707e+02 1.858487678854069998e+02 5.446299327018380154e+02 2.889468305880199068e+02
+4.021047738999301373e+02 2.268103059826929666e+02 5.263277510484970207e+02 3.166409212476804669e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b8e8634a7e03b482b4790715a0f21e6220656125
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/kyiv/crossval_errors.txt
@@ -0,0 +1,31 @@
+1.065223937182152358e+00
+5.736516321493103643e-01
+4.019874183288125180e-01
+6.051618580623756294e-02
+1.331781109861698509e+00
+4.707776137115839421e-01
+1.562036957800347237e+00
+5.946548909007817185e-01
+7.176612972602752771e-01
+1.036625416881644890e+00
+4.840386519356732364e+00
+3.143029902439770762e+00
+3.486454358398270337e+00
+1.060173865479107747e+00
+5.061210528868288067e-02
+1.198614553697012336e-01
+1.133641006763047798e+00
+1.628127580495658533e+00
+1.076465803298405577e+00
+6.493052438943768268e-01
+1.691393384139549205e+00
+6.172534509656668611e-01
+5.961055533880751378e-01
+2.768433902281529857e+00
+1.021596472702978486e+00
+1.800412927080836445e+00
+3.333198235820573618e-01
+1.588410312156819382e+00
+2.664948685440807763e-01
+3.174975276305068483e+00
+2.025259979974938673e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..548bcc64b3b7d0b5f025e10a38dc1167ee0a1033
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e263b6c695097b44806b17f0b76cc0ae0f92fe3909293cc78755e3caecf5d0b
+size 1452608
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..26992cc042c17e10c1c3bdc07fa862327e734978
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:920a6f7a68df8766dd8a4ed3f733d67b6432770842fb606467a6f58276afdea1
+size 1368978
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c220d9b509c047ca449a4a4b6f3121b06af474f3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/corrs.txt
@@ -0,0 +1,38 @@
+5.939102000000000317e+02 2.418589599999999962e+02 4.982037500000000136e+02 1.194755399999999952e+02
+6.916765400000000454e+02 2.234832499999999982e+02 6.342270399999999881e+02 1.034648899999999969e+02
+4.017696900000000255e+02 4.162991499999999974e+02 4.903268100000000231e+02 3.427198700000000144e+02
+3.992128599999999778e+02 5.331783199999999852e+02 4.892717099999999846e+02 5.246871999999999616e+02
+3.671711700000000178e+02 5.252994800000000168e+02 4.202384299999999939e+02 5.129918800000000374e+02
+2.765707600000000070e+02 4.378662699999999859e+02 1.672225400000000093e+02 3.733789699999999812e+02
+5.865748200000000452e+02 3.992576099999999997e+02 5.190209599999999455e+02 3.079799199999999928e+02
+5.466300899999999956e+02 3.986249399999999810e+02 4.676822900000000232e+02 3.068783000000000243e+02
+6.762335600000000113e+02 3.514155400000000213e+02 6.263054899999999634e+02 2.511592899999999986e+02
+3.721860899999999788e+02 2.715075699999999870e+02 2.301838999999999942e+02 1.436160800000000108e+02
+4.851462399999999775e+02 3.550274699999999939e+02 3.759133499999999799e+02 2.528731799999999907e+02
+6.853274400000000242e+02 2.796412799999999947e+02 6.251331599999999753e+02 1.695877499999999998e+02
+3.793781299999999987e+02 4.288431100000000242e+02 4.532810900000000061e+02 3.627667700000000082e+02
+4.257075899999999820e+02 5.276683500000000322e+02 5.130700699999999870e+02 5.100330400000000282e+02
+3.989784000000000219e+02 5.130141899999999850e+02 4.866925800000000208e+02 4.951444099999999935e+02
+3.215433800000000133e+02 3.185416200000000231e+02 1.649607300000000123e+02 2.048264900000000068e+02
+8.231494699999999511e+02 3.366846400000000017e+02 8.477356399999999894e+02 2.366834699999999998e+02
+8.274871000000000549e+02 2.987010599999999840e+02 8.561764399999999569e+02 1.924865200000000129e+02
+5.764467172506857651e+02 2.733143446995878207e+02 4.803264665374170477e+02 1.560546913685867025e+02
+5.413348219803513075e+02 2.441738601895142722e+02 4.316845558332415749e+02 1.199110493870118717e+02
+5.738192965161708798e+02 2.494287016585439289e+02 4.762729739787357630e+02 1.276802434578176815e+02
+4.947578180503157341e+02 2.909897205499603388e+02 3.837182272221796211e+02 1.756465720688796068e+02
+5.071783524316585385e+02 3.748283276240243822e+02 4.097281378070512119e+02 2.790106323152525647e+02
+5.723861579337083185e+02 2.897954383979081285e+02 5.043096308429479677e+02 1.749709899757660594e+02
+6.676898736674735346e+02 2.924228591324229001e+02 6.120649746945589413e+02 1.820646019534583218e+02
+3.377939878780084655e+02 3.703309743465495671e+02 1.918655051113914283e+02 2.736075751686765898e+02
+3.210982358910493417e+02 3.827509849709947503e+02 1.703373326379488049e+02 2.901381361750700307e+02
+2.563512952586955862e+02 3.888591869174432532e+02 7.230726155352260776e+01 2.982112008526110003e+02
+3.756651435405149186e+02 4.833721106941258654e+02 4.502639317766756903e+02 4.503528604518788825e+02
+3.910914280535924945e+02 5.100965190759362713e+02 4.729039071343397609e+02 4.902707117403917891e+02
+3.956541319236577010e+02 5.083583461730543149e+02 4.776702177359532584e+02 4.843128234883749315e+02
+3.895705267635707969e+02 4.492604674750670029e+02 4.737975903721422810e+02 3.952423941207230200e+02
+3.889951219559405331e+02 4.280560699476575905e+02 4.665334998202267798e+02 3.603993349145116554e+02
+2.811901416325709420e+02 4.424291250164491203e+02 1.737224299407419039e+02 3.796104651314472562e+02
+2.423185076331632786e+02 4.371700215929998308e+02 1.180053812948370933e+02 3.731101427894250264e+02
+2.720438748091809202e+02 4.632368820396614524e+02 1.656744118030001118e+02 4.102548418866948623e+02
+2.715865614680114390e+02 4.502034518163306984e+02 1.638171768481366257e+02 3.904443357014843059e+02
+2.503214911036295973e+02 4.671240454396022415e+02 1.325537217746011436e+02 4.173742425470049398e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3ce5479cdf0f30db0390c01acb5ac9e487c86724
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/petrzin/crossval_errors.txt
@@ -0,0 +1,38 @@
+1.792424591998240890e+00
+3.095348507057420928e+00
+4.094406044190157523e-01
+7.302793957725348672e-01
+4.238650623157072528e-01
+4.165253741566223411e+00
+1.883763528268577736e-01
+8.053748309044221898e-01
+1.595531486099816265e-01
+1.697912170300142742e+00
+8.506116392505531643e-01
+3.122924043029775554e+00
+7.353410671850451052e-01
+8.858066508520040516e-01
+2.139081656124904196e+00
+3.931447770611433690e+00
+3.191792767739562997e+00
+3.623124474113926574e+00
+1.426124920876860269e+00
+1.237541469506047598e+00
+8.962695327020072655e-01
+1.530807091557143451e+00
+1.467530857533184907e+00
+1.121653280350474846e+00
+9.875670003547415421e-01
+2.657680183199366830e+00
+2.536812969979429511e+00
+2.003125655478974920e+00
+1.223910714981061520e+00
+9.496930900765404582e-01
+2.299513490038819441e+00
+4.825846860592630100e-03
+1.709008522450382817e+00
+4.024040354942465036e+00
+2.351401261463401671e+00
+2.432080444353363458e+00
+1.727733901591856469e+00
+1.847889070782418930e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9f3507f029d45e4861a98c323f9f3b1b95ba255
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5ef0c6fb8618c1cf867df09c6f6bcfef0e9c85b3e8164ecd2a865da34bf3027
+size 553707
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..d7145c3d436cb1842459cf55db663758c9e88c73
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ec6c1ba02df41d23ae4761ed875969c80e154dcd303e69640ba9f6de4fff7ba
+size 608585
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..769ed6a898959ff816ae096a10846687fd933f85
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/corrs.txt
@@ -0,0 +1,28 @@
+4.097915199999999913e+02 1.833357399999999870e+02 6.136951299999999492e+02 2.662372500000000173e+02
+3.972475499999999897e+02 1.823978700000000117e+02 5.958756700000000137e+02 2.663544800000000237e+02
+2.113397699999999872e+02 2.108091400000000135e+02 3.278392200000000116e+02 2.986441300000000183e+02
+3.930648499999999785e+02 2.325574699999999950e+02 5.847751600000000280e+02 3.389658099999999763e+02
+4.474652500000000259e+02 2.239596699999999885e+02 6.665026699999999664e+02 3.319995900000000120e+02
+2.731753100000000245e+02 2.438032200000000103e+02 4.108272499999999923e+02 3.488628499999999804e+02
+5.092126000000000090e+02 2.347558500000000095e+02 7.640383399999999483e+02 3.494393400000000156e+02
+1.692692700000000059e+02 2.209019199999999898e+02 1.972320100000000025e+02 3.109740199999999959e+02
+4.040470799999999940e+02 2.020930700000000115e+02 6.080679300000000467e+02 2.948421700000000101e+02
+1.872901500000000112e+02 1.894243800000000135e+02 3.020479000000000269e+02 2.720321700000000078e+02
+2.597743300000000204e+02 2.072379799999999932e+02 3.954828400000000101e+02 2.965339299999999980e+02
+2.937157100000000014e+02 2.189660000000000082e+02 4.434426500000000146e+02 3.158611599999999839e+02
+3.841037499999999909e+02 2.189624900000000025e+02 5.730518299999999954e+02 3.188016799999999762e+02
+4.726703999999999724e+02 2.223183999999999969e+02 7.063619899999999916e+02 3.270757899999999836e+02
+2.800920800000000099e+02 2.620916100000000029e+02 4.103583199999999920e+02 3.753575799999999845e+02
+4.936205800000000181e+02 2.189293600000000026e+02 7.385987199999999575e+02 3.250548099999999749e+02
+2.900925279496414078e+02 2.417676896878288915e+02 4.357913877379355085e+02 3.470031760735771513e+02
+3.921982208998750252e+02 2.025229845743038197e+02 5.874506446412149216e+02 2.962461188747516303e+02
+3.710130261040782216e+02 2.011337914729400893e+02 5.575755385475616777e+02 2.931013708648933971e+02
+4.272753467093090194e+02 2.219716879933959319e+02 6.358011452927852361e+02 3.269074119708694184e+02
+3.553846037137363396e+02 2.283047722814643521e+02 4.919289238417711658e+02 3.302140225398936764e+02
+3.828211674656698733e+02 2.285907842234612701e+02 5.292728064588377492e+02 3.313933030435904925e+02
+1.417557024889432000e+02 2.020893680292453212e+02 2.413374582723404842e+02 2.847164876775290736e+02
+1.349735569233231445e+02 1.758650718421810666e+02 1.759893184350587489e+02 2.526156470557064608e+02
+4.226454386301443265e+02 2.434042994222771199e+02 6.254660433374846207e+02 3.575691427733916044e+02
+4.158557961245051047e+02 2.434042994222771199e+02 6.163289032236831417e+02 3.587112852876167608e+02
+4.084757499227233097e+02 2.391499171974719218e+02 6.052881922528397354e+02 3.505595196120268611e+02
+3.981436852402288196e+02 2.373787061090442876e+02 5.900596253965038613e+02 3.471330920693512780e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..03e080f6eefd0d10531cfc16ecfaa0d16bb932a3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/strahov/crossval_errors.txt
@@ -0,0 +1,28 @@
+2.795827317827273895e+00
+7.504814977516608421e-01
+2.358292759495422164e+00
+1.292287445027972881e+00
+2.551228559052934219e+00
+1.430468433774237103e-01
+1.822288632049907875e+00
+2.256415147302962332e+00
+4.115527888980367588e-01
+2.543433883252152139e+00
+9.771310714237624317e-01
+1.339901065391976509e+00
+5.753504995346042650e-01
+2.368484738426400060e+00
+8.237638325866948330e-01
+4.048545226987346202e-01
+2.386360497959467142e-01
+1.256918929674968322e+00
+7.735100630893360085e-01
+1.320422016321886716e+00
+7.705570551665671397e-01
+2.159805887096448274e+00
+2.559822007875808048e+00
+4.642055664900872181e+00
+9.192902876777044874e-01
+1.471532469196588755e+00
+2.965823050434360231e-01
+4.094884609445577084e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..88f10f54a598c0677cdd8c83f6cf1c982153de44
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc4ee14c70d6fb6bc44a0be7507f90b62c67773d93afc75a4f586d3cd01f6108
+size 917944
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9339231ab07fa31f4bb302d50e8ddc68ad7c591
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0ffbab080ab2c02eb02c03e7128697c49b9c6cd13f8c99d72263379ca9c5ff83
+size 565060
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4ada597fe23f7f3ffb92f7874c0f6399f709efca
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/corrs.txt
@@ -0,0 +1,44 @@
+3.259731699999999819e+02 3.123227999999999724e+02 1.918973599999999919e+02 3.167674200000000155e+02
+4.166598599999999806e+02 3.097243000000000279e+02 2.435228899999999896e+02 3.129422900000000141e+02
+4.208318499999999744e+02 3.786820200000000227e+02 2.431711899999999957e+02 3.723449800000000209e+02
+3.264614700000000198e+02 3.788186200000000099e+02 1.913160300000000120e+02 3.726870000000000118e+02
+3.147381399999999871e+02 3.924176800000000185e+02 1.795927000000000078e+02 3.835896999999999935e+02
+4.289209500000000048e+02 3.902881199999999922e+02 2.436401199999999960e+02 3.814891799999999762e+02
+3.367984099999999899e+02 2.758902199999999993e+02 2.027032199999999875e+02 2.864261300000000006e+02
+4.080620700000000056e+02 2.735939700000000130e+02 2.440993699999999933e+02 2.813162800000000061e+02
+3.733191699999999855e+02 1.544756099999999890e+02 2.283057300000000112e+02 1.825394499999999880e+02
+3.746969199999999773e+02 1.133440200000000004e+02 2.360974199999999996e+02 1.435132199999999898e+02
+3.700463199999999802e+02 2.459151500000000112e+02 2.259659100000000080e+02 2.595727899999999977e+02
+3.808317799999999806e+02 2.440394200000000069e+02 2.326482000000000028e+02 2.576970600000000218e+02
+4.701611300000000142e+02 4.732083600000000274e+02 2.428049599999999941e+02 4.495536500000000046e+02
+5.705985799999999699e+02 5.078243400000000065e+02 1.030767199999999946e+02 4.655914399999999773e+02
+3.791130499999999870e+02 4.844928400000000011e+02 7.086453899999999351e+01 4.475752100000000269e+02
+2.952947100000000091e+02 4.967646399999999858e+02 1.516005400000000058e+01 4.524554299999999785e+02
+3.762500000000000000e+02 3.383506600000000049e+02 2.201671800000000019e+02 3.373724599999999896e+02
+4.003322800000000257e+02 4.686663500000000226e+02 1.770076899999999966e+02 4.441851300000000151e+02
+4.388674500000000194e+02 4.084516699999999787e+02 3.164035999999999831e+02 3.992031600000000253e+02
+4.069875999999999863e+02 2.193011199999999974e+02 2.671130600000000186e+02 2.328415200000000027e+02
+3.385181699999999978e+02 1.586970499999999902e+02 2.117699599999999975e+02 1.866726999999999919e+02
+3.301161099999999919e+02 2.897237499999999955e+02 1.969587899999999934e+02 2.989700900000000274e+02
+3.902007600000000025e+02 1.570547500000000127e+02 2.436632899999999893e+02 1.806637200000000121e+02
+3.722350200000000200e+02 1.264741499999999945e+02 2.335182900000000075e+02 1.539469799999999964e+02
+3.896441404448885351e+02 1.684838064643168423e+02 2.422721730666990538e+02 1.898239388128006908e+02
+4.044041567098074665e+02 1.982381249666137535e+02 2.636038469030721672e+02 2.123862861397337838e+02
+3.938612879491511194e+02 1.373237721272657268e+02 2.508868875006189967e+02 1.619286730267743053e+02
+3.727755504278383114e+02 1.384952019895608828e+02 2.318114483969391699e+02 1.654155812500275999e+02
+3.610612518048868083e+02 1.661409467397265303e+02 2.197098257397659609e+02 1.910546123033606705e+02
+3.111402209243742618e+02 4.100159911502355499e+02 1.790477254239340255e+02 3.982068908671177496e+02
+4.722420884743182796e+02 4.605639990494716471e+02 3.538258560724232211e+02 4.464215475977354117e+02
+4.315213324560577917e+02 4.083185007618920963e+02 2.414684863698230401e+02 3.977764028605943167e+02
+4.345945970612095266e+02 3.925680196604894832e+02 3.090551033939925105e+02 3.827093226322763257e+02
+7.209615812075318217e+02 5.667660675231618370e+02 7.917544657994415047e+01 5.005743734609804392e+02
+6.606642064527560478e+02 5.493360788255572515e+02 5.787154836054548923e+01 4.867895001037628049e+02
+6.609687386484872604e+02 5.669989461779662179e+02 5.745990518108130374e+01 4.981376862110609522e+02
+7.197434524246070850e+02 5.520768685871379375e+02 7.959145443812062126e+01 4.914611294519540934e+02
+2.786691872047798029e+02 4.833365024492061366e+02 3.202105196908468088e+01 4.448035153264768269e+02
+5.282652496148941736e+02 5.076146177348316542e+02 8.452211168355722748e+01 4.630119753430569176e+02
+3.647074203222595088e+02 5.080405495819479711e+02 2.595156529689131730e+01 4.593702833397409222e+02
+3.460325768818356664e+02 3.378550919150705454e+02 2.019291652358627971e+02 3.379567249800489890e+02
+3.951785592414013308e+02 3.276982555607603445e+02 2.298960110465138200e+02 3.285692522603899306e+02
+3.938679997118128995e+02 3.548923657997199825e+02 2.287225769565564519e+02 3.514512170145588925e+02
+3.506195352353951193e+02 3.240942168543921866e+02 2.048627504607562742e+02 3.270046734737800307e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..73e5fbcaa5d23546a11a3efa73431eb8c9fcec40
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGABS/vatutin/crossval_errors.txt
@@ -0,0 +1,44 @@
+1.616226771291696451e+00
+1.754369060755345267e-02
+9.779284929175396934e-01
+6.576697555919837068e-01
+3.066270550408434770e-01
+1.612998834698591566e-01
+4.201040675301740968e-01
+8.155385261083895054e-01
+3.452953469835793321e+00
+8.020687774441207507e-01
+3.442170255501640352e-01
+8.943909552976530009e-01
+2.589561677756893943e+00
+2.425722120341224919e+00
+7.024860799446467352e-01
+4.603193840964353578e-01
+1.342611948197300231e+00
+1.205190859321352503e+00
+2.079690518854804271e+00
+2.371458487568577578e+00
+8.220301036726821442e-01
+6.511991988937511078e-01
+6.705641250402661901e-01
+1.972532867873080908e+00
+1.110895732960297089e+00
+1.907553063517399838e+00
+6.222314791948166945e-01
+8.391836101199339204e-01
+1.321392328191198562e+00
+1.251759274178413817e+00
+8.858787342394179865e-01
+1.676465379553625290e+00
+1.921688964633968766e+00
+2.509639115201356852e+00
+1.290504458418172407e+00
+8.185896321589704039e-01
+4.325984259078770044e+00
+9.623694333738862516e-01
+4.290098346935177220e-01
+1.840930682367824200e+00
+7.234322594695873354e-01
+5.028036646473726945e-01
+6.933085218108160364e-01
+2.703262904684516910e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..abdf1339cdec7cf4dbd289a9c5df8eb4a976cc4b
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b4dea47bd01e2c4104b4f6caf1244aed61877f7
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc091548ebd8d946965f745cd3688ba6bc4494dec3c1098ab504e5187284e393
+size 865043
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..1c94dc870ee672c88df769e6ec1ea0e313202499
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e56dea63ce9ecbf803681654d917245299884c81f46a865efdfd57615c1efea
+size 581008
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9a58a592c33b046940d869b241be0819a7499221
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/corrs.txt
@@ -0,0 +1,29 @@
+2.538888800000000003e+02 4.090991300000000024e+02 9.853851099999999974e+01 5.648856100000000424e+02
+4.446807000000000016e+02 3.349336900000000128e+02 2.318818599999999890e+02 5.031479400000000055e+02
+4.724476900000000228e+02 4.192500200000000063e+02 2.566540499999999838e+02 5.702202499999999645e+02
+5.367548199999999952e+02 1.505413000000000068e+02 2.976936600000000226e+02 3.666731800000000021e+02
+6.043347899999999981e+02 2.343101900000000057e+02 3.508577399999999784e+02 4.230219299999999976e+02
+6.183655300000000210e+01 3.329600899999999797e+02 3.854480900000000076e+01 4.845407000000000153e+02
+4.316377100000000269e+02 3.135661099999999806e+02 1.981795400000000029e+02 4.685077499999999873e+02
+3.278198600000000056e+02 3.057405299999999784e+02 1.495681899999999871e+02 4.782585100000000011e+02
+5.888589600000000246e+02 3.169766000000000190e+02 3.457091599999999971e+02 4.851171800000000189e+02
+1.975148900000000083e+02 2.537826699999999960e+02 8.469426300000000651e+01 4.480296200000000226e+02
+4.739104947356379398e+02 3.122058046065272379e+02 2.327137685012712893e+02 4.752906051208748295e+02
+2.660098803478572336e+02 3.302083717741842861e+02 1.072749444745419396e+02 5.030318835114017020e+02
+2.654291523747069732e+02 3.964113607133099890e+02 1.093856939172994203e+02 5.476591574439880787e+02
+4.345871269560143446e+02 2.563627422044701802e+02 1.969468277504925879e+02 4.412749581862206014e+02
+4.596464188883149973e+02 2.459448343225024871e+02 2.062400571471764863e+02 4.357385662052174666e+02
+4.576754633430778654e+02 2.848008150714630347e+02 2.068332420022839813e+02 4.513591007230478453e+02
+4.796375394185772620e+02 2.369347518299899207e+02 2.117764491281796779e+02 4.303999025092500688e+02
+4.773850187954491275e+02 2.724119516442582380e+02 2.119741774132155001e+02 4.460204370270804475e+02
+5.055415265845509225e+02 2.225749328575479637e+02 2.179060259642903361e+02 4.260498802384619239e+02
+5.122990884539353829e+02 2.329928407395156569e+02 2.226515048051501822e+02 4.260498802384619239e+02
+4.408767197835879870e+02 4.253480886759214172e+02 2.329805935534139110e+02 5.728758021914462688e+02
+4.960907211153821663e+02 4.011672450316198706e+02 2.494870383241641889e+02 5.275034761405271411e+02
+1.316883174179574780e+02 3.062138219672158357e+02 6.982777892872408643e+01 4.722154469113214645e+02
+4.400374437744409875e+02 2.934501406239173775e+02 1.996894706097435801e+02 4.580269697371936672e+02
+2.510945103708573924e+02 3.324661813641471326e+02 9.511189071275245510e+01 5.078397482162217216e+02
+3.239784450498666502e+02 3.717378655408519421e+02 1.480045826113571650e+02 5.162190666884113170e+02
+4.210246903248170156e+02 1.588813566390604421e+02 1.849670454514543110e+02 4.016497702999811850e+02
+4.220810507607419026e+02 1.702372313252526226e+02 1.853919095185078447e+02 4.061816536818855639e+02
+6.723896457485888050e+02 2.257854111677311266e+02 2.913295806495771103e+02 4.216122987886223541e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1004f9035428c12151d0b31185f15df88ad36642
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/bridge/crossval_errors.txt
@@ -0,0 +1,29 @@
+3.031160677158884820e+00
+7.311932777312657450e-01
+3.773723759246768505e+00
+5.271593918761533715e+00
+7.408896420633541702e-01
+1.344973370763895837e-01
+7.609603320007013449e-01
+7.372587473637195465e-01
+3.574722318348853456e+00
+3.117516676048450286e+00
+1.765242167196014789e+00
+2.846135323394646033e+00
+2.075394853079954416e+00
+6.644829410992774132e-01
+4.763596528023891774e-01
+3.482295325426788768e+00
+2.091971221017461691e+00
+4.151898915801082723e-01
+3.310514074093689363e+00
+8.052084405660506761e+00
+8.410137898663668787e-01
+2.015056363810584728e+00
+5.752254255109191305e+00
+1.772150454723667945e+00
+3.570929124981857772e+00
+2.216294034905811061e-01
+5.777857679990370698e+00
+4.744082439652409278e+00
+5.579018075674269106e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..9c9bfbab79ef181a08f44691db648d3b462ad6b8
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:38b11253931290a4bcd08b36042f785dbbceb74dbc2ea3c634ba733fd96e3a9d
+size 621493
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..f8c500a2985fe355d566b920e99489140c42f97b
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31beee2ca02cff5dcf9ff514bc9f6b1a40d999b72f9990e33caa79f98f07d784
+size 727477
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fb1a3e32f2abe80f7230b4eb6c0d52a17b241aa8
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/corrs.txt
@@ -0,0 +1,30 @@
+4.985831099999999765e+02 1.960997399999999971e+02 2.980889399999999796e+02 1.978388699999999858e+02
+5.990689800000000105e+02 1.914491399999999999e+02 4.239487300000000118e+02 2.013558700000000101e+02
+4.162984799999999836e+02 2.040328700000000026e+02 6.599885299999999688e+01 1.785621299999999962e+02
+5.601637799999999743e+02 2.146042700000000139e+02 3.768953000000000202e+02 2.268169399999999882e+02
+5.480296200000000226e+02 1.210943000000000040e+02 3.666562400000000252e+02 1.103174800000000033e+02
+6.592826199999999517e+02 2.331184900000000084e+02 4.989541699999999764e+02 2.562348799999999756e+02
+6.768396000000000186e+02 2.438759400000000142e+02 5.178985900000000129e+02 2.699328499999999735e+02
+7.861383700000000090e+02 2.580815699999999993e+02 6.324438099999999849e+02 2.911928800000000024e+02
+5.068776199999999790e+02 2.486631399999999985e+02 3.205019399999999905e+02 2.688583800000000110e+02
+5.223340899999999465e+02 1.412895499999999913e+02 3.293438899999999876e+02 1.326423000000000059e+02
+6.823205199999999877e+02 2.695918699999999717e+02 5.282742500000000518e+02 3.020578800000000115e+02
+4.333758100000000013e+02 2.568404199999999946e+02 2.319244199999999978e+02 2.734003999999999905e+02
+4.219256799999999998e+02 2.166940700000000106e+02 7.549474999999999625e+01 1.996641199999999969e+02
+5.458613199999999779e+02 2.142525699999999915e+02 3.584896800000000212e+02 2.266997000000000071e+02
+5.531297799999999825e+02 1.849442500000000109e+02 3.698613100000000031e+02 1.901229199999999935e+02
+6.479109899999999698e+02 2.222157899999999984e+02 4.880514800000000264e+02 2.449804900000000032e+02
+7.624572500000000446e+02 2.641777000000000157e+02 6.088799199999999701e+02 2.968200800000000186e+02
+3.819970700000000079e+02 2.591869899999999802e+02 1.369325900000000047e+02 2.699572099999999750e+02
+5.965183347360991775e+02 2.029132062807596526e+02 4.219686410382408894e+02 2.159518907465023574e+02
+6.023373115966885507e+02 2.511254858522926554e+02 4.348509607094081275e+02 2.762665105431891561e+02
+6.259073538232618148e+02 1.517615823481111761e+02 4.452224995348522611e+02 1.489270616307932187e+02
+6.724091600182804314e+02 2.338288746764192467e+02 5.136538292206985261e+02 2.581492043719323419e+02
+6.047190589338688369e+02 1.805424661357062064e+02 4.299513950707617482e+02 1.875140898765805559e+02
+6.379103839944622223e+02 2.238172317210368192e+02 4.564564387357852979e+02 2.401400461390186081e+02
+3.975935259747074042e+02 2.453914551063811587e+02 1.686753387917444229e+02 2.514382534722063269e+02
+3.294271181369521173e+02 2.496277855369761198e+02 3.153879244165455020e+01 2.550952280415420432e+02
+3.764118738217325699e+02 2.519385112263915687e+02 1.260106354828275244e+02 2.593616983724337501e+02
+4.084209182343442990e+02 2.160644250742822692e+02 5.376696716789734865e+01 1.973480266865232977e+02
+4.087249287016336439e+02 2.210805977845561756e+02 5.349111045331716241e+01 2.047961579801880703e+02
+4.195173002904047053e+02 2.093761947939171364e+02 7.169765361560885708e+01 1.882447551053774646e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6402176f76ec6c4f2f202b5b6e28e2ef8798c882
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/flood/crossval_errors.txt
@@ -0,0 +1,30 @@
+9.966072664324915342e-01
+1.544877954599090109e-01
+1.216055812894111021e+00
+1.471117641364499684e-01
+2.490423304593254894e+00
+1.247993306307343175e+00
+1.160910691169436904e+00
+1.703539622639616002e+00
+3.285086051858440825e+00
+2.630431498484563591e+00
+1.403033950025595278e+00
+9.816643371412056007e-01
+9.521494556258358957e-03
+2.043210232592532982e+00
+4.043898815586809969e-01
+1.625646765139507899e+00
+2.487338704638109910e+00
+1.876527263518038113e+00
+7.103605442748321952e-01
+1.457778219094958194e+00
+3.647273024353496762e+00
+8.317520446923173383e-01
+7.234716387740043331e-01
+1.089039901934546029e+00
+5.020792994495852035e+00
+5.186133597713616261e+00
+2.199157286183403226e+00
+7.252937165707351586e-02
+7.685001472076117279e-01
+4.437382505674167255e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..8f640ed9c57d69f83944dd400ae8326e21693a9b
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/01.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bd82688419a849ed4652fc5bc94466b92e13579d
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/01.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d87f4222d67b8dd076e4a3c88c4edce0710f8c83824e4102bc72881f7924ffd
+size 684132
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/02.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9d389a772e942aa9363e946c97771b452a306b46
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/02.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c519c699a7dbab7a5cd750b8ac8a34eafc3ee09ae0ebe8d5e3363a0078a9e20
+size 252451
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9e1be7157a9f46fe2e1a51f701b4fbef6d2761be
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/corrs.txt
@@ -0,0 +1,64 @@
+1.403517121234611750e+03 5.466632954662579778e+02 9.157400829009418430e+02 4.245773770158315870e+02
+1.446026508108539019e+03 5.782981880235990957e+02 9.655740039421418714e+02 4.516423513744143179e+02
+1.398574169272527115e+03 5.654465129221791813e+02 9.247617410204694579e+02 4.400430766493074088e+02
+1.427879044896787491e+03 5.599802860549647221e+02 9.514510915071646195e+02 4.365780725750759075e+02
+1.460029932428406937e+03 6.078546003356241272e+02 9.741354618952484543e+02 4.774099392736268896e+02
+1.427409688874427957e+03 5.487157415183387457e+02 9.496363418761178536e+02 4.273228494567376856e+02
+1.463550102596102306e+03 5.620923881555819435e+02 9.792167608621792851e+02 4.391187220585413229e+02
+8.721037604001448926e+02 5.140213980344973379e+02 4.820933786500399947e+02 3.755344490808192290e+02
+8.900586507513088463e+02 5.001811700554750928e+02 4.969637109152347989e+02 3.650377439524464194e+02
+8.530266894020330710e+02 4.983108689772288358e+02 4.663483209574807802e+02 3.606641168156244248e+02
+9.132503841215623197e+02 5.286097464448180290e+02 5.249549245908956436e+02 3.882179677776030076e+02
+8.410567625012570261e+02 4.990589894085273386e+02 4.541021649743792068e+02 3.624135676703532454e+02
+8.679890980280031272e+02 5.708785508131833240e+02 4.768450260858535898e+02 4.240817102995434311e+02
+9.024026378677340290e+02 5.263653851509225206e+02 5.078977787572898137e+02 3.869058796365563921e+02
+9.076394808868235486e+02 4.766153764695723112e+02 5.131461313214762185e+02 3.453564218367474723e+02
+8.590116528524209798e+02 4.552939441775650948e+02 4.711593108079849799e+02 3.248003742936840581e+02
+9.293349733944800164e+02 4.908296646642438645e+02 5.293285517277175813e+02 3.576025778198489888e+02
+8.691112786749508814e+02 4.201322839065355765e+02 4.807812905089933224e+02 2.946223470496121308e+02
+7.972917172702948392e+02 4.710044732348335401e+02 4.173636970250743161e+02 3.348597167083745489e+02
+8.002841989954887367e+02 4.840965807825573393e+02 4.199878733071674901e+02 3.471058726914761792e+02
+8.317052571100257410e+02 4.119029591622520456e+02 4.475417242691461297e+02 2.858750927759681417e+02
+8.638744356558614754e+02 3.348465547385064838e+02 4.768450260858535330e+02 2.189585975825915227e+02
+9.970361884849378384e+02 4.624331693670503114e+02 5.835681049941797482e+02 3.363546786606245860e+02
+1.067836733301809318e+03 4.949436236196954155e+02 6.476321625989970698e+02 3.691316383654148581e+02
+1.052665187983908254e+03 5.606869866639333395e+02 6.349683372585099050e+02 4.220217324345082375e+02
+9.884432769467641720e+02 5.193298235340000701e+02 5.794862958419262213e+02 3.850917673052189230e+02
+9.930452410428733856e+02 5.395784655568805874e+02 5.831225687213482161e+02 4.016822623175819444e+02
+9.543887426355561274e+02 5.572960273269010258e+02 5.592595279501410914e+02 4.150910185604507205e+02
+9.210245029387642717e+02 5.766242765305596549e+02 5.303966119697288377e+02 4.305451782979943118e+02
+1.004089954873535362e+03 5.761640801209487108e+02 5.908496485901200685e+02 4.332723829575608079e+02
+9.631324744181634969e+02 5.782349639641978456e+02 5.683502101486963056e+02 4.337269170674885572e+02
+9.033069411687438333e+02 5.439503314481843290e+02 5.131243157924741354e+02 4.007731940977264458e+02
+9.132011639753786767e+02 5.540746524596245308e+02 5.217604638811014865e+02 4.114547456810286121e+02
+9.210245029387642717e+02 5.432600368337679129e+02 5.313056801895843364e+02 4.012277282076541951e+02
+7.784212227798731192e+02 6.446639064584231846e+02 3.978784008196220157e+02 4.852386616555929777e+02
+7.823088294167193908e+02 6.252258732741913718e+02 4.013899429846774183e+02 4.676809508303159646e+02
+7.698684881788110488e+02 6.310572832294609498e+02 3.841833863759060250e+02 4.718948014283824932e+02
+7.578169076045874135e+02 6.641019396426549974e+02 3.715418345817065529e+02 4.999871387488257142e+02
+7.480978910124714503e+02 6.252258732741913718e+02 3.631141333855736093e+02 4.662763339642938263e+02
+8.204073744578138303e+02 6.664345036247627831e+02 4.347495935527038000e+02 5.059567604294198873e+02
+8.145759645025442524e+02 6.116192500452291370e+02 4.294822803051206392e+02 4.585509412011718950e+02
+7.434327630482557652e+02 6.104529680541752441e+02 3.567933574884738164e+02 4.525813195205777220e+02
+8.122434005204364666e+02 6.843174941542561101e+02 4.565211549760472280e+02 5.266748592032467968e+02
+1.353822027092758844e+03 5.167364512312507259e+02 8.780895812280047039e+02 3.982440569932516041e+02
+1.308385355407982388e+03 4.634288600420246667e+02 8.399007632469630380e+02 3.516356570951298863e+02
+1.358473024981751678e+03 4.566312477427273961e+02 8.801944767072747027e+02 3.483279641991341578e+02
+1.334860266468403324e+03 4.949125380598226229e+02 8.585441232062116796e+02 3.783978996172771758e+02
+1.306596510066061910e+03 4.838216969399165350e+02 8.377958677676930392e+02 3.684748209292899901e+02
+1.285130365963017994e+03 4.920503855127500970e+02 8.200546058709886665e+02 3.753909060754629081e+02
+1.325558270690417885e+03 3.636112899628698756e+02 8.537329335393088741e+02 2.716496288828693650e+02
+7.619232704200128410e+02 7.187420337489038502e+02 4.538960459578931363e+02 5.910054474930707329e+02
+7.531660864475670678e+02 7.187420337489038502e+02 4.429703050095473600e+02 5.928264043177950953e+02
+7.573825083602262112e+02 7.550681302271973436e+02 4.721056142051361917e+02 6.602018068325943432e+02
+7.622476105671405548e+02 7.427432046363477411e+02 4.915291536688620795e+02 6.462411378430413151e+02
+7.732751755694796429e+02 7.424188644892201410e+02 5.060968082666565238e+02 6.450271666265584827e+02
+7.658153521855443842e+02 7.657713550824089452e+02 5.127736499573122728e+02 6.984419001518047025e+02
+7.700317740982034138e+02 7.523016972315615476e+02 5.413019735446596314e+02 6.853738526981995847e+02
+7.800863186591597014e+02 7.535990578200720620e+02 6.244590018737362698e+02 7.363606437904802533e+02
+7.917625639557541035e+02 7.516530169373063472e+02 7.913800441401308490e+02 8.219456145525227839e+02
+8.083039114592627357e+02 7.513286767901787471e+02 8.247642525934097648e+02 8.207316433360399515e+02
+5.656885711194121313e+02 7.433783793471844774e+02 1.960135294447161698e+02 6.822218008636290278e+02
+5.542464120866965231e+02 7.429296672282544023e+02 1.741398289424236907e+02 6.815968379921349651e+02
+5.486375106000713231e+02 7.507821293095298643e+02 1.544534984903604595e+02 7.325313120189016445e+02
+5.349517909727056804e+02 7.501090611311348084e+02 1.244552806586450799e+02 7.328437934546486758e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b8231162755761c1514e8aeba67f9e97753bb132
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater/crossval_errors.txt
@@ -0,0 +1,64 @@
+2.109604541263442046e-01
+6.343754191545079024e-01
+9.604626447087906138e-01
+1.112371853549305678e+00
+1.719472206398944270e-01
+1.389983618486540617e+00
+1.772333438562186669e-01
+1.043528763362485901e+00
+1.662313702257255166e+00
+6.873164968361269445e-01
+9.132861587639566903e-01
+2.079079424822051880e+00
+3.929539696712000540e-02
+5.858750028825535777e-02
+1.725362960485958386e+00
+1.498706429453000144e+00
+7.361902992256655898e-01
+1.155112462959435371e+00
+6.905814524075082339e-01
+1.952610849768650514e+00
+3.943618478419495532e-01
+5.791753657862884985e+00
+2.647108694363052184e+00
+3.032848331037355738e+00
+8.897047696970773467e-01
+3.825412735926300156e-01
+1.052341249583254479e+00
+3.080588116791681541e-01
+2.719112836909035394e-02
+1.413006610379454075e+00
+2.183674009730531551e-01
+7.847240469606918678e-01
+1.378023295799093217e+00
+3.976818240761115231e-01
+9.738282951090618256e-02
+9.351098173156836557e-01
+6.955525944570016827e-01
+5.584256764423154440e-01
+2.150741620512164332e-01
+3.337131805125549688e-02
+1.575918828991396126e+00
+2.025164692696442614e+00
+5.579441790759137376e+00
+8.872051410423215101e-01
+7.742449638499517839e-01
+4.386510970436738321e-01
+1.772402648524209701e+00
+7.499468297338771627e-01
+8.822004627697077606e-01
+1.928876693608766679e+00
+5.351425384933634621e-01
+3.008437227703353312e+00
+5.620997066040206214e+00
+5.690692638023513439e-01
+8.616311524736414151e-01
+5.594131258364883230e+00
+3.248873051656249622e+00
+6.803024422893112766e-02
+4.311064498905580855e+00
+2.259541286881223687e+00
+1.085469088010938998e+00
+1.497713528125711679e+00
+3.860337864713685008e+00
+2.744465939206676364e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..8f640ed9c57d69f83944dd400ae8326e21693a9b
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/01.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..10f1464582f2d63d89f59e7f0fa7ab13252c0bf7
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/01.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db2f587ddd146911e65c48b5c9fa81af5e060f51f36a3b69bbe0359dd9263dd3
+size 297557
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/02.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bd82688419a849ed4652fc5bc94466b92e13579d
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/02.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d87f4222d67b8dd076e4a3c88c4edce0710f8c83824e4102bc72881f7924ffd
+size 684132
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..df5e6ba412a763a0c7d1231c55fd8c3dc602ed9c
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/corrs.txt
@@ -0,0 +1,50 @@
+2.987967877092588651e+02 4.223206853946798560e+02 9.225885599702967284e+02 5.269050686756181676e+02
+3.444547977198159288e+02 4.320679010149111150e+02 9.675792986429532903e+02 5.294045541574324716e+02
+3.234213324340537383e+02 4.510493209069404656e+02 9.482975534975291794e+02 5.483292299483117631e+02
+2.987967877092588651e+02 4.751608542833020010e+02 9.258021841612007847e+02 5.761806396028134714e+02
+2.885365607405943820e+02 4.223206853946798560e+02 9.133047567521296060e+02 5.286904154483427192e+02
+6.066166940245287833e+02 4.446378419628800884e+02 1.350580024253208649e+03 4.783373551138423068e+02
+6.152647798956927545e+02 4.460215357022663056e+02 1.362009688343699736e+03 4.783373551138423068e+02
+6.308313344637878117e+02 4.844190369702341741e+02 1.384869016524681683e+03 5.190555334362164785e+02
+6.045411534154493438e+02 4.795761088823824139e+02 1.348794139239069409e+03 5.172696484220772390e+02
+5.723702739747195665e+02 4.771546448384565338e+02 1.301646774865794214e+03 5.208414184503557181e+02
+7.321326865086233511e+02 5.189600829123442054e+02 1.426648962845904407e+03 5.502526159436411035e+02
+6.952341867916571800e+02 5.302602484506650171e+02 1.399037887857698024e+03 5.654962302600466728e+02
+7.623433331518893965e+02 5.261091672325063655e+02 1.467777959963753347e+03 5.508278466725621456e+02
+7.457390082792546764e+02 5.480176514394549940e+02 1.446494422993677745e+03 5.775760755673869653e+02
+7.215243678399956480e+02 4.693777239176708918e+02 1.427511808939285856e+03 4.950304659672286789e+02
+2.137836007884084779e+02 3.473788774328126010e+02 8.906646694904027299e+02 4.566001456861312136e+02
+1.820175798862397301e+02 3.409722849819551129e+02 8.584382324718022801e+02 4.551989962505399490e+02
+2.153852489011228499e+02 3.943605554057680251e+02 8.946345928912448926e+02 5.000357781894623486e+02
+1.798820490692872340e+02 3.858184321379579842e+02 8.593723320955298277e+02 4.977005291301434795e+02
+4.348808888852079235e+01 4.488766776777111431e+02 7.668412146435323393e+02 5.669073554834446895e+02
+4.447047982195066851e+01 4.129908360862074801e+02 7.672572798517172714e+02 5.381466887008497224e+02
+1.941838714601171318e+02 4.986112097318912788e+02 8.817404582750726831e+02 5.998069580836389605e+02
+1.797534379770573310e+02 4.966611511530994107e+02 8.677929218487254275e+02 5.998069580836389605e+02
+1.746832856721984513e+02 5.099215494888841249e+02 8.633550693494330517e+02 6.115355682603401419e+02
+1.731232388091649455e+02 4.650702021766711596e+02 8.595511957786111452e+02 5.703269379097686169e+02
+2.035441526383181383e+02 4.814506942385228285e+02 8.899821843451870791e+02 5.823725375507049193e+02
+1.523935219787211395e+02 4.133482128927837493e+02 8.355361283176038114e+02 5.236464652993014397e+02
+1.963543895283957568e+02 4.074867638861604746e+02 8.762130226334011240e+02 5.146416566339721612e+02
+4.434197361592768516e+02 4.544329741768374902e+02 1.128685088353774518e+03 5.237216515323132171e+02
+4.355340789320897557e+02 4.587908373813356206e+02 1.123810639541448154e+03 5.276676339041963502e+02
+4.498527723182979230e+02 4.556780779495512661e+02 1.134952236826765329e+03 5.241858847525347755e+02
+4.639639484090537849e+02 4.581682854949787611e+02 1.151200399534519192e+03 5.248822345828671132e+02
+4.556632565909620780e+02 4.390766943133678524e+02 1.140987268689645362e+03 5.058486725537838993e+02
+4.550407047046052185e+02 4.463397996541980888e+02 1.140755152079534582e+03 5.132764040773286069e+02
+4.587760160227464894e+02 4.463397996541980888e+02 1.145397484281749939e+03 5.135085206874392725e+02
+4.465324955910612630e+02 4.839004301310629899e+02 1.132863187335768316e+03 5.522719945759380380e+02
+4.676992597271950558e+02 4.677140810857841871e+02 1.157699664617620783e+03 5.323099661064117072e+02
+1.889113811606345052e+02 3.089447292260057338e+02 8.693428810072523447e+02 4.208949625426010357e+02
+1.834191741745515571e+02 2.252941920533577900e+02 8.635352561421038899e+02 3.351208106881026652e+02
+8.742537370746526904e+01 5.356845514731066942e+02 7.785539052876317783e+02 6.439940125048218533e+02
+3.263139680571970302e+01 5.831726647879529537e+02 7.599659899561384009e+02 6.915572076177609233e+02
+9.107830550091497912e+01 5.159587197884783336e+02 7.823808290323510164e+02 6.254060971733283623e+02
+5.729830842288644135e+02 3.426182336336015055e+02 1.326426310886490228e+03 3.645918379055057130e+02
+5.953251222022986440e+02 4.159330672469500314e+02 1.346409458154053709e+03 4.449439870747970645e+02
+5.432431460913823003e+02 4.486905092352639031e+02 1.285043155030771914e+03 4.921488356311676284e+02
+5.521984180018566803e+02 4.625947472015266158e+02 1.291337134838287966e+03 5.055235427221392683e+02
+6.269042389392343466e+02 4.493975043860908158e+02 1.379059478405543359e+03 4.799542497541052057e+02
+3.778310515211186384e+02 4.781719669646308262e+02 1.051826775237077982e+03 5.583517754714107468e+02
+3.808586594790852473e+02 4.958330133861024933e+02 1.057766404338827897e+03 5.745507639307290901e+02
+3.427612593413391551e+02 5.003744253230523782e+02 1.032387989085895697e+03 5.815703255964338041e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a14b3cdd5a55383687b953a78e61923fd2c5f038
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/kyiv_dolltheater2/crossval_errors.txt
@@ -0,0 +1,50 @@
+1.205625952224563369e+00
+4.939422348912164695e-01
+1.408368284957418615e+00
+4.354895719375530660e+00
+5.845732642066374662e-01
+5.671421046587751258e-01
+3.241250475917350493e-01
+1.176934650482001077e+00
+9.602228893504475282e-01
+1.539524729734165653e+00
+5.977619023052603842e-01
+3.805437463550794264e-01
+2.688021655154999578e+00
+1.321323261517616254e+00
+2.544837266919131480e+00
+1.340388986541366956e+00
+1.049147679116919640e+00
+6.487360989178689863e-01
+1.221299429389517410e+00
+3.649503550987649003e+00
+5.805746815258191695e+00
+1.008171333833599137e+00
+1.074115212606320480e+00
+2.646741835059085446e+00
+3.351911533805484567e+00
+1.617215307823174131e+00
+7.151241593030583488e+00
+2.331406646868857457e+00
+7.210754257250947541e-01
+8.528073969880842764e-01
+8.421739174470552758e-01
+1.636074672456483636e+00
+1.697227528685974207e-01
+1.095537504541505142e-01
+1.492383094478506145e+00
+3.682014845518886692e-01
+3.537468933850831943e-01
+2.031705423621828210e-02
+5.066690777812540070e+00
+3.293129568230977355e+00
+2.760852141563832074e+00
+5.163656053651205724e+00
+1.366074023533330628e+00
+2.186481125400338232e+00
+4.904450483949484019e-01
+2.338855629843699102e+00
+2.778060616070257560e-01
+4.025423249494965994e-01
+7.730484211806706862e-02
+1.283781951389742604e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..4beb4f11e66b9f66ec03e273cacef6412a9ff615
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc924ed57a48a75b7f073db4741b8b7e2d803d460f1a862b835f4d99311651fb
+size 691893
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a3959270387f20c9d983f0ca3575b5fe8c42a85
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e26283b4d7ef459bf967312a5a9e8e5b32fa51c0efea20fcc5531b271e88266
+size 496664
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6871c86a0e69b76610a82bb89ed95538d7bf2414
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/corrs.txt
@@ -0,0 +1,22 @@
+3.925862300000000005e+02 2.545839400000000126e+02 2.499319999999999879e+02 3.251010099999999738e+02
+6.381495099999999638e+02 2.215317599999999914e+02 4.726627899999999727e+02 3.199427400000000148e+02
+5.014558400000000233e+02 2.289949200000000076e+02 3.575410899999999970e+02 3.167764099999999985e+02
+2.755296599999999785e+02 2.137739599999999882e+02 9.343351699999999482e+01 2.898795000000000073e+02
+2.447049199999999871e+02 3.068080900000000000e+01 7.066881899999999916e+01 1.679281799999999976e+02
+7.458263799999999719e+02 2.498548299999999927e+02 5.190902899999999818e+02 3.398734400000000164e+02
+5.115282199999999762e+02 6.731880400000000009e+01 3.293148400000000038e+02 2.256605500000000006e+02
+6.698830699999999752e+02 2.581891200000000026e+02 4.624721400000000244e+02 3.405574700000000234e+02
+6.418331899999999450e+02 3.179037400000000275e+02 4.336016300000000001e+02 3.712456199999999740e+02
+5.147312200000000075e+02 2.215231100000000026e+02 3.794336299999999937e+02 3.116288599999999747e+02
+7.477903000000000020e+02 5.901460699999999804e+01 5.822250800000000481e+02 2.456988000000000056e+02
+3.554008000000000038e+02 2.895215299999999843e+02 1.787855700000000070e+02 3.432538400000000252e+02
+6.026838400000000320e+02 2.625073800000000119e+02 4.086288700000000063e+02 3.397562100000000100e+02
+4.577935299999999756e+02 1.795010299999999859e+02 3.282694299999999998e+02 2.821538499999999772e+02
+5.587376199999999926e+02 2.643346999999999980e+02 3.724694900000000075e+02 3.408016200000000140e+02
+4.964148099999999886e+02 2.184439199999999914e+02 3.535551600000000008e+02 3.105630400000000009e+02
+3.180709299999999757e+02 1.195023400000000038e+02 1.410334100000000035e+02 2.347362300000000062e+02
+3.730920300000000225e+02 1.536518200000000078e+01 2.000334100000000035e+02 1.823751499999999908e+02
+4.079947199999999725e+02 9.514830000000000609e+01 2.322432599999999923e+02 2.305287199999999928e+02
+4.558833500000000072e+02 3.438276700000000119e+01 2.814226300000000265e+02 2.027807699999999897e+02
+4.566216911046785185e+02 2.287329680078115075e+02 3.273232766414439538e+02 3.111077682432915594e+02
+3.965061988826488459e+02 2.726309488309927360e+02 2.496704751160283706e+02 3.352331852304391191e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2180b0f470557d9e7eb64fd1ac1ce6aef0c09fa1
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/rovenki/crossval_errors.txt
@@ -0,0 +1,22 @@
+2.223769488158787055e+00
+2.041660990806244691e-01
+2.591624711512738433e+00
+2.980969265769438881e+00
+4.279999916905446788e+00
+3.607338159981161585e-01
+5.093382919492810856e+00
+1.723610403284629444e-01
+1.385046455528339049e+00
+1.359466349711395239e+00
+1.376832687544708378e+01
+2.418395247191616804e+00
+2.410329102105791821e-01
+3.259756579203542781e+00
+5.259098693002325575e+00
+3.273444131738968643e+00
+2.662659910154434151e+00
+5.280286776725191977e+00
+4.117411483340728540e-02
+5.288077931743116224e-03
+3.267937158693952515e+00
+1.394501496347928304e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..d1b4ece1baf47d952aa3746bfa2335b5e1fa1684
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53886fb0c41be6ded1d7d3918fe3948e6dba991d72599710b2570641b952f22c
+size 990530
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..7deb1d6677c76ec50ab4c324ada4838947c490b0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0ecab06081c7b8611ceca3bf64630b13a85cc709389bf925696ae456c30c3553
+size 759959
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fea259ddf089697453947ba96caebe9fe232bc8c
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/corrs.txt
@@ -0,0 +1,25 @@
+1.768720399999999984e+02 3.979225299999999947e+02 4.708233799999999860e+02 1.959799399999999991e+02
+1.766017400000000066e+02 4.222839299999999980e+02 4.707061499999999796e+02 2.611706399999999917e+02
+1.758983399999999904e+02 4.596776399999999967e+02 4.682442500000000223e+02 3.509988299999999981e+02
+3.112423999999999751e+02 4.943585600000000113e+02 7.164331600000000435e+02 2.956984299999999735e+02
+3.310578300000000240e+02 4.842376600000000053e+02 8.092814100000000508e+02 2.980654000000000110e+02
+3.111968499999999835e+02 5.296323700000000372e+02 7.189657799999999952e+02 4.042986500000000092e+02
+3.314909200000000169e+02 5.161572499999999764e+02 8.047335199999999986e+02 3.949218900000000190e+02
+2.031905600000000049e+02 3.987170399999999972e+02 5.484593099999999595e+02 2.075618299999999863e+02
+2.019465500000000020e+02 4.416759700000000066e+02 5.445906099999999697e+02 3.245060100000000034e+02
+3.006914100000000190e+02 4.595604099999999903e+02 7.968062599999999520e+02 3.468510499999999865e+02
+8.590342699999999354e+01 3.966785199999999918e+02 2.417377799999999866e+02 1.957473699999999894e+02
+6.083863300000000152e+01 3.960206699999999955e+02 1.838174500000000080e+02 1.984437399999999911e+02
+1.653182099999999934e+02 3.722618600000000129e+02 4.425554599999999823e+02 1.297882600000000082e+02
+1.193828900000000033e+02 4.585934100000000058e+02 3.248836699999999951e+02 3.518213600000000270e+02
+1.579033799999999985e+02 4.215282700000000204e+02 4.204320900000000165e+02 2.576441300000000183e+02
+1.842219000000000051e+02 4.180635300000000143e+02 4.893462400000000230e+02 2.474543400000000020e+02
+1.923110000000000070e+02 4.434994399999999928e+02 5.161926600000000462e+02 3.249730400000000259e+02
+3.200348999999999933e+02 4.505334399999999846e+02 8.501474100000000362e+02 3.175427300000000059e+02
+3.200348999999999933e+02 4.593259400000000028e+02 8.503818800000000238e+02 3.455614899999999921e+02
+3.247481557361764999e+02 4.943084603998072453e+02 7.579714323690594711e+02 3.001010955510471945e+02
+3.436647225311791090e+02 4.846647204651000607e+02 8.483393450439383514e+02 3.027986451831331465e+02
+3.102825458341156377e+02 4.572171529586256611e+02 8.240613983551650108e+02 3.405643400323361334e+02
+3.556460695439168944e+02 4.606928978841127105e+02 9.578515049123923291e+02 3.496653086346104260e+02
+3.555344823310228435e+02 4.505384615107533932e+02 9.556024229066407543e+02 3.187404310555252778e+02
+1.521515380648792188e+02 4.448291019826118600e+02 4.073903440174262300e+02 3.232320381043676889e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b0ba1873dfcfe6c7e870018858c8694b6d36fa0d
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/stadium/crossval_errors.txt
@@ -0,0 +1,25 @@
+2.299443554484710273e+00
+1.885525491243613239e+00
+6.086572087327464686e+00
+2.034851019204923617e+00
+8.692972926988611349e+00
+6.310085817588906743e+00
+5.782550545440895640e+00
+9.802227095767416243e-01
+2.634379550856073671e+00
+3.004014133133211306e+00
+7.441118090272435204e+00
+6.820322738563534770e+00
+2.857044245627155199e+00
+3.251653039626808406e-01
+2.033689779695354805e+00
+2.800776884471881445e+00
+2.064570116053137561e+00
+2.820266171514703490e+00
+1.978516375695342111e+00
+4.965627338948699787e+00
+5.329964473568547412e+00
+2.740790105599770765e+00
+6.842949138841069257e+00
+2.904103786325705983e+00
+2.871389478412832652e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..0a577f2ae6655ebf32fa35fc811ea6030ac40d4c
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7713564379e217f0cfee10f12ee5a9681679afe355c01c60aecf7e1f72cb095
+size 439590
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..160568f2d168e6271365d7eb44c612dd60cade56
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83a8c09a57a979df54e364175754a043301ac7726ce5f7e2b48ebcb77b8976cd
+size 732859
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a428ea7b70331b541856e006a7ff8c48aa1276d0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/corrs.txt
@@ -0,0 +1,24 @@
+5.728948299999999563e+02 1.939825099999999907e+02 6.367427099999999882e+02 3.378577599999999848e+02
+4.687543299999999817e+02 1.892049900000000093e+02 5.377120499999999765e+02 3.363918300000000272e+02
+4.121361800000000244e+02 1.856385400000000061e+02 4.508359600000000000e+02 3.342429000000000201e+02
+5.869821899999999459e+02 9.877109900000000664e+01 5.553046600000000126e+02 2.652346800000000258e+02
+6.642043499999999767e+02 9.188649300000000153e+01 5.798319999999999936e+02 2.647494899999999802e+02
+3.670359500000000139e+02 2.320079399999999907e+02 3.654839000000000055e+02 3.947103700000000117e+02
+5.437276100000000270e+02 2.619335500000000252e+02 6.116904100000000426e+02 3.971625900000000229e+02
+4.302309099999999944e+02 2.261345200000000091e+02 4.775478600000000142e+02 3.763801399999999830e+02
+6.861916800000000194e+02 2.373991800000000012e+02 7.011787799999999606e+02 3.622585000000000264e+02
+5.732465300000000070e+02 2.929346600000000080e+02 6.364006900000000542e+02 4.162532899999999927e+02
+3.287964600000000246e+02 2.068985700000000065e+02 2.830723600000000033e+02 3.759810600000000136e+02
+5.876089626368114978e+02 8.289871306694442410e+01 5.532441718240431783e+02 2.492208339372825208e+02
+4.687863450413139503e+02 2.336412410594688822e+02 5.361196246855870413e+02 3.800426423573028387e+02
+6.047161665443178435e+02 2.697475998962042922e+02 6.569702977946423061e+02 3.925444361272050742e+02
+3.965736273678431303e+02 2.421368549034066291e+02 4.186027632485056529e+02 4.033793240611204283e+02
+3.965736273678431303e+02 2.187739168325778110e+02 4.186027632485056529e+02 3.742084719313484129e+02
+4.893174118308301104e+02 2.541723078489850991e+02 5.577894005534176358e+02 3.983786065531595000e+02
+5.751068647856137659e+02 1.106118303041440356e+02 5.480509012947979954e+02 2.721983718604265050e+02
+5.925572401533436278e+02 1.051961965693313061e+02 5.627727784119069838e+02 2.706068175774958036e+02
+6.085032728169588836e+02 1.018866426202791047e+02 5.751073241046200337e+02 2.682194861530997514e+02
+5.723990479182074296e+02 8.533887287501804053e+01 5.400931298801444882e+02 2.495187233286639525e+02
+5.648773343976341721e+02 8.714408411995560755e+01 5.257691413337681752e+02 2.491208347579312772e+02
+5.627712546118737009e+02 1.067005392734459406e+02 5.261670299045008505e+02 2.666279318701690499e+02
+6.623587416242630752e+02 8.052497622185117621e+01 5.790862098119467873e+02 2.546912747481887322e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1149a1de18179182c045e56198b0d3fec185f407
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine/crossval_errors.txt
@@ -0,0 +1,24 @@
+2.396012140755679365e+00
+4.480711277911260115e-01
+2.987724645529541867e+00
+4.277000771411354485e+00
+8.508079482254223835e-01
+3.462941354139007277e+00
+4.003849617755437684e+00
+3.053093833108082134e+00
+4.069386060501227753e+00
+1.160811959330638521e+00
+1.068969419958323996e+01
+2.050613156701358797e+00
+2.598761125557944029e-01
+1.744324021645896394e+00
+1.059052432611495398e+00
+1.831625933590129485e+00
+3.212063708729713252e+00
+1.474694824562631679e+00
+3.835499616059664163e+00
+1.586777748179354131e+00
+1.313007222469624091e+00
+2.977945246196925133e+00
+7.088566085793469584e-01
+4.436947495199563996e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..dd882a8f19e9abcc96b9e04102719bd03f6b6566
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62a4fd2d8ce7b7b795c5d24b12fea707c760af4fef0c06435927a9d75b5a924a
+size 837810
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..fb568f60b774b17ce78b9af75475e46540a94264
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f540475363f6e65ccd0b9b37f132d1f3b2cfbd670ca39bd055f532566ba02c34
+size 1081407
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..055872e435cfe2718d5eebeb8266829fdbd0fe88
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/corrs.txt
@@ -0,0 +1,16 @@
+5.720838800000000219e+02 2.991620199999999841e+02 4.506014900000000125e+02 3.717636600000000158e+02
+5.265157500000000255e+02 2.894499799999999823e+02 3.801778299999999717e+02 3.718808900000000222e+02
+6.661703400000000101e+02 3.181354799999999727e+02 5.339412200000000439e+02 3.701223899999999958e+02
+5.609456800000000385e+02 1.807470600000000047e+02 2.540061200000000099e+02 2.981746999999999730e+02
+4.714216400000000249e+02 2.935240999999999758e+02 2.144760200000000054e+02 3.865360900000000015e+02
+5.766269300000000158e+02 3.185162300000000073e+02 4.876987399999999866e+02 3.795601899999999773e+02
+7.010207799999999452e+02 9.238567299999999705e+01 4.310612300000000232e+02 2.756551900000000046e+02
+5.256166200000000117e+02 2.393388099999999952e+02 3.795722999999999843e+02 3.411927499999999895e+02
+5.705888999999999669e+02 2.400422100000000114e+02 4.503573499999999967e+02 3.406065800000000081e+02
+6.636008900000000494e+02 2.423675100000000100e+02 5.344295200000000250e+02 3.397956300000000169e+02
+7.031309800000000223e+02 1.568639799999999980e+02 4.430190299999999866e+02 3.019154500000000212e+02
+5.248059264207870456e+02 2.251150199168595236e+02 3.790005248323007550e+02 3.295326686123989930e+02
+5.710311283966667588e+02 2.707436201554564263e+02 4.491400054071735326e+02 3.548917961443609670e+02
+5.556861938875076703e+02 1.970657522714686820e+02 2.508518240386408706e+02 3.072043463526618439e+02
+6.934073850204836162e+02 1.585616007408017367e+02 4.255523570510792410e+02 3.026881814837404363e+02
+6.839389858118621532e+02 2.837548791659100971e+02 5.470150264116562084e+02 3.555176696442887305e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d2e3cfa5881603cd96fd7b37a87205da61ef67ae
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/submarine2/crossval_errors.txt
@@ -0,0 +1,16 @@
+2.089796770444086516e+00
+3.560139367682131351e-01
+1.824197790271693709e+00
+1.579644173121356632e+00
+5.693810774341417913e-01
+5.688931313970732262e+00
+2.318580807601463345e+00
+3.531501782675507961e+00
+2.223367754851889533e+00
+9.060112503001150897e-01
+3.469717531158602153e+00
+3.655183345533559169e+00
+3.059823892691055924e+00
+2.766381135592363094e+00
+3.540309193863722115e+00
+1.130637216535496786e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..88e667f0235eee842466d8d368b25e2e693ed3b1
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18b637e3a4bb1e7211f2d7af844a00ba96bc2529e99701b0e20818eb0b88493f
+size 362970
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..733b0f62a17af04175905d6df843be81c4cebcc3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b6ed4c9fa8fb005a7afdd4dc23cf3d902ded49ea42782b1cd99fde9511a4e71
+size 396068
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a0fc970672499a59f2d23651d1ec8b8670474e96
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/corrs.txt
@@ -0,0 +1,29 @@
+1.699193999999999960e+02 2.080499700000000018e+02 1.176834999999999951e+02 3.173487400000000207e+02
+2.292048599999999965e+02 2.058467400000000112e+02 2.624721400000000244e+02 3.187410199999999918e+02
+1.994497399999999914e+02 2.146790100000000052e+02 1.887659500000000037e+02 3.375730399999999918e+02
+1.795626200000000097e+02 1.855245799999999861e+02 1.436619100000000060e+02 2.663481100000000197e+02
+2.182769199999999898e+02 1.848357000000000028e+02 2.359317700000000002e+02 2.660797499999999900e+02
+1.542284699999999873e+02 2.212005000000000052e+02 4.179829199999999645e+01 3.534005700000000161e+02
+2.153751399999999876e+02 2.377072200000000066e+02 2.079303099999999915e+02 4.203798499999999763e+02
+2.204801500000000090e+02 1.973710199999999872e+02 2.409243899999999883e+02 2.938661200000000235e+02
+1.779310399999999959e+02 1.991198400000000106e+02 1.400422100000000114e+02 2.972416799999999739e+02
+2.127223500000000058e+02 2.496951099999999997e+02 1.978181600000000060e+02 4.553668999999999869e+02
+2.283880399999999895e+02 2.374727599999999939e+02 2.469689899999999909e+02 4.207315500000000270e+02
+2.189803200000000061e+02 2.132964200000000119e+02 2.351401899999999898e+02 3.365421499999999924e+02
+1.708922000000000025e+02 2.512869200000000092e+02 7.552689499999999612e+01 4.546662699999999973e+02
+1.787323338049845063e+02 2.414601446406908281e+02 1.001912872441162392e+02 4.245428804908389111e+02
+1.981264334817288955e+02 2.411944720423792887e+02 1.559283702540789989e+02 4.242349408057009441e+02
+1.787323338049845063e+02 2.527512300689324434e+02 9.895952850356457020e+01 4.590321252262854728e+02
+1.819204049847233193e+02 2.527512300689324434e+02 1.085056587428399553e+02 4.587241855411475626e+02
+2.062294477302316977e+02 2.467735966069221831e+02 1.808714847502502323e+02 4.460986584504929624e+02
+2.145981345770460678e+02 2.407959631449119229e+02 2.070463579869730495e+02 4.288540360827697100e+02
+1.973835808866980983e+02 1.868103617409780099e+02 1.870956487489719677e+02 2.717154015176403732e+02
+1.973835808866980983e+02 1.927719464278879400e+02 1.868461254080040419e+02 2.854391852708775446e+02
+1.718061239043292971e+02 1.418526518919348405e+02 1.330619491545299979e+02 1.649613076970173893e+02
+1.828660773958816606e+02 1.412996542173572152e+02 1.584983998654981860e+02 1.625951262355319500e+02
+2.218524134536037309e+02 1.410231553800684026e+02 2.501879314980578783e+02 1.578627633125611283e+02
+1.626198460074653553e+02 2.383619460711001352e+02 5.223644950679010890e+01 4.145819853141157409e+02
+1.631010251149683086e+02 2.424519684848752945e+02 5.223644950679010890e+01 4.251941023577143710e+02
+1.877643568234823022e+02 2.458244282717295732e+02 1.269324872222991871e+02 4.371435348179112452e+02
+1.746148151192849696e+02 1.935591612449198919e+02 1.314704047950642689e+02 2.889048941075864718e+02
+2.287110309909827777e+02 1.759154723759968704e+02 2.635742274688911948e+02 2.470552098254199223e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a2c4648311a92a8c241d75278c3062eebbee5f5a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/tyn/crossval_errors.txt
@@ -0,0 +1,29 @@
+4.094747428255010924e+00
+5.334159760209342238e+00
+2.691091772186884334e-01
+1.560500265154325140e-02
+3.124908881356045054e+00
+5.995524641871649685e+00
+1.064547777858604327e+00
+6.520279900533545003e+00
+2.616296339503305646e+00
+2.968194822467821403e+00
+3.153913249945764274e+00
+3.083370555874225261e+00
+1.787218154227150801e+00
+5.843889444021538315e-01
+3.743974298750481378e+00
+5.431708954181251325e-01
+1.817303339677495411e-01
+3.599477957313317877e-01
+4.085401545563981385e-01
+8.127999326452226558e-01
+1.391102985596567976e+00
+1.998585421031668696e+00
+9.684885938784187909e-01
+5.681982249348700442e+00
+2.961368819060811841e+00
+3.282659730971194123e-01
+1.145096508198633289e+00
+4.981046127640841981e+00
+4.610934227319352985e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..34b104d65cf96f3bd7746f39cef0a5d9ba7e0c6a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:198c31a0d634f2730279b292f5b9657568337707cb08d4afa3c8b5bae6af3200
+size 583920
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..1c8073bc28774c599bb47b82e4cac84a2f98bc3a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b78338df059ad86b55ec507003d6b7695364429f1e8fe021d3e10fe758ea7a44
+size 546186
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..17367ca09d5f4a30996ca64f1360e1148a998a5f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/corrs.txt
@@ -0,0 +1,12 @@
+6.312811599999999999e+02 1.731887600000000020e+02 4.330541999999999803e+02 2.733132499999999823e+02
+6.462870199999999841e+02 1.682649600000000021e+02 4.463015599999999949e+02 2.687411500000000046e+02
+2.651443199999999933e+02 2.370269499999999994e+02 1.882136900000000139e+02 3.347089100000000030e+02
+1.669149299999999982e+02 2.277354400000000112e+02 1.235839100000000030e+02 3.306541599999999903e+02
+7.190660699999999679e+02 1.947026299999999992e+02 5.051288099999999872e+02 2.947001999999999953e+02
+5.631307199999999966e+02 2.854144200000000069e+02 4.356277799999999729e+02 3.759208699999999794e+02
+5.389725700000000188e+02 2.043652099999999905e+02 3.457382200000000125e+02 2.980709100000000262e+02
+1.657813299999999970e+02 2.919350099999999770e+02 1.120262299999999982e+02 3.869109199999999760e+02
+2.623597700000000259e+02 2.942409499999999980e+02 1.863379600000000096e+02 3.861978399999999851e+02
+6.702347700000000259e+02 2.667869200000000092e+02 4.865651399999999853e+02 3.571412700000000200e+02
+6.644581799999999703e+02 1.676787999999999954e+02 4.610729499999999916e+02 2.699134799999999927e+02
+6.771543175316711540e+02 1.770803000551282196e+02 4.715674571393174119e+02 2.789517777805050400e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2478510d2a7fb6c2df4e60ce18608a1c138b66e9
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGALBS/zanky/crossval_errors.txt
@@ -0,0 +1,12 @@
+2.942755404534845853e-01
+9.185852889156850276e-01
+3.346748789702642668e+00
+2.162153226798289030e+00
+1.122322047852059246e+00
+6.403388640390542896e+00
+5.048211887924723307e+00
+2.993677282433861997e+00
+4.056310620435071179e+00
+2.749595016855273855e+00
+1.479053446042520203e+00
+2.570293849026595190e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..732fed30a1bd60aee698b30b4fb000c59f344ef7
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..e7391ff899ad3f4b3f9ad8be103378d2625425bd
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/01.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b380871be1e0c70c7176328b680573d1e880be8c
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/01.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5bf7ff2c7b8d279888a54057c2f8a485fb42a6d4be9443550fd58d215ce91670
+size 209291
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/02.jpg b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3e9fdbb4706c2e7db4d6014a8253625c6aa2f579
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/02.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d4ea5306dda72998b7b3eb240098835f4f9f04732a89b166878d6add7956558
+size 233057
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..098d2e39bd91a58f3c335dbbdf75200eb0e7188e
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/corrs.txt
@@ -0,0 +1,50 @@
+3.435071438397347947e+02 6.959497241954786659e+02 1.781599902477694854e+02 6.095230902465996223e+02
+3.805172005450424990e+02 7.016261132607098716e+02 2.785537957219011673e+02 5.854025006196978893e+02
+3.961840343650807199e+02 7.029884466363654383e+02 3.170163575593931569e+02 5.704086205813534889e+02
+4.039039234937952187e+02 7.483995591582153111e+02 2.739904409276225010e+02 6.877520295770918892e+02
+3.955028676772529366e+02 7.309162808373031339e+02 3.163644497316391266e+02 6.551566381893867401e+02
+4.050392013068414485e+02 7.084377801389873639e+02 3.906819420956067006e+02 5.795353301699109352e+02
+3.948217009894252101e+02 7.100271690772522106e+02 3.228835280091801678e+02 5.925734867249930176e+02
+3.283572984315276813e+02 6.949318290264856159e+02 1.321443203711017418e+02 6.181286591705385263e+02
+2.433024337899045406e+02 5.431589549166819779e+02 1.129317308936844029e+02 3.525648714114240079e+02
+3.179748834014487215e+02 5.187761142271981498e+02 4.267970256861472080e+02 2.272259253061831714e+02
+3.057834630567068075e+02 5.299515828765449896e+02 3.988288311006801905e+02 2.686602876550231258e+02
+2.488901681145779321e+02 5.076006455778514237e+02 1.139675899524054330e+02 2.676244285963020957e+02
+2.605736126116221953e+02 4.689944811528353057e+02 7.460494572100719779e+01 1.661102408416440994e+02
+2.554938541346464262e+02 4.867736358222505828e+02 9.221454971926425515e+01 2.137597575428100072e+02
+2.204435206435134091e+02 4.639147226758595366e+02 4.166529727979423114e+00 2.034011669556000470e+02
+2.712411054132713275e+02 4.217527273169604314e+02 2.631312944082292233e+02 7.288292555675411677e+01
+2.930840668642672426e+02 4.293723650324241703e+02 3.594661868692821827e+02 6.148847591082312647e+01
+3.849691714495721726e+02 6.482676180430371460e+02 3.231634799028571479e+02 4.351514987440575624e+02
+4.471865534723270912e+02 7.245437354267030514e+02 6.353598906562697266e+02 5.831930225529338259e+02
+4.217611810111051227e+02 7.257402235425252002e+02 6.091756755608221283e+02 6.154197488242539293e+02
+2.859597798652843039e+02 5.743844768910154244e+02 3.473335246063475097e+02 3.928539205129500260e+02
+2.425870856667291378e+02 5.561380331247266895e+02 1.197322703151500036e+02 3.878184945330563096e+02
+2.022056117577295424e+02 5.366951012426158059e+02 3.312294346097746711e+01 3.858043241410987321e+02
+2.798073496144208434e+02 5.806978845794687913e+02 2.953217305659241561e+02 4.074589025475969493e+02
+2.585312994386380296e+02 5.667866210029953891e+02 1.757320048165831849e+02 3.942236759824524484e+02
+2.620773078012684891e+02 5.866988218085356266e+02 2.083473845664034343e+02 4.362927889930901983e+02
+2.499390484061103734e+02 5.776974159649353169e+02 1.360263251211497959e+02 4.268390557322727545e+02
+2.538995865089028712e+02 3.916563015146545581e+02 1.363968836700531710e+02 1.849164094976345041e+01
+2.400746446444624667e+02 3.985687724468747319e+02 1.254220968370807441e+02 5.580591618186952019e+01
+3.045910400118511348e+02 4.206886794299794587e+02 4.744203181256031030e+02 3.714877856581654214e+01
+2.995218946615563027e+02 4.607810108368566944e+02 4.535682231429553894e+02 1.227521158630012224e+02
+2.004431446330665381e+02 3.985687724468747319e+02 4.420867427308530750e+01 1.029974995636508766e+02
+2.124247609155815724e+02 3.815180108140648372e+02 5.518346110605773447e+01 5.361095881527501206e+01
+1.981389876556598040e+02 4.252969933847928701e+02 3.652632349000464274e+01 1.578714337285127840e+02
+2.064339527743240410e+02 5.008733422437339868e+02 3.103893007351837241e+01 2.994461838738567963e+02
+2.105814353336561737e+02 4.861267375883307977e+02 2.445405797373496171e+01 2.632293873250479237e+02
+2.474479469721640328e+02 4.497210573453043025e+02 1.034725231711362312e+02 1.468966468955404707e+02
+4.037349278460346795e+02 8.149293409312576841e+02 3.764499032891841352e+02 8.998383716029238713e+02
+3.392049459191438814e+02 7.821174857141945722e+02 1.668952560366415128e+02 8.572170874159660343e+02
+3.935267951118372594e+02 7.883152805885287080e+02 3.080782599059392624e+02 8.252511242757477703e+02
+3.268093561704755530e+02 7.795654525306451887e+02 1.180583679057523341e+02 8.616568045187741518e+02
+3.286322370158679860e+02 7.591491870622503484e+02 1.180583679057523341e+02 7.968369348177757274e+02
+4.011828946624853529e+02 7.573263062168580291e+02 2.858796743918987886e+02 7.204738006494764022e+02
+3.475901978079489254e+02 7.518576636806808438e+02 1.438087271020393700e+02 7.577674243130644527e+02
+3.402986744263793071e+02 7.423786832846403740e+02 1.278257455319301243e+02 7.364567822195855342e+02
+3.282676608467895107e+02 7.358163122412277062e+02 8.520446134497240109e+01 7.320170651167774167e+02
+3.068108708836140295e+02 6.952626313721491442e+02 1.212063200657838706e+02 6.466681953866445838e+02
+2.906379928440490517e+02 5.406954008811891299e+02 3.609134926918036399e+02 3.078481908902177793e+02
+2.981969097652342384e+02 5.297006126321925876e+02 3.754353015445329902e+02 2.742187388102130399e+02
+2.820483145245204355e+02 5.602798674497142883e+02 3.685565499827138183e+02 3.674640377593170797e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4e030d242bb4adec67a3a3b2eed901416184964f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGBS/kn-church/crossval_errors.txt
@@ -0,0 +1,50 @@
+4.852766667173328829e-01
+5.773447457216664136e-02
+9.383705325914409867e-01
+3.297959889110597231e+00
+2.931178080395086649e+00
+1.276087969696925839e+00
+1.243149818607734325e+00
+3.525279717474206986e+00
+3.009354776434531864e+00
+2.895347087318727475e+00
+3.539998841178534139e+00
+2.350301461569856531e+00
+9.330662437729901892e-01
+7.953587772446200910e-02
+2.113451162925480187e+00
+3.793867642603225399e+00
+2.037081321239481713e-01
+1.528024863707494685e+00
+6.803045160843667283e+00
+1.168806582172539699e+01
+2.174503807975670888e+00
+2.372178250210668082e+00
+9.428519074525644195e-01
+2.516024716133358208e+00
+3.173357005405610942e-01
+2.952660548568692978e+00
+1.363268908279729741e-01
+6.417730434506982995e+00
+2.527729278270938323e-02
+1.552720762154066403e+00
+5.804936135579450429e+00
+1.385047607413002257e+00
+2.062289662207262175e+00
+1.759740112676018597e+00
+4.297790461315768695e+00
+3.494339002420721041e+00
+7.779047508699438174e-01
+6.133546870807665918e+00
+1.148380196528532515e+00
+1.967728399278650508e+00
+5.405149290397431194e+00
+7.113500860836646789e-01
+1.241186207976585854e+00
+1.504359218668552267e-02
+1.474388622498197021e+00
+3.022153888227849805e+00
+5.445628133836079299e-02
+3.960920553836733138e-01
+4.612436046632942821e-01
+2.592764738552643777e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..27c87b2ee899e9b6aa4d8c3425d5699f448bad8c
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..477037d16e537445f82fed5ea99ddf3bdb1a96be
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad67e56da8504b8bd3f95b72e736b9d06def8672ac948f05f4f2204d187034e4
+size 760888
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..18f7fb65fe31d9058c6d5070c78bb3e696a481b3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81140ccca5518980d3feb3ec30dd56cfe3b6fbde661d89f6e551bd4d1e889168
+size 673034
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be564efc8d3bd58e5330cf5227328a0510c23fb3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/corrs.txt
@@ -0,0 +1,34 @@
+5.632709800000000087e+02 1.313976600000000019e+02 2.893448599999999828e+02 2.536246800000000121e+02
+5.506097899999999754e+02 1.301081000000000074e+02 2.860623299999999745e+02 2.525695900000000051e+02
+4.947821900000000142e+02 1.264932299999999969e+02 2.713383099999999786e+02 2.487299399999999991e+02
+4.522330799999999726e+02 1.441868100000000084e+02 2.543481299999999976e+02 2.505078000000000031e+02
+3.927916500000000042e+02 1.576599899999999934e+02 2.347196800000000110e+02 2.504593900000000133e+02
+3.001216800000000262e+02 1.267567499999999967e+02 2.091220199999999920e+02 2.363021699999999896e+02
+6.090455600000000231e+02 1.744962800000000129e+02 3.033945299999999747e+02 2.699415500000000065e+02
+4.882655399999999872e+02 2.175810700000000111e+02 2.666683499999999754e+02 2.783823499999999740e+02
+7.410004500000000007e+02 2.281611200000000110e+02 3.389677399999999921e+02 2.931838300000000004e+02
+4.153606199999999831e+02 2.576769300000000271e+02 2.458869300000000067e+02 2.926073499999999967e+02
+5.188150200000000041e+02 1.337616899999999873e+02 2.764965799999999945e+02 2.524814000000000078e+02
+6.467946899999999459e+02 1.982946399999999869e+02 3.093734299999999848e+02 2.772100199999999859e+02
+2.725054383200736652e+02 1.379274566053054514e+02 2.004945675835598990e+02 2.362944654439720296e+02
+2.693994122349491249e+02 1.472455348606790722e+02 1.980667729146825877e+02 2.411500547817267943e+02
+4.464826675509639813e+02 1.676949277233762814e+02 2.507904124676491620e+02 2.574015523786398489e+02
+3.146731428626089269e+02 1.491177463914604004e+02 2.128540067898943562e+02 2.443481654787672142e+02
+5.137143714188497370e+02 1.730026938182093090e+02 2.736338395424262444e+02 2.639282458285761663e+02
+5.632535216372918967e+02 1.650410446759597107e+02 2.879109814641619209e+02 2.639282458285761663e+02
+2.863650570234991619e+02 1.655093765814348217e+02 2.046956399774740021e+02 2.474195504486096411e+02
+1.077564275177346786e+02 1.663604524797651720e+02 1.403882031816047515e+02 2.366652162690947137e+02
+2.072420979386583042e+02 1.381041673897987039e+02 1.768318838209144133e+02 2.330608962058662996e+02
+1.672123607278724648e+02 1.434022208441674593e+02 1.616136435539499416e+02 2.318594561847901332e+02
+2.534226267199915483e+01 2.028581540543052029e+02 1.067478825914727736e+02 2.418714563604246734e+02
+1.683897059399544105e+02 1.701348880022814569e+02 1.616136435539499416e+02 2.416358834803155844e+02
+1.370441600788676624e+02 1.634171506956903954e+02 1.503214646539008186e+02 2.360283531679565954e+02
+3.235426251126277748e+02 1.920638149445047986e+02 2.137522429888200293e+02 2.553334511431434635e+02
+2.835902178531248978e+02 1.889905528476199663e+02 2.029687883345565069e+02 2.553334511431434635e+02
+2.594431585204583826e+02 1.911857400596805689e+02 1.961305000172186510e+02 2.550704400540150800e+02
+2.541747092115129476e+02 2.052349382168683860e+02 1.924483447694213680e+02 2.600676507474543087e+02
+3.015907529920217485e+02 2.096253126409895629e+02 2.090180433845091272e+02 2.645388392626367704e+02
+4.578482005041917660e+02 2.124169269111976064e+02 2.558986741134780800e+02 2.746111778115074458e+02
+5.788114731647209510e+02 2.259128292328269083e+02 2.944315520780464226e+02 2.855384118611611939e+02
+5.383237661998331305e+02 1.989210245895683329e+02 2.823540828652712662e+02 2.731733838576056428e+02
+5.978057060618289142e+02 2.034196586967780718e+02 2.975946987766304233e+02 2.763365305561895866e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7533d4e3db031e37eed91a7e7883687beebb63ee
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/alupka/crossval_errors.txt
@@ -0,0 +1,34 @@
+1.035615702231201585e+00
+1.196879631355693974e+00
+6.112053699849899768e-01
+1.006642457920902262e-01
+3.007740766737664884e-01
+1.879998269793030863e+00
+1.811132848182660293e+00
+3.217420525980498436e+00
+6.823652600275327551e-02
+3.178430665916133491e-02
+2.209381490455120112e+00
+2.370019498457026419e+00
+6.773946495656722355e+00
+3.532300831263593288e+00
+2.910156346175509778e+00
+2.171179401819538946e+00
+2.810952725959120269e+00
+2.118168175101782680e+00
+2.200726898729873926e-01
+1.483904459267001741e+01
+1.082194779711510346e+00
+1.053720698982808024e+01
+5.646140718108199508e+00
+9.828258767995476930e+00
+1.611215092864966536e+01
+1.013231868435379290e+01
+1.197727714496128693e+00
+2.927725272338077112e+00
+3.251141531427588660e+00
+3.606444000992193200e+00
+4.698983654702942658e+00
+9.882536248838456050e-01
+1.239308029690608270e+00
+4.261074506023826203e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a161f26bbc1b77cb2eed80c2c211fc0db7f74ba
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bcc33b07aebaf526136256c3ef9f4ea10f45e03ebd495b44a7b0cce02946674c
+size 973773
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..da48563252d1b9c6c09d9f446098a40aa3405765
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e11e32c52622d12eaa12787ab7635eff03683301a947f17c6db7fee3594baf2
+size 753067
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b50bbd7590e4aaba33bff7d776da4e8881fd38e4
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/corrs.txt
@@ -0,0 +1,32 @@
+2.360286199999999894e+02 9.481852299999999900e+01 2.781969799999999964e+02 7.274162200000000666e+01
+4.471002300000000105e+02 1.218223899999999986e+02 4.180615500000000111e+02 8.753985900000000697e+01
+7.609214600000000246e+02 2.900361500000000206e+02 6.275490600000000541e+02 1.801840300000000070e+02
+1.597529800000000080e+02 2.783806000000000154e+02 2.159903700000000129e+02 1.841312299999999880e+02
+4.597864999999999895e+02 3.497691100000000120e+02 4.174805000000000064e+02 2.315654000000000110e+02
+4.714313300000000027e+02 2.345709699999999884e+02 4.179397500000000036e+02 1.500723600000000033e+02
+3.982929700000000253e+02 3.084224100000000135e+02 3.645750899999999888e+02 1.984153599999999926e+02
+2.304207800000000077e+02 1.884263599999999883e+02 2.663360099999999875e+02 1.280013800000000117e+02
+6.925498999999999796e+02 1.784217600000000061e+02 5.756890600000000404e+02 1.011625700000000023e+02
+6.969950800000000299e+02 3.943423099999999977e+02 5.771937299999999595e+02 2.589903800000000160e+02
+2.236513400000000047e+02 3.598176199999999767e+02 2.519356899999999939e+02 2.340100099999999941e+02
+4.623742799999999988e+02 4.189536899999999946e+02 4.180279300000000262e+02 2.751502500000000282e+02
+4.317743100000000140e+02 7.095999299999999721e+01 4.198057999999999765e+02 6.238532299999999964e+01
+4.414981000000000222e+02 3.498863400000000183e+02 4.029435700000000224e+02 2.332066700000000026e+02
+4.704547299999999836e+02 3.528171699999999760e+02 4.251006600000000049e+02 2.340272999999999968e+02
+2.306552499999999952e+02 1.557182699999999897e+02 2.644602800000000116e+02 1.066649199999999951e+02
+2.063879599999999925e+02 1.704896699999999896e+02 2.509784500000000094e+02 1.153401900000000069e+02
+2.550397800000000075e+02 1.704896699999999896e+02 2.761836099999999874e+02 1.149884899999999988e+02
+7.186847736100646671e+02 1.612839620694029463e+02 5.970815920021781267e+02 8.742081692709524532e+01
+6.520464260002647734e+02 1.600024553845991022e+02 5.450039535103417165e+02 9.082776523964528792e+01
+7.242379692442146961e+02 1.873412646604144527e+02 5.995151265111425118e+02 1.066457395479133794e+02
+4.569601499328094860e+02 4.444818995056880340e+02 4.138725246013769379e+02 2.920795103148053613e+02
+4.246167624506033462e+02 3.411289695813454728e+02 3.828439599583084032e+02 2.230627029778865733e+02
+5.301583426556967424e+02 3.416153363104012897e+02 4.538906733933718556e+02 2.221927432215388762e+02
+6.667335001808941115e+02 1.035710570582327819e+02 5.605959728884329252e+02 5.568705899556496774e+01
+2.431912216314050852e+02 3.232394024770093779e+02 2.702069573547900063e+02 2.114583240039939938e+02
+6.277595469798195609e+02 1.869733877933344957e+02 5.310581761039995854e+02 1.130581252647631914e+02
+7.397337531878540631e+02 2.504353068495012735e+02 6.119037385776150586e+02 1.497963717166607296e+02
+4.698633348550070536e+02 2.115440656563695256e+02 4.163448779995451901e+02 1.327446834734667220e+02
+4.874772549326754074e+02 1.687674026106035399e+02 4.422008912230889450e+02 1.149003926572182763e+02
+3.245484942142433056e+02 1.505244139587327936e+02 3.507944219398570453e+02 1.152645618575498929e+02
+1.490383620117622741e+02 3.789503133324984105e+02 2.135026334148435865e+02 2.485933319605100849e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0224275375750f200ecedbc5425595014ed1e3c4
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/berlin/crossval_errors.txt
@@ -0,0 +1,32 @@
+5.191993455264680424e-02
+1.465544064612110287e+00
+5.936445158653792520e-01
+9.030548139514468220e-01
+4.941321813867107782e-01
+8.177992914893557064e-01
+2.908001862198693832e+00
+1.591621936844796747e-01
+3.293545072056112133e-01
+1.623973304517367655e+00
+1.129899579273499599e+00
+3.384341811700881220e+00
+3.831337751029900041e+00
+5.062077129708542955e+00
+1.124576558576982688e+00
+9.286197303745061804e-01
+3.181643372606760334e+00
+3.213996797500379365e+00
+1.887924995137213013e+00
+1.325644623510203290e+00
+1.121749737644925915e+00
+3.512606009773121318e+00
+2.341729310650397800e+00
+2.537510842859704852e+00
+3.453915924287136896e+00
+2.267845496619912193e+00
+2.052534989207430272e+00
+3.696536742617343219e+00
+3.207315167365495157e+00
+8.871948800972830895e-01
+1.100981285460365527e-01
+2.052016565246739255e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..615009b8f25cc3e32c086e37e6f4382cb14afbf1
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ffd41d1811e888a0a0c842b67735eb00aa5184558a48949783f107fb41616df
+size 402968
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..24681732835bf9bcf421dac2d071a0b5addaae7f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81c79e7aad63e61801dcbe3d2acd0fdb5d6389f1d36c3afcfd651926c1848ebb
+size 607963
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..701ba26d2e27c0c1087836983d9f5bd3bd59174a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/corrs.txt
@@ -0,0 +1,40 @@
+1.958494000000000028e+02 6.222665799999999692e+00 3.766220999999999890e+02 4.925415799999999678e+01
+3.864019000000000119e+02 2.038731400000000065e+02 4.052995300000000043e+02 1.847192300000000103e+02
+3.676445800000000190e+02 1.994182800000000100e+02 3.763429100000000176e+02 1.814366899999999987e+02
+2.598528599999999855e+01 2.009665200000000027e+02 1.552787399999999991e+02 2.763817199999999730e+02
+1.751416399999999953e+02 1.976200100000000077e+02 1.787758900000000040e+02 2.392446999999999946e+02
+2.103825199999999995e+02 1.999307900000000018e+02 1.993615700000000004e+02 2.291325500000000090e+02
+1.947942999999999927e+02 1.565561999999999898e+02 3.251428599999999847e+02 2.015612900000000138e+02
+1.921269800000000032e+02 2.566903699999999731e+02 1.854097700000000088e+02 3.081927499999999895e+02
+2.822226699999999937e+02 2.667928400000000124e+02 3.187434400000000210e+02 3.170346900000000119e+02
+3.175098899999999844e+02 2.493250800000000140e+02 3.567270300000000134e+02 2.835059699999999907e+02
+2.488111800000000073e+02 2.758197999999999865e+02 2.940072200000000180e+02 3.350886199999999917e+02
+2.768154099999999858e+02 3.384818599999999833e+02 3.104392500000000155e+02 4.314336299999999937e+02
+2.413777599999999950e+02 2.045910700000000020e+02 2.937018600000000106e+02 2.384143899999999974e+02
+3.662948400000000220e+02 2.708717899999999759e+02 3.770028500000000236e+02 3.046951099999999997e+02
+1.327947999999999951e+01 2.718929999999999723e+02 1.429682099999999991e+02 3.506149399999999901e+02
+1.903684800000000052e+02 3.095667300000000068e+02 1.749459200000000010e+02 3.826023400000000265e+02
+1.946819099999999878e+02 9.647915500000000577e+01 3.337697099999999750e+02 1.337759000000000071e+02
+1.952680799999999977e+02 5.459955200000000275e+01 3.610968199999999797e+02 9.506643599999999594e+01
+1.860018000000000029e+02 1.009136499999999970e+01 3.681813099999999963e+02 5.652262199999999837e+01
+3.155240499999999884e+02 2.760003899999999817e+02 3.526113700000000222e+02 3.268587299999999800e+02
+3.825456899999999791e+02 2.486293200000000070e+02 3.975715099999999893e+02 2.614097300000000246e+02
+2.244132299999999987e+02 9.858509200000000305e+01 3.674812400000000139e+02 1.278130100000000056e+02
+2.366054900000000032e+02 1.043295200000000023e+02 4.054648300000000063e+02 1.343780800000000113e+02
+2.180582805126952053e+02 2.673983552561883243e+02 2.153567351223074979e+02 3.225453098178572873e+02
+2.191047260962589576e+02 2.520504866972534046e+02 2.158411780145567604e+02 3.031675941278855930e+02
+2.187559109017377068e+02 3.092561785987379608e+02 2.076056488463187861e+02 3.816473426722710656e+02
+2.476600803710741729e+02 2.496246586406160759e+02 2.962334247218788050e+02 2.989252798449094826e+02
+2.762871664014870703e+02 2.492479864560053784e+02 3.184615294427110257e+02 2.939488384894992805e+02
+2.864573153859758463e+02 2.575347745174407237e+02 3.257603100973126402e+02 3.042334839573470049e+02
+2.480367525556848705e+02 2.571581023328300262e+02 2.959016619648515416e+02 3.098734508268119043e+02
+3.039725719703732238e+02 2.496246586406161327e+02 3.430119734627346588e+02 2.879771088630070608e+02
+1.931669929845307934e+02 1.731478943608864256e+02 3.276824680474070419e+02 2.227384219525362141e+02
+1.931669929845307934e+02 1.671728696567184898e+02 3.248873840844720462e+02 2.149121868563181579e+02
+2.208014822413075535e+02 1.753885286249494015e+02 3.623415091878012504e+02 2.193843211970141738e+02
+2.252827507694335338e+02 1.682931867887499777e+02 3.612234756026272180e+02 2.093220189304481664e+02
+2.472352169056881053e+02 3.121931499331572013e+02 2.880808778718156304e+02 3.891075898947697169e+02
+2.754026383392823618e+02 3.121931499331572013e+02 3.105887966431296832e+02 3.881498061172669622e+02
+2.858350166480209396e+02 3.124539593908756387e+02 3.211244181956597004e+02 3.881498061172669622e+02
+2.521905966023389283e+02 3.234079566150512051e+02 2.943064724255833653e+02 4.068265897785701100e+02
+3.922452753971548418e+02 2.490772611652885473e+02 4.145083365021756663e+02 2.622012393756584743e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9a806d88120fb851486d0a0f937bf4095df7b301
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/charlottenburg/crossval_errors.txt
@@ -0,0 +1,40 @@
+1.619268508549998931e+00
+3.886848298702200832e+00
+2.588669894928778081e+00
+4.542110231958436106e-01
+3.446516145153667221e+00
+3.679736430290312832e+00
+8.399536963373055790e-03
+2.585070026123002407e+00
+6.175314108689570203e-01
+9.191016794758739561e-01
+1.035661036366828158e+00
+1.060660686176270406e+00
+6.119007703997929593e-01
+1.708246958790519354e+00
+6.962075281211588251e-01
+6.766506035606598690e-01
+1.950827625135682419e+00
+1.401927986179450983e-02
+7.156366045273916676e-02
+3.324444405591722163e-01
+6.130206632473679251e-01
+6.648246702775817418e-01
+1.672542182573007796e+00
+3.336871484617892625e-01
+3.435317825533382941e+00
+1.453007621079772382e-01
+3.963916321132272130e-02
+1.382163780362680061e+00
+1.401298777655322736e+00
+5.694420338034895668e-01
+4.020510732517361130e-01
+1.105574600334674251e+00
+5.122245513220390345e-01
+1.421834385649536903e-01
+9.107395820013965970e-01
+5.779344119513706302e-01
+4.792392571718611105e-01
+6.734883886232745365e-01
+1.843982514331961031e-01
+2.359130939770464508e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..69f2e2915d7f69643463008d672b15481da85b6a
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7ca21ddaf898862d1241f55198ea8bf660b77213ca0a3a6929064358fe74e34
+size 493636
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..05dd88b5d711c1c21952cd970a1d579ae5db92f0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:729bbec0906acedc9a489d40c73d13d101777bb9d2c2ee48b4fa901762063407
+size 341760
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6bdfad54d2f1b131de03c827d8c1db96833e32b3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/corrs.txt
@@ -0,0 +1,32 @@
+5.782488299999999981e+02 3.321297799999999825e+02 2.645256800000000226e+02 2.530823000000000036e+02
+5.151237200000000485e+02 3.282901299999999765e+02 1.738835500000000138e+02 2.547565199999999948e+02
+5.437307200000000194e+02 2.607585700000000202e+02 2.158755600000000072e+02 1.678416599999999903e+02
+4.656277600000000234e+02 1.522212900000000104e+02 1.256151799999999952e+02 5.771545400000000114e+01
+5.945356100000000197e+02 1.904791199999999947e+02 3.027003100000000018e+02 7.952084600000000592e+01
+6.255754600000000210e+02 2.941043500000000108e+02 3.369732999999999947e+02 2.012910899999999970e+02
+4.581635600000000181e+02 2.818820000000000050e+02 9.779877999999999361e+01 2.046366099999999904e+02
+4.146572199999999953e+02 3.733990099999999757e+02 6.323125600000000190e+01 3.186344300000000089e+02
+5.010740599999999745e+02 4.452488299999999981e+02 4.459785200000000316e+01 4.015304100000000176e+02
+5.371559700000000248e+02 3.425732199999999921e+02 2.075248799999999960e+02 2.702041899999999828e+02
+5.083822900000000118e+02 3.983911400000000071e+02 1.642316299999999956e+02 3.398770099999999843e+02
+5.221781200000000354e+02 3.980588000000000193e+02 1.852397300000000087e+02 3.401686300000000074e+02
+5.729335600000000568e+02 3.989966699999999946e+02 2.583335900000000152e+02 3.389691799999999944e+02
+6.058200900000000502e+02 3.091112499999999841e+02 3.082761899999999855e+02 2.216287499999999966e+02
+4.797849699999999871e+02 2.989894100000000208e+02 1.290178600000000131e+02 2.216258300000000077e+02
+6.343098599999999578e+02 3.721481699999999933e+02 3.581887500000000273e+02 3.003527300000000082e+02
+5.014257600000000252e+02 4.854920099999999934e+02 4.433334700000000339e+01 4.526695399999999836e+02
+3.656204999999999927e+02 3.221445400000000063e+02 2.374288500000000024e+00 2.647601399999999785e+02
+4.079351500000000215e+02 2.986570800000000077e+02 8.063225099999999657e+01 2.342744999999999891e+02
+5.545559500000000526e+02 2.884277000000000157e+02 2.333967100000000130e+02 2.012610500000000116e+02
+5.073271899999999732e+02 2.014827899999999943e+02 2.081594499999999925e+02 1.116769400000000019e+02
+4.988272600000000239e+02 2.918962799999999902e+02 1.504106199999999944e+02 2.125329600000000028e+02
+6.400349200000000565e+02 4.073600000000000136e+02 3.692232700000000136e+02 3.471842700000000264e+02
+4.092644799999999918e+02 3.879359299999999848e+02 5.773772699999999958e+01 3.332048700000000281e+02
+4.194637799999999856e+02 3.879359299999999848e+02 6.793702399999999386e+01 3.333220999999999776e+02
+3.642394613790673361e+02 3.767043016888868578e+02 6.488157981007134367e-01 3.242811770475954631e+02
+3.751359519042346164e+02 3.569291892543239442e+02 1.372472163530062517e+01 3.012426762868147421e+02
+3.844181475367845451e+02 3.476469936217740155e+02 2.368731655888154819e+01 2.887894326323386167e+02
+4.130718818807430353e+02 3.658078111637195207e+02 6.166970970503348326e+01 3.087146224795003491e+02
+5.707847220192514897e+02 2.695261601368335391e+02 2.558068126731624261e+02 1.733098156215094718e+02
+6.054041967536991251e+02 2.711417356244410826e+02 3.075495409571985874e+02 1.757406820509608281e+02
+6.176364111598705904e+02 2.692953636386038738e+02 3.252601392289156479e+02 1.715734824576156257e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..38d7fa5119ff211ac4f06a1f7cd1eee30ea7534f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/church/crossval_errors.txt
@@ -0,0 +1,32 @@
+8.715988390665808661e-01
+1.187855777575472560e+00
+1.488191769406739251e+00
+1.540662910329874125e+00
+2.503540793373755236e+00
+2.841651356896309100e+00
+2.629872184775364219e-01
+3.090908285052762494e+00
+2.992049923663860245e+00
+9.722442946457870994e-01
+4.463918797905627178e-01
+1.835074225796184066e+00
+2.557828320585105786e+00
+1.158792722156904764e+00
+1.650155615904922879e+00
+1.934096082935798488e+00
+4.625630744759416935e+00
+4.197963781019461216e+00
+5.461443548326146979e-01
+1.679977028297218733e-01
+9.915869307461986359e-01
+1.248967533235761707e+00
+3.646619983191577319e+00
+9.866973171593411696e-01
+2.850378296718754090e-01
+1.203073091924289351e+00
+1.284716092338700877e+00
+4.604175489226451368e-01
+1.112167057576260554e+00
+3.792437685853049967e+00
+3.502812721344579217e+00
+3.510454335316306018e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..66146c46a4400256843da6a8287d1d2556dea107
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fe6a0defd25979bb53bd4d7842c3b6e2730487764122463bcbee3e45fdf8f4d
+size 597486
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..ec411be52ea5df3a9af7358be8640d44f61370f6
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:430b8c56acf2d7032aa3b4774784d82a5846ac3cdd2af66fa4b2386fd477900e
+size 705899
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e1f60a837fe1399550c2af4081806d6ae82fffd4
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/corrs.txt
@@ -0,0 +1,35 @@
+3.223486300000000142e+02 2.839303400000000011e+02 2.607185000000000059e+02 2.346901599999999917e+02
+3.738558800000000133e+02 3.103271899999999732e+02 2.862871099999999842e+02 2.504682400000000086e+02
+5.751319499999999607e+02 2.281521899999999903e+02 4.111983200000000238e+02 2.059368800000000022e+02
+6.758038699999999608e+02 2.683759999999999764e+02 4.804883699999999749e+02 2.325014600000000087e+02
+6.953925500000000284e+02 1.296883400000000108e+02 5.107850500000000125e+02 1.445208199999999863e+02
+4.742061899999999923e+02 2.980799700000000030e+01 3.646923199999999952e+02 6.850869000000000142e+01
+2.896083800000000110e+02 7.866834099999999808e+01 2.492576500000000124e+02 9.834715400000000329e+01
+1.757568699999999922e+02 9.239535600000000670e+01 1.770270700000000090e+02 1.072966499999999996e+02
+9.089311700000000371e+01 1.456234200000000101e+02 1.333346799999999917e+02 1.387957499999999982e+02
+4.809985499999999803e+01 3.841322900000000118e+02 1.161294000000000040e+02 2.905360800000000268e+02
+1.514875099999999861e+02 5.156977799999999661e+02 1.701780899999999974e+02 3.710622099999999932e+02
+2.158129700000000071e+02 4.196958200000000261e+02 2.132144700000000057e+02 3.131641799999999876e+02
+1.912347800000000007e+02 3.351159799999999791e+02 1.951712599999999895e+02 2.637563999999999851e+02
+4.668785899999999742e+02 3.555844299999999976e+02 3.098036299999999983e+02 2.781965000000000146e+02
+5.274342500000000200e+02 3.454529099999999744e+02 3.402573100000000181e+02 2.732823900000000208e+02
+4.883730899999999906e+02 1.831799100000000067e+02 3.679458000000000197e+02 1.750723600000000033e+02
+1.623815600000000074e+02 2.169064400000000035e+02 1.714579799999999921e+02 1.866988599999999963e+02
+5.728382100000000321e+01 2.690503600000000120e+02 1.178298000000000059e+02 2.197908099999999934e+02
+6.048134099999999762e+02 2.930939000000000192e+02 4.289413499999999999e+02 2.442742399999999918e+02
+6.595461400000000367e+02 4.517810799999999745e+02 4.601274799999999914e+02 3.389769499999999880e+02
+7.747463400000000320e+02 3.500443799999999896e+02 5.464174000000000433e+02 2.836386800000000221e+02
+3.286211299999999937e+02 3.751343800000000215e+02 2.635514600000000200e+02 2.878494000000000028e+02
+4.516448399999999879e+02 1.776312200000000132e+02 3.483066400000000158e+02 1.717898299999999949e+02
+2.345733999999999924e+02 5.146814100000000280e+02 2.136747499999999889e+02 3.677011800000000221e+02
+6.956982956609938356e+02 4.448380497371611000e+02 4.897555306141204028e+02 3.345909281372238979e+02
+6.850766939846554351e+02 3.374418550097398111e+02 4.905243322592649520e+02 2.730867965256645675e+02
+7.794909311076631866e+02 4.082525328519956247e+02 5.458780507096682868e+02 3.153708870086115894e+02
+7.755681227403512139e+02 3.093362791412923798e+02 5.515505328082300593e+02 2.611494394884711028e+02
+7.633183634593142415e+02 2.678232060222228483e+02 5.455192559794936642e+02 2.326013958324519422e+02
+8.911064507318195638e+01 3.751099451192026208e+02 1.393658564074330002e+02 2.864113796803452487e+02
+1.193134095552586729e+02 3.668285419547622155e+02 1.526517208245552411e+02 2.806802224808023993e+02
+1.422090535981232620e+02 4.822810448943134816e+02 1.659375852416775388e+02 3.491936017298837100e+02
+1.378247813345959969e+02 4.939724375970528740e+02 1.638535280782073755e+02 3.562272946565955181e+02
+7.790639373305670290e+01 4.257726468310732457e+02 1.193068062090328141e+02 3.174117299869639055e+02
+5.257504287712141888e+01 4.272340709189156769e+02 9.820572742889750373e+01 3.174117299869639055e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6bfcc6f9edeb89e6ef7afa5a15d6581b2c8d85cb
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/him/crossval_errors.txt
@@ -0,0 +1,35 @@
+2.043379849940813120e+00
+8.169318573230145430e-01
+3.638987395587739204e+00
+2.649559860979942894e+00
+9.965473481391702304e-01
+7.991661399239784025e+00
+3.273218782217096612e-01
+6.738468738498821331e+00
+7.367010900988238964e-01
+1.904385438409952336e-01
+7.762423758884946956e+00
+2.328785188654431959e+00
+1.671697819648116257e+00
+1.777887754345692795e+00
+1.656063175527794584e+00
+2.665111671711846153e+00
+2.313613197174701597e+00
+1.391896839663567587e+00
+1.051739843204129166e+00
+3.623669403787067367e+00
+3.950511428855001395e+00
+2.630721197748733697e+00
+5.315878363056087963e+00
+1.207551898246299604e+00
+2.257441047339381335e-01
+4.487693040202633266e-01
+1.532589474507056217e-02
+7.002406873062206216e+00
+3.276273526147291948e-01
+1.302720833184545457e+00
+9.656081119072771335e-01
+1.766602521388588656e-02
+9.723366725220530249e-01
+4.123774285849783894e+00
+3.611205377490628088e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..71ba98f2229cb32a4c44658dad6063580745c0d1
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d1c8df896d59d5afe51af5c2cd9e94c20fc41272179aa8118a4056092f92c74c
+size 822512
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..c3c8bb78563d97117b17de9da23a6b2e07e89788
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3563b4ef12305c69c89249264f4d66c72c7cd0558da2fbf5e6c73a2d91562f2
+size 697851
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/03.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/03.png
new file mode 100644
index 0000000000000000000000000000000000000000..51128df81280589a20ea36fb4cb8c350933e0036
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/03.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1ed226c9076112816b88b2495fbdfa6ef9b8cdba1bf89980e31c3c71722247c
+size 730090
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9c73430138216837f251073fca764b5cd41e7aa3
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/corrs.txt
@@ -0,0 +1,27 @@
+4.280615799999999922e+02 1.554128499999999917e+02 4.518512200000000121e+02 2.504033300000000111e+02
+4.847582300000000259e+02 1.798080899999999929e+02 5.348963999999999714e+02 2.607818399999999883e+02
+2.941127000000000180e+02 1.590083500000000072e+02 2.959175900000000183e+02 2.538389400000000080e+02
+2.353940299999999866e+02 1.586372800000000041e+02 2.221174000000000035e+02 2.382013699999999972e+02
+2.054092899999999986e+02 1.781485000000000127e+02 1.566796099999999967e+02 2.132441300000000126e+02
+1.970556400000000110e+02 1.930091099999999926e+02 1.424913999999999987e+02 2.384619700000000080e+02
+5.121046999999999798e+02 1.857589800000000082e+02 6.022913099999999531e+02 2.330498299999999858e+02
+2.384227300000000014e+02 3.142075800000000072e+02 9.785242499999999666e+01 3.494049200000000042e+02
+4.045547500000000127e+02 2.504769400000000132e+02 4.488389900000000239e+02 3.224188800000000015e+02
+3.780600299999999834e+02 2.505941699999999912e+02 4.038214100000000144e+02 3.221844100000000140e+02
+2.540943000000000040e+02 1.889834099999999921e+02 2.462995600000000138e+02 2.800431800000000067e+02
+4.550273099999999999e+02 1.833949499999999944e+02 4.811752000000000180e+02 2.812610599999999863e+02
+4.109541699999999764e+02 1.833164400000000001e+02 4.311599299999999744e+02 2.832995700000000170e+02
+3.209515099999999848e+02 1.857987500000000125e+02 3.270717599999999834e+02 2.852663999999999760e+02
+3.323552999999999997e+02 2.624831500000000233e+02 3.413861299999999801e+02 3.458199900000000184e+02
+4.950069700000000239e+02 2.153823399999999992e+02 5.528816199999999981e+02 3.004902900000000159e+02
+3.173656899999999723e+02 3.218481499999999755e+02 2.990111999999999739e+02 3.575231400000000122e+02
+3.194274800000000027e+02 2.037354400000000112e+02 3.261338900000000081e+02 3.078924200000000155e+02
+2.290635804859078348e+02 1.900059737384212326e+02 2.118688394443077527e+02 2.747577352684748462e+02
+2.183714045427048518e+02 1.799618084584426754e+02 2.000627579024160241e+02 2.638261782852417809e+02
+2.433198150768451455e+02 1.994021283551753640e+02 2.306711174554686181e+02 2.909364396036597782e+02
+2.373870513700753691e+02 2.957831966989249395e+02 9.725213989465640907e+01 3.038423278124165563e+02
+5.110142288247588453e+02 1.927576376379860790e+02 6.016781362971241833e+02 2.409929073642294952e+02
+4.989069086051293880e+02 1.915706454595910202e+02 5.583054420985690740e+02 2.691103780860514689e+02
+2.920441067044379224e+02 3.204059275408216081e+02 2.264041560013508274e+02 3.558669926242116617e+02
+3.894701654697320805e+02 2.555707881893844444e+02 4.243882654767081135e+02 3.327558240056685008e+02
+3.573868057796908602e+02 3.125699484897768912e+02 4.077506998783375707e+02 3.483199982751119705e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..214a9e6dc7aa32cf5b5de3fbddc9a24173f7062f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/maidan/crossval_errors.txt
@@ -0,0 +1,27 @@
+9.619610733823423521e-02
+2.222023173583103528e+00
+3.603161857524858069e-01
+3.100090121775084251e-01
+3.417931173654581212e+00
+2.637932478789538404e+00
+4.066589147187809616e-01
+5.500156820911586308e+00
+3.241004267474435263e+00
+4.430177782216134119e+00
+1.294956167617228537e+00
+4.710739887513437196e+00
+2.130034143130785651e-01
+1.209101134344288342e+00
+5.132132907292509927e+00
+1.079070568921344231e+00
+9.880345585358989879e+00
+3.341722142539176765e+00
+2.132997868911963302e-01
+3.638139721764838130e-01
+8.544621529274050165e-01
+8.368403497652689538e+00
+1.494648836718671658e+00
+1.149710096719515517e+00
+7.077658455701806517e+00
+2.806159435480652053e+00
+8.022198499442469100e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..b2b3c9292b9ab6fd8e3be009edb3c0ff36a2dec7
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24ad7fea46687df48cfff5c5d8b7fec4c8394a3fe8aba5958443ad1e30840bd0
+size 701851
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..92f4b7dfa6ed91009b747063dd129aa50fde5236
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ddfd13cf66c2d2ee29222e7fd17a65bbe1314d0d9f2633bb6890f8500d6a0a1
+size 798718
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..261f92cdaba758ab740a8394f9a2e1db3dc4100d
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/corrs.txt
@@ -0,0 +1,20 @@
+5.382691700000000310e+02 3.201234299999999848e+02 8.401269600000000537e+02 2.890135200000000282e+02
+5.136501799999999776e+02 3.198889699999999721e+02 7.974540399999999636e+02 2.904203200000000038e+02
+1.237602799999999945e+02 3.225853299999999990e+02 2.599307699999999954e+02 3.111280600000000049e+02
+1.537719999999999914e+02 3.224680999999999926e+02 2.903263200000000097e+02 3.092262000000000057e+02
+1.461141399999999919e+02 3.159998600000000124e+02 3.083832499999999754e+02 3.169733100000000263e+02
+1.357277599999999893e+02 4.126277600000000234e+02 2.628877299999999764e+02 4.291835699999999747e+02
+3.249277599999999779e+02 3.920216699999999719e+02 4.928279400000000123e+02 4.026567099999999755e+02
+3.176001600000000167e+02 3.420965499999999793e+02 4.826547699999999850e+02 3.305104000000000042e+02
+3.671542200000000093e+02 3.417739000000000260e+02 5.515446899999999459e+02 3.293574800000000096e+02
+3.581272599999999784e+02 3.794777100000000019e+02 5.398474899999999934e+02 3.856578900000000090e+02
+5.280300999999999476e+02 3.967701400000000262e+02 8.180482600000000275e+02 4.170241500000000201e+02
+1.984237099999999998e+02 3.644407299999999736e+02 3.422599599999999782e+02 3.659925299999999879e+02
+2.103922200000000089e+02 3.919926199999999881e+02 3.558881600000000276e+02 4.075805100000000039e+02
+4.188388800000000174e+02 3.609237299999999777e+02 6.433891999999999598e+02 3.599419500000000198e+02
+5.105621300000000247e+02 3.594185200000000009e+02 7.931084399999999732e+02 3.569712799999999788e+02
+3.755597900000000209e+02 3.891606100000000197e+02 5.768200500000000375e+02 4.055832399999999893e+02
+3.593878532316947485e+02 3.070974129196879403e+02 5.717608921013747931e+02 2.883844101883552185e+02
+3.098401408430057700e+02 3.068076602156605190e+02 5.045165838838461809e+02 2.903391865900275661e+02
+3.417129382860220517e+02 2.931892831263717198e+02 5.522131280846514301e+02 2.680547356109628936e+02
+3.379461531336655753e+02 3.105744453680169954e+02 5.447849777582964634e+02 2.938577841130378374e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..499b15ab86c373201270fa389256c34de410fc1e
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/ministry/crossval_errors.txt
@@ -0,0 +1,20 @@
+1.045113346684135858e-01
+5.782519597789786969e-01
+1.288780350433794464e+00
+7.571290543942928997e-01
+2.320711714614160037e+00
+4.092129552015536298e+00
+1.144462481685707411e+00
+7.868003583494459496e-01
+1.517594664920262293e+00
+5.049213364468354559e-01
+5.365094462105434170e-01
+2.255850096218184664e+00
+6.997883488536579266e-01
+2.091201714459292038e+00
+3.756746109812822088e+00
+2.731580148502795180e-01
+2.652923335284723461e+00
+3.001299658395643888e-02
+1.538170234708976736e+00
+6.465584911931659962e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..0fe8b56b9d4182e98e519a53cf45c89cd81ca05f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c23e0225a23cca98354330e16bb25b087521cfdc02e260157917a73d394943b5
+size 455694
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..eed716c239494c2a8d858ac87f3aa15cc3daf798
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0a45c8c303c9519f388b4d0a020cc42ec1568f60fb264e1bab1c37ba36071a47
+size 440407
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..98548c1c0fed67b5e082bbe57de6c0f0cd217708
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/corrs.txt
@@ -0,0 +1,28 @@
+2.939664000000000215e+02 3.045692700000000173e+02 4.071241999999999734e+02 2.350125999999999920e+02
+2.789837099999999737e+02 2.884094099999999798e+02 3.715412999999999784e+02 2.365269499999999994e+02
+5.225245799999999718e+01 2.785376300000000072e+01 1.143709000000000060e+02 2.453592199999999934e+02
+3.934123499999999751e+01 8.992861499999999353e+01 1.233397599999999983e+02 2.691005200000000173e+02
+3.389435199999999782e+02 4.382933800000000133e+02 4.989929099999999949e+02 3.114054800000000114e+02
+3.443217300000000023e+02 3.877132700000000227e+02 4.141194600000000037e+02 2.683777499999999918e+02
+3.975691499999999792e+02 2.972368299999999977e+02 4.135139300000000162e+02 1.597436500000000024e+02
+2.289945999999999913e+02 3.682853900000000067e+02 3.708185399999999845e+02 3.206088100000000054e+02
+1.120068500000000000e+02 2.817852100000000064e+02 2.133036899999999889e+02 3.207066800000000057e+02
+1.708717900000000043e+02 1.632734199999999873e+02 2.057512999999999863e+02 2.416852199999999868e+02
+2.271479199999999992e+02 2.101553299999999922e+02 2.645968799999999987e+02 2.313095500000000015e+02
+3.922296799999999735e+02 3.802393799999999828e+02 5.436231699999999591e+02 2.153056500000000142e+02
+3.574577300000000264e+02 3.783975399999999922e+02 4.509628799999999842e+02 2.486417500000000018e+02
+3.725721800000000030e+02 3.528192399999999793e+02 4.507284099999999967e+02 2.180116999999999905e+02
+3.500021800000000098e+02 4.498192399999999793e+02 5.468960200000000214e+02 3.127756200000000035e+02
+1.992346400000000131e+02 3.168652799999999843e+02 3.004045600000000036e+02 3.025635300000000143e+02
+1.692256899999999860e+02 2.660894400000000246e+02 2.471184000000000083e+02 2.895495999999999981e+02
+3.875780399999999872e+02 3.331240500000000111e+02 4.582313399999999888e+02 1.899929500000000075e+02
+3.346446199999999749e+02 4.545085700000000202e+02 5.133672900000000254e+02 3.283676500000000260e+02
+1.324547699999999963e+02 2.970682899999999904e+02 2.345914400000000057e+02 3.211758500000000254e+02
+1.349730299999999943e+02 2.913113200000000091e+01 1.460674099999999953e+02 2.118581499999999949e+02
+3.387018081266798504e+02 4.211557926482452103e+02 4.577099168638121682e+02 2.986015155657120772e+02
+3.576395934174669833e+02 4.358770839507731694e+02 5.373521867041577025e+02 2.962761938185487338e+02
+2.004951893092595014e+02 1.164639536435136620e+02 2.074144164015886247e+02 2.075142798410796274e+02
+2.348958052953859124e+02 1.630370952862693912e+02 2.508620034821227591e+02 2.048542234892101987e+02
+2.470683309520152591e+02 1.805020234023028252e+02 2.703690833958319217e+02 2.037458666759312678e+02
+2.486560516898364881e+02 2.022008734858594607e+02 2.834476937925233528e+02 2.126127211821627156e+02
+1.973197478336170718e+02 1.947915100426937443e+02 2.373400503601197897e+02 2.407649842394475854e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..45575c3a43a37bd1cce73af0af5e68337390ba23
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/silasveta2/crossval_errors.txt
@@ -0,0 +1,28 @@
+4.035863378755469855e-01
+9.721997233306670649e-01
+1.633407097610160008e+00
+2.340841776830394072e-01
+3.307562772547593677e-01
+7.631667399378996297e-01
+6.225453458632575376e+00
+1.546114169351287426e+00
+6.222134223220379123e-01
+2.352856144041039210e-02
+3.672563650444910649e-02
+1.676348197120265171e+00
+6.024728066961207995e-01
+1.600202597754185385e+00
+1.452115718380216469e+00
+2.531982995760208022e-01
+6.108426706884196866e-01
+2.584720721075003613e+00
+1.257587290875826991e+00
+1.795838969385847195e+00
+4.973690863625047420e+00
+2.269005600197053329e+00
+2.040850506195554637e+00
+1.557553158164507678e+00
+2.292730519237752285e-01
+7.590197588485040336e-01
+6.090837038593645003e-01
+5.670417249077306376e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..09954993abacdecd352ae6e8a1a0371ff7c41024
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33b085670bc7a1f0c584de21d7ddc720bb4c758a00b58ec781567cccc73b6ebe
+size 560605
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..20aa465c068c1365a8855a88db7c6f2a72361231
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07f191bb344878e5f8b78f8b41cf81d5f34363227d44cea390aa6e2b9aa271d0
+size 671791
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..59b10f1baad7423b4d2401478e040780edbff085
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/corrs.txt
@@ -0,0 +1,29 @@
+3.627090400000000159e+02 2.702030500000000188e+02 1.724055199999999957e+02 2.529717200000000048e+02
+4.603232100000000173e+02 2.732317499999999768e+02 2.515151700000000119e+02 2.542031900000000064e+02
+3.189187699999999950e+02 2.363893700000000138e+02 3.194758899999999926e+02 2.518294699999999864e+02
+5.270825499999999693e+02 2.679455300000000193e+02 4.677583599999999819e+02 2.706653000000000020e+02
+5.266039399999999659e+02 3.173941199999999867e+02 4.673281600000000253e+02 3.058190299999999979e+02
+3.091980899999999792e+02 3.267136399999999981e+02 3.137508300000000077e+02 3.146115300000000161e+02
+6.601129399999999805e+02 2.463251500000000078e+02 5.610241800000000012e+02 2.537342500000000030e+02
+2.104019000000000119e+02 1.983037699999999859e+02 2.456234200000000101e+02 2.264652399999999943e+02
+3.583122700000000123e+02 1.349441999999999950e+02 1.771539899999999932e+02 1.422758300000000133e+02
+4.208415299999999775e+02 1.127763800000000032e+02 2.445876800000000060e+02 1.253243800000000050e+02
+4.750268300000000181e+02 2.091580500000000029e+02 4.303675099999999816e+02 2.300800999999999874e+02
+6.398789600000000064e+02 2.135752200000000016e+02 5.467690999999999804e+02 2.309007400000000132e+02
+4.855788600000000201e+02 3.579385100000000079e+02 4.400204100000000267e+02 3.348551899999999932e+02
+3.184401599999999917e+02 3.619534899999999880e+02 3.223776799999999980e+02 3.390755899999999770e+02
+3.963968300000000227e+02 2.484256699999999967e+02 2.047155700000000138e+02 2.361374999999999886e+02
+4.147153099999999881e+02 1.501392199999999946e+02 2.245483999999999867e+02 1.525546600000000126e+02
+6.291612800000000334e+02 2.689328499999999735e+02 5.399007500000000164e+02 2.703921000000000276e+02
+7.098627300000000560e+02 3.155754400000000146e+02 5.514003300000000536e+02 2.971717800000000125e+02
+7.261398299999999608e+02 2.930355299999999943e+02 6.026451100000000451e+02 2.853398599999999874e+02
+5.579149099999999635e+02 2.668904299999999807e+02 4.895637500000000273e+02 2.692585000000000264e+02
+8.792362759200152311e+01 3.270246234764344990e+02 1.419145446904389019e+02 3.147235164869259165e+02
+2.207784042665544177e+01 2.922096095729181684e+02 4.256379174071378202e+01 2.857788089346437914e+02
+8.716677946366421281e+01 2.853979764178823757e+02 1.426968340837437381e+02 2.849965195413388983e+02
+1.136564639554701017e+02 2.687473175944615491e+02 1.653832264895864910e+02 2.740444680350699969e+02
+3.570110673672705559e+01 2.384733924609690803e+02 9.497718109214360993e+01 2.529226544158370871e+02
+6.768215978607523766e+02 2.604095934004204196e+02 5.730242994595705568e+02 2.641227329797326320e+02
+7.547998898712631899e+02 3.500846292125078776e+02 6.176350265001917705e+02 3.257280227024951955e+02
+7.415435802294763334e+02 3.126550490474626827e+02 6.105539587159662460e+02 2.981118583440154453e+02
+7.470020606702121313e+02 2.526117641993693610e+02 6.225917739491496832e+02 2.577497719739296258e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aab023e8dd2cef9877da7fbebae616534e2d4b0c
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGLBS/warsaw/crossval_errors.txt
@@ -0,0 +1,29 @@
+8.374156557111057664e-01
+1.328765030518358758e-01
+9.603049900896819535e-01
+1.439087407591574275e+00
+7.279438076244393319e-01
+9.105730160893195091e-01
+5.490532157496890164e-01
+7.626828820733734526e-02
+1.322878606601286089e+00
+8.847866436993027106e-01
+1.026845569665250979e+00
+1.485877897448205642e-01
+1.937133625975002438e-01
+1.459832159161253751e+00
+2.039198747794767730e+00
+1.901781580371950442e+00
+7.648652934105401036e-01
+7.588993994218221628e-01
+8.633612883264278892e-01
+1.994298324411831080e+00
+4.200922354091934374e-01
+2.263208001745716125e+00
+1.945820659482044857e+00
+1.067577664256973868e+00
+1.035513580979050774e+00
+4.557118879492584318e-01
+1.209804537310066763e+00
+2.390102886096314716e+00
+1.692217931017956589e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..b8c2731f8b256d6eb9486debd49e4424f5ae793e
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7cde3fbbd4a580b626e23b3d893fef722a0f49b248084d97e62a2373a515d6c3
+size 33920
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..6cf80e281f3483e87a7c64f52d5e6e4ac77feae0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a7d3befa0dd5d7f385133211d360842177e899c75ea1b3e86fbccf1d6340efd
+size 571399
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..610b18654ece85e6dfff13d32904b048f54888f5
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/corrs.txt
@@ -0,0 +1,16 @@
+9.781961300000000392e+01 1.640526199999999903e+02 3.646632700000000114e+02 3.071947599999999738e+02
+5.911208700000000249e+01 1.617282899999999870e+02 2.837176600000000235e+02 3.020257799999999975e+02
+9.678194299999999828e+01 1.289415600000000097e+02 3.626208500000000186e+02 2.492279200000000117e+02
+1.213090100000000007e+02 4.938435900000000345e+01 3.456403599999999869e+02 9.959778400000000431e+01
+1.372421600000000126e+02 9.650363299999999356e+01 3.866751199999999926e+02 1.909491399999999999e+02
+1.627041900000000112e+02 7.731990999999999303e+01 4.574601599999999735e+02 1.513018099999999890e+02
+1.630471799999999973e+02 9.672647999999999513e+01 4.436449600000000260e+02 1.959417599999999879e+02
+1.267521700000000067e+02 1.210687199999999919e+01 3.491089400000000182e+02 1.299489400000000039e+01
+7.839227300000000120e+00 1.478580799999999940e+01 1.648144100000000094e+02 5.922769000000000261e+01
+3.307260300000000086e+01 1.982403400000000033e+02 2.107729699999999866e+02 3.684140499999999747e+02
+8.604231599999999647e-01 2.169522200000000112e+02 1.594700899999999990e+02 3.949990399999999795e+02
+5.562704200000000299e+01 8.441304300000000183e+01 2.715727800000000229e+02 1.740955600000000061e+02
+1.220638988325962941e+02 1.185185435128226459e+02 3.517891147119119069e+02 2.390017363973832971e+02
+2.390941223160494644e+02 1.172779404370260750e+02 5.516424743042626915e+02 2.473289597137311375e+02
+1.878158618498014221e+02 1.346463834981746004e+02 5.183335810388708751e+02 2.689797403362358637e+02
+9.644884478416841489e+01 2.165869733815735287e+02 2.693040235684538857e+02 4.296418009957009190e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..81a8f649fae74130a4a1e12c0b470148f748c8f0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle/crossval_errors.txt
@@ -0,0 +1,16 @@
+4.027252550598703351e-01
+3.191271567188723068e+00
+5.821331243373673026e-01
+2.905143316101345352e+00
+2.687513676074431590e+00
+4.783819982168825424e-01
+5.648388513287787127e+00
+8.213834063298731891e+00
+7.330748034364480858e+00
+1.110741198399469026e-01
+2.536472718509040103e+00
+3.155581166483636846e+00
+3.507286154532603728e+00
+1.179596267993104775e+01
+4.008971269381730984e+00
+9.571523717347277582e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..cfefe5bcf68ced8df5f98eef309c47ec288e0dc0
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:76e1af809211af74946975764fb4ed3eb450018f4400ddd220a5b2b8c28f3325
+size 39460
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..40b296d4a5a7772c1c77b7bc1db799682897ca30
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3587f375b8f16a462cc146d69470725708d2e6c3e9bc008828c6226e90404ae
+size 571656
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6e1cfab21473a292a75450a8d91fd13576c763f6
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/corrs.txt
@@ -0,0 +1,23 @@
+2.031673099999999863e+02 1.511227499999999964e+02 5.470917500000000473e+02 3.812720199999999977e+02
+1.278246900000000039e+02 1.582333099999999888e+02 3.950975900000000252e+02 3.912282000000000153e+02
+6.981330699999999467e+01 1.378685399999999959e+02 2.838833099999999945e+02 3.628750400000000127e+02
+2.062450300000000070e+01 1.665811500000000080e+02 2.064353399999999965e+02 4.181154300000000035e+02
+5.874876700000000085e+01 9.520412799999999720e+01 2.468054300000000012e+02 2.787631000000000085e+02
+4.306461399999999884e+01 1.194967299999999994e+02 2.168206900000000132e+02 3.224748599999999783e+02
+4.306461399999999884e+01 9.546390700000000606e+01 2.136284100000000024e+02 2.799354299999999967e+02
+1.363042900000000088e+02 2.272685700000000111e+02 4.315301600000000235e+02 5.132182599999999866e+02
+1.249141999999999939e+02 2.089112900000000081e+02 4.102798199999999724e+02 4.857932799999999816e+02
+7.821797799999999690e+00 1.875146499999999889e+02 1.872661400000000071e+02 4.604397599999999784e+02
+1.791764400000000137e+02 4.797713999999999857e+01 4.946746400000000108e+02 2.018916099999999858e+02
+1.969959000000000060e+02 4.750820600000000127e+01 5.294929300000000012e+02 2.007192699999999945e+02
+6.168053900000000311e+01 4.406037500000000051e+01 2.541400899999999865e+02 1.951090299999999900e+02
+2.074171461111978942e+02 1.130778721436135328e+02 5.534743734514790958e+02 3.064526860507694437e+02
+2.077209801180918589e+02 9.697466977823438583e+01 5.548043328588382792e+02 2.798534979035860601e+02
+1.779452474424851403e+02 6.780660511641553967e+01 4.956211392313551869e+02 2.379597765717722382e+02
+1.803759194976366871e+02 9.059415563346149725e+01 4.969510986387143703e+02 2.665539038299943400e+02
+9.408706153975595043e+01 8.176509711761033827e+01 3.220614365709834601e+02 2.578700518258719399e+02
+1.056327538017258973e+02 8.176509711761033827e+01 3.453357261997688852e+02 2.558751127148332216e+02
+1.065442558224077345e+02 7.052323886253432761e+01 3.566403811623218303e+02 2.445704577522802765e+02
+8.832415683445049126e+01 1.547573300627117305e+02 3.291543664475755122e+02 3.951877991166010133e+02
+7.327830040831864267e+01 1.730162867597605043e+02 2.886199438891764544e+02 4.214042890366270626e+02
+7.969089821058844336e+01 1.811777748717402687e+02 3.093775883080778044e+02 4.366993954505543343e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4789cb2b13b2f53035fbcadeca812fff6cceef12
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/kettle2/crossval_errors.txt
@@ -0,0 +1,23 @@
+6.956774605531451883e+00
+2.860879616904584033e+00
+1.697843297093116988e+00
+5.438113664493887889e+00
+4.995418497168292449e-01
+1.942702354772605444e+00
+4.400681895413596084e-02
+3.868755083381945958e+00
+1.391979085576086628e-01
+7.054838917641514939e+00
+7.125814072624173656e-01
+1.492414840109906216e+00
+3.395624676571964873e+00
+3.332879502809552030e-01
+2.233287232798145894e+00
+6.365794263520651031e-01
+9.102453491498465610e-01
+1.003026023461308425e+00
+2.251558963684663073e-01
+3.920756389386108598e+00
+1.955627908840252616e+00
+2.169390693519445357e+00
+4.825353733967688852e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..051db61918b367bb72469d3107bbd28172e95a22
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24a3f33faf8b550c9a1d242cfc309ab8cb137f228796cda303aa5b0ed5b77522
+size 40390
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6a7a888167d82497f909ccd26bded4755f0e491
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f16337a7f7eaad8a67af97b15e5f4d77ca0641f62719cbf87a3934b2b5d824ce
+size 589826
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a11078c55b58daf033d7ba3cb1e16ae9e664a9d1
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/corrs.txt
@@ -0,0 +1,28 @@
+1.628443699999999978e+02 2.167013399999999876e+02 2.607572299999999927e+02 3.298615899999999783e+02
+1.470178799999999910e+02 2.174047400000000039e+02 2.319178400000000124e+02 3.297443599999999719e+02
+1.615533499999999947e+02 1.005133799999999979e+02 4.440848500000000172e+02 1.440032600000000116e+02
+1.929841299999999933e+02 4.930592200000000247e+01 4.686370900000000006e+02 3.985307399999999944e+01
+1.940406900000000121e+02 1.124077800000000025e+02 4.498787300000000187e+02 1.653708399999999870e+02
+1.767331900000000076e+02 1.142835100000000068e+02 4.307095299999999725e+02 1.706269800000000032e+02
+1.740266399999999862e+02 2.291979100000000074e+02 4.231765100000000075e+02 3.930738499999999931e+02
+1.904393000000000029e+02 2.389282800000000009e+02 4.407615099999999870e+02 4.050316500000000133e+02
+4.218699300000000108e+01 9.592236900000000333e+01 1.025002499999999941e+02 1.362164099999999962e+02
+9.343587299999999374e+00 7.190201600000000326e+01 3.459909999999999819e+00 9.663790299999999434e+01
+6.077387399999999928e+01 1.963945599999999914e+01 8.101543499999999653e+01 6.908445300000000344e+00
+4.348383400000000165e+01 1.419482800000000111e+02 1.556304399999999930e+02 2.232398200000000088e+02
+1.144621499999999941e+02 5.671926500000000004e+01 1.966845700000000079e+02 6.646710500000000366e+01
+1.264752399999999994e+02 7.190783600000000320e+01 2.193707699999999932e+02 9.096125299999999925e+01
+1.480104100000000074e+02 5.954450400000000343e+01 2.564196000000000026e+02 6.945707199999999659e+01
+1.214080100000000044e+02 9.816435199999999384e+01 3.516977499999999850e+02 1.410918000000000063e+02
+1.333015800000000013e+02 2.099018100000000118e+02 2.282836100000000101e+02 3.237654600000000187e+02
+1.450249100000000055e+02 2.088467100000000016e+02 2.544266299999999887e+02 3.232965300000000184e+02
+1.690577299999999923e+02 2.489404999999999859e+02 2.669705999999999904e+02 3.816787100000000237e+02
+2.788453099999999907e+01 9.568790199999999402e+01 7.600552500000000578e+01 1.351613099999999861e+02
+1.189651512583989756e+02 4.266819208086550930e+01 2.160874152713540184e+02 3.704794679781866762e+01
+1.507345979079198912e+02 4.911762861873818053e+01 2.834910754242177404e+02 5.065235526903887830e+01
+1.311474202743806927e+02 3.430781138362316085e+01 2.222712373037268776e+02 2.344353832659845693e+01
+3.044997022037620837e+01 7.979086484477357999e+01 1.416168022564780244e+02 1.083209577219725475e+02
+2.806449069218989933e+01 1.184886438575735923e+02 1.416168022564780244e+02 1.817617464555121956e+02
+2.089269163050161637e+00 1.182235905766639945e+02 8.699142220673780912e+01 1.799409004538541694e+02
+7.655388062151530448e+00 7.634517219294892243e+01 1.977627943726929516e-01 1.010375737153404998e+02
+3.469082271492963088e+01 9.754943466571606336e+01 8.820531954117637952e+01 1.380614424157199664e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..930b7074944da86e735567c7ded9fac73bdebaba
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab/crossval_errors.txt
@@ -0,0 +1,28 @@
+7.752240801306259366e-01
+1.148758241413943004e+00
+1.768738052957149032e+00
+1.568654570431663009e+00
+2.271406557678926907e-01
+2.449004580308638268e-01
+2.812550544505385108e+00
+2.360283068848247012e+00
+8.507221766035513166e-01
+1.476137889129510050e+00
+3.842955461712121989e-01
+3.790487856614814244e-01
+1.834805512081233880e+00
+1.705524880090546203e-02
+1.197058223742027971e+00
+1.595218938876604153e-01
+4.930223213755020595e-01
+2.767083777630560126e-01
+1.538363609208045046e+00
+3.359096424522557411e-01
+3.107673526549437959e+00
+2.830308237137624694e+00
+3.095121535859486794e+00
+2.387996241400114794e+00
+2.596366077667401351e-01
+1.379228150193078539e+00
+1.946997708432720353e+00
+5.931413354392515158e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..c7148ac014161852700da15270e74668497c8b8b
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:989fbd9964db755bb62ff02415a77d5fec359750c4a4674a6d9f30e047cd5f56
+size 42553
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..8be58f1a21e8e277f17e0594ccd7e1c291439fb5
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bcb2a79c6312164966841bcf36b8f9f01af95fd70b135f118740a0bd6de8ff9e
+size 554068
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..31d982d310eec85faadee731bf733a144e7588d5
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/corrs.txt
@@ -0,0 +1,26 @@
+1.163067999999999991e+01 1.006243099999999941e+02 4.788471099999999865e+02 2.956166799999999739e+02
+8.383898000000000650e+01 1.648345700000000136e+02 6.089369799999999486e+02 4.227757199999999784e+02
+6.257121200000000272e+01 1.188440099999999973e+02 5.714095300000000179e+02 3.278976799999999798e+02
+5.359749899999999911e+01 1.795072900000000118e+02 5.490653499999999667e+02 4.416253699999999753e+02
+4.060295299999999941e+01 1.382418600000000026e+02 5.267889499999999998e+02 3.491276199999999790e+02
+5.382905699999999882e+01 1.715750299999999982e+02 5.498956699999999955e+02 4.306991699999999810e+02
+9.625899200000000633e+01 1.509786799999999971e+02 5.973007999999999811e+02 3.364266600000000267e+02
+1.111350900000000053e+02 1.272402799999999985e+02 6.381398199999999861e+02 3.165250100000000089e+02
+1.594964899999999943e+02 5.475679999999999836e+01 7.662270399999999881e+02 1.909491399999999999e+02
+1.576062100000000044e+02 1.025320300000000060e+02 7.599642300000000432e+02 2.903702299999999923e+02
+7.635349899999999934e+01 1.282101399999999956e+02 5.923780299999999670e+02 3.443898800000000051e+02
+1.604430800000000090e+02 7.735484999999999900e+01 7.678683099999999513e+02 2.363184300000000064e+02
+1.629253400000000056e+02 1.936663399999999911e+02 6.783636400000000322e+02 3.636074899999999843e+02
+3.278412300000000101e+01 1.553686600000000055e+02 4.981625900000000229e+02 3.797770399999999995e+02
+1.694398599999999888e+01 1.558463199999999915e+02 4.681584799999999973e+02 3.737583700000000135e+02
+1.595255699999999877e+02 1.682267200000000003e+02 7.001690300000000207e+02 3.496567299999999818e+02
+1.689042300000000125e+02 1.947214400000000012e+02 6.910248299999999517e+02 3.653659900000000107e+02
+1.670709897623286508e+02 3.286760692285372443e+01 7.881821280378179608e+02 1.521775730514374345e+02
+1.598054208230585118e+02 6.878027625124609301e+01 7.676396469214458875e+02 2.198916774720711373e+02
+3.880017453908246239e+01 1.591056886131800923e+02 5.224260410634325353e+02 3.925349736709622448e+02
+4.383920758304722654e+01 1.587697530769157765e+02 5.382471696556165170e+02 3.918470985147803276e+02
+5.324540259844809498e+01 1.587697530769157765e+02 5.568197988725282812e+02 4.069803519507824490e+02
+1.285673750694253954e+02 1.236897702190242683e+02 6.630219832663985926e+02 2.996768242875007218e+02
+1.208212523026747789e+02 1.600470593496146421e+02 6.767368812378181246e+02 4.147277185279446599e+02
+7.859565859050952952e+01 1.699149970106097953e+02 6.024193888738506075e+02 4.321538891512198575e+02
+3.315724796546213327e+01 1.662432062530304790e+02 5.004250372846813093e+02 3.988391511949585038e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..11a7637176f2973f3cd07984b07696dc68999c64
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/lab2/crossval_errors.txt
@@ -0,0 +1,26 @@
+7.907225210316218167e+00
+2.535624236718984736e+00
+3.086222092834641773e-01
+1.874681247790922489e+00
+6.968237522558005992e+00
+5.594632224115625441e+00
+9.166892174461624521e+00
+5.255621098281488379e+00
+4.583878768840595752e-01
+2.145340407892266654e+00
+2.464833084486542703e+00
+2.055031501699763563e+00
+2.236388676391594998e+00
+7.235284117576637364e+00
+7.775682560729027415e+00
+5.285462960500259655e+00
+2.039091324975429664e+00
+2.681807750177786165e+00
+1.585768370720183507e+00
+3.566932614906344501e+00
+1.104428367277808753e+01
+2.088649663206282092e+00
+1.125499910202593234e+01
+7.554842478097656411e+00
+1.402237521628181938e+00
+3.734888943104968106e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..a923b184c80023980d26b184a600d9c1f37bd5b2
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9478623f10f039e55d1a34d0eb99a618cd5ffcc4f127b232b3d835db352acd64
+size 43062
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..fe6e8ac17368220990dd082d6e407a5200df6b45
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf74fb4e2859475a9181a31a2dbca5b36e1fbb890f1b3dfca1beaa341a2d5611
+size 770313
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2cc8d21fbc51a953aac5bd97e79c66ae4b9d7d68
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/corrs.txt
@@ -0,0 +1,31 @@
+9.897087299999999743e+01 4.934491899999999731e+01 3.632952000000000226e+02 1.367832099999999969e+02
+1.261303699999999992e+02 5.910536499999999904e+01 3.965527999999999906e+02 1.532947800000000029e+02
+1.599906200000000069e+02 7.385744900000000257e+01 4.414175299999999993e+02 1.768897900000000050e+02
+1.897854499999999973e+02 1.028151900000000012e+02 4.796774199999999837e+02 2.206693400000000054e+02
+1.413513900000000092e+02 8.731513800000000458e+01 4.165426299999999742e+02 1.930496599999999887e+02
+8.060585000000000377e+01 1.388418299999999874e+02 3.286501700000000028e+02 2.480362199999999859e+02
+7.987434899999999516e+01 1.242497299999999960e+01 3.395829600000000141e+02 9.155710200000000043e+01
+1.868148799999999881e+02 4.807181800000000038e+01 4.802248599999999783e+02 1.500509800000000098e+02
+1.995904700000000105e+02 1.538564800000000048e+02 4.841914199999999937e+02 2.815089199999999892e+02
+9.005026000000000863e+00 5.210730299999999815e+00 2.472646799999999985e+02 7.471320400000000461e+01
+5.654654400000000081e+00 1.297964200000000119e+02 2.329611899999999878e+02 2.294531900000000064e+02
+4.750520800000000321e+01 1.867672300000000121e+02 2.826131100000000060e+02 3.014105599999999754e+02
+2.378445499999999981e+02 8.430482899999999802e+01 5.434962500000000318e+02 2.020185199999999952e+02
+1.966315299999999979e+02 6.298087300000000255e+01 4.909328600000000051e+02 1.670217900000000100e+02
+2.435347500000000025e+02 1.771240099999999984e+02 5.369892899999999827e+02 3.132715299999999843e+02
+2.639605799999999931e+02 1.931162699999999859e+02 5.629881000000000313e+02 3.345121899999999755e+02
+1.291009400000000085e+02 1.613778599999999983e+02 3.893327499999999759e+02 2.781672500000000241e+02
+1.546802099999999882e+02 1.858129000000000133e+02 4.204801499999999805e+02 3.131155600000000163e+02
+1.518268700000000138e+02 2.105879899999999907e+02 4.132891500000000065e+02 3.402469399999999951e+02
+1.056154099999999971e+02 2.148762299999999925e+02 3.533884699999999839e+02 3.404814099999999826e+02
+1.066277800000000013e+02 7.896014099999999303e+01 3.683018799999999828e+02 1.776588500000000010e+02
+2.017403999999999940e+02 1.114033000000000015e+02 4.912942400000000021e+02 2.288863900000000058e+02
+3.921801490628814690e+01 7.235750264493356099e+01 2.786385103920094934e+02 1.611776639627567533e+02
+3.970357384006362622e+01 4.613732022105824626e+01 2.800882859621938223e+02 1.271079380634251947e+02
+6.200007409556292259e+00 5.147846849258843349e+01 2.380447944268485116e+02 1.300074892037938525e+02
+1.396895034996379081e+01 1.014910386714616948e+02 2.460185600628623206e+02 1.952473898620882551e+02
+4.941475251557298520e+01 7.184338185676199373e+01 2.931362660938526687e+02 1.604101363703597372e+02
+4.504472211159377082e+01 3.865131128196424015e+00 2.938611538789447195e+02 7.704804108476133706e+01
+2.638002771336840624e+02 7.635581717826380554e+01 5.788273089375950349e+02 1.954089779812106826e+02
+2.352291152034932509e+02 1.095780984924392101e+02 5.360306840245516469e+02 2.303220140944829382e+02
+1.694489982014259510e+02 3.914686210638734565e+01 4.594472499696317982e+02 1.370439269645612512e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f559b819b0225783996edbd143c240f10c4aff0d
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WGSBS/window/crossval_errors.txt
@@ -0,0 +1,31 @@
+3.584996310465787595e+00
+4.946349073889345083e-01
+1.107964495161395613e+00
+5.027808283948981272e-01
+1.233645956010521960e+00
+1.140904304662818447e+00
+1.746155141624057272e+00
+6.075603257546631220e-01
+8.903860399682645976e-01
+1.231921832177196446e-01
+4.462442879723696465e-01
+1.348039545561001029e+00
+1.267108840243132750e+00
+1.770145183486460772e+00
+1.586913035140011363e+00
+7.921288124968347555e-01
+7.384951230528175037e-01
+1.349153872928968045e+00
+5.354941394430628998e-01
+7.480195637740312264e-01
+1.645405748374197286e+00
+6.062212534539174191e-01
+2.481283471472908175e+00
+3.224114288960418850e+00
+9.488628321632434082e-01
+6.451342801452116804e-01
+1.852461708152915998e-01
+1.880127786478289709e+00
+2.137403187048483755e+00
+1.837261225400464437e-01
+7.754601935878026042e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/.DS_Store b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..61b34e0771a875eb567d8c70906d04413e0987ca
Binary files /dev/null and b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/.DS_Store differ
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..bd7292fa5d275ccc539d43881c5ac5647f3da656
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da934d6a058f970ddcee0cb0434cabd06e0f23588c0466b925de5b2b290d8ce4
+size 352843
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..ed05585af37639e01b384253800f6a086740abac
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0371595a5ccab153f60b0e00bb5df8056c21722b53800dc1b484cb64945e869
+size 602365
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b6e2dfb82a70a2658388de80ad1e7613c9196a2
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/corrs.txt
@@ -0,0 +1,40 @@
+1.169112800000000050e+02 4.929016299999999973e+01 3.541015600000000063e+02 1.477567500000000109e+02
+1.563096199999999953e+02 6.461383399999999710e+01 3.742968000000000188e+02 1.550746599999999944e+02
+1.300493600000000072e+02 6.308980100000000135e+01 3.624562399999999798e+02 1.543712600000000066e+02
+3.226906299999999987e+02 1.387599099999999908e+02 4.564341099999999756e+02 1.899251099999999894e+02
+2.597951400000000035e+02 1.359850399999999979e+02 4.243488500000000272e+02 1.898659700000000043e+02
+2.609577899999999886e+02 1.679579099999999983e+02 4.245833099999999831e+02 2.069820400000000120e+02
+3.232767900000000054e+02 1.715852299999999957e+02 4.563168800000000260e+02 2.086824400000000139e+02
+3.150161699999999882e+02 2.376190400000000125e+02 4.504939400000000091e+02 2.430843500000000006e+02
+2.615342800000000238e+02 2.722495400000000245e+02 4.229323600000000170e+02 2.602702699999999822e+02
+2.843377100000000155e+02 2.706373300000000199e+02 4.345491799999999785e+02 2.606122799999999984e+02
+3.165208400000000211e+02 3.306116599999999721e+02 4.491161900000000173e+02 2.890138600000000224e+02
+2.994231100000000083e+02 3.597612500000000182e+02 4.395611599999999726e+02 3.028581100000000106e+02
+2.541969899999999996e+02 4.175856200000000058e+02 4.182032600000000002e+02 3.301561800000000062e+02
+1.484055400000000020e+02 4.088754299999999944e+02 3.732707500000000209e+02 3.302637300000000096e+02
+1.363789299999999969e+02 3.112467300000000137e+02 3.649461499999999887e+02 2.824778499999999894e+02
+1.341020399999999881e+02 1.730726099999999974e+02 3.628068999999999846e+02 2.122596100000000092e+02
+2.236309200000000033e+02 2.367838800000000106e+02 4.060594199999999887e+02 2.489460200000000043e+02
+2.237529999999999859e+02 2.510873799999999960e+02 4.061863299999999981e+02 2.571727500000000077e+02
+2.515199900000000071e+02 2.541112400000000093e+02 4.197573800000000119e+02 2.533223800000000097e+02
+3.215183000000000106e+02 4.430418399999999792e+02 4.511779799999999909e+02 3.412158800000000269e+02
+2.831653699999999958e+02 2.885740200000000186e+02 4.341974799999999846e+02 2.681152200000000221e+02
+2.585463799999999992e+02 2.877533900000000244e+02 4.217707500000000209e+02 2.681152200000000221e+02
+1.490401200000000017e+02 2.910826000000000136e+02 3.710422800000000052e+02 2.746232200000000034e+02
+2.603124900000000252e+02 2.408638799999999947e+02 4.239777799999999957e+02 2.441781799999999976e+02
+2.434460722294301149e+02 2.803174670547233518e+02 4.122604712625906700e+02 2.872925018567745497e+02
+2.733039336263267387e+02 4.925272668931202134e+02 4.241827474016961332e+02 3.639119046943361582e+02
+5.872857669477701847e+01 4.268111770917495846e+02 3.644077504740858444e+02 3.464977785170541438e+02
+8.004596266233613733e+01 4.657705376531507682e+02 3.708833822336325738e+02 3.569927679204574247e+02
+1.109194182015597079e+02 5.466295878749267558e+02 3.603883928302292361e+02 3.824486996648824970e+02
+3.020408096348485287e+02 4.084341202231640864e+02 4.401056527667184355e+02 3.250612044164856229e+02
+2.402938985564013592e+02 3.613888546395853609e+02 4.124167445534841363e+02 3.027314397283934682e+02
+1.256210636964280951e+02 3.562432787163814396e+02 3.628446669459195277e+02 3.047411185503217439e+02
+1.131246650257899944e+02 3.893219810798351546e+02 3.577088210676583913e+02 3.208185491257481772e+02
+1.168000763995070770e+02 3.555081964416379492e+02 3.590486069489438705e+02 3.042945232565599554e+02
+2.511865675635852710e+02 2.673496829800556043e+02 4.184642774916459871e+02 2.594638950006838627e+02
+2.144765395951733922e+02 2.560330578168609463e+02 4.015903118942297851e+02 2.573677502059737776e+02
+2.260691800062507753e+02 2.215311518315115222e+02 4.066210594015339552e+02 2.316899764707751501e+02
+2.487204598853905395e+02 6.543366483336311035e+01 4.200943651318810339e+02 1.384655147177983849e+02
+2.892832720200565291e+02 6.791710231099571615e+01 4.423739670532051491e+02 1.400020389882345455e+02
+2.093993664895408813e+02 1.357977266996203980e+02 3.993512874809929940e+02 1.807199321547924740e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c4c00ab7cffad294a2d3b1c0761783529b18c8f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/dh/crossval_errors.txt
@@ -0,0 +1,40 @@
+3.107206334320491425e+00
+6.237904612903644175e-01
+3.844272521878835835e+00
+4.161468221374081677e-01
+6.602655361141566148e-01
+2.933609060206928687e-01
+1.621433032763675008e+00
+8.156725260303979708e-01
+1.374326397156321988e+00
+5.699350862722842859e-01
+2.645182804425461054e-01
+2.216777645068836122e+00
+2.065341874831398972e-01
+2.487949740495516426e+00
+1.915661557083224054e+00
+1.524632930681501408e+00
+8.505375980072693576e-01
+1.324056400088596597e+00
+2.622759999458317726e+00
+1.273873093214882291e+00
+9.289628752266489986e-01
+1.223324686678771839e-01
+2.363886694428318780e+00
+1.428140686106678992e+00
+7.601475707330846099e+00
+2.027895169522786478e+00
+1.569901890632411812e+00
+1.942497203188160437e+00
+3.283187903146319364e+00
+3.453279298906811512e+00
+2.921981659860365954e-01
+1.235501646779935125e+00
+2.856988289013143945e+00
+1.343701685909615584e+00
+1.528449211300779431e-01
+6.289527516872593926e-01
+6.832133453264218614e-01
+1.883882880515432845e+00
+7.781953177049958370e-01
+7.535427800225099615e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..c85e935768a0ff0631d7798063fd549a8ad7b237
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd32df9ca232efdd702f4694c74f262286989e39ed26ef8a2960012c58a987e9
+size 1003664
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..98d9c30ecdf718c3a00932987710d8d07556db66
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd51d44e58fe701ac03c3034c28efb14bb91558650ce8aaff0587c18f3c3c37f
+size 882713
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..425014b75a07d62405a16ec51b1a87c0b4418d35
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/corrs.txt
@@ -0,0 +1,51 @@
+4.594434400000000096e+02 1.421354699999999980e+02 4.110907700000000204e+02 1.401134500000000003e+02
+4.247887299999999868e+02 1.560969499999999925e+02 3.760649900000000230e+02 1.548364300000000071e+02
+3.429439899999999852e+02 2.414490099999999870e+02 2.926467599999999720e+02 2.486593699999999956e+02
+3.713144500000000221e+02 2.355873400000000117e+02 3.224240199999999845e+02 2.410392099999999971e+02
+4.057173999999999978e+02 2.017435800000000086e+02 3.556643199999999752e+02 2.049583299999999895e+02
+4.831266400000000090e+02 1.719836200000000019e+02 4.367282000000000153e+02 1.692571599999999989e+02
+4.923009200000000192e+02 2.504372400000000027e+02 4.481890500000000088e+02 2.526162500000000080e+02
+4.432739000000000260e+02 2.727330000000000041e+02 3.958687699999999836e+02 2.772373200000000111e+02
+4.377047999999999774e+02 3.714216400000000249e+02 3.913257199999999898e+02 3.837612199999999802e+02
+4.153218800000000215e+02 3.899476099999999974e+02 3.685330099999999902e+02 4.022775100000000066e+02
+5.751513099999999667e+02 3.342953499999999849e+02 5.388746999999999616e+02 3.367766100000000051e+02
+3.755078700000000254e+02 2.774814700000000016e+02 3.261097700000000259e+02 2.857081999999999766e+02
+2.731549099999999726e+02 4.166197399999999789e+02 1.820476999999999919e+02 4.358663999999999987e+02
+3.387235900000000015e+02 4.110980200000000195e+02 2.598591299999999933e+02 4.305704999999999814e+02
+5.390134799999999586e+02 2.733323599999999942e+02 4.980345800000000054e+02 2.736875699999999938e+02
+2.177295000000000016e+02 1.132552699999999959e+02 1.542236399999999890e+02 1.107051900000000018e+02
+1.145859999999999985e+02 2.526850699999999961e+02 5.192982200000000148e+01 2.623476499999999874e+02
+4.961008100000000240e+02 3.077200500000000147e+02 4.523599899999999820e+02 3.136505300000000034e+02
+4.634293799999999806e+02 3.142657500000000255e+02 4.179203800000000228e+02 3.208705800000000181e+02
+3.886099899999999820e+02 3.253458499999999844e+02 3.399249699999999734e+02 3.348417499999999905e+02
+5.037316900000000146e+02 4.045523200000000088e+02 4.621204399999999737e+02 4.156023400000000265e+02
+7.397184899999999743e+02 4.309706199999999967e+02 7.718919399999999769e+02 4.487039700000000266e+02
+7.284640900000000556e+02 4.433973500000000172e+02 7.557137400000000298e+02 4.653510999999999740e+02
+6.990568700000000035e+02 3.712936900000000264e+02 6.837079499999999825e+02 3.754366200000000049e+02
+4.164350800000000277e+02 2.141025300000000016e+02 3.669003900000000158e+02 2.148543400000000076e+02
+4.680218800000000101e+02 1.948160900000000026e+02 4.207339799999999741e+02 1.938685499999999990e+02
+4.289209500000000048e+02 4.063602700000000141e+02 3.828354699999999866e+02 4.189246400000000108e+02
+4.493195400000000177e+02 3.867823099999999954e+02 4.037029899999999998e+02 3.984088100000000168e+02
+6.950526099999999587e+02 4.259295900000000188e+02 7.128063600000000406e+02 4.427250700000000165e+02
+2.044560836833049393e+02 1.242459178115272778e+02 1.406147127704518027e+02 1.236865742051300145e+02
+2.281308525063912498e+02 1.210389770147745594e+02 1.662544620520748424e+02 1.197919287446302974e+02
+2.348276994643161402e+02 1.438019680727641116e+02 1.722587071370118679e+02 1.445662008445619620e+02
+6.717463228213715638e+02 3.649054578569836167e+02 6.445953051321157545e+02 3.698576462864243695e+02
+6.726109392126528519e+02 4.008951151440707577e+02 6.462311986451652501e+02 4.084647331943905897e+02
+6.521843769686304313e+02 3.636085332700615709e+02 6.234377490300099680e+02 3.709482419617906430e+02
+4.940707018408172075e+02 1.776044948851611878e+02 4.486826560612118442e+02 1.751879062404033505e+02
+4.926196935077252306e+02 2.074307772876081799e+02 4.476141575501331999e+02 2.061743630616842609e+02
+4.855258749903864555e+02 1.848595365506212715e+02 4.403812445520623555e+02 1.831605489769133044e+02
+8.207629282561811124e+01 2.572739565469469198e+02 1.809768513909574494e+01 2.677228259828372074e+02
+7.478616494490586319e+01 2.775595297802331629e+02 6.228466757917601626e+00 2.887637131131075421e+02
+5.388479506127032437e+02 1.739406870605558879e+02 4.964077749761838732e+02 1.681577771226033917e+02
+5.791628298642275468e+02 1.504787491354884708e+02 5.394336916079546427e+02 1.415599377502360596e+02
+5.418219990820780367e+02 1.924458775366653924e+02 4.983634984594461912e+02 1.873238672585739835e+02
+6.054122596184905660e+02 4.225034913737084707e+02 5.760699684299472665e+02 4.370432786270732208e+02
+6.292047167561709102e+02 4.238849888849285890e+02 6.045560925447747422e+02 4.376367395461321053e+02
+6.330922480021150704e+02 3.748893448863988738e+02 6.008324998773894094e+02 3.828873228498626986e+02
+6.940164133197212095e+02 3.739800289861361193e+02 6.688765891054499662e+02 3.792636731276583077e+02
+7.219021009277796566e+02 3.603402904821944048e+02 6.998789256176431763e+02 3.643664464919290822e+02
+7.322076811307578055e+02 4.003501900937566234e+02 7.123603857719028838e+02 4.086554986522052104e+02
+6.843170437169181923e+02 4.079278225959464521e+02 6.600187786733947632e+02 4.175133090842604133e+02
+6.993143883506772909e+02 4.338220179609293723e+02 7.397286915533125011e+02 4.586266329625896105e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c7c2bf71fa21c921dc45a9dabc6915ee11230c9f
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kpi/crossval_errors.txt
@@ -0,0 +1,51 @@
+1.059984065478942972e+00
+1.172619320842898327e+00
+2.759231593787795367e-01
+1.246239438681858713e+00
+1.478182636100590752e+00
+6.418757997184851294e-02
+5.476242100183045247e-01
+2.837136386542546962e-01
+2.091690129035768120e+00
+3.300151015928541320e-01
+1.270728842324514885e+00
+1.133749249740539478e+00
+1.009659388247202694e+00
+3.648056175214751118e+00
+5.856684777977440998e-02
+8.687130066573967024e-01
+6.579392662125330693e-01
+6.064583546252843016e-01
+6.361053578088456117e-02
+3.291927550264510116e-01
+1.698324996372617024e+00
+1.442326735645127700e+00
+1.464286713857532840e+00
+5.724373708212716627e-01
+6.418198956336912397e-01
+5.189720473721508576e-01
+1.443933549361242852e+00
+1.376498883837392051e-01
+1.606222283955370589e+00
+3.479913123925851837e-01
+5.277859297109415149e-01
+1.946469222702315705e+00
+1.392742694583132490e+00
+2.400295113941870007e+00
+1.841967347264411803e+00
+2.352230930271596021e-01
+4.439526535957881159e-01
+1.145484789065688735e+00
+7.598014255470626477e-01
+7.455699952272273334e-01
+5.072581507499479558e-01
+7.148961590925400067e-01
+1.950846228854876463e+00
+1.470122079859542685e+00
+1.948916047213692104e-01
+1.393201192210905548e-01
+1.750177991124282739e+00
+8.992849163506809740e-01
+2.437138945474842666e-02
+4.008038096484493606e-01
+2.857518653713483348e-01
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/01.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a893e4c9ae1e30ab4e76c2808e2648889a7c6d6
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7ea892f78feda9190d466ec48584ccbd178f4645076c49add4b43cbe444b2dcf
+size 608612
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..4b0cbf6e83fbc1cdbf92897f3fe73286c50ed7df
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86ef3afd1ac043c1d51702a7efc71774cf402a989072f4fc2fb19c29cbbe9288
+size 584762
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..32daf93f65fb756166ea98850677583674250d46
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/corrs.txt
@@ -0,0 +1,38 @@
+4.797559299999999780e+02 3.636057400000000257e+02 5.910282999999999447e+02 2.342230399999999975e+02
+4.575773899999999799e+02 3.698191100000000233e+02 5.583859200000000556e+02 2.408268300000000011e+02
+2.726085100000000239e+02 3.609577899999999886e+02 2.741034999999999968e+02 2.178287100000000009e+02
+2.530779300000000092e+02 3.872298099999999863e+02 2.463365000000000009e+02 2.571533800000000269e+02
+1.570759799999999871e+02 3.638692599999999970e+02 9.238810100000000602e+01 2.168714699999999880e+02
+3.391323499999999740e+02 3.876095199999999750e+02 3.732686800000000176e+02 2.604542500000000018e+02
+5.019839099999999803e+02 3.695469499999999812e+02 6.249612799999999879e+02 2.458215199999999925e+02
+5.469841999999999871e+02 3.587206800000000158e+02 6.836004000000000360e+02 2.281850000000000023e+02
+4.328263000000000034e+02 3.796009999999999991e+02 5.088481100000000197e+02 2.551237600000000043e+02
+4.344319499999999721e+02 3.635670099999999820e+02 5.068776199999999790e+02 2.286345699999999965e+02
+2.574950999999999794e+02 4.067603899999999726e+02 2.631620700000000284e+02 2.901571500000000015e+02
+4.085600499999999897e+02 4.326710199999999986e+02 4.939809200000000260e+02 3.393498299999999972e+02
+5.832414499999999862e+02 3.753785300000000120e+02 7.470578500000000304e+02 2.571437000000000239e+02
+2.331665999999999883e+02 3.660773300000000177e+02 2.143684700000000021e+02 2.241108900000000119e+02
+3.921076199999999972e+02 3.674550800000000095e+02 4.443483600000000138e+02 2.336271900000000130e+02
+4.930527399999999716e+02 3.986024699999999825e+02 6.086143299999999954e+02 2.861023999999999887e+02
+5.256456699999999955e+02 4.605939799999999877e+02 6.824474400000000287e+02 3.859730500000000006e+02
+4.944595400000000041e+02 4.609456799999999816e+02 6.369771700000000010e+02 3.860902899999999818e+02
+5.826532099999999446e+02 4.119304099999999949e+02 7.677974199999999882e+02 3.165097400000000221e+02
+4.631633807972402792e+02 3.838893306705910504e+02 5.657415136605623047e+02 2.633477841694225390e+02
+4.772016069320450811e+02 3.809024740461644569e+02 5.853869735965520249e+02 2.612353691225419539e+02
+4.999017172776869415e+02 3.826945880208203903e+02 6.215092708982107297e+02 2.637702671787986901e+02
+2.298391962234342714e+02 3.919639079503628523e+02 2.112270426050509968e+02 2.649053558881469144e+02
+2.258501706000224658e+02 3.922488383520351363e+02 2.065109531980616566e+02 2.662808819651854719e+02
+1.987817824411567926e+02 3.950981423687578626e+02 1.648521634363227690e+02 2.678529117675152520e+02
+2.467925551229343455e+02 3.931036295570519314e+02 2.385410604205307550e+02 2.670668968663503620e+02
+2.581897711898251941e+02 3.922488383520351363e+02 2.544578621691196645e+02 2.676564080422240295e+02
+2.522062327547075427e+02 3.617612853731021687e+02 2.428641423769375933e+02 2.187269804447099375e+02
+2.131707677256064812e+02 3.660352413981862583e+02 1.835200173389887652e+02 2.234430698516992493e+02
+1.701462770730936143e+02 3.785721790717661861e+02 1.198528103446330704e+02 2.419144200290740514e+02
+1.983543868386483950e+02 3.688845454149090415e+02 1.644591559857403240e+02 2.281591592586885611e+02
+5.589168186820540996e+02 4.165899106386119115e+02 7.329688054814949965e+02 3.228653612625751634e+02
+6.007144447556318028e+02 4.167355469663943950e+02 7.965020870332186860e+02 3.259219452547543483e+02
+4.800636057737958708e+02 4.608663719209688452e+02 6.134348439460603686e+02 3.839675846978299205e+02
+1.447956348980008556e+02 3.697010709907046930e+02 6.855698710166603860e+01 2.243443305002154773e+02
+1.529894419054618879e+02 3.697010709907046930e+02 8.174821679819848441e+01 2.246825671591009268e+02
+1.623971462473616043e+02 3.748601346620690720e+02 1.061012562379507358e+02 2.358443769023207039e+02
+4.879100783360448190e+02 3.810894645944220542e+02 5.999324939202114138e+02 2.625451796870169119e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9ad28740ce26bd03baeb9766ce5198e11d65d5a8
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/kyiv/crossval_errors.txt
@@ -0,0 +1,38 @@
+3.396145170153651027e-01
+1.985991266213804840e+00
+8.730360763178448558e-01
+1.759074792352649741e+00
+6.690911651420399231e-01
+3.127928775947395046e+00
+2.242026380339626535e+00
+1.505917217336340819e-01
+2.060474379419809488e+00
+1.356362993459239386e+00
+4.635784250486629787e-02
+2.898097811114895173e+00
+7.993129866486650137e-01
+3.993484324439461886e-01
+1.941237862719949803e+00
+2.113590891431526231e+00
+1.301169151744067665e-01
+1.814139350301335440e+00
+1.942498926613547283e+00
+4.498550778761978841e-01
+2.105282325971198798e+00
+1.643782548361214652e-02
+7.447399375311934688e-02
+1.379783288516800654e+00
+8.107935155796172078e-01
+6.991320042190590778e-01
+1.313621562781517982e+00
+2.148165976067879790e-01
+2.133775335796964101e-01
+1.210735840397118235e+00
+9.441267027822466407e-01
+3.666260745439142266e+00
+3.918843643756335204e+00
+1.097652941210957911e+00
+6.105290073251094796e-02
+1.356574778919217295e-01
+1.617277148218513982e+00
+2.955076440033605589e+00
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/02.png b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..430c0e076263407c36895282121b585e09d5cd98
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c2abc34567f59818457ba0c69eef03868b6b0be276c73388bade0fbfff21970
+size 540896
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/corrs.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/corrs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8a045058924894eba94c7b02ccf8038a8d67f1ef
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/corrs.txt
@@ -0,0 +1,36 @@
+4.615342800000000238e+02 2.389287400000000048e+02 4.738835399999999822e+02 2.814227500000000077e+02
+4.576365299999999934e+02 2.757140499999999861e+02 4.694200299999999970e+02 3.442165600000000154e+02
+2.591256500000000074e+02 2.957716599999999971e+02 1.176534300000000002e+02 3.840285000000000082e+02
+2.595849000000000046e+02 2.628464000000000169e+02 1.152990800000000036e+02 3.305369299999999839e+02
+3.461092899999999872e+02 1.858082300000000089e+02 4.895454199999999787e+02 1.801452899999999886e+02
+3.563085899999999810e+02 1.855737599999999929e+02 5.023238499999999931e+02 1.801452899999999886e+02
+3.703001499999999737e+02 2.577946499999999901e+02 3.055337799999999788e+02 3.182941700000000083e+02
+3.900855999999999995e+02 2.519620400000000018e+02 3.353916100000000142e+02 3.058674500000000194e+02
+5.045426400000000058e+02 2.437352999999999952e+02 6.176713799999999992e+02 2.771141700000000014e+02
+5.591871800000000121e+02 2.706719800000000191e+02 6.825259399999999914e+02 3.283976799999999798e+02
+4.768939100000000053e+02 2.398385900000000106e+02 5.635075199999999995e+02 2.751426399999999717e+02
+3.574819499999999834e+02 2.812627499999999827e+02 2.861010600000000181e+02 3.566723400000000197e+02
+4.566492099999999823e+02 2.588013300000000072e+02 4.655793400000000020e+02 3.148266300000000228e+02
+2.816849300000000085e+02 2.655427700000000186e+02 1.499537899999999979e+02 3.346798600000000192e+02
+2.157655900000000031e+02 2.429049899999999980e+02 1.148688800000000043e+02 2.944754100000000108e+02
+2.927736800000000130e+02 2.799935800000000086e+02 1.727572199999999896e+02 3.558925100000000157e+02
+3.823568500000000085e+02 2.611653699999999958e+02 3.271164600000000178e+02 3.231802799999999820e+02
+4.123609700000000089e+02 2.038438199999999938e+02 6.026451100000000451e+02 2.004868299999999977e+02
+3.833538500000000226e+02 2.794558299999999917e+02 3.290319600000000264e+02 3.525691800000000171e+02
+3.647116899999999760e+02 2.336124299999999891e+02 3.032482499999999845e+02 2.754504200000000083e+02
+3.588877200000000016e+02 2.104272200000000055e+02 5.073648800000000278e+02 2.216458800000000053e+02
+3.409510200000000282e+02 2.089031799999999919e+02 4.814563299999999799e+02 2.183633499999999970e+02
+3.613904200000000060e+02 2.614288799999999924e+02 2.906451500000000010e+02 3.242730700000000184e+02
+4.981131399999999871e+02 2.399558199999999886e+02 6.060632100000000264e+02 2.727979700000000207e+02
+4.695820600000000127e+02 2.615820600000000127e+02 5.372573300000000245e+02 3.162080900000000270e+02
+1.721220706987422773e+02 2.428520120853525555e+02 2.947191743642216011e+01 2.989297157012354091e+02
+1.647649712238389554e+02 2.405529184994452692e+02 1.660567796869992208e+01 2.947568812792713970e+02
+1.954195523692694394e+02 2.331958190245419473e+02 6.946158064691022105e+01 2.806735651051429841e+02
+1.744211642846495636e+02 2.557269361664333474e+02 3.260154325289514077e+01 3.187506792055642109e+02
+1.800922617965541974e+02 2.594054859038849941e+02 4.425070601421123229e+01 3.236189860311888538e+02
+1.558751426916641094e+02 2.331958190245419473e+02 1.826889390910878319e+00 2.846725314261918243e+02
+1.566415072202998715e+02 2.754991410052360266e+02 4.608779005553515162e+00 3.490037287648030428e+02
+1.960326439921780661e+02 2.738131390422373670e+02 7.293894266521351710e+01 3.441354219391784000e+02
+4.117581439734341870e+02 2.195749695714285963e+02 6.025075896224828966e+02 2.308977018826773815e+02
+4.066962080099323202e+02 2.078240467990135585e+02 5.947414811149388925e+02 2.116562987147324293e+02
+4.067502800813609269e+02 2.193798131157576563e+02 5.957846897204299239e+02 2.306658781902932844e+02
diff --git a/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/crossval_errors.txt b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/crossval_errors.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aec67f6c0ec471fac7a21dd8402b563e6906ec51
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/.WxBS/v1.1/WLABS/ministry/crossval_errors.txt
@@ -0,0 +1,36 @@
+2.698356670634800558e+00
+3.569808760282799365e-01
+2.495950953406511630e+00
+7.742120527396407770e-01
+1.402599750820609226e-01
+2.236933413898831713e+00
+7.785110674558124444e-01
+1.881489831117169276e+00
+3.487995336350284248e+00
+3.396485994228788829e+00
+6.879465990330972947e-01
+5.618151694300682619e-01
+5.303773159161432327e-01
+7.361475832658961882e-02
+8.930819946990574687e-01
+1.779976084615252585e+00
+4.086245552329793029e-01
+5.164113151014938730e+00
+1.462905037131466912e+00
+7.370368629768874191e-01
+1.987278302009876985e+00
+9.721266174071729882e-01
+1.473186291581487783e-01
+6.483443779751860703e-01
+2.697537697590622230e+00
+1.015760142012948029e-02
+1.506916098273485272e+00
+2.212622623304430380e+00
+1.220237821851318932e-01
+4.458375803433419216e-01
+3.622931236761927076e-01
+9.706837058276953645e-01
+1.245315119310226093e+00
+1.259635152519284762e-01
+2.739689609504042389e-01
+9.760547549081137475e-02
diff --git a/imcui/datasets/wxbs_benchmark/download.py b/imcui/datasets/wxbs_benchmark/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..afbcc20568c56786e547998ac1748173f44474ee
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/download.py
@@ -0,0 +1,4 @@
+from wxbs_benchmark.dataset import * # noqa: F403
+
+dset = EVDDataset(".EVD", download=True) # noqa: F405
+dset = WxBSDataset(".WxBS", subset="test", download=True) # noqa: F405
diff --git a/imcui/datasets/wxbs_benchmark/example.py b/imcui/datasets/wxbs_benchmark/example.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcd81950482920b39184086f525a60ef35c0ab7c
--- /dev/null
+++ b/imcui/datasets/wxbs_benchmark/example.py
@@ -0,0 +1,19 @@
+import os
+from pathlib import Path
+
+ROOT_PATH = Path("/teamspace/studios/this_studio/image-matching-webui")
+prefix = "datasets/wxbs_benchmark/.WxBS/v1.1"
+wxbs_path = ROOT_PATH / prefix
+
+pairs = []
+for catg in os.listdir(wxbs_path):
+ catg_path = wxbs_path / catg
+ if not catg_path.is_dir():
+ continue
+ for scene in os.listdir(catg_path):
+ scene_path = catg_path / scene
+ if not scene_path.is_dir():
+ continue
+ img1_path = scene_path / "01.png"
+ img2_path = scene_path / "02.png"
+ pairs.append([str(img1_path), str(img2_path)])
diff --git a/imcui/hloc/__init__.py b/imcui/hloc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaf2753de6b5a09703aa58067891c9cd4bb5dee7
--- /dev/null
+++ b/imcui/hloc/__init__.py
@@ -0,0 +1,68 @@
+import logging
+import sys
+
+import torch
+from packaging import version
+
+__version__ = "1.5"
+
+LOG_PATH = "log.txt"
+
+
+def read_logs():
+ sys.stdout.flush()
+ with open(LOG_PATH, "r") as f:
+ return f.read()
+
+
+def flush_logs():
+ sys.stdout.flush()
+ logs = open(LOG_PATH, "w")
+ logs.close()
+
+
+formatter = logging.Formatter(
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
+ datefmt="%Y/%m/%d %H:%M:%S",
+)
+
+logs_file = open(LOG_PATH, "w")
+logs_file.close()
+
+file_handler = logging.FileHandler(filename=LOG_PATH)
+file_handler.setFormatter(formatter)
+file_handler.setLevel(logging.INFO)
+stdout_handler = logging.StreamHandler()
+stdout_handler.setFormatter(formatter)
+stdout_handler.setLevel(logging.INFO)
+logger = logging.getLogger("hloc")
+logger.setLevel(logging.INFO)
+logger.addHandler(file_handler)
+logger.addHandler(stdout_handler)
+logger.propagate = False
+
+try:
+ import pycolmap
+except ImportError:
+ logger.warning("pycolmap is not installed, some features may not work.")
+else:
+ min_version = version.parse("0.6.0")
+ found_version = pycolmap.__version__
+ if found_version != "dev":
+ version = version.parse(found_version)
+ if version < min_version:
+ s = f"pycolmap>={min_version}"
+ logger.warning(
+ "hloc requires %s but found pycolmap==%s, "
+ 'please upgrade with `pip install --upgrade "%s"`',
+ s,
+ found_version,
+ s,
+ )
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# model hub: https://huggingface.co/Realcat/imcui_checkpoints
+MODEL_REPO_ID = "Realcat/imcui_checkpoints"
+
+DATASETS_REPO_ID = "Realcat/imcui_datasets"
diff --git a/imcui/hloc/colmap_from_nvm.py b/imcui/hloc/colmap_from_nvm.py
new file mode 100644
index 0000000000000000000000000000000000000000..121ac42182c1942a96d5b1585319cdc634d40db7
--- /dev/null
+++ b/imcui/hloc/colmap_from_nvm.py
@@ -0,0 +1,216 @@
+import argparse
+import sqlite3
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+from tqdm import tqdm
+
+from . import logger
+from .utils.read_write_model import (
+ CAMERA_MODEL_NAMES,
+ Camera,
+ Image,
+ Point3D,
+ write_model,
+)
+
+
+def recover_database_images_and_ids(database_path):
+ images = {}
+ cameras = {}
+ db = sqlite3.connect(str(database_path))
+ ret = db.execute("SELECT name, image_id, camera_id FROM images;")
+ for name, image_id, camera_id in ret:
+ images[name] = image_id
+ cameras[name] = camera_id
+ db.close()
+ logger.info(f"Found {len(images)} images and {len(cameras)} cameras in database.")
+ return images, cameras
+
+
+def quaternion_to_rotation_matrix(qvec):
+ qvec = qvec / np.linalg.norm(qvec)
+ w, x, y, z = qvec
+ R = np.array(
+ [
+ [
+ 1 - 2 * y * y - 2 * z * z,
+ 2 * x * y - 2 * z * w,
+ 2 * x * z + 2 * y * w,
+ ],
+ [
+ 2 * x * y + 2 * z * w,
+ 1 - 2 * x * x - 2 * z * z,
+ 2 * y * z - 2 * x * w,
+ ],
+ [
+ 2 * x * z - 2 * y * w,
+ 2 * y * z + 2 * x * w,
+ 1 - 2 * x * x - 2 * y * y,
+ ],
+ ]
+ )
+ return R
+
+
+def camera_center_to_translation(c, qvec):
+ R = quaternion_to_rotation_matrix(qvec)
+ return (-1) * np.matmul(R, c)
+
+
+def read_nvm_model(nvm_path, intrinsics_path, image_ids, camera_ids, skip_points=False):
+ with open(intrinsics_path, "r") as f:
+ raw_intrinsics = f.readlines()
+
+ logger.info(f"Reading {len(raw_intrinsics)} cameras...")
+ cameras = {}
+ for intrinsics in raw_intrinsics:
+ intrinsics = intrinsics.strip("\n").split(" ")
+ name, camera_model, width, height = intrinsics[:4]
+ params = [float(p) for p in intrinsics[4:]]
+ camera_model = CAMERA_MODEL_NAMES[camera_model]
+ assert len(params) == camera_model.num_params
+ camera_id = camera_ids[name]
+ camera = Camera(
+ id=camera_id,
+ model=camera_model.model_name,
+ width=int(width),
+ height=int(height),
+ params=params,
+ )
+ cameras[camera_id] = camera
+
+ nvm_f = open(nvm_path, "r")
+ line = nvm_f.readline()
+ while line == "\n" or line.startswith("NVM_V3"):
+ line = nvm_f.readline()
+ num_images = int(line)
+ assert num_images == len(cameras)
+
+ logger.info(f"Reading {num_images} images...")
+ image_idx_to_db_image_id = []
+ image_data = []
+ i = 0
+ while i < num_images:
+ line = nvm_f.readline()
+ if line == "\n":
+ continue
+ data = line.strip("\n").split(" ")
+ image_data.append(data)
+ image_idx_to_db_image_id.append(image_ids[data[0]])
+ i += 1
+
+ line = nvm_f.readline()
+ while line == "\n":
+ line = nvm_f.readline()
+ num_points = int(line)
+
+ if skip_points:
+ logger.info(f"Skipping {num_points} points.")
+ num_points = 0
+ else:
+ logger.info(f"Reading {num_points} points...")
+ points3D = {}
+ image_idx_to_keypoints = defaultdict(list)
+ i = 0
+ pbar = tqdm(total=num_points, unit="pts")
+ while i < num_points:
+ line = nvm_f.readline()
+ if line == "\n":
+ continue
+
+ data = line.strip("\n").split(" ")
+ x, y, z, r, g, b, num_observations = data[:7]
+ obs_image_ids, point2D_idxs = [], []
+ for j in range(int(num_observations)):
+ s = 7 + 4 * j
+ img_index, kp_index, kx, ky = data[s : s + 4]
+ image_idx_to_keypoints[int(img_index)].append(
+ (int(kp_index), float(kx), float(ky), i)
+ )
+ db_image_id = image_idx_to_db_image_id[int(img_index)]
+ obs_image_ids.append(db_image_id)
+ point2D_idxs.append(kp_index)
+
+ point = Point3D(
+ id=i,
+ xyz=np.array([x, y, z], float),
+ rgb=np.array([r, g, b], int),
+ error=1.0, # fake
+ image_ids=np.array(obs_image_ids, int),
+ point2D_idxs=np.array(point2D_idxs, int),
+ )
+ points3D[i] = point
+
+ i += 1
+ pbar.update(1)
+ pbar.close()
+
+ logger.info("Parsing image data...")
+ images = {}
+ for i, data in enumerate(image_data):
+ # Skip the focal length. Skip the distortion and terminal 0.
+ name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data
+ qvec = np.array([qw, qx, qy, qz], float)
+ c = np.array([cx, cy, cz], float)
+ t = camera_center_to_translation(c, qvec)
+
+ if i in image_idx_to_keypoints:
+ # NVM only stores triangulated 2D keypoints: add dummy ones
+ keypoints = image_idx_to_keypoints[i]
+ point2D_idxs = np.array([d[0] for d in keypoints])
+ tri_xys = np.array([[x, y] for _, x, y, _ in keypoints])
+ tri_ids = np.array([i for _, _, _, i in keypoints])
+
+ num_2Dpoints = max(point2D_idxs) + 1
+ xys = np.zeros((num_2Dpoints, 2), float)
+ point3D_ids = np.full(num_2Dpoints, -1, int)
+ xys[point2D_idxs] = tri_xys
+ point3D_ids[point2D_idxs] = tri_ids
+ else:
+ xys = np.zeros((0, 2), float)
+ point3D_ids = np.full(0, -1, int)
+
+ image_id = image_ids[name]
+ image = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=t,
+ camera_id=camera_ids[name],
+ name=name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ images[image_id] = image
+
+ return cameras, images, points3D
+
+
+def main(nvm, intrinsics, database, output, skip_points=False):
+ assert nvm.exists(), nvm
+ assert intrinsics.exists(), intrinsics
+ assert database.exists(), database
+
+ image_ids, camera_ids = recover_database_images_and_ids(database)
+
+ logger.info("Reading the NVM model...")
+ model = read_nvm_model(
+ nvm, intrinsics, image_ids, camera_ids, skip_points=skip_points
+ )
+
+ logger.info("Writing the COLMAP model...")
+ output.mkdir(exist_ok=True, parents=True)
+ write_model(*model, path=str(output), ext=".bin")
+ logger.info("Done.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--nvm", required=True, type=Path)
+ parser.add_argument("--intrinsics", required=True, type=Path)
+ parser.add_argument("--database", required=True, type=Path)
+ parser.add_argument("--output", required=True, type=Path)
+ parser.add_argument("--skip_points", action="store_true")
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/extract_features.py b/imcui/hloc/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b6a5c76f7c8bffa41fb82b4d0ec3dbf09ffcf3e
--- /dev/null
+++ b/imcui/hloc/extract_features.py
@@ -0,0 +1,607 @@
+import argparse
+import collections.abc as collections
+import pprint
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Dict, List, Optional, Union
+
+import cv2
+import h5py
+import numpy as np
+import PIL.Image
+import torch
+import torchvision.transforms.functional as F
+from tqdm import tqdm
+
+from . import extractors, logger
+from .utils.base_model import dynamic_load
+from .utils.io import list_h5_names, read_image
+from .utils.parsers import parse_image_lists
+
+"""
+A set of standard configurations that can be directly selected from the command
+line using their name. Each is a dictionary with the following entries:
+ - output: the name of the feature file that will be generated.
+ - model: the model configuration, as passed to a feature extractor.
+ - preprocessing: how to preprocess the images read from disk.
+"""
+confs = {
+ "superpoint_aachen": {
+ "output": "feats-superpoint-n4096-r1024",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ "keypoint_threshold": 0.005,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ # Resize images to 1600px even if they are originally smaller.
+ # Improves the keypoint localization if the images are of good quality.
+ "superpoint_max": {
+ "output": "feats-superpoint-n4096-rmax1600",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ "keypoint_threshold": 0.005,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "superpoint_inloc": {
+ "output": "feats-superpoint-n4096-r1600",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 4,
+ "max_keypoints": 4096,
+ "keypoint_threshold": 0.005,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1600,
+ },
+ },
+ "r2d2": {
+ "output": "feats-r2d2-n5000-r1024",
+ "model": {
+ "name": "r2d2",
+ "max_keypoints": 5000,
+ "reliability_threshold": 0.7,
+ "repetability_threshold": 0.7,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "d2net-ss": {
+ "output": "feats-d2net-ss-n5000-r1600",
+ "model": {
+ "name": "d2net",
+ "multiscale": False,
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "d2net-ms": {
+ "output": "feats-d2net-ms-n5000-r1600",
+ "model": {
+ "name": "d2net",
+ "multiscale": True,
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "rord": {
+ "output": "feats-rord-ss-n5000-r1600",
+ "model": {
+ "name": "rord",
+ "multiscale": False,
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "rootsift": {
+ "output": "feats-rootsift-n5000-r1600",
+ "model": {
+ "name": "dog",
+ "descriptor": "rootsift",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "sift": {
+ "output": "feats-sift-n5000-r1600",
+ "model": {
+ "name": "sift",
+ "rootsift": True,
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "sosnet": {
+ "output": "feats-sosnet-n5000-r1600",
+ "model": {
+ "name": "dog",
+ "descriptor": "sosnet",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1600,
+ "force_resize": True,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "hardnet": {
+ "output": "feats-hardnet-n5000-r1600",
+ "model": {
+ "name": "dog",
+ "descriptor": "hardnet",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1600,
+ "force_resize": True,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "disk": {
+ "output": "feats-disk-n5000-r1600",
+ "model": {
+ "name": "disk",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "xfeat": {
+ "output": "feats-xfeat-n5000-r1600",
+ "model": {
+ "name": "xfeat",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "aliked-n16-rot": {
+ "output": "feats-aliked-n16-rot",
+ "model": {
+ "name": "aliked",
+ "model_name": "aliked-n16rot",
+ "max_num_keypoints": -1,
+ "detection_threshold": 0.2,
+ "nms_radius": 2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1024,
+ },
+ },
+ "aliked-n16": {
+ "output": "feats-aliked-n16",
+ "model": {
+ "name": "aliked",
+ "model_name": "aliked-n16",
+ "max_num_keypoints": -1,
+ "detection_threshold": 0.2,
+ "nms_radius": 2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1024,
+ },
+ },
+ "alike": {
+ "output": "feats-alike-n5000-r1600",
+ "model": {
+ "name": "alike",
+ "max_keypoints": 5000,
+ "use_relu": True,
+ "multiscale": False,
+ "detection_threshold": 0.5,
+ "top_k": -1,
+ "sub_pixel": False,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "lanet": {
+ "output": "feats-lanet-n5000-r1600",
+ "model": {
+ "name": "lanet",
+ "keypoint_threshold": 0.1,
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1600,
+ },
+ },
+ "darkfeat": {
+ "output": "feats-darkfeat-n5000-r1600",
+ "model": {
+ "name": "darkfeat",
+ "max_keypoints": 5000,
+ "reliability_threshold": 0.7,
+ "repetability_threshold": 0.7,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "dedode": {
+ "output": "feats-dedode-n5000-r1600",
+ "model": {
+ "name": "dedode",
+ "max_keypoints": 5000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 768,
+ "height": 768,
+ "dfactor": 8,
+ },
+ },
+ "example": {
+ "output": "feats-example-n2000-r1024",
+ "model": {
+ "name": "example",
+ "keypoint_threshold": 0.1,
+ "max_keypoints": 2000,
+ "model_name": "model.pth",
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 768,
+ "height": 768,
+ "dfactor": 8,
+ },
+ },
+ "sfd2": {
+ "output": "feats-sfd2-n4096-r1600",
+ "model": {
+ "name": "sfd2",
+ "max_keypoints": 4096,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1600,
+ "width": 640,
+ "height": 480,
+ "conf_th": 0.001,
+ "multiscale": False,
+ "scales": [1.0],
+ },
+ },
+ # Global descriptors
+ "dir": {
+ "output": "global-feats-dir",
+ "model": {"name": "dir"},
+ "preprocessing": {"resize_max": 1024},
+ },
+ "netvlad": {
+ "output": "global-feats-netvlad",
+ "model": {"name": "netvlad"},
+ "preprocessing": {"resize_max": 1024},
+ },
+ "openibl": {
+ "output": "global-feats-openibl",
+ "model": {"name": "openibl"},
+ "preprocessing": {"resize_max": 1024},
+ },
+ "cosplace": {
+ "output": "global-feats-cosplace",
+ "model": {"name": "cosplace"},
+ "preprocessing": {"resize_max": 1024},
+ },
+ "eigenplaces": {
+ "output": "global-feats-eigenplaces",
+ "model": {"name": "eigenplaces"},
+ "preprocessing": {"resize_max": 1024},
+ },
+}
+
+
+def resize_image(image, size, interp):
+ if interp.startswith("cv2_"):
+ interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper())
+ h, w = image.shape[:2]
+ if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
+ interp = cv2.INTER_LINEAR
+ resized = cv2.resize(image, size, interpolation=interp)
+ elif interp.startswith("pil_"):
+ interp = getattr(PIL.Image, interp[len("pil_") :].upper())
+ resized = PIL.Image.fromarray(image.astype(np.uint8))
+ resized = resized.resize(size, resample=interp)
+ resized = np.asarray(resized, dtype=image.dtype)
+ else:
+ raise ValueError(f"Unknown interpolation {interp}.")
+ return resized
+
+
+class ImageDataset(torch.utils.data.Dataset):
+ default_conf = {
+ "globs": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
+ "grayscale": False,
+ "resize_max": None,
+ "force_resize": False,
+ "interpolation": "cv2_area", # pil_linear is more accurate but slower
+ }
+
+ def __init__(self, root, conf, paths=None):
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
+ self.root = root
+
+ if paths is None:
+ paths = []
+ for g in conf.globs:
+ paths += list(Path(root).glob("**/" + g))
+ if len(paths) == 0:
+ raise ValueError(f"Could not find any image in root: {root}.")
+ paths = sorted(list(set(paths)))
+ self.names = [i.relative_to(root).as_posix() for i in paths]
+ logger.info(f"Found {len(self.names)} images in root {root}.")
+ else:
+ if isinstance(paths, (Path, str)):
+ self.names = parse_image_lists(paths)
+ elif isinstance(paths, collections.Iterable):
+ self.names = [p.as_posix() if isinstance(p, Path) else p for p in paths]
+ else:
+ raise ValueError(f"Unknown format for path argument {paths}.")
+
+ for name in self.names:
+ if not (root / name).exists():
+ raise ValueError(f"Image {name} does not exists in root: {root}.")
+
+ def __getitem__(self, idx):
+ name = self.names[idx]
+ image = read_image(self.root / name, self.conf.grayscale)
+ image = image.astype(np.float32)
+ size = image.shape[:2][::-1]
+
+ if self.conf.resize_max and (
+ self.conf.force_resize or max(size) > self.conf.resize_max
+ ):
+ scale = self.conf.resize_max / max(size)
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, self.conf.interpolation)
+
+ if self.conf.grayscale:
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = image / 255.0
+
+ data = {
+ "image": image,
+ "original_size": np.array(size),
+ }
+ return data
+
+ def __len__(self):
+ return len(self.names)
+
+
+def extract(model, image_0, conf):
+ default_conf = {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "cache_images": False,
+ "force_resize": False,
+ "width": 320,
+ "height": 240,
+ "interpolation": "cv2_area",
+ }
+ conf = SimpleNamespace(**{**default_conf, **conf})
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ def preprocess(image: np.ndarray, conf: SimpleNamespace):
+ image = image.astype(np.float32, copy=False)
+ size = image.shape[:2][::-1]
+ scale = np.array([1.0, 1.0])
+ if conf.resize_max:
+ scale = conf.resize_max / max(size)
+ if scale < 1.0:
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, "cv2_area")
+ scale = np.array(size) / np.array(size_new)
+ if conf.force_resize:
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
+ size_new = (conf.width, conf.height)
+ scale = np.array(size) / np.array(size_new)
+ if conf.grayscale:
+ assert image.ndim == 2, image.shape
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = torch.from_numpy(image / 255.0).float()
+
+ # assure that the size is divisible by dfactor
+ size_new = tuple(
+ map(
+ lambda x: int(x // conf.dfactor * conf.dfactor),
+ image.shape[-2:],
+ )
+ )
+ image = F.resize(image, size=size_new, antialias=True)
+ input_ = image.to(device, non_blocking=True)[None]
+ data = {
+ "image": input_,
+ "image_orig": image_0,
+ "original_size": np.array(size),
+ "size": np.array(image.shape[1:][::-1]),
+ }
+ return data
+
+ # convert to grayscale if needed
+ if len(image_0.shape) == 3 and conf.grayscale:
+ image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY)
+ else:
+ image0 = image_0
+ # comment following lines, image is always RGB mode
+ # if not conf.grayscale and len(image_0.shape) == 3:
+ # image0 = image_0[:, :, ::-1] # BGR to RGB
+ data = preprocess(image0, conf)
+ pred = model({"image": data["image"]})
+ pred["image_size"] = data["original_size"]
+ pred = {**pred, **data}
+ return pred
+
+
+@torch.no_grad()
+def main(
+ conf: Dict,
+ image_dir: Path,
+ export_dir: Optional[Path] = None,
+ as_half: bool = True,
+ image_list: Optional[Union[Path, List[str]]] = None,
+ feature_path: Optional[Path] = None,
+ overwrite: bool = False,
+) -> Path:
+ logger.info(
+ "Extracting local features with configuration:" f"\n{pprint.pformat(conf)}"
+ )
+
+ dataset = ImageDataset(image_dir, conf["preprocessing"], image_list)
+ if feature_path is None:
+ feature_path = Path(export_dir, conf["output"] + ".h5")
+ feature_path.parent.mkdir(exist_ok=True, parents=True)
+ skip_names = set(
+ list_h5_names(feature_path) if feature_path.exists() and not overwrite else ()
+ )
+ dataset.names = [n for n in dataset.names if n not in skip_names]
+ if len(dataset.names) == 0:
+ logger.info("Skipping the extraction.")
+ return feature_path
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ Model = dynamic_load(extractors, conf["model"]["name"])
+ model = Model(conf["model"]).eval().to(device)
+
+ loader = torch.utils.data.DataLoader(
+ dataset, num_workers=1, shuffle=False, pin_memory=True
+ )
+ for idx, data in enumerate(tqdm(loader)):
+ name = dataset.names[idx]
+ pred = model({"image": data["image"].to(device, non_blocking=True)})
+ pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
+
+ pred["image_size"] = original_size = data["original_size"][0].numpy()
+ if "keypoints" in pred:
+ size = np.array(data["image"].shape[-2:][::-1])
+ scales = (original_size / size).astype(np.float32)
+ pred["keypoints"] = (pred["keypoints"] + 0.5) * scales[None] - 0.5
+ if "scales" in pred:
+ pred["scales"] *= scales.mean()
+ # add keypoint uncertainties scaled to the original resolution
+ uncertainty = getattr(model, "detection_noise", 1) * scales.mean()
+
+ if as_half:
+ for k in pred:
+ dt = pred[k].dtype
+ if (dt == np.float32) and (dt != np.float16):
+ pred[k] = pred[k].astype(np.float16)
+
+ with h5py.File(str(feature_path), "a", libver="latest") as fd:
+ try:
+ if name in fd:
+ del fd[name]
+ grp = fd.create_group(name)
+ for k, v in pred.items():
+ grp.create_dataset(k, data=v)
+ if "keypoints" in pred:
+ grp["keypoints"].attrs["uncertainty"] = uncertainty
+ except OSError as error:
+ if "No space left on device" in error.args[0]:
+ logger.error(
+ "Out of disk space: storing features on disk can take "
+ "significant space, did you enable the as_half flag?"
+ )
+ del grp, fd[name]
+ raise error
+
+ del pred
+
+ logger.info("Finished exporting features.")
+ return feature_path
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--image_dir", type=Path, required=True)
+ parser.add_argument("--export_dir", type=Path, required=True)
+ parser.add_argument(
+ "--conf",
+ type=str,
+ default="superpoint_aachen",
+ choices=list(confs.keys()),
+ )
+ parser.add_argument("--as_half", action="store_true")
+ parser.add_argument("--image_list", type=Path)
+ parser.add_argument("--feature_path", type=Path)
+ args = parser.parse_args()
+ main(confs[args.conf], args.image_dir, args.export_dir, args.as_half)
diff --git a/imcui/hloc/extractors/__init__.py b/imcui/hloc/extractors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/extractors/alike.py b/imcui/hloc/extractors/alike.py
new file mode 100644
index 0000000000000000000000000000000000000000..64724dc035c98ce01fa0bbb98b4772a993eb1526
--- /dev/null
+++ b/imcui/hloc/extractors/alike.py
@@ -0,0 +1,61 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+
+from ..utils.base_model import BaseModel
+
+alike_path = Path(__file__).parent / "../../third_party/ALIKE"
+sys.path.append(str(alike_path))
+from alike import ALike as Alike_
+from alike import configs
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Alike(BaseModel):
+ default_conf = {
+ "model_name": "alike-t", # 'alike-t', 'alike-s', 'alike-n', 'alike-l'
+ "use_relu": True,
+ "multiscale": False,
+ "max_keypoints": 1000,
+ "detection_threshold": 0.5,
+ "top_k": -1,
+ "sub_pixel": False,
+ }
+
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}.pth".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ logger.info("Loaded Alike model from {}".format(model_path))
+ configs[conf["model_name"]]["model_path"] = model_path
+ self.net = Alike_(
+ **configs[conf["model_name"]],
+ device=device,
+ top_k=conf["top_k"],
+ scores_th=conf["detection_threshold"],
+ n_limit=conf["max_keypoints"],
+ )
+ logger.info("Load Alike model done.")
+
+ def _forward(self, data):
+ image = data["image"]
+ image = image.permute(0, 2, 3, 1).squeeze()
+ image = image.cpu().numpy() * 255.0
+ pred = self.net(image, sub_pixel=self.conf["sub_pixel"])
+
+ keypoints = pred["keypoints"]
+ descriptors = pred["descriptors"]
+ scores = pred["scores"]
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/imcui/hloc/extractors/aliked.py b/imcui/hloc/extractors/aliked.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f712bebd7c8a1a8052cff22064f19c0a7b13615
--- /dev/null
+++ b/imcui/hloc/extractors/aliked.py
@@ -0,0 +1,32 @@
+import sys
+from pathlib import Path
+
+from ..utils.base_model import BaseModel
+
+lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
+sys.path.append(str(lightglue_path))
+
+from lightglue import ALIKED as ALIKED_
+
+
+class ALIKED(BaseModel):
+ default_conf = {
+ "model_name": "aliked-n16",
+ "max_num_keypoints": -1,
+ "detection_threshold": 0.2,
+ "nms_radius": 2,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ conf.pop("name")
+ self.model = ALIKED_(**conf)
+
+ def _forward(self, data):
+ features = self.model(data)
+
+ return {
+ "keypoints": [f for f in features["keypoints"]],
+ "scores": [f for f in features["keypoint_scores"]],
+ "descriptors": [f.t() for f in features["descriptors"]],
+ }
diff --git a/imcui/hloc/extractors/cosplace.py b/imcui/hloc/extractors/cosplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d13a84d57d80bee090709623cce74453784844b
--- /dev/null
+++ b/imcui/hloc/extractors/cosplace.py
@@ -0,0 +1,44 @@
+"""
+Code for loading models trained with CosPlace as a global features extractor
+for geolocalization through image retrieval.
+Multiple models are available with different backbones. Below is a summary of
+models available (backbone : list of available output descriptors
+dimensionality). For example you can use a model based on a ResNet50 with
+descriptors dimensionality 1024.
+ ResNet18: [32, 64, 128, 256, 512]
+ ResNet50: [32, 64, 128, 256, 512, 1024, 2048]
+ ResNet101: [32, 64, 128, 256, 512, 1024, 2048]
+ ResNet152: [32, 64, 128, 256, 512, 1024, 2048]
+ VGG16: [ 64, 128, 256, 512]
+
+CosPlace paper: https://arxiv.org/abs/2204.02287
+"""
+
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+
+class CosPlace(BaseModel):
+ default_conf = {"backbone": "ResNet50", "fc_output_dim": 2048}
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "gmberton/CosPlace",
+ "get_trained_model",
+ backbone=conf["backbone"],
+ fc_output_dim=conf["fc_output_dim"],
+ ).eval()
+
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ self.norm_rgb = tvf.Normalize(mean=mean, std=std)
+
+ def _forward(self, data):
+ image = self.norm_rgb(data["image"])
+ desc = self.net(image)
+ return {
+ "global_descriptor": desc,
+ }
diff --git a/imcui/hloc/extractors/d2net.py b/imcui/hloc/extractors/d2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..207977c732e14ae6fde1e02d3e7f4335fbdf57e9
--- /dev/null
+++ b/imcui/hloc/extractors/d2net.py
@@ -0,0 +1,60 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+d2net_path = Path(__file__).parent / "../../third_party/d2net"
+sys.path.append(str(d2net_path))
+from lib.model_test import D2Net as _D2Net
+from lib.pyramid import process_multiscale
+
+
+class D2Net(BaseModel):
+ default_conf = {
+ "model_name": "d2_tf.pth",
+ "checkpoint_dir": d2net_path / "models",
+ "use_relu": True,
+ "multiscale": False,
+ "max_keypoints": 1024,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ logger.info("Loading D2Net model...")
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ logger.info(f"Loading model from {model_path}...")
+ self.net = _D2Net(
+ model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
+ )
+ logger.info("Load D2Net model done.")
+
+ def _forward(self, data):
+ image = data["image"]
+ image = image.flip(1) # RGB -> BGR
+ norm = image.new_tensor([103.939, 116.779, 123.68])
+ image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization
+
+ if self.conf["multiscale"]:
+ keypoints, scores, descriptors = process_multiscale(image, self.net)
+ else:
+ keypoints, scores, descriptors = process_multiscale(
+ image, self.net, scales=[1]
+ )
+ keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale
+
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
+ keypoints = keypoints[idxs, :2]
+ descriptors = descriptors[idxs]
+ scores = scores[idxs]
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/imcui/hloc/extractors/darkfeat.py b/imcui/hloc/extractors/darkfeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..8833041e9a168f465df0d07191245777612da890
--- /dev/null
+++ b/imcui/hloc/extractors/darkfeat.py
@@ -0,0 +1,44 @@
+import sys
+from pathlib import Path
+
+from .. import MODEL_REPO_ID, logger
+
+from ..utils.base_model import BaseModel
+
+darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
+sys.path.append(str(darkfeat_path))
+from darkfeat import DarkFeat as DarkFeat_
+
+
+class DarkFeat(BaseModel):
+ default_conf = {
+ "model_name": "DarkFeat.pth",
+ "max_keypoints": 1000,
+ "detection_threshold": 0.5,
+ "sub_pixel": False,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ logger.info("Loaded DarkFeat model: {}".format(model_path))
+ self.net = DarkFeat_(model_path)
+ logger.info("Load DarkFeat model done.")
+
+ def _forward(self, data):
+ pred = self.net({"image": data["image"]})
+ keypoints = pred["keypoints"]
+ descriptors = pred["descriptors"]
+ scores = pred["scores"]
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
+ keypoints = keypoints[idxs, :2]
+ descriptors = descriptors[:, idxs]
+ scores = scores[idxs]
+ return {
+ "keypoints": keypoints[None], # 1 x N x 2
+ "scores": scores[None], # 1 x N
+ "descriptors": descriptors[None], # 1 x 128 x N
+ }
diff --git a/imcui/hloc/extractors/dedode.py b/imcui/hloc/extractors/dedode.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7108e31340535afcef062c1d8eb495014b70ee1
--- /dev/null
+++ b/imcui/hloc/extractors/dedode.py
@@ -0,0 +1,86 @@
+import sys
+from pathlib import Path
+
+import torch
+import torchvision.transforms as transforms
+
+from .. import MODEL_REPO_ID, logger
+
+from ..utils.base_model import BaseModel
+
+dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
+sys.path.append(str(dedode_path))
+
+from DeDoDe import dedode_descriptor_B, dedode_detector_L
+from DeDoDe.utils import to_pixel_coords
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class DeDoDe(BaseModel):
+ default_conf = {
+ "name": "dedode",
+ "model_detector_name": "dedode_detector_L.pth",
+ "model_descriptor_name": "dedode_descriptor_B.pth",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ "dense": False, # Now fixed to be false
+ }
+ required_inputs = [
+ "image",
+ ]
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_detector_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, conf["model_detector_name"]),
+ )
+ model_descriptor_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, conf["model_descriptor_name"]),
+ )
+ logger.info("Loaded DarkFeat model: {}".format(model_detector_path))
+ self.normalizer = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ # load the model
+ weights_detector = torch.load(model_detector_path, map_location="cpu")
+ weights_descriptor = torch.load(model_descriptor_path, map_location="cpu")
+ self.detector = dedode_detector_L(weights=weights_detector, device=device)
+ self.descriptor = dedode_descriptor_B(weights=weights_descriptor, device=device)
+ logger.info("Load DeDoDe model done.")
+
+ def _forward(self, data):
+ """
+ data: dict, keys: {'image0','image1'}
+ image shape: N x C x H x W
+ color mode: RGB
+ """
+ img0 = self.normalizer(data["image"].squeeze()).float()[None]
+ H_A, W_A = img0.shape[2:]
+
+ # step 1: detect keypoints
+ detections_A = None
+ batch_A = {"image": img0}
+ if self.conf["dense"]:
+ detections_A = self.detector.detect_dense(batch_A)
+ else:
+ detections_A = self.detector.detect(
+ batch_A, num_keypoints=self.conf["max_keypoints"]
+ )
+ keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
+
+ # step 2: describe keypoints
+ # dim: 1 x N x 256
+ description_A = self.descriptor.describe_keypoints(batch_A, keypoints_A)[
+ "descriptions"
+ ]
+ keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A)
+
+ return {
+ "keypoints": keypoints_A, # 1 x N x 2
+ "descriptors": description_A.permute(0, 2, 1), # 1 x 256 x N
+ "scores": P_A, # 1 x N
+ }
diff --git a/imcui/hloc/extractors/dir.py b/imcui/hloc/extractors/dir.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd7322a922a151b0a5ad5e185fbb312a0b5d12a7
--- /dev/null
+++ b/imcui/hloc/extractors/dir.py
@@ -0,0 +1,78 @@
+import os
+import sys
+from pathlib import Path
+from zipfile import ZipFile
+
+import gdown
+import sklearn
+import torch
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party/deep-image-retrieval"))
+os.environ["DB_ROOT"] = "" # required by dirtorch
+
+from dirtorch.extract_features import load_model # noqa: E402
+from dirtorch.utils import common # noqa: E402
+
+# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
+# which has been deprecated in sklearn v0.24
+# and must be explicitly imported with `from sklearn.decomposition import PCA`.
+# This is a hacky workaround to maintain forward compatibility.
+sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca
+
+
+class DIR(BaseModel):
+ default_conf = {
+ "model_name": "Resnet-101-AP-GeM",
+ "whiten_name": "Landmarks_clean",
+ "whiten_params": {
+ "whitenp": 0.25,
+ "whitenv": None,
+ "whitenm": 1.0,
+ },
+ "pooling": "gem",
+ "gemp": 3,
+ }
+ required_inputs = ["image"]
+
+ dir_models = {
+ "Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy",
+ }
+
+ def _init(self, conf):
+ # todo: download from google drive -> huggingface models
+ checkpoint = Path(torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt")
+ if not checkpoint.exists():
+ checkpoint.parent.mkdir(exist_ok=True, parents=True)
+ link = self.dir_models[conf["model_name"]]
+ gdown.download(str(link), str(checkpoint) + ".zip", quiet=False)
+ zf = ZipFile(str(checkpoint) + ".zip", "r")
+ zf.extractall(checkpoint.parent)
+ zf.close()
+ os.remove(str(checkpoint) + ".zip")
+
+ self.net = load_model(checkpoint, False) # first load on CPU
+ if conf["whiten_name"]:
+ assert conf["whiten_name"] in self.net.pca
+
+ def _forward(self, data):
+ image = data["image"]
+ assert image.shape[1] == 3
+ mean = self.net.preprocess["mean"]
+ std = self.net.preprocess["std"]
+ image = image - image.new_tensor(mean)[:, None, None]
+ image = image / image.new_tensor(std)[:, None, None]
+
+ desc = self.net(image)
+ desc = desc.unsqueeze(0) # batch dimension
+ if self.conf["whiten_name"]:
+ pca = self.net.pca[self.conf["whiten_name"]]
+ desc = common.whiten_features(
+ desc.cpu().numpy(), pca, **self.conf["whiten_params"]
+ )
+ desc = torch.from_numpy(desc)
+
+ return {
+ "global_descriptor": desc,
+ }
diff --git a/imcui/hloc/extractors/disk.py b/imcui/hloc/extractors/disk.py
new file mode 100644
index 0000000000000000000000000000000000000000..a062a908af68656c29e7ee1e8c5047c92790bcc9
--- /dev/null
+++ b/imcui/hloc/extractors/disk.py
@@ -0,0 +1,35 @@
+import kornia
+
+from .. import logger
+
+from ..utils.base_model import BaseModel
+
+
+class DISK(BaseModel):
+ default_conf = {
+ "weights": "depth",
+ "max_keypoints": None,
+ "nms_window_size": 5,
+ "detection_threshold": 0.0,
+ "pad_if_not_divisible": True,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.model = kornia.feature.DISK.from_pretrained(conf["weights"])
+ logger.info("Load DISK model done.")
+
+ def _forward(self, data):
+ image = data["image"]
+ features = self.model(
+ image,
+ n=self.conf["max_keypoints"],
+ window_size=self.conf["nms_window_size"],
+ score_threshold=self.conf["detection_threshold"],
+ pad_if_not_divisible=self.conf["pad_if_not_divisible"],
+ )
+ return {
+ "keypoints": [f.keypoints for f in features][0][None],
+ "scores": [f.detection_scores for f in features][0][None],
+ "descriptors": [f.descriptors.t() for f in features][0][None],
+ }
diff --git a/imcui/hloc/extractors/dog.py b/imcui/hloc/extractors/dog.py
new file mode 100644
index 0000000000000000000000000000000000000000..b280bbc42376f3af827002bb85ff4996ccdf50b4
--- /dev/null
+++ b/imcui/hloc/extractors/dog.py
@@ -0,0 +1,135 @@
+import kornia
+import numpy as np
+import pycolmap
+import torch
+from kornia.feature.laf import (
+ extract_patches_from_pyramid,
+ laf_from_center_scale_ori,
+)
+
+from ..utils.base_model import BaseModel
+
+EPS = 1e-6
+
+
+def sift_to_rootsift(x):
+ x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
+ x = np.sqrt(x.clip(min=EPS))
+ x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
+ return x
+
+
+class DoG(BaseModel):
+ default_conf = {
+ "options": {
+ "first_octave": 0,
+ "peak_threshold": 0.01,
+ },
+ "descriptor": "rootsift",
+ "max_keypoints": -1,
+ "patch_size": 32,
+ "mr_size": 12,
+ }
+ required_inputs = ["image"]
+ detection_noise = 1.0
+ max_batch_size = 1024
+
+ def _init(self, conf):
+ if conf["descriptor"] == "sosnet":
+ self.describe = kornia.feature.SOSNet(pretrained=True)
+ elif conf["descriptor"] == "hardnet":
+ self.describe = kornia.feature.HardNet(pretrained=True)
+ elif conf["descriptor"] not in ["sift", "rootsift"]:
+ raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')
+
+ self.sift = None # lazily instantiated on the first image
+ self.dummy_param = torch.nn.Parameter(torch.empty(0))
+ self.device = torch.device("cpu")
+
+ def to(self, *args, **kwargs):
+ device = kwargs.get("device")
+ if device is None:
+ match = [a for a in args if isinstance(a, (torch.device, str))]
+ if len(match) > 0:
+ device = match[0]
+ if device is not None:
+ self.device = torch.device(device)
+ return super().to(*args, **kwargs)
+
+ def _forward(self, data):
+ image = data["image"]
+ image_np = image.cpu().numpy()[0, 0]
+ assert image.shape[1] == 1
+ assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
+
+ if self.sift is None:
+ device = self.dummy_param.device
+ use_gpu = pycolmap.has_cuda and device.type == "cuda"
+ options = {**self.conf["options"]}
+ if self.conf["descriptor"] == "rootsift":
+ options["normalization"] = pycolmap.Normalization.L1_ROOT
+ else:
+ options["normalization"] = pycolmap.Normalization.L2
+ self.sift = pycolmap.Sift(
+ options=pycolmap.SiftExtractionOptions(options),
+ device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
+ )
+ keypoints, descriptors = self.sift.extract(image_np)
+ scales = keypoints[:, 2]
+ oris = np.rad2deg(keypoints[:, 3])
+
+ if self.conf["descriptor"] in ["sift", "rootsift"]:
+ # We still renormalize because COLMAP does not normalize well,
+ # maybe due to numerical errors
+ if self.conf["descriptor"] == "rootsift":
+ descriptors = sift_to_rootsift(descriptors)
+ descriptors = torch.from_numpy(descriptors)
+ elif self.conf["descriptor"] in ("sosnet", "hardnet"):
+ center = keypoints[:, :2] + 0.5
+ laf_scale = scales * self.conf["mr_size"] / 2
+ laf_ori = -oris
+ lafs = laf_from_center_scale_ori(
+ torch.from_numpy(center)[None],
+ torch.from_numpy(laf_scale)[None, :, None, None],
+ torch.from_numpy(laf_ori)[None, :, None],
+ ).to(image.device)
+ patches = extract_patches_from_pyramid(
+ image, lafs, PS=self.conf["patch_size"]
+ )[0]
+ descriptors = patches.new_zeros((len(patches), 128))
+ if len(patches) > 0:
+ for start_idx in range(0, len(patches), self.max_batch_size):
+ end_idx = min(len(patches), start_idx + self.max_batch_size)
+ descriptors[start_idx:end_idx] = self.describe(
+ patches[start_idx:end_idx]
+ )
+ else:
+ raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')
+
+ keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
+ scales = torch.from_numpy(scales)
+ oris = torch.from_numpy(oris)
+ scores = keypoints.new_zeros(len(keypoints)) # no scores for SIFT yet
+
+ if self.conf["max_keypoints"] != -1:
+ # TODO: check that the scores from PyCOLMAP are 100% correct,
+ # follow https://github.com/mihaidusmanu/pycolmap/issues/8
+ max_number = (
+ scores.shape[0]
+ if scores.shape[0] < self.conf["max_keypoints"]
+ else self.conf["max_keypoints"]
+ )
+ values, indices = torch.topk(scores, max_number)
+ keypoints = keypoints[indices]
+ scales = scales[indices]
+ oris = oris[indices]
+ scores = scores[indices]
+ descriptors = descriptors[indices]
+
+ return {
+ "keypoints": keypoints[None],
+ "scales": scales[None],
+ "oris": oris[None],
+ "scores": scores[None],
+ "descriptors": descriptors.T[None],
+ }
diff --git a/imcui/hloc/extractors/eigenplaces.py b/imcui/hloc/extractors/eigenplaces.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd9953b27c00682c830842736fd0bdab93857f14
--- /dev/null
+++ b/imcui/hloc/extractors/eigenplaces.py
@@ -0,0 +1,57 @@
+"""
+Code for loading models trained with EigenPlaces (or CosPlace) as a global
+features extractor for geolocalization through image retrieval.
+Multiple models are available with different backbones. Below is a summary of
+models available (backbone : list of available output descriptors
+dimensionality). For example you can use a model based on a ResNet50 with
+descriptors dimensionality 1024.
+
+EigenPlaces trained models:
+ ResNet18: [ 256, 512]
+ ResNet50: [128, 256, 512, 2048]
+ ResNet101: [128, 256, 512, 2048]
+ VGG16: [ 512]
+
+CosPlace trained models:
+ ResNet18: [32, 64, 128, 256, 512]
+ ResNet50: [32, 64, 128, 256, 512, 1024, 2048]
+ ResNet101: [32, 64, 128, 256, 512, 1024, 2048]
+ ResNet152: [32, 64, 128, 256, 512, 1024, 2048]
+ VGG16: [ 64, 128, 256, 512]
+
+EigenPlaces paper (ICCV 2023): https://arxiv.org/abs/2308.10832
+CosPlace paper (CVPR 2022): https://arxiv.org/abs/2204.02287
+"""
+
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+
+class EigenPlaces(BaseModel):
+ default_conf = {
+ "variant": "EigenPlaces",
+ "backbone": "ResNet101",
+ "fc_output_dim": 2048,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "gmberton/" + conf["variant"],
+ "get_trained_model",
+ backbone=conf["backbone"],
+ fc_output_dim=conf["fc_output_dim"],
+ ).eval()
+
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ self.norm_rgb = tvf.Normalize(mean=mean, std=std)
+
+ def _forward(self, data):
+ image = self.norm_rgb(data["image"])
+ desc = self.net(image)
+ return {
+ "global_descriptor": desc,
+ }
diff --git a/imcui/hloc/extractors/example.py b/imcui/hloc/extractors/example.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d952c4014e006d74409a8f32ee7159d58305de5
--- /dev/null
+++ b/imcui/hloc/extractors/example.py
@@ -0,0 +1,56 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import logger
+from ..utils.base_model import BaseModel
+
+example_path = Path(__file__).parent / "../../third_party/example"
+sys.path.append(str(example_path))
+
+# import some modules here
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Example(BaseModel):
+ # change to your default configs
+ default_conf = {
+ "name": "example",
+ "keypoint_threshold": 0.1,
+ "max_keypoints": 2000,
+ "model_name": "model.pth",
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ # set checkpoints paths if needed
+ model_path = example_path / "checkpoints" / f'{conf["model_name"]}'
+ if not model_path.exists():
+ logger.info(f"No model found at {model_path}")
+
+ # init model
+ self.net = callable
+ # self.net = ExampleNet(is_test=True)
+ state_dict = torch.load(model_path, map_location="cpu")
+ self.net.load_state_dict(state_dict["model_state"])
+ logger.info("Load example model done.")
+
+ def _forward(self, data):
+ # data: dict, keys: 'image'
+ # image color mode: RGB
+ # image value range in [0, 1]
+ image = data["image"]
+
+ # B: batch size, N: number of keypoints
+ # keypoints shape: B x N x 2, type: torch tensor
+ # scores shape: B x N, type: torch tensor
+ # descriptors shape: B x 128 x N, type: torch tensor
+ keypoints, scores, descriptors = self.net(image)
+
+ return {
+ "keypoints": keypoints,
+ "scores": scores,
+ "descriptors": descriptors,
+ }
diff --git a/imcui/hloc/extractors/fire.py b/imcui/hloc/extractors/fire.py
new file mode 100644
index 0000000000000000000000000000000000000000..980f18e63d1a395835891c8e6595cfc66c21db2d
--- /dev/null
+++ b/imcui/hloc/extractors/fire.py
@@ -0,0 +1,72 @@
+import logging
+import subprocess
+import sys
+from pathlib import Path
+
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+fire_path = Path(__file__).parent / "../../third_party/fire"
+sys.path.append(str(fire_path))
+
+
+import fire_network
+
+
+class FIRe(BaseModel):
+ default_conf = {
+ "global": True,
+ "asmk": False,
+ "model_name": "fire_SfM_120k.pth",
+ "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params
+ "features_num": 1000, # TODO:not supported now
+ "asmk_name": "asmk_codebook.bin", # TODO:not supported now
+ "config_name": "eval_fire.yml",
+ }
+ required_inputs = ["image"]
+
+ # Models exported using
+ fire_models = {
+ "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
+ "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
+ }
+
+ def _init(self, conf):
+ assert conf["model_name"] in self.fire_models.keys()
+ # Config paths
+ model_path = fire_path / "model" / conf["model_name"]
+
+ # Download the model.
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ link = self.fire_models[conf["model_name"]]
+ cmd = ["wget", "--quiet", link, "-O", str(model_path)]
+ logger.info(f"Downloading the FIRe model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ logger.info("Loading fire model...")
+
+ # Load net
+ state = torch.load(model_path)
+ state["net_params"]["pretrained"] = None
+ net = fire_network.init_network(**state["net_params"])
+ net.load_state_dict(state["state_dict"])
+ self.net = net
+
+ self.norm_rgb = tvf.Normalize(
+ **dict(zip(["mean", "std"], net.runtime["mean_std"]))
+ )
+
+ # params
+ self.scales = conf["scales"]
+
+ def _forward(self, data):
+ image = self.norm_rgb(data["image"])
+
+ # Feature extraction.
+ desc = self.net.forward_global(image, scales=self.scales)
+
+ return {"global_descriptor": desc}
diff --git a/imcui/hloc/extractors/fire_local.py b/imcui/hloc/extractors/fire_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8e9ba9f4c3d86280e8232f61263b729ccb933be
--- /dev/null
+++ b/imcui/hloc/extractors/fire_local.py
@@ -0,0 +1,84 @@
+import subprocess
+import sys
+from pathlib import Path
+
+import torch
+import torchvision.transforms as tvf
+
+from .. import logger
+from ..utils.base_model import BaseModel
+
+fire_path = Path(__file__).parent / "../../third_party/fire"
+
+sys.path.append(str(fire_path))
+
+
+import fire_network
+
+EPS = 1e-6
+
+
+class FIRe(BaseModel):
+ default_conf = {
+ "global": True,
+ "asmk": False,
+ "model_name": "fire_SfM_120k.pth",
+ "scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params
+ "features_num": 1000,
+ "asmk_name": "asmk_codebook.bin",
+ "config_name": "eval_fire.yml",
+ }
+ required_inputs = ["image"]
+
+ # Models exported using
+ fire_models = {
+ "fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth",
+ "fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth",
+ }
+
+ def _init(self, conf):
+ assert conf["model_name"] in self.fire_models.keys()
+
+ # Config paths
+ model_path = fire_path / "model" / conf["model_name"]
+ config_path = fire_path / conf["config_name"] # noqa: F841
+ asmk_bin_path = fire_path / "model" / conf["asmk_name"] # noqa: F841
+
+ # Download the model.
+ if not model_path.exists():
+ model_path.parent.mkdir(exist_ok=True)
+ link = self.fire_models[conf["model_name"]]
+ cmd = ["wget", "--quiet", link, "-O", str(model_path)]
+ logger.info(f"Downloading the FIRe model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ logger.info("Loading fire model...")
+
+ # Load net
+ state = torch.load(model_path)
+ state["net_params"]["pretrained"] = None
+ net = fire_network.init_network(**state["net_params"])
+ net.load_state_dict(state["state_dict"])
+ self.net = net
+
+ self.norm_rgb = tvf.Normalize(
+ **dict(zip(["mean", "std"], net.runtime["mean_std"]))
+ )
+
+ # params
+ self.scales = conf["scales"]
+ self.features_num = conf["features_num"]
+
+ def _forward(self, data):
+ image = self.norm_rgb(data["image"])
+
+ local_desc = self.net.forward_local(
+ image, features_num=self.features_num, scales=self.scales
+ )
+
+ logger.info(f"output[0].shape = {local_desc[0].shape}\n")
+
+ return {
+ # 'global_descriptor': desc
+ "local_descriptor": local_desc
+ }
diff --git a/imcui/hloc/extractors/lanet.py b/imcui/hloc/extractors/lanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5f7af8692d9c216bd613fe2cf488e3c148392fa
--- /dev/null
+++ b/imcui/hloc/extractors/lanet.py
@@ -0,0 +1,63 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+
+from ..utils.base_model import BaseModel
+
+lib_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(lib_path))
+from lanet.network_v0.model import PointModel
+
+lanet_path = Path(__file__).parent / "../../third_party/lanet"
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class LANet(BaseModel):
+ default_conf = {
+ "model_name": "PointModel_v0.pth",
+ "keypoint_threshold": 0.1,
+ "max_keypoints": 1024,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ logger.info("Loading LANet model...")
+
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ self.net = PointModel(is_test=True)
+ state_dict = torch.load(model_path, map_location="cpu")
+ self.net.load_state_dict(state_dict["model_state"])
+ logger.info("Load LANet model done.")
+
+ def _forward(self, data):
+ image = data["image"]
+ keypoints, scores, descriptors = self.net(image)
+ _, _, Hc, Wc = descriptors.shape
+
+ # Scores & Descriptors
+ kpts_score = torch.cat([keypoints, scores], dim=1).view(3, -1).t()
+ descriptors = descriptors.view(256, Hc, Wc).view(256, -1).t()
+
+ # Filter based on confidence threshold
+ descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ keypoints = kpts_score[:, 1:]
+ scores = kpts_score[:, 0]
+
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
+ keypoints = keypoints[idxs, :2]
+ descriptors = descriptors[idxs]
+ scores = scores[idxs]
+
+ return {
+ "keypoints": keypoints[None],
+ "scores": scores[None],
+ "descriptors": descriptors.T[None],
+ }
diff --git a/imcui/hloc/extractors/netvlad.py b/imcui/hloc/extractors/netvlad.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba5f9a2feebf7ed0accd23e318b2a83e0f9df12
--- /dev/null
+++ b/imcui/hloc/extractors/netvlad.py
@@ -0,0 +1,146 @@
+import subprocess
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+from scipy.io import loadmat
+
+from .. import logger
+from ..utils.base_model import BaseModel
+
+EPS = 1e-6
+
+
+class NetVLADLayer(nn.Module):
+ def __init__(self, input_dim=512, K=64, score_bias=False, intranorm=True):
+ super().__init__()
+ self.score_proj = nn.Conv1d(input_dim, K, kernel_size=1, bias=score_bias)
+ centers = nn.parameter.Parameter(torch.empty([input_dim, K]))
+ nn.init.xavier_uniform_(centers)
+ self.register_parameter("centers", centers)
+ self.intranorm = intranorm
+ self.output_dim = input_dim * K
+
+ def forward(self, x):
+ b = x.size(0)
+ scores = self.score_proj(x)
+ scores = F.softmax(scores, dim=1)
+ diff = x.unsqueeze(2) - self.centers.unsqueeze(0).unsqueeze(-1)
+ desc = (scores.unsqueeze(1) * diff).sum(dim=-1)
+ if self.intranorm:
+ # From the official MATLAB implementation.
+ desc = F.normalize(desc, dim=1)
+ desc = desc.view(b, -1)
+ desc = F.normalize(desc, dim=1)
+ return desc
+
+
+class NetVLAD(BaseModel):
+ default_conf = {"model_name": "VGG16-NetVLAD-Pitts30K", "whiten": True}
+ required_inputs = ["image"]
+
+ # Models exported using
+ # https://github.com/uzh-rpg/netvlad_tf_open/blob/master/matlab/net_class2struct.m.
+ dir_models = {
+ "VGG16-NetVLAD-Pitts30K": "https://cvg-data.inf.ethz.ch/hloc/netvlad/Pitts30K_struct.mat",
+ "VGG16-NetVLAD-TokyoTM": "https://cvg-data.inf.ethz.ch/hloc/netvlad/TokyoTM_struct.mat",
+ }
+
+ def _init(self, conf):
+ assert conf["model_name"] in self.dir_models.keys()
+
+ # Download the checkpoint.
+ checkpoint = Path(torch.hub.get_dir(), "netvlad", conf["model_name"] + ".mat")
+ if not checkpoint.exists():
+ checkpoint.parent.mkdir(exist_ok=True, parents=True)
+ link = self.dir_models[conf["model_name"]]
+ cmd = ["wget", "--quiet", link, "-O", str(checkpoint)]
+ logger.info(f"Downloading the NetVLAD model with `{cmd}`.")
+ subprocess.run(cmd, check=True)
+
+ # Create the network.
+ # Remove classification head.
+ backbone = list(models.vgg16().children())[0]
+ # Remove last ReLU + MaxPool2d.
+ self.backbone = nn.Sequential(*list(backbone.children())[:-2])
+
+ self.netvlad = NetVLADLayer()
+
+ if conf["whiten"]:
+ self.whiten = nn.Linear(self.netvlad.output_dim, 4096)
+
+ # Parse MATLAB weights using https://github.com/uzh-rpg/netvlad_tf_open
+ mat = loadmat(checkpoint, struct_as_record=False, squeeze_me=True)
+
+ # CNN weights.
+ for layer, mat_layer in zip(self.backbone.children(), mat["net"].layers):
+ if isinstance(layer, nn.Conv2d):
+ w = mat_layer.weights[0] # Shape: S x S x IN x OUT
+ b = mat_layer.weights[1] # Shape: OUT
+ # Prepare for PyTorch - enforce float32 and right shape.
+ # w should have shape: OUT x IN x S x S
+ # b should have shape: OUT
+ w = torch.tensor(w).float().permute([3, 2, 0, 1])
+ b = torch.tensor(b).float()
+ # Update layer weights.
+ layer.weight = nn.Parameter(w)
+ layer.bias = nn.Parameter(b)
+
+ # NetVLAD weights.
+ score_w = mat["net"].layers[30].weights[0] # D x K
+ # centers are stored as opposite in official MATLAB code
+ center_w = -mat["net"].layers[30].weights[1] # D x K
+ # Prepare for PyTorch - make sure it is float32 and has right shape.
+ # score_w should have shape K x D x 1
+ # center_w should have shape D x K
+ score_w = torch.tensor(score_w).float().permute([1, 0]).unsqueeze(-1)
+ center_w = torch.tensor(center_w).float()
+ # Update layer weights.
+ self.netvlad.score_proj.weight = nn.Parameter(score_w)
+ self.netvlad.centers = nn.Parameter(center_w)
+
+ # Whitening weights.
+ if conf["whiten"]:
+ w = mat["net"].layers[33].weights[0] # Shape: 1 x 1 x IN x OUT
+ b = mat["net"].layers[33].weights[1] # Shape: OUT
+ # Prepare for PyTorch - make sure it is float32 and has right shape
+ w = torch.tensor(w).float().squeeze().permute([1, 0]) # OUT x IN
+ b = torch.tensor(b.squeeze()).float() # Shape: OUT
+ # Update layer weights.
+ self.whiten.weight = nn.Parameter(w)
+ self.whiten.bias = nn.Parameter(b)
+
+ # Preprocessing parameters.
+ self.preprocess = {
+ "mean": mat["net"].meta.normalization.averageImage[0, 0],
+ "std": np.array([1, 1, 1], dtype=np.float32),
+ }
+
+ def _forward(self, data):
+ image = data["image"]
+ assert image.shape[1] == 3
+ assert image.min() >= -EPS and image.max() <= 1 + EPS
+ image = torch.clamp(image * 255, 0.0, 255.0) # Input should be 0-255.
+ mean = self.preprocess["mean"]
+ std = self.preprocess["std"]
+ image = image - image.new_tensor(mean).view(1, -1, 1, 1)
+ image = image / image.new_tensor(std).view(1, -1, 1, 1)
+
+ # Feature extraction.
+ descriptors = self.backbone(image)
+ b, c, _, _ = descriptors.size()
+ descriptors = descriptors.view(b, c, -1)
+
+ # NetVLAD layer.
+ descriptors = F.normalize(descriptors, dim=1) # Pre-normalization.
+ desc = self.netvlad(descriptors)
+
+ # Whiten if needed.
+ if hasattr(self, "whiten"):
+ desc = self.whiten(desc)
+ desc = F.normalize(desc, dim=1) # Final L2 normalization.
+
+ return {"global_descriptor": desc}
diff --git a/imcui/hloc/extractors/openibl.py b/imcui/hloc/extractors/openibl.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e332a4e0016fceb184dd850bd3b6f86231dad54
--- /dev/null
+++ b/imcui/hloc/extractors/openibl.py
@@ -0,0 +1,26 @@
+import torch
+import torchvision.transforms as tvf
+
+from ..utils.base_model import BaseModel
+
+
+class OpenIBL(BaseModel):
+ default_conf = {
+ "model_name": "vgg16_netvlad",
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "yxgeee/OpenIBL", conf["model_name"], pretrained=True
+ ).eval()
+ mean = [0.48501960784313836, 0.4579568627450961, 0.4076039215686255]
+ std = [0.00392156862745098, 0.00392156862745098, 0.00392156862745098]
+ self.norm_rgb = tvf.Normalize(mean=mean, std=std)
+
+ def _forward(self, data):
+ image = self.norm_rgb(data["image"])
+ desc = self.net(image)
+ return {
+ "global_descriptor": desc,
+ }
diff --git a/imcui/hloc/extractors/r2d2.py b/imcui/hloc/extractors/r2d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..66769b040dd10e1b4f38eca0cf41c2023d096482
--- /dev/null
+++ b/imcui/hloc/extractors/r2d2.py
@@ -0,0 +1,73 @@
+import sys
+from pathlib import Path
+
+import torchvision.transforms as tvf
+
+from .. import MODEL_REPO_ID, logger
+
+from ..utils.base_model import BaseModel
+
+r2d2_path = Path(__file__).parents[2] / "third_party/r2d2"
+sys.path.append(str(r2d2_path))
+
+gim_path = Path(__file__).parents[2] / "third_party/gim"
+if str(gim_path) in sys.path:
+ sys.path.remove(str(gim_path))
+
+from extract import NonMaxSuppression, extract_multiscale, load_network
+
+
+class R2D2(BaseModel):
+ default_conf = {
+ "model_name": "r2d2_WASF_N16.pt",
+ "max_keypoints": 5000,
+ "scale_factor": 2**0.25,
+ "min_size": 256,
+ "max_size": 1024,
+ "min_scale": 0,
+ "max_scale": 1,
+ "reliability_threshold": 0.7,
+ "repetability_threshold": 0.7,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ self.norm_rgb = tvf.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ self.net = load_network(model_path)
+ self.detector = NonMaxSuppression(
+ rel_thr=conf["reliability_threshold"],
+ rep_thr=conf["repetability_threshold"],
+ )
+ logger.info("Load R2D2 model done.")
+
+ def _forward(self, data):
+ img = data["image"]
+ img = self.norm_rgb(img)
+
+ xys, desc, scores = extract_multiscale(
+ self.net,
+ img,
+ self.detector,
+ scale_f=self.conf["scale_factor"],
+ min_size=self.conf["min_size"],
+ max_size=self.conf["max_size"],
+ min_scale=self.conf["min_scale"],
+ max_scale=self.conf["max_scale"],
+ )
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
+ xy = xys[idxs, :2]
+ desc = desc[idxs].t()
+ scores = scores[idxs]
+
+ pred = {
+ "keypoints": xy[None],
+ "descriptors": desc[None],
+ "scores": scores[None],
+ }
+ return pred
diff --git a/imcui/hloc/extractors/rekd.py b/imcui/hloc/extractors/rekd.py
new file mode 100644
index 0000000000000000000000000000000000000000..82fc522920e21e171cb269e680506ad7aeeeaf9a
--- /dev/null
+++ b/imcui/hloc/extractors/rekd.py
@@ -0,0 +1,60 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+
+from ..utils.base_model import BaseModel
+
+rekd_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(rekd_path))
+from REKD.training.model.REKD import REKD as REKD_
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class REKD(BaseModel):
+ default_conf = {
+ "model_name": "v0",
+ "keypoint_threshold": 0.1,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ # TODO: download model
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ if not model_path.exists():
+ print(f"No model found at {model_path}")
+ self.net = REKD_(is_test=True)
+ state_dict = torch.load(model_path, map_location="cpu")
+ self.net.load_state_dict(state_dict["model_state"])
+ logger.info("Load REKD model done.")
+
+ def _forward(self, data):
+ image = data["image"]
+ keypoints, scores, descriptors = self.net(image)
+ _, _, Hc, Wc = descriptors.shape
+
+ # Scores & Descriptors
+ kpts_score = (
+ torch.cat([keypoints, scores], dim=1).view(3, -1).t().cpu().detach().numpy()
+ )
+ descriptors = (
+ descriptors.view(256, Hc, Wc).view(256, -1).t().cpu().detach().numpy()
+ )
+
+ # Filter based on confidence threshold
+ descriptors = descriptors[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ kpts_score = kpts_score[kpts_score[:, 0] > self.conf["keypoint_threshold"], :]
+ keypoints = kpts_score[:, 1:]
+ scores = kpts_score[:, 0]
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/imcui/hloc/extractors/rord.py b/imcui/hloc/extractors/rord.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba71113e4f9a57609879c95bb453af4104dbb72d
--- /dev/null
+++ b/imcui/hloc/extractors/rord.py
@@ -0,0 +1,59 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+
+from ..utils.base_model import BaseModel
+
+rord_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(rord_path))
+from RoRD.lib.model_test import D2Net as _RoRD
+from RoRD.lib.pyramid import process_multiscale
+
+
+class RoRD(BaseModel):
+ default_conf = {
+ "model_name": "rord.pth",
+ "checkpoint_dir": rord_path / "RoRD" / "models",
+ "use_relu": True,
+ "multiscale": False,
+ "max_keypoints": 1024,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ self.net = _RoRD(
+ model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
+ )
+ logger.info("Load RoRD model done.")
+
+ def _forward(self, data):
+ image = data["image"]
+ image = image.flip(1) # RGB -> BGR
+ norm = image.new_tensor([103.939, 116.779, 123.68])
+ image = image * 255 - norm.view(1, 3, 1, 1) # caffe normalization
+
+ if self.conf["multiscale"]:
+ keypoints, scores, descriptors = process_multiscale(image, self.net)
+ else:
+ keypoints, scores, descriptors = process_multiscale(
+ image, self.net, scales=[1]
+ )
+ keypoints = keypoints[:, [1, 0]] # (x, y) and remove the scale
+
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
+ keypoints = keypoints[idxs, :2]
+ descriptors = descriptors[idxs]
+ scores = scores[idxs]
+
+ return {
+ "keypoints": torch.from_numpy(keypoints)[None],
+ "scores": torch.from_numpy(scores)[None],
+ "descriptors": torch.from_numpy(descriptors.T)[None],
+ }
diff --git a/imcui/hloc/extractors/sfd2.py b/imcui/hloc/extractors/sfd2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2724ed3046644da4903c54a7f0d9d8d8f585aafe
--- /dev/null
+++ b/imcui/hloc/extractors/sfd2.py
@@ -0,0 +1,44 @@
+import sys
+from pathlib import Path
+
+import torchvision.transforms as tvf
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+tp_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(tp_path))
+from pram.nets.sfd2 import load_sfd2
+
+
+class SFD2(BaseModel):
+ default_conf = {
+ "max_keypoints": 4096,
+ "model_name": "sfd2_20230511_210205_resnet4x.79.pth",
+ "conf_th": 0.001,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.conf = {**self.default_conf, **conf}
+ self.norm_rgb = tvf.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format("pram", self.conf["model_name"]),
+ )
+ self.net = load_sfd2(weight_path=model_path).eval()
+
+ logger.info("Load SFD2 model done.")
+
+ def _forward(self, data):
+ pred = self.net.extract_local_global(
+ data={"image": self.norm_rgb(data["image"])}, config=self.conf
+ )
+ out = {
+ "keypoints": pred["keypoints"][0][None],
+ "scores": pred["scores"][0][None],
+ "descriptors": pred["descriptors"][0][None],
+ }
+ return out
diff --git a/imcui/hloc/extractors/sift.py b/imcui/hloc/extractors/sift.py
new file mode 100644
index 0000000000000000000000000000000000000000..05df8a76f18b7eae32ef52cbfc91fb13d37c2a9f
--- /dev/null
+++ b/imcui/hloc/extractors/sift.py
@@ -0,0 +1,216 @@
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from kornia.color import rgb_to_grayscale
+from omegaconf import OmegaConf
+from packaging import version
+
+try:
+ import pycolmap
+except ImportError:
+ pycolmap = None
+from .. import logger
+
+from ..utils.base_model import BaseModel
+
+
+def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
+ h, w = image_shape
+ ij = np.round(points - 0.5).astype(int).T[::-1]
+
+ # Remove duplicate points (identical coordinates).
+ # Pick highest scale or score
+ s = scales if scores is None else scores
+ buffer = np.zeros((h, w))
+ np.maximum.at(buffer, tuple(ij), s)
+ keep = np.where(buffer[tuple(ij)] == s)[0]
+
+ # Pick lowest angle (arbitrary).
+ ij = ij[:, keep]
+ buffer[:] = np.inf
+ o_abs = np.abs(angles[keep])
+ np.minimum.at(buffer, tuple(ij), o_abs)
+ mask = buffer[tuple(ij)] == o_abs
+ ij = ij[:, mask]
+ keep = keep[mask]
+
+ if nms_radius > 0:
+ # Apply NMS on the remaining points
+ buffer[:] = 0
+ buffer[tuple(ij)] = s[keep] # scores or scale
+
+ local_max = torch.nn.functional.max_pool2d(
+ torch.from_numpy(buffer).unsqueeze(0),
+ kernel_size=nms_radius * 2 + 1,
+ stride=1,
+ padding=nms_radius,
+ ).squeeze(0)
+ is_local_max = buffer == local_max.numpy()
+ keep = keep[is_local_max[tuple(ij)]]
+ return keep
+
+
+def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
+ x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
+ x.clip_(min=eps).sqrt_()
+ return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
+
+
+def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
+ """
+ Detect keypoints using OpenCV Detector.
+ Optionally, perform description.
+ Args:
+ features: OpenCV based keypoints detector and descriptor
+ image: Grayscale image of uint8 data type
+ Returns:
+ keypoints: 1D array of detected cv2.KeyPoint
+ scores: 1D array of responses
+ descriptors: 1D array of descriptors
+ """
+ detections, descriptors = features.detectAndCompute(image, None)
+ points = np.array([k.pt for k in detections], dtype=np.float32)
+ scores = np.array([k.response for k in detections], dtype=np.float32)
+ scales = np.array([k.size for k in detections], dtype=np.float32)
+ angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
+ return points, scores, scales, angles, descriptors
+
+
+class SIFT(BaseModel):
+ default_conf = {
+ "rootsift": True,
+ "nms_radius": 0, # None to disable filtering entirely.
+ "max_keypoints": 4096,
+ "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
+ "detection_threshold": 0.0066667, # from COLMAP
+ "edge_threshold": 10,
+ "first_octave": -1, # only used by pycolmap, the default of COLMAP
+ "num_octaves": 4,
+ }
+
+ required_data_keys = ["image"]
+
+ def _init(self, conf):
+ self.conf = OmegaConf.create(self.conf)
+ backend = self.conf.backend
+ if backend.startswith("pycolmap"):
+ if pycolmap is None:
+ raise ImportError(
+ "Cannot find module pycolmap: install it with pip"
+ "or use backend=opencv."
+ )
+ options = {
+ "peak_threshold": self.conf.detection_threshold,
+ "edge_threshold": self.conf.edge_threshold,
+ "first_octave": self.conf.first_octave,
+ "num_octaves": self.conf.num_octaves,
+ "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
+ }
+ device = (
+ "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
+ )
+ if (
+ backend == "pycolmap_cpu" or not pycolmap.has_cuda
+ ) and pycolmap.__version__ < "0.5.0":
+ warnings.warn(
+ "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
+ "consider upgrading pycolmap or use the CUDA version.",
+ stacklevel=1,
+ )
+ else:
+ options["max_num_features"] = self.conf.max_keypoints
+ self.sift = pycolmap.Sift(options=options, device=device)
+ elif backend == "opencv":
+ self.sift = cv2.SIFT_create(
+ contrastThreshold=self.conf.detection_threshold,
+ nfeatures=self.conf.max_keypoints,
+ edgeThreshold=self.conf.edge_threshold,
+ nOctaveLayers=self.conf.num_octaves,
+ )
+ else:
+ backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
+ raise ValueError(
+ f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
+ )
+ logger.info("Load SIFT model done.")
+
+ def extract_single_image(self, image: torch.Tensor):
+ image_np = image.cpu().numpy().squeeze(0)
+
+ if self.conf.backend.startswith("pycolmap"):
+ if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
+ detections, descriptors = self.sift.extract(image_np)
+ scores = None # Scores are not exposed by COLMAP anymore.
+ else:
+ detections, scores, descriptors = self.sift.extract(image_np)
+ keypoints = detections[:, :2] # Keep only (x, y).
+ scales, angles = detections[:, -2:].T
+ if scores is not None and (
+ self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
+ ):
+ # Set the scores as a combination of abs. response and scale.
+ scores = np.abs(scores) * scales
+ elif self.conf.backend == "opencv":
+ # TODO: Check if opencv keypoints are already in corner convention
+ keypoints, scores, scales, angles, descriptors = run_opencv_sift(
+ self.sift, (image_np * 255.0).astype(np.uint8)
+ )
+ pred = {
+ "keypoints": keypoints,
+ "scales": scales,
+ "oris": angles,
+ "descriptors": descriptors,
+ }
+ if scores is not None:
+ pred["scores"] = scores
+
+ # sometimes pycolmap returns points outside the image. We remove them
+ if self.conf.backend.startswith("pycolmap"):
+ is_inside = (
+ pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
+ ).all(-1)
+ pred = {k: v[is_inside] for k, v in pred.items()}
+
+ if self.conf.nms_radius is not None:
+ keep = filter_dog_point(
+ pred["keypoints"],
+ pred["scales"],
+ pred["oris"],
+ image_np.shape,
+ self.conf.nms_radius,
+ scores=pred.get("scores"),
+ )
+ pred = {k: v[keep] for k, v in pred.items()}
+
+ pred = {k: torch.from_numpy(v) for k, v in pred.items()}
+ if scores is not None:
+ # Keep the k keypoints with highest score
+ num_points = self.conf.max_keypoints
+ if num_points is not None and len(pred["keypoints"]) > num_points:
+ indices = torch.topk(pred["scores"], num_points).indices
+ pred = {k: v[indices] for k, v in pred.items()}
+ return pred
+
+ def _forward(self, data: dict) -> dict:
+ image = data["image"]
+ if image.shape[1] == 3:
+ image = rgb_to_grayscale(image)
+ device = image.device
+ image = image.cpu()
+ pred = []
+ for k in range(len(image)):
+ img = image[k]
+ if "image_size" in data.keys():
+ # avoid extracting points in padded areas
+ w, h = data["image_size"][k]
+ img = img[:, :h, :w]
+ p = self.extract_single_image(img)
+ pred.append(p)
+ pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
+ if self.conf.rootsift:
+ pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
+ pred["descriptors"] = pred["descriptors"].permute(0, 2, 1)
+ pred["keypoint_scores"] = pred["scores"].clone()
+ return pred
diff --git a/imcui/hloc/extractors/superpoint.py b/imcui/hloc/extractors/superpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f4c03e314743be2b862f3b8d8df078d2f85bc39
--- /dev/null
+++ b/imcui/hloc/extractors/superpoint.py
@@ -0,0 +1,51 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import logger
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from SuperGluePretrainedNetwork.models import superpoint # noqa E402
+
+
+# The original keypoint sampling is incorrect. We patch it here but
+# we don't fix it upstream to not impact exisiting evaluations.
+def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
+ """Interpolate descriptors at keypoint locations"""
+ b, c, h, w = descriptors.shape
+ keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors,
+ keypoints.view(b, 1, -1, 2),
+ mode="bilinear",
+ align_corners=False,
+ )
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1
+ )
+ return descriptors
+
+
+class SuperPoint(BaseModel):
+ default_conf = {
+ "nms_radius": 4,
+ "keypoint_threshold": 0.005,
+ "max_keypoints": -1,
+ "remove_borders": 4,
+ "fix_sampling": False,
+ }
+ required_inputs = ["image"]
+ detection_noise = 2.0
+
+ def _init(self, conf):
+ if conf["fix_sampling"]:
+ superpoint.sample_descriptors = sample_descriptors_fix_sampling
+ self.net = superpoint.SuperPoint(conf)
+ logger.info("Load SuperPoint model done.")
+
+ def _forward(self, data):
+ return self.net(data, self.conf)
diff --git a/imcui/hloc/extractors/xfeat.py b/imcui/hloc/extractors/xfeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..f29e115dca54db10bc2b58369eb1ff28dc6e3b2c
--- /dev/null
+++ b/imcui/hloc/extractors/xfeat.py
@@ -0,0 +1,33 @@
+import torch
+
+from .. import logger
+
+from ..utils.base_model import BaseModel
+
+
+class XFeat(BaseModel):
+ default_conf = {
+ "keypoint_threshold": 0.005,
+ "max_keypoints": -1,
+ }
+ required_inputs = ["image"]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "verlab/accelerated_features",
+ "XFeat",
+ pretrained=True,
+ top_k=self.conf["max_keypoints"],
+ )
+ logger.info("Load XFeat(sparse) model done.")
+
+ def _forward(self, data):
+ pred = self.net.detectAndCompute(
+ data["image"], top_k=self.conf["max_keypoints"]
+ )[0]
+ pred = {
+ "keypoints": pred["keypoints"][None],
+ "scores": pred["scores"][None],
+ "descriptors": pred["descriptors"].T[None],
+ }
+ return pred
diff --git a/imcui/hloc/localize_inloc.py b/imcui/hloc/localize_inloc.py
new file mode 100644
index 0000000000000000000000000000000000000000..acda7520012c53f468b1603d6a26a34855ebbffb
--- /dev/null
+++ b/imcui/hloc/localize_inloc.py
@@ -0,0 +1,179 @@
+import argparse
+import pickle
+from pathlib import Path
+
+import cv2
+import h5py
+import numpy as np
+import pycolmap
+import torch
+from scipy.io import loadmat
+from tqdm import tqdm
+
+from . import logger
+from .utils.parsers import names_to_pair, parse_retrieval
+
+
+def interpolate_scan(scan, kp):
+ h, w, c = scan.shape
+ kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1
+ assert np.all(kp > -1) and np.all(kp < 1)
+ scan = torch.from_numpy(scan).permute(2, 0, 1)[None]
+ kp = torch.from_numpy(kp)[None, None]
+ grid_sample = torch.nn.functional.grid_sample
+
+ # To maximize the number of points that have depth:
+ # do bilinear interpolation first and then nearest for the remaining points
+ interp_lin = grid_sample(scan, kp, align_corners=True, mode="bilinear")[0, :, 0]
+ interp_nn = torch.nn.functional.grid_sample(
+ scan, kp, align_corners=True, mode="nearest"
+ )[0, :, 0]
+ interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)
+ valid = ~torch.any(torch.isnan(interp), 0)
+
+ kp3d = interp.T.numpy()
+ valid = valid.numpy()
+ return kp3d, valid
+
+
+def get_scan_pose(dataset_dir, rpath):
+ split_image_rpath = rpath.split("/")
+ floor_name = split_image_rpath[-3]
+ scan_id = split_image_rpath[-2]
+ image_name = split_image_rpath[-1]
+ building_name = image_name[:3]
+
+ path = Path(
+ dataset_dir,
+ "database/alignments",
+ floor_name,
+ f"transformations/{building_name}_trans_{scan_id}.txt",
+ )
+ with open(path) as f:
+ raw_lines = f.readlines()
+
+ P_after_GICP = np.array(
+ [
+ np.fromstring(raw_lines[7], sep=" "),
+ np.fromstring(raw_lines[8], sep=" "),
+ np.fromstring(raw_lines[9], sep=" "),
+ np.fromstring(raw_lines[10], sep=" "),
+ ]
+ )
+
+ return P_after_GICP
+
+
+def pose_from_cluster(dataset_dir, q, retrieved, feature_file, match_file, skip=None):
+ height, width = cv2.imread(str(dataset_dir / q)).shape[:2]
+ cx = 0.5 * width
+ cy = 0.5 * height
+ focal_length = 4032.0 * 28.0 / 36.0
+
+ all_mkpq = []
+ all_mkpr = []
+ all_mkp3d = []
+ all_indices = []
+ kpq = feature_file[q]["keypoints"].__array__()
+ num_matches = 0
+
+ for i, r in enumerate(retrieved):
+ kpr = feature_file[r]["keypoints"].__array__()
+ pair = names_to_pair(q, r)
+ m = match_file[pair]["matches0"].__array__()
+ v = m > -1
+
+ if skip and (np.count_nonzero(v) < skip):
+ continue
+
+ mkpq, mkpr = kpq[v], kpr[m[v]]
+ num_matches += len(mkpq)
+
+ scan_r = loadmat(Path(dataset_dir, r + ".mat"))["XYZcut"]
+ mkp3d, valid = interpolate_scan(scan_r, mkpr)
+ Tr = get_scan_pose(dataset_dir, r)
+ mkp3d = (Tr[:3, :3] @ mkp3d.T + Tr[:3, -1:]).T
+
+ all_mkpq.append(mkpq[valid])
+ all_mkpr.append(mkpr[valid])
+ all_mkp3d.append(mkp3d[valid])
+ all_indices.append(np.full(np.count_nonzero(valid), i))
+
+ all_mkpq = np.concatenate(all_mkpq, 0)
+ all_mkpr = np.concatenate(all_mkpr, 0)
+ all_mkp3d = np.concatenate(all_mkp3d, 0)
+ all_indices = np.concatenate(all_indices, 0)
+
+ cfg = {
+ "model": "SIMPLE_PINHOLE",
+ "width": width,
+ "height": height,
+ "params": [focal_length, cx, cy],
+ }
+ ret = pycolmap.absolute_pose_estimation(all_mkpq, all_mkp3d, cfg, 48.00)
+ ret["cfg"] = cfg
+ return ret, all_mkpq, all_mkpr, all_mkp3d, all_indices, num_matches
+
+
+def main(dataset_dir, retrieval, features, matches, results, skip_matches=None):
+ assert retrieval.exists(), retrieval
+ assert features.exists(), features
+ assert matches.exists(), matches
+
+ retrieval_dict = parse_retrieval(retrieval)
+ queries = list(retrieval_dict.keys())
+
+ feature_file = h5py.File(features, "r", libver="latest")
+ match_file = h5py.File(matches, "r", libver="latest")
+
+ poses = {}
+ logs = {
+ "features": features,
+ "matches": matches,
+ "retrieval": retrieval,
+ "loc": {},
+ }
+ logger.info("Starting localization...")
+ for q in tqdm(queries):
+ db = retrieval_dict[q]
+ ret, mkpq, mkpr, mkp3d, indices, num_matches = pose_from_cluster(
+ dataset_dir, q, db, feature_file, match_file, skip_matches
+ )
+
+ poses[q] = (ret["qvec"], ret["tvec"])
+ logs["loc"][q] = {
+ "db": db,
+ "PnP_ret": ret,
+ "keypoints_query": mkpq,
+ "keypoints_db": mkpr,
+ "3d_points": mkp3d,
+ "indices_db": indices,
+ "num_matches": num_matches,
+ }
+
+ logger.info(f"Writing poses to {results}...")
+ with open(results, "w") as f:
+ for q in queries:
+ qvec, tvec = poses[q]
+ qvec = " ".join(map(str, qvec))
+ tvec = " ".join(map(str, tvec))
+ name = q.split("/")[-1]
+ f.write(f"{name} {qvec} {tvec}\n")
+
+ logs_path = f"{results}_logs.pkl"
+ logger.info(f"Writing logs to {logs_path}...")
+ with open(logs_path, "wb") as f:
+ pickle.dump(logs, f)
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset_dir", type=Path, required=True)
+ parser.add_argument("--retrieval", type=Path, required=True)
+ parser.add_argument("--features", type=Path, required=True)
+ parser.add_argument("--matches", type=Path, required=True)
+ parser.add_argument("--results", type=Path, required=True)
+ parser.add_argument("--skip_matches", type=int)
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/localize_sfm.py b/imcui/hloc/localize_sfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8122e2aca1aa057c424f0e39204c193f01cd57e7
--- /dev/null
+++ b/imcui/hloc/localize_sfm.py
@@ -0,0 +1,243 @@
+import argparse
+import pickle
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Union
+
+import numpy as np
+import pycolmap
+from tqdm import tqdm
+
+from . import logger
+from .utils.io import get_keypoints, get_matches
+from .utils.parsers import parse_image_lists, parse_retrieval
+
+
+def do_covisibility_clustering(
+ frame_ids: List[int], reconstruction: pycolmap.Reconstruction
+):
+ clusters = []
+ visited = set()
+ for frame_id in frame_ids:
+ # Check if already labeled
+ if frame_id in visited:
+ continue
+
+ # New component
+ clusters.append([])
+ queue = {frame_id}
+ while len(queue):
+ exploration_frame = queue.pop()
+
+ # Already part of the component
+ if exploration_frame in visited:
+ continue
+ visited.add(exploration_frame)
+ clusters[-1].append(exploration_frame)
+
+ observed = reconstruction.images[exploration_frame].points2D
+ connected_frames = {
+ obs.image_id
+ for p2D in observed
+ if p2D.has_point3D()
+ for obs in reconstruction.points3D[p2D.point3D_id].track.elements
+ }
+ connected_frames &= set(frame_ids)
+ connected_frames -= visited
+ queue |= connected_frames
+
+ clusters = sorted(clusters, key=len, reverse=True)
+ return clusters
+
+
+class QueryLocalizer:
+ def __init__(self, reconstruction, config=None):
+ self.reconstruction = reconstruction
+ self.config = config or {}
+
+ def localize(self, points2D_all, points2D_idxs, points3D_id, query_camera):
+ points2D = points2D_all[points2D_idxs]
+ points3D = [self.reconstruction.points3D[j].xyz for j in points3D_id]
+ ret = pycolmap.absolute_pose_estimation(
+ points2D,
+ points3D,
+ query_camera,
+ estimation_options=self.config.get("estimation", {}),
+ refinement_options=self.config.get("refinement", {}),
+ )
+ return ret
+
+
+def pose_from_cluster(
+ localizer: QueryLocalizer,
+ qname: str,
+ query_camera: pycolmap.Camera,
+ db_ids: List[int],
+ features_path: Path,
+ matches_path: Path,
+ **kwargs,
+):
+ kpq = get_keypoints(features_path, qname)
+ kpq += 0.5 # COLMAP coordinates
+
+ kp_idx_to_3D = defaultdict(list)
+ kp_idx_to_3D_to_db = defaultdict(lambda: defaultdict(list))
+ num_matches = 0
+ for i, db_id in enumerate(db_ids):
+ image = localizer.reconstruction.images[db_id]
+ if image.num_points3D == 0:
+ logger.debug(f"No 3D points found for {image.name}.")
+ continue
+ points3D_ids = np.array(
+ [p.point3D_id if p.has_point3D() else -1 for p in image.points2D]
+ )
+
+ matches, _ = get_matches(matches_path, qname, image.name)
+ matches = matches[points3D_ids[matches[:, 1]] != -1]
+ num_matches += len(matches)
+ for idx, m in matches:
+ id_3D = points3D_ids[m]
+ kp_idx_to_3D_to_db[idx][id_3D].append(i)
+ # avoid duplicate observations
+ if id_3D not in kp_idx_to_3D[idx]:
+ kp_idx_to_3D[idx].append(id_3D)
+
+ idxs = list(kp_idx_to_3D.keys())
+ mkp_idxs = [i for i in idxs for _ in kp_idx_to_3D[i]]
+ mp3d_ids = [j for i in idxs for j in kp_idx_to_3D[i]]
+ ret = localizer.localize(kpq, mkp_idxs, mp3d_ids, query_camera, **kwargs)
+ if ret is not None:
+ ret["camera"] = query_camera
+
+ # mostly for logging and post-processing
+ mkp_to_3D_to_db = [
+ (j, kp_idx_to_3D_to_db[i][j]) for i in idxs for j in kp_idx_to_3D[i]
+ ]
+ log = {
+ "db": db_ids,
+ "PnP_ret": ret,
+ "keypoints_query": kpq[mkp_idxs],
+ "points3D_ids": mp3d_ids,
+ "points3D_xyz": None, # we don't log xyz anymore because of file size
+ "num_matches": num_matches,
+ "keypoint_index_to_db": (mkp_idxs, mkp_to_3D_to_db),
+ }
+ return ret, log
+
+
+def main(
+ reference_sfm: Union[Path, pycolmap.Reconstruction],
+ queries: Path,
+ retrieval: Path,
+ features: Path,
+ matches: Path,
+ results: Path,
+ ransac_thresh: int = 12,
+ covisibility_clustering: bool = False,
+ prepend_camera_name: bool = False,
+ config: Dict = None,
+):
+ assert retrieval.exists(), retrieval
+ assert features.exists(), features
+ assert matches.exists(), matches
+
+ queries = parse_image_lists(queries, with_intrinsics=True)
+ retrieval_dict = parse_retrieval(retrieval)
+
+ logger.info("Reading the 3D model...")
+ if not isinstance(reference_sfm, pycolmap.Reconstruction):
+ reference_sfm = pycolmap.Reconstruction(reference_sfm)
+ db_name_to_id = {img.name: i for i, img in reference_sfm.images.items()}
+
+ config = {
+ "estimation": {"ransac": {"max_error": ransac_thresh}},
+ **(config or {}),
+ }
+ localizer = QueryLocalizer(reference_sfm, config)
+
+ cam_from_world = {}
+ logs = {
+ "features": features,
+ "matches": matches,
+ "retrieval": retrieval,
+ "loc": {},
+ }
+ logger.info("Starting localization...")
+ for qname, qcam in tqdm(queries):
+ if qname not in retrieval_dict:
+ logger.warning(f"No images retrieved for query image {qname}. Skipping...")
+ continue
+ db_names = retrieval_dict[qname]
+ db_ids = []
+ for n in db_names:
+ if n not in db_name_to_id:
+ logger.warning(f"Image {n} was retrieved but not in database")
+ continue
+ db_ids.append(db_name_to_id[n])
+
+ if covisibility_clustering:
+ clusters = do_covisibility_clustering(db_ids, reference_sfm)
+ best_inliers = 0
+ best_cluster = None
+ logs_clusters = []
+ for i, cluster_ids in enumerate(clusters):
+ ret, log = pose_from_cluster(
+ localizer, qname, qcam, cluster_ids, features, matches
+ )
+ if ret is not None and ret["num_inliers"] > best_inliers:
+ best_cluster = i
+ best_inliers = ret["num_inliers"]
+ logs_clusters.append(log)
+ if best_cluster is not None:
+ ret = logs_clusters[best_cluster]["PnP_ret"]
+ cam_from_world[qname] = ret["cam_from_world"]
+ logs["loc"][qname] = {
+ "db": db_ids,
+ "best_cluster": best_cluster,
+ "log_clusters": logs_clusters,
+ "covisibility_clustering": covisibility_clustering,
+ }
+ else:
+ ret, log = pose_from_cluster(
+ localizer, qname, qcam, db_ids, features, matches
+ )
+ if ret is not None:
+ cam_from_world[qname] = ret["cam_from_world"]
+ else:
+ closest = reference_sfm.images[db_ids[0]]
+ cam_from_world[qname] = closest.cam_from_world
+ log["covisibility_clustering"] = covisibility_clustering
+ logs["loc"][qname] = log
+
+ logger.info(f"Localized {len(cam_from_world)} / {len(queries)} images.")
+ logger.info(f"Writing poses to {results}...")
+ with open(results, "w") as f:
+ for query, t in cam_from_world.items():
+ qvec = " ".join(map(str, t.rotation.quat[[3, 0, 1, 2]]))
+ tvec = " ".join(map(str, t.translation))
+ name = query.split("/")[-1]
+ if prepend_camera_name:
+ name = query.split("/")[-2] + "/" + name
+ f.write(f"{name} {qvec} {tvec}\n")
+
+ logs_path = f"{results}_logs.pkl"
+ logger.info(f"Writing logs to {logs_path}...")
+ # TODO: Resolve pickling issue with pycolmap objects.
+ with open(logs_path, "wb") as f:
+ pickle.dump(logs, f)
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--reference_sfm", type=Path, required=True)
+ parser.add_argument("--queries", type=Path, required=True)
+ parser.add_argument("--features", type=Path, required=True)
+ parser.add_argument("--matches", type=Path, required=True)
+ parser.add_argument("--retrieval", type=Path, required=True)
+ parser.add_argument("--results", type=Path, required=True)
+ parser.add_argument("--ransac_thresh", type=float, default=12.0)
+ parser.add_argument("--covisibility_clustering", action="store_true")
+ parser.add_argument("--prepend_camera_name", action="store_true")
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/match_dense.py b/imcui/hloc/match_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..978db58b2438b337549a810196fdf017e0195e85
--- /dev/null
+++ b/imcui/hloc/match_dense.py
@@ -0,0 +1,1158 @@
+import argparse
+import pprint
+from collections import Counter, defaultdict
+from itertools import chain
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+
+import cv2
+import h5py
+import numpy as np
+import torch
+import torchvision.transforms.functional as F
+from scipy.spatial import KDTree
+from tqdm import tqdm
+
+from . import logger, matchers
+from .extract_features import read_image, resize_image
+from .match_features import find_unique_new_pairs
+from .utils.base_model import dynamic_load
+from .utils.io import list_h5_names
+from .utils.parsers import names_to_pair, parse_retrieval
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+confs = {
+ # Best quality but loads of points. Only use for small scenes
+ "loftr": {
+ "output": "matches-loftr",
+ "model": {
+ "name": "loftr",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ "force_resize": True,
+ },
+ "max_error": 1, # max error for assigned keypoints (in px)
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
+ },
+ "minima_loftr": {
+ "output": "matches-minima_loftr",
+ "model": {
+ "name": "loftr",
+ "weights": "outdoor",
+ "model_name": "minima_loftr.ckpt",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ "force_resize": False,
+ },
+ "max_error": 1, # max error for assigned keypoints (in px)
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
+ },
+ "eloftr": {
+ "output": "matches-eloftr",
+ "model": {
+ "name": "eloftr",
+ "model_name": "eloftr_outdoor.ckpt",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 32,
+ "width": 640,
+ "height": 480,
+ "force_resize": True,
+ },
+ "max_error": 1, # max error for assigned keypoints (in px)
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
+ },
+ "xoftr": {
+ "output": "matches-xoftr",
+ "model": {
+ "name": "xoftr",
+ "weights": "weights_xoftr_640.ckpt",
+ "max_keypoints": 2000,
+ "match_threshold": 0.3,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ "force_resize": True,
+ },
+ "max_error": 1, # max error for assigned keypoints (in px)
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
+ },
+ # "loftr_quadtree": {
+ # "output": "matches-loftr-quadtree",
+ # "model": {
+ # "name": "quadtree",
+ # "weights": "outdoor",
+ # "max_keypoints": 2000,
+ # "match_threshold": 0.2,
+ # },
+ # "preprocessing": {
+ # "grayscale": True,
+ # "resize_max": 1024,
+ # "dfactor": 8,
+ # "width": 640,
+ # "height": 480,
+ # "force_resize": True,
+ # },
+ # "max_error": 1, # max error for assigned keypoints (in px)
+ # "cell_size": 1, # size of quantization patch (max 1 kp/patch)
+ # },
+ "cotr": {
+ "output": "matches-cotr",
+ "model": {
+ "name": "cotr",
+ "weights": "out/default",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ "force_resize": True,
+ },
+ "max_error": 1, # max error for assigned keypoints (in px)
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
+ },
+ # Semi-scalable loftr which limits detected keypoints
+ "loftr_aachen": {
+ "output": "matches-loftr_aachen",
+ "model": {
+ "name": "loftr",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {"grayscale": True, "resize_max": 1024, "dfactor": 8},
+ "max_error": 2, # max error for assigned keypoints (in px)
+ "cell_size": 8, # size of quantization patch (max 1 kp/patch)
+ },
+ # Use for matching superpoint feats with loftr
+ "loftr_superpoint": {
+ "output": "matches-loftr_aachen",
+ "model": {
+ "name": "loftr",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ "force_resize": True,
+ },
+ "max_error": 4, # max error for assigned keypoints (in px)
+ "cell_size": 4, # size of quantization patch (max 1 kp/patch)
+ },
+ # Use topicfm for matching feats
+ "topicfm": {
+ "output": "matches-topicfm",
+ "model": {
+ "name": "topicfm",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "width": 640,
+ "height": 480,
+ },
+ },
+ # Use aspanformer for matching feats
+ "aspanformer": {
+ "output": "matches-aspanformer",
+ "model": {
+ "name": "aspanformer",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "duster": {
+ "output": "matches-duster",
+ "model": {
+ "name": "duster",
+ "weights": "vit_large",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 512,
+ "dfactor": 16,
+ },
+ },
+ "mast3r": {
+ "output": "matches-mast3r",
+ "model": {
+ "name": "mast3r",
+ "weights": "vit_large",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 512,
+ "dfactor": 16,
+ },
+ },
+ "xfeat_lightglue": {
+ "output": "matches-xfeat_lightglue",
+ "model": {
+ "name": "xfeat_lightglue",
+ "max_keypoints": 8000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": False,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "xfeat_dense": {
+ "output": "matches-xfeat_dense",
+ "model": {
+ "name": "xfeat_dense",
+ "max_keypoints": 8000,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": False,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "dkm": {
+ "output": "matches-dkm",
+ "model": {
+ "name": "dkm",
+ "weights": "outdoor",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 80,
+ "height": 60,
+ "dfactor": 8,
+ },
+ },
+ "roma": {
+ "output": "matches-roma",
+ "model": {
+ "name": "roma",
+ "weights": "outdoor",
+ "model_name": "roma_outdoor.pth",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 320,
+ "height": 240,
+ "dfactor": 8,
+ },
+ },
+ "minima_roma": {
+ "output": "matches-minima_roma",
+ "model": {
+ "name": "roma",
+ "weights": "outdoor",
+ "model_name": "minima_roma.pth",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": False,
+ "resize_max": 1024,
+ "width": 320,
+ "height": 240,
+ "dfactor": 8,
+ },
+ },
+ "gim(dkm)": {
+ "output": "matches-gim",
+ "model": {
+ "name": "gim",
+ "model_name": "gim_dkm_100h.ckpt",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 320,
+ "height": 240,
+ "dfactor": 8,
+ },
+ },
+ "omniglue": {
+ "output": "matches-omniglue",
+ "model": {
+ "name": "omniglue",
+ "match_threshold": 0.2,
+ "max_keypoints": 2000,
+ "features": "null",
+ },
+ "preprocessing": {
+ "grayscale": False,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ "width": 640,
+ "height": 480,
+ },
+ },
+ "sold2": {
+ "output": "matches-sold2",
+ "model": {
+ "name": "sold2",
+ "max_keypoints": 2000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+ "gluestick": {
+ "output": "matches-gluestick",
+ "model": {
+ "name": "gluestick",
+ "use_lines": True,
+ "max_keypoints": 1000,
+ "max_lines": 300,
+ "force_num_keypoints": False,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "force_resize": True,
+ "resize_max": 1024,
+ "width": 640,
+ "height": 480,
+ "dfactor": 8,
+ },
+ },
+}
+
+
+def to_cpts(kpts, ps):
+ if ps > 0.0:
+ kpts = np.round(np.round((kpts + 0.5) / ps) * ps - 0.5, 2)
+ return [tuple(cpt) for cpt in kpts]
+
+
+def assign_keypoints(
+ kpts: np.ndarray,
+ other_cpts: Union[List[Tuple], np.ndarray],
+ max_error: float,
+ update: bool = False,
+ ref_bins: Optional[List[Counter]] = None,
+ scores: Optional[np.ndarray] = None,
+ cell_size: Optional[int] = None,
+):
+ if not update:
+ # Without update this is just a NN search
+ if len(other_cpts) == 0 or len(kpts) == 0:
+ return np.full(len(kpts), -1)
+ dist, kpt_ids = KDTree(np.array(other_cpts)).query(kpts)
+ valid = dist <= max_error
+ kpt_ids[~valid] = -1
+ return kpt_ids
+ else:
+ ps = cell_size if cell_size is not None else max_error
+ ps = max(ps, max_error)
+ # With update we quantize and bin (optionally)
+ assert isinstance(other_cpts, list)
+ kpt_ids = []
+ cpts = to_cpts(kpts, ps)
+ bpts = to_cpts(kpts, int(max_error))
+ cp_to_id = {val: i for i, val in enumerate(other_cpts)}
+ for i, (cpt, bpt) in enumerate(zip(cpts, bpts)):
+ try:
+ kid = cp_to_id[cpt]
+ except KeyError:
+ kid = len(cp_to_id)
+ cp_to_id[cpt] = kid
+ other_cpts.append(cpt)
+ if ref_bins is not None:
+ ref_bins.append(Counter())
+ if ref_bins is not None:
+ score = scores[i] if scores is not None else 1
+ ref_bins[cp_to_id[cpt]][bpt] += score
+ kpt_ids.append(kid)
+ return np.array(kpt_ids)
+
+
+def get_grouped_ids(array):
+ # Group array indices based on its values
+ # all duplicates are grouped as a set
+ idx_sort = np.argsort(array)
+ sorted_array = array[idx_sort]
+ _, ids, _ = np.unique(sorted_array, return_counts=True, return_index=True)
+ res = np.split(idx_sort, ids[1:])
+ return res
+
+
+def get_unique_matches(match_ids, scores):
+ if len(match_ids.shape) == 1:
+ return [0]
+
+ isets1 = get_grouped_ids(match_ids[:, 0])
+ isets2 = get_grouped_ids(match_ids[:, 1])
+ uid1s = [ids[scores[ids].argmax()] for ids in isets1 if len(ids) > 0]
+ uid2s = [ids[scores[ids].argmax()] for ids in isets2 if len(ids) > 0]
+ uids = list(set(uid1s).intersection(uid2s))
+ return match_ids[uids], scores[uids]
+
+
+def matches_to_matches0(matches, scores):
+ if len(matches) == 0:
+ return np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.float16)
+ n_kps0 = np.max(matches[:, 0]) + 1
+ matches0 = -np.ones((n_kps0,))
+ scores0 = np.zeros((n_kps0,))
+ matches0[matches[:, 0]] = matches[:, 1]
+ scores0[matches[:, 0]] = scores
+ return matches0.astype(np.int32), scores0.astype(np.float16)
+
+
+def kpids_to_matches0(kpt_ids0, kpt_ids1, scores):
+ valid = (kpt_ids0 != -1) & (kpt_ids1 != -1)
+ matches = np.dstack([kpt_ids0[valid], kpt_ids1[valid]])
+ matches = matches.reshape(-1, 2)
+ scores = scores[valid]
+
+ # Remove n-to-1 matches
+ matches, scores = get_unique_matches(matches, scores)
+ return matches_to_matches0(matches, scores)
+
+
+def scale_keypoints(kpts, scale):
+ if np.any(scale != 1.0):
+ kpts *= kpts.new_tensor(scale)
+ return kpts
+
+
+class ImagePairDataset(torch.utils.data.Dataset):
+ default_conf = {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "cache_images": False,
+ }
+
+ def __init__(self, image_dir, conf, pairs):
+ self.image_dir = image_dir
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
+ self.pairs = pairs
+ if self.conf.cache_images:
+ image_names = set(sum(pairs, ())) # unique image names in pairs
+ logger.info(f"Loading and caching {len(image_names)} unique images.")
+ self.images = {}
+ self.scales = {}
+ for name in tqdm(image_names):
+ image = read_image(self.image_dir / name, self.conf.grayscale)
+ self.images[name], self.scales[name] = self.preprocess(image)
+
+ def preprocess(self, image: np.ndarray):
+ image = image.astype(np.float32, copy=False)
+ size = image.shape[:2][::-1]
+ scale = np.array([1.0, 1.0])
+
+ if self.conf.resize_max:
+ scale = self.conf.resize_max / max(size)
+ if scale < 1.0:
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, "cv2_area")
+ scale = np.array(size) / np.array(size_new)
+
+ if self.conf.grayscale:
+ assert image.ndim == 2, image.shape
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = torch.from_numpy(image / 255.0).float()
+
+ # assure that the size is divisible by dfactor
+ size_new = tuple(
+ map(
+ lambda x: int(x // self.conf.dfactor * self.conf.dfactor),
+ image.shape[-2:],
+ )
+ )
+ image = F.resize(image, size=size_new)
+ scale = np.array(size) / np.array(size_new)[::-1]
+ return image, scale
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def __getitem__(self, idx):
+ name0, name1 = self.pairs[idx]
+ if self.conf.cache_images:
+ image0, scale0 = self.images[name0], self.scales[name0]
+ image1, scale1 = self.images[name1], self.scales[name1]
+ else:
+ image0 = read_image(self.image_dir / name0, self.conf.grayscale)
+ image1 = read_image(self.image_dir / name1, self.conf.grayscale)
+ image0, scale0 = self.preprocess(image0)
+ image1, scale1 = self.preprocess(image1)
+ return image0, image1, scale0, scale1, name0, name1
+
+
+@torch.no_grad()
+def match_dense(
+ conf: Dict,
+ pairs: List[Tuple[str, str]],
+ image_dir: Path,
+ match_path: Path, # out
+ existing_refs: Optional[List] = [],
+):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ Model = dynamic_load(matchers, conf["model"]["name"])
+ model = Model(conf["model"]).eval().to(device)
+
+ dataset = ImagePairDataset(image_dir, conf["preprocessing"], pairs)
+ loader = torch.utils.data.DataLoader(
+ dataset, num_workers=16, batch_size=1, shuffle=False
+ )
+
+ logger.info("Performing dense matching...")
+ with h5py.File(str(match_path), "a") as fd:
+ for data in tqdm(loader, smoothing=0.1):
+ # load image-pair data
+ image0, image1, scale0, scale1, (name0,), (name1,) = data
+ scale0, scale1 = scale0[0].numpy(), scale1[0].numpy()
+ image0, image1 = image0.to(device), image1.to(device)
+
+ # match semi-dense
+ # for consistency with pairs_from_*: refine kpts of image0
+ if name0 in existing_refs:
+ # special case: flip to enable refinement in query image
+ pred = model({"image0": image1, "image1": image0})
+ pred = {
+ **pred,
+ "keypoints0": pred["keypoints1"],
+ "keypoints1": pred["keypoints0"],
+ }
+ else:
+ # usual case
+ pred = model({"image0": image0, "image1": image1})
+
+ # Rescale keypoints and move to cpu
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5
+ kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5
+ kpts0 = kpts0.cpu().numpy()
+ kpts1 = kpts1.cpu().numpy()
+ scores = pred["scores"].cpu().numpy()
+
+ # Write matches and matching scores in hloc format
+ pair = names_to_pair(name0, name1)
+ if pair in fd:
+ del fd[pair]
+ grp = fd.create_group(pair)
+
+ # Write dense matching output
+ grp.create_dataset("keypoints0", data=kpts0)
+ grp.create_dataset("keypoints1", data=kpts1)
+ grp.create_dataset("scores", data=scores)
+ del model, loader
+
+
+# default: quantize all!
+def load_keypoints(
+ conf: Dict, feature_paths_refs: List[Path], quantize: Optional[set] = None
+):
+ name2ref = {
+ n: i for i, p in enumerate(feature_paths_refs) for n in list_h5_names(p)
+ }
+
+ existing_refs = set(name2ref.keys())
+ if quantize is None:
+ quantize = existing_refs # quantize all
+ if len(existing_refs) > 0:
+ logger.info(f"Loading keypoints from {len(existing_refs)} images.")
+
+ # Load query keypoints
+ cpdict = defaultdict(list)
+ bindict = defaultdict(list)
+ for name in existing_refs:
+ with h5py.File(str(feature_paths_refs[name2ref[name]]), "r") as fd:
+ kps = fd[name]["keypoints"].__array__()
+ if name not in quantize:
+ cpdict[name] = kps
+ else:
+ if "scores" in fd[name].keys():
+ kp_scores = fd[name]["scores"].__array__()
+ else:
+ # we set the score to 1.0 if not provided
+ # increase for more weight on reference keypoints for
+ # stronger anchoring
+ kp_scores = [1.0 for _ in range(kps.shape[0])]
+ # bin existing keypoints of reference images for association
+ assign_keypoints(
+ kps,
+ cpdict[name],
+ conf["max_error"],
+ True,
+ bindict[name],
+ kp_scores,
+ conf["cell_size"],
+ )
+ return cpdict, bindict
+
+
+def aggregate_matches(
+ conf: Dict,
+ pairs: List[Tuple[str, str]],
+ match_path: Path,
+ feature_path: Path,
+ required_queries: Optional[Set[str]] = None,
+ max_kps: Optional[int] = None,
+ cpdict: Dict[str, Iterable] = defaultdict(list),
+ bindict: Dict[str, List[Counter]] = defaultdict(list),
+):
+ if required_queries is None:
+ required_queries = set(sum(pairs, ()))
+ # default: do not overwrite existing features in feature_path!
+ required_queries -= set(list_h5_names(feature_path))
+
+ # if an entry in cpdict is provided as np.ndarray we assume it is fixed
+ required_queries -= set([k for k, v in cpdict.items() if isinstance(v, np.ndarray)])
+
+ # sort pairs for reduced RAM
+ pairs_per_q = Counter(list(chain(*pairs)))
+ pairs_score = [min(pairs_per_q[i], pairs_per_q[j]) for i, j in pairs]
+ pairs = [p for _, p in sorted(zip(pairs_score, pairs))]
+
+ if len(required_queries) > 0:
+ logger.info(f"Aggregating keypoints for {len(required_queries)} images.")
+ n_kps = 0
+ with h5py.File(str(match_path), "a") as fd:
+ for name0, name1 in tqdm(pairs, smoothing=0.1):
+ pair = names_to_pair(name0, name1)
+ grp = fd[pair]
+ kpts0 = grp["keypoints0"].__array__()
+ kpts1 = grp["keypoints1"].__array__()
+ scores = grp["scores"].__array__()
+
+ # Aggregate local features
+ update0 = name0 in required_queries
+ update1 = name1 in required_queries
+
+ # in localization we do not want to bin the query kp
+ # assumes that the query is name0!
+ if update0 and not update1 and max_kps is None:
+ max_error0 = cell_size0 = 0.0
+ else:
+ max_error0 = conf["max_error"]
+ cell_size0 = conf["cell_size"]
+
+ # Get match ids and extend query keypoints (cpdict)
+ mkp_ids0 = assign_keypoints(
+ kpts0,
+ cpdict[name0],
+ max_error0,
+ update0,
+ bindict[name0],
+ scores,
+ cell_size0,
+ )
+ mkp_ids1 = assign_keypoints(
+ kpts1,
+ cpdict[name1],
+ conf["max_error"],
+ update1,
+ bindict[name1],
+ scores,
+ conf["cell_size"],
+ )
+
+ # Build matches from assignments
+ matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores)
+
+ assert kpts0.shape[0] == scores.shape[0]
+ grp.create_dataset("matches0", data=matches0)
+ grp.create_dataset("matching_scores0", data=scores0)
+
+ # Convert bins to kps if finished, and store them
+ for name in (name0, name1):
+ pairs_per_q[name] -= 1
+ if pairs_per_q[name] > 0 or name not in required_queries:
+ continue
+ kp_score = [c.most_common(1)[0][1] for c in bindict[name]]
+ cpdict[name] = [c.most_common(1)[0][0] for c in bindict[name]]
+ cpdict[name] = np.array(cpdict[name], dtype=np.float32)
+
+ # Select top-k query kps by score (reassign matches later)
+ if max_kps:
+ top_k = min(max_kps, cpdict[name].shape[0])
+ top_k = np.argsort(kp_score)[::-1][:top_k]
+ cpdict[name] = cpdict[name][top_k]
+ kp_score = np.array(kp_score)[top_k]
+
+ # Write query keypoints
+ with h5py.File(feature_path, "a") as kfd:
+ if name in kfd:
+ del kfd[name]
+ kgrp = kfd.create_group(name)
+ kgrp.create_dataset("keypoints", data=cpdict[name])
+ kgrp.create_dataset("score", data=kp_score)
+ n_kps += cpdict[name].shape[0]
+ del bindict[name]
+
+ if len(required_queries) > 0:
+ avg_kp_per_image = round(n_kps / len(required_queries), 1)
+ logger.info(
+ f"Finished assignment, found {avg_kp_per_image} "
+ f"keypoints/image (avg.), total {n_kps}."
+ )
+ return cpdict
+
+
+def assign_matches(
+ pairs: List[Tuple[str, str]],
+ match_path: Path,
+ keypoints: Union[List[Path], Dict[str, np.array]],
+ max_error: float,
+):
+ if isinstance(keypoints, list):
+ keypoints = load_keypoints({}, keypoints, kpts_as_bin=set([]))
+ assert len(set(sum(pairs, ())) - set(keypoints.keys())) == 0
+ with h5py.File(str(match_path), "a") as fd:
+ for name0, name1 in tqdm(pairs):
+ pair = names_to_pair(name0, name1)
+ grp = fd[pair]
+ kpts0 = grp["keypoints0"].__array__()
+ kpts1 = grp["keypoints1"].__array__()
+ scores = grp["scores"].__array__()
+
+ # NN search across cell boundaries
+ mkp_ids0 = assign_keypoints(kpts0, keypoints[name0], max_error)
+ mkp_ids1 = assign_keypoints(kpts1, keypoints[name1], max_error)
+
+ matches0, scores0 = kpids_to_matches0(mkp_ids0, mkp_ids1, scores)
+
+ # overwrite matches0 and matching_scores0
+ del grp["matches0"], grp["matching_scores0"]
+ grp.create_dataset("matches0", data=matches0)
+ grp.create_dataset("matching_scores0", data=scores0)
+
+
+@torch.no_grad()
+def match_and_assign(
+ conf: Dict,
+ pairs_path: Path,
+ image_dir: Path,
+ match_path: Path, # out
+ feature_path_q: Path, # out
+ feature_paths_refs: Optional[List[Path]] = [],
+ max_kps: Optional[int] = 8192,
+ overwrite: bool = False,
+) -> Path:
+ for path in feature_paths_refs:
+ if not path.exists():
+ raise FileNotFoundError(f"Reference feature file {path}.")
+ pairs = parse_retrieval(pairs_path)
+ pairs = [(q, r) for q, rs in pairs.items() for r in rs]
+ pairs = find_unique_new_pairs(pairs, None if overwrite else match_path)
+ required_queries = set(sum(pairs, ()))
+
+ name2ref = {
+ n: i for i, p in enumerate(feature_paths_refs) for n in list_h5_names(p)
+ }
+ existing_refs = required_queries.intersection(set(name2ref.keys()))
+
+ # images which require feature extraction
+ required_queries = required_queries - existing_refs
+
+ if feature_path_q.exists():
+ existing_queries = set(list_h5_names(feature_path_q))
+ feature_paths_refs.append(feature_path_q)
+ existing_refs = set.union(existing_refs, existing_queries)
+ if not overwrite:
+ required_queries = required_queries - existing_queries
+
+ if len(pairs) == 0 and len(required_queries) == 0:
+ logger.info("All pairs exist. Skipping dense matching.")
+ return
+
+ # extract semi-dense matches
+ match_dense(conf, pairs, image_dir, match_path, existing_refs=existing_refs)
+
+ logger.info("Assigning matches...")
+
+ # Pre-load existing keypoints
+ cpdict, bindict = load_keypoints(
+ conf, feature_paths_refs, quantize=required_queries
+ )
+
+ # Reassign matches by aggregation
+ cpdict = aggregate_matches(
+ conf,
+ pairs,
+ match_path,
+ feature_path=feature_path_q,
+ required_queries=required_queries,
+ max_kps=max_kps,
+ cpdict=cpdict,
+ bindict=bindict,
+ )
+
+ # Invalidate matches that are far from selected bin by reassignment
+ if max_kps is not None:
+ logger.info(f'Reassign matches with max_error={conf["max_error"]}.')
+ assign_matches(pairs, match_path, cpdict, max_error=conf["max_error"])
+
+
+def scale_lines(lines, scale):
+ if np.any(scale != 1.0):
+ lines *= lines.new_tensor(scale)
+ return lines
+
+
+def match(model, path_0, path_1, conf):
+ default_conf = {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "cache_images": False,
+ "force_resize": False,
+ "width": 320,
+ "height": 240,
+ }
+
+ def preprocess(image: np.ndarray):
+ image = image.astype(np.float32, copy=False)
+ size = image.shape[:2][::-1]
+ scale = np.array([1.0, 1.0])
+ if conf.resize_max:
+ scale = conf.resize_max / max(size)
+ if scale < 1.0:
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, "cv2_area")
+ scale = np.array(size) / np.array(size_new)
+ if conf.force_resize:
+ size = image.shape[:2][::-1]
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
+ size_new = (conf.width, conf.height)
+ scale = np.array(size) / np.array(size_new)
+ if conf.grayscale:
+ assert image.ndim == 2, image.shape
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = torch.from_numpy(image / 255.0).float()
+ # assure that the size is divisible by dfactor
+ size_new = tuple(
+ map(
+ lambda x: int(x // conf.dfactor * conf.dfactor),
+ image.shape[-2:],
+ )
+ )
+ image = F.resize(image, size=size_new, antialias=True)
+ scale = np.array(size) / np.array(size_new)[::-1]
+ return image, scale
+
+ conf = SimpleNamespace(**{**default_conf, **conf})
+ image0 = read_image(path_0, conf.grayscale)
+ image1 = read_image(path_1, conf.grayscale)
+ image0, scale0 = preprocess(image0)
+ image1, scale1 = preprocess(image1)
+ image0 = image0.to(device)[None]
+ image1 = image1.to(device)[None]
+ pred = model({"image0": image0, "image1": image1})
+
+ # Rescale keypoints and move to cpu
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ kpts0 = scale_keypoints(kpts0 + 0.5, scale0) - 0.5
+ kpts1 = scale_keypoints(kpts1 + 0.5, scale1) - 0.5
+
+ ret = {
+ "image0": image0.squeeze().cpu().numpy(),
+ "image1": image1.squeeze().cpu().numpy(),
+ "keypoints0": kpts0.cpu().numpy(),
+ "keypoints1": kpts1.cpu().numpy(),
+ }
+ if "mconf" in pred.keys():
+ ret["mconf"] = pred["mconf"].cpu().numpy()
+ return ret
+
+
+@torch.no_grad()
+def match_images(model, image_0, image_1, conf, device="cpu"):
+ default_conf = {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "cache_images": False,
+ "force_resize": False,
+ "width": 320,
+ "height": 240,
+ }
+
+ def preprocess(image: np.ndarray):
+ image = image.astype(np.float32, copy=False)
+ size = image.shape[:2][::-1]
+ scale = np.array([1.0, 1.0])
+ if conf.resize_max:
+ scale = conf.resize_max / max(size)
+ if scale < 1.0:
+ size_new = tuple(int(round(x * scale)) for x in size)
+ image = resize_image(image, size_new, "cv2_area")
+ scale = np.array(size) / np.array(size_new)
+ if conf.force_resize:
+ size = image.shape[:2][::-1]
+ image = resize_image(image, (conf.width, conf.height), "cv2_area")
+ size_new = (conf.width, conf.height)
+ scale = np.array(size) / np.array(size_new)
+ if conf.grayscale:
+ assert image.ndim == 2, image.shape
+ image = image[None]
+ else:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ image = torch.from_numpy(image / 255.0).float()
+
+ # assure that the size is divisible by dfactor
+ size_new = tuple(
+ map(
+ lambda x: int(x // conf.dfactor * conf.dfactor),
+ image.shape[-2:],
+ )
+ )
+ image = F.resize(image, size=size_new)
+ scale = np.array(size) / np.array(size_new)[::-1]
+ return image, scale
+
+ conf = SimpleNamespace(**{**default_conf, **conf})
+
+ if len(image_0.shape) == 3 and conf.grayscale:
+ image0 = cv2.cvtColor(image_0, cv2.COLOR_RGB2GRAY)
+ else:
+ image0 = image_0
+ if len(image_0.shape) == 3 and conf.grayscale:
+ image1 = cv2.cvtColor(image_1, cv2.COLOR_RGB2GRAY)
+ else:
+ image1 = image_1
+
+ # comment following lines, image is always RGB mode
+ # if not conf.grayscale and len(image0.shape) == 3:
+ # image0 = image0[:, :, ::-1] # BGR to RGB
+ # if not conf.grayscale and len(image1.shape) == 3:
+ # image1 = image1[:, :, ::-1] # BGR to RGB
+
+ image0, scale0 = preprocess(image0)
+ image1, scale1 = preprocess(image1)
+ image0 = image0.to(device)[None]
+ image1 = image1.to(device)[None]
+ pred = model({"image0": image0, "image1": image1})
+
+ s0 = np.array(image_0.shape[:2][::-1]) / np.array(image0.shape[-2:][::-1])
+ s1 = np.array(image_1.shape[:2][::-1]) / np.array(image1.shape[-2:][::-1])
+
+ # Rescale keypoints and move to cpu
+ if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
+ kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
+
+ ret = {
+ "image0": image0.squeeze().cpu().numpy(),
+ "image1": image1.squeeze().cpu().numpy(),
+ "image0_orig": image_0,
+ "image1_orig": image_1,
+ "keypoints0": kpts0.cpu().numpy(),
+ "keypoints1": kpts1.cpu().numpy(),
+ "keypoints0_orig": kpts0_origin.cpu().numpy(),
+ "keypoints1_orig": kpts1_origin.cpu().numpy(),
+ "mkeypoints0": kpts0.cpu().numpy(),
+ "mkeypoints1": kpts1.cpu().numpy(),
+ "mkeypoints0_orig": kpts0_origin.cpu().numpy(),
+ "mkeypoints1_orig": kpts1_origin.cpu().numpy(),
+ "original_size0": np.array(image_0.shape[:2][::-1]),
+ "original_size1": np.array(image_1.shape[:2][::-1]),
+ "new_size0": np.array(image0.shape[-2:][::-1]),
+ "new_size1": np.array(image1.shape[-2:][::-1]),
+ "scale0": s0,
+ "scale1": s1,
+ }
+ if "mconf" in pred.keys():
+ ret["mconf"] = pred["mconf"].cpu().numpy()
+ elif "scores" in pred.keys(): # adapting loftr
+ ret["mconf"] = pred["scores"].cpu().numpy()
+ else:
+ ret["mconf"] = np.ones_like(kpts0.cpu().numpy()[:, 0])
+ if "lines0" in pred.keys() and "lines1" in pred.keys():
+ if "keypoints0" in pred.keys() and "keypoints1" in pred.keys():
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ kpts0_origin = scale_keypoints(kpts0 + 0.5, s0) - 0.5
+ kpts1_origin = scale_keypoints(kpts1 + 0.5, s1) - 0.5
+ kpts0_origin = kpts0_origin.cpu().numpy()
+ kpts1_origin = kpts1_origin.cpu().numpy()
+ else:
+ kpts0_origin, kpts1_origin = (
+ None,
+ None,
+ ) # np.zeros([0]), np.zeros([0])
+ lines0, lines1 = pred["lines0"], pred["lines1"]
+ lines0_raw, lines1_raw = pred["raw_lines0"], pred["raw_lines1"]
+
+ lines0_raw = torch.from_numpy(lines0_raw.copy())
+ lines1_raw = torch.from_numpy(lines1_raw.copy())
+ lines0_raw = scale_lines(lines0_raw + 0.5, s0) - 0.5
+ lines1_raw = scale_lines(lines1_raw + 0.5, s1) - 0.5
+
+ lines0 = torch.from_numpy(lines0.copy())
+ lines1 = torch.from_numpy(lines1.copy())
+ lines0 = scale_lines(lines0 + 0.5, s0) - 0.5
+ lines1 = scale_lines(lines1 + 0.5, s1) - 0.5
+
+ ret = {
+ "image0_orig": image_0,
+ "image1_orig": image_1,
+ "line0": lines0_raw.cpu().numpy(),
+ "line1": lines1_raw.cpu().numpy(),
+ "line0_orig": lines0.cpu().numpy(),
+ "line1_orig": lines1.cpu().numpy(),
+ "line_keypoints0_orig": kpts0_origin,
+ "line_keypoints1_orig": kpts1_origin,
+ }
+ del pred
+ torch.cuda.empty_cache()
+ return ret
+
+
+@torch.no_grad()
+def main(
+ conf: Dict,
+ pairs: Path,
+ image_dir: Path,
+ export_dir: Optional[Path] = None,
+ matches: Optional[Path] = None, # out
+ features: Optional[Path] = None, # out
+ features_ref: Optional[Path] = None,
+ max_kps: Optional[int] = 8192,
+ overwrite: bool = False,
+) -> Path:
+ logger.info(
+ "Extracting semi-dense features with configuration:" f"\n{pprint.pformat(conf)}"
+ )
+
+ if features is None:
+ features = "feats_"
+
+ if isinstance(features, Path):
+ features_q = features
+ if matches is None:
+ raise ValueError(
+ "Either provide both features and matches as Path" " or both as names."
+ )
+ else:
+ if export_dir is None:
+ raise ValueError(
+ "Provide an export_dir if features and matches"
+ f" are not file paths: {features}, {matches}."
+ )
+ features_q = Path(export_dir, f'{features}{conf["output"]}.h5')
+ if matches is None:
+ matches = Path(export_dir, f'{conf["output"]}_{pairs.stem}.h5')
+
+ if features_ref is None:
+ features_ref = []
+ elif isinstance(features_ref, list):
+ features_ref = list(features_ref)
+ elif isinstance(features_ref, Path):
+ features_ref = [features_ref]
+ else:
+ raise TypeError(str(features_ref))
+
+ match_and_assign(
+ conf,
+ pairs,
+ image_dir,
+ matches,
+ features_q,
+ features_ref,
+ max_kps,
+ overwrite,
+ )
+
+ return features_q, matches
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pairs", type=Path, required=True)
+ parser.add_argument("--image_dir", type=Path, required=True)
+ parser.add_argument("--export_dir", type=Path, required=True)
+ parser.add_argument("--matches", type=Path, default=confs["loftr"]["output"])
+ parser.add_argument(
+ "--features", type=str, default="feats_" + confs["loftr"]["output"]
+ )
+ parser.add_argument("--conf", type=str, default="loftr", choices=list(confs.keys()))
+ args = parser.parse_args()
+ main(
+ confs[args.conf],
+ args.pairs,
+ args.image_dir,
+ args.export_dir,
+ args.matches,
+ args.features,
+ )
diff --git a/imcui/hloc/match_features.py b/imcui/hloc/match_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..50917b5a586a172a092a6d00c5ec9235e1b81a36
--- /dev/null
+++ b/imcui/hloc/match_features.py
@@ -0,0 +1,459 @@
+import argparse
+import pprint
+from functools import partial
+from pathlib import Path
+from queue import Queue
+from threading import Thread
+from typing import Dict, List, Optional, Tuple, Union
+
+import h5py
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from . import logger, matchers
+from .utils.base_model import dynamic_load
+from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
+
+"""
+A set of standard configurations that can be directly selected from the command
+line using their name. Each is a dictionary with the following entries:
+ - output: the name of the match file that will be generated.
+ - model: the model configuration, as passed to a feature matcher.
+"""
+confs = {
+ "superglue": {
+ "output": "matches-superglue",
+ "model": {
+ "name": "superglue",
+ "weights": "outdoor",
+ "sinkhorn_iterations": 50,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "superglue-fast": {
+ "output": "matches-superglue-it5",
+ "model": {
+ "name": "superglue",
+ "weights": "outdoor",
+ "sinkhorn_iterations": 5,
+ "match_threshold": 0.2,
+ },
+ },
+ "superpoint-lightglue": {
+ "output": "matches-lightglue",
+ "model": {
+ "name": "lightglue",
+ "match_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "superpoint",
+ "model_name": "superpoint_lightglue.pth",
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "disk-lightglue": {
+ "output": "matches-disk-lightglue",
+ "model": {
+ "name": "lightglue",
+ "match_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "disk",
+ "model_name": "disk_lightglue.pth",
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "aliked-lightglue": {
+ "output": "matches-aliked-lightglue",
+ "model": {
+ "name": "lightglue",
+ "match_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "aliked",
+ "model_name": "aliked_lightglue.pth",
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "sift-lightglue": {
+ "output": "matches-sift-lightglue",
+ "model": {
+ "name": "lightglue",
+ "match_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "sift",
+ "add_scale_ori": True,
+ "model_name": "sift_lightglue.pth",
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "sgmnet": {
+ "output": "matches-sgmnet",
+ "model": {
+ "name": "sgmnet",
+ "seed_top_k": [256, 256],
+ "seed_radius_coe": 0.01,
+ "net_channels": 128,
+ "layer_num": 9,
+ "head": 4,
+ "seedlayer": [0, 6],
+ "use_mc_seeding": True,
+ "use_score_encoding": False,
+ "conf_bar": [1.11, 0.1],
+ "sink_iter": [10, 100],
+ "detach_iter": 1000000,
+ "match_threshold": 0.2,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ "dfactor": 8,
+ "force_resize": False,
+ },
+ },
+ "NN-superpoint": {
+ "output": "matches-NN-mutual-dist.7",
+ "model": {
+ "name": "nearest_neighbor",
+ "do_mutual_check": True,
+ "distance_threshold": 0.7,
+ "match_threshold": 0.2,
+ },
+ },
+ "NN-ratio": {
+ "output": "matches-NN-mutual-ratio.8",
+ "model": {
+ "name": "nearest_neighbor",
+ "do_mutual_check": True,
+ "ratio_threshold": 0.8,
+ "match_threshold": 0.2,
+ },
+ },
+ "NN-mutual": {
+ "output": "matches-NN-mutual",
+ "model": {
+ "name": "nearest_neighbor",
+ "do_mutual_check": True,
+ "match_threshold": 0.2,
+ },
+ },
+ "Dual-Softmax": {
+ "output": "matches-Dual-Softmax",
+ "model": {
+ "name": "dual_softmax",
+ "match_threshold": 0.01,
+ "inv_temperature": 20,
+ },
+ },
+ "adalam": {
+ "output": "matches-adalam",
+ "model": {
+ "name": "adalam",
+ "match_threshold": 0.2,
+ },
+ },
+ "imp": {
+ "output": "matches-imp",
+ "model": {
+ "name": "imp",
+ "match_threshold": 0.2,
+ },
+ },
+}
+
+
+class WorkQueue:
+ def __init__(self, work_fn, num_threads=1):
+ self.queue = Queue(num_threads)
+ self.threads = [
+ Thread(target=self.thread_fn, args=(work_fn,)) for _ in range(num_threads)
+ ]
+ for thread in self.threads:
+ thread.start()
+
+ def join(self):
+ for thread in self.threads:
+ self.queue.put(None)
+ for thread in self.threads:
+ thread.join()
+
+ def thread_fn(self, work_fn):
+ item = self.queue.get()
+ while item is not None:
+ work_fn(item)
+ item = self.queue.get()
+
+ def put(self, data):
+ self.queue.put(data)
+
+
+class FeaturePairsDataset(torch.utils.data.Dataset):
+ def __init__(self, pairs, feature_path_q, feature_path_r):
+ self.pairs = pairs
+ self.feature_path_q = feature_path_q
+ self.feature_path_r = feature_path_r
+
+ def __getitem__(self, idx):
+ name0, name1 = self.pairs[idx]
+ data = {}
+ with h5py.File(self.feature_path_q, "r") as fd:
+ grp = fd[name0]
+ for k, v in grp.items():
+ data[k + "0"] = torch.from_numpy(v.__array__()).float()
+ # some matchers might expect an image but only use its size
+ data["image0"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
+ with h5py.File(self.feature_path_r, "r") as fd:
+ grp = fd[name1]
+ for k, v in grp.items():
+ data[k + "1"] = torch.from_numpy(v.__array__()).float()
+ data["image1"] = torch.empty((1,) + tuple(grp["image_size"])[::-1])
+ return data
+
+ def __len__(self):
+ return len(self.pairs)
+
+
+def writer_fn(inp, match_path):
+ pair, pred = inp
+ with h5py.File(str(match_path), "a", libver="latest") as fd:
+ if pair in fd:
+ del fd[pair]
+ grp = fd.create_group(pair)
+ matches = pred["matches0"][0].cpu().short().numpy()
+ grp.create_dataset("matches0", data=matches)
+ if "matching_scores0" in pred:
+ scores = pred["matching_scores0"][0].cpu().half().numpy()
+ grp.create_dataset("matching_scores0", data=scores)
+
+
+def main(
+ conf: Dict,
+ pairs: Path,
+ features: Union[Path, str],
+ export_dir: Optional[Path] = None,
+ matches: Optional[Path] = None,
+ features_ref: Optional[Path] = None,
+ overwrite: bool = False,
+) -> Path:
+ if isinstance(features, Path) or Path(features).exists():
+ features_q = features
+ if matches is None:
+ raise ValueError(
+ "Either provide both features and matches as Path" " or both as names."
+ )
+ else:
+ if export_dir is None:
+ raise ValueError(
+ "Provide an export_dir if features is not" f" a file path: {features}."
+ )
+ features_q = Path(export_dir, features + ".h5")
+ if matches is None:
+ matches = Path(export_dir, f'{features}_{conf["output"]}_{pairs.stem}.h5')
+
+ if features_ref is None:
+ features_ref = features_q
+ match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite)
+
+ return matches
+
+
+def find_unique_new_pairs(pairs_all: List[Tuple[str]], match_path: Path = None):
+ """Avoid to recompute duplicates to save time."""
+ pairs = set()
+ for i, j in pairs_all:
+ if (j, i) not in pairs:
+ pairs.add((i, j))
+ pairs = list(pairs)
+ if match_path is not None and match_path.exists():
+ with h5py.File(str(match_path), "r", libver="latest") as fd:
+ pairs_filtered = []
+ for i, j in pairs:
+ if (
+ names_to_pair(i, j) in fd
+ or names_to_pair(j, i) in fd
+ or names_to_pair_old(i, j) in fd
+ or names_to_pair_old(j, i) in fd
+ ):
+ continue
+ pairs_filtered.append((i, j))
+ return pairs_filtered
+ return pairs
+
+
+@torch.no_grad()
+def match_from_paths(
+ conf: Dict,
+ pairs_path: Path,
+ match_path: Path,
+ feature_path_q: Path,
+ feature_path_ref: Path,
+ overwrite: bool = False,
+) -> Path:
+ logger.info(
+ "Matching local features with configuration:" f"\n{pprint.pformat(conf)}"
+ )
+
+ if not feature_path_q.exists():
+ raise FileNotFoundError(f"Query feature file {feature_path_q}.")
+ if not feature_path_ref.exists():
+ raise FileNotFoundError(f"Reference feature file {feature_path_ref}.")
+ match_path.parent.mkdir(exist_ok=True, parents=True)
+
+ assert pairs_path.exists(), pairs_path
+ pairs = parse_retrieval(pairs_path)
+ pairs = [(q, r) for q, rs in pairs.items() for r in rs]
+ pairs = find_unique_new_pairs(pairs, None if overwrite else match_path)
+ if len(pairs) == 0:
+ logger.info("Skipping the matching.")
+ return
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ Model = dynamic_load(matchers, conf["model"]["name"])
+ model = Model(conf["model"]).eval().to(device)
+
+ dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref)
+ loader = torch.utils.data.DataLoader(
+ dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True
+ )
+ writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5)
+
+ for idx, data in enumerate(tqdm(loader, smoothing=0.1)):
+ data = {
+ k: v if k.startswith("image") else v.to(device, non_blocking=True)
+ for k, v in data.items()
+ }
+ pred = model(data)
+ pair = names_to_pair(*pairs[idx])
+ writer_queue.put((pair, pred))
+ writer_queue.join()
+ logger.info("Finished exporting matches.")
+
+
+def scale_keypoints(kpts, scale):
+ if (
+ isinstance(scale, (list, tuple, np.ndarray))
+ and len(scale) == 2
+ and np.any(scale != np.array([1.0, 1.0]))
+ ):
+ if isinstance(kpts, torch.Tensor):
+ kpts[:, 0] *= scale[0] # scale x-dimension
+ kpts[:, 1] *= scale[1] # scale y-dimension
+ elif isinstance(kpts, np.ndarray):
+ kpts[:, 0] *= scale[0] # scale x-dimension
+ kpts[:, 1] *= scale[1] # scale y-dimension
+ return kpts
+
+
+@torch.no_grad()
+def match_images(model, feat0, feat1):
+ # forward pass to match keypoints
+ desc0 = feat0["descriptors"][0]
+ desc1 = feat1["descriptors"][0]
+ if len(desc0.shape) == 2:
+ desc0 = desc0.unsqueeze(0)
+ if len(desc1.shape) == 2:
+ desc1 = desc1.unsqueeze(0)
+ if isinstance(feat0["keypoints"], list):
+ feat0["keypoints"] = feat0["keypoints"][0][None]
+ if isinstance(feat1["keypoints"], list):
+ feat1["keypoints"] = feat1["keypoints"][0][None]
+ input_dict = {
+ "image0": feat0["image"],
+ "keypoints0": feat0["keypoints"],
+ "scores0": feat0["scores"][0].unsqueeze(0),
+ "descriptors0": desc0,
+ "image1": feat1["image"],
+ "keypoints1": feat1["keypoints"],
+ "scores1": feat1["scores"][0].unsqueeze(0),
+ "descriptors1": desc1,
+ }
+ if "scales" in feat0:
+ input_dict = {**input_dict, "scales0": feat0["scales"]}
+ if "scales" in feat1:
+ input_dict = {**input_dict, "scales1": feat1["scales"]}
+ if "oris" in feat0:
+ input_dict = {**input_dict, "oris0": feat0["oris"]}
+ if "oris" in feat1:
+ input_dict = {**input_dict, "oris1": feat1["oris"]}
+ pred = model(input_dict)
+ pred = {
+ k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v
+ for k, v in pred.items()
+ }
+ kpts0, kpts1 = (
+ feat0["keypoints"][0].cpu().numpy(),
+ feat1["keypoints"][0].cpu().numpy(),
+ )
+ matches, confid = pred["matches0"], pred["matching_scores0"]
+ # Keep the matching keypoints.
+ valid = matches > -1
+ mkpts0 = kpts0[valid]
+ mkpts1 = kpts1[matches[valid]]
+ mconfid = confid[valid]
+ # rescale the keypoints to their original size
+ s0 = feat0["original_size"] / feat0["size"]
+ s1 = feat1["original_size"] / feat1["size"]
+ kpts0_origin = scale_keypoints(torch.from_numpy(kpts0 + 0.5), s0) - 0.5
+ kpts1_origin = scale_keypoints(torch.from_numpy(kpts1 + 0.5), s1) - 0.5
+
+ mkpts0_origin = scale_keypoints(torch.from_numpy(mkpts0 + 0.5), s0) - 0.5
+ mkpts1_origin = scale_keypoints(torch.from_numpy(mkpts1 + 0.5), s1) - 0.5
+
+ ret = {
+ "image0_orig": feat0["image_orig"],
+ "image1_orig": feat1["image_orig"],
+ "keypoints0": kpts0,
+ "keypoints1": kpts1,
+ "keypoints0_orig": kpts0_origin.numpy(),
+ "keypoints1_orig": kpts1_origin.numpy(),
+ "mkeypoints0": mkpts0,
+ "mkeypoints1": mkpts1,
+ "mkeypoints0_orig": mkpts0_origin.numpy(),
+ "mkeypoints1_orig": mkpts1_origin.numpy(),
+ "mconf": mconfid.numpy(),
+ }
+ del feat0, feat1, desc0, desc1, kpts0, kpts1, kpts0_origin, kpts1_origin
+ torch.cuda.empty_cache()
+
+ return ret
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pairs", type=Path, required=True)
+ parser.add_argument("--export_dir", type=Path)
+ parser.add_argument("--features", type=str, default="feats-superpoint-n4096-r1024")
+ parser.add_argument("--matches", type=Path)
+ parser.add_argument(
+ "--conf", type=str, default="superglue", choices=list(confs.keys())
+ )
+ args = parser.parse_args()
+ main(confs[args.conf], args.pairs, args.features, args.export_dir)
diff --git a/imcui/hloc/matchers/__init__.py b/imcui/hloc/matchers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9fd381eb391604db3d3c3c03d278c0e08de2531
--- /dev/null
+++ b/imcui/hloc/matchers/__init__.py
@@ -0,0 +1,3 @@
+def get_matcher(matcher):
+ mod = __import__(f"{__name__}.{matcher}", fromlist=[""])
+ return getattr(mod, "Model")
diff --git a/imcui/hloc/matchers/adalam.py b/imcui/hloc/matchers/adalam.py
new file mode 100644
index 0000000000000000000000000000000000000000..7820428a5a087d0b5d6855de15c0230327ce7dc1
--- /dev/null
+++ b/imcui/hloc/matchers/adalam.py
@@ -0,0 +1,68 @@
+import torch
+from kornia.feature.adalam import AdalamFilter
+from kornia.utils.helpers import get_cuda_device_if_available
+
+from ..utils.base_model import BaseModel
+
+
+class AdaLAM(BaseModel):
+ # See https://kornia.readthedocs.io/en/latest/_modules/kornia/feature/adalam/adalam.html.
+ default_conf = {
+ "area_ratio": 100,
+ "search_expansion": 4,
+ "ransac_iters": 128,
+ "min_inliers": 6,
+ "min_confidence": 200,
+ "orientation_difference_threshold": 30,
+ "scale_rate_threshold": 1.5,
+ "detected_scale_rate_threshold": 5,
+ "refit": True,
+ "force_seed_mnn": True,
+ "device": get_cuda_device_if_available(),
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ "descriptors0",
+ "descriptors1",
+ "keypoints0",
+ "keypoints1",
+ "scales0",
+ "scales1",
+ "oris0",
+ "oris1",
+ ]
+
+ def _init(self, conf):
+ self.adalam = AdalamFilter(conf)
+
+ def _forward(self, data):
+ assert data["keypoints0"].size(0) == 1
+ if data["keypoints0"].size(1) < 2 or data["keypoints1"].size(1) < 2:
+ matches = torch.zeros(
+ (0, 2), dtype=torch.int64, device=data["keypoints0"].device
+ )
+ else:
+ matches = self.adalam.match_and_filter(
+ data["keypoints0"][0],
+ data["keypoints1"][0],
+ data["descriptors0"][0].T,
+ data["descriptors1"][0].T,
+ data["image0"].shape[2:],
+ data["image1"].shape[2:],
+ data["oris0"][0],
+ data["oris1"][0],
+ data["scales0"][0],
+ data["scales1"][0],
+ )
+ matches_new = torch.full(
+ (data["keypoints0"].size(1),),
+ -1,
+ dtype=torch.int64,
+ device=data["keypoints0"].device,
+ )
+ matches_new[matches[:, 0]] = matches[:, 1]
+ return {
+ "matches0": matches_new.unsqueeze(0),
+ "matching_scores0": torch.zeros(matches_new.size(0)).unsqueeze(0),
+ }
diff --git a/imcui/hloc/matchers/aspanformer.py b/imcui/hloc/matchers/aspanformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0636ff95a60edc992a8ded22590a7ab1baad4210
--- /dev/null
+++ b/imcui/hloc/matchers/aspanformer.py
@@ -0,0 +1,66 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
+from ASpanFormer.src.config.default import get_cfg_defaults
+from ASpanFormer.src.utils.misc import lower_config
+
+aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
+
+
+class ASpanFormer(BaseModel):
+ default_conf = {
+ "model_name": "outdoor.ckpt",
+ "match_threshold": 0.2,
+ "sinkhorn_iterations": 20,
+ "max_keypoints": 2048,
+ "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ config = get_cfg_defaults()
+ config.merge_from_file(conf["config_path"])
+ _config = lower_config(config)
+
+ # update: match threshold
+ _config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"]
+ _config["aspan"]["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
+
+ self.net = _ASpanFormer(config=_config["aspan"])
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ state_dict = torch.load(str(model_path), map_location="cpu")["state_dict"]
+ self.net.load_state_dict(state_dict, strict=False)
+ logger.info("Loaded Aspanformer model")
+
+ def _forward(self, data):
+ data_ = {
+ "image0": data["image0"],
+ "image1": data["image1"],
+ }
+ self.net(data_, online_resize=True)
+ pred = {
+ "keypoints0": data_["mkpts0_f"],
+ "keypoints1": data_["mkpts1_f"],
+ "mconf": data_["mconf"],
+ }
+ scores = data_["mconf"]
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(scores) > top_k:
+ keep = torch.argsort(scores, descending=True)[:top_k]
+ scores = scores[keep]
+ pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
+ pred["keypoints0"][keep],
+ pred["keypoints1"][keep],
+ scores,
+ )
+ return pred
diff --git a/imcui/hloc/matchers/cotr.py b/imcui/hloc/matchers/cotr.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec0234b2917ad0e3da9fbff76da9bcf83a19c04
--- /dev/null
+++ b/imcui/hloc/matchers/cotr.py
@@ -0,0 +1,77 @@
+import argparse
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+from torchvision.transforms import ToPILImage
+
+from .. import DEVICE, MODEL_REPO_ID
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
+from COTR.inference.sparse_engine import SparseEngine
+from COTR.models import build_model
+from COTR.options.options import * # noqa: F403
+from COTR.options.options_utils import * # noqa: F403
+from COTR.utils import utils as utils_cotr
+
+utils_cotr.fix_randomness(0)
+torch.set_grad_enabled(False)
+
+
+class COTR(BaseModel):
+ default_conf = {
+ "weights": "out/default",
+ "match_threshold": 0.2,
+ "max_keypoints": -1,
+ "model_name": "checkpoint.pth.tar",
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ parser = argparse.ArgumentParser()
+ set_COTR_arguments(parser) # noqa: F405
+ opt = parser.parse_args()
+ opt.command = " ".join(sys.argv)
+ opt.load_weights_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+
+ layer_2_channels = {
+ "layer1": 256,
+ "layer2": 512,
+ "layer3": 1024,
+ "layer4": 2048,
+ }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+
+ model = build_model(opt)
+ model = model.to(DEVICE)
+ weights = torch.load(opt.load_weights_path, map_location="cpu")[
+ "model_state_dict"
+ ]
+ utils_cotr.safe_load_weights(model, weights)
+ self.net = model.eval()
+ self.to_pil_func = ToPILImage(mode="RGB")
+
+ def _forward(self, data):
+ img_a = np.array(self.to_pil_func(data["image0"][0].cpu()))
+ img_b = np.array(self.to_pil_func(data["image1"][0].cpu()))
+ corrs = SparseEngine(
+ self.net, 32, mode="tile"
+ ).cotr_corr_multiscale_with_cycle_consistency(
+ img_a,
+ img_b,
+ np.linspace(0.5, 0.0625, 4),
+ 1,
+ max_corrs=self.conf["max_keypoints"],
+ queries_a=None,
+ )
+ pred = {
+ "keypoints0": torch.from_numpy(corrs[:, :2]),
+ "keypoints1": torch.from_numpy(corrs[:, 2:]),
+ }
+ return pred
diff --git a/imcui/hloc/matchers/dkm.py b/imcui/hloc/matchers/dkm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2deca95ca987dbd4d7e1fbb5c65e587222d3dd4c
--- /dev/null
+++ b/imcui/hloc/matchers/dkm.py
@@ -0,0 +1,53 @@
+import sys
+from pathlib import Path
+
+from PIL import Image
+
+from .. import DEVICE, MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from DKM.dkm import DKMv3_outdoor
+
+
+class DKMv3(BaseModel):
+ default_conf = {
+ "model_name": "DKMv3_outdoor.pth",
+ "match_threshold": 0.2,
+ "max_keypoints": -1,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+
+ self.net = DKMv3_outdoor(path_to_weights=str(model_path), device=DEVICE)
+ logger.info("Loading DKMv3 model done")
+
+ def _forward(self, data):
+ img0 = data["image0"].cpu().numpy().squeeze() * 255
+ img1 = data["image1"].cpu().numpy().squeeze() * 255
+ img0 = img0.transpose(1, 2, 0)
+ img1 = img1.transpose(1, 2, 0)
+ img0 = Image.fromarray(img0.astype("uint8"))
+ img1 = Image.fromarray(img1.astype("uint8"))
+ W_A, H_A = img0.size
+ W_B, H_B = img1.size
+
+ warp, certainty = self.net.match(img0, img1, device=DEVICE)
+ matches, certainty = self.net.sample(
+ warp, certainty, num=self.conf["max_keypoints"]
+ )
+ kpts1, kpts2 = self.net.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ pred = {
+ "keypoints0": kpts1,
+ "keypoints1": kpts2,
+ "mconf": certainty,
+ }
+ return pred
diff --git a/imcui/hloc/matchers/dual_softmax.py b/imcui/hloc/matchers/dual_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cef54473a0483ce2bca3b158c507c9e9480b641
--- /dev/null
+++ b/imcui/hloc/matchers/dual_softmax.py
@@ -0,0 +1,71 @@
+import numpy as np
+import torch
+
+from ..utils.base_model import BaseModel
+
+
+# borrow from dedode
+def dual_softmax_matcher(
+ desc_A: tuple["B", "C", "N"], # noqa: F821
+ desc_B: tuple["B", "C", "M"], # noqa: F821
+ threshold=0.1,
+ inv_temperature=20,
+ normalize=True,
+):
+ B, C, N = desc_A.shape
+ if len(desc_A.shape) < 3:
+ desc_A, desc_B = desc_A[None], desc_B[None]
+ if normalize:
+ desc_A = desc_A / desc_A.norm(dim=1, keepdim=True)
+ desc_B = desc_B / desc_B.norm(dim=1, keepdim=True)
+ sim = torch.einsum("b c n, b c m -> b n m", desc_A, desc_B) * inv_temperature
+ P = sim.softmax(dim=-2) * sim.softmax(dim=-1)
+ mask = torch.nonzero(
+ (P == P.max(dim=-1, keepdim=True).values)
+ * (P == P.max(dim=-2, keepdim=True).values)
+ * (P > threshold)
+ )
+ mask = mask.cpu().numpy()
+ matches0 = np.ones((B, P.shape[-2]), dtype=int) * (-1)
+ scores0 = np.zeros((B, P.shape[-2]), dtype=float)
+ matches0[:, mask[:, 1]] = mask[:, 2]
+ tmp_P = P.cpu().numpy()
+ scores0[:, mask[:, 1]] = tmp_P[mask[:, 0], mask[:, 1], mask[:, 2]]
+ matches0 = torch.from_numpy(matches0).to(P.device)
+ scores0 = torch.from_numpy(scores0).to(P.device)
+ return matches0, scores0
+
+
+class DualSoftMax(BaseModel):
+ default_conf = {
+ "match_threshold": 0.2,
+ "inv_temperature": 20,
+ }
+ # shape: B x DIM x M
+ required_inputs = ["descriptors0", "descriptors1"]
+
+ def _init(self, conf):
+ pass
+
+ def _forward(self, data):
+ if data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0:
+ matches0 = torch.full(
+ data["descriptors0"].shape[:2],
+ -1,
+ device=data["descriptors0"].device,
+ )
+ return {
+ "matches0": matches0,
+ "matching_scores0": torch.zeros_like(matches0),
+ }
+
+ matches0, scores0 = dual_softmax_matcher(
+ data["descriptors0"],
+ data["descriptors1"],
+ threshold=self.conf["match_threshold"],
+ inv_temperature=self.conf["inv_temperature"],
+ )
+ return {
+ "matches0": matches0, # 1 x M
+ "matching_scores0": scores0,
+ }
diff --git a/imcui/hloc/matchers/duster.py b/imcui/hloc/matchers/duster.py
new file mode 100644
index 0000000000000000000000000000000000000000..36fa34bc6d433295800e0223db9dc97fec93f9f9
--- /dev/null
+++ b/imcui/hloc/matchers/duster.py
@@ -0,0 +1,109 @@
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+import torchvision.transforms as tfm
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+duster_path = Path(__file__).parent / "../../third_party/dust3r"
+sys.path.append(str(duster_path))
+
+from dust3r.cloud_opt import GlobalAlignerMode, global_aligner
+from dust3r.image_pairs import make_pairs
+from dust3r.inference import inference
+from dust3r.model import AsymmetricCroCo3DStereo
+from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Duster(BaseModel):
+ default_conf = {
+ "name": "Duster3r",
+ "model_name": "duster_vit_large.pth",
+ "max_keypoints": 3000,
+ "vit_patch_size": 16,
+ }
+
+ def _init(self, conf):
+ self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ self.net = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(device)
+ logger.info("Loaded Dust3r model")
+
+ def preprocess(self, img):
+ # the super-class already makes sure that img0,img1 have
+ # same resolution and that h == w
+ _, h, _ = img.shape
+ imsize = h
+ if not ((h % self.vit_patch_size) == 0):
+ imsize = int(self.vit_patch_size * round(h / self.vit_patch_size, 0))
+ img = tfm.functional.resize(img, imsize, antialias=True)
+
+ _, new_h, new_w = img.shape
+ if not ((new_w % self.vit_patch_size) == 0):
+ safe_w = int(self.vit_patch_size * round(new_w / self.vit_patch_size, 0))
+ img = tfm.functional.resize(img, (new_h, safe_w), antialias=True)
+
+ img = self.normalize(img).unsqueeze(0)
+
+ return img
+
+ def _forward(self, data):
+ img0, img1 = data["image0"], data["image1"]
+ mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
+ std = torch.tensor([0.5, 0.5, 0.5]).to(device)
+
+ img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
+ img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
+
+ images = [
+ {"img": img0, "idx": 0, "instance": 0},
+ {"img": img1, "idx": 1, "instance": 1},
+ ]
+ pairs = make_pairs(
+ images, scene_graph="complete", prefilter=None, symmetrize=True
+ )
+ output = inference(pairs, self.net, device, batch_size=1)
+ scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PairViewer)
+ # retrieve useful values from scene:
+ imgs = scene.imgs
+ confidence_masks = scene.get_masks()
+ pts3d = scene.get_pts3d()
+ pts2d_list, pts3d_list = [], []
+ for i in range(2):
+ conf_i = confidence_masks[i].cpu().numpy()
+ pts2d_list.append(
+ xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
+ ) # imgs[i].shape[:2] = (H, W)
+ pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
+
+ if len(pts3d_list[1]) == 0:
+ pred = {
+ "keypoints0": torch.zeros([0, 2]),
+ "keypoints1": torch.zeros([0, 2]),
+ }
+ logger.warning(f"Matched {0} points")
+ else:
+ reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
+ *pts3d_list
+ )
+ logger.info(f"Found {num_matches} matches")
+ mkpts1 = pts2d_list[1][reciprocal_in_P2]
+ mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(mkpts0) > top_k:
+ keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
+ mkpts0 = mkpts0[keep]
+ mkpts1 = mkpts1[keep]
+ pred = {
+ "keypoints0": torch.from_numpy(mkpts0),
+ "keypoints1": torch.from_numpy(mkpts1),
+ }
+ return pred
diff --git a/imcui/hloc/matchers/eloftr.py b/imcui/hloc/matchers/eloftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ca352808e7b5a2a8bc7253be2d591c439798491
--- /dev/null
+++ b/imcui/hloc/matchers/eloftr.py
@@ -0,0 +1,97 @@
+import sys
+import warnings
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+
+tp_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(tp_path))
+
+from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_
+from EfficientLoFTR.src.loftr import (
+ full_default_cfg,
+ opt_default_cfg,
+ reparameter,
+)
+
+
+from ..utils.base_model import BaseModel
+
+
+class ELoFTR(BaseModel):
+ default_conf = {
+ "model_name": "eloftr_outdoor.ckpt",
+ "match_threshold": 0.2,
+ # "sinkhorn_iterations": 20,
+ "max_keypoints": -1,
+ # You can choose model type in ['full', 'opt']
+ "model_type": "full", # 'full' for best quality, 'opt' for best efficiency
+ # You can choose numerical precision in ['fp32', 'mp', 'fp16']. 'fp16' for best efficiency
+ "precision": "fp32",
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ if self.conf["model_type"] == "full":
+ _default_cfg = deepcopy(full_default_cfg)
+ elif self.conf["model_type"] == "opt":
+ _default_cfg = deepcopy(opt_default_cfg)
+
+ if self.conf["precision"] == "mp":
+ _default_cfg["mp"] = True
+ elif self.conf["precision"] == "fp16":
+ _default_cfg["half"] = True
+
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+
+ cfg = _default_cfg
+ cfg["match_coarse"]["thr"] = conf["match_threshold"]
+ # cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+ matcher = ELoFTR_(config=cfg)
+ matcher.load_state_dict(state_dict)
+ self.net = reparameter(matcher)
+
+ if self.conf["precision"] == "fp16":
+ self.net = self.net.half()
+ logger.info(f"Loaded Efficient LoFTR with weights {conf['model_name']}")
+
+ def _forward(self, data):
+ # For consistency with hloc pairs, we refine kpts in image0!
+ rename = {
+ "keypoints0": "keypoints1",
+ "keypoints1": "keypoints0",
+ "image0": "image1",
+ "image1": "image0",
+ "mask0": "mask1",
+ "mask1": "mask0",
+ }
+ data_ = {rename[k]: v for k, v in data.items()}
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ pred = self.net(data_)
+ pred = {
+ "keypoints0": data_["mkpts0_f"],
+ "keypoints1": data_["mkpts1_f"],
+ }
+ scores = data_["mconf"]
+
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(scores) > top_k:
+ keep = torch.argsort(scores, descending=True)[:top_k]
+ pred["keypoints0"], pred["keypoints1"] = (
+ pred["keypoints0"][keep],
+ pred["keypoints1"][keep],
+ )
+ scores = scores[keep]
+
+ # Switch back indices
+ pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
+ pred["scores"] = scores
+ return pred
diff --git a/imcui/hloc/matchers/gim.py b/imcui/hloc/matchers/gim.py
new file mode 100644
index 0000000000000000000000000000000000000000..afaf78c2ac832c47d8d0f7210e8188e8a0aa9899
--- /dev/null
+++ b/imcui/hloc/matchers/gim.py
@@ -0,0 +1,200 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import DEVICE, MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+gim_path = Path(__file__).parents[2] / "third_party/gim"
+sys.path.append(str(gim_path))
+
+
+def load_model(weight_name, checkpoints_path):
+ # load model
+ model = None
+ detector = None
+ if weight_name == "gim_dkm":
+ from networks.dkm.models.model_zoo.DKMv3 import DKMv3
+
+ model = DKMv3(weights=None, h=672, w=896)
+ elif weight_name == "gim_loftr":
+ from networks.loftr.config import get_cfg_defaults
+ from networks.loftr.loftr import LoFTR
+ from networks.loftr.misc import lower_config
+
+ model = LoFTR(lower_config(get_cfg_defaults())["loftr"])
+ elif weight_name == "gim_lightglue":
+ from networks.lightglue.models.matchers.lightglue import LightGlue
+ from networks.lightglue.superpoint import SuperPoint
+
+ detector = SuperPoint(
+ {
+ "max_num_keypoints": 2048,
+ "force_num_keypoints": True,
+ "detection_threshold": 0.0,
+ "nms_radius": 3,
+ "trainable": False,
+ }
+ )
+ model = LightGlue(
+ {
+ "filter_threshold": 0.1,
+ "flash": False,
+ "checkpointed": True,
+ }
+ )
+
+ # load state dict
+ if weight_name == "gim_dkm":
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+ for k in list(state_dict.keys()):
+ if k.startswith("model."):
+ state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
+ if "encoder.net.fc" in k:
+ state_dict.pop(k)
+ model.load_state_dict(state_dict)
+
+ elif weight_name == "gim_loftr":
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+ model.load_state_dict(state_dict)
+
+ elif weight_name == "gim_lightglue":
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+ for k in list(state_dict.keys()):
+ if k.startswith("model."):
+ state_dict.pop(k)
+ if k.startswith("superpoint."):
+ state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k)
+ detector.load_state_dict(state_dict)
+
+ state_dict = torch.load(checkpoints_path, map_location="cpu")
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+ for k in list(state_dict.keys()):
+ if k.startswith("superpoint."):
+ state_dict.pop(k)
+ if k.startswith("model."):
+ state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
+ model.load_state_dict(state_dict)
+
+ # eval mode
+ if detector is not None:
+ detector = detector.eval().to(DEVICE)
+ model = model.eval().to(DEVICE)
+ return model
+
+
+class GIM(BaseModel):
+ default_conf = {
+ "match_threshold": 0.2,
+ "checkpoint_dir": gim_path / "weights",
+ "weights": "gim_dkm",
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+ ckpt_name_dict = {
+ "gim_dkm": "gim_dkm_100h.ckpt",
+ "gim_loftr": "gim_loftr_50h.ckpt",
+ "gim_lightglue": "gim_lightglue_100h.ckpt",
+ }
+
+ def _init(self, conf):
+ ckpt_name = self.ckpt_name_dict[conf["weights"]]
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, ckpt_name),
+ )
+ self.aspect_ratio = 896 / 672
+ model = load_model(conf["weights"], model_path)
+ self.net = model
+ logger.info("Loaded GIM model")
+
+ def pad_image(self, image, aspect_ratio):
+ new_width = max(image.shape[3], int(image.shape[2] * aspect_ratio))
+ new_height = max(image.shape[2], int(image.shape[3] / aspect_ratio))
+ pad_width = new_width - image.shape[3]
+ pad_height = new_height - image.shape[2]
+ return torch.nn.functional.pad(
+ image,
+ (
+ pad_width // 2,
+ pad_width - pad_width // 2,
+ pad_height // 2,
+ pad_height - pad_height // 2,
+ ),
+ )
+
+ def rescale_kpts(self, sparse_matches, shape0, shape1):
+ kpts0 = torch.stack(
+ (
+ shape0[1] * (sparse_matches[:, 0] + 1) / 2,
+ shape0[0] * (sparse_matches[:, 1] + 1) / 2,
+ ),
+ dim=-1,
+ )
+ kpts1 = torch.stack(
+ (
+ shape1[1] * (sparse_matches[:, 2] + 1) / 2,
+ shape1[0] * (sparse_matches[:, 3] + 1) / 2,
+ ),
+ dim=-1,
+ )
+ return kpts0, kpts1
+
+ def compute_mask(self, kpts0, kpts1, orig_shape0, orig_shape1):
+ mask = (
+ (kpts0[:, 0] > 0)
+ & (kpts0[:, 1] > 0)
+ & (kpts1[:, 0] > 0)
+ & (kpts1[:, 1] > 0)
+ )
+ mask &= (
+ (kpts0[:, 0] <= (orig_shape0[1] - 1))
+ & (kpts1[:, 0] <= (orig_shape1[1] - 1))
+ & (kpts0[:, 1] <= (orig_shape0[0] - 1))
+ & (kpts1[:, 1] <= (orig_shape1[0] - 1))
+ )
+ return mask
+
+ def _forward(self, data):
+ # TODO: only support dkm+gim
+ image0, image1 = (
+ self.pad_image(data["image0"], self.aspect_ratio),
+ self.pad_image(data["image1"], self.aspect_ratio),
+ )
+ dense_matches, dense_certainty = self.net.match(image0, image1)
+ sparse_matches, mconf = self.net.sample(
+ dense_matches, dense_certainty, self.conf["max_keypoints"]
+ )
+ kpts0, kpts1 = self.rescale_kpts(
+ sparse_matches, image0.shape[-2:], image1.shape[-2:]
+ )
+ mask = self.compute_mask(
+ kpts0, kpts1, data["image0"].shape[-2:], data["image1"].shape[-2:]
+ )
+ b_ids, i_ids = torch.where(mconf[None])
+ pred = {
+ "keypoints0": kpts0[i_ids],
+ "keypoints1": kpts1[i_ids],
+ "confidence": mconf[i_ids],
+ "batch_indexes": b_ids,
+ }
+ scores, b_ids = pred["confidence"], pred["batch_indexes"]
+ kpts0, kpts1 = pred["keypoints0"], pred["keypoints1"]
+ pred["confidence"], pred["batch_indexes"] = scores[mask], b_ids[mask]
+ pred["keypoints0"], pred["keypoints1"] = kpts0[mask], kpts1[mask]
+
+ out = {
+ "keypoints0": pred["keypoints0"],
+ "keypoints1": pred["keypoints1"],
+ }
+ return out
diff --git a/imcui/hloc/matchers/gluestick.py b/imcui/hloc/matchers/gluestick.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f775325fde3e39570ab93a7071455a5a2661dda
--- /dev/null
+++ b/imcui/hloc/matchers/gluestick.py
@@ -0,0 +1,99 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
+sys.path.append(str(gluestick_path))
+
+from gluestick import batch_to_np
+from gluestick.models.two_view_pipeline import TwoViewPipeline
+
+
+class GlueStick(BaseModel):
+ default_conf = {
+ "name": "two_view_pipeline",
+ "model_name": "checkpoint_GlueStick_MD.tar",
+ "use_lines": True,
+ "max_keypoints": 1000,
+ "max_lines": 300,
+ "force_num_keypoints": False,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ # Download the model.
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ logger.info("Loading GlueStick model...")
+
+ gluestick_conf = {
+ "name": "two_view_pipeline",
+ "use_lines": True,
+ "extractor": {
+ "name": "wireframe",
+ "sp_params": {
+ "force_num_keypoints": False,
+ "max_num_keypoints": 1000,
+ },
+ "wireframe_params": {
+ "merge_points": True,
+ "merge_line_endpoints": True,
+ },
+ "max_n_lines": 300,
+ },
+ "matcher": {
+ "name": "gluestick",
+ "weights": str(model_path),
+ "trainable": False,
+ },
+ "ground_truth": {
+ "from_pose_depth": False,
+ },
+ }
+ gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[
+ "max_keypoints"
+ ]
+ gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[
+ "force_num_keypoints"
+ ]
+ gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"]
+ self.net = TwoViewPipeline(gluestick_conf)
+
+ def _forward(self, data):
+ pred = self.net(data)
+
+ pred = batch_to_np(pred)
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
+ m0 = pred["matches0"]
+
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
+ line_matches = pred["line_matches0"]
+
+ valid_matches = m0 != -1
+ match_indices = m0[valid_matches]
+ matched_kps0 = kp0[valid_matches]
+ matched_kps1 = kp1[match_indices]
+
+ valid_matches = line_matches != -1
+ match_indices = line_matches[valid_matches]
+ matched_lines0 = line_seg0[valid_matches]
+ matched_lines1 = line_seg1[match_indices]
+
+ pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1
+ pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1
+ pred["keypoints0"], pred["keypoints1"] = (
+ torch.from_numpy(matched_kps0),
+ torch.from_numpy(matched_kps1),
+ )
+ pred = {**pred, **data}
+ return pred
diff --git a/imcui/hloc/matchers/imp.py b/imcui/hloc/matchers/imp.py
new file mode 100644
index 0000000000000000000000000000000000000000..f37d218e7a8e46ea31b834307f48e4e7649f87a0
--- /dev/null
+++ b/imcui/hloc/matchers/imp.py
@@ -0,0 +1,50 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import DEVICE, MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+tp_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(tp_path))
+from pram.nets.gml import GML
+
+
+class IMP(BaseModel):
+ default_conf = {
+ "match_threshold": 0.2,
+ "features": "sfd2",
+ "model_name": "imp_gml.920.pth",
+ "sinkhorn_iterations": 20,
+ }
+ required_inputs = [
+ "image0",
+ "keypoints0",
+ "scores0",
+ "descriptors0",
+ "image1",
+ "keypoints1",
+ "scores1",
+ "descriptors1",
+ ]
+
+ def _init(self, conf):
+ self.conf = {**self.default_conf, **conf}
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format("pram", self.conf["model_name"]),
+ )
+
+ # self.net = nets.gml(self.conf).eval().to(DEVICE)
+ self.net = GML(self.conf).eval().to(DEVICE)
+ self.net.load_state_dict(
+ torch.load(model_path, map_location="cpu")["model"], strict=True
+ )
+ logger.info("Load IMP model done.")
+
+ def _forward(self, data):
+ data["descriptors0"] = data["descriptors0"].transpose(2, 1).float()
+ data["descriptors1"] = data["descriptors1"].transpose(2, 1).float()
+
+ return self.net.produce_matches(data, p=0.2)
diff --git a/imcui/hloc/matchers/lightglue.py b/imcui/hloc/matchers/lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bd3693813d70545bbfbfc24c4b578e10092759
--- /dev/null
+++ b/imcui/hloc/matchers/lightglue.py
@@ -0,0 +1,67 @@
+import sys
+from pathlib import Path
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
+sys.path.append(str(lightglue_path))
+from lightglue import LightGlue as LG
+
+
+class LightGlue(BaseModel):
+ default_conf = {
+ "match_threshold": 0.2,
+ "filter_threshold": 0.2,
+ "width_confidence": 0.99, # for point pruning
+ "depth_confidence": 0.95, # for early stopping,
+ "features": "superpoint",
+ "model_name": "superpoint_lightglue.pth",
+ "flash": True, # enable FlashAttention if available.
+ "mp": False, # enable mixed precision
+ "add_scale_ori": False,
+ }
+ required_inputs = [
+ "image0",
+ "keypoints0",
+ "scores0",
+ "descriptors0",
+ "image1",
+ "keypoints1",
+ "scores1",
+ "descriptors1",
+ ]
+
+ def _init(self, conf):
+ logger.info("Loading lightglue model, {}".format(conf["model_name"]))
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ conf["weights"] = str(model_path)
+ conf["filter_threshold"] = conf["match_threshold"]
+ self.net = LG(**conf)
+ logger.info("Load lightglue model done.")
+
+ def _forward(self, data):
+ input = {}
+ input["image0"] = {
+ "image": data["image0"],
+ "keypoints": data["keypoints0"],
+ "descriptors": data["descriptors0"].permute(0, 2, 1),
+ }
+ if "scales0" in data:
+ input["image0"] = {**input["image0"], "scales": data["scales0"]}
+ if "oris0" in data:
+ input["image0"] = {**input["image0"], "oris": data["oris0"]}
+
+ input["image1"] = {
+ "image": data["image1"],
+ "keypoints": data["keypoints1"],
+ "descriptors": data["descriptors1"].permute(0, 2, 1),
+ }
+ if "scales1" in data:
+ input["image1"] = {**input["image1"], "scales": data["scales1"]}
+ if "oris1" in data:
+ input["image1"] = {**input["image1"], "oris": data["oris1"]}
+ return self.net(input)
diff --git a/imcui/hloc/matchers/loftr.py b/imcui/hloc/matchers/loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce7fe28d87337905069c675452bf0bf2522068c7
--- /dev/null
+++ b/imcui/hloc/matchers/loftr.py
@@ -0,0 +1,71 @@
+import warnings
+
+import torch
+from kornia.feature import LoFTR as LoFTR_
+from kornia.feature.loftr.loftr import default_cfg
+from pathlib import Path
+from .. import logger, MODEL_REPO_ID
+
+from ..utils.base_model import BaseModel
+
+
+class LoFTR(BaseModel):
+ default_conf = {
+ "weights": "outdoor",
+ "match_threshold": 0.2,
+ "sinkhorn_iterations": 20,
+ "max_keypoints": -1,
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ cfg = default_cfg
+ cfg["match_coarse"]["thr"] = conf["match_threshold"]
+ cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
+
+ model_name = conf.get("model_name", None)
+ if model_name is not None and "minima" in model_name:
+ cfg["coarse"]["temp_bug_fix"] = True
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+ self.net = LoFTR_(pretrained=conf["weights"], config=cfg)
+ self.net.load_state_dict(state_dict)
+ logger.info(f"ReLoaded LoFTR(minima) with weights {conf['model_name']}")
+ else:
+ self.net = LoFTR_(pretrained=conf["weights"], config=cfg)
+ logger.info(f"Loaded LoFTR with weights {conf['weights']}")
+
+ def _forward(self, data):
+ # For consistency with hloc pairs, we refine kpts in image0!
+ rename = {
+ "keypoints0": "keypoints1",
+ "keypoints1": "keypoints0",
+ "image0": "image1",
+ "image1": "image0",
+ "mask0": "mask1",
+ "mask1": "mask0",
+ }
+ data_ = {rename[k]: v for k, v in data.items()}
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ pred = self.net(data_)
+
+ scores = pred["confidence"]
+
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(scores) > top_k:
+ keep = torch.argsort(scores, descending=True)[:top_k]
+ pred["keypoints0"], pred["keypoints1"] = (
+ pred["keypoints0"][keep],
+ pred["keypoints1"][keep],
+ )
+ scores = scores[keep]
+
+ # Switch back indices
+ pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
+ pred["scores"] = scores
+ del pred["confidence"]
+ return pred
diff --git a/imcui/hloc/matchers/mast3r.py b/imcui/hloc/matchers/mast3r.py
new file mode 100644
index 0000000000000000000000000000000000000000..47a5a3ffdd6855332f61012ef97037a9f6fe469e
--- /dev/null
+++ b/imcui/hloc/matchers/mast3r.py
@@ -0,0 +1,96 @@
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+import torchvision.transforms as tfm
+
+from .. import DEVICE, MODEL_REPO_ID, logger
+
+mast3r_path = Path(__file__).parent / "../../third_party/mast3r"
+sys.path.append(str(mast3r_path))
+
+dust3r_path = Path(__file__).parent / "../../third_party/dust3r"
+sys.path.append(str(dust3r_path))
+
+from dust3r.image_pairs import make_pairs
+from dust3r.inference import inference
+from mast3r.fast_nn import fast_reciprocal_NNs
+from mast3r.model import AsymmetricMASt3R
+
+from .duster import Duster
+
+
+class Mast3r(Duster):
+ default_conf = {
+ "name": "Mast3r",
+ "model_name": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
+ "max_keypoints": 2000,
+ "vit_patch_size": 16,
+ }
+
+ def _init(self, conf):
+ self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ self.net = AsymmetricMASt3R.from_pretrained(model_path).to(DEVICE)
+ logger.info("Loaded Mast3r model")
+
+ def _forward(self, data):
+ img0, img1 = data["image0"], data["image1"]
+ mean = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE)
+ std = torch.tensor([0.5, 0.5, 0.5]).to(DEVICE)
+
+ img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
+ img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
+
+ images = [
+ {"img": img0, "idx": 0, "instance": 0},
+ {"img": img1, "idx": 1, "instance": 1},
+ ]
+ pairs = make_pairs(
+ images, scene_graph="complete", prefilter=None, symmetrize=True
+ )
+ output = inference(pairs, self.net, DEVICE, batch_size=1)
+
+ # at this stage, you have the raw dust3r predictions
+ _, pred1 = output["view1"], output["pred1"]
+ _, pred2 = output["view2"], output["pred2"]
+
+ desc1, desc2 = (
+ pred1["desc"][1].squeeze(0).detach(),
+ pred2["desc"][1].squeeze(0).detach(),
+ )
+
+ # find 2D-2D matches between the two images
+ matches_im0, matches_im1 = fast_reciprocal_NNs(
+ desc1,
+ desc2,
+ subsample_or_initxy1=2,
+ device=DEVICE,
+ dist="dot",
+ block_size=2**13,
+ )
+
+ mkpts0 = matches_im0.copy()
+ mkpts1 = matches_im1.copy()
+
+ if len(mkpts0) == 0:
+ pred = {
+ "keypoints0": torch.zeros([0, 2]),
+ "keypoints1": torch.zeros([0, 2]),
+ }
+ logger.warning(f"Matched {0} points")
+ else:
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(mkpts0) > top_k:
+ keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
+ mkpts0 = mkpts0[keep]
+ mkpts1 = mkpts1[keep]
+ pred = {
+ "keypoints0": torch.from_numpy(mkpts0),
+ "keypoints1": torch.from_numpy(mkpts1),
+ }
+ return pred
diff --git a/imcui/hloc/matchers/mickey.py b/imcui/hloc/matchers/mickey.py
new file mode 100644
index 0000000000000000000000000000000000000000..d18e908ee64b01ab394b8533b1e5257791424f4e
--- /dev/null
+++ b/imcui/hloc/matchers/mickey.py
@@ -0,0 +1,50 @@
+import sys
+from pathlib import Path
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+mickey_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(mickey_path))
+
+from mickey.config.default import cfg
+from mickey.lib.models.builder import build_model
+
+
+class Mickey(BaseModel):
+ default_conf = {
+ "config_path": "config.yaml",
+ "model_name": "mickey.ckpt",
+ "max_keypoints": 3000,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ # TODO: config path of mickey
+ config_path = model_path.parent / self.conf["config_path"]
+ logger.info("Loading mickey model...")
+ cfg.merge_from_file(config_path)
+ self.net = build_model(cfg, checkpoint=model_path)
+ logger.info("Load Mickey model done.")
+
+ def _forward(self, data):
+ pred = self.net(data)
+ pred = {
+ **pred,
+ **data,
+ }
+ inliers = data["inliers_list"]
+ pred = {
+ "keypoints0": inliers[:, :2],
+ "keypoints1": inliers[:, 2:4],
+ }
+
+ return pred
diff --git a/imcui/hloc/matchers/nearest_neighbor.py b/imcui/hloc/matchers/nearest_neighbor.py
new file mode 100644
index 0000000000000000000000000000000000000000..fab96e780a5c1a1672cdaf5b624ecdb310db23d3
--- /dev/null
+++ b/imcui/hloc/matchers/nearest_neighbor.py
@@ -0,0 +1,66 @@
+import torch
+
+from ..utils.base_model import BaseModel
+
+
+def find_nn(sim, ratio_thresh, distance_thresh):
+ sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
+ dist_nn = 2 * (1 - sim_nn)
+ mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
+ if ratio_thresh:
+ mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1])
+ if distance_thresh:
+ mask = mask & (dist_nn[..., 0] <= distance_thresh**2)
+ matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
+ scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0))
+ return matches, scores
+
+
+def mutual_check(m0, m1):
+ inds0 = torch.arange(m0.shape[-1], device=m0.device)
+ loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
+ ok = (m0 > -1) & (inds0 == loop)
+ m0_new = torch.where(ok, m0, m0.new_tensor(-1))
+ return m0_new
+
+
+class NearestNeighbor(BaseModel):
+ default_conf = {
+ "ratio_threshold": None,
+ "distance_threshold": None,
+ "do_mutual_check": True,
+ }
+ required_inputs = ["descriptors0", "descriptors1"]
+
+ def _init(self, conf):
+ pass
+
+ def _forward(self, data):
+ if data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0:
+ matches0 = torch.full(
+ data["descriptors0"].shape[:2],
+ -1,
+ device=data["descriptors0"].device,
+ )
+ return {
+ "matches0": matches0,
+ "matching_scores0": torch.zeros_like(matches0),
+ }
+ ratio_threshold = self.conf["ratio_threshold"]
+ if data["descriptors0"].size(-1) == 1 or data["descriptors1"].size(-1) == 1:
+ ratio_threshold = None
+ sim = torch.einsum("bdn,bdm->bnm", data["descriptors0"], data["descriptors1"])
+ matches0, scores0 = find_nn(
+ sim, ratio_threshold, self.conf["distance_threshold"]
+ )
+ if self.conf["do_mutual_check"]:
+ matches1, scores1 = find_nn(
+ sim.transpose(1, 2),
+ ratio_threshold,
+ self.conf["distance_threshold"],
+ )
+ matches0 = mutual_check(matches0, matches1)
+ return {
+ "matches0": matches0,
+ "matching_scores0": scores0,
+ }
diff --git a/imcui/hloc/matchers/omniglue.py b/imcui/hloc/matchers/omniglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..07539535ff61ca9a3bdc075926995d2319a70fee
--- /dev/null
+++ b/imcui/hloc/matchers/omniglue.py
@@ -0,0 +1,80 @@
+import sys
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+thirdparty_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(thirdparty_path))
+from omniglue.src import omniglue
+
+omniglue_path = thirdparty_path / "omniglue"
+
+
+class OmniGlue(BaseModel):
+ default_conf = {
+ "match_threshold": 0.02,
+ "max_keypoints": 2048,
+ }
+ required_inputs = ["image0", "image1"]
+ dino_v2_link_dict = {
+ "dinov2_vitb14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth"
+ }
+
+ def _init(self, conf):
+ logger.info("Loading OmniGlue model")
+ og_model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, "omniglue.onnx"),
+ )
+ sp_model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, "sp_v6.onnx"),
+ )
+ dino_model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, "dinov2_vitb14_pretrain.pth"),
+ )
+
+ self.net = omniglue.OmniGlue(
+ og_export=str(og_model_path),
+ sp_export=str(sp_model_path),
+ dino_export=str(dino_model_path),
+ max_keypoints=self.conf["max_keypoints"],
+ )
+ logger.info("Loaded OmniGlue model done!")
+
+ def _forward(self, data):
+ image0_rgb_np = data["image0"][0].permute(1, 2, 0).cpu().numpy() * 255
+ image1_rgb_np = data["image1"][0].permute(1, 2, 0).cpu().numpy() * 255
+ image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
+ image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
+ match_kp0, match_kp1, match_confidences = self.net.FindMatches(
+ image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"]
+ )
+ # filter matches
+ match_threshold = self.conf["match_threshold"]
+ keep_idx = []
+ for i in range(match_kp0.shape[0]):
+ if match_confidences[i] > match_threshold:
+ keep_idx.append(i)
+ scores = torch.from_numpy(match_confidences[keep_idx]).reshape(-1, 1)
+ pred = {
+ "keypoints0": torch.from_numpy(match_kp0[keep_idx]),
+ "keypoints1": torch.from_numpy(match_kp1[keep_idx]),
+ "mconf": scores,
+ }
+
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(scores) > top_k:
+ keep = torch.argsort(scores, descending=True)[:top_k]
+ scores = scores[keep]
+ pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
+ pred["keypoints0"][keep],
+ pred["keypoints1"][keep],
+ scores,
+ )
+ return pred
diff --git a/imcui/hloc/matchers/roma.py b/imcui/hloc/matchers/roma.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb83e1c22a8d22d85fa39a989ed4a8897989354
--- /dev/null
+++ b/imcui/hloc/matchers/roma.py
@@ -0,0 +1,80 @@
+import sys
+from pathlib import Path
+
+import torch
+from PIL import Image
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+roma_path = Path(__file__).parent / "../../third_party/RoMa"
+sys.path.append(str(roma_path))
+from romatch.models.model_zoo import roma_model
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Roma(BaseModel):
+ default_conf = {
+ "name": "two_view_pipeline",
+ "model_name": "roma_outdoor.pth",
+ "model_utils_name": "dinov2_vitl14_pretrain.pth",
+ "max_keypoints": 3000,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+
+ dinov2_weights = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_utils_name"]),
+ )
+
+ logger.info("Loading Roma model")
+ # load the model
+ weights = torch.load(model_path, map_location="cpu")
+ dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
+
+ self.net = roma_model(
+ resolution=(14 * 8 * 6, 14 * 8 * 6),
+ upsample_preds=False,
+ weights=weights,
+ dinov2_weights=dinov2_weights,
+ device=device,
+ # temp fix issue: https://github.com/Parskatt/RoMa/issues/26
+ amp_dtype=torch.float32,
+ )
+ logger.info("Load Roma model done.")
+
+ def _forward(self, data):
+ img0 = data["image0"].cpu().numpy().squeeze() * 255
+ img1 = data["image1"].cpu().numpy().squeeze() * 255
+ img0 = img0.transpose(1, 2, 0)
+ img1 = img1.transpose(1, 2, 0)
+ img0 = Image.fromarray(img0.astype("uint8"))
+ img1 = Image.fromarray(img1.astype("uint8"))
+ W_A, H_A = img0.size
+ W_B, H_B = img1.size
+
+ # Match
+ warp, certainty = self.net.match(img0, img1, device=device)
+ # Sample matches for estimation
+ matches, certainty = self.net.sample(
+ warp, certainty, num=self.conf["max_keypoints"]
+ )
+ kpts1, kpts2 = self.net.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ pred = {
+ "keypoints0": kpts1,
+ "keypoints1": kpts2,
+ "mconf": certainty,
+ }
+
+ return pred
diff --git a/imcui/hloc/matchers/sgmnet.py b/imcui/hloc/matchers/sgmnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1141e4be0b74a5dc74f4cf1b5189ef4893a8cef
--- /dev/null
+++ b/imcui/hloc/matchers/sgmnet.py
@@ -0,0 +1,106 @@
+import sys
+from collections import OrderedDict, namedtuple
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet"
+sys.path.append(str(sgmnet_path))
+
+from sgmnet import matcher as SGM_Model
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class SGMNet(BaseModel):
+ default_conf = {
+ "name": "SGM",
+ "model_name": "weights/sgm/root/model_best.pth",
+ "seed_top_k": [256, 256],
+ "seed_radius_coe": 0.01,
+ "net_channels": 128,
+ "layer_num": 9,
+ "head": 4,
+ "seedlayer": [0, 6],
+ "use_mc_seeding": True,
+ "use_score_encoding": False,
+ "conf_bar": [1.11, 0.1],
+ "sink_iter": [10, 100],
+ "detach_iter": 1000000,
+ "match_threshold": 0.2,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+
+ # config
+ config = namedtuple("config", conf.keys())(*conf.values())
+ self.net = SGM_Model(config)
+ checkpoint = torch.load(model_path, map_location="cpu")
+ # for ddp model
+ if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module":
+ new_stat_dict = OrderedDict()
+ for key, value in checkpoint["state_dict"].items():
+ new_stat_dict[key[7:]] = value
+ checkpoint["state_dict"] = new_stat_dict
+ self.net.load_state_dict(checkpoint["state_dict"])
+ logger.info("Load SGMNet model done.")
+
+ def _forward(self, data):
+ x1 = data["keypoints0"].squeeze() # N x 2
+ x2 = data["keypoints1"].squeeze()
+ score1 = data["scores0"].reshape(-1, 1) # N x 1
+ score2 = data["scores1"].reshape(-1, 1)
+ desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128
+ desc2 = data["descriptors1"].permute(0, 2, 1)
+ size1 = (
+ torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device)
+ ) # W x H -> x & y
+ size2 = torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device) # W x H
+ norm_x1 = self.normalize_size(x1, size1)
+ norm_x2 = self.normalize_size(x2, size2)
+
+ x1 = torch.cat((norm_x1, score1), dim=-1) # N x 3
+ x2 = torch.cat((norm_x2, score2), dim=-1)
+ input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2}
+ input = {
+ k: v.to(device).float() if isinstance(v, torch.Tensor) else v
+ for k, v in input.items()
+ }
+ pred = self.net(input, test_mode=True)
+
+ p = pred["p"] # shape: N * M
+ indices0 = self.match_p(p[0, :-1, :-1])
+ pred = {
+ "matches0": indices0.unsqueeze(0),
+ "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0),
+ }
+ return pred
+
+ def match_p(self, p):
+ score, index = torch.topk(p, k=1, dim=-1)
+ _, index2 = torch.topk(p, k=1, dim=-2)
+ mask_th, index, index2 = (
+ score[:, 0] > self.conf["match_threshold"],
+ index[:, 0],
+ index2.squeeze(0),
+ )
+ mask_mc = index2[index] == torch.arange(len(p)).to(device)
+ mask = mask_th & mask_mc
+ indices0 = torch.where(mask, index, index.new_tensor(-1))
+ return indices0
+
+ def normalize_size(self, x, size, scale=1):
+ norm_fac = size.max()
+ return (x - size / 2 + 0.5) / (norm_fac * scale)
diff --git a/imcui/hloc/matchers/sold2.py b/imcui/hloc/matchers/sold2.py
new file mode 100644
index 0000000000000000000000000000000000000000..daed4f029f4fcd23771ffe4a848ed12bc0b81478
--- /dev/null
+++ b/imcui/hloc/matchers/sold2.py
@@ -0,0 +1,144 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID, logger
+from ..utils.base_model import BaseModel
+
+sold2_path = Path(__file__).parent / "../../third_party/SOLD2"
+sys.path.append(str(sold2_path))
+
+from sold2.model.line_matcher import LineMatcher
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class SOLD2(BaseModel):
+ default_conf = {
+ "model_name": "sold2_wireframe.tar",
+ "match_threshold": 0.2,
+ "checkpoint_dir": sold2_path / "pretrained",
+ "detect_thresh": 0.25,
+ "multiscale": False,
+ "valid_thresh": 1e-3,
+ "num_blocks": 20,
+ "overlap_ratio": 0.5,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ # Initialize the line matcher
+ def _init(self, conf):
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ logger.info("Loading SOLD2 model: {}".format(model_path))
+
+ mode = "dynamic" # 'dynamic' or 'static'
+ match_config = {
+ "model_cfg": {
+ "model_name": "lcnn_simple",
+ "model_architecture": "simple",
+ # Backbone related config
+ "backbone": "lcnn",
+ "backbone_cfg": {
+ "input_channel": 1, # Use RGB images or grayscale images.
+ "depth": 4,
+ "num_stacks": 2,
+ "num_blocks": 1,
+ "num_classes": 5,
+ },
+ # Junction decoder related config
+ "junction_decoder": "superpoint_decoder",
+ "junc_decoder_cfg": {},
+ # Heatmap decoder related config
+ "heatmap_decoder": "pixel_shuffle",
+ "heatmap_decoder_cfg": {},
+ # Descriptor decoder related config
+ "descriptor_decoder": "superpoint_descriptor",
+ "descriptor_decoder_cfg": {},
+ # Shared configurations
+ "grid_size": 8,
+ "keep_border_valid": True,
+ # Threshold of junction detection
+ "detection_thresh": 0.0153846, # 1/65
+ "max_num_junctions": 300,
+ # Threshold of heatmap detection
+ "prob_thresh": 0.5,
+ # Weighting related parameters
+ "weighting_policy": mode,
+ # [Heatmap loss]
+ "w_heatmap": 0.0,
+ "w_heatmap_class": 1,
+ "heatmap_loss_func": "cross_entropy",
+ "heatmap_loss_cfg": {"policy": mode},
+ # [Heatmap consistency loss]
+ # [Junction loss]
+ "w_junc": 0.0,
+ "junction_loss_func": "superpoint",
+ "junction_loss_cfg": {"policy": mode},
+ # [Descriptor loss]
+ "w_desc": 0.0,
+ "descriptor_loss_func": "regular_sampling",
+ "descriptor_loss_cfg": {
+ "dist_threshold": 8,
+ "grid_size": 4,
+ "margin": 1,
+ "policy": mode,
+ },
+ },
+ "line_detector_cfg": {
+ "detect_thresh": 0.25, # depending on your images, you might need to tune this parameter
+ "num_samples": 64,
+ "sampling_method": "local_max",
+ "inlier_thresh": 0.9,
+ "use_candidate_suppression": True,
+ "nms_dist_tolerance": 3.0,
+ "use_heatmap_refinement": True,
+ "heatmap_refine_cfg": {
+ "mode": "local",
+ "ratio": 0.2,
+ "valid_thresh": 1e-3,
+ "num_blocks": 20,
+ "overlap_ratio": 0.5,
+ },
+ },
+ "multiscale": False,
+ "line_matcher_cfg": {
+ "cross_check": True,
+ "num_samples": 5,
+ "min_dist_pts": 8,
+ "top_k_candidates": 10,
+ "grid_size": 4,
+ },
+ }
+ self.net = LineMatcher(
+ match_config["model_cfg"],
+ model_path,
+ device,
+ match_config["line_detector_cfg"],
+ match_config["line_matcher_cfg"],
+ match_config["multiscale"],
+ )
+
+ def _forward(self, data):
+ img0 = data["image0"]
+ img1 = data["image1"]
+ pred = self.net([img0, img1])
+ line_seg1 = pred["line_segments"][0]
+ line_seg2 = pred["line_segments"][1]
+ matches = pred["matches"]
+
+ valid_matches = matches != -1
+ match_indices = matches[valid_matches]
+ matched_lines1 = line_seg1[valid_matches][:, :, ::-1]
+ matched_lines2 = line_seg2[match_indices][:, :, ::-1]
+
+ pred["raw_lines0"], pred["raw_lines1"] = line_seg1, line_seg2
+ pred["lines0"], pred["lines1"] = matched_lines1, matched_lines2
+ pred = {**pred, **data}
+ return pred
diff --git a/imcui/hloc/matchers/superglue.py b/imcui/hloc/matchers/superglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fae344e8dfe0fd46f090c6915036e5f9c09a635
--- /dev/null
+++ b/imcui/hloc/matchers/superglue.py
@@ -0,0 +1,33 @@
+import sys
+from pathlib import Path
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from SuperGluePretrainedNetwork.models.superglue import ( # noqa: E402
+ SuperGlue as SG,
+)
+
+
+class SuperGlue(BaseModel):
+ default_conf = {
+ "weights": "outdoor",
+ "sinkhorn_iterations": 100,
+ "match_threshold": 0.2,
+ }
+ required_inputs = [
+ "image0",
+ "keypoints0",
+ "scores0",
+ "descriptors0",
+ "image1",
+ "keypoints1",
+ "scores1",
+ "descriptors1",
+ ]
+
+ def _init(self, conf):
+ self.net = SG(conf)
+
+ def _forward(self, data):
+ return self.net(data)
diff --git a/imcui/hloc/matchers/topicfm.py b/imcui/hloc/matchers/topicfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c99adc740e82cdcd644e18fff450a4efeaaf9bc
--- /dev/null
+++ b/imcui/hloc/matchers/topicfm.py
@@ -0,0 +1,60 @@
+import sys
+from pathlib import Path
+
+import torch
+
+from .. import MODEL_REPO_ID
+
+from ..utils.base_model import BaseModel
+
+sys.path.append(str(Path(__file__).parent / "../../third_party"))
+from TopicFM.src import get_model_cfg
+from TopicFM.src.models.topic_fm import TopicFM as _TopicFM
+
+topicfm_path = Path(__file__).parent / "../../third_party/TopicFM"
+
+
+class TopicFM(BaseModel):
+ default_conf = {
+ "weights": "outdoor",
+ "model_name": "model_best.ckpt",
+ "match_threshold": 0.2,
+ "n_sampling_topics": 4,
+ "max_keypoints": -1,
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ _conf = dict(get_model_cfg())
+ _conf["match_coarse"]["thr"] = conf["match_threshold"]
+ _conf["coarse"]["n_samples"] = conf["n_sampling_topics"]
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+ self.net = _TopicFM(config=_conf)
+ ckpt_dict = torch.load(model_path, map_location="cpu")
+ self.net.load_state_dict(ckpt_dict["state_dict"])
+
+ def _forward(self, data):
+ data_ = {
+ "image0": data["image0"],
+ "image1": data["image1"],
+ }
+ self.net(data_)
+ pred = {
+ "keypoints0": data_["mkpts0_f"],
+ "keypoints1": data_["mkpts1_f"],
+ "mconf": data_["mconf"],
+ }
+ scores = data_["mconf"]
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(scores) > top_k:
+ keep = torch.argsort(scores, descending=True)[:top_k]
+ scores = scores[keep]
+ pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
+ pred["keypoints0"][keep],
+ pred["keypoints1"][keep],
+ scores,
+ )
+ return pred
diff --git a/imcui/hloc/matchers/xfeat_dense.py b/imcui/hloc/matchers/xfeat_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b6956636c126d1d482c1dc3ff18e23978a9163b
--- /dev/null
+++ b/imcui/hloc/matchers/xfeat_dense.py
@@ -0,0 +1,54 @@
+import torch
+
+from .. import logger
+
+from ..utils.base_model import BaseModel
+
+
+class XFeatDense(BaseModel):
+ default_conf = {
+ "keypoint_threshold": 0.005,
+ "max_keypoints": 8000,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "verlab/accelerated_features",
+ "XFeat",
+ pretrained=True,
+ top_k=self.conf["max_keypoints"],
+ )
+ logger.info("Load XFeat(dense) model done.")
+
+ def _forward(self, data):
+ # Compute coarse feats
+ out0 = self.net.detectAndComputeDense(
+ data["image0"], top_k=self.conf["max_keypoints"]
+ )
+ out1 = self.net.detectAndComputeDense(
+ data["image1"], top_k=self.conf["max_keypoints"]
+ )
+
+ # Match batches of pairs
+ idxs_list = self.net.batch_match(out0["descriptors"], out1["descriptors"])
+ B = len(data["image0"])
+
+ # Refine coarse matches
+ # this part is harder to batch, currently iterate
+ matches = []
+ for b in range(B):
+ matches.append(
+ self.net.refine_matches(out0, out1, matches=idxs_list, batch_idx=b)
+ )
+ # we use results from one batch
+ matches = matches[0]
+ pred = {
+ "keypoints0": matches[:, :2],
+ "keypoints1": matches[:, 2:],
+ "mconf": torch.ones_like(matches[:, 0]),
+ }
+ return pred
diff --git a/imcui/hloc/matchers/xfeat_lightglue.py b/imcui/hloc/matchers/xfeat_lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..6041d29e1473cc2b0cebc3ca64f2b01df2753f06
--- /dev/null
+++ b/imcui/hloc/matchers/xfeat_lightglue.py
@@ -0,0 +1,48 @@
+import torch
+
+from .. import logger
+
+from ..utils.base_model import BaseModel
+
+
+class XFeatLightGlue(BaseModel):
+ default_conf = {
+ "keypoint_threshold": 0.005,
+ "max_keypoints": 8000,
+ }
+ required_inputs = [
+ "image0",
+ "image1",
+ ]
+
+ def _init(self, conf):
+ self.net = torch.hub.load(
+ "verlab/accelerated_features",
+ "XFeat",
+ pretrained=True,
+ top_k=self.conf["max_keypoints"],
+ )
+ logger.info("Load XFeat(dense) model done.")
+
+ def _forward(self, data):
+ # we use results from one batch
+ im0 = data["image0"]
+ im1 = data["image1"]
+ # Compute coarse feats
+ out0 = self.net.detectAndCompute(im0, top_k=self.conf["max_keypoints"])[0]
+ out1 = self.net.detectAndCompute(im1, top_k=self.conf["max_keypoints"])[0]
+ out0.update({"image_size": (im0.shape[-1], im0.shape[-2])}) # W H
+ out1.update({"image_size": (im1.shape[-1], im1.shape[-2])}) # W H
+ pred = self.net.match_lighterglue(out0, out1)
+ if len(pred) == 3:
+ mkpts_0, mkpts_1, _ = pred
+ else:
+ mkpts_0, mkpts_1 = pred
+ mkpts_0 = torch.from_numpy(mkpts_0) # n x 2
+ mkpts_1 = torch.from_numpy(mkpts_1) # n x 2
+ pred = {
+ "keypoints0": mkpts_0,
+ "keypoints1": mkpts_1,
+ "mconf": torch.ones_like(mkpts_0[:, 0]),
+ }
+ return pred
diff --git a/imcui/hloc/matchers/xoftr.py b/imcui/hloc/matchers/xoftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..135f67f811468a13fb172bf06115aafb3084ccfb
--- /dev/null
+++ b/imcui/hloc/matchers/xoftr.py
@@ -0,0 +1,90 @@
+import sys
+import warnings
+from pathlib import Path
+
+import torch
+
+from .. import DEVICE, MODEL_REPO_ID, logger
+
+tp_path = Path(__file__).parent / "../../third_party"
+sys.path.append(str(tp_path))
+
+from XoFTR.src.config.default import get_cfg_defaults
+from XoFTR.src.utils.misc import lower_config
+from XoFTR.src.xoftr import XoFTR as XoFTR_
+
+
+from ..utils.base_model import BaseModel
+
+
+class XoFTR(BaseModel):
+ default_conf = {
+ "model_name": "weights_xoftr_640.ckpt",
+ "match_threshold": 0.3,
+ "max_keypoints": -1,
+ }
+ required_inputs = ["image0", "image1"]
+
+ def _init(self, conf):
+ # Get default configurations
+ config_ = get_cfg_defaults(inference=True)
+ config_ = lower_config(config_)
+
+ # Coarse level threshold
+ config_["xoftr"]["match_coarse"]["thr"] = self.conf["match_threshold"]
+
+ # Fine level threshold
+ config_["xoftr"]["fine"]["thr"] = 0.1 # Default 0.1
+
+ # It is posseble to get denser matches
+ # If True, xoftr returns all fine-level matches for each fine-level window (at 1/2 resolution)
+ config_["xoftr"]["fine"]["denser"] = False # Default False
+
+ # XoFTR model
+ matcher = XoFTR_(config=config_["xoftr"])
+
+ model_path = self._download_model(
+ repo_id=MODEL_REPO_ID,
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
+ )
+
+ # Load model
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+ matcher.load_state_dict(state_dict, strict=True)
+ matcher = matcher.eval().to(DEVICE)
+ self.net = matcher
+ logger.info(f"Loaded XoFTR with weights {conf['model_name']}")
+
+ def _forward(self, data):
+ # For consistency with hloc pairs, we refine kpts in image0!
+ rename = {
+ "keypoints0": "keypoints1",
+ "keypoints1": "keypoints0",
+ "image0": "image1",
+ "image1": "image0",
+ "mask0": "mask1",
+ "mask1": "mask0",
+ }
+ data_ = {rename[k]: v for k, v in data.items()}
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ pred = self.net(data_)
+ pred = {
+ "keypoints0": data_["mkpts0_f"],
+ "keypoints1": data_["mkpts1_f"],
+ }
+ scores = data_["mconf_f"]
+
+ top_k = self.conf["max_keypoints"]
+ if top_k is not None and len(scores) > top_k:
+ keep = torch.argsort(scores, descending=True)[:top_k]
+ pred["keypoints0"], pred["keypoints1"] = (
+ pred["keypoints0"][keep],
+ pred["keypoints1"][keep],
+ )
+ scores = scores[keep]
+
+ # Switch back indices
+ pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
+ pred["scores"] = scores
+ return pred
diff --git a/imcui/hloc/pairs_from_covisibility.py b/imcui/hloc/pairs_from_covisibility.py
new file mode 100644
index 0000000000000000000000000000000000000000..49f3e57f2bd1aec20e12ecca6df8f94a68b7fd4e
--- /dev/null
+++ b/imcui/hloc/pairs_from_covisibility.py
@@ -0,0 +1,60 @@
+import argparse
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+from tqdm import tqdm
+
+from . import logger
+from .utils.read_write_model import read_model
+
+
+def main(model, output, num_matched):
+ logger.info("Reading the COLMAP model...")
+ cameras, images, points3D = read_model(model)
+
+ logger.info("Extracting image pairs from covisibility info...")
+ pairs = []
+ for image_id, image in tqdm(images.items()):
+ matched = image.point3D_ids != -1
+ points3D_covis = image.point3D_ids[matched]
+
+ covis = defaultdict(int)
+ for point_id in points3D_covis:
+ for image_covis_id in points3D[point_id].image_ids:
+ if image_covis_id != image_id:
+ covis[image_covis_id] += 1
+
+ if len(covis) == 0:
+ logger.info(f"Image {image_id} does not have any covisibility.")
+ continue
+
+ covis_ids = np.array(list(covis.keys()))
+ covis_num = np.array([covis[i] for i in covis_ids])
+
+ if len(covis_ids) <= num_matched:
+ top_covis_ids = covis_ids[np.argsort(-covis_num)]
+ else:
+ # get covisible image ids with top k number of common matches
+ ind_top = np.argpartition(covis_num, -num_matched)
+ ind_top = ind_top[-num_matched:] # unsorted top k
+ ind_top = ind_top[np.argsort(-covis_num[ind_top])]
+ top_covis_ids = [covis_ids[i] for i in ind_top]
+ assert covis_num[ind_top[0]] == np.max(covis_num)
+
+ for i in top_covis_ids:
+ pair = (image.name, images[i].name)
+ pairs.append(pair)
+
+ logger.info(f"Found {len(pairs)} pairs.")
+ with open(output, "w") as f:
+ f.write("\n".join(" ".join([i, j]) for i, j in pairs))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True, type=Path)
+ parser.add_argument("--output", required=True, type=Path)
+ parser.add_argument("--num_matched", required=True, type=int)
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/pairs_from_exhaustive.py b/imcui/hloc/pairs_from_exhaustive.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d54ed1dcdbb16d490fcadf9ac2577fd064c3828
--- /dev/null
+++ b/imcui/hloc/pairs_from_exhaustive.py
@@ -0,0 +1,64 @@
+import argparse
+import collections.abc as collections
+from pathlib import Path
+from typing import List, Optional, Union
+
+from . import logger
+from .utils.io import list_h5_names
+from .utils.parsers import parse_image_lists
+
+
+def main(
+ output: Path,
+ image_list: Optional[Union[Path, List[str]]] = None,
+ features: Optional[Path] = None,
+ ref_list: Optional[Union[Path, List[str]]] = None,
+ ref_features: Optional[Path] = None,
+):
+ if image_list is not None:
+ if isinstance(image_list, (str, Path)):
+ names_q = parse_image_lists(image_list)
+ elif isinstance(image_list, collections.Iterable):
+ names_q = list(image_list)
+ else:
+ raise ValueError(f"Unknown type for image list: {image_list}")
+ elif features is not None:
+ names_q = list_h5_names(features)
+ else:
+ raise ValueError("Provide either a list of images or a feature file.")
+
+ self_matching = False
+ if ref_list is not None:
+ if isinstance(ref_list, (str, Path)):
+ names_ref = parse_image_lists(ref_list)
+ elif isinstance(image_list, collections.Iterable):
+ names_ref = list(ref_list)
+ else:
+ raise ValueError(f"Unknown type for reference image list: {ref_list}")
+ elif ref_features is not None:
+ names_ref = list_h5_names(ref_features)
+ else:
+ self_matching = True
+ names_ref = names_q
+
+ pairs = []
+ for i, n1 in enumerate(names_q):
+ for j, n2 in enumerate(names_ref):
+ if self_matching and j <= i:
+ continue
+ pairs.append((n1, n2))
+
+ logger.info(f"Found {len(pairs)} pairs.")
+ with open(output, "w") as f:
+ f.write("\n".join(" ".join([i, j]) for i, j in pairs))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output", required=True, type=Path)
+ parser.add_argument("--image_list", type=Path)
+ parser.add_argument("--features", type=Path)
+ parser.add_argument("--ref_list", type=Path)
+ parser.add_argument("--ref_features", type=Path)
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/pairs_from_poses.py b/imcui/hloc/pairs_from_poses.py
new file mode 100644
index 0000000000000000000000000000000000000000..83ee1b8cce2b680fac9a4de35d68c5f234092361
--- /dev/null
+++ b/imcui/hloc/pairs_from_poses.py
@@ -0,0 +1,68 @@
+import argparse
+from pathlib import Path
+
+import numpy as np
+import scipy.spatial
+
+from . import logger
+from .pairs_from_retrieval import pairs_from_score_matrix
+from .utils.read_write_model import read_images_binary
+
+DEFAULT_ROT_THRESH = 30 # in degrees
+
+
+def get_pairwise_distances(images):
+ ids = np.array(list(images.keys()))
+ Rs = []
+ ts = []
+ for id_ in ids:
+ image = images[id_]
+ R = image.qvec2rotmat()
+ t = image.tvec
+ Rs.append(R)
+ ts.append(t)
+ Rs = np.stack(Rs, 0)
+ ts = np.stack(ts, 0)
+
+ # Invert the poses from world-to-camera to camera-to-world.
+ Rs = Rs.transpose(0, 2, 1)
+ ts = -(Rs @ ts[:, :, None])[:, :, 0]
+
+ dist = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(ts))
+
+ # Instead of computing the angle between two camera orientations,
+ # we compute the angle between the principal axes, as two images rotated
+ # around their principal axis still observe the same scene.
+ axes = Rs[:, :, -1]
+ dots = np.einsum("mi,ni->mn", axes, axes, optimize=True)
+ dR = np.rad2deg(np.arccos(np.clip(dots, -1.0, 1.0)))
+
+ return ids, dist, dR
+
+
+def main(model, output, num_matched, rotation_threshold=DEFAULT_ROT_THRESH):
+ logger.info("Reading the COLMAP model...")
+ images = read_images_binary(model / "images.bin")
+
+ logger.info(f"Obtaining pairwise distances between {len(images)} images...")
+ ids, dist, dR = get_pairwise_distances(images)
+ scores = -dist
+
+ invalid = dR >= rotation_threshold
+ np.fill_diagonal(invalid, True)
+ pairs = pairs_from_score_matrix(scores, invalid, num_matched)
+ pairs = [(images[ids[i]].name, images[ids[j]].name) for i, j in pairs]
+
+ logger.info(f"Found {len(pairs)} pairs.")
+ with open(output, "w") as f:
+ f.write("\n".join(" ".join(p) for p in pairs))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True, type=Path)
+ parser.add_argument("--output", required=True, type=Path)
+ parser.add_argument("--num_matched", required=True, type=int)
+ parser.add_argument("--rotation_threshold", default=DEFAULT_ROT_THRESH, type=float)
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/pairs_from_retrieval.py b/imcui/hloc/pairs_from_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..323368011086b10065aba177a360284f558904e8
--- /dev/null
+++ b/imcui/hloc/pairs_from_retrieval.py
@@ -0,0 +1,133 @@
+import argparse
+import collections.abc as collections
+from pathlib import Path
+from typing import Optional
+
+import h5py
+import numpy as np
+import torch
+
+from . import logger
+from .utils.io import list_h5_names
+from .utils.parsers import parse_image_lists
+from .utils.read_write_model import read_images_binary
+
+
+def parse_names(prefix, names, names_all):
+ if prefix is not None:
+ if not isinstance(prefix, str):
+ prefix = tuple(prefix)
+ names = [n for n in names_all if n.startswith(prefix)]
+ if len(names) == 0:
+ raise ValueError(f"Could not find any image with the prefix `{prefix}`.")
+ elif names is not None:
+ if isinstance(names, (str, Path)):
+ names = parse_image_lists(names)
+ elif isinstance(names, collections.Iterable):
+ names = list(names)
+ else:
+ raise ValueError(
+ f"Unknown type of image list: {names}."
+ "Provide either a list or a path to a list file."
+ )
+ else:
+ names = names_all
+ return names
+
+
+def get_descriptors(names, path, name2idx=None, key="global_descriptor"):
+ if name2idx is None:
+ with h5py.File(str(path), "r", libver="latest") as fd:
+ desc = [fd[n][key].__array__() for n in names]
+ else:
+ desc = []
+ for n in names:
+ with h5py.File(str(path[name2idx[n]]), "r", libver="latest") as fd:
+ desc.append(fd[n][key].__array__())
+ return torch.from_numpy(np.stack(desc, 0)).float()
+
+
+def pairs_from_score_matrix(
+ scores: torch.Tensor,
+ invalid: np.array,
+ num_select: int,
+ min_score: Optional[float] = None,
+):
+ assert scores.shape == invalid.shape
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ invalid = torch.from_numpy(invalid).to(scores.device)
+ if min_score is not None:
+ invalid |= scores < min_score
+ scores.masked_fill_(invalid, float("-inf"))
+
+ topk = torch.topk(scores, num_select, dim=1)
+ indices = topk.indices.cpu().numpy()
+ valid = topk.values.isfinite().cpu().numpy()
+
+ pairs = []
+ for i, j in zip(*np.where(valid)):
+ pairs.append((i, indices[i, j]))
+ return pairs
+
+
+def main(
+ descriptors,
+ output,
+ num_matched,
+ query_prefix=None,
+ query_list=None,
+ db_prefix=None,
+ db_list=None,
+ db_model=None,
+ db_descriptors=None,
+):
+ logger.info("Extracting image pairs from a retrieval database.")
+
+ # We handle multiple reference feature files.
+ # We only assume that names are unique among them and map names to files.
+ if db_descriptors is None:
+ db_descriptors = descriptors
+ if isinstance(db_descriptors, (Path, str)):
+ db_descriptors = [db_descriptors]
+ name2db = {n: i for i, p in enumerate(db_descriptors) for n in list_h5_names(p)}
+ db_names_h5 = list(name2db.keys())
+ query_names_h5 = list_h5_names(descriptors)
+
+ if db_model:
+ images = read_images_binary(db_model / "images.bin")
+ db_names = [i.name for i in images.values()]
+ else:
+ db_names = parse_names(db_prefix, db_list, db_names_h5)
+ if len(db_names) == 0:
+ raise ValueError("Could not find any database image.")
+ query_names = parse_names(query_prefix, query_list, query_names_h5)
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ db_desc = get_descriptors(db_names, db_descriptors, name2db)
+ query_desc = get_descriptors(query_names, descriptors)
+ sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device))
+
+ # Avoid self-matching
+ self = np.array(query_names)[:, None] == np.array(db_names)[None]
+ pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0)
+ pairs = [(query_names[i], db_names[j]) for i, j in pairs]
+
+ logger.info(f"Found {len(pairs)} pairs.")
+ with open(output, "w") as f:
+ f.write("\n".join(" ".join([i, j]) for i, j in pairs))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--descriptors", type=Path, required=True)
+ parser.add_argument("--output", type=Path, required=True)
+ parser.add_argument("--num_matched", type=int, required=True)
+ parser.add_argument("--query_prefix", type=str, nargs="+")
+ parser.add_argument("--query_list", type=Path)
+ parser.add_argument("--db_prefix", type=str, nargs="+")
+ parser.add_argument("--db_list", type=Path)
+ parser.add_argument("--db_model", type=Path)
+ parser.add_argument("--db_descriptors", type=Path)
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/pipelines/4Seasons/README.md b/imcui/hloc/pipelines/4Seasons/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad23ac8348ae9f0963611bc9a342240d5ae97255
--- /dev/null
+++ b/imcui/hloc/pipelines/4Seasons/README.md
@@ -0,0 +1,43 @@
+# 4Seasons dataset
+
+This pipeline localizes sequences from the [4Seasons dataset](https://arxiv.org/abs/2009.06364) and can reproduce our winning submission to the challenge of the [ECCV 2020 Workshop on Map-based Localization for Autonomous Driving](https://sites.google.com/view/mlad-eccv2020/home).
+
+## Installation
+
+Download the sequences from the [challenge webpage](https://sites.google.com/view/mlad-eccv2020/challenge) and run:
+```bash
+unzip recording_2020-04-07_10-20-32.zip -d datasets/4Seasons/reference
+unzip recording_2020-03-24_17-36-22.zip -d datasets/4Seasons/training
+unzip recording_2020-03-03_12-03-23.zip -d datasets/4Seasons/validation
+unzip recording_2020-03-24_17-45-31.zip -d datasets/4Seasons/test0
+unzip recording_2020-04-23_19-37-00.zip -d datasets/4Seasons/test1
+```
+Note that the provided scripts might modify the dataset files by deleting unused images to speed up the feature extraction
+
+## Pipeline
+
+The process is presented in our workshop talk, whose recording can be found [here](https://youtu.be/M-X6HX1JxYk?t=5245).
+
+We first triangulate a 3D model from the given poses of the reference sequence:
+```bash
+python3 -m hloc.pipelines.4Seasons.prepare_reference
+```
+
+We then relocalize a given sequence:
+```bash
+python3 -m hloc.pipelines.4Seasons.localize --sequence [training|validation|test0|test1]
+```
+
+The final submission files can be found in `outputs/4Seasons/submission_hloc+superglue/`. The script will also evaluate these results if the training or validation sequences are selected.
+
+## Results
+
+We evaluate the localization recall at distance thresholds 0.1m, 0.2m, and 0.5m.
+
+| Methods | test0 | test1 |
+| -------------------- | ---------------------- | ---------------------- |
+| **hloc + SuperGlue** | **91.8 / 97.7 / 99.2** | **67.3 / 93.5 / 98.7** |
+| Baseline SuperGlue | 21.2 / 33.9 / 60.0 | 12.4 / 26.5 / 54.4 |
+| Baseline R2D2 | 21.5 / 33.1 / 53.0 | 12.3 / 23.7 / 42.0 |
+| Baseline D2Net | 12.5 / 29.3 / 56.7 | 7.5 / 21.4 / 47.7 |
+| Baseline SuperPoint | 15.5 / 27.5 / 47.5 | 9.0 / 19.4 / 36.4 |
diff --git a/imcui/hloc/pipelines/4Seasons/__init__.py b/imcui/hloc/pipelines/4Seasons/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/pipelines/4Seasons/localize.py b/imcui/hloc/pipelines/4Seasons/localize.py
new file mode 100644
index 0000000000000000000000000000000000000000..50ed957bcf159915d0be98fa9b54c8bce0059b56
--- /dev/null
+++ b/imcui/hloc/pipelines/4Seasons/localize.py
@@ -0,0 +1,89 @@
+import argparse
+from pathlib import Path
+
+from ... import extract_features, localize_sfm, logger, match_features
+from .utils import (
+ delete_unused_images,
+ evaluate_submission,
+ generate_localization_pairs,
+ generate_query_lists,
+ get_timestamps,
+ prepare_submission,
+)
+
+relocalization_files = {
+ "training": "RelocalizationFilesTrain//relocalizationFile_recording_2020-03-24_17-36-22.txt", # noqa: E501
+ "validation": "RelocalizationFilesVal/relocalizationFile_recording_2020-03-03_12-03-23.txt", # noqa: E501
+ "test0": "RelocalizationFilesTest/relocalizationFile_recording_2020-03-24_17-45-31_*.txt", # noqa: E501
+ "test1": "RelocalizationFilesTest/relocalizationFile_recording_2020-04-23_19-37-00_*.txt", # noqa: E501
+}
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--sequence",
+ type=str,
+ required=True,
+ choices=["training", "validation", "test0", "test1"],
+ help="Sequence to be relocalized.",
+)
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/4Seasons",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/4Seasons",
+ help="Path to the output directory, default: %(default)s",
+)
+args = parser.parse_args()
+sequence = args.sequence
+
+data_dir = args.dataset
+ref_dir = data_dir / "reference"
+assert ref_dir.exists(), f"{ref_dir} does not exist"
+seq_dir = data_dir / sequence
+assert seq_dir.exists(), f"{seq_dir} does not exist"
+seq_images = seq_dir / "undistorted_images"
+reloc = ref_dir / relocalization_files[sequence]
+
+output_dir = args.outputs
+output_dir.mkdir(exist_ok=True, parents=True)
+query_list = output_dir / f"{sequence}_queries_with_intrinsics.txt"
+ref_pairs = output_dir / "pairs-db-dist20.txt"
+ref_sfm = output_dir / "sfm_superpoint+superglue"
+results_path = output_dir / f"localization_{sequence}_hloc+superglue.txt"
+submission_dir = output_dir / "submission_hloc+superglue"
+
+num_loc_pairs = 10
+loc_pairs = output_dir / f"pairs-query-{sequence}-dist{num_loc_pairs}.txt"
+
+fconf = extract_features.confs["superpoint_max"]
+mconf = match_features.confs["superglue"]
+
+# Not all query images that are used for the evaluation
+# To save time in feature extraction, we delete unsused images.
+timestamps = get_timestamps(reloc, 1)
+delete_unused_images(seq_images, timestamps)
+
+# Generate a list of query images with their intrinsics.
+generate_query_lists(timestamps, seq_dir, query_list)
+
+# Generate the localization pairs from the given reference frames.
+generate_localization_pairs(sequence, reloc, num_loc_pairs, ref_pairs, loc_pairs)
+
+# Extract, match, amd localize.
+ffile = extract_features.main(fconf, seq_images, output_dir)
+mfile = match_features.main(mconf, loc_pairs, fconf["output"], output_dir)
+localize_sfm.main(ref_sfm, query_list, loc_pairs, ffile, mfile, results_path)
+
+# Convert the absolute poses to relative poses with the reference frames.
+submission_dir.mkdir(exist_ok=True)
+prepare_submission(results_path, reloc, ref_dir / "poses.txt", submission_dir)
+
+# If not a test sequence: evaluation the localization accuracy
+if "test" not in sequence:
+ logger.info("Evaluating the relocalization submission...")
+ evaluate_submission(submission_dir, reloc)
diff --git a/imcui/hloc/pipelines/4Seasons/prepare_reference.py b/imcui/hloc/pipelines/4Seasons/prepare_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..f47aee778ba24ef89a1cc4418f5db9cfab209b9d
--- /dev/null
+++ b/imcui/hloc/pipelines/4Seasons/prepare_reference.py
@@ -0,0 +1,51 @@
+import argparse
+from pathlib import Path
+
+from ... import extract_features, match_features, pairs_from_poses, triangulation
+from .utils import build_empty_colmap_model, delete_unused_images, get_timestamps
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/4Seasons",
+ help="Path to the dataset, default: %(default)s",
+)
+parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/4Seasons",
+ help="Path to the output directory, default: %(default)s",
+)
+args = parser.parse_args()
+
+ref_dir = args.dataset / "reference"
+assert ref_dir.exists(), f"{ref_dir} does not exist"
+ref_images = ref_dir / "undistorted_images"
+
+output_dir = args.outputs
+output_dir.mkdir(exist_ok=True, parents=True)
+ref_sfm_empty = output_dir / "sfm_reference_empty"
+ref_sfm = output_dir / "sfm_superpoint+superglue"
+
+num_ref_pairs = 20
+ref_pairs = output_dir / f"pairs-db-dist{num_ref_pairs}.txt"
+
+fconf = extract_features.confs["superpoint_max"]
+mconf = match_features.confs["superglue"]
+
+# Only reference images that have a pose are used in the pipeline.
+# To save time in feature extraction, we delete unsused images.
+delete_unused_images(ref_images, get_timestamps(ref_dir / "poses.txt", 0))
+
+# Build an empty COLMAP model containing only camera and images
+# from the provided poses and intrinsics.
+build_empty_colmap_model(ref_dir, ref_sfm_empty)
+
+# Match reference images that are spatially close.
+pairs_from_poses.main(ref_sfm_empty, ref_pairs, num_ref_pairs)
+
+# Extract, match, and triangulate the reference SfM model.
+ffile = extract_features.main(fconf, ref_images, output_dir)
+mfile = match_features.main(mconf, ref_pairs, fconf["output"], output_dir)
+triangulation.main(ref_sfm, ref_sfm_empty, ref_images, ref_pairs, ffile, mfile)
diff --git a/imcui/hloc/pipelines/4Seasons/utils.py b/imcui/hloc/pipelines/4Seasons/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5aace9dd9a31b9c39a58691c0c23795313bc462
--- /dev/null
+++ b/imcui/hloc/pipelines/4Seasons/utils.py
@@ -0,0 +1,231 @@
+import glob
+import logging
+import os
+from pathlib import Path
+
+import numpy as np
+
+from ...utils.parsers import parse_retrieval
+from ...utils.read_write_model import (
+ Camera,
+ Image,
+ qvec2rotmat,
+ rotmat2qvec,
+ write_model,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def get_timestamps(files, idx):
+ """Extract timestamps from a pose or relocalization file."""
+ lines = []
+ for p in files.parent.glob(files.name):
+ with open(p) as f:
+ lines += f.readlines()
+ timestamps = set()
+ for line in lines:
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ ts = line.replace(",", " ").split()[idx]
+ timestamps.add(ts)
+ return timestamps
+
+
+def delete_unused_images(root, timestamps):
+ """Delete all images in root if they are not contained in timestamps."""
+ images = glob.glob((root / "**/*.png").as_posix(), recursive=True)
+ deleted = 0
+ for image in images:
+ ts = Path(image).stem
+ if ts not in timestamps:
+ os.remove(image)
+ deleted += 1
+ logger.info(f"Deleted {deleted} images in {root}.")
+
+
+def camera_from_calibration_file(id_, path):
+ """Create a COLMAP camera from an MLAD calibration file."""
+ with open(path, "r") as f:
+ data = f.readlines()
+ model, fx, fy, cx, cy = data[0].split()[:5]
+ width, height = data[1].split()
+ assert model == "Pinhole"
+ model_name = "PINHOLE"
+ params = [float(i) for i in [fx, fy, cx, cy]]
+ camera = Camera(
+ id=id_, model=model_name, width=int(width), height=int(height), params=params
+ )
+ return camera
+
+
+def parse_poses(path, colmap=False):
+ """Parse a list of poses in COLMAP or MLAD quaternion convention."""
+ poses = []
+ with open(path) as f:
+ for line in f.readlines():
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ data = line.replace(",", " ").split()
+ ts, p = data[0], np.array(data[1:], float)
+ if colmap:
+ q, t = np.split(p, [4])
+ else:
+ t, q = np.split(p, [3])
+ q = q[[3, 0, 1, 2]] # xyzw to wxyz
+ R = qvec2rotmat(q)
+ poses.append((ts, R, t))
+ return poses
+
+
+def parse_relocalization(path, has_poses=False):
+ """Parse a relocalization file, possibly with poses."""
+ reloc = []
+ with open(path) as f:
+ for line in f.readlines():
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ data = line.replace(",", " ").split()
+ out = data[:2] # ref_ts, q_ts
+ if has_poses:
+ assert len(data) == 9
+ t, q = np.split(np.array(data[2:], float), [3])
+ q = q[[3, 0, 1, 2]] # xyzw to wxyz
+ R = qvec2rotmat(q)
+ out += [R, t]
+ reloc.append(out)
+ return reloc
+
+
+def build_empty_colmap_model(root, sfm_dir):
+ """Build a COLMAP model with images and cameras only."""
+ calibration = "Calibration/undistorted_calib_{}.txt"
+ cam0 = camera_from_calibration_file(0, root / calibration.format(0))
+ cam1 = camera_from_calibration_file(1, root / calibration.format(1))
+ cameras = {0: cam0, 1: cam1}
+
+ T_0to1 = np.loadtxt(root / "Calibration/undistorted_calib_stereo.txt")
+ poses = parse_poses(root / "poses.txt")
+ images = {}
+ id_ = 0
+ for ts, R_cam0_to_w, t_cam0_to_w in poses:
+ R_w_to_cam0 = R_cam0_to_w.T
+ t_w_to_cam0 = -(R_w_to_cam0 @ t_cam0_to_w)
+
+ R_w_to_cam1 = T_0to1[:3, :3] @ R_w_to_cam0
+ t_w_to_cam1 = T_0to1[:3, :3] @ t_w_to_cam0 + T_0to1[:3, 3]
+
+ for idx, (R_w_to_cam, t_w_to_cam) in enumerate(
+ zip([R_w_to_cam0, R_w_to_cam1], [t_w_to_cam0, t_w_to_cam1])
+ ):
+ image = Image(
+ id=id_,
+ qvec=rotmat2qvec(R_w_to_cam),
+ tvec=t_w_to_cam,
+ camera_id=idx,
+ name=f"cam{idx}/{ts}.png",
+ xys=np.zeros((0, 2), float),
+ point3D_ids=np.full(0, -1, int),
+ )
+ images[id_] = image
+ id_ += 1
+
+ sfm_dir.mkdir(exist_ok=True, parents=True)
+ write_model(cameras, images, {}, path=str(sfm_dir), ext=".bin")
+
+
+def generate_query_lists(timestamps, seq_dir, out_path):
+ """Create a list of query images with intrinsics from timestamps."""
+ cam0 = camera_from_calibration_file(
+ 0, seq_dir / "Calibration/undistorted_calib_0.txt"
+ )
+ intrinsics = [cam0.model, cam0.width, cam0.height] + cam0.params
+ intrinsics = [str(p) for p in intrinsics]
+ data = map(lambda ts: " ".join([f"cam0/{ts}.png"] + intrinsics), timestamps)
+ with open(out_path, "w") as f:
+ f.write("\n".join(data))
+
+
+def generate_localization_pairs(sequence, reloc, num, ref_pairs, out_path):
+ """Create the matching pairs for the localization.
+ We simply lookup the corresponding reference frame
+ and extract its `num` closest frames from the existing pair list.
+ """
+ if "test" in sequence:
+ # hard pairs will be overwritten by easy ones if available
+ relocs = [str(reloc).replace("*", d) for d in ["hard", "moderate", "easy"]]
+ else:
+ relocs = [reloc]
+ query_to_ref_ts = {}
+ for reloc in relocs:
+ with open(reloc, "r") as f:
+ for line in f.readlines():
+ line = line.rstrip("\n")
+ if line[0] == "#" or line == "":
+ continue
+ ref_ts, q_ts = line.split()[:2]
+ query_to_ref_ts[q_ts] = ref_ts
+
+ ts_to_name = "cam0/{}.png".format
+ ref_pairs = parse_retrieval(ref_pairs)
+ loc_pairs = []
+ for q_ts, ref_ts in query_to_ref_ts.items():
+ ref_name = ts_to_name(ref_ts)
+ selected = [ref_name] + ref_pairs[ref_name][: num - 1]
+ loc_pairs.extend([" ".join((ts_to_name(q_ts), s)) for s in selected])
+ with open(out_path, "w") as f:
+ f.write("\n".join(loc_pairs))
+
+
+def prepare_submission(results, relocs, poses_path, out_dir):
+ """Obtain relative poses from estimated absolute and reference poses."""
+ gt_poses = parse_poses(poses_path)
+ all_T_ref0_to_w = {ts: (R, t) for ts, R, t in gt_poses}
+
+ pred_poses = parse_poses(results, colmap=True)
+ all_T_w_to_q0 = {Path(name).stem: (R, t) for name, R, t in pred_poses}
+
+ for reloc in relocs.parent.glob(relocs.name):
+ relative_poses = []
+ reloc_ts = parse_relocalization(reloc)
+ for ref_ts, q_ts in reloc_ts:
+ R_w_to_q0, t_w_to_q0 = all_T_w_to_q0[q_ts]
+ R_ref0_to_w, t_ref0_to_w = all_T_ref0_to_w[ref_ts]
+
+ R_ref0_to_q0 = R_w_to_q0 @ R_ref0_to_w
+ t_ref0_to_q0 = R_w_to_q0 @ t_ref0_to_w + t_w_to_q0
+
+ tvec = t_ref0_to_q0.tolist()
+ qvec = rotmat2qvec(R_ref0_to_q0)[[1, 2, 3, 0]] # wxyz to xyzw
+
+ out = [ref_ts, q_ts] + list(map(str, tvec)) + list(map(str, qvec))
+ relative_poses.append(" ".join(out))
+
+ out_path = out_dir / reloc.name
+ with open(out_path, "w") as f:
+ f.write("\n".join(relative_poses))
+ logger.info(f"Submission file written to {out_path}.")
+
+
+def evaluate_submission(submission_dir, relocs, ths=[0.1, 0.2, 0.5]):
+ """Compute the relocalization recall from predicted and ground truth poses."""
+ for reloc in relocs.parent.glob(relocs.name):
+ poses_gt = parse_relocalization(reloc, has_poses=True)
+ poses_pred = parse_relocalization(submission_dir / reloc.name, has_poses=True)
+ poses_pred = {(ref_ts, q_ts): (R, t) for ref_ts, q_ts, R, t in poses_pred}
+
+ error = []
+ for ref_ts, q_ts, R_gt, t_gt in poses_gt:
+ R, t = poses_pred[(ref_ts, q_ts)]
+ e = np.linalg.norm(t - t_gt)
+ error.append(e)
+
+ error = np.array(error)
+ recall = [np.mean(error <= th) for th in ths]
+ s = f"Relocalization evaluation {submission_dir.name}/{reloc.name}\n"
+ s += " / ".join([f"{th:>7}m" for th in ths]) + "\n"
+ s += " / ".join([f"{100*r:>7.3f}%" for r in recall])
+ logger.info(s)
diff --git a/imcui/hloc/pipelines/7Scenes/README.md b/imcui/hloc/pipelines/7Scenes/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2124779c43ec8d1ffc552e07790d39c3578526a9
--- /dev/null
+++ b/imcui/hloc/pipelines/7Scenes/README.md
@@ -0,0 +1,65 @@
+# 7Scenes dataset
+
+## Installation
+
+Download the images from the [7Scenes project page](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/):
+```bash
+export dataset=datasets/7scenes
+for scene in chess fire heads office pumpkin redkitchen stairs; \
+do wget http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/$scene.zip -P $dataset \
+&& unzip $dataset/$scene.zip -d $dataset && unzip $dataset/$scene/'*.zip' -d $dataset/$scene; done
+```
+
+Download the SIFT SfM models and DenseVLAD image pairs, courtesy of Torsten Sattler:
+```bash
+function download {
+wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$1" -O $2 && rm -rf /tmp/cookies.txt
+unzip $2 -d $dataset && rm $2;
+}
+download 1cu6KUR7WHO7G4EO49Qi3HEKU6n_yYDjb $dataset/7scenes_sfm_triangulated.zip
+download 1IbS2vLmxr1N0f3CEnd_wsYlgclwTyvB1 $dataset/7scenes_densevlad_retrieval_top_10.zip
+```
+
+Download the rendered depth maps, courtesy of Eric Brachmann for [DSAC\*](https://github.com/vislearn/dsacstar):
+```bash
+wget https://heidata.uni-heidelberg.de/api/access/datafile/4037 -O $dataset/7scenes_rendered_depth.tar.gz
+mkdir $dataset/depth/
+tar xzf $dataset/7scenes_rendered_depth.tar.gz -C $dataset/depth/ && rm $dataset/7scenes_rendered_depth.tar.gz
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.7Scenes.pipeline [--use_dense_depth]
+```
+By default, hloc triangulates a sparse point cloud that can be noisy in indoor environements due to image noise and lack of texture. With the flag `--use_dense_depth`, the pipeline improves the accuracy of the sparse point cloud using dense depth maps provided by the dataset. The original depth maps captured by the RGBD sensor are miscalibrated, so we use depth maps rendered from the mesh obtained by fusing the RGBD data.
+
+## Results
+We report the median error in translation/rotation in cm/deg over all scenes:
+| Method \ Scene | Chess | Fire | Heads | Office | Pumpkin | Kitchen | Stairs |
+| ------------------------------- | -------------- | -------------- | -------------- | -------------- | -------------- | -------------- | ---------- |
+| Active Search | 3/0.87 | **2**/1.01 | **1**/0.82 | 4/1.15 | 7/1.69 | 5/1.72 | 4/**1.01** |
+| DSAC* | **2**/1.10 | **2**/1.24 | **1**/1.82 | **3**/1.15 | **4**/1.34 | 4/1.68 | **3**/1.16 |
+| **SuperPoint+SuperGlue** (sfm) | **2**/0.84 | **2**/0.93 | **1**/**0.74** | **3**/0.92 | 5/1.27 | 4/1.40 | 5/1.47 |
+| **SuperPoint+SuperGlue** (RGBD) | **2**/**0.80** | **2**/**0.77** | **1**/0.79 | **3**/**0.80** | **4**/**1.07** | **3**/**1.13** | 4/1.15 |
+
+## Citation
+Please cite the following paper if you use the 7Scenes dataset:
+```
+@inproceedings{shotton2013scene,
+ title={Scene coordinate regression forests for camera relocalization in {RGB-D} images},
+ author={Shotton, Jamie and Glocker, Ben and Zach, Christopher and Izadi, Shahram and Criminisi, Antonio and Fitzgibbon, Andrew},
+ booktitle={CVPR},
+ year={2013}
+}
+```
+
+Also cite DSAC* if you use dense depth maps with the flag `--use_dense_depth`:
+```
+@article{brachmann2020dsacstar,
+ title={Visual Camera Re-Localization from {RGB} and {RGB-D} Images Using {DSAC}},
+ author={Brachmann, Eric and Rother, Carsten},
+ journal={TPAMI},
+ year={2021}
+}
+```
diff --git a/imcui/hloc/pipelines/7Scenes/__init__.py b/imcui/hloc/pipelines/7Scenes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py b/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..95dfa461e17de99e0bdde0c52c5f02568c4fbab3
--- /dev/null
+++ b/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py
@@ -0,0 +1,134 @@
+from pathlib import Path
+
+import numpy as np
+import PIL.Image
+import pycolmap
+import torch
+from tqdm import tqdm
+
+from ...utils.read_write_model import read_model, write_model
+
+
+def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera):
+ assert len(depth) == len(p2D)
+ p2D_norm = np.stack(pycolmap.Camera(camera._asdict()).image_to_world(p2D))
+ p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1)
+ p3D_c = p2D_h * depth[:, None]
+ p3D_w = (p3D_c - t_w2c) @ R_w2c
+ return p3D_w
+
+
+def interpolate_depth(depth, kp):
+ h, w = depth.shape
+ kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1
+ assert np.all(kp > -1) and np.all(kp < 1)
+ depth = torch.from_numpy(depth)[None, None]
+ kp = torch.from_numpy(kp)[None, None]
+ grid_sample = torch.nn.functional.grid_sample
+
+ # To maximize the number of points that have depth:
+ # do bilinear interpolation first and then nearest for the remaining points
+ interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[0, :, 0]
+ interp_nn = torch.nn.functional.grid_sample(
+ depth, kp, align_corners=True, mode="nearest"
+ )[0, :, 0]
+ interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)
+ valid = ~torch.any(torch.isnan(interp), 0)
+
+ interp_depth = interp.T.numpy().flatten()
+ valid = valid.numpy()
+ return interp_depth, valid
+
+
+def image_path_to_rendered_depth_path(image_name):
+ parts = image_name.split("/")
+ name = "_".join(["".join(parts[0].split("-")), parts[1]])
+ name = name.replace("color", "pose")
+ name = name.replace("png", "depth.tiff")
+ return name
+
+
+def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1):
+ p3D = (p3D @ R.T) + t
+ visible = p3D[:, -1] >= eps # keep points in front of the camera
+ p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps)
+ p2D = np.stack(pycolmap.Camera(camera._asdict()).world_to_image(p2D_norm))
+ size = np.array([camera.width - pad - 1, camera.height - pad - 1])
+ valid = np.all((p2D >= pad) & (p2D <= size), -1)
+ valid &= visible
+ return p2D[valid], valid
+
+
+def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path):
+ cameras, images, points3D = read_model(sfm_path)
+ for imgid, img in tqdm(images.items()):
+ image_name = img.name
+ depth_name = image_path_to_rendered_depth_path(image_name)
+
+ depth = PIL.Image.open(Path(depth_folder_path) / depth_name)
+ depth = np.array(depth).astype("float64")
+ depth = depth / 1000.0 # mm to meter
+ depth[(depth == 0.0) | (depth > 1000.0)] = np.nan
+
+ R_w2c, t_w2c = img.qvec2rotmat(), img.tvec
+ camera = cameras[img.camera_id]
+ p3D_ids = img.point3D_ids
+ p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0)
+
+ p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera)
+ invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected]
+ interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds)
+ scs = scene_coordinates(
+ p2Ds[valids_backprojected],
+ R_w2c,
+ t_w2c,
+ interp_depth[valids_backprojected],
+ camera,
+ )
+ invalid_p3D_ids = np.append(
+ invalid_p3D_ids,
+ p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected],
+ )
+ for p3did in invalid_p3D_ids:
+ if p3did == -1:
+ continue
+ else:
+ obs_imgids = points3D[p3did].image_ids
+ invalid_imgids = list(np.where(obs_imgids == img.id)[0])
+ points3D[p3did] = points3D[p3did]._replace(
+ image_ids=np.delete(obs_imgids, invalid_imgids),
+ point2D_idxs=np.delete(
+ points3D[p3did].point2D_idxs, invalid_imgids
+ ),
+ )
+
+ new_p3D_ids = p3D_ids.copy()
+ sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1]
+ valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool)
+ valids[~valids_projected] = False
+ valids[valids_projected] = valids_backprojected
+ sub_p3D_ids[~valids] = -1
+ new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids
+ img = img._replace(point3D_ids=new_p3D_ids)
+
+ assert len(img.point3D_ids[img.point3D_ids != -1]) == len(
+ scs
+ ), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}"
+ for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]):
+ points3D[p3did] = points3D[p3did]._replace(xyz=scs[i])
+ images[imgid] = img
+
+ output_path.mkdir(parents=True, exist_ok=True)
+ write_model(cameras, images, points3D, output_path)
+
+
+if __name__ == "__main__":
+ dataset = Path("datasets/7scenes")
+ outputs = Path("outputs/7Scenes")
+
+ SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"]
+ for scene in SCENES:
+ sfm_path = outputs / scene / "sfm_superpoint+superglue"
+ depth_path = dataset / f"depth/7scenes_{scene}/train/depth"
+ output_path = outputs / scene / "sfm_superpoint+superglue+depth"
+ correct_sfm_with_gt_depth(sfm_path, depth_path, output_path)
diff --git a/imcui/hloc/pipelines/7Scenes/pipeline.py b/imcui/hloc/pipelines/7Scenes/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc28c6d29828ddbe4b9efaf1678be624a998819
--- /dev/null
+++ b/imcui/hloc/pipelines/7Scenes/pipeline.py
@@ -0,0 +1,139 @@
+import argparse
+from pathlib import Path
+
+from ... import (
+ extract_features,
+ localize_sfm,
+ logger,
+ match_features,
+ pairs_from_covisibility,
+ triangulation,
+)
+from ..Cambridge.utils import create_query_list_with_intrinsics, evaluate
+from .create_gt_sfm import correct_sfm_with_gt_depth
+from .utils import create_reference_sfm
+
+SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"]
+
+
+def run_scene(
+ images,
+ gt_dir,
+ retrieval,
+ outputs,
+ results,
+ num_covis,
+ use_dense_depth,
+ depth_dir=None,
+):
+ outputs.mkdir(exist_ok=True, parents=True)
+ ref_sfm_sift = outputs / "sfm_sift"
+ ref_sfm = outputs / "sfm_superpoint+superglue"
+ query_list = outputs / "query_list_with_intrinsics.txt"
+
+ feature_conf = {
+ "output": "feats-superpoint-n4096-r1024",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ },
+ "preprocessing": {
+ "globs": ["*.color.png"],
+ "grayscale": True,
+ "resize_max": 1024,
+ },
+ }
+ matcher_conf = match_features.confs["superglue"]
+ matcher_conf["model"]["sinkhorn_iterations"] = 5
+
+ test_list = gt_dir / "list_test.txt"
+ create_reference_sfm(gt_dir, ref_sfm_sift, test_list)
+ create_query_list_with_intrinsics(gt_dir, query_list, test_list)
+
+ features = extract_features.main(feature_conf, images, outputs, as_half=True)
+
+ sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt"
+ pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+ if not (use_dense_depth and ref_sfm.exists()):
+ triangulation.main(
+ ref_sfm, ref_sfm_sift, images, sfm_pairs, features, sfm_matches
+ )
+ if use_dense_depth:
+ assert depth_dir is not None
+ ref_sfm_fix = outputs / "sfm_superpoint+superglue+depth"
+ correct_sfm_with_gt_depth(ref_sfm, depth_dir, ref_sfm_fix)
+ ref_sfm = ref_sfm_fix
+
+ loc_matches = match_features.main(
+ matcher_conf, retrieval, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ ref_sfm,
+ query_list,
+ retrieval,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ prepend_camera_name=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+")
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/7scenes",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/7scenes",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument("--use_dense_depth", action="store_true")
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=30,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ args = parser.parse_args()
+
+ gt_dirs = args.dataset / "7scenes_sfm_triangulated/{scene}/triangulated"
+ retrieval_dirs = args.dataset / "7scenes_densevlad_retrieval_top_10"
+
+ all_results = {}
+ for scene in args.scenes:
+ logger.info(f'Working on scene "{scene}".')
+ results = (
+ args.outputs
+ / scene
+ / "results_{}.txt".format("dense" if args.use_dense_depth else "sparse")
+ )
+ if args.overwrite or not results.exists():
+ run_scene(
+ args.dataset / scene,
+ Path(str(gt_dirs).format(scene=scene)),
+ retrieval_dirs / f"{scene}_top10.txt",
+ args.outputs / scene,
+ results,
+ args.num_covis,
+ args.use_dense_depth,
+ depth_dir=args.dataset / f"depth/7scenes_{scene}/train/depth",
+ )
+ all_results[scene] = results
+
+ for scene in args.scenes:
+ logger.info(f'Evaluate scene "{scene}".')
+ gt_dir = Path(str(gt_dirs).format(scene=scene))
+ evaluate(gt_dir, all_results[scene], gt_dir / "list_test.txt")
diff --git a/imcui/hloc/pipelines/7Scenes/utils.py b/imcui/hloc/pipelines/7Scenes/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb021286e550de6fe89e03370c11e3f7d567c5f
--- /dev/null
+++ b/imcui/hloc/pipelines/7Scenes/utils.py
@@ -0,0 +1,34 @@
+import logging
+
+import numpy as np
+
+from hloc.utils.read_write_model import read_model, write_model
+
+logger = logging.getLogger(__name__)
+
+
+def create_reference_sfm(full_model, ref_model, blacklist=None, ext=".bin"):
+ """Create a new COLMAP model with only training images."""
+ logger.info("Creating the reference model.")
+ ref_model.mkdir(exist_ok=True)
+ cameras, images, points3D = read_model(full_model, ext)
+
+ if blacklist is not None:
+ with open(blacklist, "r") as f:
+ blacklist = f.read().rstrip().split("\n")
+
+ images_ref = dict()
+ for id_, image in images.items():
+ if blacklist and image.name in blacklist:
+ continue
+ images_ref[id_] = image
+
+ points3D_ref = dict()
+ for id_, point3D in points3D.items():
+ ref_ids = [i for i in point3D.image_ids if i in images_ref]
+ if len(ref_ids) == 0:
+ continue
+ points3D_ref[id_] = point3D._replace(image_ids=np.array(ref_ids))
+
+ write_model(cameras, images_ref, points3D_ref, ref_model, ".bin")
+ logger.info(f"Kept {len(images_ref)} images out of {len(images)}.")
diff --git a/imcui/hloc/pipelines/Aachen/README.md b/imcui/hloc/pipelines/Aachen/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..57b66d6ad1e5cdb3e74c6c1866d394d487c608d1
--- /dev/null
+++ b/imcui/hloc/pipelines/Aachen/README.md
@@ -0,0 +1,16 @@
+# Aachen-Day-Night dataset
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/aachen
+wget -r -np -nH -R "index.html*,aachen_v1_1.zip" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset
+unzip $dataset/images/database_and_query_images.zip -d $dataset
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.Aachen.pipeline
+```
diff --git a/imcui/hloc/pipelines/Aachen/__init__.py b/imcui/hloc/pipelines/Aachen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/pipelines/Aachen/pipeline.py b/imcui/hloc/pipelines/Aachen/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e31ce7255ce5e178f505a6fb415b5bbf13b76879
--- /dev/null
+++ b/imcui/hloc/pipelines/Aachen/pipeline.py
@@ -0,0 +1,109 @@
+import argparse
+from pathlib import Path
+from pprint import pformat
+
+from ... import (
+ colmap_from_nvm,
+ extract_features,
+ localize_sfm,
+ logger,
+ match_features,
+ pairs_from_covisibility,
+ pairs_from_retrieval,
+ triangulation,
+)
+
+
+def run(args):
+ # Setup the paths
+ dataset = args.dataset
+ images = dataset / "images_upright/"
+
+ outputs = args.outputs # where everything will be saved
+ sift_sfm = outputs / "sfm_sift" # from which we extract the reference poses
+ reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build
+ sfm_pairs = (
+ outputs / f"pairs-db-covis{args.num_covis}.txt"
+ ) # top-k most covisible in SIFT model
+ loc_pairs = (
+ outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+ ) # top-k retrieved by NetVLAD
+ results = outputs / f"Aachen_hloc_superpoint+superglue_netvlad{args.num_loc}.txt"
+
+ # list the standard configurations available
+ logger.info("Configs for feature extractors:\n%s", pformat(extract_features.confs))
+ logger.info("Configs for feature matchers:\n%s", pformat(match_features.confs))
+
+ # pick one of the configurations for extraction and matching
+ retrieval_conf = extract_features.confs["netvlad"]
+ feature_conf = extract_features.confs["superpoint_aachen"]
+ matcher_conf = match_features.confs["superglue"]
+
+ features = extract_features.main(feature_conf, images, outputs)
+
+ colmap_from_nvm.main(
+ dataset / "3D-models/aachen_cvpr2018_db.nvm",
+ dataset / "3D-models/database_intrinsics.txt",
+ dataset / "aachen.db",
+ sift_sfm,
+ )
+ pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+
+ triangulation.main(
+ reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches
+ )
+
+ global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+ pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix="query",
+ db_model=reference_sfm,
+ )
+ loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ reference_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ ) # not required with SuperPoint+SuperGlue
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/aachen",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=50,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
+ run(args)
diff --git a/imcui/hloc/pipelines/Aachen_v1_1/README.md b/imcui/hloc/pipelines/Aachen_v1_1/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c17e751777b56e36b8633c1eec37ff656f2d3979
--- /dev/null
+++ b/imcui/hloc/pipelines/Aachen_v1_1/README.md
@@ -0,0 +1,17 @@
+# Aachen-Day-Night dataset v1.1
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/aachen_v1.1
+wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Aachen-Day-Night/ -P $dataset
+unzip $dataset/images/database_and_query_images.zip -d $dataset
+unzip $dataset/aachen_v1_1.zip -d $dataset
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.Aachen_v1_1.pipeline
+```
diff --git a/imcui/hloc/pipelines/Aachen_v1_1/__init__.py b/imcui/hloc/pipelines/Aachen_v1_1/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py b/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..0753604624d31984952942ab5b297a247e4d5123
--- /dev/null
+++ b/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py
@@ -0,0 +1,104 @@
+import argparse
+from pathlib import Path
+from pprint import pformat
+
+from ... import (
+ extract_features,
+ localize_sfm,
+ logger,
+ match_features,
+ pairs_from_covisibility,
+ pairs_from_retrieval,
+ triangulation,
+)
+
+
+def run(args):
+ # Setup the paths
+ dataset = args.dataset
+ images = dataset / "images_upright/"
+ sift_sfm = dataset / "3D-models/aachen_v_1_1"
+
+ outputs = args.outputs # where everything will be saved
+ reference_sfm = outputs / "sfm_superpoint+superglue" # the SfM model we will build
+ sfm_pairs = (
+ outputs / f"pairs-db-covis{args.num_covis}.txt"
+ ) # top-k most covisible in SIFT model
+ loc_pairs = (
+ outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+ ) # top-k retrieved by NetVLAD
+ results = (
+ outputs / f"Aachen-v1.1_hloc_superpoint+superglue_netvlad{args.num_loc}.txt"
+ )
+
+ # list the standard configurations available
+ logger.info("Configs for feature extractors:\n%s", pformat(extract_features.confs))
+ logger.info("Configs for feature matchers:\n%s", pformat(match_features.confs))
+
+ # pick one of the configurations for extraction and matching
+ retrieval_conf = extract_features.confs["netvlad"]
+ feature_conf = extract_features.confs["superpoint_max"]
+ matcher_conf = match_features.confs["superglue"]
+
+ features = extract_features.main(feature_conf, images, outputs)
+
+ pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+
+ triangulation.main(
+ reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches
+ )
+
+ global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+ pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix="query",
+ db_model=reference_sfm,
+ )
+ loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ reference_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ ) # not required with SuperPoint+SuperGlue
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/aachen_v1.1",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen_v1.1",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=50,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
+ run(args)
diff --git a/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py b/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c0a769897604fa9838106c7a38ba585ceeefe5c
--- /dev/null
+++ b/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py
@@ -0,0 +1,104 @@
+import argparse
+from pathlib import Path
+from pprint import pformat
+
+from ... import (
+ extract_features,
+ localize_sfm,
+ logger,
+ match_dense,
+ pairs_from_covisibility,
+ pairs_from_retrieval,
+ triangulation,
+)
+
+
+def run(args):
+ # Setup the paths
+ dataset = args.dataset
+ images = dataset / "images_upright/"
+ sift_sfm = dataset / "3D-models/aachen_v_1_1"
+
+ outputs = args.outputs # where everything will be saved
+ outputs.mkdir()
+ reference_sfm = outputs / "sfm_loftr" # the SfM model we will build
+ sfm_pairs = (
+ outputs / f"pairs-db-covis{args.num_covis}.txt"
+ ) # top-k most covisible in SIFT model
+ loc_pairs = (
+ outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+ ) # top-k retrieved by NetVLAD
+ results = outputs / f"Aachen-v1.1_hloc_loftr_netvlad{args.num_loc}.txt"
+
+ # list the standard configurations available
+ logger.info("Configs for dense feature matchers:\n%s", pformat(match_dense.confs))
+
+ # pick one of the configurations for extraction and matching
+ retrieval_conf = extract_features.confs["netvlad"]
+ matcher_conf = match_dense.confs["loftr_aachen"]
+
+ pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+ features, sfm_matches = match_dense.main(
+ matcher_conf, sfm_pairs, images, outputs, max_kps=8192, overwrite=False
+ )
+
+ triangulation.main(
+ reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches
+ )
+
+ global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+ pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix="query",
+ db_model=reference_sfm,
+ )
+ features, loc_matches = match_dense.main(
+ matcher_conf,
+ loc_pairs,
+ images,
+ outputs,
+ features=features,
+ max_kps=None,
+ matches=sfm_matches,
+ )
+
+ localize_sfm.main(
+ reference_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ ) # not required with loftr
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/aachen_v1.1",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen_v1.1",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=50,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
diff --git a/imcui/hloc/pipelines/CMU/README.md b/imcui/hloc/pipelines/CMU/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..566ba352c53ada2a13dce21c8ec1041b56969d03
--- /dev/null
+++ b/imcui/hloc/pipelines/CMU/README.md
@@ -0,0 +1,16 @@
+# Extended CMU Seasons dataset
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/cmu_extended
+wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/Extended-CMU-Seasons/ -P $dataset
+for slice in $dataset/*.tar; do tar -xf $slice -C $dataset && rm $slice; done
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.CMU.pipeline
+```
diff --git a/imcui/hloc/pipelines/CMU/__init__.py b/imcui/hloc/pipelines/CMU/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/pipelines/CMU/pipeline.py b/imcui/hloc/pipelines/CMU/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4706a05c6aef134dc244419501b22a2ba95ede04
--- /dev/null
+++ b/imcui/hloc/pipelines/CMU/pipeline.py
@@ -0,0 +1,133 @@
+import argparse
+from pathlib import Path
+
+from ... import (
+ extract_features,
+ localize_sfm,
+ logger,
+ match_features,
+ pairs_from_covisibility,
+ pairs_from_retrieval,
+ triangulation,
+)
+
+TEST_SLICES = [2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18, 19, 20, 21]
+
+
+def generate_query_list(dataset, path, slice_):
+ cameras = {}
+ with open(dataset / "intrinsics.txt", "r") as f:
+ for line in f.readlines():
+ if line[0] == "#" or line == "\n":
+ continue
+ data = line.split()
+ cameras[data[0]] = data[1:]
+ assert len(cameras) == 2
+
+ queries = dataset / f"{slice_}/test-images-{slice_}.txt"
+ with open(queries, "r") as f:
+ queries = [q.rstrip("\n") for q in f.readlines()]
+
+ out = [[q] + cameras[q.split("_")[2]] for q in queries]
+ with open(path, "w") as f:
+ f.write("\n".join(map(" ".join, out)))
+
+
+def run_slice(slice_, root, outputs, num_covis, num_loc):
+ dataset = root / slice_
+ ref_images = dataset / "database"
+ query_images = dataset / "query"
+ sift_sfm = dataset / "sparse"
+
+ outputs = outputs / slice_
+ outputs.mkdir(exist_ok=True, parents=True)
+ query_list = dataset / "queries_with_intrinsics.txt"
+ sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt"
+ loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt"
+ ref_sfm = outputs / "sfm_superpoint+superglue"
+ results = outputs / f"CMU_hloc_superpoint+superglue_netvlad{num_loc}.txt"
+
+ # pick one of the configurations for extraction and matching
+ retrieval_conf = extract_features.confs["netvlad"]
+ feature_conf = extract_features.confs["superpoint_aachen"]
+ matcher_conf = match_features.confs["superglue"]
+
+ pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=num_covis)
+ features = extract_features.main(feature_conf, ref_images, outputs, as_half=True)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+ triangulation.main(ref_sfm, sift_sfm, ref_images, sfm_pairs, features, sfm_matches)
+
+ generate_query_list(root, query_list, slice_)
+ global_descriptors = extract_features.main(retrieval_conf, ref_images, outputs)
+ global_descriptors = extract_features.main(retrieval_conf, query_images, outputs)
+ pairs_from_retrieval.main(
+ global_descriptors, loc_pairs, num_loc, query_list=query_list, db_model=ref_sfm
+ )
+
+ features = extract_features.main(feature_conf, query_images, outputs, as_half=True)
+ loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ ref_sfm,
+ dataset / "queries/*_time_queries_with_intrinsics.txt",
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--slices",
+ type=str,
+ default="*",
+ help="a single number, an interval (e.g. 2-6), "
+ "or a Python-style list or int (e.g. [2, 3, 4]",
+ )
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/cmu_extended",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/aachen_extended",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=10,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
+
+ if args.slice == "*":
+ slices = TEST_SLICES
+ if "-" in args.slices:
+ min_, max_ = args.slices.split("-")
+ slices = list(range(int(min_), int(max_) + 1))
+ else:
+ slices = eval(args.slices)
+ if isinstance(slices, int):
+ slices = [slices]
+
+ for slice_ in slices:
+ logger.info("Working on slice %s.", slice_)
+ run_slice(
+ f"slice{slice_}", args.dataset, args.outputs, args.num_covis, args.num_loc
+ )
diff --git a/imcui/hloc/pipelines/Cambridge/README.md b/imcui/hloc/pipelines/Cambridge/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d5ae07b71c48a98fa9235f0dfb0234c3c18c74c6
--- /dev/null
+++ b/imcui/hloc/pipelines/Cambridge/README.md
@@ -0,0 +1,47 @@
+# Cambridge Landmarks dataset
+
+## Installation
+
+Download the dataset from the [PoseNet project page](http://mi.eng.cam.ac.uk/projects/relocalisation/):
+```bash
+export dataset=datasets/cambridge
+export scenes=( "KingsCollege" "OldHospital" "StMarysChurch" "ShopFacade" "GreatCourt" )
+export IDs=( "251342" "251340" "251294" "251336" "251291" )
+for i in "${!scenes[@]}"; do
+wget https://www.repository.cam.ac.uk/bitstream/handle/1810/${IDs[i]}/${scenes[i]}.zip -P $dataset \
+&& unzip $dataset/${scenes[i]}.zip -d $dataset && rm $dataset/${scenes[i]}.zip; done
+```
+
+Download the SIFT SfM models, courtesy of Torsten Sattler:
+```bash
+export fileid=1esqzZ1zEQlzZVic-H32V6kkZvc4NeS15
+export filename=$dataset/CambridgeLandmarks_Colmap_Retriangulated_1024px.zip
+wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$fileid" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=$fileid" -O $filename && rm -rf /tmp/cookies.txt
+unzip $filename -d $dataset
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.Cambridge.pipeline
+```
+
+## Results
+We report the median error in translation/rotation in cm/deg over all scenes:
+| Method \ Scene | Court | King's | Hospital | Shop | St. Mary's |
+| ------------------------ | --------------- | --------------- | --------------- | -------------- | -------------- |
+| Active Search | 24/0.13 | 13/0.22 | 20/0.36 | **4**/0.21 | 8/0.25 |
+| DSAC* | 49/0.3 | 15/0.3 | 21/0.4 | 5/0.3 | 13/0.4 |
+| **SuperPoint+SuperGlue** | **17**/**0.11** | **12**/**0.21** | **14**/**0.30** | **4**/**0.19** | **7**/**0.22** |
+
+## Citation
+
+Please cite the following paper if you use the Cambridge Landmarks dataset:
+```
+@inproceedings{kendall2015posenet,
+ title={{PoseNet}: A convolutional network for real-time {6-DoF} camera relocalization},
+ author={Kendall, Alex and Grimes, Matthew and Cipolla, Roberto},
+ booktitle={ICCV},
+ year={2015}
+}
+```
diff --git a/imcui/hloc/pipelines/Cambridge/__init__.py b/imcui/hloc/pipelines/Cambridge/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/pipelines/Cambridge/pipeline.py b/imcui/hloc/pipelines/Cambridge/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a676e5af411858f2459cb7f58f777b30be67d29
--- /dev/null
+++ b/imcui/hloc/pipelines/Cambridge/pipeline.py
@@ -0,0 +1,140 @@
+import argparse
+from pathlib import Path
+
+from ... import (
+ extract_features,
+ localize_sfm,
+ logger,
+ match_features,
+ pairs_from_covisibility,
+ pairs_from_retrieval,
+ triangulation,
+)
+from .utils import create_query_list_with_intrinsics, evaluate, scale_sfm_images
+
+SCENES = ["KingsCollege", "OldHospital", "ShopFacade", "StMarysChurch", "GreatCourt"]
+
+
+def run_scene(images, gt_dir, outputs, results, num_covis, num_loc):
+ ref_sfm_sift = gt_dir / "model_train"
+ test_list = gt_dir / "list_query.txt"
+
+ outputs.mkdir(exist_ok=True, parents=True)
+ ref_sfm = outputs / "sfm_superpoint+superglue"
+ ref_sfm_scaled = outputs / "sfm_sift_scaled"
+ query_list = outputs / "query_list_with_intrinsics.txt"
+ sfm_pairs = outputs / f"pairs-db-covis{num_covis}.txt"
+ loc_pairs = outputs / f"pairs-query-netvlad{num_loc}.txt"
+
+ feature_conf = {
+ "output": "feats-superpoint-n4096-r1024",
+ "model": {
+ "name": "superpoint",
+ "nms_radius": 3,
+ "max_keypoints": 4096,
+ },
+ "preprocessing": {
+ "grayscale": True,
+ "resize_max": 1024,
+ },
+ }
+ matcher_conf = match_features.confs["superglue"]
+ retrieval_conf = extract_features.confs["netvlad"]
+
+ create_query_list_with_intrinsics(
+ gt_dir / "empty_all", query_list, test_list, ext=".txt", image_dir=images
+ )
+ with open(test_list, "r") as f:
+ query_seqs = {q.split("/")[0] for q in f.read().rstrip().split("\n")}
+
+ global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+ pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ num_loc,
+ db_model=ref_sfm_sift,
+ query_prefix=query_seqs,
+ )
+
+ features = extract_features.main(feature_conf, images, outputs, as_half=True)
+ pairs_from_covisibility.main(ref_sfm_sift, sfm_pairs, num_matched=num_covis)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+
+ scale_sfm_images(ref_sfm_sift, ref_sfm_scaled, images)
+ triangulation.main(
+ ref_sfm, ref_sfm_scaled, images, sfm_pairs, features, sfm_matches
+ )
+
+ loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ ref_sfm,
+ query_list,
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ prepend_camera_name=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--scenes", default=SCENES, choices=SCENES, nargs="+")
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/cambridge",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/cambridge",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=10,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
+
+ gt_dirs = args.dataset / "CambridgeLandmarks_Colmap_Retriangulated_1024px"
+
+ all_results = {}
+ for scene in args.scenes:
+ logger.info(f'Working on scene "{scene}".')
+ results = args.outputs / scene / "results.txt"
+ if args.overwrite or not results.exists():
+ run_scene(
+ args.dataset / scene,
+ gt_dirs / scene,
+ args.outputs / scene,
+ results,
+ args.num_covis,
+ args.num_loc,
+ )
+ all_results[scene] = results
+
+ for scene in args.scenes:
+ logger.info(f'Evaluate scene "{scene}".')
+ evaluate(
+ gt_dirs / scene / "empty_all",
+ all_results[scene],
+ gt_dirs / scene / "list_query.txt",
+ ext=".txt",
+ )
diff --git a/imcui/hloc/pipelines/Cambridge/utils.py b/imcui/hloc/pipelines/Cambridge/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..36460f067369065668837fa317b1c0f7047e9203
--- /dev/null
+++ b/imcui/hloc/pipelines/Cambridge/utils.py
@@ -0,0 +1,145 @@
+import logging
+
+import cv2
+import numpy as np
+
+from hloc.utils.read_write_model import (
+ qvec2rotmat,
+ read_cameras_binary,
+ read_cameras_text,
+ read_images_binary,
+ read_images_text,
+ read_model,
+ write_model,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def scale_sfm_images(full_model, scaled_model, image_dir):
+ """Duplicate the provided model and scale the camera intrinsics so that
+ they match the original image resolution - makes everything easier.
+ """
+ logger.info("Scaling the COLMAP model to the original image size.")
+ scaled_model.mkdir(exist_ok=True)
+ cameras, images, points3D = read_model(full_model)
+
+ scaled_cameras = {}
+ for id_, image in images.items():
+ name = image.name
+ img = cv2.imread(str(image_dir / name))
+ assert img is not None, image_dir / name
+ h, w = img.shape[:2]
+
+ cam_id = image.camera_id
+ if cam_id in scaled_cameras:
+ assert scaled_cameras[cam_id].width == w
+ assert scaled_cameras[cam_id].height == h
+ continue
+
+ camera = cameras[cam_id]
+ assert camera.model == "SIMPLE_RADIAL"
+ sx = w / camera.width
+ sy = h / camera.height
+ assert sx == sy, (sx, sy)
+ scaled_cameras[cam_id] = camera._replace(
+ width=w, height=h, params=camera.params * np.array([sx, sx, sy, 1.0])
+ )
+
+ write_model(scaled_cameras, images, points3D, scaled_model)
+
+
+def create_query_list_with_intrinsics(
+ model, out, list_file=None, ext=".bin", image_dir=None
+):
+ """Create a list of query images with intrinsics from the colmap model."""
+ if ext == ".bin":
+ images = read_images_binary(model / "images.bin")
+ cameras = read_cameras_binary(model / "cameras.bin")
+ else:
+ images = read_images_text(model / "images.txt")
+ cameras = read_cameras_text(model / "cameras.txt")
+
+ name2id = {image.name: i for i, image in images.items()}
+ if list_file is None:
+ names = list(name2id)
+ else:
+ with open(list_file, "r") as f:
+ names = f.read().rstrip().split("\n")
+ data = []
+ for name in names:
+ image = images[name2id[name]]
+ camera = cameras[image.camera_id]
+ w, h, params = camera.width, camera.height, camera.params
+
+ if image_dir is not None:
+ # Check the original image size and rescale the camera intrinsics
+ img = cv2.imread(str(image_dir / name))
+ assert img is not None, image_dir / name
+ h_orig, w_orig = img.shape[:2]
+ assert camera.model == "SIMPLE_RADIAL"
+ sx = w_orig / w
+ sy = h_orig / h
+ assert sx == sy, (sx, sy)
+ w, h = w_orig, h_orig
+ params = params * np.array([sx, sx, sy, 1.0])
+
+ p = [name, camera.model, w, h] + params.tolist()
+ data.append(" ".join(map(str, p)))
+ with open(out, "w") as f:
+ f.write("\n".join(data))
+
+
+def evaluate(model, results, list_file=None, ext=".bin", only_localized=False):
+ predictions = {}
+ with open(results, "r") as f:
+ for data in f.read().rstrip().split("\n"):
+ data = data.split()
+ name = data[0]
+ q, t = np.split(np.array(data[1:], float), [4])
+ predictions[name] = (qvec2rotmat(q), t)
+ if ext == ".bin":
+ images = read_images_binary(model / "images.bin")
+ else:
+ images = read_images_text(model / "images.txt")
+ name2id = {image.name: i for i, image in images.items()}
+
+ if list_file is None:
+ test_names = list(name2id)
+ else:
+ with open(list_file, "r") as f:
+ test_names = f.read().rstrip().split("\n")
+
+ errors_t = []
+ errors_R = []
+ for name in test_names:
+ if name not in predictions:
+ if only_localized:
+ continue
+ e_t = np.inf
+ e_R = 180.0
+ else:
+ image = images[name2id[name]]
+ R_gt, t_gt = image.qvec2rotmat(), image.tvec
+ R, t = predictions[name]
+ e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0)
+ cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1.0, 1.0)
+ e_R = np.rad2deg(np.abs(np.arccos(cos)))
+ errors_t.append(e_t)
+ errors_R.append(e_R)
+
+ errors_t = np.array(errors_t)
+ errors_R = np.array(errors_R)
+
+ med_t = np.median(errors_t)
+ med_R = np.median(errors_R)
+ out = f"Results for file {results.name}:"
+ out += f"\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg"
+
+ out += "\nPercentage of test images localized within:"
+ threshs_t = [0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5.0]
+ threshs_R = [1.0, 2.0, 3.0, 5.0, 2.0, 5.0, 10.0]
+ for th_t, th_R in zip(threshs_t, threshs_R):
+ ratio = np.mean((errors_t < th_t) & (errors_R < th_R))
+ out += f"\n\t{th_t*100:.0f}cm, {th_R:.0f}deg : {ratio*100:.2f}%"
+ logger.info(out)
diff --git a/imcui/hloc/pipelines/RobotCar/README.md b/imcui/hloc/pipelines/RobotCar/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9881d153d4930cf32b5481ecd4fa2c900fa58c8c
--- /dev/null
+++ b/imcui/hloc/pipelines/RobotCar/README.md
@@ -0,0 +1,16 @@
+# RobotCar Seasons dataset
+
+## Installation
+
+Download the dataset from [visuallocalization.net](https://www.visuallocalization.net):
+```bash
+export dataset=datasets/robotcar
+wget -r -np -nH -R "index.html*" --cut-dirs=4 https://data.ciirc.cvut.cz/public/projects/2020VisualLocalization/RobotCar-Seasons/ -P $dataset
+for condition in $dataset/images/*.zip; do unzip condition -d $dataset/images/; done
+```
+
+## Pipeline
+
+```bash
+python3 -m hloc.pipelines.RobotCar.pipeline
+```
diff --git a/imcui/hloc/pipelines/RobotCar/__init__.py b/imcui/hloc/pipelines/RobotCar/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py b/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e90ed72b5391990d26961b5acfaaada6517ac191
--- /dev/null
+++ b/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py
@@ -0,0 +1,176 @@
+import argparse
+import logging
+import sqlite3
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+from tqdm import tqdm
+
+from ...colmap_from_nvm import (
+ camera_center_to_translation,
+ recover_database_images_and_ids,
+)
+from ...utils.read_write_model import (
+ CAMERA_MODEL_IDS,
+ Camera,
+ Image,
+ Point3D,
+ write_model,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def read_nvm_model(nvm_path, database_path, image_ids, camera_ids, skip_points=False):
+ # Extract the intrinsics from the db file instead of the NVM model
+ db = sqlite3.connect(str(database_path))
+ ret = db.execute("SELECT camera_id, model, width, height, params FROM cameras;")
+ cameras = {}
+ for camera_id, camera_model, width, height, params in ret:
+ params = np.fromstring(params, dtype=np.double).reshape(-1)
+ camera_model = CAMERA_MODEL_IDS[camera_model]
+ assert len(params) == camera_model.num_params, (
+ len(params),
+ camera_model.num_params,
+ )
+ camera = Camera(
+ id=camera_id,
+ model=camera_model.model_name,
+ width=int(width),
+ height=int(height),
+ params=params,
+ )
+ cameras[camera_id] = camera
+
+ nvm_f = open(nvm_path, "r")
+ line = nvm_f.readline()
+ while line == "\n" or line.startswith("NVM_V3"):
+ line = nvm_f.readline()
+ num_images = int(line)
+ # assert num_images == len(cameras), (num_images, len(cameras))
+
+ logger.info(f"Reading {num_images} images...")
+ image_idx_to_db_image_id = []
+ image_data = []
+ i = 0
+ while i < num_images:
+ line = nvm_f.readline()
+ if line == "\n":
+ continue
+ data = line.strip("\n").lstrip("./").split(" ")
+ image_data.append(data)
+ image_idx_to_db_image_id.append(image_ids[data[0]])
+ i += 1
+
+ line = nvm_f.readline()
+ while line == "\n":
+ line = nvm_f.readline()
+ num_points = int(line)
+
+ if skip_points:
+ logger.info(f"Skipping {num_points} points.")
+ num_points = 0
+ else:
+ logger.info(f"Reading {num_points} points...")
+ points3D = {}
+ image_idx_to_keypoints = defaultdict(list)
+ i = 0
+ pbar = tqdm(total=num_points, unit="pts")
+ while i < num_points:
+ line = nvm_f.readline()
+ if line == "\n":
+ continue
+
+ data = line.strip("\n").split(" ")
+ x, y, z, r, g, b, num_observations = data[:7]
+ obs_image_ids, point2D_idxs = [], []
+ for j in range(int(num_observations)):
+ s = 7 + 4 * j
+ img_index, kp_index, kx, ky = data[s : s + 4]
+ image_idx_to_keypoints[int(img_index)].append(
+ (int(kp_index), float(kx), float(ky), i)
+ )
+ db_image_id = image_idx_to_db_image_id[int(img_index)]
+ obs_image_ids.append(db_image_id)
+ point2D_idxs.append(kp_index)
+
+ point = Point3D(
+ id=i,
+ xyz=np.array([x, y, z], float),
+ rgb=np.array([r, g, b], int),
+ error=1.0, # fake
+ image_ids=np.array(obs_image_ids, int),
+ point2D_idxs=np.array(point2D_idxs, int),
+ )
+ points3D[i] = point
+
+ i += 1
+ pbar.update(1)
+ pbar.close()
+
+ logger.info("Parsing image data...")
+ images = {}
+ for i, data in enumerate(image_data):
+ # Skip the focal length. Skip the distortion and terminal 0.
+ name, _, qw, qx, qy, qz, cx, cy, cz, _, _ = data
+ qvec = np.array([qw, qx, qy, qz], float)
+ c = np.array([cx, cy, cz], float)
+ t = camera_center_to_translation(c, qvec)
+
+ if i in image_idx_to_keypoints:
+ # NVM only stores triangulated 2D keypoints: add dummy ones
+ keypoints = image_idx_to_keypoints[i]
+ point2D_idxs = np.array([d[0] for d in keypoints])
+ tri_xys = np.array([[x, y] for _, x, y, _ in keypoints])
+ tri_ids = np.array([i for _, _, _, i in keypoints])
+
+ num_2Dpoints = max(point2D_idxs) + 1
+ xys = np.zeros((num_2Dpoints, 2), float)
+ point3D_ids = np.full(num_2Dpoints, -1, int)
+ xys[point2D_idxs] = tri_xys
+ point3D_ids[point2D_idxs] = tri_ids
+ else:
+ xys = np.zeros((0, 2), float)
+ point3D_ids = np.full(0, -1, int)
+
+ image_id = image_ids[name]
+ image = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=t,
+ camera_id=camera_ids[name],
+ name=name.replace("png", "jpg"), # some hack required for RobotCar
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ images[image_id] = image
+
+ return cameras, images, points3D
+
+
+def main(nvm, database, output, skip_points=False):
+ assert nvm.exists(), nvm
+ assert database.exists(), database
+
+ image_ids, camera_ids = recover_database_images_and_ids(database)
+
+ logger.info("Reading the NVM model...")
+ model = read_nvm_model(
+ nvm, database, image_ids, camera_ids, skip_points=skip_points
+ )
+
+ logger.info("Writing the COLMAP model...")
+ output.mkdir(exist_ok=True, parents=True)
+ write_model(*model, path=str(output), ext=".bin")
+ logger.info("Done.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--nvm", required=True, type=Path)
+ parser.add_argument("--database", required=True, type=Path)
+ parser.add_argument("--output", required=True, type=Path)
+ parser.add_argument("--skip_points", action="store_true")
+ args = parser.parse_args()
+ main(**args.__dict__)
diff --git a/imcui/hloc/pipelines/RobotCar/pipeline.py b/imcui/hloc/pipelines/RobotCar/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a7ee314480d09b2200b9ff1992a3217e24bea2f
--- /dev/null
+++ b/imcui/hloc/pipelines/RobotCar/pipeline.py
@@ -0,0 +1,143 @@
+import argparse
+import glob
+from pathlib import Path
+
+from ... import (
+ extract_features,
+ localize_sfm,
+ match_features,
+ pairs_from_covisibility,
+ pairs_from_retrieval,
+ triangulation,
+)
+from . import colmap_from_nvm
+
+CONDITIONS = [
+ "dawn",
+ "dusk",
+ "night",
+ "night-rain",
+ "overcast-summer",
+ "overcast-winter",
+ "rain",
+ "snow",
+ "sun",
+]
+
+
+def generate_query_list(dataset, image_dir, path):
+ h, w = 1024, 1024
+ intrinsics_filename = "intrinsics/{}_intrinsics.txt"
+ cameras = {}
+ for side in ["left", "right", "rear"]:
+ with open(dataset / intrinsics_filename.format(side), "r") as f:
+ fx = f.readline().split()[1]
+ fy = f.readline().split()[1]
+ cx = f.readline().split()[1]
+ cy = f.readline().split()[1]
+ assert fx == fy
+ params = ["SIMPLE_RADIAL", w, h, fx, cx, cy, 0.0]
+ cameras[side] = [str(p) for p in params]
+
+ queries = glob.glob((image_dir / "**/*.jpg").as_posix(), recursive=True)
+ queries = [
+ Path(q).relative_to(image_dir.parents[0]).as_posix() for q in sorted(queries)
+ ]
+
+ out = [[q] + cameras[Path(q).parent.name] for q in queries]
+ with open(path, "w") as f:
+ f.write("\n".join(map(" ".join, out)))
+
+
+def run(args):
+ # Setup the paths
+ dataset = args.dataset
+ images = dataset / "images/"
+
+ outputs = args.outputs # where everything will be saved
+ outputs.mkdir(exist_ok=True, parents=True)
+ query_list = outputs / "{condition}_queries_with_intrinsics.txt"
+ sift_sfm = outputs / "sfm_sift"
+ reference_sfm = outputs / "sfm_superpoint+superglue"
+ sfm_pairs = outputs / f"pairs-db-covis{args.num_covis}.txt"
+ loc_pairs = outputs / f"pairs-query-netvlad{args.num_loc}.txt"
+ results = outputs / f"RobotCar_hloc_superpoint+superglue_netvlad{args.num_loc}.txt"
+
+ # pick one of the configurations for extraction and matching
+ retrieval_conf = extract_features.confs["netvlad"]
+ feature_conf = extract_features.confs["superpoint_aachen"]
+ matcher_conf = match_features.confs["superglue"]
+
+ for condition in CONDITIONS:
+ generate_query_list(
+ dataset, images / condition, str(query_list).format(condition=condition)
+ )
+
+ features = extract_features.main(feature_conf, images, outputs, as_half=True)
+
+ colmap_from_nvm.main(
+ dataset / "3D-models/all-merged/all.nvm",
+ dataset / "3D-models/overcast-reference.db",
+ sift_sfm,
+ )
+ pairs_from_covisibility.main(sift_sfm, sfm_pairs, num_matched=args.num_covis)
+ sfm_matches = match_features.main(
+ matcher_conf, sfm_pairs, feature_conf["output"], outputs
+ )
+
+ triangulation.main(
+ reference_sfm, sift_sfm, images, sfm_pairs, features, sfm_matches
+ )
+
+ global_descriptors = extract_features.main(retrieval_conf, images, outputs)
+ # TODO: do per location and per camera
+ pairs_from_retrieval.main(
+ global_descriptors,
+ loc_pairs,
+ args.num_loc,
+ query_prefix=CONDITIONS,
+ db_model=reference_sfm,
+ )
+ loc_matches = match_features.main(
+ matcher_conf, loc_pairs, feature_conf["output"], outputs
+ )
+
+ localize_sfm.main(
+ reference_sfm,
+ Path(str(query_list).format(condition="*")),
+ loc_pairs,
+ features,
+ loc_matches,
+ results,
+ covisibility_clustering=False,
+ prepend_camera_name=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=Path,
+ default="datasets/robotcar",
+ help="Path to the dataset, default: %(default)s",
+ )
+ parser.add_argument(
+ "--outputs",
+ type=Path,
+ default="outputs/robotcar",
+ help="Path to the output directory, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_covis",
+ type=int,
+ default=20,
+ help="Number of image pairs for SfM, default: %(default)s",
+ )
+ parser.add_argument(
+ "--num_loc",
+ type=int,
+ default=20,
+ help="Number of image pairs for loc, default: %(default)s",
+ )
+ args = parser.parse_args()
diff --git a/imcui/hloc/pipelines/__init__.py b/imcui/hloc/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/hloc/reconstruction.py b/imcui/hloc/reconstruction.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1e7fc09c52cca2935c217e912bb077fe712e05
--- /dev/null
+++ b/imcui/hloc/reconstruction.py
@@ -0,0 +1,194 @@
+import argparse
+import multiprocessing
+import shutil
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import pycolmap
+
+from . import logger
+from .triangulation import (
+ OutputCapture,
+ estimation_and_geometric_verification,
+ import_features,
+ import_matches,
+ parse_option_args,
+)
+from .utils.database import COLMAPDatabase
+
+
+def create_empty_db(database_path: Path):
+ if database_path.exists():
+ logger.warning("The database already exists, deleting it.")
+ database_path.unlink()
+ logger.info("Creating an empty database...")
+ db = COLMAPDatabase.connect(database_path)
+ db.create_tables()
+ db.commit()
+ db.close()
+
+
+def import_images(
+ image_dir: Path,
+ database_path: Path,
+ camera_mode: pycolmap.CameraMode,
+ image_list: Optional[List[str]] = None,
+ options: Optional[Dict[str, Any]] = None,
+):
+ logger.info("Importing images into the database...")
+ if options is None:
+ options = {}
+ images = list(image_dir.iterdir())
+ if len(images) == 0:
+ raise IOError(f"No images found in {image_dir}.")
+ with pycolmap.ostream():
+ pycolmap.import_images(
+ database_path,
+ image_dir,
+ camera_mode,
+ image_list=image_list or [],
+ options=options,
+ )
+
+
+def get_image_ids(database_path: Path) -> Dict[str, int]:
+ db = COLMAPDatabase.connect(database_path)
+ images = {}
+ for name, image_id in db.execute("SELECT name, image_id FROM images;"):
+ images[name] = image_id
+ db.close()
+ return images
+
+
+def run_reconstruction(
+ sfm_dir: Path,
+ database_path: Path,
+ image_dir: Path,
+ verbose: bool = False,
+ options: Optional[Dict[str, Any]] = None,
+) -> pycolmap.Reconstruction:
+ models_path = sfm_dir / "models"
+ models_path.mkdir(exist_ok=True, parents=True)
+ logger.info("Running 3D reconstruction...")
+ if options is None:
+ options = {}
+ options = {"num_threads": min(multiprocessing.cpu_count(), 16), **options}
+ with OutputCapture(verbose):
+ with pycolmap.ostream():
+ reconstructions = pycolmap.incremental_mapping(
+ database_path, image_dir, models_path, options=options
+ )
+
+ if len(reconstructions) == 0:
+ logger.error("Could not reconstruct any model!")
+ return None
+ logger.info(f"Reconstructed {len(reconstructions)} model(s).")
+
+ largest_index = None
+ largest_num_images = 0
+ for index, rec in reconstructions.items():
+ num_images = rec.num_reg_images()
+ if num_images > largest_num_images:
+ largest_index = index
+ largest_num_images = num_images
+ assert largest_index is not None
+ logger.info(
+ f"Largest model is #{largest_index} " f"with {largest_num_images} images."
+ )
+
+ for filename in ["images.bin", "cameras.bin", "points3D.bin"]:
+ if (sfm_dir / filename).exists():
+ (sfm_dir / filename).unlink()
+ shutil.move(str(models_path / str(largest_index) / filename), str(sfm_dir))
+ return reconstructions[largest_index]
+
+
+def main(
+ sfm_dir: Path,
+ image_dir: Path,
+ pairs: Path,
+ features: Path,
+ matches: Path,
+ camera_mode: pycolmap.CameraMode = pycolmap.CameraMode.AUTO,
+ verbose: bool = False,
+ skip_geometric_verification: bool = False,
+ min_match_score: Optional[float] = None,
+ image_list: Optional[List[str]] = None,
+ image_options: Optional[Dict[str, Any]] = None,
+ mapper_options: Optional[Dict[str, Any]] = None,
+) -> pycolmap.Reconstruction:
+ assert features.exists(), features
+ assert pairs.exists(), pairs
+ assert matches.exists(), matches
+
+ sfm_dir.mkdir(parents=True, exist_ok=True)
+ database = sfm_dir / "database.db"
+
+ create_empty_db(database)
+ import_images(image_dir, database, camera_mode, image_list, image_options)
+ image_ids = get_image_ids(database)
+ import_features(image_ids, database, features)
+ import_matches(
+ image_ids,
+ database,
+ pairs,
+ matches,
+ min_match_score,
+ skip_geometric_verification,
+ )
+ if not skip_geometric_verification:
+ estimation_and_geometric_verification(database, pairs, verbose)
+ reconstruction = run_reconstruction(
+ sfm_dir, database, image_dir, verbose, mapper_options
+ )
+ if reconstruction is not None:
+ logger.info(
+ f"Reconstruction statistics:\n{reconstruction.summary()}"
+ + f"\n\tnum_input_images = {len(image_ids)}"
+ )
+ return reconstruction
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--sfm_dir", type=Path, required=True)
+ parser.add_argument("--image_dir", type=Path, required=True)
+
+ parser.add_argument("--pairs", type=Path, required=True)
+ parser.add_argument("--features", type=Path, required=True)
+ parser.add_argument("--matches", type=Path, required=True)
+
+ parser.add_argument(
+ "--camera_mode",
+ type=str,
+ default="AUTO",
+ choices=list(pycolmap.CameraMode.__members__.keys()),
+ )
+ parser.add_argument("--skip_geometric_verification", action="store_true")
+ parser.add_argument("--min_match_score", type=float)
+ parser.add_argument("--verbose", action="store_true")
+
+ parser.add_argument(
+ "--image_options",
+ nargs="+",
+ default=[],
+ help="List of key=value from {}".format(pycolmap.ImageReaderOptions().todict()),
+ )
+ parser.add_argument(
+ "--mapper_options",
+ nargs="+",
+ default=[],
+ help="List of key=value from {}".format(
+ pycolmap.IncrementalMapperOptions().todict()
+ ),
+ )
+ args = parser.parse_args().__dict__
+
+ image_options = parse_option_args(
+ args.pop("image_options"), pycolmap.ImageReaderOptions()
+ )
+ mapper_options = parse_option_args(
+ args.pop("mapper_options"), pycolmap.IncrementalMapperOptions()
+ )
+
+ main(**args, image_options=image_options, mapper_options=mapper_options)
diff --git a/imcui/hloc/triangulation.py b/imcui/hloc/triangulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..83203c38f4e4a2493b8b1b11773fb2140d76b8bc
--- /dev/null
+++ b/imcui/hloc/triangulation.py
@@ -0,0 +1,311 @@
+import argparse
+import contextlib
+import io
+import sys
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import pycolmap
+from tqdm import tqdm
+
+from . import logger
+from .utils.database import COLMAPDatabase
+from .utils.geometry import compute_epipolar_errors
+from .utils.io import get_keypoints, get_matches
+from .utils.parsers import parse_retrieval
+
+
+class OutputCapture:
+ def __init__(self, verbose: bool):
+ self.verbose = verbose
+
+ def __enter__(self):
+ if not self.verbose:
+ self.capture = contextlib.redirect_stdout(io.StringIO())
+ self.out = self.capture.__enter__()
+
+ def __exit__(self, exc_type, *args):
+ if not self.verbose:
+ self.capture.__exit__(exc_type, *args)
+ if exc_type is not None:
+ logger.error("Failed with output:\n%s", self.out.getvalue())
+ sys.stdout.flush()
+
+
+def create_db_from_model(
+ reconstruction: pycolmap.Reconstruction, database_path: Path
+) -> Dict[str, int]:
+ if database_path.exists():
+ logger.warning("The database already exists, deleting it.")
+ database_path.unlink()
+
+ db = COLMAPDatabase.connect(database_path)
+ db.create_tables()
+
+ for i, camera in reconstruction.cameras.items():
+ db.add_camera(
+ camera.model.value,
+ camera.width,
+ camera.height,
+ camera.params,
+ camera_id=i,
+ prior_focal_length=True,
+ )
+
+ for i, image in reconstruction.images.items():
+ db.add_image(image.name, image.camera_id, image_id=i)
+
+ db.commit()
+ db.close()
+ return {image.name: i for i, image in reconstruction.images.items()}
+
+
+def import_features(
+ image_ids: Dict[str, int], database_path: Path, features_path: Path
+):
+ logger.info("Importing features into the database...")
+ db = COLMAPDatabase.connect(database_path)
+
+ for image_name, image_id in tqdm(image_ids.items()):
+ keypoints = get_keypoints(features_path, image_name)
+ keypoints += 0.5 # COLMAP origin
+ db.add_keypoints(image_id, keypoints)
+
+ db.commit()
+ db.close()
+
+
+def import_matches(
+ image_ids: Dict[str, int],
+ database_path: Path,
+ pairs_path: Path,
+ matches_path: Path,
+ min_match_score: Optional[float] = None,
+ skip_geometric_verification: bool = False,
+):
+ logger.info("Importing matches into the database...")
+
+ with open(str(pairs_path), "r") as f:
+ pairs = [p.split() for p in f.readlines()]
+
+ db = COLMAPDatabase.connect(database_path)
+
+ matched = set()
+ for name0, name1 in tqdm(pairs):
+ id0, id1 = image_ids[name0], image_ids[name1]
+ if len({(id0, id1), (id1, id0)} & matched) > 0:
+ continue
+ matches, scores = get_matches(matches_path, name0, name1)
+ if min_match_score:
+ matches = matches[scores > min_match_score]
+ db.add_matches(id0, id1, matches)
+ matched |= {(id0, id1), (id1, id0)}
+
+ if skip_geometric_verification:
+ db.add_two_view_geometry(id0, id1, matches)
+
+ db.commit()
+ db.close()
+
+
+def estimation_and_geometric_verification(
+ database_path: Path, pairs_path: Path, verbose: bool = False
+):
+ logger.info("Performing geometric verification of the matches...")
+ with OutputCapture(verbose):
+ with pycolmap.ostream():
+ pycolmap.verify_matches(
+ database_path,
+ pairs_path,
+ options=dict(ransac=dict(max_num_trials=20000, min_inlier_ratio=0.1)),
+ )
+
+
+def geometric_verification(
+ image_ids: Dict[str, int],
+ reference: pycolmap.Reconstruction,
+ database_path: Path,
+ features_path: Path,
+ pairs_path: Path,
+ matches_path: Path,
+ max_error: float = 4.0,
+):
+ logger.info("Performing geometric verification of the matches...")
+
+ pairs = parse_retrieval(pairs_path)
+ db = COLMAPDatabase.connect(database_path)
+
+ inlier_ratios = []
+ matched = set()
+ for name0 in tqdm(pairs):
+ id0 = image_ids[name0]
+ image0 = reference.images[id0]
+ cam0 = reference.cameras[image0.camera_id]
+ kps0, noise0 = get_keypoints(features_path, name0, return_uncertainty=True)
+ noise0 = 1.0 if noise0 is None else noise0
+ if len(kps0) > 0:
+ kps0 = np.stack(cam0.cam_from_img(kps0))
+ else:
+ kps0 = np.zeros((0, 2))
+
+ for name1 in pairs[name0]:
+ id1 = image_ids[name1]
+ image1 = reference.images[id1]
+ cam1 = reference.cameras[image1.camera_id]
+ kps1, noise1 = get_keypoints(features_path, name1, return_uncertainty=True)
+ noise1 = 1.0 if noise1 is None else noise1
+ if len(kps1) > 0:
+ kps1 = np.stack(cam1.cam_from_img(kps1))
+ else:
+ kps1 = np.zeros((0, 2))
+
+ matches = get_matches(matches_path, name0, name1)[0]
+
+ if len({(id0, id1), (id1, id0)} & matched) > 0:
+ continue
+ matched |= {(id0, id1), (id1, id0)}
+
+ if matches.shape[0] == 0:
+ db.add_two_view_geometry(id0, id1, matches)
+ continue
+
+ cam1_from_cam0 = image1.cam_from_world * image0.cam_from_world.inverse()
+ errors0, errors1 = compute_epipolar_errors(
+ cam1_from_cam0, kps0[matches[:, 0]], kps1[matches[:, 1]]
+ )
+ valid_matches = np.logical_and(
+ errors0 <= cam0.cam_from_img_threshold(noise0 * max_error),
+ errors1 <= cam1.cam_from_img_threshold(noise1 * max_error),
+ )
+ # TODO: We could also add E to the database, but we need
+ # to reverse the transformations if id0 > id1 in utils/database.py.
+ db.add_two_view_geometry(id0, id1, matches[valid_matches, :])
+ inlier_ratios.append(np.mean(valid_matches))
+ logger.info(
+ "mean/med/min/max valid matches %.2f/%.2f/%.2f/%.2f%%.",
+ np.mean(inlier_ratios) * 100,
+ np.median(inlier_ratios) * 100,
+ np.min(inlier_ratios) * 100,
+ np.max(inlier_ratios) * 100,
+ )
+
+ db.commit()
+ db.close()
+
+
+def run_triangulation(
+ model_path: Path,
+ database_path: Path,
+ image_dir: Path,
+ reference_model: pycolmap.Reconstruction,
+ verbose: bool = False,
+ options: Optional[Dict[str, Any]] = None,
+) -> pycolmap.Reconstruction:
+ model_path.mkdir(parents=True, exist_ok=True)
+ logger.info("Running 3D triangulation...")
+ if options is None:
+ options = {}
+ with OutputCapture(verbose):
+ with pycolmap.ostream():
+ reconstruction = pycolmap.triangulate_points(
+ reference_model,
+ database_path,
+ image_dir,
+ model_path,
+ options=options,
+ )
+ return reconstruction
+
+
+def main(
+ sfm_dir: Path,
+ reference_model: Path,
+ image_dir: Path,
+ pairs: Path,
+ features: Path,
+ matches: Path,
+ skip_geometric_verification: bool = False,
+ estimate_two_view_geometries: bool = False,
+ min_match_score: Optional[float] = None,
+ verbose: bool = False,
+ mapper_options: Optional[Dict[str, Any]] = None,
+) -> pycolmap.Reconstruction:
+ assert reference_model.exists(), reference_model
+ assert features.exists(), features
+ assert pairs.exists(), pairs
+ assert matches.exists(), matches
+
+ sfm_dir.mkdir(parents=True, exist_ok=True)
+ database = sfm_dir / "database.db"
+ reference = pycolmap.Reconstruction(reference_model)
+
+ image_ids = create_db_from_model(reference, database)
+ import_features(image_ids, database, features)
+ import_matches(
+ image_ids,
+ database,
+ pairs,
+ matches,
+ min_match_score,
+ skip_geometric_verification,
+ )
+ if not skip_geometric_verification:
+ if estimate_two_view_geometries:
+ estimation_and_geometric_verification(database, pairs, verbose)
+ else:
+ geometric_verification(
+ image_ids, reference, database, features, pairs, matches
+ )
+ reconstruction = run_triangulation(
+ sfm_dir, database, image_dir, reference, verbose, mapper_options
+ )
+ logger.info(
+ "Finished the triangulation with statistics:\n%s",
+ reconstruction.summary(),
+ )
+ return reconstruction
+
+
+def parse_option_args(args: List[str], default_options) -> Dict[str, Any]:
+ options = {}
+ for arg in args:
+ idx = arg.find("=")
+ if idx == -1:
+ raise ValueError("Options format: key1=value1 key2=value2 etc.")
+ key, value = arg[:idx], arg[idx + 1 :]
+ if not hasattr(default_options, key):
+ raise ValueError(
+ f'Unknown option "{key}", allowed options and default values'
+ f" for {default_options.summary()}"
+ )
+ value = eval(value)
+ target_type = type(getattr(default_options, key))
+ if not isinstance(value, target_type):
+ raise ValueError(
+ f'Incorrect type for option "{key}":' f" {type(value)} vs {target_type}"
+ )
+ options[key] = value
+ return options
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--sfm_dir", type=Path, required=True)
+ parser.add_argument("--reference_sfm_model", type=Path, required=True)
+ parser.add_argument("--image_dir", type=Path, required=True)
+
+ parser.add_argument("--pairs", type=Path, required=True)
+ parser.add_argument("--features", type=Path, required=True)
+ parser.add_argument("--matches", type=Path, required=True)
+
+ parser.add_argument("--skip_geometric_verification", action="store_true")
+ parser.add_argument("--min_match_score", type=float)
+ parser.add_argument("--verbose", action="store_true")
+ args = parser.parse_args().__dict__
+
+ mapper_options = parse_option_args(
+ args.pop("mapper_options"), pycolmap.IncrementalMapperOptions()
+ )
+
+ main(**args, mapper_options=mapper_options)
diff --git a/imcui/hloc/utils/__init__.py b/imcui/hloc/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b030ce404e986f2dcf81cf39640cb8e841e87a
--- /dev/null
+++ b/imcui/hloc/utils/__init__.py
@@ -0,0 +1,12 @@
+import os
+import sys
+from .. import logger
+
+
+def do_system(cmd, verbose=False):
+ if verbose:
+ logger.info(f"Run cmd: `{cmd}`.")
+ err = os.system(cmd)
+ if err:
+ logger.info("Run cmd err.")
+ sys.exit(err)
diff --git a/imcui/hloc/utils/base_model.py b/imcui/hloc/utils/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6cf3971f8ea8bc6c4bf6081f82c4fd9cc4c22b6
--- /dev/null
+++ b/imcui/hloc/utils/base_model.py
@@ -0,0 +1,56 @@
+import sys
+from abc import ABCMeta, abstractmethod
+from torch import nn
+from copy import copy
+import inspect
+from huggingface_hub import hf_hub_download
+
+
+class BaseModel(nn.Module, metaclass=ABCMeta):
+ default_conf = {}
+ required_inputs = []
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ super().__init__()
+ self.conf = conf = {**self.default_conf, **conf}
+ self.required_inputs = copy(self.required_inputs)
+ self._init(conf)
+ sys.stdout.flush()
+
+ def forward(self, data):
+ """Check the data and call the _forward method of the child model."""
+ for key in self.required_inputs:
+ assert key in data, "Missing key {} in data".format(key)
+ return self._forward(data)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _forward(self, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ def _download_model(self, repo_id=None, filename=None, **kwargs):
+ """Download model from hf hub and return the path."""
+ return hf_hub_download(
+ repo_type="model",
+ repo_id=repo_id,
+ filename=filename,
+ )
+
+
+def dynamic_load(root, model):
+ module_path = f"{root.__name__}.{model}"
+ module = __import__(module_path, fromlist=[""])
+ classes = inspect.getmembers(module, inspect.isclass)
+ # Filter classes defined in the module
+ classes = [c for c in classes if c[1].__module__ == module_path]
+ # Filter classes inherited from BaseModel
+ classes = [c for c in classes if issubclass(c[1], BaseModel)]
+ assert len(classes) == 1, classes
+ return classes[0][1]
+ # return getattr(module, 'Model')
diff --git a/imcui/hloc/utils/database.py b/imcui/hloc/utils/database.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2e5e0b7342677757f4b654c1aaeaa76cfe68187
--- /dev/null
+++ b/imcui/hloc/utils/database.py
@@ -0,0 +1,412 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+# This script is based on an original implementation by True Price.
+
+import sqlite3
+import sys
+
+import numpy as np
+
+IS_PYTHON3 = sys.version_info[0] >= 3
+
+MAX_IMAGE_ID = 2**31 - 1
+
+CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
+ camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ model INTEGER NOT NULL,
+ width INTEGER NOT NULL,
+ height INTEGER NOT NULL,
+ params BLOB,
+ prior_focal_length INTEGER NOT NULL)"""
+
+CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
+
+CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
+ image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
+ name TEXT NOT NULL UNIQUE,
+ camera_id INTEGER NOT NULL,
+ prior_qw REAL,
+ prior_qx REAL,
+ prior_qy REAL,
+ prior_qz REAL,
+ prior_tx REAL,
+ prior_ty REAL,
+ prior_tz REAL,
+ CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
+ FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
+""".format(MAX_IMAGE_ID)
+
+CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
+CREATE TABLE IF NOT EXISTS two_view_geometries (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ config INTEGER NOT NULL,
+ F BLOB,
+ E BLOB,
+ H BLOB,
+ qvec BLOB,
+ tvec BLOB)
+"""
+
+CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
+ image_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB,
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
+"""
+
+CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
+ pair_id INTEGER PRIMARY KEY NOT NULL,
+ rows INTEGER NOT NULL,
+ cols INTEGER NOT NULL,
+ data BLOB)"""
+
+CREATE_NAME_INDEX = "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
+
+CREATE_ALL = "; ".join(
+ [
+ CREATE_CAMERAS_TABLE,
+ CREATE_IMAGES_TABLE,
+ CREATE_KEYPOINTS_TABLE,
+ CREATE_DESCRIPTORS_TABLE,
+ CREATE_MATCHES_TABLE,
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE,
+ CREATE_NAME_INDEX,
+ ]
+)
+
+
+def image_ids_to_pair_id(image_id1, image_id2):
+ if image_id1 > image_id2:
+ image_id1, image_id2 = image_id2, image_id1
+ return image_id1 * MAX_IMAGE_ID + image_id2
+
+
+def pair_id_to_image_ids(pair_id):
+ image_id2 = pair_id % MAX_IMAGE_ID
+ image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
+ return image_id1, image_id2
+
+
+def array_to_blob(array):
+ if IS_PYTHON3:
+ return array.tobytes()
+ else:
+ return np.getbuffer(array)
+
+
+def blob_to_array(blob, dtype, shape=(-1,)):
+ if IS_PYTHON3:
+ return np.fromstring(blob, dtype=dtype).reshape(*shape)
+ else:
+ return np.frombuffer(blob, dtype=dtype).reshape(*shape)
+
+
+class COLMAPDatabase(sqlite3.Connection):
+ @staticmethod
+ def connect(database_path):
+ return sqlite3.connect(str(database_path), factory=COLMAPDatabase)
+
+ def __init__(self, *args, **kwargs):
+ super(COLMAPDatabase, self).__init__(*args, **kwargs)
+
+ self.create_tables = lambda: self.executescript(CREATE_ALL)
+ self.create_cameras_table = lambda: self.executescript(CREATE_CAMERAS_TABLE)
+ self.create_descriptors_table = lambda: self.executescript(
+ CREATE_DESCRIPTORS_TABLE
+ )
+ self.create_images_table = lambda: self.executescript(CREATE_IMAGES_TABLE)
+ self.create_two_view_geometries_table = lambda: self.executescript(
+ CREATE_TWO_VIEW_GEOMETRIES_TABLE
+ )
+ self.create_keypoints_table = lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
+ self.create_matches_table = lambda: self.executescript(CREATE_MATCHES_TABLE)
+ self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
+
+ def add_camera(
+ self, model, width, height, params, prior_focal_length=False, camera_id=None
+ ):
+ params = np.asarray(params, np.float64)
+ cursor = self.execute(
+ "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
+ (
+ camera_id,
+ model,
+ width,
+ height,
+ array_to_blob(params),
+ prior_focal_length,
+ ),
+ )
+ return cursor.lastrowid
+
+ def add_image(
+ self,
+ name,
+ camera_id,
+ prior_q=np.full(4, np.NaN),
+ prior_t=np.full(3, np.NaN),
+ image_id=None,
+ ):
+ cursor = self.execute(
+ "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (
+ image_id,
+ name,
+ camera_id,
+ prior_q[0],
+ prior_q[1],
+ prior_q[2],
+ prior_q[3],
+ prior_t[0],
+ prior_t[1],
+ prior_t[2],
+ ),
+ )
+ return cursor.lastrowid
+
+ def add_keypoints(self, image_id, keypoints):
+ assert len(keypoints.shape) == 2
+ assert keypoints.shape[1] in [2, 4, 6]
+
+ keypoints = np.asarray(keypoints, np.float32)
+ self.execute(
+ "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
+ (image_id,) + keypoints.shape + (array_to_blob(keypoints),),
+ )
+
+ def add_descriptors(self, image_id, descriptors):
+ descriptors = np.ascontiguousarray(descriptors, np.uint8)
+ self.execute(
+ "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
+ (image_id,) + descriptors.shape + (array_to_blob(descriptors),),
+ )
+
+ def add_matches(self, image_id1, image_id2, matches):
+ assert len(matches.shape) == 2
+ assert matches.shape[1] == 2
+
+ if image_id1 > image_id2:
+ matches = matches[:, ::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ self.execute(
+ "INSERT INTO matches VALUES (?, ?, ?, ?)",
+ (pair_id,) + matches.shape + (array_to_blob(matches),),
+ )
+
+ def add_two_view_geometry(
+ self,
+ image_id1,
+ image_id2,
+ matches,
+ F=np.eye(3),
+ E=np.eye(3),
+ H=np.eye(3),
+ qvec=np.array([1.0, 0.0, 0.0, 0.0]),
+ tvec=np.zeros(3),
+ config=2,
+ ):
+ assert len(matches.shape) == 2
+ assert matches.shape[1] == 2
+
+ if image_id1 > image_id2:
+ matches = matches[:, ::-1]
+
+ pair_id = image_ids_to_pair_id(image_id1, image_id2)
+ matches = np.asarray(matches, np.uint32)
+ F = np.asarray(F, dtype=np.float64)
+ E = np.asarray(E, dtype=np.float64)
+ H = np.asarray(H, dtype=np.float64)
+ qvec = np.asarray(qvec, dtype=np.float64)
+ tvec = np.asarray(tvec, dtype=np.float64)
+ self.execute(
+ "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (pair_id,)
+ + matches.shape
+ + (
+ array_to_blob(matches),
+ config,
+ array_to_blob(F),
+ array_to_blob(E),
+ array_to_blob(H),
+ array_to_blob(qvec),
+ array_to_blob(tvec),
+ ),
+ )
+
+
+def example_usage():
+ import os
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--database_path", default="database.db")
+ args = parser.parse_args()
+
+ if os.path.exists(args.database_path):
+ print("ERROR: database path already exists -- will not modify it.")
+ return
+
+ # Open the database.
+
+ db = COLMAPDatabase.connect(args.database_path)
+
+ # For convenience, try creating all the tables upfront.
+
+ db.create_tables()
+
+ # Create dummy cameras.
+
+ model1, width1, height1, params1 = (
+ 0,
+ 1024,
+ 768,
+ np.array((1024.0, 512.0, 384.0)),
+ )
+ model2, width2, height2, params2 = (
+ 2,
+ 1024,
+ 768,
+ np.array((1024.0, 512.0, 384.0, 0.1)),
+ )
+
+ camera_id1 = db.add_camera(model1, width1, height1, params1)
+ camera_id2 = db.add_camera(model2, width2, height2, params2)
+
+ # Create dummy images.
+
+ image_id1 = db.add_image("image1.png", camera_id1)
+ image_id2 = db.add_image("image2.png", camera_id1)
+ image_id3 = db.add_image("image3.png", camera_id2)
+ image_id4 = db.add_image("image4.png", camera_id2)
+
+ # Create dummy keypoints.
+ #
+ # Note that COLMAP supports:
+ # - 2D keypoints: (x, y)
+ # - 4D keypoints: (x, y, theta, scale)
+ # - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22)
+
+ num_keypoints = 1000
+ keypoints1 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints2 = np.random.rand(num_keypoints, 2) * (width1, height1)
+ keypoints3 = np.random.rand(num_keypoints, 2) * (width2, height2)
+ keypoints4 = np.random.rand(num_keypoints, 2) * (width2, height2)
+
+ db.add_keypoints(image_id1, keypoints1)
+ db.add_keypoints(image_id2, keypoints2)
+ db.add_keypoints(image_id3, keypoints3)
+ db.add_keypoints(image_id4, keypoints4)
+
+ # Create dummy matches.
+
+ M = 50
+ matches12 = np.random.randint(num_keypoints, size=(M, 2))
+ matches23 = np.random.randint(num_keypoints, size=(M, 2))
+ matches34 = np.random.randint(num_keypoints, size=(M, 2))
+
+ db.add_matches(image_id1, image_id2, matches12)
+ db.add_matches(image_id2, image_id3, matches23)
+ db.add_matches(image_id3, image_id4, matches34)
+
+ # Commit the data to the file.
+
+ db.commit()
+
+ # Read and check cameras.
+
+ rows = db.execute("SELECT * FROM cameras")
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id1
+ assert model == model1 and width == width1 and height == height1
+ assert np.allclose(params, params1)
+
+ camera_id, model, width, height, params, prior = next(rows)
+ params = blob_to_array(params, np.float64)
+ assert camera_id == camera_id2
+ assert model == model2 and width == width2 and height == height2
+ assert np.allclose(params, params2)
+
+ # Read and check keypoints.
+
+ keypoints = dict(
+ (image_id, blob_to_array(data, np.float32, (-1, 2)))
+ for image_id, data in db.execute("SELECT image_id, data FROM keypoints")
+ )
+
+ assert np.allclose(keypoints[image_id1], keypoints1)
+ assert np.allclose(keypoints[image_id2], keypoints2)
+ assert np.allclose(keypoints[image_id3], keypoints3)
+ assert np.allclose(keypoints[image_id4], keypoints4)
+
+ # Read and check matches.
+
+ pair_ids = [ # noqa: F841
+ image_ids_to_pair_id(*pair)
+ for pair in (
+ (image_id1, image_id2),
+ (image_id2, image_id3),
+ (image_id3, image_id4),
+ )
+ ]
+
+ matches = dict(
+ (pair_id_to_image_ids(pair_id), blob_to_array(data, np.uint32, (-1, 2)))
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches")
+ )
+
+ assert np.all(matches[(image_id1, image_id2)] == matches12)
+ assert np.all(matches[(image_id2, image_id3)] == matches23)
+ assert np.all(matches[(image_id3, image_id4)] == matches34)
+
+ # Clean up.
+
+ db.close()
+
+ if os.path.exists(args.database_path):
+ os.remove(args.database_path)
+
+
+if __name__ == "__main__":
+ example_usage()
diff --git a/imcui/hloc/utils/geometry.py b/imcui/hloc/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5995cccda40354f2346ad84fa966614b4ccbfed0
--- /dev/null
+++ b/imcui/hloc/utils/geometry.py
@@ -0,0 +1,16 @@
+import numpy as np
+import pycolmap
+
+
+def to_homogeneous(p):
+ return np.pad(p, ((0, 0),) * (p.ndim - 1) + ((0, 1),), constant_values=1)
+
+
+def compute_epipolar_errors(j_from_i: pycolmap.Rigid3d, p2d_i, p2d_j):
+ j_E_i = j_from_i.essential_matrix()
+ l2d_j = to_homogeneous(p2d_i) @ j_E_i.T
+ l2d_i = to_homogeneous(p2d_j) @ j_E_i
+ dist = np.abs(np.sum(to_homogeneous(p2d_i) * l2d_i, axis=1))
+ errors_i = dist / np.linalg.norm(l2d_i[:, :2], axis=1)
+ errors_j = dist / np.linalg.norm(l2d_j[:, :2], axis=1)
+ return errors_i, errors_j
diff --git a/imcui/hloc/utils/io.py b/imcui/hloc/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd55d4c30b41c3754634a164312dc5e8c294274
--- /dev/null
+++ b/imcui/hloc/utils/io.py
@@ -0,0 +1,77 @@
+from typing import Tuple
+from pathlib import Path
+import numpy as np
+import cv2
+import h5py
+
+from .parsers import names_to_pair, names_to_pair_old
+
+
+def read_image(path, grayscale=False):
+ if grayscale:
+ mode = cv2.IMREAD_GRAYSCALE
+ else:
+ mode = cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise ValueError(f"Cannot read image {path}.")
+ if not grayscale and len(image.shape) == 3:
+ image = image[:, :, ::-1] # BGR to RGB
+ return image
+
+
+def list_h5_names(path):
+ names = []
+ with h5py.File(str(path), "r", libver="latest") as fd:
+
+ def visit_fn(_, obj):
+ if isinstance(obj, h5py.Dataset):
+ names.append(obj.parent.name.strip("/"))
+
+ fd.visititems(visit_fn)
+ return list(set(names))
+
+
+def get_keypoints(
+ path: Path, name: str, return_uncertainty: bool = False
+) -> np.ndarray:
+ with h5py.File(str(path), "r", libver="latest") as hfile:
+ dset = hfile[name]["keypoints"]
+ p = dset.__array__()
+ uncertainty = dset.attrs.get("uncertainty")
+ if return_uncertainty:
+ return p, uncertainty
+ return p
+
+
+def find_pair(hfile: h5py.File, name0: str, name1: str):
+ pair = names_to_pair(name0, name1)
+ if pair in hfile:
+ return pair, False
+ pair = names_to_pair(name1, name0)
+ if pair in hfile:
+ return pair, True
+ # older, less efficient format
+ pair = names_to_pair_old(name0, name1)
+ if pair in hfile:
+ return pair, False
+ pair = names_to_pair_old(name1, name0)
+ if pair in hfile:
+ return pair, True
+ raise ValueError(
+ f"Could not find pair {(name0, name1)}... "
+ "Maybe you matched with a different list of pairs? "
+ )
+
+
+def get_matches(path: Path, name0: str, name1: str) -> Tuple[np.ndarray]:
+ with h5py.File(str(path), "r", libver="latest") as hfile:
+ pair, reverse = find_pair(hfile, name0, name1)
+ matches = hfile[pair]["matches0"].__array__()
+ scores = hfile[pair]["matching_scores0"].__array__()
+ idx = np.where(matches != -1)[0]
+ matches = np.stack([idx, matches[idx]], -1)
+ if reverse:
+ matches = np.flip(matches, -1)
+ scores = scores[idx]
+ return matches, scores
diff --git a/imcui/hloc/utils/parsers.py b/imcui/hloc/utils/parsers.py
new file mode 100644
index 0000000000000000000000000000000000000000..9407dcf916d67170e3f8d19581041a774a61e84b
--- /dev/null
+++ b/imcui/hloc/utils/parsers.py
@@ -0,0 +1,59 @@
+import logging
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+import pycolmap
+
+logger = logging.getLogger(__name__)
+
+
+def parse_image_list(path, with_intrinsics=False):
+ images = []
+ with open(path, "r") as f:
+ for line in f:
+ line = line.strip("\n")
+ if len(line) == 0 or line[0] == "#":
+ continue
+ name, *data = line.split()
+ if with_intrinsics:
+ model, width, height, *params = data
+ params = np.array(params, float)
+ cam = pycolmap.Camera(
+ model=model, width=int(width), height=int(height), params=params
+ )
+ images.append((name, cam))
+ else:
+ images.append(name)
+
+ assert len(images) > 0
+ logger.info(f"Imported {len(images)} images from {path.name}")
+ return images
+
+
+def parse_image_lists(paths, with_intrinsics=False):
+ images = []
+ files = list(Path(paths.parent).glob(paths.name))
+ assert len(files) > 0
+ for lfile in files:
+ images += parse_image_list(lfile, with_intrinsics=with_intrinsics)
+ return images
+
+
+def parse_retrieval(path):
+ retrieval = defaultdict(list)
+ with open(path, "r") as f:
+ for p in f.read().rstrip("\n").split("\n"):
+ if len(p) == 0:
+ continue
+ q, r = p.split()
+ retrieval[q].append(r)
+ return dict(retrieval)
+
+
+def names_to_pair(name0, name1, separator="/"):
+ return separator.join((name0.replace("/", "-"), name1.replace("/", "-")))
+
+
+def names_to_pair_old(name0, name1):
+ return names_to_pair(name0, name1, separator="_")
diff --git a/imcui/hloc/utils/read_write_model.py b/imcui/hloc/utils/read_write_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..197921ded6d9cad3f365fd68a225822dc5411aee
--- /dev/null
+++ b/imcui/hloc/utils/read_write_model.py
@@ -0,0 +1,588 @@
+# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
+
+import argparse
+import collections
+import logging
+import os
+import struct
+
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"]
+)
+Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
+)
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
+)
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
+}
+CAMERA_MODEL_IDS = dict(
+ [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
+)
+CAMERA_MODEL_NAMES = dict(
+ [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
+)
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
+ """pack and write to a binary file.
+ :param fid:
+ :param data: data to send, if multiple elements are sent at the same time,
+ they should be encapsuled either in a list or a tuple
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ should be the same length as the data list or tuple
+ :param endian_character: Any of {@, =, <, >, !}
+ """
+ if isinstance(data, (list, tuple)):
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
+ else:
+ bytes = struct.pack(endian_character + format_char_sequence, data)
+ fid.write(bytes)
+
+
+def read_cameras_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(
+ id=camera_id, model=model, width=width, height=height, params=params
+ )
+ return cameras
+
+
+def read_cameras_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ"
+ )
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(
+ fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params
+ )
+ cameras[camera_id] = Camera(
+ id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params),
+ )
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def write_cameras_text(cameras, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ HEADER = (
+ "# Camera list with one line of data per camera:\n"
+ + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
+ + "# Number of cameras: {}\n".format(len(cameras))
+ )
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, cam in cameras.items():
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
+ line = " ".join([str(elem) for elem in to_write])
+ fid.write(line + "\n")
+
+
+def write_cameras_binary(cameras, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(cameras), "Q")
+ for _, cam in cameras.items():
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
+ camera_properties = [cam.id, model_id, cam.width, cam.height]
+ write_next_bytes(fid, camera_properties, "iiQQ")
+ for p in cam.params:
+ write_next_bytes(fid, float(p), "d")
+ return cameras
+
+
+def read_images_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack(
+ [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]
+ )
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def read_images_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi"
+ )
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
+ 0
+ ]
+ x_y_id_s = read_next_bytes(
+ fid,
+ num_bytes=24 * num_points2D,
+ format_char_sequence="ddq" * num_points2D,
+ )
+ xys = np.column_stack(
+ [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]
+ )
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def write_images_text(images, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ if len(images) == 0:
+ mean_observations = 0
+ else:
+ mean_observations = sum(
+ (len(img.point3D_ids) for _, img in images.items())
+ ) / len(images)
+ HEADER = (
+ "# Image list with two lines of data per image:\n"
+ + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
+ + "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
+ + "# Number of images: {}, mean observations per image: {}\n".format(
+ len(images), mean_observations
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, img in images.items():
+ image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
+ first_line = " ".join(map(str, image_header))
+ fid.write(first_line + "\n")
+
+ points_strings = []
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
+ fid.write(" ".join(points_strings) + "\n")
+
+
+def write_images_binary(images, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_points3D_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ points3D = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ point3D_id = int(elems[0])
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = float(elems[7])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def read_points3D_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd"
+ )
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
+ 0
+ ]
+ track_elems = read_next_bytes(
+ fid,
+ num_bytes=8 * track_length,
+ format_char_sequence="ii" * track_length,
+ )
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def write_points3D_text(points3D, path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ if len(points3D) == 0:
+ mean_track_length = 0
+ else:
+ mean_track_length = sum(
+ (len(pt.image_ids) for _, pt in points3D.items())
+ ) / len(points3D)
+ HEADER = (
+ "# 3D point list with one line of data per point:\n"
+ + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" # noqa: E501
+ + "# Number of points: {}, mean track length: {}\n".format(
+ len(points3D), mean_track_length
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, pt in points3D.items():
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
+ fid.write(" ".join(map(str, point_header)) + " ")
+ track_strings = []
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
+ fid.write(" ".join(track_strings) + "\n")
+
+
+def write_points3D_binary(points3D, path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
+
+
+def detect_model_format(path, ext):
+ if (
+ os.path.isfile(os.path.join(path, "cameras" + ext))
+ and os.path.isfile(os.path.join(path, "images" + ext))
+ and os.path.isfile(os.path.join(path, "points3D" + ext))
+ ):
+ return True
+
+ return False
+
+
+def read_model(path, ext=""):
+ # try to detect the extension automatically
+ if ext == "":
+ if detect_model_format(path, ".bin"):
+ ext = ".bin"
+ elif detect_model_format(path, ".txt"):
+ ext = ".txt"
+ else:
+ try:
+ cameras, images, points3D = read_model(os.path.join(path, "model/"))
+ logger.warning("This SfM file structure was deprecated in hloc v1.1")
+ return cameras, images, points3D
+ except FileNotFoundError:
+ raise FileNotFoundError(
+ f"Could not find binary or text COLMAP model at {path}"
+ )
+
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def write_model(cameras, images, points3D, path, ext=".bin"):
+ if ext == ".txt":
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
+ write_images_text(images, os.path.join(path, "images" + ext))
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
+ else:
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
+ write_images_binary(images, os.path.join(path, "images" + ext))
+ write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def qvec2rotmat(qvec):
+ return np.array(
+ [
+ [
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
+ ],
+ [
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
+ ],
+ [
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
+ ],
+ ]
+ )
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = (
+ np.array(
+ [
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
+ ]
+ )
+ / 3.0
+ )
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Read and write COLMAP binary and text models"
+ )
+ parser.add_argument("--input_model", help="path to input model folder")
+ parser.add_argument(
+ "--input_format",
+ choices=[".bin", ".txt"],
+ help="input model format",
+ default="",
+ )
+ parser.add_argument("--output_model", help="path to output model folder")
+ parser.add_argument(
+ "--output_format",
+ choices=[".bin", ".txt"],
+ help="outut model format",
+ default=".txt",
+ )
+ args = parser.parse_args()
+
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
+
+ print("num_cameras:", len(cameras))
+ print("num_images:", len(images))
+ print("num_points3D:", len(points3D))
+
+ if args.output_model is not None:
+ write_model(
+ cameras, images, points3D, path=args.output_model, ext=args.output_format
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/imcui/hloc/utils/viz.py b/imcui/hloc/utils/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..f87b51706652e47e6f8fe7f5f67fc5362a970ecd
--- /dev/null
+++ b/imcui/hloc/utils/viz.py
@@ -0,0 +1,145 @@
+"""
+2D visualization primitives based on Matplotlib.
+
+1) Plot images with `plot_images`.
+2) Call `plot_keypoints` or `plot_matches` any number of times.
+3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
+"""
+
+import matplotlib
+import matplotlib.pyplot as plt
+import matplotlib.patheffects as path_effects
+import numpy as np
+
+
+def cm_RdGn(x):
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
+ x = np.clip(x, 0, 1)[..., None] * 2
+ c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
+ return np.clip(c, 0, 1)
+
+
+def plot_images(
+ imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True, figsize=4.5
+):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ adaptive: whether the figure size should fit the image aspect ratios.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ if adaptive:
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
+ else:
+ ratios = [4 / 3] * n
+ figsize = [sum(ratios) * figsize, figsize]
+ fig, axs = plt.subplots(
+ 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+ )
+ if n == 1:
+ axs = [axs]
+ for i, (img, ax) in enumerate(zip(imgs, axs)):
+ ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
+ ax.set_axis_off()
+ if titles:
+ ax.set_title(titles[i])
+ fig.tight_layout(pad=pad)
+ return fig
+
+
+def plot_keypoints(kpts, colors="lime", ps=4):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ axes = plt.gcf().axes
+ try:
+ for a, k, c in zip(axes, kpts, colors):
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
+ except IndexError:
+ pass
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ for i in range(len(kpts0)):
+ fig.add_artist(
+ matplotlib.patches.ConnectionPatch(
+ xyA=(kpts0[i, 0], kpts0[i, 1]),
+ coordsA=ax0.transData,
+ xyB=(kpts1[i, 0], kpts1[i, 1]),
+ coordsB=ax1.transData,
+ zorder=1,
+ color=color[i],
+ linewidth=lw,
+ alpha=a,
+ )
+ )
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def add_text(
+ idx,
+ text,
+ pos=(0.01, 0.99),
+ fs=15,
+ color="w",
+ lcolor="k",
+ lwidth=2,
+ ha="left",
+ va="top",
+):
+ ax = plt.gcf().axes[idx]
+ t = ax.text(
+ *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
+ )
+ if lcolor is not None:
+ t.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+ path_effects.Normal(),
+ ]
+ )
+
+
+def save_plot(path, **kw):
+ """Save the current figure without any white margin."""
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
diff --git a/imcui/hloc/utils/viz_3d.py b/imcui/hloc/utils/viz_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9fd1b1a02eaee99e061bb392a561ebbc00d93b1
--- /dev/null
+++ b/imcui/hloc/utils/viz_3d.py
@@ -0,0 +1,203 @@
+"""
+3D visualization based on plotly.
+Works for a small number of points and cameras, might be slow otherwise.
+
+1) Initialize a figure with `init_figure`
+2) Add 3D points, camera frustums, or both as a pycolmap.Reconstruction
+
+Written by Paul-Edouard Sarlin and Philipp Lindenberger.
+"""
+
+from typing import Optional
+
+import numpy as np
+import plotly.graph_objects as go
+import pycolmap
+
+
+def to_homogeneous(points):
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
+ return np.concatenate([points, pad], axis=-1)
+
+
+def init_figure(height: int = 800) -> go.Figure:
+ """Initialize a 3D figure."""
+ fig = go.Figure()
+ axes = dict(
+ visible=False,
+ showbackground=False,
+ showgrid=False,
+ showline=False,
+ showticklabels=True,
+ autorange=True,
+ )
+ fig.update_layout(
+ template="plotly_dark",
+ height=height,
+ scene_camera=dict(
+ eye=dict(x=0.0, y=-0.1, z=-2),
+ up=dict(x=0, y=-1.0, z=0),
+ projection=dict(type="orthographic"),
+ ),
+ scene=dict(
+ xaxis=axes,
+ yaxis=axes,
+ zaxis=axes,
+ aspectmode="data",
+ dragmode="orbit",
+ ),
+ margin=dict(l=0, r=0, b=0, t=0, pad=0),
+ legend=dict(orientation="h", yanchor="top", y=0.99, xanchor="left", x=0.1),
+ )
+ return fig
+
+
+def plot_points(
+ fig: go.Figure,
+ pts: np.ndarray,
+ color: str = "rgba(255, 0, 0, 1)",
+ ps: int = 2,
+ colorscale: Optional[str] = None,
+ name: Optional[str] = None,
+):
+ """Plot a set of 3D points."""
+ x, y, z = pts.T
+ tr = go.Scatter3d(
+ x=x,
+ y=y,
+ z=z,
+ mode="markers",
+ name=name,
+ legendgroup=name,
+ marker=dict(size=ps, color=color, line_width=0.0, colorscale=colorscale),
+ )
+ fig.add_trace(tr)
+
+
+def plot_camera(
+ fig: go.Figure,
+ R: np.ndarray,
+ t: np.ndarray,
+ K: np.ndarray,
+ color: str = "rgb(0, 0, 255)",
+ name: Optional[str] = None,
+ legendgroup: Optional[str] = None,
+ fill: bool = False,
+ size: float = 1.0,
+ text: Optional[str] = None,
+):
+ """Plot a camera frustum from pose and intrinsic matrix."""
+ W, H = K[0, 2] * 2, K[1, 2] * 2
+ corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]])
+ if size is not None:
+ image_extent = max(size * W / 1024.0, size * H / 1024.0)
+ world_extent = max(W, H) / (K[0, 0] + K[1, 1]) / 0.5
+ scale = 0.5 * image_extent / world_extent
+ else:
+ scale = 1.0
+ corners = to_homogeneous(corners) @ np.linalg.inv(K).T
+ corners = (corners / 2 * scale) @ R.T + t
+ legendgroup = legendgroup if legendgroup is not None else name
+
+ x, y, z = np.concatenate(([t], corners)).T
+ i = [0, 0, 0, 0]
+ j = [1, 2, 3, 4]
+ k = [2, 3, 4, 1]
+
+ if fill:
+ pyramid = go.Mesh3d(
+ x=x,
+ y=y,
+ z=z,
+ color=color,
+ i=i,
+ j=j,
+ k=k,
+ legendgroup=legendgroup,
+ name=name,
+ showlegend=False,
+ hovertemplate=text.replace("\n", "
"),
+ )
+ fig.add_trace(pyramid)
+
+ triangles = np.vstack((i, j, k)).T
+ vertices = np.concatenate(([t], corners))
+ tri_points = np.array([vertices[i] for i in triangles.reshape(-1)])
+ x, y, z = tri_points.T
+
+ pyramid = go.Scatter3d(
+ x=x,
+ y=y,
+ z=z,
+ mode="lines",
+ legendgroup=legendgroup,
+ name=name,
+ line=dict(color=color, width=1),
+ showlegend=False,
+ hovertemplate=text.replace("\n", "
"),
+ )
+ fig.add_trace(pyramid)
+
+
+def plot_camera_colmap(
+ fig: go.Figure,
+ image: pycolmap.Image,
+ camera: pycolmap.Camera,
+ name: Optional[str] = None,
+ **kwargs,
+):
+ """Plot a camera frustum from PyCOLMAP objects"""
+ world_t_camera = image.cam_from_world.inverse()
+ plot_camera(
+ fig,
+ world_t_camera.rotation.matrix(),
+ world_t_camera.translation,
+ camera.calibration_matrix(),
+ name=name or str(image.image_id),
+ text=str(image),
+ **kwargs,
+ )
+
+
+def plot_cameras(fig: go.Figure, reconstruction: pycolmap.Reconstruction, **kwargs):
+ """Plot a camera as a cone with camera frustum."""
+ for image_id, image in reconstruction.images.items():
+ plot_camera_colmap(
+ fig, image, reconstruction.cameras[image.camera_id], **kwargs
+ )
+
+
+def plot_reconstruction(
+ fig: go.Figure,
+ rec: pycolmap.Reconstruction,
+ max_reproj_error: float = 6.0,
+ color: str = "rgb(0, 0, 255)",
+ name: Optional[str] = None,
+ min_track_length: int = 2,
+ points: bool = True,
+ cameras: bool = True,
+ points_rgb: bool = True,
+ cs: float = 1.0,
+):
+ # Filter outliers
+ bbs = rec.compute_bounding_box(0.001, 0.999)
+ # Filter points, use original reproj error here
+ p3Ds = [
+ p3D
+ for _, p3D in rec.points3D.items()
+ if (
+ (p3D.xyz >= bbs[0]).all()
+ and (p3D.xyz <= bbs[1]).all()
+ and p3D.error <= max_reproj_error
+ and p3D.track.length() >= min_track_length
+ )
+ ]
+ xyzs = [p3D.xyz for p3D in p3Ds]
+ if points_rgb:
+ pcolor = [p3D.color for p3D in p3Ds]
+ else:
+ pcolor = color
+ if points:
+ plot_points(fig, np.array(xyzs), color=pcolor, ps=1, name=name)
+ if cameras:
+ plot_cameras(fig, rec, color=color, legendgroup=name, size=cs)
diff --git a/imcui/hloc/visualization.py b/imcui/hloc/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..456c2ee991efe4895c664f5bbd475c24fa789bf8
--- /dev/null
+++ b/imcui/hloc/visualization.py
@@ -0,0 +1,178 @@
+import pickle
+import random
+
+import numpy as np
+import pycolmap
+from matplotlib import cm
+
+from .utils.io import read_image
+from .utils.viz import (
+ add_text,
+ cm_RdGn,
+ plot_images,
+ plot_keypoints,
+ plot_matches,
+)
+
+
+def visualize_sfm_2d(
+ reconstruction,
+ image_dir,
+ color_by="visibility",
+ selected=[],
+ n=1,
+ seed=0,
+ dpi=75,
+):
+ assert image_dir.exists()
+ if not isinstance(reconstruction, pycolmap.Reconstruction):
+ reconstruction = pycolmap.Reconstruction(reconstruction)
+
+ if not selected:
+ image_ids = reconstruction.reg_image_ids()
+ selected = random.Random(seed).sample(image_ids, min(n, len(image_ids)))
+
+ for i in selected:
+ image = reconstruction.images[i]
+ keypoints = np.array([p.xy for p in image.points2D])
+ visible = np.array([p.has_point3D() for p in image.points2D])
+
+ if color_by == "visibility":
+ color = [(0, 0, 1) if v else (1, 0, 0) for v in visible]
+ text = f"visible: {np.count_nonzero(visible)}/{len(visible)}"
+ elif color_by == "track_length":
+ tl = np.array(
+ [
+ (
+ reconstruction.points3D[p.point3D_id].track.length()
+ if p.has_point3D()
+ else 1
+ )
+ for p in image.points2D
+ ]
+ )
+ max_, med_ = np.max(tl), np.median(tl[tl > 1])
+ tl = np.log(tl)
+ color = cm.jet(tl / tl.max()).tolist()
+ text = f"max/median track length: {max_}/{med_}"
+ elif color_by == "depth":
+ p3ids = [p.point3D_id for p in image.points2D if p.has_point3D()]
+ z = np.array(
+ [
+ (image.cam_from_world * reconstruction.points3D[j].xyz)[-1]
+ for j in p3ids
+ ]
+ )
+ z -= z.min()
+ color = cm.jet(z / np.percentile(z, 99.9))
+ text = f"visible: {np.count_nonzero(visible)}/{len(visible)}"
+ keypoints = keypoints[visible]
+ else:
+ raise NotImplementedError(f"Coloring not implemented: {color_by}.")
+
+ name = image.name
+ fig = plot_images([read_image(image_dir / name)], dpi=dpi)
+ plot_keypoints([keypoints], colors=[color], ps=4)
+ add_text(0, text)
+ add_text(0, name, pos=(0.01, 0.01), fs=5, lcolor=None, va="bottom")
+ return fig
+
+
+def visualize_loc(
+ results,
+ image_dir,
+ reconstruction=None,
+ db_image_dir=None,
+ selected=[],
+ n=1,
+ seed=0,
+ prefix=None,
+ **kwargs,
+):
+ assert image_dir.exists()
+
+ with open(str(results) + "_logs.pkl", "rb") as f:
+ logs = pickle.load(f)
+
+ if not selected:
+ queries = list(logs["loc"].keys())
+ if prefix:
+ queries = [q for q in queries if q.startswith(prefix)]
+ selected = random.Random(seed).sample(queries, min(n, len(queries)))
+
+ if reconstruction is not None:
+ if not isinstance(reconstruction, pycolmap.Reconstruction):
+ reconstruction = pycolmap.Reconstruction(reconstruction)
+
+ for qname in selected:
+ loc = logs["loc"][qname]
+ visualize_loc_from_log(
+ image_dir, qname, loc, reconstruction, db_image_dir, **kwargs
+ )
+
+
+def visualize_loc_from_log(
+ image_dir,
+ query_name,
+ loc,
+ reconstruction=None,
+ db_image_dir=None,
+ top_k_db=2,
+ dpi=75,
+):
+ q_image = read_image(image_dir / query_name)
+ if loc.get("covisibility_clustering", False):
+ # select the first, largest cluster if the localization failed
+ loc = loc["log_clusters"][loc["best_cluster"] or 0]
+
+ inliers = np.array(loc["PnP_ret"]["inliers"])
+ mkp_q = loc["keypoints_query"]
+ n = len(loc["db"])
+ if reconstruction is not None:
+ # for each pair of query keypoint and its matched 3D point,
+ # we need to find its corresponding keypoint in each database image
+ # that observes it. We also count the number of inliers in each.
+ kp_idxs, kp_to_3D_to_db = loc["keypoint_index_to_db"]
+ counts = np.zeros(n)
+ dbs_kp_q_db = [[] for _ in range(n)]
+ inliers_dbs = [[] for _ in range(n)]
+ for i, (inl, (p3D_id, db_idxs)) in enumerate(zip(inliers, kp_to_3D_to_db)):
+ track = reconstruction.points3D[p3D_id].track
+ track = {el.image_id: el.point2D_idx for el in track.elements}
+ for db_idx in db_idxs:
+ counts[db_idx] += inl
+ kp_db = track[loc["db"][db_idx]]
+ dbs_kp_q_db[db_idx].append((i, kp_db))
+ inliers_dbs[db_idx].append(inl)
+ else:
+ # for inloc the database keypoints are already in the logs
+ assert "keypoints_db" in loc
+ assert "indices_db" in loc
+ counts = np.array([np.sum(loc["indices_db"][inliers] == i) for i in range(n)])
+
+ # display the database images with the most inlier matches
+ db_sort = np.argsort(-counts)
+ for db_idx in db_sort[:top_k_db]:
+ if reconstruction is not None:
+ db = reconstruction.images[loc["db"][db_idx]]
+ db_name = db.name
+ db_kp_q_db = np.array(dbs_kp_q_db[db_idx])
+ kp_q = mkp_q[db_kp_q_db[:, 0]]
+ kp_db = np.array([db.points2D[i].xy for i in db_kp_q_db[:, 1]])
+ inliers_db = inliers_dbs[db_idx]
+ else:
+ db_name = loc["db"][db_idx]
+ kp_q = mkp_q[loc["indices_db"] == db_idx]
+ kp_db = loc["keypoints_db"][loc["indices_db"] == db_idx]
+ inliers_db = inliers[loc["indices_db"] == db_idx]
+
+ db_image = read_image((db_image_dir or image_dir) / db_name)
+ color = cm_RdGn(inliers_db).tolist()
+ text = f"inliers: {sum(inliers_db)}/{len(inliers_db)}"
+
+ plot_images([q_image, db_image], dpi=dpi)
+ plot_matches(kp_q, kp_db, color, a=0.1)
+ add_text(0, text)
+ opts = dict(pos=(0.01, 0.01), fs=5, lcolor=None, va="bottom")
+ add_text(0, query_name, **opts)
+ add_text(1, db_name, **opts)
diff --git a/imcui/third_party/ALIKE/alike.py b/imcui/third_party/ALIKE/alike.py
new file mode 100644
index 0000000000000000000000000000000000000000..303616d52581efce0ae0eb86af70f5ea8984909d
--- /dev/null
+++ b/imcui/third_party/ALIKE/alike.py
@@ -0,0 +1,143 @@
+import logging
+import os
+import cv2
+import torch
+from copy import deepcopy
+import torch.nn.functional as F
+from torchvision.transforms import ToTensor
+import math
+
+from alnet import ALNet
+from soft_detect import DKD
+import time
+
+configs = {
+ 'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2,
+ 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-t.pth')},
+ 'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2,
+ 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-s.pth')},
+ 'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2,
+ 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-n.pth')},
+ 'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2,
+ 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-l.pth')},
+}
+
+
+class ALike(ALNet):
+ def __init__(self,
+ # ================================== feature encoder
+ c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128,
+ single_head: bool = False,
+ # ================================== detect parameters
+ radius: int = 2,
+ top_k: int = 500, scores_th: float = 0.5,
+ n_limit: int = 5000,
+ device: str = 'cpu',
+ model_path: str = ''
+ ):
+ super().__init__(c1, c2, c3, c4, dim, single_head)
+ self.radius = radius
+ self.top_k = top_k
+ self.n_limit = n_limit
+ self.scores_th = scores_th
+ self.dkd = DKD(radius=self.radius, top_k=self.top_k,
+ scores_th=self.scores_th, n_limit=self.n_limit)
+ self.device = device
+
+ if model_path != '':
+ state_dict = torch.load(model_path, self.device)
+ self.load_state_dict(state_dict)
+ self.to(self.device)
+ self.eval()
+ logging.info(f'Loaded model parameters from {model_path}')
+ logging.info(
+ f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB")
+
+ def extract_dense_map(self, image, ret_dict=False):
+ # ====================================================
+ # check image size, should be integer multiples of 2^5
+ # if it is not a integer multiples of 2^5, padding zeros
+ device = image.device
+ b, c, h, w = image.shape
+ h_ = math.ceil(h / 32) * 32 if h % 32 != 0 else h
+ w_ = math.ceil(w / 32) * 32 if w % 32 != 0 else w
+ if h_ != h:
+ h_padding = torch.zeros(b, c, h_ - h, w, device=device)
+ image = torch.cat([image, h_padding], dim=2)
+ if w_ != w:
+ w_padding = torch.zeros(b, c, h_, w_ - w, device=device)
+ image = torch.cat([image, w_padding], dim=3)
+ # ====================================================
+
+ scores_map, descriptor_map = super().forward(image)
+
+ # ====================================================
+ if h_ != h or w_ != w:
+ descriptor_map = descriptor_map[:, :, :h, :w]
+ scores_map = scores_map[:, :, :h, :w] # Bx1xHxW
+ # ====================================================
+
+ # BxCxHxW
+ descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1)
+
+ if ret_dict:
+ return {'descriptor_map': descriptor_map, 'scores_map': scores_map, }
+ else:
+ return descriptor_map, scores_map
+
+ def forward(self, img, image_size_max=99999, sort=False, sub_pixel=False):
+ """
+ :param img: np.array HxWx3, RGB
+ :param image_size_max: maximum image size, otherwise, the image will be resized
+ :param sort: sort keypoints by scores
+ :param sub_pixel: whether to use sub-pixel accuracy
+ :return: a dictionary with 'keypoints', 'descriptors', 'scores', and 'time'
+ """
+ H, W, three = img.shape
+ assert three == 3, "input image shape should be [HxWx3]"
+
+ # ==================== image size constraint
+ image = deepcopy(img)
+ max_hw = max(H, W)
+ if max_hw > image_size_max:
+ ratio = float(image_size_max / max_hw)
+ image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio)
+
+ # ==================== convert image to tensor
+ image = torch.from_numpy(image).to(self.device).to(torch.float32).permute(2, 0, 1)[None] / 255.0
+
+ # ==================== extract keypoints
+ start = time.time()
+
+ with torch.no_grad():
+ descriptor_map, scores_map = self.extract_dense_map(image)
+ keypoints, descriptors, scores, _ = self.dkd(scores_map, descriptor_map,
+ sub_pixel=sub_pixel)
+ keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0]
+ keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]])
+
+ if sort:
+ indices = torch.argsort(scores, descending=True)
+ keypoints = keypoints[indices]
+ descriptors = descriptors[indices]
+ scores = scores[indices]
+
+ end = time.time()
+
+ return {'keypoints': keypoints.cpu().numpy(),
+ 'descriptors': descriptors.cpu().numpy(),
+ 'scores': scores.cpu().numpy(),
+ 'scores_map': scores_map.cpu().numpy(),
+ 'time': end - start, }
+
+
+if __name__ == '__main__':
+ import numpy as np
+ from thop import profile
+
+ net = ALike(c1=32, c2=64, c3=128, c4=128, dim=128, single_head=False)
+
+ image = np.random.random((640, 480, 3)).astype(np.float32)
+ flops, params = profile(net, inputs=(image, 9999, False), verbose=False)
+ print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9))
+ print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3))
diff --git a/imcui/third_party/ALIKE/alnet.py b/imcui/third_party/ALIKE/alnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..53127063233660c7b96aa15e89aa4a8a1a340dd1
--- /dev/null
+++ b/imcui/third_party/ALIKE/alnet.py
@@ -0,0 +1,164 @@
+import torch
+from torch import nn
+from torchvision.models import resnet
+from typing import Optional, Callable
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels,
+ gate: Optional[Callable[..., nn.Module]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None):
+ super().__init__()
+ if gate is None:
+ self.gate = nn.ReLU(inplace=True)
+ else:
+ self.gate = gate
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self.conv1 = resnet.conv3x3(in_channels, out_channels)
+ self.bn1 = norm_layer(out_channels)
+ self.conv2 = resnet.conv3x3(out_channels, out_channels)
+ self.bn2 = norm_layer(out_channels)
+
+ def forward(self, x):
+ x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
+ x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
+ return x
+
+
+# copied from torchvision\models\resnet.py#27->BasicBlock
+class ResBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ gate: Optional[Callable[..., nn.Module]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(ResBlock, self).__init__()
+ if gate is None:
+ self.gate = nn.ReLU(inplace=True)
+ else:
+ self.gate = gate
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('ResBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in ResBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = resnet.conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.conv2 = resnet.conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.gate(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.gate(out)
+
+ return out
+
+
+class ALNet(nn.Module):
+ def __init__(self, c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128,
+ single_head: bool = True,
+ ):
+ super().__init__()
+
+ self.gate = nn.ReLU(inplace=True)
+
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4)
+
+ self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d)
+
+ self.block2 = ResBlock(inplanes=c1, planes=c2, stride=1,
+ downsample=nn.Conv2d(c1, c2, 1),
+ gate=self.gate,
+ norm_layer=nn.BatchNorm2d)
+ self.block3 = ResBlock(inplanes=c2, planes=c3, stride=1,
+ downsample=nn.Conv2d(c2, c3, 1),
+ gate=self.gate,
+ norm_layer=nn.BatchNorm2d)
+ self.block4 = ResBlock(inplanes=c3, planes=c4, stride=1,
+ downsample=nn.Conv2d(c3, c4, 1),
+ gate=self.gate,
+ norm_layer=nn.BatchNorm2d)
+
+ # ================================== feature aggregation
+ self.conv1 = resnet.conv1x1(c1, dim // 4)
+ self.conv2 = resnet.conv1x1(c2, dim // 4)
+ self.conv3 = resnet.conv1x1(c3, dim // 4)
+ self.conv4 = resnet.conv1x1(dim, dim // 4)
+ self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
+ self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
+ self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
+
+ # ================================== detector and descriptor head
+ self.single_head = single_head
+ if not self.single_head:
+ self.convhead1 = resnet.conv1x1(dim, dim)
+ self.convhead2 = resnet.conv1x1(dim, dim + 1)
+
+ def forward(self, image):
+ # ================================== feature encoder
+ x1 = self.block1(image) # B x c1 x H x W
+ x2 = self.pool2(x1)
+ x2 = self.block2(x2) # B x c2 x H/2 x W/2
+ x3 = self.pool4(x2)
+ x3 = self.block3(x3) # B x c3 x H/8 x W/8
+ x4 = self.pool4(x3)
+ x4 = self.block4(x4) # B x dim x H/32 x W/32
+
+ # ================================== feature aggregation
+ x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
+ x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
+ x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
+ x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
+ x2_up = self.upsample2(x2) # B x dim//4 x H x W
+ x3_up = self.upsample8(x3) # B x dim//4 x H x W
+ x4_up = self.upsample32(x4) # B x dim//4 x H x W
+ x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
+
+ # ================================== detector and descriptor head
+ if not self.single_head:
+ x1234 = self.gate(self.convhead1(x1234))
+ x = self.convhead2(x1234) # B x dim+1 x H x W
+
+ descriptor_map = x[:, :-1, :, :]
+ scores_map = torch.sigmoid(x[:, -1, :, :]).unsqueeze(1)
+
+ return scores_map, descriptor_map
+
+
+if __name__ == '__main__':
+ from thop import profile
+
+ net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True)
+
+ image = torch.randn(1, 3, 640, 480)
+ flops, params = profile(net, inputs=(image,), verbose=False)
+ print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9))
+ print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3))
diff --git a/imcui/third_party/ALIKE/demo.py b/imcui/third_party/ALIKE/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bfbefdd26cfeceefc75f90d1c44a7f922c624a5
--- /dev/null
+++ b/imcui/third_party/ALIKE/demo.py
@@ -0,0 +1,167 @@
+import copy
+import os
+import cv2
+import glob
+import logging
+import argparse
+import numpy as np
+from tqdm import tqdm
+from alike import ALike, configs
+
+
+class ImageLoader(object):
+ def __init__(self, filepath: str):
+ self.N = 3000
+ if filepath.startswith('camera'):
+ camera = int(filepath[6:])
+ self.cap = cv2.VideoCapture(camera)
+ if not self.cap.isOpened():
+ raise IOError(f"Can't open camera {camera}!")
+ logging.info(f'Opened camera {camera}')
+ self.mode = 'camera'
+ elif os.path.exists(filepath):
+ if os.path.isfile(filepath):
+ self.cap = cv2.VideoCapture(filepath)
+ if not self.cap.isOpened():
+ raise IOError(f"Can't open video {filepath}!")
+ rate = self.cap.get(cv2.CAP_PROP_FPS)
+ self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
+ duration = self.N / rate
+ logging.info(f'Opened video {filepath}')
+ logging.info(f'Frames: {self.N}, FPS: {rate}, Duration: {duration}s')
+ self.mode = 'video'
+ else:
+ self.images = glob.glob(os.path.join(filepath, '*.png')) + \
+ glob.glob(os.path.join(filepath, '*.jpg')) + \
+ glob.glob(os.path.join(filepath, '*.ppm'))
+ self.images.sort()
+ self.N = len(self.images)
+ logging.info(f'Loading {self.N} images')
+ self.mode = 'images'
+ else:
+ raise IOError('Error filepath (camerax/path of images/path of videos): ', filepath)
+
+ def __getitem__(self, item):
+ if self.mode == 'camera' or self.mode == 'video':
+ if item > self.N:
+ return None
+ ret, img = self.cap.read()
+ if not ret:
+ raise "Can't read image from camera"
+ if self.mode == 'video':
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
+ elif self.mode == 'images':
+ filename = self.images[item]
+ img = cv2.imread(filename)
+ if img is None:
+ raise Exception('Error reading image %s' % filename)
+ return img
+
+ def __len__(self):
+ return self.N
+
+
+class SimpleTracker(object):
+ def __init__(self):
+ self.pts_prev = None
+ self.desc_prev = None
+
+ def update(self, img, pts, desc):
+ N_matches = 0
+ if self.pts_prev is None:
+ self.pts_prev = pts
+ self.desc_prev = desc
+
+ out = copy.deepcopy(img)
+ for pt1 in pts:
+ p1 = (int(round(pt1[0])), int(round(pt1[1])))
+ cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16)
+ else:
+ matches = self.mnn_mather(self.desc_prev, desc)
+ mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]]
+ N_matches = len(matches)
+
+ out = copy.deepcopy(img)
+ for pt1, pt2 in zip(mpts1, mpts2):
+ p1 = (int(round(pt1[0])), int(round(pt1[1])))
+ p2 = (int(round(pt2[0])), int(round(pt2[1])))
+ cv2.line(out, p1, p2, (0, 255, 0), lineType=16)
+ cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16)
+
+ self.pts_prev = pts
+ self.desc_prev = desc
+
+ return out, N_matches
+
+ def mnn_mather(self, desc1, desc2):
+ sim = desc1 @ desc2.transpose()
+ sim[sim < 0.9] = 0
+ nn12 = np.argmax(sim, axis=1)
+ nn21 = np.argmax(sim, axis=0)
+ ids1 = np.arange(0, sim.shape[0])
+ mask = (ids1 == nn21[nn12])
+ matches = np.stack([ids1[mask], nn12[mask]])
+ return matches.transpose()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='ALike Demo.')
+ parser.add_argument('input', type=str, default='',
+ help='Image directory or movie file or "camera0" (for webcam0).')
+ parser.add_argument('--model', choices=['alike-t', 'alike-s', 'alike-n', 'alike-l'], default="alike-t",
+ help="The model configuration")
+ parser.add_argument('--device', type=str, default='cuda', help="Running device (default: cuda).")
+ parser.add_argument('--top_k', type=int, default=-1,
+ help='Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)')
+ parser.add_argument('--scores_th', type=float, default=0.2,
+ help='Detector score threshold (default: 0.2).')
+ parser.add_argument('--n_limit', type=int, default=5000,
+ help='Maximum number of keypoints to be detected (default: 5000).')
+ parser.add_argument('--no_display', action='store_true',
+ help='Do not display images to screen. Useful if running remotely (default: False).')
+ parser.add_argument('--no_sub_pixel', action='store_true',
+ help='Do not detect sub-pixel keypoints (default: False).')
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.INFO)
+
+ image_loader = ImageLoader(args.input)
+ model = ALike(**configs[args.model],
+ device=args.device,
+ top_k=args.top_k,
+ scores_th=args.scores_th,
+ n_limit=args.n_limit)
+ tracker = SimpleTracker()
+
+ if not args.no_display:
+ logging.info("Press 'q' to stop!")
+ cv2.namedWindow(args.model)
+
+ runtime = []
+ progress_bar = tqdm(image_loader)
+ for img in progress_bar:
+ if img is None:
+ break
+
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
+ kpts = pred['keypoints']
+ desc = pred['descriptors']
+ runtime.append(pred['time'])
+
+ out, N_matches = tracker.update(img, kpts, desc)
+
+ ave_fps = (1. / np.stack(runtime)).mean()
+ status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
+ progress_bar.set_description(status)
+
+ if not args.no_display:
+ cv2.setWindowTitle(args.model, args.model + ': ' + status)
+ cv2.imshow(args.model, out)
+ if cv2.waitKey(1) == ord('q'):
+ break
+
+ logging.info('Finished!')
+ if not args.no_display:
+ logging.info('Press any key to exit!')
+ cv2.waitKey()
diff --git a/imcui/third_party/ALIKE/hseq/eval.py b/imcui/third_party/ALIKE/hseq/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..abca625044013a0cd34a518223c32d3ec8abb8a3
--- /dev/null
+++ b/imcui/third_party/ALIKE/hseq/eval.py
@@ -0,0 +1,162 @@
+import cv2
+import os
+from tqdm import tqdm
+import torch
+import numpy as np
+from extract import extract_method
+
+use_cuda = torch.cuda.is_available()
+device = torch.device('cuda' if use_cuda else 'cpu')
+
+methods = ['d2', 'lfnet', 'superpoint', 'r2d2', 'aslfeat', 'disk',
+ 'alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms']
+names = ['D2-Net(MS)', 'LF-Net(MS)', 'SuperPoint', 'R2D2(MS)', 'ASLFeat(MS)', 'DISK',
+ 'ALike-N', 'ALike-L', 'ALike-N(MS)', 'ALike-L(MS)']
+
+top_k = None
+n_i = 52
+n_v = 56
+cache_dir = 'hseq/cache'
+dataset_path = 'hseq/hpatches-sequences-release'
+
+
+def generate_read_function(method, extension='ppm'):
+ def read_function(seq_name, im_idx):
+ aux = np.load(os.path.join(dataset_path, seq_name, '%d.%s.%s' % (im_idx, extension, method)))
+ if top_k is None:
+ return aux['keypoints'], aux['descriptors']
+ else:
+ assert ('scores' in aux)
+ ids = np.argsort(aux['scores'])[-top_k:]
+ return aux['keypoints'][ids, :], aux['descriptors'][ids, :]
+
+ return read_function
+
+
+def mnn_matcher(descriptors_a, descriptors_b):
+ device = descriptors_a.device
+ sim = descriptors_a @ descriptors_b.t()
+ nn12 = torch.max(sim, dim=1)[1]
+ nn21 = torch.max(sim, dim=0)[1]
+ ids1 = torch.arange(0, sim.shape[0], device=device)
+ mask = (ids1 == nn21[nn12])
+ matches = torch.stack([ids1[mask], nn12[mask]])
+ return matches.t().data.cpu().numpy()
+
+
+def homo_trans(coord, H):
+ kpt_num = coord.shape[0]
+ homo_coord = np.concatenate((coord, np.ones((kpt_num, 1))), axis=-1)
+ proj_coord = np.matmul(H, homo_coord.T).T
+ proj_coord = proj_coord / proj_coord[:, 2][..., None]
+ proj_coord = proj_coord[:, 0:2]
+ return proj_coord
+
+
+def benchmark_features(read_feats):
+ lim = [1, 5]
+ rng = np.arange(lim[0], lim[1] + 1)
+
+ seq_names = sorted(os.listdir(dataset_path))
+
+ n_feats = []
+ n_matches = []
+ seq_type = []
+ i_err = {thr: 0 for thr in rng}
+ v_err = {thr: 0 for thr in rng}
+
+ i_err_homo = {thr: 0 for thr in rng}
+ v_err_homo = {thr: 0 for thr in rng}
+
+ for seq_idx, seq_name in tqdm(enumerate(seq_names), total=len(seq_names)):
+ keypoints_a, descriptors_a = read_feats(seq_name, 1)
+ n_feats.append(keypoints_a.shape[0])
+
+ # =========== compute homography
+ ref_img = cv2.imread(os.path.join(dataset_path, seq_name, '1.ppm'))
+ ref_img_shape = ref_img.shape
+
+ for im_idx in range(2, 7):
+ keypoints_b, descriptors_b = read_feats(seq_name, im_idx)
+ n_feats.append(keypoints_b.shape[0])
+
+ matches = mnn_matcher(
+ torch.from_numpy(descriptors_a).to(device=device),
+ torch.from_numpy(descriptors_b).to(device=device)
+ )
+
+ homography = np.loadtxt(os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx)))
+
+ pos_a = keypoints_a[matches[:, 0], : 2]
+ pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1)
+ pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
+ pos_b_proj = pos_b_proj_h[:, : 2] / pos_b_proj_h[:, 2:]
+
+ pos_b = keypoints_b[matches[:, 1], : 2]
+
+ dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
+
+ n_matches.append(matches.shape[0])
+ seq_type.append(seq_name[0])
+
+ if dist.shape[0] == 0:
+ dist = np.array([float("inf")])
+
+ for thr in rng:
+ if seq_name[0] == 'i':
+ i_err[thr] += np.mean(dist <= thr)
+ else:
+ v_err[thr] += np.mean(dist <= thr)
+
+ # =========== compute homography
+ gt_homo = homography
+ pred_homo, _ = cv2.findHomography(keypoints_a[matches[:, 0], : 2], keypoints_b[matches[:, 1], : 2],
+ cv2.RANSAC)
+ if pred_homo is None:
+ homo_dist = np.array([float("inf")])
+ else:
+ corners = np.array([[0, 0],
+ [ref_img_shape[1] - 1, 0],
+ [0, ref_img_shape[0] - 1],
+ [ref_img_shape[1] - 1, ref_img_shape[0] - 1]])
+ real_warped_corners = homo_trans(corners, gt_homo)
+ warped_corners = homo_trans(corners, pred_homo)
+ homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
+
+ for thr in rng:
+ if seq_name[0] == 'i':
+ i_err_homo[thr] += np.mean(homo_dist <= thr)
+ else:
+ v_err_homo[thr] += np.mean(homo_dist <= thr)
+
+ seq_type = np.array(seq_type)
+ n_feats = np.array(n_feats)
+ n_matches = np.array(n_matches)
+
+ return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
+
+
+if __name__ == '__main__':
+ errors = {}
+ for method in methods:
+ output_file = os.path.join(cache_dir, method + '.npy')
+ read_function = generate_read_function(method)
+ if os.path.exists(output_file):
+ errors[method] = np.load(output_file, allow_pickle=True)
+ else:
+ extract_method(method)
+ errors[method] = benchmark_features(read_function)
+ np.save(output_file, errors[method])
+
+ for name, method in zip(names, methods):
+ i_err, v_err, i_err_hom, v_err_hom, _ = errors[method]
+
+ print(f"====={name}=====")
+ print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end='')
+ for thr in range(1, 4):
+ err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5)
+ print(f"{err * 100:.2f}%", end=' ')
+ for thr in range(1, 4):
+ err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5)
+ print(f"{err_hom * 100:.2f}%", end=' ')
+ print('')
diff --git a/imcui/third_party/ALIKE/hseq/extract.py b/imcui/third_party/ALIKE/hseq/extract.py
new file mode 100644
index 0000000000000000000000000000000000000000..1342e40dd2d0e1d1986e90f995c95b17972ec4e1
--- /dev/null
+++ b/imcui/third_party/ALIKE/hseq/extract.py
@@ -0,0 +1,159 @@
+import os
+import sys
+import cv2
+from pathlib import Path
+import numpy as np
+import torch
+import torch.utils.data as data
+from tqdm import tqdm
+from copy import deepcopy
+from torchvision.transforms import ToTensor
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+from alike import ALike, configs
+
+dataset_root = 'hseq/hpatches-sequences-release'
+use_cuda = torch.cuda.is_available()
+device = 'cuda' if use_cuda else 'cpu'
+methods = ['alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms']
+
+
+class HPatchesDataset(data.Dataset):
+ def __init__(self, root: str = dataset_root, alteration: str = 'all'):
+ """
+ Args:
+ root: dataset root path
+ alteration: # 'all', 'i' for illumination or 'v' for viewpoint
+ """
+ assert (Path(root).exists()), f"Dataset root path {root} dose not exist!"
+ self.root = root
+
+ # get all image file name
+ self.image0_list = []
+ self.image1_list = []
+ self.homographies = []
+ folders = [x for x in Path(self.root).iterdir() if x.is_dir()]
+ self.seqs = []
+ for folder in folders:
+ if alteration == 'i' and folder.stem[0] != 'i':
+ continue
+ if alteration == 'v' and folder.stem[0] != 'v':
+ continue
+
+ self.seqs.append(folder)
+
+ self.len = len(self.seqs)
+ assert (self.len > 0), f'Can not find PatchDataset in path {self.root}'
+
+ def __getitem__(self, item):
+ folder = self.seqs[item]
+
+ imgs = []
+ homos = []
+ for i in range(1, 7):
+ img = cv2.imread(str(folder / f'{i}.ppm'), cv2.IMREAD_COLOR)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # HxWxC
+ imgs.append(img)
+
+ if i != 1:
+ homo = np.loadtxt(str(folder / f'H_1_{i}')).astype('float32')
+ homos.append(homo)
+
+ return imgs, homos, folder.stem
+
+ def __len__(self):
+ return self.len
+
+ def name(self):
+ return self.__class__
+
+
+def extract_multiscale(model, img, scale_f=2 ** 0.5,
+ min_scale=1., max_scale=1.,
+ min_size=0., max_size=99999.,
+ image_size_max=99999,
+ n_k=0, sort=False):
+ H_, W_, three = img.shape
+ assert three == 3, "input image shape should be [HxWx3]"
+
+ old_bm = torch.backends.cudnn.benchmark
+ torch.backends.cudnn.benchmark = False # speedup
+
+ # ==================== image size constraint
+ image = deepcopy(img)
+ max_hw = max(H_, W_)
+ if max_hw > image_size_max:
+ ratio = float(image_size_max / max_hw)
+ image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio)
+
+ # ==================== convert image to tensor
+ H, W, three = image.shape
+ image = ToTensor()(image).unsqueeze(0)
+ image = image.to(device)
+
+ s = 1.0 # current scale factor
+ keypoints, descriptors, scores, scores_maps, descriptor_maps = [], [], [], [], []
+ while s + 0.001 >= max(min_scale, min_size / max(H, W)):
+ if s - 0.001 <= min(max_scale, max_size / max(H, W)):
+ nh, nw = image.shape[2:]
+
+ # extract descriptors
+ with torch.no_grad():
+ descriptor_map, scores_map = model.extract_dense_map(image)
+ keypoints_, descriptors_, scores_, _ = model.dkd(scores_map, descriptor_map)
+
+ keypoints.append(keypoints_[0])
+ descriptors.append(descriptors_[0])
+ scores.append(scores_[0])
+
+ s /= scale_f
+
+ # down-scale the image for next iteration
+ nh, nw = round(H * s), round(W * s)
+ image = torch.nn.functional.interpolate(image, (nh, nw), mode='bilinear', align_corners=False)
+
+ # restore value
+ torch.backends.cudnn.benchmark = old_bm
+
+ keypoints = torch.cat(keypoints)
+ descriptors = torch.cat(descriptors)
+ scores = torch.cat(scores)
+ keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W_ - 1, H_ - 1]])
+
+ if sort or 0 < n_k < len(keypoints):
+ indices = torch.argsort(scores, descending=True)
+ keypoints = keypoints[indices]
+ descriptors = descriptors[indices]
+ scores = scores[indices]
+
+ if 0 < n_k < len(keypoints):
+ keypoints = keypoints[0:n_k]
+ descriptors = descriptors[0:n_k]
+ scores = scores[0:n_k]
+
+ return {'keypoints': keypoints, 'descriptors': descriptors, 'scores': scores}
+
+
+def extract_method(m):
+ hpatches = HPatchesDataset(root=dataset_root, alteration='all')
+ model = m[:7]
+ min_scale = 0.3 if m[8:] == 'ms' else 1.0
+
+ model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000)
+
+ progbar = tqdm(hpatches, desc='Extracting for {}'.format(m))
+ for imgs, homos, seq_name in progbar:
+ for i in range(1, 7):
+ img = imgs[i - 1]
+ pred = extract_multiscale(model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000)
+ kpts, descs, scores = pred['keypoints'], pred['descriptors'], pred['scores']
+
+ with open(os.path.join(dataset_root, seq_name, f'{i}.ppm.{m}'), 'wb') as f:
+ np.savez(f, keypoints=kpts.cpu().numpy(),
+ scores=scores.cpu().numpy(),
+ descriptors=descs.cpu().numpy())
+
+
+if __name__ == '__main__':
+ for method in methods:
+ extract_method(method)
diff --git a/imcui/third_party/ALIKE/soft_detect.py b/imcui/third_party/ALIKE/soft_detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d23cd13b8a7db9b0398fdc1b235564222d30c90
--- /dev/null
+++ b/imcui/third_party/ALIKE/soft_detect.py
@@ -0,0 +1,194 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+# coordinates system
+# ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ]
+# | -----------------------------
+# | | |
+# | | |
+# | | |
+# | | image |
+# | | |
+# | | |
+# | | |
+# | |---------------------------|
+# v
+# [ y: range=-1.0~1.0; h: range=0~H ]
+
+def simple_nms(scores, nms_radius: int):
+ """ Fast Non-maximum suppression to remove nearby points """
+ assert (nms_radius >= 0)
+
+ def max_pool(x):
+ return torch.nn.functional.max_pool2d(
+ x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == max_pool(scores)
+
+ for _ in range(2):
+ supp_mask = max_pool(max_mask.float()) > 0
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == max_pool(supp_scores)
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+def sample_descriptor(descriptor_map, kpts, bilinear_interp=False):
+ """
+ :param descriptor_map: BxCxHxW
+ :param kpts: list, len=B, each is Nx2 (keypoints) [h,w]
+ :param bilinear_interp: bool, whether to use bilinear interpolation
+ :return: descriptors: list, len=B, each is NxD
+ """
+ batch_size, channel, height, width = descriptor_map.shape
+
+ descriptors = []
+ for index in range(batch_size):
+ kptsi = kpts[index] # Nx2,(x,y)
+
+ if bilinear_interp:
+ descriptors_ = torch.nn.functional.grid_sample(descriptor_map[index].unsqueeze(0), kptsi.view(1, 1, -1, 2),
+ mode='bilinear', align_corners=True)[0, :, 0, :] # CxN
+ else:
+ kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]])
+ kptsi = kptsi.long()
+ descriptors_ = descriptor_map[index, :, kptsi[:, 1], kptsi[:, 0]] # CxN
+
+ descriptors_ = torch.nn.functional.normalize(descriptors_, p=2, dim=0)
+ descriptors.append(descriptors_.t())
+
+ return descriptors
+
+
+class DKD(nn.Module):
+ def __init__(self, radius=2, top_k=0, scores_th=0.2, n_limit=20000):
+ """
+ Args:
+ radius: soft detection radius, kernel size is (2 * radius + 1)
+ top_k: top_k > 0: return top k keypoints
+ scores_th: top_k <= 0 threshold mode: scores_th > 0: return keypoints with scores>scores_th
+ else: return keypoints with scores > scores.mean()
+ n_limit: max number of keypoint in threshold mode
+ """
+ super().__init__()
+ self.radius = radius
+ self.top_k = top_k
+ self.scores_th = scores_th
+ self.n_limit = n_limit
+ self.kernel_size = 2 * self.radius + 1
+ self.temperature = 0.1 # tuned temperature
+ self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
+
+ # local xy grid
+ x = torch.linspace(-self.radius, self.radius, self.kernel_size)
+ # (kernel_size*kernel_size) x 2 : (w,h)
+ self.hw_grid = torch.stack(torch.meshgrid([x, x])).view(2, -1).t()[:, [1, 0]]
+
+ def detect_keypoints(self, scores_map, sub_pixel=True):
+ b, c, h, w = scores_map.shape
+ scores_nograd = scores_map.detach()
+ # nms_scores = simple_nms(scores_nograd, self.radius)
+ nms_scores = simple_nms(scores_nograd, 2)
+
+ # remove border
+ nms_scores[:, :, :self.radius + 1, :] = 0
+ nms_scores[:, :, :, :self.radius + 1] = 0
+ nms_scores[:, :, h - self.radius:, :] = 0
+ nms_scores[:, :, :, w - self.radius:] = 0
+
+ # detect keypoints without grad
+ if self.top_k > 0:
+ topk = torch.topk(nms_scores.view(b, -1), self.top_k)
+ indices_keypoints = topk.indices # B x top_k
+ else:
+ if self.scores_th > 0:
+ masks = nms_scores > self.scores_th
+ if masks.sum() == 0:
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
+ else:
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
+ masks = masks.reshape(b, -1)
+
+ indices_keypoints = [] # list, B x (any size)
+ scores_view = scores_nograd.reshape(b, -1)
+ for mask, scores in zip(masks, scores_view):
+ indices = mask.nonzero(as_tuple=False)[:, 0]
+ if len(indices) > self.n_limit:
+ kpts_sc = scores[indices]
+ sort_idx = kpts_sc.sort(descending=True)[1]
+ sel_idx = sort_idx[:self.n_limit]
+ indices = indices[sel_idx]
+ indices_keypoints.append(indices)
+
+ keypoints = []
+ scoredispersitys = []
+ kptscores = []
+ if sub_pixel:
+ # detect soft keypoints with grad backpropagation
+ patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
+ self.hw_grid = self.hw_grid.to(patches) # to device
+ for b_idx in range(b):
+ patch = patches[b_idx].t() # (H*W) x (kernel**2)
+ indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M
+ patch_scores = patch[indices_kpt] # M x (kernel**2)
+
+ # max is detached to prevent undesired backprop loops in the graph
+ max_v = patch_scores.max(dim=1).values.detach()[:, None]
+ x_exp = ((patch_scores - max_v) / self.temperature).exp() # M * (kernel**2), in [0, 1]
+
+ # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
+ xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] # Soft-argmax, Mx2
+
+ hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius,
+ dim=-1) ** 2
+ scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
+
+ # compute result keypoints
+ keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2
+ keypoints_xy = keypoints_xy_nms + xy_residual
+ keypoints_xy = keypoints_xy / keypoints_xy.new_tensor(
+ [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1)
+
+ kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0),
+ keypoints_xy.view(1, 1, -1, 2),
+ mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN
+
+ keypoints.append(keypoints_xy)
+ scoredispersitys.append(scoredispersity)
+ kptscores.append(kptscore)
+ else:
+ for b_idx in range(b):
+ indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M
+ keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2
+ keypoints_xy = keypoints_xy_nms / keypoints_xy_nms.new_tensor(
+ [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1)
+ kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0),
+ keypoints_xy.view(1, 1, -1, 2),
+ mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN
+ keypoints.append(keypoints_xy)
+ scoredispersitys.append(None)
+ kptscores.append(kptscore)
+
+ return keypoints, scoredispersitys, kptscores
+
+ def forward(self, scores_map, descriptor_map, sub_pixel=False):
+ """
+ :param scores_map: Bx1xHxW
+ :param descriptor_map: BxCxHxW
+ :param sub_pixel: whether to use sub-pixel keypoint detection
+ :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
+ """
+ keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map,
+ sub_pixel)
+
+ descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel)
+
+ # keypoints: B M 2
+ # descriptors: B M D
+ # scoredispersitys:
+ return keypoints, descriptors, kptscores, scoredispersitys
diff --git a/imcui/third_party/ASpanFormer/.github/workflows/sync.yml b/imcui/third_party/ASpanFormer/.github/workflows/sync.yml
new file mode 100644
index 0000000000000000000000000000000000000000..42e762d5299095226503f3a8cebfeef440ef68d7
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/.github/workflows/sync.yml
@@ -0,0 +1,39 @@
+name: Upstream Sync
+
+permissions:
+ contents: write
+
+on:
+ schedule:
+ - cron: "0 0 * * *" # every day
+ workflow_dispatch:
+
+jobs:
+ sync_latest_from_upstream:
+ name: Sync latest commits from upstream repo
+ runs-on: ubuntu-latest
+ if: ${{ github.event.repository.fork }}
+
+ steps:
+ # Step 1: run a standard checkout action
+ - name: Checkout target repo
+ uses: actions/checkout@v3
+
+ # Step 2: run the sync action
+ - name: Sync upstream changes
+ id: sync
+ uses: aormsby/Fork-Sync-With-Upstream-action@v3.4
+ with:
+ upstream_sync_repo: apple/ml-aspanformer
+ upstream_sync_branch: main
+ target_sync_branch: main
+ target_repo_token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, no need to set
+
+ # Set test_mode true to run tests instead of the true action!!
+ test_mode: false
+
+ - name: Sync check
+ if: failure()
+ run: |
+ echo "::error::Due to insufficient permissions, synchronization failed (as expected). Please go to the repository homepage and manually perform [Sync fork]."
+ exit 1
diff --git a/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc2b44807696ec280672c8f40650fd04fa4d8a36
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py
@@ -0,0 +1,10 @@
+import sys
+from pathlib import Path
+sys.path.append(str(Path(__file__).parent / '../../../'))
+from src.config.default import _CN as cfg
+
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+
+cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
+cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
+cfg.ASPAN.COARSE.TRAIN_RES = [480,640]
diff --git a/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..886d10d8f55533c8021bcca8395b5a2897fb8734
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py
@@ -0,0 +1,11 @@
+import sys
+from pathlib import Path
+sys.path.append(str(Path(__file__).parent / '../../../'))
+from src.config.default import _CN as cfg
+
+cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+
+cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
+cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
+cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
diff --git a/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0b9c04cbf3f466e413b345272afe7d7fe4274ea
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py
@@ -0,0 +1,21 @@
+import sys
+from pathlib import Path
+sys.path.append(str(Path(__file__).parent / '../../../'))
+from src.config.default import _CN as cfg
+
+cfg.ASPAN.COARSE.COARSEST_LEVEL= [36,36]
+cfg.ASPAN.COARSE.TRAIN_RES = [832,832]
+cfg.ASPAN.COARSE.TEST_RES = [1152,1152]
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+
+cfg.TRAINER.CANONICAL_LR = 8e-3
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
+
+# pose estimation
+cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
+
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+cfg.ASPAN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
diff --git a/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1202080b234562d8cc65d924d7cccf0336b9f7c0
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py
@@ -0,0 +1,20 @@
+import sys
+from pathlib import Path
+sys.path.append(str(Path(__file__).parent / '../../../'))
+from src.config.default import _CN as cfg
+
+cfg.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
+cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
+
+cfg.TRAINER.CANONICAL_LR = 8e-3
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
+
+# pose estimation
+cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
+
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+cfg.ASPAN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
diff --git a/imcui/third_party/ASpanFormer/configs/data/__init__.py b/imcui/third_party/ASpanFormer/configs/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/ASpanFormer/configs/data/base.py b/imcui/third_party/ASpanFormer/configs/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..03aab160fa4137ccc04380f94854a56fbb549074
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/data/base.py
@@ -0,0 +1,35 @@
+"""
+The data config will be the last one merged into the main config.
+Setups in data configs will override all existed setups!
+"""
+
+from yacs.config import CfgNode as CN
+_CN = CN()
+_CN.DATASET = CN()
+_CN.TRAINER = CN()
+
+# training data config
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+# validation set config
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+
+# testing data config
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+
+# dataset config
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+cfg = _CN
diff --git a/imcui/third_party/ASpanFormer/configs/data/megadepth_test_1500.py b/imcui/third_party/ASpanFormer/configs/data/megadepth_test_1500.py
new file mode 100644
index 0000000000000000000000000000000000000000..9616432f52a693ed84f3f12b9b85470b23410eee
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/data/megadepth_test_1500.py
@@ -0,0 +1,13 @@
+from configs.data.base import cfg
+
+TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info"
+
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
+cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
+cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 1152
+cfg.DATASET.MGDPT_IMG_PAD=True
+cfg.DATASET.MGDPT_DF =8
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py b/imcui/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f9b01fdaed254e10b3d55980499b88a00060f04
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/data/megadepth_trainval_832.py
@@ -0,0 +1,22 @@
+from configs.data.base import cfg
+
+
+TRAIN_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
+cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
+cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
+
+TEST_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
+cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
+cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+# 368 scenes in total for MegaDepth
+# (with difficulty balanced (further split each scene to 3 sub-scenes))
+cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 832 # for training on 32GB meme GPUs
diff --git a/imcui/third_party/ASpanFormer/configs/data/scannet_test_1500.py b/imcui/third_party/ASpanFormer/configs/data/scannet_test_1500.py
new file mode 100644
index 0000000000000000000000000000000000000000..60e560fa01d73345200aaca10961449fdf3e9fbe
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/data/scannet_test_1500.py
@@ -0,0 +1,11 @@
+from configs.data.base import cfg
+
+TEST_BASE_PATH = "assets/scannet_test_1500"
+
+cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
+cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
+cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
+cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt"
+cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz"
+
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
diff --git a/imcui/third_party/ASpanFormer/configs/data/scannet_trainval.py b/imcui/third_party/ASpanFormer/configs/data/scannet_trainval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c38d6440e2b4ec349e5f168909c7f8c367408813
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/configs/data/scannet_trainval.py
@@ -0,0 +1,17 @@
+from configs.data.base import cfg
+
+
+TRAIN_BASE_PATH = "data/scannet/index"
+cfg.DATASET.TRAINVAL_DATA_SOURCE = "ScanNet"
+cfg.DATASET.TRAIN_DATA_ROOT = "data/scannet/train"
+cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_data/train"
+cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/scene_data/train_list/scannet_all.txt"
+cfg.DATASET.TRAIN_INTRINSIC_PATH = f"{TRAIN_BASE_PATH}/intrinsics.npz"
+
+TEST_BASE_PATH = "assets/scannet_test_1500"
+cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
+cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
+cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH
+cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt"
+cfg.DATASET.VAL_INTRINSIC_PATH = cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
diff --git a/imcui/third_party/ASpanFormer/demo/demo.py b/imcui/third_party/ASpanFormer/demo/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4d8a91f131b30e131cbdd6bf8ee44d53a0b256d
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/demo/demo.py
@@ -0,0 +1,64 @@
+import os
+import sys
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, ROOT_DIR)
+
+from src.ASpanFormer.aspanformer import ASpanFormer
+from src.config.default import get_cfg_defaults
+from src.utils.misc import lower_config
+import demo_utils
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+import cv2
+import torch
+import numpy as np
+
+import argparse
+parser = argparse.ArgumentParser()
+parser.add_argument('--config_path', type=str, default='../configs/aspan/outdoor/aspan_test.py',
+ help='path for config file.')
+parser.add_argument('--img0_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg',
+ help='path for image0.')
+parser.add_argument('--img1_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg',
+ help='path for image1.')
+parser.add_argument('--weights_path', type=str, default='../weights/outdoor.ckpt',
+ help='path for model weights.')
+parser.add_argument('--long_dim0', type=int, default=1024,
+ help='resize for longest dim of image0.')
+parser.add_argument('--long_dim1', type=int, default=1024,
+ help='resize for longest dim of image1.')
+
+args = parser.parse_args()
+
+
+if __name__=='__main__':
+ config = get_cfg_defaults()
+ config.merge_from_file(args.config_path)
+ _config = lower_config(config)
+ matcher = ASpanFormer(config=_config['aspan'])
+ state_dict = torch.load(args.weights_path, map_location='cpu')['state_dict']
+ matcher.load_state_dict(state_dict,strict=False)
+ matcher.to(device),matcher.eval()
+
+ img0,img1=cv2.imread(args.img0_path),cv2.imread(args.img1_path)
+ img0_g,img1_g=cv2.imread(args.img0_path,0),cv2.imread(args.img1_path,0)
+ img0,img1=demo_utils.resize(img0,args.long_dim0),demo_utils.resize(img1,args.long_dim1)
+ img0_g,img1_g=demo_utils.resize(img0_g,args.long_dim0),demo_utils.resize(img1_g,args.long_dim1)
+ data={'image0':torch.from_numpy(img0_g/255.)[None,None].to(device).float(),
+ 'image1':torch.from_numpy(img1_g/255.)[None,None].to(device).float()}
+ with torch.no_grad():
+ matcher(data,online_resize=True)
+ corr0,corr1=data['mkpts0_f'].cpu().numpy(),data['mkpts1_f'].cpu().numpy()
+
+ F_hat,mask_F=cv2.findFundamentalMat(corr0,corr1,method=cv2.FM_RANSAC,ransacReprojThreshold=1)
+ if mask_F is not None:
+ mask_F=mask_F[:,0].astype(bool)
+ else:
+ mask_F=np.zeros_like(corr0[:,0]).astype(bool)
+
+ #visualize match
+ display=demo_utils.draw_match(img0,img1,corr0,corr1)
+ display_ransac=demo_utils.draw_match(img0,img1,corr0[mask_F],corr1[mask_F])
+ cv2.imwrite('match.png',display)
+ cv2.imwrite('match_ransac.png',display_ransac)
+ print(len(corr1),len(corr1[mask_F]))
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/demo/demo_utils.py b/imcui/third_party/ASpanFormer/demo/demo_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a104e25d3f5ee8b7efb6cc5fa0dc27378e22c83f
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/demo/demo_utils.py
@@ -0,0 +1,44 @@
+import cv2
+import numpy as np
+
+def resize(image,long_dim):
+ h,w=image.shape[0],image.shape[1]
+ image=cv2.resize(image,(int(w*long_dim/max(h,w)),int(h*long_dim/max(h,w))))
+ return image
+
+def draw_points(img,points,color=(0,255,0),radius=3):
+ dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
+ for i in range(points.shape[0]):
+ cv2.circle(img, dp[i],radius=radius,color=color)
+ return img
+
+
+def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
+ if resize is not None:
+ scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
+ img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA)
+ corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
+ corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
+ corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
+
+ assert len(corr1) == len(corr2)
+
+ draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
+ if color is None:
+ color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
+ if len(color)==1:
+ display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
+ matchColor=color[0],
+ singlePointColor=color[0],
+ flags=4
+ )
+ else:
+ height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
+ display=np.zeros([height,width,3],np.uint8)
+ display[:img1.shape[0],:img1.shape[1]]=img1
+ display[:img2.shape[0],img1.shape[1]:]=img2
+ for i in range(len(corr1)):
+ left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
+ cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
+ cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
+ return display
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/environment.yaml b/imcui/third_party/ASpanFormer/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5c52328762e971c94b447198869ec0036771bf76
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/environment.yaml
@@ -0,0 +1,12 @@
+name: ASpanFormer
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.8
+ - cudatoolkit=10.2
+ - pytorch=1.8.1
+ - pip
+ - pip:
+ - -r requirements.txt
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/__init__.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bfd5a901e83c7e8d3b439f21afa20ac8237635e
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/__init__.py
@@ -0,0 +1,2 @@
+from .aspanformer import LocalFeatureTransformer_Flow
+from .utils.cvpr_ds_config import default_cfg
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff6704976cbe9e916c6de6af9e3b755dfbd20bf
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py
@@ -0,0 +1,3 @@
+from .transformer import LocalFeatureTransformer_Flow
+from .loftr import LocalFeatureTransformer
+from .fine_preprocess import FinePreprocess
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1fb6794461d043b0df4a20664e974a38240727
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py
@@ -0,0 +1,199 @@
+import torch
+from torch.nn import Module
+import torch.nn as nn
+from itertools import product
+from torch.nn import functional as F
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+class layernorm2d(nn.Module):
+
+ def __init__(self,dim) :
+ super().__init__()
+ self.dim=dim
+ self.affine=nn.parameter.Parameter(torch.ones(dim), requires_grad=True)
+ self.bias=nn.parameter.Parameter(torch.zeros(dim), requires_grad=True)
+
+ def forward(self,x):
+ #x: B*C*H*W
+ mean,std=x.mean(dim=1,keepdim=True),x.std(dim=1,keepdim=True)
+ return self.affine[None,:,None,None]*(x-mean)/(std+1e-6)+self.bias[None,:,None,None]
+
+
+class HierachicalAttention(Module):
+ def __init__(self,d_model,nhead,nsample,radius_scale,nlevel=3):
+ super().__init__()
+ self.d_model=d_model
+ self.nhead=nhead
+ self.nsample=nsample
+ self.nlevel=nlevel
+ self.radius_scale=radius_scale
+ self.merge_head = nn.Sequential(
+ nn.Conv1d(d_model*3, d_model, kernel_size=1,bias=False),
+ nn.ReLU(True),
+ nn.Conv1d(d_model, d_model, kernel_size=1,bias=False),
+ )
+ self.fullattention=FullAttention(d_model,nhead)
+ self.temp=nn.parameter.Parameter(torch.tensor(1.),requires_grad=True)
+ sample_offset=torch.tensor([[pos[0]-nsample[1]/2+0.5, pos[1]-nsample[1]/2+0.5] for pos in product(range(nsample[1]), range(nsample[1]))]) #r^2*2
+ self.sample_offset=nn.parameter.Parameter(sample_offset,requires_grad=False)
+
+ def forward(self,query,key,value,flow,size_q,size_kv,mask0=None, mask1=None,ds0=[4,4],ds1=[4,4]):
+ """
+ Args:
+ q,k,v (torch.Tensor): [B, C, L]
+ mask (torch.Tensor): [B, L]
+ flow (torch.Tensor): [B, H, W, 4]
+ Return:
+ all_message (torch.Tensor): [B, C, H, W]
+ """
+
+ variance=flow[:,:,:,2:]
+ offset=flow[:,:,:,:2] #B*H*W*2
+ bs=query.shape[0]
+ h0,w0=size_q[0],size_q[1]
+ h1,w1=size_kv[0],size_kv[1]
+ variance=torch.exp(0.5*variance)*self.radius_scale #b*h*w*2(pixel scale)
+ span_scale=torch.clamp((variance*2/self.nsample[1]),min=1) #b*h*w*2
+
+ sub_sample0,sub_sample1=[ds0,2,1],[ds1,2,1]
+ q_list=[F.avg_pool2d(query.view(bs,-1,h0,w0),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0]
+ k_list=[F.avg_pool2d(key.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1]
+ v_list=[F.avg_pool2d(value.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] #n_level
+
+ offset_list=[F.avg_pool2d(offset.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1)/sub_size for sub_size in sub_sample0[1:]] #n_level-1
+ span_list=[F.avg_pool2d(span_scale.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1) for sub_size in sub_sample0[1:]] #n_level-1
+
+ if mask0 is not None:
+ mask0,mask1=mask0.view(bs,1,h0,w0),mask1.view(bs,1,h1,w1)
+ mask0_list=[-F.max_pool2d(-mask0,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0]
+ mask1_list=[-F.max_pool2d(-mask1,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1]
+ else:
+ mask0_list=mask1_list=[None,None,None]
+
+ message_list=[]
+ #full attention at coarse scale
+ mask0_flatten=mask0_list[0].view(bs,-1) if mask0 is not None else None
+ mask1_flatten=mask1_list[0].view(bs,-1) if mask1 is not None else None
+ message_list.append(self.fullattention(q_list[0],k_list[0],v_list[0],mask0_flatten,mask1_flatten,self.temp).view(bs,self.d_model,h0//ds0[0],w0//ds0[1]))
+
+ for index in range(1,self.nlevel):
+ q,k,v=q_list[index],k_list[index],v_list[index]
+ mask0,mask1=mask0_list[index],mask1_list[index]
+ s,o=span_list[index-1],offset_list[index-1] #B*h*w(*2)
+ q,k,v,sample_pixel,mask_sample=self.partition_token(q,k,v,o,s,mask0) #B*Head*D*G*N(G*N=H*W for q)
+ message_list.append(self.group_attention(q,k,v,1,mask_sample).view(bs,self.d_model,h0//sub_sample0[index],w0//sub_sample0[index]))
+ #fuse
+ all_message=torch.cat([F.upsample(message_list[idx],scale_factor=sub_sample0[idx],mode='nearest') \
+ for idx in range(self.nlevel)],dim=1).view(bs,-1,h0*w0) #b*3d*H*W
+
+ all_message=self.merge_head(all_message).view(bs,-1,h0,w0) #b*d*H*W
+ return all_message
+
+ def partition_token(self,q,k,v,offset,span_scale,maskv):
+ #q,k,v: B*C*H*W
+ #o: B*H/2*W/2*2
+ #span_scale:B*H*W
+ bs=q.shape[0]
+ h,w=q.shape[2],q.shape[3]
+ hk,wk=k.shape[2],k.shape[3]
+ offset=offset.view(bs,-1,2)
+ span_scale=span_scale.view(bs,-1,1,2)
+ #B*G*2
+ offset_sample=self.sample_offset[None,None]*span_scale
+ sample_pixel=offset[:,:,None]+offset_sample#B*G*r^2*2
+ sample_norm=sample_pixel/torch.tensor([wk/2,hk/2]).to(device)[None,None,None]-1
+
+ q = q.view(bs, -1 , h // self.nsample[0], self.nsample[0], w // self.nsample[0], self.nsample[0]).\
+ permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, self.nhead,self.d_model//self.nhead, -1,self.nsample[0]**2)#B*head*D*G*N(G*N=H*W for q)
+ #sample token
+ k=F.grid_sample(k, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2
+ v=F.grid_sample(v, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2
+ #import pdb;pdb.set_trace()
+ if maskv is not None:
+ mask_sample=F.grid_sample(maskv.view(bs,-1,h,w).float(),grid=sample_norm,mode='nearest')==1 #B*1*G*r^2
+ else:
+ mask_sample=None
+ return q,k,v,sample_pixel,mask_sample
+
+
+ def group_attention(self,query,key,value,temp,mask_sample=None):
+ #q,k,v: B*Head*D*G*N(G*N=H*W for q)
+ bs=query.shape[0]
+ #import pdb;pdb.set_trace()
+ QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key)
+ if mask_sample is not None:
+ num_head,number_n=QK.shape[1],QK.shape[3]
+ QK.masked_fill_(~(mask_sample[:,:,:,None]).expand(-1,num_head,-1,number_n,-1).bool(), float(-1e8))
+ # Compute the attention and the weighted average
+ softmax_temp = temp / query.size(2)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=-1)
+ queried_values = torch.einsum("bhgnm,bhdgm->bhdgn", A, value).contiguous().view(bs,self.d_model,-1)
+ return queried_values
+
+
+
+class FullAttention(Module):
+ def __init__(self,d_model,nhead):
+ super().__init__()
+ self.d_model=d_model
+ self.nhead=nhead
+
+ def forward(self, q, k,v , mask0=None, mask1=None, temp=1):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ q,k,v: [N, D, L]
+ mask: [N, L]
+ Returns:
+ msg: [N,L]
+ """
+ bs=q.shape[0]
+ q,k,v=q.view(bs,self.nhead,self.d_model//self.nhead,-1),k.view(bs,self.nhead,self.d_model//self.nhead,-1),v.view(bs,self.nhead,self.d_model//self.nhead,-1)
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum("nhdl,nhds->nhls", q, k)
+ if mask0 is not None:
+ QK.masked_fill_(~(mask0[:,None, :, None] * mask1[:, None, None]).bool(), float(-1e8))
+ # Compute the attention and the weighted average
+ softmax_temp = temp / q.size(2)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=-1)
+ queried_values = torch.einsum("nhls,nhds->nhdl", A, v).contiguous().view(bs,self.d_model,-1)
+ return queried_values
+
+
+
+def elu_feature_map(x):
+ return F.elu(x) + 1
+
+class LinearAttention(Module):
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.feature_map = elu_feature_map
+ self.eps = eps
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = self.feature_map(queries)
+ K = self.feature_map(keys)
+
+ # set padded position to zero
+ if q_mask is not None:
+ Q = Q * q_mask[:, :, None, None]
+ if kv_mask is not None:
+ K = K * kv_mask[:, :, None, None]
+ values = values * kv_mask[:, :, None, None]
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+
+ return queried_values.contiguous()
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py
@@ -0,0 +1,59 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange, repeat
+
+
+class FinePreprocess(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.cat_c_feat = config['fine_concat_coarse_feat']
+ self.W = self.config['fine_window_size']
+
+ d_model_c = self.config['coarse']['d_model']
+ d_model_f = self.config['fine']['d_model']
+ self.d_model_f = d_model_f
+ if self.cat_c_feat:
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
+ self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
+
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
+ W = self.W
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
+
+ data.update({'W': W})
+ if data['b_ids'].shape[0] == 0:
+ feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
+ feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
+ return feat0, feat1
+
+ # 1. unfold(crop) all local windows
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+
+ # option: use coarse-level loftr feature as context: concat and linear
+ if self.cat_c_feat:
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
+ feat_cf_win = self.merge_feat(torch.cat([
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
+ ], -1))
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
+
+ return feat_f0_unfold, feat_f1_unfold
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dcebaa7beee978b9b8abcec8bb1bd2cc6b60870
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py
@@ -0,0 +1,112 @@
+import copy
+import torch
+import torch.nn as nn
+from .attention import LinearAttention
+
+class LoFTREncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear'):
+ super(LoFTREncoderLayer, self).__init__()
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+ self.attention = LinearAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None, type=None, index=0):
+ """
+ Args:
+ x (torch.Tensor): [N, L, C]
+ source (torch.Tensor): [N, S, C]
+ x_mask (torch.Tensor): [N, L] (optional)
+ source_mask (torch.Tensor): [N, S] (optional)
+ """
+ bs = x.size(0)
+ query, key, value = x, source, source
+
+ # multi-head attention
+ query = self.q_proj(query).view(
+ bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
+ key = self.k_proj(key).view(bs, -1, self.nhead,
+ self.dim) # [N, S, (H, D)]
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
+
+ message = self.attention(
+ query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
+ message = self.merge(message.view(
+ bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=2))
+ message = self.norm2(message)
+
+ return x + message
+
+
+class LocalFeatureTransformer(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(LocalFeatureTransformer, self).__init__()
+
+ self.config = config
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+ self.layer_names = config['layer_names']
+ encoder_layer = LoFTREncoderLayer(
+ config['d_model'], config['nhead'], config['attention'])
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+
+ assert self.d_model == feat0.size(
+ 2), "the feature number of src and transformer must be equal"
+
+ index = 0
+ for layer, name in zip(self.layers, self.layer_names):
+ if name == 'self':
+ feat0 = layer(feat0, feat0, mask0, mask0,
+ type='self', index=index)
+ feat1 = layer(feat1, feat1, mask1, mask1)
+ elif name == 'cross':
+ feat0 = layer(feat0, feat1, mask0, mask1)
+ feat1 = layer(feat1, feat0, mask1, mask0,
+ type='cross', index=index)
+ index += 1
+ else:
+ raise KeyError
+ return feat0, feat1
+
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bed7b4f65c6b5936e9e265dfefc0d058dbfa33f
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py
@@ -0,0 +1,245 @@
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .attention import FullAttention, HierachicalAttention ,layernorm2d
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class messageLayer_ini(nn.Module):
+
+ def __init__(self, d_model, d_flow,d_value, nhead):
+ super().__init__()
+ super(messageLayer_ini, self).__init__()
+
+ self.d_model = d_model
+ self.d_flow = d_flow
+ self.d_value=d_value
+ self.nhead = nhead
+ self.attention = FullAttention(d_model,nhead)
+
+ self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
+ self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
+ self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
+ self.merge_head=nn.Conv1d(d_model,d_model,kernel_size=1,bias=False)
+
+ self.merge_f= self.merge_f = nn.Sequential(
+ nn.Conv2d(d_model*2, d_model*2, kernel_size=1, bias=False),
+ nn.ReLU(True),
+ nn.Conv2d(d_model*2, d_model, kernel_size=1, bias=False),
+ )
+
+ self.norm1 = layernorm2d(d_model)
+ self.norm2 = layernorm2d(d_model)
+
+
+ def forward(self, x0, x1,pos0,pos1,mask0=None,mask1=None):
+ #x1,x2: b*d*L
+ x0,x1=self.update(x0,x1,pos1,mask0,mask1),\
+ self.update(x1,x0,pos0,mask1,mask0)
+ return x0,x1
+
+
+ def update(self,f0,f1,pos1,mask0,mask1):
+ """
+ Args:
+ f0: [N, D, H, W]
+ f1: [N, D, H, W]
+ Returns:
+ f0_new: (N, d, h, w)
+ """
+ bs,h,w=f0.shape[0],f0.shape[2],f0.shape[3]
+
+ f0_flatten,f1_flatten=f0.view(bs,self.d_model,-1),f1.view(bs,self.d_model,-1)
+ pos1_flatten=pos1.view(bs,self.d_value-self.d_model,-1)
+ f1_flatten_v=torch.cat([f1_flatten,pos1_flatten],dim=1)
+
+ queries,keys=self.q_proj(f0_flatten),self.k_proj(f1_flatten)
+ values=self.v_proj(f1_flatten_v).view(bs,self.nhead,self.d_model//self.nhead,-1)
+
+ queried_values=self.attention(queries,keys,values,mask0,mask1)
+ msg=self.merge_head(queried_values).view(bs,-1,h,w)
+ msg=self.norm2(self.merge_f(torch.cat([f0,self.norm1(msg)],dim=1)))
+ return f0+msg
+
+
+
+class messageLayer_gla(nn.Module):
+
+ def __init__(self,d_model,d_flow,d_value,
+ nhead,radius_scale,nsample,update_flow=True):
+ super().__init__()
+ self.d_model = d_model
+ self.d_flow=d_flow
+ self.d_value=d_value
+ self.nhead = nhead
+ self.radius_scale=radius_scale
+ self.update_flow=update_flow
+ self.flow_decoder=nn.Sequential(
+ nn.Conv1d(d_flow, d_flow//2, kernel_size=1, bias=False),
+ nn.ReLU(True),
+ nn.Conv1d(d_flow//2, 4, kernel_size=1, bias=False))
+ self.attention=HierachicalAttention(d_model,nhead,nsample,radius_scale)
+
+ self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
+ self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
+ self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
+
+ d_extra=d_flow if update_flow else 0
+ self.merge_f=nn.Sequential(
+ nn.Conv2d(d_model*2+d_extra, d_model+d_flow, kernel_size=1, bias=False),
+ nn.ReLU(True),
+ nn.Conv2d(d_model+d_flow, d_model+d_extra, kernel_size=3,padding=1, bias=False),
+ )
+ self.norm1 = layernorm2d(d_model)
+ self.norm2 = layernorm2d(d_model+d_extra)
+
+ def forward(self, x0, x1, flow_feature0,flow_feature1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
+ """
+ Args:
+ x0 (torch.Tensor): [B, C, H, W]
+ x1 (torch.Tensor): [B, C, H, W]
+ flow_feature0 (torch.Tensor): [B, C', H, W]
+ flow_feature1 (torch.Tensor): [B, C', H, W]
+ """
+ flow0,flow1=self.decode_flow(flow_feature0,flow_feature1.shape[2:]),self.decode_flow(flow_feature1,flow_feature0.shape[2:])
+ x0_new,flow_feature0_new=self.update(x0,x1,flow0.detach(),flow_feature0,pos1,mask0,mask1,ds0,ds1)
+ x1_new,flow_feature1_new=self.update(x1,x0,flow1.detach(),flow_feature1,pos0,mask1,mask0,ds1,ds0)
+ return x0_new,x1_new,flow_feature0_new,flow_feature1_new,flow0,flow1
+
+ def update(self,x0,x1,flow0,flow_feature0,pos1,mask0,mask1,ds0,ds1):
+ bs=x0.shape[0]
+ queries,keys=self.q_proj(x0.view(bs,self.d_model,-1)),self.k_proj(x1.view(bs,self.d_model,-1))
+ x1_pos=torch.cat([x1,pos1],dim=1)
+ values=self.v_proj(x1_pos.view(bs,self.d_value,-1))
+ msg=self.attention(queries,keys,values,flow0,x0.shape[2:],x1.shape[2:],mask0,mask1,ds0,ds1)
+
+ if self.update_flow:
+ update_feature=torch.cat([x0,flow_feature0],dim=1)
+ else:
+ update_feature=x0
+ msg=self.norm2(self.merge_f(torch.cat([update_feature,self.norm1(msg)],dim=1)))
+ update_feature=update_feature+msg
+
+ x0_new,flow_feature0_new=update_feature[:,:self.d_model],update_feature[:,self.d_model:]
+ return x0_new,flow_feature0_new
+
+ def decode_flow(self,flow_feature,kshape):
+ bs,h,w=flow_feature.shape[0],flow_feature.shape[2],flow_feature.shape[3]
+ scale_factor=torch.tensor([kshape[1],kshape[0]]).to(device)[None,None,None]
+ flow=self.flow_decoder(flow_feature.view(bs,-1,h*w)).permute(0,2,1).view(bs,h,w,4)
+ flow_coordinates=torch.sigmoid(flow[:,:,:,:2])*scale_factor
+ flow_var=flow[:,:,:,2:]
+ flow=torch.cat([flow_coordinates,flow_var],dim=-1) #B*H*W*4
+ return flow
+
+
+class flow_initializer(nn.Module):
+
+ def __init__(self, dim, dim_flow, nhead, layer_num):
+ super().__init__()
+ self.layer_num= layer_num
+ self.dim = dim
+ self.dim_flow = dim_flow
+
+ encoder_layer = messageLayer_ini(
+ dim ,dim_flow,dim+dim_flow , nhead)
+ self.layers_coarse = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for _ in range(layer_num)])
+ self.decoupler = nn.Conv2d(
+ self.dim, self.dim+self.dim_flow, kernel_size=1)
+ self.up_merge = nn.Conv2d(2*dim, dim, kernel_size=1)
+
+ def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
+ # feat0: [B, C, H0, W0]
+ # feat1: [B, C, H1, W1]
+ # use low-res MHA to initialize flow feature
+ bs = feat0.size(0)
+ h0,w0,h1,w1=feat0.shape[2],feat0.shape[3],feat1.shape[2],feat1.shape[3]
+
+ # coarse level
+ sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), \
+ F.avg_pool2d(feat1, ds1, stride=ds1)
+
+ sub_pos0,sub_pos1=F.avg_pool2d(pos0, ds0, stride=ds0), \
+ F.avg_pool2d(pos1, ds1, stride=ds1)
+
+ if mask0 is not None:
+ mask0,mask1=-F.max_pool2d(-mask0.view(bs,1,h0,w0),ds0,stride=ds0).view(bs,-1),\
+ -F.max_pool2d(-mask1.view(bs,1,h1,w1),ds1,stride=ds1).view(bs,-1)
+
+ for layer in self.layers_coarse:
+ sub_feat0, sub_feat1 = layer(sub_feat0, sub_feat1,sub_pos0,sub_pos1,mask0,mask1)
+ # decouple flow and visual features
+ decoupled_feature0, decoupled_feature1 = self.decoupler(sub_feat0),self.decoupler(sub_feat1)
+
+ sub_feat0, sub_flow_feature0 = decoupled_feature0[:,:self.dim], decoupled_feature0[:, self.dim:]
+ sub_feat1, sub_flow_feature1 = decoupled_feature1[:,:self.dim], decoupled_feature1[:, self.dim:]
+ update_feat0, flow_feature0 = F.upsample(sub_feat0, scale_factor=ds0, mode='bilinear'),\
+ F.upsample(sub_flow_feature0, scale_factor=ds0, mode='bilinear')
+ update_feat1, flow_feature1 = F.upsample(sub_feat1, scale_factor=ds1, mode='bilinear'),\
+ F.upsample(sub_flow_feature1, scale_factor=ds1, mode='bilinear')
+
+ feat0 = feat0+self.up_merge(torch.cat([feat0, update_feat0], dim=1))
+ feat1 = feat1+self.up_merge(torch.cat([feat1, update_feat1], dim=1))
+
+ return feat0,feat1,flow_feature0,flow_feature1 #b*c*h*w
+
+
+class LocalFeatureTransformer_Flow(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(LocalFeatureTransformer_Flow, self).__init__()
+
+ self.config = config
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+
+ self.pos_transform=nn.Conv2d(config['d_model'],config['d_flow'],kernel_size=1,bias=False)
+ self.ini_layer = flow_initializer(self.d_model, config['d_flow'], config['nhead'],config['ini_layer_num'])
+
+ encoder_layer = messageLayer_gla(
+ config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'])
+ encoder_layer_last=messageLayer_gla(
+ config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'],update_flow=False)
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(config['layer_num']-1)]+[encoder_layer_last])
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for name,p in self.named_parameters():
+ if 'temp' in name or 'sample_offset' in name:
+ continue
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, C, H, W]
+ feat1 (torch.Tensor): [N, C, H, W]
+ pos1,pos2: [N, C, H, W]
+ Outputs:
+ feat0: [N,-1,C]
+ feat1: [N,-1,C]
+ flow_list: [L,N,H,W,4]*1(2)
+ """
+ bs = feat0.size(0)
+
+ pos0,pos1=self.pos_transform(pos0),self.pos_transform(pos1)
+ pos0,pos1=pos0.expand(bs,-1,-1,-1),pos1.expand(bs,-1,-1,-1)
+ assert self.d_model == feat0.size(
+ 1), "the feature number of src and transformer must be equal"
+
+ flow_list=[[],[]]# [px,py,sx,sy]
+ if mask0 is not None:
+ mask0,mask1=mask0[:,None].float(),mask1[:,None].float()
+ feat0,feat1, flow_feature0, flow_feature1 = self.ini_layer(feat0, feat1,pos0,pos1,mask0,mask1,ds0,ds1)
+ for layer in self.layers:
+ feat0,feat1,flow_feature0,flow_feature1,flow0,flow1=layer(feat0,feat1,flow_feature0,flow_feature1,pos0,pos1,mask0,mask1,ds0,ds1)
+ flow_list[0].append(flow0)
+ flow_list[1].append(flow1)
+ flow_list[0]=torch.stack(flow_list[0],dim=0)
+ flow_list[1]=torch.stack(flow_list[1],dim=0)
+ feat0, feat1 = feat0.permute(0, 2, 3, 1).view(bs, -1, self.d_model), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model)
+ return feat0, feat1, flow_list
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42f2438abda5883796cea9f379380fa6ad7d7c1
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/aspanformer.py
@@ -0,0 +1,134 @@
+import torch
+import torch.nn as nn
+from torchvision import transforms
+from einops.einops import rearrange
+
+from .backbone import build_backbone
+from .utils.position_encoding import PositionEncodingSine
+from .aspan_module import LocalFeatureTransformer_Flow, LocalFeatureTransformer, FinePreprocess
+from .utils.coarse_matching import CoarseMatching
+from .utils.fine_matching import FineMatching
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class ASpanFormer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Misc
+ self.config = config
+
+ # Modules
+ self.backbone = build_backbone(config)
+ self.pos_encoding = PositionEncodingSine(
+ config['coarse']['d_model'],pre_scaling=[config['coarse']['train_res'],config['coarse']['test_res']])
+ self.loftr_coarse = LocalFeatureTransformer_Flow(config['coarse'])
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
+ self.fine_preprocess = FinePreprocess(config)
+ self.loftr_fine = LocalFeatureTransformer(config["fine"])
+ self.fine_matching = FineMatching()
+ self.coarsest_level=config['coarse']['coarsest_level']
+
+ def forward(self, data, online_resize=False):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ if online_resize:
+ assert data['image0'].shape[0]==1 and data['image1'].shape[1]==1
+ self.resize_input(data,self.config['coarse']['train_res'])
+ else:
+ data['pos_scale0'],data['pos_scale1']=None,None
+
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
+ feats_c, feats_f = self.backbone(
+ torch.cat([data['image0'], data['image1']], dim=0))
+ (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
+ data['bs']), feats_f.split(data['bs'])
+ else: # handle different input shapes
+ (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
+ data['image0']), self.backbone(data['image1'])
+
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
+ })
+
+ # 2. coarse-level loftr module
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
+ [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(feat_c0,data['pos_scale0']), self.pos_encoding(feat_c1,data['pos_scale1'])
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n c h w ')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n c h w ')
+
+ #TODO:adjust ds
+ ds0=[int(data['hw0_c'][0]/self.coarsest_level[0]),int(data['hw0_c'][1]/self.coarsest_level[1])]
+ ds1=[int(data['hw1_c'][0]/self.coarsest_level[0]),int(data['hw1_c'][1]/self.coarsest_level[1])]
+ if online_resize:
+ ds0,ds1=[4,4],[4,4]
+
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'].flatten(
+ -2), data['mask1'].flatten(-2)
+ feat_c0, feat_c1, flow_list = self.loftr_coarse(
+ feat_c0, feat_c1,pos_encoding0,pos_encoding1,mask_c0,mask_c1,ds0,ds1)
+
+ # 3. match coarse-level and register predicted offset
+ self.coarse_matching(feat_c0, feat_c1, flow_list,data,
+ mask_c0=mask_c0, mask_c1=mask_c1)
+
+ # 4. fine-level refinement
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
+ feat_f0, feat_f1, feat_c0, feat_c1, data)
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
+ feat_f0_unfold, feat_f1_unfold)
+
+ # 5. match fine-level
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
+
+ # 6. resize match coordinates back to input resolution
+ if online_resize:
+ data['mkpts0_f']*=data['online_resize_scale0']
+ data['mkpts1_f']*=data['online_resize_scale1']
+
+ def load_state_dict(self, state_dict, *args, **kwargs):
+ for k in list(state_dict.keys()):
+ if k.startswith('matcher.'):
+ if 'sample_offset' in k:
+ state_dict.pop(k)
+ else:
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+ return super().load_state_dict(state_dict, *args, **kwargs)
+
+ def resize_input(self,data,train_res,df=32):
+ h0,w0,h1,w1=data['image0'].shape[2],data['image0'].shape[3],data['image1'].shape[2],data['image1'].shape[3]
+ data['image0'],data['image1']=self.resize_df(data['image0'],df),self.resize_df(data['image1'],df)
+
+ if len(train_res)==1:
+ train_res_h=train_res_w=train_res
+ else:
+ train_res_h,train_res_w=train_res[0],train_res[1]
+ data['pos_scale0'],data['pos_scale1']=[train_res_h/data['image0'].shape[2],train_res_w/data['image0'].shape[3]],\
+ [train_res_h/data['image1'].shape[2],train_res_w/data['image1'].shape[3]]
+ data['online_resize_scale0'],data['online_resize_scale1']=torch.tensor([w0/data['image0'].shape[3],h0/data['image0'].shape[2]])[None].to(device),\
+ torch.tensor([w1/data['image1'].shape[3],h1/data['image1'].shape[2]])[None].to(device)
+
+ def resize_df(self,image,df=32):
+ h,w=image.shape[2],image.shape[3]
+ h_new,w_new=h//df*df,w//df*df
+ if h!=h_new or w!=w_new:
+ img_new=transforms.Resize([h_new,w_new]).forward(image)
+ else:
+ img_new=image
+ return img_new
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e731b3f53ab367c89ef0ea8e1cbffb0d990775
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py
@@ -0,0 +1,11 @@
+from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
+
+
+def build_backbone(config):
+ if config['backbone_type'] == 'ResNetFPN':
+ if config['resolution'] == (8, 2):
+ return ResNetFPN_8_2(config['resnetfpn'])
+ elif config['resolution'] == (16, 4):
+ return ResNetFPN_16_4(config['resnetfpn'])
+ else:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..985e5b3f273a51e51447a8025ca3aadbe46752eb
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py
@@ -0,0 +1,199 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution without padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super().__init__()
+ self.conv1 = conv3x3(in_planes, planes, stride)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ conv1x1(in_planes, planes, stride=stride),
+ nn.BatchNorm2d(planes)
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.bn1(self.conv1(y)))
+ y = self.bn2(self.conv2(y))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+class ResNetFPN_8_2(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1/2.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ # FPN
+ x3_out = self.layer3_outconv(x3)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ return [x3_out, x1_out]
+
+
+class ResNetFPN_16_4(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/16 and 1/4.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+ self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
+
+ # 3. FPN upsample
+ self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
+ self.layer3_outconv2 = nn.Sequential(
+ conv3x3(block_dims[3], block_dims[3]),
+ nn.BatchNorm2d(block_dims[3]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[3], block_dims[2]),
+ )
+
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+ x4 = self.layer4(x3) # 1/16
+
+ # FPN
+ x4_out = self.layer4_outconv(x4)
+
+ x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x3_out = self.layer3_outconv(x3)
+ x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ return [x4_out, x2_out]
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..281a410e02465dec1d68ab69f48673268d1d3002
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py
@@ -0,0 +1,331 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange
+
+from time import time
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+INF = 1e9
+
+def mask_border(m, b: int, v):
+ """ Mask borders with value
+ Args:
+ m (torch.Tensor): [N, H0, W0, H1, W1]
+ b (int)
+ v (m.dtype)
+ """
+ if b <= 0:
+ return
+
+ m[:, :b] = v
+ m[:, :, :b] = v
+ m[:, :, :, :b] = v
+ m[:, :, :, :, :b] = v
+ m[:, -b:] = v
+ m[:, :, -b:] = v
+ m[:, :, :, -b:] = v
+ m[:, :, :, :, -b:] = v
+
+
+def mask_border_with_padding(m, bd, v, p_m0, p_m1):
+ if bd <= 0:
+ return
+
+ m[:, :bd] = v
+ m[:, :, :bd] = v
+ m[:, :, :, :bd] = v
+ m[:, :, :, :, :bd] = v
+
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
+ m[b_idx, h0 - bd:] = v
+ m[b_idx, :, w0 - bd:] = v
+ m[b_idx, :, :, h1 - bd:] = v
+ m[b_idx, :, :, :, w1 - bd:] = v
+
+
+def compute_max_candidates(p_m0, p_m1):
+ """Compute the max candidates of all pairs within a batch
+
+ Args:
+ p_m0, p_m1 (torch.Tensor): padded masks
+ """
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
+ max_cand = torch.sum(
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+ return max_cand
+
+
+class CoarseMatching(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # general config
+ self.thr = config['thr']
+ self.border_rm = config['border_rm']
+ # -- # for trainig fine-level LoFTR
+ self.train_coarse_percent = config['train_coarse_percent']
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
+
+ # we provide 2 options for differentiable matching
+ self.match_type = config['match_type']
+ if self.match_type == 'dual_softmax':
+ self.temperature=nn.parameter.Parameter(torch.tensor(10.), requires_grad=True)
+ elif self.match_type == 'sinkhorn':
+ try:
+ from .superglue import log_optimal_transport
+ except ImportError:
+ raise ImportError("download superglue.py first!")
+ self.log_optimal_transport = log_optimal_transport
+ self.bin_score = nn.Parameter(
+ torch.tensor(config['skh_init_bin_score'], requires_grad=True))
+ self.skh_iters = config['skh_iters']
+ self.skh_prefilter = config['skh_prefilter']
+ else:
+ raise NotImplementedError()
+
+ def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ offset: [layer, B, H, W, 4] (*2)
+ data (dict)
+ mask_c0 (torch.Tensor): [N, L] (optional)
+ mask_c1 (torch.Tensor): [N, S] (optional)
+ Update:
+ data (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ NOTE: M' != M during training.
+ """
+ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
+ # normalize
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1])
+
+ if self.match_type == 'dual_softmax':
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
+ feat_c1) * self.temperature
+ if mask_c0 is not None:
+ sim_matrix.masked_fill_(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -INF)
+ conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
+
+ elif self.match_type == 'sinkhorn':
+ # sinkhorn, dustbin included
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
+ if mask_c0 is not None:
+ sim_matrix[:, :L, :S].masked_fill_(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -INF)
+
+ # build uniform prior & use sinkhorn
+ log_assign_matrix = self.log_optimal_transport(
+ sim_matrix, self.bin_score, self.skh_iters)
+ assign_matrix = log_assign_matrix.exp()
+ conf_matrix = assign_matrix[:, :-1, :-1]
+
+ # filter prediction with dustbin score (only in evaluation mode)
+ if not self.training and self.skh_prefilter:
+ filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L]
+ filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S]
+ conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
+ conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
+
+ if self.config['sparse_spvs']:
+ data.update({'conf_matrix_with_bin': assign_matrix.clone()})
+
+ data.update({'conf_matrix': conf_matrix})
+ # predict coarse matches from conf_matrix
+ data.update(**self.get_coarse_match(conf_matrix, data))
+
+ #update predicted offset
+ if flow_list[0].shape[2]==flow_list[1].shape[2] and flow_list[0].shape[3]==flow_list[1].shape[3]:
+ flow_list=torch.stack(flow_list,dim=0)
+ data.update({'predict_flow':flow_list}) #[2*L*B*H*W*4]
+ self.get_offset_match(flow_list,data,mask_c0,mask_c1)
+
+ @torch.no_grad()
+ def get_coarse_match(self, conf_matrix, data):
+ """
+ Args:
+ conf_matrix (torch.Tensor): [N, L, S]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ axes_lengths = {
+ 'h0c': data['hw0_c'][0],
+ 'w0c': data['hw0_c'][1],
+ 'h1c': data['hw1_c'][0],
+ 'w1c': data['hw1_c'][1]
+ }
+ _device = conf_matrix.device
+ # 1. confidence thresholding
+ mask = conf_matrix > self.thr
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
+ **axes_lengths)
+ if 'mask0' not in data:
+ mask_border(mask, self.border_rm, False)
+ else:
+ mask_border_with_padding(mask, self.border_rm, False,
+ data['mask0'], data['mask1'])
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
+ **axes_lengths)
+
+ # 2. mutual nearest
+ mask = mask \
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
+
+ # 3. find all valid coarse matches
+ # this only works when at most one `True` in each row
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
+
+ # 4. Random sampling of training samples for fine-level LoFTR
+ # (optional) pad samples with gt coarse-level matches
+ if self.training:
+ # NOTE:
+ # The sampling is performed across all pairs in a batch without manually balancing
+ # #samples for fine-level increases w.r.t. batch_size
+ if 'mask0' not in data:
+ num_candidates_max = mask.size(0) * max(
+ mask.size(1), mask.size(2))
+ else:
+ num_candidates_max = compute_max_candidates(
+ data['mask0'], data['mask1'])
+ num_matches_train = int(num_candidates_max *
+ self.train_coarse_percent)
+ num_matches_pred = len(b_ids)
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
+
+ # pred_indices is to select from prediction
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
+ pred_indices = torch.arange(num_matches_pred, device=_device)
+ else:
+ pred_indices = torch.randint(
+ num_matches_pred,
+ (num_matches_train - self.train_pad_num_gt_min, ),
+ device=_device)
+
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
+ gt_pad_indices = torch.randint(
+ len(data['spv_b_ids']),
+ (max(num_matches_train - num_matches_pred,
+ self.train_pad_num_gt_min), ),
+ device=_device)
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
+
+ b_ids, i_ids, j_ids, mconf = map(
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
+ dim=0),
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+
+ # These matches select patches that feed into fine-level network
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+
+ # 4. Update with matches in original image resolution
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+ mkpts0_c = torch.stack(
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
+ dim=1) * scale0
+ mkpts1_c = torch.stack(
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
+ dim=1) * scale1
+
+ # These matches is the current prediction (for visualization)
+ coarse_matches.update({
+ 'gt_mask': mconf == 0,
+ 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
+ 'mkpts0_c': mkpts0_c[mconf != 0],
+ 'mkpts1_c': mkpts1_c[mconf != 0],
+ 'mconf': mconf[mconf != 0]
+ })
+
+ return coarse_matches
+
+ @torch.no_grad()
+ def get_offset_match(self, flow_list, data,mask1,mask2):
+ """
+ Args:
+ offset (torch.Tensor): [L, B, H, W, 2]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ offset1=flow_list[0]
+ bs,layer_num=offset1.shape[1],offset1.shape[0]
+
+ #left side
+ offset1=offset1.view(layer_num,bs,-1,4)
+ conf1=offset1[:,:,:,2:].mean(dim=-1)
+ if mask1 is not None:
+ conf1.masked_fill_(~mask1.bool()[None].expand(layer_num,-1,-1),100)
+ offset1=offset1[:,:,:,:2]
+ self.get_offset_match_work(offset1,conf1,data,'left')
+
+ #rihgt side
+ if len(flow_list)==2:
+ offset2=flow_list[1].view(layer_num,bs,-1,4)
+ conf2=offset2[:,:,:,2:].mean(dim=-1)
+ if mask2 is not None:
+ conf2.masked_fill_(~mask2.bool()[None].expand(layer_num,-1,-1),100)
+ offset2=offset2[:,:,:,:2]
+ self.get_offset_match_work(offset2,conf2,data,'right')
+
+
+ @torch.no_grad()
+ def get_offset_match_work(self, offset,conf, data,side):
+ bs,layer_num=offset.shape[1],offset.shape[0]
+ # 1. confidence thresholding
+ mask_conf= conf<2
+ for index in range(bs):
+ mask_conf[:,index,0]=True #safe guard in case that no match survives
+ # 3. find offset matches
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+ l_ids,b_ids,i_ids = torch.where(mask_conf)
+ j_coor=offset[l_ids,b_ids,i_ids,:2] *scale#[N,2]
+ i_coor=torch.stack([i_ids%data['hw0_c'][1],i_ids//data['hw0_c'][1]],dim=1)*scale
+ #i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).to(device).float()*scale #[N,2]
+ # These matches is the current prediction (for visualization)
+ data.update({
+ 'offset_bids_'+side: b_ids, # mconf == 0 => gt matches
+ 'offset_lids_'+side: l_ids,
+ 'conf'+side: conf[mask_conf]
+ })
+
+ if side=='right':
+ data.update({'offset_kpts0_f_'+side: j_coor.detach(),
+ 'offset_kpts1_f_'+side: i_coor})
+ else:
+ data.update({'offset_kpts0_f_'+side: i_coor,
+ 'offset_kpts1_f_'+side: j_coor.detach()})
+
+
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdc57e84936c805cb387b6239ca4a5ff6154e22e
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py
@@ -0,0 +1,50 @@
+from yacs.config import CfgNode as CN
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+_CN = CN()
+_CN.BACKBONE_TYPE = 'ResNetFPN'
+_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
+_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
+_CN.FINE_CONCAT_COARSE_FEAT = True
+
+# 1. LoFTR-backbone (local feature CNN) config
+_CN.RESNETFPN = CN()
+_CN.RESNETFPN.INITIAL_DIM = 128
+_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
+
+# 2. LoFTR-coarse module config
+_CN.COARSE = CN()
+_CN.COARSE.D_MODEL = 256
+_CN.COARSE.D_FFN = 256
+_CN.COARSE.NHEAD = 8
+_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
+_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
+_CN.COARSE.TEMP_BUG_FIX = False
+
+# 3. Coarse-Matching config
+_CN.MATCH_COARSE = CN()
+_CN.MATCH_COARSE.THR = 0.1
+_CN.MATCH_COARSE.BORDER_RM = 2
+_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
+_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.MATCH_COARSE.SKH_ITERS = 3
+_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
+_CN.MATCH_COARSE.SKH_PREFILTER = True
+_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory
+_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+
+# 4. LoFTR-fine module config
+_CN.FINE = CN()
+_CN.FINE.D_MODEL = 128
+_CN.FINE.D_FFN = 128
+_CN.FINE.NHEAD = 8
+_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
+_CN.FINE.ATTENTION = 'linear'
+
+default_cfg = lower_config(_CN)
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e77aded52e1eb5c01e22c2738104f3b09d6922a
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py
@@ -0,0 +1,74 @@
+import math
+import torch
+import torch.nn as nn
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+
+class FineMatching(nn.Module):
+ """FineMatching with s2d paradigm"""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, feat_f0, feat_f1, data):
+ """
+ Args:
+ feat0 (torch.Tensor): [M, WW, C]
+ feat1 (torch.Tensor): [M, WW, C]
+ data (dict)
+ Update:
+ data (dict):{
+ 'expec_f' (torch.Tensor): [M, 3],
+ 'mkpts0_f' (torch.Tensor): [M, 2],
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
+ """
+ M, WW, C = feat_f0.shape
+ W = int(math.sqrt(WW))
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
+
+ # corner case: if no coarse matches found
+ if M == 0:
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
+ # logger.warning('No matches found in coarse-level.')
+ data.update({
+ 'expec_f': torch.empty(0, 3, device=feat_f0.device),
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+
+ feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
+ softmax_temp = 1. / C**.5
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
+
+ # compute coordinates from heatmap
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
+
+ # compute std over
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
+
+ # for fine-level supervision
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
+
+ # compute absolute kpt coords
+ self.get_fine_match(coords_normalized, data)
+
+ @torch.no_grad()
+ def get_fine_match(self, coords_normed, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+
+ # mkpts0_f and mkpts1_f
+ mkpts0_f = data['mkpts0_c']
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
+ mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
+
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py
@@ -0,0 +1,54 @@
+import torch
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - ,
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ kpts0_long = kpts0.round().long()
+
+ # Sample depth, get calculable_mask on depth != 0
+ kpts0_depth = torch.stack(
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+ ) # (N, L)
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+ w_kpts0_long = w_kpts0.long()
+ w_kpts0_long[~covisible_mask, :] = 0
+
+ w_kpts0_depth = torch.stack(
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+ ) # (N, L)
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..07d384ae18370acb99ef00a788f628c967249ace
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py
@@ -0,0 +1,61 @@
+import math
+import torch
+from torch import nn
+
+
+class PositionEncodingSine(nn.Module):
+ """
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
+ """
+
+ def __init__(self, d_model, max_shape=(256, 256),pre_scaling=None):
+ """
+ Args:
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
+ on the final performance. For now, we keep both impls for backward compatability.
+ We will remove the buggy impl after re-training all variants of our released models.
+ """
+ super().__init__()
+ self.d_model=d_model
+ self.max_shape=max_shape
+ self.pre_scaling=pre_scaling
+
+ pe = torch.zeros((d_model, *max_shape))
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
+
+ if pre_scaling[0] is not None and pre_scaling[1] is not None:
+ train_res,test_res=pre_scaling[0],pre_scaling[1]
+ x_position,y_position=x_position*train_res[1]/test_res[1],y_position*train_res[0]/test_res[0]
+
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
+
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
+
+ def forward(self, x,scaling=None):
+ """
+ Args:
+ x: [N, C, H, W]
+ """
+ if scaling is None: #onliner scaling overwrites pre_scaling
+ return x + self.pe[:, :, :x.size(2), :x.size(3)],self.pe[:, :, :x.size(2), :x.size(3)]
+ else:
+ pe = torch.zeros((self.d_model, *self.max_shape))
+ y_position = torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0)*scaling[0]
+ x_position = torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0)*scaling[1]
+
+ div_term = torch.exp(torch.arange(0, self.d_model//2, 2).float() * (-math.log(10000.0) / (self.d_model//2)))
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
+ pe=pe.unsqueeze(0).to(x.device)
+ return x + pe[:, :, :x.size(2), :x.size(3)],pe[:, :, :x.size(2), :x.size(3)]
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cef3a7968413136f6dc9f52b6a1ec87192b006b
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py
@@ -0,0 +1,151 @@
+from math import log
+from loguru import logger
+
+import torch
+from einops import repeat
+from kornia.utils import create_meshgrid
+
+from .geometry import warp_kpts
+
+############## ↓ Coarse-Level supervision ↓ ##############
+
+
+@torch.no_grad()
+def mask_pts_at_padded_regions(grid_pt, mask):
+ """For megadepth dataset, zero-padding exists in images"""
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+ grid_pt[~mask.bool()] = 0
+ return grid_pt
+
+
+@torch.no_grad()
+def spvs_coarse(data, config):
+ """
+ Update:
+ data (dict): {
+ "conf_matrix_gt": [N, hw0, hw1],
+ 'spv_b_ids': [M]
+ 'spv_i_ids': [M]
+ 'spv_j_ids': [M]
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
+ }
+
+ NOTE:
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
+ """
+ # 1. misc
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+ scale = config['ASPAN']['RESOLUTION'][0]
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+
+ # 2. warp grids
+ # create kpts in meshgrid and resize them to image resolution
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
+ grid_pt0_i = scale0 * grid_pt0_c
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+ grid_pt1_i = scale1 * grid_pt1_c
+
+ # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
+ if 'mask0' in data:
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
+
+ # warp kpts bi-directionally and resize them to coarse-level resolution
+ # (no depth consistency check, since it leads to worse results experimentally)
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
+ _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
+ _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
+ w_pt0_c = w_pt0_i / scale1
+ w_pt1_c = w_pt1_i / scale0
+
+ # 3. check if mutual nearest neighbor
+ w_pt0_c_round = w_pt0_c[:, :, :].round().long()
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
+ w_pt1_c_round = w_pt1_c[:, :, :].round().long()
+ nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
+
+ # corner case: out of boundary
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
+ nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
+
+ loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
+ correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+ correct_0to1[:, 0] = False # ignore the top-left corner
+
+ # 4. construct a gt conf_matrix
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
+ j_ids = nearest_index1[b_ids, i_ids]
+
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
+ data.update({'conf_matrix_gt': conf_matrix_gt})
+
+ # 5. save coarse matches(gt) for training fine level
+ if len(b_ids) == 0:
+ logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
+ # this won't affect fine-level loss calculation
+ b_ids = torch.tensor([0], device=device)
+ i_ids = torch.tensor([0], device=device)
+ j_ids = torch.tensor([0], device=device)
+
+ data.update({
+ 'spv_b_ids': b_ids,
+ 'spv_i_ids': i_ids,
+ 'spv_j_ids': j_ids
+ })
+
+ # 6. save intermediate results (for fast fine-level computation)
+ data.update({
+ 'spv_w_pt0_i': w_pt0_i,
+ 'spv_pt1_i': grid_pt1_i
+ })
+
+
+def compute_supervision_coarse(data, config):
+ assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_coarse(data, config)
+ else:
+ raise ValueError(f'Unknown data source: {data_source}')
+
+
+############## ↓ Fine-Level supervision ↓ ##############
+
+@torch.no_grad()
+def spvs_fine(data, config):
+ """
+ Update:
+ data (dict):{
+ "expec_f_gt": [M, 2]}
+ """
+ # 1. misc
+ # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
+ w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
+ scale = config['ASPAN']['RESOLUTION'][1]
+ radius = config['ASPAN']['FINE_WINDOW_SIZE'] // 2
+
+ # 2. get coarse prediction
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+
+ # 3. compute gt
+ scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
+ # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
+ expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
+ data.update({"expec_f_gt": expec_f_gt})
+
+
+def compute_supervision_fine(data, config):
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_fine(data, config)
+ else:
+ raise NotImplementedError
diff --git a/imcui/third_party/ASpanFormer/src/__init__.py b/imcui/third_party/ASpanFormer/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/ASpanFormer/src/config/default.py b/imcui/third_party/ASpanFormer/src/config/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..40abd51c3f28ea6dee3c4e9fcee6efac5c080a2f
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/config/default.py
@@ -0,0 +1,180 @@
+from yacs.config import CfgNode as CN
+_CN = CN()
+
+############## ↓ ASPAN Pipeline ↓ ##############
+_CN.ASPAN = CN()
+_CN.ASPAN.BACKBONE_TYPE = 'ResNetFPN'
+_CN.ASPAN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
+_CN.ASPAN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
+_CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True
+
+# 1. ASPAN-backbone (local feature CNN) config
+_CN.ASPAN.RESNETFPN = CN()
+_CN.ASPAN.RESNETFPN.INITIAL_DIM = 128
+_CN.ASPAN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
+
+# 2. ASPAN-coarse module config
+_CN.ASPAN.COARSE = CN()
+_CN.ASPAN.COARSE.D_MODEL = 256
+_CN.ASPAN.COARSE.D_FFN = 256
+_CN.ASPAN.COARSE.D_FLOW= 128
+_CN.ASPAN.COARSE.NHEAD = 8
+_CN.ASPAN.COARSE.NLEVEL= 3
+_CN.ASPAN.COARSE.INI_LAYER_NUM = 2
+_CN.ASPAN.COARSE.LAYER_NUM = 4
+_CN.ASPAN.COARSE.NSAMPLE = [2,8]
+_CN.ASPAN.COARSE.RADIUS_SCALE= 5
+_CN.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
+_CN.ASPAN.COARSE.TRAIN_RES = None
+_CN.ASPAN.COARSE.TEST_RES = None
+
+# 3. Coarse-Matching config
+_CN.ASPAN.MATCH_COARSE = CN()
+_CN.ASPAN.MATCH_COARSE.THR = 0.2
+_CN.ASPAN.MATCH_COARSE.BORDER_RM = 2
+_CN.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
+_CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3
+_CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
+_CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False
+_CN.ASPAN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
+_CN.ASPAN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+_CN.ASPAN.MATCH_COARSE.SPARSE_SPVS = True
+_CN.ASPAN.MATCH_COARSE.LEARNABLE_DS_TEMP = True
+
+# 4. ASPAN-fine module config
+_CN.ASPAN.FINE = CN()
+_CN.ASPAN.FINE.D_MODEL = 128
+_CN.ASPAN.FINE.D_FFN = 128
+_CN.ASPAN.FINE.NHEAD = 8
+_CN.ASPAN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
+_CN.ASPAN.FINE.ATTENTION = 'linear'
+
+# 5. ASPAN Losses
+# -- # coarse-level
+_CN.ASPAN.LOSS = CN()
+_CN.ASPAN.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
+_CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0
+# _CN.ASPAN.LOSS.SPARSE_SPVS = False
+# -- - -- # focal loss (coarse)
+_CN.ASPAN.LOSS.FOCAL_ALPHA = 0.25
+_CN.ASPAN.LOSS.FOCAL_GAMMA = 2.0
+_CN.ASPAN.LOSS.POS_WEIGHT = 1.0
+_CN.ASPAN.LOSS.NEG_WEIGHT = 1.0
+# _CN.ASPAN.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not.
+# use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE`
+
+# -- # fine-level
+_CN.ASPAN.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
+_CN.ASPAN.LOSS.FINE_WEIGHT = 1.0
+_CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
+
+# -- # flow-sloss
+_CN.ASPAN.LOSS.FLOW_WEIGHT = 0.1
+
+
+############## Dataset ##############
+_CN.DATASET = CN()
+# 1. data config
+# training and validating
+_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+# testing
+_CN.DATASET.TEST_DATA_SOURCE = None
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+
+# 2. dataset config
+# general options
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
+_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
+
+# MegaDepth options
+_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
+_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
+_CN.DATASET.MGDPT_DF = 8
+
+############## Trainer ##############
+_CN.TRAINER = CN()
+_CN.TRAINER.WORLD_SIZE = 1
+_CN.TRAINER.CANONICAL_BS = 64
+_CN.TRAINER.CANONICAL_LR = 6e-3
+_CN.TRAINER.SCALING = None # this will be calculated automatically
+_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
+
+# optimizer
+_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
+_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
+_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
+_CN.TRAINER.ADAMW_DECAY = 0.1
+
+# step-based warm-up
+_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
+_CN.TRAINER.WARMUP_RATIO = 0.
+_CN.TRAINER.WARMUP_STEP = 4800
+
+# learning rate scheduler
+_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
+_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
+_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
+_CN.TRAINER.MSLR_GAMMA = 0.5
+_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
+_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
+
+# plotting related
+_CN.TRAINER.ENABLE_PLOTTING = True
+_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
+_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
+_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
+
+# geometric metrics and pose solver
+_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
+_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
+_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
+_CN.TRAINER.RANSAC_CONF = 0.99999
+_CN.TRAINER.RANSAC_MAX_ITERS = 10000
+_CN.TRAINER.USE_MAGSACPP = False
+
+# data sampler for train_dataloader
+_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
+# 'scene_balance' config
+_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
+_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
+_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
+_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
+# 'random' config
+_CN.TRAINER.RDM_REPLACEMENT = True
+_CN.TRAINER.RDM_NUM_SAMPLES = None
+
+# gradient clipping
+_CN.TRAINER.GRADIENT_CLIPPING = 0.5
+
+# reproducibility
+# This seed affects the data sampling. With the same seed, the data sampling is promised
+# to be the same. When resume training from a checkpoint, it's better to use a different
+# seed, otherwise the sampled data will be exactly the same as before resuming, which will
+# cause less unique data items sampled during the entire training.
+# Use of different seed values might affect the final training result, since not all data items
+# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
+_CN.TRAINER.SEED = 66
+
+
+def get_cfg_defaults():
+ """Get a yacs CfgNode object with default values for my_project."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _CN.clone()
diff --git a/imcui/third_party/ASpanFormer/src/datasets/__init__.py b/imcui/third_party/ASpanFormer/src/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1860e3ae060a26e4625925861cecdc355f2b08b7
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/datasets/__init__.py
@@ -0,0 +1,3 @@
+from .scannet import ScanNetDataset
+from .megadepth import MegaDepthDataset
+
diff --git a/imcui/third_party/ASpanFormer/src/datasets/megadepth.py b/imcui/third_party/ASpanFormer/src/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..a70ac715a3f807e37bc5b87ae9446ddd2aa4fc86
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/datasets/megadepth.py
@@ -0,0 +1,127 @@
+import os.path as osp
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+from loguru import logger
+
+from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
+
+
+class MegaDepthDataset(Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ mode='train',
+ min_overlap_score=0.4,
+ img_resize=None,
+ df=None,
+ img_padding=False,
+ depth_padding=False,
+ augment_fn=None,
+ **kwargs):
+ """
+ Manage one scene(npz_path) of MegaDepth dataset.
+
+ Args:
+ root_dir (str): megadepth root directory that has `phoenix`.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ mode (str): options are ['train', 'val', 'test']
+ min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
+ This is useful during training with batches and testing with memory intensive algorithms.
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
+ depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.mode = mode
+ self.scene_id = npz_path.split('.')[0]
+
+ # prepare scene_info and pair_info
+ if mode == 'test' and min_overlap_score != 0:
+ logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
+ min_overlap_score = 0
+ self.scene_info = np.load(npz_path, allow_pickle=True)
+ self.pair_infos = self.scene_info['pair_infos'].copy()
+ del self.scene_info['pair_infos']
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
+
+ # parameters for image resizing, padding and depthmap padding
+ if mode == 'train':
+ assert img_resize is not None and img_padding and depth_padding
+ self.img_resize = img_resize
+ self.df = df
+ self.img_padding = img_padding
+ self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
+
+ # for training LoFTR
+ self.augment_fn = augment_fn if mode == 'train' else None
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+
+ def __len__(self):
+ return len(self.pair_infos)
+
+ def __getitem__(self, idx):
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
+
+ # read grayscale image and mask. (1, h, w) and (h, w)
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
+
+ # TODO: Support augmentation & handle seeds for each worker correctly.
+ image0, mask0, scale0 = read_megadepth_gray(
+ img_name0, self.img_resize, self.df, self.img_padding, None)
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+ image1, mask1, scale1 = read_megadepth_gray(
+ img_name1, self.img_resize, self.df, self.img_padding, None)
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+
+ # read depth. shape: (h, w)
+ if self.mode in ['train', 'val']:
+ depth0 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
+ depth1 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read intrinsics of original size
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T0 = self.scene_info['poses'][idx0]
+ T1 = self.scene_info['poses'][idx1]
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
+ T_1to0 = T_0to1.inverse()
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'scale0': scale0, # [scale_w, scale_h]
+ 'scale1': scale1,
+ 'dataset_name': 'MegaDepth',
+ 'scene_id': self.scene_id,
+ 'pair_id': idx,
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
+ }
+
+ # for LoFTR training
+ if mask0 is not None: # img_padding is True
+ if self.coarse_scale:
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
+ scale_factor=self.coarse_scale,
+ mode='nearest',
+ recompute_scale_factor=False)[0].bool()
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+
+ return data
diff --git a/imcui/third_party/ASpanFormer/src/datasets/sampler.py b/imcui/third_party/ASpanFormer/src/datasets/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/datasets/sampler.py
@@ -0,0 +1,77 @@
+import torch
+from torch.utils.data import Sampler, ConcatDataset
+
+
+class RandomConcatSampler(Sampler):
+ """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
+ in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
+ However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
+
+ For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
+ Args:
+ shuffle (bool): shuffle the random sampled indices across all sub-datsets.
+ repeat (int): repeatedly use the sampled indices multiple times for training.
+ [arXiv:1902.05509, arXiv:1901.09335]
+ NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
+ NOTE: This sampler behaves differently with DistributedSampler.
+ It assume the dataset is splitted across ranks instead of replicated.
+ TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
+ ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
+ """
+ def __init__(self,
+ data_source: ConcatDataset,
+ n_samples_per_subset: int,
+ subset_replacement: bool=True,
+ shuffle: bool=True,
+ repeat: int=1,
+ seed: int=None):
+ if not isinstance(data_source, ConcatDataset):
+ raise TypeError("data_source should be torch.utils.data.ConcatDataset")
+
+ self.data_source = data_source
+ self.n_subset = len(self.data_source.datasets)
+ self.n_samples_per_subset = n_samples_per_subset
+ self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
+ self.subset_replacement = subset_replacement
+ self.repeat = repeat
+ self.shuffle = shuffle
+ self.generator = torch.manual_seed(seed)
+ assert self.repeat >= 1
+
+ def __len__(self):
+ return self.n_samples
+
+ def __iter__(self):
+ indices = []
+ # sample from each sub-dataset
+ for d_idx in range(self.n_subset):
+ low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
+ high = self.data_source.cumulative_sizes[d_idx]
+ if self.subset_replacement:
+ rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ else: # sample without replacement
+ len_subset = len(self.data_source.datasets[d_idx])
+ rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
+ if len_subset >= self.n_samples_per_subset:
+ rand_tensor = rand_tensor[:self.n_samples_per_subset]
+ else: # padding with replacement
+ rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
+ indices.append(rand_tensor)
+ indices = torch.cat(indices)
+ if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
+ rand_tensor = torch.randperm(len(indices), generator=self.generator)
+ indices = indices[rand_tensor]
+
+ # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
+ if self.repeat > 1:
+ repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
+ if self.shuffle:
+ _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
+ repeat_indices = map(_choice, repeat_indices)
+ indices = torch.cat([indices, *repeat_indices], 0)
+
+ assert indices.shape[0] == self.n_samples
+ return iter(indices.tolist())
diff --git a/imcui/third_party/ASpanFormer/src/datasets/scannet.py b/imcui/third_party/ASpanFormer/src/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3520d34c0f08a784ddbf923846a7cb2a847b1787
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/datasets/scannet.py
@@ -0,0 +1,113 @@
+from os import path as osp
+from typing import Dict
+from unicodedata import name
+
+import numpy as np
+import torch
+import torch.utils as utils
+from numpy.linalg import inv
+from src.utils.dataset import (
+ read_scannet_gray,
+ read_scannet_depth,
+ read_scannet_pose,
+ read_scannet_intrinsic
+)
+
+
+class ScanNetDataset(utils.data.Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ intrinsic_path,
+ mode='train',
+ min_overlap_score=0.4,
+ augment_fn=None,
+ pose_dir=None,
+ **kwargs):
+ """Manage one scene of ScanNet Dataset.
+ Args:
+ root_dir (str): ScanNet root directory that contains scene folders.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ intrinsic_path (str): path to depth-camera intrinsic file.
+ mode (str): options are ['train', 'val', 'test'].
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ pose_dir (str): ScanNet root directory that contains all poses.
+ (we use a separate (optional) pose_dir since we store images and poses separately.)
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.pose_dir = pose_dir if pose_dir is not None else root_dir
+ self.mode = mode
+
+ # prepare data_names, intrinsics and extrinsics(T)
+ with np.load(npz_path) as data:
+ self.data_names = data['name']
+ if 'score' in data.keys() and mode not in ['val' or 'test']:
+ kept_mask = data['score'] > min_overlap_score
+ self.data_names = self.data_names[kept_mask]
+ self.intrinsics = dict(np.load(intrinsic_path))
+
+ # for training LoFTR
+ self.augment_fn = augment_fn if mode == 'train' else None
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def _read_abs_pose(self, scene_name, name):
+ pth = osp.join(self.pose_dir,
+ scene_name,
+ 'pose', f'{name}.txt')
+ return read_scannet_pose(pth)
+
+ def _compute_rel_pose(self, scene_name, name0, name1):
+ pose0 = self._read_abs_pose(scene_name, name0)
+ pose1 = self._read_abs_pose(scene_name, name1)
+
+ return np.matmul(pose1, inv(pose0)) # (4, 4)
+
+ def __getitem__(self, idx):
+ data_name = self.data_names[idx]
+ scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the grayscale image which will be resized to (1, 480, 640)
+ img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
+ img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
+ # TODO: Support augmentation & handle seeds for each worker correctly.
+ image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+ image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+
+ # read the depthmap which is stored as (480, 640)
+ if self.mode in ['train', 'val']:
+ depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
+ depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read the intrinsic of depthmap
+ K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
+ dtype=torch.float32)
+ T_1to0 = T_0to1.inverse()
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'dataset_name': 'ScanNet',
+ 'scene_id': scene_name,
+ 'pair_id': idx,
+ 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
+ osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
+ }
+
+ return data
diff --git a/imcui/third_party/ASpanFormer/src/lightning/data.py b/imcui/third_party/ASpanFormer/src/lightning/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..73db514b8924d647814e6c5def919c23393d3ccf
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/lightning/data.py
@@ -0,0 +1,326 @@
+import os
+import math
+from collections import abc
+from loguru import logger
+from torch.utils.data.dataset import Dataset
+from tqdm import tqdm
+from os import path as osp
+from pathlib import Path
+from joblib import Parallel, delayed
+
+import pytorch_lightning as pl
+from torch import distributed as dist
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset,
+ DistributedSampler,
+ RandomSampler,
+ dataloader
+)
+
+from src.utils.augment import build_augmentor
+from src.utils.dataloader import get_local_split
+from src.utils.misc import tqdm_joblib
+from src.utils import comm
+from src.datasets.megadepth import MegaDepthDataset
+from src.datasets.scannet import ScanNetDataset
+from src.datasets.sampler import RandomConcatSampler
+
+
+class MultiSceneDataModule(pl.LightningDataModule):
+ """
+ For distributed training, each training process is assgined
+ only a part of the training scenes to reduce memory overhead.
+ """
+ def __init__(self, args, config):
+ super().__init__()
+
+ # 1. data config
+ # Train and Val should from the same data source
+ self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE
+ self.test_data_source = config.DATASET.TEST_DATA_SOURCE
+ # training and validating
+ self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
+ self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional)
+ self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
+ self.train_list_path = config.DATASET.TRAIN_LIST_PATH
+ self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
+ self.val_data_root = config.DATASET.VAL_DATA_ROOT
+ self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional)
+ self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
+ self.val_list_path = config.DATASET.VAL_LIST_PATH
+ self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
+ # testing
+ self.test_data_root = config.DATASET.TEST_DATA_ROOT
+ self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
+ self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
+ self.test_list_path = config.DATASET.TEST_LIST_PATH
+ self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
+
+ # 2. dataset config
+ # general options
+ self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
+ self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
+ self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
+
+ # MegaDepth options
+ self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
+ self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
+ self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
+ self.mgdpt_df = config.DATASET.MGDPT_DF # 8
+ self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0] # 0.125. for training loftr.
+
+ # 3.loader parameters
+ self.train_loader_params = {
+ 'batch_size': args.batch_size,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.val_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.test_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': True
+ }
+
+ # 4. sampler
+ self.data_sampler = config.TRAINER.DATA_SAMPLER
+ self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
+ self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
+ self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
+ self.repeat = config.TRAINER.SB_REPEAT
+
+ # (optional) RandomSampler for debugging
+
+ # misc configurations
+ self.parallel_load_data = getattr(args, 'parallel_load_data', False)
+ self.seed = config.TRAINER.SEED # 66
+
+ def setup(self, stage=None):
+ """
+ Setup train / val / test dataset. This method will be called by PL automatically.
+ Args:
+ stage (str): 'fit' in training phase, and 'test' in testing phase.
+ """
+
+ assert stage in ['fit', 'test'], "stage must be either fit or test"
+
+ try:
+ self.world_size = dist.get_world_size()
+ self.rank = dist.get_rank()
+ logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
+ except AssertionError as ae:
+ self.world_size = 1
+ self.rank = 0
+ logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
+
+ if stage == 'fit':
+ self.train_dataset = self._setup_dataset(
+ self.train_data_root,
+ self.train_npz_root,
+ self.train_list_path,
+ self.train_intrinsic_path,
+ mode='train',
+ min_overlap_score=self.min_overlap_score_train,
+ pose_dir=self.train_pose_root)
+ # setup multiple (optional) validation subsets
+ if isinstance(self.val_list_path, (list, tuple)):
+ self.val_dataset = []
+ if not isinstance(self.val_npz_root, (list, tuple)):
+ self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
+ for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
+ self.val_dataset.append(self._setup_dataset(
+ self.val_data_root,
+ npz_root,
+ npz_list,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root))
+ else:
+ self.val_dataset = self._setup_dataset(
+ self.val_data_root,
+ self.val_npz_root,
+ self.val_list_path,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root)
+ logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
+ else: # stage == 'test
+ self.test_dataset = self._setup_dataset(
+ self.test_data_root,
+ self.test_npz_root,
+ self.test_list_path,
+ self.test_intrinsic_path,
+ mode='test',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.test_pose_root)
+ logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
+
+ def _setup_dataset(self,
+ data_root,
+ split_npz_root,
+ scene_list_path,
+ intri_path,
+ mode='train',
+ min_overlap_score=0.,
+ pose_dir=None):
+ """ Setup train / val / test set"""
+ with open(scene_list_path, 'r') as f:
+ npz_names = [name.split()[0] for name in f.readlines()]
+
+ if mode == 'train':
+ local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
+ else:
+ local_npz_names = npz_names
+ logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
+
+ dataset_builder = self._build_concat_dataset_parallel \
+ if self.parallel_load_data \
+ else self._build_concat_dataset
+ return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
+ mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+
+ def _build_concat_dataset(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None
+ ):
+ datasets = []
+ augment_fn = self.augment_fn if mode == 'train' else None
+ data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
+ if data_source=='GL3D' and mode=='val':
+ data_source='MegaDepth'
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ if str(data_source).lower() == 'gl3d':
+ npz_names = [f'{n}.txt' for n in npz_names]
+ #npz_names=npz_names[:8]
+ for npz_name in tqdm(npz_names,
+ desc=f'[rank:{self.rank}] loading {mode} datasets',
+ disable=int(self.rank) != 0):
+ # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
+ npz_path = osp.join(npz_dir, npz_name)
+ if data_source == 'ScanNet':
+ datasets.append(
+ ScanNetDataset(data_root,
+ npz_path,
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir))
+ elif data_source == 'MegaDepth':
+ datasets.append(
+ MegaDepthDataset(data_root,
+ npz_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale))
+ else:
+ raise NotImplementedError()
+ return ConcatDataset(datasets)
+
+ def _build_concat_dataset_parallel(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None,
+ ):
+ augment_fn = self.augment_fn if mode == 'train' else None
+ data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ #npz_names=npz_names[:8]
+ with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
+ total=len(npz_names), disable=int(self.rank) != 0)):
+ if data_source == 'ScanNet':
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ ScanNetDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir))(name)
+ for name in npz_names)
+ elif data_source == 'MegaDepth':
+ # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
+ raise NotImplementedError()
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ MegaDepthDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale))(name)
+ for name in npz_names)
+ else:
+ raise ValueError(f'Unknown dataset: {data_source}')
+ return ConcatDataset(datasets)
+
+ def train_dataloader(self):
+ """ Build training dataloader for ScanNet / MegaDepth. """
+ assert self.data_sampler in ['scene_balance']
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
+ if self.data_sampler == 'scene_balance':
+ sampler = RandomConcatSampler(self.train_dataset,
+ self.n_samples_per_subset,
+ self.subset_replacement,
+ self.shuffle, self.repeat, self.seed)
+ else:
+ sampler = None
+ dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
+ return dataloader
+
+ def val_dataloader(self):
+ """ Build validation dataloader for ScanNet / MegaDepth. """
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
+ if not isinstance(self.val_dataset, abc.Sequence):
+ sampler = DistributedSampler(self.val_dataset, shuffle=False)
+ return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
+ else:
+ dataloaders = []
+ for dataset in self.val_dataset:
+ sampler = DistributedSampler(dataset, shuffle=False)
+ dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
+ return dataloaders
+
+ def test_dataloader(self, *args, **kwargs):
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
+ sampler = DistributedSampler(self.test_dataset, shuffle=False)
+ return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
+
+
+def _build_dataset(dataset: Dataset, *args, **kwargs):
+ return dataset(*args, **kwargs)
diff --git a/imcui/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py b/imcui/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee20cbec4628b73c08358ebf1e1906fb2c0ac13c
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/lightning/lightning_aspanformer.py
@@ -0,0 +1,276 @@
+
+from collections import defaultdict
+import pprint
+from loguru import logger
+from pathlib import Path
+
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from matplotlib import pyplot as plt
+
+from src.ASpanFormer.aspanformer import ASpanFormer
+from src.ASpanFormer.utils.supervision import compute_supervision_coarse, compute_supervision_fine
+from src.losses.aspan_loss import ASpanLoss
+from src.optimizers import build_optimizer, build_scheduler
+from src.utils.metrics import (
+ compute_symmetrical_epipolar_errors,compute_symmetrical_epipolar_errors_offset_bidirectional,
+ compute_pose_errors,
+ aggregate_metrics
+)
+from src.utils.plotting import make_matching_figures,make_matching_figures_offset
+from src.utils.comm import gather, all_gather
+from src.utils.misc import lower_config, flattenList
+from src.utils.profiler import PassThroughProfiler
+
+
+class PL_ASpanFormer(pl.LightningModule):
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
+ """
+ TODO:
+ - use the new version of PL logging API.
+ """
+ super().__init__()
+ # Misc
+ self.config = config # full config
+ _config = lower_config(self.config)
+ self.loftr_cfg = lower_config(_config['aspan'])
+ self.profiler = profiler or PassThroughProfiler()
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+
+ # Matcher: LoFTR
+ self.matcher = ASpanFormer(config=_config['aspan'])
+ self.loss = ASpanLoss(_config)
+
+ # Pretrained weights
+ print(pretrained_ckpt)
+ if pretrained_ckpt:
+ print('load')
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
+ msg=self.matcher.load_state_dict(state_dict, strict=False)
+ print(msg)
+ logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
+
+ # Testing
+ self.dump_dir = dump_dir
+
+ def configure_optimizers(self):
+ # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
+ optimizer = build_optimizer(self, self.config)
+ scheduler = build_scheduler(self.config, optimizer)
+ return [optimizer], [scheduler]
+
+ def optimizer_step(
+ self, epoch, batch_idx, optimizer, optimizer_idx,
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+ # learning rate warm up
+ warmup_step = self.config.TRAINER.WARMUP_STEP
+ if self.trainer.global_step < warmup_step:
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
+ lr = base_lr + \
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
+ for pg in optimizer.param_groups:
+ pg['lr'] = lr
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+ pass
+ else:
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+
+ # update params
+ optimizer.step(closure=optimizer_closure)
+ optimizer.zero_grad()
+
+ def _trainval_inference(self, batch):
+ with self.profiler.profile("Compute coarse supervision"):
+ compute_supervision_coarse(batch, self.config)
+
+ with self.profiler.profile("LoFTR"):
+ self.matcher(batch)
+
+ with self.profiler.profile("Compute fine supervision"):
+ compute_supervision_fine(batch, self.config)
+
+ with self.profiler.profile("Compute losses"):
+ self.loss(batch)
+
+ def _compute_metrics(self, batch):
+ with self.profiler.profile("Copmute metrics"):
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ compute_symmetrical_epipolar_errors_offset_bidirectional(batch) # compute epi_errs for offset match
+ compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
+
+ rel_pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].size(0)
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
+ 'epi_errs_offset': [batch['epi_errs_offset_left'][batch['offset_bids_left'] == b].cpu().numpy() for b in range(bs)], #only consider left side
+ 'R_errs': batch['R_errs'],
+ 't_errs': batch['t_errs'],
+ 'inliers': batch['inliers']}
+ ret_dict = {'metrics': metrics}
+ return ret_dict, rel_pair_names
+
+
+ def training_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ # logging
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+ # scalars
+ for k, v in batch['loss_scalars'].items():
+ if not k.startswith('loss_flow') and not k.startswith('conf_'):
+ self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
+
+ #log offset_loss and conf for each layer and level
+ layer_num=self.loftr_cfg['coarse']['layer_num']
+ for layer_index in range(layer_num):
+ log_title='layer_'+str(layer_index)
+ self.logger.experiment.add_scalar(log_title+'/offset_loss', batch['loss_scalars']['loss_flow_'+str(layer_index)], self.global_step)
+ self.logger.experiment.add_scalar(log_title+'/conf_', batch['loss_scalars']['conf_'+str(layer_index)],self.global_step)
+
+ # net-params
+ if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == 'sinkhorn':
+ self.logger.experiment.add_scalar(
+ f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step)
+
+ # figures
+ if self.config.TRAINER.ENABLE_PLOTTING:
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+ for k, v in figures.items():
+ self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
+
+ #plot offset
+ if self.global_step%200==0:
+ compute_symmetrical_epipolar_errors_offset_bidirectional(batch)
+ figures_left = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_left')
+ figures_right = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right')
+ for k, v in figures_left.items():
+ self.logger.experiment.add_figure(f'train_offset/{k}'+'_left', v, self.global_step)
+ figures = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right')
+ for k, v in figures_right.items():
+ self.logger.experiment.add_figure(f'train_offset/{k}'+'_right', v, self.global_step)
+
+ return {'loss': batch['loss']}
+
+ def training_epoch_end(self, outputs):
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+ if self.trainer.global_rank == 0:
+ self.logger.experiment.add_scalar(
+ 'train/avg_loss_on_epoch', avg_loss,
+ global_step=self.current_epoch)
+
+ def validation_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ ret_dict, _ = self._compute_metrics(batch) #this func also compute the epi_errors
+
+ val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
+ figures = {self.config.TRAINER.PLOT_MODE: []}
+ figures_offset = {self.config.TRAINER.PLOT_MODE: []}
+ if batch_idx % val_plot_interval == 0:
+ figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
+ figures_offset=make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,'_left')
+ return {
+ **ret_dict,
+ 'loss_scalars': batch['loss_scalars'],
+ 'figures': figures,
+ 'figures_offset_left':figures_offset
+ }
+
+ def validation_epoch_end(self, outputs):
+ # handle multiple validation sets
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+ multi_val_metrics = defaultdict(list)
+
+ for valset_idx, outputs in enumerate(multi_outputs):
+ # since pl performs sanity_check at the very begining of the training
+ cur_epoch = self.trainer.current_epoch
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+ cur_epoch = -1
+
+ # 1. loss_scalars: dict of list, on cpu
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+
+ # 2. val metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+ for thr in [5, 10, 20]:
+ multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
+
+ # 3. figures
+ _figures = [o['figures'] for o in outputs]
+ figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
+
+ # tensorboard records only on rank 0
+ if self.trainer.global_rank == 0:
+ for k, v in loss_scalars.items():
+ mean_v = torch.stack(v).mean()
+ self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+
+ for k, v in val_metrics_4tb.items():
+ self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
+
+ for k, v in figures.items():
+ if self.trainer.global_rank == 0:
+ for plot_idx, fig in enumerate(v):
+ self.logger.experiment.add_figure(
+ f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
+ plt.close('all')
+
+ for thr in [5, 10, 20]:
+ # log on all ranks for ModelCheckpoint callback to work properly
+ self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
+
+ def test_step(self, batch, batch_idx):
+ with self.profiler.profile("LoFTR"):
+ self.matcher(batch)
+
+ ret_dict, rel_pair_names = self._compute_metrics(batch)
+
+ with self.profiler.profile("dump_results"):
+ if self.dump_dir is not None:
+ # dump results for further analysis
+ keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'}
+ pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].shape[0]
+ dumps = []
+ for b_id in range(bs):
+ item = {}
+ mask = batch['m_bids'] == b_id
+ item['pair_names'] = pair_names[b_id]
+ item['identifier'] = '#'.join(rel_pair_names[b_id])
+ for key in keys_to_save:
+ item[key] = batch[key][mask].cpu().numpy()
+ for key in ['R_errs', 't_errs', 'inliers']:
+ item[key] = batch[key][b_id]
+ dumps.append(item)
+ ret_dict['dumps'] = dumps
+
+ return ret_dict
+
+ def test_epoch_end(self, outputs):
+ # metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+
+ # [{key: [{...}, *#bs]}, *#batch]
+ if self.dump_dir is not None:
+ Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
+ _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
+ dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
+ logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
+
+ if self.trainer.global_rank == 0:
+ print(self.profiler.summary())
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+ logger.info('\n' + pprint.pformat(val_metrics_4tb))
+ if self.dump_dir is not None:
+ np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps)
diff --git a/imcui/third_party/ASpanFormer/src/losses/aspan_loss.py b/imcui/third_party/ASpanFormer/src/losses/aspan_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cca52b36fc997415937969f26caba8c41ac2b8e
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/losses/aspan_loss.py
@@ -0,0 +1,231 @@
+from loguru import logger
+
+import torch
+import torch.nn as nn
+
+class ASpanLoss(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config # config under the global namespace
+ self.loss_config = config['aspan']['loss']
+ self.match_type = self.config['aspan']['match_coarse']['match_type']
+ self.sparse_spvs = self.config['aspan']['match_coarse']['sparse_spvs']
+ self.flow_weight=self.config['aspan']['loss']['flow_weight']
+
+ # coarse-level
+ self.correct_thr = self.loss_config['fine_correct_thr']
+ self.c_pos_w = self.loss_config['pos_weight']
+ self.c_neg_w = self.loss_config['neg_weight']
+ # fine-level
+ self.fine_type = self.loss_config['fine_type']
+
+ def compute_flow_loss(self,coarse_corr_gt,flow_list,h0,w0,h1,w1):
+ #coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]]
+ #flow_list: [L,B,H,W,4]
+ loss1=self.flow_loss_worker(flow_list[0],coarse_corr_gt[0],coarse_corr_gt[1],coarse_corr_gt[2],w1)
+ loss2=self.flow_loss_worker(flow_list[1],coarse_corr_gt[0],coarse_corr_gt[2],coarse_corr_gt[1],w0)
+ total_loss=(loss1+loss2)/2
+ return total_loss
+
+ def flow_loss_worker(self,flow,batch_indicies,self_indicies,cross_indicies,w):
+ bs,layer_num=flow.shape[1],flow.shape[0]
+ flow=flow.view(layer_num,bs,-1,4)
+ gt_flow=torch.stack([cross_indicies%w,cross_indicies//w],dim=1)
+
+ total_loss_list=[]
+ for layer_index in range(layer_num):
+ cur_flow_list=flow[layer_index]
+ spv_flow=cur_flow_list[batch_indicies,self_indicies][:,:2]
+ spv_conf=cur_flow_list[batch_indicies,self_indicies][:,2:]#[#coarse,2]
+ l2_flow_dis=((gt_flow-spv_flow)**2) #[#coarse,2]
+ total_loss=(spv_conf+torch.exp(-spv_conf)*l2_flow_dis) #[#coarse,2]
+ total_loss_list.append(total_loss.mean())
+ total_loss=torch.stack(total_loss_list,dim=-1)*self.flow_weight
+ return total_loss
+
+ def compute_coarse_loss(self, conf, conf_gt, weight=None):
+ """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
+ Args:
+ conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
+ conf_gt (torch.Tensor): (N, HW0, HW1)
+ weight (torch.Tensor): (N, HW0, HW1)
+ """
+ pos_mask, neg_mask = conf_gt == 1, conf_gt == 0
+ c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w
+ # corner case: no gt coarse-level match at all
+ if not pos_mask.any(): # assign a wrong gt
+ pos_mask[0, 0, 0] = True
+ if weight is not None:
+ weight[0, 0, 0] = 0.
+ c_pos_w = 0.
+ if not neg_mask.any():
+ neg_mask[0, 0, 0] = True
+ if weight is not None:
+ weight[0, 0, 0] = 0.
+ c_neg_w = 0.
+
+ if self.loss_config['coarse_type'] == 'cross_entropy':
+ assert not self.sparse_spvs, 'Sparse Supervision for cross-entropy not implemented!'
+ conf = torch.clamp(conf, 1e-6, 1-1e-6)
+ loss_pos = - torch.log(conf[pos_mask])
+ loss_neg = - torch.log(1 - conf[neg_mask])
+ if weight is not None:
+ loss_pos = loss_pos * weight[pos_mask]
+ loss_neg = loss_neg * weight[neg_mask]
+ return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
+ elif self.loss_config['coarse_type'] == 'focal':
+ conf = torch.clamp(conf, 1e-6, 1-1e-6)
+ alpha = self.loss_config['focal_alpha']
+ gamma = self.loss_config['focal_gamma']
+
+ if self.sparse_spvs:
+ pos_conf = conf[:, :-1, :-1][pos_mask] \
+ if self.match_type == 'sinkhorn' \
+ else conf[pos_mask]
+ loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
+ # calculate losses for negative samples
+ if self.match_type == 'sinkhorn':
+ neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0
+ neg_conf = torch.cat([conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0)
+ loss_neg = - alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log()
+ else:
+ # These is no dustbin for dual_softmax, so we left unmatchable patches without supervision.
+ # we could also add 'pseudo negtive-samples'
+ pass
+ # handle loss weights
+ if weight is not None:
+ # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
+ # but only through manually setting corresponding regions in sim_matrix to '-inf'.
+ loss_pos = loss_pos * weight[pos_mask]
+ if self.match_type == 'sinkhorn':
+ neg_w0 = (weight.sum(-1) != 0)[neg0]
+ neg_w1 = (weight.sum(1) != 0)[neg1]
+ neg_mask = torch.cat([neg_w0, neg_w1], 0)
+ loss_neg = loss_neg[neg_mask]
+
+ loss = c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() \
+ if self.match_type == 'sinkhorn' \
+ else c_pos_w * loss_pos.mean()
+ return loss
+ # positive and negative elements occupy similar propotions. => more balanced loss weights needed
+ else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.)
+ loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
+ loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log()
+ if weight is not None:
+ loss_pos = loss_pos * weight[pos_mask]
+ loss_neg = loss_neg * weight[neg_mask]
+ return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
+ # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
+ else:
+ raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type']))
+
+ def compute_fine_loss(self, expec_f, expec_f_gt):
+ if self.fine_type == 'l2_with_std':
+ return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
+ elif self.fine_type == 'l2':
+ return self._compute_fine_loss_l2(expec_f, expec_f_gt)
+ else:
+ raise NotImplementedError()
+
+ def _compute_fine_loss_l2(self, expec_f, expec_f_gt):
+ """
+ Args:
+ expec_f (torch.Tensor): [M, 2]
+ expec_f_gt (torch.Tensor): [M, 2]
+ """
+ correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+ if correct_mask.sum() == 0:
+ if self.training: # this seldomly happen when training, since we pad prediction with gt
+ logger.warning("assign a false supervision to avoid ddp deadlock")
+ correct_mask[0] = True
+ else:
+ return None
+ flow_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1)
+ return flow_l2.mean()
+
+ def _compute_fine_loss_l2_std(self, expec_f, expec_f_gt):
+ """
+ Args:
+ expec_f (torch.Tensor): [M, 3]
+ expec_f_gt (torch.Tensor): [M, 2]
+ """
+ # correct_mask tells you which pair to compute fine-loss
+ correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+
+ # use std as weight that measures uncertainty
+ std = expec_f[:, 2]
+ inverse_std = 1. / torch.clamp(std, min=1e-10)
+ weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std
+
+ # corner case: no correct coarse match found
+ if not correct_mask.any():
+ if self.training: # this seldomly happen during training, since we pad prediction with gt
+ # sometimes there is not coarse-level gt at all.
+ logger.warning("assign a false supervision to avoid ddp deadlock")
+ correct_mask[0] = True
+ weight[0] = 0.
+ else:
+ return None
+
+ # l2 loss with std
+ flow_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1)
+ loss = (flow_l2 * weight[correct_mask]).mean()
+
+ return loss
+
+ @torch.no_grad()
+ def compute_c_weight(self, data):
+ """ compute element-wise weights for computing coarse-level loss. """
+ if 'mask0' in data:
+ c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
+ else:
+ c_weight = None
+ return c_weight
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): update{
+ 'loss': [1] the reduced loss across a batch,
+ 'loss_scalars' (dict): loss scalars for tensorboard_record
+ }
+ """
+ loss_scalars = {}
+ # 0. compute element-wise loss weight
+ c_weight = self.compute_c_weight(data)
+
+ # 1. coarse-level loss
+ loss_c = self.compute_coarse_loss(
+ data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \
+ else data['conf_matrix'],
+ data['conf_matrix_gt'],
+ weight=c_weight)
+ loss = loss_c * self.loss_config['coarse_weight']
+ loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
+
+ # 2. fine-level loss
+ loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt'])
+ if loss_f is not None:
+ loss += loss_f * self.loss_config['fine_weight']
+ loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
+ else:
+ assert self.training is False
+ loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound
+
+ # 3. flow loss
+ coarse_corr=[data['spv_b_ids'],data['spv_i_ids'],data['spv_j_ids']]
+ loss_flow = self.compute_flow_loss(coarse_corr,data['predict_flow'],\
+ data['hw0_c'][0],data['hw0_c'][1],data['hw1_c'][0],data['hw1_c'][1])
+ loss_flow=loss_flow*self.flow_weight
+ for index,loss_off in enumerate(loss_flow):
+ loss_scalars.update({'loss_flow_'+str(index): loss_off.clone().detach().cpu()}) # 1 is the upper bound
+ conf=data['predict_flow'][0][:,:,:,:,2:]
+ layer_num=conf.shape[0]
+ for layer_index in range(layer_num):
+ loss_scalars.update({'conf_'+str(layer_index): conf[layer_index].mean().clone().detach().cpu()}) # 1 is the upper bound
+
+
+ loss+=loss_flow.sum()
+ #print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data)
+ loss_scalars.update({'loss': loss.clone().detach().cpu()})
+ data.update({"loss": loss, "loss_scalars": loss_scalars})
diff --git a/imcui/third_party/ASpanFormer/src/optimizers/__init__.py b/imcui/third_party/ASpanFormer/src/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/optimizers/__init__.py
@@ -0,0 +1,42 @@
+import torch
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
+
+
+def build_optimizer(model, config):
+ name = config.TRAINER.OPTIMIZER
+ lr = config.TRAINER.TRUE_LR
+
+ if name == "adam":
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
+ elif name == "adamw":
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
+ else:
+ raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
+
+
+def build_scheduler(config, optimizer):
+ """
+ Returns:
+ scheduler (dict):{
+ 'scheduler': lr_scheduler,
+ 'interval': 'step', # or 'epoch'
+ 'monitor': 'val_f1', (optional)
+ 'frequency': x, (optional)
+ }
+ """
+ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
+ name = config.TRAINER.SCHEDULER
+
+ if name == 'MultiStepLR':
+ scheduler.update(
+ {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
+ elif name == 'CosineAnnealing':
+ scheduler.update(
+ {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
+ elif name == 'ExponentialLR':
+ scheduler.update(
+ {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
+ else:
+ raise NotImplementedError()
+
+ return scheduler
diff --git a/imcui/third_party/ASpanFormer/src/utils/augment.py b/imcui/third_party/ASpanFormer/src/utils/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/utils/augment.py
@@ -0,0 +1,55 @@
+import albumentations as A
+
+
+class DarkAug(object):
+ """
+ Extreme dark augmentation aiming at Aachen Day-Night
+ """
+
+ def __init__(self) -> None:
+ self.augmentor = A.Compose([
+ A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
+ A.Blur(p=0.1, blur_limit=(3, 9)),
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
+ ], p=0.75)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+class MobileAug(object):
+ """
+ Random augmentations aiming at images of mobile/handhold devices.
+ """
+
+ def __init__(self):
+ self.augmentor = A.Compose([
+ A.MotionBlur(p=0.25),
+ A.ColorJitter(p=0.5),
+ A.RandomRain(p=0.1), # random occlusion
+ A.RandomSunFlare(p=0.1),
+ A.JpegCompression(p=0.25),
+ A.ISONoise(p=0.25)
+ ], p=1.0)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+def build_augmentor(method=None, **kwargs):
+ if method is not None:
+ raise NotImplementedError('Using of augmentation functions are not supported yet!')
+ if method == 'dark':
+ return DarkAug()
+ elif method == 'mobile':
+ return MobileAug()
+ elif method is None:
+ return None
+ else:
+ raise ValueError(f'Invalid augmentation method: {method}')
+
+
+if __name__ == '__main__':
+ augmentor = build_augmentor('FDA')
diff --git a/imcui/third_party/ASpanFormer/src/utils/comm.py b/imcui/third_party/ASpanFormer/src/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/utils/comm.py
@@ -0,0 +1,265 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+[Copied from detectron2]
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group=group) == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving Tensor from all ranks
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/imcui/third_party/ASpanFormer/src/utils/dataloader.py b/imcui/third_party/ASpanFormer/src/utils/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/utils/dataloader.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+# --- PL-DATAMODULE ---
+
+def get_local_split(items: list, world_size: int, rank: int, seed: int):
+ """ The local rank only loads a split of the dataset. """
+ n_items = len(items)
+ items_permute = np.random.RandomState(seed).permutation(items)
+ if n_items % world_size == 0:
+ padded_items = items_permute
+ else:
+ padding = np.random.RandomState(seed).choice(
+ items,
+ world_size - (n_items % world_size),
+ replace=True)
+ padded_items = np.concatenate([items_permute, padding])
+ assert len(padded_items) % world_size == 0, \
+ f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
+ n_per_rank = len(padded_items) // world_size
+ local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
+
+ return local_items
diff --git a/imcui/third_party/ASpanFormer/src/utils/dataset.py b/imcui/third_party/ASpanFormer/src/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..209bf554acc20e33ea89eb9e7024ba68d0b3a30b
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/utils/dataset.py
@@ -0,0 +1,222 @@
+import io
+import cv2
+import numpy as np
+import h5py
+import torch
+from numpy.linalg import inv
+import re
+
+
+try:
+ # for internel use only
+ from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
+except Exception:
+ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
+
+# --- DATA IO ---
+
+def load_array_from_s3(
+ path, client, cv_type,
+ use_h5py=False,
+):
+ byte_str = client.Get(path)
+ try:
+ if not use_h5py:
+ raw_array = np.fromstring(byte_str, np.uint8)
+ data = cv2.imdecode(raw_array, cv_type)
+ else:
+ f = io.BytesIO(byte_str)
+ data = np.array(h5py.File(f, 'r')['/depth'])
+ except Exception as ex:
+ print(f"==> Data loading failure: {path}")
+ raise ex
+
+ assert data is not None
+ return data
+
+
+def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
+ else cv2.IMREAD_COLOR
+ if str(path).startswith('s3://'):
+ image = load_array_from_s3(str(path), client, cv_type)
+ else:
+ image = cv2.imread(str(path), cv_type)
+
+ if augment_fn is not None:
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = augment_fn(image)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ return image # (h, w)
+
+
+def get_resized_wh(w, h, resize=None):
+ if resize is not None: # resize the longer edge
+ scale = resize / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def get_divisible_wh(w, h, df=None):
+ if df is not None:
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def pad_bottom_right(inp, pad_size, ret_mask=False):
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+ mask = None
+ if inp.ndim == 2:
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
+ padded[:inp.shape[0], :inp.shape[1]] = inp
+ if ret_mask:
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
+ mask[:inp.shape[0], :inp.shape[1]] = True
+ elif inp.ndim == 3:
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
+ if ret_mask:
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
+ else:
+ raise NotImplementedError()
+ return padded, mask
+
+
+# --- MEGADEPTH ---
+
+def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read image
+ image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ else:
+ mask = None
+
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ if mask is not None:
+ mask = torch.from_numpy(mask)
+
+ return image, mask, scale
+
+
+def read_megadepth_depth(path, pad_to=None):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
+ else:
+ depth = np.array(h5py.File(path, 'r')['depth'])
+ if pad_to is not None:
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+# --- ScanNet ---
+
+def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
+ """
+ Args:
+ resize (tuple): align image to depthmap, in (w, h).
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read and resize image
+ image = imread_gray(path, augment_fn)
+ image = cv2.resize(image, resize)
+
+ # (h, w) -> (1, h, w) and normalized
+ image = torch.from_numpy(image).float()[None] / 255
+ return image
+
+
+def read_scannet_depth(path):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
+ else:
+ depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+def read_scannet_pose(path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = inv(cam2world)
+ return world2cam
+
+
+def read_scannet_intrinsic(path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return intrinsic[:-1, :-1]
+
+
+def read_gl3d_gray(path,resize):
+ img=cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),(int(resize),int(resize)))
+ img = torch.from_numpy(img).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ return img
+
+def read_gl3d_depth(file_path):
+ with open(file_path, 'rb') as fin:
+ color = None
+ width = None
+ height = None
+ scale = None
+ data_type = None
+ header = str(fin.readline().decode('UTF-8')).rstrip()
+ if header == 'PF':
+ color = True
+ elif header == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+ scale = float((fin.readline().decode('UTF-8')).rstrip())
+ if scale < 0: # little-endian
+ data_type = ' best_num_inliers:
+ ret = (R, t[:, 0], mask.ravel() > 0)
+ best_num_inliers = n
+
+ return ret
+
+
+def compute_pose_errors(data, config):
+ """
+ Update:
+ data (dict):{
+ "R_errs" List[float]: [N]
+ "t_errs" List[float]: [N]
+ "inliers" List[np.ndarray]: [N]
+ }
+ """
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+
+ m_bids = data['m_bids'].cpu().numpy()
+ pts0 = data['mkpts0_f'].cpu().numpy()
+ pts1 = data['mkpts1_f'].cpu().numpy()
+ K0 = data['K0'].cpu().numpy()
+ K1 = data['K1'].cpu().numpy()
+ T_0to1 = data['T_0to1'].cpu().numpy()
+
+ for bs in range(K0.shape[0]):
+ mask = m_bids == bs
+ ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
+
+ if ret is None:
+ data['R_errs'].append(np.inf)
+ data['t_errs'].append(np.inf)
+ data['inliers'].append(np.array([]).astype(np.bool))
+ else:
+ R, t, inliers = ret
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
+ data['R_errs'].append(R_err)
+ data['t_errs'].append(t_err)
+ data['inliers'].append(inliers)
+
+
+# --- METRIC AGGREGATION ---
+
+def error_auc(errors, thresholds):
+ """
+ Args:
+ errors (list): [N,]
+ thresholds (list)
+ """
+ errors = [0] + sorted(list(errors))
+ recall = list(np.linspace(0, 1, len(errors)))
+
+ aucs = []
+ thresholds = [5, 10, 20]
+ for thr in thresholds:
+ last_index = np.searchsorted(errors, thr)
+ y = recall[:last_index] + [recall[last_index-1]]
+ x = errors[:last_index] + [thr]
+ aucs.append(np.trapz(y, x) / thr)
+
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+
+
+def epidist_prec(errors, thresholds, ret_dict=False,offset=False):
+ precs = []
+ for thr in thresholds:
+ prec_ = []
+ for errs in errors:
+ correct_mask = errs < thr
+ prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
+ precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
+ if ret_dict:
+ return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} if not offset else {f'prec_flow@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
+ else:
+ return precs
+
+
+def aggregate_metrics(metrics, epi_err_thr=5e-4):
+ """ Aggregate metrics for the whole dataset:
+ (This method should be called once per dataset)
+ 1. AUC of the pose error (angular) at the threshold [5, 10, 20]
+ 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
+ """
+ # filter duplicates
+ unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
+ unq_ids = list(unq_ids.values())
+ logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
+
+ # pose auc
+ angular_thresholds = [5, 10, 20]
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
+ aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
+
+ # matching precision
+ dist_thresholds = [epi_err_thr]
+ precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
+
+ #offset precision
+ try:
+ precs_offset = epidist_prec(np.array(metrics['epi_errs_offset'], dtype=object)[unq_ids], [2e-3], True,offset=True)
+ return {**aucs, **precs,**precs_offset}
+ except:
+ return {**aucs, **precs}
diff --git a/imcui/third_party/ASpanFormer/src/utils/misc.py b/imcui/third_party/ASpanFormer/src/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e4433f5ffa41adc4c0435cfe2b5696e43b58b3
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/utils/misc.py
@@ -0,0 +1,139 @@
+import os
+import contextlib
+import joblib
+from typing import Union
+from loguru import _Logger, logger
+from itertools import chain
+
+import torch
+from yacs.config import CfgNode as CN
+from pytorch_lightning.utilities import rank_zero_only
+import cv2
+import numpy as np
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+def upper_config(dict_cfg):
+ if not isinstance(dict_cfg, dict):
+ return dict_cfg
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
+
+
+def log_on(condition, message, level):
+ if condition:
+ assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
+ logger.log(level, message)
+
+
+def get_rank_zero_only_logger(logger: _Logger):
+ if rank_zero_only.rank == 0:
+ return logger
+ else:
+ for _level in logger._core.levels.keys():
+ level = _level.lower()
+ setattr(logger, level,
+ lambda x: None)
+ logger._log = lambda x: None
+ return logger
+
+
+def setup_gpus(gpus: Union[str, int]) -> int:
+ """ A temporary fix for pytorch-lighting 1.3.x """
+ gpus = str(gpus)
+ gpu_ids = []
+
+ if ',' not in gpus:
+ n_gpus = int(gpus)
+ return n_gpus if n_gpus != -1 else torch.cuda.device_count()
+ else:
+ gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
+
+ # setup environment variables
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ if visible_devices is None:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
+ else:
+ logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
+ return len(gpu_ids)
+
+
+def flattenList(x):
+ return list(chain(*x))
+
+
+@contextlib.contextmanager
+def tqdm_joblib(tqdm_object):
+ """Context manager to patch joblib to report into tqdm progress bar given as argument
+
+ Usage:
+ with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
+ Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
+
+ When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
+ ret_vals = Parallel(n_jobs=args.world_size)(
+ delayed(lambda x: _compute_cov_score(pid, *x))(param)
+ for param in tqdm(combinations(image_ids, 2),
+ desc=f'Computing cov_score of [{pid}]',
+ total=len(image_ids)*(len(image_ids)-1)/2))
+ Src: https://stackoverflow.com/a/58936697
+ """
+ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ tqdm_object.update(n=self.batch_size)
+ return super().__call__(*args, **kwargs)
+
+ old_batch_callback = joblib.parallel.BatchCompletionCallBack
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
+ try:
+ yield tqdm_object
+ finally:
+ joblib.parallel.BatchCompletionCallBack = old_batch_callback
+ tqdm_object.close()
+
+
+def draw_points(img,points,color=(0,255,0),radius=3):
+ dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
+ for i in range(points.shape[0]):
+ cv2.circle(img, dp[i],radius=radius,color=color)
+ return img
+
+
+def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
+ if resize is not None:
+ scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
+ img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA)
+ corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
+ corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
+ corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
+
+ assert len(corr1) == len(corr2)
+
+ draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
+ if color is None:
+ color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
+ if len(color)==1:
+ display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
+ matchColor=color[0],
+ singlePointColor=color[0],
+ flags=4
+ )
+ else:
+ height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
+ display=np.zeros([height,width,3],np.uint8)
+ display[:img1.shape[0],:img1.shape[1]]=img1
+ display[:img2.shape[0],img1.shape[1]:]=img2
+ for i in range(len(corr1)):
+ left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
+ cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
+ cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
+ return display
diff --git a/imcui/third_party/ASpanFormer/src/utils/plotting.py b/imcui/third_party/ASpanFormer/src/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..8696880237b6ad9fe48d3c1fc44ed13b691a6c4d
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/utils/plotting.py
@@ -0,0 +1,219 @@
+import bisect
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+from copy import deepcopy
+
+def _compute_conf_thresh(data):
+ dataset_name = data['dataset_name'][0].lower()
+ if dataset_name == 'scannet':
+ thr = 5e-4
+ elif dataset_name == 'megadepth' or dataset_name=='gl3d':
+ thr = 1e-4
+ else:
+ raise ValueError(f'Unknown dataset: {dataset_name}')
+ return thr
+
+
+# --- VISUALIZATION --- #
+
+def make_matching_figure(
+ img0, img1, mkpts0, mkpts1, color,
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+ # draw image pair
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0, cmap='gray')
+ axes[1].imshow(img1, cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if kpts0 is not None:
+ assert kpts1 is not None
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
+
+ # draw matches
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure, c=color[i], linewidth=1)
+ for i in range(len(mkpts0))]
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
+
+ # put txts
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def _make_evaluation_figure(data, b_id, alpha='dynamic'):
+ b_mask = data['m_bids'] == b_id
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
+ kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
+
+ # for megadepth, we visualize matches on the resized image
+ if 'scale0' in data:
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+ ]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
+ color, text=text)
+ return figure
+
+def _make_evaluation_figure_offset(data, b_id, alpha='dynamic',side=''):
+ layer_num=data['predict_flow'][0].shape[0]
+
+ b_mask = data['offset_bids'+side] == b_id
+ conf_thr = 2e-3 #hardcode for scannet(coarse level)
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+
+ figure_list=[]
+ #draw offset matches in different layers
+ for layer_index in range(layer_num):
+ l_mask=data['offset_lids'+side]==layer_index
+ mask=l_mask&b_mask
+ kpts0 = data['offset_kpts0_f'+side][mask].cpu().numpy()
+ kpts1 = data['offset_kpts1_f'+side][mask].cpu().numpy()
+
+ epi_errs = data['epi_errs_offset'+side][mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+ ]
+
+ # make the figure
+ #import pdb;pdb.set_trace()
+ figure = make_matching_figure(deepcopy(img0), deepcopy(img1) , kpts0, kpts1,
+ color, text=text)
+ figure_list.append(figure)
+ return figure
+
+def _make_confidence_figure(data, b_id):
+ # TODO: Implement confidence figure
+ raise NotImplementedError()
+
+
+def make_matching_figures(data, config, mode='evaluation'):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ['evaluation', 'confidence'] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ fig = _make_evaluation_figure(
+ data, b_id,
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
+ elif mode == 'confidence':
+ fig = _make_confidence_figure(data, b_id)
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+def make_matching_figures_offset(data, config, mode='evaluation',side=''):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ['evaluation', 'confidence'] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ fig = _make_evaluation_figure_offset(
+ data, b_id,
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA,side=side)
+ elif mode == 'confidence':
+ fig = _make_evaluation_figure_offset(data, b_id)
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+def dynamic_alpha(n_matches,
+ milestones=[0, 300, 1000, 2000],
+ alphas=[1.0, 0.8, 0.4, 0.2]):
+ if n_matches == 0:
+ return 1.0
+ ranges = list(zip(alphas, alphas[1:] + [None]))
+ loc = bisect.bisect_right(milestones, n_matches) - 1
+ _range = ranges[loc]
+ if _range[1] is None:
+ return _range[0]
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+
+
+def error_colormap(err, thr, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
+ return np.clip(
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
diff --git a/imcui/third_party/ASpanFormer/src/utils/profiler.py b/imcui/third_party/ASpanFormer/src/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/src/utils/profiler.py
@@ -0,0 +1,39 @@
+import torch
+from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
+from contextlib import contextmanager
+from pytorch_lightning.utilities import rank_zero_only
+
+
+class InferenceProfiler(SimpleProfiler):
+ """
+ This profiler records duration of actions with cuda.synchronize()
+ Use this in test time.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.start = rank_zero_only(self.start)
+ self.stop = rank_zero_only(self.stop)
+ self.summary = rank_zero_only(self.summary)
+
+ @contextmanager
+ def profile(self, action_name: str) -> None:
+ try:
+ torch.cuda.synchronize()
+ self.start(action_name)
+ yield action_name
+ finally:
+ torch.cuda.synchronize()
+ self.stop(action_name)
+
+
+def build_profiler(name):
+ if name == 'inference':
+ return InferenceProfiler()
+ elif name == 'pytorch':
+ from pytorch_lightning.profiler import PyTorchProfiler
+ return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
+ elif name is None:
+ return PassThroughProfiler()
+ else:
+ raise ValueError(f'Invalid profiler: {name}')
diff --git a/imcui/third_party/ASpanFormer/test.py b/imcui/third_party/ASpanFormer/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..541ce84662ab4888c6fece30403c5c9983118637
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/test.py
@@ -0,0 +1,69 @@
+import pytorch_lightning as pl
+import argparse
+import pprint
+from loguru import logger as loguru_logger
+
+from src.config.default import get_cfg_defaults
+from src.utils.profiler import build_profiler
+
+from src.lightning.data import MultiSceneDataModule
+from src.lightning.lightning_aspanformer import PL_ASpanFormer
+import torch
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
+ parser.add_argument(
+ '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
+ parser.add_argument(
+ '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--batch_size', type=int, default=1, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=2)
+ parser.add_argument(
+ '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
+ parser.add_argument(
+ '--mode', type=str, default='vanilla', help='modify the coarse-level matching threshold.')
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+ # parse arguments
+ args = parse_args()
+ pprint.pprint(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+
+ # tune when testing
+ if args.thr is not None:
+ config.ASPAN.MATCH_COARSE.THR = args.thr
+
+ loguru_logger.info(f"Args and config initialized!")
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_ASpanFormer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
+ loguru_logger.info(f"ASpanFormer-lightning initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"DataModule initialized!")
+
+ # lightning trainer
+ trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)
+
+ loguru_logger.info(f"Start testing!")
+ trainer.test(model, datamodule=data_module, verbose=False)
diff --git a/imcui/third_party/ASpanFormer/tools/SensorData.py b/imcui/third_party/ASpanFormer/tools/SensorData.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3ec2644bf8b3b988ef0f36851cd3317c00511b2
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/tools/SensorData.py
@@ -0,0 +1,125 @@
+
+import os, struct
+import numpy as np
+import zlib
+import imageio
+import cv2
+import png
+
+COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'}
+COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'}
+
+class RGBDFrame():
+
+ def load(self, file_handle):
+ self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4)
+ self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0]
+ self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0]
+ self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0]
+ self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0]
+ self.color_data = ''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes)))
+ self.depth_data = ''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes)))
+
+
+ def decompress_depth(self, compression_type):
+ if compression_type == 'zlib_ushort':
+ return self.decompress_depth_zlib()
+ else:
+ raise
+
+
+ def decompress_depth_zlib(self):
+ return zlib.decompress(self.depth_data)
+
+
+ def decompress_color(self, compression_type):
+ if compression_type == 'jpeg':
+ return self.decompress_color_jpeg()
+ else:
+ raise
+
+
+ def decompress_color_jpeg(self):
+ return imageio.imread(self.color_data)
+
+
+class SensorData:
+
+ def __init__(self, filename):
+ self.version = 4
+ self.load(filename)
+
+
+ def load(self, filename):
+ with open(filename, 'rb') as f:
+ version = struct.unpack('I', f.read(4))[0]
+ assert self.version == version
+ strlen = struct.unpack('Q', f.read(8))[0]
+ self.sensor_name = ''.join(struct.unpack('c'*strlen, f.read(strlen)))
+ self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
+ self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
+ self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
+ self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
+ self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]]
+ self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]]
+ self.color_width = struct.unpack('I', f.read(4))[0]
+ self.color_height = struct.unpack('I', f.read(4))[0]
+ self.depth_width = struct.unpack('I', f.read(4))[0]
+ self.depth_height = struct.unpack('I', f.read(4))[0]
+ self.depth_shift = struct.unpack('f', f.read(4))[0]
+ num_frames = struct.unpack('Q', f.read(8))[0]
+ self.frames = []
+ for i in range(num_frames):
+ frame = RGBDFrame()
+ frame.load(f)
+ self.frames.append(frame)
+
+
+ def export_depth_images(self, output_path, image_size=None, frame_skip=1):
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ print 'exporting', len(self.frames)//frame_skip, ' depth frames to', output_path
+ for f in range(0, len(self.frames), frame_skip):
+ depth_data = self.frames[f].decompress_depth(self.depth_compression_type)
+ depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width)
+ if image_size is not None:
+ depth = cv2.resize(depth, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST)
+ #imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth)
+ with open(os.path.join(output_path, str(f) + '.png'), 'wb') as f: # write 16-bit
+ writer = png.Writer(width=depth.shape[1], height=depth.shape[0], bitdepth=16)
+ depth = depth.reshape(-1, depth.shape[1]).tolist()
+ writer.write(f, depth)
+
+ def export_color_images(self, output_path, image_size=None, frame_skip=1):
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ print 'exporting', len(self.frames)//frame_skip, 'color frames to', output_path
+ for f in range(0, len(self.frames), frame_skip):
+ color = self.frames[f].decompress_color(self.color_compression_type)
+ if image_size is not None:
+ color = cv2.resize(color, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST)
+ imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color)
+
+
+ def save_mat_to_file(self, matrix, filename):
+ with open(filename, 'w') as f:
+ for line in matrix:
+ np.savetxt(f, line[np.newaxis], fmt='%f')
+
+
+ def export_poses(self, output_path, frame_skip=1):
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ print 'exporting', len(self.frames)//frame_skip, 'camera poses to', output_path
+ for f in range(0, len(self.frames), frame_skip):
+ self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, str(f) + '.txt'))
+
+
+ def export_intrinsics(self, output_path):
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ print 'exporting camera intrinsics to', output_path
+ self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt'))
+ self.save_mat_to_file(self.extrinsic_color, os.path.join(output_path, 'extrinsic_color.txt'))
+ self.save_mat_to_file(self.intrinsic_depth, os.path.join(output_path, 'intrinsic_depth.txt'))
+ self.save_mat_to_file(self.extrinsic_depth, os.path.join(output_path, 'extrinsic_depth.txt'))
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/tools/extract.py b/imcui/third_party/ASpanFormer/tools/extract.py
new file mode 100644
index 0000000000000000000000000000000000000000..12f55e2f94120d5765f124f8eec867f1d82e0aa7
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/tools/extract.py
@@ -0,0 +1,47 @@
+import os
+import glob
+from re import split
+from tqdm import tqdm
+from multiprocessing import Pool
+from functools import partial
+
+scannet_dir='/root/data/ScanNet-v2-1.0.0/data/raw'
+dump_dir='/root/data/scannet_dump'
+num_process=32
+
+def extract(seq,scannet_dir,split,dump_dir):
+ assert split=='train' or split=='test'
+ if not os.path.exists(os.path.join(dump_dir,split,seq)):
+ os.mkdir(os.path.join(dump_dir,split,seq))
+ cmd='python reader.py --filename '+os.path.join(scannet_dir,'scans' if split=='train' else 'scans_test',seq,seq+'.sens')+' --output_path '+os.path.join(dump_dir,split,seq)+\
+ ' --export_depth_images --export_color_images --export_poses --export_intrinsics'
+ os.system(cmd)
+
+if __name__=='__main__':
+ if not os.path.exists(dump_dir):
+ os.mkdir(dump_dir)
+ os.mkdir(os.path.join(dump_dir,'train'))
+ os.mkdir(os.path.join(dump_dir,'test'))
+
+ train_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans','scene*'))]
+ test_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans_test','scene*'))]
+
+ extract_train=partial(extract,scannet_dir=scannet_dir,split='train',dump_dir=dump_dir)
+ extract_test=partial(extract,scannet_dir=scannet_dir,split='test',dump_dir=dump_dir)
+
+ num_train_iter=len(train_seq_list)//num_process if len(train_seq_list)%num_process==0 else len(train_seq_list)//num_process+1
+ num_test_iter=len(test_seq_list)//num_process if len(test_seq_list)%num_process==0 else len(test_seq_list)//num_process+1
+
+ pool = Pool(num_process)
+ for index in tqdm(range(num_train_iter)):
+ seq_list=train_seq_list[index*num_process:min((index+1)*num_process,len(train_seq_list))]
+ pool.map(extract_train,seq_list)
+ pool.close()
+ pool.join()
+
+ pool = Pool(num_process)
+ for index in tqdm(range(num_test_iter)):
+ seq_list=test_seq_list[index*num_process:min((index+1)*num_process,len(test_seq_list))]
+ pool.map(extract_test,seq_list)
+ pool.close()
+ pool.join()
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/tools/preprocess_scene.py b/imcui/third_party/ASpanFormer/tools/preprocess_scene.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20c0d070243519d67bbd25668ff5eb1657474be
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/tools/preprocess_scene.py
@@ -0,0 +1,242 @@
+import argparse
+
+import imagesize
+
+import numpy as np
+
+import os
+
+parser = argparse.ArgumentParser(description='MegaDepth preprocessing script')
+
+parser.add_argument(
+ '--base_path', type=str, required=True,
+ help='path to MegaDepth'
+)
+parser.add_argument(
+ '--scene_id', type=str, required=True,
+ help='scene ID'
+)
+
+parser.add_argument(
+ '--output_path', type=str, required=True,
+ help='path to the output directory'
+)
+
+args = parser.parse_args()
+
+base_path = args.base_path
+# Remove the trailing / if need be.
+if base_path[-1] in ['/', '\\']:
+ base_path = base_path[: - 1]
+scene_id = args.scene_id
+
+base_depth_path = os.path.join(
+ base_path, 'phoenix/S6/zl548/MegaDepth_v1'
+)
+base_undistorted_sfm_path = os.path.join(
+ base_path, 'Undistorted_SfM'
+)
+
+undistorted_sparse_path = os.path.join(
+ base_undistorted_sfm_path, scene_id, 'sparse-txt'
+)
+if not os.path.exists(undistorted_sparse_path):
+ exit()
+
+depths_path = os.path.join(
+ base_depth_path, scene_id, 'dense0', 'depths'
+)
+if not os.path.exists(depths_path):
+ exit()
+
+images_path = os.path.join(
+ base_undistorted_sfm_path, scene_id, 'images'
+)
+if not os.path.exists(images_path):
+ exit()
+
+# Process cameras.txt
+with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f:
+ raw = f.readlines()[3 :] # skip the header
+
+camera_intrinsics = {}
+for camera in raw:
+ camera = camera.split(' ')
+ camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]]
+
+# Process points3D.txt
+with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f:
+ raw = f.readlines()[3 :] # skip the header
+
+points3D = {}
+for point3D in raw:
+ point3D = point3D.split(' ')
+ points3D[int(point3D[0])] = np.array([
+ float(point3D[1]), float(point3D[2]), float(point3D[3])
+ ])
+
+# Process images.txt
+with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f:
+ raw = f.readlines()[4 :] # skip the header
+
+image_id_to_idx = {}
+image_names = []
+raw_pose = []
+camera = []
+points3D_id_to_2D = []
+n_points3D = []
+for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])):
+ image = image.split(' ')
+ points = points.split(' ')
+
+ image_id_to_idx[int(image[0])] = idx
+
+ image_name = image[-1].strip('\n')
+ image_names.append(image_name)
+
+ raw_pose.append([float(elem) for elem in image[1 : -2]])
+ camera.append(int(image[-2]))
+ current_points3D_id_to_2D = {}
+ for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]):
+ if int(point3D_id) == -1:
+ continue
+ current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)]
+ points3D_id_to_2D.append(current_points3D_id_to_2D)
+ n_points3D.append(len(current_points3D_id_to_2D))
+n_images = len(image_names)
+
+# Image and depthmaps paths
+image_paths = []
+depth_paths = []
+for image_name in image_names:
+ image_path = os.path.join(images_path, image_name)
+
+ # Path to the depth file
+ depth_path = os.path.join(
+ depths_path, '%s.h5' % os.path.splitext(image_name)[0]
+ )
+
+ if os.path.exists(depth_path):
+ # Check if depth map or background / foreground mask
+ file_size = os.stat(depth_path).st_size
+ # Rough estimate - 75KB might work as well
+ if file_size < 100 * 1024:
+ depth_paths.append(None)
+ image_paths.append(None)
+ else:
+ depth_paths.append(depth_path[len(base_path) + 1 :])
+ image_paths.append(image_path[len(base_path) + 1 :])
+ else:
+ depth_paths.append(None)
+ image_paths.append(None)
+
+# Camera configuration
+intrinsics = []
+poses = []
+principal_axis = []
+points3D_id_to_ndepth = []
+for idx, image_name in enumerate(image_names):
+ if image_paths[idx] is None:
+ intrinsics.append(None)
+ poses.append(None)
+ principal_axis.append([0, 0, 0])
+ points3D_id_to_ndepth.append({})
+ continue
+ image_intrinsics = camera_intrinsics[camera[idx]]
+ K = np.zeros([3, 3])
+ K[0, 0] = image_intrinsics[2]
+ K[0, 2] = image_intrinsics[4]
+ K[1, 1] = image_intrinsics[3]
+ K[1, 2] = image_intrinsics[5]
+ K[2, 2] = 1
+ intrinsics.append(K)
+
+ image_pose = raw_pose[idx]
+ qvec = image_pose[: 4]
+ qvec = qvec / np.linalg.norm(qvec)
+ w, x, y, z = qvec
+ R = np.array([
+ [
+ 1 - 2 * y * y - 2 * z * z,
+ 2 * x * y - 2 * z * w,
+ 2 * x * z + 2 * y * w
+ ],
+ [
+ 2 * x * y + 2 * z * w,
+ 1 - 2 * x * x - 2 * z * z,
+ 2 * y * z - 2 * x * w
+ ],
+ [
+ 2 * x * z - 2 * y * w,
+ 2 * y * z + 2 * x * w,
+ 1 - 2 * x * x - 2 * y * y
+ ]
+ ])
+ principal_axis.append(R[2, :])
+ t = image_pose[4 : 7]
+ # World-to-Camera pose
+ current_pose = np.zeros([4, 4])
+ current_pose[: 3, : 3] = R
+ current_pose[: 3, 3] = t
+ current_pose[3, 3] = 1
+ # Camera-to-World pose
+ # pose = np.zeros([4, 4])
+ # pose[: 3, : 3] = np.transpose(R)
+ # pose[: 3, 3] = -np.matmul(np.transpose(R), t)
+ # pose[3, 3] = 1
+ poses.append(current_pose)
+
+ current_points3D_id_to_ndepth = {}
+ for point3D_id in points3D_id_to_2D[idx].keys():
+ p3d = points3D[point3D_id]
+ current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1]))
+ points3D_id_to_ndepth.append(current_points3D_id_to_ndepth)
+principal_axis = np.array(principal_axis)
+angles = np.rad2deg(np.arccos(
+ np.clip(
+ np.dot(principal_axis, np.transpose(principal_axis)),
+ -1, 1
+ )
+))
+
+# Compute overlap score
+overlap_matrix = np.full([n_images, n_images], -1.)
+scale_ratio_matrix = np.full([n_images, n_images], -1.)
+for idx1 in range(n_images):
+ if image_paths[idx1] is None or depth_paths[idx1] is None:
+ continue
+ for idx2 in range(idx1 + 1, n_images):
+ if image_paths[idx2] is None or depth_paths[idx2] is None:
+ continue
+ matches = (
+ points3D_id_to_2D[idx1].keys() &
+ points3D_id_to_2D[idx2].keys()
+ )
+ min_num_points3D = min(
+ len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2])
+ )
+ overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D
+ overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D
+ if len(matches) == 0:
+ continue
+ points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1]
+ points3D_id_to_ndepth2 = points3D_id_to_ndepth[idx2]
+ nd1 = np.array([points3D_id_to_ndepth1[match] for match in matches])
+ nd2 = np.array([points3D_id_to_ndepth2[match] for match in matches])
+ min_scale_ratio = np.min(np.maximum(nd1 / nd2, nd2 / nd1))
+ scale_ratio_matrix[idx1, idx2] = min_scale_ratio
+ scale_ratio_matrix[idx2, idx1] = min_scale_ratio
+
+np.savez(
+ os.path.join(args.output_path, '%s.npz' % scene_id),
+ image_paths=image_paths,
+ depth_paths=depth_paths,
+ intrinsics=intrinsics,
+ poses=poses,
+ overlap_matrix=overlap_matrix,
+ scale_ratio_matrix=scale_ratio_matrix,
+ angles=angles,
+ n_points3D=n_points3D,
+ points3D_id_to_2D=points3D_id_to_2D,
+ points3D_id_to_ndepth=points3D_id_to_ndepth
+)
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/tools/reader.py b/imcui/third_party/ASpanFormer/tools/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..f419fbaa8a099fcfede1cea51fcf95a2c1589160
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/tools/reader.py
@@ -0,0 +1,39 @@
+import argparse
+import os, sys
+
+from SensorData import SensorData
+
+# params
+parser = argparse.ArgumentParser()
+# data paths
+parser.add_argument('--filename', required=True, help='path to sens file to read')
+parser.add_argument('--output_path', required=True, help='path to output folder')
+parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true')
+parser.add_argument('--export_color_images', dest='export_color_images', action='store_true')
+parser.add_argument('--export_poses', dest='export_poses', action='store_true')
+parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true')
+parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False)
+
+opt = parser.parse_args()
+print(opt)
+
+
+def main():
+ if not os.path.exists(opt.output_path):
+ os.makedirs(opt.output_path)
+ # load the data
+ sys.stdout.write('loading %s...' % opt.filename)
+ sd = SensorData(opt.filename)
+ sys.stdout.write('loaded!\n')
+ if opt.export_depth_images:
+ sd.export_depth_images(os.path.join(opt.output_path, 'depth'))
+ if opt.export_color_images:
+ sd.export_color_images(os.path.join(opt.output_path, 'color'))
+ if opt.export_poses:
+ sd.export_poses(os.path.join(opt.output_path, 'pose'))
+ if opt.export_intrinsics:
+ sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic'))
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/tools/undistort_mega.py b/imcui/third_party/ASpanFormer/tools/undistort_mega.py
new file mode 100644
index 0000000000000000000000000000000000000000..68798ff30e6afa37a0f98571ecfd3f05751868c8
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/tools/undistort_mega.py
@@ -0,0 +1,69 @@
+import argparse
+
+import imagesize
+
+import os
+
+import subprocess
+
+parser = argparse.ArgumentParser(description='MegaDepth Undistortion')
+
+parser.add_argument(
+ '--colmap_path', type=str,default='/usr/bin/',
+ help='path to colmap executable'
+)
+parser.add_argument(
+ '--base_path', type=str,default='/root/MegaDepth',
+ help='path to MegaDepth'
+)
+
+args = parser.parse_args()
+
+sfm_path = os.path.join(
+ args.base_path, 'MegaDepth_v1_SfM'
+)
+base_depth_path = os.path.join(
+ args.base_path, 'phoenix/S6/zl548/MegaDepth_v1'
+)
+output_path = os.path.join(
+ args.base_path, 'Undistorted_SfM'
+)
+
+os.mkdir(output_path)
+
+for scene_name in os.listdir(base_depth_path):
+ current_output_path = os.path.join(output_path, scene_name)
+ os.mkdir(current_output_path)
+
+ image_path = os.path.join(
+ base_depth_path, scene_name, 'dense0', 'imgs'
+ )
+ if not os.path.exists(image_path):
+ continue
+
+ # Find the maximum image size in scene.
+ max_image_size = 0
+ for image_name in os.listdir(image_path):
+ max_image_size = max(
+ max_image_size,
+ max(imagesize.get(os.path.join(image_path, image_name)))
+ )
+
+ # Undistort the images and update the reconstruction.
+ subprocess.call([
+ os.path.join(args.colmap_path, 'colmap'), 'image_undistorter',
+ '--image_path', os.path.join(sfm_path, scene_name, 'images'),
+ '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'),
+ '--output_path', current_output_path,
+ '--max_image_size', str(max_image_size)
+ ])
+
+ # Transform the reconstruction to raw text format.
+ sparse_txt_path = os.path.join(current_output_path, 'sparse-txt')
+ os.mkdir(sparse_txt_path)
+ subprocess.call([
+ os.path.join(args.colmap_path, 'colmap'), 'model_converter',
+ '--input_path', os.path.join(current_output_path, 'sparse'),
+ '--output_path', sparse_txt_path,
+ '--output_type', 'TXT'
+ ])
\ No newline at end of file
diff --git a/imcui/third_party/ASpanFormer/train.py b/imcui/third_party/ASpanFormer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..21f644763711481e84863ed5d861ec57d95f2d5c
--- /dev/null
+++ b/imcui/third_party/ASpanFormer/train.py
@@ -0,0 +1,134 @@
+import math
+import argparse
+import pprint
+from distutils.util import strtobool
+from pathlib import Path
+from loguru import logger as loguru_logger
+
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+from pytorch_lightning.plugins import DDPPlugin
+
+from src.config.default import get_cfg_defaults
+from src.utils.misc import get_rank_zero_only_logger, setup_gpus
+from src.utils.profiler import build_profiler
+from src.lightning.data import MultiSceneDataModule
+from src.lightning.lightning_aspanformer import PL_ASpanFormer
+
+loguru_logger = get_rank_zero_only_logger(loguru_logger)
+
+
+def parse_args():
+ def str2bool(v):
+ return v.lower() in ("true", "1")
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--exp_name', type=str, default='default_exp_name')
+ parser.add_argument(
+ '--batch_size', type=int, default=4, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=4)
+ parser.add_argument(
+ '--pin_memory', type=lambda x: bool(strtobool(x)),
+ nargs='?', default=True, help='whether loading data to pinned memory or not')
+ parser.add_argument(
+ '--ckpt_path', type=str, default=None,
+ help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer')
+ parser.add_argument(
+ '--disable_ckpt', action='store_true',
+ help='disable checkpoint saving (useful for debugging).')
+ parser.add_argument(
+ '--profiler_name', type=str, default=None,
+ help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--parallel_load_data', action='store_true',
+ help='load datasets in with multiple processes.')
+ parser.add_argument(
+ '--mode', type=str, default='vanilla',
+ help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer')
+ parser.add_argument(
+ '--ini', type=str2bool, default=False,
+ help='pretrained checkpoint path, helpful for using a pre-trained coarse-only ASpanFormer')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+
+def main():
+ # parse arguments
+ args = parse_args()
+ rank_zero_only(pprint.pprint)(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+ # TODO: Use different seeds for each dataloader workers
+ # This is needed for data augmentation
+
+ # scale lr and warmup-step automatically
+ args.gpus = _n_gpus = setup_gpus(args.gpus)
+ config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
+ config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
+ _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
+ config.TRAINER.SCALING = _scaling
+ config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
+ config.TRAINER.WARMUP_STEP = math.floor(
+ config.TRAINER.WARMUP_STEP / _scaling)
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_ASpanFormer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
+ loguru_logger.info(f"ASpanFormer LightningModule initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"ASpanFormer DataModule initialized!")
+
+ # TensorBoard Logger
+ logger = TensorBoardLogger(
+ save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
+ ckpt_dir = Path(logger.log_dir) / 'checkpoints'
+
+ # Callbacks
+ # TODO: update ModelCheckpoint to monitor multiple metrics
+ ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
+ save_last=True,
+ dirpath=str(ckpt_dir),
+ filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
+ lr_monitor = LearningRateMonitor(logging_interval='step')
+ callbacks = [lr_monitor]
+ if not args.disable_ckpt:
+ callbacks.append(ckpt_callback)
+
+ # Lightning Trainer
+ trainer = pl.Trainer.from_argparse_args(
+ args,
+ plugins=DDPPlugin(find_unused_parameters=False,
+ num_nodes=args.num_nodes,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
+ gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
+ callbacks=callbacks,
+ logger=logger,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
+ replace_sampler_ddp=False, # use custom sampler
+ reload_dataloaders_every_epoch=False, # avoid repeated samples!
+ weights_summary='full',
+ profiler=profiler)
+ loguru_logger.info(f"Trainer initialized!")
+ loguru_logger.info(f"Start training!")
+ trainer.fit(model, datamodule=data_module)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/imcui/third_party/COTR/COTR/cameras/camera_pose.py b/imcui/third_party/COTR/COTR/cameras/camera_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fd2263cbaace94036c3ef85eb082ca749ba51eb
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/cameras/camera_pose.py
@@ -0,0 +1,164 @@
+'''
+Extrinsic camera pose
+'''
+import math
+import copy
+
+import numpy as np
+
+from COTR.transformations import transformations
+from COTR.transformations.transform_basics import Translation, Rotation, UnstableRotation
+
+
+class CameraPose():
+ def __init__(self, t: Translation, r: Rotation):
+ '''
+ WARN: World 2 cam
+ Translation and rotation are world to camera
+ translation_vector is not the coordinate of the camera in world space.
+ '''
+ assert isinstance(t, Translation)
+ assert isinstance(r, Rotation) or isinstance(r, UnstableRotation)
+ self.t = t
+ self.r = r
+
+ def __str__(self):
+ string = f'center in world: {self.camera_center_in_world}, translation(w2c): {self.t}, rotation(w2c): {self.r}'
+ return string
+
+ @classmethod
+ def from_world_to_camera(cls, world_to_camera, unstable=False):
+ assert isinstance(world_to_camera, np.ndarray)
+ assert world_to_camera.shape == (4, 4)
+ vec = transformations.translation_from_matrix(world_to_camera).astype(np.float32)
+ t = Translation(vec)
+ if unstable:
+ r = UnstableRotation(world_to_camera)
+ else:
+ quat = transformations.quaternion_from_matrix(world_to_camera).astype(np.float32)
+ r = Rotation(quat)
+ return cls(t, r)
+
+ @classmethod
+ def from_camera_to_world(cls, camera_to_world, unstable=False):
+ assert isinstance(camera_to_world, np.ndarray)
+ assert camera_to_world.shape == (4, 4)
+ world_to_camera = np.linalg.inv(camera_to_world)
+ world_to_camera /= world_to_camera[3, 3]
+ return cls.from_world_to_camera(world_to_camera, unstable)
+
+ @classmethod
+ def from_pose_vector(cls, pose_vector):
+ t = Translation(pose_vector[:3])
+ r = Rotation(pose_vector[3:])
+ return cls(t, r)
+
+ @property
+ def translation_vector(self):
+ return self.t.translation_vector
+
+ @property
+ def translation_matrix(self):
+ return self.t.translation_matrix
+
+ @property
+ def quaternion(self):
+ '''
+ quaternion format (w, x, y, z)
+ '''
+ return self.r.quaternion
+
+ @property
+ def rotation_matrix(self):
+ return self.r.rotation_matrix
+
+ @property
+ def pose_vector(self):
+ '''
+ Pose vector is a concat of translation vector and quaternion vector
+ (X, Y, Z, w, x, y, z)
+ w2c
+ '''
+ return np.concatenate([self.translation_vector, self.quaternion])
+
+ @property
+ def inv_pose_vector(self):
+ inv_quat = transformations.quaternion_inverse(self.quaternion)
+ return np.concatenate([self.camera_center_in_world, inv_quat])
+
+ @property
+ def pose_vector_6_dof(self):
+ '''
+ Here we assuming the quaternion is normalized and we remove the W component
+ (X, Y, Z, x, y, z)
+ '''
+ return np.concatenate([self.translation_vector, self.quaternion[1:]])
+
+ @property
+ def world_to_camera(self):
+ M = np.matmul(self.translation_matrix, self.rotation_matrix)
+ M /= M[3, 3]
+ return M
+
+ @property
+ def world_to_camera_3x4(self):
+ M = self.world_to_camera
+ M = M[0:3, 0:4]
+ return M
+
+ @property
+ def extrinsic_mat(self):
+ return self.world_to_camera_3x4
+
+ @property
+ def camera_to_world(self):
+ M = np.linalg.inv(self.world_to_camera)
+ M /= M[3, 3]
+ return M
+
+ @property
+ def camera_to_world_3x4(self):
+ M = self.camera_to_world
+ M = M[0:3, 0:4]
+ return M
+
+ @property
+ def camera_center_in_world(self):
+ return self.camera_to_world[:3, 3]
+
+ @property
+ def forward(self):
+ return self.camera_to_world[:3, 2]
+
+ @property
+ def up(self):
+ return self.camera_to_world[:3, 1]
+
+ @property
+ def right(self):
+ return self.camera_to_world[:3, 0]
+
+ @property
+ def essential_matrix(self):
+ E = np.cross(self.rotation_matrix[:3, :3], self.camera_center_in_world)
+ return E / np.linalg.norm(E)
+
+
+def inverse_camera_pose(cam_pose: CameraPose):
+ return CameraPose.from_world_to_camera(np.linalg.inv(cam_pose.world_to_camera))
+
+
+def rotate_camera_pose(cam_pose, rot):
+ if rot == 0:
+ return copy.deepcopy(cam_pose)
+ else:
+ rot = rot / 180 * np.pi
+ sin_rot = np.sin(rot)
+ cos_rot = np.cos(rot)
+
+ rot_mat = np.stack([np.stack([cos_rot, -sin_rot, 0, 0], axis=-1),
+ np.stack([sin_rot, cos_rot, 0, 0], axis=-1),
+ np.stack([0, 0, 1, 0], axis=-1),
+ np.stack([0, 0, 0, 1], axis=-1)], axis=1)
+ new_world2cam = np.matmul(rot_mat, cam_pose.world_to_camera)
+ return CameraPose.from_world_to_camera(new_world2cam)
diff --git a/imcui/third_party/COTR/COTR/cameras/capture.py b/imcui/third_party/COTR/COTR/cameras/capture.py
new file mode 100644
index 0000000000000000000000000000000000000000..c09180def5fcb030e5cde0e14ab85b71831642ae
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/cameras/capture.py
@@ -0,0 +1,432 @@
+'''
+Capture from a pinhole camera
+Separate the captured content and the camera...
+'''
+
+import os
+import time
+import abc
+import copy
+
+import cv2
+import torch
+import numpy as np
+import imageio
+import PIL
+from PIL import Image
+
+from COTR.cameras.camera_pose import CameraPose, rotate_camera_pose
+from COTR.cameras.pinhole_camera import PinholeCamera, rotate_pinhole_camera, crop_pinhole_camera
+from COTR.utils import debug_utils, utils, constants
+from COTR.utils.utils import Point2D
+from COTR.projector import pcd_projector
+from COTR.utils.constants import MAX_SIZE
+from COTR.utils.utils import CropCamConfig
+
+
+def crop_center_max_xy(p2d, shape):
+ h, w = shape
+ crop_x = min(h, w)
+ crop_y = crop_x
+ start_x = w // 2 - crop_x // 2
+ start_y = h // 2 - crop_y // 2
+ mask = (p2d.xy[:, 0] > start_x) & (p2d.xy[:, 0] < start_x + crop_x) & (p2d.xy[:, 1] > start_y) & (p2d.xy[:, 1] < start_y + crop_y)
+ out_xy = (p2d.xy - [start_x, start_y])[mask]
+ out = Point2D(p2d.id_3d[mask], out_xy)
+ return out
+
+
+def crop_center_max(img):
+ if isinstance(img, torch.Tensor):
+ return crop_center_max_torch(img)
+ elif isinstance(img, np.ndarray):
+ return crop_center_max_np(img)
+ else:
+ raise ValueError
+
+
+def crop_center_max_torch(img):
+ if len(img.shape) == 2:
+ h, w = img.shape
+ elif len(img.shape) == 3:
+ c, h, w = img.shape
+ elif len(img.shape) == 4:
+ b, c, h, w = img.shape
+ else:
+ raise ValueError
+ crop_x = min(h, w)
+ crop_y = crop_x
+ start_x = w // 2 - crop_x // 2
+ start_y = h // 2 - crop_y // 2
+ if len(img.shape) == 2:
+ return img[start_y:start_y + crop_y, start_x:start_x + crop_x]
+ elif len(img.shape) in [3, 4]:
+ return img[..., start_y:start_y + crop_y, start_x:start_x + crop_x]
+
+
+def crop_center_max_np(img, return_starts=False):
+ if len(img.shape) == 2:
+ h, w = img.shape
+ elif len(img.shape) == 3:
+ h, w, c = img.shape
+ elif len(img.shape) == 4:
+ b, h, w, c = img.shape
+ else:
+ raise ValueError
+ crop_x = min(h, w)
+ crop_y = crop_x
+ start_x = w // 2 - crop_x // 2
+ start_y = h // 2 - crop_y // 2
+ if len(img.shape) == 2:
+ canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x]
+ elif len(img.shape) == 3:
+ canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x, :]
+ elif len(img.shape) == 4:
+ canvas = img[:, start_y:start_y + crop_y, start_x:start_x + crop_x, :]
+ if return_starts:
+ return canvas, -start_x, -start_y
+ else:
+ return canvas
+
+
+def pad_to_square_np(img, till_divisible_by=1, return_starts=False):
+ if len(img.shape) == 2:
+ h, w = img.shape
+ elif len(img.shape) == 3:
+ h, w, c = img.shape
+ elif len(img.shape) == 4:
+ b, h, w, c = img.shape
+ else:
+ raise ValueError
+ if till_divisible_by == 1:
+ size = max(h, w)
+ else:
+ size = (max(h, w) + till_divisible_by) - (max(h, w) % till_divisible_by)
+ start_x = size // 2 - w // 2
+ start_y = size // 2 - h // 2
+ if len(img.shape) == 2:
+ canvas = np.zeros([size, size], dtype=img.dtype)
+ canvas[start_y:start_y + h, start_x:start_x + w] = img
+ elif len(img.shape) == 3:
+ canvas = np.zeros([size, size, c], dtype=img.dtype)
+ canvas[start_y:start_y + h, start_x:start_x + w, :] = img
+ elif len(img.shape) == 4:
+ canvas = np.zeros([b, size, size, c], dtype=img.dtype)
+ canvas[:, start_y:start_y + h, start_x:start_x + w, :] = img
+ if return_starts:
+ return canvas, start_x, start_y
+ else:
+ return canvas
+
+
+def stretch_to_square_np(img):
+ size = max(*img.shape[:2])
+ return np.array(PIL.Image.fromarray(img).resize((size, size), resample=PIL.Image.BILINEAR))
+
+
+def rotate_image(image, angle, interpolation=cv2.INTER_LINEAR):
+ image_center = tuple(np.array(image.shape[1::-1]) / 2)
+ rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
+ result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=interpolation)
+ return result
+
+
+def read_array(path):
+ '''
+ https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
+ '''
+ with open(path, "rb") as fid:
+ width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
+ usecols=(0, 1, 2), dtype=int)
+ fid.seek(0)
+ num_delimiter = 0
+ byte = fid.read(1)
+ while True:
+ if byte == b"&":
+ num_delimiter += 1
+ if num_delimiter >= 3:
+ break
+ byte = fid.read(1)
+ array = np.fromfile(fid, np.float32)
+ array = array.reshape((width, height, channels), order="F")
+ return np.transpose(array, (1, 0, 2)).squeeze()
+
+
+################ Content ################
+
+
+class CapturedContent(abc.ABC):
+ def __init__(self):
+ self._rotation = 0
+
+ @property
+ def rotation(self):
+ return self._rotation
+
+ @rotation.setter
+ def rotation(self, rot):
+ self._rotation = rot
+
+
+class CapturedImage(CapturedContent):
+ def __init__(self, img_path, crop_cam, pinhole_cam_before=None):
+ super(CapturedImage, self).__init__()
+ assert os.path.isfile(img_path), 'file does not exist: {0}'.format(img_path)
+ self.crop_cam = crop_cam
+ self._image = None
+ self.img_path = img_path
+ self.pinhole_cam_before = pinhole_cam_before
+ self._p2d = None
+
+ def read_image_to_ram(self) -> int:
+ # raise NotImplementedError
+ assert self._image is None
+ _image = self.image
+ self._image = _image
+ return self._image.nbytes
+
+ @property
+ def image(self):
+ if self._image is not None:
+ _image = self._image
+ else:
+ _image = imageio.imread(self.img_path, pilmode='RGB')
+ if self.rotation != 0:
+ _image = rotate_image(_image, self.rotation)
+ if _image.shape[:2] != self.pinhole_cam_before.shape:
+ _image = np.array(PIL.Image.fromarray(_image).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.BILINEAR))
+ assert _image.shape[:2] == self.pinhole_cam_before.shape
+ if self.crop_cam == 'no_crop':
+ pass
+ elif self.crop_cam == 'crop_center':
+ _image = crop_center_max(_image)
+ elif self.crop_cam == 'crop_center_and_resize':
+ _image = crop_center_max(_image)
+ _image = np.array(PIL.Image.fromarray(_image).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
+ elif isinstance(self.crop_cam, CropCamConfig):
+ assert _image.shape[0] == self.crop_cam.orig_h
+ assert _image.shape[1] == self.crop_cam.orig_w
+ _image = _image[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h,
+ self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ]
+ _image = np.array(PIL.Image.fromarray(_image).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.BILINEAR))
+ assert _image.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w)
+ else:
+ raise ValueError()
+ return _image
+
+ @property
+ def p2d(self):
+ if self._p2d is None:
+ return self._p2d
+ else:
+ _p2d = self._p2d
+ if self.crop_cam == 'no_crop':
+ pass
+ elif self.crop_cam == 'crop_center':
+ _p2d = crop_center_max_xy(_p2d, self.pinhole_cam_before.shape)
+ else:
+ raise ValueError()
+ return _p2d
+
+ @p2d.setter
+ def p2d(self, value):
+ if value is not None:
+ assert isinstance(value, Point2D)
+ self._p2d = value
+
+
+class CapturedDepth(CapturedContent):
+ def __init__(self, depth_path, crop_cam, pinhole_cam_before=None):
+ super(CapturedDepth, self).__init__()
+ if not depth_path.endswith('dummy'):
+ assert os.path.isfile(depth_path), 'file does not exist: {0}'.format(depth_path)
+ self.crop_cam = crop_cam
+ self._depth = None
+ self.depth_path = depth_path
+ self.pinhole_cam_before = pinhole_cam_before
+
+ def read_depth(self):
+ import tables
+ if self.depth_path.endswith('dummy'):
+ image_path = self.depth_path[:-5]
+ w, h = Image.open(image_path).size
+ _depth = np.zeros([h, w], dtype=np.float32)
+ elif self.depth_path.endswith('.h5'):
+ depth_h5 = tables.open_file(self.depth_path, mode='r')
+ _depth = np.array(depth_h5.root.depth)
+ depth_h5.close()
+ else:
+ raise ValueError
+ return _depth.astype(np.float32)
+
+ def read_depth_to_ram(self) -> int:
+ # raise NotImplementedError
+ assert self._depth is None
+ _depth = self.depth_map
+ self._depth = _depth
+ return self._depth.nbytes
+
+ @property
+ def depth_map(self):
+ if self._depth is not None:
+ _depth = self._depth
+ else:
+ _depth = self.read_depth()
+ if self.rotation != 0:
+ _depth = rotate_image(_depth, self.rotation, interpolation=cv2.INTER_NEAREST)
+ if _depth.shape != self.pinhole_cam_before.shape:
+ _depth = np.array(PIL.Image.fromarray(_depth).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.NEAREST))
+ assert _depth.shape[:2] == self.pinhole_cam_before.shape
+ if self.crop_cam == 'no_crop':
+ pass
+ elif self.crop_cam == 'crop_center':
+ _depth = crop_center_max(_depth)
+ elif self.crop_cam == 'crop_center_and_resize':
+ _depth = crop_center_max(_depth)
+ _depth = np.array(PIL.Image.fromarray(_depth).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.NEAREST))
+ elif isinstance(self.crop_cam, CropCamConfig):
+ assert _depth.shape[0] == self.crop_cam.orig_h
+ assert _depth.shape[1] == self.crop_cam.orig_w
+ _depth = _depth[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h,
+ self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ]
+ _depth = np.array(PIL.Image.fromarray(_depth).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.NEAREST))
+ assert _depth.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w)
+ else:
+ raise ValueError()
+ assert (_depth >= 0).all()
+ return _depth
+
+
+################ Pinhole Capture ################
+class BasePinholeCapture():
+ def __init__(self, pinhole_cam, cam_pose, crop_cam):
+ self.crop_cam = crop_cam
+ self.cam_pose = cam_pose
+ # modify the camera instrinsics
+ self.pinhole_cam = crop_pinhole_camera(pinhole_cam, crop_cam)
+ self.pinhole_cam_before = pinhole_cam
+
+ def __str__(self):
+ string = 'pinhole camera: {0}\ncamera pose: {1}'.format(self.pinhole_cam, self.cam_pose)
+ return string
+
+ @property
+ def intrinsic_mat(self):
+ return self.pinhole_cam.intrinsic_mat
+
+ @property
+ def extrinsic_mat(self):
+ return self.cam_pose.extrinsic_mat
+
+ @property
+ def shape(self):
+ return self.pinhole_cam.shape
+
+ @property
+ def size(self):
+ return self.shape
+
+ @property
+ def mvp_mat(self):
+ '''
+ model-view-projection matrix (naming from opengl)
+ '''
+ return np.matmul(self.pinhole_cam.intrinsic_mat, self.cam_pose.world_to_camera_3x4)
+
+
+class RGBPinholeCapture(BasePinholeCapture):
+ def __init__(self, img_path, pinhole_cam, cam_pose, crop_cam):
+ BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam)
+ self.captured_image = CapturedImage(img_path, crop_cam, self.pinhole_cam_before)
+
+ def read_image_to_ram(self) -> int:
+ return self.captured_image.read_image_to_ram()
+
+ @property
+ def img_path(self):
+ return self.captured_image.img_path
+
+ @property
+ def image(self):
+ _image = self.captured_image.image
+ assert _image.shape[0:2] == self.pinhole_cam.shape, 'image shape: {0}, pinhole camera: {1}'.format(_image.shape, self.pinhole_cam)
+ return _image
+
+ @property
+ def seq_id(self):
+ return os.path.dirname(self.captured_image.img_path)
+
+ @property
+ def p2d(self):
+ return self.captured_image.p2d
+
+ @p2d.setter
+ def p2d(self, value):
+ self.captured_image.p2d = value
+
+
+class DepthPinholeCapture(BasePinholeCapture):
+ def __init__(self, depth_path, pinhole_cam, cam_pose, crop_cam):
+ BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam)
+ self.captured_depth = CapturedDepth(depth_path, crop_cam, self.pinhole_cam_before)
+
+ def read_depth_to_ram(self) -> int:
+ return self.captured_depth.read_depth_to_ram()
+
+ @property
+ def depth_path(self):
+ return self.captured_depth.depth_path
+
+ @property
+ def depth_map(self):
+ _depth = self.captured_depth.depth_map
+ # if self.pinhole_cam.shape != _depth.shape:
+ # _depth = misc.imresize(_depth, self.pinhole_cam.shape, interp='nearest', mode='F')
+ assert (_depth >= 0).all()
+ return _depth
+
+ @property
+ def point_cloud_world(self):
+ return self.get_point_cloud_world_from_depth(feat_map=None)
+
+ def get_point_cloud_world_from_depth(self, feat_map=None):
+ _pcd = pcd_projector.PointCloudProjector.img_2d_to_pcd_3d_np(self.depth_map, self.pinhole_cam.intrinsic_mat, img=feat_map, motion=self.cam_pose.camera_to_world).astype(constants.DEFAULT_PRECISION)
+ return _pcd
+
+
+class RGBDPinholeCapture(RGBPinholeCapture, DepthPinholeCapture):
+ def __init__(self, img_path, depth_path, pinhole_cam, cam_pose, crop_cam):
+ RGBPinholeCapture.__init__(self, img_path, pinhole_cam, cam_pose, crop_cam)
+ DepthPinholeCapture.__init__(self, depth_path, pinhole_cam, cam_pose, crop_cam)
+
+ @property
+ def point_cloud_w_rgb_world(self):
+ return self.get_point_cloud_world_from_depth(feat_map=self.image)
+
+
+def rotate_capture(cap, rot):
+ if rot == 0:
+ return copy.deepcopy(cap)
+ else:
+ rot_pose = rotate_camera_pose(cap.cam_pose, rot)
+ rot_cap = copy.deepcopy(cap)
+ rot_cap.cam_pose = rot_pose
+ if hasattr(rot_cap, 'captured_image'):
+ rot_cap.captured_image.rotation = rot
+ if hasattr(rot_cap, 'captured_depth'):
+ rot_cap.captured_depth.rotation = rot
+ return rot_cap
+
+
+def crop_capture(cap, crop_cam):
+ if isinstance(cap, RGBDPinholeCapture):
+ cropped_cap = RGBDPinholeCapture(cap.img_path, cap.depth_path, cap.pinhole_cam, cap.cam_pose, crop_cam)
+ elif isinstance(cap, RGBPinholeCapture):
+ cropped_cap = RGBPinholeCapture(cap.img_path, cap.pinhole_cam, cap.cam_pose, crop_cam)
+ else:
+ raise ValueError
+ if hasattr(cropped_cap, 'captured_image'):
+ cropped_cap.captured_image.rotation = cap.captured_image.rotation
+ if hasattr(cropped_cap, 'captured_depth'):
+ cropped_cap.captured_depth.rotation = cap.captured_depth.rotation
+ return cropped_cap
diff --git a/imcui/third_party/COTR/COTR/cameras/pinhole_camera.py b/imcui/third_party/COTR/COTR/cameras/pinhole_camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f06cc167ba4a56092c9bdc7e15dc77f245f5d79
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/cameras/pinhole_camera.py
@@ -0,0 +1,73 @@
+"""
+Static pinhole camera
+"""
+
+import copy
+
+import numpy as np
+
+from COTR.utils import constants
+from COTR.utils.constants import MAX_SIZE
+from COTR.utils.utils import CropCamConfig
+
+
+class PinholeCamera():
+ def __init__(self, width, height, fx, fy, cx, cy):
+ self.width = int(width)
+ self.height = int(height)
+ self.fx = fx
+ self.fy = fy
+ self.cx = cx
+ self.cy = cy
+
+ def __str__(self):
+ string = 'width: {0}, height: {1}, fx: {2}, fy: {3}, cx: {4}, cy: {5}'.format(self.width, self.height, self.fx, self.fy, self.cx, self.cy)
+ return string
+
+ @property
+ def shape(self):
+ return (self.height, self.width)
+
+ @property
+ def intrinsic_mat(self):
+ mat = np.array([[self.fx, 0.0, self.cx],
+ [0.0, self.fy, self.cy],
+ [0.0, 0.0, 1.0]], dtype=constants.DEFAULT_PRECISION)
+ return mat
+
+
+def rotate_pinhole_camera(cam, rot):
+ assert 0, 'TODO: Camera should stay the same while rotation'
+ assert rot in [0, 90, 180, 270], 'only support 0/90/180/270 degrees rotation'
+ if rot in [0, 180]:
+ return copy.deepcopy(cam)
+ elif rot in [90, 270]:
+ return PinholeCamera(width=cam.height, height=cam.width, fx=cam.fy, fy=cam.fx, cx=cam.cy, cy=cam.cx)
+ else:
+ raise NotImplementedError
+
+
+def crop_pinhole_camera(pinhole_cam, crop_cam):
+ if crop_cam == 'no_crop':
+ cropped_pinhole_cam = pinhole_cam
+ elif crop_cam == 'crop_center':
+ _h = _w = min(*pinhole_cam.shape)
+ _cx = _cy = _h / 2
+ cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx, pinhole_cam.fy, _cx, _cy)
+ elif crop_cam == 'crop_center_and_resize':
+ _h = _w = MAX_SIZE
+ _cx = _cy = MAX_SIZE / 2
+ scale = MAX_SIZE / min(*pinhole_cam.shape)
+ cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx * scale, pinhole_cam.fy * scale, _cx, _cy)
+ elif isinstance(crop_cam, CropCamConfig):
+ scale = crop_cam.out_h / crop_cam.h
+ cropped_pinhole_cam = PinholeCamera(crop_cam.out_w,
+ crop_cam.out_h,
+ pinhole_cam.fx * scale,
+ pinhole_cam.fy * scale,
+ (pinhole_cam.cx - crop_cam.x) * scale,
+ (pinhole_cam.cy - crop_cam.y) * scale
+ )
+ else:
+ raise ValueError
+ return cropped_pinhole_cam
diff --git a/imcui/third_party/COTR/COTR/datasets/colmap_helper.py b/imcui/third_party/COTR/COTR/datasets/colmap_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d84ed8645ba156fce7069361609d8cd4f577cfe4
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/datasets/colmap_helper.py
@@ -0,0 +1,312 @@
+import sys
+assert sys.version_info >= (3, 7), 'ordered dict is required'
+import os
+import re
+from collections import namedtuple
+import json
+
+import numpy as np
+from tqdm import tqdm
+
+from COTR.utils import debug_utils
+from COTR.cameras.pinhole_camera import PinholeCamera
+from COTR.cameras.camera_pose import CameraPose
+from COTR.cameras.capture import RGBPinholeCapture, RGBDPinholeCapture
+from COTR.cameras import capture
+from COTR.transformations import transformations
+from COTR.transformations.transform_basics import Translation, Rotation
+from COTR.sfm_scenes import sfm_scenes
+from COTR.global_configs import dataset_config
+from COTR.utils.utils import Point2D, Point3D
+
+ImageMeta = namedtuple('ImageMeta', ['image_id', 'r', 't', 'camera_id', 'image_path', 'point3d_id', 'p2d'])
+COVISIBILITY_CHECK = False
+LOAD_PCD = False
+
+
+class ColmapAsciiReader():
+ def __init__(self):
+ pass
+
+ @classmethod
+ def read_sfm_scene(cls, scene_dir, images_dir, crop_cam):
+ point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
+ cameras_path = os.path.join(scene_dir, 'cameras.txt')
+ images_path = os.path.join(scene_dir, 'images.txt')
+ captures = cls.read_captures(images_path, cameras_path, images_dir, crop_cam)
+ if LOAD_PCD:
+ point_cloud = cls.read_point_cloud(point_cloud_path)
+ else:
+ point_cloud = None
+ sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
+ return sfm_scene
+
+ @staticmethod
+ def read_point_cloud(points_txt_path):
+ with open(points_txt_path, "r") as fid:
+ line = fid.readline()
+ assert line == '# 3D point list with one line of data per point:\n'
+ line = fid.readline()
+ assert line == '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n'
+ line = fid.readline()
+ assert re.search('^# Number of points: \d+, mean track length: [-+]?\d*\.\d+|\d+\n$', line)
+ num_points, mean_track_length = re.findall(r"[-+]?\d*\.\d+|\d+", line)
+ num_points = int(num_points)
+ mean_track_length = float(mean_track_length)
+
+ xyz = np.zeros((num_points, 3), dtype=np.float32)
+ rgb = np.zeros((num_points, 3), dtype=np.float32)
+ if COVISIBILITY_CHECK:
+ point_meta = {}
+
+ for i in tqdm(range(num_points), desc='reading point cloud'):
+ elems = fid.readline().split()
+ xyz[i] = list(map(float, elems[1:4]))
+ rgb[i] = list(map(int, elems[4:7]))
+ if COVISIBILITY_CHECK:
+ point_id = int(elems[0])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point_meta[point_id] = Point3D(id=point_id,
+ arr_idx=i,
+ image_ids=image_ids)
+ pcd = np.concatenate([xyz, rgb], axis=1)
+ if COVISIBILITY_CHECK:
+ return pcd, point_meta
+ else:
+ return pcd
+
+ @classmethod
+ def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, crop_cam):
+ captures = []
+ cameras = cls.read_cameras(cameras_txt_path)
+ images_meta = cls.read_images_meta(images_txt_path, images_dir)
+ for key in images_meta.keys():
+ cur_cam_id = images_meta[key].camera_id
+ cur_cam = cameras[cur_cam_id]
+ cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
+ cur_image_path = images_meta[key].image_path
+ cap = RGBPinholeCapture(cur_image_path, cur_cam, cur_camera_pose, crop_cam)
+ captures.append(cap)
+ return captures
+
+ @classmethod
+ def read_cameras(cls, cameras_txt_path):
+ cameras = {}
+ with open(cameras_txt_path, "r") as fid:
+ line = fid.readline()
+ assert line == '# Camera list with one line of data per camera:\n'
+ line = fid.readline()
+ assert line == '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n'
+ line = fid.readline()
+ assert re.search('^# Number of cameras: \d+\n$', line)
+ num_cams = int(re.findall(r"[-+]?\d*\.\d+|\d+", line)[0])
+
+ for _ in tqdm(range(num_cams), desc='reading cameras'):
+ elems = fid.readline().split()
+ camera_id = int(elems[0])
+ camera_type = elems[1]
+ if camera_type == "PINHOLE":
+ width, height, focal_length_x, focal_length_y, cx, cy = list(map(float, elems[2:8]))
+ else:
+ raise ValueError('Please rectify the 3D model to pinhole camera.')
+ cur_cam = PinholeCamera(width, height, focal_length_x, focal_length_y, cx, cy)
+ assert camera_id not in cameras
+ cameras[camera_id] = cur_cam
+ return cameras
+
+ @classmethod
+ def read_images_meta(cls, images_txt_path, images_dir):
+ images_meta = {}
+ with open(images_txt_path, "r") as fid:
+ line = fid.readline()
+ assert line == '# Image list with two lines of data per image:\n'
+ line = fid.readline()
+ assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
+ line = fid.readline()
+ assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
+ line = fid.readline()
+ assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line)
+ num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line)
+ num_images = int(num_images)
+ mean_ob_per_img = float(mean_ob_per_img)
+
+ for _ in tqdm(range(num_images), desc='reading images meta'):
+ elems = fid.readline().split()
+ assert len(elems) == 10
+
+ image_path = os.path.join(images_dir, elems[9])
+ assert os.path.isfile(image_path)
+ image_id = int(elems[0])
+ qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8]))
+ t = Translation(np.array([tx, ty, tz], dtype=np.float32))
+ r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32))
+ camera_id = int(elems[8])
+ assert image_id not in images_meta
+
+ line = fid.readline()
+ if COVISIBILITY_CHECK:
+ elems = line.split()
+ elems = list(map(float, elems))
+ elems = np.array(elems).reshape(-1, 3)
+ point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int))
+ point3d_id = np.sort(np.array(list(point3d_id)))
+ xyi = elems[elems[:, 2] != -1]
+ xy = xyi[:, :2]
+ idx = xyi[:, 2].astype(np.int)
+ p2d = Point2D(idx, xy)
+ else:
+ point3d_id = None
+ p2d = None
+
+ images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d)
+ return images_meta
+
+
+class ColmapWithDepthAsciiReader(ColmapAsciiReader):
+ '''
+ Not all images have usable depth estimate from colmap.
+ A valid list is needed.
+ '''
+
+ @classmethod
+ def read_sfm_scene(cls, scene_dir, images_dir, depth_dir, crop_cam):
+ point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
+ cameras_path = os.path.join(scene_dir, 'cameras.txt')
+ images_path = os.path.join(scene_dir, 'images.txt')
+ captures = cls.read_captures(images_path, cameras_path, images_dir, depth_dir, crop_cam)
+ if LOAD_PCD:
+ point_cloud = cls.read_point_cloud(point_cloud_path)
+ else:
+ point_cloud = None
+ sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
+ return sfm_scene
+
+ @classmethod
+ def read_sfm_scene_given_valid_list_path(cls, scene_dir, images_dir, depth_dir, valid_list_json_path, crop_cam):
+ point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
+ cameras_path = os.path.join(scene_dir, 'cameras.txt')
+ images_path = os.path.join(scene_dir, 'images.txt')
+ valid_list = cls.read_valid_list(valid_list_json_path)
+ captures = cls.read_captures_with_depth_given_valid_list(images_path, cameras_path, images_dir, depth_dir, valid_list, crop_cam)
+ if LOAD_PCD:
+ point_cloud = cls.read_point_cloud(point_cloud_path)
+ else:
+ point_cloud = None
+ sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
+ return sfm_scene
+
+ @classmethod
+ def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, crop_cam):
+ captures = []
+ cameras = cls.read_cameras(cameras_txt_path)
+ images_meta = cls.read_images_meta(images_txt_path, images_dir)
+ for key in images_meta.keys():
+ cur_cam_id = images_meta[key].camera_id
+ cur_cam = cameras[cur_cam_id]
+ cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
+ cur_image_path = images_meta[key].image_path
+ try:
+ cur_depth_path = cls.image_path_2_depth_path(cur_image_path[len(images_dir) + 1:], depth_dir)
+ except:
+ print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir))
+ # TODO
+ # continue
+ # exec(debug_utils.embed_breakpoint())
+ cur_depth_path = f'{cur_image_path}dummy'
+
+ cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam)
+ cap.point3d_id = images_meta[key].point3d_id
+ cap.p2d = images_meta[key].p2d
+ cap.image_id = key
+ captures.append(cap)
+ return captures
+
+ @classmethod
+ def read_captures_with_depth_given_valid_list(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, valid_list, crop_cam):
+ captures = []
+ cameras = cls.read_cameras(cameras_txt_path)
+ images_meta = cls.read_images_meta_given_valid_list(images_txt_path, images_dir, valid_list)
+ for key in images_meta.keys():
+ cur_cam_id = images_meta[key].camera_id
+ cur_cam = cameras[cur_cam_id]
+ cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
+ cur_image_path = images_meta[key].image_path
+ try:
+ cur_depth_path = cls.image_path_2_depth_path(cur_image_path, depth_dir)
+ except:
+ print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir))
+ continue
+ cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam)
+ cap.point3d_id = images_meta[key].point3d_id
+ cap.p2d = images_meta[key].p2d
+ cap.image_id = key
+ captures.append(cap)
+ return captures
+
+ @classmethod
+ def read_images_meta_given_valid_list(cls, images_txt_path, images_dir, valid_list):
+ images_meta = {}
+ with open(images_txt_path, "r") as fid:
+ line = fid.readline()
+ assert line == '# Image list with two lines of data per image:\n'
+ line = fid.readline()
+ assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
+ line = fid.readline()
+ assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
+ line = fid.readline()
+ assert re.search('^# Number of images: \d+, mean observations per image:[-+]?\d*\.\d+|\d+\n$', line), line
+ num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line)
+ num_images = int(num_images)
+ mean_ob_per_img = float(mean_ob_per_img)
+
+ for _ in tqdm(range(num_images), desc='reading images meta'):
+ elems = fid.readline().split()
+ assert len(elems) == 10
+ line = fid.readline()
+ image_path = os.path.join(images_dir, elems[9])
+ prefix = os.path.abspath(os.path.join(image_path, '../../../../')) + '/'
+ rel_image_path = image_path.replace(prefix, '')
+ if rel_image_path not in valid_list:
+ continue
+ assert os.path.isfile(image_path), '{0} is not existing'.format(image_path)
+ image_id = int(elems[0])
+ qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8]))
+ t = Translation(np.array([tx, ty, tz], dtype=np.float32))
+ r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32))
+ camera_id = int(elems[8])
+ assert image_id not in images_meta
+
+ if COVISIBILITY_CHECK:
+ elems = line.split()
+ elems = list(map(float, elems))
+ elems = np.array(elems).reshape(-1, 3)
+ point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int))
+ point3d_id = np.sort(np.array(list(point3d_id)))
+ xyi = elems[elems[:, 2] != -1]
+ xy = xyi[:, :2]
+ idx = xyi[:, 2].astype(np.int)
+ p2d = Point2D(idx, xy)
+ else:
+ point3d_id = None
+ p2d = None
+ images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d)
+ return images_meta
+
+ @classmethod
+ def read_valid_list(cls, valid_list_json_path):
+ assert os.path.isfile(valid_list_json_path), valid_list_json_path
+ with open(valid_list_json_path, 'r') as f:
+ valid_list = json.load(f)
+ assert len(valid_list) == len(set(valid_list))
+ return set(valid_list)
+
+ @classmethod
+ def image_path_2_depth_path(cls, image_path, depth_dir):
+ depth_file = os.path.splitext(os.path.basename(image_path))[0] + '.h5'
+ depth_path = os.path.join(depth_dir, depth_file)
+ if not os.path.isfile(depth_path):
+ # depth_file = image_path + '.photometric.bin'
+ depth_file = image_path + '.geometric.bin'
+ depth_path = os.path.join(depth_dir, depth_file)
+ assert os.path.isfile(depth_path), '{0} is not file'.format(depth_path)
+ return depth_path
diff --git a/imcui/third_party/COTR/COTR/datasets/cotr_dataset.py b/imcui/third_party/COTR/COTR/datasets/cotr_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99cfa54fc30fcad0ed91674912cbc22f4cb4564
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/datasets/cotr_dataset.py
@@ -0,0 +1,243 @@
+'''
+COTR dataset
+'''
+
+import random
+
+import numpy as np
+import torch
+from torchvision.transforms import functional as tvtf
+from torch.utils import data
+
+from COTR.datasets import megadepth_dataset
+from COTR.utils import debug_utils, utils, constants
+from COTR.projector import pcd_projector
+from COTR.cameras import capture
+from COTR.utils.utils import CropCamConfig
+from COTR.inference import inference_helper
+from COTR.inference.inference_helper import two_images_side_by_side
+
+
+class COTRDataset(data.Dataset):
+ def __init__(self, opt, dataset_type: str):
+ assert dataset_type in ['train', 'val', 'test']
+ assert len(opt.scenes_name_list) > 0
+ self.opt = opt
+ self.dataset_type = dataset_type
+ self.sfm_dataset = megadepth_dataset.MegadepthDataset(opt, dataset_type)
+
+ self.kp_pool = opt.kp_pool
+ self.num_kp = opt.num_kp
+ self.bidirectional = opt.bidirectional
+ self.need_rotation = opt.need_rotation
+ self.max_rotation = opt.max_rotation
+ self.rotation_chance = opt.rotation_chance
+
+ def _trim_corrs(self, in_corrs):
+ length = in_corrs.shape[0]
+ if length >= self.num_kp:
+ mask = np.random.choice(length, self.num_kp)
+ return in_corrs[mask]
+ else:
+ mask = np.random.choice(length, self.num_kp - length)
+ return np.concatenate([in_corrs, in_corrs[mask]], axis=0)
+
+ def __len__(self):
+ if self.dataset_type == 'val':
+ return min(1000, self.sfm_dataset.num_queries)
+ else:
+ return self.sfm_dataset.num_queries
+
+ def augment_with_rotation(self, query_cap, nn_cap):
+ if random.random() < self.rotation_chance:
+ theta = np.random.uniform(low=-1, high=1) * self.max_rotation
+ query_cap = capture.rotate_capture(query_cap, theta)
+ if random.random() < self.rotation_chance:
+ theta = np.random.uniform(low=-1, high=1) * self.max_rotation
+ nn_cap = capture.rotate_capture(nn_cap, theta)
+ return query_cap, nn_cap
+
+ def __getitem__(self, index):
+ assert self.opt.k_size == 1
+ query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index)
+ nn_cap = nn_caps[0]
+
+ if self.need_rotation:
+ query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap)
+
+ nn_keypoints_y, nn_keypoints_x = np.where(nn_cap.depth_map > 0)
+ nn_keypoints_y = nn_keypoints_y[..., None]
+ nn_keypoints_x = nn_keypoints_x[..., None]
+ nn_keypoints_z = nn_cap.depth_map[np.floor(nn_keypoints_y).astype('int'), np.floor(nn_keypoints_x).astype('int')]
+ nn_keypoints_xy = np.concatenate([nn_keypoints_x, nn_keypoints_y], axis=1)
+ nn_keypoints_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(nn_keypoints_xy, nn_keypoints_z, nn_cap.pinhole_cam.intrinsic_mat, motion=nn_cap.cam_pose.camera_to_world, return_index=True)
+
+ query_keypoints_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(
+ nn_keypoints_3d_world,
+ query_cap.pinhole_cam.intrinsic_mat,
+ query_cap.cam_pose.world_to_camera[0:3, :],
+ query_cap.image.shape[:2],
+ keep_z=True,
+ crop=True,
+ filter_neg=True,
+ norm_coord=False,
+ return_index=True,
+ )
+ query_keypoints_xy = query_keypoints_xyz[:, 0:2]
+ query_keypoints_z_proj = query_keypoints_xyz[:, 2:3]
+ query_keypoints_z = query_cap.depth_map[np.floor(query_keypoints_xy[:, 1:2]).astype('int'), np.floor(query_keypoints_xy[:, 0:1]).astype('int')]
+ mask = (abs(query_keypoints_z - query_keypoints_z_proj) < 0.5)[:, 0]
+ query_keypoints_xy = query_keypoints_xy[mask]
+
+ if query_keypoints_xy.shape[0] < self.num_kp:
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
+
+ nn_keypoints_xy = nn_keypoints_xy[valid_index_1][valid_index_2][mask]
+ assert nn_keypoints_xy.shape == query_keypoints_xy.shape
+ corrs = np.concatenate([query_keypoints_xy, nn_keypoints_xy], axis=1)
+ corrs = self._trim_corrs(corrs)
+ # flip augmentation
+ if np.random.uniform() < 0.5:
+ corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0]
+ corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2]
+ sbs_img = two_images_side_by_side(np.fliplr(query_cap.image), np.fliplr(nn_cap.image))
+ else:
+ sbs_img = two_images_side_by_side(query_cap.image, nn_cap.image)
+ corrs[:, 2] += constants.MAX_SIZE
+ corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
+ assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all()
+ assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all()
+ assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all()
+ assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all()
+ out = {
+ 'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ 'corrs': torch.from_numpy(corrs).float(),
+ }
+ if self.bidirectional:
+ out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float()
+ out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float()
+ else:
+ out['queries'] = torch.from_numpy(corrs[:, :2]).float()
+ out['targets'] = torch.from_numpy(corrs[:, 2:]).float()
+ return out
+
+
+class COTRZoomDataset(COTRDataset):
+ def __init__(self, opt, dataset_type: str):
+ assert opt.crop_cam in ['no_crop', 'crop_center']
+ assert opt.use_ram == False
+ super().__init__(opt, dataset_type)
+ self.zoom_start = opt.zoom_start
+ self.zoom_end = opt.zoom_end
+ self.zoom_levels = opt.zoom_levels
+ self.zoom_jitter = opt.zoom_jitter
+ self.zooms = np.logspace(np.log10(opt.zoom_start),
+ np.log10(opt.zoom_end),
+ num=opt.zoom_levels)
+
+ def get_corrs(self, from_cap, to_cap, reduced_size=None):
+ from_y, from_x = np.where(from_cap.depth_map > 0)
+ from_y, from_x = from_y[..., None], from_x[..., None]
+ if reduced_size is not None:
+ filter_idx = np.random.choice(from_y.shape[0], reduced_size, replace=False)
+ from_y, from_x = from_y[filter_idx], from_x[filter_idx]
+ from_z = from_cap.depth_map[np.floor(from_y).astype('int'), np.floor(from_x).astype('int')]
+ from_xy = np.concatenate([from_x, from_y], axis=1)
+ from_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(from_xy, from_z, from_cap.pinhole_cam.intrinsic_mat, motion=from_cap.cam_pose.camera_to_world, return_index=True)
+
+ to_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(
+ from_3d_world,
+ to_cap.pinhole_cam.intrinsic_mat,
+ to_cap.cam_pose.world_to_camera[0:3, :],
+ to_cap.image.shape[:2],
+ keep_z=True,
+ crop=True,
+ filter_neg=True,
+ norm_coord=False,
+ return_index=True,
+ )
+
+ to_xy = to_xyz[:, 0:2]
+ to_z_proj = to_xyz[:, 2:3]
+ to_z = to_cap.depth_map[np.floor(to_xy[:, 1:2]).astype('int'), np.floor(to_xy[:, 0:1]).astype('int')]
+ mask = (abs(to_z - to_z_proj) < 0.5)[:, 0]
+ if mask.sum() > 0:
+ return np.concatenate([from_xy[valid_index_1][valid_index_2][mask], to_xy[mask]], axis=1)
+ else:
+ return None
+
+ def get_seed_corr(self, from_cap, to_cap, max_try=100):
+ seed_corr = self.get_corrs(from_cap, to_cap, reduced_size=max_try)
+ if seed_corr is None:
+ return None
+ shuffle = np.random.permutation(seed_corr.shape[0])
+ seed_corr = np.take(seed_corr, shuffle, axis=0)
+ return seed_corr[0]
+
+ def get_zoomed_cap(self, cap, pos, scale, jitter):
+ patch = inference_helper.get_patch_centered_at(cap.image, pos, scale=scale, return_content=False)
+ patch = inference_helper.get_patch_centered_at(cap.image,
+ pos + np.array([patch.w, patch.h]) * np.random.uniform(-jitter, jitter, 2),
+ scale=scale,
+ return_content=False)
+ zoom_config = CropCamConfig(x=patch.x,
+ y=patch.y,
+ w=patch.w,
+ h=patch.h,
+ out_w=constants.MAX_SIZE,
+ out_h=constants.MAX_SIZE,
+ orig_w=cap.shape[1],
+ orig_h=cap.shape[0])
+ zoom_cap = capture.crop_capture(cap, zoom_config)
+ return zoom_cap
+
+ def __getitem__(self, index):
+ assert self.opt.k_size == 1
+ query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index)
+ nn_cap = nn_caps[0]
+ if self.need_rotation:
+ query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap)
+
+ # find seed
+ seed_corr = self.get_seed_corr(nn_cap, query_cap)
+ if seed_corr is None:
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
+
+ # crop cap
+ s = np.random.choice(self.zooms)
+ nn_zoom_cap = self.get_zoomed_cap(nn_cap, seed_corr[:2], s, 0)
+ query_zoom_cap = self.get_zoomed_cap(query_cap, seed_corr[2:], s, self.zoom_jitter)
+ assert nn_zoom_cap.shape == query_zoom_cap.shape == (constants.MAX_SIZE, constants.MAX_SIZE)
+ corrs = self.get_corrs(query_zoom_cap, nn_zoom_cap)
+ if corrs is None or corrs.shape[0] < self.num_kp:
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
+ shuffle = np.random.permutation(corrs.shape[0])
+ corrs = np.take(corrs, shuffle, axis=0)
+ corrs = self._trim_corrs(corrs)
+
+ # flip augmentation
+ if np.random.uniform() < 0.5:
+ corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0]
+ corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2]
+ sbs_img = two_images_side_by_side(np.fliplr(query_zoom_cap.image), np.fliplr(nn_zoom_cap.image))
+ else:
+ sbs_img = two_images_side_by_side(query_zoom_cap.image, nn_zoom_cap.image)
+
+ corrs[:, 2] += constants.MAX_SIZE
+ corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
+ assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all()
+ assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all()
+ assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all()
+ assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all()
+ out = {
+ 'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ 'corrs': torch.from_numpy(corrs).float(),
+ }
+ if self.bidirectional:
+ out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float()
+ out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float()
+ else:
+ out['queries'] = torch.from_numpy(corrs[:, :2]).float()
+ out['targets'] = torch.from_numpy(corrs[:, 2:]).float()
+
+ return out
diff --git a/imcui/third_party/COTR/COTR/datasets/megadepth_dataset.py b/imcui/third_party/COTR/COTR/datasets/megadepth_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc1ac93e369f6ee9998b693f383581403d3369c4
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/datasets/megadepth_dataset.py
@@ -0,0 +1,140 @@
+'''
+dataset specific layer for megadepth
+'''
+
+import os
+import json
+import random
+from collections import namedtuple
+
+import numpy as np
+
+from COTR.datasets import colmap_helper
+from COTR.global_configs import dataset_config
+from COTR.sfm_scenes import knn_search
+from COTR.utils import debug_utils, utils, constants
+
+SceneCapIndex = namedtuple('SceneCapIndex', ['scene_index', 'capture_index'])
+
+
+def prefix_of_img_path_for_magedepth(img_path):
+ '''
+ get the prefix for image of megadepth dataset
+ '''
+ prefix = os.path.abspath(os.path.join(img_path, '../../../..')) + '/'
+ return prefix
+
+
+class MegadepthSceneDataBase():
+ scenes = {}
+ knn_engine_dict = {}
+
+ @classmethod
+ def _load_scene(cls, opt, scene_dir_dict):
+ if scene_dir_dict['scene_dir'] not in cls.scenes:
+ if opt.info_level == 'rgb':
+ assert 0
+ elif opt.info_level == 'rgbd':
+ scene_dir = scene_dir_dict['scene_dir']
+ images_dir = scene_dir_dict['image_dir']
+ depth_dir = scene_dir_dict['depth_dir']
+ scene = colmap_helper.ColmapWithDepthAsciiReader.read_sfm_scene_given_valid_list_path(scene_dir, images_dir, depth_dir, dataset_config[opt.dataset_name]['valid_list_json'], opt.crop_cam)
+ if opt.use_ram:
+ scene.read_data_to_ram(['image', 'depth'])
+ else:
+ raise ValueError()
+ knn_engine = knn_search.ReprojRatioKnnSearch(scene)
+ cls.scenes[scene_dir_dict['scene_dir']] = scene
+ cls.knn_engine_dict[scene_dir_dict['scene_dir']] = knn_engine
+ else:
+ pass
+
+
+class MegadepthDataset():
+
+ def __init__(self, opt, dataset_type):
+ assert dataset_type in ['train', 'val', 'test']
+ assert len(opt.scenes_name_list) > 0
+ self.opt = opt
+ self.dataset_type = dataset_type
+ self.use_ram = opt.use_ram
+ self.scenes_name_list = opt.scenes_name_list
+ self.scenes = None
+ self.knn_engine_list = None
+ self.total_caps_set = None
+ self.query_caps_set = None
+ self.db_caps_set = None
+ self.img_path_to_scene_cap_index_dict = {}
+ self.scene_index_to_db_caps_mask_dict = {}
+ self._load_scenes()
+
+ @property
+ def num_scenes(self):
+ return len(self.scenes)
+
+ @property
+ def num_queries(self):
+ return len(self.query_caps_set)
+
+ @property
+ def num_db(self):
+ return len(self.db_caps_set)
+
+ def get_scene_cap_index_by_index(self, index):
+ assert index < len(self.query_caps_set)
+ img_path = sorted(list(self.query_caps_set))[index]
+ scene_cap_index = self.img_path_to_scene_cap_index_dict[img_path]
+ return scene_cap_index
+
+ def _get_common_subset_caps_from_json(self, json_path, total_caps):
+ prefix = prefix_of_img_path_for_magedepth(list(total_caps)[0])
+ with open(json_path, 'r') as f:
+ common_caps = [prefix + cap for cap in json.load(f)]
+ common_caps = set(total_caps) & set(common_caps)
+ return common_caps
+
+ def _extend_img_path_to_scene_cap_index_dict(self, img_path_to_cap_index_dict, scene_id):
+ for key in img_path_to_cap_index_dict.keys():
+ self.img_path_to_scene_cap_index_dict[key] = SceneCapIndex(scene_id, img_path_to_cap_index_dict[key])
+
+ def _create_scene_index_to_db_caps_mask_dict(self, db_caps_set):
+ scene_index_to_db_caps_mask_dict = {}
+ for cap in db_caps_set:
+ scene_id, cap_id = self.img_path_to_scene_cap_index_dict[cap]
+ if scene_id not in scene_index_to_db_caps_mask_dict:
+ scene_index_to_db_caps_mask_dict[scene_id] = []
+ scene_index_to_db_caps_mask_dict[scene_id].append(cap_id)
+ for _k, _v in scene_index_to_db_caps_mask_dict.items():
+ scene_index_to_db_caps_mask_dict[_k] = np.array(sorted(_v))
+ return scene_index_to_db_caps_mask_dict
+
+ def _load_scenes(self):
+ scenes = []
+ knn_engine_list = []
+ total_caps_set = set()
+ for scene_id, scene_dir_dict in enumerate(self.scenes_name_list):
+ MegadepthSceneDataBase._load_scene(self.opt, scene_dir_dict)
+ scene = MegadepthSceneDataBase.scenes[scene_dir_dict['scene_dir']]
+ knn_engine = MegadepthSceneDataBase.knn_engine_dict[scene_dir_dict['scene_dir']]
+ total_caps_set = total_caps_set | set(scene.img_path_to_index_dict.keys())
+ self._extend_img_path_to_scene_cap_index_dict(scene.img_path_to_index_dict, scene_id)
+ scenes.append(scene)
+ knn_engine_list.append(knn_engine)
+ self.scenes = scenes
+ self.knn_engine_list = knn_engine_list
+ self.total_caps_set = total_caps_set
+ self.query_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name][f'{self.dataset_type}_json'], total_caps_set)
+ self.db_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name]['train_json'], total_caps_set)
+ self.scene_index_to_db_caps_mask_dict = self._create_scene_index_to_db_caps_mask_dict(self.db_caps_set)
+
+ def get_query_with_knn(self, index):
+ scene_index, cap_index = self.get_scene_cap_index_by_index(index)
+ query_cap = self.scenes[scene_index].captures[cap_index]
+ knn_engine = self.knn_engine_list[scene_index]
+ if scene_index in self.scene_index_to_db_caps_mask_dict:
+ db_mask = self.scene_index_to_db_caps_mask_dict[scene_index]
+ else:
+ db_mask = None
+ pool = knn_engine.get_knn(query_cap, self.opt.pool_size, db_mask=db_mask)
+ nn_caps = random.sample(pool, min(len(pool), self.opt.k_size))
+ return query_cap, nn_caps
diff --git a/imcui/third_party/COTR/COTR/global_configs/__init__.py b/imcui/third_party/COTR/COTR/global_configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6db45547cfe941e399571c30b8645a5aa3931368
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/global_configs/__init__.py
@@ -0,0 +1,10 @@
+import os
+import json
+
+__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
+with open(os.path.join(__location__, 'dataset_config.json'), 'r') as f:
+ dataset_config = json.load(f)
+with open(os.path.join(__location__, 'commons.json'), 'r') as f:
+ general_config = json.load(f)
+# assert os.path.isdir(general_config['out']), f'Please create {general_config["out"]}'
+# assert os.path.isdir(general_config['tb_out']), f'Please create {general_config["tb_out"]}'
diff --git a/imcui/third_party/COTR/COTR/inference/inference_helper.py b/imcui/third_party/COTR/COTR/inference/inference_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c477d1f6c60ff0e00a491c1daa3e105c7a59cdf
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/inference/inference_helper.py
@@ -0,0 +1,311 @@
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from torchvision.transforms import functional as tvtf
+from tqdm import tqdm
+import PIL
+
+from COTR.utils import utils, debug_utils
+from COTR.utils.constants import MAX_SIZE
+from COTR.cameras.capture import crop_center_max_np, pad_to_square_np
+from COTR.utils.utils import ImagePatch
+
+THRESHOLD_SPARSE = 0.02
+THRESHOLD_PIXELS_RELATIVE = 0.02
+BASE_ZOOM = 1.0
+THRESHOLD_AREA = 0.02
+LARGE_GPU = True
+
+
+def find_prediction_loop(arr):
+ '''
+ loop ends at last element
+ '''
+ assert arr.shape[1] == 2, 'requires shape (N, 2)'
+ start_index = np.where(np.prod(arr[:-1] == arr[-1], axis=1))[0][0]
+ return arr[start_index:-1]
+
+
+def two_images_side_by_side(img_a, img_b):
+ assert img_a.shape == img_b.shape, f'{img_a.shape} vs {img_b.shape}'
+ assert img_a.dtype == img_b.dtype
+ h, w, c = img_a.shape
+ canvas = np.zeros((h, 2 * w, c), dtype=img_a.dtype)
+ canvas[:, 0 * w:1 * w, :] = img_a
+ canvas[:, 1 * w:2 * w, :] = img_b
+ return canvas
+
+
+def to_square_patches(img):
+ patches = []
+ h, w, _ = img.shape
+ short = size = min(h, w)
+ long = max(h, w)
+ if long == short:
+ patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h)
+ patches = [patch_0]
+ elif long <= size * 2:
+ warnings.warn('Spatial smoothness in dense optical flow is lost, but sparse matching and triangulation should be fine')
+ patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h)
+ patch_1 = ImagePatch(img[-size:, -size:], w - size, h - size, size, size, w, h)
+ patches = [patch_0, patch_1]
+ # patches += subdivide_patch(patch_0)
+ # patches += subdivide_patch(patch_1)
+ else:
+ raise NotImplementedError
+ return patches
+
+
+def merge_flow_patches(corrs):
+ confidence = np.ones([corrs[0].oh, corrs[0].ow]) * 100
+ flow = np.zeros([corrs[0].oh, corrs[0].ow, 2])
+ cmap = np.ones([corrs[0].oh, corrs[0].ow]) * -1
+ for i, c in enumerate(corrs):
+ temp = np.ones([c.oh, c.ow]) * 100
+ temp[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., 2]
+ tempf = np.zeros([c.oh, c.ow, 2])
+ tempf[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., :2]
+ min_ind = np.stack([temp, confidence], axis=-1).argmin(axis=-1)
+ min_ind = min_ind == 0
+ confidence[min_ind] = temp[min_ind]
+ flow[min_ind] = tempf[min_ind]
+ cmap[min_ind] = i
+ return flow, confidence, cmap
+
+
+def get_patch_centered_at(img, pos, scale=1.0, return_content=True, img_shape=None):
+ '''
+ pos - [x, y]
+ '''
+ if img_shape is None:
+ img_shape = img.shape
+ h, w, _ = img_shape
+ short = min(h, w)
+ scale = np.clip(scale, 0.0, 1.0)
+ size = short * scale
+ size = int((size // 2) * 2)
+ lu_y = int(pos[1] - size // 2)
+ lu_x = int(pos[0] - size // 2)
+ if lu_y < 0:
+ lu_y -= lu_y
+ if lu_x < 0:
+ lu_x -= lu_x
+ if lu_y + size > h:
+ lu_y -= (lu_y + size) - (h)
+ if lu_x + size > w:
+ lu_x -= (lu_x + size) - (w)
+ if return_content:
+ return ImagePatch(img[lu_y:lu_y + size, lu_x:lu_x + size], lu_x, lu_y, size, size, w, h)
+ else:
+ return ImagePatch(None, lu_x, lu_y, size, size, w, h)
+
+
+def cotr_patch_flow_exhaustive(model, patches_a, patches_b):
+ def one_pass(model, img_a, img_b):
+ device = next(model.parameters()).device
+ assert img_a.shape[0] == img_a.shape[1]
+ assert img_b.shape[0] == img_b.shape[1]
+ img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
+ img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
+ img = two_images_side_by_side(img_a, img_b)
+ img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None]
+ img = img.to(device)
+
+ q_list = []
+ for i in range(MAX_SIZE):
+ queries = []
+ for j in range(MAX_SIZE * 2):
+ queries.append([(j) / (MAX_SIZE * 2), i / MAX_SIZE])
+ queries = np.array(queries)
+ q_list.append(queries)
+ if LARGE_GPU:
+ try:
+ queries = torch.from_numpy(np.concatenate(q_list))[None].float().to(device)
+ out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0]
+ out_list = out.reshape(MAX_SIZE, MAX_SIZE * 2, -1)
+ except:
+ assert 0, 'set LARGE_GPU to False'
+ else:
+ out_list = []
+ for q in q_list:
+ queries = torch.from_numpy(q)[None].float().to(device)
+ out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0]
+ out_list.append(out)
+ out_list = np.array(out_list)
+ in_grid = torch.from_numpy(np.array(q_list)).float()[None] * 2 - 1
+ out_grid = torch.from_numpy(out_list).float()[None] * 2 - 1
+ cycle_grid = torch.nn.functional.grid_sample(out_grid.permute(0, 3, 1, 2), out_grid).permute(0, 2, 3, 1)
+ confidence = torch.norm(cycle_grid[0, ...] - in_grid[0, ...], dim=-1)
+ corr = out_grid[0].clone()
+ corr[:, :MAX_SIZE, 0] = corr[:, :MAX_SIZE, 0] * 2 - 1
+ corr[:, MAX_SIZE:, 0] = corr[:, MAX_SIZE:, 0] * 2 + 1
+ corr = torch.cat([corr, confidence[..., None]], dim=-1).numpy()
+ return corr[:, :MAX_SIZE, :], corr[:, MAX_SIZE:, :]
+ corrs_a = []
+ corrs_b = []
+
+ for p_i in patches_a:
+ for p_j in patches_b:
+ c_i, c_j = one_pass(model, p_i.patch, p_j.patch)
+ base_corners = np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]])
+ real_corners_j = (np.array([[p_j.x, p_j.y], [p_j.x + p_j.w, p_j.y], [p_j.x + p_j.w, p_j.y + p_j.h], [p_j.x, p_j.y + p_j.h]]) / np.array([p_j.ow, p_j.oh])) * 2 + np.array([-1, -1])
+ real_corners_i = (np.array([[p_i.x, p_i.y], [p_i.x + p_i.w, p_i.y], [p_i.x + p_i.w, p_i.y + p_i.h], [p_i.x, p_i.y + p_i.h]]) / np.array([p_i.ow, p_i.oh])) * 2 + np.array([-1, -1])
+ T_i = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_j[:3].astype(np.float32))
+ T_j = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_i[:3].astype(np.float32))
+ c_i[..., :2] = c_i[..., :2] @ T_i[:2, :2] + T_i[:, 2]
+ c_j[..., :2] = c_j[..., :2] @ T_j[:2, :2] + T_j[:, 2]
+ c_i = utils.float_image_resize(c_i, (p_i.h, p_i.w))
+ c_j = utils.float_image_resize(c_j, (p_j.h, p_j.w))
+ c_i = ImagePatch(c_i, p_i.x, p_i.y, p_i.w, p_i.h, p_i.ow, p_i.oh)
+ c_j = ImagePatch(c_j, p_j.x, p_j.y, p_j.w, p_j.h, p_j.ow, p_j.oh)
+ corrs_a.append(c_i)
+ corrs_b.append(c_j)
+ return corrs_a, corrs_b
+
+
+def cotr_flow(model, img_a, img_b):
+ # assert img_a.shape[0] == img_a.shape[1]
+ # assert img_b.shape[0] == img_b.shape[1]
+ patches_a = to_square_patches(img_a)
+ patches_b = to_square_patches(img_b)
+
+ corrs_a, corrs_b = cotr_patch_flow_exhaustive(model, patches_a, patches_b)
+ corr_a, con_a, cmap_a = merge_flow_patches(corrs_a)
+ corr_b, con_b, cmap_b = merge_flow_patches(corrs_b)
+
+ resample_a = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_b)[None].float(),
+ torch.from_numpy(corr_a)[None].float())[0])
+ resample_b = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_a)[None].float(),
+ torch.from_numpy(corr_b)[None].float())[0])
+ return corr_a, con_a, resample_a, corr_b, con_b, resample_b
+
+
+def cotr_corr_base(model, img_a, img_b, queries_a):
+ def one_pass(model, img_a, img_b, queries):
+ device = next(model.parameters()).device
+ assert img_a.shape[0] == img_a.shape[1]
+ assert img_b.shape[0] == img_b.shape[1]
+ img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
+ img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
+ img = two_images_side_by_side(img_a, img_b)
+ img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None]
+ img = img.to(device)
+
+ queries = torch.from_numpy(queries)[None].float().to(device)
+ out = model.forward(img, queries)['pred_corrs'].clone().detach()
+ cycle = model.forward(img, out)['pred_corrs'].clone().detach()
+
+ queries = queries.cpu().numpy()[0]
+ out = out.cpu().numpy()[0]
+ cycle = cycle.cpu().numpy()[0]
+ conf = np.linalg.norm(queries - cycle, axis=1, keepdims=True)
+ return np.concatenate([out, conf], axis=1)
+
+ patches_a = to_square_patches(img_a)
+ patches_b = to_square_patches(img_b)
+ pred_list = []
+
+ for p_i in patches_a:
+ for p_j in patches_b:
+ normalized_queries_a = queries_a.copy()
+ mask = (normalized_queries_a[:, 0] >= p_i.x) & (normalized_queries_a[:, 1] >= p_i.y) & (normalized_queries_a[:, 0] <= p_i.x + p_i.w) & (normalized_queries_a[:, 1] <= p_i.y + p_i.h)
+ normalized_queries_a[:, 0] -= p_i.x
+ normalized_queries_a[:, 1] -= p_i.y
+ normalized_queries_a[:, 0] /= 2 * p_i.w
+ normalized_queries_a[:, 1] /= p_i.h
+ pred = one_pass(model, p_i.patch, p_j.patch, normalized_queries_a)
+ pred[~mask, 2] = np.inf
+ pred[:, 0] -= 0.5
+ pred[:, 0] *= 2 * p_j.w
+ pred[:, 0] += p_j.x
+ pred[:, 1] *= p_j.h
+ pred[:, 1] += p_j.y
+ pred_list.append(pred)
+
+ pred_list = np.stack(pred_list).transpose(1, 0, 2)
+ out = []
+ for item in pred_list:
+ out.append(item[np.argmin(item[..., 2], axis=0)])
+ out = np.array(out)[..., :2]
+ return np.concatenate([queries_a, out], axis=1)
+
+
+try:
+ from vispy import gloo
+ from vispy import app
+ from vispy.util.ptime import time
+ from scipy.spatial import Delaunay
+ from vispy.gloo.wrappers import read_pixels
+
+ app.use_app('glfw')
+
+
+ vertex_shader = """
+ attribute vec4 color;
+ attribute vec2 position;
+ varying vec4 v_color;
+ void main()
+ {
+ gl_Position = vec4(position, 0.0, 1.0);
+ v_color = color;
+ } """
+
+ fragment_shader = """
+ varying vec4 v_color;
+ void main()
+ {
+ gl_FragColor = v_color;
+ } """
+
+
+ class Canvas(app.Canvas):
+ def __init__(self, mesh, color, size):
+ # We hide the canvas upon creation.
+ app.Canvas.__init__(self, show=False, size=size)
+ self._t0 = time()
+ # Texture where we render the scene.
+ self._rendertex = gloo.Texture2D(shape=self.size[::-1] + (4,), internalformat='rgba32f')
+ # FBO.
+ self._fbo = gloo.FrameBuffer(self._rendertex,
+ gloo.RenderBuffer(self.size[::-1]))
+ # Regular program that will be rendered to the FBO.
+ self.program = gloo.Program(vertex_shader, fragment_shader)
+ self.program["position"] = mesh
+ self.program['color'] = color
+ # We manually draw the hidden canvas.
+ self.update()
+
+ def on_draw(self, event):
+ # Render in the FBO.
+ with self._fbo:
+ gloo.clear('black')
+ gloo.set_viewport(0, 0, *self.size)
+ self.program.draw()
+ # Retrieve the contents of the FBO texture.
+ self.im = read_pixels((0, 0, self.size[0], self.size[1]), True, out_type='float')
+ self._time = time() - self._t0
+ # Immediately exit the application.
+ app.quit()
+
+
+ def triangulate_corr(corr, from_shape, to_shape):
+ corr = corr.copy()
+ to_shape = to_shape[:2]
+ from_shape = from_shape[:2]
+ corr = corr / np.concatenate([from_shape[::-1], to_shape[::-1]])
+ tri = Delaunay(corr[:, :2])
+ mesh = corr[:, :2][tri.simplices].astype(np.float32) * 2 - 1
+ mesh[..., 1] *= -1
+ color = corr[:, 2:][tri.simplices].astype(np.float32)
+ color = np.concatenate([color, np.ones_like(color[..., 0:2])], axis=-1)
+ c = Canvas(mesh.reshape(-1, 2), color.reshape(-1, 4), size=(from_shape[::-1]))
+ app.run()
+ render = c.im.copy()
+ render = render[..., :2]
+ render *= np.array(to_shape[::-1])
+ return render
+except:
+ print('cannot use vispy, setting triangulate_corr as None')
+ triangulate_corr = None
diff --git a/imcui/third_party/COTR/COTR/inference/refinement_task.py b/imcui/third_party/COTR/COTR/inference/refinement_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..0381b91f548ff71839d363e13002d59a0618c0b6
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/inference/refinement_task.py
@@ -0,0 +1,191 @@
+import time
+
+import numpy as np
+import torch
+from torchvision.transforms import functional as tvtf
+import imageio
+import PIL
+
+from COTR.inference.inference_helper import BASE_ZOOM, THRESHOLD_PIXELS_RELATIVE, get_patch_centered_at, two_images_side_by_side, find_prediction_loop
+from COTR.utils import debug_utils, utils
+from COTR.utils.constants import MAX_SIZE
+from COTR.utils.utils import ImagePatch
+
+
+class RefinementTask():
+ def __init__(self, image_from, image_to, loc_from, loc_to, area_from, area_to, converge_iters, zoom_ins, identifier=None):
+ self.identifier = identifier
+ self.image_from = image_from
+ self.image_to = image_to
+ self.loc_from = loc_from
+ self.best_loc_to = loc_to
+ self.cur_loc_to = loc_to
+ self.area_from = area_from
+ self.area_to = area_to
+ if self.area_from < self.area_to:
+ self.s_from = BASE_ZOOM
+ self.s_to = BASE_ZOOM * np.sqrt(self.area_to / self.area_from)
+ else:
+ self.s_to = BASE_ZOOM
+ self.s_from = BASE_ZOOM * np.sqrt(self.area_from / self.area_to)
+
+ self.cur_job = {}
+ self.status = 'unfinished'
+ self.result = 'unknown'
+
+ self.converge_iters = converge_iters
+ self.zoom_ins = zoom_ins
+ self.cur_zoom_idx = 0
+ self.cur_iter = 0
+ self.total_iter = 0
+
+ self.loc_to_at_zoom = []
+ self.loc_history = [loc_to]
+ self.all_loc_to_dict = {}
+ self.job_history = []
+ self.submitted = False
+
+ @property
+ def cur_zoom(self):
+ return self.zoom_ins[self.cur_zoom_idx]
+
+ @property
+ def confidence_scaling_factor(self):
+ if self.cur_zoom_idx > 0:
+ conf_scaling = float(self.cur_zoom) / float(self.zoom_ins[0])
+ else:
+ conf_scaling = 1.0
+ return conf_scaling
+
+ def peek(self):
+ assert self.status == 'unfinished'
+ patch_from = get_patch_centered_at(None, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False, img_shape=self.image_from.shape)
+ patch_to = get_patch_centered_at(None, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False, img_shape=self.image_to.shape)
+ top_job = {'patch_from': patch_from,
+ 'patch_to': patch_to,
+ 'loc_from': self.loc_from,
+ 'loc_to': self.cur_loc_to,
+ }
+ return top_job
+
+ def get_task_pilot(self, pilot):
+ assert self.status == 'unfinished'
+ patch_from = ImagePatch(None, pilot.cur_job['patch_from'].x, pilot.cur_job['patch_from'].y, pilot.cur_job['patch_from'].w, pilot.cur_job['patch_from'].h, pilot.cur_job['patch_from'].ow, pilot.cur_job['patch_from'].oh)
+ patch_to = ImagePatch(None, pilot.cur_job['patch_to'].x, pilot.cur_job['patch_to'].y, pilot.cur_job['patch_to'].w, pilot.cur_job['patch_to'].h, pilot.cur_job['patch_to'].ow, pilot.cur_job['patch_to'].oh)
+ query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
+ self.cur_job = {'patch_from': patch_from,
+ 'patch_to': patch_to,
+ 'loc_from': self.loc_from,
+ 'loc_to': self.cur_loc_to,
+ 'img': None,
+ }
+ self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
+ assert self.submitted == False
+ self.submitted = True
+ return None, query
+
+ def get_task_fast(self):
+ assert self.status == 'unfinished'
+ patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False)
+ patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False)
+ query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
+ self.cur_job = {'patch_from': patch_from,
+ 'patch_to': patch_to,
+ 'loc_from': self.loc_from,
+ 'loc_to': self.cur_loc_to,
+ 'img': None,
+ }
+
+ self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
+ assert self.submitted == False
+ self.submitted = True
+
+ return None, query
+
+ def get_task(self):
+ assert self.status == 'unfinished'
+ patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom)
+ patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom)
+
+ query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
+
+ img_from = patch_from.patch
+ img_to = patch_to.patch
+ assert img_from.shape[0] == img_from.shape[1]
+ assert img_to.shape[0] == img_to.shape[1]
+
+ img_from = np.array(PIL.Image.fromarray(img_from).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
+ img_to = np.array(PIL.Image.fromarray(img_to).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
+ img = two_images_side_by_side(img_from, img_to)
+ img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()
+
+ self.cur_job = {'patch_from': ImagePatch(None, patch_from.x, patch_from.y, patch_from.w, patch_from.h, patch_from.ow, patch_from.oh),
+ 'patch_to': ImagePatch(None, patch_to.x, patch_to.y, patch_to.w, patch_to.h, patch_to.ow, patch_to.oh),
+ 'loc_from': self.loc_from,
+ 'loc_to': self.cur_loc_to,
+ }
+
+ self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
+ assert self.submitted == False
+ self.submitted = True
+
+ return img, query
+
+ def next_zoom(self):
+ if self.cur_zoom_idx >= len(self.zoom_ins) - 1:
+ self.status = 'finished'
+ if self.conclude() is None:
+ self.result = 'bad'
+ else:
+ self.result = 'good'
+ self.cur_zoom_idx += 1
+ self.cur_iter = 0
+ self.loc_to_at_zoom = []
+
+ def scale_to_loc(self, raw_to_loc):
+ raw_to_loc = raw_to_loc.copy()
+ patch_b = self.cur_job['patch_to']
+ raw_to_loc[0] = (raw_to_loc[0] - 0.5) * 2
+ loc_to = raw_to_loc * np.array([patch_b.w, patch_b.h])
+ loc_to = loc_to + np.array([patch_b.x, patch_b.y])
+ return loc_to
+
+ def step(self, raw_to_loc):
+ assert self.submitted == True
+ self.submitted = False
+ loc_to = self.scale_to_loc(raw_to_loc)
+ self.total_iter += 1
+ self.loc_to_at_zoom.append(loc_to)
+ self.cur_loc_to = loc_to
+ zoom_finished = False
+ if self.cur_zoom_idx == len(self.zoom_ins) - 1:
+ # converge at the last level
+ if len(self.loc_to_at_zoom) >= 2:
+ zoom_finished = np.prod(self.loc_to_at_zoom[:-1] == loc_to, axis=1, keepdims=True).any()
+ if self.cur_iter >= self.converge_iters - 1:
+ zoom_finished = True
+ self.cur_iter += 1
+ else:
+ # finish immediately for other levels
+ zoom_finished = True
+ if zoom_finished:
+ self.all_loc_to_dict[self.cur_zoom] = np.array(self.loc_to_at_zoom).copy()
+ last_level_loc_to = self.all_loc_to_dict[self.cur_zoom]
+ if len(last_level_loc_to) >= 2:
+ has_loop = np.prod(last_level_loc_to[:-1] == last_level_loc_to[-1], axis=1, keepdims=True).any()
+ if has_loop:
+ loop = find_prediction_loop(last_level_loc_to)
+ loc_to = loop.mean(axis=0)
+ self.loc_history.append(loc_to)
+ self.best_loc_to = loc_to
+ self.cur_loc_to = self.best_loc_to
+ self.next_zoom()
+
+ def conclude(self, force=False):
+ loc_history = np.array(self.loc_history)
+ if (force == False) and (max(loc_history.std(axis=0)) >= THRESHOLD_PIXELS_RELATIVE * max(*self.image_to.shape)):
+ return None
+ return np.concatenate([self.loc_from, self.best_loc_to])
+
+ def conclude_intermedia(self):
+ return np.concatenate([np.array(self.loc_history), np.array(self.job_history)], axis=1)
diff --git a/imcui/third_party/COTR/COTR/inference/sparse_engine.py b/imcui/third_party/COTR/COTR/inference/sparse_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..d814e09bfbccc78a8263fec9477ba08b5ed12e34
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/inference/sparse_engine.py
@@ -0,0 +1,427 @@
+'''
+Inference engine for sparse image pair correspondences
+'''
+
+import time
+import random
+
+import numpy as np
+import torch
+
+from COTR.inference.inference_helper import THRESHOLD_SPARSE, THRESHOLD_AREA, cotr_flow, cotr_corr_base
+from COTR.inference.refinement_task import RefinementTask
+from COTR.utils import debug_utils, utils
+from COTR.cameras.capture import stretch_to_square_np
+
+
+class SparseEngine():
+ def __init__(self, model, batch_size, mode='stretching'):
+ assert mode in ['stretching', 'tile']
+ self.model = model
+ self.batch_size = batch_size
+ self.total_tasks = 0
+ self.mode = mode
+
+ def form_batch(self, tasks, zoom=None):
+ counter = 0
+ task_ref = []
+ img_batch = []
+ query_batch = []
+ for t in tasks:
+ if t.status == 'unfinished' and t.submitted == False:
+ if zoom is not None and t.cur_zoom != zoom:
+ continue
+ task_ref.append(t)
+ img, query = t.get_task()
+ img_batch.append(img)
+ query_batch.append(query)
+ counter += 1
+ if counter >= self.batch_size:
+ break
+ if len(task_ref) == 0:
+ return [], [], []
+ img_batch = torch.stack(img_batch)
+ query_batch = torch.stack(query_batch)
+ return task_ref, img_batch, query_batch
+
+ def infer_batch(self, img_batch, query_batch):
+ self.total_tasks += img_batch.shape[0]
+ device = next(self.model.parameters()).device
+ img_batch = img_batch.to(device)
+ query_batch = query_batch.to(device)
+ out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach()
+ out = out.cpu().numpy()[:, 0, :]
+ if utils.has_nan(out):
+ raise ValueError('NaN in prediction')
+ return out
+
+ def conclude_tasks(self, tasks, return_idx=False, force=False,
+ offset_x_from=0,
+ offset_y_from=0,
+ offset_x_to=0,
+ offset_y_to=0,
+ img_a_shape=None,
+ img_b_shape=None):
+ corrs = []
+ idx = []
+ for t in tasks:
+ if t.status == 'finished':
+ out = t.conclude(force)
+ if out is not None:
+ corrs.append(np.array(out))
+ idx.append(t.identifier)
+ corrs = np.array(corrs)
+ idx = np.array(idx)
+ if corrs.shape[0] > 0:
+ corrs -= np.array([offset_x_from, offset_y_from, offset_x_to, offset_y_to])
+ if img_a_shape is not None and img_b_shape is not None and not force:
+ border_mask = np.prod(corrs < np.concatenate([img_a_shape[::-1], img_b_shape[::-1]]), axis=1)
+ border_mask = (np.prod(corrs > np.array([0, 0, 0, 0]), axis=1) * border_mask).astype(np.bool)
+ corrs = corrs[border_mask]
+ idx = idx[border_mask]
+ if return_idx:
+ return corrs, idx
+ return corrs
+
+ def num_finished_tasks(self, tasks):
+ counter = 0
+ for t in tasks:
+ if t.status == 'finished':
+ counter += 1
+ return counter
+
+ def num_good_tasks(self, tasks):
+ counter = 0
+ for t in tasks:
+ if t.result == 'good':
+ counter += 1
+ return counter
+
+ def gen_tasks_w_known_scale(self, img_a, img_b, queries_a, areas, zoom_ins=[1.0], converge_iters=1, max_corrs=1000):
+ assert self.mode == 'tile'
+ corr_a = cotr_corr_base(self.model, img_a, img_b, queries_a)
+ tasks = []
+ for c in corr_a:
+ tasks.append(RefinementTask(img_a, img_b, c[:2], c[2:], areas[0], areas[1], converge_iters, zoom_ins))
+ return tasks
+
+ def gen_tasks(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, force=False, areas=None):
+ if areas is not None:
+ assert queries_a is not None
+ assert force == True
+ assert max_corrs >= queries_a.shape[0]
+ return self.gen_tasks_w_known_scale(img_a, img_b, queries_a, areas, zoom_ins=zoom_ins, converge_iters=converge_iters, max_corrs=max_corrs)
+ if self.mode == 'stretching':
+ if img_a.shape[0] != img_a.shape[1] or img_b.shape[0] != img_b.shape[1]:
+ img_a_shape = img_a.shape
+ img_b_shape = img_b.shape
+ img_a_sq = stretch_to_square_np(img_a.copy())
+ img_b_sq = stretch_to_square_np(img_b.copy())
+ corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
+ img_a_sq,
+ img_b_sq
+ )
+ corr_a = utils.float_image_resize(corr_a, img_a_shape[:2])
+ con_a = utils.float_image_resize(con_a, img_a_shape[:2])
+ resample_a = utils.float_image_resize(resample_a, img_a_shape[:2])
+ corr_b = utils.float_image_resize(corr_b, img_b_shape[:2])
+ con_b = utils.float_image_resize(con_b, img_b_shape[:2])
+ resample_b = utils.float_image_resize(resample_b, img_b_shape[:2])
+ else:
+ corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
+ img_a,
+ img_b
+ )
+ elif self.mode == 'tile':
+ corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
+ img_a,
+ img_b
+ )
+ else:
+ raise ValueError(f'unsupported mode: {self.mode}')
+ mask_a = con_a < THRESHOLD_SPARSE
+ mask_b = con_b < THRESHOLD_SPARSE
+ area_a = (con_a < THRESHOLD_AREA).sum() / mask_a.size
+ area_b = (con_b < THRESHOLD_AREA).sum() / mask_b.size
+ tasks = []
+
+ if queries_a is None:
+ index_a = np.where(mask_a)
+ index_a = np.array(index_a).T
+ index_a = index_a[np.random.choice(len(index_a), min(max_corrs, len(index_a)))]
+ index_b = np.where(mask_b)
+ index_b = np.array(index_b).T
+ index_b = index_b[np.random.choice(len(index_b), min(max_corrs, len(index_b)))]
+ for pos in index_a:
+ loc_from = pos[::-1]
+ loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins))
+ for pos in index_b:
+ '''
+ trick: suppose to fix the query point location(loc_from),
+ but here it fixes the first guess(loc_to).
+ '''
+ loc_from = pos[::-1]
+ loc_to = (corr_b[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_a.shape[:2][::-1]
+ tasks.append(RefinementTask(img_a, img_b, loc_to, loc_from, area_a, area_b, converge_iters, zoom_ins))
+ else:
+ if force:
+ for i, loc_from in enumerate(queries_a):
+ pos = loc_from[::-1]
+ pos = np.array([np.clip(pos[0], 0, corr_a.shape[0] - 1), np.clip(pos[1], 0, corr_a.shape[1] - 1)], dtype=np.int)
+ loc_to = (corr_a[tuple(pos)].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
+ else:
+ for i, loc_from in enumerate(queries_a):
+ pos = loc_from[::-1]
+ if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any():
+ continue
+ if mask_a[tuple(np.floor(pos).astype('int'))]:
+ loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
+ if len(tasks) < max_corrs:
+ extra = max_corrs - len(tasks)
+ counter = 0
+ for i, loc_from in enumerate(queries_a):
+ if counter >= extra:
+ break
+ pos = loc_from[::-1]
+ if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any():
+ continue
+ if mask_a[tuple(np.floor(pos).astype('int'))] == False:
+ loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
+ counter += 1
+ return tasks
+
+ def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None):
+ '''
+ currently only support fixed queries_a
+ '''
+ img_a = img_a.copy()
+ img_b = img_b.copy()
+ img_a_shape = img_a.shape[:2]
+ img_b_shape = img_b.shape[:2]
+ if queries_a is not None:
+ queries_a = queries_a.copy()
+ tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas)
+ while True:
+ num_g = self.num_good_tasks(tasks)
+ print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}')
+ task_ref, img_batch, query_batch = self.form_batch(tasks)
+ if len(task_ref) == 0:
+ break
+ if num_g >= max_corrs:
+ break
+ out = self.infer_batch(img_batch, query_batch)
+ for t, o in zip(task_ref, out):
+ t.step(o)
+ if return_tasks_only:
+ return tasks
+ if return_idx:
+ corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force,
+ img_a_shape=img_a_shape,
+ img_b_shape=img_b_shape,)
+ corrs = corrs[:max_corrs]
+ idx = idx[:max_corrs]
+ return corrs, idx
+ else:
+ corrs = self.conclude_tasks(tasks, force=force,
+ img_a_shape=img_a_shape,
+ img_b_shape=img_b_shape,)
+ corrs = corrs[:max_corrs]
+ return corrs
+
+ def cotr_corr_multiscale_with_cycle_consistency(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, return_cycle_error=False):
+ EXTRACTION_RATE = 0.3
+ temp_max_corrs = int(max_corrs / EXTRACTION_RATE)
+ if queries_a is not None:
+ temp_max_corrs = min(temp_max_corrs, queries_a.shape[0])
+ queries_a = queries_a.copy()
+ corr_f, idx_f = self.cotr_corr_multiscale(img_a.copy(), img_b.copy(),
+ zoom_ins=zoom_ins,
+ converge_iters=converge_iters,
+ max_corrs=temp_max_corrs,
+ queries_a=queries_a,
+ return_idx=True)
+ assert corr_f.shape[0] > 0
+ corr_b, idx_b = self.cotr_corr_multiscale(img_b.copy(), img_a.copy(),
+ zoom_ins=zoom_ins,
+ converge_iters=converge_iters,
+ max_corrs=corr_f.shape[0],
+ queries_a=corr_f[:, 2:].copy(),
+ return_idx=True)
+ assert corr_b.shape[0] > 0
+ cycle_errors = np.linalg.norm(corr_f[idx_b][:, :2] - corr_b[:, 2:], axis=1)
+ order = np.argsort(cycle_errors)
+ out = [corr_f[idx_b][order][:max_corrs]]
+ if return_idx:
+ out.append(idx_f[idx_b][order][:max_corrs])
+ if return_cycle_error:
+ out.append(cycle_errors[order][:max_corrs])
+ if len(out) == 1:
+ out = out[0]
+ return out
+
+
+class FasterSparseEngine(SparseEngine):
+ '''
+ search and merge nearby tasks to accelerate inference speed.
+ It will make spatial accuracy slightly worse.
+ '''
+
+ def __init__(self, model, batch_size, mode='stretching', max_load=256):
+ super().__init__(model, batch_size, mode=mode)
+ self.max_load = max_load
+
+ def infer_batch_grouped(self, img_batch, query_batch):
+ device = next(self.model.parameters()).device
+ img_batch = img_batch.to(device)
+ query_batch = query_batch.to(device)
+ out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach().cpu().numpy()
+ return out
+
+ def get_tasks_map(self, zoom, tasks):
+ maps = []
+ ids = []
+ for i, t in enumerate(tasks):
+ if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom:
+ t_info = t.peek()
+ point = np.concatenate([t_info['loc_from'], t_info['loc_to']])
+ maps.append(point)
+ ids.append(i)
+ return np.array(maps), np.array(ids)
+
+ def form_squad(self, zoom, pilot, pilot_id, tasks, tasks_map, task_ids, bookkeeping):
+ assert pilot.status == 'unfinished' and pilot.submitted == False and pilot.cur_zoom == zoom
+ SAFE_AREA = 0.5
+ pilot_info = pilot.peek()
+ pilot_from_center_x = pilot_info['patch_from'].x + pilot_info['patch_from'].w/2
+ pilot_from_center_y = pilot_info['patch_from'].y + pilot_info['patch_from'].h/2
+ pilot_from_left = pilot_from_center_x - pilot_info['patch_from'].w/2 * SAFE_AREA
+ pilot_from_right = pilot_from_center_x + pilot_info['patch_from'].w/2 * SAFE_AREA
+ pilot_from_upper = pilot_from_center_y - pilot_info['patch_from'].h/2 * SAFE_AREA
+ pilot_from_lower = pilot_from_center_y + pilot_info['patch_from'].h/2 * SAFE_AREA
+
+ pilot_to_center_x = pilot_info['patch_to'].x + pilot_info['patch_to'].w/2
+ pilot_to_center_y = pilot_info['patch_to'].y + pilot_info['patch_to'].h/2
+ pilot_to_left = pilot_to_center_x - pilot_info['patch_to'].w/2 * SAFE_AREA
+ pilot_to_right = pilot_to_center_x + pilot_info['patch_to'].w/2 * SAFE_AREA
+ pilot_to_upper = pilot_to_center_y - pilot_info['patch_to'].h/2 * SAFE_AREA
+ pilot_to_lower = pilot_to_center_y + pilot_info['patch_to'].h/2 * SAFE_AREA
+
+ img, query = pilot.get_task()
+ assert pilot.submitted == True
+ members = [pilot]
+ queries = [query]
+ bookkeeping[pilot_id] = False
+
+ loads = np.where(((tasks_map[:, 0] > pilot_from_left) &
+ (tasks_map[:, 0] < pilot_from_right) &
+ (tasks_map[:, 1] > pilot_from_upper) &
+ (tasks_map[:, 1] < pilot_from_lower) &
+ (tasks_map[:, 2] > pilot_to_left) &
+ (tasks_map[:, 2] < pilot_to_right) &
+ (tasks_map[:, 3] > pilot_to_upper) &
+ (tasks_map[:, 3] < pilot_to_lower)) *
+ bookkeeping)[0][: self.max_load]
+
+ for ti in task_ids[loads]:
+ t = tasks[ti]
+ assert t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom
+ _, query = t.get_task_pilot(pilot)
+ members.append(t)
+ queries.append(query)
+ queries = torch.stack(queries, axis=1)
+ bookkeeping[loads] = False
+ return members, img, queries, bookkeeping
+
+ def form_grouped_batch(self, zoom, tasks):
+ counter = 0
+ task_ref = []
+ img_batch = []
+ query_batch = []
+ tasks_map, task_ids = self.get_tasks_map(zoom, tasks)
+ shuffle = np.random.permutation(tasks_map.shape[0])
+ tasks_map = np.take(tasks_map, shuffle, axis=0)
+ task_ids = np.take(task_ids, shuffle, axis=0)
+ bookkeeping = np.ones_like(task_ids).astype(bool)
+
+ for i, ti in enumerate(task_ids):
+ t = tasks[ti]
+ if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom:
+ members, img, queries, bookkeeping = self.form_squad(zoom, t, i, tasks, tasks_map, task_ids, bookkeeping)
+ task_ref.append(members)
+ img_batch.append(img)
+ query_batch.append(queries)
+ counter += 1
+ if counter >= self.batch_size:
+ break
+ if len(task_ref) == 0:
+ return [], [], []
+
+ max_len = max([q.shape[1] for q in query_batch])
+ for i in range(len(query_batch)):
+ q = query_batch[i]
+ query_batch[i] = torch.cat([q, torch.zeros([1, max_len - q.shape[1], 2])], axis=1)
+ img_batch = torch.stack(img_batch)
+ query_batch = torch.cat(query_batch)
+ return task_ref, img_batch, query_batch
+
+ def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None):
+ '''
+ currently only support fixed queries_a
+ '''
+ img_a = img_a.copy()
+ img_b = img_b.copy()
+ img_a_shape = img_a.shape[:2]
+ img_b_shape = img_b.shape[:2]
+ if queries_a is not None:
+ queries_a = queries_a.copy()
+ tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas)
+ for zm in zoom_ins:
+ print(f'======= Zoom: {zm} ======')
+ while True:
+ num_g = self.num_good_tasks(tasks)
+ task_ref, img_batch, query_batch = self.form_grouped_batch(zm, tasks)
+ if len(task_ref) == 0:
+ break
+ if num_g >= max_corrs:
+ break
+ out = self.infer_batch_grouped(img_batch, query_batch)
+ num_steps = 0
+ for i, temp in enumerate(task_ref):
+ for j, t in enumerate(temp):
+ t.step(out[i, j])
+ num_steps += 1
+ print(f'solved {num_steps} sub-tasks in one invocation with {img_batch.shape[0]} image pairs')
+ if num_steps <= self.batch_size:
+ break
+ # Rollback to default inference, because of too few valid tasks can be grouped together.
+ while True:
+ num_g = self.num_good_tasks(tasks)
+ print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}')
+ task_ref, img_batch, query_batch = self.form_batch(tasks, zm)
+ if len(task_ref) == 0:
+ break
+ if num_g >= max_corrs:
+ break
+ out = self.infer_batch(img_batch, query_batch)
+ for t, o in zip(task_ref, out):
+ t.step(o)
+
+ if return_tasks_only:
+ return tasks
+ if return_idx:
+ corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force,
+ img_a_shape=img_a_shape,
+ img_b_shape=img_b_shape,)
+ corrs = corrs[:max_corrs]
+ idx = idx[:max_corrs]
+ return corrs, idx
+ else:
+ corrs = self.conclude_tasks(tasks, force=force,
+ img_a_shape=img_a_shape,
+ img_b_shape=img_b_shape,)
+ corrs = corrs[:max_corrs]
+ return corrs
diff --git a/imcui/third_party/COTR/COTR/models/__init__.py b/imcui/third_party/COTR/COTR/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..263568f9eb947e24801c8e27242c66a12b4a97ed
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/models/__init__.py
@@ -0,0 +1,10 @@
+'''
+The COTR model is modified from DETR code base.
+https://github.com/facebookresearch/detr
+'''
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from .cotr_model import build
+
+
+def build_model(args):
+ return build(args)
diff --git a/imcui/third_party/COTR/COTR/models/backbone.py b/imcui/third_party/COTR/COTR/models/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5b5def2bfe403d3d12077b48ad76d8d97f2f84
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/models/backbone.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Backbone modules.
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from typing import Dict, List
+
+from .misc import NestedTensor
+
+from .position_encoding import build_position_encoding
+from COTR.utils import debug_utils, constants
+
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other models than torchvision.models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = 1e-5
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool, layer='layer3'):
+ super().__init__()
+ for name, parameter in backbone.named_parameters():
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+ parameter.requires_grad_(False)
+ # print(f'freeze {name}')
+ if return_interm_layers:
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ else:
+ return_layers = {layer: "0"}
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.num_channels = num_channels
+
+ def forward_raw(self, x):
+ y = self.body(x)
+ assert len(y.keys()) == 1
+ return y['0']
+
+ def forward(self, tensor_list: NestedTensor):
+ assert tensor_list.tensors.shape[-2:] == (constants.MAX_SIZE, constants.MAX_SIZE * 2)
+ left = self.body(tensor_list.tensors[..., 0:constants.MAX_SIZE])
+ right = self.body(tensor_list.tensors[..., constants.MAX_SIZE:2 * constants.MAX_SIZE])
+ xs = {}
+ for k in left.keys():
+ xs[k] = torch.cat([left[k], right[k]], dim=-1)
+ out: Dict[str, NestedTensor] = {}
+ for name, x in xs.items():
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ return out
+
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+
+ def __init__(self, name: str,
+ train_backbone: bool,
+ return_interm_layers: bool,
+ dilation: bool,
+ layer='layer3',
+ num_channels=1024):
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=True, norm_layer=FrozenBatchNorm2d)
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers, layer)
+
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for name, x in xs.items():
+ out.append(x)
+ # position encoding
+ pos.append(self[1](x).to(x.tensors.dtype))
+
+ return out, pos
+
+
+def build_backbone(args):
+ position_embedding = build_position_encoding(args)
+ if hasattr(args, 'lr_backbone'):
+ train_backbone = args.lr_backbone > 0
+ else:
+ train_backbone = False
+ backbone = Backbone(args.backbone, train_backbone, False, args.dilation, layer=args.layer, num_channels=args.dim_feedforward)
+ model = Joiner(backbone, position_embedding)
+ model.num_channels = backbone.num_channels
+ return model
diff --git a/imcui/third_party/COTR/COTR/models/cotr_model.py b/imcui/third_party/COTR/COTR/models/cotr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..72eced8ada79d04898acfeacd7cf61516b52f8a3
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/models/cotr_model.py
@@ -0,0 +1,51 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from COTR.utils import debug_utils, constants, utils
+from .misc import (NestedTensor, nested_tensor_from_tensor_list)
+from .backbone import build_backbone
+from .transformer import build_transformer
+from .position_encoding import NerfPositionalEncoding, MLP
+
+
+class COTR(nn.Module):
+
+ def __init__(self, backbone, transformer, sine_type='lin_sine'):
+ super().__init__()
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+ self.corr_embed = MLP(hidden_dim, hidden_dim, 2, 3)
+ self.query_proj = NerfPositionalEncoding(hidden_dim // 4, sine_type)
+ self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
+ self.backbone = backbone
+
+ def forward(self, samples: NestedTensor, queries):
+ if isinstance(samples, (list, torch.Tensor)):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.backbone(samples)
+
+ src, mask = features[-1].decompose()
+ assert mask is not None
+ _b, _q, _ = queries.shape
+ queries = queries.reshape(-1, 2)
+ queries = self.query_proj(queries).reshape(_b, _q, -1)
+ queries = queries.permute(1, 0, 2)
+ hs = self.transformer(self.input_proj(src), mask, queries, pos[-1])[0]
+ outputs_corr = self.corr_embed(hs)
+ out = {'pred_corrs': outputs_corr[-1]}
+ return out
+
+
+def build(args):
+ backbone = build_backbone(args)
+ transformer = build_transformer(args)
+ model = COTR(
+ backbone,
+ transformer,
+ sine_type=args.position_embedding,
+ )
+ return model
diff --git a/imcui/third_party/COTR/COTR/models/misc.py b/imcui/third_party/COTR/COTR/models/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6093194fe1b68a75f819cd90a233d7cb00867a7
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/models/misc.py
@@ -0,0 +1,112 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import os
+import subprocess
+import time
+from collections import defaultdict, deque
+import datetime
+import pickle
+from typing import Optional, List
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+if float(torchvision.__version__.split('.')[1]) < 7:
+ from torchvision.ops import _new_empty_tensor
+ from torchvision.ops.misc import _output_size
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], :img.shape[2]] = False
+ else:
+ raise ValueError('not supported')
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
diff --git a/imcui/third_party/COTR/COTR/models/position_encoding.py b/imcui/third_party/COTR/COTR/models/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..9207015d2671b87ec80c58b1f711e185af5db8de
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/models/position_encoding.py
@@ -0,0 +1,83 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .misc import NestedTensor
+from COTR.utils import debug_utils
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class NerfPositionalEncoding(nn.Module):
+ def __init__(self, depth=10, sine_type='lin_sine'):
+ '''
+ out_dim = in_dim * depth * 2
+ '''
+ super().__init__()
+ if sine_type == 'lin_sine':
+ self.bases = [i+1 for i in range(depth)]
+ elif sine_type == 'exp_sine':
+ self.bases = [2**i for i in range(depth)]
+ print(f'using {sine_type} as positional encoding')
+
+ @torch.no_grad()
+ def forward(self, inputs):
+ out = torch.cat([torch.sin(i * math.pi * inputs) for i in self.bases] + [torch.cos(i * math.pi * inputs) for i in self.bases], axis=-1)
+ assert torch.isnan(out).any() == False
+ return out
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None, sine_type='lin_sine'):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.sine = NerfPositionalEncoding(num_pos_feats//2, sine_type)
+
+ @torch.no_grad()
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ eps = 1e-6
+ y_embed = (y_embed-0.5) / (y_embed[:, -1:, :] + eps)
+ x_embed = (x_embed-0.5) / (x_embed[:, :, -1:] + eps)
+ pos = torch.stack([x_embed, y_embed], dim=-1)
+ return self.sine(pos).permute(0, 3, 1, 2)
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ('lin_sine', 'exp_sine'):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True, sine_type=args.position_embedding)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/imcui/third_party/COTR/COTR/models/transformer.py b/imcui/third_party/COTR/COTR/models/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..69ba7e318dd4c0ba80042742e19360ec6dfad683
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/models/transformer.py
@@ -0,0 +1,228 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+COTR/DETR Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import Optional, List
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from COTR.utils import debug_utils
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
+ activation="relu", return_intermediate_dec=False):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation)
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
+
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation)
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
+ return_intermediate=return_intermediate_dec)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, query_embed, pos_embed):
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ mask = mask.flatten(1)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
+ pos=pos_embed, query_pos=query_embed)
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ def forward(self, src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ output = src
+
+ for layer in self.layers:
+ output = layer(output, src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(output, memory, tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos, query_pos=query_pos)
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(query=q,
+ key=k,
+ value=src,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+
+class TransformerDecoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu"):
+ super().__init__()
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_transformer(args):
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ return_intermediate_dec=True,
+ )
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
diff --git a/imcui/third_party/COTR/COTR/options/options.py b/imcui/third_party/COTR/COTR/options/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b7d8b739b4de554d7c3c0bb3679a5f10434659
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/options/options.py
@@ -0,0 +1,52 @@
+import sys
+import argparse
+import json
+import os
+
+
+from COTR.options.options_utils import str2bool
+from COTR.options import options_utils
+from COTR.global_configs import general_config, dataset_config
+from COTR.utils import debug_utils
+
+
+def set_general_arguments(parser):
+ general_arg = parser.add_argument_group('General')
+ general_arg.add_argument('--confirm', type=str2bool,
+ default=True, help='promote confirmation for user')
+ general_arg.add_argument('--use_cuda', type=str2bool,
+ default=True, help='use cuda')
+ general_arg.add_argument('--use_cc', type=str2bool,
+ default=False, help='use computecanada')
+
+
+def set_dataset_arguments(parser):
+ data_arg = parser.add_argument_group('Data')
+ data_arg.add_argument('--dataset_name', type=str, default='megadepth', help='dataset name')
+ data_arg.add_argument('--shuffle_data', type=str2bool, default=True, help='use sequence dataset or shuffled dataset')
+ data_arg.add_argument('--use_ram', type=str2bool, default=False, help='load image/depth/pcd to ram')
+ data_arg.add_argument('--info_level', choices=['rgb', 'rgbd'], type=str, default='rgbd', help='the information level of dataset')
+ data_arg.add_argument('--scene_file', type=str, default=None, required=False, help='what scene/seq want to use')
+ data_arg.add_argument('--workers', type=int, default=0, help='worker for loading data')
+ data_arg.add_argument('--crop_cam', choices=['no_crop', 'crop_center', 'crop_center_and_resize'], type=str, default='crop_center_and_resize', help='crop the center of image to avoid changing aspect ratio, resize to make the operations batch-able.')
+
+
+def set_nn_arguments(parser):
+ nn_arg = parser.add_argument_group('Nearest neighbors')
+ nn_arg.add_argument('--nn_method', choices=['netvlad', 'overlapping'], type=str, default='overlapping', help='how to select nearest neighbors')
+ nn_arg.add_argument('--pool_size', type=int, default=20, help='a pool of sorted nn candidates')
+ nn_arg.add_argument('--k_size', type=int, default=1, help='select the nn randomly from pool')
+
+
+def set_COTR_arguments(parser):
+ cotr_arg = parser.add_argument_group('COTR model')
+ cotr_arg.add_argument('--backbone', type=str, default='resnet50')
+ cotr_arg.add_argument('--hidden_dim', type=int, default=256)
+ cotr_arg.add_argument('--dilation', type=str2bool, default=False)
+ cotr_arg.add_argument('--dropout', type=float, default=0.1)
+ cotr_arg.add_argument('--nheads', type=int, default=8)
+ cotr_arg.add_argument('--layer', type=str, default='layer3', help='which layer from resnet')
+ cotr_arg.add_argument('--enc_layers', type=int, default=6)
+ cotr_arg.add_argument('--dec_layers', type=int, default=6)
+ cotr_arg.add_argument('--position_embedding', type=str, default='lin_sine', help='sine wave type')
+
diff --git a/imcui/third_party/COTR/COTR/options/options_utils.py b/imcui/third_party/COTR/COTR/options/options_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..12fc1f1ae98ec89ad249fd99c7f9415e350f93de
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/options/options_utils.py
@@ -0,0 +1,108 @@
+'''utils for argparse
+'''
+
+import sys
+import os
+from os import path
+import time
+import json
+
+from COTR.utils import utils, debug_utils
+from COTR.global_configs import general_config, dataset_config
+
+
+def str2bool(v: str) -> bool:
+ return v.lower() in ('true', '1', 'yes', 'y', 't')
+
+
+def get_compact_naming_cotr(opt) -> str:
+ base_str = 'model:cotr_{0}_{1}_{2}_dset:{3}_bs:{4}_pe:{5}_lrbackbone:{6}'
+ result = base_str.format(opt.backbone,
+ opt.layer,
+ opt.dim_feedforward,
+ opt.dataset_name,
+ opt.batch_size,
+ opt.position_embedding,
+ opt.lr_backbone,
+ )
+ if opt.suffix:
+ result = result + '_suffix:{0}'.format(opt.suffix)
+ return result
+
+
+def print_opt(opt):
+ content_list = []
+ args = list(vars(opt))
+ args.sort()
+ for arg in args:
+ content_list += [arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg))]
+ utils.print_notification(content_list, 'OPTIONS')
+
+
+def confirm_opt(opt):
+ print_opt(opt)
+ if opt.use_cc == False:
+ if not utils.confirm():
+ exit(1)
+
+
+def opt_to_string(opt) -> str:
+ string = '\n\n'
+ string += 'python ' + ' '.join(sys.argv)
+ string += '\n\n'
+ # string += '---------------------- CONFIG ----------------------\n'
+ args = list(vars(opt))
+ args.sort()
+ for arg in args:
+ string += arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg)) + '\n\n'
+ # string += '----------------------------------------------------\n'
+ return string
+
+
+def save_opt(opt):
+ '''save options to a json file
+ '''
+ if not os.path.exists(opt.out):
+ os.makedirs(opt.out)
+ json_path = os.path.join(opt.out, 'params.json')
+ if 'debug' not in opt.suffix and path.isfile(json_path):
+ assert opt.resume, 'You are trying to modify a model without resuming: {0}'.format(opt.out)
+ old_dict = json.load(open(json_path))
+ new_dict = vars(opt)
+ # assert old_dict.keys() == new_dict.keys(), 'New configuration keys is different from old one.\nold: {0}\nnew: {1}'.format(old_dict.keys(), new_dict.keys())
+ if new_dict != old_dict:
+ exception_keys = ['command']
+ for key in set(old_dict.keys()).union(set(new_dict.keys())):
+ if key not in exception_keys:
+ old_val = old_dict[key] if key in old_dict else 'not exists(old)'
+ new_val = new_dict[key] if key in old_dict else 'not exists(new)'
+ if old_val != new_val:
+ print('key: {0}, old_val: {1}, new_val: {2}'.format(key, old_val, new_val))
+ if opt.use_cc == False:
+ if not utils.confirm('Please manually confirm'):
+ exit(1)
+ with open(json_path, 'w') as fp:
+ json.dump(vars(opt), fp, indent=0, sort_keys=True)
+
+
+def build_scenes_name_list_from_opt(opt):
+ if hasattr(opt, 'scene_file') and opt.scene_file is not None:
+ assert os.path.isfile(opt.scene_file), opt.scene_file
+ with open(opt.scene_file, 'r') as f:
+ scenes_list = json.load(f)
+ else:
+ scenes_list = [{'scene': opt.scene, 'seq': opt.seq}]
+ if 'megadepth' in opt.dataset_name:
+ assert opt.info_level in ['rgb', 'rgbd']
+ scenes_name_list = []
+ if opt.info_level == 'rgb':
+ dir_list = ['scene_dir', 'image_dir']
+ elif opt.info_level == 'rgbd':
+ dir_list = ['scene_dir', 'image_dir', 'depth_dir']
+ dir_list = {dir_name: dataset_config[opt.dataset_name][dir_name] for dir_name in dir_list}
+ for item in scenes_list:
+ cur_scene = {key: val.format(item['scene'], item['seq']) for key, val in dir_list.items()}
+ scenes_name_list.append(cur_scene)
+ else:
+ raise NotImplementedError()
+ return scenes_name_list
diff --git a/imcui/third_party/COTR/COTR/projector/pcd_projector.py b/imcui/third_party/COTR/COTR/projector/pcd_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..4847bd50920c84066fa7b36ea1c434521d7f12d9
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/projector/pcd_projector.py
@@ -0,0 +1,210 @@
+'''
+a point cloud projector based on np
+'''
+
+import numpy as np
+
+from COTR.utils import debug_utils, utils
+
+
+def render_point_cloud_at_capture(point_cloud, capture, render_type='rgb', return_pcd=False):
+ assert render_type in ['rgb', 'bw', 'depth']
+ if render_type == 'rgb':
+ assert point_cloud.shape[1] == 6
+ else:
+ point_cloud = point_cloud[:, :3]
+ assert point_cloud.shape[1] == 3
+ if render_type in ['bw', 'rgb']:
+ keep_z = False
+ else:
+ keep_z = True
+
+ pcd_2d = PointCloudProjector.pcd_3d_to_pcd_2d_np(point_cloud,
+ capture.intrinsic_mat,
+ capture.extrinsic_mat,
+ capture.size,
+ keep_z=True,
+ crop=True,
+ filter_neg=True,
+ norm_coord=False,
+ return_index=False)
+ reproj = PointCloudProjector.pcd_2d_to_img_2d_np(pcd_2d,
+ capture.size,
+ has_z=True,
+ keep_z=keep_z)
+ if return_pcd:
+ return reproj, pcd_2d
+ else:
+ return reproj
+
+
+def optical_flow_from_a_to_b(cap_a, cap_b):
+ cap_a_intrinsic = cap_a.pinhole_cam.intrinsic_mat
+ cap_a_img_size = cap_a.pinhole_cam.shape[:2]
+ _h, _w = cap_b.pinhole_cam.shape[:2]
+ x, y = np.meshgrid(
+ np.linspace(0, _w - 1, num=_w),
+ np.linspace(0, _h - 1, num=_h),
+ )
+ coord_map = np.concatenate([np.expand_dims(x, 2), np.expand_dims(y, 2)], axis=2)
+ pcd_from_cap_b = cap_b.get_point_cloud_world_from_depth(coord_map)
+ # pcd_from_cap_b = cap_b.point_cloud_world_w_feat(['pos', 'coord'])
+ optical_flow = PointCloudProjector.pcd_2d_to_img_2d_np(PointCloudProjector.pcd_3d_to_pcd_2d_np(pcd_from_cap_b, cap_a_intrinsic, cap_a.cam_pose.world_to_camera[0:3, :], cap_a_img_size, keep_z=True, crop=True, filter_neg=True, norm_coord=False), cap_a_img_size, has_z=True, keep_z=False)
+ return optical_flow
+
+
+class PointCloudProjector():
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def pcd_2d_to_pcd_3d_np(pcd, depth, intrinsic, motion=None, return_index=False):
+ assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd))
+ assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
+ assert len(pcd.shape) == 2 and pcd.shape[1] >= 2
+ assert len(depth.shape) == 2 and depth.shape[1] == 1
+ assert intrinsic.shape == (3, 3)
+ if motion is not None:
+ assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion))
+ assert motion.shape == (4, 4)
+ # exec(debug_utils.embed_breakpoint())
+ x, y, z = pcd[:, 0], pcd[:, 1], depth[:, 0]
+ append_ones = np.ones_like(x)
+ xyz = np.stack([x, y, append_ones], axis=1) # shape: [num_points, 3]
+ inv_intrinsic_mat = np.linalg.inv(intrinsic)
+ xyz = np.matmul(inv_intrinsic_mat, xyz.T).T * z[..., None]
+ valid_mask_1 = np.where(xyz[:, 2] > 0)
+ xyz = xyz[valid_mask_1]
+
+ if motion is not None:
+ append_ones = np.ones_like(xyz[:, 0:1])
+ xyzw = np.concatenate([xyz, append_ones], axis=1)
+ xyzw = np.matmul(motion, xyzw.T).T
+ valid_mask_2 = np.where(xyzw[:, 3] != 0)
+ xyzw = xyzw[valid_mask_2]
+ xyzw /= xyzw[:, 3:4]
+ xyz = xyzw[:, 0:3]
+
+ if pcd.shape[1] > 2:
+ features = pcd[:, 2:]
+ try:
+ features = features[valid_mask_1][valid_mask_2]
+ except UnboundLocalError:
+ features = features[valid_mask_1]
+ assert xyz.shape[0] == features.shape[0]
+ xyz = np.concatenate([xyz, features], axis=1)
+ if return_index:
+ points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2]
+ return xyz, points_index
+ return xyz
+
+ @staticmethod
+ def img_2d_to_pcd_3d_np(depth, intrinsic, img=None, motion=None):
+ '''
+ the function signature is not fully correct, because img is an optional
+ if motion is None, the output pcd is in camera space
+ if motion is camera_to_world, the out pcd is in world space.
+ here the output is pure np array
+ '''
+
+ assert isinstance(depth, np.ndarray), 'cannot process data type: {0}'.format(type(depth))
+ assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
+ assert len(depth.shape) == 2
+ assert intrinsic.shape == (3, 3)
+ if img is not None:
+ assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img))
+ assert len(img.shape) == 3
+ assert img.shape[:2] == depth.shape[:2], 'feature should have the same resolution as the depth'
+ if motion is not None:
+ assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion))
+ assert motion.shape == (4, 4)
+
+ pcd_image_space = PointCloudProjector.img_2d_to_pcd_2d_np(depth[..., None], norm_coord=False)
+ valid_mask_1 = np.where(pcd_image_space[:, 2] > 0)
+ pcd_image_space = pcd_image_space[valid_mask_1]
+ xy = pcd_image_space[:, :2]
+ z = pcd_image_space[:, 2:3]
+ if img is not None:
+ _c = img.shape[-1]
+ feat = img.reshape(-1, _c)
+ feat = feat[valid_mask_1]
+ xy = np.concatenate([xy, feat], axis=1)
+ pcd_3d = PointCloudProjector.pcd_2d_to_pcd_3d_np(xy, z, intrinsic, motion=motion)
+ return pcd_3d
+
+ @staticmethod
+ def pcd_3d_to_pcd_2d_np(pcd, intrinsic, extrinsic, size, keep_z: bool, crop: bool = True, filter_neg: bool = True, norm_coord: bool = True, return_index: bool = False):
+ assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd))
+ assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
+ assert isinstance(extrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(extrinsic))
+ assert len(pcd.shape) == 2 and pcd.shape[1] >= 3, 'seems the input pcd is not a valid 3d point cloud: {0}'.format(pcd.shape)
+
+ xyzw = np.concatenate([pcd[:, 0:3], np.ones_like(pcd[:, 0:1])], axis=1)
+ mvp_mat = np.matmul(intrinsic, extrinsic)
+ camera_points = np.matmul(mvp_mat, xyzw.T).T
+ if filter_neg:
+ valid_mask_1 = camera_points[:, 2] > 0.0
+ else:
+ valid_mask_1 = np.ones_like(camera_points[:, 2], dtype=bool)
+ camera_points = camera_points[valid_mask_1]
+ image_points = camera_points / camera_points[:, 2:3]
+ image_points = image_points[:, :2]
+ if crop:
+ valid_mask_2 = (image_points[:, 0] >= 0) * (image_points[:, 0] < size[1] - 1) * (image_points[:, 1] >= 0) * (image_points[:, 1] < size[0] - 1)
+ else:
+ valid_mask_2 = np.ones_like(image_points[:, 0], dtype=bool)
+ if norm_coord:
+ image_points = ((image_points / size[::-1]) * 2) - 1
+
+ if keep_z:
+ image_points = np.concatenate([image_points[valid_mask_2], camera_points[valid_mask_2][:, 2:3], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1)
+ else:
+ image_points = np.concatenate([image_points[valid_mask_2], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1)
+ # if filter_neg and crop:
+ # exec(debug_utils.embed_breakpoint('pcd_3d_to_pcd_2d_np'))
+ if return_index:
+ points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2]
+ return image_points, points_index
+ return image_points
+
+ @staticmethod
+ def pcd_2d_to_img_2d_np(pcd, size, has_z=False, keep_z=False):
+ assert len(pcd.shape) == 2 and pcd.shape[-1] >= 2, 'seems the input pcd is not a valid point cloud: {0}'.format(pcd.shape)
+ # assert 0, 'pass Z values in'
+ if has_z:
+ pcd = pcd[pcd[:, 2].argsort()[::-1]]
+ if not keep_z:
+ pcd = np.delete(pcd, [2], axis=1)
+ index_list = np.round(pcd[:, 0:2]).astype(np.int32)
+ index_list[:, 0] = np.clip(index_list[:, 0], 0, size[1] - 1)
+ index_list[:, 1] = np.clip(index_list[:, 1], 0, size[0] - 1)
+ _h, _w, _c = *size, pcd.shape[-1] - 2
+ if _c == 0:
+ canvas = np.zeros((_h, _w, 1))
+ canvas[index_list[:, 1], index_list[:, 0]] = 1.0
+ else:
+ canvas = np.zeros((_h, _w, _c))
+ canvas[index_list[:, 1], index_list[:, 0]] = pcd[:, 2:]
+
+ return canvas
+
+ @staticmethod
+ def img_2d_to_pcd_2d_np(img, norm_coord=True):
+ assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img))
+ assert len(img.shape) == 3
+
+ _h, _w, _c = img.shape
+ if norm_coord:
+ x, y = np.meshgrid(
+ np.linspace(-1, 1, num=_w),
+ np.linspace(-1, 1, num=_h),
+ )
+ else:
+ x, y = np.meshgrid(
+ np.linspace(0, _w - 1, num=_w),
+ np.linspace(0, _h - 1, num=_h),
+ )
+ x, y = x.reshape(-1, 1), y.reshape(-1, 1)
+ feat = img.reshape(-1, _c)
+ pcd_2d = np.concatenate([x, y, feat], axis=1)
+ return pcd_2d
diff --git a/imcui/third_party/COTR/COTR/sfm_scenes/knn_search.py b/imcui/third_party/COTR/COTR/sfm_scenes/knn_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2bda01cd48571c8641569f2f7d243288de4899
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/sfm_scenes/knn_search.py
@@ -0,0 +1,56 @@
+'''
+Given one capture in a scene, search for its KNN captures
+'''
+
+import os
+
+import numpy as np
+
+from COTR.utils import debug_utils
+from COTR.utils.constants import VALID_NN_OVERLAPPING_THRESH
+
+
+class ReprojRatioKnnSearch():
+ def __init__(self, scene):
+ self.scene = scene
+ self.distance_mat = None
+ self.nn_index = None
+ self._read_dist_mat()
+ self._build_nn_index()
+
+ def _read_dist_mat(self):
+ dist_mat_path = os.path.join(os.path.dirname(os.path.dirname(self.scene.captures[0].depth_path)), 'dist_mat/dist_mat.npy')
+ self.distance_mat = np.load(dist_mat_path)
+
+ def _build_nn_index(self):
+ # argsort is in ascending order, so we take negative
+ self.nn_index = (-1 * self.distance_mat).argsort(axis=1)
+
+ def get_knn(self, query, k, db_mask=None):
+ query_index = self.scene.img_path_to_index_dict[query.img_path]
+ if db_mask is not None:
+ query_mask = np.setdiff1d(np.arange(self.distance_mat[query_index].shape[0]), db_mask)
+ num_pos = (self.distance_mat[query_index] > VALID_NN_OVERLAPPING_THRESH).sum() if db_mask is None else (self.distance_mat[query_index][db_mask] > VALID_NN_OVERLAPPING_THRESH).sum()
+ # we have enough valid NN or not
+ if num_pos > k:
+ if db_mask is None:
+ ind = self.nn_index[query_index][:k + 1]
+ else:
+ temp_dist = self.distance_mat[query_index].copy()
+ temp_dist[query_mask] = -1
+ ind = (-1 * temp_dist).argsort(axis=0)[:k + 1]
+ # remove self
+ if query_index in ind:
+ ind = np.delete(ind, np.argwhere(ind == query_index))
+ else:
+ ind = ind[:k]
+ assert ind.shape[0] <= k, ind.shape[0] > 0
+ else:
+ k = num_pos
+ if db_mask is None:
+ ind = self.nn_index[query_index][:max(k, 1)]
+ else:
+ temp_dist = self.distance_mat[query_index].copy()
+ temp_dist[query_mask] = -1
+ ind = (-1 * temp_dist).argsort(axis=0)[:max(k, 1)]
+ return self.scene.get_captures_given_index_list(ind)
diff --git a/imcui/third_party/COTR/COTR/sfm_scenes/sfm_scenes.py b/imcui/third_party/COTR/COTR/sfm_scenes/sfm_scenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3cf0dece9e613073e445eae4e94a14e2145087
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/sfm_scenes/sfm_scenes.py
@@ -0,0 +1,87 @@
+'''
+Scene reconstructed from SFM, mainly colmap
+'''
+import os
+import copy
+import math
+
+import numpy as np
+from numpy.linalg import inv
+from tqdm import tqdm
+
+from COTR.transformations import transformations
+from COTR.transformations.transform_basics import Translation, Rotation
+from COTR.cameras.camera_pose import CameraPose
+from COTR.utils import debug_utils
+
+
+class SfmScene():
+ def __init__(self, captures, point_cloud=None):
+ self.captures = captures
+ if isinstance(point_cloud, tuple):
+ self.point_cloud = point_cloud[0]
+ self.point_meta = point_cloud[1]
+ else:
+ self.point_cloud = point_cloud
+ self.img_path_to_index_dict = {}
+ self.img_id_to_index_dict = {}
+ self.fname_to_index_dict = {}
+ self._build_img_X_to_index_dict()
+
+ def __str__(self):
+ string = 'Scene contains {0} captures'.format(len(self.captures))
+ return string
+
+ def __getitem__(self, x):
+ if isinstance(x, str):
+ try:
+ return self.captures[self.img_path_to_index_dict[x]]
+ except:
+ return self.captures[self.fname_to_index_dict[x]]
+ else:
+ return self.captures[x]
+
+ def _build_img_X_to_index_dict(self):
+ assert self.captures is not None, 'There is no captures'
+ for i, cap in enumerate(self.captures):
+ assert cap.img_path not in self.img_path_to_index_dict, 'Image already exists'
+ self.img_path_to_index_dict[cap.img_path] = i
+ assert os.path.basename(cap.img_path) not in self.fname_to_index_dict, 'Image already exists'
+ self.fname_to_index_dict[os.path.basename(cap.img_path)] = i
+ if hasattr(cap, 'image_id'):
+ self.img_id_to_index_dict[cap.image_id] = i
+
+ def get_captures_given_index_list(self, index_list):
+ captures_list = []
+ for i in index_list:
+ captures_list.append(self.captures[i])
+ return captures_list
+
+ def get_covisible_caps(self, cap):
+ assert cap.img_path in self.img_path_to_index_dict
+ covis_img_id = set()
+ point_ids = cap.point3d_id
+ for i in point_ids:
+ covis_img_id = covis_img_id.union(set(self.point_meta[i].image_ids))
+ covis_caps = []
+ for i in covis_img_id:
+ if i in self.img_id_to_index_dict:
+ covis_caps.append(self.captures[self.img_id_to_index_dict[i]])
+ else:
+ pass
+ return covis_caps
+
+ def read_data_to_ram(self, data_list):
+ print('warning: you are going to use a lot of RAM.')
+ sum_bytes = 0.0
+ pbar = tqdm(self.captures, desc='reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0)))
+ for cap in pbar:
+ if 'image' in data_list:
+ sum_bytes += cap.read_image_to_ram()
+ if 'depth' in data_list:
+ sum_bytes += cap.read_depth_to_ram()
+ if 'pcd' in data_list:
+ sum_bytes += cap.read_pcd_to_ram()
+ pbar.set_description('reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0)))
+ print('----- total memory usage for images: {0} MB-----'.format(sum_bytes / (1024.0 * 1024.0)))
+
diff --git a/imcui/third_party/COTR/COTR/trainers/base_trainer.py b/imcui/third_party/COTR/COTR/trainers/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..83afea69bb3ec76f1344e62f121efcb8b4c00a8b
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/trainers/base_trainer.py
@@ -0,0 +1,111 @@
+import os
+import math
+import abc
+import time
+
+import tqdm
+import torch.nn as nn
+import tensorboardX
+
+from COTR.trainers import tensorboard_helper
+from COTR.utils import utils
+from COTR.options import options_utils
+
+
+class BaseTrainer(abc.ABC):
+ '''base trainer class.
+ contains methods for training, validation, and writing output.
+ '''
+
+ def __init__(self, opt, model, optimizer, criterion,
+ train_loader, val_loader):
+ self.opt = opt
+ self.use_cuda = opt.use_cuda
+ self.model = model
+ self.optim = optimizer
+ self.criterion = criterion
+ self.train_loader = train_loader
+ self.val_loader = val_loader
+ self.out = opt.out
+ if not os.path.exists(opt.out):
+ os.makedirs(opt.out)
+ self.epoch = 0
+ self.iteration = 0
+ self.max_iter = opt.max_iter
+ self.valid_iter = opt.valid_iter
+ self.tb_pusher = tensorboard_helper.TensorboardPusher(opt)
+ self.push_opt_to_tb()
+ self.need_resume = opt.resume
+ if self.need_resume:
+ self.resume()
+ if self.opt.load_weights:
+ self.load_pretrained_weights()
+
+ def push_opt_to_tb(self):
+ opt_str = options_utils.opt_to_string(self.opt)
+ tb_datapack = tensorboard_helper.TensorboardDatapack()
+ tb_datapack.set_training(False)
+ tb_datapack.set_iteration(self.iteration)
+ tb_datapack.add_text({'options': opt_str})
+ self.tb_pusher.push_to_tensorboard(tb_datapack)
+
+ @abc.abstractmethod
+ def validate_batch(self, data_pack):
+ pass
+
+ @abc.abstractmethod
+ def validate(self):
+ pass
+
+ @abc.abstractmethod
+ def train_batch(self, data_pack):
+ '''train for one batch of data
+ '''
+ pass
+
+ def train_epoch(self):
+ '''train for one epoch
+ one epoch is iterating the whole training dataset once
+ '''
+ self.model.train()
+ for batch_idx, data_pack in tqdm.tqdm(enumerate(self.train_loader),
+ initial=self.iteration % len(
+ self.train_loader),
+ total=len(self.train_loader),
+ desc='Train epoch={0}'.format(
+ self.epoch),
+ ncols=80,
+ leave=True,
+ ):
+
+ # iteration = batch_idx + self.epoch * len(self.train_loader)
+ # if self.iteration != 0 and (iteration - 1) != self.iteration:
+ # continue # for resuming
+ # self.iteration = iteration
+ # self.iteration += 1
+ if self.iteration % self.valid_iter == 0:
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.validate()
+ self.train_batch(data_pack)
+
+ if self.iteration >= self.max_iter:
+ break
+ self.iteration += 1
+
+ def train(self):
+ '''entrance of the whole training process
+ '''
+ max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader)))
+ for epoch in tqdm.trange(self.epoch,
+ max_epoch,
+ desc='Train',
+ ncols=80):
+ self.epoch = epoch
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.train_epoch()
+ if self.iteration >= self.max_iter:
+ break
+
+ @abc.abstractmethod
+ def resume(self):
+ pass
diff --git a/imcui/third_party/COTR/COTR/trainers/cotr_trainer.py b/imcui/third_party/COTR/COTR/trainers/cotr_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf87677810b875ef5d29984857e11fc8c3cd7b17
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/trainers/cotr_trainer.py
@@ -0,0 +1,200 @@
+import os
+import math
+import os.path as osp
+import time
+
+import tqdm
+import torch
+import numpy as np
+import torchvision.utils as vutils
+from PIL import Image, ImageDraw
+
+
+from COTR.utils import utils, debug_utils, constants
+from COTR.trainers import base_trainer, tensorboard_helper
+from COTR.projector import pcd_projector
+
+
+class COTRTrainer(base_trainer.BaseTrainer):
+ def __init__(self, opt, model, optimizer, criterion,
+ train_loader, val_loader):
+ super().__init__(opt, model, optimizer, criterion,
+ train_loader, val_loader)
+
+ def validate_batch(self, data_pack):
+ assert self.model.training is False
+ with torch.no_grad():
+ img = data_pack['image'].cuda()
+ query = data_pack['queries'].cuda()
+ target = data_pack['targets'].cuda()
+ self.optim.zero_grad()
+ pred = self.model(img, query)['pred_corrs']
+ loss = torch.nn.functional.mse_loss(pred, target)
+ if self.opt.cycle_consis and self.opt.bidirectional:
+ cycle = self.model(img, pred)['pred_corrs']
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
+ if mask.sum() > 0:
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
+ loss += cycle_loss
+ elif self.opt.cycle_consis and not self.opt.bidirectional:
+ img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1)
+ query_reverse = pred.clone()
+ query_reverse[..., 0] = query_reverse[..., 0] - 0.5
+ cycle = self.model(img_reverse, query_reverse)['pred_corrs']
+ cycle[..., 0] = cycle[..., 0] - 0.5
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
+ if mask.sum() > 0:
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
+ loss += cycle_loss
+ loss_data = loss.data.item()
+ if np.isnan(loss_data):
+ print('loss is nan while validating')
+ return loss_data, pred
+
+ def validate(self):
+ '''validate for whole validation dataset
+ '''
+ training = self.model.training
+ self.model.eval()
+ val_loss_list = []
+ for batch_idx, data_pack in tqdm.tqdm(
+ enumerate(self.val_loader), total=len(self.val_loader),
+ desc='Valid iteration=%d' % self.iteration, ncols=80,
+ leave=False):
+ loss_data, pred = self.validate_batch(data_pack)
+ val_loss_list.append(loss_data)
+ mean_loss = np.array(val_loss_list).mean()
+ validation_data = {'val_loss': mean_loss,
+ 'pred': pred,
+ }
+ self.push_validation_data(data_pack, validation_data)
+ self.save_model()
+ if training:
+ self.model.train()
+
+ def save_model(self):
+ torch.save({
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'optim_state_dict': self.optim.state_dict(),
+ 'model_state_dict': self.model.state_dict(),
+ }, osp.join(self.out, 'checkpoint.pth.tar'))
+ if self.iteration % (10 * self.valid_iter) == 0:
+ torch.save({
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'optim_state_dict': self.optim.state_dict(),
+ 'model_state_dict': self.model.state_dict(),
+ }, osp.join(self.out, f'{self.iteration}_checkpoint.pth.tar'))
+
+ def draw_corrs(self, imgs, corrs, col=(255, 0, 0)):
+ imgs = utils.torch_img_to_np_img(imgs)
+ out = []
+ for img, corr in zip(imgs, corrs):
+ img = np.interp(img, [img.min(), img.max()], [0, 255]).astype(np.uint8)
+ img = Image.fromarray(img)
+ draw = ImageDraw.Draw(img)
+ corr *= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
+ for c in corr:
+ draw.line(c, fill=col)
+ out.append(np.array(img))
+ out = np.array(out) / 255.0
+ return utils.np_img_to_torch_img(out)
+
+ def push_validation_data(self, data_pack, validation_data):
+ val_loss = validation_data['val_loss']
+ pred_corrs = np.concatenate([data_pack['queries'].numpy(), validation_data['pred'].cpu().numpy()], axis=-1)
+ pred_corrs = self.draw_corrs(data_pack['image'], pred_corrs)
+ gt_corrs = np.concatenate([data_pack['queries'].numpy(), data_pack['targets'].cpu().numpy()], axis=-1)
+ gt_corrs = self.draw_corrs(data_pack['image'], gt_corrs, (0, 255, 0))
+
+ gt_img = vutils.make_grid(gt_corrs, normalize=True, scale_each=True)
+ pred_img = vutils.make_grid(pred_corrs, normalize=True, scale_each=True)
+ tb_datapack = tensorboard_helper.TensorboardDatapack()
+ tb_datapack.set_training(False)
+ tb_datapack.set_iteration(self.iteration)
+ tb_datapack.add_scalar({'loss/val': val_loss})
+ tb_datapack.add_image({'image/gt_corrs': gt_img})
+ tb_datapack.add_image({'image/pred_corrs': pred_img})
+ self.tb_pusher.push_to_tensorboard(tb_datapack)
+
+ def train_batch(self, data_pack):
+ '''train for one batch of data
+ '''
+ img = data_pack['image'].cuda()
+ query = data_pack['queries'].cuda()
+ target = data_pack['targets'].cuda()
+
+ self.optim.zero_grad()
+ pred = self.model(img, query)['pred_corrs']
+ loss = torch.nn.functional.mse_loss(pred, target)
+ if self.opt.cycle_consis and self.opt.bidirectional:
+ cycle = self.model(img, pred)['pred_corrs']
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
+ if mask.sum() > 0:
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
+ loss += cycle_loss
+ elif self.opt.cycle_consis and not self.opt.bidirectional:
+ img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1)
+ query_reverse = pred.clone()
+ query_reverse[..., 0] = query_reverse[..., 0] - 0.5
+ cycle = self.model(img_reverse, query_reverse)['pred_corrs']
+ cycle[..., 0] = cycle[..., 0] - 0.5
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
+ if mask.sum() > 0:
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
+ loss += cycle_loss
+ loss_data = loss.data.item()
+ if np.isnan(loss_data):
+ print('loss is nan during training')
+ self.optim.zero_grad()
+ else:
+ loss.backward()
+ self.push_training_data(data_pack, pred, target, loss)
+ self.optim.step()
+
+ def push_training_data(self, data_pack, pred, target, loss):
+ tb_datapack = tensorboard_helper.TensorboardDatapack()
+ tb_datapack.set_training(True)
+ tb_datapack.set_iteration(self.iteration)
+ tb_datapack.add_histogram({'distribution/pred': pred})
+ tb_datapack.add_histogram({'distribution/target': target})
+ tb_datapack.add_scalar({'loss/train': loss})
+ self.tb_pusher.push_to_tensorboard(tb_datapack)
+
+ def resume(self):
+ '''resume training:
+ resume from the recorded epoch, iteration, and saved weights.
+ resume from the model with the same name.
+
+ Arguments:
+ opt {[type]} -- [description]
+ '''
+ if hasattr(self.opt, 'load_weights'):
+ assert self.opt.load_weights is None or self.opt.load_weights == False
+ # 1. load check point
+ checkpoint_path = os.path.join(self.opt.out, 'checkpoint.pth.tar')
+ if os.path.isfile(checkpoint_path):
+ checkpoint = torch.load(checkpoint_path)
+ else:
+ raise FileNotFoundError(
+ 'model check point cannnot found: {0}'.format(checkpoint_path))
+ # 2. load data
+ self.epoch = checkpoint['epoch']
+ self.iteration = checkpoint['iteration']
+ self.load_pretrained_weights()
+ self.optim.load_state_dict(checkpoint['optim_state_dict'])
+
+ def load_pretrained_weights(self):
+ '''
+ load pretrained weights from another model
+ '''
+ # if hasattr(self.opt, 'resume'):
+ # assert self.opt.resume is False
+ assert os.path.isfile(self.opt.load_weights_path), self.opt.load_weights_path
+
+ saved_weights = torch.load(self.opt.load_weights_path)['model_state_dict']
+ utils.safe_load_weights(self.model, saved_weights)
+ content_list = []
+ content_list += [f'Loaded pretrained weights from {self.opt.load_weights_path}']
+ utils.print_notification(content_list)
diff --git a/imcui/third_party/COTR/COTR/trainers/tensorboard_helper.py b/imcui/third_party/COTR/COTR/trainers/tensorboard_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4550fe05c300b1b6c2729c8677e7f402c69fb8e7
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/trainers/tensorboard_helper.py
@@ -0,0 +1,97 @@
+import abc
+
+import tensorboardX
+
+
+class TensorboardDatapack():
+ '''data dictionary for pushing to tb
+ '''
+
+ def __init__(self):
+ self.SCALAR_NAME = 'scalar'
+ self.HISTOGRAM_NAME = 'histogram'
+ self.IMAGE_NAME = 'image'
+ self.TEXT_NAME = 'text'
+ self.datapack = {}
+ self.datapack[self.SCALAR_NAME] = {}
+ self.datapack[self.HISTOGRAM_NAME] = {}
+ self.datapack[self.IMAGE_NAME] = {}
+ self.datapack[self.TEXT_NAME] = {}
+
+ def set_training(self, training):
+ self.training = training
+
+ def set_iteration(self, iteration):
+ self.iteration = iteration
+
+ def add_scalar(self, scalar_dict):
+ self.datapack[self.SCALAR_NAME].update(scalar_dict)
+
+ def add_histogram(self, histogram_dict):
+ self.datapack[self.HISTOGRAM_NAME].update(histogram_dict)
+
+ def add_image(self, image_dict):
+ self.datapack[self.IMAGE_NAME].update(image_dict)
+
+ def add_text(self, text_dict):
+ self.datapack[self.TEXT_NAME].update(text_dict)
+
+
+class TensorboardHelperBase(abc.ABC):
+ '''abstract base class for tb helpers
+ '''
+
+ def __init__(self, tb_writer):
+ self.tb_writer = tb_writer
+
+ @abc.abstractmethod
+ def add_data(self, tb_datapack):
+ pass
+
+
+class TensorboardScalarHelper(TensorboardHelperBase):
+ def add_data(self, tb_datapack):
+ scalar_dict = tb_datapack.datapack[tb_datapack.SCALAR_NAME]
+ for key, val in scalar_dict.items():
+ self.tb_writer.add_scalar(
+ key, val, global_step=tb_datapack.iteration)
+
+
+class TensorboardHistogramHelper(TensorboardHelperBase):
+ def add_data(self, tb_datapack):
+ histogram_dict = tb_datapack.datapack[tb_datapack.HISTOGRAM_NAME]
+ for key, val in histogram_dict.items():
+ self.tb_writer.add_histogram(
+ key, val, global_step=tb_datapack.iteration)
+
+
+class TensorboardImageHelper(TensorboardHelperBase):
+ def add_data(self, tb_datapack):
+ image_dict = tb_datapack.datapack[tb_datapack.IMAGE_NAME]
+ for key, val in image_dict.items():
+ self.tb_writer.add_image(
+ key, val, global_step=tb_datapack.iteration)
+
+
+class TensorboardTextHelper(TensorboardHelperBase):
+ def add_data(self, tb_datapack):
+ text_dict = tb_datapack.datapack[tb_datapack.TEXT_NAME]
+ for key, val in text_dict.items():
+ self.tb_writer.add_text(
+ key, val, global_step=tb_datapack.iteration)
+
+
+class TensorboardPusher():
+ def __init__(self, opt):
+ self.tb_writer = tensorboardX.SummaryWriter(opt.tb_out)
+ scalar_helper = TensorboardScalarHelper(self.tb_writer)
+ histogram_helper = TensorboardHistogramHelper(self.tb_writer)
+ image_helper = TensorboardImageHelper(self.tb_writer)
+ text_helper = TensorboardTextHelper(self.tb_writer)
+ self.helper_list = [scalar_helper,
+ histogram_helper, image_helper, text_helper]
+
+ def push_to_tensorboard(self, tb_datapack):
+ for helper in self.helper_list:
+ helper.add_data(tb_datapack)
+ self.tb_writer.flush()
diff --git a/imcui/third_party/COTR/COTR/transformations/transform_basics.py b/imcui/third_party/COTR/COTR/transformations/transform_basics.py
new file mode 100644
index 0000000000000000000000000000000000000000..26cdb8068cfb857fba6c680724013c0d4b4721da
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/transformations/transform_basics.py
@@ -0,0 +1,114 @@
+import numpy as np
+
+from COTR.transformations import transformations
+from COTR.utils import constants
+
+
+class Rotation():
+ def __init__(self, quat):
+ """
+ quaternion format (w, x, y, z)
+ """
+ assert quat.dtype == np.float32
+ self.quaternion = quat
+
+ def __str__(self):
+ string = '{0}'.format(self.quaternion)
+ return string
+
+ @classmethod
+ def from_matrix(cls, mat):
+ assert isinstance(mat, np.ndarray)
+ if mat.shape == (3, 3):
+ id_mat = np.eye(4)
+ id_mat[0:3, 0:3] = mat
+ mat = id_mat
+ assert mat.shape == (4, 4)
+ quat = transformations.quaternion_from_matrix(mat).astype(constants.DEFAULT_PRECISION)
+ return cls(quat)
+
+ @property
+ def rotation_matrix(self):
+ return transformations.quaternion_matrix(self.quaternion).astype(constants.DEFAULT_PRECISION)
+
+ @rotation_matrix.setter
+ def rotation_matrix(self, mat):
+ assert isinstance(mat, np.ndarray)
+ assert mat.shape == (4, 4)
+ quat = transformations.quaternion_from_matrix(mat)
+ self.quaternion = quat
+
+ @property
+ def quaternion(self):
+ assert isinstance(self._quaternion, np.ndarray)
+ assert self._quaternion.shape == (4,)
+ assert np.isclose(np.linalg.norm(self._quaternion), 1.0), 'self._quaternion is not normalized or valid'
+ return self._quaternion
+
+ @quaternion.setter
+ def quaternion(self, quat):
+ assert isinstance(quat, np.ndarray)
+ assert quat.shape == (4,)
+ if not np.isclose(np.linalg.norm(quat), 1.0):
+ print(f'WARNING: normalizing the input quatternion to unit quaternion: {np.linalg.norm(quat)}')
+ quat = quat / np.linalg.norm(quat)
+ assert np.isclose(np.linalg.norm(quat), 1.0), f'input quaternion is not normalized or valid: {quat}'
+ self._quaternion = quat
+
+
+class UnstableRotation():
+ def __init__(self, mat):
+ assert isinstance(mat, np.ndarray)
+ if mat.shape == (3, 3):
+ id_mat = np.eye(4)
+ id_mat[0:3, 0:3] = mat
+ mat = id_mat
+ assert mat.shape == (4, 4)
+ mat[:3, 3] = 0
+ self._rotation_matrix = mat
+
+ def __str__(self):
+ string = f'rotation_matrix: {self.rotation_matrix}'
+ return string
+
+ @property
+ def rotation_matrix(self):
+ return self._rotation_matrix
+
+
+class Translation():
+ def __init__(self, vec):
+ assert vec.dtype == np.float32
+ self.translation_vector = vec
+
+ def __str__(self):
+ string = '{0}'.format(self.translation_vector)
+ return string
+
+ @classmethod
+ def from_matrix(cls, mat):
+ assert isinstance(mat, np.ndarray)
+ assert mat.shape == (4, 4)
+ vec = transformations.translation_from_matrix(mat)
+ return cls(vec)
+
+ @property
+ def translation_matrix(self):
+ return transformations.translation_matrix(self.translation_vector).astype(constants.DEFAULT_PRECISION)
+
+ @translation_matrix.setter
+ def translation_matrix(self, mat):
+ assert isinstance(mat, np.ndarray)
+ assert mat.shape == (4, 4)
+ vec = transformations.translation_from_matrix(mat)
+ self.translation_vector = vec
+
+ @property
+ def translation_vector(self):
+ return self._translation_vector
+
+ @translation_vector.setter
+ def translation_vector(self, vec):
+ assert isinstance(vec, np.ndarray)
+ assert vec.shape == (3,)
+ self._translation_vector = vec
diff --git a/imcui/third_party/COTR/COTR/transformations/transformations.py b/imcui/third_party/COTR/COTR/transformations/transformations.py
new file mode 100644
index 0000000000000000000000000000000000000000..809ce6683c27e641de5b845ac3b5c53d6a3167f6
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/transformations/transformations.py
@@ -0,0 +1,1951 @@
+# -*- coding: utf-8 -*-
+# transformations.py
+
+# Copyright (c) 2006-2019, Christoph Gohlke
+# Copyright (c) 2006-2019, The Regents of the University of California
+# Produced at the Laboratory for Fluorescence Dynamics
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+"""Homogeneous Transformation Matrices and Quaternions.
+
+Transformations is a Python library for calculating 4x4 matrices for
+translating, rotating, reflecting, scaling, shearing, projecting,
+orthogonalizing, and superimposing arrays of 3D homogeneous coordinates
+as well as for converting between rotation matrices, Euler angles,
+and quaternions. Also includes an Arcball control object and
+functions to decompose transformation matrices.
+
+:Author:
+ `Christoph Gohlke `_
+
+:Organization:
+ Laboratory for Fluorescence Dynamics. University of California, Irvine
+
+:License: 3-clause BSD
+
+:Version: 2019.2.20
+
+Requirements
+------------
+* `CPython 2.7 or 3.5+ `_
+* `Numpy 1.14 `_
+* A Python distutils compatible C compiler (build)
+
+Revisions
+---------
+2019.1.1
+ Update copyright year.
+
+Notes
+-----
+Transformations.py is no longer actively developed and has a few known issues
+and numerical instabilities. The module is mostly superseded by other modules
+for 3D transformations and quaternions:
+
+* `Scipy.spatial.transform `_
+* `Transforms3d `_
+ (includes most code of this module)
+* `Numpy-quaternion `_
+* `Blender.mathutils `_
+
+The API is not stable yet and is expected to change between revisions.
+
+Python 2.7 and 3.4 are deprecated.
+
+This Python code is not optimized for speed. Refer to the transformations.c
+module for a faster implementation of some functions.
+
+Documentation in HTML format can be generated with epydoc.
+
+Matrices (M) can be inverted using numpy.linalg.inv(M), be concatenated using
+numpy.dot(M0, M1), or transform homogeneous coordinate arrays (v) using
+numpy.dot(M, v) for shape (4, \*) column vectors, respectively
+numpy.dot(v, M.T) for shape (\*, 4) row vectors ("array of points").
+
+This module follows the "column vectors on the right" and "row major storage"
+(C contiguous) conventions. The translation components are in the right column
+of the transformation matrix, i.e. M[:3, 3].
+The transpose of the transformation matrices may have to be used to interface
+with other graphics systems, e.g. OpenGL's glMultMatrixd(). See also [16].
+
+Calculations are carried out with numpy.float64 precision.
+
+Vector, point, quaternion, and matrix function arguments are expected to be
+"array like", i.e. tuple, list, or numpy arrays.
+
+Return types are numpy arrays unless specified otherwise.
+
+Angles are in radians unless specified otherwise.
+
+Quaternions w+ix+jy+kz are represented as [w, x, y, z].
+
+A triple of Euler angles can be applied/interpreted in 24 ways, which can
+be specified using a 4 character string or encoded 4-tuple:
+
+ *Axes 4-string*: e.g. 'sxyz' or 'ryxy'
+
+ - first character : rotations are applied to 's'tatic or 'r'otating frame
+ - remaining characters : successive rotation axis 'x', 'y', or 'z'
+
+ *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1)
+
+ - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix.
+ - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed
+ by 'z', or 'z' is followed by 'x'. Otherwise odd (1).
+ - repetition : first and last axis are same (1) or different (0).
+ - frame : rotations are applied to static (0) or rotating (1) frame.
+
+References
+----------
+(1) Matrices and transformations. Ronald Goldman.
+ In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990.
+(2) More matrices and transformations: shear and pseudo-perspective.
+ Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
+(3) Decomposing a matrix into simple transformations. Spencer Thomas.
+ In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
+(4) Recovering the data from the transformation matrix. Ronald Goldman.
+ In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991.
+(5) Euler angle conversion. Ken Shoemake.
+ In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994.
+(6) Arcball rotation control. Ken Shoemake.
+ In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994.
+(7) Representing attitude: Euler angles, unit quaternions, and rotation
+ vectors. James Diebel. 2006.
+(8) A discussion of the solution for the best rotation to relate two sets
+ of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828.
+(9) Closed-form solution of absolute orientation using unit quaternions.
+ BKP Horn. J Opt Soc Am A. 1987. 4(4):629-642.
+(10) Quaternions. Ken Shoemake.
+ http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf
+(11) From quaternion to matrix and back. JMP van Waveren. 2005.
+ http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm
+(12) Uniform random rotations. Ken Shoemake.
+ In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992.
+(13) Quaternion in molecular modeling. CFF Karney.
+ J Mol Graph Mod, 25(5):595-604
+(14) New method for extracting the quaternion from a rotation matrix.
+ Itzhack Y Bar-Itzhack, J Guid Contr Dynam. 2000. 23(6): 1085-1087.
+(15) Multiple View Geometry in Computer Vision. Hartley and Zissermann.
+ Cambridge University Press; 2nd Ed. 2004. Chapter 4, Algorithm 4.7, p 130.
+(16) Column Vectors vs. Row Vectors.
+ http://steve.hollasch.net/cgindex/math/matrix/column-vec.html
+
+Examples
+--------
+>>> alpha, beta, gamma = 0.123, -1.234, 2.345
+>>> origin, xaxis, yaxis, zaxis = [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]
+>>> I = identity_matrix()
+>>> Rx = rotation_matrix(alpha, xaxis)
+>>> Ry = rotation_matrix(beta, yaxis)
+>>> Rz = rotation_matrix(gamma, zaxis)
+>>> R = concatenate_matrices(Rx, Ry, Rz)
+>>> euler = euler_from_matrix(R, 'rxyz')
+>>> numpy.allclose([alpha, beta, gamma], euler)
+True
+>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz')
+>>> is_same_transform(R, Re)
+True
+>>> al, be, ga = euler_from_matrix(Re, 'rxyz')
+>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz'))
+True
+>>> qx = quaternion_about_axis(alpha, xaxis)
+>>> qy = quaternion_about_axis(beta, yaxis)
+>>> qz = quaternion_about_axis(gamma, zaxis)
+>>> q = quaternion_multiply(qx, qy)
+>>> q = quaternion_multiply(q, qz)
+>>> Rq = quaternion_matrix(q)
+>>> is_same_transform(R, Rq)
+True
+>>> S = scale_matrix(1.23, origin)
+>>> T = translation_matrix([1, 2, 3])
+>>> Z = shear_matrix(beta, xaxis, origin, zaxis)
+>>> R = random_rotation_matrix(numpy.random.rand(3))
+>>> M = concatenate_matrices(T, R, Z, S)
+>>> scale, shear, angles, trans, persp = decompose_matrix(M)
+>>> numpy.allclose(scale, 1.23)
+True
+>>> numpy.allclose(trans, [1, 2, 3])
+True
+>>> numpy.allclose(shear, [0, math.tan(beta), 0])
+True
+>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles))
+True
+>>> M1 = compose_matrix(scale, shear, angles, trans, persp)
+>>> is_same_transform(M, M1)
+True
+>>> v0, v1 = random_vector(3), random_vector(3)
+>>> M = rotation_matrix(angle_between_vectors(v0, v1), vector_product(v0, v1))
+>>> v2 = numpy.dot(v0, M[:3,:3].T)
+>>> numpy.allclose(unit_vector(v1), unit_vector(v2))
+True
+
+"""
+
+from __future__ import division, print_function
+
+__version__ = '2019.2.20'
+__docformat__ = 'restructuredtext en'
+
+import math
+
+import numpy
+
+
+def identity_matrix():
+ """Return 4x4 identity/unit matrix.
+
+ >>> I = identity_matrix()
+ >>> numpy.allclose(I, numpy.dot(I, I))
+ True
+ >>> numpy.sum(I), numpy.trace(I)
+ (4.0, 4.0)
+ >>> numpy.allclose(I, numpy.identity(4))
+ True
+
+ """
+ return numpy.identity(4)
+
+
+def translation_matrix(direction):
+ """Return matrix to translate by direction vector.
+
+ >>> v = numpy.random.random(3) - 0.5
+ >>> numpy.allclose(v, translation_matrix(v)[:3, 3])
+ True
+
+ """
+ M = numpy.identity(4)
+ M[:3, 3] = direction[:3]
+ return M
+
+
+def translation_from_matrix(matrix):
+ """Return translation vector from translation matrix.
+
+ >>> v0 = numpy.random.random(3) - 0.5
+ >>> v1 = translation_from_matrix(translation_matrix(v0))
+ >>> numpy.allclose(v0, v1)
+ True
+
+ """
+ return numpy.array(matrix, copy=False)[:3, 3].copy()
+
+
+def reflection_matrix(point, normal):
+ """Return matrix to mirror at plane defined by point and normal vector.
+
+ >>> v0 = numpy.random.random(4) - 0.5
+ >>> v0[3] = 1.
+ >>> v1 = numpy.random.random(3) - 0.5
+ >>> R = reflection_matrix(v0, v1)
+ >>> numpy.allclose(2, numpy.trace(R))
+ True
+ >>> numpy.allclose(v0, numpy.dot(R, v0))
+ True
+ >>> v2 = v0.copy()
+ >>> v2[:3] += v1
+ >>> v3 = v0.copy()
+ >>> v2[:3] -= v1
+ >>> numpy.allclose(v2, numpy.dot(R, v3))
+ True
+
+ """
+ normal = unit_vector(normal[:3])
+ M = numpy.identity(4)
+ M[:3, :3] -= 2.0 * numpy.outer(normal, normal)
+ M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal
+ return M
+
+
+def reflection_from_matrix(matrix):
+ """Return mirror plane point and normal vector from reflection matrix.
+
+ >>> v0 = numpy.random.random(3) - 0.5
+ >>> v1 = numpy.random.random(3) - 0.5
+ >>> M0 = reflection_matrix(v0, v1)
+ >>> point, normal = reflection_from_matrix(M0)
+ >>> M1 = reflection_matrix(point, normal)
+ >>> is_same_transform(M0, M1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ # normal: unit eigenvector corresponding to eigenvalue -1
+ w, V = numpy.linalg.eig(M[:3, :3])
+ i = numpy.where(abs(numpy.real(w) + 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError('no unit eigenvector corresponding to eigenvalue -1')
+ normal = numpy.real(V[:, i[0]]).squeeze()
+ # point: any unit eigenvector corresponding to eigenvalue 1
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ return point, normal
+
+
+def rotation_matrix(angle, direction, point=None):
+ """Return matrix to rotate about axis defined by point and direction.
+
+ >>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0])
+ >>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1])
+ True
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
+ >>> direc = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(-angle, -direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+ >>> I = numpy.identity(4, numpy.float64)
+ >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
+ True
+ >>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2,
+ ... direc, point)))
+ True
+
+ """
+ sina = math.sin(angle)
+ cosa = math.cos(angle)
+ direction = unit_vector(direction[:3])
+ # rotation matrix around unit vector
+ R = numpy.diag([cosa, cosa, cosa])
+ R += numpy.outer(direction, direction) * (1.0 - cosa)
+ direction *= sina
+ R += numpy.array([[0.0, -direction[2], direction[1]],
+ [direction[2], 0.0, -direction[0]],
+ [-direction[1], direction[0], 0.0]])
+ M = numpy.identity(4)
+ M[:3, :3] = R
+ if point is not None:
+ # rotation not around origin
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
+ M[:3, 3] = point - numpy.dot(R, point)
+ return M
+
+
+def rotation_from_matrix(matrix):
+ """Return rotation angle and axis from rotation matrix.
+
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
+ >>> direc = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> angle, direc, point = rotation_from_matrix(R0)
+ >>> R1 = rotation_matrix(angle, direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+
+ """
+ R = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ R33 = R[:3, :3]
+ # direction: unit eigenvector of R33 corresponding to eigenvalue of 1
+ w, W = numpy.linalg.eig(R33.T)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
+ direction = numpy.real(W[:, i[-1]]).squeeze()
+ # point: unit eigenvector of R33 corresponding to eigenvalue of 1
+ w, Q = numpy.linalg.eig(R)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
+ point = numpy.real(Q[:, i[-1]]).squeeze()
+ point /= point[3]
+ # rotation angle depending on direction
+ cosa = (numpy.trace(R33) - 1.0) / 2.0
+ if abs(direction[2]) > 1e-8:
+ sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
+ elif abs(direction[1]) > 1e-8:
+ sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
+ else:
+ sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
+ angle = math.atan2(sina, cosa)
+ return angle, direction, point
+
+
+def scale_matrix(factor, origin=None, direction=None):
+ """Return matrix to scale by factor around origin in direction.
+
+ Use factor -1 for point symmetry.
+
+ >>> v = (numpy.random.rand(4, 5) - 0.5) * 20
+ >>> v[3] = 1
+ >>> S = scale_matrix(-1.234)
+ >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3])
+ True
+ >>> factor = random.random() * 10 - 5
+ >>> origin = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> S = scale_matrix(factor, origin)
+ >>> S = scale_matrix(factor, origin, direct)
+
+ """
+ if direction is None:
+ # uniform scaling
+ M = numpy.diag([factor, factor, factor, 1.0])
+ if origin is not None:
+ M[:3, 3] = origin[:3]
+ M[:3, 3] *= 1.0 - factor
+ else:
+ # nonuniform scaling
+ direction = unit_vector(direction[:3])
+ factor = 1.0 - factor
+ M = numpy.identity(4)
+ M[:3, :3] -= factor * numpy.outer(direction, direction)
+ if origin is not None:
+ M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction
+ return M
+
+
+def scale_from_matrix(matrix):
+ """Return scaling factor, origin and direction from scaling matrix.
+
+ >>> factor = random.random() * 10 - 5
+ >>> origin = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> S0 = scale_matrix(factor, origin)
+ >>> factor, origin, direction = scale_from_matrix(S0)
+ >>> S1 = scale_matrix(factor, origin, direction)
+ >>> is_same_transform(S0, S1)
+ True
+ >>> S0 = scale_matrix(factor, origin, direct)
+ >>> factor, origin, direction = scale_from_matrix(S0)
+ >>> S1 = scale_matrix(factor, origin, direction)
+ >>> is_same_transform(S0, S1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ factor = numpy.trace(M33) - 2.0
+ try:
+ # direction: unit eigenvector corresponding to eigenvalue factor
+ w, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(w) - factor) < 1e-8)[0][0]
+ direction = numpy.real(V[:, i]).squeeze()
+ direction /= vector_norm(direction)
+ except IndexError:
+ # uniform scaling
+ factor = (factor + 2.0) / 3.0
+ direction = None
+ # origin: any eigenvector corresponding to eigenvalue 1
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError('no eigenvector corresponding to eigenvalue 1')
+ origin = numpy.real(V[:, i[-1]]).squeeze()
+ origin /= origin[3]
+ return factor, origin, direction
+
+
+def projection_matrix(point, normal, direction=None,
+ perspective=None, pseudo=False):
+ """Return matrix to project onto plane defined by point and normal.
+
+ Using either perspective point, projection direction, or none of both.
+
+ If pseudo is True, perspective projections will preserve relative depth
+ such that Perspective = dot(Orthogonal, PseudoPerspective).
+
+ >>> P = projection_matrix([0, 0, 0], [1, 0, 0])
+ >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:])
+ True
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(3) - 0.5
+ >>> P0 = projection_matrix(point, normal)
+ >>> P1 = projection_matrix(point, normal, direction=direct)
+ >>> P2 = projection_matrix(point, normal, perspective=persp)
+ >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True)
+ >>> is_same_transform(P2, numpy.dot(P0, P3))
+ True
+ >>> P = projection_matrix([3, 0, 0], [1, 1, 0], [1, 0, 0])
+ >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20
+ >>> v0[3] = 1
+ >>> v1 = numpy.dot(P, v0)
+ >>> numpy.allclose(v1[1], v0[1])
+ True
+ >>> numpy.allclose(v1[0], 3-v1[1])
+ True
+
+ """
+ M = numpy.identity(4)
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
+ normal = unit_vector(normal[:3])
+ if perspective is not None:
+ # perspective projection
+ perspective = numpy.array(perspective[:3], dtype=numpy.float64,
+ copy=False)
+ M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
+ M[:3, :3] -= numpy.outer(perspective, normal)
+ if pseudo:
+ # preserve relative depth
+ M[:3, :3] -= numpy.outer(normal, normal)
+ M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
+ else:
+ M[:3, 3] = numpy.dot(point, normal) * perspective
+ M[3, :3] = -normal
+ M[3, 3] = numpy.dot(perspective, normal)
+ elif direction is not None:
+ # parallel projection
+ direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False)
+ scale = numpy.dot(direction, normal)
+ M[:3, :3] -= numpy.outer(direction, normal) / scale
+ M[:3, 3] = direction * (numpy.dot(point, normal) / scale)
+ else:
+ # orthogonal projection
+ M[:3, :3] -= numpy.outer(normal, normal)
+ M[:3, 3] = numpy.dot(point, normal) * normal
+ return M
+
+
+def projection_from_matrix(matrix, pseudo=False):
+ """Return projection plane and perspective point from projection matrix.
+
+ Return values are same as arguments for projection_matrix function:
+ point, normal, direction, perspective, and pseudo.
+
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(3) - 0.5
+ >>> P0 = projection_matrix(point, normal)
+ >>> result = projection_from_matrix(P0)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, direct)
+ >>> result = projection_from_matrix(P0)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False)
+ >>> result = projection_from_matrix(P0, pseudo=False)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True)
+ >>> result = projection_from_matrix(P0, pseudo=True)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not pseudo and len(i):
+ # point: any eigenvector corresponding to eigenvalue 1
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ # direction: unit eigenvector corresponding to eigenvalue 0
+ w, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
+ if not len(i):
+ raise ValueError('no eigenvector corresponding to eigenvalue 0')
+ direction = numpy.real(V[:, i[0]]).squeeze()
+ direction /= vector_norm(direction)
+ # normal: unit eigenvector of M33.T corresponding to eigenvalue 0
+ w, V = numpy.linalg.eig(M33.T)
+ i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
+ if len(i):
+ # parallel projection
+ normal = numpy.real(V[:, i[0]]).squeeze()
+ normal /= vector_norm(normal)
+ return point, normal, direction, None, False
+ else:
+ # orthogonal projection, where normal equals direction vector
+ return point, direction, None, None, False
+ else:
+ # perspective projection
+ i = numpy.where(abs(numpy.real(w)) > 1e-8)[0]
+ if not len(i):
+ raise ValueError(
+ 'no eigenvector not corresponding to eigenvalue 0')
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ normal = - M[3, :3]
+ perspective = M[:3, 3] / numpy.dot(point[:3], normal)
+ if pseudo:
+ perspective -= normal
+ return point, normal, None, perspective, pseudo
+
+
+def clip_matrix(left, right, bottom, top, near, far, perspective=False):
+ """Return matrix to obtain normalized device coordinates from frustum.
+
+ The frustum bounds are axis-aligned along x (left, right),
+ y (bottom, top) and z (near, far).
+
+ Normalized device coordinates are in range [-1, 1] if coordinates are
+ inside the frustum.
+
+ If perspective is True the frustum is a truncated pyramid with the
+ perspective point at origin and direction along z axis, otherwise an
+ orthographic canonical view volume (a box).
+
+ Homogeneous coordinates transformed by the perspective clip matrix
+ need to be dehomogenized (divided by w coordinate).
+
+ >>> frustum = numpy.random.rand(6)
+ >>> frustum[1] += frustum[0]
+ >>> frustum[3] += frustum[2]
+ >>> frustum[5] += frustum[4]
+ >>> M = clip_matrix(perspective=False, *frustum)
+ >>> numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
+ array([-1., -1., -1., 1.])
+ >>> numpy.dot(M, [frustum[1], frustum[3], frustum[5], 1])
+ array([ 1., 1., 1., 1.])
+ >>> M = clip_matrix(perspective=True, *frustum)
+ >>> v = numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
+ >>> v / v[3]
+ array([-1., -1., -1., 1.])
+ >>> v = numpy.dot(M, [frustum[1], frustum[3], frustum[4], 1])
+ >>> v / v[3]
+ array([ 1., 1., -1., 1.])
+
+ """
+ if left >= right or bottom >= top or near >= far:
+ raise ValueError('invalid frustum')
+ if perspective:
+ if near <= _EPS:
+ raise ValueError('invalid frustum: near <= 0')
+ t = 2.0 * near
+ M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0],
+ [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0],
+ [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)],
+ [0.0, 0.0, -1.0, 0.0]]
+ else:
+ M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)],
+ [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)],
+ [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)],
+ [0.0, 0.0, 0.0, 1.0]]
+ return numpy.array(M)
+
+
+def shear_matrix(angle, direction, point, normal):
+ """Return matrix to shear by angle along direction vector on shear plane.
+
+ The shear plane is defined by a point and normal vector. The direction
+ vector must be orthogonal to the plane's normal vector.
+
+ A point P is transformed by the shear matrix into P" such that
+ the vector P-P" is parallel to the direction vector and its extent is
+ given by the angle of P-P'-P", where P' is the orthogonal projection
+ of P onto the shear plane.
+
+ >>> angle = (random.random() - 0.5) * 4*math.pi
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
+ >>> S = shear_matrix(angle, direct, point, normal)
+ >>> numpy.allclose(1, numpy.linalg.det(S))
+ True
+
+ """
+ normal = unit_vector(normal[:3])
+ direction = unit_vector(direction[:3])
+ if abs(numpy.dot(normal, direction)) > 1e-6:
+ raise ValueError('direction and normal vectors are not orthogonal')
+ angle = math.tan(angle)
+ M = numpy.identity(4)
+ M[:3, :3] += angle * numpy.outer(direction, normal)
+ M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction
+ return M
+
+
+def shear_from_matrix(matrix):
+ """Return shear angle, direction and plane from shear matrix.
+
+ >>> angle = (random.random() - 0.5) * 4*math.pi
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
+ >>> S0 = shear_matrix(angle, direct, point, normal)
+ >>> angle, direct, point, normal = shear_from_matrix(S0)
+ >>> S1 = shear_matrix(angle, direct, point, normal)
+ >>> is_same_transform(S0, S1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ # normal: cross independent eigenvectors corresponding to the eigenvalue 1
+ w, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-4)[0]
+ if len(i) < 2:
+ raise ValueError('no two linear independent eigenvectors found %s' % w)
+ V = numpy.real(V[:, i]).squeeze().T
+ lenorm = -1.0
+ for i0, i1 in ((0, 1), (0, 2), (1, 2)):
+ n = numpy.cross(V[i0], V[i1])
+ w = vector_norm(n)
+ if w > lenorm:
+ lenorm = w
+ normal = n
+ normal /= lenorm
+ # direction and angle
+ direction = numpy.dot(M33 - numpy.identity(3), normal)
+ angle = vector_norm(direction)
+ direction /= angle
+ angle = math.atan(angle)
+ # point: eigenvector corresponding to eigenvalue 1
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError('no eigenvector corresponding to eigenvalue 1')
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ return angle, direction, point, normal
+
+
+def decompose_matrix(matrix):
+ """Return sequence of transformations from transformation matrix.
+
+ matrix : array_like
+ Non-degenerative homogeneous transformation matrix
+
+ Return tuple of:
+ scale : vector of 3 scaling factors
+ shear : list of shear factors for x-y, x-z, y-z axes
+ angles : list of Euler angles about static x, y, z axes
+ translate : translation vector along x, y, z axes
+ perspective : perspective partition of matrix
+
+ Raise ValueError if matrix is of wrong type or degenerative.
+
+ >>> T0 = translation_matrix([1, 2, 3])
+ >>> scale, shear, angles, trans, persp = decompose_matrix(T0)
+ >>> T1 = translation_matrix(trans)
+ >>> numpy.allclose(T0, T1)
+ True
+ >>> S = scale_matrix(0.123)
+ >>> scale, shear, angles, trans, persp = decompose_matrix(S)
+ >>> scale[0]
+ 0.123
+ >>> R0 = euler_matrix(1, 2, 3)
+ >>> scale, shear, angles, trans, persp = decompose_matrix(R0)
+ >>> R1 = euler_matrix(*angles)
+ >>> numpy.allclose(R0, R1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=True).T
+ if abs(M[3, 3]) < _EPS:
+ raise ValueError('M[3, 3] is zero')
+ M /= M[3, 3]
+ P = M.copy()
+ P[:, 3] = 0.0, 0.0, 0.0, 1.0
+ if not numpy.linalg.det(P):
+ raise ValueError('matrix is singular')
+
+ scale = numpy.zeros((3, ))
+ shear = [0.0, 0.0, 0.0]
+ angles = [0.0, 0.0, 0.0]
+
+ if any(abs(M[:3, 3]) > _EPS):
+ perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T))
+ M[:, 3] = 0.0, 0.0, 0.0, 1.0
+ else:
+ perspective = numpy.array([0.0, 0.0, 0.0, 1.0])
+
+ translate = M[3, :3].copy()
+ M[3, :3] = 0.0
+
+ row = M[:3, :3].copy()
+ scale[0] = vector_norm(row[0])
+ row[0] /= scale[0]
+ shear[0] = numpy.dot(row[0], row[1])
+ row[1] -= row[0] * shear[0]
+ scale[1] = vector_norm(row[1])
+ row[1] /= scale[1]
+ shear[0] /= scale[1]
+ shear[1] = numpy.dot(row[0], row[2])
+ row[2] -= row[0] * shear[1]
+ shear[2] = numpy.dot(row[1], row[2])
+ row[2] -= row[1] * shear[2]
+ scale[2] = vector_norm(row[2])
+ row[2] /= scale[2]
+ shear[1:] /= scale[2]
+
+ if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0:
+ numpy.negative(scale, scale)
+ numpy.negative(row, row)
+
+ angles[1] = math.asin(-row[0, 2])
+ if math.cos(angles[1]):
+ angles[0] = math.atan2(row[1, 2], row[2, 2])
+ angles[2] = math.atan2(row[0, 1], row[0, 0])
+ else:
+ # angles[0] = math.atan2(row[1, 0], row[1, 1])
+ angles[0] = math.atan2(-row[2, 1], row[1, 1])
+ angles[2] = 0.0
+
+ return scale, shear, angles, translate, perspective
+
+
+def compose_matrix(scale=None, shear=None, angles=None, translate=None,
+ perspective=None):
+ """Return transformation matrix from sequence of transformations.
+
+ This is the inverse of the decompose_matrix function.
+
+ Sequence of transformations:
+ scale : vector of 3 scaling factors
+ shear : list of shear factors for x-y, x-z, y-z axes
+ angles : list of Euler angles about static x, y, z axes
+ translate : translation vector along x, y, z axes
+ perspective : perspective partition of matrix
+
+ >>> scale = numpy.random.random(3) - 0.5
+ >>> shear = numpy.random.random(3) - 0.5
+ >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi)
+ >>> trans = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(4) - 0.5
+ >>> M0 = compose_matrix(scale, shear, angles, trans, persp)
+ >>> result = decompose_matrix(M0)
+ >>> M1 = compose_matrix(*result)
+ >>> is_same_transform(M0, M1)
+ True
+
+ """
+ M = numpy.identity(4)
+ if perspective is not None:
+ P = numpy.identity(4)
+ P[3, :] = perspective[:4]
+ M = numpy.dot(M, P)
+ if translate is not None:
+ T = numpy.identity(4)
+ T[:3, 3] = translate[:3]
+ M = numpy.dot(M, T)
+ if angles is not None:
+ R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
+ M = numpy.dot(M, R)
+ if shear is not None:
+ Z = numpy.identity(4)
+ Z[1, 2] = shear[2]
+ Z[0, 2] = shear[1]
+ Z[0, 1] = shear[0]
+ M = numpy.dot(M, Z)
+ if scale is not None:
+ S = numpy.identity(4)
+ S[0, 0] = scale[0]
+ S[1, 1] = scale[1]
+ S[2, 2] = scale[2]
+ M = numpy.dot(M, S)
+ M /= M[3, 3]
+ return M
+
+
+def orthogonalization_matrix(lengths, angles):
+ """Return orthogonalization matrix for crystallographic cell coordinates.
+
+ Angles are expected in degrees.
+
+ The de-orthogonalization matrix is the inverse.
+
+ >>> O = orthogonalization_matrix([10, 10, 10], [90, 90, 90])
+ >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10)
+ True
+ >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7])
+ >>> numpy.allclose(numpy.sum(O), 43.063229)
+ True
+
+ """
+ a, b, c = lengths
+ angles = numpy.radians(angles)
+ sina, sinb, _ = numpy.sin(angles)
+ cosa, cosb, cosg = numpy.cos(angles)
+ co = (cosa * cosb - cosg) / (sina * sinb)
+ return numpy.array([
+ [a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0],
+ [-a*sinb*co, b*sina, 0.0, 0.0],
+ [a*cosb, b*cosa, c, 0.0],
+ [0.0, 0.0, 0.0, 1.0]])
+
+
+def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
+ """Return affine transform matrix to register two point sets.
+
+ v0 and v1 are shape (ndims, \*) arrays of at least ndims non-homogeneous
+ coordinates, where ndims is the dimensionality of the coordinate space.
+
+ If shear is False, a similarity transformation matrix is returned.
+ If also scale is False, a rigid/Euclidean transformation matrix
+ is returned.
+
+ By default the algorithm by Hartley and Zissermann [15] is used.
+ If usesvd is True, similarity and Euclidean transformation matrices
+ are calculated by minimizing the weighted sum of squared deviations
+ (RMSD) according to the algorithm by Kabsch [8].
+ Otherwise, and if ndims is 3, the quaternion based algorithm by Horn [9]
+ is used, which is slower when using this Python implementation.
+
+ The returned matrix performs rotation, translation and uniform scaling
+ (if specified).
+
+ >>> v0 = [[0, 1031, 1031, 0], [0, 0, 1600, 1600]]
+ >>> v1 = [[675, 826, 826, 677], [55, 52, 281, 277]]
+ >>> affine_matrix_from_points(v0, v1)
+ array([[ 0.14549, 0.00062, 675.50008],
+ [ 0.00048, 0.14094, 53.24971],
+ [ 0. , 0. , 1. ]])
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
+ >>> R = random_rotation_matrix(numpy.random.random(3))
+ >>> S = scale_matrix(random.random())
+ >>> M = concatenate_matrices(T, R, S)
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
+ >>> v0[3] = 1
+ >>> v1 = numpy.dot(M, v0)
+ >>> v0[:3] += numpy.random.normal(0, 1e-8, 300).reshape(3, -1)
+ >>> M = affine_matrix_from_points(v0[:3], v1[:3])
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+
+ More examples in superimposition_matrix()
+
+ """
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=True)
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=True)
+
+ ndims = v0.shape[0]
+ if ndims < 2 or v0.shape[1] < ndims or v0.shape != v1.shape:
+ raise ValueError('input arrays are of wrong shape or type')
+
+ # move centroids to origin
+ t0 = -numpy.mean(v0, axis=1)
+ M0 = numpy.identity(ndims+1)
+ M0[:ndims, ndims] = t0
+ v0 += t0.reshape(ndims, 1)
+ t1 = -numpy.mean(v1, axis=1)
+ M1 = numpy.identity(ndims+1)
+ M1[:ndims, ndims] = t1
+ v1 += t1.reshape(ndims, 1)
+
+ if shear:
+ # Affine transformation
+ A = numpy.concatenate((v0, v1), axis=0)
+ u, s, vh = numpy.linalg.svd(A.T)
+ vh = vh[:ndims].T
+ B = vh[:ndims]
+ C = vh[ndims:2*ndims]
+ t = numpy.dot(C, numpy.linalg.pinv(B))
+ t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1)
+ M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,)))
+ elif usesvd or ndims != 3:
+ # Rigid transformation via SVD of covariance matrix
+ u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
+ # rotation matrix from SVD orthonormal bases
+ R = numpy.dot(u, vh)
+ if numpy.linalg.det(R) < 0.0:
+ # R does not constitute right handed system
+ R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0)
+ s[-1] *= -1.0
+ # homogeneous transformation matrix
+ M = numpy.identity(ndims+1)
+ M[:ndims, :ndims] = R
+ else:
+ # Rigid transformation matrix via quaternion
+ # compute symmetric matrix N
+ xx, yy, zz = numpy.sum(v0 * v1, axis=1)
+ xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
+ xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
+ N = [[xx+yy+zz, 0.0, 0.0, 0.0],
+ [yz-zy, xx-yy-zz, 0.0, 0.0],
+ [zx-xz, xy+yx, yy-xx-zz, 0.0],
+ [xy-yx, zx+xz, yz+zy, zz-xx-yy]]
+ # quaternion: eigenvector corresponding to most positive eigenvalue
+ w, V = numpy.linalg.eigh(N)
+ q = V[:, numpy.argmax(w)]
+ q /= vector_norm(q) # unit quaternion
+ # homogeneous transformation matrix
+ M = quaternion_matrix(q)
+
+ if scale and not shear:
+ # Affine transformation; scale is ratio of RMS deviations from centroid
+ v0 *= v0
+ v1 *= v1
+ M[:ndims, :ndims] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0))
+
+ # move centroids back
+ M = numpy.dot(numpy.linalg.inv(M1), numpy.dot(M, M0))
+ M /= M[ndims, ndims]
+ return M
+
+
+def superimposition_matrix(v0, v1, scale=False, usesvd=True):
+ """Return matrix to transform given 3D point set into second point set.
+
+ v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 points.
+
+ The parameters scale and usesvd are explained in the more general
+ affine_matrix_from_points function.
+
+ The returned matrix is a similarity or Euclidean transformation matrix.
+ This function has a fast C implementation in transformations.c.
+
+ >>> v0 = numpy.random.rand(3, 10)
+ >>> M = superimposition_matrix(v0, v0)
+ >>> numpy.allclose(M, numpy.identity(4))
+ True
+ >>> R = random_rotation_matrix(numpy.random.random(3))
+ >>> v0 = [[1,0,0], [0,1,0], [0,0,1], [1,1,1]]
+ >>> v1 = numpy.dot(R, v0)
+ >>> M = superimposition_matrix(v0, v1)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
+ >>> v0[3] = 1
+ >>> v1 = numpy.dot(R, v0)
+ >>> M = superimposition_matrix(v0, v1)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> S = scale_matrix(random.random())
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
+ >>> M = concatenate_matrices(T, R, S)
+ >>> v1 = numpy.dot(M, v0)
+ >>> v0[:3] += numpy.random.normal(0, 1e-9, 300).reshape(3, -1)
+ >>> M = superimposition_matrix(v0, v1, scale=True)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> v = numpy.empty((4, 100, 3))
+ >>> v[:, :, 0] = v0
+ >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
+ >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0]))
+ True
+
+ """
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
+ return affine_matrix_from_points(v0, v1, shear=False,
+ scale=scale, usesvd=usesvd)
+
+
+def euler_matrix(ai, aj, ak, axes='sxyz'):
+ """Return homogeneous rotation matrix from Euler angles and axis sequence.
+
+ ai, aj, ak : Euler's roll, pitch and yaw angles
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ >>> R = euler_matrix(1, 2, 3, 'syxz')
+ >>> numpy.allclose(numpy.sum(R[0]), -1.34786452)
+ True
+ >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1))
+ >>> numpy.allclose(numpy.sum(R[0]), -0.383436184)
+ True
+ >>> ai, aj, ak = (4*math.pi) * (numpy.random.random(3) - 0.5)
+ >>> for axes in _AXES2TUPLE.keys():
+ ... R = euler_matrix(ai, aj, ak, axes)
+ >>> for axes in _TUPLE2AXES.keys():
+ ... R = euler_matrix(ai, aj, ak, axes)
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
+ except (AttributeError, KeyError):
+ _TUPLE2AXES[axes] # noqa: validation
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i+parity]
+ k = _NEXT_AXIS[i-parity+1]
+
+ if frame:
+ ai, ak = ak, ai
+ if parity:
+ ai, aj, ak = -ai, -aj, -ak
+
+ si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
+ ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
+ cc, cs = ci*ck, ci*sk
+ sc, ss = si*ck, si*sk
+
+ M = numpy.identity(4)
+ if repetition:
+ M[i, i] = cj
+ M[i, j] = sj*si
+ M[i, k] = sj*ci
+ M[j, i] = sj*sk
+ M[j, j] = -cj*ss+cc
+ M[j, k] = -cj*cs-sc
+ M[k, i] = -sj*ck
+ M[k, j] = cj*sc+cs
+ M[k, k] = cj*cc-ss
+ else:
+ M[i, i] = cj*ck
+ M[i, j] = sj*sc-cs
+ M[i, k] = sj*cc+ss
+ M[j, i] = cj*sk
+ M[j, j] = sj*ss+cc
+ M[j, k] = sj*cs-sc
+ M[k, i] = -sj
+ M[k, j] = cj*si
+ M[k, k] = cj*ci
+ return M
+
+
+def euler_from_matrix(matrix, axes='sxyz'):
+ """Return Euler angles from rotation matrix for specified axis sequence.
+
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ Note that many Euler angle triplets can describe one matrix.
+
+ >>> R0 = euler_matrix(1, 2, 3, 'syxz')
+ >>> al, be, ga = euler_from_matrix(R0, 'syxz')
+ >>> R1 = euler_matrix(al, be, ga, 'syxz')
+ >>> numpy.allclose(R0, R1)
+ True
+ >>> angles = (4*math.pi) * (numpy.random.random(3) - 0.5)
+ >>> for axes in _AXES2TUPLE.keys():
+ ... R0 = euler_matrix(axes=axes, *angles)
+ ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes))
+ ... if not numpy.allclose(R0, R1): print(axes, "failed")
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
+ except (AttributeError, KeyError):
+ _TUPLE2AXES[axes] # noqa: validation
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i+parity]
+ k = _NEXT_AXIS[i-parity+1]
+
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
+ if repetition:
+ sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
+ if sy > _EPS:
+ ax = math.atan2(M[i, j], M[i, k])
+ ay = math.atan2(sy, M[i, i])
+ az = math.atan2(M[j, i], -M[k, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2(sy, M[i, i])
+ az = 0.0
+ else:
+ cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
+ if cy > _EPS:
+ ax = math.atan2(M[k, j], M[k, k])
+ ay = math.atan2(-M[k, i], cy)
+ az = math.atan2(M[j, i], M[i, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2(-M[k, i], cy)
+ az = 0.0
+
+ if parity:
+ ax, ay, az = -ax, -ay, -az
+ if frame:
+ ax, az = az, ax
+ return ax, ay, az
+
+
+def euler_from_quaternion(quaternion, axes='sxyz'):
+ """Return Euler angles from quaternion for specified axis sequence.
+
+ >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0])
+ >>> numpy.allclose(angles, [0.123, 0, 0])
+ True
+
+ """
+ return euler_from_matrix(quaternion_matrix(quaternion), axes)
+
+
+def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
+ """Return quaternion from Euler angles and axis sequence.
+
+ ai, aj, ak : Euler's roll, pitch and yaw angles
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ >>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
+ >>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
+ True
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
+ except (AttributeError, KeyError):
+ _TUPLE2AXES[axes] # noqa: validation
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis + 1
+ j = _NEXT_AXIS[i+parity-1] + 1
+ k = _NEXT_AXIS[i-parity] + 1
+
+ if frame:
+ ai, ak = ak, ai
+ if parity:
+ aj = -aj
+
+ ai /= 2.0
+ aj /= 2.0
+ ak /= 2.0
+ ci = math.cos(ai)
+ si = math.sin(ai)
+ cj = math.cos(aj)
+ sj = math.sin(aj)
+ ck = math.cos(ak)
+ sk = math.sin(ak)
+ cc = ci*ck
+ cs = ci*sk
+ sc = si*ck
+ ss = si*sk
+
+ q = numpy.empty((4, ))
+ if repetition:
+ q[0] = cj*(cc - ss)
+ q[i] = cj*(cs + sc)
+ q[j] = sj*(cc + ss)
+ q[k] = sj*(cs - sc)
+ else:
+ q[0] = cj*cc + sj*ss
+ q[i] = cj*sc - sj*cs
+ q[j] = cj*ss + sj*cc
+ q[k] = cj*cs - sj*sc
+ if parity:
+ q[j] *= -1.0
+
+ return q
+
+
+def quaternion_about_axis(angle, axis):
+ """Return quaternion for rotation about axis.
+
+ >>> q = quaternion_about_axis(0.123, [1, 0, 0])
+ >>> numpy.allclose(q, [0.99810947, 0.06146124, 0, 0])
+ True
+
+ """
+ q = numpy.array([0.0, axis[0], axis[1], axis[2]])
+ qlen = vector_norm(q)
+ if qlen > _EPS:
+ q *= math.sin(angle/2.0) / qlen
+ q[0] = math.cos(angle/2.0)
+ return q
+
+
+def quaternion_matrix(quaternion):
+ """Return homogeneous rotation matrix from quaternion.
+
+ >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
+ >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
+ True
+ >>> M = quaternion_matrix([1, 0, 0, 0])
+ >>> numpy.allclose(M, numpy.identity(4))
+ True
+ >>> M = quaternion_matrix([0, 1, 0, 0])
+ >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
+ True
+
+ """
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
+ n = numpy.dot(q, q)
+ if n < _EPS:
+ return numpy.identity(4)
+ q *= math.sqrt(2.0 / n)
+ q = numpy.outer(q, q)
+ return numpy.array([
+ [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
+ [q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
+ [q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
+ [0.0, 0.0, 0.0, 1.0]])
+
+
+def quaternion_from_matrix(matrix, isprecise=False):
+ """Return quaternion from rotation matrix.
+
+ If isprecise is True, the input matrix is assumed to be a precise rotation
+ matrix and a faster algorithm is used.
+
+ >>> q = quaternion_from_matrix(numpy.identity(4), True)
+ >>> numpy.allclose(q, [1, 0, 0, 0])
+ True
+ >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
+ >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
+ True
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
+ >>> q = quaternion_from_matrix(R, True)
+ >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
+ True
+ >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
+ ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
+ >>> q = quaternion_from_matrix(R)
+ >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
+ True
+ >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
+ ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
+ >>> q = quaternion_from_matrix(R)
+ >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
+ True
+ >>> R = random_rotation_matrix()
+ >>> q = quaternion_from_matrix(R)
+ >>> is_same_transform(R, quaternion_matrix(q))
+ True
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
+ ... quaternion_from_matrix(R, isprecise=True))
+ True
+ >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
+ ... quaternion_from_matrix(R, isprecise=True))
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
+ if isprecise:
+ q = numpy.empty((4, ))
+ t = numpy.trace(M)
+ if t > M[3, 3]:
+ q[0] = t
+ q[3] = M[1, 0] - M[0, 1]
+ q[2] = M[0, 2] - M[2, 0]
+ q[1] = M[2, 1] - M[1, 2]
+ else:
+ i, j, k = 0, 1, 2
+ if M[1, 1] > M[0, 0]:
+ i, j, k = 1, 2, 0
+ if M[2, 2] > M[i, i]:
+ i, j, k = 2, 0, 1
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
+ q[i] = t
+ q[j] = M[i, j] + M[j, i]
+ q[k] = M[k, i] + M[i, k]
+ q[3] = M[k, j] - M[j, k]
+ q = q[[3, 0, 1, 2]]
+ q *= 0.5 / math.sqrt(t * M[3, 3])
+ else:
+ m00 = M[0, 0]
+ m01 = M[0, 1]
+ m02 = M[0, 2]
+ m10 = M[1, 0]
+ m11 = M[1, 1]
+ m12 = M[1, 2]
+ m20 = M[2, 0]
+ m21 = M[2, 1]
+ m22 = M[2, 2]
+ # symmetric matrix K
+ K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
+ [m01+m10, m11-m00-m22, 0.0, 0.0],
+ [m02+m20, m12+m21, m22-m00-m11, 0.0],
+ [m21-m12, m02-m20, m10-m01, m00+m11+m22]])
+ K /= 3.0
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
+ w, V = numpy.linalg.eigh(K)
+ q = V[[3, 0, 1, 2], numpy.argmax(w)]
+ if q[0] < 0.0:
+ numpy.negative(q, q)
+ return q
+
+
+def quaternion_multiply(quaternion1, quaternion0):
+ """Return multiplication of two quaternions.
+
+ >>> q = quaternion_multiply([4, 1, -2, 3], [8, -5, 6, 7])
+ >>> numpy.allclose(q, [28, -44, -14, 48])
+ True
+
+ """
+ w0, x0, y0, z0 = quaternion0
+ w1, x1, y1, z1 = quaternion1
+ return numpy.array([
+ -x1*x0 - y1*y0 - z1*z0 + w1*w0,
+ x1*w0 + y1*z0 - z1*y0 + w1*x0,
+ -x1*z0 + y1*w0 + z1*x0 + w1*y0,
+ x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64)
+
+
+def quaternion_conjugate(quaternion):
+ """Return conjugate of quaternion.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = quaternion_conjugate(q0)
+ >>> q1[0] == q0[0] and all(q1[1:] == -q0[1:])
+ True
+
+ """
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
+ numpy.negative(q[1:], q[1:])
+ return q
+
+
+def quaternion_inverse(quaternion):
+ """Return inverse of quaternion.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = quaternion_inverse(q0)
+ >>> numpy.allclose(quaternion_multiply(q0, q1), [1, 0, 0, 0])
+ True
+
+ """
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
+ numpy.negative(q[1:], q[1:])
+ return q / numpy.dot(q, q)
+
+
+def quaternion_real(quaternion):
+ """Return real part of quaternion.
+
+ >>> quaternion_real([3, 0, 1, 2])
+ 3.0
+
+ """
+ return float(quaternion[0])
+
+
+def quaternion_imag(quaternion):
+ """Return imaginary part of quaternion.
+
+ >>> quaternion_imag([3, 0, 1, 2])
+ array([ 0., 1., 2.])
+
+ """
+ return numpy.array(quaternion[1:4], dtype=numpy.float64, copy=True)
+
+
+def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
+ """Return spherical linear interpolation between two quaternions.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = random_quaternion()
+ >>> q = quaternion_slerp(q0, q1, 0)
+ >>> numpy.allclose(q, q0)
+ True
+ >>> q = quaternion_slerp(q0, q1, 1, 1)
+ >>> numpy.allclose(q, q1)
+ True
+ >>> q = quaternion_slerp(q0, q1, 0.5)
+ >>> angle = math.acos(numpy.dot(q0, q))
+ >>> numpy.allclose(2, math.acos(numpy.dot(q0, q1)) / angle) or \
+ numpy.allclose(2, math.acos(-numpy.dot(q0, q1)) / angle)
+ True
+
+ """
+ q0 = unit_vector(quat0[:4])
+ q1 = unit_vector(quat1[:4])
+ if fraction == 0.0:
+ return q0
+ elif fraction == 1.0:
+ return q1
+ d = numpy.dot(q0, q1)
+ if abs(abs(d) - 1.0) < _EPS:
+ return q0
+ if shortestpath and d < 0.0:
+ # invert rotation
+ d = -d
+ numpy.negative(q1, q1)
+ angle = math.acos(d) + spin * math.pi
+ if abs(angle) < _EPS:
+ return q0
+ isin = 1.0 / math.sin(angle)
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
+ q1 *= math.sin(fraction * angle) * isin
+ q0 += q1
+ return q0
+
+
+def random_quaternion(rand=None):
+ """Return uniform random unit quaternion.
+
+ rand: array like or None
+ Three independent random variables that are uniformly distributed
+ between 0 and 1.
+
+ >>> q = random_quaternion()
+ >>> numpy.allclose(1, vector_norm(q))
+ True
+ >>> q = random_quaternion(numpy.random.random(3))
+ >>> len(q.shape), q.shape[0]==4
+ (1, True)
+
+ """
+ if rand is None:
+ rand = numpy.random.rand(3)
+ else:
+ assert len(rand) == 3
+ r1 = numpy.sqrt(1.0 - rand[0])
+ r2 = numpy.sqrt(rand[0])
+ pi2 = math.pi * 2.0
+ t1 = pi2 * rand[1]
+ t2 = pi2 * rand[2]
+ return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1,
+ numpy.cos(t1)*r1, numpy.sin(t2)*r2])
+
+
+def random_rotation_matrix(rand=None):
+ """Return uniform random rotation matrix.
+
+ rand: array like
+ Three independent random variables that are uniformly distributed
+ between 0 and 1 for each returned quaternion.
+
+ >>> R = random_rotation_matrix()
+ >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4))
+ True
+
+ """
+ return quaternion_matrix(random_quaternion(rand))
+
+
+class Arcball(object):
+ """Virtual Trackball Control.
+
+ >>> ball = Arcball()
+ >>> ball = Arcball(initial=numpy.identity(4))
+ >>> ball.place([320, 320], 320)
+ >>> ball.down([500, 250])
+ >>> ball.drag([475, 275])
+ >>> R = ball.matrix()
+ >>> numpy.allclose(numpy.sum(R), 3.90583455)
+ True
+ >>> ball = Arcball(initial=[1, 0, 0, 0])
+ >>> ball.place([320, 320], 320)
+ >>> ball.setaxes([1, 1, 0], [-1, 1, 0])
+ >>> ball.constrain = True
+ >>> ball.down([400, 200])
+ >>> ball.drag([200, 400])
+ >>> R = ball.matrix()
+ >>> numpy.allclose(numpy.sum(R), 0.2055924)
+ True
+ >>> ball.next()
+
+ """
+
+ def __init__(self, initial=None):
+ """Initialize virtual trackball control.
+
+ initial : quaternion or rotation matrix
+
+ """
+ self._axis = None
+ self._axes = None
+ self._radius = 1.0
+ self._center = [0.0, 0.0]
+ self._vdown = numpy.array([0.0, 0.0, 1.0])
+ self._constrain = False
+ if initial is None:
+ self._qdown = numpy.array([1.0, 0.0, 0.0, 0.0])
+ else:
+ initial = numpy.array(initial, dtype=numpy.float64)
+ if initial.shape == (4, 4):
+ self._qdown = quaternion_from_matrix(initial)
+ elif initial.shape == (4, ):
+ initial /= vector_norm(initial)
+ self._qdown = initial
+ else:
+ raise ValueError("initial not a quaternion or matrix")
+ self._qnow = self._qpre = self._qdown
+
+ def place(self, center, radius):
+ """Place Arcball, e.g. when window size changes.
+
+ center : sequence[2]
+ Window coordinates of trackball center.
+ radius : float
+ Radius of trackball in window coordinates.
+
+ """
+ self._radius = float(radius)
+ self._center[0] = center[0]
+ self._center[1] = center[1]
+
+ def setaxes(self, *axes):
+ """Set axes to constrain rotations."""
+ if axes is None:
+ self._axes = None
+ else:
+ self._axes = [unit_vector(axis) for axis in axes]
+
+ @property
+ def constrain(self):
+ """Return state of constrain to axis mode."""
+ return self._constrain
+
+ @constrain.setter
+ def constrain(self, value):
+ """Set state of constrain to axis mode."""
+ self._constrain = bool(value)
+
+ def down(self, point):
+ """Set initial cursor window coordinates and pick constrain-axis."""
+ self._vdown = arcball_map_to_sphere(point, self._center, self._radius)
+ self._qdown = self._qpre = self._qnow
+ if self._constrain and self._axes is not None:
+ self._axis = arcball_nearest_axis(self._vdown, self._axes)
+ self._vdown = arcball_constrain_to_axis(self._vdown, self._axis)
+ else:
+ self._axis = None
+
+ def drag(self, point):
+ """Update current cursor window coordinates."""
+ vnow = arcball_map_to_sphere(point, self._center, self._radius)
+ if self._axis is not None:
+ vnow = arcball_constrain_to_axis(vnow, self._axis)
+ self._qpre = self._qnow
+ t = numpy.cross(self._vdown, vnow)
+ if numpy.dot(t, t) < _EPS:
+ self._qnow = self._qdown
+ else:
+ q = [numpy.dot(self._vdown, vnow), t[0], t[1], t[2]]
+ self._qnow = quaternion_multiply(q, self._qdown)
+
+ def next(self, acceleration=0.0):
+ """Continue rotation in direction of last drag."""
+ q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
+ self._qpre, self._qnow = self._qnow, q
+
+ def matrix(self):
+ """Return homogeneous rotation matrix."""
+ return quaternion_matrix(self._qnow)
+
+
+def arcball_map_to_sphere(point, center, radius):
+ """Return unit sphere coordinates from window coordinates."""
+ v0 = (point[0] - center[0]) / radius
+ v1 = (center[1] - point[1]) / radius
+ n = v0*v0 + v1*v1
+ if n > 1.0:
+ # position outside of sphere
+ n = math.sqrt(n)
+ return numpy.array([v0/n, v1/n, 0.0])
+ else:
+ return numpy.array([v0, v1, math.sqrt(1.0 - n)])
+
+
+def arcball_constrain_to_axis(point, axis):
+ """Return sphere point perpendicular to axis."""
+ v = numpy.array(point, dtype=numpy.float64, copy=True)
+ a = numpy.array(axis, dtype=numpy.float64, copy=True)
+ v -= a * numpy.dot(a, v) # on plane
+ n = vector_norm(v)
+ if n > _EPS:
+ if v[2] < 0.0:
+ numpy.negative(v, v)
+ v /= n
+ return v
+ if a[2] == 1.0:
+ return numpy.array([1.0, 0.0, 0.0])
+ return unit_vector([-a[1], a[0], 0.0])
+
+
+def arcball_nearest_axis(point, axes):
+ """Return axis, which arc is nearest to point."""
+ point = numpy.array(point, dtype=numpy.float64, copy=False)
+ nearest = None
+ mx = -1.0
+ for axis in axes:
+ t = numpy.dot(arcball_constrain_to_axis(point, axis), point)
+ if t > mx:
+ nearest = axis
+ mx = t
+ return nearest
+
+
+# epsilon for testing whether a number is close to zero
+_EPS = numpy.finfo(float).eps * 4.0
+
+# axis sequences for Euler angles
+_NEXT_AXIS = [1, 2, 0, 1]
+
+# map axes strings to/from tuples of inner axis, parity, repetition, frame
+_AXES2TUPLE = {
+ 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
+ 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
+ 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
+ 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
+ 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
+ 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
+ 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
+ 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
+
+_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
+
+
+def vector_norm(data, axis=None, out=None):
+ """Return length, i.e. Euclidean norm, of ndarray along axis.
+
+ >>> v = numpy.random.random(3)
+ >>> n = vector_norm(v)
+ >>> numpy.allclose(n, numpy.linalg.norm(v))
+ True
+ >>> v = numpy.random.rand(6, 5, 3)
+ >>> n = vector_norm(v, axis=-1)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2)))
+ True
+ >>> n = vector_norm(v, axis=1)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
+ True
+ >>> v = numpy.random.rand(5, 4, 3)
+ >>> n = numpy.empty((5, 3))
+ >>> vector_norm(v, axis=1, out=n)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
+ True
+ >>> vector_norm([])
+ 0.0
+ >>> vector_norm([1])
+ 1.0
+
+ """
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
+ if out is None:
+ if data.ndim == 1:
+ return math.sqrt(numpy.dot(data, data))
+ data *= data
+ out = numpy.atleast_1d(numpy.sum(data, axis=axis))
+ numpy.sqrt(out, out)
+ return out
+ else:
+ data *= data
+ numpy.sum(data, axis=axis, out=out)
+ numpy.sqrt(out, out)
+
+
+def unit_vector(data, axis=None, out=None):
+ """Return ndarray normalized by length, i.e. Euclidean norm, along axis.
+
+ >>> v0 = numpy.random.random(3)
+ >>> v1 = unit_vector(v0)
+ >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
+ True
+ >>> v0 = numpy.random.rand(5, 4, 3)
+ >>> v1 = unit_vector(v0, axis=-1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> v1 = unit_vector(v0, axis=1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> v1 = numpy.empty((5, 4, 3))
+ >>> unit_vector(v0, axis=1, out=v1)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> list(unit_vector([]))
+ []
+ >>> list(unit_vector([1]))
+ [1.0]
+
+ """
+ if out is None:
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
+ if data.ndim == 1:
+ data /= math.sqrt(numpy.dot(data, data))
+ return data
+ else:
+ if out is not data:
+ out[:] = numpy.array(data, copy=False)
+ data = out
+ length = numpy.atleast_1d(numpy.sum(data*data, axis))
+ numpy.sqrt(length, length)
+ if axis is not None:
+ length = numpy.expand_dims(length, axis)
+ data /= length
+ if out is None:
+ return data
+
+
+def random_vector(size):
+ """Return array of random doubles in the half-open interval [0.0, 1.0).
+
+ >>> v = random_vector(10000)
+ >>> numpy.all(v >= 0) and numpy.all(v < 1)
+ True
+ >>> v0 = random_vector(10)
+ >>> v1 = random_vector(10)
+ >>> numpy.any(v0 == v1)
+ False
+
+ """
+ return numpy.random.random(size)
+
+
+def vector_product(v0, v1, axis=0):
+ """Return vector perpendicular to vectors.
+
+ >>> v = vector_product([2, 0, 0], [0, 3, 0])
+ >>> numpy.allclose(v, [0, 0, 6])
+ True
+ >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
+ >>> v1 = [[3], [0], [0]]
+ >>> v = vector_product(v0, v1)
+ >>> numpy.allclose(v, [[0, 0, 0, 0], [0, 0, 6, 6], [0, -6, 0, -6]])
+ True
+ >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
+ >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
+ >>> v = vector_product(v0, v1, axis=1)
+ >>> numpy.allclose(v, [[0, 0, 6], [0, -6, 0], [6, 0, 0], [0, -6, 6]])
+ True
+
+ """
+ return numpy.cross(v0, v1, axis=axis)
+
+
+def angle_between_vectors(v0, v1, directed=True, axis=0):
+ """Return angle between vectors.
+
+ If directed is False, the input vectors are interpreted as undirected axes,
+ i.e. the maximum angle is pi/2.
+
+ >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3])
+ >>> numpy.allclose(a, math.pi)
+ True
+ >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3], directed=False)
+ >>> numpy.allclose(a, 0)
+ True
+ >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
+ >>> v1 = [[3], [0], [0]]
+ >>> a = angle_between_vectors(v0, v1)
+ >>> numpy.allclose(a, [0, 1.5708, 1.5708, 0.95532])
+ True
+ >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
+ >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
+ >>> a = angle_between_vectors(v0, v1, axis=1)
+ >>> numpy.allclose(a, [1.5708, 1.5708, 1.5708, 0.95532])
+ True
+
+ """
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)
+ dot = numpy.sum(v0 * v1, axis=axis)
+ dot /= vector_norm(v0, axis=axis) * vector_norm(v1, axis=axis)
+ dot = numpy.clip(dot, -1.0, 1.0)
+ return numpy.arccos(dot if directed else numpy.fabs(dot))
+
+
+def inverse_matrix(matrix):
+ """Return inverse of square transformation matrix.
+
+ >>> M0 = random_rotation_matrix()
+ >>> M1 = inverse_matrix(M0.T)
+ >>> numpy.allclose(M1, numpy.linalg.inv(M0.T))
+ True
+ >>> for size in range(1, 7):
+ ... M0 = numpy.random.rand(size, size)
+ ... M1 = inverse_matrix(M0)
+ ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print(size)
+
+ """
+ return numpy.linalg.inv(matrix)
+
+
+def concatenate_matrices(*matrices):
+ """Return concatenation of series of transformation matrices.
+
+ >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5
+ >>> numpy.allclose(M, concatenate_matrices(M))
+ True
+ >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T))
+ True
+
+ """
+ M = numpy.identity(4)
+ for i in matrices:
+ M = numpy.dot(M, i)
+ return M
+
+
+def is_same_transform(matrix0, matrix1):
+ """Return True if two matrices perform same transformation.
+
+ >>> is_same_transform(numpy.identity(4), numpy.identity(4))
+ True
+ >>> is_same_transform(numpy.identity(4), random_rotation_matrix())
+ False
+
+ """
+ matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True)
+ matrix0 /= matrix0[3, 3]
+ matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True)
+ matrix1 /= matrix1[3, 3]
+ return numpy.allclose(matrix0, matrix1)
+
+
+def is_same_quaternion(q0, q1):
+ """Return True if two quaternions are equal."""
+ q0 = numpy.array(q0)
+ q1 = numpy.array(q1)
+ return numpy.allclose(q0, q1) or numpy.allclose(q0, -q1)
+
+
+def _import_module(name, package=None, warn=True, postfix='_py', ignore='_'):
+ """Try import all public attributes from module into global namespace.
+
+ Existing attributes with name clashes are renamed with prefix.
+ Attributes starting with underscore are ignored by default.
+
+ Return True on successful import.
+
+ """
+ import warnings
+ from importlib import import_module
+ try:
+ if not package:
+ module = import_module(name)
+ else:
+ module = import_module('.' + name, package=package)
+ except ImportError as err:
+ if warn:
+ warnings.warn(str(err))
+ else:
+ for attr in dir(module):
+ if ignore and attr.startswith(ignore):
+ continue
+ if postfix:
+ if attr in globals():
+ globals()[attr + postfix] = globals()[attr]
+ elif warn:
+ warnings.warn('no Python implementation of ' + attr)
+ globals()[attr] = getattr(module, attr)
+ return True
+
+
+_import_module('_transformations', __package__, warn=False)
+
+
+if __name__ == '__main__':
+ import doctest
+ import random # noqa: used in doctests
+ try:
+ numpy.set_printoptions(suppress=True, precision=5, legacy='1.13')
+ except TypeError:
+ numpy.set_printoptions(suppress=True, precision=5)
+ doctest.testmod()
diff --git a/imcui/third_party/COTR/COTR/utils/constants.py b/imcui/third_party/COTR/COTR/utils/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ae0035c8a180a9035a272e0932be5108828771f
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/utils/constants.py
@@ -0,0 +1,3 @@
+DEFAULT_PRECISION = 'float32'
+MAX_SIZE = 256
+VALID_NN_OVERLAPPING_THRESH = 0.1
diff --git a/imcui/third_party/COTR/COTR/utils/debug_utils.py b/imcui/third_party/COTR/COTR/utils/debug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..74db4f44148fd784604053084e277f8841d0d0a4
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/utils/debug_utils.py
@@ -0,0 +1,15 @@
+def embed_breakpoint(debug_info='', terminate=True):
+ print('\nyou are inside a break point')
+ if debug_info:
+ print('debug info: {0}'.format(debug_info))
+ print('')
+ embedding = ('import IPython\n'
+ 'import matplotlib.pyplot as plt\n'
+ 'IPython.embed()\n'
+ )
+ if terminate:
+ embedding += (
+ 'assert 0, \'force termination\'\n'
+ )
+
+ return embedding
diff --git a/imcui/third_party/COTR/COTR/utils/utils.py b/imcui/third_party/COTR/COTR/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d3422188ad93dea94926259ec1088552b11ef31
--- /dev/null
+++ b/imcui/third_party/COTR/COTR/utils/utils.py
@@ -0,0 +1,271 @@
+import random
+import smtplib
+import ssl
+from collections import namedtuple
+
+from COTR.utils import debug_utils
+
+import numpy as np
+import torch
+import cv2
+import matplotlib.pyplot as plt
+import PIL
+
+
+'''
+ImagePatch: patch: patch content, np array or None
+ x: left bound in original resolution
+ y: upper bound in original resolution
+ w: width of patch
+ h: height of patch
+ ow: width of original resolution
+ oh: height of original resolution
+'''
+ImagePatch = namedtuple('ImagePatch', ['patch', 'x', 'y', 'w', 'h', 'ow', 'oh'])
+Point3D = namedtuple("Point3D", ["id", "arr_idx", "image_ids"])
+Point2D = namedtuple("Point2D", ["id_3d", "xy"])
+
+
+class CropCamConfig():
+ def __init__(self, x, y, w, h, out_w, out_h, orig_w, orig_h):
+ '''
+ xy: left upper corner
+ '''
+ # assert x > 0 and x < orig_w
+ # assert y > 0 and y < orig_h
+ # assert w < orig_w and h < orig_h
+ # assert x - w / 2 > 0 and x + w / 2 < orig_w
+ # assert y - h / 2 > 0 and y + h / 2 < orig_h
+ # assert h / w == out_h / out_w
+ self.x = x
+ self.y = y
+ self.w = w
+ self.h = h
+ self.out_w = out_w
+ self.out_h = out_h
+ self.orig_w = orig_w
+ self.orig_h = orig_h
+
+ def __str__(self):
+ out = f'original image size(h,w): [{self.orig_h}, {self.orig_w}]\n'
+ out += f'crop at(x,y): [{self.x}, {self.y}]\n'
+ out += f'crop size(h,w): [{self.h}, {self.w}]\n'
+ out += f'resize crop to(h,w): [{self.out_h}, {self.out_w}]'
+ return out
+
+
+def fix_randomness(seed=42):
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+
+def worker_init_fn(worker_id):
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
+
+
+def float_image_resize(img, shape, interp=PIL.Image.BILINEAR):
+ missing_channel = False
+ if len(img.shape) == 2:
+ missing_channel = True
+ img = img[..., None]
+ layers = []
+ img = img.transpose(2, 0, 1)
+ for l in img:
+ l = np.array(PIL.Image.fromarray(l).resize(shape[::-1], resample=interp))
+ assert l.shape[:2] == shape
+ layers.append(l)
+ if missing_channel:
+ return np.stack(layers, axis=-1)[..., 0]
+ else:
+ return np.stack(layers, axis=-1)
+
+
+def is_nan(x):
+ """
+ get mask of nan values.
+ :param x: torch or numpy var.
+ :return: a N-D array of bool. True -> nan, False -> ok.
+ """
+ return x != x
+
+
+def has_nan(x) -> bool:
+ """
+ check whether x contains nan.
+ :param x: torch or numpy var.
+ :return: single bool, True -> x containing nan, False -> ok.
+ """
+ if x is None:
+ return False
+ return is_nan(x).any()
+
+
+def confirm(question='OK to continue?'):
+ """
+ Ask user to enter Y or N (case-insensitive).
+ :return: True if the answer is Y.
+ :rtype: bool
+ """
+ answer = ""
+ while answer not in ["y", "n"]:
+ answer = input(question + ' [y/n] ').lower()
+ return answer == "y"
+
+
+def print_notification(content_list, notification_type='NOTIFICATION'):
+ print('---------------------- {0} ----------------------'.format(notification_type))
+ print()
+ for content in content_list:
+ print(content)
+ print()
+ print('----------------------------------------------------')
+
+
+def torch_img_to_np_img(torch_img):
+ '''convert a torch image to matplotlib-able numpy image
+ torch use Channels x Height x Width
+ numpy use Height x Width x Channels
+ Arguments:
+ torch_img {[type]} -- [description]
+ '''
+ assert isinstance(torch_img, torch.Tensor), 'cannot process data type: {0}'.format(type(torch_img))
+ if len(torch_img.shape) == 4 and (torch_img.shape[1] == 3 or torch_img.shape[1] == 1):
+ return np.transpose(torch_img.detach().cpu().numpy(), (0, 2, 3, 1))
+ if len(torch_img.shape) == 3 and (torch_img.shape[0] == 3 or torch_img.shape[0] == 1):
+ return np.transpose(torch_img.detach().cpu().numpy(), (1, 2, 0))
+ elif len(torch_img.shape) == 2:
+ return torch_img.detach().cpu().numpy()
+ else:
+ raise ValueError('cannot process this image')
+
+
+def np_img_to_torch_img(np_img):
+ """convert a numpy image to torch image
+ numpy use Height x Width x Channels
+ torch use Channels x Height x Width
+
+ Arguments:
+ np_img {[type]} -- [description]
+ """
+ assert isinstance(np_img, np.ndarray), 'cannot process data type: {0}'.format(type(np_img))
+ if len(np_img.shape) == 4 and (np_img.shape[3] == 3 or np_img.shape[3] == 1):
+ return torch.from_numpy(np.transpose(np_img, (0, 3, 1, 2)))
+ if len(np_img.shape) == 3 and (np_img.shape[2] == 3 or np_img.shape[2] == 1):
+ return torch.from_numpy(np.transpose(np_img, (2, 0, 1)))
+ elif len(np_img.shape) == 2:
+ return torch.from_numpy(np_img)
+ else:
+ raise ValueError('cannot process this image with shape: {0}'.format(np_img.shape))
+
+
+def safe_load_weights(model, saved_weights):
+ try:
+ model.load_state_dict(saved_weights)
+ except RuntimeError:
+ try:
+ weights = saved_weights
+ weights = {k.replace('module.', ''): v for k, v in weights.items()}
+ model.load_state_dict(weights)
+ except RuntimeError:
+ try:
+ weights = saved_weights
+ weights = {'module.' + k: v for k, v in weights.items()}
+ model.load_state_dict(weights)
+ except RuntimeError:
+ try:
+ pretrained_dict = saved_weights
+ model_dict = model.state_dict()
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if ((k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape))}
+ assert len(pretrained_dict) != 0
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ non_match_keys = set(model.state_dict().keys()) - set(pretrained_dict.keys())
+ notification = []
+ notification += ['pretrained weights PARTIALLY loaded, following are missing:']
+ notification += [str(non_match_keys)]
+ print_notification(notification, 'WARNING')
+ except Exception as e:
+ print(f'pretrained weights loading failed {e}')
+ exit()
+ print('weights safely loaded')
+
+
+def visualize_corrs(img1, img2, corrs, mask=None):
+ if mask is None:
+ mask = np.ones(len(corrs)).astype(bool)
+
+ scale1 = 1.0
+ scale2 = 1.0
+ if img1.shape[1] > img2.shape[1]:
+ scale2 = img1.shape[1] / img2.shape[1]
+ w = img1.shape[1]
+ else:
+ scale1 = img2.shape[1] / img1.shape[1]
+ w = img2.shape[1]
+ # Resize if too big
+ max_w = 400
+ if w > max_w:
+ scale1 *= max_w / w
+ scale2 *= max_w / w
+ img1 = cv2.resize(img1, (0, 0), fx=scale1, fy=scale1)
+ img2 = cv2.resize(img2, (0, 0), fx=scale2, fy=scale2)
+
+ x1, x2 = corrs[:, :2], corrs[:, 2:]
+ h1, w1 = img1.shape[:2]
+ h2, w2 = img2.shape[:2]
+ img = np.zeros((h1 + h2, max(w1, w2), 3), dtype=img1.dtype)
+ img[:h1, :w1] = img1
+ img[h1:, :w2] = img2
+ # Move keypoints to coordinates to image coordinates
+ x1 = x1 * scale1
+ x2 = x2 * scale2
+ # recompute the coordinates for the second image
+ x2p = x2 + np.array([[0, h1]])
+ fig = plt.figure(frameon=False)
+ fig = plt.imshow(img)
+
+ cols = [
+ [0.0, 0.67, 0.0],
+ [0.9, 0.1, 0.1],
+ ]
+ lw = .5
+ alpha = 1
+
+ # Draw outliers
+ _x1 = x1[~mask]
+ _x2p = x2p[~mask]
+ xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T
+ ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T
+ plt.plot(
+ xs, ys,
+ alpha=alpha,
+ linestyle="-",
+ linewidth=lw,
+ aa=False,
+ color=cols[1],
+ )
+
+
+ # Draw Inliers
+ _x1 = x1[mask]
+ _x2p = x2p[mask]
+ xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T
+ ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T
+ plt.plot(
+ xs, ys,
+ alpha=alpha,
+ linestyle="-",
+ linewidth=lw,
+ aa=False,
+ color=cols[0],
+ )
+ plt.scatter(xs, ys)
+
+ fig.axes.get_xaxis().set_visible(False)
+ fig.axes.get_yaxis().set_visible(False)
+ ax = plt.gca()
+ ax.set_axis_off()
+ plt.show()
diff --git a/imcui/third_party/COTR/demo_face.py b/imcui/third_party/COTR/demo_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..4975356d3246c484703d2d6cf60fe9e1a2e67be8
--- /dev/null
+++ b/imcui/third_party/COTR/demo_face.py
@@ -0,0 +1,69 @@
+'''
+COTR demo for human face
+We use an off-the-shelf face landmarks detector: https://github.com/1adrianb/face-alignment
+'''
+import argparse
+import os
+import time
+
+import cv2
+import numpy as np
+import torch
+import imageio
+import matplotlib.pyplot as plt
+
+from COTR.utils import utils, debug_utils
+from COTR.models import build_model
+from COTR.options.options import *
+from COTR.options.options_utils import *
+from COTR.inference.inference_helper import triangulate_corr
+from COTR.inference.sparse_engine import SparseEngine
+
+utils.fix_randomness(0)
+torch.set_grad_enabled(False)
+
+
+def main(opt):
+ model = build_model(opt)
+ model = model.cuda()
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
+ utils.safe_load_weights(model, weights)
+ model = model.eval()
+
+ img_a = imageio.imread('./sample_data/imgs/face_1.png', pilmode='RGB')
+ img_b = imageio.imread('./sample_data/imgs/face_2.png', pilmode='RGB')
+ queries = np.load('./sample_data/face_landmarks.npy')[0]
+
+ engine = SparseEngine(model, 32, mode='stretching')
+ corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=False)
+
+ f, axarr = plt.subplots(1, 2)
+ axarr[0].imshow(img_a)
+ axarr[0].scatter(*queries.T, s=1)
+ axarr[0].title.set_text('Reference Face')
+ axarr[0].axis('off')
+ axarr[1].imshow(img_b)
+ axarr[1].scatter(*corrs[:, 2:].T, s=1)
+ axarr[1].title.set_text('Target Face')
+ axarr[1].axis('off')
+ plt.show()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ set_COTR_arguments(parser)
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
+
+ opt = parser.parse_args()
+ opt.command = ' '.join(sys.argv)
+
+ layer_2_channels = {'layer1': 256,
+ 'layer2': 512,
+ 'layer3': 1024,
+ 'layer4': 2048, }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+ if opt.load_weights:
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
+ print_opt(opt)
+ main(opt)
diff --git a/imcui/third_party/COTR/demo_guided_matching.py b/imcui/third_party/COTR/demo_guided_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e68d922e0785144ef3087cee7c50453f484076c
--- /dev/null
+++ b/imcui/third_party/COTR/demo_guided_matching.py
@@ -0,0 +1,85 @@
+'''
+Feature-free COTR guided matching for keypoints.
+We use DISK(https://github.com/cvlab-epfl/disk) keypoints location.
+We apply RANSAC + F matrix to further prune outliers.
+Note: This script doesn't use descriptors.
+'''
+import argparse
+import os
+import time
+
+import cv2
+import numpy as np
+import torch
+import imageio
+from scipy.spatial import distance_matrix
+
+from COTR.utils import utils, debug_utils
+from COTR.models import build_model
+from COTR.options.options import *
+from COTR.options.options_utils import *
+from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine
+
+utils.fix_randomness(0)
+torch.set_grad_enabled(False)
+
+
+def main(opt):
+ model = build_model(opt)
+ model = model.cuda()
+ weights = torch.load(opt.load_weights_path)['model_state_dict']
+ utils.safe_load_weights(model, weights)
+ model = model.eval()
+
+ img_a = imageio.imread('./sample_data/imgs/21526113_4379776807.jpg')
+ img_b = imageio.imread('./sample_data/imgs/21126421_4537535153.jpg')
+ kp_a = np.load('./sample_data/21526113_4379776807.jpg.disk.kpts.npy')
+ kp_b = np.load('./sample_data/21126421_4537535153.jpg.disk.kpts.npy')
+
+ if opt.faster_infer:
+ engine = FasterSparseEngine(model, 32, mode='tile')
+ else:
+ engine = SparseEngine(model, 32, mode='tile')
+ t0 = time.time()
+ corrs_a_b = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True)
+ corrs_b_a = engine.cotr_corr_multiscale(img_b, img_a, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_b.shape[0], queries_a=kp_b, force=True)
+ t1 = time.time()
+ print(f'COTR spent {t1-t0} seconds.')
+ inds_a_b = np.argmin(distance_matrix(corrs_a_b[:, 2:], kp_b), axis=1)
+ matched_a_b = np.stack([np.arange(kp_a.shape[0]), inds_a_b]).T
+ inds_b_a = np.argmin(distance_matrix(corrs_b_a[:, 2:], kp_a), axis=1)
+ matched_b_a = np.stack([np.arange(kp_b.shape[0]), inds_b_a]).T
+
+ good = 0
+ final_matches = []
+ for m_ab in matched_a_b:
+ for m_ba in matched_b_a:
+ if (m_ab == m_ba[::-1]).all():
+ good += 1
+ final_matches.append(m_ab)
+ break
+ final_matches = np.array(final_matches)
+ final_corrs = np.concatenate([kp_a[final_matches[:, 0]], kp_b[final_matches[:, 1]]], axis=1)
+ _, mask = cv2.findFundamentalMat(final_corrs[:, :2], final_corrs[:, 2:], cv2.FM_RANSAC, ransacReprojThreshold=5, confidence=0.999999)
+ utils.visualize_corrs(img_a, img_b, final_corrs[np.where(mask[:, 0])])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ set_COTR_arguments(parser)
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
+ parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference')
+
+ opt = parser.parse_args()
+ opt.command = ' '.join(sys.argv)
+
+ layer_2_channels = {'layer1': 256,
+ 'layer2': 512,
+ 'layer3': 1024,
+ 'layer4': 2048, }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+ if opt.load_weights:
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
+ print_opt(opt)
+ main(opt)
diff --git a/imcui/third_party/COTR/demo_homography.py b/imcui/third_party/COTR/demo_homography.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf5962ec68494a4a8817ac46ece468c952a71cd7
--- /dev/null
+++ b/imcui/third_party/COTR/demo_homography.py
@@ -0,0 +1,84 @@
+'''
+COTR demo for homography estimation
+'''
+import argparse
+import os
+import time
+
+import cv2
+import numpy as np
+import torch
+import imageio
+import matplotlib.pyplot as plt
+
+from COTR.utils import utils, debug_utils
+from COTR.models import build_model
+from COTR.options.options import *
+from COTR.options.options_utils import *
+from COTR.inference.inference_helper import triangulate_corr
+from COTR.inference.sparse_engine import SparseEngine
+
+utils.fix_randomness(0)
+torch.set_grad_enabled(False)
+
+
+def main(opt):
+ model = build_model(opt)
+ model = model.cuda()
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
+ utils.safe_load_weights(model, weights)
+ model = model.eval()
+
+ img_a = imageio.imread('./sample_data/imgs/paint_1.JPG', pilmode='RGB')
+ img_b = imageio.imread('./sample_data/imgs/paint_2.jpg', pilmode='RGB')
+ rep_img = imageio.imread('./sample_data/imgs/Meisje_met_de_parel.jpg', pilmode='RGB')
+ rep_mask = np.ones(rep_img.shape[:2])
+ lu_corner = [932, 1025]
+ ru_corner = [2469, 901]
+ lb_corner = [908, 2927]
+ rb_corner = [2436, 3080]
+ queries = np.array([lu_corner, ru_corner, lb_corner, rb_corner]).astype(np.float32)
+ rep_coord = np.array([[0, 0], [rep_img.shape[1], 0], [0, rep_img.shape[0]], [rep_img.shape[1], rep_img.shape[0]]]).astype(np.float32)
+
+ engine = SparseEngine(model, 32, mode='stretching')
+ corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=True)
+
+ T = cv2.getPerspectiveTransform(rep_coord, corrs[:, 2:].astype(np.float32))
+ vmask = cv2.warpPerspective(rep_mask, T, (img_b.shape[1], img_b.shape[0])) > 0
+ warped = cv2.warpPerspective(rep_img, T, (img_b.shape[1], img_b.shape[0]))
+ out = warped * vmask[..., None] + img_b * (~vmask[..., None])
+
+ f, axarr = plt.subplots(1, 4)
+ axarr[0].imshow(rep_img)
+ axarr[0].title.set_text('Virtual Paint')
+ axarr[0].axis('off')
+ axarr[1].imshow(img_a)
+ axarr[1].title.set_text('Annotated Frame')
+ axarr[1].axis('off')
+ axarr[2].imshow(img_b)
+ axarr[2].title.set_text('Target Frame')
+ axarr[2].axis('off')
+ axarr[3].imshow(out)
+ axarr[3].title.set_text('Overlay')
+ axarr[3].axis('off')
+ plt.show()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ set_COTR_arguments(parser)
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
+
+ opt = parser.parse_args()
+ opt.command = ' '.join(sys.argv)
+
+ layer_2_channels = {'layer1': 256,
+ 'layer2': 512,
+ 'layer3': 1024,
+ 'layer4': 2048, }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+ if opt.load_weights:
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
+ print_opt(opt)
+ main(opt)
diff --git a/imcui/third_party/COTR/demo_reconstruction.py b/imcui/third_party/COTR/demo_reconstruction.py
new file mode 100644
index 0000000000000000000000000000000000000000..06262668446595b3dc368e416ad5373dc4b26873
--- /dev/null
+++ b/imcui/third_party/COTR/demo_reconstruction.py
@@ -0,0 +1,92 @@
+'''
+COTR two view reconstruction with known extrinsic/intrinsic demo
+'''
+import argparse
+import os
+import time
+
+import numpy as np
+import torch
+import imageio
+import open3d as o3d
+
+from COTR.utils import utils, debug_utils
+from COTR.models import build_model
+from COTR.options.options import *
+from COTR.options.options_utils import *
+from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine
+from COTR.projector import pcd_projector
+
+utils.fix_randomness(0)
+torch.set_grad_enabled(False)
+
+
+def triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b):
+ A = center_a
+ a = dir_a / np.linalg.norm(dir_a, axis=1, keepdims=True)
+ B = center_b
+ b = dir_b / np.linalg.norm(dir_b, axis=1, keepdims=True)
+ c = B - A
+ D = A + a * ((-np.sum(a * b, axis=1) * np.sum(b * c, axis=1) + np.sum(a * c, axis=1) * np.sum(b * b, axis=1)) / (np.sum(a * a, axis=1) * np.sum(b * b, axis=1) - np.sum(a * b, axis=1) * np.sum(a * b, axis=1)))[..., None]
+ return D
+
+
+def main(opt):
+ model = build_model(opt)
+ model = model.cuda()
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
+ utils.safe_load_weights(model, weights)
+ model = model.eval()
+
+ img_a = imageio.imread('./sample_data/imgs/img_0.jpg', pilmode='RGB')
+ img_b = imageio.imread('./sample_data/imgs/img_1.jpg', pilmode='RGB')
+
+ if opt.faster_infer:
+ engine = FasterSparseEngine(model, 32, mode='tile')
+ else:
+ engine = SparseEngine(model, 32, mode='tile')
+ t0 = time.time()
+ corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None)
+ t1 = time.time()
+ print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.')
+
+ camera_a = np.load('./sample_data/camera_0.npy', allow_pickle=True).item()
+ camera_b = np.load('./sample_data/camera_1.npy', allow_pickle=True).item()
+ center_a = camera_a['cam_center']
+ center_b = camera_b['cam_center']
+ rays_a = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, :2], np.ones([corrs.shape[0], 1]) * 2, camera_a['intrinsic'], motion=camera_a['c2w'])
+ rays_b = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, 2:], np.ones([corrs.shape[0], 1]) * 2, camera_b['intrinsic'], motion=camera_b['c2w'])
+ dir_a = rays_a - center_a
+ dir_b = rays_b - center_b
+ center_a = np.array([center_a] * corrs.shape[0])
+ center_b = np.array([center_b] * corrs.shape[0])
+ points = triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b)
+ colors = (img_a[tuple(np.floor(corrs[:, :2]).astype(int)[:, ::-1].T)] / 255 + img_b[tuple(np.floor(corrs[:, 2:]).astype(int)[:, ::-1].T)] / 255) / 2
+ colors = np.array(colors)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+ o3d.visualization.draw_geometries([pcd])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ set_COTR_arguments(parser)
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
+ parser.add_argument('--max_corrs', type=int, default=2048, help='number of correspondences')
+ parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference')
+
+ opt = parser.parse_args()
+ opt.command = ' '.join(sys.argv)
+
+ layer_2_channels = {'layer1': 256,
+ 'layer2': 512,
+ 'layer3': 1024,
+ 'layer4': 2048, }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+ if opt.load_weights:
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
+ print_opt(opt)
+ main(opt)
diff --git a/imcui/third_party/COTR/demo_single_pair.py b/imcui/third_party/COTR/demo_single_pair.py
new file mode 100644
index 0000000000000000000000000000000000000000..babc99589542f342e3225f2c345e3aa05535a2f1
--- /dev/null
+++ b/imcui/third_party/COTR/demo_single_pair.py
@@ -0,0 +1,66 @@
+'''
+COTR demo for a single image pair
+'''
+import argparse
+import os
+import time
+
+import cv2
+import numpy as np
+import torch
+import imageio
+import matplotlib.pyplot as plt
+
+from COTR.utils import utils, debug_utils
+from COTR.models import build_model
+from COTR.options.options import *
+from COTR.options.options_utils import *
+from COTR.inference.inference_helper import triangulate_corr
+from COTR.inference.sparse_engine import SparseEngine
+
+utils.fix_randomness(0)
+torch.set_grad_enabled(False)
+
+
+def main(opt):
+ model = build_model(opt)
+ model = model.cuda()
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
+ utils.safe_load_weights(model, weights)
+ model = model.eval()
+
+ img_a = imageio.imread('./sample_data/imgs/cathedral_1.jpg', pilmode='RGB')
+ img_b = imageio.imread('./sample_data/imgs/cathedral_2.jpg', pilmode='RGB')
+
+ engine = SparseEngine(model, 32, mode='tile')
+ t0 = time.time()
+ corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None)
+ t1 = time.time()
+
+ utils.visualize_corrs(img_a, img_b, corrs)
+ print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.')
+ dense = triangulate_corr(corrs, img_a.shape, img_b.shape)
+ warped = cv2.remap(img_b, dense[..., 0].astype(np.float32), dense[..., 1].astype(np.float32), interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
+ plt.imshow(warped / 255 * 0.5 + img_a / 255 * 0.5)
+ plt.show()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ set_COTR_arguments(parser)
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
+ parser.add_argument('--max_corrs', type=int, default=100, help='number of correspondences')
+
+ opt = parser.parse_args()
+ opt.command = ' '.join(sys.argv)
+
+ layer_2_channels = {'layer1': 256,
+ 'layer2': 512,
+ 'layer3': 1024,
+ 'layer4': 2048, }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+ if opt.load_weights:
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
+ print_opt(opt)
+ main(opt)
diff --git a/imcui/third_party/COTR/demo_wbs.py b/imcui/third_party/COTR/demo_wbs.py
new file mode 100644
index 0000000000000000000000000000000000000000..1592ca506cdc0fa06c0c5290b741a3f60f6023ae
--- /dev/null
+++ b/imcui/third_party/COTR/demo_wbs.py
@@ -0,0 +1,71 @@
+'''
+Manually passing scale to COTR, skip the scale difference estimation.
+'''
+import argparse
+import os
+import time
+
+import cv2
+import numpy as np
+import torch
+import imageio
+from scipy.spatial import distance_matrix
+import matplotlib.pyplot as plt
+
+from COTR.utils import utils, debug_utils
+from COTR.models import build_model
+from COTR.options.options import *
+from COTR.options.options_utils import *
+from COTR.inference.sparse_engine import SparseEngine
+
+utils.fix_randomness(0)
+torch.set_grad_enabled(False)
+
+
+def main(opt):
+ model = build_model(opt)
+ model = model.cuda()
+ weights = torch.load(opt.load_weights_path)['model_state_dict']
+ utils.safe_load_weights(model, weights)
+ model = model.eval()
+
+ img_a = imageio.imread('./sample_data/imgs/petrzin_01.png')
+ img_b = imageio.imread('./sample_data/imgs/petrzin_02.png')
+ img_a_area = 1.0
+ img_b_area = 1.0
+ gt_corrs = np.loadtxt('./sample_data/petrzin_pts.txt')
+ kp_a = gt_corrs[:, :2]
+ kp_b = gt_corrs[:, 2:]
+
+ engine = SparseEngine(model, 32, mode='tile')
+ t0 = time.time()
+ corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.75, 0.1, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True, areas=[img_a_area, img_b_area])
+ t1 = time.time()
+ print(f'COTR spent {t1-t0} seconds.')
+
+ utils.visualize_corrs(img_a, img_b, corrs)
+ plt.imshow(img_b)
+ plt.scatter(kp_b[:,0], kp_b[:,1])
+ plt.scatter(corrs[:,2], corrs[:,3])
+ plt.plot(np.stack([kp_b[:,0], corrs[:,2]], axis=1).T, np.stack([kp_b[:,1], corrs[:,3]], axis=1).T, color=[1,0,0])
+ plt.show()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ set_COTR_arguments(parser)
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
+
+ opt = parser.parse_args()
+ opt.command = ' '.join(sys.argv)
+
+ layer_2_channels = {'layer1': 256,
+ 'layer2': 512,
+ 'layer3': 1024,
+ 'layer4': 2048, }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+ if opt.load_weights:
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
+ print_opt(opt)
+ main(opt)
diff --git a/imcui/third_party/COTR/environment.yml b/imcui/third_party/COTR/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bc9ffc910bf9809b616504348d9b369043dee5cb
--- /dev/null
+++ b/imcui/third_party/COTR/environment.yml
@@ -0,0 +1,104 @@
+name: cotr_env
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - backcall=0.2.0=pyhd3eb1b0_0
+ - blas=1.0=mkl
+ - bzip2=1.0.8=h7b6447c_0
+ - ca-certificates=2021.4.13=h06a4308_1
+ - cairo=1.16.0=hf32fb01_1
+ - certifi=2020.12.5=py37h06a4308_0
+ - cudatoolkit=10.2.89=hfd86e86_1
+ - cycler=0.10.0=py37_0
+ - dbus=1.13.18=hb2f20db_0
+ - decorator=5.0.6=pyhd3eb1b0_0
+ - expat=2.3.0=h2531618_2
+ - ffmpeg=4.0=hcdf2ecd_0
+ - fontconfig=2.13.1=h6c09931_0
+ - freeglut=3.0.0=hf484d3e_5
+ - freetype=2.10.4=h5ab3b9f_0
+ - glib=2.68.1=h36276a3_0
+ - graphite2=1.3.14=h23475e2_0
+ - gst-plugins-base=1.14.0=h8213a91_2
+ - gstreamer=1.14.0=h28cd5cc_2
+ - harfbuzz=1.8.8=hffaf4a1_0
+ - hdf5=1.10.2=hba1933b_1
+ - icu=58.2=he6710b0_3
+ - imageio=2.9.0=pyhd3eb1b0_0
+ - intel-openmp=2021.2.0=h06a4308_610
+ - ipython=7.22.0=py37hb070fc8_0
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
+ - jasper=2.0.14=h07fcdf6_1
+ - jedi=0.17.0=py37_0
+ - jpeg=9b=h024ee3a_2
+ - kiwisolver=1.3.1=py37h2531618_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.33.1=h53a641e_7
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=9.1.0=hdf63c60_0
+ - libgfortran-ng=7.3.0=hdf63c60_0
+ - libglu=9.0.0=hf484d3e_1
+ - libopencv=3.4.2=hb342d67_1
+ - libopus=1.3.1=h7b6447c_0
+ - libpng=1.6.37=hbc83047_0
+ - libstdcxx-ng=9.1.0=hdf63c60_0
+ - libtiff=4.1.0=h2733197_1
+ - libuuid=1.0.3=h1bed415_2
+ - libuv=1.40.0=h7b6447c_0
+ - libvpx=1.7.0=h439df22_0
+ - libxcb=1.14=h7b6447c_0
+ - libxml2=2.9.10=hb55368b_3
+ - lz4-c=1.9.3=h2531618_0
+ - matplotlib=3.3.4=py37h06a4308_0
+ - matplotlib-base=3.3.4=py37h62a2d02_0
+ - mkl=2020.2=256
+ - mkl-service=2.3.0=py37he8ac12f_0
+ - mkl_fft=1.3.0=py37h54f3939_0
+ - mkl_random=1.1.1=py37h0573a6f_0
+ - ncurses=6.2=he6710b0_1
+ - ninja=1.10.2=hff7bd54_1
+ - numpy=1.19.2=py37h54aff64_0
+ - numpy-base=1.19.2=py37hfa32c7d_0
+ - olefile=0.46=py37_0
+ - opencv=3.4.2=py37h6fd60c2_1
+ - openssl=1.1.1k=h27cfd23_0
+ - parso=0.8.2=pyhd3eb1b0_0
+ - pcre=8.44=he6710b0_0
+ - pexpect=4.8.0=pyhd3eb1b0_3
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
+ - pillow=8.2.0=py37he98fc37_0
+ - pip=21.0.1=py37h06a4308_0
+ - pixman=0.40.0=h7b6447c_0
+ - prompt-toolkit=3.0.17=pyh06a4308_0
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
+ - py-opencv=3.4.2=py37hb342d67_1
+ - pygments=2.8.1=pyhd3eb1b0_0
+ - pyparsing=2.4.7=pyhd3eb1b0_0
+ - pyqt=5.9.2=py37h05f1152_2
+ - python=3.7.10=hdb3f193_0
+ - python-dateutil=2.8.1=pyhd3eb1b0_0
+ - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0
+ - qt=5.9.7=h5867ecd_1
+ - readline=8.1=h27cfd23_0
+ - scipy=1.2.1=py37h7c811a0_0
+ - setuptools=52.0.0=py37h06a4308_0
+ - sip=4.19.8=py37hf484d3e_0
+ - six=1.15.0=py37h06a4308_0
+ - sqlite=3.35.4=hdfb4753_0
+ - tk=8.6.10=hbc83047_0
+ - torchaudio=0.7.2=py37
+ - torchvision=0.8.2=py37_cu102
+ - tornado=6.1=py37h27cfd23_0
+ - tqdm=4.59.0=pyhd3eb1b0_1
+ - traitlets=5.0.5=pyhd3eb1b0_0
+ - typing_extensions=3.7.4.3=pyha847dfd_0
+ - vispy=0.5.3=py37hee6b756_0
+ - wcwidth=0.2.5=py_0
+ - wheel=0.36.2=pyhd3eb1b0_0
+ - xz=5.2.5=h7b6447c_0
+ - zlib=1.2.11=h7b6447c_3
+ - zstd=1.4.9=haebb681_0
+ - pip:
+ - tables==3.6.1
diff --git a/imcui/third_party/COTR/scripts/prepare_megadepth_split.py b/imcui/third_party/COTR/scripts/prepare_megadepth_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..9da905d157cc88a167d91664689eccb8286107ee
--- /dev/null
+++ b/imcui/third_party/COTR/scripts/prepare_megadepth_split.py
@@ -0,0 +1,36 @@
+import os
+import random
+import json
+
+
+# read the json
+valid_list_json_path = './megadepth_valid_list.json'
+assert os.path.isfile(valid_list_json_path), 'Change to the valid list json'
+with open(valid_list_json_path, 'r') as f:
+ all_list = json.load(f)
+
+# build scene - image dictionary
+scene_img_dict = {}
+for item in all_list:
+ if not item[:4] in scene_img_dict:
+ scene_img_dict[item[:4]] = []
+ scene_img_dict[item[:4]].append(item)
+
+train_split = []
+val_split = []
+test_split = []
+for k in sorted(scene_img_dict.keys()):
+ if int(k) == 204:
+ val_split += scene_img_dict[k]
+ elif int(k) <= 240 and int(k) != 204:
+ train_split += scene_img_dict[k]
+ else:
+ test_split += scene_img_dict[k]
+
+# save split to json
+with open('megadepth_train.json', 'w') as outfile:
+ json.dump(sorted(train_split), outfile, indent=4)
+with open('megadepth_val.json', 'w') as outfile:
+ json.dump(sorted(val_split), outfile, indent=4)
+with open('megadepth_test.json', 'w') as outfile:
+ json.dump(sorted(test_split), outfile, indent=4)
diff --git a/imcui/third_party/COTR/scripts/prepare_megadepth_valid_list.py b/imcui/third_party/COTR/scripts/prepare_megadepth_valid_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..e08cf1fa75ecba25c46ee749d76d1356f7838dc9
--- /dev/null
+++ b/imcui/third_party/COTR/scripts/prepare_megadepth_valid_list.py
@@ -0,0 +1,41 @@
+import os
+import json
+
+import tables
+from tqdm import tqdm
+import numpy as np
+
+
+def read_all_imgs(base_dir):
+ all_imgs = []
+ for cur, dirs, files in os.walk(base_dir):
+ if 'imgs' in cur:
+ all_imgs += [os.path.join(cur, f) for f in files]
+ all_imgs.sort()
+ return all_imgs
+
+
+def filter_semantic_depth(imgs):
+ valid_imgs = []
+ for item in tqdm(imgs):
+ f_name = os.path.splitext(os.path.basename(item))[0] + '.h5'
+ depth_dir = os.path.abspath(os.path.join(os.path.dirname(item), '../depths'))
+ depth_path = os.path.join(depth_dir, f_name)
+ depth_h5 = tables.open_file(depth_path, mode='r')
+ _depth = np.array(depth_h5.root.depth)
+ if _depth.min() >= 0:
+ prefix = os.path.abspath(os.path.join(item, '../../../../')) + '/'
+ rel_image_path = item.replace(prefix, '')
+ valid_imgs.append(rel_image_path)
+ depth_h5.close()
+ valid_imgs.sort()
+ return valid_imgs
+
+
+if __name__ == "__main__":
+ MegaDepth_v1 = '/media/jiangwei/data_ssd/MegaDepth_v1/'
+ assert os.path.isdir(MegaDepth_v1), 'Change to your local path'
+ all_imgs = read_all_imgs(MegaDepth_v1)
+ valid_imgs = filter_semantic_depth(all_imgs)
+ with open('megadepth_valid_list.json', 'w') as outfile:
+ json.dump(valid_imgs, outfile, indent=4)
diff --git a/imcui/third_party/COTR/scripts/prepare_nn_distance_mat.py b/imcui/third_party/COTR/scripts/prepare_nn_distance_mat.py
new file mode 100644
index 0000000000000000000000000000000000000000..507eb3161d7c6c638dfaf99b560fa06c531cbec2
--- /dev/null
+++ b/imcui/third_party/COTR/scripts/prepare_nn_distance_mat.py
@@ -0,0 +1,145 @@
+'''
+compute distance matrix for megadepth using ComputeCanada
+'''
+
+import sys
+sys.path.append('..')
+
+
+import os
+import argparse
+import numpy as np
+from joblib import Parallel, delayed
+from tqdm import tqdm
+
+from COTR.options.options import *
+from COTR.options.options_utils import *
+from COTR.utils import debug_utils, utils, constants
+from COTR.datasets import colmap_helper
+from COTR.projector import pcd_projector
+from COTR.global_configs import dataset_config
+
+
+assert colmap_helper.COVISIBILITY_CHECK, 'Please enable COVISIBILITY_CHECK'
+assert colmap_helper.LOAD_PCD, 'Please enable LOAD_PCD'
+
+OFFSET_THRESHOLD = 1.0
+
+
+def get_index_pairs(dist_mat, cells):
+ pairs = []
+ for row in range(dist_mat.shape[0]):
+ for col in range(dist_mat.shape[0]):
+ if dist_mat[row][col] == -1:
+ pairs.append([row, col])
+ if len(pairs) == cells:
+ return pairs
+ return pairs
+
+
+def load_dist_mat(path, size=None):
+ if os.path.isfile(path):
+ dist_mat = np.load(path)
+ assert dist_mat.shape[0] == dist_mat.shape[1]
+ else:
+ dist_mat = np.ones([size, size], dtype=np.float32) * -1
+ assert dist_mat.shape[0] == dist_mat.shape[1]
+ return dist_mat
+
+
+def distance_between_two_caps(caps):
+ cap_1, cap_2 = caps
+ try:
+ if len(np.intersect1d(cap_1.point3d_id, cap_2.point3d_id)) == 0:
+ return 0.0
+ pcd = cap_2.point_cloud_world
+ extrin_cap_1 = cap_1.cam_pose.world_to_camera[0:3, :]
+ intrin_cap_1 = cap_1.pinhole_cam.intrinsic_mat
+ size = cap_1.pinhole_cam.shape[:2]
+ reproj = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(pcd[:, 0:3], intrin_cap_1, extrin_cap_1, size, keep_z=True, crop=True, filter_neg=True, norm_coord=False)
+ reproj = pcd_projector.PointCloudProjector.pcd_2d_to_img_2d_np(reproj, size)[..., 0]
+ # 1. calculate the iou
+ query_mask = cap_1.depth_map > 0
+ reproj_mask = reproj > 0
+ intersection_mask = query_mask * reproj_mask
+ union_mask = query_mask | reproj_mask
+ if union_mask.sum() == 0:
+ return 0.0
+ intersection_mask = (abs(cap_1.depth_map - reproj) * intersection_mask < OFFSET_THRESHOLD) * intersection_mask
+ ratio = intersection_mask.sum() / union_mask.sum()
+ if ratio == 0.0:
+ return 0.0
+ return ratio
+ except Exception as e:
+ print(e)
+ return 0.0
+
+
+def fill_covisibility(scene, dist_mat):
+ for i in range(dist_mat.shape[0]):
+ nns = scene.get_covisible_caps(scene[i])
+ covis_list = [scene.img_id_to_index_dict[cap.image_id] for cap in nns]
+ for j in range(dist_mat.shape[0]):
+ if j not in covis_list:
+ dist_mat[i][j] = 0
+ return dist_mat
+
+
+def main(opt):
+ # fast fail
+ try:
+ dist_mat = load_dist_mat(opt.out_path)
+ if dist_mat.min() >= 0.0:
+ print(f'{opt.out_path} is complete!')
+ exit()
+ else:
+ print('continue working')
+ except Exception as e:
+ print(e)
+ print('first time start working')
+ scene_dir = opt.scenes_name_list[0]['scene_dir']
+ image_dir = opt.scenes_name_list[0]['image_dir']
+ depth_dir = opt.scenes_name_list[0]['depth_dir']
+ scene = colmap_helper.ColmapWithDepthAsciiReader.read_sfm_scene_given_valid_list_path(scene_dir, image_dir, depth_dir, dataset_config[opt.dataset_name]['valid_list_json'], opt.crop_cam)
+ size = len(scene.captures)
+ dist_mat = load_dist_mat(opt.out_path, size)
+ if opt.use_ram:
+ scene.read_data_to_ram(['depth'])
+ if dist_mat.max() == -1:
+ dist_mat = fill_covisibility(scene, dist_mat)
+ np.save(opt.out_path, dist_mat)
+ pairs = get_index_pairs(dist_mat, opt.cells)
+ in_pairs = [[scene[p[0]], scene[p[1]]] for p in pairs]
+ results = Parallel(n_jobs=opt.num_cpus)(delayed(distance_between_two_caps)(pair) for pair in tqdm(in_pairs, desc='calculating distance matrix', total=len(in_pairs)))
+ for i, p in enumerate(pairs):
+ r, c = p
+ dist_mat[r][c] = results[i]
+ np.save(opt.out_path, dist_mat)
+ print(f'finished from {pairs[0][0]}-{pairs[0][1]} -> {pairs[-1][0]}-{pairs[-1][1]}')
+ print(f'in total {len(pairs)} cells')
+ print(f'progress {(dist_mat >= 0).sum() / dist_mat.size}')
+ print(f'save at {opt.out_path}')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ set_general_arguments(parser)
+ parser.add_argument('--dataset_name', type=str, default='megadepth', help='dataset name')
+ parser.add_argument('--use_ram', type=str2bool, default=False, help='load image/depth to ram')
+ parser.add_argument('--info_level', type=str, default='rgbd', help='the information level of dataset')
+ parser.add_argument('--scene', type=str, default='0000', required=True, help='what scene want to use')
+ parser.add_argument('--seq', type=str, default='0', required=True, help='what seq want to use')
+ parser.add_argument('--crop_cam', choices=['no_crop', 'crop_center', 'crop_center_and_resize'], type=str, default='no_crop', help='crop the center of image to avoid changing aspect ratio, resize to make the operations batch-able.')
+ parser.add_argument('--cells', type=int, default=10000, help='the number of cells to be computed in this run')
+ parser.add_argument('--num_cpus', type=int, default=6, help='num of cores')
+
+ opt = parser.parse_args()
+ opt.scenes_name_list = options_utils.build_scenes_name_list_from_opt(opt)
+ opt.out_dir = os.path.join(os.path.dirname(opt.scenes_name_list[0]['depth_dir']), 'dist_mat')
+ opt.out_path = os.path.join(opt.out_dir, 'dist_mat.npy')
+ os.makedirs(opt.out_dir, exist_ok=True)
+ if opt.confirm:
+ confirm_opt(opt)
+ else:
+ print_opt(opt)
+ main(opt)
diff --git a/imcui/third_party/COTR/scripts/rectify_megadepth.py b/imcui/third_party/COTR/scripts/rectify_megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..257dee530d2d8fac96c994b07e9f36cba984fedb
--- /dev/null
+++ b/imcui/third_party/COTR/scripts/rectify_megadepth.py
@@ -0,0 +1,299 @@
+'''
+rectify the SfM model from SIMPLE_RADIAL to PINHOLE
+'''
+import os
+
+command_1 = 'colmap image_undistorter --image_path={0} --input_path={1} --output_path={2}'
+command_2 = 'colmap model_converter --input_path={0} --output_path={1} --output_type=TXT'
+command_3 = 'mv {0} {1}'
+command_4 = 'python sort_images_txt.py --reference={0} --unordered={1} --save_to={2}'
+
+MegaDepth_v1_SfM = '/media/jiangwei/data_ssd/MegaDepth_v1_SfM/'
+assert os.path.isdir(MegaDepth_v1_SfM), 'Change to your local path'
+all_scenes = [
+ '0000/sparse/manhattan/0',
+ '0000/sparse/manhattan/1',
+ '0001/sparse/manhattan/0',
+ '0002/sparse/manhattan/0',
+ '0003/sparse/manhattan/0',
+ '0004/sparse/manhattan/0',
+ '0004/sparse/manhattan/1',
+ '0004/sparse/manhattan/2',
+ '0005/sparse/manhattan/0',
+ '0005/sparse/manhattan/1',
+ '0007/sparse/manhattan/0',
+ '0007/sparse/manhattan/1',
+ '0008/sparse/manhattan/0',
+ '0011/sparse/manhattan/0',
+ '0012/sparse/manhattan/0',
+ '0013/sparse/manhattan/0',
+ '0015/sparse/manhattan/0',
+ '0015/sparse/manhattan/1',
+ '0016/sparse/manhattan/0',
+ '0017/sparse/manhattan/0',
+ '0019/sparse/manhattan/0',
+ '0019/sparse/manhattan/1',
+ '0020/sparse/manhattan/0',
+ '0020/sparse/manhattan/1',
+ '0021/sparse/manhattan/0',
+ '0022/sparse/manhattan/0',
+ '0023/sparse/manhattan/0',
+ '0023/sparse/manhattan/1',
+ '0024/sparse/manhattan/0',
+ '0025/sparse/manhattan/0',
+ '0025/sparse/manhattan/1',
+ '0026/sparse/manhattan/0',
+ '0027/sparse/manhattan/0',
+ '0032/sparse/manhattan/0',
+ '0032/sparse/manhattan/1',
+ '0033/sparse/manhattan/0',
+ '0034/sparse/manhattan/0',
+ '0035/sparse/manhattan/0',
+ '0036/sparse/manhattan/0',
+ '0037/sparse/manhattan/0',
+ '0039/sparse/manhattan/0',
+ '0041/sparse/manhattan/0',
+ '0041/sparse/manhattan/1',
+ '0042/sparse/manhattan/0',
+ '0043/sparse/manhattan/0',
+ '0044/sparse/manhattan/0',
+ '0046/sparse/manhattan/0',
+ '0046/sparse/manhattan/1',
+ '0046/sparse/manhattan/2',
+ '0047/sparse/manhattan/0',
+ '0048/sparse/manhattan/0',
+ '0049/sparse/manhattan/0',
+ '0050/sparse/manhattan/0',
+ '0056/sparse/manhattan/0',
+ '0057/sparse/manhattan/0',
+ '0058/sparse/manhattan/0',
+ '0058/sparse/manhattan/1',
+ '0060/sparse/manhattan/0',
+ '0061/sparse/manhattan/0',
+ '0062/sparse/manhattan/0',
+ '0062/sparse/manhattan/1',
+ '0063/sparse/manhattan/0',
+ '0063/sparse/manhattan/1',
+ '0063/sparse/manhattan/2',
+ '0063/sparse/manhattan/3',
+ '0064/sparse/manhattan/0',
+ '0065/sparse/manhattan/0',
+ '0067/sparse/manhattan/0',
+ '0070/sparse/manhattan/0',
+ '0071/sparse/manhattan/0',
+ '0071/sparse/manhattan/1',
+ '0076/sparse/manhattan/0',
+ '0078/sparse/manhattan/0',
+ '0080/sparse/manhattan/0',
+ '0083/sparse/manhattan/0',
+ '0086/sparse/manhattan/0',
+ '0087/sparse/manhattan/0',
+ '0087/sparse/manhattan/1',
+ '0090/sparse/manhattan/0',
+ '0092/sparse/manhattan/0',
+ '0092/sparse/manhattan/1',
+ '0094/sparse/manhattan/0',
+ '0095/sparse/manhattan/0',
+ '0095/sparse/manhattan/1',
+ '0095/sparse/manhattan/2',
+ '0098/sparse/manhattan/0',
+ '0099/sparse/manhattan/0',
+ '0100/sparse/manhattan/0',
+ '0101/sparse/manhattan/0',
+ '0102/sparse/manhattan/0',
+ '0103/sparse/manhattan/0',
+ '0104/sparse/manhattan/0',
+ '0104/sparse/manhattan/1',
+ '0105/sparse/manhattan/0',
+ '0107/sparse/manhattan/0',
+ '0115/sparse/manhattan/0',
+ '0117/sparse/manhattan/0',
+ '0117/sparse/manhattan/1',
+ '0117/sparse/manhattan/2',
+ '0121/sparse/manhattan/0',
+ '0121/sparse/manhattan/1',
+ '0122/sparse/manhattan/0',
+ '0129/sparse/manhattan/0',
+ '0130/sparse/manhattan/0',
+ '0130/sparse/manhattan/1',
+ '0130/sparse/manhattan/2',
+ '0133/sparse/manhattan/0',
+ '0133/sparse/manhattan/1',
+ '0137/sparse/manhattan/0',
+ '0137/sparse/manhattan/1',
+ '0137/sparse/manhattan/2',
+ '0141/sparse/manhattan/0',
+ '0143/sparse/manhattan/0',
+ '0147/sparse/manhattan/0',
+ '0147/sparse/manhattan/1',
+ '0148/sparse/manhattan/0',
+ '0148/sparse/manhattan/1',
+ '0149/sparse/manhattan/0',
+ '0150/sparse/manhattan/0',
+ '0151/sparse/manhattan/0',
+ '0156/sparse/manhattan/0',
+ '0160/sparse/manhattan/0',
+ '0160/sparse/manhattan/1',
+ '0160/sparse/manhattan/2',
+ '0162/sparse/manhattan/0',
+ '0162/sparse/manhattan/1',
+ '0168/sparse/manhattan/0',
+ '0175/sparse/manhattan/0',
+ '0176/sparse/manhattan/0',
+ '0176/sparse/manhattan/1',
+ '0176/sparse/manhattan/2',
+ '0177/sparse/manhattan/0',
+ '0178/sparse/manhattan/0',
+ '0178/sparse/manhattan/1',
+ '0181/sparse/manhattan/0',
+ '0183/sparse/manhattan/0',
+ '0185/sparse/manhattan/0',
+ '0186/sparse/manhattan/0',
+ '0189/sparse/manhattan/0',
+ '0190/sparse/manhattan/0',
+ '0197/sparse/manhattan/0',
+ '0200/sparse/manhattan/0',
+ '0200/sparse/manhattan/1',
+ '0204/sparse/manhattan/0',
+ '0204/sparse/manhattan/1',
+ '0205/sparse/manhattan/0',
+ '0205/sparse/manhattan/1',
+ '0209/sparse/manhattan/1',
+ '0212/sparse/manhattan/0',
+ '0212/sparse/manhattan/1',
+ '0214/sparse/manhattan/0',
+ '0214/sparse/manhattan/1',
+ '0217/sparse/manhattan/0',
+ '0223/sparse/manhattan/0',
+ '0223/sparse/manhattan/1',
+ '0223/sparse/manhattan/2',
+ '0224/sparse/manhattan/0',
+ '0224/sparse/manhattan/1',
+ '0229/sparse/manhattan/0',
+ '0231/sparse/manhattan/0',
+ '0235/sparse/manhattan/0',
+ '0237/sparse/manhattan/0',
+ '0238/sparse/manhattan/0',
+ '0240/sparse/manhattan/0',
+ '0243/sparse/manhattan/0',
+ '0252/sparse/manhattan/0',
+ '0257/sparse/manhattan/0',
+ '0258/sparse/manhattan/0',
+ '0265/sparse/manhattan/0',
+ '0265/sparse/manhattan/1',
+ '0269/sparse/manhattan/0',
+ '0269/sparse/manhattan/1',
+ '0269/sparse/manhattan/2',
+ '0271/sparse/manhattan/0',
+ '0275/sparse/manhattan/0',
+ '0277/sparse/manhattan/0',
+ '0277/sparse/manhattan/1',
+ '0281/sparse/manhattan/0',
+ '0285/sparse/manhattan/0',
+ '0286/sparse/manhattan/0',
+ '0286/sparse/manhattan/1',
+ '0290/sparse/manhattan/0',
+ '0290/sparse/manhattan/1',
+ '0294/sparse/manhattan/0',
+ '0299/sparse/manhattan/0',
+ '0303/sparse/manhattan/0',
+ '0306/sparse/manhattan/0',
+ '0307/sparse/manhattan/0',
+ '0312/sparse/manhattan/0',
+ '0312/sparse/manhattan/1',
+ '0323/sparse/manhattan/0',
+ '0326/sparse/manhattan/0',
+ '0327/sparse/manhattan/0',
+ '0327/sparse/manhattan/1',
+ '0327/sparse/manhattan/2',
+ '0331/sparse/manhattan/0',
+ '0335/sparse/manhattan/0',
+ '0335/sparse/manhattan/1',
+ '0341/sparse/manhattan/0',
+ '0341/sparse/manhattan/1',
+ '0348/sparse/manhattan/0',
+ '0349/sparse/manhattan/0',
+ '0349/sparse/manhattan/1',
+ '0360/sparse/manhattan/0',
+ '0360/sparse/manhattan/1',
+ '0360/sparse/manhattan/2',
+ '0366/sparse/manhattan/0',
+ '0377/sparse/manhattan/0',
+ '0380/sparse/manhattan/0',
+ '0387/sparse/manhattan/0',
+ '0389/sparse/manhattan/0',
+ '0389/sparse/manhattan/1',
+ '0394/sparse/manhattan/0',
+ '0394/sparse/manhattan/1',
+ '0402/sparse/manhattan/0',
+ '0402/sparse/manhattan/1',
+ '0406/sparse/manhattan/0',
+ '0407/sparse/manhattan/0',
+ '0411/sparse/manhattan/0',
+ '0411/sparse/manhattan/1',
+ '0412/sparse/manhattan/0',
+ '0412/sparse/manhattan/1',
+ '0412/sparse/manhattan/2',
+ '0430/sparse/manhattan/0',
+ '0430/sparse/manhattan/1',
+ '0430/sparse/manhattan/2',
+ '0443/sparse/manhattan/0',
+ '0446/sparse/manhattan/0',
+ '0455/sparse/manhattan/0',
+ '0472/sparse/manhattan/0',
+ '0472/sparse/manhattan/1',
+ '0474/sparse/manhattan/0',
+ '0474/sparse/manhattan/1',
+ '0474/sparse/manhattan/2',
+ '0476/sparse/manhattan/0',
+ '0476/sparse/manhattan/1',
+ '0476/sparse/manhattan/2',
+ '0478/sparse/manhattan/0',
+ '0478/sparse/manhattan/1',
+ '0482/sparse/manhattan/0',
+ '0493/sparse/manhattan/0',
+ '0493/sparse/manhattan/1',
+ '0494/sparse/manhattan/1',
+ '0496/sparse/manhattan/0',
+ '0505/sparse/manhattan/0',
+ '0559/sparse/manhattan/0',
+ '0733/sparse/manhattan/0',
+ '0733/sparse/manhattan/1',
+ '0768/sparse/manhattan/0',
+ '0860/sparse/manhattan/0',
+ '0860/sparse/manhattan/1',
+ '1001/sparse/manhattan/0',
+ '1017/sparse/manhattan/0',
+ '1589/sparse/manhattan/0',
+ '3346/sparse/manhattan/0',
+ '4541/sparse/manhattan/0',
+ '5000/sparse/manhattan/0',
+ '5001/sparse/manhattan/0',
+ '5002/sparse/manhattan/0',
+ '5003/sparse/manhattan/0',
+ '5004/sparse/manhattan/0',
+ '5005/sparse/manhattan/0',
+ '5006/sparse/manhattan/0',
+ '5007/sparse/manhattan/0',
+ '5008/sparse/manhattan/0',
+ '5009/sparse/manhattan/0',
+ '5010/sparse/manhattan/0',
+ '5011/sparse/manhattan/0',
+ '5012/sparse/manhattan/0',
+ '5013/sparse/manhattan/0',
+ '5014/sparse/manhattan/0',
+ '5015/sparse/manhattan/0',
+ '5016/sparse/manhattan/0',
+ '5017/sparse/manhattan/0',
+ '5018/sparse/manhattan/0',
+]
+
+with open('rectify.sh', "w") as fid:
+ for s in all_scenes:
+ s = os.path.join(MegaDepth_v1_SfM, s)
+ new_dir = s + '_rectified'
+ img_dir = s[:s.find('sparse')] + 'images'
+ fid.write(command_1.format(img_dir, s, new_dir) + '\n')
+ fid.write(command_2.format(new_dir + '/sparse', new_dir + '/sparse') + '\n')
+ fid.write(command_3.format(new_dir + '/sparse/images.txt', new_dir + '/sparse/unorder_images.txt') + '\n')
+ fid.write(command_4.format(s + '/images.txt', new_dir + '/sparse/unorder_images.txt', new_dir + '/sparse/images.txt') + '\n')
diff --git a/imcui/third_party/COTR/scripts/sort_images_txt.py b/imcui/third_party/COTR/scripts/sort_images_txt.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa52bddc294a62610aa657b01a8ec366ce9f3267
--- /dev/null
+++ b/imcui/third_party/COTR/scripts/sort_images_txt.py
@@ -0,0 +1,78 @@
+import sys
+assert sys.version_info >= (3, 7), 'ordered dict is required'
+import os
+import argparse
+import re
+
+from tqdm import tqdm
+
+
+def read_images_meta(images_txt_path):
+ images_meta = {}
+ with open(images_txt_path, "r") as fid:
+ line = fid.readline()
+ assert line == '# Image list with two lines of data per image:\n'
+ line = fid.readline()
+ assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
+ line = fid.readline()
+ assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
+ line = fid.readline()
+ assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line)
+ num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line)
+ num_images = int(num_images)
+ mean_ob_per_img = float(mean_ob_per_img)
+
+ for _ in tqdm(range(num_images), desc='reading images meta'):
+ l = fid.readline()
+ elems = l.split()
+ image_id = int(elems[0])
+ l2 = fid.readline()
+ images_meta[image_id] = [l, l2]
+ return images_meta
+
+
+def read_header(images_txt_path):
+ header = []
+ with open(images_txt_path, "r") as fid:
+ line = fid.readline()
+ assert line == '# Image list with two lines of data per image:\n'
+ header.append(line)
+ line = fid.readline()
+ assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
+ header.append(line)
+ line = fid.readline()
+ assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
+ header.append(line)
+ line = fid.readline()
+ assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line)
+ header.append(line)
+ return header
+
+
+def export_images_txt(save_to, header, content):
+ assert not os.path.isfile(save_to), 'you are overriding existing files'
+ with open(save_to, "w") as fid:
+ for l in header:
+ fid.write(l)
+ for k, item in content.items():
+ for l in item:
+ fid.write(l)
+
+
+def main(opt):
+ reference = read_images_meta(opt.reference_images_txt)
+ unordered = read_images_meta(opt.unordered_images_txt)
+ ordered = {}
+ for k in reference.keys():
+ ordered[k] = unordered[k]
+ header = read_header(opt.unordered_images_txt)
+ export_images_txt(opt.save_to, header, ordered)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--reference_images_txt', type=str, default=None, required=True)
+ parser.add_argument('--unordered_images_txt', type=str, default=None, required=True)
+ parser.add_argument('--save_to', type=str, default=None, required=True)
+ opt = parser.parse_args()
+ main(opt)
diff --git a/imcui/third_party/COTR/train_cotr.py b/imcui/third_party/COTR/train_cotr.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd06fda41c7c97edaac784195895f8a2b36d07f3
--- /dev/null
+++ b/imcui/third_party/COTR/train_cotr.py
@@ -0,0 +1,149 @@
+import argparse
+import subprocess
+import pprint
+
+import numpy as np
+import torch
+# import torch.multiprocessing
+# torch.multiprocessing.set_sharing_strategy('file_system')
+from torch.utils.data import DataLoader
+
+from COTR.models import build_model
+from COTR.utils import debug_utils, utils
+from COTR.datasets import cotr_dataset
+from COTR.trainers.cotr_trainer import COTRTrainer
+from COTR.global_configs import general_config
+from COTR.options.options import *
+from COTR.options.options_utils import *
+
+
+utils.fix_randomness(0)
+
+
+def train(opt):
+ pprint.pprint(dict(os.environ), width=1)
+ result = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE)
+ print(result.stdout.read().decode())
+ device = torch.cuda.current_device()
+ print(f'can see {torch.cuda.device_count()} gpus')
+ print(f'current using gpu at {device} -- {torch.cuda.get_device_name(device)}')
+ # dummy = torch.rand(3758725612).to(device)
+ # del dummy
+ torch.cuda.empty_cache()
+ model = build_model(opt)
+ model = model.to(device)
+ if opt.enable_zoom:
+ train_dset = cotr_dataset.COTRZoomDataset(opt, 'train')
+ val_dset = cotr_dataset.COTRZoomDataset(opt, 'val')
+ else:
+ train_dset = cotr_dataset.COTRDataset(opt, 'train')
+ val_dset = cotr_dataset.COTRDataset(opt, 'val')
+
+ train_loader = DataLoader(train_dset, batch_size=opt.batch_size,
+ shuffle=opt.shuffle_data, num_workers=opt.workers,
+ worker_init_fn=utils.worker_init_fn, pin_memory=True)
+ val_loader = DataLoader(val_dset, batch_size=opt.batch_size,
+ shuffle=opt.shuffle_data, num_workers=opt.workers,
+ drop_last=True, worker_init_fn=utils.worker_init_fn, pin_memory=True)
+
+ optim_list = [{"params": model.transformer.parameters(), "lr": opt.learning_rate},
+ {"params": model.corr_embed.parameters(), "lr": opt.learning_rate},
+ {"params": model.query_proj.parameters(), "lr": opt.learning_rate},
+ {"params": model.input_proj.parameters(), "lr": opt.learning_rate},
+ ]
+ if opt.lr_backbone > 0:
+ optim_list.append({"params": model.backbone.parameters(), "lr": opt.lr_backbone})
+
+ optim = torch.optim.Adam(optim_list)
+ trainer = COTRTrainer(opt, model, optim, None, train_loader, val_loader)
+ trainer.train()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ set_general_arguments(parser)
+ set_dataset_arguments(parser)
+ set_nn_arguments(parser)
+ set_COTR_arguments(parser)
+ parser.add_argument('--num_kp', type=int,
+ default=100)
+ parser.add_argument('--kp_pool', type=int,
+ default=100)
+ parser.add_argument('--enable_zoom', type=str2bool,
+ default=False)
+ parser.add_argument('--zoom_start', type=float,
+ default=1.0)
+ parser.add_argument('--zoom_end', type=float,
+ default=0.1)
+ parser.add_argument('--zoom_levels', type=int,
+ default=10)
+ parser.add_argument('--zoom_jitter', type=float,
+ default=0.5)
+
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
+ parser.add_argument('--tb_dir', type=str, default=general_config['tb_out'], help='tensorboard runs directory')
+
+ parser.add_argument('--learning_rate', type=float,
+ default=1e-4, help='learning rate')
+ parser.add_argument('--lr_backbone', type=float,
+ default=1e-5, help='backbone learning rate')
+ parser.add_argument('--batch_size', type=int,
+ default=32, help='batch size for training')
+ parser.add_argument('--cycle_consis', type=str2bool, default=True,
+ help='cycle consistency')
+ parser.add_argument('--bidirectional', type=str2bool, default=True,
+ help='left2right and right2left')
+ parser.add_argument('--max_iter', type=int,
+ default=200000, help='total training iterations')
+ parser.add_argument('--valid_iter', type=int,
+ default=1000, help='iterval of validation')
+ parser.add_argument('--resume', type=str2bool, default=False,
+ help='resume training with same model name')
+ parser.add_argument('--cc_resume', type=str2bool, default=False,
+ help='resume from last run if possible')
+ parser.add_argument('--need_rotation', type=str2bool, default=False,
+ help='rotation augmentation')
+ parser.add_argument('--max_rotation', type=float, default=0,
+ help='max rotation for data augmentation')
+ parser.add_argument('--rotation_chance', type=float, default=0,
+ help='the probability of being rotated')
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
+ parser.add_argument('--suffix', type=str, default='', help='model suffix')
+
+ opt = parser.parse_args()
+ opt.command = ' '.join(sys.argv)
+
+ layer_2_channels = {'layer1': 256,
+ 'layer2': 512,
+ 'layer3': 1024,
+ 'layer4': 2048, }
+ opt.dim_feedforward = layer_2_channels[opt.layer]
+ opt.num_queries = opt.num_kp
+
+ opt.name = get_compact_naming_cotr(opt)
+ opt.out = os.path.join(opt.out_dir, opt.name)
+ opt.tb_out = os.path.join(opt.tb_dir, opt.name)
+
+ if opt.cc_resume:
+ if os.path.isfile(os.path.join(opt.out, 'checkpoint.pth.tar')):
+ print('resuming from last run')
+ opt.load_weights = None
+ opt.resume = True
+ else:
+ opt.resume = False
+ assert (bool(opt.load_weights) and opt.resume) == False
+ if opt.load_weights:
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
+ if opt.resume:
+ opt.load_weights_path = os.path.join(opt.out, 'checkpoint.pth.tar')
+
+ opt.scenes_name_list = build_scenes_name_list_from_opt(opt)
+
+ if opt.confirm:
+ confirm_opt(opt)
+ else:
+ print_opt(opt)
+
+ save_opt(opt)
+ train(opt)
diff --git a/imcui/third_party/DKM/demo/demo_fundamental.py b/imcui/third_party/DKM/demo/demo_fundamental.py
new file mode 100644
index 0000000000000000000000000000000000000000..e19766d5d3ce1abf0d18483cbbce71b2696983be
--- /dev/null
+++ b/imcui/third_party/DKM/demo/demo_fundamental.py
@@ -0,0 +1,37 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from dkm.utils.utils import tensor_to_pil
+import cv2
+from dkm import DKMv3_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+
+ # Create model
+ dkm_model = DKMv3_outdoor(device=device)
+
+
+ W_A, H_A = Image.open(im1_path).size
+ W_B, H_B = Image.open(im2_path).size
+
+ # Match
+ warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
+ # Sample matches for estimation
+ matches, certainty = dkm_model.sample(warp, certainty)
+ kpts1, kpts2 = dkm_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ F, mask = cv2.findFundamentalMat(
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+ )
+ # TODO: some better visualization
\ No newline at end of file
diff --git a/imcui/third_party/DKM/demo/demo_match.py b/imcui/third_party/DKM/demo/demo_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb901894d8654a884819162d3b9bb8094529e034
--- /dev/null
+++ b/imcui/third_party/DKM/demo/demo_match.py
@@ -0,0 +1,48 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from dkm.utils.utils import tensor_to_pil
+
+from dkm import DKMv3_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+ parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+ save_path = args.save_path
+
+ # Create model
+ dkm_model = DKMv3_outdoor(device=device)
+
+ H, W = 864, 1152
+
+ im1 = Image.open(im1_path).resize((W, H))
+ im2 = Image.open(im2_path).resize((W, H))
+
+ # Match
+ warp, certainty = dkm_model.match(im1_path, im2_path, device=device)
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
+ dkm_model.sample(warp, certainty)
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+ im2_transfer_rgb = F.grid_sample(
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+ )[0]
+ im1_transfer_rgb = F.grid_sample(
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+ )[0]
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
+ white_im = torch.ones((H,2*W),device=device)
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
diff --git a/imcui/third_party/DKM/dkm/__init__.py b/imcui/third_party/DKM/dkm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b47632780acc7762bcccc348e2025fe99f3726
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/__init__.py
@@ -0,0 +1,4 @@
+from .models import (
+ DKMv3_outdoor,
+ DKMv3_indoor,
+ )
diff --git a/imcui/third_party/DKM/dkm/benchmarks/__init__.py b/imcui/third_party/DKM/dkm/benchmarks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..57643fd314a2301138aecdc804a5877d0ce9274e
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/benchmarks/__init__.py
@@ -0,0 +1,4 @@
+from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
+from .scannet_benchmark import ScanNetBenchmark
+from .megadepth1500_benchmark import Megadepth1500Benchmark
+from .megadepth_dense_benchmark import MegadepthDenseBenchmark
diff --git a/imcui/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..079622fdaf77c75aeadd675629f2512c45d04c2d
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/benchmarks/deprecated/hpatches_sequences_dense_benchmark.py
@@ -0,0 +1,100 @@
+from PIL import Image
+import numpy as np
+
+import os
+
+import torch
+from tqdm import tqdm
+
+from dkm.utils import *
+
+
+class HpatchesDenseBenchmark:
+ """WARNING: HPATCHES grid goes from [0,n-1] instead of [0.5,n-0.5]"""
+
+ def __init__(self, dataset_path) -> None:
+ seqs_dir = "hpatches-sequences-release"
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
+ self.seq_names = sorted(os.listdir(self.seqs_path))
+
+ def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
+ # Get matches in output format on the grid [0, n] where the center of the top-left coordinate is [0.5, 0.5]
+ offset = (
+ 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0]
+ )
+ query_coords = (
+ torch.stack(
+ (
+ wq * (query_coords[..., 0] + 1) / 2,
+ hq * (query_coords[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ query_to_support = (
+ torch.stack(
+ (
+ wsup * (query_to_support[..., 0] + 1) / 2,
+ hsup * (query_to_support[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ return query_coords, query_to_support
+
+ def inside_image(self, x, w, h):
+ return torch.logical_and(
+ x[:, 0] < (w - 1),
+ torch.logical_and(x[:, 1] < (h - 1), (x > 0).prod(dim=-1)),
+ )
+
+ def benchmark(self, model):
+ use_cuda = torch.cuda.is_available()
+ device = torch.device("cuda:0" if use_cuda else "cpu")
+ aepes = []
+ pcks = []
+ for seq_idx, seq_name in tqdm(
+ enumerate(self.seq_names), total=len(self.seq_names)
+ ):
+ if seq_name[0] == "i":
+ continue
+ im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
+ im1 = Image.open(im1_path)
+ w1, h1 = im1.size
+ for im_idx in range(2, 7):
+ im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
+ im2 = Image.open(im2_path)
+ w2, h2 = im2.size
+ matches, certainty = model.match(im2, im1, do_pred_in_og_res=True)
+ matches, certainty = matches.reshape(-1, 4), certainty.reshape(-1)
+ inv_homography = torch.from_numpy(
+ np.loadtxt(
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
+ )
+ ).to(device)
+ homography = torch.linalg.inv(inv_homography)
+ pos_a, pos_b = self.convert_coordinates(
+ matches[:, :2], matches[:, 2:], w2, h2, w1, h1
+ )
+ pos_a, pos_b = pos_a.double(), pos_b.double()
+ pos_a_h = torch.cat(
+ [pos_a, torch.ones([pos_a.shape[0], 1], device=device)], dim=1
+ )
+ pos_b_proj_h = (homography @ pos_a_h.t()).t()
+ pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
+ mask = self.inside_image(pos_b_proj, w1, h1)
+ residual = pos_b - pos_b_proj
+ dist = (residual**2).sum(dim=1).sqrt()[mask]
+ aepes.append(torch.mean(dist).item())
+ pck1 = (dist < 1.0).float().mean().item()
+ pck3 = (dist < 3.0).float().mean().item()
+ pck5 = (dist < 5.0).float().mean().item()
+ pcks.append([pck1, pck3, pck5])
+ m_pcks = np.mean(np.array(pcks), axis=0)
+ return {
+ "hp_pck1": m_pcks[0],
+ "hp_pck3": m_pcks[1],
+ "hp_pck5": m_pcks[2],
+ }
diff --git a/imcui/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..781a291f0c358cbd435790b0a639f2a2510145b2
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/benchmarks/deprecated/yfcc100m_benchmark.py
@@ -0,0 +1,119 @@
+import pickle
+import h5py
+import numpy as np
+import torch
+from dkm.utils import *
+from PIL import Image
+from tqdm import tqdm
+
+
+class Yfcc100mBenchmark:
+ def __init__(self, data_root="data/yfcc100m_test") -> None:
+ self.scenes = [
+ "buckingham_palace",
+ "notre_dame_front_facade",
+ "reichstag",
+ "sacre_coeur",
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, model, r=2):
+ model.train(False)
+ with torch.no_grad():
+ data_root = self.data_root
+ meta_info = open(
+ f"{data_root}/yfcc_test_pairs_with_gt.txt", "r"
+ ).readlines()
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ for scene_ind in range(len(self.scenes)):
+ scene = self.scenes[scene_ind]
+ pairs = np.array(
+ pickle.load(
+ open(f"{data_root}/pairs/{scene}-te-1000-pairs.pkl", "rb")
+ )
+ )
+ scene_dir = f"{data_root}/yfcc100m/{scene}/test/"
+ calibs = open(scene_dir + "calibration.txt", "r").read().split("\n")
+ images = open(scene_dir + "images.txt", "r").read().split("\n")
+ pair_inds = np.random.choice(
+ range(len(pairs)), size=len(pairs), replace=False
+ )
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind]
+ params = meta_info[1000 * scene_ind + pairind].split()
+ rot1, rot2 = int(params[2]), int(params[3])
+ calib1 = h5py.File(scene_dir + calibs[idx1], "r")
+ K1, R1, t1, _, _ = get_pose(calib1)
+ calib2 = h5py.File(scene_dir + calibs[idx2], "r")
+ K2, R2, t2, _, _ = get_pose(calib2)
+
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ im1 = images[idx1]
+ im2 = images[idx2]
+ im1 = Image.open(scene_dir + im1).rotate(rot1 * 90, expand=True)
+ w1, h1 = im1.size
+ im2 = Image.open(scene_dir + im2).rotate(rot2 * 90, expand=True)
+ w2, h2 = im2.size
+ K1 = rotate_intrinsic(K1, rot1)
+ K2 = rotate_intrinsic(K2, rot2)
+
+ dense_matches, dense_certainty = model.match(im1, im2)
+ dense_certainty = dense_certainty ** (1 / r)
+ sparse_matches, sparse_confidence = model.sample(
+ dense_matches, dense_certainty, 10000
+ )
+ scale1 = 480 / min(w1, h1)
+ scale2 = 480 / min(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1 = K1 * scale1
+ K2 = K2 * scale2
+
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = np.stack(
+ (w1 * kpts1[:, 0] / 2, h1 * kpts1[:, 1] / 2), axis=-1
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = np.stack(
+ (w2 * kpts2[:, 0] / 2, h2 * kpts2[:, 1] / 2), axis=-1
+ )
+ try:
+ threshold = 1.0
+ norm_threshold = threshold / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
+ )
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1[:2, :2],
+ K2[:2, :2],
+ norm_threshold,
+ conf=0.9999999,
+ )
+ T1_to_2 = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2, R, t)
+ e_pose = max(e_t, e_R)
+ except:
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/imcui/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3febe5ca9e3a683bc7122cec635c4f54b66f7c
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/benchmarks/hpatches_sequences_homog_benchmark.py
@@ -0,0 +1,114 @@
+from PIL import Image
+import numpy as np
+
+import os
+
+from tqdm import tqdm
+from dkm.utils import pose_auc
+import cv2
+
+
+class HpatchesHomogBenchmark:
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
+
+ def __init__(self, dataset_path) -> None:
+ seqs_dir = "hpatches-sequences-release"
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
+ self.seq_names = sorted(os.listdir(self.seqs_path))
+ # Ignore seqs is same as LoFTR.
+ self.ignore_seqs = set(
+ [
+ "i_contruction",
+ "i_crownnight",
+ "i_dc",
+ "i_pencils",
+ "i_whitebuilding",
+ "v_artisans",
+ "v_astronautis",
+ "v_talent",
+ ]
+ )
+
+ def convert_coordinates(self, query_coords, query_to_support, wq, hq, wsup, hsup):
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
+ query_coords = (
+ np.stack(
+ (
+ wq * (query_coords[..., 0] + 1) / 2,
+ hq * (query_coords[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ query_to_support = (
+ np.stack(
+ (
+ wsup * (query_to_support[..., 0] + 1) / 2,
+ hsup * (query_to_support[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ return query_coords, query_to_support
+
+ def benchmark(self, model, model_name = None):
+ n_matches = []
+ homog_dists = []
+ for seq_idx, seq_name in tqdm(
+ enumerate(self.seq_names), total=len(self.seq_names)
+ ):
+ if seq_name in self.ignore_seqs:
+ continue
+ im1_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
+ im1 = Image.open(im1_path)
+ w1, h1 = im1.size
+ for im_idx in range(2, 7):
+ im2_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
+ im2 = Image.open(im2_path)
+ w2, h2 = im2.size
+ H = np.loadtxt(
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
+ )
+ dense_matches, dense_certainty = model.match(
+ im1_path, im2_path
+ )
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
+ pos_a, pos_b = self.convert_coordinates(
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
+ )
+ try:
+ H_pred, inliers = cv2.findHomography(
+ pos_a,
+ pos_b,
+ method = cv2.RANSAC,
+ confidence = 0.99999,
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
+ )
+ except:
+ H_pred = None
+ if H_pred is None:
+ H_pred = np.zeros((3, 3))
+ H_pred[2, 2] = 1.0
+ corners = np.array(
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
+ )
+ real_warped_corners = np.dot(corners, np.transpose(H))
+ real_warped_corners = (
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
+ )
+ warped_corners = np.dot(corners, np.transpose(H_pred))
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
+ mean_dist = np.mean(
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
+ ) / (min(w2, h2) / 480.0)
+ homog_dists.append(mean_dist)
+ n_matches = np.array(n_matches)
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ auc = pose_auc(np.array(homog_dists), thresholds)
+ return {
+ "hpatches_homog_auc_3": auc[2],
+ "hpatches_homog_auc_5": auc[4],
+ "hpatches_homog_auc_10": auc[9],
+ }
diff --git a/imcui/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b1193745ff18d239165aeb3376642fb17033874
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/benchmarks/megadepth1500_benchmark.py
@@ -0,0 +1,124 @@
+import numpy as np
+import torch
+from dkm.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+
+class Megadepth1500Benchmark:
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+ if scene_names is None:
+ self.scene_names = [
+ "0015_0.1_0.3.npz",
+ "0015_0.3_0.5.npz",
+ "0022_0.1_0.3.npz",
+ "0022_0.3_0.5.npz",
+ "0022_0.5_0.7.npz",
+ ]
+ else:
+ self.scene_names = scene_names
+ self.scenes = [
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
+ for scene in self.scene_names
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, model):
+ with torch.no_grad():
+ data_root = self.data_root
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ for scene_ind in range(len(self.scenes)):
+ scene = self.scenes[scene_ind]
+ pairs = scene["pair_infos"]
+ intrinsics = scene["intrinsics"]
+ poses = scene["poses"]
+ im_paths = scene["image_paths"]
+ pair_inds = range(len(pairs))
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind][0]
+ K1 = intrinsics[idx1].copy()
+ T1 = poses[idx1].copy()
+ R1, t1 = T1[:3, :3], T1[:3, 3]
+ K2 = intrinsics[idx2].copy()
+ T2 = poses[idx2].copy()
+ R2, t2 = T2[:3, :3], T2[:3, 3]
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ im1_path = f"{data_root}/{im_paths[idx1]}"
+ im2_path = f"{data_root}/{im_paths[idx2]}"
+ im1 = Image.open(im1_path)
+ w1, h1 = im1.size
+ im2 = Image.open(im2_path)
+ w2, h2 = im2.size
+ scale1 = 1200 / max(w1, h1)
+ scale2 = 1200 / max(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1[:2] = K1[:2] * scale1
+ K2[:2] = K2[:2] * scale2
+ dense_matches, dense_certainty = model.match(im1_path, im2_path)
+ sparse_matches,_ = model.sample(
+ dense_matches, dense_certainty, 5000
+ )
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = (
+ torch.stack(
+ (
+ w1 * (kpts1[:, 0] + 1) / 2,
+ h1 * (kpts1[:, 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = (
+ torch.stack(
+ (
+ w2 * (kpts2[:, 0] + 1) / 2,
+ h2 * (kpts2[:, 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ norm_threshold = 0.5 / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1.cpu().numpy(),
+ kpts2.cpu().numpy(),
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/imcui/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b370644497efd62563105e68e692e10ff339669
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/benchmarks/megadepth_dense_benchmark.py
@@ -0,0 +1,86 @@
+import torch
+import numpy as np
+import tqdm
+from dkm.datasets import MegadepthBuilder
+from dkm.utils import warp_kpts
+from torch.utils.data import ConcatDataset
+
+
+class MegadepthDenseBenchmark:
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000, device=None) -> None:
+ mega = MegadepthBuilder(data_root=data_root)
+ self.dataset = ConcatDataset(
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
+ ) # fixed resolution of 384,512
+ self.num_samples = num_samples
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.device = device
+
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
+ b, h1, w1, d = dense_matches.shape
+ with torch.no_grad():
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
+ # x1 = torch.stack((2*x1[...,0]/w1-1,2*x1[...,1]/h1-1),dim=-1)
+ mask, x2 = warp_kpts(
+ x1.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ )
+ x2 = torch.stack(
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
+ )
+ prob = mask.float().reshape(b, h1, w1)
+ x2_hat = dense_matches[..., 2:]
+ x2_hat = torch.stack(
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
+ )
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
+ gd = gd[prob == 1]
+ pck_1 = (gd < 1.0).float().mean()
+ pck_3 = (gd < 3.0).float().mean()
+ pck_5 = (gd < 5.0).float().mean()
+ gd = gd.mean()
+ return gd, pck_1, pck_3, pck_5
+
+ def benchmark(self, model, batch_size=8):
+ model.train(False)
+ with torch.no_grad():
+ gd_tot = 0.0
+ pck_1_tot = 0.0
+ pck_3_tot = 0.0
+ pck_5_tot = 0.0
+ sampler = torch.utils.data.WeightedRandomSampler(
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
+ )
+ dataloader = torch.utils.data.DataLoader(
+ self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler
+ )
+ for data in tqdm.tqdm(dataloader):
+ im1, im2, depth1, depth2, T_1to2, K1, K2 = (
+ data["query"],
+ data["support"],
+ data["query_depth"].to(self.device),
+ data["support_depth"].to(self.device),
+ data["T_1to2"].to(self.device),
+ data["K1"].to(self.device),
+ data["K2"].to(self.device),
+ )
+ matches, certainty = model.match(im1, im2, batched=True)
+ gd, pck_1, pck_3, pck_5 = self.geometric_dist(
+ depth1, depth2, T_1to2, K1, K2, matches
+ )
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
+ gd_tot + gd,
+ pck_1_tot + pck_1,
+ pck_3_tot + pck_3,
+ pck_5_tot + pck_5,
+ )
+ return {
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
+ }
diff --git a/imcui/third_party/DKM/dkm/benchmarks/scannet_benchmark.py b/imcui/third_party/DKM/dkm/benchmarks/scannet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca938cb462c351845ce035f8be0714cf81214452
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/benchmarks/scannet_benchmark.py
@@ -0,0 +1,143 @@
+import os.path as osp
+import numpy as np
+import torch
+from dkm.utils import *
+from PIL import Image
+from tqdm import tqdm
+
+
+class ScanNetBenchmark:
+ def __init__(self, data_root="data/scannet") -> None:
+ self.data_root = data_root
+
+ def benchmark(self, model, model_name = None):
+ model.train(False)
+ with torch.no_grad():
+ data_root = self.data_root
+ tmp = np.load(osp.join(data_root, "test.npz"))
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ pair_inds = np.random.choice(
+ range(len(pairs)), size=len(pairs), replace=False
+ )
+ for pairind in tqdm(pair_inds, smoothing=0.9):
+ scene = pairs[pairind]
+ scene_name = f"scene0{scene[0]}_00"
+ im1_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[2]}.jpg",
+ )
+ im1 = Image.open(im1_path)
+ im2_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[3]}.jpg",
+ )
+ im2 = Image.open(im2_path)
+ T_gt = rel_pose[pairind].reshape(3, 4)
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
+ K = np.stack(
+ [
+ np.array([float(i) for i in r.split()])
+ for r in open(
+ osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "intrinsic",
+ "intrinsic_color.txt",
+ ),
+ "r",
+ )
+ .read()
+ .split("\n")
+ if r
+ ]
+ )
+ w1, h1 = im1.size
+ w2, h2 = im2.size
+ K1 = K.copy()
+ K2 = K.copy()
+ dense_matches, dense_certainty = model.match(im1_path, im2_path)
+ sparse_matches, sparse_certainty = model.sample(
+ dense_matches, dense_certainty, 5000
+ )
+ scale1 = 480 / min(w1, h1)
+ scale2 = 480 / min(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1 = K1 * scale1
+ K2 = K2 * scale2
+
+ offset = 0.5
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = (
+ np.stack(
+ (
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = (
+ np.stack(
+ (
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ norm_threshold = 0.5 / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/imcui/third_party/DKM/dkm/checkpointing/__init__.py b/imcui/third_party/DKM/dkm/checkpointing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/checkpointing/__init__.py
@@ -0,0 +1 @@
+from .checkpoint import CheckPoint
diff --git a/imcui/third_party/DKM/dkm/checkpointing/checkpoint.py b/imcui/third_party/DKM/dkm/checkpointing/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..715eeb587ebb87ed0d1bcf9940e048adbe35cde2
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/checkpointing/checkpoint.py
@@ -0,0 +1,31 @@
+import os
+import torch
+from torch.nn.parallel.data_parallel import DataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
+from loguru import logger
+
+
+class CheckPoint:
+ def __init__(self, dir=None, name="tmp"):
+ self.name = name
+ self.dir = dir
+ os.makedirs(self.dir, exist_ok=True)
+
+ def __call__(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ assert model is not None
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
+ model = model.module
+ states = {
+ "model": model.state_dict(),
+ "n": n,
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(states, self.dir + self.name + f"_latest.pth")
+ logger.info(f"Saved states {list(states.keys())}, at step {n}")
diff --git a/imcui/third_party/DKM/dkm/datasets/__init__.py b/imcui/third_party/DKM/dkm/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b81083212edaf345c30f0cb1116c5f9de284ce6
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/datasets/__init__.py
@@ -0,0 +1 @@
+from .megadepth import MegadepthBuilder
diff --git a/imcui/third_party/DKM/dkm/datasets/megadepth.py b/imcui/third_party/DKM/dkm/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..c580607e910ce1926b7711b5473aa82b20865369
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/datasets/megadepth.py
@@ -0,0 +1,177 @@
+import os
+import random
+from PIL import Image
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader, ConcatDataset
+
+from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+import torchvision.transforms.functional as tvf
+from dkm.utils.transforms import GeometricSequential
+import kornia.augmentation as K
+
+
+class MegadepthScene:
+ def __init__(
+ self,
+ data_root,
+ scene_info,
+ ht=384,
+ wt=512,
+ min_overlap=0.0,
+ shake_t=0,
+ rot_prob=0.0,
+ normalize=True,
+ ) -> None:
+ self.data_root = data_root
+ self.image_paths = scene_info["image_paths"]
+ self.depth_paths = scene_info["depth_paths"]
+ self.intrinsics = scene_info["intrinsics"]
+ self.poses = scene_info["poses"]
+ self.pairs = scene_info["pairs"]
+ self.overlaps = scene_info["overlaps"]
+ threshold = self.overlaps > min_overlap
+ self.pairs = self.pairs[threshold]
+ self.overlaps = self.overlaps[threshold]
+ if len(self.pairs) > 100000:
+ pairinds = np.random.choice(
+ np.arange(0, len(self.pairs)), 100000, replace=False
+ )
+ self.pairs = self.pairs[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ # counts, bins = np.histogram(self.overlaps,20)
+ # print(counts)
+ self.im_transform_ops = get_tuple_transform_ops(
+ resize=(ht, wt), normalize=normalize
+ )
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
+ resize=(ht, wt), normalize=False
+ )
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+
+ def load_im(self, im_ref, crop=None):
+ im = Image.open(im_ref)
+ return im
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
+ return torch.from_numpy(depth)
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+ return sK @ K
+
+ def rand_shake(self, *things):
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
+ return [
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
+ for thing in things
+ ], t
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ idx1, idx2 = self.pairs[pair_idx]
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T1 = self.poses[idx1]
+ T2 = self.poses[idx2]
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+ :4, :4
+ ] # (4, 4)
+
+ # Load positive pair data
+ im1, im2 = self.image_paths[idx1], self.image_paths[idx2]
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
+ im_src_ref = os.path.join(self.data_root, im1)
+ im_pos_ref = os.path.join(self.data_root, im2)
+ depth_src_ref = os.path.join(self.data_root, depth1)
+ depth_pos_ref = os.path.join(self.data_root, depth2)
+ # return torch.randn((1000,1000))
+ im_src = self.load_im(im_src_ref)
+ im_pos = self.load_im(im_pos_ref)
+ depth_src = self.load_depth(depth_src_ref)
+ depth_pos = self.load_depth(depth_pos_ref)
+
+ # Recompute camera intrinsic matrix due to the resize
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
+ # Process images
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
+ depth_src, depth_pos = self.depth_transform_ops(
+ (depth_src[None, None], depth_pos[None, None])
+ )
+ [im_src, im_pos, depth_src, depth_pos], t = self.rand_shake(
+ im_src, im_pos, depth_src, depth_pos
+ )
+ im_src, Hq = self.H_generator(im_src[None])
+ depth_src = self.H_generator.apply_transform(depth_src, Hq)
+ K1[:2, 2] += t
+ K2[:2, 2] += t
+ K1 = Hq[0] @ K1
+ data_dict = {
+ "query": im_src[0],
+ "query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
+ "support": im_pos,
+ "support_identifier": self.image_paths[idx2]
+ .split("/")[-1]
+ .split(".jpg")[0],
+ "query_depth": depth_src[0, 0],
+ "support_depth": depth_pos[0, 0],
+ "K1": K1,
+ "K2": K2,
+ "T_1to2": T_1to2,
+ }
+ return data_dict
+
+
+class MegadepthBuilder:
+ def __init__(self, data_root="data/megadepth") -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
+ self.all_scenes = os.listdir(self.scene_info_root)
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
+
+ def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
+ if split == "train":
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
+ elif split == "train_loftr":
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
+ elif split == "test":
+ scene_names = self.test_scenes
+ elif split == "test_loftr":
+ scene_names = self.test_scenes_loftr
+ else:
+ raise ValueError(f"Split {split} not available")
+ scenes = []
+ for scene_name in scene_names:
+ scene_info = np.load(
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+ ).item()
+ scenes.append(
+ MegadepthScene(
+ self.data_root, scene_info, min_overlap=min_overlap, **kwargs
+ )
+ )
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=0.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
+ return ws
+
+
+if __name__ == "__main__":
+ mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train"))
+ mega_test[0]
diff --git a/imcui/third_party/DKM/dkm/datasets/scannet.py b/imcui/third_party/DKM/dkm/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac39b41480f7585c4755cc30e0677ef74ed5e0c
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/datasets/scannet.py
@@ -0,0 +1,151 @@
+import os
+import random
+from PIL import Image
+import cv2
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset)
+
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+import os.path as osp
+import matplotlib.pyplot as plt
+from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+from dkm.utils.transforms import GeometricSequential
+
+from tqdm import tqdm
+
+class ScanNetScene:
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None:
+ self.scene_root = osp.join(data_root,"scans","scans_train")
+ self.data_names = scene_info['name']
+ self.overlaps = scene_info['score']
+ # Only sample 10s
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+ self.overlaps = self.overlaps[valid]
+ self.data_names = self.data_names[valid]
+ if len(self.data_names) > 10000:
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+ self.data_names = self.data_names[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+
+ def load_im(self, im_ref, crop=None):
+ im = Image.open(im_ref)
+ return im
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1]])
+ return sK@K
+
+ def read_scannet_pose(self,path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = np.linalg.inv(cam2world)
+ return world2cam
+
+
+ def read_scannet_intrinsic(self,path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return intrinsic[:-1, :-1]
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ data_name = self.data_names[pair_idx]
+ scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the intrinsic of depthmap
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
+ scene_name,
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+ # read and compute relative poses
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_1}.txt'))
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_2}.txt'))
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
+
+ # Load positive pair data
+ im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
+ im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
+ depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
+ depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+
+ im_src = self.load_im(im_src_ref)
+ im_pos = self.load_im(im_pos_ref)
+ depth_src = self.load_depth(depth_src_ref)
+ depth_pos = self.load_depth(depth_pos_ref)
+
+ # Recompute camera intrinsic matrix due to the resize
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
+ # Process images
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
+ depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None]))
+
+ data_dict = {'query': im_src,
+ 'support': im_pos,
+ 'query_depth': depth_src[0,0],
+ 'support_depth': depth_pos[0,0],
+ 'K1': K1,
+ 'K2': K2,
+ 'T_1to2':T_1to2,
+ }
+ return data_dict
+
+
+class ScanNetBuilder:
+ def __init__(self, data_root = 'data/scannet') -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
+ self.all_scenes = os.listdir(self.scene_info_root)
+
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+ # Note: split doesn't matter here as we always use same scannet_train scenes
+ scene_names = self.all_scenes
+ scenes = []
+ for scene_name in tqdm(scene_names):
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+ return ws
+
+
+if __name__ == "__main__":
+ mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train'))
+ mega_test[0]
\ No newline at end of file
diff --git a/imcui/third_party/DKM/dkm/losses/__init__.py b/imcui/third_party/DKM/dkm/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..71914f50d891079d204a07c57367159888f892de
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/losses/__init__.py
@@ -0,0 +1 @@
+from .depth_match_regression_loss import DepthRegressionLoss
diff --git a/imcui/third_party/DKM/dkm/losses/depth_match_regression_loss.py b/imcui/third_party/DKM/dkm/losses/depth_match_regression_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..80da70347b4b4addc721e2a14ed489f8683fd48a
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/losses/depth_match_regression_loss.py
@@ -0,0 +1,128 @@
+from einops.einops import rearrange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dkm.utils.utils import warp_kpts
+
+
+class DepthRegressionLoss(nn.Module):
+ def __init__(
+ self,
+ robust=True,
+ center_coords=False,
+ scale_normalize=False,
+ ce_weight=0.01,
+ local_loss=True,
+ local_dist=4.0,
+ local_largest_scale=8,
+ ):
+ super().__init__()
+ self.robust = robust # measured in pixels
+ self.center_coords = center_coords
+ self.scale_normalize = scale_normalize
+ self.ce_weight = ce_weight
+ self.local_loss = local_loss
+ self.local_dist = local_dist
+ self.local_largest_scale = local_largest_scale
+
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale):
+ """[summary]
+
+ Args:
+ H ([type]): [description]
+ scale ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ b, h1, w1, d = dense_matches.shape
+ with torch.no_grad():
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device
+ )
+ for n in (b, h1, w1)
+ ]
+ )
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2)
+ mask, x2 = warp_kpts(
+ x1_n.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ )
+ prob = mask.float().reshape(b, h1, w1)
+ gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale?
+ return gd, prob
+
+ def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8):
+ """[summary]
+
+ Args:
+ dense_certainty ([type]): [description]
+ prob ([type]): [description]
+ eps ([type], optional): [description]. Defaults to 1e-8.
+
+ Returns:
+ [type]: [description]
+ """
+ smooth_prob = prob
+ ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob)
+ depth_loss = gd[prob > 0]
+ if not torch.any(prob > 0).item():
+ depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere
+ return {
+ f"ce_loss_{scale}": ce_loss.mean(),
+ f"depth_loss_{scale}": depth_loss.mean(),
+ }
+
+ def forward(self, dense_corresps, batch):
+ """[summary]
+
+ Args:
+ out ([type]): [description]
+ batch ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ scales = list(dense_corresps.keys())
+ tot_loss = 0.0
+ prev_gd = 0.0
+ for scale in scales:
+ dense_scale_corresps = dense_corresps[scale]
+ dense_scale_certainty, dense_scale_coords = (
+ dense_scale_corresps["dense_certainty"],
+ dense_scale_corresps["dense_flow"],
+ )
+ dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d")
+ b, h, w, d = dense_scale_coords.shape
+ gd, prob = self.geometric_dist(
+ batch["query_depth"],
+ batch["support_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],
+ dense_scale_coords,
+ scale,
+ )
+ if (
+ scale <= self.local_largest_scale and self.local_loss
+ ): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching
+ prob = prob * (
+ F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0]
+ < (2 / 512) * (self.local_dist * scale)
+ )
+ depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale)
+ scale_loss = (
+ self.ce_weight * depth_losses[f"ce_loss_{scale}"]
+ + depth_losses[f"depth_loss_{scale}"]
+ ) # scale ce loss for coarser scales
+ if self.scale_normalize:
+ scale_loss = scale_loss * 1 / scale
+ tot_loss = tot_loss + scale_loss
+ prev_gd = gd.detach()
+ return tot_loss
diff --git a/imcui/third_party/DKM/dkm/models/__init__.py b/imcui/third_party/DKM/dkm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4fc321ec70fd116beca23e94248cb6bbe771523
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/__init__.py
@@ -0,0 +1,4 @@
+from .model_zoo import (
+ DKMv3_outdoor,
+ DKMv3_indoor,
+)
diff --git a/imcui/third_party/DKM/dkm/models/deprecated/build_model.py b/imcui/third_party/DKM/dkm/models/deprecated/build_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd28335f3e348ab6c90b26ba91b95e864b0bbbb9
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/deprecated/build_model.py
@@ -0,0 +1,787 @@
+import torch
+import torch.nn as nn
+from dkm import *
+from .local_corr import LocalCorr
+from .corr_channels import NormedCorr
+from torchvision.models import resnet as tv_resnet
+
+dkm_pretrained_urls = {
+ "DKM": {
+ "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth",
+ "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth",
+ },
+ "DKMv2":{
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth",
+ "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth",
+ }
+}
+
+
+def DKM(pretrained=True, version="mega_synthetic", device=None):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": ConvRefiner(
+ 2 * 256,
+ 512,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": ConvRefiner(
+ 2 * 64,
+ 128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained),
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["DKM"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+def DKMv2(pretrained=True, version="outdoor", resolution = "low", **kwargs):
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ displacement_emb = "linear"
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512+128,
+ 1024+128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=128,
+ ),
+ "8": ConvRefiner(
+ 2 * 512+64,
+ 1024+64,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=64,
+ ),
+ "4": ConvRefiner(
+ 2 * 256+32,
+ 512+32,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=32,
+ ),
+ "2": ConvRefiner(
+ 2 * 64+16,
+ 128+16,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=16,
+ ),
+ "1": ConvRefiner(
+ 2 * 3+6,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=6,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ if resolution == "low":
+ h, w = 384, 512
+ elif resolution == "high":
+ h, w = 480, 640
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained),
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs).to(device)
+ if pretrained:
+ try:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["DKMv2"][version]
+ )
+ except:
+ weights = torch.load(
+ dkm_pretrained_urls["DKMv2"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def local_corr(pretrained=True, version="mega_synthetic"):
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": LocalCorr(
+ 81,
+ 81 * 6,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": LocalCorr(
+ 81,
+ 81,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["local_corr"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def corr_channels(pretrained=True, version="mega_synthetic"):
+ h, w = 384, 512
+ gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim[0] + feat_dim, dfn_dim),
+ "16": RRB(gp_dim[1] + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": ConvRefiner(
+ 2 * 256,
+ 512,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": ConvRefiner(
+ 2 * 64,
+ 128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ gp32 = NormedCorr()
+ gp16 = NormedCorr()
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["corr_channels"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def baseline(pretrained=True, version="mega_synthetic"):
+ h, w = 384, 512
+ gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16)
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim[0] + feat_dim, dfn_dim),
+ "16": RRB(gp_dim[1] + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": LocalCorr(
+ 81,
+ 81 * 12,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": LocalCorr(
+ 81,
+ 81 * 6,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": LocalCorr(
+ 81,
+ 81,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ gp32 = NormedCorr()
+ gp16 = NormedCorr()
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["baseline"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
+
+
+def linear(pretrained=True, version="mega_synthetic"):
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "8": ConvRefiner(
+ 2 * 512,
+ 1024,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "4": ConvRefiner(
+ 2 * 256,
+ 512,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "2": ConvRefiner(
+ 2 * 64,
+ 128,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ "1": ConvRefiner(
+ 2 * 3,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "linear"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+ h, w = 384, 512
+ encoder = Encoder(
+ tv_resnet.resnet50(pretrained=not pretrained)
+ ) # only load pretrained weights if not loading a pretrained matcher ;)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device)
+ if pretrained:
+ weights = torch.hub.load_state_dict_from_url(
+ dkm_pretrained_urls["linear"][version]
+ )
+ matcher.load_state_dict(weights)
+ return matcher
diff --git a/imcui/third_party/DKM/dkm/models/deprecated/corr_channels.py b/imcui/third_party/DKM/dkm/models/deprecated/corr_channels.py
new file mode 100644
index 0000000000000000000000000000000000000000..8713b0d8c7a0ce91da4d2105ba29097a4969a037
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/deprecated/corr_channels.py
@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+
+class NormedCorrelationKernel(nn.Module): # similar to softmax kernel
+ def __init__(self):
+ super().__init__()
+
+ def __call__(self, x, y, eps=1e-6):
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
+ )
+ return c
+
+
+class NormedCorr(nn.Module):
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ self.corr = NormedCorrelationKernel()
+
+ def reshape(self, x):
+ return rearrange(x, "b d h w -> b (h w) d")
+
+ def forward(self, x, y, **kwargs):
+ b, c, h, w = y.shape
+ assert x.shape == y.shape
+ x, y = self.reshape(x), self.reshape(y)
+ corr_xy = self.corr(x, y)
+ corr_xy_flat = rearrange(corr_xy, "b (h w) c -> b c h w", h=h, w=w)
+ return corr_xy_flat
diff --git a/imcui/third_party/DKM/dkm/models/deprecated/local_corr.py b/imcui/third_party/DKM/dkm/models/deprecated/local_corr.py
new file mode 100644
index 0000000000000000000000000000000000000000..681fe4c0079561fa7a4c44e82a8879a4a27273a1
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/deprecated/local_corr.py
@@ -0,0 +1,630 @@
+import torch
+import torch.nn.functional as F
+
+try:
+ import cupy
+except:
+ print("Cupy not found, local correlation will not work")
+import re
+from ..dkm import ConvRefiner
+
+
+class Stream:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if device == 'cuda':
+ stream = torch.cuda.current_stream(device=device).cuda_stream
+ else:
+ stream = None
+
+
+kernel_Correlation_rearrange = """
+ extern "C" __global__ void kernel_Correlation_rearrange(
+ const int n,
+ const float* input,
+ float* output
+ ) {
+ int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
+ if (intIndex >= n) {
+ return;
+ }
+ int intSample = blockIdx.z;
+ int intChannel = blockIdx.y;
+ float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
+ __syncthreads();
+ int intPaddedY = (intIndex / SIZE_3(input)) + 4;
+ int intPaddedX = (intIndex % SIZE_3(input)) + 4;
+ int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
+ output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue;
+ }
+"""
+
+kernel_Correlation_updateOutput = """
+ extern "C" __global__ void kernel_Correlation_updateOutput(
+ const int n,
+ const float* rbot0,
+ const float* rbot1,
+ float* top
+ ) {
+ extern __shared__ char patch_data_char[];
+ float *patch_data = (float *)patch_data_char;
+ // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
+ int x1 = blockIdx.x + 4;
+ int y1 = blockIdx.y + 4;
+ int item = blockIdx.z;
+ int ch_off = threadIdx.x;
+ // Load 3D patch into shared shared memory
+ for (int j = 0; j < 1; j++) { // HEIGHT
+ for (int i = 0; i < 1; i++) { // WIDTH
+ int ji_off = (j + i) * SIZE_3(rbot0);
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
+ int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
+ int idxPatchData = ji_off + ch;
+ patch_data[idxPatchData] = rbot0[idx1];
+ }
+ }
+ }
+ __syncthreads();
+ __shared__ float sum[32];
+ // Compute correlation
+ for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
+ sum[ch_off] = 0;
+ int s2o = top_channel % 9 - 4;
+ int s2p = top_channel / 9 - 4;
+ for (int j = 0; j < 1; j++) { // HEIGHT
+ for (int i = 0; i < 1; i++) { // WIDTH
+ int ji_off = (j + i) * SIZE_3(rbot0);
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
+ int x2 = x1 + s2o;
+ int y2 = y1 + s2p;
+ int idxPatchData = ji_off + ch;
+ int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
+ sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
+ }
+ }
+ }
+ __syncthreads();
+ if (ch_off == 0) {
+ float total_sum = 0;
+ for (int idx = 0; idx < 32; idx++) {
+ total_sum += sum[idx];
+ }
+ const int sumelems = SIZE_3(rbot0);
+ const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
+ top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
+ }
+ }
+ }
+"""
+
+kernel_Correlation_updateGradFirst = """
+ #define ROUND_OFF 50000
+ extern "C" __global__ void kernel_Correlation_updateGradFirst(
+ const int n,
+ const int intSample,
+ const float* rbot0,
+ const float* rbot1,
+ const float* gradOutput,
+ float* gradFirst,
+ float* gradSecond
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ int n = intIndex % SIZE_1(gradFirst); // channels
+ int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
+ int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = round_off;
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
+ int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
+ int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
+ // Same here:
+ int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
+ int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
+ float sum = 0;
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
+ xmin = max(0,xmin);
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
+ ymin = max(0,ymin);
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
+ for (int p = -4; p <= 4; p++) {
+ for (int o = -4; o <= 4; o++) {
+ // Get rbot1 data:
+ int s2o = o;
+ int s2p = p;
+ int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
+ float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
+ // Index offset for gradOutput in following loops:
+ int op = (p+4) * 9 + (o+4); // index[o,p]
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
+ sum += gradOutput[idxgradOutput] * bot1tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = SIZE_1(gradFirst);
+ const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
+ gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
+ } }
+"""
+
+kernel_Correlation_updateGradSecond = """
+ #define ROUND_OFF 50000
+ extern "C" __global__ void kernel_Correlation_updateGradSecond(
+ const int n,
+ const int intSample,
+ const float* rbot0,
+ const float* rbot1,
+ const float* gradOutput,
+ float* gradFirst,
+ float* gradSecond
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ int n = intIndex % SIZE_1(gradSecond); // channels
+ int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
+ int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = round_off;
+ float sum = 0;
+ for (int p = -4; p <= 4; p++) {
+ for (int o = -4; o <= 4; o++) {
+ int s2o = o;
+ int s2p = p;
+ //Get X,Y ranges and clamp
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
+ int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
+ int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
+ // Same here:
+ int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
+ int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
+ xmin = max(0,xmin);
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
+ ymin = max(0,ymin);
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
+ // Get rbot0 data:
+ int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
+ float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
+ // Index offset for gradOutput in following loops:
+ int op = (p+4) * 9 + (o+4); // index[o,p]
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
+ sum += gradOutput[idxgradOutput] * bot0tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = SIZE_1(gradSecond);
+ const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
+ gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
+ } }
+"""
+
+
+def cupy_kernel(strFunction, objectVariables):
+ strKernel = globals()[strFunction]
+
+ while True:
+ objectMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
+
+ if objectMatch is None:
+ break
+
+ intArg = int(objectMatch.group(2))
+
+ strTensor = objectMatch.group(4)
+ intSizes = objectVariables[strTensor].size()
+
+ strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
+
+ while True:
+ objectMatch = re.search(r"(VALUE_)([0-4])(\()([^\)]+)(\))", strKernel)
+
+ if objectMatch is None:
+ break
+
+ intArgs = int(objectMatch.group(2))
+ strArgs = objectMatch.group(4).split(",")
+
+ strTensor = strArgs[0]
+ intStrides = objectVariables[strTensor].stride()
+ strIndex = [
+ "(("
+ + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ + ")*"
+ + str(intStrides[intArg])
+ + ")"
+ for intArg in range(intArgs)
+ ]
+
+ strKernel = strKernel.replace(
+ objectMatch.group(0), strTensor + "[" + str.join("+", strIndex) + "]"
+ )
+
+ return strKernel
+
+
+try:
+
+ @cupy.memoize(for_each_device=True)
+ def cupy_launch(strFunction, strKernel):
+ return cupy.RawModule(code=strKernel).get_function(strFunction)
+
+except:
+ pass
+
+
+class _FunctionCorrelation(torch.autograd.Function):
+ @staticmethod
+ def forward(self, first, second):
+ rbot0 = first.new_zeros(
+ [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
+ )
+ rbot1 = first.new_zeros(
+ [first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
+ )
+
+ self.save_for_backward(first, second, rbot0, rbot1)
+
+ first = first.contiguous()
+ second = second.contiguous()
+
+ output = first.new_zeros([first.size(0), 81, first.size(2), first.size(3)])
+
+ if first.is_cuda == True:
+ n = first.size(2) * first.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange", {"input": first, "output": rbot0}
+ ),
+ )(
+ grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, first.data_ptr(), rbot0.data_ptr()],
+ stream=Stream,
+ )
+
+ n = second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange", {"input": second, "output": rbot1}
+ ),
+ )(
+ grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, second.data_ptr(), rbot1.data_ptr()],
+ stream=Stream,
+ )
+
+ n = output.size(1) * output.size(2) * output.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateOutput",
+ cupy_kernel(
+ "kernel_Correlation_updateOutput",
+ {"rbot0": rbot0, "rbot1": rbot1, "top": output},
+ ),
+ )(
+ grid=tuple([output.size(3), output.size(2), output.size(0)]),
+ block=tuple([32, 1, 1]),
+ shared_mem=first.size(1) * 4,
+ args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()],
+ stream=Stream,
+ )
+
+ elif first.is_cuda == False:
+ raise NotImplementedError()
+
+ return output
+
+ @staticmethod
+ def backward(self, gradOutput):
+ first, second, rbot0, rbot1 = self.saved_tensors
+
+ gradOutput = gradOutput.contiguous()
+
+ assert gradOutput.is_contiguous() == True
+
+ gradFirst = (
+ first.new_zeros(
+ [first.size(0), first.size(1), first.size(2), first.size(3)]
+ )
+ if self.needs_input_grad[0] == True
+ else None
+ )
+ gradSecond = (
+ first.new_zeros(
+ [first.size(0), first.size(1), first.size(2), first.size(3)]
+ )
+ if self.needs_input_grad[1] == True
+ else None
+ )
+
+ if first.is_cuda == True:
+ if gradFirst is not None:
+ for intSample in range(first.size(0)):
+ n = first.size(1) * first.size(2) * first.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradFirst",
+ cupy_kernel(
+ "kernel_Correlation_updateGradFirst",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": gradOutput,
+ "gradFirst": gradFirst,
+ "gradSecond": None,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ gradOutput.data_ptr(),
+ gradFirst.data_ptr(),
+ None,
+ ],
+ stream=Stream,
+ )
+
+ if gradSecond is not None:
+ for intSample in range(first.size(0)):
+ n = first.size(1) * first.size(2) * first.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradSecond",
+ cupy_kernel(
+ "kernel_Correlation_updateGradSecond",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": gradOutput,
+ "gradFirst": None,
+ "gradSecond": gradSecond,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ gradOutput.data_ptr(),
+ None,
+ gradSecond.data_ptr(),
+ ],
+ stream=Stream,
+ )
+
+ elif first.is_cuda == False:
+ raise NotImplementedError()
+
+ return gradFirst, gradSecond
+
+
+class _FunctionCorrelationTranspose(torch.autograd.Function):
+ @staticmethod
+ def forward(self, input, second):
+ rbot0 = second.new_zeros(
+ [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
+ )
+ rbot1 = second.new_zeros(
+ [second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
+ )
+
+ self.save_for_backward(input, second, rbot0, rbot1)
+
+ input = input.contiguous()
+ second = second.contiguous()
+
+ output = second.new_zeros(
+ [second.size(0), second.size(1), second.size(2), second.size(3)]
+ )
+
+ if second.is_cuda == True:
+ n = second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange", {"input": second, "output": rbot1}
+ ),
+ )(
+ grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, second.data_ptr(), rbot1.data_ptr()],
+ stream=Stream,
+ )
+
+ for intSample in range(second.size(0)):
+ n = second.size(1) * second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradFirst",
+ cupy_kernel(
+ "kernel_Correlation_updateGradFirst",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": input,
+ "gradFirst": output,
+ "gradSecond": None,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ input.data_ptr(),
+ output.data_ptr(),
+ None,
+ ],
+ stream=Stream,
+ )
+
+ elif second.is_cuda == False:
+ raise NotImplementedError()
+
+ return output
+
+ @staticmethod
+ def backward(self, gradOutput):
+ input, second, rbot0, rbot1 = self.saved_tensors
+
+ gradOutput = gradOutput.contiguous()
+
+ gradInput = (
+ input.new_zeros(
+ [input.size(0), input.size(1), input.size(2), input.size(3)]
+ )
+ if self.needs_input_grad[0] == True
+ else None
+ )
+ gradSecond = (
+ second.new_zeros(
+ [second.size(0), second.size(1), second.size(2), second.size(3)]
+ )
+ if self.needs_input_grad[1] == True
+ else None
+ )
+
+ if second.is_cuda == True:
+ if gradInput is not None or gradSecond is not None:
+ n = second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_rearrange",
+ cupy_kernel(
+ "kernel_Correlation_rearrange",
+ {"input": gradOutput, "output": rbot0},
+ ),
+ )(
+ grid=tuple(
+ [int((n + 16 - 1) / 16), gradOutput.size(1), gradOutput.size(0)]
+ ),
+ block=tuple([16, 1, 1]),
+ args=[n, gradOutput.data_ptr(), rbot0.data_ptr()],
+ stream=Stream,
+ )
+
+ if gradInput is not None:
+ n = gradInput.size(1) * gradInput.size(2) * gradInput.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateOutput",
+ cupy_kernel(
+ "kernel_Correlation_updateOutput",
+ {"rbot0": rbot0, "rbot1": rbot1, "top": gradInput},
+ ),
+ )(
+ grid=tuple(
+ [gradInput.size(3), gradInput.size(2), gradInput.size(0)]
+ ),
+ block=tuple([32, 1, 1]),
+ shared_mem=gradOutput.size(1) * 4,
+ args=[n, rbot0.data_ptr(), rbot1.data_ptr(), gradInput.data_ptr()],
+ stream=Stream,
+ )
+
+ if gradSecond is not None:
+ for intSample in range(second.size(0)):
+ n = second.size(1) * second.size(2) * second.size(3)
+ cupy_launch(
+ "kernel_Correlation_updateGradSecond",
+ cupy_kernel(
+ "kernel_Correlation_updateGradSecond",
+ {
+ "rbot0": rbot0,
+ "rbot1": rbot1,
+ "gradOutput": input,
+ "gradFirst": None,
+ "gradSecond": gradSecond,
+ },
+ ),
+ )(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[
+ n,
+ intSample,
+ rbot0.data_ptr(),
+ rbot1.data_ptr(),
+ input.data_ptr(),
+ None,
+ gradSecond.data_ptr(),
+ ],
+ stream=Stream,
+ )
+
+ elif second.is_cuda == False:
+ raise NotImplementedError()
+
+ return gradInput, gradSecond
+
+
+def FunctionCorrelation(reference_features, query_features):
+ return _FunctionCorrelation.apply(reference_features, query_features)
+
+
+class ModuleCorrelation(torch.nn.Module):
+ def __init__(self):
+ super(ModuleCorrelation, self).__init__()
+
+ def forward(self, tensorFirst, tensorSecond):
+ return _FunctionCorrelation.apply(tensorFirst, tensorSecond)
+
+
+def FunctionCorrelationTranspose(reference_features, query_features):
+ return _FunctionCorrelationTranspose.apply(reference_features, query_features)
+
+
+class ModuleCorrelationTranspose(torch.nn.Module):
+ def __init__(self):
+ super(ModuleCorrelationTranspose, self).__init__()
+
+ def forward(self, tensorFirst, tensorSecond):
+ return _FunctionCorrelationTranspose.apply(tensorFirst, tensorSecond)
+
+
+class LocalCorr(ConvRefiner):
+ def forward(self, x, y, flow):
+ """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
+
+ Args:
+ x ([type]): [description]
+ y ([type]): [description]
+ flow ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ with torch.no_grad():
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
+ corr = FunctionCorrelation(x, x_hat)
+ d = self.block1(corr)
+ d = self.hidden_blocks(d)
+ d = self.out_conv(d)
+ certainty, displacement = d[:, :-2], d[:, -2:]
+ return certainty, displacement
+
+
+if __name__ == "__main__":
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ x = torch.randn(2, 128, 32, 32).to(device)
+ y = torch.randn(2, 128, 32, 32).to(device)
+ local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4)
+ z = local_corr(x, y)
+ print("hej")
diff --git a/imcui/third_party/DKM/dkm/models/dkm.py b/imcui/third_party/DKM/dkm/models/dkm.py
new file mode 100644
index 0000000000000000000000000000000000000000..27c3f6d59ad3a8e976e3d719868908ddf443883e
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/dkm.py
@@ -0,0 +1,759 @@
+import math
+import os
+import numpy as np
+from PIL import Image
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..utils import get_tuple_transform_ops
+from einops import rearrange
+from ..utils.local_correlation import local_correlation
+
+
+class ConvRefiner(nn.Module):
+ def __init__(
+ self,
+ in_dim=6,
+ hidden_dim=16,
+ out_dim=2,
+ dw=False,
+ kernel_size=5,
+ hidden_blocks=3,
+ displacement_emb = None,
+ displacement_emb_dim = None,
+ local_corr_radius = None,
+ corr_in_other = None,
+ no_support_fm = False,
+ ):
+ super().__init__()
+ self.block1 = self.create_block(
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size
+ )
+ self.hidden_blocks = nn.Sequential(
+ *[
+ self.create_block(
+ hidden_dim,
+ hidden_dim,
+ dw=dw,
+ kernel_size=kernel_size,
+ )
+ for hb in range(hidden_blocks)
+ ]
+ )
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
+ if displacement_emb:
+ self.has_displacement_emb = True
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
+ else:
+ self.has_displacement_emb = False
+ self.local_corr_radius = local_corr_radius
+ self.corr_in_other = corr_in_other
+ self.no_support_fm = no_support_fm
+ def create_block(
+ self,
+ in_dim,
+ out_dim,
+ dw=False,
+ kernel_size=5,
+ ):
+ num_groups = 1 if not dw else in_dim
+ if dw:
+ assert (
+ out_dim % in_dim == 0
+ ), "outdim must be divisible by indim for depthwise"
+ conv1 = nn.Conv2d(
+ in_dim,
+ out_dim,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ groups=num_groups,
+ )
+ norm = nn.BatchNorm2d(out_dim)
+ relu = nn.ReLU(inplace=True)
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
+ return nn.Sequential(conv1, norm, relu, conv2)
+
+ def forward(self, x, y, flow):
+ """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
+
+ Args:
+ x ([type]): [description]
+ y ([type]): [description]
+ flow ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ device = x.device
+ b,c,hs,ws = x.shape
+ with torch.no_grad():
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
+ if self.has_displacement_emb:
+ query_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
+ )
+ )
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
+ in_displacement = flow-query_coords
+ emb_in_displacement = self.disp_emb(in_displacement)
+ if self.local_corr_radius:
+ #TODO: should corr have gradient?
+ if self.corr_in_other:
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow)
+ else:
+ # Otherwise we use the warp to sample in the first image
+ # This is actually different operations, especially for large viewpoint changes
+ local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,)
+ if self.no_support_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
+ else:
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
+ else:
+ if self.no_support_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat), dim=1)
+ d = self.block1(d)
+ d = self.hidden_blocks(d)
+ d = self.out_conv(d)
+ certainty, displacement = d[:, :-2], d[:, -2:]
+ return certainty, displacement
+
+
+class CosKernel(nn.Module): # similar to softmax kernel
+ def __init__(self, T, learn_temperature=False):
+ super().__init__()
+ self.learn_temperature = learn_temperature
+ if self.learn_temperature:
+ self.T = nn.Parameter(torch.tensor(T))
+ else:
+ self.T = T
+
+ def __call__(self, x, y, eps=1e-6):
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
+ )
+ if self.learn_temperature:
+ T = self.T.abs() + 0.01
+ else:
+ T = torch.tensor(self.T, device=c.device)
+ K = ((c - 1.0) / T).exp()
+ return K
+
+
+class CAB(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(CAB, self).__init__()
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2d(
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.sigmod = nn.Sigmoid()
+
+ def forward(self, x):
+ x1, x2 = x # high, low (old, new)
+ x = torch.cat([x1, x2], dim=1)
+ x = self.global_pooling(x)
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.sigmod(x)
+ x2 = x * x2
+ res = x2 + x1
+ return res
+
+
+class RRB(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super(RRB, self).__init__()
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.conv2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ self.relu = nn.ReLU()
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.conv3 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+
+ def forward(self, x):
+ x = self.conv1(x)
+ res = self.conv2(x)
+ res = self.bn(res)
+ res = self.relu(res)
+ res = self.conv3(res)
+ return self.relu(x + res)
+
+
+class DFN(nn.Module):
+ def __init__(
+ self,
+ internal_dim,
+ feat_input_modules,
+ pred_input_modules,
+ rrb_d_dict,
+ cab_dict,
+ rrb_u_dict,
+ use_global_context=False,
+ global_dim=None,
+ terminal_module=None,
+ upsample_mode="bilinear",
+ align_corners=False,
+ ):
+ super().__init__()
+ if use_global_context:
+ assert (
+ global_dim is not None
+ ), "Global dim must be provided when using global context"
+ self.align_corners = align_corners
+ self.internal_dim = internal_dim
+ self.feat_input_modules = feat_input_modules
+ self.pred_input_modules = pred_input_modules
+ self.rrb_d = rrb_d_dict
+ self.cab = cab_dict
+ self.rrb_u = rrb_u_dict
+ self.use_global_context = use_global_context
+ if use_global_context:
+ self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0)
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
+ self.terminal_module = (
+ terminal_module if terminal_module is not None else nn.Identity()
+ )
+ self.upsample_mode = upsample_mode
+ self._scales = [int(key) for key in self.terminal_module.keys()]
+
+ def scales(self):
+ return self._scales.copy()
+
+ def forward(self, embeddings, feats, context, key):
+ feats = self.feat_input_modules[str(key)](feats)
+ embeddings = torch.cat([feats, embeddings], dim=1)
+ embeddings = self.rrb_d[str(key)](embeddings)
+ context = self.cab[str(key)]([context, embeddings])
+ context = self.rrb_u[str(key)](context)
+ preds = self.terminal_module[str(key)](context)
+ pred_coord = preds[:, -2:]
+ pred_certainty = preds[:, :-2]
+ return pred_coord, pred_certainty, context
+
+
+class GP(nn.Module):
+ def __init__(
+ self,
+ kernel,
+ T=1,
+ learn_temperature=False,
+ only_attention=False,
+ gp_dim=64,
+ basis="fourier",
+ covar_size=5,
+ only_nearest_neighbour=False,
+ sigma_noise=0.1,
+ no_cov=False,
+ predict_features = False,
+ ):
+ super().__init__()
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
+ self.sigma_noise = sigma_noise
+ self.covar_size = covar_size
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
+ self.only_attention = only_attention
+ self.only_nearest_neighbour = only_nearest_neighbour
+ self.basis = basis
+ self.no_cov = no_cov
+ self.dim = gp_dim
+ self.predict_features = predict_features
+
+ def get_local_cov(self, cov):
+ K = self.covar_size
+ b, h, w, h, w = cov.shape
+ hw = h * w
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
+ delta = torch.stack(
+ torch.meshgrid(
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
+ ),
+ dim=-1,
+ )
+ positions = torch.stack(
+ torch.meshgrid(
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
+ ),
+ dim=-1,
+ )
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
+ :,
+ points.flatten(),
+ neighbours[..., 0].flatten(),
+ neighbours[..., 1].flatten(),
+ ].reshape(b, h, w, K**2)
+ return local_cov
+
+ def reshape(self, x):
+ return rearrange(x, "b d h w -> b (h w) d")
+
+ def project_to_basis(self, x):
+ if self.basis == "fourier":
+ return torch.cos(8 * math.pi * self.pos_conv(x))
+ elif self.basis == "linear":
+ return self.pos_conv(x)
+ else:
+ raise ValueError(
+ "No other bases other than fourier and linear currently supported in public release"
+ )
+
+ def get_pos_enc(self, y):
+ b, c, h, w = y.shape
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
+ )
+ )
+
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
+ return coarse_embedded_coords
+
+ def forward(self, x, y, **kwargs):
+ b, c, h1, w1 = x.shape
+ b, c, h2, w2 = y.shape
+ f = self.get_pos_enc(y)
+ if self.predict_features:
+ f = f + y[:,:self.dim] # Stupid way to predict features
+ b, d, h2, w2 = f.shape
+ #assert x.shape == y.shape
+ x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
+ K_xx = self.K(x, x)
+ K_yy = self.K(y, y)
+ K_xy = self.K(x, y)
+ K_yx = K_xy.permute(0, 2, 1)
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
+ # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
+ if len(K_yy[0]) > 2000:
+ K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)])
+ else:
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
+
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
+ if not self.no_cov:
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
+ local_cov_x = self.get_local_cov(cov_x)
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
+ else:
+ gp_feats = mu_x
+ return gp_feats
+
+
+class Encoder(nn.Module):
+ def __init__(self, resnet):
+ super().__init__()
+ self.resnet = resnet
+ def forward(self, x):
+ x0 = x
+ b, c, h, w = x.shape
+ x = self.resnet.conv1(x)
+ x = self.resnet.bn1(x)
+ x1 = self.resnet.relu(x)
+
+ x = self.resnet.maxpool(x1)
+ x2 = self.resnet.layer1(x)
+
+ x3 = self.resnet.layer2(x2)
+
+ x4 = self.resnet.layer3(x3)
+
+ x5 = self.resnet.layer4(x4)
+ feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0}
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None,
+ ):
+ super().__init__()
+ self.embedding_decoder = embedding_decoder
+ self.gps = gps
+ self.proj = proj
+ self.conv_refiner = conv_refiner
+ self.detach = detach
+ if scales == "all":
+ self.scales = ["32", "16", "8", "4", "2", "1"]
+ else:
+ self.scales = scales
+
+ def upsample_preds(self, flow, certainty, query, support):
+ b, hs, ws, d = flow.shape
+ b, c, h, w = query.shape
+ flow = flow.permute(0, 3, 1, 2)
+ certainty = F.interpolate(
+ certainty, size=(h, w), align_corners=False, mode="bilinear"
+ )
+ flow = F.interpolate(
+ flow, size=(h, w), align_corners=False, mode="bilinear"
+ )
+ delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
+ flow = torch.stack(
+ (
+ flow[:, 0] + delta_flow[:, 0] / (4 * w),
+ flow[:, 1] + delta_flow[:, 1] / (4 * h),
+ ),
+ dim=1,
+ )
+ flow = flow.permute(0, 2, 3, 1)
+ certainty = certainty + delta_certainty
+ return flow, certainty
+
+ def get_placeholder_flow(self, b, h, w, device):
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ )
+ )
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ return coarse_coords
+
+
+ def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None):
+ coarse_scales = self.embedding_decoder.scales()
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
+ h, w = sizes[1]
+ b = f1[1].shape[0]
+ device = f1[1].device
+ coarsest_scale = int(all_scales[0])
+ old_stuff = torch.zeros(
+ b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
+ )
+ dense_corresps = {}
+ if not upsample:
+ dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
+ dense_certainty = 0.0
+ else:
+ dense_flow = F.interpolate(
+ dense_flow,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ dense_certainty = F.interpolate(
+ dense_certainty,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ for new_scale in all_scales:
+ ins = int(new_scale)
+ f1_s, f2_s = f1[ins], f2[ins]
+ if new_scale in self.proj:
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
+ b, c, hs, ws = f1_s.shape
+ if ins in coarse_scales:
+ old_stuff = F.interpolate(
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
+ )
+ new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow)
+ dense_flow, dense_certainty, old_stuff = self.embedding_decoder(
+ new_stuff, f1_s, old_stuff, new_scale
+ )
+
+ if new_scale in self.conv_refiner:
+ delta_certainty, displacement = self.conv_refiner[new_scale](
+ f1_s, f2_s, dense_flow
+ )
+ dense_flow = torch.stack(
+ (
+ dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w),
+ dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h),
+ ),
+ dim=1,
+ )
+ dense_certainty = (
+ dense_certainty + delta_certainty
+ ) # predict both certainty and displacement
+
+ dense_corresps[ins] = {
+ "dense_flow": dense_flow,
+ "dense_certainty": dense_certainty,
+ }
+
+ if new_scale != "1":
+ dense_flow = F.interpolate(
+ dense_flow,
+ size=sizes[ins // 2],
+ align_corners=False,
+ mode="bilinear",
+ )
+
+ dense_certainty = F.interpolate(
+ dense_certainty,
+ size=sizes[ins // 2],
+ align_corners=False,
+ mode="bilinear",
+ )
+ if self.detach:
+ dense_flow = dense_flow.detach()
+ dense_certainty = dense_certainty.detach()
+ return dense_corresps
+
+
+class RegressionMatcher(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ h=384,
+ w=512,
+ use_contrastive_loss = False,
+ alpha = 1,
+ beta = 0,
+ sample_mode = "threshold",
+ upsample_preds = False,
+ symmetric = False,
+ name = None,
+ use_soft_mutual_nearest_neighbours = False,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.w_resized = w
+ self.h_resized = h
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
+ self.use_contrastive_loss = use_contrastive_loss
+ self.alpha = alpha
+ self.beta = beta
+ self.sample_mode = sample_mode
+ self.upsample_preds = upsample_preds
+ self.symmetric = symmetric
+ self.name = name
+ self.sample_thresh = 0.05
+ self.upsample_res = (864,1152)
+ if use_soft_mutual_nearest_neighbours:
+ assert symmetric, "MNS requires symmetric inference"
+ self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
+
+ def extract_backbone_features(self, batch, batched = True, upsample = True):
+ #TODO: only extract stride [1,2,4,8] for upsample = True
+ x_q = batch["query"]
+ x_s = batch["support"]
+ if batched:
+ X = torch.cat((x_q, x_s))
+ feature_pyramid = self.encoder(X)
+ else:
+ feature_pyramid = self.encoder(x_q), self.encoder(x_s)
+ return feature_pyramid
+
+ def sample(
+ self,
+ dense_matches,
+ dense_certainty,
+ num=10000,
+ ):
+ if "threshold" in self.sample_mode:
+ upper_thresh = self.sample_thresh
+ dense_certainty = dense_certainty.clone()
+ dense_certainty[dense_certainty > upper_thresh] = 1
+ elif "pow" in self.sample_mode:
+ dense_certainty = dense_certainty**(1/3)
+ elif "naive" in self.sample_mode:
+ dense_certainty = torch.ones_like(dense_certainty)
+ matches, certainty = (
+ dense_matches.reshape(-1, 4),
+ dense_certainty.reshape(-1),
+ )
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
+ good_samples = torch.multinomial(certainty,
+ num_samples = min(expansion_factor*num, len(certainty)),
+ replacement=False)
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
+ if "balanced" not in self.sample_mode:
+ return good_matches, good_certainty
+
+ from ..utils.kde import kde
+ density = kde(good_matches, std=0.1)
+ p = 1 / (density+1)
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+ balanced_samples = torch.multinomial(p,
+ num_samples = min(num,len(good_certainty)),
+ replacement=False)
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
+
+ def forward(self, batch, batched = True):
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched)
+ if batched:
+ f_q_pyramid = {
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
+ }
+ f_s_pyramid = {
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
+ }
+ else:
+ f_q_pyramid, f_s_pyramid = feature_pyramid
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid)
+ if self.training and self.use_contrastive_loss:
+ return dense_corresps, (f_q_pyramid, f_s_pyramid)
+ else:
+ return dense_corresps
+
+ def forward_symmetric(self, batch, upsample = False, batched = True):
+ feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched)
+ f_q_pyramid = feature_pyramid
+ f_s_pyramid = {
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
+ for scale, f_scale in feature_pyramid.items()
+ }
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {}))
+ return dense_corresps
+
+ def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
+ kpts_A, kpts_B = matches[...,:2], matches[...,2:]
+ kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
+ kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
+ return kpts_A, kpts_B
+
+ def match(
+ self,
+ im1_path,
+ im2_path,
+ *args,
+ batched=False,
+ device = None
+ ):
+ assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
+ if isinstance(im1_path, (str, os.PathLike)):
+ im1, im2 = Image.open(im1_path), Image.open(im2_path)
+ else: # assume it is a PIL Image
+ im1, im2 = im1_path, im2_path
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ symmetric = self.symmetric
+ self.train(False)
+ with torch.no_grad():
+ if not batched:
+ b = 1
+ w, h = im1.size
+ w2, h2 = im2.size
+ # Get images in good format
+ ws = self.w_resized
+ hs = self.h_resized
+
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True
+ )
+ query, support = test_transform((im1, im2))
+ batch = {"query": query[None].to(device), "support": support[None].to(device)}
+ else:
+ b, c, h, w = im1.shape
+ b, c, h2, w2 = im2.shape
+ assert w == w2 and h == h2, "For batched images we assume same size"
+ batch = {"query": im1.to(device), "support": im2.to(device)}
+ hs, ws = self.h_resized, self.w_resized
+ finest_scale = 1
+ # Run matcher
+ if symmetric:
+ dense_corresps = self.forward_symmetric(batch, batched = True)
+ else:
+ dense_corresps = self.forward(batch, batched = True)
+
+ if self.upsample_preds:
+ hs, ws = self.upsample_res
+ low_res_certainty = F.interpolate(
+ dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ cert_clamp = 0
+ factor = 0.5
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+
+ if self.upsample_preds:
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True
+ )
+ query, support = test_transform((im1, im2))
+ query, support = query[None].to(device), support[None].to(device)
+ batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]}
+ if symmetric:
+ dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True)
+ else:
+ dense_corresps = self.forward(batch, batched = True, upsample=True)
+ query_to_support = dense_corresps[finest_scale]["dense_flow"]
+ dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
+
+ # Get certainty interpolation
+ dense_certainty = dense_certainty - low_res_certainty
+ query_to_support = query_to_support.permute(
+ 0, 2, 3, 1
+ )
+ # Create im1 meshgrid
+ query_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
+ )
+ )
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
+ dense_certainty = dense_certainty.sigmoid() # logits -> probs
+ query_coords = query_coords.permute(0, 2, 3, 1)
+ if (query_to_support.abs() > 1).any() and True:
+ wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
+ dense_certainty[wrong[:,None]] = 0
+
+ query_to_support = torch.clamp(query_to_support, -1, 1)
+ if symmetric:
+ support_coords = query_coords
+ qts, stq = query_to_support.chunk(2)
+ q_warp = torch.cat((query_coords, qts), dim=-1)
+ s_warp = torch.cat((stq, support_coords), dim=-1)
+ warp = torch.cat((q_warp, s_warp),dim=2)
+ dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0]
+ else:
+ warp = torch.cat((query_coords, query_to_support), dim=-1)
+ if batched:
+ return (
+ warp,
+ dense_certainty
+ )
+ else:
+ return (
+ warp[0],
+ dense_certainty[0],
+ )
diff --git a/imcui/third_party/DKM/dkm/models/encoders.py b/imcui/third_party/DKM/dkm/models/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..29077e1797196611e9b59a753130a5b153e0aa05
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/encoders.py
@@ -0,0 +1,147 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as tvm
+
+class ResNet18(nn.Module):
+ def __init__(self, pretrained=False) -> None:
+ super().__init__()
+ self.net = tvm.resnet18(pretrained=pretrained)
+ def forward(self, x):
+ self = self.net
+ x1 = x
+ x = self.conv1(x1)
+ x = self.bn1(x)
+ x2 = self.relu(x)
+ x = self.maxpool(x2)
+ x4 = self.layer1(x)
+ x8 = self.layer2(x4)
+ x16 = self.layer3(x8)
+ x32 = self.layer4(x16)
+ return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+class ResNet50(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
+ super().__init__()
+ if dilation is None:
+ dilation = [False,False,False]
+ if anti_aliased:
+ pass
+ else:
+ if weights is not None:
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
+ else:
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
+
+ self.high_res = high_res
+ self.freeze_bn = freeze_bn
+ def forward(self, x):
+ net = self.net
+ feats = {1:x}
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x
+ x = net.layer2(x)
+ feats[8] = x
+ x = net.layer3(x)
+ feats[16] = x
+ x = net.layer4(x)
+ feats[32] = x
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ if self.freeze_bn:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+
+
+
+class ResNet101(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
+ super().__init__()
+ if weights is not None:
+ self.net = tvm.resnet101(weights = weights)
+ else:
+ self.net = tvm.resnet101(pretrained=pretrained)
+ self.high_res = high_res
+ self.scale_factor = 1 if not high_res else 1.5
+ def forward(self, x):
+ net = self.net
+ feats = {1:x}
+ sf = self.scale_factor
+ if self.high_res:
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer2(x)
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer3(x)
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer4(x)
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+
+class WideResNet50(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
+ super().__init__()
+ if weights is not None:
+ self.net = tvm.wide_resnet50_2(weights = weights)
+ else:
+ self.net = tvm.wide_resnet50_2(pretrained=pretrained)
+ self.high_res = high_res
+ self.scale_factor = 1 if not high_res else 1.5
+ def forward(self, x):
+ net = self.net
+ feats = {1:x}
+ sf = self.scale_factor
+ if self.high_res:
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer2(x)
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer3(x)
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ x = net.layer4(x)
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
\ No newline at end of file
diff --git a/imcui/third_party/DKM/dkm/models/model_zoo/DKMv3.py b/imcui/third_party/DKM/dkm/models/model_zoo/DKMv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f4c9ede3863d778f679a033d8d2287b8776e894
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/model_zoo/DKMv3.py
@@ -0,0 +1,150 @@
+import torch
+
+from torch import nn
+from ..dkm import *
+from ..encoders import *
+
+
+def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ gp_dim = 256
+ dfn_dim = 384
+ feat_dim = 256
+ coordinate_decoder = DFN(
+ internal_dim=dfn_dim,
+ feat_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
+ }
+ ),
+ pred_input_modules=nn.ModuleDict(
+ {
+ "32": nn.Identity(),
+ "16": nn.Identity(),
+ }
+ ),
+ rrb_d_dict=nn.ModuleDict(
+ {
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
+ }
+ ),
+ cab_dict=nn.ModuleDict(
+ {
+ "32": CAB(2 * dfn_dim, dfn_dim),
+ "16": CAB(2 * dfn_dim, dfn_dim),
+ }
+ ),
+ rrb_u_dict=nn.ModuleDict(
+ {
+ "32": RRB(dfn_dim, dfn_dim),
+ "16": RRB(dfn_dim, dfn_dim),
+ }
+ ),
+ terminal_module=nn.ModuleDict(
+ {
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
+ }
+ ),
+ )
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ displacement_emb = "linear"
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512+128+(2*7+1)**2,
+ 2 * 512+128+(2*7+1)**2,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=128,
+ local_corr_radius = 7,
+ corr_in_other = True,
+ ),
+ "8": ConvRefiner(
+ 2 * 512+64+(2*3+1)**2,
+ 2 * 512+64+(2*3+1)**2,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=64,
+ local_corr_radius = 3,
+ corr_in_other = True,
+ ),
+ "4": ConvRefiner(
+ 2 * 256+32+(2*2+1)**2,
+ 2 * 256+32+(2*2+1)**2,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=32,
+ local_corr_radius = 2,
+ corr_in_other = True,
+ ),
+ "2": ConvRefiner(
+ 2 * 64+16,
+ 128+16,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=16,
+ ),
+ "1": ConvRefiner(
+ 2 * 3+6,
+ 24,
+ 3,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=6,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp32 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
+ proj = nn.ModuleDict(
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
+ )
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
+
+ encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False)
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device)
+ res = matcher.load_state_dict(weights)
+ return matcher
diff --git a/imcui/third_party/DKM/dkm/models/model_zoo/__init__.py b/imcui/third_party/DKM/dkm/models/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c85da2920c1acfac140ada2d87623203607d42ca
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/models/model_zoo/__init__.py
@@ -0,0 +1,39 @@
+weight_urls = {
+ "DKMv3": {
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
+ "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
+ },
+}
+import torch
+from .DKMv3 import DKMv3
+
+
+def DKMv3_outdoor(path_to_weights = None, device=None):
+ """
+ Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default
+ resolution can be changed by setting model.h_resized, model.w_resized later.
+ Additionally upsamples preds to fixed resolution of (864, 1152),
+ can be turned off by model.upsample_preds = False
+ """
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if path_to_weights is not None:
+ weights = torch.load(path_to_weights, map_location='cpu')
+ else:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"],
+ map_location='cpu')
+ return DKMv3(weights, 540, 720, upsample_preds = True, device=device)
+
+def DKMv3_indoor(path_to_weights = None, device=None):
+ """
+ Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default
+ Resolution can be changed by setting model.h_resized, model.w_resized later.
+ """
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if path_to_weights is not None:
+ weights = torch.load(path_to_weights, map_location=device)
+ else:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"],
+ map_location=device)
+ return DKMv3(weights, 480, 640, upsample_preds = False, device=device)
diff --git a/imcui/third_party/DKM/dkm/train/__init__.py b/imcui/third_party/DKM/dkm/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/train/__init__.py
@@ -0,0 +1 @@
+from .train import train_k_epochs
diff --git a/imcui/third_party/DKM/dkm/train/train.py b/imcui/third_party/DKM/dkm/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..b580221f56a2667784836f0237955cc75131b88c
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/train/train.py
@@ -0,0 +1,67 @@
+from tqdm import tqdm
+from dkm.utils.utils import to_cuda
+
+
+def train_step(train_batch, model, objective, optimizer, **kwargs):
+ optimizer.zero_grad()
+ out = model(train_batch)
+ l = objective(out, train_batch)
+ l.backward()
+ optimizer.step()
+ return {"train_out": out, "train_loss": l.item()}
+
+
+def train_k_steps(
+ n_0, k, dataloader, model, objective, optimizer, lr_scheduler, progress_bar=True
+):
+ for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar):
+ batch = next(dataloader)
+ model.train(True)
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ n=n,
+ )
+ lr_scheduler.step()
+
+
+def train_epoch(
+ dataloader=None,
+ model=None,
+ objective=None,
+ optimizer=None,
+ lr_scheduler=None,
+ epoch=None,
+):
+ model.train(True)
+ print(f"At epoch {epoch}")
+ for batch in tqdm(dataloader, mininterval=5.0):
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch, model=model, objective=objective, optimizer=optimizer
+ )
+ lr_scheduler.step()
+ return {
+ "model": model,
+ "optimizer": optimizer,
+ "lr_scheduler": lr_scheduler,
+ "epoch": epoch,
+ }
+
+
+def train_k_epochs(
+ start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
+):
+ for epoch in range(start_epoch, end_epoch + 1):
+ train_epoch(
+ dataloader=dataloader,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ )
diff --git a/imcui/third_party/DKM/dkm/utils/__init__.py b/imcui/third_party/DKM/dkm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05367ac9521664992f587738caa231f32ae2e81c
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/utils/__init__.py
@@ -0,0 +1,13 @@
+from .utils import (
+ pose_auc,
+ get_pose,
+ compute_relative_pose,
+ compute_pose_error,
+ estimate_pose,
+ rotate_intrinsic,
+ get_tuple_transform_ops,
+ get_depth_tuple_transform_ops,
+ warp_kpts,
+ numpy_to_pil,
+ tensor_to_pil,
+)
diff --git a/imcui/third_party/DKM/dkm/utils/kde.py b/imcui/third_party/DKM/dkm/utils/kde.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa392455e70fda4c9c77c28bda76bcb7ef9045b0
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/utils/kde.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1):
+ raise NotImplementedError("WIP, use at your own risk.")
+ # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid
+ x = x.permute(0,3,1,2)
+ B,C,H,W = x.shape
+ K = kernel_size ** 2
+ unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W)
+ scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp()
+ density = scores.sum(dim=1)
+ return density
+
+
+def kde(x, std = 0.1, device=None):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x)
+ # use a gaussian kernel to estimate density
+ x = x.to(device)
+ scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+ density = scores.sum(dim=-1)
+ return density
diff --git a/imcui/third_party/DKM/dkm/utils/local_correlation.py b/imcui/third_party/DKM/dkm/utils/local_correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c1c06291d0b760376a2b2162bcf49d6eb1303c
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/utils/local_correlation.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn.functional as F
+
+
+def local_correlation(
+ feature0,
+ feature1,
+ local_radius,
+ padding_mode="zeros",
+ flow = None
+):
+ device = feature0.device
+ b, c, h, w = feature0.size()
+ if flow is None:
+ # If flow is None, assume feature0 and feature1 are aligned
+ coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ ))
+ coords = torch.stack((coords[1], coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ else:
+ coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+ r = local_radius
+ local_window = torch.meshgrid(
+ (
+ torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device),
+ torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device),
+ ))
+ local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
+ None
+ ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2)
+ coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2)
+ window_feature = F.grid_sample(
+ feature1, coords, padding_mode=padding_mode, align_corners=False
+ )[...,None].reshape(b,c,h,w,(2*r+1)**2)
+ corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5)
+ return corr
diff --git a/imcui/third_party/DKM/dkm/utils/transforms.py b/imcui/third_party/DKM/dkm/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..754d853fda4cbcf89d2111bed4f44b0ca84f0518
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/utils/transforms.py
@@ -0,0 +1,104 @@
+from typing import Dict
+import numpy as np
+import torch
+import kornia.augmentation as K
+from kornia.geometry.transform import warp_perspective
+
+# Adapted from Kornia
+class GeometricSequential:
+ def __init__(self, *transforms, align_corners=True) -> None:
+ self.transforms = transforms
+ self.align_corners = align_corners
+
+ def __call__(self, x, mode="bilinear"):
+ b, c, h, w = x.shape
+ M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
+ for t in self.transforms:
+ if np.random.rand() < t.p:
+ M = M.matmul(
+ t.compute_transformation(x, t.generate_parameters((b, c, h, w)))
+ )
+ return (
+ warp_perspective(
+ x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
+ ),
+ M,
+ )
+
+ def apply_transform(self, x, M, mode="bilinear"):
+ b, c, h, w = x.shape
+ return warp_perspective(
+ x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
+ )
+
+
+class RandomPerspective(K.RandomPerspective):
+ def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
+ distortion_scale = torch.as_tensor(
+ self.distortion_scale, device=self._device, dtype=self._dtype
+ )
+ return self.random_perspective_generator(
+ batch_shape[0],
+ batch_shape[-2],
+ batch_shape[-1],
+ distortion_scale,
+ self.same_on_batch,
+ self.device,
+ self.dtype,
+ )
+
+ def random_perspective_generator(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ distortion_scale: torch.Tensor,
+ same_on_batch: bool = False,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.float32,
+ ) -> Dict[str, torch.Tensor]:
+ r"""Get parameters for ``perspective`` for a random perspective transform.
+
+ Args:
+ batch_size (int): the tensor batch size.
+ height (int) : height of the image.
+ width (int): width of the image.
+ distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
+ same_on_batch (bool): apply the same transformation across the batch. Default: False.
+ device (torch.device): the device on which the random numbers will be generated. Default: cpu.
+ dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
+
+ Returns:
+ params Dict[str, torch.Tensor]: parameters to be passed for transformation.
+ - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
+ - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
+
+ Note:
+ The generated random numbers are not reproducible across different devices and dtypes.
+ """
+ if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
+ raise AssertionError(
+ f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
+ )
+ if not (
+ type(height) is int and height > 0 and type(width) is int and width > 0
+ ):
+ raise AssertionError(
+ f"'height' and 'width' must be integers. Got {height}, {width}."
+ )
+
+ start_points: torch.Tensor = torch.tensor(
+ [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
+ device=distortion_scale.device,
+ dtype=distortion_scale.dtype,
+ ).expand(batch_size, -1, -1)
+
+ # generate random offset not larger than half of the image
+ fx = distortion_scale * width / 2
+ fy = distortion_scale * height / 2
+
+ factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
+ offset = (torch.rand_like(start_points) - 0.5) * 2
+ end_points = start_points + factor * offset
+
+ return dict(start_points=start_points, end_points=end_points)
diff --git a/imcui/third_party/DKM/dkm/utils/utils.py b/imcui/third_party/DKM/dkm/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..46bbe60260930aed184c6fa5907c837c0177b304
--- /dev/null
+++ b/imcui/third_party/DKM/dkm/utils/utils.py
@@ -0,0 +1,341 @@
+import numpy as np
+import cv2
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import torch.nn.functional as F
+from PIL import Image
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
+# --- GEOMETRY ---
+def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC
+ )
+
+ ret = None
+ if E is not None:
+ best_num_inliers = 0
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+
+def rotate_intrinsic(K, n):
+ base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ rot = np.linalg.matrix_power(base_rot, n)
+ return rot @ K
+
+
+def rotate_pose_inplane(i_T_w, rot):
+ rotation_matrices = [
+ np.array(
+ [
+ [np.cos(r), -np.sin(r), 0.0, 0.0],
+ [np.sin(r), np.cos(r), 0.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ dtype=np.float32,
+ )
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+ ]
+ return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+ scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
+ return np.dot(scales, K)
+
+
+def to_homogeneous(points):
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
+
+
+def angle_error_mat(R1, R2):
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+ cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
+ return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t):
+ R_gt = T_0to1[:3, :3]
+ t_gt = T_0to1[:3, 3]
+ error_t = angle_error_vec(t.squeeze(), t_gt)
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_R = angle_error_mat(R, R_gt)
+ return error_t, error_R
+
+
+def pose_auc(errors, thresholds):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0.0, errors]
+ recall = np.r_[0.0, recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e) / t)
+ return aucs
+
+
+# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
+def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
+ return TupleCompose(ops)
+
+
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize))
+ if normalize:
+ ops.append(TupleToTensorScaled())
+ ops.append(
+ TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ) # Imagenet mean/std
+ else:
+ if unscale:
+ ops.append(TupleToTensorUnscaled())
+ else:
+ ops.append(TupleToTensorScaled())
+ return TupleCompose(ops)
+
+
+class ToTensorScaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
+
+ def __call__(self, im):
+ if not isinstance(im, torch.Tensor):
+ im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
+ im /= 255.0
+ return torch.from_numpy(im)
+ else:
+ return im
+
+ def __repr__(self):
+ return "ToTensorScaled(./255)"
+
+
+class TupleToTensorScaled(object):
+ def __init__(self):
+ self.to_tensor = ToTensorScaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorScaled(./255)"
+
+
+class ToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __call__(self, im):
+ return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
+
+ def __repr__(self):
+ return "ToTensorUnscaled()"
+
+
+class TupleToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __init__(self):
+ self.to_tensor = ToTensorUnscaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorUnscaled()"
+
+
+class TupleResize(object):
+ def __init__(self, size, mode=InterpolationMode.BICUBIC):
+ self.size = size
+ self.resize = transforms.Resize(size, mode)
+
+ def __call__(self, im_tuple):
+ return [self.resize(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleResize(size={})".format(self.size)
+
+
+class TupleNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+ self.normalize = transforms.Normalize(mean=mean, std=std)
+
+ def __call__(self, im_tuple):
+ return [self.normalize(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
+
+
+class TupleCompose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, im_tuple):
+ for t in self.transforms:
+ im_tuple = t(im_tuple)
+ return im_tuple
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1)
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ (
+ n,
+ h,
+ w,
+ ) = depth0.shape
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode="bilinear")[
+ :, 0, :, 0
+ ]
+ kpts0 = torch.stack(
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ # Sample depth, get calculable_mask on depth != 0
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = (
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+ * kpts0_depth[..., None]
+ ) # (N, L, 3)
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+ kpts0_cam = kpts0_n
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
+ w_kpts0_h[:, :, [2]] + 1e-4
+ ) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (
+ (w_kpts0[:, :, 0] > 0)
+ * (w_kpts0[:, :, 0] < w - 1)
+ * (w_kpts0[:, :, 1] > 0)
+ * (w_kpts0[:, :, 1] < h - 1)
+ )
+ w_kpts0 = torch.stack(
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+ # w_kpts0[~covisible_mask, :] = -5 # xd
+
+ w_kpts0_depth = F.grid_sample(
+ depth1[:, None], w_kpts0[:, :, None], mode="bilinear"
+ )[:, 0, :, 0]
+ consistent_mask = (
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+ ).abs() < 0.05
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
+
+
+imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
+imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
+
+
+def numpy_to_pil(x: np.ndarray):
+ """
+ Args:
+ x: Assumed to be of shape (h,w,c)
+ """
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ if x.max() <= 1.01:
+ x *= 255
+ x = x.astype(np.uint8)
+ return Image.fromarray(x)
+
+
+def tensor_to_pil(x, unnormalize=False):
+ if unnormalize:
+ x = x * imagenet_std[:, None, None] + imagenet_mean[:, None, None]
+ x = x.detach().permute(1, 2, 0).cpu().numpy()
+ x = np.clip(x, 0.0, 1.0)
+ return numpy_to_pil(x)
+
+
+def to_cuda(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.to(device)
+ return batch
+
+
+def to_cpu(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cpu()
+ return batch
+
+
+def get_pose(calib):
+ w, h = np.array(calib["imsize"])[0]
+ return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
+
+
+def compute_relative_pose(R1, t1, R2, t2):
+ rots = R2 @ (R1.T)
+ trans = -rots @ t1 + t2
+ return rots, trans
diff --git a/imcui/third_party/DKM/setup.py b/imcui/third_party/DKM/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..73ae664126066249200d72c1ec3166f4f2d76b10
--- /dev/null
+++ b/imcui/third_party/DKM/setup.py
@@ -0,0 +1,9 @@
+from setuptools import setup, find_packages
+
+setup(
+ name="dkm",
+ packages=find_packages(include=("dkm*",)),
+ version="0.3.0",
+ author="Johan Edstedt",
+ install_requires=open("requirements.txt", "r").read().split("\n"),
+)
diff --git a/imcui/third_party/DarkFeat/configs/config.yaml b/imcui/third_party/DarkFeat/configs/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7ffead73fc3eac520aa7aa4bf3811c5069a4c149
--- /dev/null
+++ b/imcui/third_party/DarkFeat/configs/config.yaml
@@ -0,0 +1,24 @@
+training:
+ optimizer: 'SGD'
+ lr: 0.01
+ momentum: 0.9
+ weight_decay: 0.0001
+ lr_gamma: 0.1
+ lr_step: 200000
+network:
+ input_type: 'raw-demosaic'
+ noise: true
+ noise_maxstep: 1
+ model: 'Quad_L2Net'
+ loss_type: 'HARD_CONTRASTIVE'
+ photaug: true
+ resize: 480
+ use_corr_n: 512
+ det:
+ corr_weight: true
+ safe_radius: 12
+ kpt_n: 512
+ score_thld: -1
+ edge_thld: 10
+ nms_size: 3
+ eof_size: 5
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/configs/config_stage1.yaml b/imcui/third_party/DarkFeat/configs/config_stage1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f94e1da377bf8f507d6fa6db394b1016227d0e25
--- /dev/null
+++ b/imcui/third_party/DarkFeat/configs/config_stage1.yaml
@@ -0,0 +1,24 @@
+training:
+ optimizer: 'SGD'
+ lr: 0.1
+ momentum: 0.9
+ weight_decay: 0.0001
+ lr_gamma: 0.1
+ lr_step: 200000
+network:
+ input_type: 'raw-demosaic'
+ noise: true
+ noise_maxstep: 1
+ model: 'Quad_L2Net'
+ loss_type: 'HARD_CONTRASTIVE'
+ photaug: true
+ resize: 480
+ use_corr_n: 512
+ det:
+ corr_weight: true
+ safe_radius: 12
+ kpt_n: 512
+ score_thld: -1
+ edge_thld: 10
+ nms_size: 3
+ eof_size: 5
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/darkfeat.py b/imcui/third_party/DarkFeat/darkfeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..d146e2b41f5399ff3fc2f52ec5daff1c56e491c0
--- /dev/null
+++ b/imcui/third_party/DarkFeat/darkfeat.py
@@ -0,0 +1,359 @@
+import torch
+from torch import nn
+from torch.nn.parameter import Parameter
+import torchvision.transforms as tvf
+import torch.nn.functional as F
+import numpy as np
+
+
+def gather_nd(params, indices):
+ orig_shape = list(indices.shape)
+ num_samples = np.prod(orig_shape[:-1])
+ m = orig_shape[-1]
+ n = len(params.shape)
+
+ if m <= n:
+ out_shape = orig_shape[:-1] + list(params.shape)[m:]
+ else:
+ raise ValueError(
+ f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}'
+ )
+
+ indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
+ output = params[indices] # (num_samples, ...)
+ return output.reshape(out_shape).contiguous()
+
+
+# input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W]
+# output: [kpt_n, 128] / [kpt_n]
+def interpolate(pos, inputs, nd=True):
+ h = inputs.shape[0]
+ w = inputs.shape[1]
+
+ i = pos[:, 0]
+ j = pos[:, 1]
+
+ i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1)
+ j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1)
+
+ i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1)
+ j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)
+
+ i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1)
+ j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1)
+
+ i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1)
+ j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)
+
+ dist_i_top_left = i - i_top_left.float()
+ dist_j_top_left = j - j_top_left.float()
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
+ w_bottom_right = dist_i_top_left * dist_j_top_left
+
+ if nd:
+ w_top_left = w_top_left[..., None]
+ w_top_right = w_top_right[..., None]
+ w_bottom_left = w_bottom_left[..., None]
+ w_bottom_right = w_bottom_right[..., None]
+
+ interpolated_val = (
+ w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) +
+ w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) +
+ w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) +
+ w_bottom_right *
+ gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
+ )
+
+ return interpolated_val
+
+
+def edge_mask(inputs, n_channel, dilation=1, edge_thld=5):
+ b, c, h, w = inputs.size()
+ device = inputs.device
+
+ dii_filter = torch.tensor(
+ [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
+ ).view(1, 1, 3, 3)
+ dij_filter = 0.25 * torch.tensor(
+ [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
+ ).view(1, 1, 3, 3)
+ djj_filter = torch.tensor(
+ [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
+ ).view(1, 1, 3, 3)
+
+ dii = F.conv2d(
+ inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation
+ ).view(b, c, h, w)
+ dij = F.conv2d(
+ inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation
+ ).view(b, c, h, w)
+ djj = F.conv2d(
+ inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation
+ ).view(b, c, h, w)
+
+ det = dii * djj - dij * dij
+ tr = dii + djj
+ del dii, dij, djj
+
+ threshold = (edge_thld + 1) ** 2 / edge_thld
+ is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
+
+ return is_not_edge
+
+
+# input: score_map [batch_size, 1, H, W]
+# output: indices [2, k, 2], scores [2, k]
+def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5):
+ h = score_map.shape[2]
+ w = score_map.shape[3]
+
+ mask = score_map > score_thld
+ if nms_size > 0:
+ nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2)
+ nms_mask = torch.eq(score_map, nms_mask)
+ mask = torch.logical_and(nms_mask, mask)
+ if eof_size > 0:
+ eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device)
+ eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0)
+ eof_mask = eof_mask.bool()
+ mask = torch.logical_and(eof_mask, mask)
+ if edge_thld > 0:
+ non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld)
+ mask = torch.logical_and(non_edge_mask, mask)
+
+ bs = score_map.shape[0]
+ if bs is None:
+ indices = torch.nonzero(mask)[0]
+ scores = gather_nd(score_map, indices)[0]
+ sample = torch.sort(scores, descending=True)[1][0:k]
+ indices = indices[sample].unsqueeze(0)
+ scores = scores[sample].unsqueeze(0)
+ else:
+ indices = []
+ scores = []
+ for i in range(bs):
+ tmp_mask = mask[i][0]
+ tmp_score_map = score_map[i][0]
+ tmp_indices = torch.nonzero(tmp_mask)
+ tmp_scores = gather_nd(tmp_score_map, tmp_indices)
+ tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k]
+ tmp_indices = tmp_indices[tmp_sample]
+ tmp_scores = tmp_scores[tmp_sample]
+ indices.append(tmp_indices)
+ scores.append(tmp_scores)
+ try:
+ indices = torch.stack(indices, dim=0)
+ scores = torch.stack(scores, dim=0)
+ except:
+ min_num = np.min([len(i) for i in indices])
+ indices = torch.stack([i[:min_num] for i in indices], dim=0)
+ scores = torch.stack([i[:min_num] for i in scores], dim=0)
+ return indices, scores
+
+
+# input: [batch_size, C, H, W]
+# output: [batch_size, C, H, W], [batch_size, C, H, W]
+def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
+ inputs = inputs / moving_instance_max
+
+ batch_size, C, H, W = inputs.shape
+
+ pad_size = ksize // 2 + (dilation - 1)
+ kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize)
+
+ pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect')
+
+ avg_spatial_inputs = F.conv2d(
+ pad_inputs,
+ kernel,
+ stride=1,
+ dilation=dilation,
+ padding=0,
+ groups=C
+ )
+ avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1
+ # print(avg_spatial_inputs.shape)
+
+ alpha = F.softplus(inputs - avg_spatial_inputs)
+ beta = F.softplus(inputs - avg_channel_inputs)
+
+ return alpha, beta
+
+
+class DarkFeat(nn.Module):
+ default_config = {
+ 'model_path': '',
+ 'input_type': 'raw-demosaic',
+ 'kpt_n': 5000,
+ 'kpt_refinement': True,
+ 'score_thld': 0.5,
+ 'edge_thld': 10,
+ 'multi_scale': False,
+ 'multi_level': True,
+ 'nms_size': 3,
+ 'eof_size': 5,
+ 'need_norm': True,
+ 'use_peakiness': True
+ }
+
+ def __init__(self, model_path='', inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
+ super(DarkFeat, self).__init__()
+ inchan = 3 if self.default_config['input_type'] == 'rgb' or self.default_config['input_type'] == 'raw-demosaic' else 1
+ self.config = {**self.default_config}
+
+ self.inchan = inchan
+ self.curchan = inchan
+ self.dilated = dilated
+ self.dilation = dilation
+ self.bn = bn
+ self.bn_affine = bn_affine
+ self.config['model_path'] = model_path
+
+ dim = 128
+ mchan = 4
+
+ self.conv0 = self._add_conv( 8*mchan)
+ self.conv1 = self._add_conv( 8*mchan, bn=False)
+ self.bn1 = self._make_bn(8*mchan)
+ self.conv2 = self._add_conv( 16*mchan, stride=2)
+ self.conv3 = self._add_conv( 16*mchan, bn=False)
+ self.bn3 = self._make_bn(16*mchan)
+ self.conv4 = self._add_conv( 32*mchan, stride=2)
+ self.conv5 = self._add_conv( 32*mchan)
+ # replace last 8x8 convolution with 3 3x3 convolutions
+ self.conv6_0 = self._add_conv( 32*mchan)
+ self.conv6_1 = self._add_conv( 32*mchan)
+ self.conv6_2 = self._add_conv(dim, bn=False, relu=False)
+ self.out_dim = dim
+
+ self.moving_avg_params = nn.ParameterList([
+ Parameter(torch.tensor(1.), requires_grad=False),
+ Parameter(torch.tensor(1.), requires_grad=False),
+ Parameter(torch.tensor(1.), requires_grad=False)
+ ])
+ self.clf = nn.Conv2d(128, 2, kernel_size=1)
+
+ state_dict = torch.load(self.config["model_path"], map_location="cpu")
+ new_state_dict = {}
+
+ for key in state_dict:
+ if 'running_mean' not in key and 'running_var' not in key and 'num_batches_tracked' not in key:
+ new_state_dict[key] = state_dict[key]
+
+ self.load_state_dict(new_state_dict)
+ print('Loaded DarkFeat model')
+
+ def _make_bn(self, outd):
+ return nn.BatchNorm2d(outd, affine=self.bn_affine, track_running_stats=False)
+
+ def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False):
+ d = self.dilation * dilation
+ conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias)
+
+ ops = nn.ModuleList([])
+
+ ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) )
+ if bn and self.bn: ops.append( self._make_bn(outd) )
+ if relu: ops.append( nn.ReLU(inplace=True) )
+ self.curchan = outd
+
+ if k_pool > 1:
+ if pool_type == 'avg':
+ ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
+ elif pool_type == 'max':
+ ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
+ else:
+ print(f"Error, unknown pooling type {pool_type}...")
+
+ return nn.Sequential(*ops)
+
+ def forward(self, input):
+ """ Compute keypoints, scores, descriptors for image """
+ data = input['image']
+ H, W = data.shape[2:]
+
+ if self.config['input_type'] == 'rgb':
+ # 3-channel rgb
+ RGB_mean = [0.485, 0.456, 0.406]
+ RGB_std = [0.229, 0.224, 0.225]
+ norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
+ data = norm_RGB(data)
+
+ elif self.config['input_type'] == 'gray':
+ # 1-channel
+ data = torch.mean(data, dim=1, keepdim=True)
+ norm_gray0 = tvf.Normalize(mean=data.mean(), std=data.std())
+ data = norm_gray0(data)
+
+ elif self.config['input_type'] == 'raw':
+ # 4-channel
+ pass
+ elif self.config['input_type'] == 'raw-demosaic':
+ # 3-channel
+ pass
+ else:
+ raise NotImplementedError()
+
+ # x: [N, C, H, W]
+ x0 = self.conv0(data)
+ x1 = self.conv1(x0)
+ x1_bn = self.bn1(x1)
+ x2 = self.conv2(x1_bn)
+ x3 = self.conv3(x2)
+ x3_bn = self.bn3(x3)
+ x4 = self.conv4(x3_bn)
+ x5 = self.conv5(x4)
+ x6_0 = self.conv6_0(x5)
+ x6_1 = self.conv6_1(x6_0)
+ x6_2 = self.conv6_2(x6_1)
+
+ comb_weights = torch.tensor([1., 2., 3.], device=data.device)
+ comb_weights /= torch.sum(comb_weights)
+ ksize = [3, 2, 1]
+ det_score_maps = []
+
+ for idx, xx in enumerate([x1, x3, x6_2]):
+ alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx])
+ score_vol = alpha * beta
+ det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0]
+ det_score_map = F.interpolate(det_score_map, size=data.shape[2:], mode='bilinear', align_corners=True)
+ det_score_map = comb_weights[idx] * det_score_map
+ det_score_maps.append(det_score_map)
+
+ det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0)
+
+ desc = x6_2
+ score_map = det_score_map
+ conf = F.softmax(self.clf((desc)**2), dim=1)[:,1:2]
+ score_map = score_map * F.interpolate(conf, size=score_map.shape[2:], mode='bilinear', align_corners=True)
+
+ kpt_inds, kpt_score = extract_kpts(
+ score_map,
+ k=self.config['kpt_n'],
+ score_thld=self.config['score_thld'],
+ nms_size=self.config['nms_size'],
+ eof_size=self.config['eof_size'],
+ edge_thld=self.config['edge_thld']
+ )
+
+ descs = F.normalize(
+ interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)),
+ p=2,
+ dim=-1
+ ).detach().cpu().numpy(),
+ kpts = np.squeeze(torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0) \
+ * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32)
+ scores = np.squeeze(kpt_score.detach().cpu().numpy(), axis=0)
+
+ idxs = np.negative(scores).argsort()[0:self.config['kpt_n']]
+ descs = descs[0][idxs]
+ kpts = kpts[idxs]
+ scores = scores[idxs]
+
+ return {
+ 'keypoints': kpts,
+ 'scores': torch.from_numpy(scores),
+ 'descriptors': torch.from_numpy(descs.T),
+ }
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/cal_metrics.py b/imcui/third_party/DarkFeat/datasets/InvISP/cal_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc3e501664487de4c08ab8c89328dd266fba2868
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/cal_metrics.py
@@ -0,0 +1,114 @@
+import cv2
+import numpy as np
+import math
+# from skimage.metrics import structural_similarity as ssim
+from skimage.measure import compare_ssim
+from scipy.misc import imread
+from glob import glob
+
+import argparse
+
+parser = argparse.ArgumentParser(description="evaluation codes")
+
+parser.add_argument("--path", type=str, help="Path to evaluate images.")
+
+args = parser.parse_args()
+
+def psnr(img1, img2):
+ mse = np.mean( (img1/255. - img2/255.) ** 2 )
+ if mse < 1.0e-10:
+ return 100
+ PIXEL_MAX = 1
+ return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
+def psnr_raw(img1, img2):
+ mse = np.mean( (img1 - img2) ** 2 )
+ if mse < 1.0e-10:
+ return 100
+ PIXEL_MAX = 1
+ return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
+
+
+def my_ssim(img1, img2):
+ return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), multichannel=True)
+
+
+def quan_eval(path, suffix="jpg"):
+ # path: /disk2/yazhou/projects/IISP/exps/test_final_unet_globalEDV2/
+ # ours
+ gt_imgs = sorted(glob(path+"tar*.%s"%suffix))
+ pred_imgs = sorted(glob(path+"pred*.%s"%suffix))
+
+ # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb:
+ # gt_imgs = [line.rstrip() for line in f_gt.readlines()]
+ # pred_imgs = [line.rstrip() for line in f_rgb.readlines()]
+
+ assert len(gt_imgs) == len(pred_imgs)
+
+ psnr_avg = 0.
+ ssim_avg = 0.
+ for i in range(len(gt_imgs)):
+ gt = imread(gt_imgs[i])
+ pred = imread(pred_imgs[i])
+ psnr_temp = psnr(gt, pred)
+ psnr_avg += psnr_temp
+ ssim_temp = my_ssim(gt, pred)
+ ssim_avg += ssim_temp
+
+ print("psnr: ", psnr_temp)
+ print("ssim: ", ssim_temp)
+
+ psnr_avg /= float(len(gt_imgs))
+ ssim_avg /= float(len(gt_imgs))
+
+ print("psnr_avg: ", psnr_avg)
+ print("ssim_avg: ", ssim_avg)
+
+ return psnr_avg, ssim_avg
+
+def mse(gt, pred):
+ return np.mean((gt-pred)**2)
+
+def mse_raw(path, suffix="npy"):
+ gt_imgs = sorted(glob(path+"raw_tar*.%s"%suffix))
+ pred_imgs = sorted(glob(path+"raw_pred*.%s"%suffix))
+
+ # with open(split_path + "test_gt.txt", 'r') as f_gt, open(split_path+"test_rgb.txt","r") as f_rgb:
+ # gt_imgs = [line.rstrip() for line in f_gt.readlines()]
+ # pred_imgs = [line.rstrip() for line in f_rgb.readlines()]
+
+ assert len(gt_imgs) == len(pred_imgs)
+
+ mse_avg = 0.
+ psnr_avg = 0.
+ for i in range(len(gt_imgs)):
+ gt = np.load(gt_imgs[i])
+ pred = np.load(pred_imgs[i])
+ mse_temp = mse(gt, pred)
+ mse_avg += mse_temp
+ psnr_temp = psnr_raw(gt, pred)
+ psnr_avg += psnr_temp
+
+ print("mse: ", mse_temp)
+ print("psnr: ", psnr_temp)
+
+ mse_avg /= float(len(gt_imgs))
+ psnr_avg /= float(len(gt_imgs))
+
+ print("mse_avg: ", mse_avg)
+ print("psnr_avg: ", psnr_avg)
+
+ return mse_avg, psnr_avg
+
+test_full = False
+
+# if test_full:
+# psnr_avg, ssim_avg = quan_eval(ROOT_PATH+"%s/vis_%s_full/"%(args.task, args.ckpt), "jpeg")
+# mse_avg, psnr_avg_raw = mse_raw(ROOT_PATH+"%s/vis_%s_full/"%(args.task, args.ckpt))
+# else:
+psnr_avg, ssim_avg = quan_eval(args.path, "jpg")
+mse_avg, psnr_avg_raw = mse_raw(args.path)
+
+print("pnsr: {}, ssim: {}, mse: {}, psnr raw: {}".format(psnr_avg, ssim_avg, mse_avg, psnr_avg_raw))
+
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/config/config.py b/imcui/third_party/DarkFeat/datasets/InvISP/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc42182ecf7464cc85ed5c77b7aeb9ee4e3ecd74
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/config/config.py
@@ -0,0 +1,21 @@
+import argparse
+
+BATCH_SIZE = 1
+
+DATA_PATH = "./data/"
+
+
+
+def get_arguments():
+ parser = argparse.ArgumentParser(description="training codes")
+
+ parser.add_argument("--task", type=str, help="Name of this training")
+ parser.add_argument("--data_path", type=str, default=DATA_PATH, help="Dataset root path.")
+ parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Batch size for training. ")
+ parser.add_argument("--debug_mode", dest='debug_mode', action='store_true', help="If debug mode, load less data.")
+ parser.add_argument("--gamma", dest='gamma', action='store_true', help="Use gamma compression for raw data.")
+ parser.add_argument("--camera", type=str, default="NIKON_D700", choices=["NIKON_D700", "Canon_EOS_5D"], help="Choose which camera to use. ")
+ parser.add_argument("--rgb_weight", type=float, default=1, help="Weight for rgb loss. ")
+
+
+ return parser
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py b/imcui/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..62271771a17a4863b730136d49f2a23aed0e49b2
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/data/data_preprocess.py
@@ -0,0 +1,56 @@
+import rawpy
+import numpy as np
+import glob, os
+import colour_demosaicing
+import imageio
+import argparse
+from PIL import Image as PILImage
+import scipy.io as scio
+
+parser = argparse.ArgumentParser(description="data preprocess")
+
+parser.add_argument("--camera", type=str, default="NIKON_D700", help="Camera Name")
+parser.add_argument("--Bayer_Pattern", type=str, default="RGGB", help="Bayer Pattern of RAW")
+parser.add_argument("--JPEG_Quality", type=int, default=90, help="Jpeg Quality of the ground truth.")
+
+args = parser.parse_args()
+camera_name = args.camera
+Bayer_Pattern = args.Bayer_Pattern
+JPEG_Quality = args.JPEG_Quality
+
+dng_path = sorted(glob.glob('/mnt/nvme2n1/hyz/data/' + camera_name + '/DNG/*.cr2'))
+rgb_target_path = '/mnt/nvme2n1/hyz/data/'+ camera_name + '/RGB/'
+raw_input_path = '/mnt/nvme2n1/hyz/data/' + camera_name + '/RAW/'
+if not os.path.isdir(rgb_target_path):
+ os.mkdir(rgb_target_path)
+if not os.path.isdir(raw_input_path):
+ os.mkdir(raw_input_path)
+
+def flip(raw_img, flip):
+ if flip == 3:
+ raw_img = np.rot90(raw_img, k=2)
+ elif flip == 5:
+ raw_img = np.rot90(raw_img, k=1)
+ elif flip == 6:
+ raw_img = np.rot90(raw_img, k=3)
+ else:
+ pass
+ return raw_img
+
+
+
+for path in dng_path:
+ print("Start Processing %s" % os.path.basename(path))
+ raw = rawpy.imread(path)
+ file_name = path.split('/')[-1].split('.')[0]
+ im = raw.postprocess(use_camera_wb=True,no_auto_bright=True)
+ flip_val = raw.sizes.flip
+ cwb = raw.camera_whitebalance
+ raw_img = raw.raw_image_visible
+ if camera_name == 'Canon_EOS_5D':
+ raw_img = np.maximum(raw_img - 127.0, 0)
+ de_raw = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw_img, Bayer_Pattern)
+ de_raw = flip(de_raw, flip_val)
+ rgb_img = PILImage.fromarray(im).save(rgb_target_path + file_name + '.jpg', quality = JPEG_Quality, subsampling = 1)
+ np.savez(raw_input_path + file_name + '.npz', raw=de_raw, wb=cwb)
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c71bd3b4162bd21761983deef6b94fa46a364f6
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/FiveK_dataset.py
@@ -0,0 +1,132 @@
+from __future__ import print_function, division
+import os, random, time
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+from torchvision import transforms, utils
+import rawpy
+from glob import glob
+from PIL import Image as PILImage
+import numbers
+from scipy.misc import imread
+from .base_dataset import BaseDataset
+
+
+class FiveKDatasetTrain(BaseDataset):
+ def __init__(self, opt):
+ super().__init__(opt=opt)
+ self.patch_size = 256
+ input_RAWs_WBs, target_RGBs = self.load(is_train=True)
+ assert len(input_RAWs_WBs) == len(target_RGBs)
+ self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs}
+
+ def random_flip(self, input_raw, target_rgb):
+ idx = np.random.randint(2)
+ input_raw = np.flip(input_raw,axis=idx).copy()
+ target_rgb = np.flip(target_rgb,axis=idx).copy()
+
+ return input_raw, target_rgb
+
+ def random_rotate(self, input_raw, target_rgb):
+ idx = np.random.randint(4)
+ input_raw = np.rot90(input_raw,k=idx)
+ target_rgb = np.rot90(target_rgb,k=idx)
+
+ return input_raw, target_rgb
+
+ def random_crop(self, patch_size, input_raw, target_rgb,flow=False,demos=False):
+ H, W, _ = input_raw.shape
+ rnd_h = random.randint(0, max(0, H - patch_size))
+ rnd_w = random.randint(0, max(0, W - patch_size))
+
+ patch_input_raw = input_raw[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :]
+ if flow or demos:
+ patch_target_rgb = target_rgb[rnd_h:rnd_h + patch_size, rnd_w:rnd_w + patch_size, :]
+ else:
+ patch_target_rgb = target_rgb[rnd_h*2:rnd_h*2 + patch_size*2, rnd_w*2:rnd_w*2 + patch_size*2, :]
+
+ return patch_input_raw, patch_target_rgb
+
+ def aug(self, patch_size, input_raw, target_rgb, flow=False, demos=False):
+ input_raw, target_rgb = self.random_crop(patch_size, input_raw,target_rgb,flow=flow, demos=demos)
+ input_raw, target_rgb = self.random_rotate(input_raw,target_rgb)
+ input_raw, target_rgb = self.random_flip(input_raw,target_rgb)
+
+ return input_raw, target_rgb
+
+ def __len__(self):
+ return len(self.data['input_RAWs_WBs'])
+
+ def __getitem__(self, idx):
+ input_raw_wb_path = self.data['input_RAWs_WBs'][idx]
+ target_rgb_path = self.data['target_RGBs'][idx]
+
+ target_rgb_img = imread(target_rgb_path)
+ input_raw_wb = np.load(input_raw_wb_path)
+ input_raw_img = input_raw_wb['raw']
+ wb = input_raw_wb['wb']
+ wb = wb / wb.max()
+ input_raw_img = input_raw_img * wb[:-1]
+
+ self.patch_size = 256
+ input_raw_img, target_rgb_img = self.aug(self.patch_size, input_raw_img, target_rgb_img, flow=True, demos=True)
+
+ if self.gamma:
+ norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2)
+ input_raw_img = np.power(input_raw_img, 1/2.2)
+ else:
+ norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383
+
+ target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
+ input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
+ target_raw_img = input_raw_img.copy()
+
+ input_raw_img = self.np2tensor(input_raw_img).float()
+ target_rgb_img = self.np2tensor(target_rgb_img).float()
+ target_raw_img = self.np2tensor(target_raw_img).float()
+
+ sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img,
+ 'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]}
+ return sample
+
+class FiveKDatasetTest(BaseDataset):
+ def __init__(self, opt):
+ super().__init__(opt=opt)
+ self.patch_size = 256
+
+ input_RAWs_WBs, target_RGBs = self.load(is_train=False)
+ assert len(input_RAWs_WBs) == len(target_RGBs)
+ self.data = {'input_RAWs_WBs':input_RAWs_WBs, 'target_RGBs':target_RGBs}
+
+ def __len__(self):
+ return len(self.data['input_RAWs_WBs'])
+
+ def __getitem__(self, idx):
+ input_raw_wb_path = self.data['input_RAWs_WBs'][idx]
+ target_rgb_path = self.data['target_RGBs'][idx]
+
+ target_rgb_img = imread(target_rgb_path)
+ input_raw_wb = np.load(input_raw_wb_path)
+ input_raw_img = input_raw_wb['raw']
+ wb = input_raw_wb['wb']
+ wb = wb / wb.max()
+ input_raw_img = input_raw_img * wb[:-1]
+
+ if self.gamma:
+ norm_value = np.power(4095, 1/2.2) if self.camera_name=='Canon_EOS_5D' else np.power(16383, 1/2.2)
+ input_raw_img = np.power(input_raw_img, 1/2.2)
+ else:
+ norm_value = 4095 if self.camera_name=='Canon_EOS_5D' else 16383
+
+ target_rgb_img = self.norm_img(target_rgb_img, max_value=255)
+ input_raw_img = self.norm_img(input_raw_img, max_value=norm_value)
+ target_raw_img = input_raw_img.copy()
+
+ input_raw_img = self.np2tensor(input_raw_img).float()
+ target_rgb_img = self.np2tensor(target_rgb_img).float()
+ target_raw_img = self.np2tensor(target_raw_img).float()
+
+ sample = {'input_raw':input_raw_img, 'target_rgb':target_rgb_img, 'target_raw':target_raw_img,
+ 'file_name':input_raw_wb_path.split("/")[-1].split(".")[0]}
+ return sample
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/dataset/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..34c5de9f75dbfb5323c2cdad532cb0a42c09df22
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/dataset/base_dataset.py
@@ -0,0 +1,84 @@
+from __future__ import print_function, division
+import numpy as np
+from torch.utils.data import Dataset
+import torch
+
+class BaseDataset(Dataset):
+ def __init__(self, opt):
+ self.crop_size = 512
+ self.debug_mode = opt.debug_mode
+ self.data_path = opt.data_path # dataset path. e.g., ./data/
+ self.camera_name = opt.camera
+ self.gamma = opt.gamma
+
+ def norm_img(self, img, max_value):
+ img = img / float(max_value)
+ return img
+
+ def pack_raw(self, raw):
+ # pack Bayer image to 4 channels
+ im = np.expand_dims(raw, axis=2)
+ H, W = raw.shape[0], raw.shape[1]
+ # RGBG
+ out = np.concatenate((im[0:H:2, 0:W:2, :],
+ im[0:H:2, 1:W:2, :],
+ im[1:H:2, 1:W:2, :],
+ im[1:H:2, 0:W:2, :]), axis=2)
+ return out
+
+ def np2tensor(self, array):
+ return torch.Tensor(array).permute(2,0,1)
+
+ def center_crop(self, img, crop_size=None):
+ H = img.shape[0]
+ W = img.shape[1]
+
+ if crop_size is not None:
+ th, tw = crop_size[0], crop_size[1]
+ else:
+ th, tw = self.crop_size, self.crop_size
+ x1_img = int(round((W - tw) / 2.))
+ y1_img = int(round((H - th) / 2.))
+ if img.ndim == 3:
+ input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw, :]
+ else:
+ input_patch = img[y1_img:y1_img + th, x1_img:x1_img + tw]
+
+ return input_patch
+
+ def load(self, is_train=True):
+ # ./data
+ # ./data/NIKON D700/RAW, ./data/NIKON D700/RGB
+ # ./data/Canon EOS 5D/RAW, ./data/Canon EOS 5D/RGB
+ # ./data/NIKON D700_train.txt, ./data/NIKON D700_test.txt
+ # ./data/NIKON D700_train.txt: a0016, ...
+ input_RAWs_WBs = []
+ target_RGBs = []
+
+ data_path = self.data_path # ./data/
+ if is_train:
+ txt_path = data_path + self.camera_name + "_train.txt"
+ else:
+ txt_path = data_path + self.camera_name + "_test.txt"
+
+ with open(txt_path, "r") as f_read:
+ # valid_camera_list = [os.path.basename(line.strip()).split('.')[0] for line in f_read.readlines()]
+ valid_camera_list = [line.strip() for line in f_read.readlines()]
+
+ if self.debug_mode:
+ valid_camera_list = valid_camera_list[:10]
+
+ for i,name in enumerate(valid_camera_list):
+ full_name = data_path + self.camera_name
+ input_RAWs_WBs.append(full_name + "/RAW/" + name + ".npz")
+ target_RGBs.append(full_name + "/RGB/" + name + ".jpg")
+
+ return input_RAWs_WBs, target_RGBs
+
+
+ def __len__(self):
+ return 0
+
+ def __getitem__(self, idx):
+
+ return None
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/environment.yml b/imcui/third_party/DarkFeat/datasets/InvISP/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..20a58415354b80fb01f72fbbeb8d55edee6067ce
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/environment.yml
@@ -0,0 +1,56 @@
+name: invertible-isp
+channels:
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _pytorch_select=0.2=gpu_0
+ - blas=1.0=mkl
+ - ca-certificates=2021.1.19=h06a4308_1
+ - certifi=2020.12.5=py36h06a4308_0
+ - cffi=1.14.5=py36h261ae71_0
+ - cudatoolkit=10.1.243=h6bb024c_0
+ - cudnn=7.6.5=cuda10.1_0
+ - freetype=2.10.4=h5ab3b9f_0
+ - intel-openmp=2020.2=254
+ - jpeg=9b=h024ee3a_2
+ - lcms2=2.11=h396b838_0
+ - ld_impl_linux-64=2.33.1=h53a641e_7
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=9.1.0=hdf63c60_0
+ - libpng=1.6.37=hbc83047_0
+ - libstdcxx-ng=9.1.0=hdf63c60_0
+ - libtiff=4.1.0=h2733197_1
+ - lz4-c=1.9.3=h2531618_0
+ - mkl=2020.2=256
+ - mkl-service=2.3.0=py36he8ac12f_0
+ - mkl_fft=1.3.0=py36h54f3939_0
+ - mkl_random=1.1.1=py36h0573a6f_0
+ - ncurses=6.2=he6710b0_1
+ - ninja=1.10.2=py36hff7bd54_0
+ - numpy=1.19.2=py36h54aff64_0
+ - numpy-base=1.19.2=py36hfa32c7d_0
+ - olefile=0.46=py36_0
+ - openssl=1.1.1k=h27cfd23_0
+ - pillow=8.2.0=py36he98fc37_0
+ - pip=21.0.1=py36h06a4308_0
+ - pycparser=2.20=py_2
+ - python=3.6.13=hdb3f193_0
+ - pytorch=1.4.0=cuda101py36h02f0884_0
+ - readline=8.1=h27cfd23_0
+ - setuptools=52.0.0=py36h06a4308_0
+ - six=1.15.0=py36h06a4308_0
+ - sqlite=3.35.3=hdfb4753_0
+ - tk=8.6.10=hbc83047_0
+ - torchvision=0.2.1=py36_0
+ - wheel=0.36.2=pyhd3eb1b0_0
+ - xz=5.2.5=h7b6447c_0
+ - zlib=1.2.11=h7b6447c_3
+ - zstd=1.4.9=haebb681_0
+ - pip:
+ - colour-demosaicing==0.1.6
+ - colour-science==0.3.16
+ - imageio==2.9.0
+ - rawpy==0.16.0
+ - scipy==1.2.0
+ - tqdm==4.59.0
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/loss.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..abe8b599d5402c367bb7c84b7e370964d8273518
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/loss.py
@@ -0,0 +1,15 @@
+import torch.nn.functional as F
+import torch
+
+
+def l1_loss(output, target_rgb, target_raw, weight=1.):
+ raw_loss = F.l1_loss(output['reconstruct_raw'], target_raw)
+ rgb_loss = F.l1_loss(output['reconstruct_rgb'], target_rgb)
+ total_loss = raw_loss + weight * rgb_loss
+ return total_loss, raw_loss, rgb_loss
+
+def l2_loss(output, target_rgb, target_raw, weight=1.):
+ raw_loss = F.mse_loss(output['reconstruct_raw'], target_raw)
+ rgb_loss = F.mse_loss(output['reconstruct_rgb'], target_rgb)
+ total_loss = raw_loss + weight * rgb_loss
+ return total_loss, raw_loss, rgb_loss
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/model.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dd0e33cee8ebb26d621ece84622bd2611b33a60
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/model.py
@@ -0,0 +1,179 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import torch.nn.init as init
+
+from .modules import InvertibleConv1x1
+
+
+def initialize_weights(net_l, scale=1):
+ if not isinstance(net_l, list):
+ net_l = [net_l]
+ for net in net_l:
+ for m in net.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+ m.weight.data *= scale # for residual block
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias.data, 0.0)
+
+
+def initialize_weights_xavier(net_l, scale=1):
+ if not isinstance(net_l, list):
+ net_l = [net_l]
+ for net in net_l:
+ for m in net.modules():
+ if isinstance(m, nn.Conv2d):
+ init.xavier_normal_(m.weight)
+ m.weight.data *= scale # for residual block
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ init.xavier_normal_(m.weight)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias.data, 0.0)
+
+
+class DenseBlock(nn.Module):
+ def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
+ super(DenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
+ self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
+ self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ if init == 'xavier':
+ initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
+ else:
+ initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
+ initialize_weights(self.conv5, 0)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+
+ return x5
+
+def subnet(net_structure, init='xavier'):
+ def constructor(channel_in, channel_out):
+ if net_structure == 'DBNet':
+ if init == 'xavier':
+ return DenseBlock(channel_in, channel_out, init)
+ else:
+ return DenseBlock(channel_in, channel_out)
+ # return UNetBlock(channel_in, channel_out)
+ else:
+ return None
+
+ return constructor
+
+
+class InvBlock(nn.Module):
+ def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=0.8):
+ super(InvBlock, self).__init__()
+ # channel_num: 3
+ # channel_split_num: 1
+
+ self.split_len1 = channel_split_num # 1
+ self.split_len2 = channel_num - channel_split_num # 2
+
+ self.clamp = clamp
+
+ self.F = subnet_constructor(self.split_len2, self.split_len1)
+ self.G = subnet_constructor(self.split_len1, self.split_len2)
+ self.H = subnet_constructor(self.split_len1, self.split_len2)
+
+ in_channels = 3
+ self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True)
+ self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev)
+
+ def forward(self, x, rev=False):
+ if not rev:
+ # invert1x1conv
+ x, logdet = self.flow_permutation(x, logdet=0, rev=False)
+
+ # split to 1 channel and 2 channel.
+ x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
+
+ y1 = x1 + self.F(x2) # 1 channel
+ self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1)
+ y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel
+ out = torch.cat((y1, y2), 1)
+ else:
+ # split.
+ x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
+ self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1)
+ y2 = (x2 - self.G(x1)).div(torch.exp(self.s))
+ y1 = x1 - self.F(y2)
+
+ x = torch.cat((y1, y2), 1)
+
+ # inv permutation
+ out, logdet = self.flow_permutation(x, logdet=0, rev=True)
+
+ return out
+
+class InvISPNet(nn.Module):
+ def __init__(self, channel_in=3, channel_out=3, subnet_constructor=subnet('DBNet'), block_num=8):
+ super(InvISPNet, self).__init__()
+ operations = []
+
+ current_channel = channel_in
+ channel_num = channel_in
+ channel_split_num = 1
+
+ for j in range(block_num):
+ b = InvBlock(subnet_constructor, channel_num, channel_split_num) # one block is one flow step.
+ operations.append(b)
+
+ self.operations = nn.ModuleList(operations)
+
+ self.initialize()
+
+ def initialize(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.xavier_normal_(m.weight)
+ m.weight.data *= 1. # for residual block
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ init.xavier_normal_(m.weight)
+ m.weight.data *= 1.
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias.data, 0.0)
+
+ def forward(self, x, rev=False):
+ out = x # x: [N,3,H,W]
+
+ if not rev:
+ for op in self.operations:
+ out = op.forward(out, rev)
+ else:
+ for op in reversed(self.operations):
+ out = op.forward(out, rev)
+
+ return out
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/modules.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..88244c0b211860d97be78ba4f60f4743228171a7
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/modules.py
@@ -0,0 +1,387 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import split_feature, compute_same_pad
+
+
+def gaussian_p(mean, logs, x):
+ """
+ lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
+ k = 1 (Independent)
+ Var = logs ** 2
+ """
+ c = math.log(2 * math.pi)
+ return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c)
+
+
+def gaussian_likelihood(mean, logs, x):
+ p = gaussian_p(mean, logs, x)
+ return torch.sum(p, dim=[1, 2, 3])
+
+
+def gaussian_sample(mean, logs, temperature=1):
+ # Sample from Gaussian with temperature
+ z = torch.normal(mean, torch.exp(logs) * temperature)
+
+ return z
+
+
+def squeeze2d(input, factor):
+ if factor == 1:
+ return input
+
+ B, C, H, W = input.size()
+
+ assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0"
+
+ x = input.view(B, C, H // factor, factor, W // factor, factor)
+ x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
+ x = x.view(B, C * factor * factor, H // factor, W // factor)
+
+ return x
+
+
+def unsqueeze2d(input, factor):
+ if factor == 1:
+ return input
+
+ factor2 = factor ** 2
+
+ B, C, H, W = input.size()
+
+ assert C % (factor2) == 0, "C module factor squared is not 0"
+
+ x = input.view(B, C // factor2, factor, factor, H, W)
+ x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
+ x = x.view(B, C // (factor2), H * factor, W * factor)
+
+ return x
+
+
+class _ActNorm(nn.Module):
+ """
+ Activation Normalization
+ Initialize the bias and scale with a given minibatch,
+ so that the output per-channel have zero mean and unit variance for that.
+
+ After initialization, `bias` and `logs` will be trained as parameters.
+ """
+
+ def __init__(self, num_features, scale=1.0):
+ super().__init__()
+ # register mean and scale
+ size = [1, num_features, 1, 1]
+ self.bias = nn.Parameter(torch.zeros(*size))
+ self.logs = nn.Parameter(torch.zeros(*size))
+ self.num_features = num_features
+ self.scale = scale
+ self.inited = False
+
+ def initialize_parameters(self, input):
+ if not self.training:
+ raise ValueError("In Eval mode, but ActNorm not inited")
+
+ with torch.no_grad():
+ bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True)
+ vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
+ logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
+
+ self.bias.data.copy_(bias.data)
+ self.logs.data.copy_(logs.data)
+
+ self.inited = True
+
+ def _center(self, input, reverse=False):
+ if reverse:
+ return input - self.bias
+ else:
+ return input + self.bias
+
+ def _scale(self, input, logdet=None, reverse=False):
+
+ if reverse:
+ input = input * torch.exp(-self.logs)
+ else:
+ input = input * torch.exp(self.logs)
+
+ if logdet is not None:
+ """
+ logs is log_std of `mean of channels`
+ so we need to multiply by number of pixels
+ """
+ b, c, h, w = input.shape
+
+ dlogdet = torch.sum(self.logs) * h * w
+
+ if reverse:
+ dlogdet *= -1
+
+ logdet = logdet + dlogdet
+
+ return input, logdet
+
+ def forward(self, input, logdet=None, reverse=False):
+ self._check_input_dim(input)
+
+ if not self.inited:
+ self.initialize_parameters(input)
+
+ if reverse:
+ input, logdet = self._scale(input, logdet, reverse)
+ input = self._center(input, reverse)
+ else:
+ input = self._center(input, reverse)
+ input, logdet = self._scale(input, logdet, reverse)
+
+ return input, logdet
+
+
+class ActNorm2d(_ActNorm):
+ def __init__(self, num_features, scale=1.0):
+ super().__init__(num_features, scale)
+
+ def _check_input_dim(self, input):
+ assert len(input.size()) == 4
+ assert input.size(1) == self.num_features, (
+ "[ActNorm]: input should be in shape as `BCHW`,"
+ " channels should be {} rather than {}".format(
+ self.num_features, input.size()
+ )
+ )
+
+
+class LinearZeros(nn.Module):
+ def __init__(self, in_channels, out_channels, logscale_factor=3):
+ super().__init__()
+
+ self.linear = nn.Linear(in_channels, out_channels)
+ self.linear.weight.data.zero_()
+ self.linear.bias.data.zero_()
+
+ self.logscale_factor = logscale_factor
+
+ self.logs = nn.Parameter(torch.zeros(out_channels))
+
+ def forward(self, input):
+ output = self.linear(input)
+ return output * torch.exp(self.logs * self.logscale_factor)
+
+
+class Conv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding="same",
+ do_actnorm=True,
+ weight_std=0.05,
+ ):
+ super().__init__()
+
+ if padding == "same":
+ padding = compute_same_pad(kernel_size, stride)
+ elif padding == "valid":
+ padding = 0
+
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=(not do_actnorm),
+ )
+
+ # init weight with std
+ self.conv.weight.data.normal_(mean=0.0, std=weight_std)
+
+ if not do_actnorm:
+ self.conv.bias.data.zero_()
+ else:
+ self.actnorm = ActNorm2d(out_channels)
+
+ self.do_actnorm = do_actnorm
+
+ def forward(self, input):
+ x = self.conv(input)
+ if self.do_actnorm:
+ x, _ = self.actnorm(x)
+ return x
+
+
+class Conv2dZeros(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding="same",
+ logscale_factor=3,
+ ):
+ super().__init__()
+
+ if padding == "same":
+ padding = compute_same_pad(kernel_size, stride)
+ elif padding == "valid":
+ padding = 0
+
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
+
+ self.conv.weight.data.zero_()
+ self.conv.bias.data.zero_()
+
+ self.logscale_factor = logscale_factor
+ self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1))
+
+ def forward(self, input):
+ output = self.conv(input)
+ return output * torch.exp(self.logs * self.logscale_factor)
+
+
+class Permute2d(nn.Module):
+ def __init__(self, num_channels, shuffle):
+ super().__init__()
+ self.num_channels = num_channels
+ self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long)
+ self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long)
+
+ for i in range(self.num_channels):
+ self.indices_inverse[self.indices[i]] = i
+
+ if shuffle:
+ self.reset_indices()
+
+ def reset_indices(self):
+ shuffle_idx = torch.randperm(self.indices.shape[0])
+ self.indices = self.indices[shuffle_idx]
+
+ for i in range(self.num_channels):
+ self.indices_inverse[self.indices[i]] = i
+
+ def forward(self, input, reverse=False):
+ assert len(input.size()) == 4
+
+ if not reverse:
+ input = input[:, self.indices, :, :]
+ return input
+ else:
+ return input[:, self.indices_inverse, :, :]
+
+
+class Split2d(nn.Module):
+ def __init__(self, num_channels):
+ super().__init__()
+ self.conv = Conv2dZeros(num_channels // 2, num_channels)
+
+ def split2d_prior(self, z):
+ h = self.conv(z)
+ return split_feature(h, "cross")
+
+ def forward(self, input, logdet=0.0, reverse=False, temperature=None):
+ if reverse:
+ z1 = input
+ mean, logs = self.split2d_prior(z1)
+ z2 = gaussian_sample(mean, logs, temperature)
+ z = torch.cat((z1, z2), dim=1)
+ return z, logdet
+ else:
+ z1, z2 = split_feature(input, "split")
+ mean, logs = self.split2d_prior(z1)
+ logdet = gaussian_likelihood(mean, logs, z2) + logdet
+ return z1, logdet
+
+
+class SqueezeLayer(nn.Module):
+ def __init__(self, factor):
+ super().__init__()
+ self.factor = factor
+
+ def forward(self, input, logdet=None, reverse=False):
+ if reverse:
+ output = unsqueeze2d(input, self.factor)
+ else:
+ output = squeeze2d(input, self.factor)
+
+ return output, logdet
+
+
+class InvertibleConv1x1(nn.Module):
+ def __init__(self, num_channels, LU_decomposed):
+ super().__init__()
+ w_shape = [num_channels, num_channels]
+ w_init = torch.linalg.qr(torch.randn(*w_shape))[0]
+
+ if not LU_decomposed:
+ self.weight = nn.Parameter(torch.Tensor(w_init))
+ else:
+ p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
+ s = torch.diag(upper)
+ sign_s = torch.sign(s)
+ log_s = torch.log(torch.abs(s))
+ upper = torch.triu(upper, 1)
+ l_mask = torch.tril(torch.ones(w_shape), -1)
+ eye = torch.eye(*w_shape)
+
+ self.register_buffer("p", p)
+ self.register_buffer("sign_s", sign_s)
+ self.lower = nn.Parameter(lower)
+ self.log_s = nn.Parameter(log_s)
+ self.upper = nn.Parameter(upper)
+ self.l_mask = l_mask
+ self.eye = eye
+
+ self.w_shape = w_shape
+ self.LU_decomposed = LU_decomposed
+
+ def get_weight(self, input, reverse):
+ b, c, h, w = input.shape
+
+ if not self.LU_decomposed:
+ dlogdet = torch.slogdet(self.weight)[1] * h * w
+ if reverse:
+ weight = torch.inverse(self.weight)
+ else:
+ weight = self.weight
+ else:
+ self.l_mask = self.l_mask.to(input.device)
+ self.eye = self.eye.to(input.device)
+
+ lower = self.lower * self.l_mask + self.eye
+
+ u = self.upper * self.l_mask.transpose(0, 1).contiguous()
+ u += torch.diag(self.sign_s * torch.exp(self.log_s))
+
+ dlogdet = torch.sum(self.log_s) * h * w
+
+ if reverse:
+ u_inv = torch.inverse(u)
+ l_inv = torch.inverse(lower)
+ p_inv = torch.inverse(self.p)
+
+ weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv))
+ else:
+ weight = torch.matmul(self.p, torch.matmul(lower, u))
+
+ return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet
+
+ def forward(self, input, logdet=None, reverse=False):
+ """
+ log-det = log|abs(|W|)| * pixels
+ """
+ weight, dlogdet = self.get_weight(input, reverse)
+
+ if not reverse:
+ z = F.conv2d(input, weight)
+ if logdet is not None:
+ logdet = logdet + dlogdet
+ return z, logdet
+ else:
+ z = F.conv2d(input, weight)
+ if logdet is not None:
+ logdet = logdet - dlogdet
+ return z, logdet
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/model/utils.py b/imcui/third_party/DarkFeat/datasets/InvISP/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1bef31afd7d61d4c942ffd895c818b90571b4b7
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/model/utils.py
@@ -0,0 +1,52 @@
+import math
+import torch
+
+
+def compute_same_pad(kernel_size, stride):
+ if isinstance(kernel_size, int):
+ kernel_size = [kernel_size]
+
+ if isinstance(stride, int):
+ stride = [stride]
+
+ assert len(stride) == len(
+ kernel_size
+ ), "Pass kernel size and stride both as int, or both as equal length iterable"
+
+ return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)]
+
+
+def uniform_binning_correction(x, n_bits=8):
+ """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0).
+
+ Args:
+ x: 4-D Tensor of shape (NCHW)
+ n_bits: optional.
+ Returns:
+ x: x ~ U(x, x + 1.0 / 256)
+ objective: Equivalent to -q(x)*log(q(x)).
+ """
+ b, c, h, w = x.size()
+ n_bins = 2 ** n_bits
+ chw = c * h * w
+ x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins)
+
+ objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device)
+ return x, objective
+
+
+def split_feature(tensor, type="split"):
+ """
+ type = ["split", "cross"]
+ """
+ C = tensor.size(1)
+ if type == "split":
+ # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...]
+ return tensor[:, :1, ...], tensor[:,1:, ...]
+ elif type == "cross":
+ # return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
+ return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
+
+
+
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/test_raw.py b/imcui/third_party/DarkFeat/datasets/InvISP/test_raw.py
new file mode 100644
index 0000000000000000000000000000000000000000..37610f8268e4586864e0275236c5bb1932f894df
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/test_raw.py
@@ -0,0 +1,118 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import torch
+import numpy as np
+import os, time, random
+import argparse
+from torch.utils.data import Dataset, DataLoader
+from PIL import Image as PILImage
+from glob import glob
+from tqdm import tqdm
+
+from model.model import InvISPNet
+from dataset.FiveK_dataset import FiveKDatasetTest
+from config.config import get_arguments
+
+from utils.JPEG import DiffJPEG
+from utils.commons import denorm, preprocess_test_patch
+
+
+os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
+os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()]))
+# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
+os.system('rm tmp')
+
+DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda()
+
+parser = get_arguments()
+parser.add_argument("--ckpt", type=str, help="Checkpoint path.")
+parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ")
+parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ")
+args = parser.parse_args()
+print("Parsed arguments: {}".format(args))
+
+
+ckpt_name = args.ckpt.split("/")[-1].split(".")[0]
+if args.split_to_patch:
+ os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True)
+ out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name)
+else:
+ os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True)
+ out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name)
+
+
+def main(args):
+ # ======================================define the model============================================
+ net = InvISPNet(channel_in=3, channel_out=3, block_num=8)
+ device = torch.device("cuda:0")
+
+ net.to(device)
+ net.eval()
+ # load the pretrained weight if there exists one
+ if os.path.isfile(args.ckpt):
+ net.load_state_dict(torch.load(args.ckpt), strict=False)
+ print("[INFO] Loaded checkpoint: {}".format(args.ckpt))
+
+ print("[INFO] Start data load and preprocessing")
+ RAWDataset = FiveKDatasetTest(opt=args)
+ dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
+
+ input_RGBs = sorted(glob(out_path+"pred*jpg"))
+ input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs]
+
+ print("[INFO] Start test...")
+ for i_batch, sample_batched in enumerate(tqdm(dataloader)):
+ step_time = time.time()
+
+ input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \
+ sample_batched['target_raw'].to(device)
+ file_name = sample_batched['file_name'][0]
+
+ if args.split_to_patch:
+ input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw)
+ else:
+ # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution
+ input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]]
+
+ for i_patch in range(len(input_list)):
+ file_name_patch = file_name + "_%05d"%i_patch
+ idx = input_RGBs_names.index(file_name_patch)
+ input_RGB_path = input_RGBs[idx]
+ input_RGB = torch.from_numpy(np.array(PILImage.open(input_RGB_path))/255.0).unsqueeze(0).permute(0,3,1,2).float().to(device)
+
+ target_raw_patch = target_raw_list[i_patch]
+
+ with torch.no_grad():
+ reconstruct_raw = net(input_RGB, rev=True)
+
+ pred_raw = reconstruct_raw.detach().permute(0,2,3,1)
+ pred_raw = torch.clamp(pred_raw, 0, 1)
+
+ target_raw_patch = target_raw_patch.permute(0,2,3,1)
+ pred_raw = denorm(pred_raw, 255)
+ target_raw_patch = denorm(target_raw_patch, 255)
+
+ pred_raw = pred_raw.cpu().numpy()
+ target_raw_patch = target_raw_patch.cpu().numpy().astype(np.float32)
+
+ raw_pred = PILImage.fromarray(np.uint8(pred_raw[0,:,:,0]))
+ raw_tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_raw_patch[0,:,:,0]), np.uint8(pred_raw[0,:,:,0]))))
+
+ raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0,:,:,0]))
+
+ raw_pred.save(out_path+"raw_pred_%s_%05d.jpg"%(file_name, i_patch))
+ raw_tar.save(out_path+"raw_tar_%s_%05d.jpg"%(file_name, i_patch))
+ raw_tar_pred.save(out_path+"raw_gt_pred_%s_%05d.jpg"%(file_name, i_patch))
+
+ np.save(out_path+"raw_pred_%s_%05d.npy"%(file_name, i_patch), pred_raw[0,:,:,:]/255.0)
+ np.save(out_path+"raw_tar_%s_%05d.npy"%(file_name, i_patch), target_raw_patch[0,:,:,:]/255.0)
+
+ del reconstruct_raw
+
+
+if __name__ == '__main__':
+
+ torch.set_num_threads(4)
+ main(args)
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/test_rgb.py b/imcui/third_party/DarkFeat/datasets/InvISP/test_rgb.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1e054b899d9142609e3f90f4a12d367a45aeac0
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/test_rgb.py
@@ -0,0 +1,105 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import torch
+import numpy as np
+import os, time, random
+import argparse
+from torch.utils.data import Dataset, DataLoader
+from PIL import Image as PILImage
+
+from model.model import InvISPNet
+from dataset.FiveK_dataset import FiveKDatasetTest
+from config.config import get_arguments
+
+from utils.JPEG import DiffJPEG
+from utils.commons import denorm, preprocess_test_patch
+from tqdm import tqdm
+
+os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
+os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()]))
+# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
+os.system('rm tmp')
+
+DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda()
+
+parser = get_arguments()
+parser.add_argument("--ckpt", type=str, help="Checkpoint path.")
+parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save results. ")
+parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ")
+args = parser.parse_args()
+print("Parsed arguments: {}".format(args))
+
+
+ckpt_name = args.ckpt.split("/")[-1].split(".")[0]
+if args.split_to_patch:
+ os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True)
+ out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name)
+else:
+ os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True)
+ out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name)
+
+
+def main(args):
+ # ======================================define the model============================================
+ net = InvISPNet(channel_in=3, channel_out=3, block_num=8)
+ device = torch.device("cuda:0")
+
+ net.to(device)
+ net.eval()
+ # load the pretrained weight if there exists one
+ if os.path.isfile(args.ckpt):
+ net.load_state_dict(torch.load(args.ckpt), strict=False)
+ print("[INFO] Loaded checkpoint: {}".format(args.ckpt))
+
+ print("[INFO] Start data load and preprocessing")
+ RAWDataset = FiveKDatasetTest(opt=args)
+ dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
+
+ print("[INFO] Start test...")
+ for i_batch, sample_batched in enumerate(tqdm(dataloader)):
+ step_time = time.time()
+
+ input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \
+ sample_batched['target_raw'].to(device)
+ file_name = sample_batched['file_name'][0]
+
+ if args.split_to_patch:
+ input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw)
+ else:
+ # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution
+ input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]]
+
+ for i_patch in range(len(input_list)):
+ input_patch = input_list[i_patch]
+ target_rgb_patch = target_rgb_list[i_patch]
+ target_raw_patch = target_raw_list[i_patch]
+
+ with torch.no_grad():
+ reconstruct_rgb = net(input_patch)
+ reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)
+
+ pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1)
+ target_rgb_patch = target_rgb_patch.permute(0,2,3,1)
+
+ pred_rgb = denorm(pred_rgb, 255)
+ target_rgb_patch = denorm(target_rgb_patch, 255)
+ pred_rgb = pred_rgb.cpu().numpy()
+ target_rgb_patch = target_rgb_patch.cpu().numpy().astype(np.float32)
+
+ # print(type(pred_rgb))
+ pred = PILImage.fromarray(np.uint8(pred_rgb[0,:,:,:]))
+ tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_rgb_patch[0,:,:,:]), np.uint8(pred_rgb[0,:,:,:]))))
+
+ tar = PILImage.fromarray(np.uint8(target_rgb_patch[0,:,:,:]))
+
+ pred.save(out_path+"pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1)
+ tar.save(out_path+"tar_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1)
+ tar_pred.save(out_path+"gt_pred_%s_%05d.jpg"%(file_name, i_patch), quality=90, subsampling=1)
+
+ del reconstruct_rgb
+
+if __name__ == '__main__':
+ torch.set_num_threads(4)
+ main(args)
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/train.py b/imcui/third_party/DarkFeat/datasets/InvISP/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..16186cb38d825ac1299e5c4164799d35bfa79907
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/train.py
@@ -0,0 +1,98 @@
+import numpy as np
+import os, time, random
+import argparse
+import json
+
+import torch.nn.functional as F
+import torch
+from torch.utils.data import Dataset, DataLoader
+from torch.optim import lr_scheduler
+
+from model.model import InvISPNet
+from dataset.FiveK_dataset import FiveKDatasetTrain
+from config.config import get_arguments
+
+from utils.JPEG import DiffJPEG
+
+os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
+os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()]))
+# os.environ['CUDA_VISIBLE_DEVICES'] = "1"
+os.system('rm tmp')
+
+DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda()
+
+parser = get_arguments()
+parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ")
+parser.add_argument("--resume", dest='resume', action='store_true', help="Resume training. ")
+parser.add_argument("--loss", type=str, default="L1", choices=["L1", "L2"], help="Choose which loss function to use. ")
+parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
+parser.add_argument("--aug", dest='aug', action='store_true', help="Use data augmentation.")
+args = parser.parse_args()
+print("Parsed arguments: {}".format(args))
+
+os.makedirs(args.out_path, exist_ok=True)
+os.makedirs(args.out_path+"%s"%args.task, exist_ok=True)
+os.makedirs(args.out_path+"%s/checkpoint"%args.task, exist_ok=True)
+
+with open(args.out_path+"%s/commandline_args.yaml"%args.task , 'w') as f:
+ json.dump(args.__dict__, f, indent=2)
+
+def main(args):
+ # ======================================define the model======================================
+ net = InvISPNet(channel_in=3, channel_out=3, block_num=8)
+ net.cuda()
+ # load the pretrained weight if there exists one
+ if args.resume:
+ net.load_state_dict(torch.load(args.out_path+"%s/checkpoint/latest.pth"%args.task))
+ print("[INFO] loaded " + args.out_path+"%s/checkpoint/latest.pth"%args.task)
+
+ optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
+ scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.5)
+
+ print("[INFO] Start data loading and preprocessing")
+ RAWDataset = FiveKDatasetTrain(opt=args)
+ dataloader = DataLoader(RAWDataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
+
+ print("[INFO] Start to train")
+ step = 0
+ for epoch in range(0, 300):
+ epoch_time = time.time()
+
+ for i_batch, sample_batched in enumerate(dataloader):
+ step_time = time.time()
+
+ input, target_rgb, target_raw = sample_batched['input_raw'].cuda(), sample_batched['target_rgb'].cuda(), \
+ sample_batched['target_raw'].cuda()
+
+ reconstruct_rgb = net(input)
+ reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)
+ rgb_loss = F.l1_loss(reconstruct_rgb, target_rgb)
+ reconstruct_rgb = DiffJPEG(reconstruct_rgb)
+ reconstruct_raw = net(reconstruct_rgb, rev=True)
+ raw_loss = F.l1_loss(reconstruct_raw, target_raw)
+
+ loss = args.rgb_weight * rgb_loss + raw_loss
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ print("task: %s Epoch: %d Step: %d || loss: %.5f raw_loss: %.5f rgb_loss: %.5f || lr: %f time: %f"%(
+ args.task, epoch, step, loss.detach().cpu().numpy(), raw_loss.detach().cpu().numpy(),
+ rgb_loss.detach().cpu().numpy(), optimizer.param_groups[0]['lr'], time.time()-step_time
+ ))
+ step += 1
+
+ torch.save(net.state_dict(), args.out_path+"%s/checkpoint/latest.pth"%args.task)
+ if (epoch+1) % 10 == 0:
+ # os.makedirs(args.out_path+"%s/checkpoint/%04d"%(args.task,epoch), exist_ok=True)
+ torch.save(net.state_dict(), args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch))
+ print("[INFO] Successfully saved "+args.out_path+"%s/checkpoint/%04d.pth"%(args.task,epoch))
+ scheduler.step()
+
+ print("[INFO] Epoch time: ", time.time()-epoch_time, "task: ", args.task)
+
+if __name__ == '__main__':
+
+ torch.set_num_threads(4)
+ main(args)
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py
new file mode 100644
index 0000000000000000000000000000000000000000..8997ee98a41668b4737a9b2acc2341032f173bd3
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG.py
@@ -0,0 +1,43 @@
+
+
+import torch
+import torch.nn as nn
+
+from .JPEG_utils import diff_round, quality_to_factor, Quantization
+from .compression import compress_jpeg
+from .decompression import decompress_jpeg
+
+
+class DiffJPEG(nn.Module):
+ def __init__(self, differentiable=True, quality=75):
+ ''' Initialize the DiffJPEG layer
+ Inputs:
+ height(int): Original image height
+ width(int): Original image width
+ differentiable(bool): If true uses custom differentiable
+ rounding function, if false uses standrard torch.round
+ quality(float): Quality factor for jpeg compression scheme.
+ '''
+ super(DiffJPEG, self).__init__()
+ if differentiable:
+ rounding = diff_round
+ # rounding = Quantization()
+ else:
+ rounding = torch.round
+ factor = quality_to_factor(quality)
+ self.compress = compress_jpeg(rounding=rounding, factor=factor)
+ # self.decompress = decompress_jpeg(height, width, rounding=rounding,
+ # factor=factor)
+ self.decompress = decompress_jpeg(rounding=rounding, factor=factor)
+
+ def forward(self, x):
+ '''
+ '''
+ org_height = x.shape[2]
+ org_width = x.shape[3]
+ y, cb, cr = self.compress(x)
+
+ recovered = self.decompress(y, cb, cr, org_height, org_width)
+ return recovered
+
+
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ebd9bdc184e869ade58eea1c6763baa1d9fc91
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/JPEG_utils.py
@@ -0,0 +1,75 @@
+# Standard libraries
+import numpy as np
+# PyTorch
+import torch
+import torch.nn as nn
+import math
+
+y_table = np.array(
+ [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60,
+ 55], [14, 13, 16, 24, 40, 57, 69, 56],
+ [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103,
+ 77], [24, 35, 55, 64, 81, 104, 113, 92],
+ [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
+ dtype=np.float32).T
+
+y_table = nn.Parameter(torch.from_numpy(y_table))
+#
+c_table = np.empty((8, 8), dtype=np.float32)
+c_table.fill(99)
+c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66],
+ [24, 26, 56, 99], [47, 66, 99, 99]]).T
+c_table = nn.Parameter(torch.from_numpy(c_table))
+
+
+def diff_round_back(x):
+ """ Differentiable rounding function
+ Input:
+ x(tensor)
+ Output:
+ x(tensor)
+ """
+ return torch.round(x) + (x - torch.round(x))**3
+
+
+
+def diff_round(input_tensor):
+ test = 0
+ for n in range(1, 10):
+ test += math.pow(-1, n+1) / n * torch.sin(2 * math.pi * n * input_tensor)
+ final_tensor = input_tensor - 1 / math.pi * test
+ return final_tensor
+
+
+class Quant(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input):
+ input = torch.clamp(input, 0, 1)
+ output = (input * 255.).round() / 255.
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output
+
+class Quantization(nn.Module):
+ def __init__(self):
+ super(Quantization, self).__init__()
+
+ def forward(self, input):
+ return Quant.apply(input)
+
+
+def quality_to_factor(quality):
+ """ Calculate factor corresponding to quality
+ Input:
+ quality(float): Quality for jpeg compression
+ Output:
+ factor(float): Compression factor
+ """
+ if quality < 50:
+ quality = 5000. / quality
+ else:
+ quality = 200. - quality*2
+ return quality / 100.
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/__init__.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/commons.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..e594e0597bac601edc2015d9cae670799f981495
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/commons.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+def denorm(img, max_value):
+ img = img * float(max_value)
+ return img
+
+def preprocess_test_patch(input_image, target_image, gt_image):
+ input_patch_list = []
+ target_patch_list = []
+ gt_patch_list = []
+ H = input_image.shape[2]
+ W = input_image.shape[3]
+ for i in range(3):
+ for j in range(3):
+ input_patch = input_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)]
+ target_patch = target_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)]
+ gt_patch = gt_image[:,:,int(i * H / 3):int((i+1) * H / 3),int(j * W / 3):int((j+1) * W / 3)]
+ input_patch_list.append(input_patch)
+ target_patch_list.append(target_patch)
+ gt_patch_list.append(gt_patch)
+
+ return input_patch_list, target_patch_list, gt_patch_list
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/compression.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/compression.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ae22f8839517bfd7e3c774528943e8fff59dce7
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/compression.py
@@ -0,0 +1,185 @@
+# Standard libraries
+import itertools
+import numpy as np
+# PyTorch
+import torch
+import torch.nn as nn
+# Local
+from . import JPEG_utils
+
+
+class rgb_to_ycbcr_jpeg(nn.Module):
+ """ Converts RGB image to YCbCr
+ Input:
+ image(tensor): batch x 3 x height x width
+ Outpput:
+ result(tensor): batch x height x width x 3
+ """
+ def __init__(self):
+ super(rgb_to_ycbcr_jpeg, self).__init__()
+ matrix = np.array(
+ [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5],
+ [0.5, -0.418688, -0.081312]], dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
+ #
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ image = image.permute(0, 2, 3, 1)
+ result = torch.tensordot(image, self.matrix, dims=1) + self.shift
+ # result = torch.from_numpy(result)
+ result.view(image.shape)
+ return result
+
+
+
+class chroma_subsampling(nn.Module):
+ """ Chroma subsampling on CbCv channels
+ Input:
+ image(tensor): batch x height x width x 3
+ Output:
+ y(tensor): batch x height x width
+ cb(tensor): batch x height/2 x width/2
+ cr(tensor): batch x height/2 x width/2
+ """
+ def __init__(self):
+ super(chroma_subsampling, self).__init__()
+
+ def forward(self, image):
+ image_2 = image.permute(0, 3, 1, 2).clone()
+ avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2),
+ count_include_pad=False)
+ cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1))
+ cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1))
+ cb = cb.permute(0, 2, 3, 1)
+ cr = cr.permute(0, 2, 3, 1)
+ return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
+
+
+class block_splitting(nn.Module):
+ """ Splitting image into patches
+ Input:
+ image(tensor): batch x height x width
+ Output:
+ patch(tensor): batch x h*w/64 x h x w
+ """
+ def __init__(self):
+ super(block_splitting, self).__init__()
+ self.k = 8
+
+ def forward(self, image):
+ height, width = image.shape[1:3]
+ # print(height, width)
+ batch_size = image.shape[0]
+ # print(image.shape)
+ image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
+
+
+class dct_8x8(nn.Module):
+ """ Discrete Cosine Transformation
+ Input:
+ image(tensor): batch x height x width
+ Output:
+ dcp(tensor): batch x height x width
+ """
+ def __init__(self):
+ super(dct_8x8, self).__init__()
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
+ (2 * y + 1) * v * np.pi / 16)
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ #
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+ self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() )
+
+ def forward(self, image):
+ image = image - 128
+ result = self.scale * torch.tensordot(image, self.tensor, dims=2)
+ result.view(image.shape)
+ return result
+
+
+class y_quantize(nn.Module):
+ """ JPEG Quantization for Y channel
+ Input:
+ image(tensor): batch x height x width
+ rounding(function): rounding function to use
+ factor(float): Degree of compression
+ Output:
+ image(tensor): batch x height x width
+ """
+ def __init__(self, rounding, factor=1):
+ super(y_quantize, self).__init__()
+ self.rounding = rounding
+ self.factor = factor
+ self.y_table = JPEG_utils.y_table
+
+ def forward(self, image):
+ image = image.float() / (self.y_table * self.factor)
+ image = self.rounding(image)
+ return image
+
+
+class c_quantize(nn.Module):
+ """ JPEG Quantization for CrCb channels
+ Input:
+ image(tensor): batch x height x width
+ rounding(function): rounding function to use
+ factor(float): Degree of compression
+ Output:
+ image(tensor): batch x height x width
+ """
+ def __init__(self, rounding, factor=1):
+ super(c_quantize, self).__init__()
+ self.rounding = rounding
+ self.factor = factor
+ self.c_table = JPEG_utils.c_table
+
+ def forward(self, image):
+ image = image.float() / (self.c_table * self.factor)
+ image = self.rounding(image)
+ return image
+
+
+class compress_jpeg(nn.Module):
+ """ Full JPEG compression algortihm
+ Input:
+ imgs(tensor): batch x 3 x height x width
+ rounding(function): rounding function to use
+ factor(float): Compression factor
+ Ouput:
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
+ """
+ def __init__(self, rounding=torch.round, factor=1):
+ super(compress_jpeg, self).__init__()
+ self.l1 = nn.Sequential(
+ rgb_to_ycbcr_jpeg(),
+ # comment this line if no subsampling
+ chroma_subsampling()
+ )
+ self.l2 = nn.Sequential(
+ block_splitting(),
+ dct_8x8()
+ )
+ self.c_quantize = c_quantize(rounding=rounding, factor=factor)
+ self.y_quantize = y_quantize(rounding=rounding, factor=factor)
+
+ def forward(self, image):
+ y, cb, cr = self.l1(image*255) # modify
+
+ # y, cb, cr = result[:,:,:,0], result[:,:,:,1], result[:,:,:,2]
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ for k in components.keys():
+ comp = self.l2(components[k])
+ # print(comp.shape)
+ if k in ('cb', 'cr'):
+ comp = self.c_quantize(comp)
+ else:
+ comp = self.y_quantize(comp)
+
+ components[k] = comp
+
+ return components['y'], components['cb'], components['cr']
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/datasets/InvISP/utils/decompression.py b/imcui/third_party/DarkFeat/datasets/InvISP/utils/decompression.py
new file mode 100644
index 0000000000000000000000000000000000000000..b73ff96d5f6818e1d0464b9c4133f559a3b23fba
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/InvISP/utils/decompression.py
@@ -0,0 +1,190 @@
+# Standard libraries
+import itertools
+import numpy as np
+# PyTorch
+import torch
+import torch.nn as nn
+# Local
+from . import JPEG_utils as utils
+
+
+class y_dequantize(nn.Module):
+ """ Dequantize Y channel
+ Inputs:
+ image(tensor): batch x height x width
+ factor(float): compression factor
+ Outputs:
+ image(tensor): batch x height x width
+ """
+ def __init__(self, factor=1):
+ super(y_dequantize, self).__init__()
+ self.y_table = utils.y_table
+ self.factor = factor
+
+ def forward(self, image):
+ return image * (self.y_table * self.factor)
+
+
+class c_dequantize(nn.Module):
+ """ Dequantize CbCr channel
+ Inputs:
+ image(tensor): batch x height x width
+ factor(float): compression factor
+ Outputs:
+ image(tensor): batch x height x width
+ """
+ def __init__(self, factor=1):
+ super(c_dequantize, self).__init__()
+ self.factor = factor
+ self.c_table = utils.c_table
+
+ def forward(self, image):
+ return image * (self.c_table * self.factor)
+
+
+class idct_8x8(nn.Module):
+ """ Inverse discrete Cosine Transformation
+ Input:
+ dcp(tensor): batch x height x width
+ Output:
+ image(tensor): batch x height x width
+ """
+ def __init__(self):
+ super(idct_8x8, self).__init__()
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
+ (2 * v + 1) * y * np.pi / 16)
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+
+ def forward(self, image):
+
+ image = image * self.alpha
+ result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
+ result.view(image.shape)
+ return result
+
+
+class block_merging(nn.Module):
+ """ Merge pathces into image
+ Inputs:
+ patches(tensor) batch x height*width/64, height x width
+ height(int)
+ width(int)
+ Output:
+ image(tensor): batch x height x width
+ """
+ def __init__(self):
+ super(block_merging, self).__init__()
+
+ def forward(self, patches, height, width):
+ k = 8
+ batch_size = patches.shape[0]
+ # print(patches.shape) # (1,1024,8,8)
+ image_reshaped = patches.view(batch_size, height//k, width//k, k, k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, height, width)
+
+
+class chroma_upsampling(nn.Module):
+ """ Upsample chroma layers
+ Input:
+ y(tensor): y channel image
+ cb(tensor): cb channel
+ cr(tensor): cr channel
+ Ouput:
+ image(tensor): batch x height x width x 3
+ """
+ def __init__(self):
+ super(chroma_upsampling, self).__init__()
+
+ def forward(self, y, cb, cr):
+ def repeat(x, k=2):
+ height, width = x.shape[1:3]
+ x = x.unsqueeze(-1)
+ x = x.repeat(1, 1, k, k)
+ x = x.view(-1, height * k, width * k)
+ return x
+
+ cb = repeat(cb)
+ cr = repeat(cr)
+
+ return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
+
+
+class ycbcr_to_rgb_jpeg(nn.Module):
+ """ Converts YCbCr image to RGB JPEG
+ Input:
+ image(tensor): batch x height x width x 3
+ Outpput:
+ result(tensor): batch x 3 x height x width
+ """
+ def __init__(self):
+ super(ycbcr_to_rgb_jpeg, self).__init__()
+
+ matrix = np.array(
+ [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
+ dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ result = torch.tensordot(image + self.shift, self.matrix, dims=1)
+ #result = torch.from_numpy(result)
+ result.view(image.shape)
+ return result.permute(0, 3, 1, 2)
+
+
+class decompress_jpeg(nn.Module):
+ """ Full JPEG decompression algortihm
+ Input:
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
+ rounding(function): rounding function to use
+ factor(float): Compression factor
+ Ouput:
+ image(tensor): batch x 3 x height x width
+ """
+ # def __init__(self, height, width, rounding=torch.round, factor=1):
+ def __init__(self, rounding=torch.round, factor=1):
+ super(decompress_jpeg, self).__init__()
+ self.c_dequantize = c_dequantize(factor=factor)
+ self.y_dequantize = y_dequantize(factor=factor)
+ self.idct = idct_8x8()
+ self.merging = block_merging()
+ # comment this line if no subsampling
+ self.chroma = chroma_upsampling()
+ self.colors = ycbcr_to_rgb_jpeg()
+
+ # self.height, self.width = height, width
+
+ def forward(self, y, cb, cr, height, width):
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ # height = y.shape[0]
+ # width = y.shape[1]
+ self.height = height
+ self.width = width
+ for k in components.keys():
+ if k in ('cb', 'cr'):
+ comp = self.c_dequantize(components[k])
+ # comment this line if no subsampling
+ height, width = int(self.height/2), int(self.width/2)
+ # height, width = int(self.height), int(self.width)
+
+ else:
+ comp = self.y_dequantize(components[k])
+ # comment this line if no subsampling
+ height, width = self.height, self.width
+ comp = self.idct(comp)
+ components[k] = self.merging(comp, height, width)
+ #
+ # comment this line if no subsampling
+ image = self.chroma(components['y'], components['cb'], components['cr'])
+ # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3)
+ image = self.colors(image)
+
+ image = torch.min(255*torch.ones_like(image),
+ torch.max(torch.zeros_like(image), image))
+ return image/255
+
diff --git a/imcui/third_party/DarkFeat/datasets/__init__.py b/imcui/third_party/DarkFeat/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DarkFeat/datasets/gl3d/io.py b/imcui/third_party/DarkFeat/datasets/gl3d/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e5b4b0459d6814ef6af17a0a322b59202037d4f
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/gl3d/io.py
@@ -0,0 +1,76 @@
+import os
+import re
+import cv2
+import numpy as np
+
+from ..utils.common import Notify
+
+def read_list(list_path):
+ """Read list."""
+ if list_path is None or not os.path.exists(list_path):
+ print(Notify.FAIL, 'Not exist', list_path, Notify.ENDC)
+ exit(-1)
+ content = open(list_path).read().splitlines()
+ return content
+
+
+def load_pfm(pfm_path):
+ with open(pfm_path, 'rb') as fin:
+ color = None
+ width = None
+ height = None
+ scale = None
+ data_type = None
+ header = str(fin.readline().decode('UTF-8')).rstrip()
+
+ if header == 'PF':
+ color = True
+ elif header == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$',
+ fin.readline().decode('UTF-8'))
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+ scale = float((fin.readline().decode('UTF-8')).rstrip())
+ if scale < 0: # little-endian
+ data_type = ' 0:
+ img = cv2.resize(
+ img, (config['resize'], config['resize']))
+ return img
+
+
+def _parse_depth(depth_paths, idx, config):
+ depth = load_pfm(depth_paths[idx])
+
+ if config['resize'] > 0:
+ target_size = config['resize']
+ if config['input_type'] == 'raw':
+ depth = cv2.resize(depth, (int(target_size/2), int(target_size/2)))
+ else:
+ depth = cv2.resize(depth, (target_size, target_size))
+ return depth
+
+
+def _parse_kpts(kpts_paths, idx, config):
+ kpts = np.load(kpts_paths[idx])['pts']
+ # output: [N, 2] (W first H last)
+ return kpts
diff --git a/imcui/third_party/DarkFeat/datasets/gl3d_dataset.py b/imcui/third_party/DarkFeat/datasets/gl3d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..db3d2db646ae7fce81424f5f72cdff7e6e34ba60
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/gl3d_dataset.py
@@ -0,0 +1,127 @@
+import os
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from random import shuffle, seed
+
+from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts
+from .utils.common import Notify
+from .utils.photaug import photaug
+
+
+class GL3DDataset(Dataset):
+ def __init__(self, dataset_dir, config, data_split, is_training):
+ self.dataset_dir = dataset_dir
+ self.config = config
+ self.is_training = is_training
+ self.data_split = data_split
+
+ self.match_set_list, self.global_img_list, \
+ self.global_depth_list = self.prepare_match_sets()
+
+ pass
+
+
+ def __len__(self):
+ return len(self.match_set_list)
+
+
+ def __getitem__(self, idx):
+ match_set_path = self.match_set_list[idx]
+ decoded = np.fromfile(match_set_path, dtype=np.float32)
+
+ idx0, idx1 = int(decoded[0]), int(decoded[1])
+ inlier_num = int(decoded[2])
+ ori_img_size0 = np.reshape(decoded[3:5], (2,))
+ ori_img_size1 = np.reshape(decoded[5:7], (2,))
+ K0 = np.reshape(decoded[7:16], (3, 3))
+ K1 = np.reshape(decoded[16:25], (3, 3))
+ rel_pose = np.reshape(decoded[34:46], (3, 4))
+
+ # parse images.
+ img0 = _parse_img(self.global_img_list, idx0, self.config)
+ img1 = _parse_img(self.global_img_list, idx1, self.config)
+ # parse depths
+ depth0 = _parse_depth(self.global_depth_list, idx0, self.config)
+ depth1 = _parse_depth(self.global_depth_list, idx1, self.config)
+
+ # photometric augmentation
+ img0 = photaug(img0)
+ img1 = photaug(img1)
+
+ return {
+ 'img0': img0 / 255.,
+ 'img1': img1 / 255.,
+ 'depth0': depth0,
+ 'depth1': depth1,
+ 'ori_img_size0': ori_img_size0,
+ 'ori_img_size1': ori_img_size1,
+ 'K0': K0,
+ 'K1': K1,
+ 'rel_pose': rel_pose,
+ 'inlier_num': inlier_num
+ }
+
+
+ def points_to_2D(self, pnts, H, W):
+ labels = np.zeros((H, W))
+ pnts = pnts.astype(int)
+ labels[pnts[:, 1], pnts[:, 0]] = 1
+ return labels
+
+
+ def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60):
+ """Get match sets.
+ Args:
+ is_training: Use training imageset or testing imageset.
+ data_split: Data split name.
+ Returns:
+ match_set_list: List of match sets path.
+ global_img_list: List of global image path.
+ global_context_feat_list:
+ """
+ # get necessary lists.
+ gl3d_list_folder = os.path.join(self.dataset_dir, 'list', self.data_split)
+ global_info = read_list(os.path.join(
+ gl3d_list_folder, 'image_index_offset.txt'))
+ global_img_list = [os.path.join(self.dataset_dir, i) for i in read_list(
+ os.path.join(gl3d_list_folder, 'image_list.txt'))]
+ global_depth_list = [os.path.join(self.dataset_dir, i) for i in read_list(
+ os.path.join(gl3d_list_folder, 'depth_list.txt'))]
+
+ imageset_list_name = 'imageset_train.txt' if self.is_training else 'imageset_test.txt'
+ match_set_list = self.get_match_set_list(os.path.join(
+ gl3d_list_folder, imageset_list_name), q_diff_thld, rot_diff_thld)
+ return match_set_list, global_img_list, global_depth_list
+
+
+ def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld):
+ """Get the path list of match sets.
+ Args:
+ imageset_list_path: Path to imageset list.
+ q_diff_thld: Threshold of image pair sampling regarding camera orientation.
+ Returns:
+ match_set_list: List of match set path.
+ """
+ imageset_list = [os.path.join(self.dataset_dir, 'data', i)
+ for i in read_list(imageset_list_path)]
+ print(Notify.INFO, 'Use # imageset', len(imageset_list), Notify.ENDC)
+ match_set_list = []
+ # discard image pairs whose image simiarity is beyond the threshold.
+ for i in imageset_list:
+ match_set_folder = os.path.join(i, 'match_sets')
+ if os.path.exists(match_set_folder):
+ match_set_files = os.listdir(match_set_folder)
+ for val in match_set_files:
+ name, ext = os.path.splitext(val)
+ if ext == '.match_set':
+ splits = name.split('_')
+ q_diff = int(splits[2])
+ rot_diff = int(splits[3])
+ if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld:
+ match_set_list.append(
+ os.path.join(match_set_folder, val))
+
+ print(Notify.INFO, 'Get # match sets', len(match_set_list), Notify.ENDC)
+ return match_set_list
+
diff --git a/imcui/third_party/DarkFeat/datasets/noise.py b/imcui/third_party/DarkFeat/datasets/noise.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa68c98183186e9e9185e78e1a3e7335ac8d5bb1
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/noise.py
@@ -0,0 +1,82 @@
+import numpy as np
+import random
+from scipy.stats import tukeylambda
+
+camera_params = {
+ 'Kmin': 0.2181895124454343,
+ 'Kmax': 3.0,
+ 'G_shape': np.array([0.15714286, 0.14285714, 0.08571429, 0.08571429, 0.2 ,
+ 0.2 , 0.1 , 0.08571429, 0.05714286, 0.07142857,
+ 0.02857143, 0.02857143, 0.01428571, 0.02857143, 0.08571429,
+ 0.07142857, 0.11428571, 0.11428571]),
+ 'Profile-1': {
+ 'R_scale': {
+ 'slope': 0.4712797750747537,
+ 'bias': -0.8078958947116487,
+ 'sigma': 0.2436176299944695
+ },
+ 'g_scale': {
+ 'slope': 0.6771267783987617,
+ 'bias': 1.5121876510805845,
+ 'sigma': 0.24641096601611254
+ },
+ 'G_scale': {
+ 'slope': 0.6558756156508007,
+ 'bias': 1.09268679594838,
+ 'sigma': 0.28604721742277756
+ }
+ },
+ 'black_level': 2048,
+ 'max_value': 16383
+}
+
+
+# photon shot noise
+def addPStarNoise(img, K):
+ return np.random.poisson(img / K).astype(np.float32) * K
+
+
+# read noise
+# tukey lambda distribution
+def addGStarNoise(img, K, G_shape, G_scale_param):
+ # sample a shape parameter [lambda] from histogram of samples
+ a, b = np.histogram(G_shape, bins=10, range=(-0.25, 0.25))
+ a, b = np.array(a), np.array(b)
+ a = a / a.sum()
+
+ rand_num = random.uniform(0, 1)
+ idx = np.sum(np.cumsum(a) < rand_num)
+ lam = random.uniform(b[idx], b[idx+1])
+
+ # calculate scale parameter [G_scale]
+ log_K = np.log(K)
+ log_G_scale = np.random.standard_normal() * G_scale_param['sigma'] * 1 +\
+ G_scale_param['slope'] * log_K + G_scale_param['bias']
+ G_scale = np.exp(log_G_scale)
+ # print(f'G_scale: {G_scale}')
+
+ return img + tukeylambda.rvs(lam, scale=G_scale, size=img.shape).astype(np.float32)
+
+
+# row noise
+# uniform distribution for each row
+def addRowNoise(img, K, R_scale_param):
+ # calculate scale parameter [R_scale]
+ log_K = np.log(K)
+ log_R_scale = np.random.standard_normal() * R_scale_param['sigma'] * 1 +\
+ R_scale_param['slope'] * log_K + R_scale_param['bias']
+ R_scale = np.exp(log_R_scale)
+ # print(f'R_scale: {R_scale}')
+
+ row_noise = np.random.randn(img.shape[0], 1).astype(np.float32) * R_scale
+ return img + np.tile(row_noise, (1, img.shape[1]))
+
+
+# quantization noise
+# uniform distribution
+def addQuantNoise(img, q):
+ return img + np.random.uniform(low=-0.5*q, high=0.5*q, size=img.shape)
+
+
+def sampleK(Kmin, Kmax):
+ return np.exp(np.random.uniform(low=np.log(Kmin), high=np.log(Kmax)))
diff --git a/imcui/third_party/DarkFeat/datasets/noise_simulator.py b/imcui/third_party/DarkFeat/datasets/noise_simulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..17e21d3b3443aaa3585ae8460709f60b05835a84
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/noise_simulator.py
@@ -0,0 +1,244 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import torch
+import numpy as np
+import os, time, random
+import argparse
+from torch.utils.data import Dataset, DataLoader
+from PIL import Image as PILImage
+from glob import glob
+from tqdm import tqdm
+import rawpy
+import colour_demosaicing
+
+from .InvISP.model.model import InvISPNet
+from .utils.common import Notify
+from datasets.noise import camera_params, addGStarNoise, addPStarNoise, addQuantNoise, addRowNoise, sampleK
+
+
+class NoiseSimulator:
+ def __init__(self, device, ckpt_path='./datasets/InvISP/pretrained/canon.pth'):
+ self.device = device
+
+ # load Invertible ISP Network
+ self.net = InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval()
+ self.net.load_state_dict(torch.load(ckpt_path), strict=False)
+ print(Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC)
+
+ # white balance parameters
+ self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0])
+
+ # use Canon EOS 5D4 noise parameters provided by ELD
+ self.camera_params = camera_params
+
+ # random specify exposure time ratio from 50 to 150
+ self.ratio_min = 50
+ self.ratio_max = 150
+ pass
+
+ # inverse demosaic
+ # input: [H, W, 3]
+ # output: [H, W]
+ def invDemosaic(self, img):
+ img_R = img[::2, ::2, 0]
+ img_G1 = img[::2, 1::2, 1]
+ img_G2 = img[1::2, ::2, 1]
+ img_B = img[1::2, 1::2, 2]
+ raw_img = np.ones(img.shape[:2])
+ raw_img[::2, ::2] = img_R
+ raw_img[::2, 1::2] = img_G1
+ raw_img[1::2, ::2] = img_G2
+ raw_img[1::2, 1::2] = img_B
+ return raw_img
+
+ # demosaic - nearest ver
+ # input: [H, W]
+ # output: [H, W, 3]
+ def demosaicNearest(self, img):
+ raw = np.ones((img.shape[0], img.shape[1], 3))
+ raw[::2, ::2, 0] = img[::2, ::2]
+ raw[::2, 1::2, 0] = img[::2, ::2]
+ raw[1::2, ::2, 0] = img[::2, ::2]
+ raw[1::2, 1::2, 0] = img[::2, ::2]
+ raw[::2, ::2, 2] = img[1::2, 1::2]
+ raw[::2, 1::2, 2] = img[1::2, 1::2]
+ raw[1::2, ::2, 2] = img[1::2, 1::2]
+ raw[1::2, 1::2, 2] = img[1::2, 1::2]
+ raw[::2, ::2, 1] = img[::2, 1::2]
+ raw[::2, 1::2, 1] = img[::2, 1::2]
+ raw[1::2, ::2, 1] = img[1::2, ::2]
+ raw[1::2, 1::2, 1] = img[1::2, ::2]
+ return raw
+
+ # demosaic
+ # input: [H, W]
+ # output: [H, W, 3]
+ def demosaic(self, img):
+ return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, 'RGGB')
+
+ # load rgb image
+ def path2rgb(self, path):
+ return torch.from_numpy(np.array(PILImage.open(path))/255.0)
+
+ # InvISP
+ # input: rgb image [H, W, 3]
+ # output: raw image [H, W]
+ def rgb2raw(self, rgb, batched=False):
+ # 1. rgb -> invnet
+ if not batched:
+ rgb = rgb.unsqueeze(0)
+
+ rgb = rgb.permute(0,3,1,2).float().to(self.device)
+ with torch.no_grad():
+ reconstruct_raw = self.net(rgb, rev=True)
+
+ pred_raw = reconstruct_raw.detach().permute(0,2,3,1)
+ pred_raw = torch.clamp(pred_raw, 0, 1)
+
+ if not batched:
+ pred_raw = pred_raw[0, ...]
+
+ pred_raw = pred_raw.cpu().numpy()
+
+ # 2. -> inv gamma
+ norm_value = np.power(16383, 1/2.2)
+ pred_raw *= norm_value
+ pred_raw = np.power(pred_raw, 2.2)
+
+ # 3. -> inv white balance
+ wb = self.wb / self.wb.max()
+ pred_raw = pred_raw / wb[:-1]
+
+ # 4. -> add black level
+ pred_raw += self.camera_params['black_level']
+
+ # 5. -> inv demosaic
+ if not batched:
+ pred_raw = self.invDemosaic(pred_raw)
+ else:
+ preds = []
+ for i in range(pred_raw.shape[0]):
+ preds.append(self.invDemosaic(pred_raw[i]))
+ pred_raw = np.stack(preds, axis=0)
+
+ return pred_raw
+
+
+ def raw2noisyRaw(self, raw, ratio_dec=1, batched=False):
+ if not batched:
+ ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1
+ raw = raw.copy() / ratio
+
+ K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax'])
+ q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level'])
+
+ raw = addPStarNoise(raw, K)
+ raw = addGStarNoise(raw, K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale'])
+ raw = addRowNoise(raw, K, self.camera_params['Profile-1']['R_scale'])
+ raw = addQuantNoise(raw, q)
+ raw *= ratio
+ return raw
+
+ else:
+ raw = raw.copy()
+ for i in range(raw.shape[0]):
+ ratio = random.uniform(self.ratio_min, self.ratio_max)
+ raw[i] /= ratio
+
+ K = sampleK(self.camera_params['Kmin'], self.camera_params['Kmax'])
+ q = 1 / (self.camera_params['max_value'] - self.camera_params['black_level'])
+
+ raw[i] = addPStarNoise(raw[i], K)
+ raw[i] = addGStarNoise(raw[i], K, self.camera_params['G_shape'], self.camera_params['Profile-1']['G_scale'])
+ raw[i] = addRowNoise(raw[i], K, self.camera_params['Profile-1']['R_scale'])
+ raw[i] = addQuantNoise(raw[i], q)
+ raw[i] *= ratio
+ return raw
+
+ def raw2rgb(self, raw, batched=False):
+ # 1. -> demosaic
+ if not batched:
+ raw = self.demosaic(raw)
+ else:
+ raws = []
+ for i in range(raw.shape[0]):
+ raws.append(self.demosaic(raw[i]))
+ raw = np.stack(raws, axis=0)
+
+ # 2. -> substract black level
+ raw -= self.camera_params['black_level']
+ raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level'])
+
+ # 3. -> white balance
+ wb = self.wb / self.wb.max()
+ raw = raw * wb[:-1]
+
+ # 4. -> gamma
+ norm_value = np.power(16383, 1/2.2)
+ raw = np.power(raw, 1/2.2)
+ raw /= norm_value
+
+ # 5. -> ispnet
+ if not batched:
+ input_raw_img = torch.Tensor(raw).permute(2,0,1).float().to(self.device)[np.newaxis, ...]
+ else:
+ input_raw_img = torch.Tensor(raw).permute(0,3,1,2).float().to(self.device)
+
+ with torch.no_grad():
+ reconstruct_rgb = self.net(input_raw_img)
+ reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)
+
+ pred_rgb = reconstruct_rgb.detach().permute(0,2,3,1)
+
+ if not batched:
+ pred_rgb = pred_rgb[0, ...]
+ pred_rgb = pred_rgb.cpu().numpy()
+
+ return pred_rgb
+
+
+ def raw2packedRaw(self, raw, batched=False):
+ # 1. -> substract black level
+ raw -= self.camera_params['black_level']
+ raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level'])
+ raw /= self.camera_params['max_value']
+
+ # 2. pack
+ if not batched:
+ im = np.expand_dims(raw, axis=2)
+ img_shape = im.shape
+ H = img_shape[0]
+ W = img_shape[1]
+
+ out = np.concatenate((im[0:H:2, 0:W:2, :],
+ im[0:H:2, 1:W:2, :],
+ im[1:H:2, 1:W:2, :],
+ im[1:H:2, 0:W:2, :]), axis=2)
+ else:
+ im = np.expand_dims(raw, axis=3)
+ img_shape = im.shape
+ H = img_shape[1]
+ W = img_shape[2]
+
+ out = np.concatenate((im[:, 0:H:2, 0:W:2, :],
+ im[:, 0:H:2, 1:W:2, :],
+ im[:, 1:H:2, 1:W:2, :],
+ im[:, 1:H:2, 0:W:2, :]), axis=3)
+ return out
+
+ def raw2demosaicRaw(self, raw, batched=False):
+ # 1. -> demosaic
+ if not batched:
+ raw = self.demosaic(raw)
+ else:
+ raws = []
+ for i in range(raw.shape[0]):
+ raws.append(self.demosaic(raw[i]))
+ raw = np.stack(raws, axis=0)
+
+ # 2. -> substract black level
+ raw -= self.camera_params['black_level']
+ raw = np.clip(raw, 0, self.camera_params['max_value'] - self.camera_params['black_level'])
+ raw /= self.camera_params['max_value']
+ return raw
diff --git a/imcui/third_party/DarkFeat/datasets/utils/common.py b/imcui/third_party/DarkFeat/datasets/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..6433408a39e53fcedb634901268754ed1ba971b3
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/utils/common.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python
+"""
+Copyright 2017, Zixin Luo, HKUST.
+Commonly used functions
+"""
+
+from __future__ import print_function
+
+import os
+from datetime import datetime
+
+
+class ClassProperty(property):
+ """For dynamically obtaining system time"""
+
+ def __get__(self, cls, owner):
+ return classmethod(self.fget).__get__(None, owner)()
+
+
+class Notify(object):
+ """Colorful printing prefix.
+ A quick example:
+ print(Notify.INFO, YOUR TEXT, Notify.ENDC)
+ """
+
+ def __init__(self):
+ pass
+
+ @ClassProperty
+ def HEADER(cls):
+ return str(datetime.now()) + ': \033[95m'
+
+ @ClassProperty
+ def INFO(cls):
+ return str(datetime.now()) + ': \033[92mI'
+
+ @ClassProperty
+ def OKBLUE(cls):
+ return str(datetime.now()) + ': \033[94m'
+
+ @ClassProperty
+ def WARNING(cls):
+ return str(datetime.now()) + ': \033[93mW'
+
+ @ClassProperty
+ def FAIL(cls):
+ return str(datetime.now()) + ': \033[91mF'
+
+ @ClassProperty
+ def BOLD(cls):
+ return str(datetime.now()) + ': \033[1mB'
+
+ @ClassProperty
+ def UNDERLINE(cls):
+ return str(datetime.now()) + ': \033[4mU'
+ ENDC = '\033[0m'
+
+
diff --git a/imcui/third_party/DarkFeat/datasets/utils/photaug.py b/imcui/third_party/DarkFeat/datasets/utils/photaug.py
new file mode 100644
index 0000000000000000000000000000000000000000..41f2278c720355470f00a881a1516cf1b71d2c4a
--- /dev/null
+++ b/imcui/third_party/DarkFeat/datasets/utils/photaug.py
@@ -0,0 +1,50 @@
+import cv2
+import numpy as np
+import random
+
+
+def random_brightness_np(image, max_abs_change=50):
+ delta = random.uniform(-max_abs_change, max_abs_change)
+ return np.clip(image + delta, 0, 255)
+
+def random_contrast_np(image, strength_range=[0.3, 1.5]):
+ delta = random.uniform(*strength_range)
+ mean = image.mean()
+ return np.clip((image - mean) * delta + mean, 0, 255)
+
+def motion_blur_np(img, max_kernel_size=3):
+ # Either vertial, hozirontal or diagonal blur
+ mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up'])
+ ksize = np.random.randint(
+ 0, (max_kernel_size+1)/2)*2 + 1 # make sure is odd
+ center = int((ksize-1)/2)
+ kernel = np.zeros((ksize, ksize))
+ if mode == 'h':
+ kernel[center, :] = 1.
+ elif mode == 'v':
+ kernel[:, center] = 1.
+ elif mode == 'diag_down':
+ kernel = np.eye(ksize)
+ elif mode == 'diag_up':
+ kernel = np.flip(np.eye(ksize), 0)
+ var = ksize * ksize / 16.
+ grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1)
+ gaussian = np.exp(-(np.square(grid-center) +
+ np.square(grid.T-center))/(2.*var))
+ kernel *= gaussian
+ kernel /= np.sum(kernel)
+ img = cv2.filter2D(img, -1, kernel)
+ return np.clip(img, 0, 255)
+
+def additive_gaussian_noise(image, stddev_range=[5, 95]):
+ stddev = random.uniform(*stddev_range)
+ noise = np.random.normal(size=image.shape, scale=stddev)
+ noisy_image = np.clip(image + noise, 0, 255)
+ return noisy_image
+
+def photaug(img):
+ img = random_brightness_np(img)
+ img = random_contrast_np(img)
+ # img = additive_gaussian_noise(img)
+ img = motion_blur_np(img)
+ return img
diff --git a/imcui/third_party/DarkFeat/demo_darkfeat.py b/imcui/third_party/DarkFeat/demo_darkfeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca50ae5b892e7a90e75da7197c33bc0c06e699bf
--- /dev/null
+++ b/imcui/third_party/DarkFeat/demo_darkfeat.py
@@ -0,0 +1,124 @@
+from pathlib import Path
+import argparse
+import cv2
+import matplotlib.cm as cm
+import torch
+import numpy as np
+from utils.nnmatching import NNMatching
+from utils.misc import (AverageTimer, VideoStreamer, make_matching_plot_fast, frame2tensor)
+
+torch.set_grad_enabled(False)
+
+
+def compute_essential(matched_kp1, matched_kp2, K):
+ pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+ pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+ K_1 = np.eye(3)
+ # Estimate the homography between the matches using RANSAC
+ ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000)
+ if ransac_inliers is None or ransac_model.shape != (3,3):
+ ransac_inliers = np.array([])
+ ransac_model = None
+ return ransac_model, ransac_inliers, pts1, pts2
+
+
+sizer = (960, 640)
+focallength_x = 4.504986436499113e+03/(6744/sizer[0])
+focallength_y = 4.513311442889859e+03/(4502/sizer[1])
+K = np.eye(3)
+K[0,0] = focallength_x
+K[1,1] = focallength_y
+K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5
+K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='DarkFeat demo',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '--input', type=str,
+ help='path to an image directory')
+ parser.add_argument(
+ '--output_dir', type=str, default=None,
+ help='Directory where to write output frames (If None, no output)')
+
+ parser.add_argument(
+ '--image_glob', type=str, nargs='+', default=['*.ARW'],
+ help='Glob if a directory of images is specified')
+ parser.add_argument(
+ '--resize', type=int, nargs='+', default=[640, 480],
+ help='Resize the input image before running inference. If two numbers, '
+ 'resize to the exact dimensions, if one number, resize the max '
+ 'dimension, if -1, do not resize')
+ parser.add_argument(
+ '--force_cpu', action='store_true',
+ help='Force pytorch to run in CPU mode.')
+ parser.add_argument('--model_path', type=str,
+ help='Path to the pretrained model')
+
+ opt = parser.parse_args()
+ print(opt)
+
+ assert len(opt.resize) == 2
+ print('Will resize to {}x{} (WxH)'.format(opt.resize[0], opt.resize[1]))
+
+ device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
+ print('Running inference on device \"{}\"'.format(device))
+ matching = NNMatching(opt.model_path).eval().to(device)
+ keys = ['keypoints', 'scores', 'descriptors']
+
+ vs = VideoStreamer(opt.input, opt.resize, opt.image_glob)
+ frame, ret = vs.next_frame()
+ assert ret, 'Error when reading the first frame (try different --input?)'
+
+ frame_tensor = frame2tensor(frame, device)
+ last_data = matching.darkfeat({'image': frame_tensor})
+ last_data = {k+'0': [last_data[k]] for k in keys}
+ last_data['image0'] = frame_tensor
+ last_frame = frame
+ last_image_id = 0
+
+ if opt.output_dir is not None:
+ print('==> Will write outputs to {}'.format(opt.output_dir))
+ Path(opt.output_dir).mkdir(exist_ok=True)
+
+ timer = AverageTimer()
+
+ while True:
+ frame, ret = vs.next_frame()
+ if not ret:
+ print('Finished demo_darkfeat.py')
+ break
+ timer.update('data')
+ stem0, stem1 = last_image_id, vs.i - 1
+
+ frame_tensor = frame2tensor(frame, device)
+ pred = matching({**last_data, 'image1': frame_tensor})
+ kpts0 = last_data['keypoints0'][0].cpu().numpy()
+ kpts1 = pred['keypoints1'][0].cpu().numpy()
+ matches = pred['matches0'][0].cpu().numpy()
+ confidence = pred['matching_scores0'][0].cpu().numpy()
+ timer.update('forward')
+
+ valid = matches > -1
+ mkpts0 = kpts0[valid]
+ mkpts1 = kpts1[matches[valid]]
+
+ E, inliers, pts1, pts2 = compute_essential(mkpts0, mkpts1, K)
+ color = cm.jet(np.clip(confidence[valid][inliers[:, 0].astype('bool')] * 2 - 1, -1, 1))
+
+ text = [
+ 'DarkFeat',
+ 'Matches: {}'.format(inliers.sum())
+ ]
+
+ out = make_matching_plot_fast(
+ last_frame, frame, mkpts0[inliers[:, 0].astype('bool')], mkpts1[inliers[:, 0].astype('bool')], color, text,
+ path=None, small_text=' ')
+
+ if opt.output_dir is not None:
+ stem = 'matches_{:06}_{:06}'.format(stem0, stem1)
+ out_file = str(Path(opt.output_dir, stem + '.png'))
+ print('Writing image to {}'.format(out_file))
+ cv2.imwrite(out_file, out)
diff --git a/imcui/third_party/DarkFeat/export_features.py b/imcui/third_party/DarkFeat/export_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7caea5e57890948728f84cbb7e68e59d455e171
--- /dev/null
+++ b/imcui/third_party/DarkFeat/export_features.py
@@ -0,0 +1,128 @@
+import argparse
+import glob
+import math
+import subprocess
+import numpy as np
+import os
+import tqdm
+import torch
+import torch.nn as nn
+import cv2
+from darkfeat import DarkFeat
+from utils import matching
+
+def darkfeat_pre(img, cuda):
+ H, W = img.shape[0], img.shape[1]
+ inp = img.copy()
+ inp = inp.transpose(2, 0, 1)
+ inp = torch.from_numpy(inp)
+ inp = torch.autograd.Variable(inp).view(1, 3, H, W)
+ if cuda:
+ inp = inp.cuda()
+ return inp
+
+if __name__ == '__main__':
+ # Parse command line arguments.
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--H', type=int, default=int(640))
+ parser.add_argument('--W', type=int, default=int(960))
+ parser.add_argument('--histeq', action='store_true')
+ parser.add_argument('--model_path', type=str)
+ parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/')
+ opt = parser.parse_args()
+
+ sizer = (opt.W, opt.H)
+ focallength_x = 4.504986436499113e+03/(6744/sizer[0])
+ focallength_y = 4.513311442889859e+03/(4502/sizer[1])
+ K = np.eye(3)
+ K[0,0] = focallength_x
+ K[1,1] = focallength_y
+ K[0,2] = 3.363322177533149e+03/(6744/sizer[0])# * 0.5
+ K[1,2] = 2.291824660547715e+03/(4502/sizer[1])# * 0.5
+ Kinv = np.linalg.inv(K)
+ Kinvt = np.transpose(Kinv)
+
+ cuda = True
+ if cuda:
+ darkfeat = DarkFeat(opt.model_path).cuda().eval()
+
+ for scene in ['Indoor', 'Outdoor']:
+ base_save = './result/' + scene + '/'
+ dir_base = opt.dataset_dir + '/' + scene + '/'
+ pair_list = sorted(os.listdir(dir_base))
+
+ for pair in tqdm.tqdm(pair_list):
+ opention = 1
+ if scene == 'Outdoor':
+ pass
+ else:
+ if int(pair[4::]) <= 17:
+ opention = 0
+ else:
+ pass
+ name=[]
+ files = sorted(os.listdir(dir_base+pair))
+ for file_ in files:
+ if file_.endswith('.cr2'):
+ name.append(file_[0:9])
+ ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800']
+ if opention == 1:
+ Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5']
+ else:
+ Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1']
+
+ E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy')
+ F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv)
+ R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy')
+ t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy')
+
+ id0, id1 = sorted([ int(i.split('/')[-1]) for i in glob.glob(f'{dir_base+pair}/?????') ])
+
+ cnt = 0
+
+ for iso in ISO:
+ for ex in Shutter_speed:
+ dark_name1 = name[0] + iso+'_'+ex+'_'+scene+'.npy'
+ dark_name2 = name[1] + iso+'_'+ex+'_'+scene+'.npy'
+
+ if not opt.histeq:
+ dst_T1_None = f'{dir_base}{pair}/{id0:05d}-npy-nohisteq/{dark_name1}'
+ dst_T2_None = f'{dir_base}{pair}/{id1:05d}-npy-nohisteq/{dark_name2}'
+
+ img1_orig_None = np.load(dst_T1_None)
+ img2_orig_None = np.load(dst_T2_None)
+
+ dir_save = base_save + pair + '/None/'
+
+ img_input1 = darkfeat_pre(img1_orig_None.astype('float32')/255.0, cuda)
+ img_input2 = darkfeat_pre(img2_orig_None.astype('float32')/255.0, cuda)
+
+ else:
+ dst_T1_histeq = f'{dir_base}{pair}/{id0:05d}-npy/{dark_name1}'
+ dst_T2_histeq = f'{dir_base}{pair}/{id1:05d}-npy/{dark_name2}'
+
+ img1_orig_histeq = np.load(dst_T1_histeq)
+ img2_orig_histeq = np.load(dst_T2_histeq)
+
+ dir_save = base_save + pair + '/HistEQ/'
+
+ img_input1 = darkfeat_pre(img1_orig_histeq.astype('float32')/255.0, cuda)
+ img_input2 = darkfeat_pre(img2_orig_histeq.astype('float32')/255.0, cuda)
+
+ result1 = darkfeat({'image': img_input1})
+ result2 = darkfeat({'image': img_input2})
+
+ mkpts0, mkpts1, _ = matching.match_descriptors(
+ cv2.KeyPoint_convert(result1['keypoints'].detach().cpu().float().numpy()), result1['descriptors'].detach().cpu().numpy(),
+ cv2.KeyPoint_convert(result2['keypoints'].detach().cpu().float().numpy()), result2['descriptors'].detach().cpu().numpy(),
+ ORB=False
+ )
+
+ POINT_1_dir = dir_save+f'DarkFeat/POINT_1/'
+ POINT_2_dir = dir_save+f'DarkFeat/POINT_2/'
+
+ subprocess.check_output(['mkdir', '-p', POINT_1_dir])
+ subprocess.check_output(['mkdir', '-p', POINT_2_dir])
+ np.save(POINT_1_dir+dark_name1[0:-3]+'npy',mkpts0)
+ np.save(POINT_2_dir+dark_name2[0:-3]+'npy',mkpts1)
+
diff --git a/imcui/third_party/DarkFeat/nets/__init__.py b/imcui/third_party/DarkFeat/nets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DarkFeat/nets/geom.py b/imcui/third_party/DarkFeat/nets/geom.py
new file mode 100644
index 0000000000000000000000000000000000000000..043ca6e8f5917c56defd6aa17c1ff236a431f8c0
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/geom.py
@@ -0,0 +1,323 @@
+import time
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def rnd_sample(inputs, n_sample):
+ cur_size = inputs[0].shape[0]
+ rnd_idx = torch.randperm(cur_size)[0:n_sample]
+ outputs = [i[rnd_idx] for i in inputs]
+ return outputs
+
+
+def _grid_positions(h, w, bs):
+ x_rng = torch.arange(0, w.int())
+ y_rng = torch.arange(0, h.int())
+ xv, yv = torch.meshgrid(x_rng, y_rng, indexing='xy')
+ return torch.reshape(
+ torch.stack((yv, xv), axis=-1),
+ (1, -1, 2)
+ ).repeat(bs, 1, 1).float()
+
+
+def getK(ori_img_size, cur_feat_size, K):
+ # WARNING: cur_feat_size's order is [h, w]
+ r = ori_img_size / cur_feat_size[[1, 0]]
+ r_K0 = torch.stack([K[:, 0] / r[:, 0][..., None], K[:, 1] /
+ r[:, 1][..., None], K[:, 2]], axis=1)
+ return r_K0
+
+
+def gather_nd(params, indices):
+ """ The same as tf.gather_nd but batched gather is not supported yet.
+ indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params:
+
+ output[\\(i_0, ..., i_{k-2}\\)] = params[indices[\\(i_0, ..., i_{k-2}\\)]]
+
+ Args:
+ params (Tensor): "n" dimensions. shape: [x_0, x_1, x_2, ..., x_{n-1}]
+ indices (Tensor): "k" dimensions. shape: [y_0,y_2,...,y_{k-2}, m]. m <= n.
+
+ Returns: gathered Tensor.
+ shape [y_0,y_2,...y_{k-2}] + params.shape[m:]
+
+ """
+ orig_shape = list(indices.shape)
+ num_samples = np.prod(orig_shape[:-1])
+ m = orig_shape[-1]
+ n = len(params.shape)
+
+ if m <= n:
+ out_shape = orig_shape[:-1] + list(params.shape)[m:]
+ else:
+ raise ValueError(
+ f'the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}'
+ )
+
+ indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
+ output = params[indices] # (num_samples, ...)
+ return output.reshape(out_shape).contiguous()
+
+# input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W]
+# output: [kpt_n, 128] / [kpt_n]
+def interpolate(pos, inputs, nd=True):
+ h = inputs.shape[0]
+ w = inputs.shape[1]
+
+ i = pos[:, 0]
+ j = pos[:, 1]
+
+ i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1)
+ j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1)
+
+ i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1)
+ j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)
+
+ i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1)
+ j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1)
+
+ i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1)
+ j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)
+
+ dist_i_top_left = i - i_top_left.float()
+ dist_j_top_left = j - j_top_left.float()
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
+ w_bottom_right = dist_i_top_left * dist_j_top_left
+
+ if nd:
+ w_top_left = w_top_left[..., None]
+ w_top_right = w_top_right[..., None]
+ w_bottom_left = w_bottom_left[..., None]
+ w_bottom_right = w_bottom_right[..., None]
+
+ interpolated_val = (
+ w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) +
+ w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) +
+ w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) +
+ w_bottom_right *
+ gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
+ )
+
+ return interpolated_val
+
+
+def validate_and_interpolate(pos, inputs, validate_corner=True, validate_val=None, nd=False):
+ if nd:
+ h, w, c = inputs.shape
+ else:
+ h, w = inputs.shape
+ ids = torch.arange(0, pos.shape[0])
+
+ i = pos[:, 0]
+ j = pos[:, 1]
+
+ i_top_left = torch.floor(i).int()
+ j_top_left = torch.floor(j).int()
+
+ i_top_right = torch.floor(i).int()
+ j_top_right = torch.ceil(j).int()
+
+ i_bottom_left = torch.ceil(i).int()
+ j_bottom_left = torch.floor(j).int()
+
+ i_bottom_right = torch.ceil(i).int()
+ j_bottom_right = torch.ceil(j).int()
+
+ if validate_corner:
+ # Valid corner
+ valid_top_left = torch.logical_and(i_top_left >= 0, j_top_left >= 0)
+ valid_top_right = torch.logical_and(i_top_right >= 0, j_top_right < w)
+ valid_bottom_left = torch.logical_and(i_bottom_left < h, j_bottom_left >= 0)
+ valid_bottom_right = torch.logical_and(i_bottom_right < h, j_bottom_right < w)
+
+ valid_corner = torch.logical_and(
+ torch.logical_and(valid_top_left, valid_top_right),
+ torch.logical_and(valid_bottom_left, valid_bottom_right)
+ )
+
+ i_top_left = i_top_left[valid_corner]
+ j_top_left = j_top_left[valid_corner]
+
+ i_top_right = i_top_right[valid_corner]
+ j_top_right = j_top_right[valid_corner]
+
+ i_bottom_left = i_bottom_left[valid_corner]
+ j_bottom_left = j_bottom_left[valid_corner]
+
+ i_bottom_right = i_bottom_right[valid_corner]
+ j_bottom_right = j_bottom_right[valid_corner]
+
+ ids = ids[valid_corner]
+
+ if validate_val is not None:
+ # Valid depth
+ valid_depth = torch.logical_and(
+ torch.logical_and(
+ gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) > 0,
+ gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) > 0
+ ),
+ torch.logical_and(
+ gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) > 0,
+ gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) > 0
+ )
+ )
+
+ i_top_left = i_top_left[valid_depth]
+ j_top_left = j_top_left[valid_depth]
+
+ i_top_right = i_top_right[valid_depth]
+ j_top_right = j_top_right[valid_depth]
+
+ i_bottom_left = i_bottom_left[valid_depth]
+ j_bottom_left = j_bottom_left[valid_depth]
+
+ i_bottom_right = i_bottom_right[valid_depth]
+ j_bottom_right = j_bottom_right[valid_depth]
+
+ ids = ids[valid_depth]
+
+ # Interpolation
+ i = i[ids]
+ j = j[ids]
+ dist_i_top_left = i - i_top_left.float()
+ dist_j_top_left = j - j_top_left.float()
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
+ w_bottom_right = dist_i_top_left * dist_j_top_left
+
+ if nd:
+ w_top_left = w_top_left[..., None]
+ w_top_right = w_top_right[..., None]
+ w_bottom_left = w_bottom_left[..., None]
+ w_bottom_right = w_bottom_right[..., None]
+
+ interpolated_val = (
+ w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) +
+ w_top_right * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) +
+ w_bottom_left * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) +
+ w_bottom_right * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
+ )
+
+ pos = torch.stack([i, j], axis=1)
+ return [interpolated_val, pos, ids]
+
+
+# pos0: [2, 230400, 2]
+# depth0: [2, 480, 480]
+def getWarp(pos0, rel_pose, depth0, K0, depth1, K1, bs):
+ def swap_axis(data):
+ return torch.stack([data[:, 1], data[:, 0]], axis=-1)
+
+ all_pos0 = []
+ all_pos1 = []
+ all_ids = []
+ for i in range(bs):
+ z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0)
+
+ uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1)
+ xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t())
+ xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo,
+ torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0)
+
+ xyz1 = torch.matmul(rel_pose[i], xyz0_homo)
+ xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
+ uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2]
+
+ new_pos1 = swap_axis(uv1)
+ annotated_depth, new_pos1, new_ids = validate_and_interpolate(
+ new_pos1, depth1[i], validate_val=0)
+
+ ids = ids[new_ids]
+ new_pos0 = new_pos0[new_ids]
+ estimated_depth = xyz1.t()[new_ids][:, -1]
+
+ inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05
+
+ all_ids.append(ids[inlier_mask])
+ all_pos0.append(new_pos0[inlier_mask])
+ all_pos1.append(new_pos1[inlier_mask])
+ # all_pos0 & all_pose1: [inlier_num, 2] * batch_size
+ return all_pos0, all_pos1, all_ids
+
+
+# pos0: [2, 230400, 2]
+# depth0: [2, 480, 480]
+def getWarpNoValidate(pos0, rel_pose, depth0, K0, depth1, K1, bs):
+ def swap_axis(data):
+ return torch.stack([data[:, 1], data[:, 0]], axis=-1)
+
+ all_pos0 = []
+ all_pos1 = []
+ all_ids = []
+ for i in range(bs):
+ z0, new_pos0, ids = validate_and_interpolate(pos0[i], depth0[i], validate_val=0)
+
+ uv0_homo = torch.cat([swap_axis(new_pos0), torch.ones((new_pos0.shape[0], 1)).to(new_pos0.device)], axis=-1)
+ xy0_homo = torch.matmul(torch.linalg.inv(K0[i]), uv0_homo.t())
+ xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo,
+ torch.ones((1, new_pos0.shape[0])).to(z0.device)], axis=0)
+
+ xyz1 = torch.matmul(rel_pose[i], xyz0_homo)
+ xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
+ uv1 = torch.matmul(K1[i], xy1_homo).t()[:, 0:2]
+
+ new_pos1 = swap_axis(uv1)
+ _, new_pos1, new_ids = validate_and_interpolate(
+ new_pos1, depth1[i], validate_val=0)
+
+ ids = ids[new_ids]
+ new_pos0 = new_pos0[new_ids]
+
+ all_ids.append(ids)
+ all_pos0.append(new_pos0)
+ all_pos1.append(new_pos1)
+ # all_pos0 & all_pose1: [inlier_num, 2] * batch_size
+ return all_pos0, all_pos1, all_ids
+
+
+# pos0: [2, 230400, 2]
+# depth0: [2, 480, 480]
+def getWarpNoValidate2(pos0, rel_pose, depth0, K0, depth1, K1):
+ def swap_axis(data):
+ return torch.stack([data[:, 1], data[:, 0]], axis=-1)
+
+ z0 = interpolate(pos0, depth0, nd=False)
+
+ uv0_homo = torch.cat([swap_axis(pos0), torch.ones((pos0.shape[0], 1)).to(pos0.device)], axis=-1)
+ xy0_homo = torch.matmul(torch.linalg.inv(K0), uv0_homo.t())
+ xyz0_homo = torch.cat([torch.unsqueeze(z0, 0) * xy0_homo,
+ torch.ones((1, pos0.shape[0])).to(z0.device)], axis=0)
+
+ xyz1 = torch.matmul(rel_pose, xyz0_homo)
+ xy1_homo = xyz1 / torch.unsqueeze(xyz1[-1, :], axis=0)
+ uv1 = torch.matmul(K1, xy1_homo).t()[:, 0:2]
+
+ new_pos1 = swap_axis(uv1)
+
+ return new_pos1
+
+
+
+def get_dist_mat(feat1, feat2, dist_type):
+ eps = 1e-6
+ cos_dist_mat = torch.matmul(feat1, feat2.t())
+ if dist_type == 'cosine_dist':
+ dist_mat = torch.clamp(cos_dist_mat, -1, 1)
+ elif dist_type == 'euclidean_dist':
+ dist_mat = torch.sqrt(torch.clamp(2 - 2 * cos_dist_mat, min=eps))
+ elif dist_type == 'euclidean_dist_no_norm':
+ norm1 = torch.sum(feat1 * feat1, axis=-1, keepdims=True)
+ norm2 = torch.sum(feat2 * feat2, axis=-1, keepdims=True)
+ dist_mat = torch.sqrt(
+ torch.clamp(
+ norm1 - 2 * cos_dist_mat + norm2.t(),
+ min=0.
+ ) + eps
+ )
+ else:
+ raise NotImplementedError()
+ return dist_mat
diff --git a/imcui/third_party/DarkFeat/nets/l2net.py b/imcui/third_party/DarkFeat/nets/l2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1ddfe8919bd4d5fe75215d253525123e1402952
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/l2net.py
@@ -0,0 +1,116 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+
+from .score import peakiness_score
+
+
+class BaseNet(nn.Module):
+ """ Helper class to construct a fully-convolutional network that
+ extract a l2-normalized patch descriptor.
+ """
+ def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
+ super(BaseNet, self).__init__()
+ self.inchan = inchan
+ self.curchan = inchan
+ self.dilated = dilated
+ self.dilation = dilation
+ self.bn = bn
+ self.bn_affine = bn_affine
+
+ def _make_bn(self, outd):
+ return nn.BatchNorm2d(outd, affine=self.bn_affine)
+
+ def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False):
+ # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer
+ d = self.dilation * dilation
+ # if self.dilated:
+ # conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1)
+ # self.dilation *= stride
+ # else:
+ # conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride)
+ conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias)
+
+ ops = nn.ModuleList([])
+
+ ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) )
+ if bn and self.bn: ops.append( self._make_bn(outd) )
+ if relu: ops.append( nn.ReLU(inplace=True) )
+ self.curchan = outd
+
+ if k_pool > 1:
+ if pool_type == 'avg':
+ ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
+ elif pool_type == 'max':
+ ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
+ else:
+ print(f"Error, unknown pooling type {pool_type}...")
+
+ return nn.Sequential(*ops)
+
+
+class Quad_L2Net(BaseNet):
+ """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.
+ """
+ def __init__(self, dim=128, mchan=4, relu22=False, **kw):
+ BaseNet.__init__(self, **kw)
+ self.conv0 = self._add_conv( 8*mchan)
+ self.conv1 = self._add_conv( 8*mchan, bn=False)
+ self.bn1 = self._make_bn(8*mchan)
+ self.conv2 = self._add_conv( 16*mchan, stride=2)
+ self.conv3 = self._add_conv( 16*mchan, bn=False)
+ self.bn3 = self._make_bn(16*mchan)
+ self.conv4 = self._add_conv( 32*mchan, stride=2)
+ self.conv5 = self._add_conv( 32*mchan)
+ # replace last 8x8 convolution with 3 3x3 convolutions
+ self.conv6_0 = self._add_conv( 32*mchan)
+ self.conv6_1 = self._add_conv( 32*mchan)
+ self.conv6_2 = self._add_conv(dim, bn=False, relu=False)
+ self.out_dim = dim
+
+ self.moving_avg_params = nn.ParameterList([
+ Parameter(torch.tensor(1.), requires_grad=False),
+ Parameter(torch.tensor(1.), requires_grad=False),
+ Parameter(torch.tensor(1.), requires_grad=False)
+ ])
+
+ def forward(self, x):
+ # x: [N, C, H, W]
+ x0 = self.conv0(x)
+ x1 = self.conv1(x0)
+ x1_bn = self.bn1(x1)
+ x2 = self.conv2(x1_bn)
+ x3 = self.conv3(x2)
+ x3_bn = self.bn3(x3)
+ x4 = self.conv4(x3_bn)
+ x5 = self.conv5(x4)
+ x6_0 = self.conv6_0(x5)
+ x6_1 = self.conv6_1(x6_0)
+ x6_2 = self.conv6_2(x6_1)
+
+ # calculate score map
+ comb_weights = torch.tensor([1., 2., 3.], device=x.device)
+ comb_weights /= torch.sum(comb_weights)
+ ksize = [3, 2, 1]
+ det_score_maps = []
+
+ for idx, xx in enumerate([x1, x3, x6_2]):
+ if self.training:
+ instance_max = torch.max(xx)
+ self.moving_avg_params[idx].data = self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01
+ else:
+ pass
+
+ alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx])
+
+ score_vol = alpha * beta
+ det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0]
+ det_score_map = F.interpolate(det_score_map, size=x.shape[2:], mode='bilinear', align_corners=True)
+ det_score_map = comb_weights[idx] * det_score_map
+ det_score_maps.append(det_score_map)
+
+ det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0)
+ # print([param.data for param in self.moving_avg_params])
+
+ return x6_2, det_score_map, x1, x3
diff --git a/imcui/third_party/DarkFeat/nets/loss.py b/imcui/third_party/DarkFeat/nets/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dd42b4214d021137ddfe72771ccad0264d2321f
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/loss.py
@@ -0,0 +1,260 @@
+import torch
+import torch.nn.functional as F
+
+from .geom import rnd_sample, interpolate, get_dist_mat
+
+
+def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1,
+ score_map0, score_map1, batch_size, num_corr, loss_type, config):
+ joint_loss = 0.
+ accuracy = 0.
+ all_valid_pos0 = []
+ all_valid_pos1 = []
+ all_valid_match = []
+ for i in range(batch_size):
+ # random sample
+ valid_pos0, valid_pos1 = rnd_sample([pos0[i], pos1[i]], num_corr)
+ valid_num = valid_pos0.shape[0]
+
+ valid_feat0 = interpolate(valid_pos0 / 4, dense_feat_map0[i])
+ valid_feat1 = interpolate(valid_pos1 / 4, dense_feat_map1[i])
+
+ valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1)
+ valid_feat1 = F.normalize(valid_feat1, p=2, dim=-1)
+
+ valid_score0 = interpolate(valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False)
+ valid_score1 = interpolate(valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False)
+
+ if config['network']['det']['corr_weight']:
+ corr_weight = valid_score0 * valid_score1
+ else:
+ corr_weight = None
+
+ safe_radius = config['network']['det']['safe_radius']
+ if safe_radius > 0:
+ radius_mask_row = get_dist_mat(
+ valid_pos1, valid_pos1, "euclidean_dist_no_norm")
+ radius_mask_row = torch.le(radius_mask_row, safe_radius)
+ radius_mask_col = get_dist_mat(
+ valid_pos0, valid_pos0, "euclidean_dist_no_norm")
+ radius_mask_col = torch.le(radius_mask_col, safe_radius)
+ radius_mask_row = radius_mask_row.float() - torch.eye(valid_num, device=radius_mask_row.device)
+ radius_mask_col = radius_mask_col.float() - torch.eye(valid_num, device=radius_mask_col.device)
+ else:
+ radius_mask_row = None
+ radius_mask_col = None
+
+ if valid_num < 32:
+ si_loss, si_accuracy, matched_mask = 0., 1., torch.zeros((1, valid_num)).bool()
+ else:
+ si_loss, si_accuracy, matched_mask = make_structured_loss(
+ torch.unsqueeze(valid_feat0, 0), torch.unsqueeze(valid_feat1, 0),
+ loss_type=loss_type,
+ radius_mask_row=radius_mask_row, radius_mask_col=radius_mask_col,
+ corr_weight=torch.unsqueeze(corr_weight, 0) if corr_weight is not None else None
+ )
+
+ joint_loss += si_loss / batch_size
+ accuracy += si_accuracy / batch_size
+ all_valid_match.append(torch.squeeze(matched_mask, dim=0))
+ all_valid_pos0.append(valid_pos0)
+ all_valid_pos1.append(valid_pos1)
+
+ return joint_loss, accuracy
+
+
+def make_structured_loss(feat_anc, feat_pos,
+ loss_type='RATIO', inlier_mask=None,
+ radius_mask_row=None, radius_mask_col=None,
+ corr_weight=None, dist_mat=None):
+ """
+ Structured loss construction.
+ Args:
+ feat_anc, feat_pos: Feature matrix.
+ loss_type: Loss type.
+ inlier_mask:
+ Returns:
+
+ """
+ batch_size = feat_anc.shape[0]
+ num_corr = feat_anc.shape[1]
+ if inlier_mask is None:
+ inlier_mask = torch.ones((batch_size, num_corr), device=feat_anc.device).bool()
+ inlier_num = torch.count_nonzero(inlier_mask.float(), dim=-1)
+
+ if loss_type == 'L2NET' or loss_type == 'CIRCLE':
+ dist_type = 'cosine_dist'
+ elif loss_type.find('HARD') >= 0:
+ dist_type = 'euclidean_dist'
+ else:
+ raise NotImplementedError()
+
+ if dist_mat is None:
+ dist_mat = get_dist_mat(feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type).unsqueeze(0)
+ pos_vec = dist_mat[0].diag().unsqueeze(0)
+
+ if loss_type.find('HARD') >= 0:
+ neg_margin = 1
+ dist_mat_without_min_on_diag = dist_mat + \
+ 10 * torch.unsqueeze(torch.eye(num_corr, device=dist_mat.device), dim=0)
+ mask = torch.le(dist_mat_without_min_on_diag, 0.008).float()
+ dist_mat_without_min_on_diag += mask*10
+
+ if radius_mask_row is not None:
+ hard_neg_dist_row = dist_mat_without_min_on_diag + 10 * radius_mask_row
+ else:
+ hard_neg_dist_row = dist_mat_without_min_on_diag
+ if radius_mask_col is not None:
+ hard_neg_dist_col = dist_mat_without_min_on_diag + 10 * radius_mask_col
+ else:
+ hard_neg_dist_col = dist_mat_without_min_on_diag
+
+ hard_neg_dist_row = torch.min(hard_neg_dist_row, dim=-1)[0]
+ hard_neg_dist_col = torch.min(hard_neg_dist_col, dim=-2)[0]
+
+ if loss_type == 'HARD_TRIPLET':
+ loss_row = torch.clamp(neg_margin + pos_vec - hard_neg_dist_row, min=0)
+ loss_col = torch.clamp(neg_margin + pos_vec - hard_neg_dist_col, min=0)
+ elif loss_type == 'HARD_CONTRASTIVE':
+ pos_margin = 0.2
+ pos_loss = torch.clamp(pos_vec - pos_margin, min=0)
+ loss_row = pos_loss + torch.clamp(neg_margin - hard_neg_dist_row, min=0)
+ loss_col = pos_loss + torch.clamp(neg_margin - hard_neg_dist_col, min=0)
+ else:
+ raise NotImplementedError()
+
+ elif loss_type == 'CIRCLE':
+ log_scale = 512
+ m = 0.1
+ neg_mask_row = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0)
+ if radius_mask_row is not None:
+ neg_mask_row += radius_mask_row
+ neg_mask_col = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0)
+ if radius_mask_col is not None:
+ neg_mask_col += radius_mask_col
+
+ pos_margin = 1 - m
+ neg_margin = m
+ pos_optimal = 1 + m
+ neg_optimal = -m
+
+ neg_mat_row = dist_mat - 128 * neg_mask_row
+ neg_mat_col = dist_mat - 128 * neg_mask_col
+
+ lse_positive = torch.logsumexp(-log_scale * (pos_vec[..., None] - pos_margin) * \
+ torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(), dim=-1)
+
+ lse_negative_row = torch.logsumexp(log_scale * (neg_mat_row - neg_margin) * \
+ torch.clamp(neg_mat_row - neg_optimal, min=0).detach(), dim=-1)
+
+ lse_negative_col = torch.logsumexp(log_scale * (neg_mat_col - neg_margin) * \
+ torch.clamp(neg_mat_col - neg_optimal, min=0).detach(), dim=-2)
+
+ loss_row = F.softplus(lse_positive + lse_negative_row) / log_scale
+ loss_col = F.softplus(lse_positive + lse_negative_col) / log_scale
+
+ else:
+ raise NotImplementedError()
+
+ if dist_type == 'cosine_dist':
+ err_row = dist_mat - torch.unsqueeze(pos_vec, -1)
+ err_col = dist_mat - torch.unsqueeze(pos_vec, -2)
+ elif dist_type == 'euclidean_dist' or dist_type == 'euclidean_dist_no_norm':
+ err_row = torch.unsqueeze(pos_vec, -1) - dist_mat
+ err_col = torch.unsqueeze(pos_vec, -2) - dist_mat
+ else:
+ raise NotImplementedError()
+ if radius_mask_row is not None:
+ err_row = err_row - 10 * radius_mask_row
+ if radius_mask_col is not None:
+ err_col = err_col - 10 * radius_mask_col
+ err_row = torch.sum(torch.clamp(err_row, min=0), dim=-1)
+ err_col = torch.sum(torch.clamp(err_col, min=0), dim=-2)
+
+ loss = 0
+ accuracy = 0
+
+ tot_loss = (loss_row + loss_col) / 2
+ if corr_weight is not None:
+ tot_loss = tot_loss * corr_weight
+
+ for i in range(batch_size):
+ if corr_weight is not None:
+ loss += torch.sum(tot_loss[i][inlier_mask[i]]) / \
+ (torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6)
+ else:
+ loss += torch.mean(tot_loss[i][inlier_mask[i]])
+ cnt_err_row = torch.count_nonzero(err_row[i][inlier_mask[i]]).float()
+ cnt_err_col = torch.count_nonzero(err_col[i][inlier_mask[i]]).float()
+ tot_err = cnt_err_row + cnt_err_col
+ if inlier_num[i] != 0:
+ accuracy += 1. - tot_err / inlier_num[i] / batch_size / 2.
+ else:
+ accuracy += 1.
+
+ matched_mask = torch.logical_and(torch.eq(err_row, 0), torch.eq(err_col, 0))
+ matched_mask = torch.logical_and(matched_mask, inlier_mask)
+
+ loss /= batch_size
+ accuracy /= batch_size
+
+ return loss, accuracy, matched_mask
+
+
+# for the neighborhood areas of keypoints extracted from normal image, the score from noise_score_map should be close
+# for the rest, the noise image's score should less than normal image
+# input: score_map [batch_size, H, W, 1]; indices [2, k, 2]
+# output: loss [scalar]
+def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, thld=0.):
+ H, W = score_map.shape[1:3]
+ loss = 0
+ for i in range(batch_size):
+ kpts_coords = indices[i].T # (2, num_kpts)
+ mask = torch.zeros([H, W], device=score_map.device)
+ mask[kpts_coords.cpu().numpy()] = 1
+
+ # using 3x3 kernel to put kpts' neightborhood area into the mask
+ kernel = torch.ones([1, 1, 3, 3], device=score_map.device)
+ mask = F.conv2d(mask.unsqueeze(0).unsqueeze(0), kernel, padding=1)[0, 0] > 0
+
+ loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask)
+ loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask))
+
+ loss += loss1
+ loss += loss2
+
+ if i == 0:
+ first_mask = mask
+
+ return loss, first_mask
+
+
+def make_noise_score_map_loss_labelmap(score_map, noise_score_map, labelmap, batch_size, thld=0.):
+ H, W = score_map.shape[1:3]
+ loss = 0
+ for i in range(batch_size):
+ # using 3x3 kernel to put kpts' neightborhood area into the mask
+ kernel = torch.ones([1, 1, 3, 3], device=score_map.device)
+ mask = F.conv2d(labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1)[0, 0] > 0
+
+ loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask)
+ loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask))
+
+ loss += loss1
+ loss += loss2
+
+ if i == 0:
+ first_mask = mask
+
+ return loss, first_mask
+
+
+def make_score_map_peakiness_loss(score_map, scores, batch_size):
+ H, W = score_map.shape[1:3]
+ loss = 0
+
+ for i in range(batch_size):
+ loss += torch.mean(scores[i]) - torch.mean(score_map[i])
+
+ loss /= batch_size
+ return 1 - loss
diff --git a/imcui/third_party/DarkFeat/nets/multi_sampler.py b/imcui/third_party/DarkFeat/nets/multi_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc400fb2afeb50575cd81d3c01b605bea6db1121
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/multi_sampler.py
@@ -0,0 +1,172 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from .geom import rnd_sample, interpolate
+
+class MultiSampler (nn.Module):
+ """ Similar to NghSampler, but doesnt warp the 2nd image.
+ Distance to GT => 0 ... pos_d ... neg_d ... ngh
+ Pixel label => + + + + + + 0 0 - - - - - - -
+
+ Subsample on query side: if > 0, regular grid
+ < 0, random points
+ In both cases, the number of query points is = W*H/subq**2
+ """
+ def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None,
+ maxpool_pos=True, subd_neg=0):
+ nn.Module.__init__(self)
+ assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
+ self.ngh = ngh
+ self.pos_d = pos_d
+ self.neg_d = neg_d
+ assert subd <= ngh or ngh == 0
+ assert subq != 0
+ self.sub_q = subq
+ self.sub_d = subd
+ self.sub_d_neg = subd_neg
+ if border is None: border = ngh
+ assert border >= ngh, 'border has to be larger than ngh'
+ self.border = border
+ self.maxpool_pos = maxpool_pos
+ self.precompute_offsets()
+
+ def precompute_offsets(self):
+ pos_d2 = self.pos_d**2
+ neg_d2 = self.neg_d**2
+ rad2 = self.ngh**2
+ rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
+ pos = []
+ neg = []
+ for j in range(-rad, rad+1, self.sub_d):
+ for i in range(-rad, rad+1, self.sub_d):
+ d2 = i*i + j*j
+ if d2 <= pos_d2:
+ pos.append( (i,j) )
+ elif neg_d2 <= d2 <= rad2:
+ neg.append( (i,j) )
+
+ self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
+ self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
+
+
+ def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=2500):
+ pscores_ls, nscores_ls, distractors_ls = [], [], []
+ valid_feat0_ls = []
+ noise_pscores_ls, noise_nscores_ls, noise_distractors_ls = [], [], []
+ valid_noise_feat0_ls = []
+ valid_pos1_ls, valid_pos2_ls = [], []
+ qconf_ls = []
+ noise_qconf_ls = []
+ mask_ls = []
+
+ for i in range(B):
+ tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \
+ * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border)
+
+ selected_pos0 = pos0[i][tmp_mask]
+ selected_pos1 = pos1[i][tmp_mask]
+ valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N)
+
+ # sample features from first image
+ valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128]
+ valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128]
+ qconf = interpolate(valid_pos0 / 4, conf0[i])
+
+ valid_noise_feat0 = interpolate(valid_pos0 / 4, noise_feat0[i]) # [N, 128]
+ valid_noise_feat0 = F.normalize(valid_noise_feat0, p=2, dim=-1) # [N, 128]
+ noise_qconf = interpolate(valid_pos0 / 4, noise_conf0[i])
+
+ # sample GT from second image
+ mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \
+ * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H)
+
+ def clamp(xy):
+ xy = xy
+ torch.clamp(xy[0], 0, H-1, out=xy[0])
+ torch.clamp(xy[1], 0, W-1, out=xy[1])
+ return xy
+
+ # compute positive scores
+ valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
+ valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
+ valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128]
+ valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128]
+ valid_noise_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128]
+ valid_noise_feat1p = F.normalize(valid_noise_feat1p, p=2, dim=-1) # [29, N, 128]
+
+ pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29]
+ pscores, pos = pscores.max(dim=1, keepdim=True)
+ sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device))
+ qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2
+ noise_pscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1p).sum(dim=-1).t() # [N, 29]
+ noise_pscores, noise_pos = noise_pscores.max(dim=1, keepdim=True)
+ noise_sel = clamp(valid_pos1.t() + self.pos_offsets[:,noise_pos.view(-1)].to(valid_pos1.device))
+ noise_qconf = (noise_qconf + interpolate(noise_sel.t() / 4, noise_conf1[i]))/2
+
+ # compute negative scores
+ valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
+ valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
+ valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128]
+ valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128]
+ nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29]
+ valid_noise_feat1n = interpolate(valid_pos1n / 4, noise_feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128]
+ valid_noise_feat1n = F.normalize(valid_noise_feat1n, p=2, dim=-1) # [29, N, 128]
+ noise_nscores = (valid_noise_feat0[None,:,:] * valid_noise_feat1n).sum(dim=-1).t() # [N, 29]
+
+ if self.sub_d_neg:
+ valid_pos2 = rnd_sample([selected_pos1], N)[0]
+ distractors = interpolate(valid_pos2 / 4, feat1[i])
+ distractors = F.normalize(distractors, p=2, dim=-1)
+ noise_distractors = interpolate(valid_pos2 / 4, noise_feat1[i])
+ noise_distractors = F.normalize(noise_distractors, p=2, dim=-1)
+
+ pscores_ls.append(pscores)
+ nscores_ls.append(nscores)
+ distractors_ls.append(distractors)
+ valid_feat0_ls.append(valid_feat0)
+ noise_pscores_ls.append(noise_pscores)
+ noise_nscores_ls.append(noise_nscores)
+ noise_distractors_ls.append(noise_distractors)
+ valid_noise_feat0_ls.append(valid_noise_feat0)
+ valid_pos1_ls.append(valid_pos1)
+ valid_pos2_ls.append(valid_pos2)
+ qconf_ls.append(qconf)
+ noise_qconf_ls.append(noise_qconf)
+ mask_ls.append(mask)
+
+ N = np.min([len(i) for i in qconf_ls])
+
+ # merge batches
+ qconf = torch.stack([i[:N] for i in qconf_ls], dim=0).squeeze(-1)
+ mask = torch.stack([i[:N] for i in mask_ls], dim=0)
+ pscores = torch.cat([i[:N] for i in pscores_ls], dim=0)
+ nscores = torch.cat([i[:N] for i in nscores_ls], dim=0)
+ distractors = torch.cat([i[:N] for i in distractors_ls], dim=0)
+ valid_feat0 = torch.cat([i[:N] for i in valid_feat0_ls], dim=0)
+ valid_pos1 = torch.cat([i[:N] for i in valid_pos1_ls], dim=0)
+ valid_pos2 = torch.cat([i[:N] for i in valid_pos2_ls], dim=0)
+
+ noise_qconf = torch.stack([i[:N] for i in noise_qconf_ls], dim=0).squeeze(-1)
+ noise_pscores = torch.cat([i[:N] for i in noise_pscores_ls], dim=0)
+ noise_nscores = torch.cat([i[:N] for i in noise_nscores_ls], dim=0)
+ noise_distractors = torch.cat([i[:N] for i in noise_distractors_ls], dim=0)
+ valid_noise_feat0 = torch.cat([i[:N] for i in valid_noise_feat0_ls], dim=0)
+
+ # remove scores that corresponds to positives or nulls
+ dscores = torch.matmul(valid_feat0, distractors.t())
+ noise_dscores = torch.matmul(valid_noise_feat0, noise_distractors.t())
+
+ dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2
+ b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1)
+ dis2 += (b != b[:,None]).long() * self.neg_d**2
+ dscores[dis2 < self.neg_d**2] = 0
+ noise_dscores[dis2 < self.neg_d**2] = 0
+ scores = torch.cat((pscores, nscores, dscores), dim=1)
+ noise_scores = torch.cat((noise_pscores, noise_nscores, noise_dscores), dim=1)
+
+ gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
+ gt[:, :pscores.shape[1]] = 1
+
+ return scores, noise_scores, gt, mask, qconf, noise_qconf
diff --git a/imcui/third_party/DarkFeat/nets/noise_reliability_loss.py b/imcui/third_party/DarkFeat/nets/noise_reliability_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9efddae149653c225ee7f2c1eb5fed5f92cef15c
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/noise_reliability_loss.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn as nn
+from .reliability_loss import APLoss
+
+
+class MultiPixelAPLoss (nn.Module):
+ """ Computes the pixel-wise AP loss:
+ Given two images and ground-truth optical flow, computes the AP per pixel.
+
+ feat1: (B, C, H, W) pixel-wise features extracted from img1
+ feat2: (B, C, H, W) pixel-wise features extracted from img2
+ aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2
+ """
+ def __init__(self, sampler, nq=20):
+ nn.Module.__init__(self)
+ self.aploss = APLoss(nq, min=0, max=1, euc=False)
+ self.sampler = sampler
+ self.base = 0.25
+ self.dec_base = 0.20
+
+ def loss_from_ap(self, ap, rel, noise_ap, noise_rel):
+ dec_ap = torch.clamp(ap - noise_ap, min=0, max=1)
+ return (1 - ap*noise_rel - (1-noise_rel)*self.base), (1. - dec_ap*(1-noise_rel) - noise_rel*self.dec_base)
+
+ def forward(self, feat0, feat1, noise_feat0, noise_feat1, conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500):
+ # subsample things
+ scores, noise_scores, gt, msk, qconf, noise_qconf = self.sampler(feat0, feat1, noise_feat0, noise_feat1, \
+ conf0, conf1, noise_conf0, noise_conf1, pos0, pos1, B, H, W, N=1500)
+
+ # compute pixel-wise AP
+ n = qconf.numel()
+ if n == 0: return 0, 0
+ scores, noise_scores, gt = scores.view(n,-1), noise_scores, gt.view(n,-1)
+ ap = self.aploss(scores, gt).view(msk.shape)
+ noise_ap = self.aploss(noise_scores, gt).view(msk.shape)
+
+ pixel_loss = self.loss_from_ap(ap, qconf, noise_ap, noise_qconf)
+
+ loss = pixel_loss[0][msk].mean(), pixel_loss[1][msk].mean()
+ return loss
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/nets/reliability_loss.py b/imcui/third_party/DarkFeat/nets/reliability_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..527f9886a2d4785680bac52ff2fa20033b8d8920
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/reliability_loss.py
@@ -0,0 +1,105 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class APLoss (nn.Module):
+ """ differentiable AP loss, through quantization.
+
+ Input: (N, M) values in [min, max]
+ label: (N, M) values in {0, 1}
+
+ Returns: list of query AP (for each n in {1..N})
+ Note: typically, you want to minimize 1 - mean(AP)
+ """
+ def __init__(self, nq=25, min=0, max=1, euc=False):
+ nn.Module.__init__(self)
+ assert isinstance(nq, int) and 2 <= nq <= 100
+ self.nq = nq
+ self.min = min
+ self.max = max
+ self.euc = euc
+ gap = max - min
+ assert gap > 0
+
+ # init quantizer = non-learnable (fixed) convolution
+ self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True)
+ a = (nq-1) / gap
+ #1st half = lines passing to (min+x,1) and (min+x+1/a,0) with x = {nq-1..0}*gap/(nq-1)
+ q.weight.data[:nq] = -a
+ q.bias.data[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x)
+ #2nd half = lines passing to (min+x,1) and (min+x-1/a,0) with x = {nq-1..0}*gap/(nq-1)
+ q.weight.data[nq:] = a
+ q.bias.data[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x)
+ # first and last one are special: just horizontal straight line
+ q.weight.data[0] = q.weight.data[-1] = 0
+ q.bias.data[0] = q.bias.data[-1] = 1
+
+ def compute_AP(self, x, label):
+ N, M = x.shape
+ # print(x.shape, label.shape)
+ if self.euc: # euclidean distance in same range than similarities
+ x = 1 - torch.sqrt(2.001 - 2*x)
+
+ # quantize all predictions
+ q = self.quantizer(x.unsqueeze(1))
+ q = torch.min(q[:,:self.nq], q[:,self.nq:]).clamp(min=0) # N x Q x M [1600, 20, 1681]
+
+ nbs = q.sum(dim=-1) # number of samples N x Q = c
+ rec = (q * label.view(N,1,M).float()).sum(dim=-1) # nb of correct samples = c+ N x Q
+ prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision
+ rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1]
+
+ ap = (prec * rec).sum(dim=-1) # per-image AP
+ return ap
+
+ def forward(self, x, label):
+ assert x.shape == label.shape # N x M
+ return self.compute_AP(x, label)
+
+
+class PixelAPLoss (nn.Module):
+ """ Computes the pixel-wise AP loss:
+ Given two images and ground-truth optical flow, computes the AP per pixel.
+
+ feat1: (B, C, H, W) pixel-wise features extracted from img1
+ feat2: (B, C, H, W) pixel-wise features extracted from img2
+ aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2
+ """
+ def __init__(self, sampler, nq=20):
+ nn.Module.__init__(self)
+ self.aploss = APLoss(nq, min=0, max=1, euc=False)
+ self.name = 'pixAP'
+ self.sampler = sampler
+
+ def loss_from_ap(self, ap, rel):
+ return 1 - ap
+
+ def forward(self, feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200):
+ # subsample things
+ scores, gt, msk, qconf = self.sampler(feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=1200)
+
+ # compute pixel-wise AP
+ n = qconf.numel()
+ if n == 0: return 0
+ scores, gt = scores.view(n,-1), gt.view(n,-1)
+ ap = self.aploss(scores, gt).view(msk.shape)
+
+ pixel_loss = self.loss_from_ap(ap, qconf)
+
+ loss = pixel_loss[msk].mean()
+ return loss
+
+
+class ReliabilityLoss (PixelAPLoss):
+ """ same than PixelAPLoss, but also train a pixel-wise confidence
+ that this pixel is going to have a good AP.
+ """
+ def __init__(self, sampler, base=0.5, **kw):
+ PixelAPLoss.__init__(self, sampler, **kw)
+ assert 0 <= base < 1
+ self.base = base
+
+ def loss_from_ap(self, ap, rel):
+ return 1 - ap*rel - (1-rel)*self.base
+
diff --git a/imcui/third_party/DarkFeat/nets/sampler.py b/imcui/third_party/DarkFeat/nets/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b732a3671872d5675be9826f76b0818d3b99d466
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/sampler.py
@@ -0,0 +1,160 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from .geom import rnd_sample, interpolate
+
+class NghSampler2 (nn.Module):
+ """ Similar to NghSampler, but doesnt warp the 2nd image.
+ Distance to GT => 0 ... pos_d ... neg_d ... ngh
+ Pixel label => + + + + + + 0 0 - - - - - - -
+
+ Subsample on query side: if > 0, regular grid
+ < 0, random points
+ In both cases, the number of query points is = W*H/subq**2
+ """
+ def __init__(self, ngh, subq=1, subd=1, pos_d=0, neg_d=2, border=None,
+ maxpool_pos=True, subd_neg=0):
+ nn.Module.__init__(self)
+ assert 0 <= pos_d < neg_d <= (ngh if ngh else 99)
+ self.ngh = ngh
+ self.pos_d = pos_d
+ self.neg_d = neg_d
+ assert subd <= ngh or ngh == 0
+ assert subq != 0
+ self.sub_q = subq
+ self.sub_d = subd
+ self.sub_d_neg = subd_neg
+ if border is None: border = ngh
+ assert border >= ngh, 'border has to be larger than ngh'
+ self.border = border
+ self.maxpool_pos = maxpool_pos
+ self.precompute_offsets()
+
+ def precompute_offsets(self):
+ pos_d2 = self.pos_d**2
+ neg_d2 = self.neg_d**2
+ rad2 = self.ngh**2
+ rad = (self.ngh//self.sub_d) * self.ngh # make an integer multiple
+ pos = []
+ neg = []
+ for j in range(-rad, rad+1, self.sub_d):
+ for i in range(-rad, rad+1, self.sub_d):
+ d2 = i*i + j*j
+ if d2 <= pos_d2:
+ pos.append( (i,j) )
+ elif neg_d2 <= d2 <= rad2:
+ neg.append( (i,j) )
+
+ self.register_buffer('pos_offsets', torch.LongTensor(pos).view(-1,2).t())
+ self.register_buffer('neg_offsets', torch.LongTensor(neg).view(-1,2).t())
+
+ def gen_grid(self, step, B, H, W, dev):
+ b1 = torch.arange(B, device=dev)
+ if step > 0:
+ # regular grid
+ x1 = torch.arange(self.border, W-self.border, step, device=dev)
+ y1 = torch.arange(self.border, H-self.border, step, device=dev)
+ H1, W1 = len(y1), len(x1)
+ x1 = x1[None,None,:].expand(B,H1,W1).reshape(-1)
+ y1 = y1[None,:,None].expand(B,H1,W1).reshape(-1)
+ b1 = b1[:,None,None].expand(B,H1,W1).reshape(-1)
+ shape = (B, H1, W1)
+ else:
+ # randomly spread
+ n = (H - 2*self.border) * (W - 2*self.border) // step**2
+ x1 = torch.randint(self.border, W-self.border, (n,), device=dev)
+ y1 = torch.randint(self.border, H-self.border, (n,), device=dev)
+ x1 = x1[None,:].expand(B,n).reshape(-1)
+ y1 = y1[None,:].expand(B,n).reshape(-1)
+ b1 = b1[:,None].expand(B,n).reshape(-1)
+ shape = (B, n)
+ return b1, y1, x1, shape
+
+ def forward(self, feat0, feat1, conf0, conf1, pos0, pos1, B, H, W, N=2500):
+ pscores_ls, nscores_ls, distractors_ls = [], [], []
+ valid_feat0_ls = []
+ valid_pos1_ls, valid_pos2_ls = [], []
+ qconf_ls = []
+ mask_ls = []
+
+ for i in range(B):
+ # positions in the first image
+ tmp_mask = (pos0[i][:, 1] >= self.border) * (pos0[i][:, 1] < W-self.border) \
+ * (pos0[i][:, 0] >= self.border) * (pos0[i][:, 0] < H-self.border)
+
+ selected_pos0 = pos0[i][tmp_mask]
+ selected_pos1 = pos1[i][tmp_mask]
+ valid_pos0, valid_pos1 = rnd_sample([selected_pos0, selected_pos1], N)
+
+ # sample features from first image
+ valid_feat0 = interpolate(valid_pos0 / 4, feat0[i]) # [N, 128]
+ valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) # [N, 128]
+ qconf = interpolate(valid_pos0 / 4, conf0[i])
+
+ # sample GT from second image
+ mask = (valid_pos1[:, 1] >= 0) * (valid_pos1[:, 1] < W) \
+ * (valid_pos1[:, 0] >= 0) * (valid_pos1[:, 0] < H)
+
+ def clamp(xy):
+ xy = xy
+ torch.clamp(xy[0], 0, H-1, out=xy[0])
+ torch.clamp(xy[1], 0, W-1, out=xy[1])
+ return xy
+
+ # compute positive scores
+ valid_pos1p = clamp(valid_pos1.t()[:,None,:] + self.pos_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
+ valid_pos1p = valid_pos1p.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
+ valid_feat1p = interpolate(valid_pos1p / 4, feat1[i]).reshape(self.pos_offsets.shape[-1], -1, 128) # [29, N, 128]
+ valid_feat1p = F.normalize(valid_feat1p, p=2, dim=-1) # [29, N, 128]
+
+ pscores = (valid_feat0[None,:,:] * valid_feat1p).sum(dim=-1).t() # [N, 29]
+ pscores, pos = pscores.max(dim=1, keepdim=True)
+ sel = clamp(valid_pos1.t() + self.pos_offsets[:,pos.view(-1)].to(valid_pos1.device))
+ qconf = (qconf + interpolate(sel.t() / 4, conf1[i]))/2
+
+ # compute negative scores
+ valid_pos1n = clamp(valid_pos1.t()[:,None,:] + self.neg_offsets[:,:,None].to(valid_pos1.device)) # [2, 29, N]
+ valid_pos1n = valid_pos1n.permute(1, 2, 0).reshape(-1, 2) # [29, N, 2] -> [29*N, 2]
+ valid_feat1n = interpolate(valid_pos1n / 4, feat1[i]).reshape(self.neg_offsets.shape[-1], -1, 128) # [29, N, 128]
+ valid_feat1n = F.normalize(valid_feat1n, p=2, dim=-1) # [29, N, 128]
+ nscores = (valid_feat0[None,:,:] * valid_feat1n).sum(dim=-1).t() # [N, 29]
+
+ if self.sub_d_neg:
+ valid_pos2 = rnd_sample([selected_pos1], N)[0]
+ distractors = interpolate(valid_pos2 / 4, feat1[i])
+ distractors = F.normalize(distractors, p=2, dim=-1)
+
+ pscores_ls.append(pscores)
+ nscores_ls.append(nscores)
+ distractors_ls.append(distractors)
+ valid_feat0_ls.append(valid_feat0)
+ valid_pos1_ls.append(valid_pos1)
+ valid_pos2_ls.append(valid_pos2)
+ qconf_ls.append(qconf)
+ mask_ls.append(mask)
+
+ N = np.min([len(i) for i in qconf_ls])
+
+ # merge batches
+ qconf = torch.stack([i[:N] for i in qconf_ls], dim=0).squeeze(-1)
+ mask = torch.stack([i[:N] for i in mask_ls], dim=0)
+ pscores = torch.cat([i[:N] for i in pscores_ls], dim=0)
+ nscores = torch.cat([i[:N] for i in nscores_ls], dim=0)
+ distractors = torch.cat([i[:N] for i in distractors_ls], dim=0)
+ valid_feat0 = torch.cat([i[:N] for i in valid_feat0_ls], dim=0)
+ valid_pos1 = torch.cat([i[:N] for i in valid_pos1_ls], dim=0)
+ valid_pos2 = torch.cat([i[:N] for i in valid_pos2_ls], dim=0)
+
+ dscores = torch.matmul(valid_feat0, distractors.t())
+ dis2 = (valid_pos2[:, 1] - valid_pos1[:, 1][:,None])**2 + (valid_pos2[:, 0] - valid_pos1[:, 0][:,None])**2
+ b = torch.arange(B, device=dscores.device)[:,None].expand(B, N).reshape(-1)
+ dis2 += (b != b[:,None]).long() * self.neg_d**2
+ dscores[dis2 < self.neg_d**2] = 0
+ scores = torch.cat((pscores, nscores, dscores), dim=1)
+
+ gt = scores.new_zeros(scores.shape, dtype=torch.uint8)
+ gt[:, :pscores.shape[1]] = 1
+
+ return scores, gt, mask, qconf
diff --git a/imcui/third_party/DarkFeat/nets/score.py b/imcui/third_party/DarkFeat/nets/score.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78cf1c893bc338c12803697d55e121a75171f2c
--- /dev/null
+++ b/imcui/third_party/DarkFeat/nets/score.py
@@ -0,0 +1,116 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+from .geom import gather_nd
+
+# input: [batch_size, C, H, W]
+# output: [batch_size, C, H, W], [batch_size, C, H, W]
+def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
+ inputs = inputs / moving_instance_max
+
+ batch_size, C, H, W = inputs.shape
+
+ pad_size = ksize // 2 + (dilation - 1)
+ kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize)
+
+ pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect')
+
+ avg_spatial_inputs = F.conv2d(
+ pad_inputs,
+ kernel,
+ stride=1,
+ dilation=dilation,
+ padding=0,
+ groups=C
+ )
+ avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1
+
+ alpha = F.softplus(inputs - avg_spatial_inputs)
+ beta = F.softplus(inputs - avg_channel_inputs)
+
+ return alpha, beta
+
+
+# input: score_map [batch_size, 1, H, W]
+# output: indices [2, k, 2], scores [2, k]
+def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5):
+ h = score_map.shape[2]
+ w = score_map.shape[3]
+
+ mask = score_map > score_thld
+ if nms_size > 0:
+ nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2)
+ nms_mask = torch.eq(score_map, nms_mask)
+ mask = torch.logical_and(nms_mask, mask)
+ if eof_size > 0:
+ eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device)
+ eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0)
+ eof_mask = eof_mask.bool()
+ mask = torch.logical_and(eof_mask, mask)
+ if edge_thld > 0:
+ non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld)
+ mask = torch.logical_and(non_edge_mask, mask)
+
+ bs = score_map.shape[0]
+ if bs is None:
+ indices = torch.nonzero(mask)[0]
+ scores = gather_nd(score_map, indices)[0]
+ sample = torch.sort(scores, descending=True)[1][0:k]
+ indices = indices[sample].unsqueeze(0)
+ scores = scores[sample].unsqueeze(0)
+ else:
+ indices = []
+ scores = []
+ for i in range(bs):
+ tmp_mask = mask[i][0]
+ tmp_score_map = score_map[i][0]
+ tmp_indices = torch.nonzero(tmp_mask)
+ tmp_scores = gather_nd(tmp_score_map, tmp_indices)
+ tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k]
+ tmp_indices = tmp_indices[tmp_sample]
+ tmp_scores = tmp_scores[tmp_sample]
+ indices.append(tmp_indices)
+ scores.append(tmp_scores)
+ try:
+ indices = torch.stack(indices, dim=0)
+ scores = torch.stack(scores, dim=0)
+ except:
+ min_num = np.min([len(i) for i in indices])
+ indices = torch.stack([i[:min_num] for i in indices], dim=0)
+ scores = torch.stack([i[:min_num] for i in scores], dim=0)
+ return indices, scores
+
+
+def edge_mask(inputs, n_channel, dilation=1, edge_thld=5):
+ b, c, h, w = inputs.size()
+ device = inputs.device
+
+ dii_filter = torch.tensor(
+ [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
+ ).view(1, 1, 3, 3)
+ dij_filter = 0.25 * torch.tensor(
+ [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
+ ).view(1, 1, 3, 3)
+ djj_filter = torch.tensor(
+ [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
+ ).view(1, 1, 3, 3)
+
+ dii = F.conv2d(
+ inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation
+ ).view(b, c, h, w)
+ dij = F.conv2d(
+ inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation
+ ).view(b, c, h, w)
+ djj = F.conv2d(
+ inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation
+ ).view(b, c, h, w)
+
+ det = dii * djj - dij * dij
+ tr = dii + djj
+ del dii, dij, djj
+
+ threshold = (edge_thld + 1) ** 2 / edge_thld
+ is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
+
+ return is_not_edge
diff --git a/imcui/third_party/DarkFeat/pose_estimation.py b/imcui/third_party/DarkFeat/pose_estimation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c87877191e7e31c3bc0a362d7d481dfd5d4b5757
--- /dev/null
+++ b/imcui/third_party/DarkFeat/pose_estimation.py
@@ -0,0 +1,137 @@
+import argparse
+import cv2
+import numpy as np
+import os
+import math
+import subprocess
+from tqdm import tqdm
+
+
+def compute_essential(matched_kp1, matched_kp2, K):
+ pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+ pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+ K_1 = np.eye(3)
+ # Estimate the homography between the matches using RANSAC
+ ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.RANSAC, prob=0.999, threshold=0.001, maxIters=10000)
+ if ransac_inliers is None or ransac_model.shape != (3,3):
+ ransac_inliers = np.array([])
+ ransac_model = None
+ return ransac_model, ransac_inliers, pts1, pts2
+
+
+def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
+ """Compute the angular error between two rotation matrices and two translation vectors.
+ Keyword arguments:
+ R -- 2D numpy array containing an estimated rotation
+ gt_R -- 2D numpy array containing the corresponding ground truth rotation
+ t -- 2D numpy array containing an estimated translation as column
+ gt_t -- 2D numpy array containing the corresponding ground truth translation
+ """
+
+ inliers = inliers.ravel()
+ R = np.eye(3)
+ t = np.zeros((3,1))
+ sst = True
+ try:
+ _, R, t, _ = cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), inliers)
+ except:
+ sst = False
+ # calculate angle between provided rotations
+ #
+ if sst:
+ dR = np.matmul(R, np.transpose(R_GT))
+ dR = cv2.Rodrigues(dR)[0]
+ dR = np.linalg.norm(dR) * 180 / math.pi
+
+ # calculate angle between provided translations
+ dT = float(np.dot(t_GT.T, t))
+ dT /= float(np.linalg.norm(t_GT))
+
+ if dT > 1 or dT < -1:
+ print("Domain warning! dT:",dT)
+ dT = max(-1,min(1,dT))
+ dT = math.acos(dT) * 180 / math.pi
+ dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation
+ else:
+ dR, dT = 180.0, 180.0
+ return dR, dT
+
+
+def pose_evaluation(result_base_dir, dark_name1, dark_name2, enhancer, K, R_GT, t_GT):
+ try:
+ m_kp1 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_1/'+dark_name1)
+ m_kp2 = np.load(result_base_dir+enhancer+'/DarkFeat/POINT_2/'+dark_name2)
+ except:
+ return 180.0, 180.0
+ try:
+ E, inliers, pts1, pts2 = compute_essential(m_kp1, m_kp2, K)
+ except:
+ E, inliers, pts1, pts2 = np.zeros((3, 3)), np.array([]), None, None
+ dR, dT = compute_error(R_GT, t_GT, E, pts1, pts2, inliers)
+ return dR, dT
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--histeq', action='store_true')
+ parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/')
+ opt = parser.parse_args()
+
+ sizer = (960, 640)
+ focallength_x = 4.504986436499113e+03/(6744/sizer[0])
+ focallength_y = 4.513311442889859e+03/(4502/sizer[1])
+ K = np.eye(3)
+ K[0,0] = focallength_x
+ K[1,1] = focallength_y
+ K[0,2] = 3.363322177533149e+03/(6744/sizer[0])
+ K[1,2] = 2.291824660547715e+03/(4502/sizer[1])
+ Kinv = np.linalg.inv(K)
+ Kinvt = np.transpose(Kinv)
+
+ PE_MT = np.zeros((6, 8))
+
+ enhancer = 'None' if not opt.histeq else 'HistEQ'
+
+ for scene in ['Indoor', 'Outdoor']:
+ dir_base = opt.dataset_dir + '/' + scene + '/'
+ base_save = 'result_errors/' + scene + '/'
+ pair_list = sorted(os.listdir(dir_base))
+
+ os.makedirs(base_save, exist_ok=True)
+
+ for pair in tqdm(pair_list):
+ opention = 1
+ if scene == 'Outdoor':
+ pass
+ else:
+ if int(pair[4::]) <= 17:
+ opention = 0
+ else:
+ pass
+ name = []
+ files = sorted(os.listdir(dir_base+pair))
+ for file_ in files:
+ if file_.endswith('.cr2'):
+ name.append(file_[0:9])
+ ISO = ['00100', '00200', '00400', '00800', '01600', '03200', '06400', '12800']
+ if opention == 1:
+ Shutter_speed = ['0.005','0.01','0.025','0.05','0.17','0.5']
+ else:
+ Shutter_speed = ['0.01','0.02','0.05','0.1','0.3','1']
+
+ E_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'E_estimated.npy')
+ F_GT = np.dot(np.dot(Kinvt,E_GT),Kinv)
+ R_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'R_GT.npy')
+ t_GT = np.load(dir_base+pair+'/GT_Correspondence/'+'T_GT.npy')
+ result_base_dir ='result/' +scene+'/'+pair+'/'
+ for iso in ISO:
+ for ex in Shutter_speed:
+ dark_name1 = name[0]+iso+'_'+ex+'_'+scene+'.npy'
+ dark_name2 = name[1]+iso+'_'+ex+'_'+scene+'.npy'
+
+ dr, dt = pose_evaluation(result_base_dir,dark_name1,dark_name2,enhancer,K,R_GT,t_GT)
+ PE_MT[Shutter_speed.index(ex),ISO.index(iso)] = max(dr, dt)
+
+ subprocess.check_output(['mkdir', '-p', base_save + pair + f'/{enhancer}/'])
+ np.save(base_save + pair + f'/{enhancer}/Pose_error_DarkFeat.npy', PE_MT)
+
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/raw_preprocess.py b/imcui/third_party/DarkFeat/raw_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..226155a84e97f15782d3650f4ef6b3fa1880e07b
--- /dev/null
+++ b/imcui/third_party/DarkFeat/raw_preprocess.py
@@ -0,0 +1,62 @@
+import glob
+import rawpy
+import cv2
+import os
+import numpy as np
+import colour_demosaicing
+from tqdm import tqdm
+
+
+def process_raw(args, path, w_new, h_new):
+ raw = rawpy.imread(str(path)).raw_image_visible
+ if '_00200_' in str(path) or '_00100_' in str(path):
+ raw = np.clip(raw.astype('float32') - 512, 0, 65535)
+ else:
+ raw = np.clip(raw.astype('float32') - 2048, 0, 65535)
+ img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32')
+ img = np.clip(img, 0, 16383)
+
+ # HistEQ start
+ if args.histeq:
+ img2 = np.zeros_like(img)
+ for i in range(3):
+ hist,bins = np.histogram(img[..., i].flatten(),16384,[0,16384])
+ cdf = hist.cumsum()
+ cdf_normalized = cdf * float(hist.max()) / cdf.max()
+ cdf_m = np.ma.masked_equal(cdf,0)
+ cdf_m = (cdf_m - cdf_m.min())*16383/(cdf_m.max()-cdf_m.min())
+ cdf = np.ma.filled(cdf_m,0).astype('uint16')
+ img2[..., i] = cdf[img[..., i].astype('int16')]
+ img[..., i] = img2[..., i].astype('float32')
+ # HistEQ end
+
+ m = img.mean()
+ d = np.abs(img - img.mean()).mean()
+ img = (img - m + 2*d) / 4/d * 255
+ image = np.clip(img, 0, 255)
+
+ image = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA)
+
+ if args.histeq:
+ path=str(path)
+ os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']), exist_ok=True)
+ np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy']+[path.split('/')[-1].replace('cr2','npy')]), image)
+ else:
+ path=str(path)
+ os.makedirs('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']), exist_ok=True)
+ np.save('/'.join(path.split('/')[:-2]+[path.split('/')[-2]+'-npy-nohisteq']+[path.split('/')[-1].replace('cr2','npy')]), image)
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--H', type=int, default=int(640))
+ parser.add_argument('--W', type=int, default=int(960))
+ parser.add_argument('--histeq', action='store_true')
+ parser.add_argument('--dataset_dir', type=str, default='/data/hyz/MID/')
+ args = parser.parse_args()
+
+ path_ls = glob.glob(args.dataset_dir + '/*/pair*/?????/*')
+ for path in tqdm(path_ls):
+ process_raw(args, path, args.W, args.H)
+
diff --git a/imcui/third_party/DarkFeat/read_error.py b/imcui/third_party/DarkFeat/read_error.py
new file mode 100644
index 0000000000000000000000000000000000000000..406b92dbd3877a11e51aebc3a705cd8d8d17e173
--- /dev/null
+++ b/imcui/third_party/DarkFeat/read_error.py
@@ -0,0 +1,56 @@
+import os
+import numpy as np
+import subprocess
+
+# def ratio(losses, thresholds=[1,2,3,4,5,6,7,8,9,10]):
+def ratio(losses, thresholds=[5,10]):
+ return [
+ '{:.3f}'.format(np.mean(losses < threshold))
+ for threshold in thresholds
+ ]
+
+if __name__ == '__main__':
+ scene = 'Indoor'
+ dir_base = 'result_errors/Indoor/'
+ save_pt = 'resultfinal_errors/Indoor/'
+
+ subprocess.check_output(['mkdir', '-p', save_pt])
+
+ with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f:
+ f.write('5deg 10deg'+'\n')
+ pair_list = os.listdir(dir_base)
+ enhancer = os.listdir(dir_base+'/pair9/')
+ for method in enhancer:
+ pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method))
+ for pose_error in pose_error_list:
+ error_array = np.expand_dims(np.zeros((6, 8)),axis=2)
+ for pair in pair_list:
+ try:
+ error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2)
+ except:
+ print('error in', dir_base+'/'+pair+'/'+method+'/'+pose_error)
+ continue
+ error_array = np.concatenate((error_array,error),axis=2)
+ ratio_result = ratio(error_array[:,:,1::].flatten())
+ f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n")
+
+
+ scene = 'Outdoor'
+ dir_base = 'result_errors/Outdoor/'
+ save_pt = 'resultfinal_errors/Outdoor/'
+
+ subprocess.check_output(['mkdir', '-p', save_pt])
+
+ with open(save_pt +'ratio_methods_'+scene+'.txt','w') as f:
+ f.write('5deg 10deg'+'\n')
+ pair_list = os.listdir(dir_base)
+ enhancer = os.listdir(dir_base+'/pair9/')
+ for method in enhancer:
+ pose_error_list = sorted(os.listdir(dir_base+'/pair9/'+method))
+ for pose_error in pose_error_list:
+ error_array = np.expand_dims(np.zeros((6, 8)),axis=2)
+ for pair in pair_list:
+ error = np.expand_dims(np.load(dir_base+'/'+pair+'/'+method+'/'+pose_error),axis=2)
+ error_array = np.concatenate((error_array,error),axis=2)
+ ratio_result = ratio(error_array[:,:,1::].flatten())
+ f.write(method + '_' + pose_error[11:-4] +' '+' '.join([str(i) for i in ratio_result])+"\n")
diff --git a/imcui/third_party/DarkFeat/run.py b/imcui/third_party/DarkFeat/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e4c87053d2970fc927d8991aa0dab208f3c4917
--- /dev/null
+++ b/imcui/third_party/DarkFeat/run.py
@@ -0,0 +1,48 @@
+import cv2
+import yaml
+import argparse
+import os
+from torch.utils.data import DataLoader
+
+from datasets.gl3d_dataset import GL3DDataset
+from trainer import Trainer
+from trainer_single_norel import SingleTrainerNoRel
+from trainer_single import SingleTrainer
+
+
+if __name__ == '__main__':
+ # add argument parser
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, default='./configs/config.yaml')
+ parser.add_argument('--dataset_dir', type=str, default='/mnt/nvme2n1/hyz/data/GL3D')
+ parser.add_argument('--data_split', type=str, default='comb')
+ parser.add_argument('--is_training', type=bool, default=True)
+ parser.add_argument('--job_name', type=str, default='')
+ parser.add_argument('--gpu', type=str, default='0')
+ parser.add_argument('--start_cnt', type=int, default=0)
+ parser.add_argument('--stage', type=int, default=1)
+ args = parser.parse_args()
+
+ # load global config
+ with open(args.config, 'r') as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+
+ # setup dataloader
+ dataset = GL3DDataset(args.dataset_dir, config['network'], args.data_split, is_training=args.is_training)
+ data_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)
+
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
+
+
+ if args.stage == 1:
+ trainer = SingleTrainerNoRel(config, f'cuda:0', data_loader, args.job_name, args.start_cnt)
+ elif args.stage == 2:
+ trainer = SingleTrainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt)
+ elif args.stage == 3:
+ trainer = Trainer(config, f'cuda:0', data_loader, args.job_name, args.start_cnt)
+ else:
+ raise NotImplementedError()
+
+ trainer.train()
+
+
\ No newline at end of file
diff --git a/imcui/third_party/DarkFeat/trainer.py b/imcui/third_party/DarkFeat/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ff2af9608e934b6899058d756bb2ab7d0fee2d
--- /dev/null
+++ b/imcui/third_party/DarkFeat/trainer.py
@@ -0,0 +1,348 @@
+import os
+import cv2
+import time
+import yaml
+import torch
+import datetime
+from tensorboardX import SummaryWriter
+import torchvision.transforms as tvf
+import torch.nn as nn
+import torch.nn.functional as F
+
+from nets.geom import getK, getWarp, _grid_positions, getWarpNoValidate
+from nets.loss import make_detector_loss, make_noise_score_map_loss
+from nets.score import extract_kpts
+from nets.multi_sampler import MultiSampler
+from nets.noise_reliability_loss import MultiPixelAPLoss
+from datasets.noise_simulator import NoiseSimulator
+from nets.l2net import Quad_L2Net
+
+
+class Trainer:
+ def __init__(self, config, device, loader, job_name, start_cnt):
+ self.config = config
+ self.device = device
+ self.loader = loader
+
+ # tensorboard writer construction
+ os.makedirs('./runs/', exist_ok=True)
+ if job_name != '':
+ self.log_dir = f'runs/{job_name}'
+ else:
+ self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
+
+ self.writer = SummaryWriter(self.log_dir)
+ with open(f'{self.log_dir}/config.yaml', 'w') as f:
+ yaml.dump(config, f)
+
+ if config['network']['input_type'] == 'gray':
+ self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
+ elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic':
+ self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
+ elif config['network']['input_type'] == 'raw':
+ self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
+ else:
+ raise NotImplementedError()
+
+ # noise maker
+ self.noise_maker = NoiseSimulator(device)
+
+ # reliability map conv
+ self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda()
+
+ # load model
+ self.cnt = 0
+ if start_cnt != 0:
+ self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth', map_location=device))
+ self.cnt = start_cnt + 1
+
+ # sampler
+ sampler = MultiSampler(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16,
+ subd_neg=-8,maxpool_pos=True).to(device)
+ self.reliability_relitive_loss = MultiPixelAPLoss(sampler, nq=20).to(device)
+
+
+ # optimizer and scheduler
+ if self.config['training']['optimizer'] == 'SGD':
+ self.optimizer = torch.optim.SGD(
+ [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
+ lr=self.config['training']['lr'],
+ momentum=self.config['training']['momentum'],
+ weight_decay=self.config['training']['weight_decay'],
+ )
+ elif self.config['training']['optimizer'] == 'Adam':
+ self.optimizer = torch.optim.Adam(
+ [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
+ lr=self.config['training']['lr'],
+ weight_decay=self.config['training']['weight_decay']
+ )
+ else:
+ raise NotImplementedError()
+
+ self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
+ self.optimizer,
+ step_size=self.config['training']['lr_step'],
+ gamma=self.config['training']['lr_gamma'],
+ last_epoch=start_cnt
+ )
+ for param_tensor in self.model.state_dict():
+ print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
+
+
+ def save(self, iter_num):
+ torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth')
+
+ def load(self, path):
+ self.model.load_state_dict(torch.load(path))
+
+ def train(self):
+ self.model.train()
+
+ for epoch in range(2):
+ for batch_idx, inputs in enumerate(self.loader):
+ self.optimizer.zero_grad()
+ t = time.time()
+
+ # preprocess and add noise
+ img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt)
+ img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt)
+
+ img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
+ img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
+ noise_img0 = noise_img0_ori.permute(0, 3, 1, 2).float().to(self.device)
+ noise_img1 = noise_img1_ori.permute(0, 3, 1, 2).float().to(self.device)
+
+ if self.config['network']['input_type'] == 'rgb':
+ # 3-channel rgb
+ RGB_mean = [0.485, 0.456, 0.406]
+ RGB_std = [0.229, 0.224, 0.225]
+ norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
+ img0 = norm_RGB(img0)
+ img1 = norm_RGB(img1)
+ noise_img0 = norm_RGB(noise_img0)
+ noise_img1 = norm_RGB(noise_img1)
+
+ elif self.config['network']['input_type'] == 'gray':
+ # 1-channel
+ img0 = torch.mean(img0, dim=1, keepdim=True)
+ img1 = torch.mean(img1, dim=1, keepdim=True)
+ noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True)
+ noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True)
+ norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std())
+ norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std())
+ img0 = norm_gray0(img0)
+ img1 = norm_gray1(img1)
+ noise_img0 = norm_gray0(noise_img0)
+ noise_img1 = norm_gray1(noise_img1)
+
+ elif self.config['network']['input_type'] == 'raw':
+ # 4-channel
+ pass
+
+ elif self.config['network']['input_type'] == 'raw-demosaic':
+ # 3-channel
+ pass
+
+ else:
+ raise NotImplementedError()
+
+ desc0, score_map0, _, _ = self.model(img0)
+ desc1, score_map1, _, _ = self.model(img1)
+
+ conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2]
+ conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2]
+
+ noise_desc0, noise_score_map0, noise_at0, noise_att0 = self.model(noise_img0)
+ noise_desc1, noise_score_map1, noise_at1, noise_att1 = self.model(noise_img1)
+
+ noise_conf0 = F.softmax(self.model.clf(torch.abs(noise_desc0)**2.0), dim=1)[:,1:2]
+ noise_conf1 = F.softmax(self.model.clf(torch.abs(noise_desc1)**2.0), dim=1)[:,1:2]
+
+ cur_feat_size0 = torch.tensor(score_map0.shape[2:])
+ cur_feat_size1 = torch.tensor(score_map1.shape[2:])
+
+ desc0 = desc0.permute(0, 2, 3, 1)
+ desc1 = desc1.permute(0, 2, 3, 1)
+ score_map0 = score_map0.permute(0, 2, 3, 1)
+ score_map1 = score_map1.permute(0, 2, 3, 1)
+ noise_desc0 = noise_desc0.permute(0, 2, 3, 1)
+ noise_desc1 = noise_desc1.permute(0, 2, 3, 1)
+ noise_score_map0 = noise_score_map0.permute(0, 2, 3, 1)
+ noise_score_map1 = noise_score_map1.permute(0, 2, 3, 1)
+ conf0 = conf0.permute(0, 2, 3, 1)
+ conf1 = conf1.permute(0, 2, 3, 1)
+ noise_conf0 = noise_conf0.permute(0, 2, 3, 1)
+ noise_conf1 = noise_conf1.permute(0, 2, 3, 1)
+
+ r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device)
+ r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device)
+
+ pos0 = _grid_positions(
+ cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device)
+
+ pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate(
+ pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
+ r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+
+ pos0, pos1, _ = getWarp(
+ pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
+ r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+
+ reliab_loss_relative = self.reliability_relitive_loss(desc0, desc1, noise_desc0, noise_desc1, conf0, conf1, noise_conf0, noise_conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3])
+
+ det_structured_loss, det_accuracy = make_detector_loss(
+ pos0, pos1, desc0, desc1,
+ score_map0, score_map1, img0.shape[0],
+ self.config['network']['use_corr_n'],
+ self.config['network']['loss_type'],
+ self.config
+ )
+
+ det_structured_loss_noise, det_accuracy_noise = make_detector_loss(
+ pos0, pos1, noise_desc0, noise_desc1,
+ noise_score_map0, noise_score_map1, img0.shape[0],
+ self.config['network']['use_corr_n'],
+ self.config['network']['loss_type'],
+ self.config
+ )
+
+ indices0, scores0 = extract_kpts(
+ score_map0.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+ indices1, scores1 = extract_kpts(
+ score_map1.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+
+ noise_score_loss0, mask0 = make_noise_score_map_loss(score_map0, noise_score_map0, indices0, img0.shape[0], thld=0.1)
+ noise_score_loss1, mask1 = make_noise_score_map_loss(score_map1, noise_score_map1, indices1, img1.shape[0], thld=0.1)
+
+ total_loss = det_structured_loss + det_structured_loss_noise
+ total_loss += noise_score_loss0 / 2. * 1.
+ total_loss += noise_score_loss1 / 2. * 1.
+ total_loss += reliab_loss_relative[0] / 2. * 0.5
+ total_loss += reliab_loss_relative[1] / 2. * 0.5
+
+ self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
+ self.writer.add_scalar("acc/noise_acc", det_accuracy_noise, self.cnt)
+ self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
+ self.writer.add_scalar("loss/noise_score_loss", (noise_score_loss0 + noise_score_loss1) / 2., self.cnt)
+ self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt)
+ self.writer.add_scalar("loss/det_loss_noise", det_structured_loss_noise, self.cnt)
+ print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t))
+ # print(f'normal_loss: {det_structured_loss}, noise_loss: {det_structured_loss_noise}, reliab_loss: {reliab_loss_relative[0]}, {reliab_loss_relative[1]}')
+
+ if det_structured_loss != 0:
+ total_loss.backward()
+ self.optimizer.step()
+ self.lr_scheduler.step()
+
+ if self.cnt % 100 == 0:
+ noise_indices0, noise_scores0 = extract_kpts(
+ noise_score_map0.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+ noise_indices1, noise_scores1 = extract_kpts(
+ noise_score_map1.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+ if self.config['network']['input_type'] == 'raw':
+ kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0])
+ kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0])
+ noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0][..., :3] * 255., noise_indices0[0])
+ noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0][..., :3] * 255., noise_indices1[0])
+ else:
+ kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0])
+ kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0])
+ noise_kpt_img0 = self.showKeyPoints(noise_img0_ori[0] * 255., noise_indices0[0])
+ noise_kpt_img1 = self.showKeyPoints(noise_img1_ori[0] * 255., noise_indices1[0])
+
+ self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/noise_kpts', noise_kpt_img0, self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/noise_kpts', noise_kpt_img1, self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/noise_score_map', noise_score_map0[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/noise_score_map', noise_score_map1[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/kpt_mask', mask0.unsqueeze(2), self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/kpt_mask', mask1.unsqueeze(2), self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/noise_conf', noise_conf0[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/noise_conf', noise_conf1[0], self.cnt, dataformats='HWC')
+
+ if self.cnt % 5000 == 0:
+ self.save(self.cnt)
+
+ self.cnt += 1
+
+
+ def showKeyPoints(self, img, indices):
+ key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
+ img = img.numpy().astype('uint8')
+ img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
+ return img
+
+
+ def preprocess(self, img, iter_idx):
+ if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']:
+ return img
+
+ raw = self.noise_maker.rgb2raw(img, batched=True)
+
+ if self.config['network']['noise']:
+ ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+ raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+
+ if self.config['network']['input_type'] == 'raw':
+ return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
+
+ if self.config['network']['input_type'] == 'raw-demosaic':
+ return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
+
+ rgb = self.noise_maker.raw2rgb(raw, batched=True)
+ if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+ return torch.tensor(rgb)
+
+ raise NotImplementedError()
+
+
+ def preprocess_noise_pair(self, img, iter_idx):
+ assert self.config['network']['noise']
+
+ raw = self.noise_maker.rgb2raw(img, batched=True)
+
+ ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+ noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+
+ if self.config['network']['input_type'] == 'raw':
+ return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \
+ torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
+
+ if self.config['network']['input_type'] == 'raw-demosaic':
+ return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \
+ torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
+
+ noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
+ if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+ return img, torch.tensor(noise_rgb)
+
+ raise NotImplementedError()
diff --git a/imcui/third_party/DarkFeat/trainer_single.py b/imcui/third_party/DarkFeat/trainer_single.py
new file mode 100644
index 0000000000000000000000000000000000000000..65566e7e27cfd605eba000d308b6d3610f29e746
--- /dev/null
+++ b/imcui/third_party/DarkFeat/trainer_single.py
@@ -0,0 +1,294 @@
+import os
+import cv2
+import time
+import yaml
+import torch
+import datetime
+from tensorboardX import SummaryWriter
+import torchvision.transforms as tvf
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from nets.geom import getK, getWarp, _grid_positions, getWarpNoValidate
+from nets.loss import make_detector_loss
+from nets.score import extract_kpts
+from nets.sampler import NghSampler2
+from nets.reliability_loss import ReliabilityLoss
+from datasets.noise_simulator import NoiseSimulator
+from nets.l2net import Quad_L2Net
+
+
+class SingleTrainer:
+ def __init__(self, config, device, loader, job_name, start_cnt):
+ self.config = config
+ self.device = device
+ self.loader = loader
+
+ # tensorboard writer construction
+ os.makedirs('./runs/', exist_ok=True)
+ if job_name != '':
+ self.log_dir = f'runs/{job_name}'
+ else:
+ self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
+
+ self.writer = SummaryWriter(self.log_dir)
+ with open(f'{self.log_dir}/config.yaml', 'w') as f:
+ yaml.dump(config, f)
+
+ if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray':
+ self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
+ elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic':
+ self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
+ elif config['network']['input_type'] == 'raw':
+ self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
+ else:
+ raise NotImplementedError()
+
+ # noise maker
+ self.noise_maker = NoiseSimulator(device)
+
+ # load model
+ self.cnt = 0
+ if start_cnt != 0:
+ self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth'))
+ self.cnt = start_cnt + 1
+
+ # sampler
+ sampler = NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16,
+ subd_neg=-8,maxpool_pos=True).to(device)
+ self.reliability_loss = ReliabilityLoss(sampler, base=0.3, nq=20).to(device)
+ # reliability map conv
+ self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda()
+
+ # optimizer and scheduler
+ if self.config['training']['optimizer'] == 'SGD':
+ self.optimizer = torch.optim.SGD(
+ [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
+ lr=self.config['training']['lr'],
+ momentum=self.config['training']['momentum'],
+ weight_decay=self.config['training']['weight_decay'],
+ )
+ elif self.config['training']['optimizer'] == 'Adam':
+ self.optimizer = torch.optim.Adam(
+ [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
+ lr=self.config['training']['lr'],
+ weight_decay=self.config['training']['weight_decay']
+ )
+ else:
+ raise NotImplementedError()
+
+ self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
+ self.optimizer,
+ step_size=self.config['training']['lr_step'],
+ gamma=self.config['training']['lr_gamma'],
+ last_epoch=start_cnt
+ )
+ for param_tensor in self.model.state_dict():
+ print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
+
+
+ def save(self, iter_num):
+ torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth')
+
+ def load(self, path):
+ self.model.load_state_dict(torch.load(path))
+
+ def train(self):
+ self.model.train()
+
+ for epoch in range(2):
+ for batch_idx, inputs in enumerate(self.loader):
+ self.optimizer.zero_grad()
+ t = time.time()
+
+ # preprocess and add noise
+ img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt)
+ img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt)
+
+ img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
+ img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
+
+ if self.config['network']['input_type'] == 'rgb':
+ # 3-channel rgb
+ RGB_mean = [0.485, 0.456, 0.406]
+ RGB_std = [0.229, 0.224, 0.225]
+ norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
+ img0 = norm_RGB(img0)
+ img1 = norm_RGB(img1)
+ noise_img0 = norm_RGB(noise_img0)
+ noise_img1 = norm_RGB(noise_img1)
+
+ elif self.config['network']['input_type'] == 'gray':
+ # 1-channel
+ img0 = torch.mean(img0, dim=1, keepdim=True)
+ img1 = torch.mean(img1, dim=1, keepdim=True)
+ noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True)
+ noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True)
+ norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std())
+ norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std())
+ img0 = norm_gray0(img0)
+ img1 = norm_gray1(img1)
+ noise_img0 = norm_gray0(noise_img0)
+ noise_img1 = norm_gray1(noise_img1)
+
+ elif self.config['network']['input_type'] == 'raw':
+ # 4-channel
+ pass
+
+ elif self.config['network']['input_type'] == 'raw-demosaic':
+ # 3-channel
+ pass
+
+ else:
+ raise NotImplementedError()
+
+ desc0, score_map0, _, _ = self.model(img0)
+ desc1, score_map1, _, _ = self.model(img1)
+
+ cur_feat_size0 = torch.tensor(score_map0.shape[2:])
+ cur_feat_size1 = torch.tensor(score_map1.shape[2:])
+
+ conf0 = F.softmax(self.model.clf(torch.abs(desc0)**2.0), dim=1)[:,1:2]
+ conf1 = F.softmax(self.model.clf(torch.abs(desc1)**2.0), dim=1)[:,1:2]
+
+ desc0 = desc0.permute(0, 2, 3, 1)
+ desc1 = desc1.permute(0, 2, 3, 1)
+ score_map0 = score_map0.permute(0, 2, 3, 1)
+ score_map1 = score_map1.permute(0, 2, 3, 1)
+ conf0 = conf0.permute(0, 2, 3, 1)
+ conf1 = conf1.permute(0, 2, 3, 1)
+
+ r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device)
+ r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device)
+
+ pos0 = _grid_positions(
+ cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device)
+
+ pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate(
+ pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
+ r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+
+ pos0, pos1, _ = getWarp(
+ pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
+ r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+
+ reliab_loss = self.reliability_loss(desc0, desc1, conf0, conf1, pos0_for_rel, pos1_for_rel, img0.shape[0], img0.shape[2], img0.shape[3])
+
+ det_structured_loss, det_accuracy = make_detector_loss(
+ pos0, pos1, desc0, desc1,
+ score_map0, score_map1, img0.shape[0],
+ self.config['network']['use_corr_n'],
+ self.config['network']['loss_type'],
+ self.config
+ )
+
+ total_loss = det_structured_loss
+ self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt)
+
+ total_loss += reliab_loss
+
+ self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
+ self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
+ self.writer.add_scalar("loss/reliab_loss", reliab_loss, self.cnt)
+ print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t))
+
+ if det_structured_loss != 0:
+ total_loss.backward()
+ self.optimizer.step()
+ self.lr_scheduler.step()
+
+ if self.cnt % 100 == 0:
+ indices0, scores0 = extract_kpts(
+ score_map0.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+ indices1, scores1 = extract_kpts(
+ score_map1.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+
+ if self.config['network']['input_type'] == 'raw':
+ kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0])
+ kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0])
+ else:
+ kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0])
+ kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0])
+
+ self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/conf', conf0[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/conf', conf1[0], self.cnt, dataformats='HWC')
+
+ if self.cnt % 10000 == 0:
+ self.save(self.cnt)
+
+ self.cnt += 1
+
+
+ def showKeyPoints(self, img, indices):
+ key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
+ img = img.numpy().astype('uint8')
+ img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
+ return img
+
+
+ def preprocess(self, img, iter_idx):
+ if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']:
+ return img
+
+ raw = self.noise_maker.rgb2raw(img, batched=True)
+
+ if self.config['network']['noise']:
+ ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+ raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+
+ if self.config['network']['input_type'] == 'raw':
+ return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
+
+ if self.config['network']['input_type'] == 'raw-demosaic':
+ return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
+
+ rgb = self.noise_maker.raw2rgb(raw, batched=True)
+ if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+ return torch.tensor(rgb)
+
+ raise NotImplementedError()
+
+
+ def preprocess_noise_pair(self, img, iter_idx):
+ assert self.config['network']['noise']
+
+ raw = self.noise_maker.rgb2raw(img, batched=True)
+
+ ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+ noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+
+ if self.config['network']['input_type'] == 'raw':
+ return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \
+ torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
+
+ if self.config['network']['input_type'] == 'raw-demosaic':
+ return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \
+ torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
+
+ if self.config['network']['input_type'] == 'raw-gray':
+ factor = torch.tensor([0.299, 0.587, 0.114]).double()
+ return torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), factor).unsqueeze(-1), \
+ torch.matmul(torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True)), factor).unsqueeze(-1)
+
+ noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
+ if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+ return img, torch.tensor(noise_rgb)
+
+ raise NotImplementedError()
diff --git a/imcui/third_party/DarkFeat/trainer_single_norel.py b/imcui/third_party/DarkFeat/trainer_single_norel.py
new file mode 100644
index 0000000000000000000000000000000000000000..a572e9c599adc30e5753e11e668d121cd378672a
--- /dev/null
+++ b/imcui/third_party/DarkFeat/trainer_single_norel.py
@@ -0,0 +1,265 @@
+import os
+import cv2
+import time
+import yaml
+import torch
+import datetime
+from tensorboardX import SummaryWriter
+import torchvision.transforms as tvf
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from nets.l2net import Quad_L2Net
+from nets.geom import getK, getWarp, _grid_positions
+from nets.loss import make_detector_loss
+from nets.score import extract_kpts
+from datasets.noise_simulator import NoiseSimulator
+from nets.l2net import Quad_L2Net
+
+
+class SingleTrainerNoRel:
+ def __init__(self, config, device, loader, job_name, start_cnt):
+ self.config = config
+ self.device = device
+ self.loader = loader
+
+ # tensorboard writer construction
+ os.makedirs('./runs/', exist_ok=True)
+ if job_name != '':
+ self.log_dir = f'runs/{job_name}'
+ else:
+ self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
+
+ self.writer = SummaryWriter(self.log_dir)
+ with open(f'{self.log_dir}/config.yaml', 'w') as f:
+ yaml.dump(config, f)
+
+ if config['network']['input_type'] == 'gray' or config['network']['input_type'] == 'raw-gray':
+ self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
+ elif config['network']['input_type'] == 'rgb' or config['network']['input_type'] == 'raw-demosaic':
+ self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
+ elif config['network']['input_type'] == 'raw':
+ self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
+ else:
+ raise NotImplementedError()
+
+ # noise maker
+ self.noise_maker = NoiseSimulator(device)
+
+ # load model
+ self.cnt = 0
+ if start_cnt != 0:
+ self.model.load_state_dict(torch.load(f'{self.log_dir}/model_{start_cnt:06d}.pth'))
+ self.cnt = start_cnt + 1
+
+ # optimizer and scheduler
+ if self.config['training']['optimizer'] == 'SGD':
+ self.optimizer = torch.optim.SGD(
+ [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
+ lr=self.config['training']['lr'],
+ momentum=self.config['training']['momentum'],
+ weight_decay=self.config['training']['weight_decay'],
+ )
+ elif self.config['training']['optimizer'] == 'Adam':
+ self.optimizer = torch.optim.Adam(
+ [{'params': self.model.parameters(), 'initial_lr': self.config['training']['lr']}],
+ lr=self.config['training']['lr'],
+ weight_decay=self.config['training']['weight_decay']
+ )
+ else:
+ raise NotImplementedError()
+
+ self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
+ self.optimizer,
+ step_size=self.config['training']['lr_step'],
+ gamma=self.config['training']['lr_gamma'],
+ last_epoch=start_cnt
+ )
+ for param_tensor in self.model.state_dict():
+ print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
+
+
+ def save(self, iter_num):
+ torch.save(self.model.state_dict(), f'{self.log_dir}/model_{iter_num:06d}.pth')
+
+ def load(self, path):
+ self.model.load_state_dict(torch.load(path))
+
+ def train(self):
+ self.model.train()
+
+ for epoch in range(2):
+ for batch_idx, inputs in enumerate(self.loader):
+ self.optimizer.zero_grad()
+ t = time.time()
+
+ # preprocess and add noise
+ img0_ori, noise_img0_ori = self.preprocess_noise_pair(inputs['img0'], self.cnt)
+ img1_ori, noise_img1_ori = self.preprocess_noise_pair(inputs['img1'], self.cnt)
+
+ img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
+ img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
+
+ if self.config['network']['input_type'] == 'rgb':
+ # 3-channel rgb
+ RGB_mean = [0.485, 0.456, 0.406]
+ RGB_std = [0.229, 0.224, 0.225]
+ norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
+ img0 = norm_RGB(img0)
+ img1 = norm_RGB(img1)
+ noise_img0 = norm_RGB(noise_img0)
+ noise_img1 = norm_RGB(noise_img1)
+
+ elif self.config['network']['input_type'] == 'gray':
+ # 1-channel
+ img0 = torch.mean(img0, dim=1, keepdim=True)
+ img1 = torch.mean(img1, dim=1, keepdim=True)
+ noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True)
+ noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True)
+ norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std())
+ norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std())
+ img0 = norm_gray0(img0)
+ img1 = norm_gray1(img1)
+ noise_img0 = norm_gray0(noise_img0)
+ noise_img1 = norm_gray1(noise_img1)
+
+ elif self.config['network']['input_type'] == 'raw':
+ # 4-channel
+ pass
+
+ elif self.config['network']['input_type'] == 'raw-demosaic':
+ # 3-channel
+ pass
+
+ else:
+ raise NotImplementedError()
+
+ desc0, score_map0, _, _ = self.model(img0)
+ desc1, score_map1, _, _ = self.model(img1)
+
+ cur_feat_size0 = torch.tensor(score_map0.shape[2:])
+ cur_feat_size1 = torch.tensor(score_map1.shape[2:])
+
+ desc0 = desc0.permute(0, 2, 3, 1)
+ desc1 = desc1.permute(0, 2, 3, 1)
+ score_map0 = score_map0.permute(0, 2, 3, 1)
+ score_map1 = score_map1.permute(0, 2, 3, 1)
+
+ r_K0 = getK(inputs['ori_img_size0'], cur_feat_size0, inputs['K0']).to(self.device)
+ r_K1 = getK(inputs['ori_img_size1'], cur_feat_size1, inputs['K1']).to(self.device)
+
+ pos0 = _grid_positions(
+ cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]).to(self.device)
+
+ pos0, pos1, _ = getWarp(
+ pos0, inputs['rel_pose'].to(self.device), inputs['depth0'].to(self.device),
+ r_K0, inputs['depth1'].to(self.device), r_K1, img0.shape[0])
+
+ det_structured_loss, det_accuracy = make_detector_loss(
+ pos0, pos1, desc0, desc1,
+ score_map0, score_map1, img0.shape[0],
+ self.config['network']['use_corr_n'],
+ self.config['network']['loss_type'],
+ self.config
+ )
+
+ total_loss = det_structured_loss
+
+ self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
+ self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
+ self.writer.add_scalar("loss/det_loss_normal", det_structured_loss, self.cnt)
+ print('iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter'.format(self.cnt, total_loss, det_accuracy, time.time()-t))
+
+ if det_structured_loss != 0:
+ total_loss.backward()
+ self.optimizer.step()
+ self.lr_scheduler.step()
+
+ if self.cnt % 100 == 0:
+ indices0, scores0 = extract_kpts(
+ score_map0.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+ indices1, scores1 = extract_kpts(
+ score_map1.permute(0, 3, 1, 2),
+ k=self.config['network']['det']['kpt_n'],
+ score_thld=self.config['network']['det']['score_thld'],
+ nms_size=self.config['network']['det']['nms_size'],
+ eof_size=self.config['network']['det']['eof_size'],
+ edge_thld=self.config['network']['det']['edge_thld']
+ )
+
+ if self.config['network']['input_type'] == 'raw':
+ kpt_img0 = self.showKeyPoints(img0_ori[0][..., :3] * 255., indices0[0])
+ kpt_img1 = self.showKeyPoints(img1_ori[0][..., :3] * 255., indices1[0])
+ else:
+ kpt_img0 = self.showKeyPoints(img0_ori[0] * 255., indices0[0])
+ kpt_img1 = self.showKeyPoints(img1_ori[0] * 255., indices1[0])
+
+ self.writer.add_image('img0/kpts', kpt_img0, self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/kpts', kpt_img1, self.cnt, dataformats='HWC')
+ self.writer.add_image('img0/score_map', score_map0[0], self.cnt, dataformats='HWC')
+ self.writer.add_image('img1/score_map', score_map1[0], self.cnt, dataformats='HWC')
+
+ if self.cnt % 10000 == 0:
+ self.save(self.cnt)
+
+ self.cnt += 1
+
+
+ def showKeyPoints(self, img, indices):
+ key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
+ img = img.numpy().astype('uint8')
+ img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
+ return img
+
+
+ def preprocess(self, img, iter_idx):
+ if not self.config['network']['noise'] and 'raw' not in self.config['network']['input_type']:
+ return img
+
+ raw = self.noise_maker.rgb2raw(img, batched=True)
+
+ if self.config['network']['noise']:
+ ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+ raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+
+ if self.config['network']['input_type'] == 'raw':
+ return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
+
+ if self.config['network']['input_type'] == 'raw-demosaic':
+ return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
+
+ rgb = self.noise_maker.raw2rgb(raw, batched=True)
+ if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+ return torch.tensor(rgb)
+
+ raise NotImplementedError()
+
+
+ def preprocess_noise_pair(self, img, iter_idx):
+ assert self.config['network']['noise']
+
+ raw = self.noise_maker.rgb2raw(img, batched=True)
+
+ ratio_dec = min(self.config['network']['noise_maxstep'], iter_idx) / self.config['network']['noise_maxstep']
+ noise_raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
+
+ if self.config['network']['input_type'] == 'raw':
+ return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True)), \
+ torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
+
+ if self.config['network']['input_type'] == 'raw-demosaic':
+ return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True)), \
+ torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
+
+ noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
+ if self.config['network']['input_type'] == 'rgb' or self.config['network']['input_type'] == 'gray':
+ return img, torch.tensor(noise_rgb)
+
+ raise NotImplementedError()
diff --git a/imcui/third_party/DarkFeat/utils/__init__.py b/imcui/third_party/DarkFeat/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DarkFeat/utils/matching.py b/imcui/third_party/DarkFeat/utils/matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca091f418bb4dc4d278611e5126a930aa51e7f3f
--- /dev/null
+++ b/imcui/third_party/DarkFeat/utils/matching.py
@@ -0,0 +1,128 @@
+import math
+import numpy as np
+import cv2
+
+def extract_ORB_keypoints_and_descriptors(img):
+ # gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ detector = cv2.ORB_create(nfeatures=1000)
+ kp, desc = detector.detectAndCompute(img, None)
+ return kp, desc
+
+def match_descriptors_NG(kp1, desc1, kp2, desc2):
+ bf = cv2.BFMatcher()
+ try:
+ matches = bf.knnMatch(desc1, desc2,k=2)
+ except:
+ matches = []
+ good_matches=[]
+ image1_kp = []
+ image2_kp = []
+ ratios = []
+ try:
+ for (m1,m2) in matches:
+ if m1.distance < 0.8 * m2.distance:
+ good_matches.append(m1)
+ image2_kp.append(kp2[m1.trainIdx].pt)
+ image1_kp.append(kp1[m1.queryIdx].pt)
+ ratios.append(m1.distance / m2.distance)
+ except:
+ pass
+ image1_kp = np.array([image1_kp])
+ image2_kp = np.array([image2_kp])
+ ratios = np.array([ratios])
+ ratios = np.expand_dims(ratios, 2)
+ return image1_kp, image2_kp, good_matches, ratios
+
+def match_descriptors(kp1, desc1, kp2, desc2, ORB):
+ if ORB:
+ bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
+ try:
+ matches = bf.match(desc1,desc2)
+ matches = sorted(matches, key = lambda x:x.distance)
+ except:
+ matches = []
+ good_matches=[]
+ image1_kp = []
+ image2_kp = []
+ count = 0
+ try:
+ for m in matches:
+ count+=1
+ if count < 1000:
+ good_matches.append(m)
+ image2_kp.append(kp2[m.trainIdx].pt)
+ image1_kp.append(kp1[m.queryIdx].pt)
+ except:
+ pass
+ else:
+ # Match the keypoints with the warped_keypoints with nearest neighbor search
+ bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
+ try:
+ matches = bf.match(desc1.transpose(1,0), desc2.transpose(1,0))
+ matches = sorted(matches, key = lambda x:x.distance)
+ except:
+ matches = []
+ good_matches=[]
+ image1_kp = []
+ image2_kp = []
+ try:
+ for m in matches:
+ good_matches.append(m)
+ image2_kp.append(kp2[m.trainIdx].pt)
+ image1_kp.append(kp1[m.queryIdx].pt)
+ except:
+ pass
+
+ image1_kp = np.array([image1_kp])
+ image2_kp = np.array([image2_kp])
+ return image1_kp, image2_kp, good_matches
+
+
+def compute_essential(matched_kp1, matched_kp2, K):
+ pts1 = cv2.undistortPoints(matched_kp1,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+ pts2 = cv2.undistortPoints(matched_kp2,cameraMatrix=K, distCoeffs = (-0.117918271740560,0.075246403574314,0,0))
+ K_1 = np.eye(3)
+ # Estimate the homography between the matches using RANSAC
+ ransac_model, ransac_inliers = cv2.findEssentialMat(pts1, pts2, K_1, method=cv2.FM_RANSAC, prob=0.999, threshold=0.001)
+ if ransac_inliers is None or ransac_model.shape != (3,3):
+ ransac_inliers = np.array([])
+ ransac_model = None
+ return ransac_model, ransac_inliers, pts1, pts2
+
+
+def compute_error(R_GT,t_GT,E,pts1_norm, pts2_norm, inliers):
+ """Compute the angular error between two rotation matrices and two translation vectors.
+ Keyword arguments:
+ R -- 2D numpy array containing an estimated rotation
+ gt_R -- 2D numpy array containing the corresponding ground truth rotation
+ t -- 2D numpy array containing an estimated translation as column
+ gt_t -- 2D numpy array containing the corresponding ground truth translation
+ """
+
+ inliers = inliers.ravel()
+ R = np.eye(3)
+ t = np.zeros((3,1))
+ sst = True
+ try:
+ cv2.recoverPose(E, pts1_norm, pts2_norm, np.eye(3), R, t, inliers)
+ except:
+ sst = False
+ # calculate angle between provided rotations
+ #
+ if sst:
+ dR = np.matmul(R, np.transpose(R_GT))
+ dR = cv2.Rodrigues(dR)[0]
+ dR = np.linalg.norm(dR) * 180 / math.pi
+
+ # calculate angle between provided translations
+ dT = float(np.dot(t_GT.T, t))
+ dT /= float(np.linalg.norm(t_GT))
+
+ if dT > 1 or dT < -1:
+ print("Domain warning! dT:",dT)
+ dT = max(-1,min(1,dT))
+ dT = math.acos(dT) * 180 / math.pi
+ dT = np.minimum(dT, 180 - dT) # ambiguity of E estimation
+ else:
+ dR,dT = 180.0, 180.0
+ return dR, dT
diff --git a/imcui/third_party/DarkFeat/utils/misc.py b/imcui/third_party/DarkFeat/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df6fdec97121486dbb94e0b32a2f66c85c48f7d
--- /dev/null
+++ b/imcui/third_party/DarkFeat/utils/misc.py
@@ -0,0 +1,158 @@
+from pathlib import Path
+import time
+from collections import OrderedDict
+import numpy as np
+import cv2
+import rawpy
+import torch
+import colour_demosaicing
+
+
+class AverageTimer:
+ """ Class to help manage printing simple timing of code execution. """
+
+ def __init__(self, smoothing=0.3, newline=False):
+ self.smoothing = smoothing
+ self.newline = newline
+ self.times = OrderedDict()
+ self.will_print = OrderedDict()
+ self.reset()
+
+ def reset(self):
+ now = time.time()
+ self.start = now
+ self.last_time = now
+ for name in self.will_print:
+ self.will_print[name] = False
+
+ def update(self, name='default'):
+ now = time.time()
+ dt = now - self.last_time
+ if name in self.times:
+ dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name]
+ self.times[name] = dt
+ self.will_print[name] = True
+ self.last_time = now
+
+ def print(self, text='Timer'):
+ total = 0.
+ print('[{}]'.format(text), end=' ')
+ for key in self.times:
+ val = self.times[key]
+ if self.will_print[key]:
+ print('%s=%.3f' % (key, val), end=' ')
+ total += val
+ print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ')
+ if self.newline:
+ print(flush=True)
+ else:
+ print(end='\r', flush=True)
+ self.reset()
+
+
+class VideoStreamer:
+ def __init__(self, basedir, resize, image_glob):
+ self.listing = []
+ self.resize = resize
+ self.i = 0
+ if Path(basedir).is_dir():
+ print('==> Processing image directory input: {}'.format(basedir))
+ self.listing = list(Path(basedir).glob(image_glob[0]))
+ for j in range(1, len(image_glob)):
+ image_path = list(Path(basedir).glob(image_glob[j]))
+ self.listing = self.listing + image_path
+ self.listing.sort()
+ if len(self.listing) == 0:
+ raise IOError('No images found (maybe bad \'image_glob\' ?)')
+ self.max_length = len(self.listing)
+ else:
+ raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir))
+
+ def load_image(self, impath):
+ raw = rawpy.imread(str(impath)).raw_image_visible
+ raw = np.clip(raw.astype('float32') - 512, 0, 65535)
+ img = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, 'RGGB').astype('float32')
+ img = np.clip(img, 0, 16383)
+
+ m = img.mean()
+ d = np.abs(img - img.mean()).mean()
+ img = (img - m + 2*d) / 4/d * 255
+ image = np.clip(img, 0, 255)
+
+ w_new, h_new = self.resize[0], self.resize[1]
+
+ im = cv2.resize(image.astype('float32'), (w_new, h_new), interpolation=cv2.INTER_AREA)
+ return im
+
+ def next_frame(self):
+ if self.i == self.max_length:
+ return (None, False)
+ image_file = str(self.listing[self.i])
+ image = self.load_image(image_file)
+ self.i = self.i + 1
+ return (image, True)
+
+
+def frame2tensor(frame, device):
+ if len(frame.shape) == 2:
+ return torch.from_numpy(frame/255.).float()[None, None].to(device)
+ else:
+ return torch.from_numpy(frame/255.).float().permute(2, 0, 1)[None].to(device)
+
+
+def make_matching_plot_fast(image0, image1, mkpts0, mkpts1,
+ color, text, path=None, margin=10,
+ opencv_display=False, opencv_title='',
+ small_text=[]):
+ H0, W0 = image0.shape[:2]
+ H1, W1 = image1.shape[:2]
+ H, W = max(H0, H1), W0 + W1 + margin
+
+ out = 255*np.ones((H, W, 3), np.uint8)
+ out[:H0, :W0, :] = image0
+ out[:H1, W0+margin:, :] = image1
+
+ # Scale factor for consistent visualization across scales.
+ sc = min(H / 640., 2.0)
+
+ # Big text.
+ Ht = int(30 * sc) # text height
+ txt_color_fg = (255, 255, 255)
+ txt_color_bg = (0, 0, 0)
+
+ for i, t in enumerate(text):
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0*sc, txt_color_bg, 2, cv2.LINE_AA)
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0*sc, txt_color_fg, 1, cv2.LINE_AA)
+
+ out_backup = out.copy()
+
+ mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
+ color = (np.array(color[:, :3])*255).astype(int)[:, ::-1]
+ for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color):
+ c = c.tolist()
+ cv2.line(out, (x0, y0), (x1 + margin + W0, y1),
+ color=c, thickness=1, lineType=cv2.LINE_AA)
+ # display line end-points as circles
+ cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA)
+ cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1,
+ lineType=cv2.LINE_AA)
+
+ # Small text.
+ Ht = int(18 * sc) # text height
+ for i, t in enumerate(reversed(small_text)):
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
+ 0.5*sc, txt_color_bg, 2, cv2.LINE_AA)
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
+ 0.5*sc, txt_color_fg, 1, cv2.LINE_AA)
+
+ if path is not None:
+ cv2.imwrite(str(path), out)
+
+ if opencv_display:
+ cv2.imshow(opencv_title, out)
+ cv2.waitKey(1)
+
+ return out / 2 + out_backup / 2
+
diff --git a/imcui/third_party/DarkFeat/utils/nn.py b/imcui/third_party/DarkFeat/utils/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a80631d6e12d848cceee3b636baf49deaa7647a
--- /dev/null
+++ b/imcui/third_party/DarkFeat/utils/nn.py
@@ -0,0 +1,50 @@
+import torch
+from torch import nn
+
+
+class NN2(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, data):
+ desc1, desc2 = data['descriptors0'].cuda(), data['descriptors1'].cuda()
+ kpts1, kpts2 = data['keypoints0'].cuda(), data['keypoints1'].cuda()
+
+ # torch.cuda.synchronize()
+ # t = time.time()
+
+ if kpts1.shape[1] <= 1 or kpts2.shape[1] <= 1: # no keypoints
+ shape0, shape1 = kpts1.shape[:-1], kpts2.shape[:-1]
+ return {
+ 'matches0': kpts1.new_full(shape0, -1, dtype=torch.int),
+ 'matches1': kpts2.new_full(shape1, -1, dtype=torch.int),
+ 'matching_scores0': kpts1.new_zeros(shape0),
+ 'matching_scores1': kpts2.new_zeros(shape1),
+ }
+
+ sim = torch.matmul(desc1.squeeze().T, desc2.squeeze())
+ ids1 = torch.arange(0, sim.shape[0], device=desc1.device)
+ nn12 = torch.argmax(sim, dim=1)
+
+ nn21 = torch.argmax(sim, dim=0)
+ mask = torch.eq(ids1, nn21[nn12])
+ matches = torch.stack([torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)])
+ # matches = torch.stack([ids1, nn12])
+ indices0 = torch.ones((1, desc1.shape[-1]), dtype=int) * -1
+ mscores0 = torch.ones((1, desc1.shape[-1]), dtype=float) * -1
+
+ # torch.cuda.synchronize()
+ # print(time.time() - t)
+
+ matches_0 = matches[0].cpu().int().numpy()
+ matches_1 = matches[1].cpu().int()
+ for i in range(matches.shape[-1]):
+ indices0[0, matches_0[i]] = matches_1[i].int()
+ mscores0[0, matches_0[i]] = sim[matches_0[i], matches_1[i]]
+
+ return {
+ 'matches0': indices0, # use -1 for invalid match
+ 'matches1': indices0, # use -1 for invalid match
+ 'matching_scores0': mscores0,
+ 'matching_scores1': mscores0,
+ }
diff --git a/imcui/third_party/DarkFeat/utils/nnmatching.py b/imcui/third_party/DarkFeat/utils/nnmatching.py
new file mode 100644
index 0000000000000000000000000000000000000000..7be6f98c050fc2e416ef48e25ca0f293106c1082
--- /dev/null
+++ b/imcui/third_party/DarkFeat/utils/nnmatching.py
@@ -0,0 +1,41 @@
+import torch
+
+from .nn import NN2
+from darkfeat import DarkFeat
+
+class NNMatching(torch.nn.Module):
+ def __init__(self, model_path=''):
+ super().__init__()
+ self.nn = NN2().eval()
+ self.darkfeat = DarkFeat(model_path).eval()
+
+ def forward(self, data):
+ """ Run DarkFeat and nearest neighborhood matching
+ Args:
+ data: dictionary with minimal keys: ['image0', 'image1']
+ """
+ pred = {}
+
+ # Extract DarkFeat (keypoints, scores, descriptors)
+ if 'keypoints0' not in data:
+ pred0 = self.darkfeat({'image': data['image0']})
+ # print({k+'0': v[0].shape for k, v in pred0.items()})
+ pred = {**pred, **{k+'0': [v] for k, v in pred0.items()}}
+ if 'keypoints1' not in data:
+ pred1 = self.darkfeat({'image': data['image1']})
+ pred = {**pred, **{k+'1': [v] for k, v in pred1.items()}}
+
+
+ # Batch all features
+ # We should either have i) one image per batch, or
+ # ii) the same number of local features for all images in the batch.
+ data = {**data, **pred}
+
+ for k in data:
+ if isinstance(data[k], (list, tuple)):
+ data[k] = torch.stack(data[k])
+
+ # Perform the matching
+ pred = {**pred, **self.nn(data)}
+
+ return pred
diff --git a/imcui/third_party/DeDoDe/DeDoDe/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9716b62f0672cfc604ca95280d8aa51a04944d4f
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/__init__.py
@@ -0,0 +1,2 @@
+from .model_zoo import dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G
+DEBUG_MODE = False
diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..06d86ba8d4e509dae88e7f5297407a542d9a8774
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py
@@ -0,0 +1,4 @@
+from .num_inliers import NumInliersBenchmark
+from .mega_pose_est import MegaDepthPoseEstimationBenchmark
+from .mega_pose_est_mnn import MegaDepthPoseMNNBenchmark
+from .nll_benchmark import MegadepthNLLBenchmark
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py
new file mode 100644
index 0000000000000000000000000000000000000000..2104284b54d5fe339d6f12d9ae14dcdd3c0fb564
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py
@@ -0,0 +1,114 @@
+import numpy as np
+import torch
+from DeDoDe.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+
+class MegaDepthPoseEstimationBenchmark:
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+ if scene_names is None:
+ self.scene_names = [
+ "0015_0.1_0.3.npz",
+ "0015_0.3_0.5.npz",
+ "0022_0.1_0.3.npz",
+ "0022_0.3_0.5.npz",
+ "0022_0.5_0.7.npz",
+ ]
+ else:
+ self.scene_names = scene_names
+ self.scenes = [
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
+ for scene in self.scene_names
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, keypoint_model, matching_model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True):
+ H,W = matching_model.get_output_resolution()
+ with torch.no_grad():
+ data_root = self.data_root
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ thresholds = [5, 10, 20]
+ for scene_ind in range(len(self.scenes)):
+ import os
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
+ scene = self.scenes[scene_ind]
+ pairs = scene["pair_infos"]
+ intrinsics = scene["intrinsics"]
+ poses = scene["poses"]
+ im_paths = scene["image_paths"]
+ pair_inds = range(len(pairs))
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind][0]
+ K1 = intrinsics[idx1].copy()
+ T1 = poses[idx1].copy()
+ R1, t1 = T1[:3, :3], T1[:3, 3]
+ K2 = intrinsics[idx2].copy()
+ T2 = poses[idx2].copy()
+ R2, t2 = T2[:3, :3], T2[:3, 3]
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
+
+ keypoints_A = keypoint_model.detect_from_path(im_A_path, num_keypoints = 20_000)["keypoints"][0]
+ keypoints_B = keypoint_model.detect_from_path(im_B_path, num_keypoints = 20_000)["keypoints"][0]
+ warp, certainty = matching_model.match(im_A_path, im_B_path)
+ matches = matching_model.match_keypoints(keypoints_A, keypoints_B, warp, certainty, return_tuple = False)
+ im_A = Image.open(im_A_path)
+ w1, h1 = im_A.size
+ im_B = Image.open(im_B_path)
+ w2, h2 = im_B.size
+ if scale_intrinsics:
+ scale1 = 1200 / max(w1, h1)
+ scale2 = 1200 / max(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1, K2 = K1.copy(), K2.copy()
+ K1[:2] = K1[:2] * scale1
+ K2[:2] = K2[:2] * scale2
+ kpts1, kpts2 = matching_model.to_pixel_coordinates(matches, h1, w1, h2, w2)
+ for _ in range(1):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ threshold = 0.5
+ if calibrated:
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1.cpu().numpy(),
+ kpts2.cpu().numpy(),
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ print(f"{model_name} auc: {auc}")
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d717a09701889fdae42eb7aba7050025ad7c6c52
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py
@@ -0,0 +1,119 @@
+import numpy as np
+import torch
+from DeDoDe.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+
+class MegaDepthPoseMNNBenchmark:
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+ if scene_names is None:
+ self.scene_names = [
+ "0015_0.1_0.3.npz",
+ "0015_0.3_0.5.npz",
+ "0022_0.1_0.3.npz",
+ "0022_0.3_0.5.npz",
+ "0022_0.5_0.7.npz",
+ ]
+ else:
+ self.scene_names = scene_names
+ self.scenes = [
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
+ for scene in self.scene_names
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, detector_model, descriptor_model, matcher_model, model_name = None, resolution = None, scale_intrinsics = False, calibrated = True):
+ with torch.no_grad():
+ data_root = self.data_root
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ thresholds = [5, 10, 20]
+ for scene_ind in range(len(self.scenes)):
+ import os
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
+ scene = self.scenes[scene_ind]
+ pairs = scene["pair_infos"]
+ intrinsics = scene["intrinsics"]
+ poses = scene["poses"]
+ im_paths = scene["image_paths"]
+ pair_inds = range(len(pairs))
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind][0]
+ K1 = intrinsics[idx1].copy()
+ T1 = poses[idx1].copy()
+ R1, t1 = T1[:3, :3], T1[:3, 3]
+ K2 = intrinsics[idx2].copy()
+ T2 = poses[idx2].copy()
+ R2, t2 = T2[:3, :3], T2[:3, 3]
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
+ detections_A = detector_model.detect_from_path(im_A_path)
+ keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
+ detections_B = detector_model.detect_from_path(im_B_path)
+ keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"]
+ description_A = descriptor_model.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"]
+ description_B = descriptor_model.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"]
+ matches_A, matches_B, batch_ids = matcher_model.match(keypoints_A, description_A,
+ keypoints_B, description_B,
+ P_A = P_A, P_B = P_B,
+ normalize = True, inv_temp=20, threshold = 0.01)
+
+ im_A = Image.open(im_A_path)
+ w1, h1 = im_A.size
+ im_B = Image.open(im_B_path)
+ w2, h2 = im_B.size
+ if scale_intrinsics:
+ scale1 = 840 / max(w1, h1)
+ scale2 = 840 / max(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1, K2 = K1.copy(), K2.copy()
+ K1[:2] = K1[:2] * scale1
+ K2[:2] = K2[:2] * scale2
+ kpts1, kpts2 = matcher_model.to_pixel_coords(matches_A, matches_B, h1, w1, h2, w2)
+ for _ in range(1):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ threshold = 0.5
+ if calibrated:
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1.cpu().numpy(),
+ kpts2.cpu().numpy(),
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ print(f"{model_name} auc: {auc}")
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64103708919594bf8d297d92a908afb79f48002
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py
@@ -0,0 +1,57 @@
+import torch
+import torch.nn as nn
+from DeDoDe.utils import *
+import DeDoDe
+
+class MegadepthNLLBenchmark(nn.Module):
+
+ def __init__(self, dataset, num_samples = 1000, batch_size = 8, device = "cuda") -> None:
+ super().__init__()
+ sampler = torch.utils.data.WeightedRandomSampler(
+ torch.ones(len(dataset)), replacement=False, num_samples=num_samples
+ )
+ dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler
+ )
+ self.dataloader = dataloader
+ self.tracked_metrics = {}
+ self.batch_size = batch_size
+ self.N = len(dataloader)
+
+ def compute_batch_metrics(self, detector, descriptor, batch, device = "cuda"):
+ kpts = detector.detect(batch)["keypoints"]
+ descriptions_A, descriptions_B = descriptor.describe_keypoints(batch, kpts)["descriptions"].chunk(2)
+ kpts_A, kpts_B = kpts.chunk(2)
+ mask_A_to_B, kpts_A_to_B = warp_kpts(kpts_A,
+ batch["im_A_depth"],
+ batch["im_B_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],)
+ mask_B_to_A, kpts_B_to_A = warp_kpts(kpts_B,
+ batch["im_B_depth"],
+ batch["im_A_depth"],
+ batch["T_1to2"].inverse(),
+ batch["K2"],
+ batch["K1"],)
+ with torch.no_grad():
+ D_B = torch.cdist(kpts_A_to_B, kpts_B)
+ D_A = torch.cdist(kpts_A, kpts_B_to_A)
+ inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values)
+ * (D_A == D_A.min(dim=-2, keepdim = True).values)
+ * (D_B < 0.01)
+ * (D_A < 0.01))
+ logP_A_B = dual_log_softmax_matcher(descriptions_A, descriptions_B,
+ normalize = True,
+ inv_temperature = 20)
+ neg_log_likelihood = -logP_A_B[inds[:,0], inds[:,1], inds[:,2]].mean()
+ self.tracked_metrics["neg_log_likelihood"] = self.tracked_metrics.get("neg_log_likelihood", 0) + 1/self.N * neg_log_likelihood
+
+ def benchmark(self, detector, descriptor):
+ self.tracked_metrics = {}
+ from tqdm import tqdm
+ print("Evaluating percent inliers...")
+ for idx, batch in tqdm(enumerate(self.dataloader), mininterval = 10.):
+ batch = to_cuda(batch)
+ self.compute_batch_metrics(detector, descriptor, batch)
+ [print(name, metric.item() * self.N / (idx+1)) for name, metric in self.tracked_metrics.items()]
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb3a5869d8ff15ff4d0b300da8259a99e38c5cf2
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn as nn
+from DeDoDe.utils import *
+import DeDoDe
+
+class NumInliersBenchmark(nn.Module):
+
+ def __init__(self, dataset, num_samples = 1000, batch_size = 8, num_keypoints = 10_000, device = get_best_device()) -> None:
+ super().__init__()
+ sampler = torch.utils.data.WeightedRandomSampler(
+ torch.ones(len(dataset)), replacement=False, num_samples=num_samples
+ )
+ dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=batch_size, num_workers=batch_size, sampler=sampler
+ )
+ self.dataloader = dataloader
+ self.tracked_metrics = {}
+ self.batch_size = batch_size
+ self.N = len(dataloader)
+ self.num_keypoints = num_keypoints
+
+ def compute_batch_metrics(self, outputs, batch, device = get_best_device()):
+ kpts_A, kpts_B = outputs["keypoints_A"], outputs["keypoints_B"]
+ B, K, H, W = batch["im_A"].shape
+ gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(
+ batch["im_A_depth"],
+ batch["im_B_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],
+ H=H,
+ W=W,
+ )
+ kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:],
+ align_corners=False, mode = 'bilinear')[...,0].mT
+ legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:],
+ align_corners=False, mode = 'bilinear')[...,0,:,0]
+ dists = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0.]).float()
+ if legit_A_to_B.sum() == 0:
+ return
+ percent_inliers_at_1 = (dists < 0.02).float().mean()
+ percent_inliers_at_05 = (dists < 0.01).float().mean()
+ percent_inliers_at_025 = (dists < 0.005).float().mean()
+ percent_inliers_at_01 = (dists < 0.002).float().mean()
+ percent_inliers_at_005 = (dists < 0.001).float().mean()
+
+ inlier_bins = torch.linspace(0, 0.002, steps = 100, device = device)[None]
+ inlier_counts = (dists[...,None] < inlier_bins).float().mean(dim=0)
+ self.tracked_metrics["inlier_counts"] = self.tracked_metrics.get("inlier_counts", 0) + 1/self.N * inlier_counts
+ self.tracked_metrics["percent_inliers_at_1"] = self.tracked_metrics.get("percent_inliers_at_1", 0) + 1/self.N * percent_inliers_at_1
+ self.tracked_metrics["percent_inliers_at_05"] = self.tracked_metrics.get("percent_inliers_at_05", 0) + 1/self.N * percent_inliers_at_05
+ self.tracked_metrics["percent_inliers_at_025"] = self.tracked_metrics.get("percent_inliers_at_025", 0) + 1/self.N * percent_inliers_at_025
+ self.tracked_metrics["percent_inliers_at_01"] = self.tracked_metrics.get("percent_inliers_at_01", 0) + 1/self.N * percent_inliers_at_01
+ self.tracked_metrics["percent_inliers_at_005"] = self.tracked_metrics.get("percent_inliers_at_005", 0) + 1/self.N * percent_inliers_at_005
+
+ def benchmark(self, detector):
+ self.tracked_metrics = {}
+ from tqdm import tqdm
+ print("Evaluating percent inliers...")
+ for idx, batch in tqdm(enumerate(self.dataloader), mininterval = 10.):
+ batch = to_best_device(batch)
+ outputs = detector.detect(batch, num_keypoints = self.num_keypoints)
+ keypoints_A, keypoints_B = outputs["keypoints"][:self.batch_size], outputs["keypoints"][self.batch_size:]
+ if isinstance(outputs["keypoints"], (tuple, list)):
+ keypoints_A, keypoints_B = torch.stack(keypoints_A), torch.stack(keypoints_B)
+ outputs = {"keypoints_A": keypoints_A, "keypoints_B": keypoints_B}
+ self.compute_batch_metrics(outputs, batch)
+ import matplotlib.pyplot as plt
+ plt.plot(torch.linspace(0, 0.002, steps = 100), self.tracked_metrics["inlier_counts"].cpu())
+ import numpy as np
+ x = np.linspace(0,0.002, 100)
+ sigma = 0.52 * 2 / 512
+ F = 1 - np.exp(-x**2 / (2*sigma**2))
+ plt.plot(x, F)
+ plt.savefig("vis/inlier_counts")
+ [print(name, metric.item() * self.N / (idx+1)) for name, metric in self.tracked_metrics.items() if "percent" in name]
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/checkpoint.py b/imcui/third_party/DeDoDe/DeDoDe/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..07d6f80ae09acf5702475504a8e8d61f40c21cd3
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/checkpoint.py
@@ -0,0 +1,59 @@
+import os
+import torch
+from torch.nn.parallel.data_parallel import DataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
+import gc
+
+import DeDoDe
+
+class CheckPoint:
+ def __init__(self, dir=None, name="tmp"):
+ self.name = name
+ self.dir = dir
+ os.makedirs(self.dir, exist_ok=True)
+
+ def save(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ if DeDoDe.RANK == 0:
+ assert model is not None
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
+ model = model.module
+ states = {
+ "model": model.state_dict(),
+ "n": n,
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(states, self.dir + self.name + f"_latest.pth")
+ print(f"Saved states {list(states.keys())}, at step {n}")
+
+ def load(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ if os.path.exists(self.dir + self.name + f"_latest.pth") and DeDoDe.RANK == 0:
+ states = torch.load(self.dir + self.name + f"_latest.pth")
+ if "model" in states:
+ model.load_state_dict(states["model"])
+ if "n" in states:
+ n = states["n"] if states["n"] else n
+ if "optimizer" in states:
+ try:
+ optimizer.load_state_dict(states["optimizer"])
+ except Exception as e:
+ print(f"Failed to load states for optimizer, with error {e}")
+ if "lr_scheduler" in states:
+ lr_scheduler.load_state_dict(states["lr_scheduler"])
+ print(f"Loaded states {list(states.keys())}, at step {n}")
+ del states
+ gc.collect()
+ torch.cuda.empty_cache()
+ return model, optimizer, lr_scheduler, n
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/datasets/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DeDoDe/DeDoDe/datasets/megadepth.py b/imcui/third_party/DeDoDe/DeDoDe/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..7de9d9a8e270fb74a6591944878c0e5e70ddf650
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/datasets/megadepth.py
@@ -0,0 +1,269 @@
+import os
+from PIL import Image
+import h5py
+import numpy as np
+import torch
+import torchvision.transforms.functional as tvf
+from tqdm import tqdm
+
+from DeDoDe.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+import DeDoDe
+from DeDoDe.utils import *
+
+class MegadepthScene:
+ def __init__(
+ self,
+ data_root,
+ scene_info,
+ ht=512,
+ wt=512,
+ min_overlap=0.0,
+ max_overlap=1.0,
+ shake_t=0,
+ scene_info_detections=None,
+ scene_info_detections3D=None,
+ normalize=True,
+ max_num_pairs = 100_000,
+ scene_name = None,
+ use_horizontal_flip_aug = False,
+ grayscale = False,
+ clahe = False,
+ ) -> None:
+ self.data_root = data_root
+ self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
+ self.image_paths = scene_info["image_paths"]
+ self.depth_paths = scene_info["depth_paths"]
+ self.intrinsics = scene_info["intrinsics"]
+ self.poses = scene_info["poses"]
+ self.pairs = scene_info["pairs"]
+ self.overlaps = scene_info["overlaps"]
+ threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
+ self.pairs = self.pairs[threshold]
+ self.overlaps = self.overlaps[threshold]
+ self.detections = scene_info_detections
+ self.tracks3D = scene_info_detections3D
+ if len(self.pairs) > max_num_pairs:
+ pairinds = np.random.choice(
+ np.arange(0, len(self.pairs)), max_num_pairs, replace=False
+ )
+ self.pairs = self.pairs[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ self.im_transform_ops = get_tuple_transform_ops(
+ resize=(ht, wt), normalize=normalize, clahe = clahe,
+ )
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
+ resize=(ht, wt), normalize=False
+ )
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.use_horizontal_flip_aug = use_horizontal_flip_aug
+ self.grayscale = grayscale
+
+ def load_im(self, im_B, crop=None):
+ im = Image.open(im_B)
+ return im
+
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
+ im_A = im_A.flip(-1)
+ im_B = im_B.flip(-1)
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+ K_A = flip_mat@K_A
+ K_B = flip_mat@K_B
+
+ return im_A, im_B, depth_A, depth_B, K_A, K_B
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
+ return torch.from_numpy(depth)
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+ return sK @ K
+
+ def scale_detections(self, detections, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ return detections * torch.tensor([[sx,sy]])
+
+ def rand_shake(self, *things):
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=(2))
+ return [
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
+ for thing in things
+ ], t
+
+ def tracks_to_detections(self, tracks3D, pose, intrinsics, H, W):
+ tracks3D = tracks3D.double()
+ intrinsics = intrinsics.double()
+ bearing_vectors = pose[...,:3,:3] @ tracks3D.mT + pose[...,:3,3:]
+ hom_pixel_coords = (intrinsics @ bearing_vectors).mT
+ pixel_coords = hom_pixel_coords[...,:2] / (hom_pixel_coords[...,2:]+1e-12)
+ legit_detections = (pixel_coords > 0).prod(dim = -1) * (pixel_coords[...,0] < W - 1) * (pixel_coords[...,1] < H - 1) * (tracks3D != 0).prod(dim=-1)
+ return pixel_coords.float(), legit_detections.bool()
+
+ def __getitem__(self, pair_idx):
+ try:
+ # read intrinsics of original size
+ idx1, idx2 = self.pairs[pair_idx]
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T1 = self.poses[idx1]
+ T2 = self.poses[idx2]
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+ :4, :4
+ ] # (4, 4)
+
+ # Load positive pair data
+ im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
+ im_A_ref = os.path.join(self.data_root, im_A)
+ im_B_ref = os.path.join(self.data_root, im_B)
+ depth_A_ref = os.path.join(self.data_root, depth1)
+ depth_B_ref = os.path.join(self.data_root, depth2)
+ # return torch.randn((1000,1000))
+ im_A = self.load_im(im_A_ref)
+ im_B = self.load_im(im_B_ref)
+ depth_A = self.load_depth(depth_A_ref)
+ depth_B = self.load_depth(depth_B_ref)
+
+ # Recompute camera intrinsic matrix due to the resize
+ W_A, H_A = im_A.width, im_A.height
+ W_B, H_B = im_B.width, im_B.height
+
+ detections2D_A = self.detections[idx1]
+ detections2D_B = self.detections[idx2]
+
+ K = 10000
+ tracks3D_A = torch.zeros(K,3)
+ tracks3D_B = torch.zeros(K,3)
+ tracks3D_A[:len(detections2D_A)] = torch.tensor(self.tracks3D[detections2D_A[:K,-1].astype(np.int32)])
+ tracks3D_B[:len(detections2D_B)] = torch.tensor(self.tracks3D[detections2D_B[:K,-1].astype(np.int32)])
+
+ #projs_A, _ = self.tracks_to_detections(tracks3D_A, T1, K1, W_A, H_A)
+ #tracks3D_B = torch.zeros(K,2)
+
+ K1 = self.scale_intrinsic(K1, W_A, H_A)
+ K2 = self.scale_intrinsic(K2, W_B, H_B)
+
+ # Process images
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
+ depth_A, depth_B = self.depth_transform_ops(
+ (depth_A[None, None], depth_B[None, None])
+ )
+ [im_A, depth_A], t_A = self.rand_shake(im_A, depth_A)
+ [im_B, depth_B], t_B = self.rand_shake(im_B, depth_B)
+
+ detections_A = -torch.ones(K,2)
+ detections_B = -torch.ones(K,2)
+ detections_A[:len(self.detections[idx1])] = self.scale_detections(torch.tensor(detections2D_A[:K,:2]), W_A, H_A) + t_A
+ detections_B[:len(self.detections[idx2])] = self.scale_detections(torch.tensor(detections2D_B[:K,:2]), W_B, H_B) + t_B
+
+
+ K1[:2, 2] += t_A
+ K2[:2, 2] += t_B
+
+ if self.use_horizontal_flip_aug:
+ if np.random.rand() > 0.5:
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+ detections_A[:,0] = W-detections_A
+ detections_B[:,0] = W-detections_B
+
+ if DeDoDe.DEBUG_MODE:
+ tensor_to_pil(im_A[0], unnormalize=True).save(
+ f"vis/im_A.jpg")
+ tensor_to_pil(im_B[0], unnormalize=True).save(
+ f"vis/im_B.jpg")
+ if self.grayscale:
+ im_A = im_A.mean(dim=-3,keepdim=True)
+ im_B = im_B.mean(dim=-3,keepdim=True)
+ data_dict = {
+ "im_A": im_A,
+ "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
+ "im_B": im_B,
+ "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
+ "im_A_depth": depth_A[0, 0],
+ "im_B_depth": depth_B[0, 0],
+ "pose_A": T1,
+ "pose_B": T2,
+ "detections_A": detections_A,
+ "detections_B": detections_B,
+ "tracks3D_A": tracks3D_A,
+ "tracks3D_B": tracks3D_B,
+ "K1": K1,
+ "K2": K2,
+ "T_1to2": T_1to2,
+ "im_A_path": im_A_ref,
+ "im_B_path": im_B_ref,
+ }
+ except Exception as e:
+ print(e)
+ print(f"Failed to load image pair {self.pairs[pair_idx]}")
+ print("Loading a random pair in scene instead")
+ rand_ind = np.random.choice(range(len(self)))
+ return self[rand_ind]
+ return data_dict
+
+
+class MegadepthBuilder:
+ def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
+ self.all_scenes = os.listdir(self.scene_info_root)
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
+ # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
+ self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
+ self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
+ self.loftr_ignore = loftr_ignore
+ self.imc21_ignore = imc21_ignore
+
+ def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
+ if split == "train":
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
+ elif split == "train_loftr":
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
+ elif split == "test":
+ scene_names = self.test_scenes
+ elif split == "test_loftr":
+ scene_names = self.test_scenes_loftr
+ elif split == "custom":
+ scene_names = scene_names
+ else:
+ raise ValueError(f"Split {split} not available")
+ scenes = []
+ for scene_name in tqdm(scene_names):
+ if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
+ continue
+ if self.imc21_ignore and scene_name in self.imc21_scenes:
+ continue
+ if ".npy" not in scene_name:
+ continue
+ scene_info = np.load(
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+ ).item()
+ scene_info_detections = np.load(
+ os.path.join(self.scene_info_root, "detections", f"detections_{scene_name}"), allow_pickle=True
+ ).item()
+ scene_info_detections3D = np.load(
+ os.path.join(self.scene_info_root, "detections3D", f"detections3D_{scene_name}"), allow_pickle=True
+ )
+
+ scenes.append(
+ MegadepthScene(
+ self.data_root, scene_info, scene_info_detections = scene_info_detections, scene_info_detections3D = scene_info_detections3D, min_overlap=min_overlap,scene_name = scene_name, **kwargs
+ )
+ )
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=0.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
+ return ws
diff --git a/imcui/third_party/DeDoDe/DeDoDe/decoder.py b/imcui/third_party/DeDoDe/DeDoDe/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e1b58fcc588e6ee12c591b5f446829a914bc611
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/decoder.py
@@ -0,0 +1,90 @@
+import torch
+import torch.nn as nn
+import torchvision.models as tvm
+
+
+class Decoder(nn.Module):
+ def __init__(self, layers, *args, super_resolution = False, num_prototypes = 1, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.layers = layers
+ self.scales = self.layers.keys()
+ self.super_resolution = super_resolution
+ self.num_prototypes = num_prototypes
+ def forward(self, features, context = None, scale = None):
+ if context is not None:
+ features = torch.cat((features, context), dim = 1)
+ stuff = self.layers[scale](features)
+ logits, context = stuff[:,:self.num_prototypes], stuff[:,self.num_prototypes:]
+ return logits, context
+
+class ConvRefiner(nn.Module):
+ def __init__(
+ self,
+ in_dim=6,
+ hidden_dim=16,
+ out_dim=2,
+ dw=True,
+ kernel_size=5,
+ hidden_blocks=5,
+ amp = True,
+ residual = False,
+ amp_dtype = torch.float16,
+ ):
+ super().__init__()
+ self.block1 = self.create_block(
+ in_dim, hidden_dim, dw=False, kernel_size=1,
+ )
+ self.hidden_blocks = nn.Sequential(
+ *[
+ self.create_block(
+ hidden_dim,
+ hidden_dim,
+ dw=dw,
+ kernel_size=kernel_size,
+ )
+ for hb in range(hidden_blocks)
+ ]
+ )
+ self.hidden_blocks = self.hidden_blocks
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+ self.residual = residual
+
+ def create_block(
+ self,
+ in_dim,
+ out_dim,
+ dw=True,
+ kernel_size=5,
+ bias = True,
+ norm_type = nn.BatchNorm2d,
+ ):
+ num_groups = 1 if not dw else in_dim
+ if dw:
+ assert (
+ out_dim % in_dim == 0
+ ), "outdim must be divisible by indim for depthwise"
+ conv1 = nn.Conv2d(
+ in_dim,
+ out_dim,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ groups=num_groups,
+ bias=bias,
+ )
+ norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
+ relu = nn.ReLU(inplace=True)
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
+ return nn.Sequential(conv1, norm, relu, conv2)
+
+ def forward(self, feats):
+ b,c,hs,ws = feats.shape
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+ x0 = self.block1(feats)
+ x = self.hidden_blocks(x0)
+ if self.residual:
+ x = (x + x0)/1.4
+ x = self.out_conv(x)
+ return x
diff --git a/imcui/third_party/DeDoDe/DeDoDe/descriptors/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/descriptors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py b/imcui/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py
new file mode 100644
index 0000000000000000000000000000000000000000..47629729f36b96aef4604e05bb99bd59b6ee070c
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py
@@ -0,0 +1,50 @@
+import torch
+from PIL import Image
+import torch.nn as nn
+import torchvision.models as tvm
+import torch.nn.functional as F
+import numpy as np
+from DeDoDe.utils import get_best_device
+
+class DeDoDeDescriptor(nn.Module):
+ def __init__(self, encoder, decoder, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.encoder = encoder
+ self.decoder = decoder
+ import torchvision.transforms as transforms
+ self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ def forward(
+ self,
+ batch,
+ ):
+ if "im_A" in batch:
+ images = torch.cat((batch["im_A"], batch["im_B"]))
+ else:
+ images = batch["image"]
+ features, sizes = self.encoder(images)
+ descriptor = 0
+ context = None
+ scales = self.decoder.scales
+ for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
+ delta_descriptor, context = self.decoder(feature_map, scale = scale, context = context)
+ descriptor = descriptor + delta_descriptor
+ if idx < len(scales) - 1:
+ size = sizes[-(idx+2)]
+ descriptor = F.interpolate(descriptor, size = size, mode = "bilinear", align_corners = False)
+ context = F.interpolate(context, size = size, mode = "bilinear", align_corners = False)
+ return {"description_grid" : descriptor}
+
+ @torch.inference_mode()
+ def describe_keypoints(self, batch, keypoints):
+ self.train(False)
+ description_grid = self.forward(batch)["description_grid"]
+ described_keypoints = F.grid_sample(description_grid.float(), keypoints[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT
+ return {"descriptions": described_keypoints}
+
+ def read_image(self, im_path, H = 784, W = 784, device=get_best_device()):
+ return self.normalizer(torch.from_numpy(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1)).float().to(device)[None]
+
+ def describe_keypoints_from_path(self, im_path, keypoints, H = 784, W = 784):
+ batch = {"image": self.read_image(im_path, H = H, W = W)}
+ return self.describe_keypoints(batch, keypoints)
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py b/imcui/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ece7fed2db02ea8ea51b4b5f49391cdcaef0903
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+import math
+import torch.nn.functional as F
+
+from DeDoDe.utils import *
+import DeDoDe
+
+class DescriptorLoss(nn.Module):
+
+ def __init__(self, detector, num_keypoints = 5000, normalize_descriptions = False, inv_temp = 1, device = get_best_device()) -> None:
+ super().__init__()
+ self.detector = detector
+ self.tracked_metrics = {}
+ self.num_keypoints = num_keypoints
+ self.normalize_descriptions = normalize_descriptions
+ self.inv_temp = inv_temp
+
+ def warp_from_depth(self, batch, kpts_A, kpts_B):
+ mask_A_to_B, kpts_A_to_B = warp_kpts(kpts_A,
+ batch["im_A_depth"],
+ batch["im_B_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],)
+ mask_B_to_A, kpts_B_to_A = warp_kpts(kpts_B,
+ batch["im_B_depth"],
+ batch["im_A_depth"],
+ batch["T_1to2"].inverse(),
+ batch["K2"],
+ batch["K1"],)
+ return (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A)
+
+ def warp_from_homog(self, batch, kpts_A, kpts_B):
+ kpts_A_to_B = homog_transform(batch["Homog_A_to_B"], kpts_A)
+ kpts_B_to_A = homog_transform(batch["Homog_A_to_B"].inverse(), kpts_B)
+ return (None, kpts_A_to_B), (None, kpts_B_to_A)
+
+ def supervised_loss(self, outputs, batch):
+ kpts_A, kpts_B = self.detector.detect(batch, num_keypoints = self.num_keypoints)['keypoints'].clone().chunk(2)
+ desc_grid_A, desc_grid_B = outputs["description_grid"].chunk(2)
+ desc_A = F.grid_sample(desc_grid_A.float(), kpts_A[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT
+ desc_B = F.grid_sample(desc_grid_B.float(), kpts_B[:,None], mode = "bilinear", align_corners = False)[:,:,0].mT
+ if "im_A_depth" in batch:
+ (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_depth(batch, kpts_A, kpts_B)
+ elif "Homog_A_to_B" in batch:
+ (mask_A_to_B, kpts_A_to_B), (mask_B_to_A, kpts_B_to_A) = self.warp_from_homog(batch, kpts_A, kpts_B)
+
+ with torch.no_grad():
+ D_B = torch.cdist(kpts_A_to_B, kpts_B)
+ D_A = torch.cdist(kpts_A, kpts_B_to_A)
+ inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values)
+ * (D_A == D_A.min(dim=-2, keepdim = True).values)
+ * (D_B < 0.01)
+ * (D_A < 0.01))
+
+ logP_A_B = dual_log_softmax_matcher(desc_A, desc_B,
+ normalize = self.normalize_descriptions,
+ inv_temperature = self.inv_temp)
+ neg_log_likelihood = -logP_A_B[inds[:,0], inds[:,1], inds[:,2]].mean()
+ self.tracked_metrics["neg_log_likelihood"] = (0.99 * self.tracked_metrics.get("neg_log_likelihood", neg_log_likelihood.detach().item()) + 0.01 * neg_log_likelihood.detach().item())
+ if np.random.rand() > 0.99:
+ print(self.tracked_metrics["neg_log_likelihood"])
+ return neg_log_likelihood
+
+ def forward(self, outputs, batch):
+ losses = self.supervised_loss(outputs, batch)
+ return losses
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/detectors/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/detectors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py b/imcui/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a02f621a4a93a30df94c2fe5f6fd0297ce53f95
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py
@@ -0,0 +1,76 @@
+import torch
+from PIL import Image
+import torch.nn as nn
+import torchvision.models as tvm
+import torch.nn.functional as F
+import numpy as np
+
+from DeDoDe.utils import sample_keypoints, to_pixel_coords, to_normalized_coords, get_best_device
+
+
+
+class DeDoDeDetector(nn.Module):
+ def __init__(self, encoder, decoder, *args, remove_borders = False, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.encoder = encoder
+ self.decoder = decoder
+ import torchvision.transforms as transforms
+ self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ self.remove_borders = remove_borders
+
+ def forward(
+ self,
+ batch,
+ ):
+ if "im_A" in batch:
+ images = torch.cat((batch["im_A"], batch["im_B"]))
+ else:
+ images = batch["image"]
+ features, sizes = self.encoder(images)
+ logits = 0
+ context = None
+ scales = ["8", "4", "2", "1"]
+ for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
+ delta_logits, context = self.decoder(feature_map, context = context, scale = scale)
+ logits = logits + delta_logits.float() # ensure float (need bf16 doesnt have f.interpolate)
+ if idx < len(scales) - 1:
+ size = sizes[-(idx+2)]
+ logits = F.interpolate(logits, size = size, mode = "bicubic", align_corners = False)
+ context = F.interpolate(context.float(), size = size, mode = "bilinear", align_corners = False)
+ return {"keypoint_logits" : logits.float()}
+
+ @torch.inference_mode()
+ def detect(self, batch, num_keypoints = 10_000):
+ self.train(False)
+ keypoint_logits = self.forward(batch)["keypoint_logits"]
+ B,K,H,W = keypoint_logits.shape
+ keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1)
+ keypoints, confidence = sample_keypoints(keypoint_p.reshape(B,H,W),
+ use_nms = False, sample_topk = True, num_samples = num_keypoints,
+ return_scoremap=True, sharpen = False, upsample = False,
+ increase_coverage=True, remove_borders = self.remove_borders)
+ return {"keypoints": keypoints, "confidence": confidence}
+
+ @torch.inference_mode()
+ def detect_dense(self, batch):
+ self.train(False)
+ keypoint_logits = self.forward(batch)["keypoint_logits"]
+ return {"dense_keypoint_logits": keypoint_logits}
+
+ def read_image(self, im_path, H = 784, W = 784, device=get_best_device()):
+ pil_im = Image.open(im_path).resize((W, H))
+ standard_im = np.array(pil_im)/255.
+ return self.normalizer(torch.from_numpy(standard_im).permute(2,0,1)).float().to(device)[None]
+
+ def detect_from_path(self, im_path, num_keypoints = 30_000, H = 784, W = 784, dense = False):
+ batch = {"image": self.read_image(im_path, H = H, W = W)}
+ if dense:
+ return self.detect_dense(batch)
+ else:
+ return self.detect(batch, num_keypoints = num_keypoints)
+
+ def to_pixel_coords(self, x, H, W):
+ return to_pixel_coords(x, H, W)
+
+ def to_normalized_coords(self, x, H, W):
+ return to_normalized_coords(x, H, W)
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py b/imcui/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8dfbd7747aedad25101dd4b59e9cf950bfe4880
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py
@@ -0,0 +1,185 @@
+import torch
+import torch.nn as nn
+import math
+
+from DeDoDe.utils import *
+import DeDoDe
+
+class KeyPointLoss(nn.Module):
+
+ def __init__(self, smoothing_size = 1, use_max_logit = False, entropy_target = 80,
+ num_matches = 1024, jacobian_density_adjustment = False,
+ matchability_weight = 1, device = "cuda") -> None:
+ super().__init__()
+ X = torch.linspace(-1,1,smoothing_size, device = device)
+ G = (-X**2 / (2 *1/2**2)).exp()
+ G = G/G.sum()
+ self.use_max_logit = use_max_logit
+ self.entropy_target = entropy_target
+ self.smoothing_kernel = G[None, None, None,:]
+ self.smoothing_size = smoothing_size
+ self.tracked_metrics = {}
+ self.center = None
+ self.num_matches = num_matches
+ self.jacobian_density_adjustment = jacobian_density_adjustment
+ self.matchability_weight = matchability_weight
+
+ def compute_consistency(self, logits_A, logits_B_to_A, mask = None):
+
+ masked_logits_A = torch.full_like(logits_A, -torch.inf)
+ masked_logits_A[mask] = logits_A[mask]
+
+ masked_logits_B_to_A = torch.full_like(logits_B_to_A, -torch.inf)
+ masked_logits_B_to_A[mask] = logits_B_to_A[mask]
+
+ log_p_A = masked_logits_A.log_softmax(dim=-1)[mask]
+ log_p_B_to_A = masked_logits_B_to_A.log_softmax(dim=-1)[mask]
+
+ return self.compute_jensen_shannon_div(log_p_A, log_p_B_to_A)
+
+ def compute_joint_neg_log_likelihood(self, logits_A, logits_B_to_A, detections_A = None, detections_B_to_A = None, mask = None, device = "cuda", dtype = torch.float32, num_matches = None):
+ B, K, HW = logits_A.shape
+ logits_A, logits_B_to_A = logits_A.to(dtype), logits_B_to_A.to(dtype)
+ mask = mask[:,None].expand(B, K, HW).reshape(B, K*HW)
+ log_p_B_to_A = self.masked_log_softmax(logits_B_to_A.reshape(B,K*HW), mask = mask)
+ log_p_A = self.masked_log_softmax(logits_A.reshape(B,K*HW), mask = mask)
+ log_p = log_p_A + log_p_B_to_A
+ if detections_A is None:
+ detections_A = torch.zeros_like(log_p_A)
+ if detections_B_to_A is None:
+ detections_B_to_A = torch.zeros_like(log_p_B_to_A)
+ detections_A = detections_A.reshape(B, HW)
+ detections_A[~mask] = 0
+ detections_B_to_A = detections_B_to_A.reshape(B, HW)
+ detections_B_to_A[~mask] = 0
+ log_p_target = log_p.detach() + 50*detections_A + 50*detections_B_to_A
+ num_matches = self.num_matches if num_matches is None else num_matches
+ best_k = -(-log_p_target).flatten().kthvalue(k = B * num_matches, dim=-1).values
+ p_target = (log_p_target > best_k[..., None]).float().reshape(B,K*HW)/num_matches
+ return self.compute_cross_entropy(log_p_A[mask], p_target[mask]) + self.compute_cross_entropy(log_p_B_to_A[mask], p_target[mask])
+
+ def compute_jensen_shannon_div(self, log_p, log_q):
+ return 1/2 * (self.compute_kl_div(log_p, log_q) + self.compute_kl_div(log_q, log_p))
+
+ def compute_kl_div(self, log_p, log_q):
+ return (log_p.exp()*(log_p-log_q)).sum(dim=-1)
+
+ def masked_log_softmax(self, logits, mask):
+ masked_logits = torch.full_like(logits, -torch.inf)
+ masked_logits[mask] = logits[mask]
+ log_p = masked_logits.log_softmax(dim=-1)
+ return log_p
+
+ def masked_softmax(self, logits, mask):
+ masked_logits = torch.full_like(logits, -torch.inf)
+ masked_logits[mask] = logits[mask]
+ log_p = masked_logits.softmax(dim=-1)
+ return log_p
+
+ def compute_detection_img(self, detections, mask, B, H, W, device = "cuda"):
+ kernel_size = 5
+ X = torch.linspace(-2,2,kernel_size, device = device)
+ G = (-X**2 / (2 * (1/2)**2)).exp() # half pixel std
+ G = G/G.sum()
+ det_smoothing_kernel = G[None, None, None,:]
+ det_img = torch.zeros((B,1,H,W), device = device)
+ for b in range(B):
+ valid_detections = (detections[b][mask[b]]).int()
+ det_img[b,0][valid_detections[:,1], valid_detections[:,0]] = 1
+ det_img = F.conv2d(det_img, weight = det_smoothing_kernel, padding = (kernel_size//2, 0))
+ det_img = F.conv2d(det_img, weight = det_smoothing_kernel.mT, padding = (0, kernel_size//2))
+ return det_img
+
+ def compute_cross_entropy(self, log_p_hat, p):
+ return -(log_p_hat * p).sum(dim=-1)
+
+ def compute_matchability(self, keypoint_p, has_depth, B, K, H, W, device = "cuda"):
+ smooth_keypoint_p = F.conv2d(keypoint_p.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (self.smoothing_size//2,0))
+ smooth_keypoint_p = F.conv2d(smooth_keypoint_p, weight = self.smoothing_kernel.mT, padding = (0,self.smoothing_size//2))
+ log_p_hat = (smooth_keypoint_p+1e-8).log().reshape(B,H*W).log_softmax(dim=-1)
+ smooth_has_depth = F.conv2d(has_depth.reshape(B,1,H,W), weight = self.smoothing_kernel, padding = (0,self.smoothing_size//2))
+ smooth_has_depth = F.conv2d(smooth_has_depth, weight = self.smoothing_kernel.mT, padding = (self.smoothing_size//2,0)).reshape(B,H*W)
+ p = smooth_has_depth/smooth_has_depth.sum(dim=-1,keepdim=True)
+ return self.compute_cross_entropy(log_p_hat, p) - self.compute_cross_entropy((p+1e-12).log(), p)
+
+ def supervised_loss(self, outputs, batch):
+ keypoint_logits_A, keypoint_logits_B = outputs["keypoint_logits"].chunk(2)
+ B, K, H, W = keypoint_logits_A.shape
+
+ detections_A, detections_B = batch["detections_A"], batch["detections_B"]
+
+ gt_warp_A_to_B, valid_mask_A_to_B = get_gt_warp(
+ batch["im_A_depth"],
+ batch["im_B_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],
+ H=H,
+ W=W,
+ )
+ gt_warp_B_to_A, valid_mask_B_to_A = get_gt_warp(
+ batch["im_B_depth"],
+ batch["im_A_depth"],
+ batch["T_1to2"].inverse(),
+ batch["K2"],
+ batch["K1"],
+ H=H,
+ W=W,
+ )
+ keypoint_logits_A = keypoint_logits_A.reshape(B, K, H*W)
+ keypoint_logits_B = keypoint_logits_B.reshape(B, K, H*W)
+ keypoint_logits = torch.cat((keypoint_logits_A, keypoint_logits_B))
+
+ B = 2*B
+ gt_warp = torch.cat((gt_warp_A_to_B, gt_warp_B_to_A))
+ valid_mask = torch.cat((valid_mask_A_to_B, valid_mask_B_to_A))
+ valid_mask = valid_mask.reshape(B,H*W)
+ binary_mask = valid_mask == 1
+ detections = torch.cat((detections_A, detections_B))
+ legit_detections = ((detections > 0).prod(dim = -1) * (detections[...,0] < W) * (detections[...,1] < H)).bool()
+ det_imgs_A, det_imgs_B = self.compute_detection_img(detections, legit_detections, B, H, W).chunk(2)
+ det_imgs = torch.cat((det_imgs_A, det_imgs_B))
+ det_imgs_backwarped = F.grid_sample(torch.cat((det_imgs_B, det_imgs_A)).reshape(B,1,H,W),
+ gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic")
+
+ keypoint_logits_backwarped = F.grid_sample(torch.cat((keypoint_logits_B, keypoint_logits_A)).reshape(B,K,H,W),
+ gt_warp[...,-2:].reshape(B,H,W,2).float(), align_corners = False, mode = "bicubic")
+
+ keypoint_logits_backwarped = (keypoint_logits_backwarped).reshape(B,K,H*W)
+
+
+ depth = F.interpolate(torch.cat((batch["im_A_depth"][:,None],batch["im_B_depth"][:,None]),dim=0), size = (H,W), mode = "bilinear", align_corners=False)
+ has_depth = (depth > 0).float().reshape(B,H*W)
+
+ joint_log_likelihood_loss = self.compute_joint_neg_log_likelihood(keypoint_logits, keypoint_logits_backwarped,
+ mask = binary_mask, detections_A = det_imgs,
+ detections_B_to_A = det_imgs_backwarped).mean()
+ keypoint_p = keypoint_logits.reshape(B, K*H*W).softmax(dim=-1).reshape(B, K, H*W).sum(dim=1)
+ matchability_loss = self.compute_matchability(keypoint_p, has_depth, B, K, H, W).mean()
+ B = B//2
+ kpts_A = sample_keypoints(keypoint_p[:B].reshape(B,H,W),
+ use_nms = False, sample_topk = True, num_samples = 4*2048)
+ kpts_B = sample_keypoints(keypoint_p[B:].reshape(B,H,W),
+ use_nms = False, sample_topk = True, num_samples = 4*2048)
+ kpts_A_to_B = F.grid_sample(gt_warp_A_to_B[...,2:].float().permute(0,3,1,2), kpts_A[...,None,:],
+ align_corners=False, mode = 'bilinear')[...,0].mT
+ legit_A_to_B = F.grid_sample(valid_mask_A_to_B.reshape(B,1,H,W), kpts_A[...,None,:],
+ align_corners=False, mode = 'bilinear')[...,0,:,0]
+ percent_inliers = (torch.cdist(kpts_A_to_B, kpts_B).min(dim=-1).values[legit_A_to_B > 0] < 0.01).float().mean()
+ self.tracked_metrics["mega_percent_inliers"] = (0.9 * self.tracked_metrics.get("mega_percent_inliers", percent_inliers) + 0.1 * percent_inliers)
+
+ tot_loss = joint_log_likelihood_loss + self.matchability_weight * matchability_loss#
+ if torch.rand(1) > 1:
+ print(f"Precent Inlier: {self.tracked_metrics.get('mega_percent_inliers', 0)}")
+ print(f"{joint_log_likelihood_loss=} {matchability_loss=}")
+ print(f"Total Loss: {tot_loss.item()}")
+ return tot_loss
+
+ def forward(self, outputs, batch):
+
+ if not isinstance(outputs, list):
+ outputs = [outputs]
+ losses = 0
+ for output in outputs:
+ losses = losses + self.supervised_loss(output, batch)
+ return losses
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/encoder.py b/imcui/third_party/DeDoDe/DeDoDe/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..91880e7d5e98b02259127b107a459401b99bb157
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/encoder.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+import torchvision.models as tvm
+
+
+class VGG19(nn.Module):
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
+ # Maxpool layers: 6, 13, 26, 39
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+
+ def forward(self, x, **kwargs):
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+ feats = []
+ sizes = []
+ for layer in self.layers:
+ if isinstance(layer, nn.MaxPool2d):
+ feats.append(x)
+ sizes.append(x.shape[-2:])
+ x = layer(x)
+ return feats, sizes
+
+class VGG(nn.Module):
+ def __init__(self, size = "19", pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
+ super().__init__()
+ if size == "11":
+ self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22])
+ elif size == "13":
+ self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28])
+ elif size == "19":
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
+ # Maxpool layers: 6, 13, 26, 39
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+
+ def forward(self, x, **kwargs):
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+ feats = []
+ sizes = []
+ for layer in self.layers:
+ if isinstance(layer, nn.MaxPool2d):
+ feats.append(x)
+ sizes.append(x.shape[-2:])
+ x = layer(x)
+ return feats, sizes
+
+class FrozenDINOv2(nn.Module):
+ def __init__(self, amp = True, amp_dtype = torch.float16, dinov2_weights = None):
+ super().__init__()
+ if dinov2_weights is None:
+ dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
+ from .transformer import vit_large
+ vit_kwargs = dict(img_size= 518,
+ patch_size= 14,
+ init_values = 1.0,
+ ffn_layer = "mlp",
+ block_chunks = 0,
+ )
+ dinov2_vitl14 = vit_large(**vit_kwargs).eval()
+ dinov2_vitl14.load_state_dict(dinov2_weights)
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+ if self.amp:
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
+ def forward(self, x):
+ B, C, H, W = x.shape
+ if self.dinov2_vitl14[0].device != x.device:
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
+ with torch.inference_mode():
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
+ return [features_16.clone()], [(H//14, W//14)] # clone from inference mode to use in autograd
+
+class VGG_DINOv2(nn.Module):
+ def __init__(self, vgg_kwargs = None, dinov2_kwargs = None):
+ assert vgg_kwargs is not None and dinov2_kwargs is not None, "Input kwargs pls"
+ super().__init__()
+ self.vgg = VGG(**vgg_kwargs)
+ self.frozen_dinov2 = FrozenDINOv2(**dinov2_kwargs)
+
+ def forward(self, x):
+ feats_vgg, sizes_vgg = self.vgg(x)
+ feat_dinov2, size_dinov2 = self.frozen_dinov2(x)
+ return feats_vgg + feat_dinov2, sizes_vgg + size_dinov2
diff --git a/imcui/third_party/DeDoDe/DeDoDe/matchers/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/matchers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py b/imcui/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cc76cad77ee403d7d5ab729c786982a47fbe6e9
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py
@@ -0,0 +1,38 @@
+import torch
+from PIL import Image
+import torch.nn as nn
+import torchvision.models as tvm
+import torch.nn.functional as F
+import numpy as np
+from DeDoDe.utils import dual_softmax_matcher, to_pixel_coords, to_normalized_coords
+
+class DualSoftMaxMatcher(nn.Module):
+ @torch.inference_mode()
+ def match(self, keypoints_A, descriptions_A,
+ keypoints_B, descriptions_B, P_A = None, P_B = None,
+ normalize = False, inv_temp = 1, threshold = 0.0):
+ if isinstance(descriptions_A, list):
+ matches = [self.match(k_A[None], d_A[None], k_B[None], d_B[None], normalize = normalize,
+ inv_temp = inv_temp, threshold = threshold)
+ for k_A,d_A,k_B,d_B in
+ zip(keypoints_A, descriptions_A, keypoints_B, descriptions_B)]
+ matches_A = torch.cat([m[0] for m in matches])
+ matches_B = torch.cat([m[1] for m in matches])
+ inds = torch.cat([m[2] + b for b, m in enumerate(matches)])
+ return matches_A, matches_B, inds
+
+ P = dual_softmax_matcher(descriptions_A, descriptions_B,
+ normalize = normalize, inv_temperature=inv_temp,
+ )
+ inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values)
+ * (P == P.max(dim=-2, keepdim = True).values) * (P > threshold))
+ batch_inds = inds[:,0]
+ matches_A = keypoints_A[batch_inds, inds[:,1]]
+ matches_B = keypoints_B[batch_inds, inds[:,2]]
+ return matches_A, matches_B, batch_inds
+
+ def to_pixel_coords(self, x_A, x_B, H_A, W_A, H_B, W_B):
+ return to_pixel_coords(x_A, H_A, W_A), to_pixel_coords(x_B, H_B, W_B)
+
+ def to_normalized_coords(self, x_A, x_B, H_A, W_A, H_B, W_B):
+ return to_normalized_coords(x_A, H_A, W_A), to_normalized_coords(x_B, H_B, W_B)
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0775d438f94b6095d094e119f788368170694c4c
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py
@@ -0,0 +1,3 @@
+from .dedode_models import dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G
+
+
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..deac312b81691024c2124ebd825f374f9e8c9db1
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py
@@ -0,0 +1,249 @@
+import torch
+import torch.nn as nn
+
+from DeDoDe.detectors.dedode_detector import DeDoDeDetector
+from DeDoDe.descriptors.dedode_descriptor import DeDoDeDescriptor
+from DeDoDe.decoder import ConvRefiner, Decoder
+from DeDoDe.encoder import VGG19, VGG, VGG_DINOv2
+from DeDoDe.utils import get_best_device
+
+
+def dedode_detector_B(device = get_best_device(), weights = None):
+ residual = True
+ hidden_blocks = 5
+ amp_dtype = torch.float16
+ amp = True
+ NUM_PROTOTYPES = 1
+ conv_refiner = nn.ModuleDict(
+ {
+ "8": ConvRefiner(
+ 512,
+ 512,
+ 256 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ "4": ConvRefiner(
+ 256+256,
+ 256,
+ 128 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "2": ConvRefiner(
+ 128+128,
+ 64,
+ 32 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "1": ConvRefiner(
+ 64 + 32,
+ 32,
+ 1 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ }
+ )
+ encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype)
+ decoder = Decoder(conv_refiner)
+ model = DeDoDeDetector(encoder = encoder, decoder = decoder).to(device)
+ if weights is not None:
+ model.load_state_dict(weights)
+ return model
+
+
+def dedode_detector_L(device = get_best_device(), weights = None, remove_borders = False):
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_L_v2.pth", map_location = device)
+ NUM_PROTOTYPES = 1
+ residual = True
+ hidden_blocks = 8
+ amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+ amp = True
+ conv_refiner = nn.ModuleDict(
+ {
+ "8": ConvRefiner(
+ 512,
+ 512,
+ 256 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ "4": ConvRefiner(
+ 256+256,
+ 256,
+ 128 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "2": ConvRefiner(
+ 128+128,
+ 128,
+ 64 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "1": ConvRefiner(
+ 64 + 64,
+ 64,
+ 1 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ }
+ )
+ encoder = VGG19(pretrained = False, amp = amp, amp_dtype = amp_dtype)
+ decoder = Decoder(conv_refiner)
+ model = DeDoDeDetector(encoder = encoder, decoder = decoder, remove_borders = remove_borders).to(device)
+ if weights is not None:
+ model.load_state_dict(weights)
+ return model
+
+
+
+def dedode_descriptor_B(device = get_best_device(), weights = None):
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth", map_location=device)
+ NUM_PROTOTYPES = 256 # == descriptor size
+ residual = True
+ hidden_blocks = 5
+ amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+ amp = True
+ conv_refiner = nn.ModuleDict(
+ {
+ "8": ConvRefiner(
+ 512,
+ 512,
+ 256 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ "4": ConvRefiner(
+ 256+256,
+ 256,
+ 128 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "2": ConvRefiner(
+ 128+128,
+ 64,
+ 32 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "1": ConvRefiner(
+ 64 + 32,
+ 32,
+ 1 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ }
+ )
+ encoder = VGG(size = "19", pretrained = False, amp = amp, amp_dtype = amp_dtype)
+ decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES)
+ model = DeDoDeDescriptor(encoder = encoder, decoder = decoder).to(device)
+ if weights is not None:
+ model.load_state_dict(weights)
+ return model
+
+def dedode_descriptor_G(device = get_best_device(), weights = None, dinov2_weights = None):
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth", map_location=device)
+ NUM_PROTOTYPES = 256 # == descriptor size
+ residual = True
+ hidden_blocks = 5
+ amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+ amp = True
+ conv_refiner = nn.ModuleDict(
+ {
+ "14": ConvRefiner(
+ 1024,
+ 768,
+ 512 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ "8": ConvRefiner(
+ 512 + 512,
+ 512,
+ 256 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ "4": ConvRefiner(
+ 256+256,
+ 256,
+ 128 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "2": ConvRefiner(
+ 128+128,
+ 64,
+ 32 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+
+ ),
+ "1": ConvRefiner(
+ 64 + 32,
+ 32,
+ 1 + NUM_PROTOTYPES,
+ hidden_blocks = hidden_blocks,
+ residual = residual,
+ amp = amp,
+ amp_dtype = amp_dtype,
+ ),
+ }
+ )
+ vgg_kwargs = dict(size = "19", pretrained = False, amp = amp, amp_dtype = amp_dtype)
+ dinov2_kwargs = dict(amp = amp, amp_dtype = amp_dtype, dinov2_weights = dinov2_weights)
+ encoder = VGG_DINOv2(vgg_kwargs = vgg_kwargs, dinov2_kwargs = dinov2_kwargs)
+ decoder = Decoder(conv_refiner, num_prototypes=NUM_PROTOTYPES)
+ model = DeDoDeDescriptor(encoder = encoder, decoder = decoder).to(device)
+ if weights is not None:
+ model.load_state_dict(weights)
+ return model
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/train.py b/imcui/third_party/DeDoDe/DeDoDe/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..342cdd636c8d5ae0b693bf6220ba088bdbc2035c
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/train.py
@@ -0,0 +1,76 @@
+import torch
+from tqdm import tqdm
+from DeDoDe.utils import to_cuda, to_best_device
+
+
+def train_step(train_batch, model, objective, optimizer, grad_scaler = None,**kwargs):
+ optimizer.zero_grad()
+ out = model(train_batch)
+ l = objective(out, train_batch)
+ if grad_scaler is not None:
+ grad_scaler.scale(l).backward()
+ grad_scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
+ grad_scaler.step(optimizer)
+ grad_scaler.update()
+ else:
+ l.backward()
+ optimizer.step()
+ return {"train_out": out, "train_loss": l.item()}
+
+
+def train_k_steps(
+ n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler = None, progress_bar=True
+):
+ for n in tqdm(range(n_0, n_0 + k), disable=not progress_bar, mininterval = 10.):
+ batch = next(dataloader)
+ model.train(True)
+ batch = to_best_device(batch)
+ train_step(
+ train_batch=batch,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ n=n,
+ grad_scaler = grad_scaler,
+ )
+ lr_scheduler.step()
+
+
+def train_epoch(
+ dataloader=None,
+ model=None,
+ objective=None,
+ optimizer=None,
+ lr_scheduler=None,
+ epoch=None,
+):
+ model.train(True)
+ print(f"At epoch {epoch}")
+ for batch in tqdm(dataloader, mininterval=5.0):
+ batch = to_best_device(batch)
+ train_step(
+ train_batch=batch, model=model, objective=objective, optimizer=optimizer
+ )
+ lr_scheduler.step()
+ return {
+ "model": model,
+ "optimizer": optimizer,
+ "lr_scheduler": lr_scheduler,
+ "epoch": epoch,
+ }
+
+
+def train_k_epochs(
+ start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
+):
+ for epoch in range(start_epoch, end_epoch + 1):
+ train_epoch(
+ dataloader=dataloader,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ )
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..031d52e998bc18f6d5264fb8b791a6339cf793b5
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/__init__.py
@@ -0,0 +1,8 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from DeDoDe.utils import get_grid
+from .layers.block import Block
+from .layers.attention import MemEffAttention
+from .dinov2 import vit_large
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/dinov2.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b556c63096d17239c8603d5fe626c331963099fd
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/dinov2.py
@@ -0,0 +1,359 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @property
+ def device(self):
+ return self.cls_token.device
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f9b0c94b40967dfdff4f261c127cbd21328c905
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py
@@ -0,0 +1,81 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/block.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/imcui/third_party/DeDoDe/DeDoDe/utils.py b/imcui/third_party/DeDoDe/DeDoDe/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9475dc8927aa2256fc9d947cc3034dff9420e6c4
--- /dev/null
+++ b/imcui/third_party/DeDoDe/DeDoDe/utils.py
@@ -0,0 +1,717 @@
+import warnings
+import numpy as np
+import math
+import cv2
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import torch.nn.functional as F
+from PIL import Image
+from einops import rearrange
+import torch
+from time import perf_counter
+
+
+def get_best_device(verbose = False):
+ device = torch.device('cpu')
+ if torch.cuda.is_available():
+ device = torch.device('cuda')
+ elif torch.backends.mps.is_available():
+ device = torch.device('mps')
+ else:
+ device = torch.device('cpu')
+ if verbose: print (f"Fastest device found is: {device}")
+ return device
+
+
+def recover_pose(E, kpts0, kpts1, K0, K1, mask):
+ best_num_inliers = 0
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+
+
+# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
+# --- GEOMETRY ---
+def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
+ )
+
+ ret = None
+ if E is not None:
+ best_num_inliers = 0
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+
+def get_grid(B,H,W, device = get_best_device()):
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=device
+ )
+ for n in (B, H, W)
+ ]
+ )
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
+ return x1_n
+
+@torch.no_grad()
+def finite_diff_hessian(f: tuple(["B", "H", "W"]), device = get_best_device()):
+ dxx = torch.tensor([[0,0,0],[1,-2,1],[0,0,0]], device = device)[None,None]/2
+ dxy = torch.tensor([[1,0,-1],[0,0,0],[-1,0,1]], device = device)[None,None]/4
+ dyy = dxx.mT
+ Hxx = F.conv2d(f[:,None], dxx, padding = 1)[:,0]
+ Hxy = F.conv2d(f[:,None], dxy, padding = 1)[:,0]
+ Hyy = F.conv2d(f[:,None], dyy, padding = 1)[:,0]
+ H = torch.stack((Hxx, Hxy, Hxy, Hyy), dim = -1).reshape(*f.shape,2,2)
+ return H
+
+def finite_diff_grad(f: tuple(["B", "H", "W"]), device = get_best_device()):
+ dx = torch.tensor([[0,0,0],[-1,0,1],[0,0,0]],device = device)[None,None]/2
+ dy = dx.mT
+ gx = F.conv2d(f[:,None], dx, padding = 1)
+ gy = F.conv2d(f[:,None], dy, padding = 1)
+ g = torch.cat((gx, gy), dim = 1)
+ return g
+
+def fast_inv_2x2(matrix: tuple[...,2,2], eps = 1e-10):
+ return 1/(torch.linalg.det(matrix)[...,None,None]+eps) * torch.stack((matrix[...,1,1],-matrix[...,0,1],
+ -matrix[...,1,0],matrix[...,0,0]),dim=-1).reshape(*matrix.shape)
+
+def newton_step(f:tuple["B","H","W"], inds, device = get_best_device()):
+ B,H,W = f.shape
+ Hess = finite_diff_hessian(f).reshape(B,H*W,2,2)
+ Hess = torch.gather(Hess, dim = 1, index = inds[...,None].expand(B,-1,2,2))
+ grad = finite_diff_grad(f).reshape(B,H*W,2)
+ grad = torch.gather(grad, dim = 1, index = inds)
+ Hessinv = fast_inv_2x2(Hess-torch.eye(2, device = device)[None,None])
+ step = (Hessinv @ grad[...,None])
+ return step[...,0]
+
+@torch.no_grad()
+def sample_keypoints(scoremap, num_samples = 8192, device = get_best_device(), use_nms = True,
+ sample_topk = False, return_scoremap = False, sharpen = False, upsample = False,
+ increase_coverage = False, remove_borders = False):
+ #scoremap = scoremap**2
+ log_scoremap = (scoremap+1e-10).log()
+ if upsample:
+ log_scoremap = F.interpolate(log_scoremap[:,None], scale_factor = 3, mode = "bicubic", align_corners = False)[:,0]#.clamp(min = 0)
+ scoremap = log_scoremap.exp()
+ B,H,W = scoremap.shape
+ if increase_coverage:
+ weights = (-torch.linspace(-2, 2, steps = 51, device = device)**2).exp()[None,None]
+ # 10000 is just some number for maybe numerical stability, who knows. :), result is invariant anyway
+ local_density_x = F.conv2d((scoremap[:,None]+1e-6)*10000,weights[...,None,:], padding = (0,51//2))
+ local_density = F.conv2d(local_density_x, weights[...,None], padding = (51//2,0))[:,0]
+ scoremap = scoremap * (local_density+1e-8)**(-1/2)
+ grid = get_grid(B,H,W, device=device).reshape(B,H*W,2)
+ if sharpen:
+ laplace_operator = torch.tensor([[[[0,1,0],[1,-4,1],[0,1,0]]]], device = device)/4
+ scoremap = scoremap[:,None] - 0.5 * F.conv2d(scoremap[:,None], weight = laplace_operator, padding = 1)
+ scoremap = scoremap[:,0].clamp(min = 0)
+ if use_nms:
+ scoremap = scoremap * (scoremap == F.max_pool2d(scoremap, (3, 3), stride = 1, padding = 1))
+ if remove_borders:
+ frame = torch.zeros_like(scoremap)
+ # we hardcode 4px, could do it nicer, but whatever
+ frame[...,4:-4, 4:-4] = 1
+ scoremap = scoremap * frame
+ if sample_topk:
+ inds = torch.topk(scoremap.reshape(B,H*W), k = num_samples).indices
+ else:
+ inds = torch.multinomial(scoremap.reshape(B,H*W), num_samples = num_samples, replacement=False)
+ kps = torch.gather(grid, dim = 1, index = inds[...,None].expand(B,num_samples,2))
+ if return_scoremap:
+ return kps, torch.gather(scoremap.reshape(B,H*W), dim = 1, index = inds)
+ return kps
+
+@torch.no_grad()
+def jacobi_determinant(warp, certainty, R = 3, device = get_best_device(), dtype = torch.float32):
+ t = perf_counter()
+ *dims, _ = warp.shape
+ warp = warp.to(dtype)
+ certainty = certainty.to(dtype)
+
+ dtype = warp.dtype
+ match_regions = torch.zeros((*dims, 4, R, R), device = device).to(dtype)
+ match_regions[:,1:-1, 1:-1] = warp.unfold(1,R,1).unfold(2,R,1)
+ match_regions = rearrange(match_regions,"B H W D R1 R2 -> B H W (R1 R2) D") - warp[...,None,:]
+
+ match_regions_cert = torch.zeros((*dims, R, R), device = device).to(dtype)
+ match_regions_cert[:,1:-1, 1:-1] = certainty.unfold(1,R,1).unfold(2,R,1)
+ match_regions_cert = rearrange(match_regions_cert,"B H W R1 R2 -> B H W (R1 R2)")[..., None]
+
+ #print("Time for unfold", perf_counter()-t)
+ #t = perf_counter()
+ *dims, N, D = match_regions.shape
+ # standardize:
+ mu, sigma = match_regions.mean(dim=(-2,-1), keepdim = True), match_regions.std(dim=(-2,-1),keepdim=True)
+ match_regions = (match_regions-mu)/(sigma+1e-6)
+ x_a, x_b = match_regions.chunk(2,-1)
+
+
+ A = torch.zeros((*dims,2*x_a.shape[-2],4), device = device).to(dtype)
+ A[...,::2,:2] = x_a * match_regions_cert
+ A[...,1::2,2:] = x_a * match_regions_cert
+
+ a_block = A[...,::2,:2]
+ ata = a_block.mT @ a_block
+ #print("Time for ata", perf_counter()-t)
+ #t = perf_counter()
+
+ #atainv = torch.linalg.inv(ata+1e-5*torch.eye(2,device=device).to(dtype))
+ atainv = fast_inv_2x2(ata)
+ ATA_inv = torch.zeros((*dims, 4, 4), device = device, dtype = dtype)
+ ATA_inv[...,:2,:2] = atainv
+ ATA_inv[...,2:,2:] = atainv
+ atb = A.mT @ (match_regions_cert*x_b).reshape(*dims,N*2,1)
+ theta = ATA_inv @ atb
+ #print("Time for theta", perf_counter()-t)
+ #t = perf_counter()
+
+ J = theta.reshape(*dims, 2, 2)
+ abs_J_det = torch.linalg.det(J+1e-8*torch.eye(2,2,device = device).expand(*dims,2,2)).abs() # Note: This should always be positive for correct warps, but still taking abs here
+ abs_J_logdet = (abs_J_det+1e-12).log()
+ B = certainty.shape[0]
+ # Handle outliers
+ robust_abs_J_logdet = abs_J_logdet.clamp(-3, 3) # Shouldn't be more that exp(3) \approx 8 times zoom
+ #print("Time for logdet", perf_counter()-t)
+ #t = perf_counter()
+
+ return robust_abs_J_logdet
+
+def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
+
+ if H is None:
+ B,H,W = depth1.shape
+ else:
+ B = depth1.shape[0]
+ with torch.no_grad():
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
+ )
+ for n in (B, H, W)
+ ]
+ )
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
+ mask, x2 = warp_kpts(
+ x1_n.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ depth_interpolation_mode = depth_interpolation_mode,
+ relative_depth_error_threshold = relative_depth_error_threshold,
+ )
+ prob = mask.float().reshape(B, H, W)
+ x2 = x2.reshape(B, H, W, 2)
+ return torch.cat((x1_n.reshape(B,H,W,2),x2),dim=-1), prob
+
+def unnormalize_coords(x_n,h,w):
+ x = torch.stack(
+ (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ return x
+
+
+def rotate_intrinsic(K, n):
+ base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ rot = np.linalg.matrix_power(base_rot, n)
+ return rot @ K
+
+
+def rotate_pose_inplane(i_T_w, rot):
+ rotation_matrices = [
+ np.array(
+ [
+ [np.cos(r), -np.sin(r), 0.0, 0.0],
+ [np.sin(r), np.cos(r), 0.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ dtype=np.float32,
+ )
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+ ]
+ return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+ scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
+ return np.dot(scales, K)
+
+def angle_error_mat(R1, R2):
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+ cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
+ return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t):
+ R_gt = T_0to1[:3, :3]
+ t_gt = T_0to1[:3, 3]
+ error_t = angle_error_vec(t.squeeze(), t_gt)
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_R = angle_error_mat(R, R_gt)
+ return error_t, error_R
+
+
+def pose_auc(errors, thresholds):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0.0, errors]
+ recall = np.r_[0.0, recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e) / t)
+ return aucs
+
+
+# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
+def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias = False))
+ return TupleCompose(ops)
+
+
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize, antialias = True))
+ if clahe:
+ ops.append(TupleClahe())
+ if normalize:
+ ops.append(TupleToTensorScaled())
+ ops.append(
+ TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ) # Imagenet mean/std
+ else:
+ if unscale:
+ ops.append(TupleToTensorUnscaled())
+ else:
+ ops.append(TupleToTensorScaled())
+ return TupleCompose(ops)
+
+class Clahe:
+ def __init__(self, cliplimit = 2, blocksize = 8) -> None:
+ self.clahe = cv2.createCLAHE(cliplimit,(blocksize,blocksize))
+ def __call__(self, im):
+ im_hsv = cv2.cvtColor(np.array(im),cv2.COLOR_RGB2HSV)
+ im_v = self.clahe.apply(im_hsv[:,:,2])
+ im_hsv[...,2] = im_v
+ im_clahe = cv2.cvtColor(im_hsv,cv2.COLOR_HSV2RGB)
+ return Image.fromarray(im_clahe)
+
+class TupleClahe:
+ def __init__(self, cliplimit = 8, blocksize = 8) -> None:
+ self.clahe = Clahe(cliplimit,blocksize)
+ def __call__(self, ims):
+ return [self.clahe(im) for im in ims]
+
+class ToTensorScaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
+
+ def __call__(self, im):
+ if not isinstance(im, torch.Tensor):
+ im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
+ im /= 255.0
+ return torch.from_numpy(im)
+ else:
+ return im
+
+ def __repr__(self):
+ return "ToTensorScaled(./255)"
+
+
+class TupleToTensorScaled(object):
+ def __init__(self):
+ self.to_tensor = ToTensorScaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorScaled(./255)"
+
+
+class ToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __call__(self, im):
+ return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
+
+ def __repr__(self):
+ return "ToTensorUnscaled()"
+
+
+class TupleToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __init__(self):
+ self.to_tensor = ToTensorUnscaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorUnscaled()"
+
+
+class TupleResize(object):
+ def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias = None):
+ self.size = size
+ self.resize = transforms.Resize(size, mode, antialias = antialias)
+
+ def __call__(self, im_tuple):
+ return [self.resize(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleResize(size={})".format(self.size)
+
+class Normalize:
+ def __call__(self,im):
+ mean = im.mean(dim=(1,2), keepdims=True)
+ std = im.std(dim=(1,2), keepdims=True)
+ return (im-mean)/std
+
+
+class TupleNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+ self.normalize = transforms.Normalize(mean=mean, std=std)
+
+ def __call__(self, im_tuple):
+ c,h,w = im_tuple[0].shape
+ if c > 3:
+ warnings.warn(f"Number of channels {c=} > 3, assuming first 3 are rgb")
+ return [self.normalize(im[:3]) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
+
+
+class TupleCompose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, im_tuple):
+ for t in self.transforms:
+ im_tuple = t(im_tuple)
+ return im_tuple
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1)
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ (
+ n,
+ h,
+ w,
+ ) = depth0.shape
+ if depth_interpolation_mode == "combined":
+ # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
+ if smooth_mask:
+ raise NotImplementedError("Combined bilinear and NN warp not implemented")
+ valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "bilinear",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "nearest-exact",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
+ warp = warp_bilinear.clone()
+ warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
+ valid = valid_bilinear | valid_nearest
+ return valid, warp
+
+
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
+ :, 0, :, 0
+ ]
+ kpts0 = torch.stack(
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ # Sample depth, get calculable_mask on depth != 0
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = (
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+ * kpts0_depth[..., None]
+ ) # (N, L, 3)
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+ kpts0_cam = kpts0_n
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
+ w_kpts0_h[:, :, [2]] + 1e-4
+ ) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (
+ (w_kpts0[:, :, 0] > 0)
+ * (w_kpts0[:, :, 0] < w - 1)
+ * (w_kpts0[:, :, 1] > 0)
+ * (w_kpts0[:, :, 1] < h - 1)
+ )
+ w_kpts0 = torch.stack(
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+ # w_kpts0[~covisible_mask, :] = -5 # xd
+
+ w_kpts0_depth = F.grid_sample(
+ depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
+ )[:, 0, :, 0]
+
+ relative_depth_error = (
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+ ).abs()
+ if not smooth_mask:
+ consistent_mask = relative_depth_error < relative_depth_error_threshold
+ else:
+ consistent_mask = (-relative_depth_error/smooth_mask).exp()
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+ if return_relative_depth_error:
+ return relative_depth_error, w_kpts0
+ else:
+ return valid_mask, w_kpts0
+
+imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
+imagenet_std = torch.tensor([0.229, 0.224, 0.225])
+
+
+def numpy_to_pil(x: np.ndarray):
+ """
+ Args:
+ x: Assumed to be of shape (h,w,c)
+ """
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ if x.max() <= 1.01:
+ x *= 255
+ x = x.astype(np.uint8)
+ return Image.fromarray(x)
+
+
+def tensor_to_pil(x, unnormalize=False, autoscale = False):
+ if unnormalize:
+ x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
+ if autoscale:
+ if x.max() == x.min():
+ warnings.warn("x max == x min, cant autoscale")
+ else:
+ x = (x-x.min())/(x.max()-x.min())
+
+ x = x.detach().permute(1, 2, 0).cpu().numpy()
+ x = np.clip(x, 0.0, 1.0)
+ return numpy_to_pil(x)
+
+
+def to_cuda(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cuda()
+ return batch
+
+
+def to_best_device(batch, device=get_best_device()):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.to(device)
+ return batch
+
+
+def to_cpu(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cpu()
+ return batch
+
+
+def get_pose(calib):
+ w, h = np.array(calib["imsize"])[0]
+ return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
+
+
+def compute_relative_pose(R1, t1, R2, t2):
+ rots = R2 @ (R1.T)
+ trans = -rots @ t1 + t2
+ return rots, trans
+
+def to_pixel_coords(flow, h1, w1):
+ flow = (
+ torch.stack(
+ (
+ w1 * (flow[..., 0] + 1) / 2,
+ h1 * (flow[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ return flow
+
+def to_normalized_coords(flow, h1, w1):
+ flow = (
+ torch.stack(
+ (
+ 2 * (flow[..., 0]) / w1 - 1,
+ 2 * (flow[..., 1]) / h1 - 1,
+ ),
+ axis=-1,
+ )
+ )
+ return flow
+
+
+def warp_to_pixel_coords(warp, h1, w1, h2, w2):
+ warp1 = warp[..., :2]
+ warp1 = (
+ torch.stack(
+ (
+ w1 * (warp1[..., 0] + 1) / 2,
+ h1 * (warp1[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ warp2 = warp[..., 2:]
+ warp2 = (
+ torch.stack(
+ (
+ w2 * (warp2[..., 0] + 1) / 2,
+ h2 * (warp2[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ return torch.cat((warp1,warp2), dim=-1)
+
+
+def to_homogeneous(x):
+ ones = torch.ones_like(x[...,-1:])
+ return torch.cat((x, ones), dim = -1)
+
+def from_homogeneous(xh, eps = 1e-12):
+ return xh[...,:-1] / (xh[...,-1:]+eps)
+
+def homog_transform(Homog, x):
+ xh = to_homogeneous(x)
+ yh = (Homog @ xh.mT).mT
+ y = from_homogeneous(yh)
+ return y
+
+def get_homog_warp(Homog, H, W, device = get_best_device()):
+ grid = torch.meshgrid(torch.linspace(-1+1/H,1-1/H,H, device = device), torch.linspace(-1+1/W,1-1/W,W, device = device))
+
+ x_A = torch.stack((grid[1], grid[0]), dim = -1)[None]
+ x_A_to_B = homog_transform(Homog, x_A)
+ mask = ((x_A_to_B > -1) * (x_A_to_B < 1)).prod(dim=-1).float()
+ return torch.cat((x_A.expand(*x_A_to_B.shape), x_A_to_B),dim=-1), mask
+
+def dual_log_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False):
+ B, N, C = desc_A.shape
+ if normalize:
+ desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True)
+ desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True)
+ corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
+ else:
+ corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
+ logP = corr.log_softmax(dim = -2) + corr.log_softmax(dim= -1)
+ return logP
+
+def dual_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False):
+ if len(desc_A.shape) < 3:
+ desc_A, desc_B = desc_A[None], desc_B[None]
+ B, N, C = desc_A.shape
+ if normalize:
+ desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True)
+ desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True)
+ corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
+ else:
+ corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
+ P = corr.softmax(dim = -2) * corr.softmax(dim= -1)
+ return P
+
+def conditional_softmax_matcher(desc_A: tuple['B','N','C'], desc_B: tuple['B','M','C'], inv_temperature = 1, normalize = False):
+ if len(desc_A.shape) < 3:
+ desc_A, desc_B = desc_A[None], desc_B[None]
+ B, N, C = desc_A.shape
+ if normalize:
+ desc_A = desc_A/desc_A.norm(dim=-1,keepdim=True)
+ desc_B = desc_B/desc_B.norm(dim=-1,keepdim=True)
+ corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
+ else:
+ corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
+ P_B_cond_A = corr.softmax(dim = -1)
+ P_A_cond_B = corr.softmax(dim = -2)
+
+ return P_A_cond_B, P_B_cond_A
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/data_prep/prep_keypoints.py b/imcui/third_party/DeDoDe/data_prep/prep_keypoints.py
new file mode 100644
index 0000000000000000000000000000000000000000..04fc3c7b110dbb3292b57028f75293325444e242
--- /dev/null
+++ b/imcui/third_party/DeDoDe/data_prep/prep_keypoints.py
@@ -0,0 +1,103 @@
+import argparse
+import numpy as np
+
+import os
+
+
+base_path = "data/megadepth"
+# Remove the trailing / if need be.
+if base_path[-1] in ['/', '\\']:
+ base_path = base_path[: - 1]
+
+
+base_depth_path = os.path.join(
+ base_path, 'phoenix/S6/zl548/MegaDepth_v1'
+)
+base_undistorted_sfm_path = os.path.join(
+ base_path, 'Undistorted_SfM'
+)
+
+scene_ids = os.listdir(base_undistorted_sfm_path)
+for scene_id in scene_ids:
+ if os.path.exists(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy"):
+ print(f"skipping {scene_id} as it exists")
+ continue
+ undistorted_sparse_path = os.path.join(
+ base_undistorted_sfm_path, scene_id, 'sparse-txt'
+ )
+ if not os.path.exists(undistorted_sparse_path):
+ print("sparse path doesnt exist")
+ continue
+
+ depths_path = os.path.join(
+ base_depth_path, scene_id, 'dense0', 'depths'
+ )
+ if not os.path.exists(depths_path):
+ print("depths doesnt exist")
+
+ continue
+
+ images_path = os.path.join(
+ base_undistorted_sfm_path, scene_id, 'images'
+ )
+ if not os.path.exists(images_path):
+ print("images path doesnt exist")
+ continue
+
+ # Process cameras.txt
+ if not os.path.exists(os.path.join(undistorted_sparse_path, 'cameras.txt')):
+ print("no cameras")
+ continue
+ with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f:
+ raw = f.readlines()[3 :] # skip the header
+
+ camera_intrinsics = {}
+ for camera in raw:
+ camera = camera.split(' ')
+ camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]]
+
+ # Process points3D.txt
+ with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f:
+ raw = f.readlines()[3 :] # skip the header
+
+ points3D = {}
+ for point3D in raw:
+ point3D = point3D.split(' ')
+ points3D[int(point3D[0])] = np.array([
+ float(point3D[1]), float(point3D[2]), float(point3D[3])
+ ])
+
+ points3D_np = np.zeros((max(points3D.keys())+1, 3))
+ for idx, point in points3D.items():
+ points3D_np[idx] = point
+ np.save(f"{base_path}/prep_scene_info/detections3D/detections3D_{scene_id}.npy",
+ points3D_np)
+
+ # Process images.txt
+ with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f:
+ raw = f.readlines()[4 :] # skip the header
+
+ image_id_to_idx = {}
+ image_names = []
+ raw_pose = []
+ camera = []
+ points3D_id_to_2D = []
+ n_points3D = []
+ id_to_detections = {}
+ for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])):
+ image = image.split(' ')
+ points = points.split(' ')
+
+ image_id_to_idx[int(image[0])] = idx
+
+ image_name = image[-1].strip('\n')
+ image_names.append(image_name)
+
+ raw_pose.append([float(elem) for elem in image[1 : -2]])
+ camera.append(int(image[-2]))
+ points_np = np.array(points).astype(np.float32).reshape(len(points)//3, 3)
+ visible_points = points_np[points_np[:,2] != -1]
+ id_to_detections[idx] = visible_points
+ np.save(f"{base_path}/prep_scene_info/detections/detections_{scene_id}.npy",
+ id_to_detections)
+ print(f"{scene_id} done")
diff --git a/imcui/third_party/DeDoDe/demo/demo_kpts.py b/imcui/third_party/DeDoDe/demo/demo_kpts.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc0dddbe8d5de9b67abe2976ebc5d9c23f412340
--- /dev/null
+++ b/imcui/third_party/DeDoDe/demo/demo_kpts.py
@@ -0,0 +1,24 @@
+import torch
+import cv2
+import numpy as np
+from PIL import Image
+from DeDoDe import dedode_detector_L
+from DeDoDe.utils import *
+
+def draw_kpts(im, kpts):
+ kpts = [cv2.KeyPoint(x,y,1.) for x,y in kpts.cpu().numpy()]
+ im = np.array(im)
+ ret = cv2.drawKeypoints(im, kpts, None)
+ return ret
+
+
+if __name__ == "__main__":
+ device = get_best_device()
+ detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device))
+ im_path = "assets/im_A.jpg"
+ im = Image.open(im_path)
+ out = detector.detect_from_path(im_path, num_keypoints = 10_000)
+ W,H = im.size
+ kps = out["keypoints"]
+ kps = detector.to_pixel_coords(kps, H, W)
+ Image.fromarray(draw_kpts(im, kps[0])).save("demo/keypoints.png")
diff --git a/imcui/third_party/DeDoDe/demo/demo_match.py b/imcui/third_party/DeDoDe/demo/demo_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..01143998f007ee1d2fb17adc64dcf8387510ac80
--- /dev/null
+++ b/imcui/third_party/DeDoDe/demo/demo_match.py
@@ -0,0 +1,46 @@
+import torch
+from DeDoDe import dedode_detector_L, dedode_descriptor_B
+from DeDoDe.matchers.dual_softmax_matcher import DualSoftMaxMatcher
+from DeDoDe.utils import *
+from PIL import Image
+import cv2
+import numpy as np
+
+
+def draw_matches(im_A, kpts_A, im_B, kpts_B):
+ kpts_A = [cv2.KeyPoint(x,y,1.) for x,y in kpts_A.cpu().numpy()]
+ kpts_B = [cv2.KeyPoint(x,y,1.) for x,y in kpts_B.cpu().numpy()]
+ matches_A_to_B = [cv2.DMatch(idx, idx, 0.) for idx in range(len(kpts_A))]
+ im_A, im_B = np.array(im_A), np.array(im_B)
+ ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B,
+ matches_A_to_B, None)
+ return ret
+
+if __name__ == "__main__":
+ device = get_best_device()
+ detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device))
+ descriptor = dedode_descriptor_B(weights = torch.load("dedode_descriptor_B.pth", map_location = device))
+ matcher = DualSoftMaxMatcher()
+
+ im_A_path = "assets/im_A.jpg"
+ im_B_path = "assets/im_B.jpg"
+ im_A = Image.open(im_A_path)
+ im_B = Image.open(im_B_path)
+ W_A, H_A = im_A.size
+ W_B, H_B = im_B.size
+
+
+ detections_A = detector.detect_from_path(im_A_path, num_keypoints = 10_000)
+ keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
+ detections_B = detector.detect_from_path(im_B_path, num_keypoints = 10_000)
+ keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"]
+ description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"]
+ description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"]
+ matches_A, matches_B, batch_ids = matcher.match(keypoints_A, description_A,
+ keypoints_B, description_B,
+ P_A = P_A, P_B = P_B,
+ normalize = True, inv_temp=20, threshold = 0.01)#Increasing threshold -> fewer matches, fewer outliers
+
+ matches_A, matches_B = matcher.to_pixel_coords(matches_A, matches_B, H_A, W_A, H_B, W_B)
+
+ Image.fromarray(draw_matches(im_A, matches_A, im_B, matches_B)).save("demo/matches.png")
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/demo/demo_match_dedode_G.py b/imcui/third_party/DeDoDe/demo/demo_match_dedode_G.py
new file mode 100644
index 0000000000000000000000000000000000000000..586da9a0949067264d643bfabb29cd541c9e624a
--- /dev/null
+++ b/imcui/third_party/DeDoDe/demo/demo_match_dedode_G.py
@@ -0,0 +1,45 @@
+import torch
+from DeDoDe import dedode_detector_L, dedode_descriptor_G
+from DeDoDe.matchers.dual_softmax_matcher import DualSoftMaxMatcher
+from DeDoDe.utils import *
+from PIL import Image
+import cv2
+import numpy as np
+
+
+def draw_matches(im_A, kpts_A, im_B, kpts_B):
+ kpts_A = [cv2.KeyPoint(x,y,1.) for x,y in kpts_A.cpu().numpy()]
+ kpts_B = [cv2.KeyPoint(x,y,1.) for x,y in kpts_B.cpu().numpy()]
+ matches_A_to_B = [cv2.DMatch(idx, idx, 0.) for idx in range(len(kpts_A))]
+ im_A, im_B = np.array(im_A), np.array(im_B)
+ ret = cv2.drawMatches(im_A, kpts_A, im_B, kpts_B,
+ matches_A_to_B, None)
+ return ret
+
+
+if __name__ == "__main__":
+ device = get_best_device()
+ detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device))
+ descriptor = dedode_descriptor_G(weights = torch.load("dedode_descriptor_G.pth", map_location = device))
+ matcher = DualSoftMaxMatcher()
+
+ im_A_path = "assets/im_A.jpg"
+ im_B_path = "assets/im_B.jpg"
+ im_A = Image.open(im_A_path)
+ im_B = Image.open(im_B_path)
+ W_A, H_A = im_A.size
+ W_B, H_B = im_B.size
+
+ detections_A = detector.detect_from_path(im_A_path, num_keypoints = 10_000)
+ keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
+ detections_B = detector.detect_from_path(im_B_path, num_keypoints = 10_000)
+ keypoints_B, P_B = detections_B["keypoints"], detections_B["confidence"]
+ description_A = descriptor.describe_keypoints_from_path(im_A_path, keypoints_A)["descriptions"]
+ description_B = descriptor.describe_keypoints_from_path(im_B_path, keypoints_B)["descriptions"]
+ matches_A, matches_B, batch_ids = matcher.match(keypoints_A, description_A,
+ keypoints_B, description_B,
+ P_A = P_A, P_B = P_B,
+ normalize = True, inv_temp=20, threshold = 0.01)#Increasing threshold -> fewer matches, fewer outliers
+
+ matches_A, matches_B = matcher.to_pixel_coords(matches_A, matches_B, H_A, W_A, H_B, W_B)
+ Image.fromarray(draw_matches(im_A, matches_A, im_B, matches_B)).save("demo/matches.jpg")
\ No newline at end of file
diff --git a/imcui/third_party/DeDoDe/demo/demo_scoremap.py b/imcui/third_party/DeDoDe/demo/demo_scoremap.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5ae13a89ea18b364671a29692d47d550c8e88f0
--- /dev/null
+++ b/imcui/third_party/DeDoDe/demo/demo_scoremap.py
@@ -0,0 +1,23 @@
+import torch
+from PIL import Image
+import numpy as np
+
+from DeDoDe import dedode_detector_L
+from DeDoDe.utils import tensor_to_pil, get_best_device
+
+
+if __name__ == "__main__":
+ device = get_best_device()
+ detector = dedode_detector_L(weights = torch.load("dedode_detector_L.pth", map_location = device))
+ H, W = 784, 784
+ im_path = "assets/im_A.jpg"
+
+ out = detector.detect_from_path(im_path, dense = True, H = H, W = W)
+
+ logit_map = out["dense_keypoint_logits"].clone()
+ min = logit_map.max() - 3
+ logit_map[logit_map < min] = min
+ logit_map = (logit_map-min)/(logit_map.max()-min)
+ logit_map = logit_map.cpu()[0].expand(3,H,W)
+ im_A = torch.tensor(np.array(Image.open(im_path).resize((W,H)))/255.).permute(2,0,1)
+ tensor_to_pil(logit_map * logit_map + 0.15 * (1-logit_map) * im_A).save("demo/dense_logits.png")
diff --git a/imcui/third_party/DeDoDe/setup.py b/imcui/third_party/DeDoDe/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d175ab96e3493a2e53e2daaae99eb822a71b463e
--- /dev/null
+++ b/imcui/third_party/DeDoDe/setup.py
@@ -0,0 +1,11 @@
+from setuptools import setup, find_packages
+
+
+setup(
+ name="DeDoDe",
+ packages=find_packages(include= ["DeDoDe*"]),
+ install_requires=open("requirements.txt", "r").read().split("\n"),
+ python_requires='>=3.9.0',
+ version="0.0.1",
+ author="Johan Edstedt",
+)
diff --git a/imcui/third_party/EfficientLoFTR/configs/data/__init__.py b/imcui/third_party/EfficientLoFTR/configs/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/EfficientLoFTR/configs/data/base.py b/imcui/third_party/EfficientLoFTR/configs/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..03aab160fa4137ccc04380f94854a56fbb549074
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/configs/data/base.py
@@ -0,0 +1,35 @@
+"""
+The data config will be the last one merged into the main config.
+Setups in data configs will override all existed setups!
+"""
+
+from yacs.config import CfgNode as CN
+_CN = CN()
+_CN.DATASET = CN()
+_CN.TRAINER = CN()
+
+# training data config
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+# validation set config
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+
+# testing data config
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+
+# dataset config
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+cfg = _CN
diff --git a/imcui/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py
new file mode 100644
index 0000000000000000000000000000000000000000..876bd4cad7772922d81c83ad3107ab6b8af599a3
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py
@@ -0,0 +1,13 @@
+from configs.data.base import cfg
+
+TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info"
+
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
+cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
+cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 832
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
+
+cfg.DATASET.NPE_NAME = 'megadepth'
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ce0dd463cf09d031464176a5f28a6fe5ba2ad3
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py
@@ -0,0 +1,24 @@
+from configs.data.base import cfg
+
+
+TRAIN_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
+cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
+cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
+
+TEST_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
+cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
+cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+# 368 scenes in total for MegaDepth
+# (with difficulty balanced (further split each scene to 3 sub-scenes))
+cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 832 # for training on 32GB meme GPUs
+
+cfg.DATASET.NPE_NAME = 'megadepth'
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py b/imcui/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca98ed4b120d699f8de00016f169a83c0c8ddac8
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py
@@ -0,0 +1,16 @@
+from configs.data.base import cfg
+
+TEST_BASE_PATH = "assets/scannet_test_1500"
+
+cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
+cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
+cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
+cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt"
+cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz"
+
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
+
+cfg.DATASET.SCAN_IMG_RESIZEX = 640
+cfg.DATASET.SCAN_IMG_RESIZEY = 480
+
+cfg.DATASET.NPE_NAME = 'scannet'
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py
new file mode 100644
index 0000000000000000000000000000000000000000..24ff5f33b6cf6ee11c4b564050fbe736126b8bc5
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py
@@ -0,0 +1,36 @@
+from src.config.default import _CN as cfg
+
+# training config
+cfg.TRAINER.CANONICAL_LR = 8e-3
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
+cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+cfg.TRAINER.GRADIENT_CLIPPING = 0.0
+cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2']
+cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True
+cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True
+cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25
+cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
+cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
+
+# model config
+cfg.LOFTR.RESOLUTION = (8, 1)
+cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even
+cfg.LOFTR.ALIGN_CORNER = False
+cfg.LOFTR.MP = True # just for reproducing paper, FP16 is much faster on modern GPUs
+cfg.LOFTR.REPLACE_NAN = True
+cfg.LOFTR.EVAL_TIMES = 5
+cfg.LOFTR.COARSE.NO_FLASH = True # Not use Flash-Attention just for reproducing paper timing
+cfg.LOFTR.MATCH_COARSE.THR = 0.2 # recommend 0.2 for full model and 25 for optimized model
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
+
+# dataset config
+cfg.DATASET.FP16 = False
+
+# full model config
+cfg.LOFTR.MATCH_COARSE.FP16MATMUL = False
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c044e49db7ecb31e22570d8295d8ac617dcf64c
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py
@@ -0,0 +1,37 @@
+from src.config.default import _CN as cfg
+
+# training config
+cfg.TRAINER.CANONICAL_LR = 8e-3
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24]
+cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+cfg.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+cfg.TRAINER.GRADIENT_CLIPPING = 0.0
+cfg.LOFTR.LOSS.FINE_TYPE = 'l2' # ['l2_with_std', 'l2']
+cfg.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = True
+cfg.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = True
+cfg.LOFTR.LOSS.LOCAL_WEIGHT = 0.25
+cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
+cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
+
+# model config
+cfg.LOFTR.RESOLUTION = (8, 1)
+cfg.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even
+cfg.LOFTR.ALIGN_CORNER = False
+cfg.LOFTR.MP = True # just for reproducing paper, FP16 is much faster on modern GPUs
+cfg.LOFTR.REPLACE_NAN = True
+cfg.LOFTR.EVAL_TIMES = 5
+cfg.LOFTR.COARSE.NO_FLASH = True # Not use Flash-Attention just for reproducing paper timing
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0
+cfg.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
+
+# dataset config
+cfg.DATASET.FP16 = False
+
+# optimized model config
+cfg.LOFTR.MATCH_COARSE.FP16MATMUL = True
+cfg.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = True
+cfg.LOFTR.MATCH_COARSE.THR = 25.0 # recommend 0.2 for full model and 25 for optimized model
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/environment.yaml b/imcui/third_party/EfficientLoFTR/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..52bc6f68c1b0d7c0f020453427873370753234bc
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/environment.yaml
@@ -0,0 +1,7 @@
+name: eloftr
+channels:
+ - pytorch
+ - nvidia
+dependencies:
+ - python=3.8
+ - pip
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/environment_training.yaml b/imcui/third_party/EfficientLoFTR/environment_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e9b8f38a07da3d29cec8eb9f5e6a6379a2d9800b
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/environment_training.yaml
@@ -0,0 +1,9 @@
+name: eloftr_training
+channels:
+ - pytorch
+ - nvidia
+dependencies:
+ - python=3.8
+ - cudatoolkit=11.3
+ - pytorch=1.12.1
+ - pip
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/__init__.py b/imcui/third_party/EfficientLoFTR/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/EfficientLoFTR/src/config/default.py b/imcui/third_party/EfficientLoFTR/src/config/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d98095be47b6870cf1475bcfe239de44ee98f9
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/config/default.py
@@ -0,0 +1,182 @@
+from yacs.config import CfgNode as CN
+_CN = CN()
+
+############## ↓ LoFTR Pipeline ↓ ##############
+_CN.LOFTR = CN()
+_CN.LOFTR.BACKBONE_TYPE = 'RepVGG'
+_CN.LOFTR.ALIGN_CORNER = False
+_CN.LOFTR.RESOLUTION = (8, 1)
+_CN.LOFTR.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even
+_CN.LOFTR.MP = False
+_CN.LOFTR.REPLACE_NAN = False
+_CN.LOFTR.EVAL_TIMES = 1
+_CN.LOFTR.HALF = False
+
+# 1. LoFTR-backbone (local feature CNN) config
+_CN.LOFTR.BACKBONE = CN()
+_CN.LOFTR.BACKBONE.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3
+
+# 2. LoFTR-coarse module config
+_CN.LOFTR.COARSE = CN()
+_CN.LOFTR.COARSE.D_MODEL = 256
+_CN.LOFTR.COARSE.D_FFN = 256
+_CN.LOFTR.COARSE.NHEAD = 8
+_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
+_CN.LOFTR.COARSE.AGG_SIZE0 = 4
+_CN.LOFTR.COARSE.AGG_SIZE1 = 4
+_CN.LOFTR.COARSE.NO_FLASH = False
+_CN.LOFTR.COARSE.ROPE = True
+_CN.LOFTR.COARSE.NPE = None # [832, 832, long_side, long_side] Suggest setting based on the long side of the input image, especially when the long_side > 832
+
+# 3. Coarse-Matching config
+_CN.LOFTR.MATCH_COARSE = CN()
+_CN.LOFTR.MATCH_COARSE.THR = 0.2 # recommend 0.2 for full model and 25 for optimized model
+_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2
+_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
+_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True
+_CN.LOFTR.MATCH_COARSE.SKIP_SOFTMAX = False
+_CN.LOFTR.MATCH_COARSE.FP16MATMUL = False
+
+# 4. Fine-Matching config
+_CN.LOFTR.MATCH_FINE = CN()
+_CN.LOFTR.MATCH_FINE.SPARSE_SPVS = True
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 1.0
+_CN.LOFTR.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
+
+# 5. LoFTR Losses
+# -- # coarse-level
+_CN.LOFTR.LOSS = CN()
+_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
+_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0
+_CN.LOFTR.LOSS.COARSE_SIGMOID_WEIGHT = 1.0
+_CN.LOFTR.LOSS.LOCAL_WEIGHT = 0.5
+_CN.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT = False
+_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT = False
+_CN.LOFTR.LOSS.FINE_OVERLAP_WEIGHT2 = False
+# -- - -- # focal loss (coarse)
+_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25
+_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0
+_CN.LOFTR.LOSS.POS_WEIGHT = 1.0
+_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0
+
+# -- # fine-level
+_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
+_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0
+_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
+
+
+############## Dataset ##############
+_CN.DATASET = CN()
+# 1. data config
+# training and validating
+_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+_CN.DATASET.FP16 = False
+# testing
+_CN.DATASET.TEST_DATA_SOURCE = None
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+
+# 2. dataset config
+# general options
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
+_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
+
+# scanNet options
+_CN.DATASET.SCAN_IMG_RESIZEX = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.SCAN_IMG_RESIZEY = 480 # resize the shorter side, zero-pad bottom-right to square.
+
+# MegaDepth options
+_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
+_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
+_CN.DATASET.MGDPT_DF = 8
+
+_CN.DATASET.NPE_NAME = None
+
+############## Trainer ##############
+_CN.TRAINER = CN()
+_CN.TRAINER.WORLD_SIZE = 1
+_CN.TRAINER.CANONICAL_BS = 64
+_CN.TRAINER.CANONICAL_LR = 6e-3
+_CN.TRAINER.SCALING = None # this will be calculated automatically
+_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
+
+# optimizer
+_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
+_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
+_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
+_CN.TRAINER.ADAMW_DECAY = 0.1
+
+# step-based warm-up
+_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
+_CN.TRAINER.WARMUP_RATIO = 0.
+_CN.TRAINER.WARMUP_STEP = 4800
+
+# learning rate scheduler
+_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
+_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
+_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
+_CN.TRAINER.MSLR_GAMMA = 0.5
+_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
+_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
+
+# plotting related
+_CN.TRAINER.ENABLE_PLOTTING = True
+_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
+_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
+_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
+
+# geometric metrics and pose solver
+_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
+_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, LO-RANSAC]
+_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
+_CN.TRAINER.RANSAC_CONF = 0.99999
+_CN.TRAINER.RANSAC_MAX_ITERS = 10000
+_CN.TRAINER.USE_MAGSACPP = False
+
+# data sampler for train_dataloader
+_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
+# 'scene_balance' config
+_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
+_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
+_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
+_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
+# 'random' config
+_CN.TRAINER.RDM_REPLACEMENT = True
+_CN.TRAINER.RDM_NUM_SAMPLES = None
+
+# gradient clipping
+_CN.TRAINER.GRADIENT_CLIPPING = 0.5
+
+# reproducibility
+# This seed affects the data sampling. With the same seed, the data sampling is promised
+# to be the same. When resume training from a checkpoint, it's better to use a different
+# seed, otherwise the sampled data will be exactly the same as before resuming, which will
+# cause less unique data items sampled during the entire training.
+# Use of different seed values might affect the final training result, since not all data items
+# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
+_CN.TRAINER.SEED = 66
+
+
+def get_cfg_defaults():
+ """Get a yacs CfgNode object with default values for my_project."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _CN.clone()
diff --git a/imcui/third_party/EfficientLoFTR/src/datasets/megadepth.py b/imcui/third_party/EfficientLoFTR/src/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f070b2b6da7ef779d41090773d4c45592e95514
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/datasets/megadepth.py
@@ -0,0 +1,133 @@
+import os.path as osp
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+from loguru import logger
+
+from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
+
+
+class MegaDepthDataset(Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ mode='train',
+ min_overlap_score=0.4,
+ img_resize=None,
+ df=None,
+ img_padding=False,
+ depth_padding=False,
+ augment_fn=None,
+ fp16=False,
+ **kwargs):
+ """
+ Manage one scene(npz_path) of MegaDepth dataset.
+
+ Args:
+ root_dir (str): megadepth root directory that has `phoenix`.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ mode (str): options are ['train', 'val', 'test']
+ min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
+ This is useful during training with batches and testing with memory intensive algorithms.
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
+ depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.mode = mode
+ self.scene_id = npz_path.split('.')[0]
+
+ # prepare scene_info and pair_info
+ if mode == 'test' and min_overlap_score != 0:
+ logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
+ min_overlap_score = 0
+ self.scene_info = np.load(npz_path, allow_pickle=True)
+ self.pair_infos = self.scene_info['pair_infos'].copy()
+
+ del self.scene_info['pair_infos']
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
+
+ # parameters for image resizing, padding and depthmap padding
+ if mode == 'train':
+ assert img_resize is not None and img_padding and depth_padding
+ self.img_resize = img_resize
+ self.df = df
+ self.img_padding = img_padding
+ self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
+
+ # for training LoFTR
+ self.augment_fn = augment_fn if mode == 'train' else None
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+
+ self.fp16 = fp16
+
+ def __len__(self):
+ return len(self.pair_infos)
+
+ def __getitem__(self, idx):
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
+
+ # read grayscale image and mask. (1, h, w) and (h, w)
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
+
+ # TODO: Support augmentation & handle seeds for each worker correctly.
+ image0, mask0, scale0 = read_megadepth_gray(
+ img_name0, self.img_resize, self.df, self.img_padding, None)
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+ image1, mask1, scale1 = read_megadepth_gray(
+ img_name1, self.img_resize, self.df, self.img_padding, None)
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+
+ # read depth. shape: (h, w)
+ if self.mode in ['train', 'val']:
+ depth0 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
+ depth1 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read intrinsics of original size
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T0 = self.scene_info['poses'][idx0]
+ T1 = self.scene_info['poses'][idx1]
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
+ T_1to0 = T_0to1.inverse()
+
+ if self.fp16:
+ image0, image1, depth0, depth1, scale0, scale1 = map(lambda x: x.half(),
+ [image0, image1, depth0, depth1, scale0, scale1])
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'scale0': scale0, # [scale_w, scale_h]
+ 'scale1': scale1,
+ 'dataset_name': 'MegaDepth',
+ 'scene_id': self.scene_id,
+ 'pair_id': idx,
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
+ }
+ # for LoFTR training
+ if mask0 is not None: # img_padding is True
+ if self.coarse_scale:
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
+ scale_factor=self.coarse_scale,
+ mode='nearest',
+ recompute_scale_factor=False)[0].bool()
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+
+ return data
diff --git a/imcui/third_party/EfficientLoFTR/src/datasets/sampler.py b/imcui/third_party/EfficientLoFTR/src/datasets/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/datasets/sampler.py
@@ -0,0 +1,77 @@
+import torch
+from torch.utils.data import Sampler, ConcatDataset
+
+
+class RandomConcatSampler(Sampler):
+ """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
+ in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
+ However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
+
+ For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
+ Args:
+ shuffle (bool): shuffle the random sampled indices across all sub-datsets.
+ repeat (int): repeatedly use the sampled indices multiple times for training.
+ [arXiv:1902.05509, arXiv:1901.09335]
+ NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
+ NOTE: This sampler behaves differently with DistributedSampler.
+ It assume the dataset is splitted across ranks instead of replicated.
+ TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
+ ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
+ """
+ def __init__(self,
+ data_source: ConcatDataset,
+ n_samples_per_subset: int,
+ subset_replacement: bool=True,
+ shuffle: bool=True,
+ repeat: int=1,
+ seed: int=None):
+ if not isinstance(data_source, ConcatDataset):
+ raise TypeError("data_source should be torch.utils.data.ConcatDataset")
+
+ self.data_source = data_source
+ self.n_subset = len(self.data_source.datasets)
+ self.n_samples_per_subset = n_samples_per_subset
+ self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
+ self.subset_replacement = subset_replacement
+ self.repeat = repeat
+ self.shuffle = shuffle
+ self.generator = torch.manual_seed(seed)
+ assert self.repeat >= 1
+
+ def __len__(self):
+ return self.n_samples
+
+ def __iter__(self):
+ indices = []
+ # sample from each sub-dataset
+ for d_idx in range(self.n_subset):
+ low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
+ high = self.data_source.cumulative_sizes[d_idx]
+ if self.subset_replacement:
+ rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ else: # sample without replacement
+ len_subset = len(self.data_source.datasets[d_idx])
+ rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
+ if len_subset >= self.n_samples_per_subset:
+ rand_tensor = rand_tensor[:self.n_samples_per_subset]
+ else: # padding with replacement
+ rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
+ indices.append(rand_tensor)
+ indices = torch.cat(indices)
+ if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
+ rand_tensor = torch.randperm(len(indices), generator=self.generator)
+ indices = indices[rand_tensor]
+
+ # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
+ if self.repeat > 1:
+ repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
+ if self.shuffle:
+ _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
+ repeat_indices = map(_choice, repeat_indices)
+ indices = torch.cat([indices, *repeat_indices], 0)
+
+ assert indices.shape[0] == self.n_samples
+ return iter(indices.tolist())
diff --git a/imcui/third_party/EfficientLoFTR/src/datasets/scannet.py b/imcui/third_party/EfficientLoFTR/src/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..41743aaa0b0f6827c116ab6166ae71964515d196
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/datasets/scannet.py
@@ -0,0 +1,129 @@
+from os import path as osp
+from typing import Dict
+from unicodedata import name
+
+import numpy as np
+import torch
+import torch.utils as utils
+from numpy.linalg import inv
+from src.utils.dataset import (
+ read_scannet_gray,
+ read_scannet_depth,
+ read_scannet_pose,
+ read_scannet_intrinsic
+)
+
+
+class ScanNetDataset(utils.data.Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ intrinsic_path,
+ mode='train',
+ min_overlap_score=0.4,
+ augment_fn=None,
+ pose_dir=None,
+ img_resize=None,
+ fp16=False,
+ **kwargs):
+ """Manage one scene of ScanNet Dataset.
+ Args:
+ root_dir (str): ScanNet root directory that contains scene folders.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ intrinsic_path (str): path to depth-camera intrinsic file.
+ mode (str): options are ['train', 'val', 'test'].
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ pose_dir (str): ScanNet root directory that contains all poses.
+ (we use a separate (optional) pose_dir since we store images and poses separately.)
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.pose_dir = pose_dir if pose_dir is not None else root_dir
+ self.mode = mode
+
+ # prepare data_names, intrinsics and extrinsics(T)
+ with np.load(npz_path) as data:
+ self.data_names = data['name']
+ if 'score' in data.keys() and mode not in ['val' or 'test']:
+ kept_mask = data['score'] > min_overlap_score
+ self.data_names = self.data_names[kept_mask]
+ self.intrinsics = dict(np.load(intrinsic_path))
+
+ # for training LoFTR
+ self.augment_fn = augment_fn if mode == 'train' else None
+
+ self.fp16 = fp16
+ self.img_resize = img_resize
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def _read_abs_pose(self, scene_name, name):
+ pth = osp.join(self.pose_dir,
+ scene_name,
+ 'pose', f'{name}.txt')
+ return read_scannet_pose(pth)
+
+ def _compute_rel_pose(self, scene_name, name0, name1):
+ pose0 = self._read_abs_pose(scene_name, name0)
+ pose1 = self._read_abs_pose(scene_name, name1)
+
+ return np.matmul(pose1, inv(pose0)) # (4, 4)
+
+ def __getitem__(self, idx):
+ data_name = self.data_names[idx]
+ scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the grayscale image which will be resized to (1, 480, 640)
+ img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
+ img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
+
+ # TODO: Support augmentation & handle seeds for each worker correctly.
+ image0 = read_scannet_gray(img_name0, resize=self.img_resize, augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+ image1 = read_scannet_gray(img_name1, resize=self.img_resize, augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+
+ # read the depthmap which is stored as (480, 640)
+ if self.mode in ['train', 'val']:
+ depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
+ depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read the intrinsic of depthmap
+ K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
+ dtype=torch.float32)
+ T_1to0 = T_0to1.inverse()
+
+ h_new, w_new = self.img_resize[1], self.img_resize[0]
+ scale0 = torch.tensor([640/w_new, 480/h_new], dtype=torch.float)
+ scale1 = torch.tensor([640/w_new, 480/h_new], dtype=torch.float)
+
+ if self.fp16:
+ image0, image1, depth0, depth1, scale0, scale1 = map(lambda x: x.half(),
+ [image0, image1, depth0, depth1, scale0, scale1])
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'scale0': scale0, # [scale_w, scale_h]
+ 'scale1': scale1,
+ 'dataset_name': 'ScanNet',
+ 'scene_id': scene_name,
+ 'pair_id': idx,
+ 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
+ osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
+ }
+
+ return data
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/lightning/data.py b/imcui/third_party/EfficientLoFTR/src/lightning/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..28730fc6bf30fcd99a35b6708e26b7a9a1eca9df
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/lightning/data.py
@@ -0,0 +1,357 @@
+import os
+import math
+from collections import abc
+from loguru import logger
+from torch.utils.data.dataset import Dataset
+from tqdm import tqdm
+from os import path as osp
+from pathlib import Path
+from joblib import Parallel, delayed
+
+import pytorch_lightning as pl
+from torch import distributed as dist
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset,
+ DistributedSampler,
+ RandomSampler,
+ dataloader
+)
+
+from src.utils.augment import build_augmentor
+from src.utils.dataloader import get_local_split
+from src.utils.misc import tqdm_joblib
+from src.utils import comm
+from src.datasets.megadepth import MegaDepthDataset
+from src.datasets.scannet import ScanNetDataset
+from src.datasets.sampler import RandomConcatSampler
+
+
+class MultiSceneDataModule(pl.LightningDataModule):
+ """
+ For distributed training, each training process is assgined
+ only a part of the training scenes to reduce memory overhead.
+ """
+ def __init__(self, args, config):
+ super().__init__()
+
+ # 1. data config
+ # Train and Val should from the same data source
+ self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE
+ self.test_data_source = config.DATASET.TEST_DATA_SOURCE
+ # training and validating
+ self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
+ self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional)
+ self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
+ self.train_list_path = config.DATASET.TRAIN_LIST_PATH
+ self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
+ self.val_data_root = config.DATASET.VAL_DATA_ROOT
+ self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional)
+ self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
+ self.val_list_path = config.DATASET.VAL_LIST_PATH
+ self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
+ # testing
+ self.test_data_root = config.DATASET.TEST_DATA_ROOT
+ self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
+ self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
+ self.test_list_path = config.DATASET.TEST_LIST_PATH
+ self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
+
+ # 2. dataset config
+ # general options
+ self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
+ self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
+ self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
+
+ # ScanNet options
+ self.scan_img_resizeX = config.DATASET.SCAN_IMG_RESIZEX # 640
+ self.scan_img_resizeY = config.DATASET.SCAN_IMG_RESIZEY # 480
+
+
+ # MegaDepth options
+ self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 832
+ self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
+ self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
+ self.mgdpt_df = config.DATASET.MGDPT_DF # 8
+ self.coarse_scale = 1 / config.LOFTR.RESOLUTION[0] # 0.125. for training loftr.
+
+ self.fp16 = config.DATASET.FP16
+
+ # 3.loader parameters
+ self.train_loader_params = {
+ 'batch_size': args.batch_size,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.val_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.test_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': True
+ }
+
+ # 4. sampler
+ self.data_sampler = config.TRAINER.DATA_SAMPLER
+ self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
+ self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
+ self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
+ self.repeat = config.TRAINER.SB_REPEAT
+
+ # (optional) RandomSampler for debugging
+
+ # misc configurations
+ self.parallel_load_data = getattr(args, 'parallel_load_data', False)
+ self.seed = config.TRAINER.SEED # 66
+
+ def setup(self, stage=None):
+ """
+ Setup train / val / test dataset. This method will be called by PL automatically.
+ Args:
+ stage (str): 'fit' in training phase, and 'test' in testing phase.
+ """
+
+ assert stage in ['fit', 'validate', 'test'], "stage must be either fit or test"
+
+ try:
+ self.world_size = dist.get_world_size()
+ self.rank = dist.get_rank()
+ logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
+ except AssertionError as ae:
+ self.world_size = 1
+ self.rank = 0
+ # logger.warning(" (set wolrd_size=1 and rank=0)")
+ logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
+
+ if stage == 'fit':
+ self.train_dataset = self._setup_dataset(
+ self.train_data_root,
+ self.train_npz_root,
+ self.train_list_path,
+ self.train_intrinsic_path,
+ mode='train',
+ min_overlap_score=self.min_overlap_score_train,
+ pose_dir=self.train_pose_root)
+ # setup multiple (optional) validation subsets
+ if isinstance(self.val_list_path, (list, tuple)):
+ self.val_dataset = []
+ if not isinstance(self.val_npz_root, (list, tuple)):
+ self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
+ for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
+ self.val_dataset.append(self._setup_dataset(
+ self.val_data_root,
+ npz_root,
+ npz_list,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root))
+ else:
+ self.val_dataset = self._setup_dataset(
+ self.val_data_root,
+ self.val_npz_root,
+ self.val_list_path,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root)
+ logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
+ elif stage == 'validate':
+ if isinstance(self.val_list_path, (list, tuple)):
+ self.val_dataset = []
+ if not isinstance(self.val_npz_root, (list, tuple)):
+ self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
+ for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
+ self.val_dataset.append(self._setup_dataset(
+ self.val_data_root,
+ npz_root,
+ npz_list,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root))
+ else:
+ self.val_dataset = self._setup_dataset(
+ self.val_data_root,
+ self.val_npz_root,
+ self.val_list_path,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root)
+ logger.info(f'[rank:{self.rank}] Val Dataset loaded!')
+ else: # stage == 'test
+ self.test_dataset = self._setup_dataset(
+ self.test_data_root,
+ self.test_npz_root,
+ self.test_list_path,
+ self.test_intrinsic_path,
+ mode='test',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.test_pose_root)
+ logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
+
+ def _setup_dataset(self,
+ data_root,
+ split_npz_root,
+ scene_list_path,
+ intri_path,
+ mode='train',
+ min_overlap_score=0.,
+ pose_dir=None):
+ """ Setup train / val / test set"""
+ with open(scene_list_path, 'r') as f:
+ npz_names = [name.split()[0] for name in f.readlines()]
+
+ if mode == 'train':
+ local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
+ else:
+ local_npz_names = npz_names
+ logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
+
+ dataset_builder = self._build_concat_dataset_parallel \
+ if self.parallel_load_data \
+ else self._build_concat_dataset
+ return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
+ mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+
+ def _build_concat_dataset(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None
+ ):
+ datasets = []
+ augment_fn = self.augment_fn if mode == 'train' else None
+ data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ for npz_name in tqdm(npz_names,
+ desc=f'[rank:{self.rank}] loading {mode} datasets',
+ disable=int(self.rank) != 0):
+ # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
+ npz_path = osp.join(npz_dir, npz_name)
+ if data_source == 'ScanNet':
+ datasets.append(
+ ScanNetDataset(data_root,
+ npz_path,
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir,
+ img_resize=(self.scan_img_resizeX, self.scan_img_resizeY),
+ fp16 = self.fp16,
+ ))
+ elif data_source == 'MegaDepth':
+ datasets.append(
+ MegaDepthDataset(data_root,
+ npz_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale,
+ fp16 = self.fp16,
+ ))
+ else:
+ raise NotImplementedError()
+ return ConcatDataset(datasets)
+
+ def _build_concat_dataset_parallel(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None,
+ ):
+ augment_fn = self.augment_fn if mode == 'train' else None
+ data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
+ total=len(npz_names), disable=int(self.rank) != 0)):
+ if data_source == 'ScanNet':
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ ScanNetDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir))(name)
+ for name in npz_names)
+ elif data_source == 'MegaDepth':
+ # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
+ raise NotImplementedError()
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ MegaDepthDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale))(name)
+ for name in npz_names)
+ else:
+ raise ValueError(f'Unknown dataset: {data_source}')
+ return ConcatDataset(datasets)
+
+ def train_dataloader(self):
+ """ Build training dataloader for ScanNet / MegaDepth. """
+ assert self.data_sampler in ['scene_balance']
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
+ if self.data_sampler == 'scene_balance':
+ sampler = RandomConcatSampler(self.train_dataset,
+ self.n_samples_per_subset,
+ self.subset_replacement,
+ self.shuffle, self.repeat, self.seed)
+ else:
+ sampler = None
+ dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
+ return dataloader
+
+ def val_dataloader(self):
+ """ Build validation dataloader for ScanNet / MegaDepth. """
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
+ if not isinstance(self.val_dataset, abc.Sequence):
+ sampler = DistributedSampler(self.val_dataset, shuffle=False)
+ return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
+ else:
+ dataloaders = []
+ for dataset in self.val_dataset:
+ sampler = DistributedSampler(dataset, shuffle=False)
+ dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
+ return dataloaders
+
+ def test_dataloader(self, *args, **kwargs):
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
+ sampler = DistributedSampler(self.test_dataset, shuffle=False)
+ return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
+
+
+def _build_dataset(dataset: Dataset, *args, **kwargs):
+ return dataset(*args, **kwargs)
diff --git a/imcui/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py b/imcui/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a8e3ef2725e63b633c6022ae5bfd1a138e438b
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py
@@ -0,0 +1,272 @@
+
+from collections import defaultdict
+import pprint
+from loguru import logger
+from pathlib import Path
+
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from matplotlib import pyplot as plt
+
+from src.loftr import LoFTR
+from src.loftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
+from src.losses.loftr_loss import LoFTRLoss
+from src.optimizers import build_optimizer, build_scheduler
+from src.utils.metrics import (
+ compute_symmetrical_epipolar_errors,
+ compute_pose_errors,
+ aggregate_metrics
+)
+from src.utils.plotting import make_matching_figures
+from src.utils.comm import gather, all_gather
+from src.utils.misc import lower_config, flattenList
+from src.utils.profiler import PassThroughProfiler
+
+from torch.profiler import profile
+
+def reparameter(matcher):
+ module = matcher.backbone.layer0
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]:
+ for module in modules:
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]:
+ for module in modules:
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ return matcher
+
+
+class PL_LoFTR(pl.LightningModule):
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
+ """
+ TODO:
+ - use the new version of PL logging API.
+ """
+ super().__init__()
+ # Misc
+ self.config = config # full config
+ _config = lower_config(self.config)
+ self.loftr_cfg = lower_config(_config['loftr'])
+ self.profiler = profiler or PassThroughProfiler()
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+
+ # Matcher: LoFTR
+ self.matcher = LoFTR(config=_config['loftr'], profiler=self.profiler)
+ self.loss = LoFTRLoss(_config)
+
+ # Pretrained weights
+ if pretrained_ckpt:
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
+ msg=self.matcher.load_state_dict(state_dict, strict=False)
+ logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
+
+ # Testing
+ self.warmup = False
+ self.reparameter = False
+ self.start_event = torch.cuda.Event(enable_timing=True)
+ self.end_event = torch.cuda.Event(enable_timing=True)
+ self.total_ms = 0
+
+ def configure_optimizers(self):
+ # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
+ optimizer = build_optimizer(self, self.config)
+ scheduler = build_scheduler(self.config, optimizer)
+ return [optimizer], [scheduler]
+
+ def optimizer_step(
+ self, epoch, batch_idx, optimizer, optimizer_idx,
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+ # learning rate warm up
+ warmup_step = self.config.TRAINER.WARMUP_STEP
+ if self.trainer.global_step < warmup_step:
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
+ lr = base_lr + \
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
+ for pg in optimizer.param_groups:
+ pg['lr'] = lr
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+ pass
+ else:
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+
+ # update params
+ optimizer.step(closure=optimizer_closure)
+ optimizer.zero_grad()
+
+ def _trainval_inference(self, batch):
+ with self.profiler.profile("Compute coarse supervision"):
+ with torch.autocast(enabled=False, device_type='cuda'):
+ compute_supervision_coarse(batch, self.config)
+
+ with self.profiler.profile("LoFTR"):
+ with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'):
+ self.matcher(batch)
+
+ with self.profiler.profile("Compute fine supervision"):
+ with torch.autocast(enabled=False, device_type='cuda'):
+ compute_supervision_fine(batch, self.config, self.logger)
+
+ with self.profiler.profile("Compute losses"):
+ with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'):
+ self.loss(batch)
+
+ def _compute_metrics(self, batch):
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
+
+ rel_pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].size(0)
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [(batch['epi_errs'].reshape(-1,1))[batch['m_bids'] == b].reshape(-1).cpu().numpy() for b in range(bs)],
+ 'R_errs': batch['R_errs'],
+ 't_errs': batch['t_errs'],
+ 'inliers': batch['inliers'],
+ 'num_matches': [batch['mconf'].shape[0]], # batch size = 1 only
+ }
+ ret_dict = {'metrics': metrics}
+ return ret_dict, rel_pair_names
+
+ def training_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ # logging
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+ # scalars
+ for k, v in batch['loss_scalars'].items():
+ self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
+
+ # figures
+ if self.config.TRAINER.ENABLE_PLOTTING:
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+ for k, v in figures.items():
+ self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
+ return {'loss': batch['loss']}
+
+ def training_epoch_end(self, outputs):
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+ if self.trainer.global_rank == 0:
+ self.logger.experiment.add_scalar(
+ 'train/avg_loss_on_epoch', avg_loss,
+ global_step=self.current_epoch)
+
+ def on_validation_epoch_start(self):
+ self.matcher.fine_matching.validate = True
+
+ def validation_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ ret_dict, _ = self._compute_metrics(batch)
+
+ val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
+ figures = {self.config.TRAINER.PLOT_MODE: []}
+ if batch_idx % val_plot_interval == 0:
+ figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
+
+ return {
+ **ret_dict,
+ 'loss_scalars': batch['loss_scalars'],
+ 'figures': figures,
+ }
+
+ def validation_epoch_end(self, outputs):
+ self.matcher.fine_matching.validate = False
+ # handle multiple validation sets
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+ multi_val_metrics = defaultdict(list)
+
+ for valset_idx, outputs in enumerate(multi_outputs):
+ # since pl performs sanity_check at the very begining of the training
+ cur_epoch = self.trainer.current_epoch
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+ cur_epoch = -1
+
+ # 1. loss_scalars: dict of list, on cpu
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+
+ # 2. val metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, config=self.config)
+ for thr in [5, 10, 20]:
+ multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
+
+ # 3. figures
+ _figures = [o['figures'] for o in outputs]
+ figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
+
+ # tensorboard records only on rank 0
+ if self.trainer.global_rank == 0:
+ for k, v in loss_scalars.items():
+ mean_v = torch.stack(v).mean()
+ self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+
+ for k, v in val_metrics_4tb.items():
+ self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
+
+ for k, v in figures.items():
+ if self.trainer.global_rank == 0:
+ for plot_idx, fig in enumerate(v):
+ self.logger.experiment.add_figure(
+ f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
+ plt.close('all')
+
+ for thr in [5, 10, 20]:
+ # log on all ranks for ModelCheckpoint callback to work properly
+ self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
+
+ def test_step(self, batch, batch_idx):
+ if (self.config.LOFTR.BACKBONE_TYPE == 'RepVGG') and not self.reparameter:
+ self.matcher = reparameter(self.matcher)
+ if self.config.LOFTR.HALF:
+ self.matcher = self.matcher.eval().half()
+ self.reparameter = True
+
+ if not self.warmup:
+ if self.config.LOFTR.HALF:
+ for i in range(50):
+ self.matcher(batch)
+ else:
+ with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'):
+ for i in range(50):
+ self.matcher(batch)
+ self.warmup = True
+ torch.cuda.synchronize()
+
+ if self.config.LOFTR.HALF:
+ self.start_event.record()
+ self.matcher(batch)
+ self.end_event.record()
+ torch.cuda.synchronize()
+ self.total_ms += self.start_event.elapsed_time(self.end_event)
+ else:
+ with torch.autocast(enabled=self.config.LOFTR.MP, device_type='cuda'):
+ self.start_event.record()
+ self.matcher(batch)
+ self.end_event.record()
+ torch.cuda.synchronize()
+ self.total_ms += self.start_event.elapsed_time(self.end_event)
+
+ ret_dict, rel_pair_names = self._compute_metrics(batch)
+ return ret_dict
+
+ def test_epoch_end(self, outputs):
+ # metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+
+ # [{key: [{...}, *#bs]}, *#batch]
+ if self.trainer.global_rank == 0:
+ print('Averaged Matching time over 1500 pairs: {:.2f} ms'.format(self.total_ms / 1500))
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR, config=self.config)
+ logger.info('\n' + pprint.pformat(val_metrics_4tb))
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/__init__.py b/imcui/third_party/EfficientLoFTR/src/loftr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..362c0d04fa437c0073016a9ceac607ec5cffdfb7
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/__init__.py
@@ -0,0 +1,4 @@
+from .loftr import LoFTR
+from .utils.full_config import full_default_cfg
+from .utils.opt_config import opt_default_cfg
+from .loftr import reparameter
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa2c02f486958a542550c0943ce284cf97e44050
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py
@@ -0,0 +1,11 @@
+from .backbone import RepVGG_8_1_align
+
+def build_backbone(config):
+ if config['backbone_type'] == 'RepVGG':
+ if config['align_corner'] is False:
+ if config['resolution'] == (8, 1):
+ return RepVGG_8_1_align(config['backbone'])
+ else:
+ raise ValueError(f"LOFTR.ALIGN_CORNER {config['align_corner']} not supported.")
+ else:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1d921971914d4d465e5e4c7fe87c90e773d4ef1
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py
@@ -0,0 +1,37 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from .repvgg import create_RepVGG
+
+class RepVGG_8_1_align(nn.Module):
+ """
+ RepVGG backbone, output resolution are 1/8 and 1.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ backbone = create_RepVGG(False)
+
+ self.layer0, self.layer1, self.layer2, self.layer3 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3
+
+ for layer in [self.layer0, self.layer1, self.layer2, self.layer3]:
+ for m in layer.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ out = self.layer0(x) # 1/2
+ for module in self.layer1:
+ out = module(out) # 1/2
+ x1 = out
+ for module in self.layer2:
+ out = module(out) # 1/4
+ x2 = out
+ for module in self.layer3:
+ out = module(out) # 1/8
+ x3 = out
+
+ return {'feats_c': x3, 'feats_f': None, 'feats_x2': x2, 'feats_x1': x1}
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b038c82511e902947b65d7c63ff461dc24c3e7
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py
@@ -0,0 +1,224 @@
+# --------------------------------------------------------
+# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
+# Github source: https://github.com/DingXiaoH/RepVGG
+# Licensed under The MIT License [see LICENSE for details]
+# Modified from: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
+# --------------------------------------------------------
+import torch.nn as nn
+import numpy as np
+import torch
+import copy
+# from se_block import SEBlock
+import torch.utils.checkpoint as checkpoint
+from loguru import logger
+
+def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
+ result = nn.Sequential()
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
+ return result
+
+class RepVGGBlock(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
+ super(RepVGGBlock, self).__init__()
+ self.deploy = deploy
+ self.groups = groups
+ self.in_channels = in_channels
+
+ assert kernel_size == 3
+ assert padding == 1
+
+ padding_11 = padding - kernel_size // 2
+
+ self.nonlinearity = nn.ReLU()
+
+ if use_se:
+ # Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.
+ # self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
+ raise ValueError(f"SEBlock not supported")
+ else:
+ self.se = nn.Identity()
+
+ if deploy:
+ self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
+ else:
+ self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
+ self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
+ self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
+
+ def forward(self, inputs):
+ if hasattr(self, 'rbr_reparam'):
+ return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
+
+ if self.rbr_identity is None:
+ id_out = 0
+ else:
+ id_out = self.rbr_identity(inputs)
+
+ return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
+
+
+ # Optional. This may improve the accuracy and facilitates quantization in some cases.
+ # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
+ # 2. Use like this.
+ # loss = criterion(....)
+ # for every RepVGGBlock blk:
+ # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
+ # optimizer.zero_grad()
+ # loss.backward()
+ def get_custom_L2(self):
+ K3 = self.rbr_dense.conv.weight
+ K1 = self.rbr_1x1.conv.weight
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
+
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
+ eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
+ return l2_loss_eq_kernel + l2_loss_circle
+
+
+
+# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
+# You can get the equivalent kernel and bias at any time and do whatever you want,
+ # for example, apply some penalties or constraints during training, just like you do to the other models.
+# May be useful for quantization or pruning.
+ def get_equivalent_kernel_bias(self):
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
+
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
+ if kernel1x1 is None:
+ return 0
+ else:
+ return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
+
+ def _fuse_bn_tensor(self, branch):
+ if branch is None:
+ return 0, 0
+ if isinstance(branch, nn.Sequential):
+ kernel = branch.conv.weight
+ running_mean = branch.bn.running_mean
+ running_var = branch.bn.running_var
+ gamma = branch.bn.weight
+ beta = branch.bn.bias
+ eps = branch.bn.eps
+ else:
+ assert isinstance(branch, nn.BatchNorm2d)
+ if not hasattr(self, 'id_tensor'):
+ input_dim = self.in_channels // self.groups
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
+ for i in range(self.in_channels):
+ kernel_value[i, i % input_dim, 1, 1] = 1
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
+ kernel = self.id_tensor
+ running_mean = branch.running_mean
+ running_var = branch.running_var
+ gamma = branch.weight
+ beta = branch.bias
+ eps = branch.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+
+ def switch_to_deploy(self):
+ if hasattr(self, 'rbr_reparam'):
+ return
+ kernel, bias = self.get_equivalent_kernel_bias()
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
+ kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
+ padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
+ self.rbr_reparam.weight.data = kernel
+ self.rbr_reparam.bias.data = bias
+ self.__delattr__('rbr_dense')
+ self.__delattr__('rbr_1x1')
+ if hasattr(self, 'rbr_identity'):
+ self.__delattr__('rbr_identity')
+ if hasattr(self, 'id_tensor'):
+ self.__delattr__('id_tensor')
+ self.deploy = True
+
+
+
+class RepVGG(nn.Module):
+
+ def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False):
+ super(RepVGG, self).__init__()
+ assert len(width_multiplier) == 4
+ self.deploy = deploy
+ self.override_groups_map = override_groups_map or dict()
+ assert 0 not in self.override_groups_map
+ self.use_se = use_se
+ self.use_checkpoint = use_checkpoint
+
+ self.in_planes = min(64, int(64 * width_multiplier[0]))
+ self.stage0 = RepVGGBlock(in_channels=1, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se)
+ self.cur_layer_idx = 1
+ self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1)
+ self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)
+ self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)
+
+ def _make_stage(self, planes, num_blocks, stride):
+ strides = [stride] + [1]*(num_blocks-1)
+ blocks = []
+ for stride in strides:
+ cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
+ blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
+ stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se))
+ self.in_planes = planes
+ self.cur_layer_idx += 1
+ return nn.ModuleList(blocks)
+
+ def forward(self, x):
+ out = self.stage0(x)
+ for stage in (self.stage1, self.stage2, self.stage3):
+ for block in stage:
+ if self.use_checkpoint:
+ out = checkpoint.checkpoint(block, out)
+ else:
+ out = block(out)
+ out = self.gap(out)
+ out = out.view(out.size(0), -1)
+ out = self.linear(out)
+ return out
+
+
+optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
+g2_map = {l: 2 for l in optional_groupwise_layers}
+g4_map = {l: 4 for l in optional_groupwise_layers}
+
+def create_RepVGG(deploy=False, use_checkpoint=False):
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
+
+# Use this for converting a RepVGG model or a bigger model with RepVGG as its component
+# Use like this
+# model = create_RepVGG_A0(deploy=False)
+# train model or load weights
+# repvgg_model_convert(model, save_path='repvgg_deploy.pth')
+# If you want to preserve the original model, call with do_copy=True
+
+# ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
+# train_backbone = create_RepVGG_B2(deploy=False)
+# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
+# train_pspnet = build_pspnet(backbone=train_backbone)
+# segmentation_train(train_pspnet)
+# deploy_pspnet = repvgg_model_convert(train_pspnet)
+# segmentation_test(deploy_pspnet)
+# ===================== example_pspnet.py shows an example
+
+def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
+ if do_copy:
+ model = copy.deepcopy(model)
+ for module in model.modules():
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ if save_path is not None:
+ torch.save(model.state_dict(), save_path)
+ return model
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f76939a1b0c68504f535d4c9eb4ef91d19cd63a
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr.py
@@ -0,0 +1,124 @@
+import torch
+import torch.nn as nn
+from einops.einops import rearrange
+
+from .backbone import build_backbone
+from .loftr_module import LocalFeatureTransformer, FinePreprocess
+from .utils.coarse_matching import CoarseMatching
+from .utils.fine_matching import FineMatching
+from ..utils.misc import detect_NaN
+
+from loguru import logger
+
+def reparameter(matcher):
+ module = matcher.backbone.layer0
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ for modules in [matcher.backbone.layer1, matcher.backbone.layer2, matcher.backbone.layer3]:
+ for module in modules:
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ for modules in [matcher.fine_preprocess.layer2_outconv2, matcher.fine_preprocess.layer1_outconv2]:
+ for module in modules:
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ return matcher
+
+class LoFTR(nn.Module):
+ def __init__(self, config, profiler=None):
+ super().__init__()
+ # Misc
+ self.config = config
+ self.profiler = profiler
+
+ # Modules
+ self.backbone = build_backbone(config)
+ self.loftr_coarse = LocalFeatureTransformer(config)
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
+ self.fine_preprocess = FinePreprocess(config)
+ self.fine_matching = FineMatching(config)
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
+ ret_dict = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
+ feats_c = ret_dict['feats_c']
+ data.update({
+ 'feats_x2': ret_dict['feats_x2'],
+ 'feats_x1': ret_dict['feats_x1'],
+ })
+ (feat_c0, feat_c1) = feats_c.split(data['bs'])
+ else: # handle different input shapes
+ ret_dict0, ret_dict1 = self.backbone(data['image0']), self.backbone(data['image1'])
+ feat_c0 = ret_dict0['feats_c']
+ feat_c1 = ret_dict1['feats_c']
+ data.update({
+ 'feats_x2_0': ret_dict0['feats_x2'],
+ 'feats_x1_0': ret_dict0['feats_x1'],
+ 'feats_x2_1': ret_dict1['feats_x2'],
+ 'feats_x1_1': ret_dict1['feats_x1'],
+ })
+
+
+ mul = self.config['resolution'][0] // self.config['resolution'][1]
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_f': [feat_c0.shape[2] * mul, feat_c0.shape[3] * mul] ,
+ 'hw1_f': [feat_c1.shape[2] * mul, feat_c1.shape[3] * mul]
+ })
+
+ # 2. coarse-level loftr module
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'], data['mask1']
+
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
+
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+
+ # detect NaN during mixed precision training
+ if self.config['replace_nan'] and (torch.any(torch.isnan(feat_c0)) or torch.any(torch.isnan(feat_c1))):
+ detect_NaN(feat_c0, feat_c1)
+
+ # 3. match coarse-level
+ self.coarse_matching(feat_c0, feat_c1, data,
+ mask_c0=mask_c0.view(mask_c0.size(0), -1) if mask_c0 is not None else mask_c0,
+ mask_c1=mask_c1.view(mask_c1.size(0), -1) if mask_c1 is not None else mask_c1
+ )
+
+ # prevent fp16 overflow during mixed precision training
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1])
+
+ # 4. fine-level refinement
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_c0, feat_c1, data)
+
+ # detect NaN during mixed precision training
+ if self.config['replace_nan'] and (torch.any(torch.isnan(feat_f0_unfold)) or torch.any(torch.isnan(feat_f1_unfold))):
+ detect_NaN(feat_f0_unfold, feat_f1_unfold)
+
+ del feat_c0, feat_c1, mask_c0, mask_c1
+
+ # 5. match fine-level
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
+
+ def load_state_dict(self, state_dict, *args, **kwargs):
+ for k in list(state_dict.keys()):
+ if k.startswith('matcher.'):
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+ return super().load_state_dict(state_dict, *args, **kwargs)
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py
@@ -0,0 +1,2 @@
+from .transformer import LocalFeatureTransformer
+from .fine_preprocess import FinePreprocess
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca37e02a0e709650d8133db04e84e9cfccbd6bf0
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange, repeat
+
+from loguru import logger
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution without padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+class FinePreprocess(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ block_dims = config['backbone']['block_dims']
+ self.W = self.config['fine_window_size']
+ self.fine_d_model = block_dims[0]
+
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
+
+ def inter_fpn(self, feat_c, x2, x1, stride):
+ feat_c = self.layer3_outconv(feat_c)
+ feat_c = F.interpolate(feat_c, scale_factor=2., mode='bilinear', align_corners=False)
+
+ x2 = self.layer2_outconv(x2)
+ x2 = self.layer2_outconv2(x2+feat_c)
+ x2 = F.interpolate(x2, scale_factor=2., mode='bilinear', align_corners=False)
+
+ x1 = self.layer1_outconv(x1)
+ x1 = self.layer1_outconv2(x1+x2)
+ x1 = F.interpolate(x1, scale_factor=2., mode='bilinear', align_corners=False)
+ return x1
+
+ def forward(self, feat_c0, feat_c1, data):
+ W = self.W
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
+
+ data.update({'W': W})
+ if data['b_ids'].shape[0] == 0:
+ feat0 = torch.empty(0, self.W**2, self.fine_d_model, device=feat_c0.device)
+ feat1 = torch.empty(0, self.W**2, self.fine_d_model, device=feat_c0.device)
+ return feat0, feat1
+
+ if data['hw0_i'] == data['hw1_i']:
+ feat_c = rearrange(torch.cat([feat_c0, feat_c1], 0), 'b (h w) c -> b c h w', h=data['hw0_c'][0]) # 1/8 feat
+ x2 = data['feats_x2'] # 1/4 feat
+ x1 = data['feats_x1'] # 1/2 feat
+ del data['feats_x2'], data['feats_x1']
+
+ # 1. fine feature extraction
+ x1 = self.inter_fpn(feat_c, x2, x1, stride)
+ feat_f0, feat_f1 = torch.chunk(x1, 2, dim=0)
+
+ # 2. unfold(crop) all local windows
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
+
+ # 3. select only the predicted matches
+ feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
+
+ return feat_f0, feat_f1
+ else: # handle different input shapes
+ feat_c0, feat_c1 = rearrange(feat_c0, 'b (h w) c -> b c h w', h=data['hw0_c'][0]), rearrange(feat_c1, 'b (h w) c -> b c h w', h=data['hw1_c'][0]) # 1/8 feat
+ x2_0, x2_1 = data['feats_x2_0'], data['feats_x2_1'] # 1/4 feat
+ x1_0, x1_1 = data['feats_x1_0'], data['feats_x1_1'] # 1/2 feat
+ del data['feats_x2_0'], data['feats_x1_0'], data['feats_x2_1'], data['feats_x1_1']
+
+ # 1. fine feature extraction
+ feat_f0, feat_f1 = self.inter_fpn(feat_c0, x2_0, x1_0, stride), self.inter_fpn(feat_c1, x2_1, x1_1, stride)
+
+ # 2. unfold(crop) all local windows
+ feat_f0 = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=0)
+ feat_f0 = rearrange(feat_f0, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1 = F.unfold(feat_f1, kernel_size=(W+2, W+2), stride=stride, padding=1)
+ feat_f1 = rearrange(feat_f1, 'n (c ww) l -> n l ww c', ww=(W+2)**2)
+
+ # 3. select only the predicted matches
+ feat_f0 = feat_f0[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1 = feat_f1[data['b_ids'], data['j_ids']]
+
+ return feat_f0, feat_f1
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d95414e35441522c7df65c88a403672d3aa227b
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py
@@ -0,0 +1,103 @@
+"""
+Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
+Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
+"""
+
+import torch
+from torch.nn import Module
+import torch.nn.functional as F
+from einops.einops import rearrange
+
+if hasattr(F, 'scaled_dot_product_attention'):
+ FLASH_AVAILABLE = True
+ from torch.backends.cuda import sdp_kernel
+else:
+ FLASH_AVAILABLE = False
+
+def crop_feature(query, key, value, x_mask, source_mask):
+ mask_h0, mask_w0, mask_h1, mask_w1 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0], source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
+ query = query[:, :mask_h0, :mask_w0, :]
+ key = key[:, :mask_h1, :mask_w1, :]
+ value = value[:, :mask_h1, :mask_w1, :]
+ return query, key, value, mask_h0, mask_w0
+
+def pad_feature(m, mask_h0, mask_w0, x_mask):
+ bs, L, H, D = m.size()
+ m = m.view(bs, mask_h0, mask_w0, H, D)
+ if mask_h0 != x_mask.size(-2):
+ m = torch.cat([m, torch.zeros(m.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), H, D, device=m.device, dtype=m.dtype)], dim=1)
+ elif mask_w0 != x_mask.size(-1):
+ m = torch.cat([m, torch.zeros(m.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, H, D, device=m.device, dtype=m.dtype)], dim=2)
+ return m
+
+class Attention(Module):
+ def __init__(self, no_flash=False, nhead=8, dim=256, fp32=False):
+ super().__init__()
+ self.flash = FLASH_AVAILABLE and not no_flash
+ self.nhead = nhead
+ self.dim = dim
+ self.fp32 = fp32
+
+ def attention(self, query, key, value, q_mask=None, kv_mask=None):
+ assert q_mask is None and kv_mask is None, "Not support generalized attention mask yet."
+ if self.flash and not self.fp32:
+ args = [x.contiguous() for x in [query, key, value]]
+ with sdp_kernel(enable_math= False, enable_flash= True, enable_mem_efficient= False):
+ out = F.scaled_dot_product_attention(*args)
+ elif self.flash:
+ args = [x.contiguous() for x in [query, key, value]]
+ out = F.scaled_dot_product_attention(*args)
+ else:
+ QK = torch.einsum("nlhd,nshd->nlsh", query, key)
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / query.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+
+ out = torch.einsum("nlsh,nshd->nlhd", A, value)
+ return out
+
+ def _forward(self, query, key, value, q_mask=None, kv_mask=None):
+ if q_mask is not None:
+ query, key, value, mask_h0, mask_w0 = crop_feature(query, key, value, q_mask, kv_mask)
+
+ if self.flash:
+ query, key, value = map(lambda x: rearrange(x, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim), [query, key, value])
+ else:
+ query, key, value = map(lambda x: rearrange(x, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim), [query, key, value])
+
+ m = self.attention(query, key, value, q_mask=None, kv_mask=None)
+
+ if self.flash:
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
+
+ if q_mask is not None:
+ m = pad_feature(m, mask_h0, mask_w0, q_mask)
+
+ return m
+
+ def forward(self, query, key, value, q_mask=None, kv_mask=None):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ if FLASH_AVAILABLE: # pytorch scaled_dot_product_attention
+ queries: [N, H, L, D]
+ keys: [N, H, S, D]
+ values: [N, H, S, D]
+ else:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ bs = query.size(0)
+ if bs == 1 or q_mask is None:
+ m = self._forward(query, key, value, q_mask=q_mask, kv_mask=kv_mask)
+ else: # for faster trainning with padding mask while batch size > 1
+ m_list = []
+ for i in range(bs):
+ m_list.append(self._forward(query[i:i+1], key[i:i+1], value[i:i+1], q_mask=q_mask[i:i+1], kv_mask=kv_mask[i:i+1]))
+ m = torch.cat(m_list, dim=0)
+ return m
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e97a033a185049539a9f2fd29483333a839a3bcd
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py
@@ -0,0 +1,164 @@
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .linear_attention import Attention, crop_feature, pad_feature
+from einops.einops import rearrange
+from collections import OrderedDict
+from ..utils.position_encoding import RoPEPositionEncodingSine
+import numpy as np
+from loguru import logger
+
+class AG_RoPE_EncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ agg_size0=4,
+ agg_size1=4,
+ no_flash=False,
+ rope=False,
+ npe=None,
+ fp32=False,
+ ):
+ super(AG_RoPE_EncoderLayer, self).__init__()
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+ self.agg_size0, self.agg_size1 = agg_size0, agg_size1
+ self.rope = rope
+
+ # aggregate and position encoding
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=agg_size0, padding=0, stride=agg_size0, bias=False, groups=d_model) if self.agg_size0 != 1 else nn.Identity()
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.agg_size1, stride=self.agg_size1) if self.agg_size1 != 1 else nn.Identity()
+ if self.rope:
+ self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+ self.attention = Attention(no_flash, self.nhead, self.dim, fp32)
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.LeakyReLU(inplace = True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, C, H0, W0]
+ source (torch.Tensor): [N, C, H1, W1]
+ x_mask (torch.Tensor): [N, H0, W0] (optional) (L = H0*W0)
+ source_mask (torch.Tensor): [N, H1, W1] (optional) (S = H1*W1)
+ """
+ bs, C, H0, W0 = x.size()
+ H1, W1 = source.size(-2), source.size(-1)
+
+ # Aggragate feature
+ query, source = self.norm1(self.aggregate(x).permute(0,2,3,1)), self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
+ if x_mask is not None:
+ x_mask, source_mask = map(lambda x: self.max_pool(x.float()).bool(), [x_mask, source_mask])
+ query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
+
+ # Positional encoding
+ if self.rope:
+ query = self.rope_pos_enc(query)
+ key = self.rope_pos_enc(key)
+
+ # multi-head attention handle padding mask
+ m = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask)
+ m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
+
+ # Upsample feature
+ m = rearrange(m, 'b (h w) c -> b c h w', h=H0 // self.agg_size0, w=W0 // self.agg_size0) # [N, C, H0, W0]
+ if self.agg_size0 != 1:
+ m = torch.nn.functional.interpolate(m, scale_factor=self.agg_size0, mode='bilinear', align_corners=False) # [N, C, H0, W0]
+
+ # feed-forward network
+ m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H0, W0, C]
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H0, W0]
+
+ return x + m
+
+class LocalFeatureTransformer(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(LocalFeatureTransformer, self).__init__()
+
+ self.full_config = config
+ self.fp32 = not (config['mp'] or config['half'])
+ config = config['coarse']
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+ self.layer_names = config['layer_names']
+ self.agg_size0, self.agg_size1 = config['agg_size0'], config['agg_size1']
+ self.rope = config['rope']
+
+ self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'],
+ config['no_flash'], config['rope'], config['npe'], self.fp32)
+ cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'],
+ config['no_flash'], False, config['npe'], self.fp32)
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, C, H, W]
+ feat1 (torch.Tensor): [N, C, H, W]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+ H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
+ bs = feat0.shape[0]
+
+ feature_cropped = False
+ if bs == 1 and mask0 is not None and mask1 is not None:
+ mask_H0, mask_W0, mask_H1, mask_W1 = mask0.size(-2), mask0.size(-1), mask1.size(-2), mask1.size(-1)
+ mask_h0, mask_w0, mask_h1, mask_w1 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0], mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
+ mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.agg_size0*self.agg_size0, mask_w0//self.agg_size0*self.agg_size0, mask_h1//self.agg_size1*self.agg_size1, mask_w1//self.agg_size1*self.agg_size1
+ feat0 = feat0[:, :, :mask_h0, :mask_w0]
+ feat1 = feat1[:, :, :mask_h1, :mask_w1]
+ feature_cropped = True
+
+ for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
+ if feature_cropped:
+ mask0, mask1 = None, None
+ if name == 'self':
+ feat0 = layer(feat0, feat0, mask0, mask0)
+ feat1 = layer(feat1, feat1, mask1, mask1)
+ elif name == 'cross':
+ feat0 = layer(feat0, feat1, mask0, mask1)
+ feat1 = layer(feat1, feat0, mask1, mask0)
+ else:
+ raise KeyError
+
+ if feature_cropped:
+ # padding feature
+ bs, c, mask_h0, mask_w0 = feat0.size()
+ if mask_h0 != mask_H0:
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
+ elif mask_w0 != mask_W0:
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
+
+ bs, c, mask_h1, mask_w1 = feat1.size()
+ if mask_h1 != mask_H1:
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
+ elif mask_w1 != mask_W1:
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
+
+ return feat0, feat1
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..156c9eecf8c2cfb54b8eb22a8663d5cda5afa6a8
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py
@@ -0,0 +1,241 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange, repeat
+
+from loguru import logger
+import numpy as np
+
+INF = 1e9
+
+def mask_border(m, b: int, v):
+ """ Mask borders with value
+ Args:
+ m (torch.Tensor): [N, H0, W0, H1, W1]
+ b (int)
+ v (m.dtype)
+ """
+ if b <= 0:
+ return
+
+ m[:, :b] = v
+ m[:, :, :b] = v
+ m[:, :, :, :b] = v
+ m[:, :, :, :, :b] = v
+ m[:, -b:] = v
+ m[:, :, -b:] = v
+ m[:, :, :, -b:] = v
+ m[:, :, :, :, -b:] = v
+
+
+def mask_border_with_padding(m, bd, v, p_m0, p_m1):
+ if bd <= 0:
+ return
+
+ m[:, :bd] = v
+ m[:, :, :bd] = v
+ m[:, :, :, :bd] = v
+ m[:, :, :, :, :bd] = v
+
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
+ m[b_idx, h0 - bd:] = v
+ m[b_idx, :, w0 - bd:] = v
+ m[b_idx, :, :, h1 - bd:] = v
+ m[b_idx, :, :, :, w1 - bd:] = v
+
+
+def compute_max_candidates(p_m0, p_m1):
+ """Compute the max candidates of all pairs within a batch
+
+ Args:
+ p_m0, p_m1 (torch.Tensor): padded masks
+ """
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
+ max_cand = torch.sum(
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+ return max_cand
+
+class CoarseMatching(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # general config
+ self.thr = config['thr']
+ self.border_rm = config['border_rm']
+ self.temperature = config['dsmax_temperature']
+ self.skip_softmax = config['skip_softmax']
+ self.fp16matmul = config['fp16matmul']
+ # -- # for trainig fine-level LoFTR
+ self.train_coarse_percent = config['train_coarse_percent']
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
+
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ data (dict)
+ mask_c0 (torch.Tensor): [N, L] (optional)
+ mask_c1 (torch.Tensor): [N, S] (optional)
+ Update:
+ data (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ NOTE: M' != M during training.
+ """
+ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
+
+ # normalize
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1])
+
+ if self.fp16matmul:
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
+ feat_c1) / self.temperature
+ del feat_c0, feat_c1
+ if mask_c0 is not None:
+ sim_matrix = sim_matrix.masked_fill(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -1e4
+ )
+ else:
+ with torch.autocast(enabled=False, device_type='cuda'):
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
+ feat_c1) / self.temperature
+ del feat_c0, feat_c1
+ if mask_c0 is not None:
+ sim_matrix = sim_matrix.float().masked_fill(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -INF
+ )
+ if self.skip_softmax:
+ sim_matrix = sim_matrix
+ else:
+ sim_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
+
+ data.update({'conf_matrix': sim_matrix})
+
+ # predict coarse matches from conf_matrix
+ data.update(**self.get_coarse_match(sim_matrix, data))
+
+ @torch.no_grad()
+ def get_coarse_match(self, conf_matrix, data):
+ """
+ Args:
+ conf_matrix (torch.Tensor): [N, L, S]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ axes_lengths = {
+ 'h0c': data['hw0_c'][0],
+ 'w0c': data['hw0_c'][1],
+ 'h1c': data['hw1_c'][0],
+ 'w1c': data['hw1_c'][1]
+ }
+ _device = conf_matrix.device
+ # 1. confidence thresholding
+ mask = conf_matrix > self.thr
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
+ **axes_lengths)
+
+ if 'mask0' not in data:
+ mask_border(mask, self.border_rm, False)
+ else:
+ mask_border_with_padding(mask, self.border_rm, False,
+ data['mask0'], data['mask1'])
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
+ **axes_lengths)
+
+ # 2. mutual nearest
+ mask = mask \
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
+
+ # 3. find all valid coarse matches
+ # this only works when at most one `True` in each row
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
+
+ # 4. Random sampling of training samples for fine-level LoFTR
+ # (optional) pad samples with gt coarse-level matches
+ if self.training:
+ # NOTE:
+ # The sampling is performed across all pairs in a batch without manually balancing
+ # #samples for fine-level increases w.r.t. batch_size
+ if 'mask0' not in data:
+ num_candidates_max = mask.size(0) * max(
+ mask.size(1), mask.size(2))
+ else:
+ num_candidates_max = compute_max_candidates(
+ data['mask0'], data['mask1'])
+ num_matches_train = int(num_candidates_max *
+ self.train_coarse_percent)
+ num_matches_pred = len(b_ids)
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
+
+ # pred_indices is to select from prediction
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
+ pred_indices = torch.arange(num_matches_pred, device=_device)
+ else:
+ pred_indices = torch.randint(
+ num_matches_pred,
+ (num_matches_train - self.train_pad_num_gt_min, ),
+ device=_device)
+
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
+ gt_pad_indices = torch.randint(
+ len(data['spv_b_ids']),
+ (max(num_matches_train - num_matches_pred,
+ self.train_pad_num_gt_min), ),
+ device=_device)
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
+
+ b_ids, i_ids, j_ids, mconf = map(
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
+ dim=0),
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+
+ # These matches select patches that feed into fine-level network
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+
+ # 4. Update with matches in original image resolution
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+ mkpts0_c = torch.stack(
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
+ dim=1) * scale0
+ mkpts1_c = torch.stack(
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
+ dim=1) * scale1
+
+ m_bids = b_ids[mconf != 0]
+ # These matches is the current prediction (for visualization)
+ coarse_matches.update({
+ 'm_bids': m_bids, # mconf == 0 => gt matches
+ 'mkpts0_c': mkpts0_c[mconf != 0],
+ 'mkpts1_c': mkpts1_c[mconf != 0],
+ 'mconf': mconf[mconf != 0]
+ })
+
+ return coarse_matches
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..6adb6f8c8a1c3d25babda3d5cbd79b44285c2eb9
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py
@@ -0,0 +1,156 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+from loguru import logger
+
+class FineMatching(nn.Module):
+ """FineMatching with s2d paradigm"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.local_regress_temperature = config['match_fine']['local_regress_temperature']
+ self.local_regress_slicedim = config['match_fine']['local_regress_slicedim']
+ self.fp16 = config['half']
+ self.validate = False
+
+ def forward(self, feat_0, feat_1, data):
+ """
+ Args:
+ feat0 (torch.Tensor): [M, WW, C]
+ feat1 (torch.Tensor): [M, WW, C]
+ data (dict)
+ Update:
+ data (dict):{
+ 'expec_f' (torch.Tensor): [M, 3],
+ 'mkpts0_f' (torch.Tensor): [M, 2],
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
+ """
+ M, WW, C = feat_0.shape
+ W = int(math.sqrt(WW))
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
+
+ # corner case: if no coarse matches found
+ if M == 0:
+ assert self.training == False, "M is always > 0 while training, see coarse_matching.py"
+ data.update({
+ 'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+
+ # compute pixel-level confidence matrix
+ with torch.autocast(enabled=True if not (self.training or self.validate) else False, device_type='cuda'):
+ feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim]
+ feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:]
+ feat_f0, feat_f1 = feat_f0 / C**.5, feat_f1 / C**.5
+ conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1)
+ conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5)
+
+ softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
+ softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
+ softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW)
+
+ # for fine-level supervision
+ if self.training or self.validate:
+ data.update({'sim_matrix_ff': conf_matrix_ff})
+ data.update({'conf_matrix_f': softmax_matrix_f})
+
+ # compute pixel-level absolute kpt coords
+ self.get_fine_ds_match(softmax_matrix_f, data)
+
+ # generate seconde-stage 3x3 grid
+ idx_l, idx_r = data['idx_l'], data['idx_r']
+ m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1)
+ m_ids = m_ids[:len(data['mconf'])]
+ idx_r_iids, idx_r_jids = idx_r // W, idx_r % W
+
+ m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
+ delta = create_meshgrid(3, 3, True, conf_matrix_ff.device).to(torch.long) # [1, 3, 3, 2]
+
+ m_ids = m_ids[...,None,None].expand(-1, 3, 3)
+ idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3]
+
+ idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
+ idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
+
+ if idx_l.numel() == 0:
+ data.update({
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+
+ # compute second-stage heatmap
+ conf_matrix_ff = conf_matrix_ff.reshape(M, self.WW, self.W+2, self.W+2)
+ conf_matrix_ff = conf_matrix_ff[m_ids, idx_l, idx_r_iids, idx_r_jids]
+ conf_matrix_ff = conf_matrix_ff.reshape(-1, 9)
+ conf_matrix_ff = F.softmax(conf_matrix_ff / self.local_regress_temperature, -1)
+ heatmap = conf_matrix_ff.reshape(-1, 3, 3)
+
+ # compute coordinates from heatmap
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
+
+ if data['bs'] == 1:
+ scale1 = scale * data['scale1'] if 'scale0' in data else scale
+ else:
+ scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']), ...][:,None,:].expand(-1, -1, 2).reshape(-1, 2) if 'scale0' in data else scale
+
+ # compute subpixel-level absolute kpt coords
+ self.get_fine_match_local(coords_normalized, data, scale1)
+
+ def get_fine_match_local(self, coords_normed, data, scale1):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+
+ mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c']
+
+ # mkpts0_f and mkpts1_f
+ mkpts0_f = mkpts0_c
+ mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1)
+
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
+
+ @torch.no_grad()
+ def get_fine_ds_match(self, conf_matrix, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+ m, _, _ = conf_matrix.shape
+
+ conf_matrix = conf_matrix.reshape(m, -1)[:len(data['mconf']),...]
+ val, idx = torch.max(conf_matrix, dim = -1)
+ idx = idx[:,None]
+ idx_l, idx_r = idx // WW, idx % WW
+
+ data.update({'idx_l': idx_l, 'idx_r': idx_r})
+
+ if self.fp16:
+ grid = create_meshgrid(W, W, False, conf_matrix.device, dtype=torch.float16) - W // 2 + 0.5 # kornia >= 0.5.1
+ else:
+ grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5
+ grid = grid.reshape(1, -1, 2).expand(m, -1, -1)
+ delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2))
+ delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2))
+
+ scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
+
+ if torch.is_tensor(scale0) and scale0.numel() > 1: # scale0 is a tensor
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
+ else: # scale0 is a float
+ mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2)
+ mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2)
+
+ data.update({
+ "mkpts0_c": mkpts0_f,
+ "mkpts1_c": mkpts1_f
+ })
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/full_config.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/full_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccf84b48f6693be4bf9306c8bd987d87a6f43792
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/full_config.py
@@ -0,0 +1,50 @@
+from yacs.config import CfgNode as CN
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+_CN = CN()
+_CN.BACKBONE_TYPE = 'RepVGG'
+_CN.ALIGN_CORNER = False
+_CN.RESOLUTION = (8, 1)
+_CN.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even
+_CN.MP = False
+_CN.REPLACE_NAN = True
+_CN.HALF = False
+
+# 1. LoFTR-backbone (local feature CNN) config
+_CN.BACKBONE = CN()
+_CN.BACKBONE.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3
+
+# 2. LoFTR-coarse module config
+_CN.COARSE = CN()
+_CN.COARSE.D_MODEL = 256
+_CN.COARSE.D_FFN = 256
+_CN.COARSE.NHEAD = 8
+_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
+_CN.COARSE.AGG_SIZE0 = 4
+_CN.COARSE.AGG_SIZE1 = 4
+_CN.COARSE.NO_FLASH = False
+_CN.COARSE.ROPE = True
+_CN.COARSE.NPE = [832, 832, 832, 832] # [832, 832, long_side, long_side] Suggest setting based on the long side of the input image, especially when the long_side > 832
+
+# 3. Coarse-Matching config
+_CN.MATCH_COARSE = CN()
+_CN.MATCH_COARSE.THR = 0.2 # recommend 0.2 for full model and 25 for optimized model
+_CN.MATCH_COARSE.BORDER_RM = 2
+_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.MATCH_COARSE.SKIP_SOFTMAX = False # False for full model and True for optimized model
+_CN.MATCH_COARSE.FP16MATMUL = False # False for full model and True for optimized model
+_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
+_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+
+# 4. Fine-Matching config
+_CN.MATCH_FINE = CN()
+_CN.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0 # use 10.0 as fine local regress temperature, not 1.0
+_CN.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
+
+full_default_cfg = lower_config(_CN)
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/geometry.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/geometry.py
@@ -0,0 +1,54 @@
+import torch
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - ,
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ kpts0_long = kpts0.round().long()
+
+ # Sample depth, get calculable_mask on depth != 0
+ kpts0_depth = torch.stack(
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+ ) # (N, L)
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+ w_kpts0_long = w_kpts0.long()
+ w_kpts0_long[~covisible_mask, :] = 0
+
+ w_kpts0_depth = torch.stack(
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+ ) # (N, L)
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..61b7fa1e88a72db226dbbbf3b47c2b4f40e7aff7
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py
@@ -0,0 +1,50 @@
+from yacs.config import CfgNode as CN
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+_CN = CN()
+_CN.BACKBONE_TYPE = 'RepVGG'
+_CN.ALIGN_CORNER = False
+_CN.RESOLUTION = (8, 1)
+_CN.FINE_WINDOW_SIZE = 8 # window_size in fine_level, must be even
+_CN.MP = False
+_CN.REPLACE_NAN = True
+_CN.HALF = False
+
+# 1. LoFTR-backbone (local feature CNN) config
+_CN.BACKBONE = CN()
+_CN.BACKBONE.BLOCK_DIMS = [64, 128, 256] # s1, s2, s3
+
+# 2. LoFTR-coarse module config
+_CN.COARSE = CN()
+_CN.COARSE.D_MODEL = 256
+_CN.COARSE.D_FFN = 256
+_CN.COARSE.NHEAD = 8
+_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
+_CN.COARSE.AGG_SIZE0 = 4
+_CN.COARSE.AGG_SIZE1 = 4
+_CN.COARSE.NO_FLASH = False
+_CN.COARSE.ROPE = True
+_CN.COARSE.NPE = [832, 832, 832, 832] # [832, 832, long_side, long_side] Suggest setting based on the long side of the input image, especially when the long_side > 832
+
+# 3. Coarse-Matching config
+_CN.MATCH_COARSE = CN()
+_CN.MATCH_COARSE.THR = 25 # recommend 0.2 for full model and 25 for optimized model
+_CN.MATCH_COARSE.BORDER_RM = 2
+_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.MATCH_COARSE.SKIP_SOFTMAX = True # False for full model and True for optimized model
+_CN.MATCH_COARSE.FP16MATMUL = True # False for full model and True for optimized model
+_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
+_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+
+# 4. Fine-Matching config
+_CN.MATCH_FINE = CN()
+_CN.MATCH_FINE.LOCAL_REGRESS_TEMPERATURE = 10.0 # use 10.0 as fine local regress temperature, not 1.0
+_CN.MATCH_FINE.LOCAL_REGRESS_SLICEDIM = 8
+
+opt_default_cfg = lower_config(_CN)
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..6431d0e2ce468fad5b0f0f6838c2ae2e5c089b32
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py
@@ -0,0 +1,50 @@
+import math
+import torch
+from torch import nn
+
+class RoPEPositionEncodingSine(nn.Module):
+ """
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
+ """
+
+ def __init__(self, d_model, max_shape=(256, 256), npe=None, ropefp16=True):
+ """
+ Args:
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
+ """
+ super().__init__()
+
+ i_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(-1) # [H, 1]
+ j_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(-1) # [W, 1]
+
+ assert npe is not None
+ train_res_H, train_res_W, test_res_H, test_res_W = npe[0], npe[1], npe[2], npe[3] # train_res_H, train_res_W, test_res_H, test_res_W
+ i_position, j_position = i_position * train_res_H / test_res_H, j_position * train_res_W / test_res_W
+
+ div_term = torch.exp(torch.arange(0, d_model//4, 1).float() * (-math.log(10000.0) / (d_model//4)))
+ div_term = div_term[None, None, :] # [1, 1, C//4]
+
+ sin = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
+ cos = torch.zeros(*max_shape, d_model//2, dtype=torch.float16 if ropefp16 else torch.float32)
+ sin[:, :, 0::2] = torch.sin(i_position * div_term).half() if ropefp16 else torch.sin(i_position * div_term)
+ sin[:, :, 1::2] = torch.sin(j_position * div_term).half() if ropefp16 else torch.sin(j_position * div_term)
+ cos[:, :, 0::2] = torch.cos(i_position * div_term).half() if ropefp16 else torch.cos(i_position * div_term)
+ cos[:, :, 1::2] = torch.cos(j_position * div_term).half() if ropefp16 else torch.cos(j_position * div_term)
+
+ sin = sin.repeat_interleave(2, dim=-1)
+ cos = cos.repeat_interleave(2, dim=-1)
+
+ self.register_buffer('sin', sin.unsqueeze(0), persistent=False) # [1, H, W, C//2]
+ self.register_buffer('cos', cos.unsqueeze(0), persistent=False) # [1, H, W, C//2]
+
+ def forward(self, x, ratio=1):
+ """
+ Args:
+ x: [N, H, W, C]
+ """
+ return (x * self.cos[:, :x.size(1), :x.size(2), :]) + (self.rotate_half(x) * self.sin[:, :x.size(1), :x.size(2), :])
+
+ def rotate_half(self, x):
+ x = x.unflatten(-1, (-1, 2))
+ x1, x2 = x.unbind(dim=-1)
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/loftr/utils/supervision.py b/imcui/third_party/EfficientLoFTR/src/loftr/utils/supervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1ae8036ef5a8108ddb7eab21bdc5efb26d356d8
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/loftr/utils/supervision.py
@@ -0,0 +1,275 @@
+from math import log
+from loguru import logger as loguru_logger
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from kornia.utils import create_meshgrid
+from src.utils.plotting import make_matching_figures
+
+from .geometry import warp_kpts
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+def static_vars(**kwargs):
+ def decorate(func):
+ for k in kwargs:
+ setattr(func, k, kwargs[k])
+ return func
+ return decorate
+
+############## ↓ Coarse-Level supervision ↓ ##############
+
+
+@torch.no_grad()
+def mask_pts_at_padded_regions(grid_pt, mask):
+ """For megadepth dataset, zero-padding exists in images"""
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+ grid_pt[~mask.bool()] = 0
+ return grid_pt
+
+
+@torch.no_grad()
+def spvs_coarse(data, config):
+ """
+ Update:
+ data (dict): {
+ "conf_matrix_gt": [N, hw0, hw1],
+ 'spv_b_ids': [M]
+ 'spv_i_ids': [M]
+ 'spv_j_ids': [M]
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
+ }
+
+ NOTE:
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
+ """
+ # 1. misc
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+ scale = config['LOFTR']['RESOLUTION'][0]
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+
+ # 2. warp grids
+ # create kpts in meshgrid and resize them to image resolution
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
+ grid_pt0_i = scale0 * grid_pt0_c
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+ grid_pt1_i = scale1 * grid_pt1_c
+
+ # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
+ if 'mask0' in data:
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
+
+ # warp kpts bi-directionally and resize them to coarse-level resolution
+ # (no depth consistency check, since it leads to worse results experimentally)
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
+ _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
+ _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
+ w_pt0_c = w_pt0_i / scale1
+ w_pt1_c = w_pt1_i / scale0
+
+ # 3. check if mutual nearest neighbor
+ w_pt0_c_round = w_pt0_c[:, :, :].round()
+ # calculate the overlap area between warped patch and grid patch as the loss weight.
+ # (larger overlap area between warped patches and grid patch with higher weight)
+ # (overlap area range from [0, 1] rather than [0.25, 1] as the penalty of warped kpts fall on midpoint of two grid kpts)
+ if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
+ w_pt0_c_error = (1.0 - 2*torch.abs(w_pt0_c - w_pt0_c_round)).prod(-1)
+ w_pt0_c_round = w_pt0_c_round[:, :, :].long()
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
+
+ w_pt1_c_round = w_pt1_c[:, :, :].round().long()
+ nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
+
+ # corner case: out of boundary
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
+ nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
+
+ loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
+ correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+ correct_0to1[:, 0] = False # ignore the top-left corner
+
+ # 4. construct a gt conf_matrix
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
+ j_ids = nearest_index1[b_ids, i_ids]
+
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
+ data.update({'conf_matrix_gt': conf_matrix_gt})
+
+ # use overlap area as loss weight
+ if config.LOFTR.LOSS.COARSE_OVERLAP_WEIGHT:
+ conf_matrix_error_gt = w_pt0_c_error[b_ids, i_ids] # weight range: [0.0, 1.0]
+ data.update({'conf_matrix_error_gt': conf_matrix_error_gt})
+
+
+ # 5. save coarse matches(gt) for training fine level
+ if len(b_ids) == 0:
+ loguru_logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
+ # this won't affect fine-level loss calculation
+ b_ids = torch.tensor([0], device=device)
+ i_ids = torch.tensor([0], device=device)
+ j_ids = torch.tensor([0], device=device)
+
+ data.update({
+ 'spv_b_ids': b_ids,
+ 'spv_i_ids': i_ids,
+ 'spv_j_ids': j_ids
+ })
+
+ # 6. save intermediate results (for fast fine-level computation)
+ data.update({
+ 'spv_w_pt0_i': w_pt0_i,
+ 'spv_pt1_i': grid_pt1_i
+ })
+
+
+def compute_supervision_coarse(data, config):
+ assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_coarse(data, config)
+ else:
+ raise ValueError(f'Unknown data source: {data_source}')
+
+
+############## ↓ Fine-Level supervision ↓ ##############
+
+@static_vars(counter = 0)
+@torch.no_grad()
+def spvs_fine(data, config, logger = None):
+ """
+ Update:
+ data (dict):{
+ "expec_f_gt": [M, 2], used as subpixel-level gt
+ "conf_matrix_f_gt": [M, WW, WW], M is the number of all coarse-level gt matches
+ "conf_matrix_f_error_gt": [Mp], Mp is the number of all pixel-level gt matches
+ "m_ids_f": [Mp]
+ "i_ids_f": [Mp]
+ "j_ids_f_di": [Mp]
+ "j_ids_f_dj": [Mp]
+ }
+ """
+ # 1. misc
+ pt1_i = data['spv_pt1_i']
+ W = config['LOFTR']['FINE_WINDOW_SIZE']
+ WW = W*W
+ scale = config['LOFTR']['RESOLUTION'][1]
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+ hf0, wf0, hf1, wf1 = data['hw0_f'][0], data['hw0_f'][1], data['hw1_f'][0], data['hw1_f'][1] # h, w of fine feature
+ assert not config.LOFTR.ALIGN_CORNER, 'only support training with align_corner=False for now.'
+
+ # 2. get coarse prediction
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+ scalei0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scalei1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+
+ # 3. compute gt
+ m = b_ids.shape[0]
+ if m == 0: # special case: there is no coarse gt
+ conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device)
+
+ data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
+ conf_matrix_f_error_gt = torch.zeros(1, device=device)
+ data.update({'conf_matrix_f_error_gt': conf_matrix_f_error_gt})
+
+ data.update({'expec_f': torch.zeros(1, 2, device=device)})
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
+ else:
+ grid_pt0_f = create_meshgrid(hf0, wf0, False, device) - W // 2 + 0.5 # [1, hf0, wf0, 2] # use fine coordinates
+ grid_pt0_f = rearrange(grid_pt0_f, 'n h w c -> n c h w')
+ # 1. unfold(crop) all local windows
+ if config.LOFTR.ALIGN_CORNER is False: # even windows
+ assert W==8
+ grid_pt0_f_unfold = F.unfold(grid_pt0_f, kernel_size=(W, W), stride=W, padding=0)
+ grid_pt0_f_unfold = rearrange(grid_pt0_f_unfold, 'n (c ww) l -> n l ww c', ww=W**2) # [1, hc0*wc0, W*W, 2]
+ grid_pt0_f_unfold = repeat(grid_pt0_f_unfold[0], 'l ww c -> N l ww c', N=N)
+
+ # 2. select only the predicted matches
+ grid_pt0_f_unfold = grid_pt0_f_unfold[data['b_ids'], data['i_ids']] # [m, ww, 2]
+ grid_pt0_f_unfold = scalei0[:,None,:] * grid_pt0_f_unfold # [m, ww, 2]
+
+ # 3. warp grids and get covisible & depth_consistent mask
+ correct_0to1_f = torch.zeros(m, WW, device=device, dtype=torch.bool)
+ w_pt0_i = torch.zeros(m, WW, 2, device=device, dtype=torch.float32)
+ for b in range(N):
+ mask = b_ids == b # mask of each batch
+ match = int(mask.sum())
+ correct_0to1_f_mask, w_pt0_i_mask = warp_kpts(grid_pt0_f_unfold[mask].reshape(1,-1,2), data['depth0'][[b],...],
+ data['depth1'][[b],...], data['T_0to1'][[b],...],
+ data['K0'][[b],...], data['K1'][[b],...]) # [k, WW], [k, WW, 2]
+ correct_0to1_f[mask] = correct_0to1_f_mask.reshape(match, WW)
+ w_pt0_i[mask] = w_pt0_i_mask.reshape(match, WW, 2)
+
+ # 4. calculate the gt index of pixel-level refinement
+ delta_w_pt0_i = w_pt0_i - pt1_i[b_ids, j_ids][:,None,:] # [m, WW, 2]
+ del b_ids, i_ids, j_ids
+ delta_w_pt0_f = delta_w_pt0_i / scalei1[:,None,:] + W // 2 - 0.5
+ delta_w_pt0_f_round = delta_w_pt0_f[:, :, :].round()
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
+ # calculate the overlap area between warped patch and grid patch as the loss weight.
+ w_pt0_f_error = (1.0 - 2*torch.abs(delta_w_pt0_f - delta_w_pt0_f_round)).prod(-1) # [0, 1]
+ delta_w_pt0_f_round = delta_w_pt0_f_round.long()
+
+ nearest_index1 = delta_w_pt0_f_round[..., 0] + delta_w_pt0_f_round[..., 1] * W # [m, WW]
+
+ # corner case: out of fine windows
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ ob_mask = out_bound_mask(delta_w_pt0_f_round, W, W)
+ nearest_index1[ob_mask] = 0
+ correct_0to1_f[ob_mask] = 0
+
+ m_ids, i_ids = torch.where(correct_0to1_f != 0)
+ j_ids = nearest_index1[m_ids, i_ids] # i_ids, j_ids range from [0, WW-1]
+ j_ids_di, j_ids_dj = j_ids // W, j_ids % W # further get the (i, j) index in fine windows of image1 (right image); j_ids_di, j_ids_dj range from [0, W-1]
+ m_ids, i_ids, j_ids_di, j_ids_dj = m_ids.to(torch.long), i_ids.to(torch.long), j_ids_di.to(torch.long), j_ids_dj.to(torch.long)
+
+ # expec_f_gt will be used as the gt of subpixel-level refinement
+ expec_f_gt = delta_w_pt0_f - delta_w_pt0_f_round
+
+ if m_ids.numel() == 0: # special case: there is no pixel-level gt
+ loguru_logger.warning(f"No groundtruth fine match found for local regress: {data['pair_names']}")
+ # this won't affect fine-level loss calculation
+ data.update({'expec_f': torch.zeros(1, 2, device=device)})
+ data.update({'expec_f_gt': torch.zeros(1, 2, device=device)})
+ else:
+ expec_f_gt = expec_f_gt[m_ids, i_ids]
+ data.update({"expec_f_gt": expec_f_gt})
+ data.update({"m_ids_f": m_ids,
+ "i_ids_f": i_ids,
+ "j_ids_f_di": j_ids_di,
+ "j_ids_f_dj": j_ids_dj
+ })
+
+ # 5. construct a pixel-level gt conf_matrix
+ conf_matrix_f_gt = torch.zeros(m, WW, WW, device=device, dtype=torch.bool)
+ conf_matrix_f_gt[m_ids, i_ids, j_ids] = 1
+ data.update({'conf_matrix_f_gt': conf_matrix_f_gt})
+ if config.LOFTR.LOSS.FINE_OVERLAP_WEIGHT:
+ # calculate the overlap area between warped pixel and grid pixel as the loss weight.
+ w_pt0_f_error = w_pt0_f_error[m_ids, i_ids]
+ data.update({'conf_matrix_f_error_gt': w_pt0_f_error})
+
+ if conf_matrix_f_gt.sum() == 0:
+ loguru_logger.info(f'no fine matches to supervise')
+
+def compute_supervision_fine(data, config, logger=None):
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_fine(data, config, logger)
+ else:
+ raise NotImplementedError
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/losses/loftr_loss.py b/imcui/third_party/EfficientLoFTR/src/losses/loftr_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..eea71e59c1b43111bfb0d24f704df1a90bb66a03
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/losses/loftr_loss.py
@@ -0,0 +1,229 @@
+from loguru import logger
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+
+class LoFTRLoss(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config # config under the global namespace
+
+ self.loss_config = config['loftr']['loss']
+ self.match_type = 'dual_softmax'
+ self.sparse_spvs = self.config['loftr']['match_coarse']['sparse_spvs']
+ self.fine_sparse_spvs = self.config['loftr']['match_fine']['sparse_spvs']
+
+ # coarse-level
+ self.correct_thr = self.loss_config['fine_correct_thr']
+ self.c_pos_w = self.loss_config['pos_weight']
+ self.c_neg_w = self.loss_config['neg_weight']
+ # coarse_overlap_weight
+ self.overlap_weightc = self.config['loftr']['loss']['coarse_overlap_weight']
+ self.overlap_weightf = self.config['loftr']['loss']['fine_overlap_weight']
+ # subpixel-level
+ self.local_regressw = self.config['loftr']['fine_window_size']
+ self.local_regress_temperature = self.config['loftr']['match_fine']['local_regress_temperature']
+
+
+ def compute_coarse_loss(self, conf, conf_gt, weight=None, overlap_weight=None):
+ """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
+ Args:
+ conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
+ conf_gt (torch.Tensor): (N, HW0, HW1)
+ weight (torch.Tensor): (N, HW0, HW1)
+ """
+ pos_mask, neg_mask = conf_gt == 1, conf_gt == 0
+ del conf_gt
+ # logger.info(f'real sum of conf_matrix_c_gt: {pos_mask.sum().item()}')
+ c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w
+ # corner case: no gt coarse-level match at all
+ if not pos_mask.any(): # assign a wrong gt
+ pos_mask[0, 0, 0] = True
+ if weight is not None:
+ weight[0, 0, 0] = 0.
+ c_pos_w = 0.
+ if not neg_mask.any():
+ neg_mask[0, 0, 0] = True
+ if weight is not None:
+ weight[0, 0, 0] = 0.
+ c_neg_w = 0.
+
+ if self.loss_config['coarse_type'] == 'focal':
+ conf = torch.clamp(conf, 1e-6, 1-1e-6)
+ alpha = self.loss_config['focal_alpha']
+ gamma = self.loss_config['focal_gamma']
+
+ if self.sparse_spvs:
+ pos_conf = conf[pos_mask]
+ loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
+ # handle loss weights
+ if weight is not None:
+ # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
+ # but only through manually setting corresponding regions in sim_matrix to '-inf'.
+ loss_pos = loss_pos * weight[pos_mask]
+ if self.overlap_weightc:
+ loss_pos = loss_pos * overlap_weight # already been masked slice in supervision
+
+ loss = c_pos_w * loss_pos.mean()
+ return loss
+ else: # dense supervision
+ loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
+ loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log()
+ logger.info("conf_pos_c: {loss_pos}, conf_neg_c: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean()))
+ if weight is not None:
+ loss_pos = loss_pos * weight[pos_mask]
+ loss_neg = loss_neg * weight[neg_mask]
+ if self.overlap_weightc:
+ loss_pos = loss_pos * overlap_weight # already been masked slice in supervision
+
+ loss_pos_mean, loss_neg_mean = loss_pos.mean(), loss_neg.mean()
+ logger.info("conf_pos_c: {loss_pos}, conf_neg_c: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean()))
+ return c_pos_w * loss_pos_mean + c_neg_w * loss_neg_mean
+ # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
+ else:
+ raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type']))
+
+ def compute_fine_loss(self, conf_matrix_f, conf_matrix_f_gt, overlap_weight=None):
+ """
+ Args:
+ conf_matrix_f (torch.Tensor): [m, WW, WW]
+ conf_matrix_f_gt (torch.Tensor): [m, WW, WW]
+ """
+ if conf_matrix_f_gt.shape[0] == 0:
+ if self.training: # this seldomly happen during training, since we pad prediction with gt
+ # sometimes there is not coarse-level gt at all.
+ logger.warning("assign a false supervision to avoid ddp deadlock")
+ pass
+ else:
+ return None
+ pos_mask, neg_mask = conf_matrix_f_gt == 1, conf_matrix_f_gt == 0
+ del conf_matrix_f_gt
+ c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w
+
+ if not pos_mask.any(): # assign a wrong gt
+ pos_mask[0, 0, 0] = True
+ c_pos_w = 0.
+ if not neg_mask.any():
+ neg_mask[0, 0, 0] = True
+ c_neg_w = 0.
+
+ conf = torch.clamp(conf_matrix_f, 1e-6, 1-1e-6)
+ alpha = self.loss_config['focal_alpha']
+ gamma = self.loss_config['focal_gamma']
+
+ if self.fine_sparse_spvs:
+ loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
+ if self.overlap_weightf:
+ loss_pos = loss_pos * overlap_weight # already been masked slice in supervision
+ return c_pos_w * loss_pos.mean()
+ else:
+ loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
+ loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log()
+ logger.info("conf_pos_f: {loss_pos}, conf_neg_f: {loss_neg}".format(loss_pos=conf[pos_mask].mean(), loss_neg=conf[neg_mask].mean()))
+ if self.overlap_weightf:
+ loss_pos = loss_pos * overlap_weight # already been masked slice in supervision
+
+ return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
+
+
+ def _compute_local_loss_l2(self, expec_f, expec_f_gt):
+ """
+ Args:
+ expec_f (torch.Tensor): [M, 2]
+ expec_f_gt (torch.Tensor): [M, 2]
+ """
+ correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+ if correct_mask.sum() == 0:
+ if self.training: # this seldomly happen when training, since we pad prediction with gt
+ logger.warning("assign a false supervision to avoid ddp deadlock")
+ correct_mask[0] = True
+ else:
+ return None
+ offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1)
+ return offset_l2.mean()
+
+ @torch.no_grad()
+ def compute_c_weight(self, data):
+ """ compute element-wise weights for computing coarse-level loss. """
+ if 'mask0' in data:
+ c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None])
+ else:
+ c_weight = None
+ return c_weight
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): update{
+ 'loss': [1] the reduced loss across a batch,
+ 'loss_scalars' (dict): loss scalars for tensorboard_record
+ }
+ """
+ loss_scalars = {}
+ # 0. compute element-wise loss weight
+ c_weight = self.compute_c_weight(data)
+
+ # 1. coarse-level loss
+ if self.overlap_weightc:
+ loss_c = self.compute_coarse_loss(
+ data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \
+ else data['conf_matrix'],
+ data['conf_matrix_gt'],
+ weight=c_weight, overlap_weight=data['conf_matrix_error_gt'])
+
+ else:
+ loss_c = self.compute_coarse_loss(
+ data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \
+ else data['conf_matrix'],
+ data['conf_matrix_gt'],
+ weight=c_weight)
+
+ loss = loss_c * self.loss_config['coarse_weight']
+ loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
+
+ # 2. pixel-level loss (first-stage refinement)
+ if self.overlap_weightf:
+ loss_f = self.compute_fine_loss(data['conf_matrix_f'], data['conf_matrix_f_gt'], data['conf_matrix_f_error_gt'])
+ else:
+ loss_f = self.compute_fine_loss(data['conf_matrix_f'], data['conf_matrix_f_gt'])
+ if loss_f is not None:
+ loss += loss_f * self.loss_config['fine_weight']
+ loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
+ else:
+ assert self.training is False
+ loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound
+
+ # 3. subpixel-level loss (second-stage refinement)
+ # we calculate subpixel-level loss for all pixel-level gt
+ if 'expec_f' not in data:
+ sim_matrix_f, m_ids, i_ids, j_ids_di, j_ids_dj = data['sim_matrix_ff'], data['m_ids_f'], data['i_ids_f'], data['j_ids_f_di'], data['j_ids_f_dj']
+ del data['sim_matrix_ff'], data['m_ids_f'], data['i_ids_f'], data['j_ids_f_di'], data['j_ids_f_dj']
+ delta = create_meshgrid(3, 3, True, sim_matrix_f.device).to(torch.long) # [1, 3, 3, 2]
+ m_ids = m_ids[...,None,None].expand(-1, 3, 3)
+ i_ids = i_ids[...,None,None].expand(-1, 3, 3)
+ # Note that j_ids_di & j_ids_dj in (i, j) format while delta in (x, y) format
+ j_ids_di = j_ids_di[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
+ j_ids_dj = j_ids_dj[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
+
+ sim_matrix_f = sim_matrix_f.reshape(-1, self.local_regressw*self.local_regressw, self.local_regressw+2, self.local_regressw+2) # [M, WW, W+2, W+2]
+ sim_matrix_f = sim_matrix_f[m_ids, i_ids, j_ids_di, j_ids_dj]
+ sim_matrix_f = sim_matrix_f.reshape(-1, 9)
+
+ sim_matrix_f = F.softmax(sim_matrix_f / self.local_regress_temperature, dim=-1)
+ heatmap = sim_matrix_f.reshape(-1, 3, 3)
+
+ # compute coordinates from heatmap
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
+ data.update({'expec_f': coords_normalized})
+ loss_l = self._compute_local_loss_l2(data['expec_f'], data['expec_f_gt'])
+
+ loss += loss_l * self.loss_config['local_weight']
+ loss_scalars.update({"loss_l": loss_l.clone().detach().cpu()})
+
+ loss_scalars.update({'loss': loss.clone().detach().cpu()})
+ data.update({"loss": loss, "loss_scalars": loss_scalars})
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/optimizers/__init__.py b/imcui/third_party/EfficientLoFTR/src/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/optimizers/__init__.py
@@ -0,0 +1,42 @@
+import torch
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
+
+
+def build_optimizer(model, config):
+ name = config.TRAINER.OPTIMIZER
+ lr = config.TRAINER.TRUE_LR
+
+ if name == "adam":
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
+ elif name == "adamw":
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
+ else:
+ raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
+
+
+def build_scheduler(config, optimizer):
+ """
+ Returns:
+ scheduler (dict):{
+ 'scheduler': lr_scheduler,
+ 'interval': 'step', # or 'epoch'
+ 'monitor': 'val_f1', (optional)
+ 'frequency': x, (optional)
+ }
+ """
+ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
+ name = config.TRAINER.SCHEDULER
+
+ if name == 'MultiStepLR':
+ scheduler.update(
+ {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
+ elif name == 'CosineAnnealing':
+ scheduler.update(
+ {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
+ elif name == 'ExponentialLR':
+ scheduler.update(
+ {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
+ else:
+ raise NotImplementedError()
+
+ return scheduler
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/augment.py b/imcui/third_party/EfficientLoFTR/src/utils/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/augment.py
@@ -0,0 +1,55 @@
+import albumentations as A
+
+
+class DarkAug(object):
+ """
+ Extreme dark augmentation aiming at Aachen Day-Night
+ """
+
+ def __init__(self) -> None:
+ self.augmentor = A.Compose([
+ A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
+ A.Blur(p=0.1, blur_limit=(3, 9)),
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
+ ], p=0.75)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+class MobileAug(object):
+ """
+ Random augmentations aiming at images of mobile/handhold devices.
+ """
+
+ def __init__(self):
+ self.augmentor = A.Compose([
+ A.MotionBlur(p=0.25),
+ A.ColorJitter(p=0.5),
+ A.RandomRain(p=0.1), # random occlusion
+ A.RandomSunFlare(p=0.1),
+ A.JpegCompression(p=0.25),
+ A.ISONoise(p=0.25)
+ ], p=1.0)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+def build_augmentor(method=None, **kwargs):
+ if method is not None:
+ raise NotImplementedError('Using of augmentation functions are not supported yet!')
+ if method == 'dark':
+ return DarkAug()
+ elif method == 'mobile':
+ return MobileAug()
+ elif method is None:
+ return None
+ else:
+ raise ValueError(f'Invalid augmentation method: {method}')
+
+
+if __name__ == '__main__':
+ augmentor = build_augmentor('FDA')
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/comm.py b/imcui/third_party/EfficientLoFTR/src/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/comm.py
@@ -0,0 +1,265 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+[Copied from detectron2]
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group=group) == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving Tensor from all ranks
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/dataloader.py b/imcui/third_party/EfficientLoFTR/src/utils/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/dataloader.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+# --- PL-DATAMODULE ---
+
+def get_local_split(items: list, world_size: int, rank: int, seed: int):
+ """ The local rank only loads a split of the dataset. """
+ n_items = len(items)
+ items_permute = np.random.RandomState(seed).permutation(items)
+ if n_items % world_size == 0:
+ padded_items = items_permute
+ else:
+ padding = np.random.RandomState(seed).choice(
+ items,
+ world_size - (n_items % world_size),
+ replace=True)
+ padded_items = np.concatenate([items_permute, padding])
+ assert len(padded_items) % world_size == 0, \
+ f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
+ n_per_rank = len(padded_items) // world_size
+ local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
+
+ return local_items
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/dataset.py b/imcui/third_party/EfficientLoFTR/src/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..37831292d08b5e9f13eeb0dee64ae8882f52a63f
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/dataset.py
@@ -0,0 +1,186 @@
+import io
+from loguru import logger
+
+import cv2
+import numpy as np
+import h5py
+import torch
+from numpy.linalg import inv
+
+
+try:
+ # for internel use only
+ from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
+except Exception:
+ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
+
+# --- DATA IO ---
+
+def load_array_from_s3(
+ path, client, cv_type,
+ use_h5py=False,
+):
+ byte_str = client.Get(path)
+ try:
+ if not use_h5py:
+ raw_array = np.fromstring(byte_str, np.uint8)
+ data = cv2.imdecode(raw_array, cv_type)
+ else:
+ f = io.BytesIO(byte_str)
+ data = np.array(h5py.File(f, 'r')['/depth'])
+ except Exception as ex:
+ print(f"==> Data loading failure: {path}")
+ raise ex
+
+ assert data is not None
+ return data
+
+
+def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
+ else cv2.IMREAD_COLOR
+ if str(path).startswith('s3://'):
+ image = load_array_from_s3(str(path), client, cv_type)
+ else:
+ image = cv2.imread(str(path), cv_type)
+
+ if augment_fn is not None:
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = augment_fn(image)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ return image # (h, w)
+
+
+def get_resized_wh(w, h, resize=None):
+ if resize is not None: # resize the longer edge
+ scale = resize / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def get_divisible_wh(w, h, df=None):
+ if df is not None:
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def pad_bottom_right(inp, pad_size, ret_mask=False):
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+ mask = None
+ if inp.ndim == 2:
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
+ padded[:inp.shape[0], :inp.shape[1]] = inp
+ if ret_mask:
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
+ mask[:inp.shape[0], :inp.shape[1]] = True
+ elif inp.ndim == 3:
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
+ if ret_mask:
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
+ else:
+ raise NotImplementedError()
+ return padded, mask
+
+
+# --- MEGADEPTH ---
+
+def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read image
+ image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ else:
+ mask = None
+
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ if mask is not None:
+ mask = torch.from_numpy(mask)
+
+ return image, mask, scale
+
+
+def read_megadepth_depth(path, pad_to=None):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
+ else:
+ depth = np.array(h5py.File(path, 'r')['depth'])
+ if pad_to is not None:
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+# --- ScanNet ---
+
+def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
+ """
+ Args:
+ resize (tuple): align image to depthmap, in (w, h).
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read and resize image
+ image = imread_gray(path, augment_fn)
+ image = cv2.resize(image, resize)
+
+ # (h, w) -> (1, h, w) and normalized
+ image = torch.from_numpy(image).float()[None] / 255
+ return image
+
+
+def read_scannet_depth(path):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
+ else:
+ depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+def read_scannet_pose(path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = inv(cam2world)
+ return world2cam
+
+
+def read_scannet_intrinsic(path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return intrinsic[:-1, :-1]
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/metrics.py b/imcui/third_party/EfficientLoFTR/src/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cc37896f363567d1f91d630def7bb717569a4ed
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/metrics.py
@@ -0,0 +1,264 @@
+import torch
+import cv2
+import numpy as np
+from collections import OrderedDict
+from loguru import logger
+from kornia.geometry.epipolar import numeric
+from kornia.geometry.conversions import convert_points_to_homogeneous
+import pprint
+
+
+# --- METRICS ---
+
+def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
+ # angle error between 2 vectors
+ t_gt = T_0to1[:3, 3]
+ n = np.linalg.norm(t) * np.linalg.norm(t_gt)
+ t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
+ t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
+ if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
+ t_err = 0
+
+ # angle error between 2 rotation matrices
+ R_gt = T_0to1[:3, :3]
+ cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
+ cos = np.clip(cos, -1., 1.) # handle numercial errors
+ R_err = np.rad2deg(np.abs(np.arccos(cos)))
+
+ return t_err, R_err
+
+
+def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
+ """Squared symmetric epipolar distance.
+ This can be seen as a biased estimation of the reprojection error.
+ Args:
+ pts0 (torch.Tensor): [N, 2]
+ E (torch.Tensor): [3, 3]
+ """
+ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+ pts0 = convert_points_to_homogeneous(pts0)
+ pts1 = convert_points_to_homogeneous(pts1)
+
+ Ep0 = pts0 @ E.T # [N, 3]
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
+ Etp1 = pts1 @ E # [N, 3]
+
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
+ return d
+
+
+def compute_symmetrical_epipolar_errors(data):
+ """
+ Update:
+ data (dict):{"epi_errs": [M]}
+ """
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
+
+ m_bids = data['m_bids']
+ pts0 = data['mkpts0_f']
+ pts1 = data['mkpts1_f']
+
+ epi_errs = []
+ for bs in range(Tx.size(0)):
+ mask = m_bids == bs
+ epi_errs.append(
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+ epi_errs = torch.cat(epi_errs, dim=0)
+
+ data.update({'epi_errs': epi_errs})
+
+
+def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ # normalize keypoints
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+
+ # normalize ransac threshold
+ ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
+
+ # compute pose with cv2
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
+ if E is None:
+ print("\nE is None while trying to recover pose.\n")
+ return None
+
+ # recover pose from E
+ best_num_inliers = 0
+ ret = None
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ ret = (R, t[:, 0], mask.ravel() > 0)
+ best_num_inliers = n
+
+ return ret
+
+
+def estimate_lo_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
+ from .warppers import Camera, Pose
+ import poselib
+ camera0, camera1 = Camera.from_calibration_matrix(K0).float(), Camera.from_calibration_matrix(K1).float()
+ pts0, pts1 = kpts0, kpts1
+
+ M, info = poselib.estimate_relative_pose(
+ pts0,
+ pts1,
+ camera0.to_cameradict(),
+ camera1.to_cameradict(),
+ {
+ "max_epipolar_error": thresh,
+ },
+ )
+ success = M is not None and ( ((M.t != [0., 0., 0.]).all()) or ((M.q != [1., 0., 0., 0.]).all()) )
+ if success:
+ M = Pose.from_Rt(torch.tensor(M.R), torch.tensor(M.t)) # .to(pts0)
+ # print(M)
+ else:
+ M = Pose.from_4x4mat(torch.eye(4).numpy()) # .to(pts0)
+ # print(M)
+
+ estimation = {
+ "success": success,
+ "M_0to1": M,
+ "inliers": torch.tensor(info.pop("inliers")), # .to(pts0),
+ **info,
+ }
+ return estimation
+
+
+def compute_pose_errors(data, config):
+ """
+ Update:
+ data (dict):{
+ "R_errs" List[float]: [N]
+ "t_errs" List[float]: [N]
+ "inliers" List[np.ndarray]: [N]
+ }
+ """
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
+ RANSAC = config.TRAINER.POSE_ESTIMATION_METHOD
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+
+ m_bids = data['m_bids'].cpu().numpy()
+ pts0 = data['mkpts0_f'].cpu().numpy()
+ pts1 = data['mkpts1_f'].cpu().numpy()
+ K0 = data['K0'].cpu().numpy()
+ K1 = data['K1'].cpu().numpy()
+ T_0to1 = data['T_0to1'].cpu().numpy()
+
+ for bs in range(K0.shape[0]):
+ mask = m_bids == bs
+ if config.LOFTR.EVAL_TIMES >= 1:
+ bpts0, bpts1 = pts0[mask], pts1[mask]
+ R_list, T_list, inliers_list = [], [], []
+ # for _ in range(config.LOFTR.EVAL_TIMES):
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(bpts0)))
+ if _ >= config.LOFTR.EVAL_TIMES:
+ continue
+ bpts0 = bpts0[shuffling]
+ bpts1 = bpts1[shuffling]
+
+ if RANSAC == 'RANSAC':
+ ret = estimate_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf)
+ if ret is None:
+ R_list.append(np.inf)
+ T_list.append(np.inf)
+ inliers_list.append(np.array([]).astype(bool))
+ else:
+ R, t, inliers = ret
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
+ R_list.append(R_err)
+ T_list.append(t_err)
+ inliers_list.append(inliers)
+
+ elif RANSAC == 'LO-RANSAC':
+ est = estimate_lo_pose(bpts0, bpts1, K0[bs], K1[bs], pixel_thr, conf=conf)
+ if not est["success"]:
+ R_list.append(90)
+ T_list.append(90)
+ inliers_list.append(np.array([]).astype(bool))
+ else:
+ M = est["M_0to1"]
+ inl = est["inliers"].numpy()
+ t_error, r_error = relative_pose_error(T_0to1[bs], M.R, M.t, ignore_gt_t_thr=0.0)
+ R_list.append(r_error)
+ T_list.append(t_error)
+ inliers_list.append(inl)
+ else:
+ raise ValueError(f"Unknown RANSAC method: {RANSAC}")
+
+ data['R_errs'].append(R_list)
+ data['t_errs'].append(T_list)
+ data['inliers'].append(inliers_list[0])
+
+
+# --- METRIC AGGREGATION ---
+
+def error_auc(errors, thresholds):
+ """
+ Args:
+ errors (list): [N,]
+ thresholds (list)
+ """
+ errors = [0] + sorted(list(errors))
+ recall = list(np.linspace(0, 1, len(errors)))
+
+ aucs = []
+ thresholds = [5, 10, 20]
+ for thr in thresholds:
+ last_index = np.searchsorted(errors, thr)
+ y = recall[:last_index] + [recall[last_index-1]]
+ x = errors[:last_index] + [thr]
+ aucs.append(np.trapz(y, x) / thr)
+
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+
+
+def epidist_prec(errors, thresholds, ret_dict=False):
+ precs = []
+ for thr in thresholds:
+ prec_ = []
+ for errs in errors:
+ correct_mask = errs < thr
+ prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
+ precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
+ if ret_dict:
+ return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
+ else:
+ return precs
+
+
+def aggregate_metrics(metrics, epi_err_thr=5e-4, config=None):
+ """ Aggregate metrics for the whole dataset:
+ (This method should be called once per dataset)
+ 1. AUC of the pose error (angular) at the threshold [5, 10, 20]
+ 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
+ """
+ # filter duplicates
+ unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
+ unq_ids = list(unq_ids.values())
+ logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
+
+ # pose auc
+ angular_thresholds = [5, 10, 20]
+
+ if config.LOFTR.EVAL_TIMES >= 1:
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0).reshape(-1, config.LOFTR.EVAL_TIMES)[unq_ids].reshape(-1)
+ else:
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
+ aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
+
+ # matching precision
+ dist_thresholds = [epi_err_thr]
+ precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
+
+ u_num_mathces = np.array(metrics['num_matches'], dtype=object)[unq_ids]
+ num_matches = {f'num_matches': u_num_mathces.mean() }
+ return {**aucs, **precs, **num_matches}
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/misc.py b/imcui/third_party/EfficientLoFTR/src/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..604b9bc93e0fb92a9750fe72f3d692edc84207b5
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/misc.py
@@ -0,0 +1,106 @@
+import os
+import contextlib
+import joblib
+from typing import Union
+from loguru import _Logger, logger
+from itertools import chain
+
+import torch
+from yacs.config import CfgNode as CN
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+def upper_config(dict_cfg):
+ if not isinstance(dict_cfg, dict):
+ return dict_cfg
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
+
+
+def log_on(condition, message, level):
+ if condition:
+ assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
+ logger.log(level, message)
+
+
+def get_rank_zero_only_logger(logger: _Logger):
+ if rank_zero_only.rank == 0:
+ return logger
+ else:
+ for _level in logger._core.levels.keys():
+ level = _level.lower()
+ setattr(logger, level,
+ lambda x: None)
+ logger._log = lambda x: None
+ return logger
+
+
+def setup_gpus(gpus: Union[str, int]) -> int:
+ """ A temporary fix for pytorch-lighting 1.3.x """
+ gpus = str(gpus)
+ gpu_ids = []
+
+ if ',' not in gpus:
+ n_gpus = int(gpus)
+ return n_gpus if n_gpus != -1 else torch.cuda.device_count()
+ else:
+ gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
+
+ # setup environment variables
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ if visible_devices is None:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
+ else:
+ logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
+ return len(gpu_ids)
+
+
+def flattenList(x):
+ return list(chain(*x))
+
+
+@contextlib.contextmanager
+def tqdm_joblib(tqdm_object):
+ """Context manager to patch joblib to report into tqdm progress bar given as argument
+
+ Usage:
+ with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
+ Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
+
+ When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
+ ret_vals = Parallel(n_jobs=args.world_size)(
+ delayed(lambda x: _compute_cov_score(pid, *x))(param)
+ for param in tqdm(combinations(image_ids, 2),
+ desc=f'Computing cov_score of [{pid}]',
+ total=len(image_ids)*(len(image_ids)-1)/2))
+ Src: https://stackoverflow.com/a/58936697
+ """
+ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ tqdm_object.update(n=self.batch_size)
+ return super().__call__(*args, **kwargs)
+
+ old_batch_callback = joblib.parallel.BatchCompletionCallBack
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
+ try:
+ yield tqdm_object
+ finally:
+ joblib.parallel.BatchCompletionCallBack = old_batch_callback
+ tqdm_object.close()
+
+def detect_NaN(feat_0, feat_1):
+ logger.info(f'NaN detected in feature')
+ logger.info(f"#NaN in feat_0: {torch.isnan(feat_0).int().sum()}, #NaN in feat_1: {torch.isnan(feat_1).int().sum()}")
+ feat_0[torch.isnan(feat_0)] = 0
+ feat_1[torch.isnan(feat_1)] = 0
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/plotting.py b/imcui/third_party/EfficientLoFTR/src/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4260c7487cfbc76dda94c589957601cea972d4
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/plotting.py
@@ -0,0 +1,154 @@
+import bisect
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+
+import torch
+
+def _compute_conf_thresh(data):
+ dataset_name = data['dataset_name'][0].lower()
+ if dataset_name == 'scannet':
+ thr = 5e-4
+ elif dataset_name == 'megadepth':
+ thr = 1e-4
+ else:
+ raise ValueError(f'Unknown dataset: {dataset_name}')
+ return thr
+
+
+# --- VISUALIZATION --- #
+
+def make_matching_figure(
+ img0, img1, mkpts0, mkpts1, color,
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+ # draw image pair
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0, cmap='gray')
+ axes[1].imshow(img1, cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if kpts0 is not None:
+ assert kpts1 is not None
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
+
+ # draw matches
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure, c=color[i], linewidth=1)
+ for i in range(len(mkpts0))]
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
+
+ # put txts
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def _make_evaluation_figure(data, b_id, alpha='dynamic'):
+ b_mask = data['m_bids'] == b_id
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
+ kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
+
+ # for megadepth, we visualize matches on the resized image
+ if 'scale0' in data:
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
+
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+ ]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
+ color, text=text)
+ return figure
+
+def _make_confidence_figure(data, b_id):
+ # TODO: Implement confidence figure
+ raise NotImplementedError()
+
+def make_matching_figures(data, config, mode='evaluation'):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ['evaluation', 'confidence', 'gt'] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ fig = _make_evaluation_figure(
+ data, b_id,
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
+ elif mode == 'confidence':
+ fig = _make_confidence_figure(data, b_id)
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+
+def dynamic_alpha(n_matches,
+ milestones=[0, 300, 1000, 2000],
+ alphas=[1.0, 0.8, 0.4, 0.2]):
+ if n_matches == 0:
+ return 1.0
+ ranges = list(zip(alphas, alphas[1:] + [None]))
+ loc = bisect.bisect_right(milestones, n_matches) - 1
+ _range = ranges[loc]
+ if _range[1] is None:
+ return _range[0]
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+
+
+def error_colormap(err, thr, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
+ return np.clip(
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/profiler.py b/imcui/third_party/EfficientLoFTR/src/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/profiler.py
@@ -0,0 +1,39 @@
+import torch
+from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
+from contextlib import contextmanager
+from pytorch_lightning.utilities import rank_zero_only
+
+
+class InferenceProfiler(SimpleProfiler):
+ """
+ This profiler records duration of actions with cuda.synchronize()
+ Use this in test time.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.start = rank_zero_only(self.start)
+ self.stop = rank_zero_only(self.stop)
+ self.summary = rank_zero_only(self.summary)
+
+ @contextmanager
+ def profile(self, action_name: str) -> None:
+ try:
+ torch.cuda.synchronize()
+ self.start(action_name)
+ yield action_name
+ finally:
+ torch.cuda.synchronize()
+ self.stop(action_name)
+
+
+def build_profiler(name):
+ if name == 'inference':
+ return InferenceProfiler()
+ elif name == 'pytorch':
+ from pytorch_lightning.profiler import PyTorchProfiler
+ return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
+ elif name is None:
+ return PassThroughProfiler()
+ else:
+ raise ValueError(f'Invalid profiler: {name}')
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/warppers.py b/imcui/third_party/EfficientLoFTR/src/utils/warppers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2f33b78f0d1645b3eefc8c9c6dbe14f24f7b2d6
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/warppers.py
@@ -0,0 +1,426 @@
+"""
+Convenience classes for an SE3 pose and a pinhole Camera with lens distortion.
+Based on PyTorch tensors: differentiable, batched, with GPU support.
+Modified from: https://github.com/cvg/glue-factory/blob/scannet1500/gluefactory/geometry/wrappers.py
+"""
+
+import functools
+import inspect
+import math
+from typing import Dict, List, NamedTuple, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from .warppers_utils import (
+ J_distort_points,
+ distort_points,
+ skew_symmetric,
+ so3exp_map,
+ to_homogeneous,
+)
+
+
+def autocast(func):
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors
+ if they are numpy arrays. Use the device and dtype of the wrapper.
+ """
+
+ @functools.wraps(func)
+ def wrap(self, *args):
+ device = torch.device("cpu")
+ dtype = None
+ if isinstance(self, TensorWrapper):
+ if self._data is not None:
+ device = self.device
+ dtype = self.dtype
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
+ raise ValueError(self)
+
+ cast_args = []
+ for arg in args:
+ if isinstance(arg, np.ndarray):
+ arg = torch.from_numpy(arg)
+ arg = arg.to(device=device, dtype=dtype)
+ cast_args.append(arg)
+ return func(self, *cast_args)
+
+ return wrap
+
+
+class TensorWrapper:
+ _data = None
+
+ @autocast
+ def __init__(self, data: torch.Tensor):
+ self._data = data
+
+ @property
+ def shape(self):
+ return self._data.shape[:-1]
+
+ @property
+ def device(self):
+ return self._data.device
+
+ @property
+ def dtype(self):
+ return self._data.dtype
+
+ def __getitem__(self, index):
+ return self.__class__(self._data[index])
+
+ def __setitem__(self, index, item):
+ self._data[index] = item.data
+
+ def to(self, *args, **kwargs):
+ return self.__class__(self._data.to(*args, **kwargs))
+
+ def cpu(self):
+ return self.__class__(self._data.cpu())
+
+ def cuda(self):
+ return self.__class__(self._data.cuda())
+
+ def pin_memory(self):
+ return self.__class__(self._data.pin_memory())
+
+ def float(self):
+ return self.__class__(self._data.float())
+
+ def double(self):
+ return self.__class__(self._data.double())
+
+ def detach(self):
+ return self.__class__(self._data.detach())
+
+ @classmethod
+ def stack(cls, objects: List, dim=0, *, out=None):
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
+ return cls(data)
+
+ @classmethod
+ def __torch_function__(self, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ if func is torch.stack:
+ return self.stack(*args, **kwargs)
+ else:
+ return NotImplemented
+
+
+class Pose(TensorWrapper):
+ def __init__(self, data: torch.Tensor):
+ assert data.shape[-1] == 12
+ super().__init__(data)
+
+ @classmethod
+ @autocast
+ def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):
+ """Pose from a rotation matrix and translation vector.
+ Accepts numpy arrays or PyTorch tensors.
+
+ Args:
+ R: rotation matrix with shape (..., 3, 3).
+ t: translation vector with shape (..., 3).
+ """
+ assert R.shape[-2:] == (3, 3)
+ assert t.shape[-1] == 3
+ assert R.shape[:-2] == t.shape[:-1]
+ data = torch.cat([R.flatten(start_dim=-2), t], -1)
+ return cls(data)
+
+ @classmethod
+ @autocast
+ def from_aa(cls, aa: torch.Tensor, t: torch.Tensor):
+ """Pose from an axis-angle rotation vector and translation vector.
+ Accepts numpy arrays or PyTorch tensors.
+
+ Args:
+ aa: axis-angle rotation vector with shape (..., 3).
+ t: translation vector with shape (..., 3).
+ """
+ assert aa.shape[-1] == 3
+ assert t.shape[-1] == 3
+ assert aa.shape[:-1] == t.shape[:-1]
+ return cls.from_Rt(so3exp_map(aa), t)
+
+ @classmethod
+ def from_4x4mat(cls, T: torch.Tensor):
+ """Pose from an SE(3) transformation matrix.
+ Args:
+ T: transformation matrix with shape (..., 4, 4).
+ """
+ assert T.shape[-2:] == (4, 4)
+ R, t = T[..., :3, :3], T[..., :3, 3]
+ return cls.from_Rt(R, t)
+
+ @classmethod
+ def from_colmap(cls, image: NamedTuple):
+ """Pose from a COLMAP Image."""
+ return cls.from_Rt(image.qvec2rotmat(), image.tvec)
+
+ @property
+ def R(self) -> torch.Tensor:
+ """Underlying rotation matrix with shape (..., 3, 3)."""
+ rvec = self._data[..., :9]
+ return rvec.reshape(rvec.shape[:-1] + (3, 3))
+
+ @property
+ def t(self) -> torch.Tensor:
+ """Underlying translation vector with shape (..., 3)."""
+ return self._data[..., -3:]
+
+ def inv(self) -> "Pose":
+ """Invert an SE(3) pose."""
+ R = self.R.transpose(-1, -2)
+ t = -(R @ self.t.unsqueeze(-1)).squeeze(-1)
+ return self.__class__.from_Rt(R, t)
+
+ def compose(self, other: "Pose") -> "Pose":
+ """Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C."""
+ R = self.R @ other.R
+ t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1)
+ return self.__class__.from_Rt(R, t)
+
+ @autocast
+ def transform(self, p3d: torch.Tensor) -> torch.Tensor:
+ """Transform a set of 3D points.
+ Args:
+ p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
+ """
+ assert p3d.shape[-1] == 3
+ # assert p3d.shape[:-2] == self.shape # allow broadcasting
+ return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2)
+
+ def __mul__(self, p3D: torch.Tensor) -> torch.Tensor:
+ """Transform a set of 3D points: T_A2B * p3D_A -> p3D_B."""
+ return self.transform(p3D)
+
+ def __matmul__(
+ self, other: Union["Pose", torch.Tensor]
+ ) -> Union["Pose", torch.Tensor]:
+ """Transform a set of 3D points: T_A2B * p3D_A -> p3D_B.
+ or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C."""
+ if isinstance(other, self.__class__):
+ return self.compose(other)
+ else:
+ return self.transform(other)
+
+ @autocast
+ def J_transform(self, p3d_out: torch.Tensor):
+ # [[1,0,0,0,-pz,py],
+ # [0,1,0,pz,0,-px],
+ # [0,0,1,-py,px,0]]
+ J_t = torch.diag_embed(torch.ones_like(p3d_out))
+ J_rot = -skew_symmetric(p3d_out)
+ J = torch.cat([J_t, J_rot], dim=-1)
+ return J # N x 3 x 6
+
+ def numpy(self) -> Tuple[np.ndarray]:
+ return self.R.numpy(), self.t.numpy()
+
+ def magnitude(self) -> Tuple[torch.Tensor]:
+ """Magnitude of the SE(3) transformation.
+ Returns:
+ dr: rotation anngle in degrees.
+ dt: translation distance in meters.
+ """
+ trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)
+ cos = torch.clamp((trace - 1) / 2, -1, 1)
+ dr = torch.acos(cos).abs() / math.pi * 180
+ dt = torch.norm(self.t, dim=-1)
+ return dr, dt
+
+ def __repr__(self):
+ return f"Pose: {self.shape} {self.dtype} {self.device}"
+
+
+class Camera(TensorWrapper):
+ eps = 1e-4
+
+ def __init__(self, data: torch.Tensor):
+ assert data.shape[-1] in {6, 8, 10}
+ super().__init__(data)
+
+ @classmethod
+ def from_colmap(cls, camera: Union[Dict, NamedTuple]):
+ """Camera from a COLMAP Camera tuple or dictionary.
+ We use the corner-convetion from COLMAP (center of top left pixel is (0.5, 0.5))
+ """
+ if isinstance(camera, tuple):
+ camera = camera._asdict()
+
+ model = camera["model"]
+ params = camera["params"]
+
+ if model in ["OPENCV", "PINHOLE", "RADIAL"]:
+ (fx, fy, cx, cy), params = np.split(params, [4])
+ elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]:
+ (f, cx, cy), params = np.split(params, [3])
+ fx = fy = f
+ if model == "SIMPLE_RADIAL":
+ params = np.r_[params, 0.0]
+ else:
+ raise NotImplementedError(model)
+
+ data = np.r_[camera["width"], camera["height"], fx, fy, cx, cy, params]
+ return cls(data)
+
+ @classmethod
+ @autocast
+ def from_calibration_matrix(cls, K: torch.Tensor):
+ cx, cy = K[..., 0, 2], K[..., 1, 2]
+ fx, fy = K[..., 0, 0], K[..., 1, 1]
+ data = torch.stack([2 * cx, 2 * cy, fx, fy, cx, cy], -1)
+ return cls(data)
+
+ @autocast
+ def calibration_matrix(self):
+ K = torch.zeros(
+ *self._data.shape[:-1],
+ 3,
+ 3,
+ device=self._data.device,
+ dtype=self._data.dtype,
+ )
+ K[..., 0, 2] = self._data[..., 4]
+ K[..., 1, 2] = self._data[..., 5]
+ K[..., 0, 0] = self._data[..., 2]
+ K[..., 1, 1] = self._data[..., 3]
+ K[..., 2, 2] = 1.0
+ return K
+
+ @property
+ def size(self) -> torch.Tensor:
+ """Size (width height) of the images, with shape (..., 2)."""
+ return self._data[..., :2]
+
+ @property
+ def f(self) -> torch.Tensor:
+ """Focal lengths (fx, fy) with shape (..., 2)."""
+ return self._data[..., 2:4]
+
+ @property
+ def c(self) -> torch.Tensor:
+ """Principal points (cx, cy) with shape (..., 2)."""
+ return self._data[..., 4:6]
+
+ @property
+ def dist(self) -> torch.Tensor:
+ """Distortion parameters, with shape (..., {0, 2, 4})."""
+ return self._data[..., 6:]
+
+ @autocast
+ def scale(self, scales: torch.Tensor):
+ """Update the camera parameters after resizing an image."""
+ s = scales
+ data = torch.cat([self.size * s, self.f * s, self.c * s, self.dist], -1)
+ return self.__class__(data)
+
+ def crop(self, left_top: Tuple[float], size: Tuple[int]):
+ """Update the camera parameters after cropping an image."""
+ left_top = self._data.new_tensor(left_top)
+ size = self._data.new_tensor(size)
+ data = torch.cat([size, self.f, self.c - left_top, self.dist], -1)
+ return self.__class__(data)
+
+ @autocast
+ def in_image(self, p2d: torch.Tensor):
+ """Check if 2D points are within the image boundaries."""
+ assert p2d.shape[-1] == 2
+ # assert p2d.shape[:-2] == self.shape # allow broadcasting
+ size = self.size.unsqueeze(-2)
+ valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
+ return valid
+
+ @autocast
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Project 3D points into the camera plane and check for visibility."""
+ z = p3d[..., -1]
+ valid = z > self.eps
+ z = z.clamp(min=self.eps)
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
+ return p2d, valid
+
+ def J_project(self, p3d: torch.Tensor):
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
+ zero = torch.zeros_like(z)
+ z = z.clamp(min=self.eps)
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
+ return J # N x 2 x 3
+
+ @autocast
+ def distort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates
+ and check for validity of the distortion model.
+ """
+ assert pts.shape[-1] == 2
+ # assert pts.shape[:-2] == self.shape # allow broadcasting
+ return distort_points(pts, self.dist)
+
+ def J_distort(self, pts: torch.Tensor):
+ return J_distort_points(pts, self.dist) # N x 2 x 2
+
+ @autocast
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert normalized 2D coordinates into pixel coordinates."""
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
+
+ @autocast
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert normalized 2D coordinates into pixel coordinates."""
+ return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2)
+
+ def J_denormalize(self):
+ return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2
+
+ @autocast
+ def cam2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Transform 3D points into 2D pixel coordinates."""
+ p2d, visible = self.project(p3d)
+ p2d, mask = self.distort(p2d)
+ p2d = self.denormalize(p2d)
+ valid = visible & mask & self.in_image(p2d)
+ return p2d, valid
+
+ def J_world2image(self, p3d: torch.Tensor):
+ p2d_dist, valid = self.project(p3d)
+ J = self.J_denormalize() @ self.J_distort(p2d_dist) @ self.J_project(p3d)
+ return J, valid
+
+ @autocast
+ def image2cam(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert 2D pixel corrdinates to 3D points with z=1"""
+ assert self._data.shape
+ p2d = self.normalize(p2d)
+ # iterative undistortion
+ return to_homogeneous(p2d)
+
+ def to_cameradict(self, camera_model: Optional[str] = None) -> List[Dict]:
+ data = self._data.clone()
+ if data.dim() == 1:
+ data = data.unsqueeze(0)
+ assert data.dim() == 2
+ b, d = data.shape
+ if camera_model is None:
+ camera_model = {6: "PINHOLE", 8: "RADIAL", 10: "OPENCV"}[d]
+ cameras = []
+ for i in range(b):
+ if camera_model.startswith("SIMPLE_"):
+ params = [x.item() for x in data[i, 3 : min(d, 7)]]
+ else:
+ params = [x.item() for x in data[i, 2:]]
+ cameras.append(
+ {
+ "model": camera_model,
+ "width": int(data[i, 0].item()),
+ "height": int(data[i, 1].item()),
+ "params": params,
+ }
+ )
+ return cameras if self._data.dim() == 2 else cameras[0]
+
+ def __repr__(self):
+ return f"Camera {self.shape} {self.dtype} {self.device}"
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/src/utils/warppers_utils.py b/imcui/third_party/EfficientLoFTR/src/utils/warppers_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad3ef5c05d74cd3bd46f5b3b0d8c6d331a17dfad
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/src/utils/warppers_utils.py
@@ -0,0 +1,171 @@
+"""
+Modified from: https://github.com/cvg/glue-factory/blob/scannet1500/gluefactory/geometry/utils.py
+"""
+
+import numpy as np
+import torch
+
+
+def to_homogeneous(points):
+ """Convert N-dimensional points to homogeneous coordinates.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N).
+ Returns:
+ A torch.Tensor or numpy.ndarray with size (..., N+1).
+ """
+ if isinstance(points, torch.Tensor):
+ pad = points.new_ones(points.shape[:-1] + (1,))
+ return torch.cat([points, pad], dim=-1)
+ elif isinstance(points, np.ndarray):
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
+ return np.concatenate([points, pad], axis=-1)
+ else:
+ raise ValueError
+
+
+def from_homogeneous(points, eps=0.0):
+ """Remove the homogeneous dimension of N-dimensional points.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N+1).
+ eps: Epsilon value to prevent zero division.
+ Returns:
+ A torch.Tensor or numpy ndarray with size (..., N).
+ """
+ return points[..., :-1] / (points[..., -1:] + eps)
+
+
+def batched_eye_like(x: torch.Tensor, n: int):
+ """Create a batch of identity matrices.
+ Args:
+ x: a reference torch.Tensor whose batch dimension will be copied.
+ n: the size of each identity matrix.
+ Returns:
+ A torch.Tensor of size (B, n, n), with same dtype and device as x.
+ """
+ return torch.eye(n).to(x)[None].repeat(len(x), 1, 1)
+
+
+def skew_symmetric(v):
+ """Create a skew-symmetric matrix from a (batched) vector of size (..., 3)."""
+ z = torch.zeros_like(v[..., 0])
+ M = torch.stack(
+ [
+ z,
+ -v[..., 2],
+ v[..., 1],
+ v[..., 2],
+ z,
+ -v[..., 0],
+ -v[..., 1],
+ v[..., 0],
+ z,
+ ],
+ dim=-1,
+ ).reshape(v.shape[:-1] + (3, 3))
+ return M
+
+
+def transform_points(T, points):
+ return from_homogeneous(to_homogeneous(points) @ T.transpose(-1, -2))
+
+
+def is_inside(pts, shape):
+ return (pts > 0).all(-1) & (pts < shape[:, None]).all(-1)
+
+
+def so3exp_map(w, eps: float = 1e-7):
+ """Compute rotation matrices from batched twists.
+ Args:
+ w: batched 3D axis-angle vectors of size (..., 3).
+ Returns:
+ A batch of rotation matrices of size (..., 3, 3).
+ """
+ theta = w.norm(p=2, dim=-1, keepdim=True)
+ small = theta < eps
+ div = torch.where(small, torch.ones_like(theta), theta)
+ W = skew_symmetric(w / div)
+ theta = theta[..., None] # ... x 1 x 1
+ res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta))
+ res = torch.where(small[..., None], W, res) # first-order Taylor approx
+ return torch.eye(3).to(W) + res
+
+
+@torch.jit.script
+def distort_points(pts, dist):
+ """Distort normalized 2D coordinates
+ and check for validity of the distortion model.
+ """
+ dist = dist.unsqueeze(-2) # add point dimension
+ ndist = dist.shape[-1]
+ undist = pts
+ valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool)
+ if ndist > 0:
+ k1, k2 = dist[..., :2].split(1, -1)
+ r2 = torch.sum(pts**2, -1, keepdim=True)
+ radial = k1 * r2 + k2 * r2**2
+ undist = undist + pts * radial
+
+ # The distortion model is supposedly only valid within the image
+ # boundaries. Because of the negative radial distortion, points that
+ # are far outside of the boundaries might actually be mapped back
+ # within the image. To account for this, we discard points that are
+ # beyond the inflection point of the distortion model,
+ # e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0
+ limited = ((k2 > 0) & ((9 * k1**2 - 20 * k2) > 0)) | ((k2 <= 0) & (k1 > 0))
+ limit = torch.abs(
+ torch.where(
+ k2 > 0,
+ (torch.sqrt(9 * k1**2 - 20 * k2) - 3 * k1) / (10 * k2),
+ 1 / (3 * k1),
+ )
+ )
+ valid = valid & torch.squeeze(~limited | (r2 < limit), -1)
+
+ if ndist > 2:
+ p12 = dist[..., 2:]
+ p21 = p12.flip(-1)
+ uv = torch.prod(pts, -1, keepdim=True)
+ undist = undist + 2 * p12 * uv + p21 * (r2 + 2 * pts**2)
+ # TODO: handle tangential boundaries
+
+ return undist, valid
+
+
+@torch.jit.script
+def J_distort_points(pts, dist):
+ dist = dist.unsqueeze(-2) # add point dimension
+ ndist = dist.shape[-1]
+
+ J_diag = torch.ones_like(pts)
+ J_cross = torch.zeros_like(pts)
+ if ndist > 0:
+ k1, k2 = dist[..., :2].split(1, -1)
+ r2 = torch.sum(pts**2, -1, keepdim=True)
+ uv = torch.prod(pts, -1, keepdim=True)
+ radial = k1 * r2 + k2 * r2**2
+ d_radial = 2 * k1 + 4 * k2 * r2
+ J_diag += radial + (pts**2) * d_radial
+ J_cross += uv * d_radial
+
+ if ndist > 2:
+ p12 = dist[..., 2:]
+ p21 = p12.flip(-1)
+ J_diag += 2 * p12 * pts.flip(-1) + 6 * p21 * pts
+ J_cross += 2 * p12 * pts + 2 * p21 * pts.flip(-1)
+
+ J = torch.diag_embed(J_diag) + torch.diag_embed(J_cross).flip(-1)
+ return J
+
+
+def get_image_coords(img):
+ h, w = img.shape[-2:]
+ return (
+ torch.stack(
+ torch.meshgrid(
+ torch.arange(h, dtype=torch.float32, device=img.device),
+ torch.arange(w, dtype=torch.float32, device=img.device),
+ indexing="ij",
+ )[::-1],
+ dim=0,
+ ).permute(1, 2, 0)
+ )[None] + 0.5
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/test.py b/imcui/third_party/EfficientLoFTR/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a04324e9cd16c74ec3affbe17e9764d2a0002b
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/test.py
@@ -0,0 +1,143 @@
+import pytorch_lightning as pl
+import argparse
+import pprint
+from loguru import logger as loguru_logger
+
+from src.config.default import get_cfg_defaults
+from src.utils.profiler import build_profiler
+
+from src.lightning.data import MultiSceneDataModule
+from src.lightning.lightning_loftr import PL_LoFTR
+
+import torch
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
+ parser.add_argument(
+ '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
+ parser.add_argument(
+ '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--batch_size', type=int, default=1, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=2)
+ parser.add_argument(
+ '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
+ parser.add_argument(
+ '--pixel_thr', type=float, default=None, help='modify the RANSAC threshold.')
+ parser.add_argument(
+ '--ransac', type=str, default=None, help='modify the RANSAC method')
+ parser.add_argument(
+ '--scannetX', type=int, default=None, help='ScanNet resize X')
+ parser.add_argument(
+ '--scannetY', type=int, default=None, help='ScanNet resize Y')
+ parser.add_argument(
+ '--megasize', type=int, default=None, help='MegaDepth resize')
+ parser.add_argument(
+ '--npe', action='store_true', default=False, help='')
+ parser.add_argument(
+ '--fp32', action='store_true', default=False, help='')
+ parser.add_argument(
+ '--ransac_times', type=int, default=None, help='repeat ransac multiple times for more robust evaluation')
+ parser.add_argument(
+ '--rmbd', type=int, default=None, help='remove border matches')
+ parser.add_argument(
+ '--deter', action='store_true', default=False, help='use deterministic mode for testing')
+ parser.add_argument(
+ '--half', action='store_true', default=False, help='pure16')
+ parser.add_argument(
+ '--flash', action='store_true', default=False, help='flash')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+def inplace_relu(m):
+ classname = m.__class__.__name__
+ if classname.find('ReLU') != -1:
+ m.inplace=True
+
+if __name__ == '__main__':
+ # parse arguments
+ args = parse_args()
+ pprint.pprint(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ if args.deter:
+ torch.backends.cudnn.deterministic = True
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+
+ # tune when testing
+ if args.thr is not None:
+ config.LOFTR.MATCH_COARSE.THR = args.thr
+
+ if args.scannetX is not None and args.scannetY is not None:
+ config.DATASET.SCAN_IMG_RESIZEX = args.scannetX
+ config.DATASET.SCAN_IMG_RESIZEY = args.scannetY
+ if args.megasize is not None:
+ config.DATASET.MGDPT_IMG_RESIZE = args.megasize
+
+ if args.npe:
+ if config.LOFTR.COARSE.ROPE:
+ assert config.DATASET.NPE_NAME is not None
+ if config.DATASET.NPE_NAME is not None:
+ if config.DATASET.NPE_NAME == 'megadepth':
+ config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.MGDPT_IMG_RESIZE, config.DATASET.MGDPT_IMG_RESIZE] # [832, 832, 1152, 1152]
+ elif config.DATASET.NPE_NAME == 'scannet':
+ config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.SCAN_IMG_RESIZEX, config.DATASET.SCAN_IMG_RESIZEX] # [832, 832, 640, 640]
+ else:
+ config.LOFTR.COARSE.NPE = [832, 832, 832, 832]
+
+ if args.ransac_times is not None:
+ config.LOFTR.EVAL_TIMES = args.ransac_times
+
+ if args.rmbd is not None:
+ config.LOFTR.MATCH_COARSE.BORDER_RM = args.rmbd
+
+ if args.pixel_thr is not None:
+ config.TRAINER.RANSAC_PIXEL_THR = args.pixel_thr
+
+ if args.ransac is not None:
+ config.TRAINER.POSE_ESTIMATION_METHOD = args.ransac
+ if args.ransac == 'LO-RANSAC' and config.TRAINER.RANSAC_PIXEL_THR == 0.5:
+ config.TRAINER.RANSAC_PIXEL_THR = 2.0
+
+ if args.fp32:
+ config.LOFTR.MP = False
+
+ if args.half:
+ config.LOFTR.HALF = True
+ config.DATASET.FP16 = True
+ else:
+ config.LOFTR.HALF = False
+ config.DATASET.FP16 = False
+
+ if args.flash:
+ config.LOFTR.COARSE.NO_FLASH = False
+
+ loguru_logger.info(f"Args and config initialized!")
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
+ loguru_logger.info(f"LoFTR-lightning initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"DataModule initialized!")
+
+ # lightning trainer
+ trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)
+
+ loguru_logger.info(f"Start testing!")
+ trainer.test(model, datamodule=data_module, verbose=False)
\ No newline at end of file
diff --git a/imcui/third_party/EfficientLoFTR/train.py b/imcui/third_party/EfficientLoFTR/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d74512464fbf5d35cb3ee48b4683b8cb870ce6e
--- /dev/null
+++ b/imcui/third_party/EfficientLoFTR/train.py
@@ -0,0 +1,154 @@
+import math
+import argparse
+import pprint
+from distutils.util import strtobool
+from pathlib import Path
+from loguru import logger as loguru_logger
+
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+from pytorch_lightning.plugins import DDPPlugin, NativeMixedPrecisionPlugin
+
+from src.config.default import get_cfg_defaults
+from src.utils.misc import get_rank_zero_only_logger, setup_gpus
+from src.utils.profiler import build_profiler
+from src.lightning.data import MultiSceneDataModule
+from src.lightning.lightning_loftr import PL_LoFTR
+import torch
+
+loguru_logger = get_rank_zero_only_logger(loguru_logger)
+
+import os
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--exp_name', type=str, default='default_exp_name')
+ parser.add_argument(
+ '--batch_size', type=int, default=4, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=4)
+ parser.add_argument(
+ '--pin_memory', type=lambda x: bool(strtobool(x)),
+ nargs='?', default=True, help='whether loading data to pinned memory or not')
+ parser.add_argument(
+ '--ckpt_path', type=str, default=None,
+ help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR')
+ parser.add_argument(
+ '--disable_ckpt', action='store_true',
+ help='disable checkpoint saving (useful for debugging).')
+ parser.add_argument(
+ '--profiler_name', type=str, default=None,
+ help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--parallel_load_data', action='store_true',
+ help='load datasets in with multiple processes.')
+ parser.add_argument(
+ '--thr', type=float, default=0.1)
+ parser.add_argument(
+ '--train_coarse_percent', type=float, default=0.1, help='training tricks: save GPU memory')
+ parser.add_argument(
+ '--disable_mp', action='store_true', help='disable mixed-precision training')
+ parser.add_argument(
+ '--deter', action='store_true', help='use deterministic mode for training')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+def inplace_relu(m):
+ classname = m.__class__.__name__
+ if classname.find('ReLU') != -1:
+ m.inplace=True
+
+def main():
+ # parse arguments
+ args = parse_args()
+ rank_zero_only(pprint.pprint)(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ get_cfg_default = get_cfg_defaults
+
+ config = get_cfg_default()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+
+ if config.LOFTR.COARSE.NPE is None:
+ config.LOFTR.COARSE.NPE = [832, 832, 832, 832] # training at 832 resolution on MegaDepth datasets
+
+ if args.deter:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+ # TODO: Use different seeds for each dataloader workers
+ # This is needed for data augmentation
+
+ # scale lr and warmup-step automatically
+ args.gpus = _n_gpus = setup_gpus(args.gpus)
+ config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
+ config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
+ _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
+ config.TRAINER.SCALING = _scaling
+ config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
+ config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
+
+ if args.thr is not None:
+ config.LOFTR.MATCH_COARSE.THR = args.thr
+ if args.disable_mp:
+ config.LOFTR.MP = False
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
+ loguru_logger.info(f"LoFTR LightningModule initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"LoFTR DataModule initialized!")
+
+ # TensorBoard Logger
+ logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
+ ckpt_dir = Path(logger.log_dir) / 'checkpoints'
+
+ # Callbacks
+ # TODO: update ModelCheckpoint to monitor multiple metrics
+ ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
+ save_last=True,
+ dirpath=str(ckpt_dir),
+ filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
+ lr_monitor = LearningRateMonitor(logging_interval='step')
+ callbacks = [lr_monitor]
+ if not args.disable_ckpt:
+ callbacks.append(ckpt_callback)
+
+ # Lightning Trainer
+ trainer = pl.Trainer.from_argparse_args(
+ args,
+ plugins=[DDPPlugin(find_unused_parameters=False,
+ num_nodes=args.num_nodes,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), NativeMixedPrecisionPlugin()],
+ gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
+ callbacks=callbacks,
+ logger=logger,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
+ replace_sampler_ddp=False, # use custom sampler
+ reload_dataloaders_every_epoch=False, # avoid repeated samples!
+ weights_summary='full',
+ profiler=profiler)
+ loguru_logger.info(f"Trainer initialized!")
+ loguru_logger.info(f"Start training!")
+
+ trainer.fit(model, datamodule=data_module)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/imcui/third_party/GlueStick/gluestick/__init__.py b/imcui/third_party/GlueStick/gluestick/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3051821ecfb2e18f4b9b4dfb50f35064106eb57
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/__init__.py
@@ -0,0 +1,53 @@
+import collections.abc as collections
+from pathlib import Path
+
+import torch
+
+GLUESTICK_ROOT = Path(__file__).parent.parent
+
+
+def get_class(mod_name, base_path, BaseClass):
+ """Get the class object which inherits from BaseClass and is defined in
+ the module named mod_name, child of base_path.
+ """
+ import inspect
+ mod_path = '{}.{}'.format(base_path, mod_name)
+ mod = __import__(mod_path, fromlist=[''])
+ classes = inspect.getmembers(mod, inspect.isclass)
+ # Filter classes defined in the module
+ classes = [c for c in classes if c[1].__module__ == mod_path]
+ # Filter classes inherited from BaseModel
+ classes = [c for c in classes if issubclass(c[1], BaseClass)]
+ assert len(classes) == 1, classes
+ return classes[0][1]
+
+
+def get_model(name):
+ from .models.base_model import BaseModel
+ return get_class('models.' + name, __name__, BaseModel)
+
+
+def numpy_image_to_torch(image):
+ """Normalize the image tensor and reorder the dimensions."""
+ if image.ndim == 3:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ elif image.ndim == 2:
+ image = image[None] # add channel axis
+ else:
+ raise ValueError(f'Not an image: {image.shape}')
+ return torch.from_numpy(image / 255.).float()
+
+
+def map_tensor(input_, func):
+ if isinstance(input_, (str, bytes)):
+ return input_
+ elif isinstance(input_, collections.Mapping):
+ return {k: map_tensor(sample, func) for k, sample in input_.items()}
+ elif isinstance(input_, collections.Sequence):
+ return [map_tensor(sample, func) for sample in input_]
+ else:
+ return func(input_)
+
+
+def batch_to_np(batch):
+ return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0])
diff --git a/imcui/third_party/GlueStick/gluestick/drawing.py b/imcui/third_party/GlueStick/gluestick/drawing.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e6d24b6bfedc93449142647410057d978d733ef
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/drawing.py
@@ -0,0 +1,166 @@
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import seaborn as sns
+
+
+def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
+ adaptive=True):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ adaptive: whether the figure size should fit the image aspect ratios.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ if adaptive:
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
+ else:
+ ratios = [4 / 3] * n
+ figsize = [sum(ratios) * 4.5, 4.5]
+ fig, ax = plt.subplots(
+ 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
+ if n == 1:
+ ax = [ax]
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ ax[i].set_axis_off()
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ if titles:
+ ax[i].set_title(titles[i])
+ fig.tight_layout(pad=pad)
+ return ax
+
+
+def plot_keypoints(kpts, colors='lime', ps=4, alpha=1):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ axes = plt.gcf().axes
+ for a, k, c in zip(axes, kpts, colors):
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, alpha=alpha, linewidths=0)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
+ alpha=a)
+ for i in range(len(kpts0))]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def plot_lines(lines, line_colors='orange', point_colors='cyan',
+ ps=4, lw=2, alpha=1., indices=(0, 1)):
+ """ Plot lines and endpoints for existing images.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float pixels.
+ lw: line width as float pixels.
+ alpha: transparency of the points and lines.
+ indices: indices of the images to draw the matches on.
+ """
+ if not isinstance(line_colors, list):
+ line_colors = [line_colors] * len(lines)
+ if not isinstance(point_colors, list):
+ point_colors = [point_colors] * len(lines)
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines and junctions
+ for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
+ for i in range(len(l)):
+ line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]),
+ (l[i, 0, 1], l[i, 1, 1]),
+ zorder=1, c=lc, linewidth=lw,
+ alpha=alpha)
+ a.add_line(line)
+ pts = l.reshape(-1, 2)
+ a.scatter(pts[:, 0], pts[:, 1],
+ c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha)
+
+
+def plot_color_line_matches(lines, correct_matches=None,
+ lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ correct_matches: bool array of size (N,) indicating correct matches.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ n_lines = len(lines[0])
+ colors = sns.color_palette('husl', n_colors=n_lines)
+ np.random.shuffle(colors)
+ alphas = np.ones(n_lines)
+ # If correct_matches is not None, display wrong matches with a low alpha
+ if correct_matches is not None:
+ alphas[~np.array(correct_matches)] = 0.2
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l in zip(axes, lines):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=colors[i],
+ alpha=alphas[i], linewidth=lw) for i in range(n_lines)]
diff --git a/imcui/third_party/GlueStick/gluestick/geometry.py b/imcui/third_party/GlueStick/gluestick/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..97853c4807d319eb9ea0377db7385e9a72fb400b
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/geometry.py
@@ -0,0 +1,175 @@
+from typing import Tuple
+
+import numpy as np
+import torch
+
+
+def to_homogeneous(points):
+ """Convert N-dimensional points to homogeneous coordinates.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N).
+ Returns:
+ A torch.Tensor or numpy.ndarray with size (..., N+1).
+ """
+ if isinstance(points, torch.Tensor):
+ pad = points.new_ones(points.shape[:-1] + (1,))
+ return torch.cat([points, pad], dim=-1)
+ elif isinstance(points, np.ndarray):
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
+ return np.concatenate([points, pad], axis=-1)
+ else:
+ raise ValueError
+
+
+def from_homogeneous(points, eps=0.):
+ """Remove the homogeneous dimension of N-dimensional points.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N+1).
+ Returns:
+ A torch.Tensor or numpy ndarray with size (..., N).
+ """
+ return points[..., :-1] / (points[..., -1:] + eps)
+
+
+def skew_symmetric(v):
+ """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
+ """
+ z = torch.zeros_like(v[..., 0])
+ M = torch.stack([
+ z, -v[..., 2], v[..., 1],
+ v[..., 2], z, -v[..., 0],
+ -v[..., 1], v[..., 0], z,
+ ], dim=-1).reshape(v.shape[:-1] + (3, 3))
+ return M
+
+
+def T_to_E(T):
+ """Convert batched poses (..., 4, 4) to batched essential matrices."""
+ return skew_symmetric(T[..., :3, 3]) @ T[..., :3, :3]
+
+
+def warp_points_torch(points, H, inverse=True):
+ """
+ Warp a list of points with the INVERSE of the given homography.
+ The inverse is used to be coherent with tf.contrib.image.transform
+ Arguments:
+ points: batched list of N points, shape (B, N, 2).
+ homography: batched or not (shapes (B, 8) and (8,) respectively).
+ Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warped points.
+ """
+ # H = np.expand_dims(homography, axis=0) if len(homography.shape) == 1 else homography
+
+ # Get the points to the homogeneous format
+ points = to_homogeneous(points)
+
+ # Apply the homography
+ out_shape = tuple(list(H.shape[:-1]) + [3, 3])
+ H_mat = torch.cat([H, torch.ones_like(H[..., :1])], axis=-1).reshape(out_shape)
+ if inverse:
+ H_mat = torch.inverse(H_mat)
+ warped_points = torch.einsum('...nj,...ji->...ni', points, H_mat.transpose(-2, -1))
+
+ warped_points = from_homogeneous(warped_points, eps=1e-5)
+
+ return warped_points
+
+
+def seg_equation(segs):
+ # calculate list of start, end and midpoints points from both lists
+ start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(segs[..., 1, :])
+ # Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1
+ lines = torch.cross(start_points, end_points, dim=-1)
+ lines_norm = (torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None])
+ assert torch.all(lines_norm > 0), 'Error: trying to compute the equation of a line with a single point'
+ lines = lines / lines_norm
+ return lines
+
+
+def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]):
+ h, w = img_shape
+ return (pts >= 0).all(dim=-1) & (pts[..., 0] < w) & (pts[..., 1] < h) & (~torch.isinf(pts).any(dim=-1))
+
+
+def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor:
+ """
+ Shrink an array of segments to fit inside the image.
+ :param segs: The tensor of segments with shape (N, 2, 2)
+ :param img_shape: The image shape in format (H, W)
+ """
+ EPS = 1e-4
+ device = segs.device
+ w, h = img_shape[1], img_shape[0]
+ # Project the segments to the reference image
+ segs = segs.clone()
+ eqs = seg_equation(segs)
+ x0, y0 = torch.tensor([1., 0, 0.], device=device), torch.tensor([0., 1, 0], device=device)
+ x0 = x0.repeat(eqs.shape[:-1] + (1,))
+ y0 = y0.repeat(eqs.shape[:-1] + (1,))
+ pt_x0s = torch.cross(eqs, x0, dim=-1)
+ pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1]
+ pt_x0s_valid = is_inside_img(pt_x0s, img_shape)
+ pt_y0s = torch.cross(eqs, y0, dim=-1)
+ pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1]
+ pt_y0s_valid = is_inside_img(pt_y0s, img_shape)
+
+ xW, yH = torch.tensor([1., 0, EPS - w], device=device), torch.tensor([0., 1, EPS - h], device=device)
+ xW = xW.repeat(eqs.shape[:-1] + (1,))
+ yH = yH.repeat(eqs.shape[:-1] + (1,))
+ pt_xWs = torch.cross(eqs, xW, dim=-1)
+ pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1]
+ pt_xWs_valid = is_inside_img(pt_xWs, img_shape)
+ pt_yHs = torch.cross(eqs, yH, dim=-1)
+ pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1]
+ pt_yHs_valid = is_inside_img(pt_yHs, img_shape)
+
+ # If the X coordinate of the first endpoint is out
+ mask = (segs[..., 0, 0] < 0) & pt_x0s_valid
+ segs[mask, 0, :] = pt_x0s[mask]
+ mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid
+ segs[mask, 0, :] = pt_xWs[mask]
+ # If the X coordinate of the second endpoint is out
+ mask = (segs[..., 1, 0] < 0) & pt_x0s_valid
+ segs[mask, 1, :] = pt_x0s[mask]
+ mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid
+ segs[mask, 1, :] = pt_xWs[mask]
+ # If the Y coordinate of the first endpoint is out
+ mask = (segs[..., 0, 1] < 0) & pt_y0s_valid
+ segs[mask, 0, :] = pt_y0s[mask]
+ mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid
+ segs[mask, 0, :] = pt_yHs[mask]
+ # If the Y coordinate of the second endpoint is out
+ mask = (segs[..., 1, 1] < 0) & pt_y0s_valid
+ segs[mask, 1, :] = pt_y0s[mask]
+ mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid
+ segs[mask, 1, :] = pt_yHs[mask]
+
+ assert torch.all(segs >= 0) and torch.all(segs[..., 0] < w) and torch.all(segs[..., 1] < h)
+ return segs
+
+
+def warp_lines_torch(lines, H, inverse=True, dst_shape: Tuple[int, int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ :param lines: A tensor of shape (B, N, 2, 2) where B is the batch size, N the number of lines.
+ :param H: The homography used to convert the lines. batched or not (shapes (B, 8) and (8,) respectively).
+ :param inverse: Whether to apply H or the inverse of H
+ :param dst_shape:If provided, lines are trimmed to be inside the image
+ """
+ device = lines.device
+ batch_size, n = lines.shape[:2]
+ lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(lines.shape)
+
+ if dst_shape is None:
+ return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device)
+
+ out_img = torch.any((lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1)
+ valid = ~out_img.all(-1)
+ any_out_of_img = out_img.any(-1)
+ lines_to_trim = valid & any_out_of_img
+
+ for b in range(batch_size):
+ lines_to_trim_mask_b = lines_to_trim[b]
+ lines_to_trim_b = lines[b][lines_to_trim_mask_b]
+ corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape)
+ lines[b][lines_to_trim_mask_b] = corrected_lines
+
+ return lines, valid
diff --git a/imcui/third_party/GlueStick/gluestick/models/__init__.py b/imcui/third_party/GlueStick/gluestick/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/GlueStick/gluestick/models/base_model.py b/imcui/third_party/GlueStick/gluestick/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..30ca991655a28ca88074b42312c33b360f655fab
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/models/base_model.py
@@ -0,0 +1,126 @@
+"""
+Base class for trainable models.
+"""
+
+from abc import ABCMeta, abstractmethod
+import omegaconf
+from omegaconf import OmegaConf
+from torch import nn
+from copy import copy
+
+
+class MetaModel(ABCMeta):
+ def __prepare__(name, bases, **kwds):
+ total_conf = OmegaConf.create()
+ for base in bases:
+ for key in ('base_default_conf', 'default_conf'):
+ update = getattr(base, key, {})
+ if isinstance(update, dict):
+ update = OmegaConf.create(update)
+ total_conf = OmegaConf.merge(total_conf, update)
+ return dict(base_default_conf=total_conf)
+
+
+class BaseModel(nn.Module, metaclass=MetaModel):
+ """
+ What the child model is expect to declare:
+ default_conf: dictionary of the default configuration of the model.
+ It recursively updates the default_conf of all parent classes, and
+ it is updated by the user-provided configuration passed to __init__.
+ Configurations can be nested.
+
+ required_data_keys: list of expected keys in the input data dictionary.
+
+ strict_conf (optional): boolean. If false, BaseModel does not raise
+ an error when the user provides an unknown configuration entry.
+
+ _init(self, conf): initialization method, where conf is the final
+ configuration object (also accessible with `self.conf`). Accessing
+ unknown configuration entries will raise an error.
+
+ _forward(self, data): method that returns a dictionary of batched
+ prediction tensors based on a dictionary of batched input data tensors.
+
+ loss(self, pred, data): method that returns a dictionary of losses,
+ computed from model predictions and input data. Each loss is a batch
+ of scalars, i.e. a torch.Tensor of shape (B,).
+ The total loss to be optimized has the key `'total'`.
+
+ metrics(self, pred, data): method that returns a dictionary of metrics,
+ each as a batch of scalars.
+ """
+ default_conf = {
+ 'name': None,
+ 'trainable': True, # if false: do not optimize this model parameters
+ 'freeze_batch_normalization': False, # use test-time statistics
+ }
+ required_data_keys = []
+ strict_conf = True
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ super().__init__()
+ default_conf = OmegaConf.merge(
+ self.base_default_conf, OmegaConf.create(self.default_conf))
+ if self.strict_conf:
+ OmegaConf.set_struct(default_conf, True)
+
+ # fixme: backward compatibility
+ if 'pad' in conf and 'pad' not in default_conf: # backward compat.
+ with omegaconf.read_write(conf):
+ with omegaconf.open_dict(conf):
+ conf['interpolation'] = {'pad': conf.pop('pad')}
+
+ if isinstance(conf, dict):
+ conf = OmegaConf.create(conf)
+ self.conf = conf = OmegaConf.merge(default_conf, conf)
+ OmegaConf.set_readonly(conf, True)
+ OmegaConf.set_struct(conf, True)
+ self.required_data_keys = copy(self.required_data_keys)
+ self._init(conf)
+
+ if not conf.trainable:
+ for p in self.parameters():
+ p.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(mode)
+
+ def freeze_bn(module):
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
+ module.eval()
+ if self.conf.freeze_batch_normalization:
+ self.apply(freeze_bn)
+
+ return self
+
+ def forward(self, data):
+ """Check the data and call the _forward method of the child model."""
+ def recursive_key_check(expected, given):
+ for key in expected:
+ assert key in given, f'Missing key {key} in data'
+ if isinstance(expected, dict):
+ recursive_key_check(expected[key], given[key])
+
+ recursive_key_check(self.required_data_keys, data)
+ return self._forward(data)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _forward(self, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def loss(self, pred, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def metrics(self, pred, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
diff --git a/imcui/third_party/GlueStick/gluestick/models/gluestick.py b/imcui/third_party/GlueStick/gluestick/models/gluestick.py
new file mode 100644
index 0000000000000000000000000000000000000000..98550ff9d8918bcf49a13ae606d1d631448b8f96
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/models/gluestick.py
@@ -0,0 +1,563 @@
+import os.path
+import warnings
+from copy import deepcopy
+
+warnings.filterwarnings("ignore", category=UserWarning)
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from .base_model import BaseModel
+
+ETH_EPS = 1e-8
+
+
+class GlueStick(BaseModel):
+ default_conf = {
+ 'input_dim': 256,
+ 'descriptor_dim': 256,
+ 'bottleneck_dim': None,
+ 'weights': None,
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9,
+ 'num_line_iterations': 1,
+ 'line_attention': False,
+ 'filter_threshold': 0.2,
+ 'checkpointed': False,
+ 'skip_init': False,
+ 'inter_supervision': None,
+ 'loss': {
+ 'nll_weight': 1.,
+ 'nll_balancing': 0.5,
+ 'reward_weight': 0.,
+ 'bottleneck_l2_weight': 0.,
+ 'dense_nll_weight': 0.,
+ 'inter_supervision': [0.3, 0.6],
+ },
+ }
+ required_data_keys = [
+ 'keypoints0', 'keypoints1',
+ 'descriptors0', 'descriptors1',
+ 'keypoint_scores0', 'keypoint_scores1']
+
+ DEFAULT_LOSS_CONF = {'nll_weight': 1., 'nll_balancing': 0.5, 'reward_weight': 0., 'bottleneck_l2_weight': 0.}
+
+ def _init(self, conf):
+ if conf.bottleneck_dim is not None:
+ self.bottleneck_down = nn.Conv1d(
+ conf.input_dim, conf.bottleneck_dim, kernel_size=1)
+ self.bottleneck_up = nn.Conv1d(
+ conf.bottleneck_dim, conf.input_dim, kernel_size=1)
+ nn.init.constant_(self.bottleneck_down.bias, 0.0)
+ nn.init.constant_(self.bottleneck_up.bias, 0.0)
+
+ if conf.input_dim != conf.descriptor_dim:
+ self.input_proj = nn.Conv1d(
+ conf.input_dim, conf.descriptor_dim, kernel_size=1)
+ nn.init.constant_(self.input_proj.bias, 0.0)
+
+ self.kenc = KeypointEncoder(conf.descriptor_dim,
+ conf.keypoint_encoder)
+ self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder)
+ self.gnn = AttentionalGNN(conf.descriptor_dim, conf.GNN_layers,
+ checkpointed=conf.checkpointed,
+ inter_supervision=conf.inter_supervision,
+ num_line_iterations=conf.num_line_iterations,
+ line_attention=conf.line_attention)
+ self.final_proj = nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim,
+ kernel_size=1)
+ nn.init.constant_(self.final_proj.bias, 0.0)
+ nn.init.orthogonal_(self.final_proj.weight, gain=1)
+ self.final_line_proj = nn.Conv1d(
+ conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
+ nn.init.constant_(self.final_line_proj.bias, 0.0)
+ nn.init.orthogonal_(self.final_line_proj.weight, gain=1)
+ if conf.inter_supervision is not None:
+ self.inter_line_proj = nn.ModuleList(
+ [nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
+ for _ in conf.inter_supervision])
+ self.layer2idx = {}
+ for i, l in enumerate(conf.inter_supervision):
+ nn.init.constant_(self.inter_line_proj[i].bias, 0.0)
+ nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1)
+ self.layer2idx[l] = i
+
+ bin_score = torch.nn.Parameter(torch.tensor(1.))
+ self.register_parameter('bin_score', bin_score)
+ line_bin_score = torch.nn.Parameter(torch.tensor(1.))
+ self.register_parameter('line_bin_score', line_bin_score)
+
+ if conf.weights:
+ assert isinstance(conf.weights, str)
+ if os.path.exists(conf.weights):
+ state_dict = torch.load(conf.weights, map_location='cpu')
+ else:
+ weights_url = "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar"
+ state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu')
+ if 'model' in state_dict:
+ state_dict = {k.replace('matcher.', ''): v for k, v in state_dict['model'].items() if 'matcher.' in k}
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
+ self.load_state_dict(state_dict)
+
+ def _forward(self, data):
+ device = data['keypoints0'].device
+ b_size = len(data['keypoints0'])
+ image_size0 = (data['image_size0'] if 'image_size0' in data
+ else data['image0'].shape)
+ image_size1 = (data['image_size1'] if 'image_size1' in data
+ else data['image1'].shape)
+
+ pred = {}
+ desc0, desc1 = data['descriptors0'], data['descriptors1']
+ kpts0, kpts1 = data['keypoints0'], data['keypoints1']
+
+ n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1]
+ n_lines0, n_lines1 = data['lines0'].shape[1], data['lines1'].shape[1]
+ if n_kpts0 == 0 or n_kpts1 == 0:
+ # No detected keypoints nor lines
+ pred['log_assignment'] = torch.zeros(
+ b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device)
+ pred['matches0'] = torch.full(
+ (b_size, n_kpts0), -1, device=device, dtype=torch.int64)
+ pred['matches1'] = torch.full(
+ (b_size, n_kpts1), -1, device=device, dtype=torch.int64)
+ pred['match_scores0'] = torch.zeros(
+ (b_size, n_kpts0), device=device, dtype=torch.float32)
+ pred['match_scores1'] = torch.zeros(
+ (b_size, n_kpts1), device=device, dtype=torch.float32)
+ pred['line_log_assignment'] = torch.zeros(b_size, n_lines0, n_lines1,
+ dtype=torch.float, device=device)
+ pred['line_matches0'] = torch.full((b_size, n_lines0), -1,
+ device=device, dtype=torch.int64)
+ pred['line_matches1'] = torch.full((b_size, n_lines1), -1,
+ device=device, dtype=torch.int64)
+ pred['line_match_scores0'] = torch.zeros(
+ (b_size, n_lines0), device=device, dtype=torch.float32)
+ pred['line_match_scores1'] = torch.zeros(
+ (b_size, n_kpts1), device=device, dtype=torch.float32)
+ return pred
+
+ lines0 = data['lines0'].flatten(1, 2)
+ lines1 = data['lines1'].flatten(1, 2)
+ lines_junc_idx0 = data['lines_junc_idx0'].flatten(1, 2) # [b_size, num_lines * 2]
+ lines_junc_idx1 = data['lines_junc_idx1'].flatten(1, 2)
+
+ if self.conf.bottleneck_dim is not None:
+ pred['down_descriptors0'] = desc0 = self.bottleneck_down(desc0)
+ pred['down_descriptors1'] = desc1 = self.bottleneck_down(desc1)
+ desc0 = self.bottleneck_up(desc0)
+ desc1 = self.bottleneck_up(desc1)
+ desc0 = nn.functional.normalize(desc0, p=2, dim=1)
+ desc1 = nn.functional.normalize(desc1, p=2, dim=1)
+ pred['bottleneck_descriptors0'] = desc0
+ pred['bottleneck_descriptors1'] = desc1
+ if self.conf.loss.nll_weight == 0:
+ desc0 = desc0.detach()
+ desc1 = desc1.detach()
+
+ if self.conf.input_dim != self.conf.descriptor_dim:
+ desc0 = self.input_proj(desc0)
+ desc1 = self.input_proj(desc1)
+
+ kpts0 = normalize_keypoints(kpts0, image_size0)
+ kpts1 = normalize_keypoints(kpts1, image_size1)
+
+ assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
+ assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
+ desc0 = desc0 + self.kenc(kpts0, data['keypoint_scores0'])
+ desc1 = desc1 + self.kenc(kpts1, data['keypoint_scores1'])
+
+ if n_lines0 != 0 and n_lines1 != 0:
+ # Pre-compute the line encodings
+ lines0 = normalize_keypoints(lines0, image_size0).reshape(
+ b_size, n_lines0, 2, 2)
+ lines1 = normalize_keypoints(lines1, image_size1).reshape(
+ b_size, n_lines1, 2, 2)
+ line_enc0 = self.lenc(lines0, data['line_scores0'])
+ line_enc1 = self.lenc(lines1, data['line_scores1'])
+ else:
+ line_enc0 = torch.zeros(
+ b_size, self.conf.descriptor_dim, n_lines0 * 2,
+ dtype=torch.float, device=device)
+ line_enc1 = torch.zeros(
+ b_size, self.conf.descriptor_dim, n_lines1 * 2,
+ dtype=torch.float, device=device)
+
+ desc0, desc1 = self.gnn(desc0, desc1, line_enc0, line_enc1,
+ lines_junc_idx0, lines_junc_idx1)
+
+ # Match all points (KP and line junctions)
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
+
+ kp_scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
+ kp_scores = kp_scores / self.conf.descriptor_dim ** .5
+ kp_scores = log_double_softmax(kp_scores, self.bin_score)
+ m0, m1, mscores0, mscores1 = self._get_matches(kp_scores)
+ pred['log_assignment'] = kp_scores
+ pred['matches0'] = m0
+ pred['matches1'] = m1
+ pred['match_scores0'] = mscores0
+ pred['match_scores1'] = mscores1
+
+ # Match the lines
+ if n_lines0 > 0 and n_lines1 > 0:
+ (line_scores, m0_lines, m1_lines, mscores0_lines,
+ mscores1_lines, raw_line_scores) = self._get_line_matches(
+ desc0[:, :, :2 * n_lines0], desc1[:, :, :2 * n_lines1],
+ lines_junc_idx0, lines_junc_idx1, self.final_line_proj)
+ if self.conf.inter_supervision:
+ for l in self.conf.inter_supervision:
+ (line_scores_i, m0_lines_i, m1_lines_i, mscores0_lines_i,
+ mscores1_lines_i) = self._get_line_matches(
+ self.gnn.inter_layers[l][0][:, :, :2 * n_lines0],
+ self.gnn.inter_layers[l][1][:, :, :2 * n_lines1],
+ lines_junc_idx0, lines_junc_idx1,
+ self.inter_line_proj[self.layer2idx[l]])
+ pred[f'line_{l}_log_assignment'] = line_scores_i
+ pred[f'line_{l}_matches0'] = m0_lines_i
+ pred[f'line_{l}_matches1'] = m1_lines_i
+ pred[f'line_{l}_match_scores0'] = mscores0_lines_i
+ pred[f'line_{l}_match_scores1'] = mscores1_lines_i
+ else:
+ line_scores = torch.zeros(b_size, n_lines0, n_lines1,
+ dtype=torch.float, device=device)
+ m0_lines = torch.full((b_size, n_lines0), -1,
+ device=device, dtype=torch.int64)
+ m1_lines = torch.full((b_size, n_lines1), -1,
+ device=device, dtype=torch.int64)
+ mscores0_lines = torch.zeros(
+ (b_size, n_lines0), device=device, dtype=torch.float32)
+ mscores1_lines = torch.zeros(
+ (b_size, n_lines1), device=device, dtype=torch.float32)
+ raw_line_scores = torch.zeros(b_size, n_lines0, n_lines1,
+ dtype=torch.float, device=device)
+ pred['line_log_assignment'] = line_scores
+ pred['line_matches0'] = m0_lines
+ pred['line_matches1'] = m1_lines
+ pred['line_match_scores0'] = mscores0_lines
+ pred['line_match_scores1'] = mscores1_lines
+ pred['raw_line_scores'] = raw_line_scores
+
+ return pred
+
+ def _get_matches(self, scores_mat):
+ max0 = scores_mat[:, :-1, :-1].max(2)
+ max1 = scores_mat[:, :-1, :-1].max(1)
+ m0, m1 = max0.indices, max1.indices
+ mutual0 = arange_like(m0, 1)[None] == m1.gather(1, m0)
+ mutual1 = arange_like(m1, 1)[None] == m0.gather(1, m1)
+ zero = scores_mat.new_tensor(0)
+ mscores0 = torch.where(mutual0, max0.values.exp(), zero)
+ mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
+ valid0 = mutual0 & (mscores0 > self.conf.filter_threshold)
+ valid1 = mutual1 & valid0.gather(1, m1)
+ m0 = torch.where(valid0, m0, m0.new_tensor(-1))
+ m1 = torch.where(valid1, m1, m1.new_tensor(-1))
+ return m0, m1, mscores0, mscores1
+
+ def _get_line_matches(self, ldesc0, ldesc1, lines_junc_idx0,
+ lines_junc_idx1, final_proj):
+ mldesc0 = final_proj(ldesc0)
+ mldesc1 = final_proj(ldesc1)
+
+ line_scores = torch.einsum('bdn,bdm->bnm', mldesc0, mldesc1)
+ line_scores = line_scores / self.conf.descriptor_dim ** .5
+
+ # Get the line representation from the junction descriptors
+ n2_lines0 = lines_junc_idx0.shape[1]
+ n2_lines1 = lines_junc_idx1.shape[1]
+ line_scores = torch.gather(
+ line_scores, dim=2,
+ index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1))
+ line_scores = torch.gather(
+ line_scores, dim=1,
+ index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1))
+ line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2,
+ n2_lines1 // 2, 2))
+
+ # Match either in one direction or the other
+ raw_line_scores = 0.5 * torch.maximum(
+ line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1],
+ line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0])
+ line_scores = log_double_softmax(raw_line_scores, self.line_bin_score)
+ m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches(
+ line_scores)
+ return (line_scores, m0_lines, m1_lines, mscores0_lines,
+ mscores1_lines, raw_line_scores)
+
+ def loss(self, pred, data):
+ raise NotImplementedError()
+
+ def metrics(self, pred, data):
+ raise NotImplementedError()
+
+
+def MLP(channels, do_bn=True):
+ n = len(channels)
+ layers = []
+ for i in range(1, n):
+ layers.append(
+ nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
+ if i < (n - 1):
+ if do_bn:
+ layers.append(nn.BatchNorm1d(channels[i]))
+ layers.append(nn.ReLU())
+ return nn.Sequential(*layers)
+
+
+def normalize_keypoints(kpts, shape_or_size):
+ if isinstance(shape_or_size, (tuple, list)):
+ # it's a shape
+ h, w = shape_or_size[-2:]
+ size = kpts.new_tensor([[w, h]])
+ else:
+ # it's a size
+ assert isinstance(shape_or_size, torch.Tensor)
+ size = shape_or_size.to(kpts)
+ c = size / 2
+ f = size.max(1, keepdim=True).values * 0.7 # somehow we used 0.7 for SG
+ return (kpts - c[:, None, :]) / f[:, None, :]
+
+
+class KeypointEncoder(nn.Module):
+ def __init__(self, feature_dim, layers):
+ super().__init__()
+ self.encoder = MLP([3] + list(layers) + [feature_dim], do_bn=True)
+ nn.init.constant_(self.encoder[-1].bias, 0.0)
+
+ def forward(self, kpts, scores):
+ inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
+ return self.encoder(torch.cat(inputs, dim=1))
+
+
+class EndPtEncoder(nn.Module):
+ def __init__(self, feature_dim, layers):
+ super().__init__()
+ self.encoder = MLP([5] + list(layers) + [feature_dim], do_bn=True)
+ nn.init.constant_(self.encoder[-1].bias, 0.0)
+
+ def forward(self, endpoints, scores):
+ # endpoints should be [B, N, 2, 2]
+ # output is [B, feature_dim, N * 2]
+ b_size, n_pts, _, _ = endpoints.shape
+ assert tuple(endpoints.shape[-2:]) == (2, 2)
+ endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2)
+ endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2)
+ endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2)
+ inputs = [endpoints.flatten(1, 2).transpose(1, 2),
+ endpt_offset, scores.repeat(1, 2).unsqueeze(1)]
+ return self.encoder(torch.cat(inputs, dim=1))
+
+
+@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+def attention(query, key, value):
+ dim = query.shape[1]
+ scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5
+ prob = torch.nn.functional.softmax(scores, dim=-1)
+ return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
+
+
+class MultiHeadedAttention(nn.Module):
+ def __init__(self, h, d_model):
+ super().__init__()
+ assert d_model % h == 0
+ self.dim = d_model // h
+ self.h = h
+ self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
+ self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
+ # self.prob = []
+
+ def forward(self, query, key, value):
+ b = query.size(0)
+ query, key, value = [l(x).view(b, self.dim, self.h, -1)
+ for l, x in zip(self.proj, (query, key, value))]
+ x, prob = attention(query, key, value)
+ # self.prob.append(prob.mean(dim=1))
+ return self.merge(x.contiguous().view(b, self.dim * self.h, -1))
+
+
+class AttentionalPropagation(nn.Module):
+ def __init__(self, num_dim, num_heads, skip_init=False):
+ super().__init__()
+ self.attn = MultiHeadedAttention(num_heads, num_dim)
+ self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True)
+ nn.init.constant_(self.mlp[-1].bias, 0.0)
+ if skip_init:
+ self.register_parameter('scaling', nn.Parameter(torch.tensor(0.)))
+ else:
+ self.scaling = 1.
+
+ def forward(self, x, source):
+ message = self.attn(x, source, source)
+ return self.mlp(torch.cat([x, message], dim=1)) * self.scaling
+
+
+class GNNLayer(nn.Module):
+ def __init__(self, feature_dim, layer_type, skip_init):
+ super().__init__()
+ assert layer_type in ['cross', 'self']
+ self.type = layer_type
+ self.update = AttentionalPropagation(feature_dim, 4, skip_init)
+
+ def forward(self, desc0, desc1):
+ if self.type == 'cross':
+ src0, src1 = desc1, desc0
+ elif self.type == 'self':
+ src0, src1 = desc0, desc1
+ else:
+ raise ValueError("Unknown layer type: " + self.type)
+ # self.update.attn.prob = []
+ delta0, delta1 = self.update(desc0, src0), self.update(desc1, src1)
+ desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
+ return desc0, desc1
+
+
+class LineLayer(nn.Module):
+ def __init__(self, feature_dim, line_attention=False):
+ super().__init__()
+ self.dim = feature_dim
+ self.mlp = MLP([self.dim * 3, self.dim * 2, self.dim], do_bn=True)
+ self.line_attention = line_attention
+ if line_attention:
+ self.proj_node = nn.Conv1d(self.dim, self.dim, kernel_size=1)
+ self.proj_neigh = nn.Conv1d(2 * self.dim, self.dim, kernel_size=1)
+
+ def get_endpoint_update(self, ldesc, line_enc, lines_junc_idx):
+ # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
+ # and lines_junc_idx [bs, n_lines * 2]
+ # Create one message per line endpoint
+ b_size = lines_junc_idx.shape[0]
+ line_desc = torch.gather(
+ ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1))
+ message = torch.cat([
+ line_desc,
+ line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(),
+ line_enc], dim=1)
+ return self.mlp(message) # [b_size, D, n_lines * 2]
+
+ def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx):
+ # ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
+ # and lines_junc_idx [bs, n_lines * 2]
+ b_size = lines_junc_idx.shape[0]
+ expanded_lines_junc_idx = lines_junc_idx[:, None].repeat(1, self.dim, 1)
+
+ # Query: desc of the current node
+ query = self.proj_node(ldesc) # [b_size, D, n_junc]
+ query = torch.gather(query, 2, expanded_lines_junc_idx)
+ # query is [b_size, D, n_lines * 2]
+
+ # Key: combination of neighboring desc and line encodings
+ line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx)
+ key = self.proj_neigh(torch.cat([
+ line_desc.reshape(b_size, self.dim, -1, 2).flip([-1]).flatten(2, 3).clone(),
+ line_enc], dim=1)) # [b_size, D, n_lines * 2]
+
+ # Compute the attention weights with a custom softmax per junction
+ prob = (query * key).sum(dim=1) / self.dim ** .5 # [b_size, n_lines * 2]
+ prob = torch.exp(prob - prob.max())
+ denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_(
+ dim=1, index=lines_junc_idx,
+ src=prob, reduce='sum', include_self=False) # [b_size, n_junc]
+ denom = torch.gather(denom, 1, lines_junc_idx) # [b_size, n_lines * 2]
+ prob = prob / (denom + ETH_EPS)
+ return prob # [b_size, n_lines * 2]
+
+ def forward(self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0,
+ lines_junc_idx1):
+ # Gather the endpoint updates
+ lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0)
+ lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1)
+
+ update0, update1 = torch.zeros_like(ldesc0), torch.zeros_like(ldesc1)
+ dim = ldesc0.shape[1]
+ if self.line_attention:
+ # Compute an attention for each neighbor and do a weighted average
+ prob0 = self.get_endpoint_attention(ldesc0, line_enc0,
+ lines_junc_idx0)
+ lupdate0 = lupdate0 * prob0[:, None]
+ update0 = update0.scatter_reduce_(
+ dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1),
+ src=lupdate0, reduce='sum', include_self=False)
+ prob1 = self.get_endpoint_attention(ldesc1, line_enc1,
+ lines_junc_idx1)
+ lupdate1 = lupdate1 * prob1[:, None]
+ update1 = update1.scatter_reduce_(
+ dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1),
+ src=lupdate1, reduce='sum', include_self=False)
+ else:
+ # Average the updates for each junction (requires torch > 1.12)
+ update0 = update0.scatter_reduce_(
+ dim=2, index=lines_junc_idx0[:, None].repeat(1, dim, 1),
+ src=lupdate0, reduce='mean', include_self=False)
+ update1 = update1.scatter_reduce_(
+ dim=2, index=lines_junc_idx1[:, None].repeat(1, dim, 1),
+ src=lupdate1, reduce='mean', include_self=False)
+
+ # Update
+ ldesc0 = ldesc0 + update0
+ ldesc1 = ldesc1 + update1
+
+ return ldesc0, ldesc1
+
+
+class AttentionalGNN(nn.Module):
+ def __init__(self, feature_dim, layer_types, checkpointed=False,
+ skip=False, inter_supervision=None, num_line_iterations=1,
+ line_attention=False):
+ super().__init__()
+ self.checkpointed = checkpointed
+ self.inter_supervision = inter_supervision
+ self.num_line_iterations = num_line_iterations
+ self.inter_layers = {}
+ self.layers = nn.ModuleList([
+ GNNLayer(feature_dim, layer_type, skip)
+ for layer_type in layer_types])
+ self.line_layers = nn.ModuleList(
+ [LineLayer(feature_dim, line_attention)
+ for _ in range(len(layer_types) // 2)])
+
+ def forward(self, desc0, desc1, line_enc0, line_enc1,
+ lines_junc_idx0, lines_junc_idx1):
+ for i, layer in enumerate(self.layers):
+ if self.checkpointed:
+ desc0, desc1 = torch.utils.checkpoint.checkpoint(
+ layer, desc0, desc1, preserve_rng_state=False)
+ else:
+ desc0, desc1 = layer(desc0, desc1)
+ if (layer.type == 'self' and lines_junc_idx0.shape[1] > 0
+ and lines_junc_idx1.shape[1] > 0):
+ # Add line self attention layers after every self layer
+ for _ in range(self.num_line_iterations):
+ if self.checkpointed:
+ desc0, desc1 = torch.utils.checkpoint.checkpoint(
+ self.line_layers[i // 2], desc0, desc1, line_enc0,
+ line_enc1, lines_junc_idx0, lines_junc_idx1,
+ preserve_rng_state=False)
+ else:
+ desc0, desc1 = self.line_layers[i // 2](
+ desc0, desc1, line_enc0, line_enc1,
+ lines_junc_idx0, lines_junc_idx1)
+
+ # Optionally store the line descriptor at intermediate layers
+ if (self.inter_supervision is not None
+ and (i // 2) in self.inter_supervision
+ and layer.type == 'cross'):
+ self.inter_layers[i // 2] = (desc0.clone(), desc1.clone())
+ return desc0, desc1
+
+
+def log_double_softmax(scores, bin_score):
+ b, m, n = scores.shape
+ bin_ = bin_score[None, None, None]
+ scores0 = torch.cat([scores, bin_.expand(b, m, 1)], 2)
+ scores1 = torch.cat([scores, bin_.expand(b, 1, n)], 1)
+ scores0 = torch.nn.functional.log_softmax(scores0, 2)
+ scores1 = torch.nn.functional.log_softmax(scores1, 1)
+ scores = scores.new_full((b, m + 1, n + 1), 0)
+ scores[:, :m, :n] = (scores0[:, :, :n] + scores1[:, :m, :]) / 2
+ scores[:, :-1, -1] = scores0[:, :, -1]
+ scores[:, -1, :-1] = scores1[:, -1, :]
+ return scores
+
+
+def arange_like(x, dim):
+ return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
diff --git a/imcui/third_party/GlueStick/gluestick/models/superpoint.py b/imcui/third_party/GlueStick/gluestick/models/superpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..872063275f4fde27f552bf2c2674dc60d5220ec9
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/models/superpoint.py
@@ -0,0 +1,229 @@
+"""
+Inference model of SuperPoint, a feature detector and descriptor.
+
+Described in:
+ SuperPoint: Self-Supervised Interest Point Detection and Description,
+ Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018.
+
+Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork
+"""
+
+import torch
+from torch import nn
+
+from .. import GLUESTICK_ROOT
+from ..models.base_model import BaseModel
+
+
+def simple_nms(scores, radius):
+ """Perform non maximum suppression on the heatmap using max-pooling.
+ This method does not suppress contiguous points that have the same score.
+ Args:
+ scores: the score heatmap of size `(B, H, W)`.
+ size: an interger scalar, the radius of the NMS window.
+ """
+
+ def max_pool(x):
+ return torch.nn.functional.max_pool2d(
+ x, kernel_size=radius * 2 + 1, stride=1, padding=radius)
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == max_pool(scores)
+ for _ in range(2):
+ supp_mask = max_pool(max_mask.float()) > 0
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == max_pool(supp_scores)
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+def remove_borders(keypoints, scores, b, h, w):
+ mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b))
+ mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b))
+ mask = mask_h & mask_w
+ return keypoints[mask], scores[mask]
+
+
+def top_k_keypoints(keypoints, scores, k):
+ if k >= len(keypoints):
+ return keypoints, scores
+ scores, indices = torch.topk(scores, k, dim=0, sorted=True)
+ return keypoints[indices], scores
+
+
+def sample_descriptors(keypoints, descriptors, s):
+ b, c, h, w = descriptors.shape
+ keypoints = keypoints - s / 2 + 0.5
+ keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
+ ).to(keypoints)[None]
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
+ args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1)
+ return descriptors
+
+
+class SuperPoint(BaseModel):
+ default_conf = {
+ 'has_detector': True,
+ 'has_descriptor': True,
+ 'descriptor_dim': 256,
+
+ # Inference
+ 'return_all': False,
+ 'sparse_outputs': True,
+ 'nms_radius': 4,
+ 'detection_threshold': 0.005,
+ 'max_num_keypoints': -1,
+ 'force_num_keypoints': False,
+ 'remove_borders': 4,
+ }
+ required_data_keys = ['image']
+
+ def _init(self, conf):
+ self.relu = nn.ReLU(inplace=True)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
+
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
+
+ if conf.has_detector:
+ self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
+
+ if conf.has_descriptor:
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convDb = nn.Conv2d(
+ c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0)
+
+ path = GLUESTICK_ROOT / 'resources' / 'weights' / 'superpoint_v1.pth'
+ if path.exists():
+ weights = torch.load(str(path), map_location='cpu')
+ else:
+ weights_url = "https://github.com/cvg/GlueStick/raw/main/resources/weights/superpoint_v1.pth"
+ weights = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu')
+ self.load_state_dict(weights, strict=False)
+
+ def _forward(self, data):
+ image = data['image']
+ if image.shape[1] == 3: # RGB
+ scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
+ image = (image * scale).sum(1, keepdim=True)
+
+ # Shared Encoder
+ x = self.relu(self.conv1a(image))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ pred = {}
+ if self.conf.has_detector and self.conf.max_num_keypoints != 0:
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x))
+ scores = self.convPb(cPa)
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+ b, c, h, w = scores.shape
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+ pred['keypoint_scores'] = dense_scores = scores
+ if self.conf.has_descriptor:
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x))
+ all_desc = self.convDb(cDa)
+ all_desc = torch.nn.functional.normalize(all_desc, p=2, dim=1)
+ pred['descriptors'] = all_desc
+
+ if self.conf.max_num_keypoints == 0: # Predict dense descriptors only
+ b_size = len(image)
+ device = image.device
+ return {
+ 'keypoints': torch.empty(b_size, 0, 2, device=device),
+ 'keypoint_scores': torch.empty(b_size, 0, device=device),
+ 'descriptors': torch.empty(b_size, self.conf.descriptor_dim, 0, device=device),
+ 'all_descriptors': all_desc
+ }
+
+ if self.conf.sparse_outputs:
+ assert self.conf.has_detector and self.conf.has_descriptor
+
+ scores = simple_nms(scores, self.conf.nms_radius)
+
+ # Extract keypoints
+ keypoints = [
+ torch.nonzero(s > self.conf.detection_threshold)
+ for s in scores]
+ scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
+
+ # Discard keypoints near the image borders
+ keypoints, scores = list(zip(*[
+ remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8)
+ for k, s in zip(keypoints, scores)]))
+
+ # Keep the k keypoints with highest score
+ if self.conf.max_num_keypoints > 0:
+ keypoints, scores = list(zip(*[
+ top_k_keypoints(k, s, self.conf.max_num_keypoints)
+ for k, s in zip(keypoints, scores)]))
+
+ # Convert (h, w) to (x, y)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+
+ if self.conf.force_num_keypoints:
+ _, _, h, w = data['image'].shape
+ assert self.conf.max_num_keypoints > 0
+ scores = list(scores)
+ for i in range(len(keypoints)):
+ k, s = keypoints[i], scores[i]
+ missing = self.conf.max_num_keypoints - len(k)
+ if missing > 0:
+ new_k = torch.rand(missing, 2).to(k)
+ new_k = new_k * k.new_tensor([[w - 1, h - 1]])
+ new_s = torch.zeros(missing).to(s)
+ keypoints[i] = torch.cat([k, new_k], 0)
+ scores[i] = torch.cat([s, new_s], 0)
+
+ # Extract descriptors
+ desc = [sample_descriptors(k[None], d[None], 8)[0]
+ for k, d in zip(keypoints, all_desc)]
+
+ if (len(keypoints) == 1) or self.conf.force_num_keypoints:
+ keypoints = torch.stack(keypoints, 0)
+ scores = torch.stack(scores, 0)
+ desc = torch.stack(desc, 0)
+
+ pred = {
+ 'keypoints': keypoints,
+ 'keypoint_scores': scores,
+ 'descriptors': desc,
+ }
+
+ if self.conf.return_all:
+ pred['all_descriptors'] = all_desc
+ pred['dense_score'] = dense_scores
+ else:
+ del all_desc
+ torch.cuda.empty_cache()
+
+ return pred
+
+ def loss(self, pred, data):
+ raise NotImplementedError
+
+ def metrics(self, pred, data):
+ raise NotImplementedError
diff --git a/imcui/third_party/GlueStick/gluestick/models/two_view_pipeline.py b/imcui/third_party/GlueStick/gluestick/models/two_view_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0e21c1f62e2bd4ad573ebb87ea5635742b5032e
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/models/two_view_pipeline.py
@@ -0,0 +1,176 @@
+"""
+A two-view sparse feature matching pipeline.
+
+This model contains sub-models for each step:
+ feature extraction, feature matching, outlier filtering, pose estimation.
+Each step is optional, and the features or matches can be provided as input.
+Default: SuperPoint with nearest neighbor matching.
+
+Convention for the matches: m0[i] is the index of the keypoint in image 1
+that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
+"""
+
+import numpy as np
+import torch
+
+from .. import get_model
+from .base_model import BaseModel
+
+
+def keep_quadrant_kp_subset(keypoints, scores, descs, h, w):
+ """Keep only keypoints in one of the four quadrant of the image."""
+ h2, w2 = h // 2, w // 2
+ w_x = np.random.choice([0, w2])
+ w_y = np.random.choice([0, h2])
+ valid_mask = ((keypoints[..., 0] >= w_x)
+ & (keypoints[..., 0] < w_x + w2)
+ & (keypoints[..., 1] >= w_y)
+ & (keypoints[..., 1] < w_y + h2))
+ keypoints = keypoints[valid_mask][None]
+ scores = scores[valid_mask][None]
+ descs = descs.permute(0, 2, 1)[valid_mask].t()[None]
+ return keypoints, scores, descs
+
+
+def keep_random_kp_subset(keypoints, scores, descs, num_selected):
+ """Keep a random subset of keypoints."""
+ num_kp = keypoints.shape[1]
+ selected_kp = torch.randperm(num_kp)[:num_selected]
+ keypoints = keypoints[:, selected_kp]
+ scores = scores[:, selected_kp]
+ descs = descs[:, :, selected_kp]
+ return keypoints, scores, descs
+
+
+def keep_best_kp_subset(keypoints, scores, descs, num_selected):
+ """Keep the top num_selected best keypoints."""
+ sorted_indices = torch.sort(scores, dim=1)[1]
+ selected_kp = sorted_indices[:, -num_selected:]
+ keypoints = torch.gather(keypoints, 1,
+ selected_kp[:, :, None].repeat(1, 1, 2))
+ scores = torch.gather(scores, 1, selected_kp)
+ descs = torch.gather(descs, 2,
+ selected_kp[:, None].repeat(1, descs.shape[1], 1))
+ return keypoints, scores, descs
+
+
+class TwoViewPipeline(BaseModel):
+ default_conf = {
+ 'extractor': {
+ 'name': 'superpoint',
+ 'trainable': False,
+ },
+ 'use_lines': False,
+ 'use_points': True,
+ 'randomize_num_kp': False,
+ 'detector': {'name': None},
+ 'descriptor': {'name': None},
+ 'matcher': {'name': 'nearest_neighbor_matcher'},
+ 'filter': {'name': None},
+ 'solver': {'name': None},
+ 'ground_truth': {
+ 'from_pose_depth': False,
+ 'from_homography': False,
+ 'th_positive': 3,
+ 'th_negative': 5,
+ 'reward_positive': 1,
+ 'reward_negative': -0.25,
+ 'is_likelihood_soft': True,
+ 'p_random_occluders': 0,
+ 'n_line_sampled_pts': 50,
+ 'line_perp_dist_th': 5,
+ 'overlap_th': 0.2,
+ 'min_visibility_th': 0.5
+ },
+ }
+ required_data_keys = ['image0', 'image1']
+ strict_conf = False # need to pass new confs to children models
+ components = [
+ 'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver']
+
+ def _init(self, conf):
+ if conf.extractor.name:
+ self.extractor = get_model(conf.extractor.name)(conf.extractor)
+ else:
+ if self.conf.detector.name:
+ self.detector = get_model(conf.detector.name)(conf.detector)
+ else:
+ self.required_data_keys += ['keypoints0', 'keypoints1']
+ if self.conf.descriptor.name:
+ self.descriptor = get_model(conf.descriptor.name)(
+ conf.descriptor)
+ else:
+ self.required_data_keys += ['descriptors0', 'descriptors1']
+
+ if conf.matcher.name:
+ self.matcher = get_model(conf.matcher.name)(conf.matcher)
+ else:
+ self.required_data_keys += ['matches0']
+
+ if conf.filter.name:
+ self.filter = get_model(conf.filter.name)(conf.filter)
+
+ if conf.solver.name:
+ self.solver = get_model(conf.solver.name)(conf.solver)
+
+ def _forward(self, data):
+
+ def process_siamese(data, i):
+ data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i}
+ if self.conf.extractor.name:
+ pred_i = self.extractor(data_i)
+ else:
+ pred_i = {}
+ if self.conf.detector.name:
+ pred_i = self.detector(data_i)
+ else:
+ for k in ['keypoints', 'keypoint_scores', 'descriptors',
+ 'lines', 'line_scores', 'line_descriptors',
+ 'valid_lines']:
+ if k in data_i:
+ pred_i[k] = data_i[k]
+ if self.conf.descriptor.name:
+ pred_i = {
+ **pred_i, **self.descriptor({**data_i, **pred_i})}
+ return pred_i
+
+ pred0 = process_siamese(data, '0')
+ pred1 = process_siamese(data, '1')
+
+ pred = {**{k + '0': v for k, v in pred0.items()},
+ **{k + '1': v for k, v in pred1.items()}}
+
+ if self.conf.matcher.name:
+ pred = {**pred, **self.matcher({**data, **pred})}
+
+ if self.conf.filter.name:
+ pred = {**pred, **self.filter({**data, **pred})}
+
+ if self.conf.solver.name:
+ pred = {**pred, **self.solver({**data, **pred})}
+
+ return pred
+
+ def loss(self, pred, data):
+ losses = {}
+ total = 0
+ for k in self.components:
+ if self.conf[k].name:
+ try:
+ losses_ = getattr(self, k).loss(pred, {**pred, **data})
+ except NotImplementedError:
+ continue
+ losses = {**losses, **losses_}
+ total = losses_['total'] + total
+ return {**losses, 'total': total}
+
+ def metrics(self, pred, data):
+ metrics = {}
+ for k in self.components:
+ if self.conf[k].name:
+ try:
+ metrics_ = getattr(self, k).metrics(pred, {**pred, **data})
+ except NotImplementedError:
+ continue
+ metrics = {**metrics, **metrics_}
+ return metrics
diff --git a/imcui/third_party/GlueStick/gluestick/models/wireframe.py b/imcui/third_party/GlueStick/gluestick/models/wireframe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e3dd9873c6fdb4edcb4c75a103673ee2cb3b3fa
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/models/wireframe.py
@@ -0,0 +1,274 @@
+import numpy as np
+import torch
+from pytlsd import lsd
+from sklearn.cluster import DBSCAN
+
+from .base_model import BaseModel
+from .superpoint import SuperPoint, sample_descriptors
+from ..geometry import warp_lines_torch
+
+
+def lines_to_wireframe(lines, line_scores, all_descs, conf):
+ """ Given a set of lines, their score and dense descriptors,
+ merge close-by endpoints and compute a wireframe defined by
+ its junctions and connectivity.
+ Returns:
+ junctions: list of [num_junc, 2] tensors listing all wireframe junctions
+ junc_scores: list of [num_junc] tensors with the junction score
+ junc_descs: list of [dim, num_junc] tensors with the junction descriptors
+ connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected
+ new_lines: the new set of [b_size, num_lines, 2, 2] lines
+ lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint
+ num_true_junctions: a list of the number of valid junctions for each image in the batch,
+ i.e. before filling with random ones
+ """
+ b_size, _, _, _ = all_descs.shape
+ device = lines.device
+ endpoints = lines.reshape(b_size, -1, 2)
+
+ (junctions, junc_scores, junc_descs, connectivity, new_lines,
+ lines_junc_idx, num_true_junctions) = [], [], [], [], [], [], []
+ for bs in range(b_size):
+ # Cluster the junctions that are close-by
+ db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit(
+ endpoints[bs].cpu().numpy())
+ clusters = db.labels_
+ n_clusters = len(set(clusters))
+ num_true_junctions.append(n_clusters)
+
+ # Compute the average junction and score for each cluster
+ clusters = torch.tensor(clusters, dtype=torch.long,
+ device=device)
+ new_junc = torch.zeros(n_clusters, 2, dtype=torch.float,
+ device=device)
+ new_junc.scatter_reduce_(0, clusters[:, None].repeat(1, 2),
+ endpoints[bs], reduce='mean',
+ include_self=False)
+ junctions.append(new_junc)
+ new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device)
+ new_scores.scatter_reduce_(
+ 0, clusters, torch.repeat_interleave(line_scores[bs], 2),
+ reduce='mean', include_self=False)
+ junc_scores.append(new_scores)
+
+ # Compute the new lines
+ new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2))
+ lines_junc_idx.append(clusters.reshape(-1, 2))
+
+ # Compute the junction connectivity
+ junc_connect = torch.eye(n_clusters, dtype=torch.bool,
+ device=device)
+ pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
+ junc_connect[pairs[:, 0], pairs[:, 1]] = True
+ junc_connect[pairs[:, 1], pairs[:, 0]] = True
+ connectivity.append(junc_connect)
+
+ # Interpolate the new junction descriptors
+ junc_descs.append(sample_descriptors(
+ junctions[-1][None], all_descs[bs:(bs + 1)], 8)[0])
+
+ new_lines = torch.stack(new_lines, dim=0)
+ lines_junc_idx = torch.stack(lines_junc_idx, dim=0)
+ return (junctions, junc_scores, junc_descs, connectivity,
+ new_lines, lines_junc_idx, num_true_junctions)
+
+
+class SPWireframeDescriptor(BaseModel):
+ default_conf = {
+ 'sp_params': {
+ 'has_detector': True,
+ 'has_descriptor': True,
+ 'descriptor_dim': 256,
+ 'trainable': False,
+
+ # Inference
+ 'return_all': True,
+ 'sparse_outputs': True,
+ 'nms_radius': 4,
+ 'detection_threshold': 0.005,
+ 'max_num_keypoints': 1000,
+ 'force_num_keypoints': True,
+ 'remove_borders': 4,
+ },
+ 'wireframe_params': {
+ 'merge_points': True,
+ 'merge_line_endpoints': True,
+ 'nms_radius': 3,
+ 'max_n_junctions': 500,
+ },
+ 'max_n_lines': 250,
+ 'min_length': 15,
+ }
+ required_data_keys = ['image']
+
+ def _init(self, conf):
+ self.conf = conf
+ self.sp = SuperPoint(conf.sp_params)
+
+ def detect_lsd_lines(self, x, max_n_lines=None):
+ if max_n_lines is None:
+ max_n_lines = self.conf.max_n_lines
+ lines, scores, valid_lines = [], [], []
+ for b in range(len(x)):
+ # For each image on batch
+ img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8)
+ if max_n_lines is None:
+ b_segs = lsd(img)
+ else:
+ for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]:
+ b_segs = lsd(img, scale=s)
+ if len(b_segs) >= max_n_lines:
+ break
+
+ segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1)
+ # Remove short lines
+ b_segs = b_segs[segs_length >= self.conf.min_length]
+ segs_length = segs_length[segs_length >= self.conf.min_length]
+ b_scores = b_segs[:, -1] * np.sqrt(segs_length)
+ # Take the most relevant segments with
+ indices = np.argsort(-b_scores)
+ if max_n_lines is not None:
+ indices = indices[:max_n_lines]
+ lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2)))
+ scores.append(torch.from_numpy(b_scores[indices]))
+ valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool))
+
+ lines = torch.stack(lines).to(x)
+ scores = torch.stack(scores).to(x)
+ valid_lines = torch.stack(valid_lines).to(x.device)
+ return lines, scores, valid_lines
+
+ def _forward(self, data):
+ b_size, _, h, w = data['image'].shape
+ device = data['image'].device
+
+ if not self.conf.sp_params.force_num_keypoints:
+ assert b_size == 1, "Only batch size of 1 accepted for non padded inputs"
+
+ # Line detection
+ if 'lines' not in data or 'line_scores' not in data:
+ if 'original_img' in data:
+ # Detect more lines, because when projecting them to the image most of them will be discarded
+ lines, line_scores, valid_lines = self.detect_lsd_lines(
+ data['original_img'], self.conf.max_n_lines * 3)
+ # Apply the same transformation that is applied in homography_adaptation
+ lines, valid_lines2 = warp_lines_torch(lines, data['H'], False, data['image'].shape[-2:])
+ valid_lines = valid_lines & valid_lines2
+ lines[~valid_lines] = -1
+ line_scores[~valid_lines] = 0
+ # Re-sort the line segments to pick the ones that are inside the image and have bigger score
+ sorted_scores, sorting_indices = torch.sort(line_scores, dim=-1, descending=True)
+ line_scores = sorted_scores[:, :self.conf.max_n_lines]
+ sorting_indices = sorting_indices[:, :self.conf.max_n_lines]
+ lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1)
+ valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1)
+ else:
+ lines, line_scores, valid_lines = self.detect_lsd_lines(data['image'])
+
+ else:
+ lines, line_scores, valid_lines = data['lines'], data['line_scores'], data['valid_lines']
+ if line_scores.shape[-1] != 0:
+ line_scores /= (line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None])
+
+ # SuperPoint prediction
+ pred = self.sp(data)
+
+ # Remove keypoints that are too close to line endpoints
+ if self.conf.wireframe_params.merge_points:
+ kp = pred['keypoints']
+ line_endpts = lines.reshape(b_size, -1, 2)
+ dist_pt_lines = torch.norm(
+ kp[:, :, None] - line_endpts[:, None], dim=-1)
+ # For each keypoint, mark it as valid or to remove
+ pts_to_remove = torch.any(
+ dist_pt_lines < self.conf.sp_params.nms_radius, dim=2)
+ # Simply remove them (we assume batch_size = 1 here)
+ assert len(kp) == 1
+ pred['keypoints'] = pred['keypoints'][0][~pts_to_remove[0]][None]
+ pred['keypoint_scores'] = pred['keypoint_scores'][0][~pts_to_remove[0]][None]
+ pred['descriptors'] = pred['descriptors'][0].T[~pts_to_remove[0]].T[None]
+
+ # Connect the lines together to form a wireframe
+ orig_lines = lines.clone()
+ if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0:
+ # Merge first close-by endpoints to connect lines
+ (line_points, line_pts_scores, line_descs, line_association,
+ lines, lines_junc_idx, num_true_junctions) = lines_to_wireframe(
+ lines, line_scores, pred['all_descriptors'],
+ conf=self.conf.wireframe_params)
+
+ # Add the keypoints to the junctions and fill the rest with random keypoints
+ (all_points, all_scores, all_descs,
+ pl_associativity) = [], [], [], []
+ for bs in range(b_size):
+ all_points.append(torch.cat(
+ [line_points[bs], pred['keypoints'][bs]], dim=0))
+ all_scores.append(torch.cat(
+ [line_pts_scores[bs], pred['keypoint_scores'][bs]], dim=0))
+ all_descs.append(torch.cat(
+ [line_descs[bs], pred['descriptors'][bs]], dim=1))
+
+ associativity = torch.eye(len(all_points[-1]), dtype=torch.bool, device=device)
+ associativity[:num_true_junctions[bs], :num_true_junctions[bs]] = \
+ line_association[bs][:num_true_junctions[bs], :num_true_junctions[bs]]
+ pl_associativity.append(associativity)
+
+ all_points = torch.stack(all_points, dim=0)
+ all_scores = torch.stack(all_scores, dim=0)
+ all_descs = torch.stack(all_descs, dim=0)
+ pl_associativity = torch.stack(pl_associativity, dim=0)
+ else:
+ # Lines are independent
+ all_points = torch.cat([lines.reshape(b_size, -1, 2),
+ pred['keypoints']], dim=1)
+ n_pts = all_points.shape[1]
+ num_lines = lines.shape[1]
+ num_true_junctions = [num_lines * 2] * b_size
+ all_scores = torch.cat([
+ torch.repeat_interleave(line_scores, 2, dim=1),
+ pred['keypoint_scores']], dim=1)
+ pred['line_descriptors'] = self.endpoints_pooling(
+ lines, pred['all_descriptors'], (h, w))
+ all_descs = torch.cat([
+ pred['line_descriptors'].reshape(b_size, self.conf.sp_params.descriptor_dim, -1),
+ pred['descriptors']], dim=2)
+ pl_associativity = torch.eye(
+ n_pts, dtype=torch.bool,
+ device=device)[None].repeat(b_size, 1, 1)
+ lines_junc_idx = torch.arange(
+ num_lines * 2, device=device).reshape(1, -1, 2).repeat(b_size, 1, 1)
+
+ del pred['all_descriptors'] # Remove dense descriptors to save memory
+ torch.cuda.empty_cache()
+
+ return {'keypoints': all_points,
+ 'keypoint_scores': all_scores,
+ 'descriptors': all_descs,
+ 'pl_associativity': pl_associativity,
+ 'num_junctions': torch.tensor(num_true_junctions),
+ 'lines': lines,
+ 'orig_lines': orig_lines,
+ 'lines_junc_idx': lines_junc_idx,
+ 'line_scores': line_scores,
+ 'valid_lines': valid_lines}
+
+ @staticmethod
+ def endpoints_pooling(segs, all_descriptors, img_shape):
+ assert segs.ndim == 4 and segs.shape[-2:] == (2, 2)
+ filter_shape = all_descriptors.shape[-2:]
+ scale_x = filter_shape[1] / img_shape[1]
+ scale_y = filter_shape[0] / img_shape[0]
+
+ scaled_segs = torch.round(segs * torch.tensor([scale_x, scale_y]).to(segs)).long()
+ scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1)
+ scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1)
+ line_descriptors = [all_descriptors[None, b, ..., torch.squeeze(b_segs[..., 1]), torch.squeeze(b_segs[..., 0])]
+ for b, b_segs in enumerate(scaled_segs)]
+ line_descriptors = torch.cat(line_descriptors)
+ return line_descriptors # Shape (1, 256, 308, 2)
+
+ def loss(self, pred, data):
+ raise NotImplementedError
+
+ def metrics(self, pred, data):
+ return {}
diff --git a/imcui/third_party/GlueStick/gluestick/run.py b/imcui/third_party/GlueStick/gluestick/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..85fd8af801dd18936163ac1af6d331f54965bfa5
--- /dev/null
+++ b/imcui/third_party/GlueStick/gluestick/run.py
@@ -0,0 +1,107 @@
+import argparse
+import os
+from os.path import join
+
+import cv2
+import torch
+from matplotlib import pyplot as plt
+
+from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
+from .drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches
+from .models.two_view_pipeline import TwoViewPipeline
+
+
+def main():
+ # Parse input parameters
+ parser = argparse.ArgumentParser(
+ prog='GlueStick Demo',
+ description='Demo app to show the point and line matches obtained by GlueStick')
+ parser.add_argument('-img1', default=join('resources' + os.path.sep + 'img1.jpg'))
+ parser.add_argument('-img2', default=join('resources' + os.path.sep + 'img2.jpg'))
+ parser.add_argument('--max_pts', type=int, default=1000)
+ parser.add_argument('--max_lines', type=int, default=300)
+ parser.add_argument('--skip-imshow', default=False, action='store_true')
+ args = parser.parse_args()
+
+ # Evaluation config
+ conf = {
+ 'name': 'two_view_pipeline',
+ 'use_lines': True,
+ 'extractor': {
+ 'name': 'wireframe',
+ 'sp_params': {
+ 'force_num_keypoints': False,
+ 'max_num_keypoints': args.max_pts,
+ },
+ 'wireframe_params': {
+ 'merge_points': True,
+ 'merge_line_endpoints': True,
+ },
+ 'max_n_lines': args.max_lines,
+ },
+ 'matcher': {
+ 'name': 'gluestick',
+ 'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'),
+ 'trainable': False,
+ },
+ 'ground_truth': {
+ 'from_pose_depth': False,
+ }
+ }
+
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ pipeline_model = TwoViewPipeline(conf).to(device).eval()
+
+ gray0 = cv2.imread(args.img1, 0)
+ gray1 = cv2.imread(args.img2, 0)
+
+ torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
+ torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None]
+ x = {'image0': torch_gray0, 'image1': torch_gray1}
+ pred = pipeline_model(x)
+
+ pred = batch_to_np(pred)
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
+ m0 = pred["matches0"]
+
+ line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
+ line_matches = pred["line_matches0"]
+
+ valid_matches = m0 != -1
+ match_indices = m0[valid_matches]
+ matched_kps0 = kp0[valid_matches]
+ matched_kps1 = kp1[match_indices]
+
+ valid_matches = line_matches != -1
+ match_indices = line_matches[valid_matches]
+ matched_lines0 = line_seg0[valid_matches]
+ matched_lines1 = line_seg1[match_indices]
+
+ # Plot the matches
+ img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR)
+ plot_images([img0, img1], ['Image 1 - detected lines', 'Image 2 - detected lines'], dpi=200, pad=2.0)
+ plot_lines([line_seg0, line_seg1], ps=4, lw=2)
+ plt.gcf().canvas.manager.set_window_title('Detected Lines')
+ plt.savefig('detected_lines.png')
+
+ plot_images([img0, img1], ['Image 1 - detected points', 'Image 2 - detected points'], dpi=200, pad=2.0)
+ plot_keypoints([kp0, kp1], colors='c')
+ plt.gcf().canvas.manager.set_window_title('Detected Points')
+ plt.savefig('detected_points.png')
+
+ plot_images([img0, img1], ['Image 1 - line matches', 'Image 2 - line matches'], dpi=200, pad=2.0)
+ plot_color_line_matches([matched_lines0, matched_lines1], lw=2)
+ plt.gcf().canvas.manager.set_window_title('Line Matches')
+ plt.savefig('line_matches.png')
+
+ plot_images([img0, img1], ['Image 1 - point matches', 'Image 2 - point matches'], dpi=200, pad=2.0)
+ plot_matches(matched_kps0, matched_kps1, 'green', lw=1, ps=0)
+ plt.gcf().canvas.manager.set_window_title('Point Matches')
+ plt.savefig('point_matches.png')
+ if not args.skip_imshow:
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/imcui/third_party/LightGlue/.github/workflows/code-quality.yml b/imcui/third_party/LightGlue/.github/workflows/code-quality.yml
new file mode 100644
index 0000000000000000000000000000000000000000..368b225f17a52121ddb6626ca7d9699c6538fcb3
--- /dev/null
+++ b/imcui/third_party/LightGlue/.github/workflows/code-quality.yml
@@ -0,0 +1,24 @@
+name: Format and Lint Checks
+on:
+ push:
+ branches:
+ - main
+ paths:
+ - '*.py'
+ pull_request:
+ types: [ assigned, opened, synchronize, reopened ]
+jobs:
+ check:
+ name: Format and Lint Checks
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ cache: 'pip'
+ - run: python -m pip install --upgrade pip
+ - run: python -m pip install .[dev]
+ - run: python -m flake8 .
+ - run: python -m isort . --check-only --diff
+ - run: python -m black . --check --diff
diff --git a/imcui/third_party/LightGlue/benchmark.py b/imcui/third_party/LightGlue/benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..b160f3a37bf64d2a42884ea29f165fb3f325b9cf
--- /dev/null
+++ b/imcui/third_party/LightGlue/benchmark.py
@@ -0,0 +1,255 @@
+# Benchmark script for LightGlue on real images
+import argparse
+import time
+from collections import defaultdict
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch._dynamo
+
+from lightglue import LightGlue, SuperPoint
+from lightglue.utils import load_image
+
+torch.set_grad_enabled(False)
+
+
+def measure(matcher, data, device="cuda", r=100):
+ timings = np.zeros((r, 1))
+ if device.type == "cuda":
+ starter = torch.cuda.Event(enable_timing=True)
+ ender = torch.cuda.Event(enable_timing=True)
+ # warmup
+ for _ in range(10):
+ _ = matcher(data)
+ # measurements
+ with torch.no_grad():
+ for rep in range(r):
+ if device.type == "cuda":
+ starter.record()
+ _ = matcher(data)
+ ender.record()
+ # sync gpu
+ torch.cuda.synchronize()
+ curr_time = starter.elapsed_time(ender)
+ else:
+ start = time.perf_counter()
+ _ = matcher(data)
+ curr_time = (time.perf_counter() - start) * 1e3
+ timings[rep] = curr_time
+ mean_syn = np.sum(timings) / r
+ std_syn = np.std(timings)
+ return {"mean": mean_syn, "std": std_syn}
+
+
+def print_as_table(d, title, cnames):
+ print()
+ header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
+ print(header)
+ print("-" * len(header))
+ for k, l in d.items():
+ print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
+ parser.add_argument(
+ "--device",
+ choices=["auto", "cuda", "cpu", "mps"],
+ default="auto",
+ help="device to benchmark on",
+ )
+ parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
+ parser.add_argument(
+ "--no_flash", action="store_true", help="disable FlashAttention"
+ )
+ parser.add_argument(
+ "--no_prune_thresholds",
+ action="store_true",
+ help="disable pruning thresholds (i.e. always do pruning)",
+ )
+ parser.add_argument(
+ "--add_superglue",
+ action="store_true",
+ help="add SuperGlue to the benchmark (requires hloc)",
+ )
+ parser.add_argument(
+ "--measure", default="time", choices=["time", "log-time", "throughput"]
+ )
+ parser.add_argument(
+ "--repeat", "--r", type=int, default=100, help="repetitions of measurements"
+ )
+ parser.add_argument(
+ "--num_keypoints",
+ nargs="+",
+ type=int,
+ default=[256, 512, 1024, 2048, 4096],
+ help="number of keypoints (list separated by spaces)",
+ )
+ parser.add_argument(
+ "--matmul_precision", default="highest", choices=["highest", "high", "medium"]
+ )
+ parser.add_argument(
+ "--save", default=None, type=str, help="path where figure should be saved"
+ )
+ args = parser.parse_intermixed_args()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if args.device != "auto":
+ device = torch.device(args.device)
+
+ print("Running benchmark on device:", device)
+
+ images = Path("assets")
+ inputs = {
+ "easy": (
+ load_image(images / "DSC_0411.JPG"),
+ load_image(images / "DSC_0410.JPG"),
+ ),
+ "difficult": (
+ load_image(images / "sacre_coeur1.jpg"),
+ load_image(images / "sacre_coeur2.jpg"),
+ ),
+ }
+
+ configs = {
+ "LightGlue-full": {
+ "depth_confidence": -1,
+ "width_confidence": -1,
+ },
+ # 'LG-prune': {
+ # 'width_confidence': -1,
+ # },
+ # 'LG-depth': {
+ # 'depth_confidence': -1,
+ # },
+ "LightGlue-adaptive": {},
+ }
+
+ if args.compile:
+ configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
+
+ sg_configs = {
+ # 'SuperGlue': {},
+ "SuperGlue-fast": {"sinkhorn_iterations": 5}
+ }
+
+ torch.set_float32_matmul_precision(args.matmul_precision)
+
+ results = {k: defaultdict(list) for k, v in inputs.items()}
+
+ extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
+ extractor = extractor.eval().to(device)
+ figsize = (len(inputs) * 4.5, 4.5)
+ fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
+ axes = axes if len(inputs) > 1 else [axes]
+ fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
+
+ for title, ax in zip(inputs.keys(), axes):
+ ax.set_xscale("log", base=2)
+ bases = [2**x for x in range(7, 16)]
+ ax.set_xticks(bases, bases)
+ ax.grid(which="major")
+ if args.measure == "log-time":
+ ax.set_yscale("log")
+ yticks = [10**x for x in range(6)]
+ ax.set_yticks(yticks, yticks)
+ mpos = [10**x * i for x in range(6) for i in range(2, 10)]
+ mlabel = [
+ 10**x * i if i in [2, 5] else None
+ for x in range(6)
+ for i in range(2, 10)
+ ]
+ ax.set_yticks(mpos, mlabel, minor=True)
+ ax.grid(which="minor", linewidth=0.2)
+ ax.set_title(title)
+
+ ax.set_xlabel("# keypoints")
+ if args.measure == "throughput":
+ ax.set_ylabel("Throughput [pairs/s]")
+ else:
+ ax.set_ylabel("Latency [ms]")
+
+ for name, conf in configs.items():
+ print("Run benchmark for:", name)
+ torch.cuda.empty_cache()
+ matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
+ if args.no_prune_thresholds:
+ matcher.pruning_keypoint_thresholds = {
+ k: -1 for k in matcher.pruning_keypoint_thresholds
+ }
+ matcher = matcher.eval().to(device)
+ if name.endswith("compile"):
+ import torch._dynamo
+
+ torch._dynamo.reset() # avoid buffer overflow
+ matcher.compile()
+ for pair_name, ax in zip(inputs.keys(), axes):
+ image0, image1 = [x.to(device) for x in inputs[pair_name]]
+ runtimes = []
+ for num_kpts in args.num_keypoints:
+ extractor.conf.max_num_keypoints = num_kpts
+ feats0 = extractor.extract(image0)
+ feats1 = extractor.extract(image1)
+ runtime = measure(
+ matcher,
+ {"image0": feats0, "image1": feats1},
+ device=device,
+ r=args.repeat,
+ )["mean"]
+ results[pair_name][name].append(
+ 1000 / runtime if args.measure == "throughput" else runtime
+ )
+ ax.plot(
+ args.num_keypoints, results[pair_name][name], label=name, marker="o"
+ )
+ del matcher, feats0, feats1
+
+ if args.add_superglue:
+ from hloc.matchers.superglue import SuperGlue
+
+ for name, conf in sg_configs.items():
+ print("Run benchmark for:", name)
+ matcher = SuperGlue(conf)
+ matcher = matcher.eval().to(device)
+ for pair_name, ax in zip(inputs.keys(), axes):
+ image0, image1 = [x.to(device) for x in inputs[pair_name]]
+ runtimes = []
+ for num_kpts in args.num_keypoints:
+ extractor.conf.max_num_keypoints = num_kpts
+ feats0 = extractor.extract(image0)
+ feats1 = extractor.extract(image1)
+ data = {
+ "image0": image0[None],
+ "image1": image1[None],
+ **{k + "0": v for k, v in feats0.items()},
+ **{k + "1": v for k, v in feats1.items()},
+ }
+ data["scores0"] = data["keypoint_scores0"]
+ data["scores1"] = data["keypoint_scores1"]
+ data["descriptors0"] = (
+ data["descriptors0"].transpose(-1, -2).contiguous()
+ )
+ data["descriptors1"] = (
+ data["descriptors1"].transpose(-1, -2).contiguous()
+ )
+ runtime = measure(matcher, data, device=device, r=args.repeat)[
+ "mean"
+ ]
+ results[pair_name][name].append(
+ 1000 / runtime if args.measure == "throughput" else runtime
+ )
+ ax.plot(
+ args.num_keypoints, results[pair_name][name], label=name, marker="o"
+ )
+ del matcher, data, image0, image1, feats0, feats1
+
+ for name, runtimes in results.items():
+ print_as_table(runtimes, name, args.num_keypoints)
+
+ axes[0].legend()
+ fig.tight_layout()
+ if args.save:
+ plt.savefig(args.save, dpi=fig.dpi)
+ plt.show()
diff --git a/imcui/third_party/LightGlue/lightglue/__init__.py b/imcui/third_party/LightGlue/lightglue/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b84d285cf2a29e3b17c8c2c052a45f856dcf89c0
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/__init__.py
@@ -0,0 +1,7 @@
+from .aliked import ALIKED # noqa
+from .disk import DISK # noqa
+from .dog_hardnet import DoGHardNet # noqa
+from .lightglue import LightGlue # noqa
+from .sift import SIFT # noqa
+from .superpoint import SuperPoint # noqa
+from .utils import match_pair # noqa
diff --git a/imcui/third_party/LightGlue/lightglue/aliked.py b/imcui/third_party/LightGlue/lightglue/aliked.py
new file mode 100644
index 0000000000000000000000000000000000000000..1161e1fc2d0cce32583031229e8ad4bb84f9a40c
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/aliked.py
@@ -0,0 +1,758 @@
+# BSD 3-Clause License
+
+# Copyright (c) 2022, Zhao Xiaoming
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# Authors:
+# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
+# Code from https://github.com/Shiaoming/ALIKED
+
+from typing import Callable, Optional
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from kornia.color import grayscale_to_rgb
+from torch import nn
+from torch.nn.modules.utils import _pair
+from torchvision.models import resnet
+
+from .utils import Extractor
+
+
+def get_patches(
+ tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
+) -> torch.Tensor:
+ c, h, w = tensor.shape
+ corner = (required_corners - ps / 2 + 1).long()
+ corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
+ corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
+ offset = torch.arange(0, ps)
+
+ kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
+ x, y = torch.meshgrid(offset, offset, **kw)
+ patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
+ patches = patches.to(corner) + corner[None, None]
+ pts = patches.reshape(-1, 2)
+ sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
+ sampled = sampled.reshape(ps, ps, -1, c)
+ assert sampled.shape[:3] == patches.shape[:3]
+ return sampled.permute(2, 3, 0, 1)
+
+
+def simple_nms(scores: torch.Tensor, nms_radius: int):
+ """Fast Non-maximum suppression to remove nearby points"""
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == torch.nn.functional.max_pool2d(
+ scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+ )
+
+ for _ in range(2):
+ supp_mask = (
+ torch.nn.functional.max_pool2d(
+ max_mask.float(),
+ kernel_size=nms_radius * 2 + 1,
+ stride=1,
+ padding=nms_radius,
+ )
+ > 0
+ )
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
+ supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+ )
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+class DKD(nn.Module):
+ def __init__(
+ self,
+ radius: int = 2,
+ top_k: int = 0,
+ scores_th: float = 0.2,
+ n_limit: int = 20000,
+ ):
+ """
+ Args:
+ radius: soft detection radius, kernel size is (2 * radius + 1)
+ top_k: top_k > 0: return top k keypoints
+ scores_th: top_k <= 0 threshold mode:
+ scores_th > 0: return keypoints with scores>scores_th
+ else: return keypoints with scores > scores.mean()
+ n_limit: max number of keypoint in threshold mode
+ """
+ super().__init__()
+ self.radius = radius
+ self.top_k = top_k
+ self.scores_th = scores_th
+ self.n_limit = n_limit
+ self.kernel_size = 2 * self.radius + 1
+ self.temperature = 0.1 # tuned temperature
+ self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
+ # local xy grid
+ x = torch.linspace(-self.radius, self.radius, self.kernel_size)
+ # (kernel_size*kernel_size) x 2 : (w,h)
+ kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
+ self.hw_grid = (
+ torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
+ )
+
+ def forward(
+ self,
+ scores_map: torch.Tensor,
+ sub_pixel: bool = True,
+ image_size: Optional[torch.Tensor] = None,
+ ):
+ """
+ :param scores_map: Bx1xHxW
+ :param descriptor_map: BxCxHxW
+ :param sub_pixel: whether to use sub-pixel keypoint detection
+ :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
+ """
+ b, c, h, w = scores_map.shape
+ scores_nograd = scores_map.detach()
+ nms_scores = simple_nms(scores_nograd, self.radius)
+
+ # remove border
+ nms_scores[:, :, : self.radius, :] = 0
+ nms_scores[:, :, :, : self.radius] = 0
+ if image_size is not None:
+ for i in range(scores_map.shape[0]):
+ w, h = image_size[i].long()
+ nms_scores[i, :, h.item() - self.radius :, :] = 0
+ nms_scores[i, :, :, w.item() - self.radius :] = 0
+ else:
+ nms_scores[:, :, -self.radius :, :] = 0
+ nms_scores[:, :, :, -self.radius :] = 0
+
+ # detect keypoints without grad
+ if self.top_k > 0:
+ topk = torch.topk(nms_scores.view(b, -1), self.top_k)
+ indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
+ else:
+ if self.scores_th > 0:
+ masks = nms_scores > self.scores_th
+ if masks.sum() == 0:
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
+ else:
+ th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
+ masks = nms_scores > th.reshape(b, 1, 1, 1)
+ masks = masks.reshape(b, -1)
+
+ indices_keypoints = [] # list, B x (any size)
+ scores_view = scores_nograd.reshape(b, -1)
+ for mask, scores in zip(masks, scores_view):
+ indices = mask.nonzero()[:, 0]
+ if len(indices) > self.n_limit:
+ kpts_sc = scores[indices]
+ sort_idx = kpts_sc.sort(descending=True)[1]
+ sel_idx = sort_idx[: self.n_limit]
+ indices = indices[sel_idx]
+ indices_keypoints.append(indices)
+
+ wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
+
+ keypoints = []
+ scoredispersitys = []
+ kptscores = []
+ if sub_pixel:
+ # detect soft keypoints with grad backpropagation
+ patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
+ self.hw_grid = self.hw_grid.to(scores_map) # to device
+ for b_idx in range(b):
+ patch = patches[b_idx].t() # (H*W) x (kernel**2)
+ indices_kpt = indices_keypoints[
+ b_idx
+ ] # one dimension vector, say its size is M
+ patch_scores = patch[indices_kpt] # M x (kernel**2)
+ keypoints_xy_nms = torch.stack(
+ [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
+ dim=1,
+ ) # Mx2
+
+ # max is detached to prevent undesired backprop loops in the graph
+ max_v = patch_scores.max(dim=1).values.detach()[:, None]
+ x_exp = (
+ (patch_scores - max_v) / self.temperature
+ ).exp() # M * (kernel**2), in [0, 1]
+
+ # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
+ xy_residual = (
+ x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
+ ) # Soft-argmax, Mx2
+
+ hw_grid_dist2 = (
+ torch.norm(
+ (self.hw_grid[None, :, :] - xy_residual[:, None, :])
+ / self.radius,
+ dim=-1,
+ )
+ ** 2
+ )
+ scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
+
+ # compute result keypoints
+ keypoints_xy = keypoints_xy_nms + xy_residual
+ keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
+
+ kptscore = torch.nn.functional.grid_sample(
+ scores_map[b_idx].unsqueeze(0),
+ keypoints_xy.view(1, 1, -1, 2),
+ mode="bilinear",
+ align_corners=True,
+ )[
+ 0, 0, 0, :
+ ] # CxN
+
+ keypoints.append(keypoints_xy)
+ scoredispersitys.append(scoredispersity)
+ kptscores.append(kptscore)
+ else:
+ for b_idx in range(b):
+ indices_kpt = indices_keypoints[
+ b_idx
+ ] # one dimension vector, say its size is M
+ # To avoid warning: UserWarning: __floordiv__ is deprecated
+ keypoints_xy_nms = torch.stack(
+ [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
+ dim=1,
+ ) # Mx2
+ keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
+ kptscore = torch.nn.functional.grid_sample(
+ scores_map[b_idx].unsqueeze(0),
+ keypoints_xy.view(1, 1, -1, 2),
+ mode="bilinear",
+ align_corners=True,
+ )[
+ 0, 0, 0, :
+ ] # CxN
+ keypoints.append(keypoints_xy)
+ scoredispersitys.append(kptscore) # for jit.script compatability
+ kptscores.append(kptscore)
+
+ return keypoints, scoredispersitys, kptscores
+
+
+class InputPadder(object):
+ """Pads images such that dimensions are divisible by 8"""
+
+ def __init__(self, h: int, w: int, divis_by: int = 8):
+ self.ht = h
+ self.wd = w
+ pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
+ pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
+ self._pad = [
+ pad_wd // 2,
+ pad_wd - pad_wd // 2,
+ pad_ht // 2,
+ pad_ht - pad_ht // 2,
+ ]
+
+ def pad(self, x: torch.Tensor):
+ assert x.ndim == 4
+ return F.pad(x, self._pad, mode="replicate")
+
+ def unpad(self, x: torch.Tensor):
+ assert x.ndim == 4
+ ht = x.shape[-2]
+ wd = x.shape[-1]
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
+ return x[..., c[0] : c[1], c[2] : c[3]]
+
+
+class DeformableConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ mask=False,
+ ):
+ super(DeformableConv2d, self).__init__()
+
+ self.padding = padding
+ self.mask = mask
+
+ self.channel_num = (
+ 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
+ )
+ self.offset_conv = nn.Conv2d(
+ in_channels,
+ self.channel_num,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=self.padding,
+ bias=True,
+ )
+
+ self.regular_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=self.padding,
+ bias=bias,
+ )
+
+ def forward(self, x):
+ h, w = x.shape[2:]
+ max_offset = max(h, w) / 4.0
+
+ out = self.offset_conv(x)
+ if self.mask:
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ else:
+ offset = out
+ mask = None
+ offset = offset.clamp(-max_offset, max_offset)
+ x = torchvision.ops.deform_conv2d(
+ input=x,
+ offset=offset,
+ weight=self.regular_conv.weight,
+ bias=self.regular_conv.bias,
+ padding=self.padding,
+ mask=mask,
+ )
+ return x
+
+
+def get_conv(
+ inplanes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ conv_type="conv",
+ mask=False,
+):
+ if conv_type == "conv":
+ conv = nn.Conv2d(
+ inplanes,
+ planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=bias,
+ )
+ elif conv_type == "dcn":
+ conv = DeformableConv2d(
+ inplanes,
+ planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=_pair(padding),
+ bias=bias,
+ mask=mask,
+ )
+ else:
+ raise TypeError
+ return conv
+
+
+class ConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ gate: Optional[Callable[..., nn.Module]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ conv_type: str = "conv",
+ mask: bool = False,
+ ):
+ super().__init__()
+ if gate is None:
+ self.gate = nn.ReLU(inplace=True)
+ else:
+ self.gate = gate
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self.conv1 = get_conv(
+ in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
+ )
+ self.bn1 = norm_layer(out_channels)
+ self.conv2 = get_conv(
+ out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
+ )
+ self.bn2 = norm_layer(out_channels)
+
+ def forward(self, x):
+ x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
+ x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
+ return x
+
+
+# modified based on torchvision\models\resnet.py#27->BasicBlock
+class ResBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ gate: Optional[Callable[..., nn.Module]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ conv_type: str = "conv",
+ mask: bool = False,
+ ) -> None:
+ super(ResBlock, self).__init__()
+ if gate is None:
+ self.gate = nn.ReLU(inplace=True)
+ else:
+ self.gate = gate
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError("ResBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in ResBlock")
+ # Both self.conv1 and self.downsample layers
+ # downsample the input when stride != 1
+ self.conv1 = get_conv(
+ inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
+ )
+ self.bn1 = norm_layer(planes)
+ self.conv2 = get_conv(
+ planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
+ )
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.gate(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.gate(out)
+
+ return out
+
+
+class SDDH(nn.Module):
+ def __init__(
+ self,
+ dims: int,
+ kernel_size: int = 3,
+ n_pos: int = 8,
+ gate=nn.ReLU(),
+ conv2D=False,
+ mask=False,
+ ):
+ super(SDDH, self).__init__()
+ self.kernel_size = kernel_size
+ self.n_pos = n_pos
+ self.conv2D = conv2D
+ self.mask = mask
+
+ self.get_patches_func = get_patches
+
+ # estimate offsets
+ self.channel_num = 3 * n_pos if mask else 2 * n_pos
+ self.offset_conv = nn.Sequential(
+ nn.Conv2d(
+ dims,
+ self.channel_num,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=0,
+ bias=True,
+ ),
+ gate,
+ nn.Conv2d(
+ self.channel_num,
+ self.channel_num,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ ),
+ )
+
+ # sampled feature conv
+ self.sf_conv = nn.Conv2d(
+ dims, dims, kernel_size=1, stride=1, padding=0, bias=False
+ )
+
+ # convM
+ if not conv2D:
+ # deformable desc weights
+ agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
+ self.register_parameter("agg_weights", agg_weights)
+ else:
+ self.convM = nn.Conv2d(
+ dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
+ )
+
+ def forward(self, x, keypoints):
+ # x: [B,C,H,W]
+ # keypoints: list, [[N_kpts,2], ...] (w,h)
+ b, c, h, w = x.shape
+ wh = torch.tensor([[w - 1, h - 1]], device=x.device)
+ max_offset = max(h, w) / 4.0
+
+ offsets = []
+ descriptors = []
+ # get offsets for each keypoint
+ for ib in range(b):
+ xi, kptsi = x[ib], keypoints[ib]
+ kptsi_wh = (kptsi / 2 + 0.5) * wh
+ N_kpts = len(kptsi)
+
+ if self.kernel_size > 1:
+ patch = self.get_patches_func(
+ xi, kptsi_wh.long(), self.kernel_size
+ ) # [N_kpts, C, K, K]
+ else:
+ kptsi_wh_long = kptsi_wh.long()
+ patch = (
+ xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
+ .permute(1, 0)
+ .reshape(N_kpts, c, 1, 1)
+ )
+
+ offset = self.offset_conv(patch).clamp(
+ -max_offset, max_offset
+ ) # [N_kpts, 2*n_pos, 1, 1]
+ if self.mask:
+ offset = (
+ offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
+ ) # [N_kpts, n_pos, 3]
+ offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
+ mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
+ else:
+ offset = (
+ offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
+ ) # [N_kpts, n_pos, 2]
+ offsets.append(offset) # for visualization
+
+ # get sample positions
+ pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
+ pos = 2.0 * pos / wh[None] - 1
+ pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
+
+ # sample features
+ features = F.grid_sample(
+ xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
+ ) # [1,C,(N_kpts*n_pos),1]
+ features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
+ 1, 0, 2, 3
+ ) # [N_kpts, C, n_pos, 1]
+ if self.mask:
+ features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
+
+ features = torch.selu_(self.sf_conv(features)).squeeze(
+ -1
+ ) # [N_kpts, C, n_pos]
+ # convM
+ if not self.conv2D:
+ descs = torch.einsum(
+ "ncp,pcd->nd", features, self.agg_weights
+ ) # [N_kpts, C]
+ else:
+ features = features.reshape(N_kpts, -1)[
+ :, :, None, None
+ ] # [N_kpts, C*n_pos, 1, 1]
+ descs = self.convM(features).squeeze() # [N_kpts, C]
+
+ # normalize
+ descs = F.normalize(descs, p=2.0, dim=1)
+ descriptors.append(descs)
+
+ return descriptors, offsets
+
+
+class ALIKED(Extractor):
+ default_conf = {
+ "model_name": "aliked-n16",
+ "max_num_keypoints": -1,
+ "detection_threshold": 0.2,
+ "nms_radius": 2,
+ }
+
+ checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
+
+ n_limit_max = 20000
+
+ # c1, c2, c3, c4, dim, K, M
+ cfgs = {
+ "aliked-t16": [8, 16, 32, 64, 64, 3, 16],
+ "aliked-n16": [16, 32, 64, 128, 128, 3, 16],
+ "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
+ "aliked-n32": [16, 32, 64, 128, 128, 3, 32],
+ }
+ preprocess_conf = {
+ "resize": 1024,
+ }
+
+ required_data_keys = ["image"]
+
+ def __init__(self, **conf):
+ super().__init__(**conf) # Update with default configuration.
+ conf = self.conf
+ c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
+ conv_types = ["conv", "conv", "dcn", "dcn"]
+ conv2D = False
+ mask = False
+
+ # build model
+ self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
+ self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
+ self.norm = nn.BatchNorm2d
+ self.gate = nn.SELU(inplace=True)
+ self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
+ self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
+ self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
+ self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
+
+ self.conv1 = resnet.conv1x1(c1, dim // 4)
+ self.conv2 = resnet.conv1x1(c2, dim // 4)
+ self.conv3 = resnet.conv1x1(c3, dim // 4)
+ self.conv4 = resnet.conv1x1(dim, dim // 4)
+ self.upsample2 = nn.Upsample(
+ scale_factor=2, mode="bilinear", align_corners=True
+ )
+ self.upsample4 = nn.Upsample(
+ scale_factor=4, mode="bilinear", align_corners=True
+ )
+ self.upsample8 = nn.Upsample(
+ scale_factor=8, mode="bilinear", align_corners=True
+ )
+ self.upsample32 = nn.Upsample(
+ scale_factor=32, mode="bilinear", align_corners=True
+ )
+ self.score_head = nn.Sequential(
+ resnet.conv1x1(dim, 8),
+ self.gate,
+ resnet.conv3x3(8, 4),
+ self.gate,
+ resnet.conv3x3(4, 4),
+ self.gate,
+ resnet.conv3x3(4, 1),
+ )
+ self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
+ self.dkd = DKD(
+ radius=conf.nms_radius,
+ top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
+ scores_th=conf.detection_threshold,
+ n_limit=conf.max_num_keypoints
+ if conf.max_num_keypoints > 0
+ else self.n_limit_max,
+ )
+
+ state_dict = torch.hub.load_state_dict_from_url(
+ self.checkpoint_url.format(conf.model_name), map_location="cpu"
+ )
+ self.load_state_dict(state_dict, strict=True)
+
+ def get_resblock(self, c_in, c_out, conv_type, mask):
+ return ResBlock(
+ c_in,
+ c_out,
+ 1,
+ nn.Conv2d(c_in, c_out, 1),
+ gate=self.gate,
+ norm_layer=self.norm,
+ conv_type=conv_type,
+ mask=mask,
+ )
+
+ def extract_dense_map(self, image):
+ # Pads images such that dimensions are divisible by
+ div_by = 2**5
+ padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
+ image = padder.pad(image)
+
+ # ================================== feature encoder
+ x1 = self.block1(image) # B x c1 x H x W
+ x2 = self.pool2(x1)
+ x2 = self.block2(x2) # B x c2 x H/2 x W/2
+ x3 = self.pool4(x2)
+ x3 = self.block3(x3) # B x c3 x H/8 x W/8
+ x4 = self.pool4(x3)
+ x4 = self.block4(x4) # B x dim x H/32 x W/32
+ # ================================== feature aggregation
+ x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
+ x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
+ x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
+ x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
+ x2_up = self.upsample2(x2) # B x dim//4 x H x W
+ x3_up = self.upsample8(x3) # B x dim//4 x H x W
+ x4_up = self.upsample32(x4) # B x dim//4 x H x W
+ x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
+ # ================================== score head
+ score_map = torch.sigmoid(self.score_head(x1234))
+ feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
+
+ # Unpads images
+ feature_map = padder.unpad(feature_map)
+ score_map = padder.unpad(score_map)
+
+ return feature_map, score_map
+
+ def forward(self, data: dict) -> dict:
+ image = data["image"]
+ if image.shape[1] == 1:
+ image = grayscale_to_rgb(image)
+ feature_map, score_map = self.extract_dense_map(image)
+ keypoints, kptscores, scoredispersitys = self.dkd(
+ score_map, image_size=data.get("image_size")
+ )
+ descriptors, offsets = self.desc_head(feature_map, keypoints)
+
+ _, _, h, w = image.shape
+ wh = torch.tensor([w - 1, h - 1], device=image.device)
+ # no padding required
+ # we can set detection_threshold=-1 and conf.max_num_keypoints > 0
+ return {
+ "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
+ "descriptors": torch.stack(descriptors), # B x N x D
+ "keypoint_scores": torch.stack(kptscores), # B x N
+ }
diff --git a/imcui/third_party/LightGlue/lightglue/disk.py b/imcui/third_party/LightGlue/lightglue/disk.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cb2195fe2f95c32959b5be4b09ad91bb51a35d5
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/disk.py
@@ -0,0 +1,55 @@
+import kornia
+import torch
+
+from .utils import Extractor
+
+
+class DISK(Extractor):
+ default_conf = {
+ "weights": "depth",
+ "max_num_keypoints": None,
+ "desc_dim": 128,
+ "nms_window_size": 5,
+ "detection_threshold": 0.0,
+ "pad_if_not_divisible": True,
+ }
+
+ preprocess_conf = {
+ "resize": 1024,
+ "grayscale": False,
+ }
+
+ required_data_keys = ["image"]
+
+ def __init__(self, **conf) -> None:
+ super().__init__(**conf) # Update with default configuration.
+ self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
+
+ def forward(self, data: dict) -> dict:
+ """Compute keypoints, scores, descriptors for image"""
+ for key in self.required_data_keys:
+ assert key in data, f"Missing key {key} in data"
+ image = data["image"]
+ if image.shape[1] == 1:
+ image = kornia.color.grayscale_to_rgb(image)
+ features = self.model(
+ image,
+ n=self.conf.max_num_keypoints,
+ window_size=self.conf.nms_window_size,
+ score_threshold=self.conf.detection_threshold,
+ pad_if_not_divisible=self.conf.pad_if_not_divisible,
+ )
+ keypoints = [f.keypoints for f in features]
+ scores = [f.detection_scores for f in features]
+ descriptors = [f.descriptors for f in features]
+ del features
+
+ keypoints = torch.stack(keypoints, 0)
+ scores = torch.stack(scores, 0)
+ descriptors = torch.stack(descriptors, 0)
+
+ return {
+ "keypoints": keypoints.to(image).contiguous(),
+ "keypoint_scores": scores.to(image).contiguous(),
+ "descriptors": descriptors.to(image).contiguous(),
+ }
diff --git a/imcui/third_party/LightGlue/lightglue/dog_hardnet.py b/imcui/third_party/LightGlue/lightglue/dog_hardnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..cce307ae1f11e2066312fd44ecac8884d1de3358
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/dog_hardnet.py
@@ -0,0 +1,41 @@
+import torch
+from kornia.color import rgb_to_grayscale
+from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
+
+from .sift import SIFT
+
+
+class DoGHardNet(SIFT):
+ required_data_keys = ["image"]
+
+ def __init__(self, **conf):
+ super().__init__(**conf)
+ self.laf_desc = LAFDescriptor(HardNet(True)).eval()
+
+ def forward(self, data: dict) -> dict:
+ image = data["image"]
+ if image.shape[1] == 3:
+ image = rgb_to_grayscale(image)
+ device = image.device
+ self.laf_desc = self.laf_desc.to(device)
+ self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
+ pred = []
+ if "image_size" in data.keys():
+ im_size = data.get("image_size").long()
+ else:
+ im_size = None
+ for k in range(len(image)):
+ img = image[k]
+ if im_size is not None:
+ w, h = data["image_size"][k]
+ img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
+ p = self.extract_single_image(img)
+ lafs = laf_from_center_scale_ori(
+ p["keypoints"].reshape(1, -1, 2),
+ 6.0 * p["scales"].reshape(1, -1, 1, 1),
+ torch.rad2deg(p["oris"]).reshape(1, -1, 1),
+ ).to(device)
+ p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
+ pred.append(p)
+ pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
+ return pred
diff --git a/imcui/third_party/LightGlue/lightglue/lightglue.py b/imcui/third_party/LightGlue/lightglue/lightglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..14e6a61e25764e1a4258e94363c30211bcbc7c44
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/lightglue.py
@@ -0,0 +1,655 @@
+import warnings
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Callable, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+try:
+ from flash_attn.modules.mha import FlashCrossAttention
+except ModuleNotFoundError:
+ FlashCrossAttention = None
+
+if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
+ FLASH_AVAILABLE = True
+else:
+ FLASH_AVAILABLE = False
+
+torch.backends.cudnn.deterministic = True
+
+
+@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+def normalize_keypoints(
+ kpts: torch.Tensor, size: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ if size is None:
+ size = 1 + kpts.max(-2).values - kpts.min(-2).values
+ elif not isinstance(size, torch.Tensor):
+ size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
+ size = size.to(kpts)
+ shift = size / 2
+ scale = size.max(-1).values / 2
+ kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
+ return kpts
+
+
+def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
+ if length <= x.shape[-2]:
+ return x, torch.ones_like(x[..., :1], dtype=torch.bool)
+ pad = torch.ones(
+ *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
+ )
+ y = torch.cat([x, pad], dim=-2)
+ mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
+ mask[..., : x.shape[-2], :] = True
+ return y, mask
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ x = x.unflatten(-1, (-1, 2))
+ x1, x2 = x.unbind(dim=-1)
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
+
+
+def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ return (t * freqs[0]) + (rotate_half(t) * freqs[1])
+
+
+class LearnableFourierPositionalEncoding(nn.Module):
+ def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
+ super().__init__()
+ F_dim = F_dim if F_dim is not None else dim
+ self.gamma = gamma
+ self.Wr = nn.Linear(M, F_dim // 2, bias=False)
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """encode position vector"""
+ projected = self.Wr(x)
+ cosines, sines = torch.cos(projected), torch.sin(projected)
+ emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
+ return emb.repeat_interleave(2, dim=-1)
+
+
+class TokenConfidence(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+ self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
+
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
+ """get confidence tokens"""
+ return (
+ self.token(desc0.detach()).squeeze(-1),
+ self.token(desc1.detach()).squeeze(-1),
+ )
+
+
+class Attention(nn.Module):
+ def __init__(self, allow_flash: bool) -> None:
+ super().__init__()
+ if allow_flash and not FLASH_AVAILABLE:
+ warnings.warn(
+ "FlashAttention is not available. For optimal speed, "
+ "consider installing torch >= 2.0 or flash-attn.",
+ stacklevel=2,
+ )
+ self.enable_flash = allow_flash and FLASH_AVAILABLE
+ self.has_sdp = hasattr(F, "scaled_dot_product_attention")
+ if allow_flash and FlashCrossAttention:
+ self.flash_ = FlashCrossAttention()
+ if self.has_sdp:
+ torch.backends.cuda.enable_flash_sdp(allow_flash)
+
+ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if q.shape[-2] == 0 or k.shape[-2] == 0:
+ return q.new_zeros((*q.shape[:-1], v.shape[-1]))
+ if self.enable_flash and q.device.type == "cuda":
+ # use torch 2.0 scaled_dot_product_attention with flash
+ if self.has_sdp:
+ args = [x.half().contiguous() for x in [q, k, v]]
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
+ return v if mask is None else v.nan_to_num()
+ else:
+ assert mask is None
+ q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
+ m = self.flash_(q.half(), torch.stack([k, v], 2).half())
+ return m.transpose(-2, -3).to(q.dtype).clone()
+ elif self.has_sdp:
+ args = [x.contiguous() for x in [q, k, v]]
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask)
+ return v if mask is None else v.nan_to_num()
+ else:
+ s = q.shape[-1] ** -0.5
+ sim = torch.einsum("...id,...jd->...ij", q, k) * s
+ if mask is not None:
+ sim.masked_fill(~mask, -float("inf"))
+ attn = F.softmax(sim, -1)
+ return torch.einsum("...ij,...jd->...id", attn, v)
+
+
+class SelfBlock(nn.Module):
+ def __init__(
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ assert self.embed_dim % num_heads == 0
+ self.head_dim = self.embed_dim // num_heads
+ self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
+ self.inner_attn = Attention(flash)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.ffn = nn.Sequential(
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(2 * embed_dim, embed_dim),
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ encoding: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ qkv = self.Wqkv(x)
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
+ q = apply_cached_rotary_emb(encoding, q)
+ k = apply_cached_rotary_emb(encoding, k)
+ context = self.inner_attn(q, k, v, mask=mask)
+ message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
+ return x + self.ffn(torch.cat([x, message], -1))
+
+
+class CrossBlock(nn.Module):
+ def __init__(
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
+ ) -> None:
+ super().__init__()
+ self.heads = num_heads
+ dim_head = embed_dim // num_heads
+ self.scale = dim_head**-0.5
+ inner_dim = dim_head * num_heads
+ self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
+ self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
+ self.ffn = nn.Sequential(
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
+ nn.GELU(),
+ nn.Linear(2 * embed_dim, embed_dim),
+ )
+ if flash and FLASH_AVAILABLE:
+ self.flash = Attention(True)
+ else:
+ self.flash = None
+
+ def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
+ return func(x0), func(x1)
+
+ def forward(
+ self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> List[torch.Tensor]:
+ qk0, qk1 = self.map_(self.to_qk, x0, x1)
+ v0, v1 = self.map_(self.to_v, x0, x1)
+ qk0, qk1, v0, v1 = map(
+ lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
+ (qk0, qk1, v0, v1),
+ )
+ if self.flash is not None and qk0.device.type == "cuda":
+ m0 = self.flash(qk0, qk1, v1, mask)
+ m1 = self.flash(
+ qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
+ )
+ else:
+ qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
+ sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
+ if mask is not None:
+ sim = sim.masked_fill(~mask, -float("inf"))
+ attn01 = F.softmax(sim, dim=-1)
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
+ m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
+ m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
+ if mask is not None:
+ m0, m1 = m0.nan_to_num(), m1.nan_to_num()
+ m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
+ m0, m1 = self.map_(self.to_out, m0, m1)
+ x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
+ x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
+ return x0, x1
+
+
+class TransformerLayer(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.self_attn = SelfBlock(*args, **kwargs)
+ self.cross_attn = CrossBlock(*args, **kwargs)
+
+ def forward(
+ self,
+ desc0,
+ desc1,
+ encoding0,
+ encoding1,
+ mask0: Optional[torch.Tensor] = None,
+ mask1: Optional[torch.Tensor] = None,
+ ):
+ if mask0 is not None and mask1 is not None:
+ return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
+ else:
+ desc0 = self.self_attn(desc0, encoding0)
+ desc1 = self.self_attn(desc1, encoding1)
+ return self.cross_attn(desc0, desc1)
+
+ # This part is compiled and allows padding inputs
+ def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
+ mask = mask0 & mask1.transpose(-1, -2)
+ mask0 = mask0 & mask0.transpose(-1, -2)
+ mask1 = mask1 & mask1.transpose(-1, -2)
+ desc0 = self.self_attn(desc0, encoding0, mask0)
+ desc1 = self.self_attn(desc1, encoding1, mask1)
+ return self.cross_attn(desc0, desc1, mask)
+
+
+def sigmoid_log_double_softmax(
+ sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
+) -> torch.Tensor:
+ """create the log assignment matrix from logits and similarity"""
+ b, m, n = sim.shape
+ certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
+ scores0 = F.log_softmax(sim, 2)
+ scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
+ scores = sim.new_full((b, m + 1, n + 1), 0)
+ scores[:, :m, :n] = scores0 + scores1 + certainties
+ scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
+ scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
+ return scores
+
+
+class MatchAssignment(nn.Module):
+ def __init__(self, dim: int) -> None:
+ super().__init__()
+ self.dim = dim
+ self.matchability = nn.Linear(dim, 1, bias=True)
+ self.final_proj = nn.Linear(dim, dim, bias=True)
+
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
+ """build assignment matrix from descriptors"""
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
+ _, _, d = mdesc0.shape
+ mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
+ sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
+ z0 = self.matchability(desc0)
+ z1 = self.matchability(desc1)
+ scores = sigmoid_log_double_softmax(sim, z0, z1)
+ return scores, sim
+
+ def get_matchability(self, desc: torch.Tensor):
+ return torch.sigmoid(self.matchability(desc)).squeeze(-1)
+
+
+def filter_matches(scores: torch.Tensor, th: float):
+ """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
+ m0, m1 = max0.indices, max1.indices
+ indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
+ indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
+ mutual0 = indices0 == m1.gather(1, m0)
+ mutual1 = indices1 == m0.gather(1, m1)
+ max0_exp = max0.values.exp()
+ zero = max0_exp.new_tensor(0)
+ mscores0 = torch.where(mutual0, max0_exp, zero)
+ mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
+ valid0 = mutual0 & (mscores0 > th)
+ valid1 = mutual1 & valid0.gather(1, m1)
+ m0 = torch.where(valid0, m0, -1)
+ m1 = torch.where(valid1, m1, -1)
+ return m0, m1, mscores0, mscores1
+
+
+class LightGlue(nn.Module):
+ default_conf = {
+ "name": "lightglue", # just for interfacing
+ "input_dim": 256, # input descriptor dimension (autoselected from weights)
+ "descriptor_dim": 256,
+ "add_scale_ori": False,
+ "n_layers": 9,
+ "num_heads": 4,
+ "flash": True, # enable FlashAttention if available.
+ "mp": False, # enable mixed precision
+ "depth_confidence": 0.95, # early stopping, disable with -1
+ "width_confidence": 0.99, # point pruning, disable with -1
+ "filter_threshold": 0.1, # match threshold
+ "weights": None,
+ }
+
+ # Point pruning involves an overhead (gather).
+ # Therefore, we only activate it if there are enough keypoints.
+ pruning_keypoint_thresholds = {
+ "cpu": -1,
+ "mps": -1,
+ "cuda": 1024,
+ "flash": 1536,
+ }
+
+ required_data_keys = ["image0", "image1"]
+
+ version = "v0.1_arxiv"
+ url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
+
+ features = {
+ "superpoint": {
+ "weights": "superpoint_lightglue",
+ "input_dim": 256,
+ },
+ "disk": {
+ "weights": "disk_lightglue",
+ "input_dim": 128,
+ },
+ "aliked": {
+ "weights": "aliked_lightglue",
+ "input_dim": 128,
+ },
+ "sift": {
+ "weights": "sift_lightglue",
+ "input_dim": 128,
+ "add_scale_ori": True,
+ },
+ "doghardnet": {
+ "weights": "doghardnet_lightglue",
+ "input_dim": 128,
+ "add_scale_ori": True,
+ },
+ }
+
+ def __init__(self, features="superpoint", **conf) -> None:
+ super().__init__()
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
+ if features is not None:
+ if features not in self.features:
+ raise ValueError(
+ f"Unsupported features: {features} not in "
+ f"{{{','.join(self.features)}}}"
+ )
+ for k, v in self.features[features].items():
+ setattr(conf, k, v)
+
+ if conf.input_dim != conf.descriptor_dim:
+ self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
+ else:
+ self.input_proj = nn.Identity()
+
+ head_dim = conf.descriptor_dim // conf.num_heads
+ self.posenc = LearnableFourierPositionalEncoding(
+ 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
+ )
+
+ h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
+
+ self.transformers = nn.ModuleList(
+ [TransformerLayer(d, h, conf.flash) for _ in range(n)]
+ )
+
+ self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
+ self.token_confidence = nn.ModuleList(
+ [TokenConfidence(d) for _ in range(n - 1)]
+ )
+ self.register_buffer(
+ "confidence_thresholds",
+ torch.Tensor(
+ [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
+ ),
+ )
+
+ state_dict = None
+ if features is not None:
+ fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
+ state_dict = torch.hub.load_state_dict_from_url(
+ self.url.format(self.version, features), file_name=fname
+ )
+ self.load_state_dict(state_dict, strict=False)
+ elif conf.weights is not None:
+ path = Path(__file__).parent
+ path = path / "weights/{}.pth".format(self.conf.weights)
+ state_dict = torch.load(str(path), map_location="cpu")
+
+ if state_dict:
+ # rename old state dict entries
+ for i in range(self.conf.n_layers):
+ pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
+ pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
+ self.load_state_dict(state_dict, strict=False)
+
+ # static lengths LightGlue is compiled for (only used with torch.compile)
+ self.static_lengths = None
+
+ def compile(
+ self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
+ ):
+ if self.conf.width_confidence != -1:
+ warnings.warn(
+ "Point pruning is partially disabled for compiled forward.",
+ stacklevel=2,
+ )
+
+ torch._inductor.cudagraph_mark_step_begin()
+ for i in range(self.conf.n_layers):
+ self.transformers[i].masked_forward = torch.compile(
+ self.transformers[i].masked_forward, mode=mode, fullgraph=True
+ )
+
+ self.static_lengths = static_lengths
+
+ def forward(self, data: dict) -> dict:
+ """
+ Match keypoints and descriptors between two images
+
+ Input (dict):
+ image0: dict
+ keypoints: [B x M x 2]
+ descriptors: [B x M x D]
+ image: [B x C x H x W] or image_size: [B x 2]
+ image1: dict
+ keypoints: [B x N x 2]
+ descriptors: [B x N x D]
+ image: [B x C x H x W] or image_size: [B x 2]
+ Output (dict):
+ matches0: [B x M]
+ matching_scores0: [B x M]
+ matches1: [B x N]
+ matching_scores1: [B x N]
+ matches: List[[Si x 2]]
+ scores: List[[Si]]
+ stop: int
+ prune0: [B x M]
+ prune1: [B x N]
+ """
+ with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
+ return self._forward(data)
+
+ def _forward(self, data: dict) -> dict:
+ for key in self.required_data_keys:
+ assert key in data, f"Missing key {key} in data"
+ data0, data1 = data["image0"], data["image1"]
+ kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
+ b, m, _ = kpts0.shape
+ b, n, _ = kpts1.shape
+ device = kpts0.device
+ size0, size1 = data0.get("image_size"), data1.get("image_size")
+ kpts0 = normalize_keypoints(kpts0, size0).clone()
+ kpts1 = normalize_keypoints(kpts1, size1).clone()
+
+ if self.conf.add_scale_ori:
+ kpts0 = torch.cat(
+ [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
+ )
+ kpts1 = torch.cat(
+ [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
+ )
+ desc0 = data0["descriptors"].detach().contiguous()
+ desc1 = data1["descriptors"].detach().contiguous()
+
+ assert desc0.shape[-1] == self.conf.input_dim
+ assert desc1.shape[-1] == self.conf.input_dim
+
+ if torch.is_autocast_enabled():
+ desc0 = desc0.half()
+ desc1 = desc1.half()
+
+ mask0, mask1 = None, None
+ c = max(m, n)
+ do_compile = self.static_lengths and c <= max(self.static_lengths)
+ if do_compile:
+ kn = min([k for k in self.static_lengths if k >= c])
+ desc0, mask0 = pad_to_length(desc0, kn)
+ desc1, mask1 = pad_to_length(desc1, kn)
+ kpts0, _ = pad_to_length(kpts0, kn)
+ kpts1, _ = pad_to_length(kpts1, kn)
+ desc0 = self.input_proj(desc0)
+ desc1 = self.input_proj(desc1)
+ # cache positional embeddings
+ encoding0 = self.posenc(kpts0)
+ encoding1 = self.posenc(kpts1)
+
+ # GNN + final_proj + assignment
+ do_early_stop = self.conf.depth_confidence > 0
+ do_point_pruning = self.conf.width_confidence > 0 and not do_compile
+ pruning_th = self.pruning_min_kpts(device)
+ if do_point_pruning:
+ ind0 = torch.arange(0, m, device=device)[None]
+ ind1 = torch.arange(0, n, device=device)[None]
+ # We store the index of the layer at which pruning is detected.
+ prune0 = torch.ones_like(ind0)
+ prune1 = torch.ones_like(ind1)
+ token0, token1 = None, None
+ for i in range(self.conf.n_layers):
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
+ break
+ desc0, desc1 = self.transformers[i](
+ desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
+ )
+ if i == self.conf.n_layers - 1:
+ continue # no early stopping or adaptive width at last layer
+
+ if do_early_stop:
+ token0, token1 = self.token_confidence[i](desc0, desc1)
+ if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
+ break
+ if do_point_pruning and desc0.shape[-2] > pruning_th:
+ scores0 = self.log_assignment[i].get_matchability(desc0)
+ prunemask0 = self.get_pruning_mask(token0, scores0, i)
+ keep0 = torch.where(prunemask0)[1]
+ ind0 = ind0.index_select(1, keep0)
+ desc0 = desc0.index_select(1, keep0)
+ encoding0 = encoding0.index_select(-2, keep0)
+ prune0[:, ind0] += 1
+ if do_point_pruning and desc1.shape[-2] > pruning_th:
+ scores1 = self.log_assignment[i].get_matchability(desc1)
+ prunemask1 = self.get_pruning_mask(token1, scores1, i)
+ keep1 = torch.where(prunemask1)[1]
+ ind1 = ind1.index_select(1, keep1)
+ desc1 = desc1.index_select(1, keep1)
+ encoding1 = encoding1.index_select(-2, keep1)
+ prune1[:, ind1] += 1
+
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
+ m0 = desc0.new_full((b, m), -1, dtype=torch.long)
+ m1 = desc1.new_full((b, n), -1, dtype=torch.long)
+ mscores0 = desc0.new_zeros((b, m))
+ mscores1 = desc1.new_zeros((b, n))
+ matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
+ mscores = desc0.new_empty((b, 0))
+ if not do_point_pruning:
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
+ return {
+ "matches0": m0,
+ "matches1": m1,
+ "matching_scores0": mscores0,
+ "matching_scores1": mscores1,
+ "stop": i + 1,
+ "matches": matches,
+ "scores": mscores,
+ "prune0": prune0,
+ "prune1": prune1,
+ }
+
+ desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
+ scores, _ = self.log_assignment[i](desc0, desc1)
+ m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
+ matches, mscores = [], []
+ for k in range(b):
+ valid = m0[k] > -1
+ m_indices_0 = torch.where(valid)[0]
+ m_indices_1 = m0[k][valid]
+ if do_point_pruning:
+ m_indices_0 = ind0[k, m_indices_0]
+ m_indices_1 = ind1[k, m_indices_1]
+ matches.append(torch.stack([m_indices_0, m_indices_1], -1))
+ mscores.append(mscores0[k][valid])
+
+ # TODO: Remove when hloc switches to the compact format.
+ if do_point_pruning:
+ m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
+ m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
+ m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
+ m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
+ mscores0_ = torch.zeros((b, m), device=mscores0.device)
+ mscores1_ = torch.zeros((b, n), device=mscores1.device)
+ mscores0_[:, ind0] = mscores0
+ mscores1_[:, ind1] = mscores1
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
+ else:
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
+
+ return {
+ "matches0": m0,
+ "matches1": m1,
+ "matching_scores0": mscores0,
+ "matching_scores1": mscores1,
+ "stop": i + 1,
+ "matches": matches,
+ "scores": mscores,
+ "prune0": prune0,
+ "prune1": prune1,
+ }
+
+ def confidence_threshold(self, layer_index: int) -> float:
+ """scaled confidence threshold"""
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
+ return np.clip(threshold, 0, 1)
+
+ def get_pruning_mask(
+ self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
+ ) -> torch.Tensor:
+ """mask points which should be removed"""
+ keep = scores > (1 - self.conf.width_confidence)
+ if confidences is not None: # Low-confidence points are never pruned.
+ keep |= confidences <= self.confidence_thresholds[layer_index]
+ return keep
+
+ def check_if_stop(
+ self,
+ confidences0: torch.Tensor,
+ confidences1: torch.Tensor,
+ layer_index: int,
+ num_points: int,
+ ) -> torch.Tensor:
+ """evaluate stopping condition"""
+ confidences = torch.cat([confidences0, confidences1], -1)
+ threshold = self.confidence_thresholds[layer_index]
+ ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
+ return ratio_confident > self.conf.depth_confidence
+
+ def pruning_min_kpts(self, device: torch.device):
+ if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
+ return self.pruning_keypoint_thresholds["flash"]
+ else:
+ return self.pruning_keypoint_thresholds[device.type]
diff --git a/imcui/third_party/LightGlue/lightglue/sift.py b/imcui/third_party/LightGlue/lightglue/sift.py
new file mode 100644
index 0000000000000000000000000000000000000000..802fc1c2eb9ee852691e0e4dd67455f822f8405f
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/sift.py
@@ -0,0 +1,216 @@
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from kornia.color import rgb_to_grayscale
+from packaging import version
+
+try:
+ import pycolmap
+except ImportError:
+ pycolmap = None
+
+from .utils import Extractor
+
+
+def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
+ h, w = image_shape
+ ij = np.round(points - 0.5).astype(int).T[::-1]
+
+ # Remove duplicate points (identical coordinates).
+ # Pick highest scale or score
+ s = scales if scores is None else scores
+ buffer = np.zeros((h, w))
+ np.maximum.at(buffer, tuple(ij), s)
+ keep = np.where(buffer[tuple(ij)] == s)[0]
+
+ # Pick lowest angle (arbitrary).
+ ij = ij[:, keep]
+ buffer[:] = np.inf
+ o_abs = np.abs(angles[keep])
+ np.minimum.at(buffer, tuple(ij), o_abs)
+ mask = buffer[tuple(ij)] == o_abs
+ ij = ij[:, mask]
+ keep = keep[mask]
+
+ if nms_radius > 0:
+ # Apply NMS on the remaining points
+ buffer[:] = 0
+ buffer[tuple(ij)] = s[keep] # scores or scale
+
+ local_max = torch.nn.functional.max_pool2d(
+ torch.from_numpy(buffer).unsqueeze(0),
+ kernel_size=nms_radius * 2 + 1,
+ stride=1,
+ padding=nms_radius,
+ ).squeeze(0)
+ is_local_max = buffer == local_max.numpy()
+ keep = keep[is_local_max[tuple(ij)]]
+ return keep
+
+
+def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
+ x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
+ x.clip_(min=eps).sqrt_()
+ return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
+
+
+def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
+ """
+ Detect keypoints using OpenCV Detector.
+ Optionally, perform description.
+ Args:
+ features: OpenCV based keypoints detector and descriptor
+ image: Grayscale image of uint8 data type
+ Returns:
+ keypoints: 1D array of detected cv2.KeyPoint
+ scores: 1D array of responses
+ descriptors: 1D array of descriptors
+ """
+ detections, descriptors = features.detectAndCompute(image, None)
+ points = np.array([k.pt for k in detections], dtype=np.float32)
+ scores = np.array([k.response for k in detections], dtype=np.float32)
+ scales = np.array([k.size for k in detections], dtype=np.float32)
+ angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
+ return points, scores, scales, angles, descriptors
+
+
+class SIFT(Extractor):
+ default_conf = {
+ "rootsift": True,
+ "nms_radius": 0, # None to disable filtering entirely.
+ "max_num_keypoints": 4096,
+ "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
+ "detection_threshold": 0.0066667, # from COLMAP
+ "edge_threshold": 10,
+ "first_octave": -1, # only used by pycolmap, the default of COLMAP
+ "num_octaves": 4,
+ }
+
+ preprocess_conf = {
+ "resize": 1024,
+ }
+
+ required_data_keys = ["image"]
+
+ def __init__(self, **conf):
+ super().__init__(**conf) # Update with default configuration.
+ backend = self.conf.backend
+ if backend.startswith("pycolmap"):
+ if pycolmap is None:
+ raise ImportError(
+ "Cannot find module pycolmap: install it with pip"
+ "or use backend=opencv."
+ )
+ options = {
+ "peak_threshold": self.conf.detection_threshold,
+ "edge_threshold": self.conf.edge_threshold,
+ "first_octave": self.conf.first_octave,
+ "num_octaves": self.conf.num_octaves,
+ "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
+ }
+ device = (
+ "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
+ )
+ if (
+ backend == "pycolmap_cpu" or not pycolmap.has_cuda
+ ) and pycolmap.__version__ < "0.5.0":
+ warnings.warn(
+ "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
+ "consider upgrading pycolmap or use the CUDA version.",
+ stacklevel=1,
+ )
+ else:
+ options["max_num_features"] = self.conf.max_num_keypoints
+ self.sift = pycolmap.Sift(options=options, device=device)
+ elif backend == "opencv":
+ self.sift = cv2.SIFT_create(
+ contrastThreshold=self.conf.detection_threshold,
+ nfeatures=self.conf.max_num_keypoints,
+ edgeThreshold=self.conf.edge_threshold,
+ nOctaveLayers=self.conf.num_octaves,
+ )
+ else:
+ backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
+ raise ValueError(
+ f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
+ )
+
+ def extract_single_image(self, image: torch.Tensor):
+ image_np = image.cpu().numpy().squeeze(0)
+
+ if self.conf.backend.startswith("pycolmap"):
+ if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
+ detections, descriptors = self.sift.extract(image_np)
+ scores = None # Scores are not exposed by COLMAP anymore.
+ else:
+ detections, scores, descriptors = self.sift.extract(image_np)
+ keypoints = detections[:, :2] # Keep only (x, y).
+ scales, angles = detections[:, -2:].T
+ if scores is not None and (
+ self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
+ ):
+ # Set the scores as a combination of abs. response and scale.
+ scores = np.abs(scores) * scales
+ elif self.conf.backend == "opencv":
+ # TODO: Check if opencv keypoints are already in corner convention
+ keypoints, scores, scales, angles, descriptors = run_opencv_sift(
+ self.sift, (image_np * 255.0).astype(np.uint8)
+ )
+ pred = {
+ "keypoints": keypoints,
+ "scales": scales,
+ "oris": angles,
+ "descriptors": descriptors,
+ }
+ if scores is not None:
+ pred["keypoint_scores"] = scores
+
+ # sometimes pycolmap returns points outside the image. We remove them
+ if self.conf.backend.startswith("pycolmap"):
+ is_inside = (
+ pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
+ ).all(-1)
+ pred = {k: v[is_inside] for k, v in pred.items()}
+
+ if self.conf.nms_radius is not None:
+ keep = filter_dog_point(
+ pred["keypoints"],
+ pred["scales"],
+ pred["oris"],
+ image_np.shape,
+ self.conf.nms_radius,
+ scores=pred.get("keypoint_scores"),
+ )
+ pred = {k: v[keep] for k, v in pred.items()}
+
+ pred = {k: torch.from_numpy(v) for k, v in pred.items()}
+ if scores is not None:
+ # Keep the k keypoints with highest score
+ num_points = self.conf.max_num_keypoints
+ if num_points is not None and len(pred["keypoints"]) > num_points:
+ indices = torch.topk(pred["keypoint_scores"], num_points).indices
+ pred = {k: v[indices] for k, v in pred.items()}
+
+ return pred
+
+ def forward(self, data: dict) -> dict:
+ image = data["image"]
+ if image.shape[1] == 3:
+ image = rgb_to_grayscale(image)
+ device = image.device
+ image = image.cpu()
+ pred = []
+ for k in range(len(image)):
+ img = image[k]
+ if "image_size" in data.keys():
+ # avoid extracting points in padded areas
+ w, h = data["image_size"][k]
+ img = img[:, :h, :w]
+ p = self.extract_single_image(img)
+ pred.append(p)
+ pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
+ if self.conf.rootsift:
+ pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
+ return pred
diff --git a/imcui/third_party/LightGlue/lightglue/superpoint.py b/imcui/third_party/LightGlue/lightglue/superpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6d380eb5737b92ce417e3ed0e5db1ee9d4a1d04
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/superpoint.py
@@ -0,0 +1,227 @@
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+# Adapted by Remi Pautrat, Philipp Lindenberger
+
+import torch
+from kornia.color import rgb_to_grayscale
+from torch import nn
+
+from .utils import Extractor
+
+
+def simple_nms(scores, nms_radius: int):
+ """Fast Non-maximum suppression to remove nearby points"""
+ assert nms_radius >= 0
+
+ def max_pool(x):
+ return torch.nn.functional.max_pool2d(
+ x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
+ )
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == max_pool(scores)
+ for _ in range(2):
+ supp_mask = max_pool(max_mask.float()) > 0
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == max_pool(supp_scores)
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+def top_k_keypoints(keypoints, scores, k):
+ if k >= len(keypoints):
+ return keypoints, scores
+ scores, indices = torch.topk(scores, k, dim=0, sorted=True)
+ return keypoints[indices], scores
+
+
+def sample_descriptors(keypoints, descriptors, s: int = 8):
+ """Interpolate descriptors at keypoint locations"""
+ b, c, h, w = descriptors.shape
+ keypoints = keypoints - s / 2 + 0.5
+ keypoints /= torch.tensor(
+ [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
+ ).to(
+ keypoints
+ )[None]
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
+ args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
+ )
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1
+ )
+ return descriptors
+
+
+class SuperPoint(Extractor):
+ """SuperPoint Convolutional Detector and Descriptor
+
+ SuperPoint: Self-Supervised Interest Point Detection and
+ Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
+ Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
+
+ """
+
+ default_conf = {
+ "descriptor_dim": 256,
+ "nms_radius": 4,
+ "max_num_keypoints": None,
+ "detection_threshold": 0.0005,
+ "remove_borders": 4,
+ }
+
+ preprocess_conf = {
+ "resize": 1024,
+ }
+
+ required_data_keys = ["image"]
+
+ def __init__(self, **conf):
+ super().__init__(**conf) # Update with default configuration.
+ self.relu = nn.ReLU(inplace=True)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
+
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
+
+ self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
+
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convDb = nn.Conv2d(
+ c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
+ )
+
+ url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
+ self.load_state_dict(torch.hub.load_state_dict_from_url(url))
+
+ if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
+ raise ValueError("max_num_keypoints must be positive or None")
+
+ def forward(self, data: dict) -> dict:
+ """Compute keypoints, scores, descriptors for image"""
+ for key in self.required_data_keys:
+ assert key in data, f"Missing key {key} in data"
+ image = data["image"]
+ if image.shape[1] == 3:
+ image = rgb_to_grayscale(image)
+
+ # Shared Encoder
+ x = self.relu(self.conv1a(image))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x))
+ scores = self.convPb(cPa)
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+ b, _, h, w = scores.shape
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
+ scores = simple_nms(scores, self.conf.nms_radius)
+
+ # Discard keypoints near the image borders
+ if self.conf.remove_borders:
+ pad = self.conf.remove_borders
+ scores[:, :pad] = -1
+ scores[:, :, :pad] = -1
+ scores[:, -pad:] = -1
+ scores[:, :, -pad:] = -1
+
+ # Extract keypoints
+ best_kp = torch.where(scores > self.conf.detection_threshold)
+ scores = scores[best_kp]
+
+ # Separate into batches
+ keypoints = [
+ torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
+ ]
+ scores = [scores[best_kp[0] == i] for i in range(b)]
+
+ # Keep the k keypoints with highest score
+ if self.conf.max_num_keypoints is not None:
+ keypoints, scores = list(
+ zip(
+ *[
+ top_k_keypoints(k, s, self.conf.max_num_keypoints)
+ for k, s in zip(keypoints, scores)
+ ]
+ )
+ )
+
+ # Convert (h, w) to (x, y)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x))
+ descriptors = self.convDb(cDa)
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
+
+ # Extract descriptors
+ descriptors = [
+ sample_descriptors(k[None], d[None], 8)[0]
+ for k, d in zip(keypoints, descriptors)
+ ]
+
+ return {
+ "keypoints": torch.stack(keypoints, 0),
+ "keypoint_scores": torch.stack(scores, 0),
+ "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
+ }
diff --git a/imcui/third_party/LightGlue/lightglue/utils.py b/imcui/third_party/LightGlue/lightglue/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1c1ab2e94716b1c54191a6ed5d01023036836c1
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/utils.py
@@ -0,0 +1,165 @@
+import collections.abc as collections
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Callable, List, Optional, Tuple, Union
+
+import cv2
+import kornia
+import numpy as np
+import torch
+
+
+class ImagePreprocessor:
+ default_conf = {
+ "resize": None, # target edge length, None for no resizing
+ "side": "long",
+ "interpolation": "bilinear",
+ "align_corners": None,
+ "antialias": True,
+ }
+
+ def __init__(self, **conf) -> None:
+ super().__init__()
+ self.conf = {**self.default_conf, **conf}
+ self.conf = SimpleNamespace(**self.conf)
+
+ def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Resize and preprocess an image, return image and resize scale"""
+ h, w = img.shape[-2:]
+ if self.conf.resize is not None:
+ img = kornia.geometry.transform.resize(
+ img,
+ self.conf.resize,
+ side=self.conf.side,
+ antialias=self.conf.antialias,
+ align_corners=self.conf.align_corners,
+ )
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
+ return img, scale
+
+
+def map_tensor(input_, func: Callable):
+ string_classes = (str, bytes)
+ if isinstance(input_, string_classes):
+ return input_
+ elif isinstance(input_, collections.Mapping):
+ return {k: map_tensor(sample, func) for k, sample in input_.items()}
+ elif isinstance(input_, collections.Sequence):
+ return [map_tensor(sample, func) for sample in input_]
+ elif isinstance(input_, torch.Tensor):
+ return func(input_)
+ else:
+ return input_
+
+
+def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
+ """Move batch (dict) to device"""
+
+ def _func(tensor):
+ return tensor.to(device=device, non_blocking=non_blocking).detach()
+
+ return map_tensor(batch, _func)
+
+
+def rbd(data: dict) -> dict:
+ """Remove batch dimension from elements in data"""
+ return {
+ k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
+ for k, v in data.items()
+ }
+
+
+def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
+ """Read an image from path as RGB or grayscale"""
+ if not Path(path).exists():
+ raise FileNotFoundError(f"No image at path {path}.")
+ mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise IOError(f"Could not read image at {path}.")
+ if not grayscale:
+ image = image[..., ::-1]
+ return image
+
+
+def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
+ """Normalize the image tensor and reorder the dimensions."""
+ if image.ndim == 3:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ elif image.ndim == 2:
+ image = image[None] # add channel axis
+ else:
+ raise ValueError(f"Not an image: {image.shape}")
+ return torch.tensor(image / 255.0, dtype=torch.float)
+
+
+def resize_image(
+ image: np.ndarray,
+ size: Union[List[int], int],
+ fn: str = "max",
+ interp: Optional[str] = "area",
+) -> np.ndarray:
+ """Resize an image to a fixed size, or according to max or min edge."""
+ h, w = image.shape[:2]
+
+ fn = {"max": max, "min": min}[fn]
+ if isinstance(size, int):
+ scale = size / fn(h, w)
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
+ scale = (w_new / w, h_new / h)
+ elif isinstance(size, (tuple, list)):
+ h_new, w_new = size
+ scale = (w_new / w, h_new / h)
+ else:
+ raise ValueError(f"Incorrect new size: {size}")
+ mode = {
+ "linear": cv2.INTER_LINEAR,
+ "cubic": cv2.INTER_CUBIC,
+ "nearest": cv2.INTER_NEAREST,
+ "area": cv2.INTER_AREA,
+ }[interp]
+ return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
+
+
+def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
+ image = read_image(path)
+ if resize is not None:
+ image, _ = resize_image(image, resize, **kwargs)
+ return numpy_image_to_torch(image)
+
+
+class Extractor(torch.nn.Module):
+ def __init__(self, **conf):
+ super().__init__()
+ self.conf = SimpleNamespace(**{**self.default_conf, **conf})
+
+ @torch.no_grad()
+ def extract(self, img: torch.Tensor, **conf) -> dict:
+ """Perform extraction with online resizing"""
+ if img.dim() == 3:
+ img = img[None] # add batch dim
+ assert img.dim() == 4 and img.shape[0] == 1
+ shape = img.shape[-2:][::-1]
+ img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
+ feats = self.forward({"image": img})
+ feats["image_size"] = torch.tensor(shape)[None].to(img).float()
+ feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
+ return feats
+
+
+def match_pair(
+ extractor,
+ matcher,
+ image0: torch.Tensor,
+ image1: torch.Tensor,
+ device: str = "cpu",
+ **preprocess,
+):
+ """Match a pair of images (image0, image1) with an extractor and matcher"""
+ feats0 = extractor.extract(image0, **preprocess)
+ feats1 = extractor.extract(image1, **preprocess)
+ matches01 = matcher({"image0": feats0, "image1": feats1})
+ data = [feats0, feats1, matches01]
+ # remove batch dim and move to target device
+ feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
+ return feats0, feats1, matches01
diff --git a/imcui/third_party/LightGlue/lightglue/viz2d.py b/imcui/third_party/LightGlue/lightglue/viz2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..22dc3f65662181666b1ff57403af4ad18bfdc271
--- /dev/null
+++ b/imcui/third_party/LightGlue/lightglue/viz2d.py
@@ -0,0 +1,184 @@
+"""
+2D visualization primitives based on Matplotlib.
+1) Plot images with `plot_images`.
+2) Call `plot_keypoints` or `plot_matches` any number of times.
+3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
+"""
+
+import matplotlib
+import matplotlib.patheffects as path_effects
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def cm_RdGn(x):
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
+ x = np.clip(x, 0, 1)[..., None] * 2
+ c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
+ return np.clip(c, 0, 1)
+
+
+def cm_BlRdGn(x_):
+ """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
+ x = np.clip(x_, 0, 1)[..., None] * 2
+ c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
+
+ xn = -np.clip(x_, -1, 0)[..., None] * 2
+ cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
+ out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
+ return out
+
+
+def cm_prune(x_):
+ """Custom colormap to visualize pruning"""
+ if isinstance(x_, torch.Tensor):
+ x_ = x_.cpu().numpy()
+ max_i = max(x_)
+ norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
+ return cm_BlRdGn(norm_x)
+
+
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ adaptive: whether the figure size should fit the image aspect ratios.
+ """
+ # conversion to (H, W, 3) for torch.Tensor
+ imgs = [
+ img.permute(1, 2, 0).cpu().numpy()
+ if (isinstance(img, torch.Tensor) and img.dim() == 3)
+ else img
+ for img in imgs
+ ]
+
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ if adaptive:
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
+ else:
+ ratios = [4 / 3] * n
+ figsize = [sum(ratios) * 4.5, 4.5]
+ fig, ax = plt.subplots(
+ 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+ )
+ if n == 1:
+ ax = [ax]
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ ax[i].set_axis_off()
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ if titles:
+ ax[i].set_title(titles[i])
+ fig.tight_layout(pad=pad)
+
+
+def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ if not isinstance(a, list):
+ a = [a] * len(kpts)
+ if axes is None:
+ axes = plt.gcf().axes
+ for ax, k, c, alpha in zip(axes, kpts, colors, a):
+ if isinstance(k, torch.Tensor):
+ k = k.cpu().numpy()
+ ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ if axes is None:
+ ax = fig.axes
+ ax0, ax1 = ax[0], ax[1]
+ else:
+ ax0, ax1 = axes
+ if isinstance(kpts0, torch.Tensor):
+ kpts0 = kpts0.cpu().numpy()
+ if isinstance(kpts1, torch.Tensor):
+ kpts1 = kpts1.cpu().numpy()
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ for i in range(len(kpts0)):
+ line = matplotlib.patches.ConnectionPatch(
+ xyA=(kpts0[i, 0], kpts0[i, 1]),
+ xyB=(kpts1[i, 0], kpts1[i, 1]),
+ coordsA=ax0.transData,
+ coordsB=ax1.transData,
+ axesA=ax0,
+ axesB=ax1,
+ zorder=1,
+ color=color[i],
+ linewidth=lw,
+ clip_on=True,
+ alpha=a,
+ label=None if labels is None else labels[i],
+ picker=5.0,
+ )
+ line.set_annotation_clip(True)
+ fig.add_artist(line)
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def add_text(
+ idx,
+ text,
+ pos=(0.01, 0.99),
+ fs=15,
+ color="w",
+ lcolor="k",
+ lwidth=2,
+ ha="left",
+ va="top",
+):
+ ax = plt.gcf().axes[idx]
+ t = ax.text(
+ *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
+ )
+ if lcolor is not None:
+ t.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+ path_effects.Normal(),
+ ]
+ )
+
+
+def save_plot(path, **kw):
+ """Save the current figure without any white margin."""
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
diff --git a/imcui/third_party/RoMa/demo/demo_3D_effect.py b/imcui/third_party/RoMa/demo/demo_3D_effect.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae26caaf92deb884dfabb6eca96aec3406325c3f
--- /dev/null
+++ b/imcui/third_party/RoMa/demo/demo_3D_effect.py
@@ -0,0 +1,47 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from romatch.utils.utils import tensor_to_pil
+
+from romatch import roma_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if torch.backends.mps.is_available():
+ device = torch.device('mps')
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+ parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+ save_path = args.save_path
+
+ # Create model
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
+ roma_model.symmetric = False
+
+ H, W = roma_model.get_output_resolution()
+
+ im1 = Image.open(im1_path).resize((W, H))
+ im2 = Image.open(im2_path).resize((W, H))
+
+ # Match
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+ coords_A, coords_B = warp[...,:2], warp[...,2:]
+ for i, x in enumerate(np.linspace(0,2*np.pi,200)):
+ t = (1 + np.cos(x))/2
+ interp_warp = (1-t)*coords_A + t*coords_B
+ im2_transfer_rgb = F.grid_sample(
+ x2[None], interp_warp[None], mode="bilinear", align_corners=False
+ )[0]
+ tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/demo/demo_fundamental.py b/imcui/third_party/RoMa/demo/demo_fundamental.py
new file mode 100644
index 0000000000000000000000000000000000000000..65ea9ccb76525da3e88e4f426bdebdc4fe742161
--- /dev/null
+++ b/imcui/third_party/RoMa/demo/demo_fundamental.py
@@ -0,0 +1,34 @@
+from PIL import Image
+import torch
+import cv2
+from romatch import roma_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if torch.backends.mps.is_available():
+ device = torch.device('mps')
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+
+ # Create model
+ roma_model = roma_outdoor(device=device)
+
+
+ W_A, H_A = Image.open(im1_path).size
+ W_B, H_B = Image.open(im2_path).size
+
+ # Match
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+ # Sample matches for estimation
+ matches, certainty = roma_model.sample(warp, certainty)
+ kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+ F, mask = cv2.findFundamentalMat(
+ kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+ )
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/demo/demo_match.py b/imcui/third_party/RoMa/demo/demo_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..582767e19d8b50c6c241ea32f81cabb38f52fce2
--- /dev/null
+++ b/imcui/third_party/RoMa/demo/demo_match.py
@@ -0,0 +1,50 @@
+import os
+os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
+import torch
+from PIL import Image
+import torch.nn.functional as F
+import numpy as np
+from romatch.utils.utils import tensor_to_pil
+
+from romatch import roma_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if torch.backends.mps.is_available():
+ device = torch.device('mps')
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+ save_path = args.save_path
+
+ # Create model
+ roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
+
+ H, W = roma_model.get_output_resolution()
+
+ im1 = Image.open(im1_path).resize((W, H))
+ im2 = Image.open(im2_path).resize((W, H))
+
+ # Match
+ warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+ # Sampling not needed, but can be done with model.sample(warp, certainty)
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+ im2_transfer_rgb = F.grid_sample(
+ x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+ )[0]
+ im1_transfer_rgb = F.grid_sample(
+ x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+ )[0]
+ warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
+ white_im = torch.ones((H,2*W),device=device)
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
+ tensor_to_pil(vis_im, unnormalize=False).save(save_path)
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/demo/demo_match_opencv_sift.py b/imcui/third_party/RoMa/demo/demo_match_opencv_sift.py
new file mode 100644
index 0000000000000000000000000000000000000000..3196fcfaab248f6c4c6247a0afb4db745206aee8
--- /dev/null
+++ b/imcui/third_party/RoMa/demo/demo_match_opencv_sift.py
@@ -0,0 +1,43 @@
+from PIL import Image
+import numpy as np
+
+import numpy as np
+import cv2 as cv
+import matplotlib.pyplot as plt
+
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+ parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+ save_path = args.save_path
+
+ img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
+ img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
+ # Initiate SIFT detector
+ sift = cv.SIFT_create()
+ # find the keypoints and descriptors with SIFT
+ kp1, des1 = sift.detectAndCompute(img1,None)
+ kp2, des2 = sift.detectAndCompute(img2,None)
+ # BFMatcher with default params
+ bf = cv.BFMatcher()
+ matches = bf.knnMatch(des1,des2,k=2)
+ # Apply ratio test
+ good = []
+ for m,n in matches:
+ if m.distance < 0.75*n.distance:
+ good.append([m])
+ # cv.drawMatchesKnn expects list of lists as matches.
+ draw_params = dict(matchColor = (255,0,0), # draw matches in red color
+ singlePointColor = None,
+ flags = 2)
+
+ img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
+ Image.fromarray(img3).save("demo/sift_matches.png")
diff --git a/imcui/third_party/RoMa/demo/demo_match_tiny.py b/imcui/third_party/RoMa/demo/demo_match_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8e66a4b80a2361e22673ddc59632f48ad653b69
--- /dev/null
+++ b/imcui/third_party/RoMa/demo/demo_match_tiny.py
@@ -0,0 +1,77 @@
+import os
+os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
+import torch
+from PIL import Image
+import torch.nn.functional as F
+import numpy as np
+from romatch.utils.utils import tensor_to_pil
+
+from romatch import tiny_roma_v1_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if torch.backends.mps.is_available():
+ device = torch.device('mps')
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+ parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+ parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
+ parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
+
+ args, _ = parser.parse_known_args()
+ im1_path = args.im_A_path
+ im2_path = args.im_B_path
+
+ # Create model
+ roma_model = tiny_roma_v1_outdoor(device=device)
+
+ # Match
+ warp, certainty1 = roma_model.match(im1_path, im2_path)
+
+ h1, w1 = warp.shape[:2]
+
+ # maybe im1.size != im2.size
+ im1 = Image.open(im1_path).resize((w1, h1))
+ im2 = Image.open(im2_path)
+ x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+ x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+ h2, w2 = x2.shape[1:]
+ g1_p2x = w2 / 2 * (warp[..., 2] + 1)
+ g1_p2y = h2 / 2 * (warp[..., 3] + 1)
+ g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
+ g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
+
+ x, y = torch.meshgrid(
+ torch.arange(w1, device=device),
+ torch.arange(h1, device=device),
+ indexing="xy",
+ )
+ g2x = torch.round(g1_p2x[y, x]).long()
+ g2y = torch.round(g1_p2y[y, x]).long()
+ idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
+ idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
+ idx = torch.bitwise_and(idx_x, idx_y)
+ g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
+ g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
+
+ certainty2 = F.grid_sample(
+ certainty1[None][None],
+ torch.stack([g2_p1x, g2_p1y], dim=2)[None],
+ mode="bilinear",
+ align_corners=False,
+ )[0]
+
+ white_im1 = torch.ones((h1, w1), device = device)
+ white_im2 = torch.ones((h2, w2), device = device)
+
+ certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
+ certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
+
+ vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
+ vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
+
+ tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
+ tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/__init__.py b/imcui/third_party/RoMa/romatch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c2aff7bb510b1d8dfcdc9c4d91ffc3aa5021f6
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/__init__.py
@@ -0,0 +1,8 @@
+import os
+from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
+
+DEBUG_MODE = False
+RANK = int(os.environ.get('RANK', default = 0))
+GLOBAL_STEP = 0
+STEP_SIZE = 1
+LOCAL_RANK = -1
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/benchmarks/__init__.py b/imcui/third_party/RoMa/romatch/benchmarks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..af32a46ba4a48d719e3ad38f9b2355a13fe6cc44
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/benchmarks/__init__.py
@@ -0,0 +1,6 @@
+from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
+from .scannet_benchmark import ScanNetBenchmark
+from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
+from .megadepth_dense_benchmark import MegadepthDenseBenchmark
+from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
+#from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..5972361f80d4f4e5cafd8fd359c87c0433a0a5a5
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py
@@ -0,0 +1,113 @@
+from PIL import Image
+import numpy as np
+
+import os
+
+from tqdm import tqdm
+from romatch.utils import pose_auc
+import cv2
+
+
+class HpatchesHomogBenchmark:
+ """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
+
+ def __init__(self, dataset_path) -> None:
+ seqs_dir = "hpatches-sequences-release"
+ self.seqs_path = os.path.join(dataset_path, seqs_dir)
+ self.seq_names = sorted(os.listdir(self.seqs_path))
+ # Ignore seqs is same as LoFTR.
+ self.ignore_seqs = set(
+ [
+ "i_contruction",
+ "i_crownnight",
+ "i_dc",
+ "i_pencils",
+ "i_whitebuilding",
+ "v_artisans",
+ "v_astronautis",
+ "v_talent",
+ ]
+ )
+
+ def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
+ offset = 0.5 # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
+ im_A_coords = (
+ np.stack(
+ (
+ wq * (im_A_coords[..., 0] + 1) / 2,
+ hq * (im_A_coords[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ im_A_to_im_B = (
+ np.stack(
+ (
+ wsup * (im_A_to_im_B[..., 0] + 1) / 2,
+ hsup * (im_A_to_im_B[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ - offset
+ )
+ return im_A_coords, im_A_to_im_B
+
+ def benchmark(self, model, model_name = None):
+ n_matches = []
+ homog_dists = []
+ for seq_idx, seq_name in tqdm(
+ enumerate(self.seq_names), total=len(self.seq_names)
+ ):
+ im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
+ im_A = Image.open(im_A_path)
+ w1, h1 = im_A.size
+ for im_idx in range(2, 7):
+ im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
+ im_B = Image.open(im_B_path)
+ w2, h2 = im_B.size
+ H = np.loadtxt(
+ os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
+ )
+ dense_matches, dense_certainty = model.match(
+ im_A_path, im_B_path
+ )
+ good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
+ pos_a, pos_b = self.convert_coordinates(
+ good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
+ )
+ try:
+ H_pred, inliers = cv2.findHomography(
+ pos_a,
+ pos_b,
+ method = cv2.RANSAC,
+ confidence = 0.99999,
+ ransacReprojThreshold = 3 * min(w2, h2) / 480,
+ )
+ except:
+ H_pred = None
+ if H_pred is None:
+ H_pred = np.zeros((3, 3))
+ H_pred[2, 2] = 1.0
+ corners = np.array(
+ [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
+ )
+ real_warped_corners = np.dot(corners, np.transpose(H))
+ real_warped_corners = (
+ real_warped_corners[:, :2] / real_warped_corners[:, 2:]
+ )
+ warped_corners = np.dot(corners, np.transpose(H_pred))
+ warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
+ mean_dist = np.mean(
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
+ ) / (min(w2, h2) / 480.0)
+ homog_dists.append(mean_dist)
+
+ n_matches = np.array(n_matches)
+ thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ auc = pose_auc(np.array(homog_dists), thresholds)
+ return {
+ "hpatches_homog_auc_3": auc[2],
+ "hpatches_homog_auc_5": auc[4],
+ "hpatches_homog_auc_10": auc[9],
+ }
diff --git a/imcui/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..09d0a297f2937afc609eed3a74aa0c3c4c7ccebc
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py
@@ -0,0 +1,106 @@
+import torch
+import numpy as np
+import tqdm
+from romatch.datasets import MegadepthBuilder
+from romatch.utils import warp_kpts
+from torch.utils.data import ConcatDataset
+import romatch
+
+class MegadepthDenseBenchmark:
+ def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
+ mega = MegadepthBuilder(data_root=data_root)
+ self.dataset = ConcatDataset(
+ mega.build_scenes(split="test_loftr", ht=h, wt=w)
+ ) # fixed resolution of 384,512
+ self.num_samples = num_samples
+
+ def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
+ b, h1, w1, d = dense_matches.shape
+ with torch.no_grad():
+ x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
+ mask, x2 = warp_kpts(
+ x1.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ )
+ x2 = torch.stack(
+ (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
+ )
+ prob = mask.float().reshape(b, h1, w1)
+ x2_hat = dense_matches[..., 2:]
+ x2_hat = torch.stack(
+ (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
+ )
+ gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
+ gd = gd[prob == 1]
+ pck_1 = (gd < 1.0).float().mean()
+ pck_3 = (gd < 3.0).float().mean()
+ pck_5 = (gd < 5.0).float().mean()
+ return gd, pck_1, pck_3, pck_5, prob
+
+ def benchmark(self, model, batch_size=8):
+ model.train(False)
+ with torch.no_grad():
+ gd_tot = 0.0
+ pck_1_tot = 0.0
+ pck_3_tot = 0.0
+ pck_5_tot = 0.0
+ sampler = torch.utils.data.WeightedRandomSampler(
+ torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
+ )
+ B = batch_size
+ dataloader = torch.utils.data.DataLoader(
+ self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
+ )
+ for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
+ im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
+ data["im_A"].cuda(),
+ data["im_B"].cuda(),
+ data["im_A_depth"].cuda(),
+ data["im_B_depth"].cuda(),
+ data["T_1to2"].cuda(),
+ data["K1"].cuda(),
+ data["K2"].cuda(),
+ )
+ matches, certainty = model.match(im_A, im_B, batched=True)
+ gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
+ depth1, depth2, T_1to2, K1, K2, matches
+ )
+ if romatch.DEBUG_MODE:
+ from romatch.utils.utils import tensor_to_pil
+ import torch.nn.functional as F
+ path = "vis"
+ H, W = model.get_output_resolution()
+ white_im = torch.ones((B,1,H,W),device="cuda")
+ im_B_transfer_rgb = F.grid_sample(
+ im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
+ )
+ warp_im = im_B_transfer_rgb
+ c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
+ vis_im = c_b * warp_im + (1 - c_b) * white_im
+ for b in range(B):
+ import os
+ os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
+ tensor_to_pil(vis_im[b], unnormalize=True).save(
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
+ tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
+ tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
+ f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
+
+
+ gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
+ gd_tot + gd.mean(),
+ pck_1_tot + pck_1,
+ pck_3_tot + pck_3,
+ pck_5_tot + pck_5,
+ )
+ return {
+ "epe": gd_tot.item() / len(dataloader),
+ "mega_pck_1": pck_1_tot.item() / len(dataloader),
+ "mega_pck_3": pck_3_tot.item() / len(dataloader),
+ "mega_pck_5": pck_5_tot.item() / len(dataloader),
+ }
diff --git a/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f293f9556d919643f6f39156314a5b402d9082
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py
@@ -0,0 +1,118 @@
+import numpy as np
+import torch
+from romatch.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+import romatch
+import kornia.geometry.epipolar as kepi
+
+class MegaDepthPoseEstimationBenchmark:
+ def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+ if scene_names is None:
+ self.scene_names = [
+ "0015_0.1_0.3.npz",
+ "0015_0.3_0.5.npz",
+ "0022_0.1_0.3.npz",
+ "0022_0.3_0.5.npz",
+ "0022_0.5_0.7.npz",
+ ]
+ else:
+ self.scene_names = scene_names
+ self.scenes = [
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
+ for scene in self.scene_names
+ ]
+ self.data_root = data_root
+
+ def benchmark(self, model, model_name = None):
+ with torch.no_grad():
+ data_root = self.data_root
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ thresholds = [5, 10, 20]
+ for scene_ind in range(len(self.scenes)):
+ import os
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
+ scene = self.scenes[scene_ind]
+ pairs = scene["pair_infos"]
+ intrinsics = scene["intrinsics"]
+ poses = scene["poses"]
+ im_paths = scene["image_paths"]
+ pair_inds = range(len(pairs))
+ for pairind in tqdm(pair_inds):
+ idx1, idx2 = pairs[pairind][0]
+ K1 = intrinsics[idx1].copy()
+ T1 = poses[idx1].copy()
+ R1, t1 = T1[:3, :3], T1[:3, 3]
+ K2 = intrinsics[idx2].copy()
+ T2 = poses[idx2].copy()
+ R2, t2 = T2[:3, :3], T2[:3, 3]
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
+ dense_matches, dense_certainty = model.match(
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
+ )
+ sparse_matches,_ = model.sample(
+ dense_matches, dense_certainty, 5_000
+ )
+
+ im_A = Image.open(im_A_path)
+ w1, h1 = im_A.size
+ im_B = Image.open(im_B_path)
+ w2, h2 = im_B.size
+ if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
+ scale1 = 1200 / max(w1, h1)
+ scale2 = 1200 / max(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1, K2 = K1.copy(), K2.copy()
+ K1[:2] = K1[:2] * scale1
+ K2[:2] = K2[:2] * scale2
+
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ threshold = 0.5
+ norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ print(f"{model_name} auc: {auc}")
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py
new file mode 100644
index 0000000000000000000000000000000000000000..4732ccf2af5b50e6db60831d7c63c5bf70ec727c
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py
@@ -0,0 +1,119 @@
+import numpy as np
+import torch
+from romatch.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+import romatch
+import kornia.geometry.epipolar as kepi
+
+# wrap cause pyposelib is still in dev
+# will add in deps later
+import poselib
+
+class Mega1500PoseLibBenchmark:
+ def __init__(self, data_root="data/megadepth", scene_names = None, num_ransac_iter = 5, test_every = 1) -> None:
+ if scene_names is None:
+ self.scene_names = [
+ "0015_0.1_0.3.npz",
+ "0015_0.3_0.5.npz",
+ "0022_0.1_0.3.npz",
+ "0022_0.3_0.5.npz",
+ "0022_0.5_0.7.npz",
+ ]
+ else:
+ self.scene_names = scene_names
+ self.scenes = [
+ np.load(f"{data_root}/{scene}", allow_pickle=True)
+ for scene in self.scene_names
+ ]
+ self.data_root = data_root
+ self.num_ransac_iter = num_ransac_iter
+ self.test_every = test_every
+
+ def benchmark(self, model, model_name = None):
+ with torch.no_grad():
+ data_root = self.data_root
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ thresholds = [5, 10, 20]
+ for scene_ind in range(len(self.scenes)):
+ import os
+ scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
+ scene = self.scenes[scene_ind]
+ pairs = scene["pair_infos"]
+ intrinsics = scene["intrinsics"]
+ poses = scene["poses"]
+ im_paths = scene["image_paths"]
+ pair_inds = range(len(pairs))[::self.test_every]
+ for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")):
+ idx1, idx2 = pairs[pairind][0]
+ K1 = intrinsics[idx1].copy()
+ T1 = poses[idx1].copy()
+ R1, t1 = T1[:3, :3], T1[:3, 3]
+ K2 = intrinsics[idx2].copy()
+ T2 = poses[idx2].copy()
+ R2, t2 = T2[:3, :3], T2[:3, 3]
+ R, t = compute_relative_pose(R1, t1, R2, t2)
+ T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+ im_A_path = f"{data_root}/{im_paths[idx1]}"
+ im_B_path = f"{data_root}/{im_paths[idx2]}"
+ dense_matches, dense_certainty = model.match(
+ im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
+ )
+ sparse_matches,_ = model.sample(
+ dense_matches, dense_certainty, 5_000
+ )
+
+ im_A = Image.open(im_A_path)
+ w1, h1 = im_A.size
+ im_B = Image.open(im_B_path)
+ w2, h2 = im_B.size
+ kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
+ kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
+ for _ in range(self.num_ransac_iter):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ threshold = 1
+ camera1 = {'model': 'PINHOLE', 'width': w1, 'height': h1, 'params': K1[[0,1,0,1], [0,1,2,2]]}
+ camera2 = {'model': 'PINHOLE', 'width': w2, 'height': h2, 'params': K2[[0,1,0,1], [0,1,2,2]]}
+ relpose, res = poselib.estimate_relative_pose(
+ kpts1,
+ kpts2,
+ camera1,
+ camera2,
+ ransac_opt = {"max_reproj_error": 2*threshold, "max_epipolar_error": threshold, "min_inliers": 8, "max_iterations": 10_000},
+ )
+ Rt_est = relpose.Rt
+ R_est, t_est = Rt_est[:3,:3], Rt_est[:3,3:]
+ mask = np.array(res['inliers']).astype(np.float32)
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}")
+ tot_e_pose = np.array(tot_e_pose)
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ print(f"{model_name} auc: {auc}")
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/imcui/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py b/imcui/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c6b3c0eeb1c211d224edde84974529c66b52460
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py
@@ -0,0 +1,143 @@
+import os.path as osp
+import numpy as np
+import torch
+from romatch.utils import *
+from PIL import Image
+from tqdm import tqdm
+
+
+class ScanNetBenchmark:
+ def __init__(self, data_root="data/scannet") -> None:
+ self.data_root = data_root
+
+ def benchmark(self, model, model_name = None):
+ model.train(False)
+ with torch.no_grad():
+ data_root = self.data_root
+ tmp = np.load(osp.join(data_root, "test.npz"))
+ pairs, rel_pose = tmp["name"], tmp["rel_pose"]
+ tot_e_t, tot_e_R, tot_e_pose = [], [], []
+ pair_inds = np.random.choice(
+ range(len(pairs)), size=len(pairs), replace=False
+ )
+ for pairind in tqdm(pair_inds, smoothing=0.9):
+ scene = pairs[pairind]
+ scene_name = f"scene0{scene[0]}_00"
+ im_A_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[2]}.jpg",
+ )
+ im_A = Image.open(im_A_path)
+ im_B_path = osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "color",
+ f"{scene[3]}.jpg",
+ )
+ im_B = Image.open(im_B_path)
+ T_gt = rel_pose[pairind].reshape(3, 4)
+ R, t = T_gt[:3, :3], T_gt[:3, 3]
+ K = np.stack(
+ [
+ np.array([float(i) for i in r.split()])
+ for r in open(
+ osp.join(
+ self.data_root,
+ "scans_test",
+ scene_name,
+ "intrinsic",
+ "intrinsic_color.txt",
+ ),
+ "r",
+ )
+ .read()
+ .split("\n")
+ if r
+ ]
+ )
+ w1, h1 = im_A.size
+ w2, h2 = im_B.size
+ K1 = K.copy()
+ K2 = K.copy()
+ dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
+ sparse_matches, sparse_certainty = model.sample(
+ dense_matches, dense_certainty, 5000
+ )
+ scale1 = 480 / min(w1, h1)
+ scale2 = 480 / min(w2, h2)
+ w1, h1 = scale1 * w1, scale1 * h1
+ w2, h2 = scale2 * w2, scale2 * h2
+ K1 = K1 * scale1
+ K2 = K2 * scale2
+
+ offset = 0.5
+ kpts1 = sparse_matches[:, :2]
+ kpts1 = (
+ np.stack(
+ (
+ w1 * (kpts1[:, 0] + 1) / 2 - offset,
+ h1 * (kpts1[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ kpts2 = sparse_matches[:, 2:]
+ kpts2 = (
+ np.stack(
+ (
+ w2 * (kpts2[:, 0] + 1) / 2 - offset,
+ h2 * (kpts2[:, 1] + 1) / 2 - offset,
+ ),
+ axis=-1,
+ )
+ )
+ for _ in range(5):
+ shuffling = np.random.permutation(np.arange(len(kpts1)))
+ kpts1 = kpts1[shuffling]
+ kpts2 = kpts2[shuffling]
+ try:
+ norm_threshold = 0.5 / (
+ np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+ R_est, t_est, mask = estimate_pose(
+ kpts1,
+ kpts2,
+ K1,
+ K2,
+ norm_threshold,
+ conf=0.99999,
+ )
+ T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
+ e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+ e_pose = max(e_t, e_R)
+ except Exception as e:
+ print(repr(e))
+ e_t, e_R = 90, 90
+ e_pose = max(e_t, e_R)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_t.append(e_t)
+ tot_e_R.append(e_R)
+ tot_e_pose.append(e_pose)
+ tot_e_pose = np.array(tot_e_pose)
+ thresholds = [5, 10, 20]
+ auc = pose_auc(tot_e_pose, thresholds)
+ acc_5 = (tot_e_pose < 5).mean()
+ acc_10 = (tot_e_pose < 10).mean()
+ acc_15 = (tot_e_pose < 15).mean()
+ acc_20 = (tot_e_pose < 20).mean()
+ map_5 = acc_5
+ map_10 = np.mean([acc_5, acc_10])
+ map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+ return {
+ "auc_5": auc[0],
+ "auc_10": auc[1],
+ "auc_20": auc[2],
+ "map_5": map_5,
+ "map_10": map_10,
+ "map_20": map_20,
+ }
diff --git a/imcui/third_party/RoMa/romatch/checkpointing/__init__.py b/imcui/third_party/RoMa/romatch/checkpointing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f5afe727aa6f6e8fffa9ecf5be69cbff686577
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/checkpointing/__init__.py
@@ -0,0 +1 @@
+from .checkpoint import CheckPoint
diff --git a/imcui/third_party/RoMa/romatch/checkpointing/checkpoint.py b/imcui/third_party/RoMa/romatch/checkpointing/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5a5131322241475323f1b992d9d3f7b21dbdac
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/checkpointing/checkpoint.py
@@ -0,0 +1,60 @@
+import os
+import torch
+from torch.nn.parallel.data_parallel import DataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
+from loguru import logger
+import gc
+
+import romatch
+
+class CheckPoint:
+ def __init__(self, dir=None, name="tmp"):
+ self.name = name
+ self.dir = dir
+ os.makedirs(self.dir, exist_ok=True)
+
+ def save(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ if romatch.RANK == 0:
+ assert model is not None
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
+ model = model.module
+ states = {
+ "model": model.state_dict(),
+ "n": n,
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ }
+ torch.save(states, self.dir + self.name + f"_latest.pth")
+ logger.info(f"Saved states {list(states.keys())}, at step {n}")
+
+ def load(
+ self,
+ model,
+ optimizer,
+ lr_scheduler,
+ n,
+ ):
+ if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
+ states = torch.load(self.dir + self.name + f"_latest.pth")
+ if "model" in states:
+ model.load_state_dict(states["model"])
+ if "n" in states:
+ n = states["n"] if states["n"] else n
+ if "optimizer" in states:
+ try:
+ optimizer.load_state_dict(states["optimizer"])
+ except Exception as e:
+ print(f"Failed to load states for optimizer, with error {e}")
+ if "lr_scheduler" in states:
+ lr_scheduler.load_state_dict(states["lr_scheduler"])
+ print(f"Loaded states {list(states.keys())}, at step {n}")
+ del states
+ gc.collect()
+ torch.cuda.empty_cache()
+ return model, optimizer, lr_scheduler, n
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/datasets/__init__.py b/imcui/third_party/RoMa/romatch/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60c709926a4a7bd019b73eac10879063a996c90
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/datasets/__init__.py
@@ -0,0 +1,2 @@
+from .megadepth import MegadepthBuilder
+from .scannet import ScanNetBuilder
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/datasets/megadepth.py b/imcui/third_party/RoMa/romatch/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..88f775ef412f7bd062cc9a1d67d95a030e7a15dd
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/datasets/megadepth.py
@@ -0,0 +1,232 @@
+import os
+from PIL import Image
+import h5py
+import numpy as np
+import torch
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+import romatch
+from romatch.utils import *
+import math
+
+class MegadepthScene:
+ def __init__(
+ self,
+ data_root,
+ scene_info,
+ ht=384,
+ wt=512,
+ min_overlap=0.0,
+ max_overlap=1.0,
+ shake_t=0,
+ rot_prob=0.0,
+ normalize=True,
+ max_num_pairs = 100_000,
+ scene_name = None,
+ use_horizontal_flip_aug = False,
+ use_single_horizontal_flip_aug = False,
+ colorjiggle_params = None,
+ random_eraser = None,
+ use_randaug = False,
+ randaug_params = None,
+ randomize_size = False,
+ ) -> None:
+ self.data_root = data_root
+ self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
+ self.image_paths = scene_info["image_paths"]
+ self.depth_paths = scene_info["depth_paths"]
+ self.intrinsics = scene_info["intrinsics"]
+ self.poses = scene_info["poses"]
+ self.pairs = scene_info["pairs"]
+ self.overlaps = scene_info["overlaps"]
+ threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
+ self.pairs = self.pairs[threshold]
+ self.overlaps = self.overlaps[threshold]
+ if len(self.pairs) > max_num_pairs:
+ pairinds = np.random.choice(
+ np.arange(0, len(self.pairs)), max_num_pairs, replace=False
+ )
+ self.pairs = self.pairs[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ if randomize_size:
+ area = ht * wt
+ s = int(16 * (math.sqrt(area)//16))
+ sizes = ((ht,wt), (s,s), (wt,ht))
+ choice = romatch.RANK % 3
+ ht, wt = sizes[choice]
+ # counts, bins = np.histogram(self.overlaps,20)
+ # print(counts)
+ self.im_transform_ops = get_tuple_transform_ops(
+ resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
+ )
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
+ resize=(ht, wt)
+ )
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.random_eraser = random_eraser
+ if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
+ raise ValueError("Can't both flip both images and only flip one")
+ self.use_horizontal_flip_aug = use_horizontal_flip_aug
+ self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
+ self.use_randaug = use_randaug
+
+ def load_im(self, im_path):
+ im = Image.open(im_path)
+ return im
+
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
+ im_A = im_A.flip(-1)
+ im_B = im_B.flip(-1)
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+ K_A = flip_mat@K_A
+ K_B = flip_mat@K_B
+
+ return im_A, im_B, depth_A, depth_B, K_A, K_B
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
+ return torch.from_numpy(depth)
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+ return sK @ K
+
+ def rand_shake(self, *things):
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
+ return [
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
+ for thing in things
+ ], t
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ idx1, idx2 = self.pairs[pair_idx]
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T1 = self.poses[idx1]
+ T2 = self.poses[idx2]
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+ :4, :4
+ ] # (4, 4)
+
+ # Load positive pair data
+ im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
+ im_A_ref = os.path.join(self.data_root, im_A)
+ im_B_ref = os.path.join(self.data_root, im_B)
+ depth_A_ref = os.path.join(self.data_root, depth1)
+ depth_B_ref = os.path.join(self.data_root, depth2)
+ im_A = self.load_im(im_A_ref)
+ im_B = self.load_im(im_B_ref)
+ K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
+ K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
+
+ if self.use_randaug:
+ im_A, im_B = self.rand_augment(im_A, im_B)
+
+ depth_A = self.load_depth(depth_A_ref)
+ depth_B = self.load_depth(depth_B_ref)
+ # Process images
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
+ depth_A, depth_B = self.depth_transform_ops(
+ (depth_A[None, None], depth_B[None, None])
+ )
+
+ [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
+ K1[:2, 2] += t
+ K2[:2, 2] += t
+
+ im_A, im_B = im_A[None], im_B[None]
+ if self.random_eraser is not None:
+ im_A, depth_A = self.random_eraser(im_A, depth_A)
+ im_B, depth_B = self.random_eraser(im_B, depth_B)
+
+ if self.use_horizontal_flip_aug:
+ if np.random.rand() > 0.5:
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+ if self.use_single_horizontal_flip_aug:
+ if np.random.rand() > 0.5:
+ im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
+
+ if romatch.DEBUG_MODE:
+ tensor_to_pil(im_A[0], unnormalize=True).save(
+ f"vis/im_A.jpg")
+ tensor_to_pil(im_B[0], unnormalize=True).save(
+ f"vis/im_B.jpg")
+
+ data_dict = {
+ "im_A": im_A[0],
+ "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
+ "im_B": im_B[0],
+ "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
+ "im_A_depth": depth_A[0, 0],
+ "im_B_depth": depth_B[0, 0],
+ "K1": K1,
+ "K2": K2,
+ "T_1to2": T_1to2,
+ "im_A_path": im_A_ref,
+ "im_B_path": im_B_ref,
+
+ }
+ return data_dict
+
+
+class MegadepthBuilder:
+ def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
+ self.all_scenes = os.listdir(self.scene_info_root)
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
+ # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
+ self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
+ self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
+ self.loftr_ignore = loftr_ignore
+ self.imc21_ignore = imc21_ignore
+
+ def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
+ if split == "train":
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
+ elif split == "train_loftr":
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
+ elif split == "test":
+ scene_names = self.test_scenes
+ elif split == "test_loftr":
+ scene_names = self.test_scenes_loftr
+ elif split == "custom":
+ scene_names = scene_names
+ else:
+ raise ValueError(f"Split {split} not available")
+ scenes = []
+ for scene_name in scene_names:
+ if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
+ continue
+ if self.imc21_ignore and scene_name in self.imc21_scenes:
+ continue
+ if ".npy" not in scene_name:
+ continue
+ scene_info = np.load(
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+ ).item()
+ scenes.append(
+ MegadepthScene(
+ self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
+ )
+ )
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=0.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
+ return ws
diff --git a/imcui/third_party/RoMa/romatch/datasets/scannet.py b/imcui/third_party/RoMa/romatch/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e03261557147cd3449c76576a5e5e22c0ae288e9
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/datasets/scannet.py
@@ -0,0 +1,160 @@
+import os
+import random
+from PIL import Image
+import cv2
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset)
+
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+import os.path as osp
+import matplotlib.pyplot as plt
+import romatch
+from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+from romatch.utils.transforms import GeometricSequential
+from tqdm import tqdm
+
+class ScanNetScene:
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
+) -> None:
+ self.scene_root = osp.join(data_root,"scans","scans_train")
+ self.data_names = scene_info['name']
+ self.overlaps = scene_info['score']
+ # Only sample 10s
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+ self.overlaps = self.overlaps[valid]
+ self.data_names = self.data_names[valid]
+ if len(self.data_names) > 10000:
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+ self.data_names = self.data_names[pairinds]
+ self.overlaps = self.overlaps[pairinds]
+ self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+ self.wt, self.ht = wt, ht
+ self.shake_t = shake_t
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+ self.use_horizontal_flip_aug = use_horizontal_flip_aug
+
+ def load_im(self, im_B, crop=None):
+ im = Image.open(im_B)
+ return im
+
+ def load_depth(self, depth_ref, crop=None):
+ depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def scale_intrinsic(self, K, wi, hi):
+ sx, sy = self.wt / wi, self.ht / hi
+ sK = torch.tensor([[sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1]])
+ return sK@K
+
+ def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
+ im_A = im_A.flip(-1)
+ im_B = im_B.flip(-1)
+ depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
+ flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+ K_A = flip_mat@K_A
+ K_B = flip_mat@K_B
+
+ return im_A, im_B, depth_A, depth_B, K_A, K_B
+ def read_scannet_pose(self,path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = np.linalg.inv(cam2world)
+ return world2cam
+
+
+ def read_scannet_intrinsic(self,path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
+
+ def __getitem__(self, pair_idx):
+ # read intrinsics of original size
+ data_name = self.data_names[pair_idx]
+ scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the intrinsic of depthmap
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
+ scene_name,
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+ # read and compute relative poses
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_1}.txt'))
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
+ scene_name,
+ 'pose', f'{stem_name_2}.txt'))
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
+
+ # Load positive pair data
+ im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
+ im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
+ depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
+ depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+
+ im_A = self.load_im(im_A_ref)
+ im_B = self.load_im(im_B_ref)
+ depth_A = self.load_depth(depth_A_ref)
+ depth_B = self.load_depth(depth_B_ref)
+
+ # Recompute camera intrinsic matrix due to the resize
+ K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
+ K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
+ # Process images
+ im_A, im_B = self.im_transform_ops((im_A, im_B))
+ depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
+ if self.use_horizontal_flip_aug:
+ if np.random.rand() > 0.5:
+ im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+
+ data_dict = {'im_A': im_A,
+ 'im_B': im_B,
+ 'im_A_depth': depth_A[0,0],
+ 'im_B_depth': depth_B[0,0],
+ 'K1': K1,
+ 'K2': K2,
+ 'T_1to2':T_1to2,
+ }
+ return data_dict
+
+
+class ScanNetBuilder:
+ def __init__(self, data_root = 'data/scannet') -> None:
+ self.data_root = data_root
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
+ self.all_scenes = os.listdir(self.scene_info_root)
+
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+ # Note: split doesn't matter here as we always use same scannet_train scenes
+ scene_names = self.all_scenes
+ scenes = []
+ for scene_name in tqdm(scene_names, disable = romatch.RANK > 0):
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+ return scenes
+
+ def weight_scenes(self, concat_dataset, alpha=.5):
+ ns = []
+ for d in concat_dataset.datasets:
+ ns.append(len(d))
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+ return ws
diff --git a/imcui/third_party/RoMa/romatch/losses/__init__.py b/imcui/third_party/RoMa/romatch/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e08abacfc0f83d7de0f2ddc0583766a80bf53cf
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/losses/__init__.py
@@ -0,0 +1 @@
+from .robust_loss import RobustLosses
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/losses/robust_loss.py b/imcui/third_party/RoMa/romatch/losses/robust_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d430069666fabe2471ec7eda2fa6e9c996f041
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/losses/robust_loss.py
@@ -0,0 +1,161 @@
+from einops.einops import rearrange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from romatch.utils.utils import get_gt_warp
+import wandb
+import romatch
+import math
+
+class RobustLosses(nn.Module):
+ def __init__(
+ self,
+ robust=False,
+ center_coords=False,
+ scale_normalize=False,
+ ce_weight=0.01,
+ local_loss=True,
+ local_dist=4.0,
+ local_largest_scale=8,
+ smooth_mask = False,
+ depth_interpolation_mode = "bilinear",
+ mask_depth_loss = False,
+ relative_depth_error_threshold = 0.05,
+ alpha = 1.,
+ c = 1e-3,
+ ):
+ super().__init__()
+ self.robust = robust # measured in pixels
+ self.center_coords = center_coords
+ self.scale_normalize = scale_normalize
+ self.ce_weight = ce_weight
+ self.local_loss = local_loss
+ self.local_dist = local_dist
+ self.local_largest_scale = local_largest_scale
+ self.smooth_mask = smooth_mask
+ self.depth_interpolation_mode = depth_interpolation_mode
+ self.mask_depth_loss = mask_depth_loss
+ self.relative_depth_error_threshold = relative_depth_error_threshold
+ self.avg_overlap = dict()
+ self.alpha = alpha
+ self.c = c
+
+ def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
+ with torch.no_grad():
+ B, C, H, W = scale_gm_cls.shape
+ device = x2.device
+ cls_res = round(math.sqrt(C))
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)], indexing='ij')
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
+ GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
+ cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
+ certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
+ if not torch.any(cls_loss):
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
+
+ losses = {
+ f"gm_certainty_loss_{scale}": certainty_loss.mean(),
+ f"gm_cls_loss_{scale}": cls_loss.mean(),
+ }
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
+ return losses
+
+ def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
+ with torch.no_grad():
+ B, C, H, W = delta_cls.shape
+ device = x2.device
+ cls_res = round(math.sqrt(C))
+ G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+ G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
+ GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
+ cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99]
+ certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
+ if not torch.any(cls_loss):
+ cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
+ losses = {
+ f"delta_certainty_loss_{scale}": certainty_loss.mean(),
+ f"delta_cls_loss_{scale}": cls_loss.mean(),
+ }
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
+ return losses
+
+ def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
+ epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
+ if scale == 1:
+ pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
+ wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
+
+ ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
+ a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
+ cs = self.c * scale
+ x = epe[prob > 0.99]
+ reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
+ if not torch.any(reg_loss):
+ reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
+ losses = {
+ f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
+ f"{mode}_regression_loss_{scale}": reg_loss.mean(),
+ }
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
+ return losses
+
+ def forward(self, corresps, batch):
+ scales = list(corresps.keys())
+ tot_loss = 0.0
+ # scale_weights due to differences in scale for regression gradients and classification gradients
+ scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
+ for scale in scales:
+ scale_corresps = corresps[scale]
+ scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
+ scale_corresps["certainty"],
+ scale_corresps.get("flow_pre_delta"),
+ scale_corresps.get("delta_cls"),
+ scale_corresps.get("offset_scale"),
+ scale_corresps.get("gm_cls"),
+ scale_corresps.get("gm_certainty"),
+ scale_corresps["flow"],
+ scale_corresps.get("gm_flow"),
+
+ )
+ if flow_pre_delta is not None:
+ flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
+ b, h, w, d = flow_pre_delta.shape
+ else:
+ # _ = 1
+ b, _, h, w = scale_certainty.shape
+ gt_warp, gt_prob = get_gt_warp(
+ batch["im_A_depth"],
+ batch["im_B_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],
+ H=h,
+ W=w,
+ )
+ x2 = gt_warp.float()
+ prob = gt_prob
+
+ if self.local_largest_scale >= scale:
+ prob = prob * (
+ F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
+ < (2 / 512) * (self.local_dist[scale] * scale))
+
+ if scale_gm_cls is not None:
+ gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
+ gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * gm_loss
+ elif scale_gm_flow is not None:
+ gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
+ gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * gm_loss
+
+ if delta_cls is not None:
+ delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
+ delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
+ else:
+ delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
+ reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
+ tot_loss = tot_loss + scale_weights[scale] * reg_loss
+ prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
+ return tot_loss
diff --git a/imcui/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py b/imcui/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py
new file mode 100644
index 0000000000000000000000000000000000000000..a17c24678b093ca843d16c1a17ea16f19fa594d5
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py
@@ -0,0 +1,160 @@
+from einops.einops import rearrange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from romatch.utils.utils import get_gt_warp
+import wandb
+import romatch
+import math
+
+# This is slightly different than regular romatch due to significantly worse corresps
+# The confidence loss is quite tricky here //Johan
+
+class RobustLosses(nn.Module):
+ def __init__(
+ self,
+ robust=False,
+ center_coords=False,
+ scale_normalize=False,
+ ce_weight=0.01,
+ local_loss=True,
+ local_dist=None,
+ smooth_mask = False,
+ depth_interpolation_mode = "bilinear",
+ mask_depth_loss = False,
+ relative_depth_error_threshold = 0.05,
+ alpha = 1.,
+ c = 1e-3,
+ epe_mask_prob_th = None,
+ cert_only_on_consistent_depth = False,
+ ):
+ super().__init__()
+ if local_dist is None:
+ local_dist = {}
+ self.robust = robust # measured in pixels
+ self.center_coords = center_coords
+ self.scale_normalize = scale_normalize
+ self.ce_weight = ce_weight
+ self.local_loss = local_loss
+ self.local_dist = local_dist
+ self.smooth_mask = smooth_mask
+ self.depth_interpolation_mode = depth_interpolation_mode
+ self.mask_depth_loss = mask_depth_loss
+ self.relative_depth_error_threshold = relative_depth_error_threshold
+ self.avg_overlap = dict()
+ self.alpha = alpha
+ self.c = c
+ self.epe_mask_prob_th = epe_mask_prob_th
+ self.cert_only_on_consistent_depth = cert_only_on_consistent_depth
+
+ def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale):
+ b, h,w, h,w = corr_volume.shape
+ inv_temp = 10
+ corr_volume = corr_volume.reshape(-1, h*w, h*w)
+ nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2)
+ corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean()
+
+ losses = {
+ f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(),
+ }
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
+ return losses
+
+
+
+ def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
+ epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
+ if scale in self.local_dist:
+ prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float()
+ if scale == 1:
+ pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
+ wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
+ if self.epe_mask_prob_th is not None:
+ # if too far away from gt, certainty should be 0
+ gt_cert = prob * (epe < scale * self.epe_mask_prob_th)
+ else:
+ gt_cert = prob
+ if self.cert_only_on_consistent_depth:
+ ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0])
+ else:
+ ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert)
+ a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
+ cs = self.c * scale
+ x = epe[prob > 0.99]
+ reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
+ if not torch.any(reg_loss):
+ reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
+ losses = {
+ f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
+ f"{mode}_regression_loss_{scale}": reg_loss.mean(),
+ }
+ wandb.log(losses, step = romatch.GLOBAL_STEP)
+ return losses
+
+ def forward(self, corresps, batch):
+ scales = list(corresps.keys())
+ tot_loss = 0.0
+ # scale_weights due to differences in scale for regression gradients and classification gradients
+ for scale in scales:
+ scale_corresps = corresps[scale]
+ scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = (
+ scale_corresps["certainty"],
+ scale_corresps.get("flow_pre_delta"),
+ scale_corresps.get("delta_cls"),
+ scale_corresps.get("offset_scale"),
+ scale_corresps.get("corr_volume"),
+ scale_corresps.get("gm_certainty"),
+ scale_corresps["flow"],
+ scale_corresps.get("gm_flow"),
+
+ )
+ if flow_pre_delta is not None:
+ flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
+ b, h, w, d = flow_pre_delta.shape
+ else:
+ # _ = 1
+ b, _, h, w = scale_certainty.shape
+ gt_warp, gt_prob = get_gt_warp(
+ batch["im_A_depth"],
+ batch["im_B_depth"],
+ batch["T_1to2"],
+ batch["K1"],
+ batch["K2"],
+ H=h,
+ W=w,
+ )
+ x2 = gt_warp.float()
+ prob = gt_prob
+
+ if scale_gm_corr_volume is not None:
+ gt_warp_back, _ = get_gt_warp(
+ batch["im_B_depth"],
+ batch["im_A_depth"],
+ batch["T_1to2"].inverse(),
+ batch["K2"],
+ batch["K1"],
+ H=h,
+ W=w,
+ )
+ grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device)
+ #fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1)
+ #diff = (fwd_bck - grid).norm(dim = -1)
+ with torch.no_grad():
+ D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2))
+ D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2))
+ inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values)
+ * (D_A == D_A.min(dim=-2, keepdim = True).values)
+ * (D_B < 0.01)
+ * (D_A < 0.01))
+
+ gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale)
+ gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"]
+ tot_loss = tot_loss + gm_loss
+ elif scale_gm_flow is not None:
+ gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
+ gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
+ tot_loss = tot_loss + gm_loss
+ delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
+ reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
+ tot_loss = tot_loss + reg_loss
+ return tot_loss
diff --git a/imcui/third_party/RoMa/romatch/models/__init__.py b/imcui/third_party/RoMa/romatch/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7650c9c7480920905e27578f175fcb5f995cc8ba
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/__init__.py
@@ -0,0 +1 @@
+from .model_zoo import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/models/encoders.py b/imcui/third_party/RoMa/romatch/models/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..84fb54395139a2ca21860ce2c18d033ad0afb19f
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/encoders.py
@@ -0,0 +1,122 @@
+from typing import Optional, Union
+import torch
+from torch import device
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as tvm
+import gc
+from romatch.utils.utils import get_autocast_params
+
+
+class ResNet50(nn.Module):
+ def __init__(self, pretrained=False, high_res = False, weights = None,
+ dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
+ super().__init__()
+ if dilation is None:
+ dilation = [False,False,False]
+ if anti_aliased:
+ pass
+ else:
+ if weights is not None:
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
+ else:
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
+
+ self.high_res = high_res
+ self.freeze_bn = freeze_bn
+ self.early_exit = early_exit
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+
+ def forward(self, x, **kwargs):
+ autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, self.amp, self.amp_dtype)
+ with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
+ net = self.net
+ feats = {1:x}
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ feats[2] = x
+ x = net.maxpool(x)
+ x = net.layer1(x)
+ feats[4] = x
+ x = net.layer2(x)
+ feats[8] = x
+ if self.early_exit:
+ return feats
+ x = net.layer3(x)
+ feats[16] = x
+ x = net.layer4(x)
+ feats[32] = x
+ return feats
+
+ def train(self, mode=True):
+ super().train(mode)
+ if self.freeze_bn:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ pass
+
+class VGG19(nn.Module):
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+
+ def forward(self, x, **kwargs):
+ autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, self.amp, self.amp_dtype)
+ with torch.autocast(device_type=autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
+ feats = {}
+ scale = 1
+ for layer in self.layers:
+ if isinstance(layer, nn.MaxPool2d):
+ feats[scale] = x
+ scale = scale*2
+ x = layer(x)
+ return feats
+
+class CNNandDinov2(nn.Module):
+ def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
+ super().__init__()
+ if dinov2_weights is None:
+ dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
+ from .transformer import vit_large
+ vit_kwargs = dict(img_size= 518,
+ patch_size= 14,
+ init_values = 1.0,
+ ffn_layer = "mlp",
+ block_chunks = 0,
+ )
+
+ dinov2_vitl14 = vit_large(**vit_kwargs).eval()
+ dinov2_vitl14.load_state_dict(dinov2_weights)
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
+ if not use_vgg:
+ self.cnn = ResNet50(**cnn_kwargs)
+ else:
+ self.cnn = VGG19(**cnn_kwargs)
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+ if self.amp:
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
+
+
+ def train(self, mode: bool = True):
+ return self.cnn.train(mode)
+
+ def forward(self, x, upsample = False):
+ B,C,H,W = x.shape
+ feature_pyramid = self.cnn(x)
+
+ if not upsample:
+ with torch.no_grad():
+ if self.dinov2_vitl14[0].device != x.device:
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
+ del dinov2_features_16
+ feature_pyramid[16] = features_16
+ return feature_pyramid
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/models/matcher.py b/imcui/third_party/RoMa/romatch/models/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..8108a92393f4657e2c87d75c4072a46e7d61cdd6
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/matcher.py
@@ -0,0 +1,748 @@
+import os
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+import warnings
+from warnings import warn
+from PIL import Image
+
+from romatch.utils import get_tuple_transform_ops
+from romatch.utils.local_correlation import local_correlation
+from romatch.utils.utils import cls_to_flow_refine, get_autocast_params
+from romatch.utils.kde import kde
+
+class ConvRefiner(nn.Module):
+ def __init__(
+ self,
+ in_dim=6,
+ hidden_dim=16,
+ out_dim=2,
+ dw=False,
+ kernel_size=5,
+ hidden_blocks=3,
+ displacement_emb = None,
+ displacement_emb_dim = None,
+ local_corr_radius = None,
+ corr_in_other = None,
+ no_im_B_fm = False,
+ amp = False,
+ concat_logits = False,
+ use_bias_block_1 = True,
+ use_cosine_corr = False,
+ disable_local_corr_grad = False,
+ is_classifier = False,
+ sample_mode = "bilinear",
+ norm_type = nn.BatchNorm2d,
+ bn_momentum = 0.1,
+ amp_dtype = torch.float16,
+ ):
+ super().__init__()
+ self.bn_momentum = bn_momentum
+ self.block1 = self.create_block(
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
+ )
+ self.hidden_blocks = nn.Sequential(
+ *[
+ self.create_block(
+ hidden_dim,
+ hidden_dim,
+ dw=dw,
+ kernel_size=kernel_size,
+ norm_type=norm_type,
+ )
+ for hb in range(hidden_blocks)
+ ]
+ )
+ self.hidden_blocks = self.hidden_blocks
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
+ if displacement_emb:
+ self.has_displacement_emb = True
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
+ else:
+ self.has_displacement_emb = False
+ self.local_corr_radius = local_corr_radius
+ self.corr_in_other = corr_in_other
+ self.no_im_B_fm = no_im_B_fm
+ self.amp = amp
+ self.concat_logits = concat_logits
+ self.use_cosine_corr = use_cosine_corr
+ self.disable_local_corr_grad = disable_local_corr_grad
+ self.is_classifier = is_classifier
+ self.sample_mode = sample_mode
+ self.amp_dtype = amp_dtype
+
+ def create_block(
+ self,
+ in_dim,
+ out_dim,
+ dw=False,
+ kernel_size=5,
+ bias = True,
+ norm_type = nn.BatchNorm2d,
+ ):
+ num_groups = 1 if not dw else in_dim
+ if dw:
+ assert (
+ out_dim % in_dim == 0
+ ), "outdim must be divisible by indim for depthwise"
+ conv1 = nn.Conv2d(
+ in_dim,
+ out_dim,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ groups=num_groups,
+ bias=bias,
+ )
+ norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
+ relu = nn.ReLU(inplace=True)
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
+ return nn.Sequential(conv1, norm, relu, conv2)
+
+ def forward(self, x, y, flow, scale_factor = 1, logits = None):
+ b,c,hs,ws = x.shape
+ autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, enabled=self.amp, dtype=self.amp_dtype)
+ with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
+ if self.has_displacement_emb:
+ im_A_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
+ ), indexing='ij'
+ )
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
+ in_displacement = flow-im_A_coords
+ emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
+ if self.local_corr_radius:
+ if self.corr_in_other:
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
+ sample_mode = self.sample_mode)
+ else:
+ raise NotImplementedError("Local corr in own frame should not be used.")
+ if self.no_im_B_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
+ else:
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
+ else:
+ if self.no_im_B_fm:
+ x_hat = torch.zeros_like(x)
+ d = torch.cat((x, x_hat), dim=1)
+ if self.concat_logits:
+ d = torch.cat((d, logits), dim=1)
+ d = self.block1(d)
+ d = self.hidden_blocks(d)
+ d = self.out_conv(d.float())
+ displacement, certainty = d[:, :-1], d[:, -1:]
+ return displacement, certainty
+
+class CosKernel(nn.Module): # similar to softmax kernel
+ def __init__(self, T, learn_temperature=False):
+ super().__init__()
+ self.learn_temperature = learn_temperature
+ if self.learn_temperature:
+ self.T = nn.Parameter(torch.tensor(T))
+ else:
+ self.T = T
+
+ def __call__(self, x, y, eps=1e-6):
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
+ )
+ if self.learn_temperature:
+ T = self.T.abs() + 0.01
+ else:
+ T = torch.tensor(self.T, device=c.device)
+ K = ((c - 1.0) / T).exp()
+ return K
+
+class GP(nn.Module):
+ def __init__(
+ self,
+ kernel,
+ T=1,
+ learn_temperature=False,
+ only_attention=False,
+ gp_dim=64,
+ basis="fourier",
+ covar_size=5,
+ only_nearest_neighbour=False,
+ sigma_noise=0.1,
+ no_cov=False,
+ predict_features = False,
+ ):
+ super().__init__()
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
+ self.sigma_noise = sigma_noise
+ self.covar_size = covar_size
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
+ self.only_attention = only_attention
+ self.only_nearest_neighbour = only_nearest_neighbour
+ self.basis = basis
+ self.no_cov = no_cov
+ self.dim = gp_dim
+ self.predict_features = predict_features
+
+ def get_local_cov(self, cov):
+ K = self.covar_size
+ b, h, w, h, w = cov.shape
+ hw = h * w
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
+ delta = torch.stack(
+ torch.meshgrid(
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1),
+ indexing = 'ij'),
+ dim=-1,
+ )
+ positions = torch.stack(
+ torch.meshgrid(
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2),
+ indexing = 'ij'),
+ dim=-1,
+ )
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
+ :,
+ points.flatten(),
+ neighbours[..., 0].flatten(),
+ neighbours[..., 1].flatten(),
+ ].reshape(b, h, w, K**2)
+ return local_cov
+
+ def reshape(self, x):
+ return rearrange(x, "b d h w -> b (h w) d")
+
+ def project_to_basis(self, x):
+ if self.basis == "fourier":
+ return torch.cos(8 * math.pi * self.pos_conv(x))
+ elif self.basis == "linear":
+ return self.pos_conv(x)
+ else:
+ raise ValueError(
+ "No other bases other than fourier and linear currently im_Bed in public release"
+ )
+
+ def get_pos_enc(self, y):
+ b, c, h, w = y.shape
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
+ ),
+ indexing = 'ij'
+ )
+
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
+ return coarse_embedded_coords
+
+ def forward(self, x, y, **kwargs):
+ b, c, h1, w1 = x.shape
+ b, c, h2, w2 = y.shape
+ f = self.get_pos_enc(y)
+ b, d, h2, w2 = f.shape
+ x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
+ K_xx = self.K(x, x)
+ K_yy = self.K(y, y)
+ K_xy = self.K(x, y)
+ K_yx = K_xy.permute(0, 2, 1)
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
+ with warnings.catch_warnings():
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
+
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
+ if not self.no_cov:
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
+ local_cov_x = self.get_local_cov(cov_x)
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
+ else:
+ gp_feats = mu_x
+ return gp_feats
+
+class Decoder(nn.Module):
+ def __init__(
+ self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
+ num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
+ flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
+ ):
+ super().__init__()
+ self.embedding_decoder = embedding_decoder
+ self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
+ self.gps = gps
+ self.proj = proj
+ self.conv_refiner = conv_refiner
+ self.detach = detach
+ if pos_embeddings is None:
+ self.pos_embeddings = {}
+ else:
+ self.pos_embeddings = pos_embeddings
+ if scales == "all":
+ self.scales = ["32", "16", "8", "4", "2", "1"]
+ else:
+ self.scales = scales
+ self.warp_noise_std = warp_noise_std
+ self.refine_init = 4
+ self.displacement_dropout_p = displacement_dropout_p
+ self.gm_warp_dropout_p = gm_warp_dropout_p
+ self.flow_upsample_mode = flow_upsample_mode
+ self.amp_dtype = amp_dtype
+
+ def get_placeholder_flow(self, b, h, w, device):
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ ),
+ indexing = 'ij'
+ )
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ return coarse_coords
+
+ def get_positional_embedding(self, b, h ,w, device):
+ coarse_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+ ),
+ indexing = 'ij'
+ )
+
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+ None
+ ].expand(b, h, w, 2)
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+ coarse_embedded_coords = self.pos_embedding(coarse_coords)
+ return coarse_embedded_coords
+
+ def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
+ coarse_scales = self.embedding_decoder.scales()
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
+ h, w = sizes[1]
+ b = f1[1].shape[0]
+ device = f1[1].device
+ coarsest_scale = int(all_scales[0])
+ old_stuff = torch.zeros(
+ b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
+ )
+ corresps = {}
+ if not upsample:
+ flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
+ certainty = 0.0
+ else:
+ flow = F.interpolate(
+ flow,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ certainty = F.interpolate(
+ certainty,
+ size=sizes[coarsest_scale],
+ align_corners=False,
+ mode="bilinear",
+ )
+ displacement = 0.0
+ for new_scale in all_scales:
+ ins = int(new_scale)
+ corresps[ins] = {}
+ f1_s, f2_s = f1[ins], f2[ins]
+ if new_scale in self.proj:
+ autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(f1_s.device, str(f1_s)=='cuda', self.amp_dtype)
+ with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
+ if not autocast_enabled:
+ f1_s, f2_s = f1_s.to(torch.float32), f2_s.to(torch.float32)
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
+
+ if ins in coarse_scales:
+ old_stuff = F.interpolate(
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
+ )
+ gp_posterior = self.gps[new_scale](f1_s, f2_s)
+ gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
+ gp_posterior, f1_s, old_stuff, new_scale
+ )
+
+ if self.embedding_decoder.is_classifier:
+ flow = cls_to_flow_refine(
+ gm_warp_or_cls,
+ ).permute(0,3,1,2)
+ corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+ else:
+ corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+ flow = gm_warp_or_cls.detach()
+
+ if new_scale in self.conv_refiner:
+ corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
+ delta_flow, delta_certainty = self.conv_refiner[new_scale](
+ f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
+ )
+ corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
+ displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
+ delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
+ flow = flow + displacement
+ certainty = (
+ certainty + delta_certainty
+ ) # predict both certainty and displacement
+ corresps[ins].update({
+ "certainty": certainty,
+ "flow": flow,
+ })
+ if new_scale != "1":
+ flow = F.interpolate(
+ flow,
+ size=sizes[ins // 2],
+ mode=self.flow_upsample_mode,
+ )
+ certainty = F.interpolate(
+ certainty,
+ size=sizes[ins // 2],
+ mode=self.flow_upsample_mode,
+ )
+ if self.detach:
+ flow = flow.detach()
+ certainty = certainty.detach()
+ #torch.cuda.empty_cache()
+ return corresps
+
+
+class RegressionMatcher(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ h=448,
+ w=448,
+ sample_mode = "threshold_balanced",
+ upsample_preds = False,
+ symmetric = False,
+ name = None,
+ attenuate_cert = None,
+ ):
+ super().__init__()
+ self.attenuate_cert = attenuate_cert
+ self.encoder = encoder
+ self.decoder = decoder
+ self.name = name
+ self.w_resized = w
+ self.h_resized = h
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
+ self.sample_mode = sample_mode
+ self.upsample_preds = upsample_preds
+ self.upsample_res = (14*16*6, 14*16*6)
+ self.symmetric = symmetric
+ self.sample_thresh = 0.05
+
+ def get_output_resolution(self):
+ if not self.upsample_preds:
+ return self.h_resized, self.w_resized
+ else:
+ return self.upsample_res
+
+ def extract_backbone_features(self, batch, batched = True, upsample = False):
+ x_q = batch["im_A"]
+ x_s = batch["im_B"]
+ if batched:
+ X = torch.cat((x_q, x_s), dim = 0)
+ feature_pyramid = self.encoder(X, upsample = upsample)
+ else:
+ feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
+ return feature_pyramid
+
+ def sample(
+ self,
+ matches,
+ certainty,
+ num=10000,
+ ):
+ if "threshold" in self.sample_mode:
+ upper_thresh = self.sample_thresh
+ certainty = certainty.clone()
+ certainty[certainty > upper_thresh] = 1
+ matches, certainty = (
+ matches.reshape(-1, 4),
+ certainty.reshape(-1),
+ )
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
+ good_samples = torch.multinomial(certainty,
+ num_samples = min(expansion_factor*num, len(certainty)),
+ replacement=False)
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
+ if "balanced" not in self.sample_mode:
+ return good_matches, good_certainty
+ density = kde(good_matches, std=0.1)
+ p = 1 / (density+1)
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+ balanced_samples = torch.multinomial(p,
+ num_samples = min(num,len(good_certainty)),
+ replacement=False)
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
+
+ def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
+ if batched:
+ f_q_pyramid = {
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
+ }
+ f_s_pyramid = {
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
+ }
+ else:
+ f_q_pyramid, f_s_pyramid = feature_pyramid
+ corresps = self.decoder(f_q_pyramid,
+ f_s_pyramid,
+ upsample = upsample,
+ **(batch["corresps"] if "corresps" in batch else {}),
+ scale_factor=scale_factor)
+
+ return corresps
+
+ def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
+ feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
+ f_q_pyramid = feature_pyramid
+ f_s_pyramid = {
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
+ for scale, f_scale in feature_pyramid.items()
+ }
+ corresps = self.decoder(f_q_pyramid,
+ f_s_pyramid,
+ upsample = upsample,
+ **(batch["corresps"] if "corresps" in batch else {}),
+ scale_factor=scale_factor)
+ return corresps
+
+ def conf_from_fb_consistency(self, flow_forward, flow_backward, th = 2):
+ # assumes that flow forward is of shape (..., H, W, 2)
+ has_batch = False
+ if len(flow_forward.shape) == 3:
+ flow_forward, flow_backward = flow_forward[None], flow_backward[None]
+ else:
+ has_batch = True
+ H,W = flow_forward.shape[-3:-1]
+ th_n = 2 * th / max(H,W)
+ coords = torch.stack(torch.meshgrid(
+ torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
+ torch.linspace(-1 + 1 / H, 1 - 1 / H, H), indexing = "xy"),
+ dim = -1).to(flow_forward.device)
+ coords_fb = F.grid_sample(
+ flow_backward.permute(0, 3, 1, 2),
+ flow_forward,
+ align_corners=False, mode="bilinear").permute(0, 2, 3, 1)
+ diff = (coords - coords_fb).norm(dim=-1)
+ in_th = (diff < th_n).float()
+ if not has_batch:
+ in_th = in_th[0]
+ return in_th
+
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
+ if coords.shape[-1] == 2:
+ return self._to_pixel_coordinates(coords, H_A, W_A)
+
+ if isinstance(coords, (list, tuple)):
+ kpts_A, kpts_B = coords[0], coords[1]
+ else:
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
+ return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
+
+ def _to_pixel_coordinates(self, coords, H, W):
+ kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
+ return kpts
+
+ def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
+ if isinstance(coords, (list, tuple)):
+ kpts_A, kpts_B = coords[0], coords[1]
+ else:
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
+ kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
+ kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
+ return kpts_A, kpts_B
+
+ def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
+ x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
+ cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
+ D = torch.cdist(x_A_to_B, x_B)
+ inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
+
+ if return_tuple:
+ if return_inds:
+ return inds_A, inds_B
+ else:
+ return x_A[inds_A], x_B[inds_B]
+ else:
+ if return_inds:
+ return torch.cat((inds_A, inds_B),dim=-1)
+ else:
+ return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
+
+ @torch.inference_mode()
+ def match(
+ self,
+ im_A_path,
+ im_B_path,
+ *args,
+ batched=False,
+ device = None,
+ ):
+ if device is None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if isinstance(im_A_path, (str, os.PathLike)):
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
+ else:
+ im_A, im_B = im_A_path, im_B_path
+
+ symmetric = self.symmetric
+ self.train(False)
+ with torch.no_grad():
+ if not batched:
+ b = 1
+ w, h = im_A.size
+ w2, h2 = im_B.size
+ # Get images in good format
+ ws = self.w_resized
+ hs = self.h_resized
+
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True, clahe = False
+ )
+ im_A, im_B = test_transform((im_A, im_B))
+ batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
+ else:
+ b, c, h, w = im_A.shape
+ b, c, h2, w2 = im_B.shape
+ assert w == w2 and h == h2, "For batched images we assume same size"
+ batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
+ if h != self.h_resized or self.w_resized != w:
+ warn("Model resolution and batch resolution differ, may produce unexpected results")
+ hs, ws = h, w
+ finest_scale = 1
+ # Run matcher
+ if symmetric:
+ corresps = self.forward_symmetric(batch)
+ else:
+ corresps = self.forward(batch, batched = True)
+
+ if self.upsample_preds:
+ hs, ws = self.upsample_res
+
+ if self.attenuate_cert:
+ low_res_certainty = F.interpolate(
+ corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ cert_clamp = 0
+ factor = 0.5
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+
+ if self.upsample_preds:
+ finest_corresps = corresps[finest_scale]
+ torch.cuda.empty_cache()
+ test_transform = get_tuple_transform_ops(
+ resize=(hs, ws), normalize=True
+ )
+ im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB')))
+ im_A, im_B = im_A[None].to(device), im_B[None].to(device)
+ scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
+ batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
+ if symmetric:
+ corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
+ else:
+ corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
+
+ im_A_to_im_B = corresps[finest_scale]["flow"]
+ certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
+ if finest_scale != 1:
+ im_A_to_im_B = F.interpolate(
+ im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ certainty = F.interpolate(
+ certainty, size=(hs, ws), align_corners=False, mode="bilinear"
+ )
+ im_A_to_im_B = im_A_to_im_B.permute(
+ 0, 2, 3, 1
+ )
+ # Create im_A meshgrid
+ im_A_coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
+ ),
+ indexing = 'ij'
+ )
+ im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
+ im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
+ certainty = certainty.sigmoid() # logits -> probs
+ im_A_coords = im_A_coords.permute(0, 2, 3, 1)
+ if (im_A_to_im_B.abs() > 1).any() and True:
+ wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
+ certainty[wrong[:,None]] = 0
+ im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
+ if symmetric:
+ A_to_B, B_to_A = im_A_to_im_B.chunk(2)
+ q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
+ im_B_coords = im_A_coords
+ s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
+ warp = torch.cat((q_warp, s_warp),dim=2)
+ certainty = torch.cat(certainty.chunk(2), dim=3)
+ else:
+ warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
+ if batched:
+ return (
+ warp,
+ certainty[:, 0]
+ )
+ else:
+ return (
+ warp[0],
+ certainty[0, 0],
+ )
+
+ def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
+ im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None, unnormalize = False):
+ #assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
+ H,W2,_ = warp.shape
+ W = W2//2 if symmetric else W2
+ if im_A is None:
+ from PIL import Image
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
+ if not isinstance(im_A, torch.Tensor):
+ im_A = im_A.resize((W,H))
+ im_B = im_B.resize((W,H))
+ x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
+ if symmetric:
+ x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
+ else:
+ if symmetric:
+ x_A = im_A
+ x_B = im_B
+ im_A_transfer_rgb = F.grid_sample(
+ x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+ )[0]
+ if symmetric:
+ im_B_transfer_rgb = F.grid_sample(
+ x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+ )[0]
+ warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
+ white_im = torch.ones((H,2*W),device=device)
+ else:
+ warp_im = im_A_transfer_rgb
+ white_im = torch.ones((H, W), device = device)
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
+ if save_path is not None:
+ from romatch.utils import tensor_to_pil
+ tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
+ return vis_im
diff --git a/imcui/third_party/RoMa/romatch/models/model_zoo/__init__.py b/imcui/third_party/RoMa/romatch/models/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0470ca3f0c3b8064b1b2f01663dfb13742d7a10
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/model_zoo/__init__.py
@@ -0,0 +1,73 @@
+from typing import Union
+import torch
+from .roma_models import roma_model, tiny_roma_v1_model
+
+weight_urls = {
+ "romatch": {
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
+ "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
+ },
+ "tiny_roma_v1": {
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth",
+ },
+ "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
+}
+
+def tiny_roma_v1_outdoor(device, weights = None, xfeat = None):
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url(
+ weight_urls["tiny_roma_v1"]["outdoor"],
+ map_location=device)
+ if xfeat is None:
+ xfeat = torch.hub.load(
+ 'verlab/accelerated_features',
+ 'XFeat',
+ pretrained = True,
+ top_k = 4096).net
+
+ return tiny_roma_v1_model(weights = weights, xfeat = xfeat).to(device)
+
+def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
+ if isinstance(coarse_res, int):
+ coarse_res = (coarse_res, coarse_res)
+ if isinstance(upsample_res, int):
+ upsample_res = (upsample_res, upsample_res)
+
+ if str(device) == 'cpu':
+ amp_dtype = torch.float32
+
+ assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+ assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["outdoor"],
+ map_location=device)
+ if dinov2_weights is None:
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+ map_location=device)
+ model = roma_model(resolution=coarse_res, upsample_preds=True,
+ weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
+ model.upsample_res = upsample_res
+ print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
+ return model
+
+def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
+ if isinstance(coarse_res, int):
+ coarse_res = (coarse_res, coarse_res)
+ if isinstance(upsample_res, int):
+ upsample_res = (upsample_res, upsample_res)
+
+ assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+ assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+
+ if weights is None:
+ weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["indoor"],
+ map_location=device)
+ if dinov2_weights is None:
+ dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+ map_location=device)
+ model = roma_model(resolution=coarse_res, upsample_preds=True,
+ weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
+ model.upsample_res = upsample_res
+ print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
+ return model
diff --git a/imcui/third_party/RoMa/romatch/models/model_zoo/roma_models.py b/imcui/third_party/RoMa/romatch/models/model_zoo/roma_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f8a08fc26d760a09f048bdd98bf8a8fffc8202c
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/model_zoo/roma_models.py
@@ -0,0 +1,170 @@
+import warnings
+import torch.nn as nn
+import torch
+from romatch.models.matcher import *
+from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
+from romatch.models.encoders import *
+from romatch.models.tiny import TinyRoMa
+
+def tiny_roma_v1_model(weights = None, freeze_xfeat=False, exact_softmax=False, xfeat = None):
+ model = TinyRoMa(
+ xfeat = xfeat,
+ freeze_xfeat=freeze_xfeat,
+ exact_softmax=exact_softmax)
+ if weights is not None:
+ model.load_state_dict(weights)
+ return model
+
+def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
+ # romatch weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
+ #torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful
+ #torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+ gp_dim = 512
+ feat_dim = 512
+ decoder_dim = gp_dim + feat_dim
+ cls_to_coord_res = 64
+ coordinate_decoder = TransformerDecoder(
+ nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
+ decoder_dim,
+ cls_to_coord_res**2 + 1,
+ is_classifier=True,
+ amp = True,
+ pos_enc = False,)
+ dw = True
+ hidden_blocks = 8
+ kernel_size = 5
+ displacement_emb = "linear"
+ disable_local_corr_grad = True
+
+ conv_refiner = nn.ModuleDict(
+ {
+ "16": ConvRefiner(
+ 2 * 512+128+(2*7+1)**2,
+ 2 * 512+128+(2*7+1)**2,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=128,
+ local_corr_radius = 7,
+ corr_in_other = True,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "8": ConvRefiner(
+ 2 * 512+64+(2*3+1)**2,
+ 2 * 512+64+(2*3+1)**2,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=64,
+ local_corr_radius = 3,
+ corr_in_other = True,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "4": ConvRefiner(
+ 2 * 256+32+(2*2+1)**2,
+ 2 * 256+32+(2*2+1)**2,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=32,
+ local_corr_radius = 2,
+ corr_in_other = True,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "2": ConvRefiner(
+ 2 * 64+16,
+ 128+16,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks=hidden_blocks,
+ displacement_emb=displacement_emb,
+ displacement_emb_dim=16,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ "1": ConvRefiner(
+ 2 * 9 + 6,
+ 24,
+ 2 + 1,
+ kernel_size=kernel_size,
+ dw=dw,
+ hidden_blocks = hidden_blocks,
+ displacement_emb = displacement_emb,
+ displacement_emb_dim = 6,
+ amp = True,
+ disable_local_corr_grad = disable_local_corr_grad,
+ bn_momentum = 0.01,
+ ),
+ }
+ )
+ kernel_temperature = 0.2
+ learn_temperature = False
+ no_cov = True
+ kernel = CosKernel
+ only_attention = False
+ basis = "fourier"
+ gp16 = GP(
+ kernel,
+ T=kernel_temperature,
+ learn_temperature=learn_temperature,
+ only_attention=only_attention,
+ gp_dim=gp_dim,
+ basis=basis,
+ no_cov=no_cov,
+ )
+ gps = nn.ModuleDict({"16": gp16})
+ proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
+ proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
+ proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
+ proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
+ proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
+ proj = nn.ModuleDict({
+ "16": proj16,
+ "8": proj8,
+ "4": proj4,
+ "2": proj2,
+ "1": proj1,
+ })
+ displacement_dropout_p = 0.0
+ gm_warp_dropout_p = 0.0
+ decoder = Decoder(coordinate_decoder,
+ gps,
+ proj,
+ conv_refiner,
+ detach=True,
+ scales=["16", "8", "4", "2", "1"],
+ displacement_dropout_p = displacement_dropout_p,
+ gm_warp_dropout_p = gm_warp_dropout_p)
+
+ encoder = CNNandDinov2(
+ cnn_kwargs = dict(
+ pretrained=False,
+ amp = True),
+ amp = True,
+ use_vgg = True,
+ dinov2_weights = dinov2_weights,
+ amp_dtype=amp_dtype,
+ )
+ h,w = resolution
+ symmetric = True
+ attenuate_cert = True
+ sample_mode = "threshold_balanced"
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
+ symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
+ matcher.load_state_dict(weights)
+ return matcher
diff --git a/imcui/third_party/RoMa/romatch/models/tiny.py b/imcui/third_party/RoMa/romatch/models/tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..88f44af6c9a6255831734d096167724f89d040ce
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/tiny.py
@@ -0,0 +1,304 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import os
+import torch
+from pathlib import Path
+import math
+import numpy as np
+
+from torch import nn
+from PIL import Image
+from torchvision.transforms import ToTensor
+from romatch.utils.kde import kde
+
+class BasicLayer(nn.Module):
+ """
+ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
+ """
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
+ super().__init__()
+ self.layer = nn.Sequential(
+ nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
+ nn.BatchNorm2d(out_channels, affine=False),
+ nn.ReLU(inplace = True) if relu else nn.Identity()
+ )
+
+ def forward(self, x):
+ return self.layer(x)
+
+class TinyRoMa(nn.Module):
+ """
+ Implementation of architecture described in
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
+ """
+
+ def __init__(self, xfeat = None,
+ freeze_xfeat = True,
+ sample_mode = "threshold_balanced",
+ symmetric = False,
+ exact_softmax = False):
+ super().__init__()
+ del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
+ if freeze_xfeat:
+ xfeat.train(False)
+ self.xfeat = [xfeat]# hide params from ddp
+ else:
+ self.xfeat = nn.ModuleList([xfeat])
+ self.freeze_xfeat = freeze_xfeat
+ match_dim = 256
+ self.coarse_matcher = nn.Sequential(
+ BasicLayer(64+64+2, match_dim,),
+ BasicLayer(match_dim, match_dim,),
+ BasicLayer(match_dim, match_dim,),
+ BasicLayer(match_dim, match_dim,),
+ nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
+ fine_match_dim = 64
+ self.fine_matcher = nn.Sequential(
+ BasicLayer(24+24+2, fine_match_dim,),
+ BasicLayer(fine_match_dim, fine_match_dim,),
+ BasicLayer(fine_match_dim, fine_match_dim,),
+ BasicLayer(fine_match_dim, fine_match_dim,),
+ nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
+ self.sample_mode = sample_mode
+ self.sample_thresh = 0.05
+ self.symmetric = symmetric
+ self.exact_softmax = exact_softmax
+
+ @property
+ def device(self):
+ return self.fine_matcher[-1].weight.device
+
+ def preprocess_tensor(self, x):
+ """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
+ H, W = x.shape[-2:]
+ _H, _W = (H//32) * 32, (W//32) * 32
+ rh, rw = H/_H, W/_W
+
+ x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
+ return x, rh, rw
+
+ def forward_single(self, x):
+ with torch.inference_mode(self.freeze_xfeat or not self.training):
+ xfeat = self.xfeat[0]
+ with torch.no_grad():
+ x = x.mean(dim=1, keepdim = True)
+ x = xfeat.norm(x)
+
+ #main backbone
+ x1 = xfeat.block1(x)
+ x2 = xfeat.block2(x1 + xfeat.skip1(x))
+ x3 = xfeat.block3(x2)
+ x4 = xfeat.block4(x3)
+ x5 = xfeat.block5(x4)
+ x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
+ x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
+ feats = xfeat.block_fusion( x3 + x4 + x5 )
+ if self.freeze_xfeat:
+ return x2.clone(), feats.clone()
+ return x2, feats
+
+ def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
+ if coords.shape[-1] == 2:
+ return self._to_pixel_coordinates(coords, H_A, W_A)
+
+ if isinstance(coords, (list, tuple)):
+ kpts_A, kpts_B = coords[0], coords[1]
+ else:
+ kpts_A, kpts_B = coords[...,:2], coords[...,2:]
+ return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
+
+ def _to_pixel_coordinates(self, coords, H, W):
+ kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
+ return kpts
+
+ def pos_embed(self, corr_volume: torch.Tensor):
+ B, H1, W1, H0, W0 = corr_volume.shape
+ grid = torch.stack(
+ torch.meshgrid(
+ torch.linspace(-1+1/W1,1-1/W1, W1),
+ torch.linspace(-1+1/H1,1-1/H1, H1),
+ indexing = "xy"),
+ dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
+ down = 4
+ if not self.training and not self.exact_softmax:
+ grid_lr = torch.stack(
+ torch.meshgrid(
+ torch.linspace(-1+down/W1,1-down/W1, W1//down),
+ torch.linspace(-1+down/H1,1-down/H1, H1//down),
+ indexing = "xy"),
+ dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
+ cv = corr_volume
+ best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
+ P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
+ pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
+ #print("hej")
+ else:
+ P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
+ pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
+ return pos_embeddings
+
+ def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
+ im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
+ device = warp.device
+ H,W2,_ = warp.shape
+ W = W2//2 if symmetric else W2
+ if im_A is None:
+ from PIL import Image
+ im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
+ if not isinstance(im_A, torch.Tensor):
+ im_A = im_A.resize((W,H))
+ im_B = im_B.resize((W,H))
+ x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
+ if symmetric:
+ x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
+ else:
+ if symmetric:
+ x_A = im_A
+ x_B = im_B
+ im_A_transfer_rgb = F.grid_sample(
+ x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+ )[0]
+ if symmetric:
+ im_B_transfer_rgb = F.grid_sample(
+ x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+ )[0]
+ warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
+ white_im = torch.ones((H,2*W),device=device)
+ else:
+ warp_im = im_A_transfer_rgb
+ white_im = torch.ones((H, W), device = device)
+ vis_im = certainty * warp_im + (1 - certainty) * white_im
+ if save_path is not None:
+ from romatch.utils import tensor_to_pil
+ tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
+ return vis_im
+
+ def corr_volume(self, feat0, feat1):
+ """
+ input:
+ feat0 -> torch.Tensor(B, C, H, W)
+ feat1 -> torch.Tensor(B, C, H, W)
+ return:
+ corr_volume -> torch.Tensor(B, H, W, H, W)
+ """
+ B, C, H0, W0 = feat0.shape
+ B, C, H1, W1 = feat1.shape
+ feat0 = feat0.view(B, C, H0*W0)
+ feat1 = feat1.view(B, C, H1*W1)
+ corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
+ return corr_volume
+
+ @torch.inference_mode()
+ def match_from_path(self, im0_path, im1_path):
+ device = self.device
+ im0 = ToTensor()(Image.open(im0_path))[None].to(device)
+ im1 = ToTensor()(Image.open(im1_path))[None].to(device)
+ return self.match(im0, im1, batched = False)
+
+ @torch.inference_mode()
+ def match(self, im0, im1, *args, batched = True):
+ # stupid
+ if isinstance(im0, (str, Path)):
+ return self.match_from_path(im0, im1)
+ elif isinstance(im0, Image.Image):
+ batched = False
+ device = self.device
+ im0 = ToTensor()(im0)[None].to(device)
+ im1 = ToTensor()(im1)[None].to(device)
+
+ B,C,H0,W0 = im0.shape
+ B,C,H1,W1 = im1.shape
+ self.train(False)
+ corresps = self.forward({"im_A":im0, "im_B":im1})
+ #return 1,1
+ flow = F.interpolate(
+ corresps[4]["flow"],
+ size = (H0, W0),
+ mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
+ grid = torch.stack(
+ torch.meshgrid(
+ torch.linspace(-1+1/W0,1-1/W0, W0),
+ torch.linspace(-1+1/H0,1-1/H0, H0),
+ indexing = "xy"),
+ dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
+
+ certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
+ warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
+ if batched:
+ return warp, cert
+ else:
+ return warp[0], cert[0]
+
+ def sample(
+ self,
+ matches,
+ certainty,
+ num=5_000,
+ ):
+ H,W,_ = matches.shape
+ if "threshold" in self.sample_mode:
+ upper_thresh = self.sample_thresh
+ certainty = certainty.clone()
+ certainty[certainty > upper_thresh] = 1
+ matches, certainty = (
+ matches.reshape(-1, 4),
+ certainty.reshape(-1),
+ )
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
+ good_samples = torch.multinomial(certainty,
+ num_samples = min(expansion_factor*num, len(certainty)),
+ replacement=False)
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
+ if "balanced" not in self.sample_mode:
+ return good_matches, good_certainty
+ use_half = True if matches.device.type == "cuda" else False
+ down = 1 if matches.device.type == "cuda" else 8
+ density = kde(good_matches, std=0.1, half = use_half, down = down)
+ p = 1 / (density+1)
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+ balanced_samples = torch.multinomial(p,
+ num_samples = min(num,len(good_certainty)),
+ replacement=False)
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
+
+
+ def forward(self, batch):
+ """
+ input:
+ x -> torch.Tensor(B, C, H, W) grayscale or rgb images
+ return:
+
+ """
+ im0 = batch["im_A"]
+ im1 = batch["im_B"]
+ corresps = {}
+ im0, rh0, rw0 = self.preprocess_tensor(im0)
+ im1, rh1, rw1 = self.preprocess_tensor(im1)
+ B, C, H0, W0 = im0.shape
+ B, C, H1, W1 = im1.shape
+ to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
+
+ if im0.shape[-2:] == im1.shape[-2:]:
+ x = torch.cat([im0, im1], dim=0)
+ x = self.forward_single(x)
+ feats_x0_c, feats_x1_c = x[1].chunk(2)
+ feats_x0_f, feats_x1_f = x[0].chunk(2)
+ else:
+ feats_x0_f, feats_x0_c = self.forward_single(im0)
+ feats_x1_f, feats_x1_c = self.forward_single(im1)
+ corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
+ coarse_warp = self.pos_embed(corr_volume)
+ coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
+ feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
+ coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
+ coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
+ corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
+ coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)
+ coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
+ feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
+ fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
+ fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
+ corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
+ return corresps
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/__init__.py b/imcui/third_party/RoMa/romatch/models/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..983f03ccc51cdbcef6166a160fe50652a81418d7
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/__init__.py
@@ -0,0 +1,48 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from romatch.utils.utils import get_grid, get_autocast_params
+from .layers.block import Block
+from .layers.attention import MemEffAttention
+from .dinov2 import vit_large
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
+ amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.blocks = blocks
+ self.to_out = nn.Linear(hidden_dim, out_dim)
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+ self._scales = [16]
+ self.is_classifier = is_classifier
+ self.amp = amp
+ self.amp_dtype = amp_dtype
+ self.pos_enc = pos_enc
+ self.learned_embeddings = learned_embeddings
+ if self.learned_embeddings:
+ self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
+
+ def scales(self):
+ return self._scales.copy()
+
+ def forward(self, gp_posterior, features, old_stuff, new_scale):
+ autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(gp_posterior.device, enabled=self.amp, dtype=self.amp_dtype)
+ with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
+ B,C,H,W = gp_posterior.shape
+ x = torch.cat((gp_posterior, features), dim = 1)
+ B,C,H,W = x.shape
+ grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
+ if self.learned_embeddings:
+ pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
+ else:
+ pos_enc = 0
+ tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
+ z = self.blocks(tokens)
+ out = self.to_out(z)
+ out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
+ warp, certainty = out[:, :-1], out[:, -1:]
+ return warp, certainty, None
+
+
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/dinov2.py b/imcui/third_party/RoMa/romatch/models/transformer/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b556c63096d17239c8603d5fe626c331963099fd
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/dinov2.py
@@ -0,0 +1,359 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @property
+ def device(self):
+ return self.cls_token.device
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/__init__.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f9b0c94b40967dfdff4f261c127cbd21328c905
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/attention.py
@@ -0,0 +1,81 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/block.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/dino_head.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/dino_head.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/drop_path.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/mlp.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/imcui/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py b/imcui/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/imcui/third_party/RoMa/romatch/train/__init__.py b/imcui/third_party/RoMa/romatch/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90269dc0f345a575e0ba21f5afa34202c7e6b433
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/train/__init__.py
@@ -0,0 +1 @@
+from .train import train_k_epochs
diff --git a/imcui/third_party/RoMa/romatch/train/train.py b/imcui/third_party/RoMa/romatch/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cb02ed1e816fd39f174f76ec15bce49ae2a3da8
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/train/train.py
@@ -0,0 +1,102 @@
+from tqdm import tqdm
+from romatch.utils.utils import to_cuda
+import romatch
+import torch
+import wandb
+
+def log_param_statistics(named_parameters, norm_type = 2):
+ named_parameters = list(named_parameters)
+ grads = [p.grad for n, p in named_parameters if p.grad is not None]
+ weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
+ names = [n for n,p in named_parameters if p.grad is not None]
+ param_norm = torch.stack(weight_norms).norm(p=norm_type)
+ device = grads[0].device
+ grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
+ nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
+ nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
+ total_grad_norm = torch.norm(grad_norms, norm_type)
+ if torch.any(nans_or_infs):
+ print(f"These params have nan or inf grads: {nan_inf_names}")
+ wandb.log({"grad_norm": total_grad_norm.item()}, step = romatch.GLOBAL_STEP)
+ wandb.log({"param_norm": param_norm.item()}, step = romatch.GLOBAL_STEP)
+
+def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
+ optimizer.zero_grad()
+ out = model(train_batch)
+ l = objective(out, train_batch)
+ grad_scaler.scale(l).backward()
+ grad_scaler.unscale_(optimizer)
+ log_param_statistics(model.named_parameters())
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
+ grad_scaler.step(optimizer)
+ grad_scaler.update()
+ wandb.log({"grad_scale": grad_scaler._scale.item()}, step = romatch.GLOBAL_STEP)
+ if grad_scaler._scale < 1.:
+ grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
+ romatch.GLOBAL_STEP = romatch.GLOBAL_STEP + romatch.STEP_SIZE # increment global step
+ return {"train_out": out, "train_loss": l.item()}
+
+
+def train_k_steps(
+ n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None, pbar_n_seconds = 1,
+):
+ for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or romatch.RANK > 0, mininterval=pbar_n_seconds):
+ batch = next(dataloader)
+ model.train(True)
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ grad_scaler=grad_scaler,
+ n=n,
+ grad_clip_norm = grad_clip_norm,
+ )
+ if ema_model is not None:
+ ema_model.update()
+ if warmup is not None:
+ with warmup.dampening():
+ lr_scheduler.step()
+ else:
+ lr_scheduler.step()
+ [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
+
+
+def train_epoch(
+ dataloader=None,
+ model=None,
+ objective=None,
+ optimizer=None,
+ lr_scheduler=None,
+ epoch=None,
+):
+ model.train(True)
+ print(f"At epoch {epoch}")
+ for batch in tqdm(dataloader, mininterval=5.0):
+ batch = to_cuda(batch)
+ train_step(
+ train_batch=batch, model=model, objective=objective, optimizer=optimizer
+ )
+ lr_scheduler.step()
+ return {
+ "model": model,
+ "optimizer": optimizer,
+ "lr_scheduler": lr_scheduler,
+ "epoch": epoch,
+ }
+
+
+def train_k_epochs(
+ start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
+):
+ for epoch in range(start_epoch, end_epoch + 1):
+ train_epoch(
+ dataloader=dataloader,
+ model=model,
+ objective=objective,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ )
diff --git a/imcui/third_party/RoMa/romatch/utils/__init__.py b/imcui/third_party/RoMa/romatch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2709f5e586150289085a4e2cbd458bc443fab7f3
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/utils/__init__.py
@@ -0,0 +1,16 @@
+from .utils import (
+ pose_auc,
+ get_pose,
+ compute_relative_pose,
+ compute_pose_error,
+ estimate_pose,
+ estimate_pose_uncalibrated,
+ rotate_intrinsic,
+ get_tuple_transform_ops,
+ get_depth_tuple_transform_ops,
+ warp_kpts,
+ numpy_to_pil,
+ tensor_to_pil,
+ recover_pose,
+ signed_left_to_right_epipolar_distance,
+)
diff --git a/imcui/third_party/RoMa/romatch/utils/kde.py b/imcui/third_party/RoMa/romatch/utils/kde.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ed2e5e106bbca93e703f39f3ad3af350666e34
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/utils/kde.py
@@ -0,0 +1,13 @@
+import torch
+
+
+def kde(x, std = 0.1, half = True, down = None):
+ # use a gaussian kernel to estimate density
+ if half:
+ x = x.half() # Do it in half precision TODO: remove hardcoding
+ if down is not None:
+ scores = (-torch.cdist(x,x[::down])**2/(2*std**2)).exp()
+ else:
+ scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+ density = scores.sum(dim=-1)
+ return density
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/utils/local_correlation.py b/imcui/third_party/RoMa/romatch/utils/local_correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe1322a20bf82d0331159f958241cb87f75f4e21
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/utils/local_correlation.py
@@ -0,0 +1,48 @@
+import torch
+import torch.nn.functional as F
+
+def local_correlation(
+ feature0,
+ feature1,
+ local_radius,
+ padding_mode="zeros",
+ flow = None,
+ sample_mode = "bilinear",
+):
+ r = local_radius
+ K = (2*r+1)**2
+ B, c, h, w = feature0.size()
+ corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
+ if flow is None:
+ # If flow is None, assume feature0 and feature1 are aligned
+ coords = torch.meshgrid(
+ (
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
+ ),
+ indexing = 'ij'
+ )
+ coords = torch.stack((coords[1], coords[0]), dim=-1)[
+ None
+ ].expand(B, h, w, 2)
+ else:
+ coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+ local_window = torch.meshgrid(
+ (
+ torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
+ torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
+ ),
+ indexing = 'ij'
+ )
+ local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
+ None
+ ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
+ for _ in range(B):
+ with torch.no_grad():
+ local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2)
+ window_feature = F.grid_sample(
+ feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
+ )
+ window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
+ corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
+ return corr
diff --git a/imcui/third_party/RoMa/romatch/utils/transforms.py b/imcui/third_party/RoMa/romatch/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea6476bd816a31df36f7d1b5417853637b65474b
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/utils/transforms.py
@@ -0,0 +1,118 @@
+from typing import Dict
+import numpy as np
+import torch
+import kornia.augmentation as K
+from kornia.geometry.transform import warp_perspective
+
+# Adapted from Kornia
+class GeometricSequential:
+ def __init__(self, *transforms, align_corners=True) -> None:
+ self.transforms = transforms
+ self.align_corners = align_corners
+
+ def __call__(self, x, mode="bilinear"):
+ b, c, h, w = x.shape
+ M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
+ for t in self.transforms:
+ if np.random.rand() < t.p:
+ M = M.matmul(
+ t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
+ )
+ return (
+ warp_perspective(
+ x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
+ ),
+ M,
+ )
+
+ def apply_transform(self, x, M, mode="bilinear"):
+ b, c, h, w = x.shape
+ return warp_perspective(
+ x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
+ )
+
+
+class RandomPerspective(K.RandomPerspective):
+ def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
+ distortion_scale = torch.as_tensor(
+ self.distortion_scale, device=self._device, dtype=self._dtype
+ )
+ return self.random_perspective_generator(
+ batch_shape[0],
+ batch_shape[-2],
+ batch_shape[-1],
+ distortion_scale,
+ self.same_on_batch,
+ self.device,
+ self.dtype,
+ )
+
+ def random_perspective_generator(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ distortion_scale: torch.Tensor,
+ same_on_batch: bool = False,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.float32,
+ ) -> Dict[str, torch.Tensor]:
+ r"""Get parameters for ``perspective`` for a random perspective transform.
+
+ Args:
+ batch_size (int): the tensor batch size.
+ height (int) : height of the image.
+ width (int): width of the image.
+ distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
+ same_on_batch (bool): apply the same transformation across the batch. Default: False.
+ device (torch.device): the device on which the random numbers will be generated. Default: cpu.
+ dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
+
+ Returns:
+ params Dict[str, torch.Tensor]: parameters to be passed for transformation.
+ - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
+ - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
+
+ Note:
+ The generated random numbers are not reproducible across different devices and dtypes.
+ """
+ if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
+ raise AssertionError(
+ f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
+ )
+ if not (
+ type(height) is int and height > 0 and type(width) is int and width > 0
+ ):
+ raise AssertionError(
+ f"'height' and 'width' must be integers. Got {height}, {width}."
+ )
+
+ start_points: torch.Tensor = torch.tensor(
+ [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
+ device=distortion_scale.device,
+ dtype=distortion_scale.dtype,
+ ).expand(batch_size, -1, -1)
+
+ # generate random offset not larger than half of the image
+ fx = distortion_scale * width / 2
+ fy = distortion_scale * height / 2
+
+ factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
+ offset = (torch.rand_like(start_points) - 0.5) * 2
+ end_points = start_points + factor * offset
+
+ return dict(start_points=start_points, end_points=end_points)
+
+
+
+class RandomErasing:
+ def __init__(self, p = 0., scale = 0.) -> None:
+ self.p = p
+ self.scale = scale
+ self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
+ def __call__(self, image, depth):
+ if self.p > 0:
+ image = self.random_eraser(image)
+ depth = self.random_eraser(depth, params=self.random_eraser._params)
+ return image, depth
+
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/romatch/utils/utils.py b/imcui/third_party/RoMa/romatch/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c522b16b7020779a0ea58b28973f2f609145838
--- /dev/null
+++ b/imcui/third_party/RoMa/romatch/utils/utils.py
@@ -0,0 +1,654 @@
+import warnings
+import numpy as np
+import cv2
+import math
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import torch.nn.functional as F
+from PIL import Image
+import kornia
+
+def recover_pose(E, kpts0, kpts1, K0, K1, mask):
+ best_num_inliers = 0
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+
+
+# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
+# --- GEOMETRY ---
+def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
+ )
+
+ ret = None
+ if E is not None:
+ best_num_inliers = 0
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ method = cv2.USAC_ACCURATE
+ F, mask = cv2.findFundamentalMat(
+ kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
+ )
+ E = K1.T@F@K0
+ ret = None
+ if E is not None:
+ best_num_inliers = 0
+ K0inv = np.linalg.inv(K0[:2,:2])
+ K1inv = np.linalg.inv(K1[:2,:2])
+
+ kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
+ kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t, mask.ravel() > 0)
+ return ret
+
+def unnormalize_coords(x_n,h,w):
+ x = torch.stack(
+ (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ return x
+
+
+def rotate_intrinsic(K, n):
+ base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ rot = np.linalg.matrix_power(base_rot, n)
+ return rot @ K
+
+
+def rotate_pose_inplane(i_T_w, rot):
+ rotation_matrices = [
+ np.array(
+ [
+ [np.cos(r), -np.sin(r), 0.0, 0.0],
+ [np.sin(r), np.cos(r), 0.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ dtype=np.float32,
+ )
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+ ]
+ return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+ scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
+ return np.dot(scales, K)
+
+
+def to_homogeneous(points):
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
+
+
+def angle_error_mat(R1, R2):
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+ cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
+ return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t):
+ R_gt = T_0to1[:3, :3]
+ t_gt = T_0to1[:3, 3]
+ error_t = angle_error_vec(t.squeeze(), t_gt)
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_R = angle_error_mat(R, R_gt)
+ return error_t, error_R
+
+
+def pose_auc(errors, thresholds):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0.0, errors]
+ recall = np.r_[0.0, recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e) / t)
+ return aucs
+
+
+# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
+def get_depth_tuple_transform_ops_nearest_exact(resize=None):
+ ops = []
+ if resize:
+ ops.append(TupleResizeNearestExact(resize))
+ return TupleCompose(ops)
+
+def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
+ return TupleCompose(ops)
+
+
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
+ ops = []
+ if resize:
+ ops.append(TupleResize(resize))
+ ops.append(TupleToTensorScaled())
+ if normalize:
+ ops.append(
+ TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ) # Imagenet mean/std
+ return TupleCompose(ops)
+
+class ToTensorScaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
+
+ def __call__(self, im):
+ if not isinstance(im, torch.Tensor):
+ im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
+ im /= 255.0
+ return torch.from_numpy(im)
+ else:
+ return im
+
+ def __repr__(self):
+ return "ToTensorScaled(./255)"
+
+
+class TupleToTensorScaled(object):
+ def __init__(self):
+ self.to_tensor = ToTensorScaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorScaled(./255)"
+
+
+class ToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __call__(self, im):
+ return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
+
+ def __repr__(self):
+ return "ToTensorUnscaled()"
+
+
+class TupleToTensorUnscaled(object):
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+ def __init__(self):
+ self.to_tensor = ToTensorUnscaled()
+
+ def __call__(self, im_tuple):
+ return [self.to_tensor(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleToTensorUnscaled()"
+
+class TupleResizeNearestExact:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, im_tuple):
+ return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleResizeNearestExact(size={})".format(self.size)
+
+
+class TupleResize(object):
+ def __init__(self, size, mode=InterpolationMode.BICUBIC):
+ self.size = size
+ self.resize = transforms.Resize(size, mode)
+ def __call__(self, im_tuple):
+ return [self.resize(im) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleResize(size={})".format(self.size)
+
+class Normalize:
+ def __call__(self,im):
+ mean = im.mean(dim=(1,2), keepdims=True)
+ std = im.std(dim=(1,2), keepdims=True)
+ return (im-mean)/std
+
+
+class TupleNormalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+ self.normalize = transforms.Normalize(mean=mean, std=std)
+
+ def __call__(self, im_tuple):
+ c,h,w = im_tuple[0].shape
+ if c > 3:
+ warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
+ return [self.normalize(im[:3]) for im in im_tuple]
+
+ def __repr__(self):
+ return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
+
+
+class TupleCompose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, im_tuple):
+ for t in self.transforms:
+ im_tuple = t(im_tuple)
+ return im_tuple
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
+@torch.no_grad()
+def cls_to_flow(cls, deterministic_sampling = True):
+ B,C,H,W = cls.shape
+ device = cls.device
+ res = round(math.sqrt(C))
+ G = torch.meshgrid(
+ *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
+ indexing = 'ij'
+ )
+ G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+ if deterministic_sampling:
+ sampled_cls = cls.max(dim=1).indices
+ else:
+ sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
+ flow = G[sampled_cls]
+ return flow
+
+@torch.no_grad()
+def cls_to_flow_refine(cls):
+ B,C,H,W = cls.shape
+ device = cls.device
+ res = round(math.sqrt(C))
+ G = torch.meshgrid(
+ *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
+ indexing = 'ij'
+ )
+ G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+ # FIXME: below softmax line causes mps to bug, don't know why.
+ if device.type == 'mps':
+ cls = cls.log_softmax(dim=1).exp()
+ else:
+ cls = cls.softmax(dim=1)
+ mode = cls.max(dim=1).indices
+
+ index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
+ neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
+ flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
+ tot_prob = neighbours.sum(dim=1)
+ flow = flow / tot_prob
+ return flow
+
+
+def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
+
+ if H is None:
+ B,H,W = depth1.shape
+ else:
+ B = depth1.shape[0]
+ with torch.no_grad():
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
+ )
+ for n in (B, H, W)
+ ],
+ indexing = 'ij'
+ )
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
+ mask, x2 = warp_kpts(
+ x1_n.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ depth_interpolation_mode = depth_interpolation_mode,
+ relative_depth_error_threshold = relative_depth_error_threshold,
+ )
+ prob = mask.float().reshape(B, H, W)
+ x2 = x2.reshape(B, H, W, 2)
+ return x2, prob
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1)
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ (
+ n,
+ h,
+ w,
+ ) = depth0.shape
+ if depth_interpolation_mode == "combined":
+ # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
+ if smooth_mask:
+ raise NotImplementedError("Combined bilinear and NN warp not implemented")
+ valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "bilinear",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "nearest-exact",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
+ warp = warp_bilinear.clone()
+ warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
+ valid = valid_bilinear | valid_nearest
+ return valid, warp
+
+
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
+ :, 0, :, 0
+ ]
+ kpts0 = torch.stack(
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ # Sample depth, get calculable_mask on depth != 0
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = (
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+ * kpts0_depth[..., None]
+ ) # (N, L, 3)
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+ kpts0_cam = kpts0_n
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
+ w_kpts0_h[:, :, [2]] + 1e-4
+ ) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (
+ (w_kpts0[:, :, 0] > 0)
+ * (w_kpts0[:, :, 0] < w - 1)
+ * (w_kpts0[:, :, 1] > 0)
+ * (w_kpts0[:, :, 1] < h - 1)
+ )
+ w_kpts0 = torch.stack(
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+ # w_kpts0[~covisible_mask, :] = -5 # xd
+
+ w_kpts0_depth = F.grid_sample(
+ depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
+ )[:, 0, :, 0]
+
+ relative_depth_error = (
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+ ).abs()
+ if not smooth_mask:
+ consistent_mask = relative_depth_error < relative_depth_error_threshold
+ else:
+ consistent_mask = (-relative_depth_error/smooth_mask).exp()
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+ if return_relative_depth_error:
+ return relative_depth_error, w_kpts0
+ else:
+ return valid_mask, w_kpts0
+
+imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
+imagenet_std = torch.tensor([0.229, 0.224, 0.225])
+
+
+def numpy_to_pil(x: np.ndarray):
+ """
+ Args:
+ x: Assumed to be of shape (h,w,c)
+ """
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ if x.max() <= 1.01:
+ x *= 255
+ x = x.astype(np.uint8)
+ return Image.fromarray(x)
+
+
+def tensor_to_pil(x, unnormalize=False):
+ if unnormalize:
+ x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
+ x = x.detach().permute(1, 2, 0).cpu().numpy()
+ x = np.clip(x, 0.0, 1.0)
+ return numpy_to_pil(x)
+
+
+def to_cuda(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cuda()
+ return batch
+
+
+def to_cpu(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.cpu()
+ return batch
+
+
+def get_pose(calib):
+ w, h = np.array(calib["imsize"])[0]
+ return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
+
+
+def compute_relative_pose(R1, t1, R2, t2):
+ rots = R2 @ (R1.T)
+ trans = -rots @ t1 + t2
+ return rots, trans
+
+@torch.no_grad()
+def reset_opt(opt):
+ for group in opt.param_groups:
+ for p in group['params']:
+ if p.requires_grad:
+ state = opt.state[p]
+ # State initialization
+
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ # Exponential moving average of gradient difference
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+
+def flow_to_pixel_coords(flow, h1, w1):
+ flow = (
+ torch.stack(
+ (
+ w1 * (flow[..., 0] + 1) / 2,
+ h1 * (flow[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ return flow
+
+to_pixel_coords = flow_to_pixel_coords # just an alias
+
+def flow_to_normalized_coords(flow, h1, w1):
+ flow = (
+ torch.stack(
+ (
+ 2 * (flow[..., 0]) / w1 - 1,
+ 2 * (flow[..., 1]) / h1 - 1,
+ ),
+ axis=-1,
+ )
+ )
+ return flow
+
+to_normalized_coords = flow_to_normalized_coords # just an alias
+
+def warp_to_pixel_coords(warp, h1, w1, h2, w2):
+ warp1 = warp[..., :2]
+ warp1 = (
+ torch.stack(
+ (
+ w1 * (warp1[..., 0] + 1) / 2,
+ h1 * (warp1[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ warp2 = warp[..., 2:]
+ warp2 = (
+ torch.stack(
+ (
+ w2 * (warp2[..., 0] + 1) / 2,
+ h2 * (warp2[..., 1] + 1) / 2,
+ ),
+ axis=-1,
+ )
+ )
+ return torch.cat((warp1,warp2), dim=-1)
+
+
+
+def signed_point_line_distance(point, line, eps: float = 1e-9):
+ r"""Return the distance from points to lines.
+
+ Args:
+ point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
+ line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
+ eps: Small constant for safe sqrt.
+
+ Returns:
+ the computed distance with shape :math:`(*, N)`.
+ """
+
+ if not point.shape[-1] in (2, 3):
+ raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
+
+ if not line.shape[-1] == 3:
+ raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
+
+ numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
+ denominator = line[..., :2].norm(dim=-1)
+
+ return numerator / (denominator + eps)
+
+
+def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
+ r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
+
+ This method measures the distance from points in the right images to the epilines
+ of the corresponding points in the left images as they reflect in the right images.
+
+ Args:
+ pts1: correspondences from the left images with shape
+ :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
+ pts2: correspondences from the right images with shape
+ :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
+ Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
+ avoid ambiguity with torch.nn.functional.
+
+ Returns:
+ the computed Symmetrical distance with shape :math:`(*, N)`.
+ """
+ import kornia
+ if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
+ raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
+
+ if pts1.shape[-1] == 2:
+ pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
+
+ F_t = Fm.transpose(dim0=-2, dim1=-1)
+ line1_in_2 = pts1 @ F_t
+
+ return signed_point_line_distance(pts2, line1_in_2)
+
+def get_grid(b, h, w, device):
+ grid = torch.meshgrid(
+ *[
+ torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
+ for n in (b, h, w)
+ ],
+ indexing = 'ij'
+ )
+ grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
+ return grid
+
+
+def get_autocast_params(device=None, enabled=False, dtype=None):
+ if device is None:
+ autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ #strip :X from device
+ autocast_device = str(device).split(":")[0]
+ if 'cuda' in str(device):
+ out_dtype = dtype
+ enabled = True
+ else:
+ out_dtype = torch.bfloat16
+ enabled = False
+ # mps is not supported
+ autocast_device = "cpu"
+ return autocast_device, enabled, out_dtype
\ No newline at end of file
diff --git a/imcui/third_party/RoMa/setup.py b/imcui/third_party/RoMa/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec18f3bbb71b85d943fdfeed3ed5c47033aebbc
--- /dev/null
+++ b/imcui/third_party/RoMa/setup.py
@@ -0,0 +1,9 @@
+from setuptools import setup, find_packages
+
+setup(
+ name="romatch",
+ packages=find_packages(include=("romatch*",)),
+ version="0.0.1",
+ author="Johan Edstedt",
+ install_requires=open("requirements.txt", "r").read().split("\n"),
+)
diff --git a/imcui/third_party/RoRD/demo/__init__.py b/imcui/third_party/RoRD/demo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/RoRD/demo/register.py b/imcui/third_party/RoRD/demo/register.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba626920887639c6c95f869231d8080de64c2ee8
--- /dev/null
+++ b/imcui/third_party/RoRD/demo/register.py
@@ -0,0 +1,265 @@
+import numpy as np
+import copy
+import argparse
+import os, sys
+import open3d as o3d
+from sys import argv
+from PIL import Image
+import math
+import cv2
+import torch
+
+sys.path.append("../")
+from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching
+from lib.model_test import D2Net
+
+#### Cuda ####
+use_cuda = torch.cuda.is_available()
+device = torch.device('cuda:0' if use_cuda else 'cpu')
+
+#### Argument Parsing ####
+parser = argparse.ArgumentParser(description='RoRD ICP evaluation')
+
+parser.add_argument(
+ '--rgb1', type=str, default = 'rgb/rgb2_1.jpg',
+ help='path to the rgb image1'
+)
+parser.add_argument(
+ '--rgb2', type=str, default = 'rgb/rgb2_2.jpg',
+ help='path to the rgb image2'
+)
+
+parser.add_argument(
+ '--depth1', type=str, default = 'depth/depth2_1.png',
+ help='path to the depth image1'
+)
+
+parser.add_argument(
+ '--depth2', type=str, default = 'depth/depth2_2.png',
+ help='path to the depth image2'
+)
+
+parser.add_argument(
+ '--model_rord', type=str, default = '../models/rord.pth',
+ help='path to the RoRD model for evaluation'
+)
+
+parser.add_argument(
+ '--model_d2', type=str,
+ help='path to the vanilla D2-Net model for evaluation'
+)
+
+parser.add_argument(
+ '--model_ens', action='store_true',
+ help='ensemble model of RoRD + D2-Net'
+)
+
+parser.add_argument(
+ '--sift', action='store_true',
+ help='Sift'
+)
+
+parser.add_argument(
+ '--camera_file', type=str, default='../configs/camera.txt',
+ help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.'
+)
+
+parser.add_argument(
+ '--viz3d', action='store_true',
+ help='visualize the pointcloud registrations'
+)
+
+args = parser.parse_args()
+
+if args.model_ens: # Change default paths accordingly for ensemble
+ model1_ens = '../../models/rord.pth'
+ model2_ens = '../../models/d2net.pth'
+
+def draw_registration_result(source, target, transformation):
+ source_temp = copy.deepcopy(source)
+ target_temp = copy.deepcopy(target)
+ source_temp.transform(transformation)
+
+ target_temp += source_temp
+ # print("Saved registered PointCloud.")
+ # o3d.io.write_point_cloud("registered.pcd", target_temp)
+
+ trgSph.append(source_temp); trgSph.append(target_temp)
+ axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ axis2.transform(transformation)
+ trgSph.append(axis1); trgSph.append(axis2)
+ print("Showing registered PointCloud.")
+ o3d.visualization.draw_geometries(trgSph)
+
+
+def readDepth(depthFile):
+ depth = Image.open(depthFile)
+ if depth.mode != "I":
+ raise Exception("Depth image is not in intensity format")
+
+ return np.asarray(depth)
+
+def readCamera(camera):
+ with open (camera, "rt") as file:
+ contents = file.read().split()
+
+ focalX = float(contents[0])
+ focalY = float(contents[1])
+ centerX = float(contents[2])
+ centerY = float(contents[3])
+ scalingFactor = float(contents[4])
+
+ return focalX, focalY, centerX, centerY, scalingFactor
+
+def getPointCloud(rgbFile, depthFile, pts):
+ thresh = 15.0
+
+ depth = readDepth(depthFile)
+ rgb = Image.open(rgbFile)
+
+ points = []
+ colors = []
+
+ corIdx = [-1]*len(pts)
+ corPts = [None]*len(pts)
+ ptIdx = 0
+
+ for v in range(depth.shape[0]):
+ for u in range(depth.shape[1]):
+ Z = depth[v, u] / scalingFactor
+ if Z==0: continue
+ if (Z > thresh): continue
+
+ X = (u - centerX) * Z / focalX
+ Y = (v - centerY) * Z / focalY
+
+ points.append((X, Y, Z))
+ colors.append(rgb.getpixel((u, v)))
+
+ if((u, v) in pts):
+ # print("Point found.")
+ index = pts.index((u, v))
+ corIdx[index] = ptIdx
+ corPts[index] = (X, Y, Z)
+
+ ptIdx = ptIdx+1
+
+ points = np.asarray(points)
+ colors = np.asarray(colors)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors/255)
+
+ return pcd, corIdx, corPts
+
+
+def convertPts(A):
+ X = A[0]; Y = A[1]
+
+ x = []; y = []
+
+ for i in range(len(X)):
+ x.append(int(float(X[i])))
+
+ for i in range(len(Y)):
+ y.append(int(float(Y[i])))
+
+ pts = []
+ for i in range(len(x)):
+ pts.append((x[i], y[i]))
+
+ return pts
+
+
+def getSphere(pts):
+ sphs = []
+
+ for ele in pts:
+ if(ele is not None):
+ sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03)
+ sphere.paint_uniform_color([0.9, 0.2, 0])
+
+ trans = np.identity(4)
+ trans[0, 3] = ele[0]
+ trans[1, 3] = ele[1]
+ trans[2, 3] = ele[2]
+
+ sphere.transform(trans)
+ sphs.append(sphere)
+
+ return sphs
+
+
+def get3dCor(src, trg):
+ corr = []
+
+ for sId, tId in zip(src, trg):
+ if(sId != -1 and tId != -1):
+ corr.append((sId, tId))
+
+ corr = np.asarray(corr)
+
+ return corr
+
+if __name__ == "__main__":
+
+ focalX, focalY, centerX, centerY, scalingFactor = readCamera(args.camera_file)
+
+ rgb_name_src = os.path.basename(args.rgb1)
+ H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy'
+ srcH = os.path.join(os.path.dirname(args.rgb1), H_name_src)
+ rgb_name_trg = os.path.basename(args.rgb2)
+ H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy'
+ trgH = os.path.join(os.path.dirname(args.rgb2), H_name_trg)
+
+ use_cuda = torch.cuda.is_available()
+ device = torch.device('cuda:0' if use_cuda else 'cpu')
+ model1 = D2Net(model_file=args.model_d2)
+ model1 = model1.to(device)
+ model2 = D2Net(model_file=args.model_rord)
+ model2 = model2.to(device)
+
+ if args.model_rord:
+ srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypoints(args.rgb1, args.rgb2, srcH, trgH, model2, device)
+ elif args.model_d2:
+ srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypoints(args.rgb1, args.rgb2, srcH, trgH, model1, device)
+ elif args.model_ens:
+ model1 = D2Net(model_file=model1_ens)
+ model1 = model1.to(device)
+ model2 = D2Net(model_file=model2_ens)
+ model2 = model2.to(device)
+ srcPts, trgPts, matchImg, matchImgOrtho = getPerspKeypointsEnsemble(model1, model2, args.rgb1, args.rgb2, srcH, trgH, device)
+ elif args.sift:
+ srcPts, trgPts, matchImg, matchImgOrtho = siftMatching(args.rgb1, args.rgb2, srcH, trgH, device)
+
+ #### Visualization ####
+ print("\nShowing matches in perspective and orthographic view. Press q\n")
+ cv2.imshow('Orthographic view', matchImgOrtho)
+ cv2.imshow('Perspective view', matchImg)
+ cv2.waitKey()
+
+ srcPts = convertPts(srcPts)
+ trgPts = convertPts(trgPts)
+
+ srcCld, srcIdx, srcCor = getPointCloud(args.rgb1, args.depth1, srcPts)
+ trgCld, trgIdx, trgCor = getPointCloud(args.rgb2, args.depth2, trgPts)
+
+ srcSph = getSphere(srcCor)
+ trgSph = getSphere(trgCor)
+ axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ srcSph.append(srcCld); srcSph.append(axis)
+ trgSph.append(trgCld); trgSph.append(axis)
+
+ corr = get3dCor(srcIdx, trgIdx)
+
+ p2p = o3d.registration.TransformationEstimationPointToPoint()
+ trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr))
+ print("Transformation matrix: \n", trans_init)
+
+ if args.viz3d:
+ # o3d.visualization.draw_geometries(srcSph)
+ # o3d.visualization.draw_geometries(trgSph)
+
+ draw_registration_result(srcCld, trgCld, trans_init)
diff --git a/imcui/third_party/RoRD/evaluation/DiverseView/evalRT.py b/imcui/third_party/RoRD/evaluation/DiverseView/evalRT.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0be9aef58e408668112e0587a03b2b33012a342
--- /dev/null
+++ b/imcui/third_party/RoRD/evaluation/DiverseView/evalRT.py
@@ -0,0 +1,307 @@
+import numpy as np
+import argparse
+import copy
+import os, sys
+import open3d as o3d
+from sys import argv, exit
+from PIL import Image
+import math
+from tqdm import tqdm
+import cv2
+
+
+sys.path.append("../../")
+
+from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching
+import pandas as pd
+
+
+import torch
+from lib.model_test import D2Net
+
+#### Cuda ####
+use_cuda = torch.cuda.is_available()
+device = torch.device('cuda:0' if use_cuda else 'cpu')
+
+#### Argument Parsing ####
+parser = argparse.ArgumentParser(description='RoRD ICP evaluation on a DiverseView dataset sequence.')
+
+parser.add_argument('--dataset', type=str, default='/scratch/udit/realsense/RoRD_data/preprocessed/',
+ help='path to the dataset folder')
+
+parser.add_argument('--sequence', type=str, default='data1')
+
+parser.add_argument(
+ '--output_dir', type=str, default='out',
+ help='output directory for RT estimates'
+)
+
+parser.add_argument(
+ '--model_rord', type=str, help='path to the RoRD model for evaluation'
+)
+
+parser.add_argument(
+ '--model_d2', type=str, help='path to the vanilla D2-Net model for evaluation'
+)
+
+parser.add_argument(
+ '--model_ens', action='store_true',
+ help='ensemble model of RoRD + D2-Net'
+)
+
+parser.add_argument(
+ '--sift', action='store_true',
+ help='Sift'
+)
+
+parser.add_argument(
+ '--viz3d', action='store_true',
+ help='visualize the pointcloud registrations'
+)
+
+parser.add_argument(
+ '--log_interval', type=int, default=9,
+ help='Matched image logging interval'
+)
+
+parser.add_argument(
+ '--camera_file', type=str, default='../../configs/camera.txt',
+ help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.'
+)
+
+parser.add_argument(
+ '--persp', action='store_true', default=False,
+ help='Feature matching on perspective images.'
+)
+
+parser.set_defaults(fp16=False)
+args = parser.parse_args()
+
+
+if args.model_ens: # Change default paths accordingly for ensemble
+ model1_ens = '../../models/rord.pth'
+ model2_ens = '../../models/d2net.pth'
+
+def draw_registration_result(source, target, transformation):
+ source_temp = copy.deepcopy(source)
+ target_temp = copy.deepcopy(target)
+ source_temp.transform(transformation)
+ trgSph.append(source_temp); trgSph.append(target_temp)
+ axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ axis2.transform(transformation)
+ trgSph.append(axis1); trgSph.append(axis2)
+ o3d.visualization.draw_geometries(trgSph)
+
+def readDepth(depthFile):
+ depth = Image.open(depthFile)
+ if depth.mode != "I":
+ raise Exception("Depth image is not in intensity format")
+
+ return np.asarray(depth)
+
+def readCamera(camera):
+ with open (camera, "rt") as file:
+ contents = file.read().split()
+
+ focalX = float(contents[0])
+ focalY = float(contents[1])
+ centerX = float(contents[2])
+ centerY = float(contents[3])
+ scalingFactor = float(contents[4])
+
+ return focalX, focalY, centerX, centerY, scalingFactor
+
+
+def getPointCloud(rgbFile, depthFile, pts):
+ thresh = 15.0
+
+ depth = readDepth(depthFile)
+ rgb = Image.open(rgbFile)
+
+ points = []
+ colors = []
+
+ corIdx = [-1]*len(pts)
+ corPts = [None]*len(pts)
+ ptIdx = 0
+
+ for v in range(depth.shape[0]):
+ for u in range(depth.shape[1]):
+ Z = depth[v, u] / scalingFactor
+ if Z==0: continue
+ if (Z > thresh): continue
+
+ X = (u - centerX) * Z / focalX
+ Y = (v - centerY) * Z / focalY
+
+ points.append((X, Y, Z))
+ colors.append(rgb.getpixel((u, v)))
+
+ if((u, v) in pts):
+ index = pts.index((u, v))
+ corIdx[index] = ptIdx
+ corPts[index] = (X, Y, Z)
+
+ ptIdx = ptIdx+1
+
+ points = np.asarray(points)
+ colors = np.asarray(colors)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors/255)
+
+ return pcd, corIdx, corPts
+
+
+def convertPts(A):
+ X = A[0]; Y = A[1]
+
+ x = []; y = []
+
+ for i in range(len(X)):
+ x.append(int(float(X[i])))
+
+ for i in range(len(Y)):
+ y.append(int(float(Y[i])))
+
+ pts = []
+ for i in range(len(x)):
+ pts.append((x[i], y[i]))
+
+ return pts
+
+
+def getSphere(pts):
+ sphs = []
+
+ for element in pts:
+ if(element is not None):
+ sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03)
+ sphere.paint_uniform_color([0.9, 0.2, 0])
+
+ trans = np.identity(4)
+ trans[0, 3] = element[0]
+ trans[1, 3] = element[1]
+ trans[2, 3] = element[2]
+
+ sphere.transform(trans)
+ sphs.append(sphere)
+
+ return sphs
+
+
+def get3dCor(src, trg):
+ corr = []
+
+ for sId, tId in zip(src, trg):
+ if(sId != -1 and tId != -1):
+ corr.append((sId, tId))
+
+ corr = np.asarray(corr)
+
+ return corr
+
+if __name__ == "__main__":
+ camera_file = args.camera_file
+ rgb_csv = args.dataset + args.sequence + '/rtImagesRgb.csv'
+ depth_csv = args.dataset + args.sequence + '/rtImagesDepth.csv'
+
+ os.makedirs(os.path.join(args.output_dir, 'vis'), exist_ok=True)
+ dir_name = args.output_dir
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ focalX, focalY, centerX, centerY, scalingFactor = readCamera(camera_file)
+
+ df_rgb = pd.read_csv(rgb_csv)
+ df_dep = pd.read_csv(depth_csv)
+
+ model1 = D2Net(model_file=args.model_d2).to(device)
+ model2 = D2Net(model_file=args.model_rord).to(device)
+
+ queryId = 0
+ for im_q, dep_q in tqdm(zip(df_rgb['query'], df_dep['query']), total=df_rgb.shape[0]):
+ filter_list = []
+ dbId = 0
+ for im_d, dep_d in tqdm(zip(df_rgb.iteritems(), df_dep.iteritems()), total=df_rgb.shape[1]):
+ if im_d[0] == 'query':
+ continue
+ rgb_name_src = os.path.basename(im_q)
+ H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy'
+ srcH = args.dataset + args.sequence + '/rgb/' + H_name_src
+ rgb_name_trg = os.path.basename(im_d[1][1])
+ H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy'
+ trgH = args.dataset + args.sequence + '/rgb/' + H_name_trg
+
+ srcImg = srcH.replace('.npy', '.jpg')
+ trgImg = trgH.replace('.npy', '.jpg')
+
+ if args.model_rord:
+ if args.persp:
+ srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device)
+ else:
+ srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model2, device)
+
+ elif args.model_d2:
+ if args.persp:
+ srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device)
+ else:
+ srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model1, device)
+
+ elif args.model_ens:
+ model1 = D2Net(model_file=model1_ens)
+ model1 = model1.to(device)
+ model2 = D2Net(model_file=model2_ens)
+ model2 = model2.to(device)
+ srcPts, trgPts, matchImg = getPerspKeypointsEnsemble(model1, model2, srcImg, trgImg, srcH, trgH, device)
+
+ elif args.sift:
+ if args.persp:
+ srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, HFile1=None, HFile2=None, device=device)
+ else:
+ srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, srcH, trgH, device)
+
+ if(isinstance(srcPts, list) == True):
+ print(np.identity(4))
+ filter_list.append(np.identity(4))
+ continue
+
+
+ srcPts = convertPts(srcPts)
+ trgPts = convertPts(trgPts)
+
+ depth_name_src = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_q
+ depth_name_trg = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_d[1][1]
+
+ srcCld, srcIdx, srcCor = getPointCloud(srcImg, depth_name_src, srcPts)
+ trgCld, trgIdx, trgCor = getPointCloud(trgImg, depth_name_trg, trgPts)
+
+ srcSph = getSphere(srcCor)
+ trgSph = getSphere(trgCor)
+ axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0])
+ srcSph.append(srcCld); srcSph.append(axis)
+ trgSph.append(trgCld); trgSph.append(axis)
+
+ corr = get3dCor(srcIdx, trgIdx)
+
+ p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint()
+ trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr))
+ # print(trans_init)
+ filter_list.append(trans_init)
+
+ if args.viz3d:
+ o3d.visualization.draw_geometries(srcSph)
+ o3d.visualization.draw_geometries(trgSph)
+ draw_registration_result(srcCld, trgCld, trans_init)
+
+ if(dbId%args.log_interval == 0):
+ cv2.imwrite(os.path.join(args.output_dir, 'vis') + "/matchImg.%02d.%02d.jpg"%(queryId, dbId//args.log_interval), matchImg)
+ dbId += 1
+
+
+ RT = np.stack(filter_list).transpose(1,2,0)
+
+ np.save(os.path.join(dir_name, str(queryId) + '.npy'), RT)
+ queryId += 1
+ print('-----check-------', RT.shape)
diff --git a/imcui/third_party/RoRD/extractMatch.py b/imcui/third_party/RoRD/extractMatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b413dde1334b52fef294fb0c10c2acfe5b901534
--- /dev/null
+++ b/imcui/third_party/RoRD/extractMatch.py
@@ -0,0 +1,195 @@
+import argparse
+
+import numpy as np
+
+import imageio
+
+import torch
+
+from tqdm import tqdm
+import time
+import scipy
+import scipy.io
+import scipy.misc
+import os
+import sys
+
+from lib.model_test import D2Net
+from lib.utils import preprocess_image
+from lib.pyramid import process_multiscale
+
+import cv2
+import matplotlib.pyplot as plt
+from PIL import Image
+from skimage.feature import match_descriptors
+from skimage.measure import ransac
+from skimage.transform import ProjectiveTransform, AffineTransform
+import pydegensac
+
+
+parser = argparse.ArgumentParser(description='Feature extraction script')
+parser.add_argument('imgs', type=str, nargs=2)
+parser.add_argument(
+ '--preprocessing', type=str, default='caffe',
+ help='image preprocessing (caffe or torch)'
+)
+
+parser.add_argument(
+ '--model_file', type=str,
+ help='path to the full model'
+)
+
+parser.add_argument(
+ '--no-relu', dest='use_relu', action='store_false',
+ help='remove ReLU after the dense feature extraction module'
+)
+parser.set_defaults(use_relu=True)
+
+parser.add_argument(
+ '--sift', dest='use_sift', action='store_true',
+ help='Show sift matching as well'
+)
+parser.set_defaults(use_sift=False)
+
+
+def extract(image, args, model, device):
+ if len(image.shape) == 2:
+ image = image[:, :, np.newaxis]
+ image = np.repeat(image, 3, -1)
+
+ input_image = preprocess_image(
+ image,
+ preprocessing=args.preprocessing
+ )
+ with torch.no_grad():
+ keypoints, scores, descriptors = process_multiscale(
+ torch.tensor(
+ input_image[np.newaxis, :, :, :].astype(np.float32),
+ device=device
+ ),
+ model,
+ scales=[1]
+ )
+
+ keypoints = keypoints[:, [1, 0, 2]]
+
+ feat = {}
+ feat['keypoints'] = keypoints
+ feat['scores'] = scores
+ feat['descriptors'] = descriptors
+
+ return feat
+
+
+def rordMatching(image1, image2, feat1, feat2, matcher="BF"):
+ if(matcher == "BF"):
+
+ t0 = time.time()
+ bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
+ matches = bf.match(feat1['descriptors'], feat2['descriptors'])
+ matches = sorted(matches, key=lambda x:x.distance)
+ t1 = time.time()
+ print("Time to extract matches: ", t1-t0)
+
+ print("Number of raw matches:", len(matches))
+
+ match1 = [m.queryIdx for m in matches]
+ match2 = [m.trainIdx for m in matches]
+
+ keypoints_left = feat1['keypoints'][match1, : 2]
+ keypoints_right = feat2['keypoints'][match2, : 2]
+
+ np.random.seed(0)
+
+ t0 = time.time()
+
+ H, inliers = pydegensac.findHomography(keypoints_left, keypoints_right, 10.0, 0.99, 10000)
+
+ t1 = time.time()
+ print("Time for ransac: ", t1-t0)
+
+ n_inliers = np.sum(inliers)
+ print('Number of inliers: %d.' % n_inliers)
+
+ inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_left[inliers]]
+ inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_right[inliers]]
+ placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)]
+
+ draw_params = dict(matchColor = (0,255,0),
+ singlePointColor = (255,0,0),
+ # matchesMask = matchesMask,
+ flags = 0)
+ image3 = cv2.drawMatches(image1, inlier_keypoints_left, image2, inlier_keypoints_right, placeholder_matches, None, **draw_params)
+
+ plt.figure(figsize=(20, 20))
+ plt.imshow(image3)
+ plt.axis('off')
+ plt.show()
+
+
+def siftMatching(img1, img2):
+ img1 = np.array(cv2.cvtColor(np.array(img1), cv2.COLOR_BGR2RGB))
+ img2 = np.array(cv2.cvtColor(np.array(img2), cv2.COLOR_BGR2RGB))
+
+ # surf = cv2.xfeatures2d.SURF_create(100)
+ surf = cv2.xfeatures2d.SIFT_create()
+
+ kp1, des1 = surf.detectAndCompute(img1, None)
+ kp2, des2 = surf.detectAndCompute(img2, None)
+
+ FLANN_INDEX_KDTREE = 0
+ index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
+ search_params = dict(checks = 50)
+ flann = cv2.FlannBasedMatcher(index_params, search_params)
+ matches = flann.knnMatch(des1,des2,k=2)
+ good = []
+ for m, n in matches:
+ if m.distance < 0.7*n.distance:
+ good.append(m)
+
+ src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1, 2)
+ dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1, 2)
+
+ model, inliers = pydegensac.findHomography(src_pts, dst_pts, 10.0, 0.99, 10000)
+
+ n_inliers = np.sum(inliers)
+ print('Number of inliers: %d.' % n_inliers)
+
+ inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]]
+ inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]]
+ placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)]
+ image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None)
+
+ cv2.imshow('Matches', image3)
+ cv2.waitKey(0)
+
+ src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
+ dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
+
+ return src_pts, dst_pts
+
+
+if __name__ == '__main__':
+ use_cuda = torch.cuda.is_available()
+ device = torch.device("cuda:0" if use_cuda else "cpu")
+ args = parser.parse_args()
+
+ model = D2Net(
+ model_file=args.model_file,
+ use_relu=args.use_relu,
+ use_cuda=use_cuda
+ )
+
+ image1 = np.array(Image.open(args.imgs[0]))
+ image2 = np.array(Image.open(args.imgs[1]))
+
+ print('--\nRoRD\n--')
+ feat1 = extract(image1, args, model, device)
+ feat2 = extract(image2, args, model, device)
+ print("Features extracted.")
+
+ rordMatching(image1, image2, feat1, feat2, matcher="BF")
+
+ if(args.use_sift):
+ print('--\nSIFT\n--')
+ siftMatching(image1, image2)
diff --git a/imcui/third_party/RoRD/scripts/getRTImages.py b/imcui/third_party/RoRD/scripts/getRTImages.py
new file mode 100644
index 0000000000000000000000000000000000000000..6972c349c0dc2c046c67e194ba79ea6d7da725bd
--- /dev/null
+++ b/imcui/third_party/RoRD/scripts/getRTImages.py
@@ -0,0 +1,54 @@
+import os
+import re
+from sys import argv, exit
+import csv
+import numpy as np
+
+
+def natural_sort(l):
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
+ alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
+ return sorted(l, key = alphanum_key)
+
+
+def getPairs(imgs):
+ queryIdxs = np.linspace(start=0, stop=len(imgs)-1, num=10).astype(int).tolist()
+ databaseIdxs = np.linspace(start=10, stop=len(imgs)-10, num=100).astype(int).tolist()
+
+ queryImgs = [imgs[idx] for idx in queryIdxs]
+ databaseImgs = [imgs[idx] for idx in databaseIdxs]
+
+ return queryImgs, databaseImgs
+
+
+def writeCSV(qImgs, dImgs):
+ with open('rtImagesDepth.csv', 'w', newline='') as file:
+ writer = csv.writer(file)
+
+ title = []
+ title.append('query')
+
+ for i in range(len(dImgs)):
+ title.append('data' + str(i+1))
+
+ writer.writerow(title)
+
+ for qImg in qImgs:
+ row = []
+ row.append(qImg)
+
+ for dImg in dImgs:
+ row.append(dImg)
+
+ writer.writerow(row)
+
+
+if __name__ == '__main__':
+ rgbDir = argv[1]
+ rgbImgs = natural_sort([file for file in os.listdir(rgbDir) if (file.find("jpg") != -1 or file.find("png") != -1)])
+
+ rgbImgs = [os.path.join(rgbDir, img) for img in rgbImgs]
+
+ queryImgs, databaseImgs = getPairs(rgbImgs)
+
+ writeCSV(queryImgs, databaseImgs)
\ No newline at end of file
diff --git a/imcui/third_party/RoRD/scripts/metricRT.py b/imcui/third_party/RoRD/scripts/metricRT.py
new file mode 100644
index 0000000000000000000000000000000000000000..99a323b269e79d4c8f179bae3227224beff57f6c
--- /dev/null
+++ b/imcui/third_party/RoRD/scripts/metricRT.py
@@ -0,0 +1,63 @@
+import numpy as np
+import re
+import os
+import argparse
+
+
+def natural_sort(l):
+ convert = lambda text: int(text) if text.isdigit() else text.lower()
+ alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
+
+ return sorted(l, key = alphanum_key)
+
+
+def angular_distance_np(R_hat, R):
+ # measure the angular distance between two rotation matrice
+ # R1,R2: [n, 3, 3]
+ if R_hat.shape == (3,3):
+ R_hat = R_hat[np.newaxis,:]
+ if R.shape == (3,3):
+ R = R[np.newaxis,:]
+ n = R.shape[0]
+ trace_idx = [0,4,8]
+ trace = np.matmul(R_hat, R.transpose(0,2,1)).reshape(n,-1)[:,trace_idx].sum(1)
+ metric = np.arccos(((trace - 1)/2).clip(-1,1)) / np.pi * 180.0
+
+ return metric
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Rotation and translation metric.')
+ parser.add_argument('--trans1', type=str)
+ parser.add_argument('--trans2', type=str)
+
+ args = parser.parse_args()
+
+ transFiles1 = natural_sort([file for file in os.listdir(args.trans1) if (file.find("npy") != -1 )])
+ transFiles1 = [os.path.join(args.trans1, img) for img in transFiles1]
+
+ transFiles2 = natural_sort([file for file in os.listdir(args.trans2) if (file.find("npy") != -1 )])
+ transFiles2 = [os.path.join(args.trans2, img) for img in transFiles2]
+
+ # print(len(transFiles1), transFiles1)
+ # print(len(transFiles2), transFiles2)
+
+ for T1_file, T2_file in zip(transFiles1, transFiles2):
+ T1 = np.load(T1_file)
+ T2 = np.load(T2_file)
+ print("Shapes: ", T1.shape, T2.shape)
+
+ for i in range(T1.shape[2]):
+ R1 = T1[:3, :3, i]
+ R2 = T2[:3, :3, i]
+ t1 = T1[:4, -1, i]
+ t2 = T2[:4, -1, i]
+
+ R_norm = angular_distance_np(R1.reshape(1,3,3), R2.reshape(1,3,3))[0]
+
+ print("R norm:", R_norm)
+ exit(1)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/imcui/third_party/RoRD/trainPT_ipr.py b/imcui/third_party/RoRD/trainPT_ipr.py
new file mode 100644
index 0000000000000000000000000000000000000000..f730bbb52338509956e9979ddb07d5bef0bd57d0
--- /dev/null
+++ b/imcui/third_party/RoRD/trainPT_ipr.py
@@ -0,0 +1,225 @@
+import argparse
+import numpy as np
+import os
+import sys
+
+import shutil
+
+import torch
+import torch.optim as optim
+
+from torch.utils.data import DataLoader
+
+from tqdm import tqdm
+
+import warnings
+
+from lib.exceptions import NoGradientError
+from lib.losses.lossPhotoTourism import loss_function
+from lib.model import D2Net
+from lib.dataloaders.datasetPhotoTourism_ipr import PhotoTourismIPR
+
+
+# CUDA
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:0" if use_cuda else "cpu")
+
+# Seed
+torch.manual_seed(1)
+if use_cuda:
+ torch.cuda.manual_seed(1)
+np.random.seed(1)
+
+# Argument parsing
+parser = argparse.ArgumentParser(description='Training script')
+
+parser.add_argument(
+ '--dataset_path', type=str, default="/scratch/udit/phototourism/",
+ help='path to the dataset'
+)
+
+parser.add_argument(
+ '--preprocessing', type=str, default='caffe',
+ help='image preprocessing (caffe or torch)'
+)
+
+parser.add_argument(
+ '--init_model', type=str, default='models/d2net.pth',
+ help='path to the initial model'
+)
+
+parser.add_argument(
+ '--num_epochs', type=int, default=10,
+ help='number of training epochs'
+)
+parser.add_argument(
+ '--lr', type=float, default=1e-3,
+ help='initial learning rate'
+)
+parser.add_argument(
+ '--batch_size', type=int, default=1,
+ help='batch size'
+)
+parser.add_argument(
+ '--num_workers', type=int, default=16,
+ help='number of workers for data loading'
+)
+
+parser.add_argument(
+ '--log_interval', type=int, default=250,
+ help='loss logging interval'
+)
+
+parser.add_argument(
+ '--log_file', type=str, default='log.txt',
+ help='loss logging file'
+)
+
+parser.add_argument(
+ '--plot', dest='plot', action='store_true',
+ help='plot training pairs'
+)
+parser.set_defaults(plot=False)
+
+parser.add_argument(
+ '--checkpoint_directory', type=str, default='checkpoints',
+ help='directory for training checkpoints'
+)
+parser.add_argument(
+ '--checkpoint_prefix', type=str, default='rord',
+ help='prefix for training checkpoints'
+)
+
+args = parser.parse_args()
+print(args)
+
+# Creating CNN model
+model = D2Net(
+ model_file=args.init_model,
+ use_cuda=False
+)
+model = model.to(device)
+
+# Optimizer
+optimizer = optim.Adam(
+ filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr
+)
+
+training_dataset = PhotoTourismIPR(
+ base_path=args.dataset_path,
+ preprocessing=args.preprocessing
+)
+training_dataset.build_dataset()
+
+training_dataloader = DataLoader(
+ training_dataset,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers
+)
+
+# Define epoch function
+def process_epoch(
+ epoch_idx,
+ model, loss_function, optimizer, dataloader, device,
+ log_file, args, train=True, plot_path=None
+):
+ epoch_losses = []
+
+ torch.set_grad_enabled(train)
+
+ progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
+ for batch_idx, batch in progress_bar:
+ if train:
+ optimizer.zero_grad()
+
+ batch['train'] = train
+ batch['epoch_idx'] = epoch_idx
+ batch['batch_idx'] = batch_idx
+ batch['batch_size'] = args.batch_size
+ batch['preprocessing'] = args.preprocessing
+ batch['log_interval'] = args.log_interval
+
+ try:
+ loss = loss_function(model, batch, device, plot=args.plot, plot_path=plot_path)
+ except NoGradientError:
+ # print("failed")
+ continue
+
+ current_loss = loss.data.cpu().numpy()[0]
+ epoch_losses.append(current_loss)
+
+ progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses)))
+
+ if batch_idx % args.log_interval == 0:
+ log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % (
+ 'train' if train else 'valid',
+ epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses)
+ ))
+
+ if train:
+ loss.backward()
+ optimizer.step()
+
+ log_file.write('[%s] epoch %d - avg_loss: %f\n' % (
+ 'train' if train else 'valid',
+ epoch_idx,
+ np.mean(epoch_losses)
+ ))
+ log_file.flush()
+
+ return np.mean(epoch_losses)
+
+
+# Create the checkpoint directory
+checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix)
+if os.path.isdir(checkpoint_directory):
+ print('[Warning] Checkpoint directory already exists.')
+else:
+ os.makedirs(checkpoint_directory, exist_ok=True)
+
+# Open the log file for writing
+log_file = os.path.join(checkpoint_directory,args.log_file)
+if os.path.exists(log_file):
+ print('[Warning] Log file already exists.')
+log_file = open(log_file, 'a+')
+
+# Create the folders for plotting if need be
+plot_path=None
+if args.plot:
+ plot_path = os.path.join(checkpoint_directory,'train_vis')
+ if os.path.isdir(plot_path):
+ print('[Warning] Plotting directory already exists.')
+ else:
+ os.makedirs(plot_path, exist_ok=True)
+
+
+# Initialize the history
+train_loss_history = []
+
+# Start the training
+for epoch_idx in range(1, args.num_epochs + 1):
+ # Process epoch
+ train_loss_history.append(
+ process_epoch(
+ epoch_idx,
+ model, loss_function, optimizer, training_dataloader, device,
+ log_file, args, train=True, plot_path=plot_path
+ )
+ )
+
+ # Save the current checkpoint
+ checkpoint_path = os.path.join(
+ checkpoint_directory,
+ '%02d.pth' % (epoch_idx)
+ )
+ checkpoint = {
+ 'args': args,
+ 'epoch_idx': epoch_idx,
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'train_loss_history': train_loss_history,
+ }
+ torch.save(checkpoint, checkpoint_path)
+
+# Close the log file
+log_file.close()
diff --git a/imcui/third_party/RoRD/trainers/trainPT_combined.py b/imcui/third_party/RoRD/trainers/trainPT_combined.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32fcf00937a451195270bc5f2e3e4f43af36237
--- /dev/null
+++ b/imcui/third_party/RoRD/trainers/trainPT_combined.py
@@ -0,0 +1,289 @@
+
+import argparse
+import numpy as np
+import os
+import sys
+sys.path.append("../")
+
+import shutil
+
+import torch
+import torch.optim as optim
+
+from torch.utils.data import DataLoader
+
+from tqdm import tqdm
+
+import warnings
+
+# from lib.dataset import MegaDepthDataset
+
+from lib.exceptions import NoGradientError
+from lib.loss import loss_function as orig_loss
+from lib.losses.lossPhotoTourism import loss_function as ipr_loss
+from lib.model import D2Net
+from lib.dataloaders.datasetPhotoTourism_combined import PhotoTourismCombined
+
+
+# CUDA
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:1" if use_cuda else "cpu")
+
+# Seed
+torch.manual_seed(1)
+if use_cuda:
+ torch.cuda.manual_seed(1)
+np.random.seed(1)
+
+# Argument parsing
+parser = argparse.ArgumentParser(description='Training script')
+
+parser.add_argument(
+ '--dataset_path', type=str, default="/scratch/udit/phototourism/",
+ help='path to the dataset'
+)
+# parser.add_argument(
+# '--scene_info_path', type=str, required=True,
+# help='path to the processed scenes'
+# )
+
+parser.add_argument(
+ '--preprocessing', type=str, default='caffe',
+ help='image preprocessing (caffe or torch)'
+)
+
+parser.add_argument(
+ '--model_file', type=str, default='models/d2_ots.pth',
+ help='path to the full model'
+)
+
+parser.add_argument(
+ '--num_epochs', type=int, default=10,
+ help='number of training epochs'
+)
+parser.add_argument(
+ '--lr', type=float, default=1e-3,
+ help='initial learning rate'
+)
+parser.add_argument(
+ '--batch_size', type=int, default=1,
+ help='batch size'
+)
+parser.add_argument(
+ '--num_workers', type=int, default=16,
+ help='number of workers for data loading'
+)
+
+parser.add_argument(
+ '--use_validation', dest='use_validation', action='store_true',
+ help='use the validation split'
+)
+parser.set_defaults(use_validation=False)
+
+parser.add_argument(
+ '--log_interval', type=int, default=250,
+ help='loss logging interval'
+)
+
+parser.add_argument(
+ '--log_file', type=str, default='log.txt',
+ help='loss logging file'
+)
+
+parser.add_argument(
+ '--plot', dest='plot', action='store_true',
+ help='plot training pairs'
+)
+parser.set_defaults(plot=False)
+
+parser.add_argument(
+ '--checkpoint_directory', type=str, default='checkpoints',
+ help='directory for training checkpoints'
+)
+parser.add_argument(
+ '--checkpoint_prefix', type=str, default='d2',
+ help='prefix for training checkpoints'
+)
+
+args = parser.parse_args()
+print(args)
+
+# Creating CNN model
+model = D2Net(
+ model_file=args.model_file,
+ use_cuda=False
+)
+model = model.to(device)
+
+# Optimizer
+optimizer = optim.Adam(
+ filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr
+)
+
+# Dataset
+if args.use_validation:
+ validation_dataset = PhotoTourismCombined(
+ # scene_list_path='megadepth_utils/valid_scenes.txt',
+ # scene_info_path=args.scene_info_path,
+ base_path=args.dataset_path,
+ train=False,
+ preprocessing=args.preprocessing,
+ pairs_per_scene=25
+ )
+ # validation_dataset.build_dataset()
+ validation_dataloader = DataLoader(
+ validation_dataset,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers
+ )
+
+training_dataset = PhotoTourismCombined(
+ # scene_list_path='megadepth_utils/train_scenes.txt',
+ # scene_info_path=args.scene_info_path,
+ base_path=args.dataset_path,
+ preprocessing=args.preprocessing
+)
+# training_dataset.build_dataset()
+
+training_dataloader = DataLoader(
+ training_dataset,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers
+)
+
+
+# Define epoch function
+def process_epoch(
+ epoch_idx,
+ model, loss_function, optimizer, dataloader, device,
+ log_file, args, train=True, plot_path=None
+):
+ epoch_losses = []
+
+ torch.set_grad_enabled(train)
+
+ progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
+ for batch_idx, (batch,method) in progress_bar:
+ if train:
+ optimizer.zero_grad()
+
+ batch['train'] = train
+ batch['epoch_idx'] = epoch_idx
+ batch['batch_idx'] = batch_idx
+ batch['batch_size'] = args.batch_size
+ batch['preprocessing'] = args.preprocessing
+ batch['log_interval'] = args.log_interval
+
+ try:
+ loss = loss_function[method](model, batch, device, plot=args.plot, plot_path=plot_path)
+ except NoGradientError:
+ # print("failed")
+ continue
+
+ current_loss = loss.data.cpu().numpy()[0]
+ epoch_losses.append(current_loss)
+
+ progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses)))
+
+ if batch_idx % args.log_interval == 0:
+ log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % (
+ 'train' if train else 'valid',
+ epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses)
+ ))
+
+ if train:
+ loss.backward()
+ optimizer.step()
+
+ log_file.write('[%s] epoch %d - avg_loss: %f\n' % (
+ 'train' if train else 'valid',
+ epoch_idx,
+ np.mean(epoch_losses)
+ ))
+ log_file.flush()
+
+ return np.mean(epoch_losses)
+
+
+# Create the checkpoint directory
+checkpoint_directory = os.path.join(args.checkpoint_directory, args.checkpoint_prefix)
+if os.path.isdir(checkpoint_directory):
+ print('[Warning] Checkpoint directory already exists.')
+else:
+ os.makedirs(checkpoint_directory, exist_ok=True)
+
+# Open the log file for writing
+log_file = os.path.join(checkpoint_directory,args.log_file)
+if os.path.exists(log_file):
+ print('[Warning] Log file already exists.')
+log_file = open(log_file, 'a+')
+
+# Create the folders for plotting if need be
+plot_path=None
+if args.plot:
+ plot_path = os.path.join(checkpoint_directory,'train_vis')
+ if os.path.isdir(plot_path):
+ print('[Warning] Plotting directory already exists.')
+ else:
+ os.makedirs(plot_path, exist_ok=True)
+
+
+# Initialize the history
+train_loss_history = []
+validation_loss_history = []
+if args.use_validation:
+ min_validation_loss = process_epoch(
+ 0,
+ model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device,
+ log_file, args,
+ train=False
+ )
+
+# Start the training
+for epoch_idx in range(1, args.num_epochs + 1):
+ # Process epoch
+ train_loss_history.append(
+ process_epoch(
+ epoch_idx,
+ model, [orig_loss, ipr_loss], optimizer, training_dataloader, device,
+ log_file, args, train=True, plot_path=plot_path
+ )
+ )
+
+ if args.use_validation:
+ validation_loss_history.append(
+ process_epoch(
+ epoch_idx,
+ model, [orig_loss, ipr_loss], optimizer, validation_dataloader, device,
+ log_file, args,
+ train=False
+ )
+ )
+
+ # Save the current checkpoint
+ checkpoint_path = os.path.join(
+ checkpoint_directory,
+ '%02d.pth' % (epoch_idx)
+ )
+ checkpoint = {
+ 'args': args,
+ 'epoch_idx': epoch_idx,
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'train_loss_history': train_loss_history,
+ 'validation_loss_history': validation_loss_history
+ }
+ torch.save(checkpoint, checkpoint_path)
+ if (
+ args.use_validation and
+ validation_loss_history[-1] < min_validation_loss
+ ):
+ min_validation_loss = validation_loss_history[-1]
+ best_checkpoint_path = os.path.join(
+ checkpoint_directory,
+ '%s.best.pth' % args.checkpoint_prefix
+ )
+ shutil.copy(checkpoint_path, best_checkpoint_path)
+
+# Close the log file
+log_file.close()
diff --git a/imcui/third_party/SGMNet/components/__init__.py b/imcui/third_party/SGMNet/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c10d2027efcf985c68abf7185f28b947012cae45
--- /dev/null
+++ b/imcui/third_party/SGMNet/components/__init__.py
@@ -0,0 +1,3 @@
+from . import extractors
+from . import matchers
+from .load_component import load_component
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/components/evaluators.py b/imcui/third_party/SGMNet/components/evaluators.py
new file mode 100644
index 0000000000000000000000000000000000000000..59bf0bd7ce3dd085dc86072fc41bad24b9805991
--- /dev/null
+++ b/imcui/third_party/SGMNet/components/evaluators.py
@@ -0,0 +1,127 @@
+import numpy as np
+import sys
+import os
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, ROOT_DIR)
+
+from utils import evaluation_utils,metrics,fm_utils
+import cv2
+
+class auc_eval:
+ def __init__(self,config):
+ self.config=config
+ self.err_r,self.err_t,self.err=[],[],[]
+ self.ms=[]
+ self.precision=[]
+
+ def run(self,info):
+ E,r_gt,t_gt=info['e'],info['r_gt'],info['t_gt']
+ K1,K2,img1,img2=info['K1'],info['K2'],info['img1'],info['img2']
+ corr1,corr2=info['corr1'],info['corr2']
+ corr1,corr2=evaluation_utils.normalize_intrinsic(corr1,K1),evaluation_utils.normalize_intrinsic(corr2,K2)
+ size1,size2=max(img1.shape),max(img2.shape)
+ scale1,scale2=self.config['rescale']/size1,self.config['rescale']/size2
+ #ransac
+ ransac_th=4./((K1[0,0]+K1[1,1])*scale1+(K2[0,0]+K2[1,1])*scale2)
+ R_hat,t_hat,E_hat=self.estimate(corr1,corr2,ransac_th)
+ #get pose error
+ err_r, err_t=metrics.evaluate_R_t(r_gt,t_gt,R_hat,t_hat)
+ err=max(err_r,err_t)
+
+ if len(corr1)>1:
+ inlier_mask=metrics.compute_epi_inlier(corr1,corr2,E,self.config['inlier_th'])
+ precision=inlier_mask.mean()
+ ms=inlier_mask.sum()/len(info['x1'])
+ else:
+ ms=precision=0
+
+ return {'err_r':err_r,'err_t':err_t,'err':err,'ms':ms,'precision':precision}
+
+ def res_inqueue(self,res):
+ self.err_r.append(res['err_r']),self.err_t.append(res['err_t']),self.err.append(res['err'])
+ self.ms.append(res['ms']),self.precision.append(res['precision'])
+
+ def estimate(self,corr1,corr2,th):
+ num_inlier = -1
+ if corr1.shape[0] >= 5:
+ E, mask_new = cv2.findEssentialMat(corr1, corr2,method=cv2.RANSAC, threshold=th,prob=1-1e-5)
+ if E is None:
+ E=[np.eye(3)]
+ for _E in np.split(E, len(E) / 3):
+ _num_inlier, _R, _t, _ = cv2.recoverPose(_E, corr1, corr2,np.eye(3), 1e9,mask=mask_new)
+ if _num_inlier > num_inlier:
+ num_inlier = _num_inlier
+ R = _R
+ t = _t
+ E = _E
+ else:
+ E,R,t=np.eye(3),np.eye(3),np.zeros(3)
+ return R,t,E
+
+ def parse(self):
+ ths = np.arange(7) * 5
+ approx_auc=metrics.approx_pose_auc(self.err,ths)
+ exact_auc=metrics.pose_auc(self.err,ths)
+ mean_pre,mean_ms=np.mean(np.asarray(self.precision)),np.mean(np.asarray(self.ms))
+
+ print('auc th: ',ths[1:])
+ print('approx auc: ',approx_auc)
+ print('exact auc: ', exact_auc)
+ print('mean match score: ',mean_ms*100)
+ print('mean precision: ',mean_pre*100)
+
+
+
+class FMbench_eval:
+
+ def __init__(self,config):
+ self.config=config
+ self.pre,self.pre_post,self.sgd=[],[],[]
+ self.num_corr,self.num_corr_post=[],[]
+
+ def run(self,info):
+ corr1,corr2=info['corr1'],info['corr2']
+ F=info['f']
+ img1,img2=info['img1'],info['img2']
+
+ if len(corr1)>1:
+ pre_bf=fm_utils.compute_inlier_rate(corr1,corr2,np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean()
+ F_hat,mask_F=cv2.findFundamentalMat(corr1,corr2,method=cv2.FM_RANSAC,ransacReprojThreshold=1,confidence=1-1e-5)
+ if F_hat is None:
+ F_hat=np.ones([3,3])
+ mask_F=np.ones([len(corr1)]).astype(bool)
+ else:
+ mask_F=mask_F.squeeze().astype(bool)
+ F_hat=F_hat[:3]
+ pre_af=fm_utils.compute_inlier_rate(corr1[mask_F],corr2[mask_F],np.flip(img1.shape[:2]),np.flip(img2.shape[:2]),F,th=self.config['inlier_th']).mean()
+ num_corr_af=mask_F.sum()
+ num_corr=len(corr1)
+ sgd=fm_utils.compute_SGD(F,F_hat,np.flip(img1.shape[:2]),np.flip(img2.shape[:2]))
+ else:
+ pre_bf,pre_af,sgd=0,0,1e8
+ num_corr,num_corr_af=0,0
+ return {'pre':pre_bf,'pre_post':pre_af,'sgd':sgd,'num_corr':num_corr,'num_corr_post':num_corr_af}
+
+
+ def res_inqueue(self,res):
+ self.pre.append(res['pre']),self.pre_post.append(res['pre_post']),self.sgd.append(res['sgd'])
+ self.num_corr.append(res['num_corr']),self.num_corr_post.append(res['num_corr_post'])
+
+ def parse(self):
+ for seq_index in range(len(self.config['seq'])):
+ seq=self.config['seq'][seq_index]
+ offset=seq_index*1000
+ pre=np.asarray(self.pre)[offset:offset+1000].mean()
+ pre_post=np.asarray(self.pre_post)[offset:offset+1000].mean()
+ num_corr=np.asarray(self.num_corr)[offset:offset+1000].mean()
+ num_corr_post=np.asarray(self.num_corr_post)[offset:offset+1000].mean()
+ f_recall=(np.asarray(self.sgd)[offset:offset+1000]self.p_th,index[:,0],index2.squeeze(0)
+ mask_mc=index2[index] == torch.arange(len(p)).cuda()
+ mask=mask_th&mask_mc
+ index1,index2=torch.nonzero(mask).squeeze(1),index[mask]
+ return index1,index2
+
+
+class NN_Matcher(object):
+
+ def __init__(self,config):
+ config=namedtuple('config',config.keys())(*config.values())
+ self.mutual_check=config.mutual_check
+ self.ratio_th=config.ratio_th
+
+ def run(self,test_data):
+ desc1,desc2,x1,x2=test_data['desc1'],test_data['desc2'],test_data['x1'],test_data['x2']
+ desc_mat=np.sqrt(abs((desc1**2).sum(-1)[:,np.newaxis]+(desc2**2).sum(-1)[np.newaxis]-2*desc1@desc2.T))
+ nn_index=np.argpartition(desc_mat,kth=(1,2),axis=-1)
+ dis_value12=np.take_along_axis(desc_mat,nn_index, axis=-1)
+ ratio_score=dis_value12[:,0]/dis_value12[:,1]
+ nn_index1=nn_index[:,0]
+ nn_index2=np.argmin(desc_mat,axis=0)
+ mask_ratio,mask_mutual=ratio_scoreself.config['angle_th'][0],angle_listself.config['overlap_th'][0],overlap_scoreself.config['min_corr'] and len(incorr_index1)>self.config['min_incorr'] and len(incorr_index2)>self.config['min_incorr']:
+ info['corr'].append(corr_index),info['incorr1'].append(incorr_index1),info['incorr2'].append(incorr_index2)
+ info['dR'].append(dR),info['dt'].append(dt),info['K1'].append(K1),info['K2'].append(K2),info['img_path1'].append(img_path1),info['img_path2'].append(img_path2)
+ info['fea_path1'].append(fea_path1),info['fea_path2'].append(fea_path2),info['size1'].append(size1),info['size2'].append(size2)
+ sample_number+=1
+ if sample_number==sample_target:
+ break
+ info['pair_num']=sample_number
+ #dump info
+ self.dump_info(seq,info)
+
+
+ def collect_meta(self):
+ print('collecting meta info...')
+ dump_path,seq_list=[],[]
+ if self.config['dump_train']:
+ dump_path.append(os.path.join(self.config['dataset_dump_dir'],'train'))
+ seq_list.append(self.train_list)
+ if self.config['dump_valid']:
+ dump_path.append(os.path.join(self.config['dataset_dump_dir'],'valid'))
+ seq_list.append(self.valid_list)
+ for pth,seqs in zip(dump_path,seq_list):
+ if not os.path.exists(pth):
+ os.mkdir(pth)
+ pair_num_list,total_pair=[],0
+ for seq_index in range(len(seqs)):
+ seq=seqs[seq_index]
+ pair_num=np.loadtxt(os.path.join(self.config['dataset_dump_dir'],seq,'pair_num.txt'),dtype=int)
+ pair_num_list.append(str(pair_num))
+ total_pair+=pair_num
+ pair_num_list=np.stack([np.asarray(seqs,dtype=str),np.asarray(pair_num_list,dtype=str)],axis=1)
+ pair_num_list=np.concatenate([np.asarray([['total',str(total_pair)]]),pair_num_list],axis=0)
+ np.savetxt(os.path.join(pth,'pair_num.txt'),pair_num_list,fmt='%s')
+
+ def format_dump_data(self):
+ print('Formatting data...')
+ iteration_num=len(self.seq_list)//self.config['num_process']
+ if len(self.seq_list)%self.config['num_process']!=0:
+ iteration_num+=1
+ pool=Pool(self.config['num_process'])
+ for index in trange(iteration_num):
+ indices=range(index*self.config['num_process'],min((index+1)*self.config['num_process'],len(self.seq_list)))
+ pool.map(self.format_seq,indices)
+ pool.close()
+ pool.join()
+
+ self.collect_meta()
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/datadump/dumper/scannet.py b/imcui/third_party/SGMNet/datadump/dumper/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2556f727fcc9b4c621e44d9ee5cb4e99cb19b7e8
--- /dev/null
+++ b/imcui/third_party/SGMNet/datadump/dumper/scannet.py
@@ -0,0 +1,72 @@
+import os
+import glob
+import pickle
+from posixpath import basename
+import numpy as np
+import h5py
+from .base_dumper import BaseDumper
+
+import sys
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+sys.path.insert(0, ROOT_DIR)
+import utils
+
+class scannet(BaseDumper):
+ def get_seqs(self):
+ self.pair_list=np.loadtxt('../assets/scannet_eval_list.txt',dtype=str)
+ self.seq_list=np.unique(np.asarray([path.split('/')[0] for path in self.pair_list[:,0]],dtype=str))
+ self.dump_seq,self.img_seq=[],[]
+ for seq in self.seq_list:
+ dump_dir=os.path.join(self.config['feature_dump_dir'],seq)
+ cur_img_seq=glob.glob(os.path.join(os.path.join(self.config['rawdata_dir'],seq,'img','*.jpg')))
+ cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\
+ +'.hdf5' for path in cur_img_seq]
+ self.img_seq+=cur_img_seq
+ self.dump_seq+=cur_dump_seq
+
+ def format_dump_folder(self):
+ if not os.path.exists(self.config['feature_dump_dir']):
+ os.mkdir(self.config['feature_dump_dir'])
+ for seq in self.seq_list:
+ seq_dir=os.path.join(self.config['feature_dump_dir'],seq)
+ if not os.path.exists(seq_dir):
+ os.mkdir(seq_dir)
+
+ def format_dump_data(self):
+ print('Formatting data...')
+ self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]}
+
+ for pair in self.pair_list:
+ img_path1,img_path2=pair[0],pair[1]
+ seq=img_path1.split('/')[0]
+ index1,index2=int(img_path1.split('/')[-1][:-4]),int(img_path2.split('/')[-1][:-4])
+ ex1,ex2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index1)+'.txt'),dtype=float),\
+ np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'extrinsic',str(index2)+'.txt'),dtype=float)
+ K1,K2=np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index1)+'.txt'),dtype=float),\
+ np.loadtxt(os.path.join(self.config['rawdata_dir'],seq,'intrinsic',str(index2)+'.txt'),dtype=float)
+
+
+ relative_extrinsic=np.matmul(np.linalg.inv(ex2),ex1)
+ dR,dt=relative_extrinsic[:3,:3],relative_extrinsic[:3,3]
+ dt /= np.sqrt(np.sum(dt**2))
+
+ e_gt_unnorm = np.reshape(np.matmul(
+ np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)),
+ np.reshape(dR.astype('float64'), (3, 3))), (3, 3))
+ e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm)
+ f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1)
+ f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm)
+
+ self.data['K1'].append(K1),self.data['K2'].append(K2)
+ self.data['R'].append(dR),self.data['T'].append(dt)
+ self.data['e'].append(e_gt),self.data['f'].append(f_gt)
+
+ dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq)
+ fea_path1,fea_path2=os.path.join(dump_seq_dir,img_path1.split('/')[-1]+'_'+self.config['extractor']['name']
+ +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\
+ os.path.join(dump_seq_dir,img_path2.split('/')[-1]+'_'+self.config['extractor']['name']
+ +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
+ self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2)
+ self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2)
+
+ self.form_standard_dataset()
diff --git a/imcui/third_party/SGMNet/datadump/dumper/yfcc.py b/imcui/third_party/SGMNet/datadump/dumper/yfcc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c52e4324bba3e5ed424fe58af7a94fd3132b1e5
--- /dev/null
+++ b/imcui/third_party/SGMNet/datadump/dumper/yfcc.py
@@ -0,0 +1,87 @@
+import os
+import glob
+import pickle
+import numpy as np
+import h5py
+from .base_dumper import BaseDumper
+
+import sys
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
+sys.path.insert(0, ROOT_DIR)
+import utils
+
+class yfcc(BaseDumper):
+
+ def get_seqs(self):
+ data_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m')
+ for seq in self.config['data_seq']:
+ for split in self.config['data_split']:
+ split_dir=os.path.join(data_dir,seq,split)
+ dump_dir=os.path.join(self.config['feature_dump_dir'],seq,split)
+ cur_img_seq=glob.glob(os.path.join(split_dir,'images','*.jpg'))
+ cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\
+ +'.hdf5' for path in cur_img_seq]
+ self.img_seq+=cur_img_seq
+ self.dump_seq+=cur_dump_seq
+
+ def format_dump_folder(self):
+ if not os.path.exists(self.config['feature_dump_dir']):
+ os.mkdir(self.config['feature_dump_dir'])
+ for seq in self.config['data_seq']:
+ seq_dir=os.path.join(self.config['feature_dump_dir'],seq)
+ if not os.path.exists(seq_dir):
+ os.mkdir(seq_dir)
+ for split in self.config['data_split']:
+ split_dir=os.path.join(seq_dir,split)
+ if not os.path.exists(split_dir):
+ os.mkdir(split_dir)
+
+ def format_dump_data(self):
+ print('Formatting data...')
+ pair_path=os.path.join(self.config['rawdata_dir'],'pairs')
+ self.data={'K1':[],'K2':[],'R':[],'T':[],'e':[],'f':[],'fea_path1':[],'fea_path2':[],'img_path1':[],'img_path2':[]}
+
+ for seq in self.config['data_seq']:
+ pair_name=os.path.join(pair_path,seq+'-te-1000-pairs.pkl')
+ with open(pair_name, 'rb') as f:
+ pairs=pickle.load(f)
+
+ #generate id list
+ seq_dir=os.path.join(self.config['rawdata_dir'],'yfcc100m',seq,'test')
+ name_list=np.loadtxt(os.path.join(seq_dir,'images.txt'),dtype=str)
+ cam_name_list=np.loadtxt(os.path.join(seq_dir,'calibration.txt'),dtype=str)
+
+ for cur_pair in pairs:
+ index1,index2=cur_pair[0],cur_pair[1]
+ cam1,cam2=h5py.File(os.path.join(seq_dir,cam_name_list[index1]),'r'),h5py.File(os.path.join(seq_dir,cam_name_list[index2]),'r')
+ K1,K2=cam1['K'][()],cam2['K'][()]
+ [w1,h1],[w2,h2]=cam1['imsize'][()][0],cam2['imsize'][()][0]
+ cx1,cy1,cx2,cy2 = (w1 - 1.0) * 0.5,(h1 - 1.0) * 0.5, (w2 - 1.0) * 0.5,(h2 - 1.0) * 0.5
+ K1[0,2],K1[1,2],K2[0,2],K2[1,2]=cx1,cy1,cx2,cy2
+
+ R1,R2,t1,t2=cam1['R'][()],cam2['R'][()],cam1['T'][()].reshape([3,1]),cam2['T'][()].reshape([3,1])
+ dR = np.dot(R2, R1.T)
+ dt = t2 - np.dot(dR, t1)
+ dt /= np.sqrt(np.sum(dt**2))
+
+ e_gt_unnorm = np.reshape(np.matmul(
+ np.reshape(utils.evaluation_utils.np_skew_symmetric(dt.astype('float64').reshape(1, 3)), (3, 3)),
+ np.reshape(dR.astype('float64'), (3, 3))), (3, 3))
+ e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm)
+ f_gt_unnorm=np.linalg.inv(K2.T)@e_gt@np.linalg.inv(K1)
+ f_gt = f_gt_unnorm / np.linalg.norm(f_gt_unnorm)
+
+ self.data['K1'].append(K1),self.data['K2'].append(K2)
+ self.data['R'].append(dR),self.data['T'].append(dt)
+ self.data['e'].append(e_gt),self.data['f'].append(f_gt)
+
+ img_path1,img_path2=os.path.join('yfcc100m',seq,'test',name_list[index1]),os.path.join('yfcc100m',seq,'test',name_list[index2])
+ dump_seq_dir=os.path.join(self.config['feature_dump_dir'],seq,'test')
+ fea_path1,fea_path2=os.path.join(dump_seq_dir,name_list[index1].split('/')[-1]+'_'+self.config['extractor']['name']
+ +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5'),\
+ os.path.join(dump_seq_dir,name_list[index2].split('/')[-1]+'_'+self.config['extractor']['name']
+ +'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
+ self.data['img_path1'].append(img_path1),self.data['img_path2'].append(img_path2)
+ self.data['fea_path1'].append(fea_path1),self.data['fea_path2'].append(fea_path2)
+
+ self.form_standard_dataset()
diff --git a/imcui/third_party/SGMNet/demo/configs/nn_config.yaml b/imcui/third_party/SGMNet/demo/configs/nn_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a87bfafce0cb7f8ab64e59311923d309aabcfab9
--- /dev/null
+++ b/imcui/third_party/SGMNet/demo/configs/nn_config.yaml
@@ -0,0 +1,10 @@
+extractor:
+ name: root
+ num_kpt: 4000
+ resize: [-1]
+ det_th: 0.00001
+
+matcher:
+ name: NN
+ ratio_th: 0.9
+ mutual_check: True
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/demo/configs/sgm_config.yaml b/imcui/third_party/SGMNet/demo/configs/sgm_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..91de752010daa54ef0b508ef79d2dc4ac23945ec
--- /dev/null
+++ b/imcui/third_party/SGMNet/demo/configs/sgm_config.yaml
@@ -0,0 +1,21 @@
+extractor:
+ name: root
+ num_kpt: 4000
+ resize: [-1]
+ det_th: 0.00001
+
+matcher:
+ name: SGM
+ model_dir: ../weights/sgm/root
+ seed_top_k: [256,256]
+ seed_radius_coe: 0.01
+ net_channels: 128
+ layer_num: 9
+ head: 4
+ seedlayer: [0,6]
+ use_mc_seeding: True
+ use_score_encoding: False
+ conf_bar: [1.11,0.1]
+ sink_iter: [10,100]
+ detach_iter: 1000000
+ p_th: 0.2
diff --git a/imcui/third_party/SGMNet/demo/demo.py b/imcui/third_party/SGMNet/demo/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbe277e26d09121f5517854a7ea014b0797a2bde
--- /dev/null
+++ b/imcui/third_party/SGMNet/demo/demo.py
@@ -0,0 +1,45 @@
+import cv2
+import yaml
+import numpy as np
+import os
+import sys
+
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, ROOT_DIR)
+from components import load_component
+from utils import evaluation_utils
+
+import argparse
+parser = argparse.ArgumentParser()
+parser.add_argument('--config_path', type=str, default='configs/sgm_config.yaml',
+ help='number of processes.')
+parser.add_argument('--img1_path', type=str, default='demo_1.jpg',
+ help='number of processes.')
+parser.add_argument('--img2_path', type=str, default='demo_2.jpg',
+ help='number of processes.')
+
+
+args = parser.parse_args()
+
+if __name__=='__main__':
+ with open(args.config_path, 'r') as f:
+ demo_config = yaml.load(f)
+
+ extractor=load_component('extractor',demo_config['extractor']['name'],demo_config['extractor'])
+
+ img1,img2=cv2.imread(args.img1_path),cv2.imread(args.img2_path)
+ size1,size2=np.flip(np.asarray(img1.shape[:2])),np.flip(np.asarray(img2.shape[:2]))
+ kpt1,desc1=extractor.run(args.img1_path)
+ kpt2,desc2=extractor.run(args.img2_path)
+
+ matcher=load_component('matcher',demo_config['matcher']['name'],demo_config['matcher'])
+ test_data={'x1':kpt1,'x2':kpt2,'desc1':desc1,'desc2':desc2,'size1':size1,'size2':size2}
+ corr1,corr2= matcher.run(test_data)
+
+ #draw points
+ dis_points_1 = evaluation_utils.draw_points(img1, kpt1)
+ dis_points_2 = evaluation_utils.draw_points(img2, kpt2)
+
+ #visualize match
+ display=evaluation_utils.draw_match(dis_points_1,dis_points_2,corr1,corr2)
+ cv2.imwrite('match.png',display)
diff --git a/imcui/third_party/SGMNet/evaluation/configs/cost/sg_cost.yaml b/imcui/third_party/SGMNet/evaluation/configs/cost/sg_cost.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..05ea5ddc7bce8ad94d3ef3ec350363b5cc846ed8
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/cost/sg_cost.yaml
@@ -0,0 +1,4 @@
+net_channels: 128
+layer_num: 9
+head: 4
+use_score_encoding: True
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/evaluation/configs/cost/sgm_cost.yaml b/imcui/third_party/SGMNet/evaluation/configs/cost/sgm_cost.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f43193fb63fb26d50a8c3abd3cf53c43734dbca
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/cost/sgm_cost.yaml
@@ -0,0 +1,11 @@
+seed_top_k: [256,256]
+seed_radius_coe: 0.01
+net_channels: 128
+layer_num: 9
+head: 4
+seedlayer: [0,6]
+use_mc_seeding: True
+use_score_encoding: False
+conf_bar: [1,0]
+sink_iter: [10,10]
+detach_iter: 1000000
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_nn.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d467a814559a27938f010dbf79a8e208551b2b5
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_nn.yaml
@@ -0,0 +1,18 @@
+reader:
+ name: standard
+ rawdata_dir: FM-Bench/Dataset
+ dataset_dir: test_fmbench_root/fmbench_root_4000.hdf5
+ num_kpt: 4000
+
+matcher:
+ name: NN
+ mutual_check: False
+ ratio_th: 0.8
+
+evaluator:
+ name: FM
+ seq: ['CPC','KITTI','TUM','Tanks_and_Temples']
+ num_pair: 4000
+ inlier_th: 0.003
+ sgd_inlier_th: 0.05
+
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sg.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec22b1340d62fad20f22584ddbded30fcc59d1c9
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sg.yaml
@@ -0,0 +1,22 @@
+reader:
+ name: standard
+ rawdata_dir: FM-Bench/Dataset
+ dataset_dir: test_fmbench_root/fmbench_root_4000.hdf5
+ num_kpt: 4000
+
+matcher:
+ name: SG
+ model_dir: ../weights/sg/root
+ net_channels: 128
+ layer_num: 9
+ head: 4
+ use_score_encoding: True
+ sink_iter: [100]
+ p_th: 0.2
+
+evaluator:
+ name: FM
+ seq: ['CPC','KITTI','TUM','Tanks_and_Temples']
+ num_pair: 4000
+ inlier_th: 0.003
+ sgd_inlier_th: 0.05
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sgm.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sgm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cd23165c95451cd44063a2b6cccea21c68fb6fa0
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/fm_eval_sgm.yaml
@@ -0,0 +1,28 @@
+reader:
+ name: standard
+ rawdata_dir: FM-Bench/Dataset
+ dataset_dir: test_fmbench_root/fmbench_root_4000.hdf5
+ num_kpt: 4000
+
+matcher:
+ name: SGM
+ model_dir: ../weights/sgm/root
+ seed_top_k: [256,256]
+ seed_radius_coe: 0.01
+ net_channels: 128
+ layer_num: 9
+ head: 4
+ seedlayer: [0,6]
+ use_mc_seeding: True
+ use_score_encoding: False
+ conf_bar: [1.11,0.1] #set to [1,0.1] for sp
+ sink_iter: [10,100]
+ detach_iter: 1000000
+ p_th: 0.2
+
+evaluator:
+ name: FM
+ seq: ['CPC','KITTI','TUM','Tanks_and_Temples']
+ num_pair: 4000
+ inlier_th: 0.003
+ sgd_inlier_th: 0.05
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_nn.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..51ad5402b6266b60a365181371be8a5e64751d2f
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_nn.yaml
@@ -0,0 +1,17 @@
+reader:
+ name: standard
+ rawdata_dir: scannet_eval
+ dataset_dir: scannet_test_root/scannet_root_2000.hdf5
+ num_kpt: 2000
+
+matcher:
+ name: NN
+ mutual_check: False
+ ratio_th: 0.8
+
+evaluator:
+ name: AUC
+ rescale: 640
+ num_pair: 1500
+ inlier_th: 0.005
+
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sg.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d0ef70cfa07b1471816cc7905d6a632599d134c
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sg.yaml
@@ -0,0 +1,22 @@
+reader:
+ name: standard
+ rawdata_dir: scannet_eval
+ dataset_dir: scannet_test_root/scannet_root_2000.hdf5
+ num_kpt: 2000
+
+matcher:
+ name: SG
+ model_dir: ../weights/sg/root
+ net_channels: 128
+ layer_num: 9
+ head: 4
+ use_score_encoding: True
+ sink_iter: [100]
+ p_th: 0.2
+
+evaluator:
+ name: AUC
+ rescale: 640
+ num_pair: 1500
+ inlier_th: 0.005
+
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sgm.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sgm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e524845a514e6d8d50f97bced5c9beeaed26ebe5
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/scannet_eval_sgm.yaml
@@ -0,0 +1,28 @@
+reader:
+ name: standard
+ rawdata_dir: scannet_eval
+ dataset_dir: scannet_test_root/scannet_root_2000.hdf5
+ num_kpt: 2000
+
+matcher:
+ name: SGM
+ model_dir: ../weights/sgm/root
+ seed_top_k: [128,128]
+ seed_radius_coe: 0.01
+ net_channels: 128
+ layer_num: 9
+ head: 4
+ seedlayer: [0,6]
+ use_mc_seeding: True
+ use_score_encoding: False
+ conf_bar: [1.11,0.1]
+ sink_iter: [10,100]
+ detach_iter: 1000000
+ p_th: 0.2
+
+evaluator:
+ name: AUC
+ rescale: 640
+ num_pair: 1500
+ inlier_th: 0.005
+
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_nn.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8ecd1eef2cff9b93f3665a9cf4af6bc9f68339f0
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_nn.yaml
@@ -0,0 +1,17 @@
+reader:
+ name: standard
+ rawdata_dir: yfcc_rawdata
+ dataset_dir: yfcc_test_root/yfcc_root_2000.hdf5
+ num_kpt: 2000
+
+matcher:
+ name: NN
+ mutual_check: False
+ ratio_th: 0.8
+
+evaluator:
+ name: AUC
+ rescale: 1600
+ num_pair: 4000
+ inlier_th: 0.005
+
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sg.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..beb2b93639160448dd955cd576e5a19a936b08f1
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sg.yaml
@@ -0,0 +1,22 @@
+reader:
+ name: standard
+ rawdata_dir: yfcc_rawdata
+ dataset_dir: yfcc_test_root/yfcc_root_2000.hdf5
+ num_kpt: 2000
+
+matcher:
+ name: SG
+ model_dir: ../weights/sg/root
+ net_channels: 128
+ layer_num: 9
+ head: 4
+ use_score_encoding: True
+ sink_iter: [100]
+ p_th: 0.2
+
+evaluator:
+ name: AUC
+ rescale: 1600
+ num_pair: 4000
+ inlier_th: 0.005
+
diff --git a/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sgm.yaml b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sgm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6c9aee8a8aa786ff209a5afadf0469f62ef2a50f
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/configs/eval/yfcc_eval_sgm.yaml
@@ -0,0 +1,28 @@
+reader:
+ name: standard
+ rawdata_dir: yfcc_rawdata
+ dataset_dir: yfcc_test_root/yfcc_root_2000.hdf5
+ num_kpt: 2000
+
+matcher:
+ name: SGM
+ model_dir: ../weights/sgm/root
+ seed_top_k: [128,128]
+ seed_radius_coe: 0.01
+ net_channels: 128
+ layer_num: 9
+ head: 4
+ seedlayer: [0,6]
+ use_mc_seeding: True
+ use_score_encoding: False
+ conf_bar: [1.11,0.1] #set to [1,0.1] for sp
+ sink_iter: [10,100]
+ detach_iter: 1000000
+ p_th: 0.2
+
+evaluator:
+ name: AUC
+ rescale: 1600
+ num_pair: 4000
+ inlier_th: 0.005
+
diff --git a/imcui/third_party/SGMNet/evaluation/eval_cost.py b/imcui/third_party/SGMNet/evaluation/eval_cost.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd3f88abc93290c96ed3d7fa8624c3534e006911
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/eval_cost.py
@@ -0,0 +1,60 @@
+import torch
+import yaml
+import time
+from collections import OrderedDict,namedtuple
+import os
+import sys
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, ROOT_DIR)
+
+from sgmnet import matcher as SGM_Model
+from superglue import matcher as SG_Model
+
+
+import argparse
+parser = argparse.ArgumentParser()
+parser.add_argument('--matcher_name', type=str, default='SGM',
+ help='number of processes.')
+parser.add_argument('--config_path', type=str, default='configs/cost/sgm_cost.yaml',
+ help='number of processes.')
+parser.add_argument('--num_kpt', type=int, default=4000,
+ help='keypoint number, default:100')
+parser.add_argument('--iter_num', type=int, default=100,
+ help='keypoint number, default:100')
+
+
+def test_cost(test_data,model):
+ with torch.no_grad():
+ #warm up call
+ _=model(test_data)
+ torch.cuda.synchronize()
+ a=time.time()
+ for _ in range(int(args.iter_num)):
+ _=model(test_data)
+ torch.cuda.synchronize()
+ b=time.time()
+ print('Average time per run(ms): ',(b-a)/args.iter_num*1e3)
+ print('Peak memory(MB): ',torch.cuda.max_memory_allocated()/1e6)
+
+
+if __name__=='__main__':
+ torch.backends.cudnn.benchmark=False
+ args = parser.parse_args()
+ with open(args.config_path, 'r') as f:
+ model_config = yaml.load(f)
+ model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
+
+ if args.matcher_name=='SGM':
+ model = SGM_Model(model_config)
+ elif args.matcher_name=='SG':
+ model = SG_Model(model_config)
+ model.cuda(),model.eval()
+
+ test_data = {
+ 'x1':torch.rand(1,args.num_kpt,2).cuda()-0.5,
+ 'x2':torch.rand(1,args.num_kpt,2).cuda()-0.5,
+ 'desc1': torch.rand(1,args.num_kpt,128).cuda(),
+ 'desc2': torch.rand(1,args.num_kpt,128).cuda()
+ }
+
+ test_cost(test_data,model)
diff --git a/imcui/third_party/SGMNet/evaluation/evaluate.py b/imcui/third_party/SGMNet/evaluation/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd5229375caa03b2763bf37a266fb76e80f8e25e
--- /dev/null
+++ b/imcui/third_party/SGMNet/evaluation/evaluate.py
@@ -0,0 +1,117 @@
+import os
+from torch.multiprocessing import Process,Manager,set_start_method,Pool
+import functools
+import argparse
+import yaml
+import numpy as np
+import sys
+import cv2
+from tqdm import trange
+set_start_method('spawn',force=True)
+
+
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, ROOT_DIR)
+
+from components import load_component
+from utils import evaluation_utils,metrics
+
+parser = argparse.ArgumentParser(description='dump eval data.')
+parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml')
+parser.add_argument('--num_process_match', type=int, default=4)
+parser.add_argument('--num_process_eval', type=int, default=4)
+parser.add_argument('--vis_folder',type=str,default=None)
+args=parser.parse_args()
+
+def feed_match(info,matcher):
+ x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2]
+ test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) }
+ corr1,corr2=matcher.run(test_data)
+ return [corr1,corr2]
+
+
+def reader_handler(config,read_que):
+ reader=load_component('reader',config['name'],config)
+ for index in range(len(reader)):
+ index+=0
+ info=reader.run(index)
+ read_que.put(info)
+ read_que.put('over')
+
+
+def match_handler(config,read_que,match_que):
+ matcher=load_component('matcher',config['name'],config)
+ match_func=functools.partial(feed_match,matcher=matcher)
+ pool = Pool(args.num_process_match)
+ cache=[]
+ while True:
+ item=read_que.get()
+ #clear cache
+ if item=='over':
+ if len(cache)!=0:
+ results=pool.map(match_func,cache)
+ for cur_item,cur_result in zip(cache,results):
+ cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
+ match_que.put(cur_item)
+ match_que.put('over')
+ break
+ cache.append(item)
+ #print(len(cache))
+ if len(cache)==args.num_process_match:
+ #matching in parallel
+ results=pool.map(match_func,cache)
+ for cur_item,cur_result in zip(cache,results):
+ cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
+ match_que.put(cur_item)
+ cache=[]
+ pool.close()
+ pool.join()
+
+
+def evaluate_handler(config,match_que):
+ evaluator=load_component('evaluator',config['name'],config)
+ pool = Pool(args.num_process_eval)
+ cache=[]
+ for _ in trange(config['num_pair']):
+ item=match_que.get()
+ if item=='over':
+ if len(cache)!=0:
+ results=pool.map(evaluator.run,cache)
+ for cur_res in results:
+ evaluator.res_inqueue(cur_res)
+ break
+ cache.append(item)
+ if len(cache)==args.num_process_eval:
+ results=pool.map(evaluator.run,cache)
+ for cur_res in results:
+ evaluator.res_inqueue(cur_res)
+ cache=[]
+ if args.vis_folder is not None:
+ #dump visualization
+ corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\
+ evaluation_utils.normalize_intrinsic(item['corr2'],item['K2'])
+ inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th'])
+ display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask)
+ cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display)
+ evaluator.parse()
+
+
+if __name__=='__main__':
+ with open(args.config_path, 'r') as f:
+ config = yaml.load(f)
+ if args.vis_folder is not None and not os.path.exists(args.vis_folder):
+ os.mkdir(args.vis_folder)
+
+ read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100)
+
+ read_process=Process(target=reader_handler,args=(config['reader'],read_que))
+ match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que))
+ evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que))
+
+ read_process.start()
+ match_process.start()
+ evaluate_process.start()
+
+ read_process.join()
+ match_process.join()
+ evaluate_process.join()
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/sgmnet/__init__.py b/imcui/third_party/SGMNet/sgmnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..828543beceebb10d05fd9d5fdfcc4b1c91e5af6b
--- /dev/null
+++ b/imcui/third_party/SGMNet/sgmnet/__init__.py
@@ -0,0 +1 @@
+from .match_model import matcher
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/sgmnet/match_model.py b/imcui/third_party/SGMNet/sgmnet/match_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8760815cba9e34749b748cdb485bdc73b1cc9edb
--- /dev/null
+++ b/imcui/third_party/SGMNet/sgmnet/match_model.py
@@ -0,0 +1,223 @@
+import torch
+import torch.nn as nn
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+eps=1e-8
+
+def sinkhorn(M,r,c,iteration):
+ p = torch.softmax(M, dim=-1)
+ u = torch.ones_like(r)
+ v = torch.ones_like(c)
+ for _ in range(iteration):
+ u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps)
+ v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps)
+ p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
+ return p
+
+def sink_algorithm(M,dustbin,iteration):
+ M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
+ M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
+ r = torch.ones([M.shape[0], M.shape[1] - 1],device=device)
+ r = torch.cat([r, torch.ones([M.shape[0], 1],device=device) * M.shape[1]], dim=-1)
+ c = torch.ones([M.shape[0], M.shape[2] - 1],device=device)
+ c = torch.cat([c, torch.ones([M.shape[0], 1],device=device) * M.shape[2]], dim=-1)
+ p=sinkhorn(M,r,c,iteration)
+ return p
+
+
+def seeding(nn_index1,nn_index2,x1,x2,topk,match_score,confbar,nms_radius,use_mc=True,test=False):
+
+ #apply mutual check before nms
+ if use_mc:
+ mask_not_mutual=nn_index2.gather(dim=-1,index=nn_index1)!=torch.arange(nn_index1.shape[1],device=device)
+ match_score[mask_not_mutual]=-1
+ #NMS
+ pos_dismat1=((x1.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x1.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x1@x1.transpose(1,2))).abs_().sqrt_()
+ x2=x2.gather(index=nn_index1.unsqueeze(-1).expand(-1,-1,2),dim=1)
+ pos_dismat2=((x2.norm(p=2,dim=-1)**2).unsqueeze_(-1)+(x2.norm(p=2,dim=-1)**2).unsqueeze_(-2)-2*(x2@x2.transpose(1,2))).abs_().sqrt_()
+ radius1, radius2 = nms_radius * pos_dismat1.mean(dim=(1,2),keepdim=True), nms_radius * pos_dismat2.mean(dim=(1,2),keepdim=True)
+ nms_mask = (pos_dismat1 >= radius1) & (pos_dismat2 >= radius2)
+ mask_not_local_max=(match_score.unsqueeze(-1)>=match_score.unsqueeze(-2))|nms_mask
+ mask_not_local_max=~(mask_not_local_max.min(dim=-1).values)
+ match_score[mask_not_local_max] = -1
+
+ #confidence bar
+ match_score[match_score0
+ if test:
+ topk=min(mask_survive.sum(dim=1)[0]+2,topk)
+ _,topindex = torch.topk(match_score,topk,dim=-1)#b*k
+ seed_index1,seed_index2=topindex,nn_index1.gather(index=topindex,dim=-1)
+ return seed_index1,seed_index2
+
+
+
+class PointCN(nn.Module):
+ def __init__(self, channels,out_channels):
+ nn.Module.__init__(self)
+ self.shot_cut = nn.Conv1d(channels, out_channels, kernel_size=1)
+ self.conv = nn.Sequential(
+ nn.InstanceNorm1d(channels, eps=1e-3),
+ nn.SyncBatchNorm(channels),
+ nn.ReLU(),
+ nn.Conv1d(channels, channels, kernel_size=1),
+ nn.InstanceNorm1d(channels, eps=1e-3),
+ nn.SyncBatchNorm(channels),
+ nn.ReLU(),
+ nn.Conv1d(channels, out_channels, kernel_size=1)
+ )
+
+ def forward(self, x):
+ return self.conv(x) + self.shot_cut(x)
+
+
+class attention_propagantion(nn.Module):
+
+ def __init__(self,channel,head):
+ nn.Module.__init__(self)
+ self.head=head
+ self.head_dim=channel//head
+ self.query_filter,self.key_filter,self.value_filter=nn.Conv1d(channel,channel,kernel_size=1),nn.Conv1d(channel,channel,kernel_size=1),\
+ nn.Conv1d(channel,channel,kernel_size=1)
+ self.mh_filter=nn.Conv1d(channel,channel,kernel_size=1)
+ self.cat_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(),
+ nn.Conv1d(2*channel, channel, kernel_size=1))
+
+ def forward(self,desc1,desc2,weight_v=None):
+ #desc1(q) attend to desc2(k,v)
+ batch_size=desc1.shape[0]
+ query,key,value=self.query_filter(desc1).view(batch_size,self.head,self.head_dim,-1),self.key_filter(desc2).view(batch_size,self.head,self.head_dim,-1),\
+ self.value_filter(desc2).view(batch_size,self.head,self.head_dim,-1)
+ if weight_v is not None:
+ value=value*weight_v.view(batch_size,1,1,-1)
+ score=torch.softmax(torch.einsum('bhdn,bhdm->bhnm',query,key)/ self.head_dim ** 0.5,dim=-1)
+ add_value=torch.einsum('bhnm,bhdm->bhdn',score,value).reshape(batch_size,self.head_dim*self.head,-1)
+ add_value=self.mh_filter(add_value)
+ desc1_new=desc1+self.cat_filter(torch.cat([desc1,add_value],dim=1))
+ return desc1_new
+
+
+class hybrid_block(nn.Module):
+ def __init__(self,channel,head):
+ nn.Module.__init__(self)
+ self.head=head
+ self.channel=channel
+ self.attention_block_down = attention_propagantion(channel, head)
+ self.cluster_filter=nn.Sequential(nn.Conv1d(2*channel,2*channel, kernel_size=1), nn.SyncBatchNorm(2*channel), nn.ReLU(),
+ nn.Conv1d(2*channel, 2*channel, kernel_size=1))
+ self.cross_filter=attention_propagantion(channel,head)
+ self.confidence_filter=PointCN(2*channel,1)
+ self.attention_block_self=attention_propagantion(channel,head)
+ self.attention_block_up=attention_propagantion(channel,head)
+
+ def forward(self,desc1,desc2,seed_index1,seed_index2):
+ cluster1, cluster2 = desc1.gather(dim=-1, index=seed_index1.unsqueeze(1).expand(-1, self.channel, -1)), \
+ desc2.gather(dim=-1, index=seed_index2.unsqueeze(1).expand(-1, self.channel, -1))
+
+ #pooling
+ cluster1, cluster2 = self.attention_block_down(cluster1, desc1), self.attention_block_down(cluster2, desc2)
+ concate_cluster=self.cluster_filter(torch.cat([cluster1,cluster2],dim=1))
+ #filtering
+ cluster1,cluster2=self.cross_filter(concate_cluster[:,:self.channel],concate_cluster[:,self.channel:]),\
+ self.cross_filter(concate_cluster[:,self.channel:],concate_cluster[:,:self.channel])
+ cluster1,cluster2=self.attention_block_self(cluster1,cluster1),self.attention_block_self(cluster2,cluster2)
+ #unpooling
+ seed_weight=self.confidence_filter(torch.cat([cluster1,cluster2],dim=1))
+ seed_weight=torch.sigmoid(seed_weight).squeeze(1)
+ desc1_new,desc2_new=self.attention_block_up(desc1,cluster1,seed_weight),self.attention_block_up(desc2,cluster2,seed_weight)
+ return desc1_new,desc2_new,seed_weight
+
+
+
+class matcher(nn.Module):
+ def __init__(self,config):
+ nn.Module.__init__(self)
+ self.seed_top_k=config.seed_top_k
+ self.conf_bar=config.conf_bar
+ self.seed_radius_coe=config.seed_radius_coe
+ self.use_score_encoding=config.use_score_encoding
+ self.detach_iter=config.detach_iter
+ self.seedlayer=config.seedlayer
+ self.layer_num=config.layer_num
+ self.sink_iter=config.sink_iter
+
+ self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1),
+ nn.SyncBatchNorm(32),nn.ReLU(),
+ nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(),
+ nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128),nn.ReLU(),
+ nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256),nn.ReLU(),
+ nn.Conv1d(256, config.net_channels, kernel_size=1))
+
+
+ self.hybrid_block=nn.Sequential(*[hybrid_block(config.net_channels, config.head) for _ in range(config.layer_num)])
+ self.final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1)
+ self.dustbin=nn.Parameter(torch.tensor(1.5,dtype=torch.float32))
+
+ #if reseeding
+ if len(config.seedlayer)!=1:
+ self.mid_dustbin=nn.ParameterDict({str(i):nn.Parameter(torch.tensor(2,dtype=torch.float32)) for i in config.seedlayer[1:]})
+ self.mid_final_project = nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1)
+
+ def forward(self,data,test_mode=True):
+ x1, x2, desc1, desc2 = data['x1'][:,:,:2], data['x2'][:,:,:2], data['desc1'], data['desc2']
+ desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1)
+ if test_mode:
+ encode_x1,encode_x2=data['x1'],data['x2']
+ else:
+ encode_x1,encode_x2=data['aug_x1'], data['aug_x2']
+
+ #preparation
+ desc_dismat=(2-2*torch.matmul(desc1,desc2.transpose(1,2))).sqrt_()
+ values,nn_index=torch.topk(desc_dismat,k=2,largest=False,dim=-1,sorted=True)
+ nn_index2=torch.min(desc_dismat,dim=1).indices.squeeze(1)
+ inverse_ratio_score,nn_index1=values[:,:,1]/values[:,:,0],nn_index[:,:,0]#get inverse score
+
+ #initial seeding
+ seed_index1,seed_index2=seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[0],inverse_ratio_score,self.conf_bar[0],\
+ self.seed_radius_coe,test=test_mode)
+
+ #position encoding
+ desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2)
+ if not self.use_score_encoding:
+ encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2]
+ encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2)
+ x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2)
+ aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2
+
+ seed_weight_tower,mid_p_tower,seed_index_tower,nn_index_tower=[],[],[],[]
+ seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1))
+ nn_index_tower.append(nn_index1)
+
+ seed_para_index=0
+ for i in range(self.layer_num):
+ #mid seeding
+ if i in self.seedlayer and i!= 0:
+ seed_para_index+=1
+ aug_desc1,aug_desc2=self.mid_final_project(aug_desc1),self.mid_final_project(aug_desc2)
+ M=torch.matmul(aug_desc1.transpose(1,2),aug_desc2)
+ p=sink_algorithm(M,self.mid_dustbin[str(i)],self.sink_iter[seed_para_index-1])
+ mid_p_tower.append(p)
+ #rematching with p
+ values,nn_index=torch.topk(p[:,:-1,:-1],k=1,dim=-1)
+ nn_index2=torch.max(p[:,:-1,:-1],dim=1).indices.squeeze(1)
+ p_match_score,nn_index1=values[:,:,0],nn_index[:,:,0]
+ #reseeding
+ seed_index1, seed_index2 = seeding(nn_index1,nn_index2,x1,x2,self.seed_top_k[seed_para_index],p_match_score,\
+ self.conf_bar[seed_para_index],self.seed_radius_coe,test=test_mode)
+ seed_index_tower.append(torch.stack([seed_index1, seed_index2],dim=-1)), nn_index_tower.append(nn_index1)
+ if not test_mode and data['step']bhnm',query1,key1)/self.head_dim**0.5,dim=-1),\
+ torch.softmax(torch.einsum('bdhn,bdhm->bhnm',query2,key2)/self.head_dim**0.5,dim=-1)
+ add_value1, add_value2 = torch.einsum('bhnm,bdhm->bdhn', score1, value1), torch.einsum('bhnm,bdhm->bdhn',score2, value2)
+ else:
+ score1,score2 = torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query1, key2) / self.head_dim ** 0.5,dim=-1), \
+ torch.softmax(torch.einsum('bdhn,bdhm->bhnm', query2, key1) / self.head_dim ** 0.5, dim=-1)
+ add_value1, add_value2 =torch.einsum('bhnm,bdhm->bdhn',score1,value2),torch.einsum('bhnm,bdhm->bdhn',score2,value1)
+ add_value1,add_value2=self.mh_filter(add_value1.contiguous().view(batch_size,self.head*self.head_dim,n)),self.mh_filter(add_value2.contiguous().view(batch_size,self.head*self.head_dim,m))
+ fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat([fea2, add_value2], dim=1)
+ fea1, fea2 = fea1+self.attention_filter(fea11), fea2+self.attention_filter(fea22)
+
+ return fea1,fea2
+
+
+class matcher(nn.Module):
+ def __init__(self, config):
+ nn.Module.__init__(self)
+ self.use_score_encoding=config.use_score_encoding
+ self.layer_num=config.layer_num
+ self.sink_iter=config.sink_iter
+ self.position_encoder = nn.Sequential(nn.Conv1d(3, 32, kernel_size=1) if config.use_score_encoding else nn.Conv1d(2, 32, kernel_size=1),
+ nn.SyncBatchNorm(32), nn.ReLU(),
+ nn.Conv1d(32, 64, kernel_size=1), nn.SyncBatchNorm(64),nn.ReLU(),
+ nn.Conv1d(64, 128, kernel_size=1), nn.SyncBatchNorm(128), nn.ReLU(),
+ nn.Conv1d(128, 256, kernel_size=1), nn.SyncBatchNorm(256), nn.ReLU(),
+ nn.Conv1d(256, config.net_channels, kernel_size=1))
+
+ self.dustbin=nn.Parameter(torch.tensor(1,dtype=torch.float32,device='cuda'))
+ self.self_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'self') for _ in range(config.layer_num)])
+ self.cross_attention_block=nn.Sequential(*[attention_block(config.net_channels,config.head,'cross') for _ in range(config.layer_num)])
+ self.final_project=nn.Conv1d(config.net_channels, config.net_channels, kernel_size=1)
+
+ def forward(self,data,test_mode=True):
+ desc1, desc2 = data['desc1'], data['desc2']
+ desc1, desc2 = torch.nn.functional.normalize(desc1,dim=-1), torch.nn.functional.normalize(desc2,dim=-1)
+ desc1,desc2=desc1.transpose(1,2),desc2.transpose(1,2)
+ if test_mode:
+ encode_x1,encode_x2=data['x1'],data['x2']
+ else:
+ encode_x1,encode_x2=data['aug_x1'], data['aug_x2']
+ if not self.use_score_encoding:
+ encode_x1,encode_x2=encode_x1[:,:,:2],encode_x2[:,:,:2]
+
+ encode_x1,encode_x2=encode_x1.transpose(1,2),encode_x2.transpose(1,2)
+
+ x1_pos_embedding, x2_pos_embedding = self.position_encoder(encode_x1), self.position_encoder(encode_x2)
+ aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding+desc2
+ for i in range(self.layer_num):
+ aug_desc1,aug_desc2=self.self_attention_block[i](aug_desc1,aug_desc2)
+ aug_desc1,aug_desc2=self.cross_attention_block[i](aug_desc1,aug_desc2)
+
+ aug_desc1,aug_desc2=self.final_project(aug_desc1),self.final_project(aug_desc2)
+ desc_mat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2)
+ p = sink_algorithm(desc_mat, self.dustbin,self.sink_iter[0])
+ return {'p':p}
+
+
diff --git a/imcui/third_party/SGMNet/superpoint/__init__.py b/imcui/third_party/SGMNet/superpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..111c8882a7bc7512c6191ca86a0e71c3b1404233
--- /dev/null
+++ b/imcui/third_party/SGMNet/superpoint/__init__.py
@@ -0,0 +1 @@
+from .superpoint import SuperPoint
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/superpoint/superpoint.py b/imcui/third_party/SGMNet/superpoint/superpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4e3ce481409264a3188270ad01aa62b1614377f
--- /dev/null
+++ b/imcui/third_party/SGMNet/superpoint/superpoint.py
@@ -0,0 +1,140 @@
+import torch
+from torch import nn
+
+
+def simple_nms(scores, nms_radius):
+ assert(nms_radius >= 0)
+
+ def max_pool(x):
+ return torch.nn.functional.max_pool2d(
+ x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == max_pool(scores)
+ for _ in range(2):
+ supp_mask = max_pool(max_mask.float()) > 0
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == max_pool(supp_scores)
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+def remove_borders(keypoints, scores, b, h, w):
+ mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b))
+ mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b))
+ mask = mask_h & mask_w
+ return keypoints[mask], scores[mask]
+
+
+def top_k_keypoints(keypoints, scores, k):
+ if k >= len(keypoints):
+ return keypoints, scores
+ scores, indices = torch.topk(scores, k, dim=0)
+ return keypoints[indices], scores
+
+
+def sample_descriptors(keypoints, descriptors, s):
+ b, c, h, w = descriptors.shape
+ keypoints = keypoints - s / 2 + 0.5
+ keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
+ ).to(keypoints)[None]
+ keypoints = keypoints*2 - 1 # normalize to (-1, 1)
+ args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1)
+ return descriptors
+
+
+class SuperPoint(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = {**config}
+
+ self.relu = nn.ReLU(inplace=True)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
+
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
+
+ self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
+
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convDb = nn.Conv2d(
+ c5, self.config['descriptor_dim'],
+ kernel_size=1, stride=1, padding=0)
+
+ self.load_state_dict(torch.load(config['model_path']))
+
+ mk = self.config['max_keypoints']
+ if mk == 0 or mk < -1:
+ raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
+
+ print('Loaded SuperPoint model')
+
+ def forward(self, data):
+ # Shared Encoder
+ x = self.relu(self.conv1a(data))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x))
+ scores = self.convPb(cPa)
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+ b, c, h, w = scores.shape
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
+ scores = simple_nms(scores, self.config['nms_radius'])
+
+ # Extract keypoints
+ keypoints = [
+ torch.nonzero(s > self.config['detection_threshold'])
+ for s in scores]
+ scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
+
+ # Discard keypoints near the image borders
+ keypoints, scores = list(zip(*[
+ remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
+ for k, s in zip(keypoints, scores)]))
+
+ # Keep the k keypoints with highest score
+ if self.config['max_keypoints'] >= 0:
+ keypoints, scores = list(zip(*[
+ top_k_keypoints(k, s, self.config['max_keypoints'])
+ for k, s in zip(keypoints, scores)]))
+
+ # Convert (h, w) to (x, y)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x))
+ descriptors = self.convDb(cDa)
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
+
+ # Extract descriptors
+ descriptors = [sample_descriptors(k[None], d[None], 8)[0]
+ for k, d in zip(keypoints, descriptors)]
+
+ return {
+ 'keypoints': keypoints,
+ 'scores': scores,
+ 'descriptors': descriptors,
+ }
diff --git a/imcui/third_party/SGMNet/train/config.py b/imcui/third_party/SGMNet/train/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c4c1c6deef3d6dd568897f4202d96456586376
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/config.py
@@ -0,0 +1,126 @@
+import argparse
+
+def str2bool(v):
+ return v.lower() in ("true", "1")
+
+
+arg_lists = []
+parser = argparse.ArgumentParser()
+
+
+def add_argument_group(name):
+ arg = parser.add_argument_group(name)
+ arg_lists.append(arg)
+ return arg
+
+
+# -----------------------------------------------------------------------------
+# Network
+net_arg = add_argument_group("Network")
+net_arg.add_argument(
+ "--model_name", type=str,default='SGM', help=""
+ "model for training")
+net_arg.add_argument(
+ "--config_path", type=str,default='configs/sgm.yaml', help=""
+ "config path for model")
+
+# -----------------------------------------------------------------------------
+# Data
+data_arg = add_argument_group("Data")
+data_arg.add_argument(
+ "--rawdata_path", type=str, default='rawdata', help=""
+ "path for rawdata")
+data_arg.add_argument(
+ "--dataset_path", type=str, default='dataset', help=""
+ "path for dataset")
+data_arg.add_argument(
+ "--desc_path", type=str, default='desc', help=""
+ "path for descriptor(kpt) dir")
+data_arg.add_argument(
+ "--num_kpt", type=int, default=1000, help=""
+ "number of kpt for training")
+data_arg.add_argument(
+ "--input_normalize", type=str, default='img', help=""
+ "normalize type for input kpt, img or intrinsic")
+data_arg.add_argument(
+ "--data_aug", type=str2bool, default=True, help=""
+ "apply kpt coordinate homography augmentation")
+data_arg.add_argument(
+ "--desc_suffix", type=str, default='suffix', help=""
+ "desc file suffix")
+
+
+# -----------------------------------------------------------------------------
+# Loss
+loss_arg = add_argument_group("loss")
+loss_arg.add_argument(
+ "--momentum", type=float, default=0.9, help=""
+ "momentum")
+loss_arg.add_argument(
+ "--seed_loss_weight", type=float, default=250, help=""
+ "confidence loss weight for sgm")
+loss_arg.add_argument(
+ "--mid_loss_weight", type=float, default=1, help=""
+ "midseeding loss weight for sgm")
+loss_arg.add_argument(
+ "--inlier_th", type=float, default=5e-3, help=""
+ "inlier threshold for epipolar distance (for sgm and visualization)")
+
+
+# -----------------------------------------------------------------------------
+# Training
+train_arg = add_argument_group("Train")
+train_arg.add_argument(
+ "--train_lr", type=float, default=1e-4, help=""
+ "learning rate")
+train_arg.add_argument(
+ "--train_batch_size", type=int, default=16, help=""
+ "batch size")
+train_arg.add_argument(
+ "--gpu_id", type=str,default='0', help='id(s) for CUDA_VISIBLE_DEVICES')
+train_arg.add_argument(
+ "--train_iter", type=int, default=1000000, help=""
+ "training iterations to perform")
+train_arg.add_argument(
+ "--log_base", type=str, default="./log/", help=""
+ "log path")
+train_arg.add_argument(
+ "--val_intv", type=int, default=20000, help=""
+ "validation interval")
+train_arg.add_argument(
+ "--save_intv", type=int, default=1000, help=""
+ "summary interval")
+train_arg.add_argument(
+ "--log_intv", type=int, default=100, help=""
+ "log interval")
+train_arg.add_argument(
+ "--decay_rate", type=float, default=0.999996, help=""
+ "lr decay rate")
+train_arg.add_argument(
+ "--decay_iter", type=float, default=300000, help=""
+ "lr decay iter")
+train_arg.add_argument(
+ "--local_rank", type=int, default=0, help=""
+ "local rank for ddp")
+train_arg.add_argument(
+ "--train_vis_folder", type=str, default='.', help=""
+ "visualization folder during training")
+
+# -----------------------------------------------------------------------------
+# Visualization
+vis_arg = add_argument_group('Visualization')
+vis_arg.add_argument(
+ "--tqdm_width", type=int, default=79, help=""
+ "width of the tqdm bar"
+)
+
+def get_config():
+ config, unparsed = parser.parse_known_args()
+ return config, unparsed
+
+
+def print_usage():
+ parser.print_usage()
+
+#
+# config.py ends here
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/train/configs/sg.yaml b/imcui/third_party/SGMNet/train/configs/sg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bb03f39f9d8445b1e345d8f8f6ac17eb6d981bc1
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/configs/sg.yaml
@@ -0,0 +1,5 @@
+net_channels: 128
+layer_num: 9
+head: 4
+use_score_encoding: True
+p_th: 0.2
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/train/configs/sgm.yaml b/imcui/third_party/SGMNet/train/configs/sgm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d674adf562a8932192a0a3bb1a993cf90d28e989
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/configs/sgm.yaml
@@ -0,0 +1,12 @@
+seed_top_k: [128,128]
+seed_radius_coe: 0.01
+net_channels: 128
+layer_num: 9
+head: 4
+seedlayer: [0,6]
+use_mc_seeding: True
+use_score_encoding: False
+conf_bar: [1,0.1]
+sink_iter: [10,100]
+detach_iter: 140000
+p_th: 0.2
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/train/dataset.py b/imcui/third_party/SGMNet/train/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d07a84e9588b755a86119363f08860187d1668c0
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/dataset.py
@@ -0,0 +1,143 @@
+import numpy as np
+import torch
+import torch.utils.data as data
+import cv2
+import os
+import h5py
+import random
+
+import sys
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
+sys.path.insert(0, ROOT_DIR)
+
+from utils import train_utils,evaluation_utils
+
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+
+class Offline_Dataset(data.Dataset):
+ def __init__(self,config,mode):
+ assert mode=='train' or mode=='valid'
+
+ self.config = config
+ self.mode = mode
+ metadir=os.path.join(config.dataset_path,'valid') if mode=='valid' else os.path.join(config.dataset_path,'train')
+
+ pair_num_list=np.loadtxt(os.path.join(metadir,'pair_num.txt'),dtype=str)
+ self.total_pairs=int(pair_num_list[0,1])
+ self.pair_seq_list,self.accu_pair_num=train_utils.parse_pair_seq(pair_num_list)
+
+
+ def collate_fn(self, batch):
+ batch_size, num_pts = len(batch), batch[0]['x1'].shape[0]
+
+ data = {}
+ dtype=['x1','x2','kpt1','kpt2','desc1','desc2','num_corr','num_incorr1','num_incorr2','e_gt','pscore1','pscore2','img_path1','img_path2']
+ for key in dtype:
+ data[key]=[]
+ for sample in batch:
+ for key in dtype:
+ data[key].append(sample[key])
+
+ for key in ['x1', 'x2','kpt1','kpt2', 'desc1', 'desc2','e_gt','pscore1','pscore2']:
+ data[key] = torch.from_numpy(np.stack(data[key])).float()
+ for key in ['num_corr', 'num_incorr1', 'num_incorr2']:
+ data[key] = torch.from_numpy(np.stack(data[key])).int()
+
+ # kpt augmentation with random homography
+ if (self.mode == 'train' and self.config.data_aug):
+ homo_mat = torch.from_numpy(train_utils.get_rnd_homography(batch_size)).unsqueeze(1)
+ aug_seed=random.random()
+ if aug_seed<0.5:
+ x1_homo = torch.cat([data['x1'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1)
+ x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1)
+ data['aug_x1'] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1)
+ data['aug_x2']=data['x2']
+ else:
+ x2_homo = torch.cat([data['x2'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1)
+ x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1)
+ data['aug_x2'] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1)
+ data['aug_x1']=data['x1']
+ else:
+ data['aug_x1'],data['aug_x2']=data['x1'],data['x2']
+ return data
+
+
+ def __getitem__(self, index):
+ seq=self.pair_seq_list[index]
+ index_within_seq=index-self.accu_pair_num[seq]
+
+ with h5py.File(os.path.join(self.config.dataset_path,seq,'info.h5py'),'r') as data:
+ R,t = data['dR'][str(index_within_seq)][()], data['dt'][str(index_within_seq)][()]
+ egt = np.reshape(np.matmul(np.reshape(evaluation_utils.np_skew_symmetric(t.astype('float64').reshape(1, 3)), (3, 3)),np.reshape(R.astype('float64'), (3, 3))), (3, 3))
+ egt = egt / np.linalg.norm(egt)
+ K1, K2 = data['K1'][str(index_within_seq)][()],data['K2'][str(index_within_seq)][()]
+ size1,size2=data['size1'][str(index_within_seq)][()],data['size2'][str(index_within_seq)][()]
+
+ img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode()
+ img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1]
+ img_path1,img_path2=os.path.join(self.config.rawdata_path,img_path1),os.path.join(self.config.rawdata_path,img_path2)
+ fea_path1,fea_path2=os.path.join(self.config.desc_path,seq,img_name1+self.config.desc_suffix),\
+ os.path.join(self.config.desc_path,seq,img_name2+self.config.desc_suffix)
+ with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2:
+ desc1,kpt1,pscore1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2],fea1['keypoints'][()][:,2]
+ desc2,kpt2,pscore2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2],fea2['keypoints'][()][:,2]
+ kpt1,kpt2,desc1,desc2=kpt1[:self.config.num_kpt],kpt2[:self.config.num_kpt],desc1[:self.config.num_kpt],desc2[:self.config.num_kpt]
+
+ # normalize kpt
+ if self.config.input_normalize=='intrinsic':
+ x1, x2 = np.concatenate([kpt1, np.ones([kpt1.shape[0], 1])], axis=-1), np.concatenate(
+ [kpt2, np.ones([kpt2.shape[0], 1])], axis=-1)
+ x1, x2 = np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2]
+ elif self.config.input_normalize=='img' :
+ x1,x2=(kpt1-size1/2)/size1,(kpt2-size2/2)/size2
+ S1_inv,S2_inv=np.asarray([[size1[0],0,0.5*size1[0]],[0,size1[1],0.5*size1[1]],[0,0,1]]),\
+ np.asarray([[size2[0],0,0.5*size2[0]],[0,size2[1],0.5*size2[1]],[0,0,1]])
+ M1,M2=np.matmul(np.linalg.inv(K1),S1_inv),np.matmul(np.linalg.inv(K2),S2_inv)
+ egt=np.matmul(np.matmul(M2.transpose(),egt),M1)
+ egt = egt / np.linalg.norm(egt)
+ else:
+ raise NotImplementedError
+
+ corr=data['corr'][str(index_within_seq)][()]
+ incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()]
+
+ #permute kpt
+ valid_corr=corr[corr.max(axis=-1)= cur_kpt1):
+ sub_idx1 =np.random.choice(len(invalid_index1), cur_kpt1,replace=False)
+ if (invalid_index2.shape[0] < cur_kpt2):
+ sub_idx2 = np.concatenate([np.arange(len(invalid_index2)),np.random.randint(len(invalid_index2),size=cur_kpt2-len(invalid_index2))])
+ if (invalid_index2.shape[0] >= cur_kpt2):
+ sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2,replace=False)
+
+ per_idx1,per_idx2=np.concatenate([valid_corr[:,0],valid_incorr1,invalid_index1[sub_idx1]]),\
+ np.concatenate([valid_corr[:,1],valid_incorr2,invalid_index2[sub_idx2]])
+
+ pscore1,pscore2=pscore1[per_idx1][:,np.newaxis],pscore2[per_idx2][:,np.newaxis]
+ x1,x2=x1[per_idx1][:,:2],x2[per_idx2][:,:2]
+ desc1,desc2=desc1[per_idx1],desc2[per_idx2]
+ kpt1,kpt2=kpt1[per_idx1],kpt2[per_idx2]
+
+ return {'x1': x1, 'x2': x2, 'kpt1':kpt1,'kpt2':kpt2,'desc1': desc1, 'desc2': desc2, 'num_corr': num_corr, 'num_incorr1': num_incorr1,'num_incorr2': num_incorr2,'e_gt':egt,\
+ 'pscore1':pscore1,'pscore2':pscore2,'img_path1':img_path1,'img_path2':img_path2}
+
+ def __len__(self):
+ return self.total_pairs
+
+
diff --git a/imcui/third_party/SGMNet/train/loss.py b/imcui/third_party/SGMNet/train/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad4234fc5827321c31e72c08ad4a3466bad1c30
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/loss.py
@@ -0,0 +1,125 @@
+import torch
+import numpy as np
+
+
+def batch_episym(x1, x2, F):
+ batch_size, num_pts = x1.shape[0], x1.shape[1]
+ x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1)
+ x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1)
+ F = F.reshape(-1,1,3,3).repeat(1,num_pts,1,1)
+ x2Fx1 = torch.matmul(x2.transpose(2,3), torch.matmul(F, x1)).reshape(batch_size,num_pts)
+ Fx1 = torch.matmul(F,x1).reshape(batch_size,num_pts,3)
+ Ftx2 = torch.matmul(F.transpose(2,3),x2).reshape(batch_size,num_pts,3)
+ ys = (x2Fx1**2 * (
+ 1.0 / (Fx1[:, :, 0]**2 + Fx1[:, :, 1]**2 + 1e-15) +
+ 1.0 / (Ftx2[:, :, 0]**2 + Ftx2[:, :, 1]**2 + 1e-15))).sqrt()
+ return ys
+
+
+def CELoss(seed_x1,seed_x2,e,confidence,inlier_th,batch_mask=1):
+ #seed_x: b*k*2
+ ys=batch_episym(seed_x1,seed_x2,e)
+ mask_pos,mask_neg=(ys<=inlier_th).float(),(ys>inlier_th).float()
+ num_pos,num_neg=torch.relu(torch.sum(mask_pos, dim=1) - 1.0) + 1.0,torch.relu(torch.sum(mask_neg, dim=1) - 1.0) + 1.0
+ loss_pos,loss_neg=-torch.log(abs(confidence) + 1e-8)*mask_pos,-torch.log(abs(1-confidence)+1e-8)*mask_neg
+ classif_loss = torch.mean(loss_pos * 0.5 / num_pos.unsqueeze(-1) + loss_neg * 0.5 / num_neg.unsqueeze(-1),dim=-1)
+ classif_loss =classif_loss*batch_mask
+ classif_loss=classif_loss.mean()
+ precision = torch.mean(
+ torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) /
+ (torch.sum((confidence > 0.5).type(confidence.type()), dim=1)+1e-8)
+ )
+ recall = torch.mean(
+ torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1) /
+ num_pos
+ )
+ return classif_loss,precision,recall
+
+
+def CorrLoss(desc_mat,batch_num_corr,batch_num_incorr1,batch_num_incorr2):
+ total_loss_corr,total_loss_incorr=0,0
+ total_acc_corr,total_acc_incorr=0,0
+ batch_size = desc_mat.shape[0]
+ log_p=torch.log(abs(desc_mat)+1e-8)
+
+ for i in range(batch_size):
+ cur_log_p=log_p[i]
+ num_corr=batch_num_corr[i]
+ num_incorr1,num_incorr2=batch_num_incorr1[i],batch_num_incorr2[i]
+
+ #loss and acc
+ loss_corr = -torch.diag(cur_log_p)[:num_corr].mean()
+ loss_incorr=(-cur_log_p[num_corr:num_corr+num_incorr1,-1].mean()-cur_log_p[-1,num_corr:num_corr+num_incorr2].mean())/2
+
+ value_row, row_index = torch.max(desc_mat[i,:-1,:-1], dim=-1)
+ value_col, col_index = torch.max(desc_mat[i,:-1,:-1], dim=-2)
+ acc_incorr=((value_row[num_corr:num_corr+num_incorr1]<0.2).float().mean()+
+ (value_col[num_corr:num_corr+num_incorr2]<0.2).float().mean())/2
+
+ acc_row_mask = row_index[:num_corr] == torch.arange(num_corr).cuda()
+ acc_col_mask = col_index[:num_corr] == torch.arange(num_corr).cuda()
+ acc = (acc_col_mask & acc_row_mask).float().mean()
+
+ total_loss_corr+=loss_corr
+ total_loss_incorr+=loss_incorr
+ total_acc_corr += acc
+ total_acc_incorr+=acc_incorr
+
+ total_acc_corr/=batch_size
+ total_acc_incorr/=batch_size
+ total_loss_corr/=batch_size
+ total_loss_incorr/=batch_size
+ return total_loss_corr,total_loss_incorr,total_acc_corr,total_acc_incorr
+
+
+class SGMLoss:
+ def __init__(self,config,model_config):
+ self.config=config
+ self.model_config=model_config
+
+ def run(self,data,result):
+ loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2'])
+ loss_mid_corr_tower,loss_mid_incorr_tower,acc_mid_tower=[],[],[]
+
+ #mid loss
+ for i in range(len(result['mid_p'])):
+ mid_p=result['mid_p'][i]
+ loss_mid_corr,loss_mid_incorr,mid_acc_corr,mid_acc_incorr=CorrLoss(mid_p,data['num_corr'],data['num_incorr1'],data['num_incorr2'])
+ loss_mid_corr_tower.append(loss_mid_corr),loss_mid_incorr_tower.append(loss_mid_incorr),acc_mid_tower.append(mid_acc_corr)
+ if len(result['mid_p']) != 0:
+ loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower = torch.stack(loss_mid_corr_tower), torch.stack(loss_mid_incorr_tower), torch.stack(acc_mid_tower)
+ else:
+ loss_mid_corr_tower,loss_mid_incorr_tower, acc_mid_tower= torch.zeros(1).cuda(), torch.zeros(1).cuda(),torch.zeros(1).cuda()
+
+ #seed confidence loss
+ classif_loss_tower,classif_precision_tower,classif_recall_tower=[],[],[]
+ for layer in range(len(result['seed_conf'])):
+ confidence=result['seed_conf'][layer]
+ seed_index=result['seed_index'][(np.asarray(self.model_config.seedlayer)<=layer).nonzero()[0][-1]]
+ seed_x1,seed_x2=data['x1'].gather(dim=1, index=seed_index[:,:,0,None].expand(-1, -1,2)),\
+ data['x2'].gather(dim=1, index=seed_index[:,:,1,None].expand(-1, -1,2))
+ classif_loss,classif_precision,classif_recall=CELoss(seed_x1,seed_x2,data['e_gt'],confidence,self.config.inlier_th)
+ classif_loss_tower.append(classif_loss), classif_precision_tower.append(classif_precision), classif_recall_tower.append(classif_recall)
+ classif_loss, classif_precision_tower, classif_recall_tower=torch.stack(classif_loss_tower).mean(),torch.stack(classif_precision_tower), \
+ torch.stack(classif_recall_tower)
+
+
+ classif_loss*=self.config.seed_loss_weight
+ loss_mid_corr_tower*=self.config.mid_loss_weight
+ loss_mid_incorr_tower*=self.config.mid_loss_weight
+ total_loss=loss_corr+loss_incorr+classif_loss+loss_mid_corr_tower.sum()+loss_mid_incorr_tower.sum()
+
+ return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'loss_seed_conf':classif_loss,
+ 'pre_seed_conf':classif_precision_tower,'recall_seed_conf':classif_recall_tower,'loss_corr_mid':loss_mid_corr_tower,
+ 'loss_incorr_mid':loss_mid_incorr_tower,'mid_acc_corr':acc_mid_tower,'total_loss':total_loss}
+
+class SGLoss:
+ def __init__(self,config,model_config):
+ self.config=config
+ self.model_config=model_config
+
+ def run(self,data,result):
+ loss_corr,loss_incorr,acc_corr,acc_incorr=CorrLoss(result['p'],data['num_corr'],data['num_incorr1'],data['num_incorr2'])
+ total_loss=loss_corr+loss_incorr
+ return {'loss_corr':loss_corr,'loss_incorr':loss_incorr,'acc_corr':acc_corr,'acc_incorr':acc_incorr,'total_loss':total_loss}
+
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/train/main.py b/imcui/third_party/SGMNet/train/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4c8fff432a3b2d58c82b9e5f2897a4e702b2dd
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/main.py
@@ -0,0 +1,61 @@
+import torch.utils.data
+from dataset import Offline_Dataset
+import yaml
+from sgmnet.match_model import matcher as SGM_Model
+from superglue.match_model import matcher as SG_Model
+import torch.distributed as dist
+import torch
+import os
+from collections import namedtuple
+from train import train
+from config import get_config, print_usage
+
+
+def main(config,model_config):
+ """The main function."""
+ # Initialize network
+ if config.model_name=='SGM':
+ model = SGM_Model(model_config)
+ elif config.model_name=='SG':
+ model= SG_Model(model_config)
+ else:
+ raise NotImplementedError
+
+ #initialize ddp
+ torch.cuda.set_device(config.local_rank)
+ device = torch.device(f'cuda:{config.local_rank}')
+ model.to(device)
+ dist.init_process_group(backend='nccl',init_method='env://')
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank])
+
+ if config.local_rank==0:
+ os.system('nvidia-smi')
+
+ #initialize dataset
+ train_dataset = Offline_Dataset(config,'train')
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True)
+ train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(),
+ num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn)
+
+ valid_dataset = Offline_Dataset(config,'valid')
+ valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False)
+ valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size,
+ num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler)
+
+ if config.local_rank==0:
+ print('start training .....')
+ train(model,train_loader, valid_loader, config,model_config)
+
+if __name__ == "__main__":
+ # ----------------------------------------
+ # Parse configuration
+ config, unparsed = get_config()
+ with open(config.config_path, 'r') as f:
+ model_config = yaml.load(f)
+ model_config=namedtuple('model_config',model_config.keys())(*model_config.values())
+ # If we have unparsed arguments, print usage and exit
+ if len(unparsed) > 0:
+ print_usage()
+ exit(1)
+
+ main(config,model_config)
diff --git a/imcui/third_party/SGMNet/train/train.py b/imcui/third_party/SGMNet/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e848e1d2e5f028d4ff3abaf0cc446be7d89c65
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/train.py
@@ -0,0 +1,160 @@
+import torch
+import torch.optim as optim
+from tqdm import trange
+import os
+from tensorboardX import SummaryWriter
+import numpy as np
+import cv2
+from loss import SGMLoss,SGLoss
+from valid import valid,dump_train_vis
+
+import sys
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, ROOT_DIR)
+
+
+from utils import train_utils
+
+def train_step(optimizer, model, match_loss, data,step,pre_avg_loss):
+ data['step']=step
+ result=model(data,test_mode=False)
+ loss_res=match_loss.run(data,result)
+
+ optimizer.zero_grad()
+ loss_res['total_loss'].backward()
+ #apply reduce on all record tensor
+ for key in loss_res.keys():
+ loss_res[key]=train_utils.reduce_tensor(loss_res[key],'mean')
+
+ if loss_res['total_loss']<7*pre_avg_loss or step<200 or pre_avg_loss==0:
+ optimizer.step()
+ unusual_loss=False
+ else:
+ optimizer.zero_grad()
+ unusual_loss=True
+ return loss_res,unusual_loss
+
+
+def train(model, train_loader, valid_loader, config,model_config):
+ model.train()
+ optimizer = optim.Adam(model.parameters(), lr=config.train_lr)
+
+ if config.model_name=='SGM':
+ match_loss = SGMLoss(config,model_config)
+ elif config.model_name=='SG':
+ match_loss= SGLoss(config,model_config)
+ else:
+ raise NotImplementedError
+
+ checkpoint_path = os.path.join(config.log_base, 'checkpoint.pth')
+ config.resume = os.path.isfile(checkpoint_path)
+ if config.resume:
+ if config.local_rank==0:
+ print('==> Resuming from checkpoint..')
+ checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(config.local_rank))
+ model.load_state_dict(checkpoint['state_dict'])
+ best_acc = checkpoint['best_acc']
+ start_step = checkpoint['step']
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ else:
+ best_acc = -1
+ start_step = 0
+ train_loader_iter = iter(train_loader)
+
+ if config.local_rank==0:
+ writer=SummaryWriter(os.path.join(config.log_base,'log_file'))
+
+ train_loader.sampler.set_epoch(start_step*config.train_batch_size//len(train_loader.dataset))
+ pre_avg_loss=0
+
+ progress_bar=trange(start_step, config.train_iter,ncols=config.tqdm_width) if config.local_rank==0 else range(start_step, config.train_iter)
+ for step in progress_bar:
+ try:
+ train_data = next(train_loader_iter)
+ except StopIteration:
+ if config.local_rank==0:
+ print('epoch: ',step*config.train_batch_size//len(train_loader.dataset))
+ train_loader.sampler.set_epoch(step*config.train_batch_size//len(train_loader.dataset))
+ train_loader_iter = iter(train_loader)
+ train_data = next(train_loader_iter)
+
+ train_data = train_utils.tocuda(train_data)
+ lr=min(config.train_lr*config.decay_rate**(step-config.decay_iter),config.train_lr)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+ # run training
+ loss_res,unusual_loss = train_step(optimizer, model, match_loss, train_data,step-start_step,pre_avg_loss)
+ if (step-start_step)<=200:
+ pre_avg_loss=loss_res['total_loss'].data
+ if (step-start_step)>200 and not unusual_loss:
+ pre_avg_loss=pre_avg_loss.data*0.9+loss_res['total_loss'].data*0.1
+ if unusual_loss and config.local_rank==0:
+ print('unusual loss! pre_avg_loss: ',pre_avg_loss,'cur_loss: ',loss_res['total_loss'].data)
+ #log
+ if config.local_rank==0 and step%config.log_intv==0 and not unusual_loss:
+ writer.add_scalar('TotalLoss',loss_res['total_loss'],step)
+ writer.add_scalar('CorrLoss',loss_res['loss_corr'],step)
+ writer.add_scalar('InCorrLoss', loss_res['loss_incorr'], step)
+ writer.add_scalar('dustbin', model.module.dustbin, step)
+
+ if config.model_name=='SGM':
+ writer.add_scalar('SeedConfLoss', loss_res['loss_seed_conf'], step)
+ writer.add_scalar('MidCorrLoss', loss_res['loss_corr_mid'].sum(), step)
+ writer.add_scalar('MidInCorrLoss', loss_res['loss_incorr_mid'].sum(), step)
+
+
+ # valid ans save
+ b_save = ((step + 1) % config.save_intv) == 0
+ b_validate = ((step + 1) % config.val_intv) == 0
+ if b_validate:
+ total_loss,acc_corr,acc_incorr,seed_precision_tower,seed_recall_tower,acc_mid=valid(valid_loader, model, match_loss, config,model_config)
+ if config.local_rank==0:
+ writer.add_scalar('ValidAcc', acc_corr, step)
+ writer.add_scalar('ValidLoss', total_loss, step)
+
+ if config.model_name=='SGM':
+ for i in range(len(seed_recall_tower)):
+ writer.add_scalar('seed_conf_pre_%d'%i,seed_precision_tower[i],step)
+ writer.add_scalar('seed_conf_recall_%d' % i, seed_precision_tower[i], step)
+ for i in range(len(acc_mid)):
+ writer.add_scalar('acc_mid%d'%i,acc_mid[i],step)
+ print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data,'seed_conf_pre: ',seed_precision_tower.mean().data,
+ 'seed_conf_recall: ',seed_recall_tower.mean().data,'acc_mid: ',acc_mid.mean().data)
+ else:
+ print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data)
+
+ #saving best
+ if acc_corr > best_acc:
+ print("Saving best model with va_res = {}".format(acc_corr))
+ best_acc = acc_corr
+ save_dict={'step': step + 1,
+ 'state_dict': model.state_dict(),
+ 'best_acc': best_acc,
+ 'optimizer' : optimizer.state_dict()}
+ save_dict.update(save_dict)
+ torch.save(save_dict, os.path.join(config.log_base, 'model_best.pth'))
+
+ if b_save:
+ if config.local_rank==0:
+ save_dict={'step': step + 1,
+ 'state_dict': model.state_dict(),
+ 'best_acc': best_acc,
+ 'optimizer' : optimizer.state_dict()}
+ torch.save(save_dict, checkpoint_path)
+
+ #draw match results
+ model.eval()
+ with torch.no_grad():
+ if config.local_rank==0:
+ if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis')):
+ os.mkdir(os.path.join(config.train_vis_folder,'train_vis'))
+ if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis',config.log_base)):
+ os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base))
+ os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step)))
+ res=model(train_data)
+ dump_train_vis(res,train_data,step,config)
+ model.train()
+
+ if config.local_rank==0:
+ writer.close()
diff --git a/imcui/third_party/SGMNet/train/valid.py b/imcui/third_party/SGMNet/train/valid.py
new file mode 100644
index 0000000000000000000000000000000000000000..443694d85104730cd50aeb342326ce593dc5684d
--- /dev/null
+++ b/imcui/third_party/SGMNet/train/valid.py
@@ -0,0 +1,77 @@
+import torch
+import numpy as np
+import cv2
+import os
+from loss import batch_episym
+from tqdm import tqdm
+
+import sys
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, ROOT_DIR)
+
+from utils import evaluation_utils,train_utils
+
+
+def valid(valid_loader, model,match_loss, config,model_config):
+ model.eval()
+ loader_iter = iter(valid_loader)
+ num_pair = 0
+ total_loss,total_acc_corr,total_acc_incorr=0,0,0
+ total_precision,total_recall=torch.zeros(model_config.layer_num ,device='cuda'),\
+ torch.zeros(model_config.layer_num ,device='cuda')
+ total_acc_mid=torch.zeros(len(model_config.seedlayer)-1,device='cuda')
+
+ with torch.no_grad():
+ if config.local_rank==0:
+ loader_iter=tqdm(loader_iter)
+ print('validating...')
+ for test_data in loader_iter:
+ num_pair+= 1
+ test_data = train_utils.tocuda(test_data)
+ res= model(test_data)
+ loss_res=match_loss.run(test_data,res)
+
+ total_acc_corr+=loss_res['acc_corr']
+ total_acc_incorr+=loss_res['acc_incorr']
+ total_loss+=loss_res['total_loss']
+
+ if config.model_name=='SGM':
+ total_acc_mid+=loss_res['mid_acc_corr']
+ total_precision,total_recall=total_precision+loss_res['pre_seed_conf'],total_recall+loss_res['recall_seed_conf']
+
+ total_acc_corr/=num_pair
+ total_acc_incorr /= num_pair
+ total_precision/=num_pair
+ total_recall/=num_pair
+ total_acc_mid/=num_pair
+
+ #apply tensor reduction
+ total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid=train_utils.reduce_tensor(total_loss,'sum'),\
+ train_utils.reduce_tensor(total_acc_corr,'mean'),train_utils.reduce_tensor(total_acc_incorr,'mean'),\
+ train_utils.reduce_tensor(total_precision,'mean'),train_utils.reduce_tensor(total_recall,'mean'),train_utils.reduce_tensor(total_acc_mid,'mean')
+ model.train()
+ return total_loss,total_acc_corr,total_acc_incorr,total_precision,total_recall,total_acc_mid
+
+
+
+def dump_train_vis(res,data,step,config):
+ #batch matching
+ p=res['p'][:,:-1,:-1]
+ score,index1=torch.max(p,dim=-1)
+ _,index2=torch.max(p,dim=-2)
+ mask_th=score>0.2
+ mask_mc=index2.gather(index=index1,dim=1) == torch.arange(len(p[0])).cuda()[None]
+ mask_p=mask_th&mask_mc#B*N
+
+ corr1,corr2=data['x1'],data['x2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1)
+ corr1_kpt,corr2_kpt=data['kpt1'],data['kpt2'].gather(index=index1[:,:,None].expand(-1,-1,2),dim=1)
+ epi_dis=batch_episym(corr1,corr2,data['e_gt'])
+ mask_inlier=epi_dis0,i0,j 0,
+ depth_top_right > 0
+ ),
+ np.logical_and(
+ depth_down_left > 0,
+ depth_down_left > 0
+ )
+ )
+ ids=ids[valid_depth]
+ depth_top_left,depth_top_right,depth_down_left,depth_down_right=depth_top_left[valid_depth],depth_top_right[valid_depth],\
+ depth_down_left[valid_depth],depth_down_right[valid_depth]
+
+ i,j,i_top_left,j_top_left=i[valid_depth],j[valid_depth],i_top_left[valid_depth],j_top_left[valid_depth]
+
+ # Interpolation
+ dist_i_top_left = i - i_top_left.astype(np.float32)
+ dist_j_top_left = j - j_top_left.astype(np.float32)
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
+ w_bottom_right = dist_i_top_left * dist_j_top_left
+
+ interpolated_depth = (
+ w_top_left * depth_top_left +
+ w_top_right * depth_top_right+
+ w_bottom_left * depth_down_left +
+ w_bottom_right * depth_down_right
+ )
+ return [interpolated_depth, ids]
+
+
+def reprojection(depth_map,kpt,dR,dt,K1_img2depth,K1,K2):
+ #warp kpt from img1 to img2
+ def swap_axis(data):
+ return np.stack([data[:, 1], data[:, 0]], axis=-1)
+
+ kp_depth = unnorm_kp(K1_img2depth,kpt)
+ uv_depth = swap_axis(kp_depth)
+ z,valid_idx = interpolate_depth(uv_depth, depth_map)
+
+ norm_kp=norm_kpt(K1,kpt)
+ norm_kp_valid = np.concatenate([norm_kp[valid_idx, :], np.ones((len(valid_idx), 1))], axis=-1)
+ xyz_valid = norm_kp_valid * z.reshape(-1, 1)
+ xyz2 = np.matmul(xyz_valid, dR.T) + dt.reshape(1, 3)
+ xy2 = xyz2[:, :2] / xyz2[:, 2:]
+ kp2, valid = np.ones(kpt.shape) * 1e5, np.zeros(kpt.shape[0])
+ kp2[valid_idx] = unnorm_kp(K2,xy2)
+ valid[valid_idx] = 1
+ return kp2, valid.astype(bool)
+
+def reprojection_2s(kp1, kp2,depth1, depth2, K1, K2, dR, dt, size1,size2):
+ #size:H*W
+ depth_size1,depth_size2 = [depth1.shape[0], depth1.shape[1]], [depth2.shape[0], depth2.shape[1]]
+ scale_1= [float(depth_size1[0]) / size1[0], float(depth_size1[1]) / size1[1], 1]
+ scale_2= [float(depth_size2[0]) / size2[0], float(depth_size2[1]) / size2[1], 1]
+ K1_img2depth, K2_img2depth = np.diag(np.asarray(scale_1)), np.diag(np.asarray(scale_2))
+ kp1_2_proj, valid1_2 = reprojection(depth1, kp1, dR, dt, K1_img2depth,K1,K2)
+ kp2_1_proj, valid2_1 = reprojection(depth2, kp2, dR.T, -np.matmul(dR.T, dt), K2_img2depth,K2,K1)
+ return [kp1_2_proj,kp2_1_proj],[valid1_2,valid2_1]
+
+def make_corr(kp1,kp2,desc1,desc2,depth1,depth2,K1,K2,dR,dt,size1,size2,corr_th,incorr_th,check_desc=False):
+ #make reprojection
+ [kp1_2,kp2_1],[valid1_2,valid2_1]=reprojection_2s(kp1,kp2,depth1,depth2,K1,K2,dR,dt,size1,size2)
+ num_pts1, num_pts2 = kp1.shape[0], kp2.shape[0]
+ #reprojection error
+ dis_mat1=np.sqrt(abs((kp1 ** 2).sum(1,keepdims=True) + (kp2_1 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp1, kp2_1.T)))
+ dis_mat2 =np.sqrt(abs((kp2 ** 2).sum(1,keepdims=True) + (kp1_2 ** 2).sum(1,keepdims=False)[np.newaxis] - 2 * np.matmul(kp2,kp1_2.T)))
+ repro_error = np.maximum(dis_mat1,dis_mat2.T) #n1*n2
+
+ # find corr index
+ nn_sort1 = np.argmin(repro_error, axis=1)
+ nn_sort2 = np.argmin(repro_error, axis=0)
+ mask_mutual = nn_sort2[nn_sort1] == np.arange(kp1.shape[0])
+ mask_inlier=np.take_along_axis(repro_error,indices=nn_sort1[:,np.newaxis],axis=-1).squeeze(1)1,mask_samepos2.sum(-1)>1)
+ duplicated_index=np.nonzero(duplicated_mask)[0]
+
+ unique_corr_index=corr_index[~duplicated_mask]
+ clean_duplicated_corr=[]
+ for index in duplicated_index:
+ cur_desc1, cur_desc2 = desc1[mask_samepos1[index]], desc2[mask_samepos2[index]]
+ cur_desc_mat = np.matmul(cur_desc1, cur_desc2.T)
+ cur_max_index =[np.argmax(cur_desc_mat)//cur_desc_mat.shape[1],np.argmax(cur_desc_mat)%cur_desc_mat.shape[1]]
+ clean_duplicated_corr.append(np.stack([np.arange(num_pts1)[mask_samepos1[index]][cur_max_index[0]],
+ np.arange(num_pts2)[mask_samepos2[index]][cur_max_index[1]]]))
+
+ clean_corr_index=unique_corr_index
+ if len(clean_duplicated_corr)!=0:
+ clean_duplicated_corr=np.stack(clean_duplicated_corr,axis=0)
+ clean_corr_index=np.concatenate([clean_corr_index,clean_duplicated_corr],axis=0)
+ else:
+ clean_corr_index=corr_index
+ # find incorr
+ mask_incorr1 = np.min(dis_mat2.T[valid1_2], axis=-1) > incorr_th
+ mask_incorr2 = np.min(dis_mat1.T[valid2_1], axis=-1) > incorr_th
+ incorr_index1, incorr_index2 = np.arange(num_pts1)[valid1_2][mask_incorr1.squeeze()], \
+ np.arange(num_pts2)[valid2_1][mask_incorr2.squeeze()]
+
+ return clean_corr_index,incorr_index1,incorr_index2
+
diff --git a/imcui/third_party/SGMNet/utils/evaluation_utils.py b/imcui/third_party/SGMNet/utils/evaluation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..82c4715a192d3c361c849896b035cd91ee56dc42
--- /dev/null
+++ b/imcui/third_party/SGMNet/utils/evaluation_utils.py
@@ -0,0 +1,58 @@
+import numpy as np
+import h5py
+import cv2
+
+def normalize_intrinsic(x,K):
+ #print(x,K)
+ return (x-K[:2,2])/np.diag(K)[:2]
+
+def normalize_size(x,size,scale=1):
+ size=size.reshape([1,2])
+ norm_fac=size.max()
+ return (x-size/2+0.5)/(norm_fac*scale)
+
+def np_skew_symmetric(v):
+ zero = np.zeros_like(v[:, 0])
+ M = np.stack([
+ zero, -v[:, 2], v[:, 1],
+ v[:, 2], zero, -v[:, 0],
+ -v[:, 1], v[:, 0], zero,
+ ], axis=1)
+ return M
+
+def draw_points(img,points,color=(0,255,0),radius=3):
+ dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
+ for i in range(points.shape[0]):
+ cv2.circle(img, dp[i],radius=radius,color=color)
+ return img
+
+
+def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
+ if resize is not None:
+ scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
+ img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA)
+ corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
+ corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
+ corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
+
+ assert len(corr1) == len(corr2)
+
+ draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
+ if color is None:
+ color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
+ if len(color)==1:
+ display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
+ matchColor=color[0],
+ singlePointColor=color[0],
+ flags=4
+ )
+ else:
+ height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
+ display=np.zeros([height,width,3],np.uint8)
+ display[:img1.shape[0],:img1.shape[1]]=img1
+ display[:img2.shape[0],img1.shape[1]:]=img2
+ for i in range(len(corr1)):
+ left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
+ cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
+ cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
+ return display
\ No newline at end of file
diff --git a/imcui/third_party/SGMNet/utils/fm_utils.py b/imcui/third_party/SGMNet/utils/fm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9cbbeefe5d6b59c1ae1fa26cdaa42146ad22a74
--- /dev/null
+++ b/imcui/third_party/SGMNet/utils/fm_utils.py
@@ -0,0 +1,95 @@
+import numpy as np
+
+
+def line_to_border(line,size):
+ #line:(a,b,c), ax+by+c=0
+ #size:(W,H)
+ H,W=size[1],size[0]
+ a,b,c=line[0],line[1],line[2]
+ epsa=1e-8 if a>=0 else -1e-8
+ epsb=1e-8 if b>=0 else -1e-8
+ intersection_list=[]
+
+ y_left=-c/(b+epsb)
+ y_right=(-c-a*(W-1))/(b+epsb)
+ x_top=-c/(a+epsa)
+ x_down=(-c-b*(H-1))/(a+epsa)
+
+ if y_left>=0 and y_left<=H-1:
+ intersection_list.append([0,y_left])
+ if y_right>=0 and y_right<=H-1:
+ intersection_list.append([W-1,y_right])
+ if x_top>=0 and x_top<=W-1:
+ intersection_list.append([x_top,0])
+ if x_down>=0 and x_down<=W-1:
+ intersection_list.append([x_down,H-1])
+ if len(intersection_list)!=2:
+ return None
+ intersection_list=np.asarray(intersection_list)
+ return intersection_list
+
+def find_point_in_line(end_point):
+ x_span,y_span=end_point[1,0]-end_point[0,0],end_point[1,1]-end_point[0,1]
+ mv=np.random.uniform()
+ point=np.asarray([end_point[0,0]+x_span*mv,end_point[0,1]+y_span*mv])
+ return point
+
+def epi_line(point,F):
+ homo=np.concatenate([point,np.ones([len(point),1])],axis=-1)
+ epi=np.matmul(homo,F.T)
+ return epi
+
+def dis_point_to_line(line,point):
+ homo=np.concatenate([point,np.ones([len(point),1])],axis=-1)
+ dis=line*homo
+ dis=dis.sum(axis=-1)/(np.linalg.norm(line[:,:2],axis=-1)+1e-8)
+ return abs(dis)
+
+def SGD_oneiter(F1,F2,size1,size2):
+ H1,W1=size1[1],size1[0]
+ factor1 = 1 / np.linalg.norm(size1)
+ factor2 = 1 / np.linalg.norm(size2)
+ p0=np.asarray([(W1-1)*np.random.uniform(),(H1-1)*np.random.uniform()])
+ epi1=epi_line(p0[np.newaxis],F1)[0]
+ border_point1=line_to_border(epi1,size2)
+ if border_point1 is None:
+ return -1
+
+ p1=find_point_in_line(border_point1)
+ epi2=epi_line(p0[np.newaxis],F2)
+ d1=dis_point_to_line(epi2,p1[np.newaxis])[0]*factor2
+ epi3=epi_line(p1[np.newaxis],F2.T)
+ d2=dis_point_to_line(epi3,p0[np.newaxis])[0]*factor1
+ return (d1+d2)/2
+
+def compute_SGD(F1,F2,size1,size2):
+ np.random.seed(1234)
+ N=1000
+ max_iter=N*10
+ count,sgd=0,0
+ for i in range(max_iter):
+ d1=SGD_oneiter(F1,F2,size1,size2)
+ if d1<0:
+ continue
+ d2=SGD_oneiter(F2,F1,size1,size2)
+ if d2<0:
+ continue
+ count+=1
+ sgd+=(d1+d2)/2
+ if count==N:
+ break
+ if count==0:
+ return 1
+ else:
+ return sgd/count
+
+def compute_inlier_rate(x1,x2,size1,size2,F_gt,th=0.003):
+ t1,t2=np.linalg.norm(size1)*th,np.linalg.norm(size2)*th
+ epi1,epi2=epi_line(x1,F_gt),epi_line(x2,F_gt.T)
+ dis1,dis2=dis_point_to_line(epi1,x2),dis_point_to_line(epi2,x1)
+ mask_inlier=np.logical_and(dis1`_
+
+:Organization:
+ Laboratory for Fluorescence Dynamics, University of California, Irvine
+
+:Version: 2015.07.18
+
+Requirements
+------------
+* `CPython 2.7 or 3.4 `_
+* `Numpy 1.9 `_
+* `Transformations.c 2015.07.18 `_
+ (recommended for speedup of some functions)
+
+Notes
+-----
+The API is not stable yet and is expected to change between revisions.
+
+This Python code is not optimized for speed. Refer to the transformations.c
+module for a faster implementation of some functions.
+
+Documentation in HTML format can be generated with epydoc.
+
+Matrices (M) can be inverted using numpy.linalg.inv(M), be concatenated using
+numpy.dot(M0, M1), or transform homogeneous coordinate arrays (v) using
+numpy.dot(M, v) for shape (4, \*) column vectors, respectively
+numpy.dot(v, M.T) for shape (\*, 4) row vectors ("array of points").
+
+This module follows the "column vectors on the right" and "row major storage"
+(C contiguous) conventions. The translation components are in the right column
+of the transformation matrix, i.e. M[:3, 3].
+The transpose of the transformation matrices may have to be used to interface
+with other graphics systems, e.g. with OpenGL's glMultMatrixd(). See also [16].
+
+Calculations are carried out with numpy.float64 precision.
+
+Vector, point, quaternion, and matrix function arguments are expected to be
+"array like", i.e. tuple, list, or numpy arrays.
+
+Return types are numpy arrays unless specified otherwise.
+
+Angles are in radians unless specified otherwise.
+
+Quaternions w+ix+jy+kz are represented as [w, x, y, z].
+
+A triple of Euler angles can be applied/interpreted in 24 ways, which can
+be specified using a 4 character string or encoded 4-tuple:
+
+ *Axes 4-string*: e.g. 'sxyz' or 'ryxy'
+
+ - first character : rotations are applied to 's'tatic or 'r'otating frame
+ - remaining characters : successive rotation axis 'x', 'y', or 'z'
+
+ *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1)
+
+ - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix.
+ - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed
+ by 'z', or 'z' is followed by 'x'. Otherwise odd (1).
+ - repetition : first and last axis are same (1) or different (0).
+ - frame : rotations are applied to static (0) or rotating (1) frame.
+
+Other Python packages and modules for 3D transformations and quaternions:
+
+* `Transforms3d `_
+ includes most code of this module.
+* `Blender.mathutils `_
+* `numpy-dtypes `_
+
+References
+----------
+(1) Matrices and transformations. Ronald Goldman.
+ In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990.
+(2) More matrices and transformations: shear and pseudo-perspective.
+ Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
+(3) Decomposing a matrix into simple transformations. Spencer Thomas.
+ In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
+(4) Recovering the data from the transformation matrix. Ronald Goldman.
+ In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991.
+(5) Euler angle conversion. Ken Shoemake.
+ In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994.
+(6) Arcball rotation control. Ken Shoemake.
+ In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994.
+(7) Representing attitude: Euler angles, unit quaternions, and rotation
+ vectors. James Diebel. 2006.
+(8) A discussion of the solution for the best rotation to relate two sets
+ of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828.
+(9) Closed-form solution of absolute orientation using unit quaternions.
+ BKP Horn. J Opt Soc Am A. 1987. 4(4):629-642.
+(10) Quaternions. Ken Shoemake.
+ http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf
+(11) From quaternion to matrix and back. JMP van Waveren. 2005.
+ http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm
+(12) Uniform random rotations. Ken Shoemake.
+ In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992.
+(13) Quaternion in molecular modeling. CFF Karney.
+ J Mol Graph Mod, 25(5):595-604
+(14) New method for extracting the quaternion from a rotation matrix.
+ Itzhack Y Bar-Itzhack, J Guid Contr Dynam. 2000. 23(6): 1085-1087.
+(15) Multiple View Geometry in Computer Vision. Hartley and Zissermann.
+ Cambridge University Press; 2nd Ed. 2004. Chapter 4, Algorithm 4.7, p 130.
+(16) Column Vectors vs. Row Vectors.
+ http://steve.hollasch.net/cgindex/math/matrix/column-vec.html
+
+Examples
+--------
+>>> alpha, beta, gamma = 0.123, -1.234, 2.345
+>>> origin, xaxis, yaxis, zaxis = [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]
+>>> I = identity_matrix()
+>>> Rx = rotation_matrix(alpha, xaxis)
+>>> Ry = rotation_matrix(beta, yaxis)
+>>> Rz = rotation_matrix(gamma, zaxis)
+>>> R = concatenate_matrices(Rx, Ry, Rz)
+>>> euler = euler_from_matrix(R, 'rxyz')
+>>> numpy.allclose([alpha, beta, gamma], euler)
+True
+>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz')
+>>> is_same_transform(R, Re)
+True
+>>> al, be, ga = euler_from_matrix(Re, 'rxyz')
+>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz'))
+True
+>>> qx = quaternion_about_axis(alpha, xaxis)
+>>> qy = quaternion_about_axis(beta, yaxis)
+>>> qz = quaternion_about_axis(gamma, zaxis)
+>>> q = quaternion_multiply(qx, qy)
+>>> q = quaternion_multiply(q, qz)
+>>> Rq = quaternion_matrix(q)
+>>> is_same_transform(R, Rq)
+True
+>>> S = scale_matrix(1.23, origin)
+>>> T = translation_matrix([1, 2, 3])
+>>> Z = shear_matrix(beta, xaxis, origin, zaxis)
+>>> R = random_rotation_matrix(numpy.random.rand(3))
+>>> M = concatenate_matrices(T, R, Z, S)
+>>> scale, shear, angles, trans, persp = decompose_matrix(M)
+>>> numpy.allclose(scale, 1.23)
+True
+>>> numpy.allclose(trans, [1, 2, 3])
+True
+>>> numpy.allclose(shear, [0, math.tan(beta), 0])
+True
+>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles))
+True
+>>> M1 = compose_matrix(scale, shear, angles, trans, persp)
+>>> is_same_transform(M, M1)
+True
+>>> v0, v1 = random_vector(3), random_vector(3)
+>>> M = rotation_matrix(angle_between_vectors(v0, v1), vector_product(v0, v1))
+>>> v2 = numpy.dot(v0, M[:3,:3].T)
+>>> numpy.allclose(unit_vector(v1), unit_vector(v2))
+True
+
+"""
+
+from __future__ import division, print_function
+
+import math
+
+import numpy
+
+__version__ = '2015.07.18'
+__docformat__ = 'restructuredtext en'
+__all__ = ()
+
+
+def identity_matrix():
+ """Return 4x4 identity/unit matrix.
+
+ >>> I = identity_matrix()
+ >>> numpy.allclose(I, numpy.dot(I, I))
+ True
+ >>> numpy.sum(I), numpy.trace(I)
+ (4.0, 4.0)
+ >>> numpy.allclose(I, numpy.identity(4))
+ True
+
+ """
+ return numpy.identity(4)
+
+
+def translation_matrix(direction):
+ """Return matrix to translate by direction vector.
+
+ >>> v = numpy.random.random(3) - 0.5
+ >>> numpy.allclose(v, translation_matrix(v)[:3, 3])
+ True
+
+ """
+ M = numpy.identity(4)
+ M[:3, 3] = direction[:3]
+ return M
+
+
+def translation_from_matrix(matrix):
+ """Return translation vector from translation matrix.
+
+ >>> v0 = numpy.random.random(3) - 0.5
+ >>> v1 = translation_from_matrix(translation_matrix(v0))
+ >>> numpy.allclose(v0, v1)
+ True
+
+ """
+ return numpy.array(matrix, copy=False)[:3, 3].copy()
+
+
+def reflection_matrix(point, normal):
+ """Return matrix to mirror at plane defined by point and normal vector.
+
+ >>> v0 = numpy.random.random(4) - 0.5
+ >>> v0[3] = 1.
+ >>> v1 = numpy.random.random(3) - 0.5
+ >>> R = reflection_matrix(v0, v1)
+ >>> numpy.allclose(2, numpy.trace(R))
+ True
+ >>> numpy.allclose(v0, numpy.dot(R, v0))
+ True
+ >>> v2 = v0.copy()
+ >>> v2[:3] += v1
+ >>> v3 = v0.copy()
+ >>> v2[:3] -= v1
+ >>> numpy.allclose(v2, numpy.dot(R, v3))
+ True
+
+ """
+ normal = unit_vector(normal[:3])
+ M = numpy.identity(4)
+ M[:3, :3] -= 2.0 * numpy.outer(normal, normal)
+ M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal
+ return M
+
+
+def reflection_from_matrix(matrix):
+ """Return mirror plane point and normal vector from reflection matrix.
+
+ >>> v0 = numpy.random.random(3) - 0.5
+ >>> v1 = numpy.random.random(3) - 0.5
+ >>> M0 = reflection_matrix(v0, v1)
+ >>> point, normal = reflection_from_matrix(M0)
+ >>> M1 = reflection_matrix(point, normal)
+ >>> is_same_transform(M0, M1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ # normal: unit eigenvector corresponding to eigenvalue -1
+ w, V = numpy.linalg.eig(M[:3, :3])
+ i = numpy.where(abs(numpy.real(w) + 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue -1")
+ normal = numpy.real(V[:, i[0]]).squeeze()
+ # point: any unit eigenvector corresponding to eigenvalue 1
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ return point, normal
+
+
+def rotation_matrix(angle, direction, point=None):
+ """Return matrix to rotate about axis defined by point and direction.
+
+ >>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0])
+ >>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1])
+ True
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
+ >>> direc = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> R1 = rotation_matrix(-angle, -direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+ >>> I = numpy.identity(4, numpy.float64)
+ >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
+ True
+ >>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2,
+ ... direc, point)))
+ True
+
+ """
+ sina = math.sin(angle)
+ cosa = math.cos(angle)
+ direction = unit_vector(direction[:3])
+ # rotation matrix around unit vector
+ R = numpy.diag([cosa, cosa, cosa])
+ R += numpy.outer(direction, direction) * (1.0 - cosa)
+ direction *= sina
+ R += numpy.array([[ 0.0, -direction[2], direction[1]],
+ [ direction[2], 0.0, -direction[0]],
+ [-direction[1], direction[0], 0.0]])
+ M = numpy.identity(4)
+ M[:3, :3] = R
+ if point is not None:
+ # rotation not around origin
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
+ M[:3, 3] = point - numpy.dot(R, point)
+ return M
+
+
+def rotation_from_matrix(matrix):
+ """Return rotation angle and axis from rotation matrix.
+
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
+ >>> direc = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> R0 = rotation_matrix(angle, direc, point)
+ >>> angle, direc, point = rotation_from_matrix(R0)
+ >>> R1 = rotation_matrix(angle, direc, point)
+ >>> is_same_transform(R0, R1)
+ True
+
+ """
+ R = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ R33 = R[:3, :3]
+ # direction: unit eigenvector of R33 corresponding to eigenvalue of 1
+ w, W = numpy.linalg.eig(R33.T)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
+ direction = numpy.real(W[:, i[-1]]).squeeze()
+ # point: unit eigenvector of R33 corresponding to eigenvalue of 1
+ w, Q = numpy.linalg.eig(R)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
+ point = numpy.real(Q[:, i[-1]]).squeeze()
+ point /= point[3]
+ # rotation angle depending on direction
+ cosa = (numpy.trace(R33) - 1.0) / 2.0
+ if abs(direction[2]) > 1e-8:
+ sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
+ elif abs(direction[1]) > 1e-8:
+ sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
+ else:
+ sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
+ angle = math.atan2(sina, cosa)
+ return angle, direction, point
+
+
+def scale_matrix(factor, origin=None, direction=None):
+ """Return matrix to scale by factor around origin in direction.
+
+ Use factor -1 for point symmetry.
+
+ >>> v = (numpy.random.rand(4, 5) - 0.5) * 20
+ >>> v[3] = 1
+ >>> S = scale_matrix(-1.234)
+ >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3])
+ True
+ >>> factor = random.random() * 10 - 5
+ >>> origin = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> S = scale_matrix(factor, origin)
+ >>> S = scale_matrix(factor, origin, direct)
+
+ """
+ if direction is None:
+ # uniform scaling
+ M = numpy.diag([factor, factor, factor, 1.0])
+ if origin is not None:
+ M[:3, 3] = origin[:3]
+ M[:3, 3] *= 1.0 - factor
+ else:
+ # nonuniform scaling
+ direction = unit_vector(direction[:3])
+ factor = 1.0 - factor
+ M = numpy.identity(4)
+ M[:3, :3] -= factor * numpy.outer(direction, direction)
+ if origin is not None:
+ M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction
+ return M
+
+
+def scale_from_matrix(matrix):
+ """Return scaling factor, origin and direction from scaling matrix.
+
+ >>> factor = random.random() * 10 - 5
+ >>> origin = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> S0 = scale_matrix(factor, origin)
+ >>> factor, origin, direction = scale_from_matrix(S0)
+ >>> S1 = scale_matrix(factor, origin, direction)
+ >>> is_same_transform(S0, S1)
+ True
+ >>> S0 = scale_matrix(factor, origin, direct)
+ >>> factor, origin, direction = scale_from_matrix(S0)
+ >>> S1 = scale_matrix(factor, origin, direction)
+ >>> is_same_transform(S0, S1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ factor = numpy.trace(M33) - 2.0
+ try:
+ # direction: unit eigenvector corresponding to eigenvalue factor
+ w, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(w) - factor) < 1e-8)[0][0]
+ direction = numpy.real(V[:, i]).squeeze()
+ direction /= vector_norm(direction)
+ except IndexError:
+ # uniform scaling
+ factor = (factor + 2.0) / 3.0
+ direction = None
+ # origin: any eigenvector corresponding to eigenvalue 1
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no eigenvector corresponding to eigenvalue 1")
+ origin = numpy.real(V[:, i[-1]]).squeeze()
+ origin /= origin[3]
+ return factor, origin, direction
+
+
+def projection_matrix(point, normal, direction=None,
+ perspective=None, pseudo=False):
+ """Return matrix to project onto plane defined by point and normal.
+
+ Using either perspective point, projection direction, or none of both.
+
+ If pseudo is True, perspective projections will preserve relative depth
+ such that Perspective = dot(Orthogonal, PseudoPerspective).
+
+ >>> P = projection_matrix([0, 0, 0], [1, 0, 0])
+ >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:])
+ True
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(3) - 0.5
+ >>> P0 = projection_matrix(point, normal)
+ >>> P1 = projection_matrix(point, normal, direction=direct)
+ >>> P2 = projection_matrix(point, normal, perspective=persp)
+ >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True)
+ >>> is_same_transform(P2, numpy.dot(P0, P3))
+ True
+ >>> P = projection_matrix([3, 0, 0], [1, 1, 0], [1, 0, 0])
+ >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20
+ >>> v0[3] = 1
+ >>> v1 = numpy.dot(P, v0)
+ >>> numpy.allclose(v1[1], v0[1])
+ True
+ >>> numpy.allclose(v1[0], 3-v1[1])
+ True
+
+ """
+ M = numpy.identity(4)
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
+ normal = unit_vector(normal[:3])
+ if perspective is not None:
+ # perspective projection
+ perspective = numpy.array(perspective[:3], dtype=numpy.float64,
+ copy=False)
+ M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
+ M[:3, :3] -= numpy.outer(perspective, normal)
+ if pseudo:
+ # preserve relative depth
+ M[:3, :3] -= numpy.outer(normal, normal)
+ M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
+ else:
+ M[:3, 3] = numpy.dot(point, normal) * perspective
+ M[3, :3] = -normal
+ M[3, 3] = numpy.dot(perspective, normal)
+ elif direction is not None:
+ # parallel projection
+ direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False)
+ scale = numpy.dot(direction, normal)
+ M[:3, :3] -= numpy.outer(direction, normal) / scale
+ M[:3, 3] = direction * (numpy.dot(point, normal) / scale)
+ else:
+ # orthogonal projection
+ M[:3, :3] -= numpy.outer(normal, normal)
+ M[:3, 3] = numpy.dot(point, normal) * normal
+ return M
+
+
+def projection_from_matrix(matrix, pseudo=False):
+ """Return projection plane and perspective point from projection matrix.
+
+ Return values are same as arguments for projection_matrix function:
+ point, normal, direction, perspective, and pseudo.
+
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.random.random(3) - 0.5
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(3) - 0.5
+ >>> P0 = projection_matrix(point, normal)
+ >>> result = projection_from_matrix(P0)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, direct)
+ >>> result = projection_from_matrix(P0)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False)
+ >>> result = projection_from_matrix(P0, pseudo=False)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True)
+ >>> result = projection_from_matrix(P0, pseudo=True)
+ >>> P1 = projection_matrix(*result)
+ >>> is_same_transform(P0, P1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not pseudo and len(i):
+ # point: any eigenvector corresponding to eigenvalue 1
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ # direction: unit eigenvector corresponding to eigenvalue 0
+ w, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no eigenvector corresponding to eigenvalue 0")
+ direction = numpy.real(V[:, i[0]]).squeeze()
+ direction /= vector_norm(direction)
+ # normal: unit eigenvector of M33.T corresponding to eigenvalue 0
+ w, V = numpy.linalg.eig(M33.T)
+ i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
+ if len(i):
+ # parallel projection
+ normal = numpy.real(V[:, i[0]]).squeeze()
+ normal /= vector_norm(normal)
+ return point, normal, direction, None, False
+ else:
+ # orthogonal projection, where normal equals direction vector
+ return point, direction, None, None, False
+ else:
+ # perspective projection
+ i = numpy.where(abs(numpy.real(w)) > 1e-8)[0]
+ if not len(i):
+ raise ValueError(
+ "no eigenvector not corresponding to eigenvalue 0")
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ normal = - M[3, :3]
+ perspective = M[:3, 3] / numpy.dot(point[:3], normal)
+ if pseudo:
+ perspective -= normal
+ return point, normal, None, perspective, pseudo
+
+
+def clip_matrix(left, right, bottom, top, near, far, perspective=False):
+ """Return matrix to obtain normalized device coordinates from frustum.
+
+ The frustum bounds are axis-aligned along x (left, right),
+ y (bottom, top) and z (near, far).
+
+ Normalized device coordinates are in range [-1, 1] if coordinates are
+ inside the frustum.
+
+ If perspective is True the frustum is a truncated pyramid with the
+ perspective point at origin and direction along z axis, otherwise an
+ orthographic canonical view volume (a box).
+
+ Homogeneous coordinates transformed by the perspective clip matrix
+ need to be dehomogenized (divided by w coordinate).
+
+ >>> frustum = numpy.random.rand(6)
+ >>> frustum[1] += frustum[0]
+ >>> frustum[3] += frustum[2]
+ >>> frustum[5] += frustum[4]
+ >>> M = clip_matrix(perspective=False, *frustum)
+ >>> numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
+ array([-1., -1., -1., 1.])
+ >>> numpy.dot(M, [frustum[1], frustum[3], frustum[5], 1])
+ array([ 1., 1., 1., 1.])
+ >>> M = clip_matrix(perspective=True, *frustum)
+ >>> v = numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
+ >>> v / v[3]
+ array([-1., -1., -1., 1.])
+ >>> v = numpy.dot(M, [frustum[1], frustum[3], frustum[4], 1])
+ >>> v / v[3]
+ array([ 1., 1., -1., 1.])
+
+ """
+ if left >= right or bottom >= top or near >= far:
+ raise ValueError("invalid frustum")
+ if perspective:
+ if near <= _EPS:
+ raise ValueError("invalid frustum: near <= 0")
+ t = 2.0 * near
+ M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0],
+ [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0],
+ [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)],
+ [0.0, 0.0, -1.0, 0.0]]
+ else:
+ M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)],
+ [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)],
+ [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)],
+ [0.0, 0.0, 0.0, 1.0]]
+ return numpy.array(M)
+
+
+def shear_matrix(angle, direction, point, normal):
+ """Return matrix to shear by angle along direction vector on shear plane.
+
+ The shear plane is defined by a point and normal vector. The direction
+ vector must be orthogonal to the plane's normal vector.
+
+ A point P is transformed by the shear matrix into P" such that
+ the vector P-P" is parallel to the direction vector and its extent is
+ given by the angle of P-P'-P", where P' is the orthogonal projection
+ of P onto the shear plane.
+
+ >>> angle = (random.random() - 0.5) * 4*math.pi
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
+ >>> S = shear_matrix(angle, direct, point, normal)
+ >>> numpy.allclose(1, numpy.linalg.det(S))
+ True
+
+ """
+ normal = unit_vector(normal[:3])
+ direction = unit_vector(direction[:3])
+ if abs(numpy.dot(normal, direction)) > 1e-6:
+ raise ValueError("direction and normal vectors are not orthogonal")
+ angle = math.tan(angle)
+ M = numpy.identity(4)
+ M[:3, :3] += angle * numpy.outer(direction, normal)
+ M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction
+ return M
+
+
+def shear_from_matrix(matrix):
+ """Return shear angle, direction and plane from shear matrix.
+
+ >>> angle = (random.random() - 0.5) * 4*math.pi
+ >>> direct = numpy.random.random(3) - 0.5
+ >>> point = numpy.random.random(3) - 0.5
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
+ >>> S0 = shear_matrix(angle, direct, point, normal)
+ >>> angle, direct, point, normal = shear_from_matrix(S0)
+ >>> S1 = shear_matrix(angle, direct, point, normal)
+ >>> is_same_transform(S0, S1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
+ M33 = M[:3, :3]
+ # normal: cross independent eigenvectors corresponding to the eigenvalue 1
+ w, V = numpy.linalg.eig(M33)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-4)[0]
+ if len(i) < 2:
+ raise ValueError("no two linear independent eigenvectors found %s" % w)
+ V = numpy.real(V[:, i]).squeeze().T
+ lenorm = -1.0
+ for i0, i1 in ((0, 1), (0, 2), (1, 2)):
+ n = numpy.cross(V[i0], V[i1])
+ w = vector_norm(n)
+ if w > lenorm:
+ lenorm = w
+ normal = n
+ normal /= lenorm
+ # direction and angle
+ direction = numpy.dot(M33 - numpy.identity(3), normal)
+ angle = vector_norm(direction)
+ direction /= angle
+ angle = math.atan(angle)
+ # point: eigenvector corresponding to eigenvalue 1
+ w, V = numpy.linalg.eig(M)
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
+ if not len(i):
+ raise ValueError("no eigenvector corresponding to eigenvalue 1")
+ point = numpy.real(V[:, i[-1]]).squeeze()
+ point /= point[3]
+ return angle, direction, point, normal
+
+
+def decompose_matrix(matrix):
+ """Return sequence of transformations from transformation matrix.
+
+ matrix : array_like
+ Non-degenerative homogeneous transformation matrix
+
+ Return tuple of:
+ scale : vector of 3 scaling factors
+ shear : list of shear factors for x-y, x-z, y-z axes
+ angles : list of Euler angles about static x, y, z axes
+ translate : translation vector along x, y, z axes
+ perspective : perspective partition of matrix
+
+ Raise ValueError if matrix is of wrong type or degenerative.
+
+ >>> T0 = translation_matrix([1, 2, 3])
+ >>> scale, shear, angles, trans, persp = decompose_matrix(T0)
+ >>> T1 = translation_matrix(trans)
+ >>> numpy.allclose(T0, T1)
+ True
+ >>> S = scale_matrix(0.123)
+ >>> scale, shear, angles, trans, persp = decompose_matrix(S)
+ >>> scale[0]
+ 0.123
+ >>> R0 = euler_matrix(1, 2, 3)
+ >>> scale, shear, angles, trans, persp = decompose_matrix(R0)
+ >>> R1 = euler_matrix(*angles)
+ >>> numpy.allclose(R0, R1)
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=True).T
+ if abs(M[3, 3]) < _EPS:
+ raise ValueError("M[3, 3] is zero")
+ M /= M[3, 3]
+ P = M.copy()
+ P[:, 3] = 0.0, 0.0, 0.0, 1.0
+ if not numpy.linalg.det(P):
+ raise ValueError("matrix is singular")
+
+ scale = numpy.zeros((3, ))
+ shear = [0.0, 0.0, 0.0]
+ angles = [0.0, 0.0, 0.0]
+
+ if any(abs(M[:3, 3]) > _EPS):
+ perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T))
+ M[:, 3] = 0.0, 0.0, 0.0, 1.0
+ else:
+ perspective = numpy.array([0.0, 0.0, 0.0, 1.0])
+
+ translate = M[3, :3].copy()
+ M[3, :3] = 0.0
+
+ row = M[:3, :3].copy()
+ scale[0] = vector_norm(row[0])
+ row[0] /= scale[0]
+ shear[0] = numpy.dot(row[0], row[1])
+ row[1] -= row[0] * shear[0]
+ scale[1] = vector_norm(row[1])
+ row[1] /= scale[1]
+ shear[0] /= scale[1]
+ shear[1] = numpy.dot(row[0], row[2])
+ row[2] -= row[0] * shear[1]
+ shear[2] = numpy.dot(row[1], row[2])
+ row[2] -= row[1] * shear[2]
+ scale[2] = vector_norm(row[2])
+ row[2] /= scale[2]
+ shear[1:] /= scale[2]
+
+ if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0:
+ numpy.negative(scale, scale)
+ numpy.negative(row, row)
+
+ angles[1] = math.asin(-row[0, 2])
+ if math.cos(angles[1]):
+ angles[0] = math.atan2(row[1, 2], row[2, 2])
+ angles[2] = math.atan2(row[0, 1], row[0, 0])
+ else:
+ #angles[0] = math.atan2(row[1, 0], row[1, 1])
+ angles[0] = math.atan2(-row[2, 1], row[1, 1])
+ angles[2] = 0.0
+
+ return scale, shear, angles, translate, perspective
+
+
+def compose_matrix(scale=None, shear=None, angles=None, translate=None,
+ perspective=None):
+ """Return transformation matrix from sequence of transformations.
+
+ This is the inverse of the decompose_matrix function.
+
+ Sequence of transformations:
+ scale : vector of 3 scaling factors
+ shear : list of shear factors for x-y, x-z, y-z axes
+ angles : list of Euler angles about static x, y, z axes
+ translate : translation vector along x, y, z axes
+ perspective : perspective partition of matrix
+
+ >>> scale = numpy.random.random(3) - 0.5
+ >>> shear = numpy.random.random(3) - 0.5
+ >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi)
+ >>> trans = numpy.random.random(3) - 0.5
+ >>> persp = numpy.random.random(4) - 0.5
+ >>> M0 = compose_matrix(scale, shear, angles, trans, persp)
+ >>> result = decompose_matrix(M0)
+ >>> M1 = compose_matrix(*result)
+ >>> is_same_transform(M0, M1)
+ True
+
+ """
+ M = numpy.identity(4)
+ if perspective is not None:
+ P = numpy.identity(4)
+ P[3, :] = perspective[:4]
+ M = numpy.dot(M, P)
+ if translate is not None:
+ T = numpy.identity(4)
+ T[:3, 3] = translate[:3]
+ M = numpy.dot(M, T)
+ if angles is not None:
+ R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
+ M = numpy.dot(M, R)
+ if shear is not None:
+ Z = numpy.identity(4)
+ Z[1, 2] = shear[2]
+ Z[0, 2] = shear[1]
+ Z[0, 1] = shear[0]
+ M = numpy.dot(M, Z)
+ if scale is not None:
+ S = numpy.identity(4)
+ S[0, 0] = scale[0]
+ S[1, 1] = scale[1]
+ S[2, 2] = scale[2]
+ M = numpy.dot(M, S)
+ M /= M[3, 3]
+ return M
+
+
+def orthogonalization_matrix(lengths, angles):
+ """Return orthogonalization matrix for crystallographic cell coordinates.
+
+ Angles are expected in degrees.
+
+ The de-orthogonalization matrix is the inverse.
+
+ >>> O = orthogonalization_matrix([10, 10, 10], [90, 90, 90])
+ >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10)
+ True
+ >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7])
+ >>> numpy.allclose(numpy.sum(O), 43.063229)
+ True
+
+ """
+ a, b, c = lengths
+ angles = numpy.radians(angles)
+ sina, sinb, _ = numpy.sin(angles)
+ cosa, cosb, cosg = numpy.cos(angles)
+ co = (cosa * cosb - cosg) / (sina * sinb)
+ return numpy.array([
+ [ a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0],
+ [-a*sinb*co, b*sina, 0.0, 0.0],
+ [ a*cosb, b*cosa, c, 0.0],
+ [ 0.0, 0.0, 0.0, 1.0]])
+
+
+def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
+ """Return affine transform matrix to register two point sets.
+
+ v0 and v1 are shape (ndims, \*) arrays of at least ndims non-homogeneous
+ coordinates, where ndims is the dimensionality of the coordinate space.
+
+ If shear is False, a similarity transformation matrix is returned.
+ If also scale is False, a rigid/Euclidean transformation matrix
+ is returned.
+
+ By default the algorithm by Hartley and Zissermann [15] is used.
+ If usesvd is True, similarity and Euclidean transformation matrices
+ are calculated by minimizing the weighted sum of squared deviations
+ (RMSD) according to the algorithm by Kabsch [8].
+ Otherwise, and if ndims is 3, the quaternion based algorithm by Horn [9]
+ is used, which is slower when using this Python implementation.
+
+ The returned matrix performs rotation, translation and uniform scaling
+ (if specified).
+
+ >>> v0 = [[0, 1031, 1031, 0], [0, 0, 1600, 1600]]
+ >>> v1 = [[675, 826, 826, 677], [55, 52, 281, 277]]
+ >>> affine_matrix_from_points(v0, v1)
+ array([[ 0.14549, 0.00062, 675.50008],
+ [ 0.00048, 0.14094, 53.24971],
+ [ 0. , 0. , 1. ]])
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
+ >>> R = random_rotation_matrix(numpy.random.random(3))
+ >>> S = scale_matrix(random.random())
+ >>> M = concatenate_matrices(T, R, S)
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
+ >>> v0[3] = 1
+ >>> v1 = numpy.dot(M, v0)
+ >>> v0[:3] += numpy.random.normal(0, 1e-8, 300).reshape(3, -1)
+ >>> M = affine_matrix_from_points(v0[:3], v1[:3])
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+
+ More examples in superimposition_matrix()
+
+ """
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=True)
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=True)
+
+ ndims = v0.shape[0]
+ if ndims < 2 or v0.shape[1] < ndims or v0.shape != v1.shape:
+ raise ValueError("input arrays are of wrong shape or type")
+
+ # move centroids to origin
+ t0 = -numpy.mean(v0, axis=1)
+ M0 = numpy.identity(ndims+1)
+ M0[:ndims, ndims] = t0
+ v0 += t0.reshape(ndims, 1)
+ t1 = -numpy.mean(v1, axis=1)
+ M1 = numpy.identity(ndims+1)
+ M1[:ndims, ndims] = t1
+ v1 += t1.reshape(ndims, 1)
+
+ if shear:
+ # Affine transformation
+ A = numpy.concatenate((v0, v1), axis=0)
+ u, s, vh = numpy.linalg.svd(A.T)
+ vh = vh[:ndims].T
+ B = vh[:ndims]
+ C = vh[ndims:2*ndims]
+ t = numpy.dot(C, numpy.linalg.pinv(B))
+ t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1)
+ M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,)))
+ elif usesvd or ndims != 3:
+ # Rigid transformation via SVD of covariance matrix
+ u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
+ # rotation matrix from SVD orthonormal bases
+ R = numpy.dot(u, vh)
+ if numpy.linalg.det(R) < 0.0:
+ # R does not constitute right handed system
+ R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0)
+ s[-1] *= -1.0
+ # homogeneous transformation matrix
+ M = numpy.identity(ndims+1)
+ M[:ndims, :ndims] = R
+ else:
+ # Rigid transformation matrix via quaternion
+ # compute symmetric matrix N
+ xx, yy, zz = numpy.sum(v0 * v1, axis=1)
+ xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
+ xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
+ N = [[xx+yy+zz, 0.0, 0.0, 0.0],
+ [yz-zy, xx-yy-zz, 0.0, 0.0],
+ [zx-xz, xy+yx, yy-xx-zz, 0.0],
+ [xy-yx, zx+xz, yz+zy, zz-xx-yy]]
+ # quaternion: eigenvector corresponding to most positive eigenvalue
+ w, V = numpy.linalg.eigh(N)
+ q = V[:, numpy.argmax(w)]
+ q /= vector_norm(q) # unit quaternion
+ # homogeneous transformation matrix
+ M = quaternion_matrix(q)
+
+ if scale and not shear:
+ # Affine transformation; scale is ratio of RMS deviations from centroid
+ v0 *= v0
+ v1 *= v1
+ M[:ndims, :ndims] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0))
+
+ # move centroids back
+ M = numpy.dot(numpy.linalg.inv(M1), numpy.dot(M, M0))
+ M /= M[ndims, ndims]
+ return M
+
+
+def superimposition_matrix(v0, v1, scale=False, usesvd=True):
+ """Return matrix to transform given 3D point set into second point set.
+
+ v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 points.
+
+ The parameters scale and usesvd are explained in the more general
+ affine_matrix_from_points function.
+
+ The returned matrix is a similarity or Euclidean transformation matrix.
+ This function has a fast C implementation in transformations.c.
+
+ >>> v0 = numpy.random.rand(3, 10)
+ >>> M = superimposition_matrix(v0, v0)
+ >>> numpy.allclose(M, numpy.identity(4))
+ True
+ >>> R = random_rotation_matrix(numpy.random.random(3))
+ >>> v0 = [[1,0,0], [0,1,0], [0,0,1], [1,1,1]]
+ >>> v1 = numpy.dot(R, v0)
+ >>> M = superimposition_matrix(v0, v1)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
+ >>> v0[3] = 1
+ >>> v1 = numpy.dot(R, v0)
+ >>> M = superimposition_matrix(v0, v1)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> S = scale_matrix(random.random())
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
+ >>> M = concatenate_matrices(T, R, S)
+ >>> v1 = numpy.dot(M, v0)
+ >>> v0[:3] += numpy.random.normal(0, 1e-9, 300).reshape(3, -1)
+ >>> M = superimposition_matrix(v0, v1, scale=True)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
+ True
+ >>> v = numpy.empty((4, 100, 3))
+ >>> v[:, :, 0] = v0
+ >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
+ >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0]))
+ True
+
+ """
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
+ return affine_matrix_from_points(v0, v1, shear=False,
+ scale=scale, usesvd=usesvd)
+
+
+def euler_matrix(ai, aj, ak, axes='sxyz'):
+ """Return homogeneous rotation matrix from Euler angles and axis sequence.
+
+ ai, aj, ak : Euler's roll, pitch and yaw angles
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ >>> R = euler_matrix(1, 2, 3, 'syxz')
+ >>> numpy.allclose(numpy.sum(R[0]), -1.34786452)
+ True
+ >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1))
+ >>> numpy.allclose(numpy.sum(R[0]), -0.383436184)
+ True
+ >>> ai, aj, ak = (4*math.pi) * (numpy.random.random(3) - 0.5)
+ >>> for axes in _AXES2TUPLE.keys():
+ ... R = euler_matrix(ai, aj, ak, axes)
+ >>> for axes in _TUPLE2AXES.keys():
+ ... R = euler_matrix(ai, aj, ak, axes)
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
+ except (AttributeError, KeyError):
+ _TUPLE2AXES[axes] # validation
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i+parity]
+ k = _NEXT_AXIS[i-parity+1]
+
+ if frame:
+ ai, ak = ak, ai
+ if parity:
+ ai, aj, ak = -ai, -aj, -ak
+
+ si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
+ ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
+ cc, cs = ci*ck, ci*sk
+ sc, ss = si*ck, si*sk
+
+ M = numpy.identity(4)
+ if repetition:
+ M[i, i] = cj
+ M[i, j] = sj*si
+ M[i, k] = sj*ci
+ M[j, i] = sj*sk
+ M[j, j] = -cj*ss+cc
+ M[j, k] = -cj*cs-sc
+ M[k, i] = -sj*ck
+ M[k, j] = cj*sc+cs
+ M[k, k] = cj*cc-ss
+ else:
+ M[i, i] = cj*ck
+ M[i, j] = sj*sc-cs
+ M[i, k] = sj*cc+ss
+ M[j, i] = cj*sk
+ M[j, j] = sj*ss+cc
+ M[j, k] = sj*cs-sc
+ M[k, i] = -sj
+ M[k, j] = cj*si
+ M[k, k] = cj*ci
+ return M
+
+
+def euler_from_matrix(matrix, axes='sxyz'):
+ """Return Euler angles from rotation matrix for specified axis sequence.
+
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ Note that many Euler angle triplets can describe one matrix.
+
+ >>> R0 = euler_matrix(1, 2, 3, 'syxz')
+ >>> al, be, ga = euler_from_matrix(R0, 'syxz')
+ >>> R1 = euler_matrix(al, be, ga, 'syxz')
+ >>> numpy.allclose(R0, R1)
+ True
+ >>> angles = (4*math.pi) * (numpy.random.random(3) - 0.5)
+ >>> for axes in _AXES2TUPLE.keys():
+ ... R0 = euler_matrix(axes=axes, *angles)
+ ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes))
+ ... if not numpy.allclose(R0, R1): print(axes, "failed")
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
+ except (AttributeError, KeyError):
+ _TUPLE2AXES[axes] # validation
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis
+ j = _NEXT_AXIS[i+parity]
+ k = _NEXT_AXIS[i-parity+1]
+
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
+ if repetition:
+ sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
+ if sy > _EPS:
+ ax = math.atan2( M[i, j], M[i, k])
+ ay = math.atan2( sy, M[i, i])
+ az = math.atan2( M[j, i], -M[k, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2( sy, M[i, i])
+ az = 0.0
+ else:
+ cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
+ if cy > _EPS:
+ ax = math.atan2( M[k, j], M[k, k])
+ ay = math.atan2(-M[k, i], cy)
+ az = math.atan2( M[j, i], M[i, i])
+ else:
+ ax = math.atan2(-M[j, k], M[j, j])
+ ay = math.atan2(-M[k, i], cy)
+ az = 0.0
+
+ if parity:
+ ax, ay, az = -ax, -ay, -az
+ if frame:
+ ax, az = az, ax
+ return ax, ay, az
+
+
+def euler_from_quaternion(quaternion, axes='sxyz'):
+ """Return Euler angles from quaternion for specified axis sequence.
+
+ >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0])
+ >>> numpy.allclose(angles, [0.123, 0, 0])
+ True
+
+ """
+ return euler_from_matrix(quaternion_matrix(quaternion), axes)
+
+
+def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
+ """Return quaternion from Euler angles and axis sequence.
+
+ ai, aj, ak : Euler's roll, pitch and yaw angles
+ axes : One of 24 axis sequences as string or encoded tuple
+
+ >>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
+ >>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
+ True
+
+ """
+ try:
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
+ except (AttributeError, KeyError):
+ _TUPLE2AXES[axes] # validation
+ firstaxis, parity, repetition, frame = axes
+
+ i = firstaxis + 1
+ j = _NEXT_AXIS[i+parity-1] + 1
+ k = _NEXT_AXIS[i-parity] + 1
+
+ if frame:
+ ai, ak = ak, ai
+ if parity:
+ aj = -aj
+
+ ai /= 2.0
+ aj /= 2.0
+ ak /= 2.0
+ ci = math.cos(ai)
+ si = math.sin(ai)
+ cj = math.cos(aj)
+ sj = math.sin(aj)
+ ck = math.cos(ak)
+ sk = math.sin(ak)
+ cc = ci*ck
+ cs = ci*sk
+ sc = si*ck
+ ss = si*sk
+
+ q = numpy.empty((4, ))
+ if repetition:
+ q[0] = cj*(cc - ss)
+ q[i] = cj*(cs + sc)
+ q[j] = sj*(cc + ss)
+ q[k] = sj*(cs - sc)
+ else:
+ q[0] = cj*cc + sj*ss
+ q[i] = cj*sc - sj*cs
+ q[j] = cj*ss + sj*cc
+ q[k] = cj*cs - sj*sc
+ if parity:
+ q[j] *= -1.0
+
+ return q
+
+
+def quaternion_about_axis(angle, axis):
+ """Return quaternion for rotation about axis.
+
+ >>> q = quaternion_about_axis(0.123, [1, 0, 0])
+ >>> numpy.allclose(q, [0.99810947, 0.06146124, 0, 0])
+ True
+
+ """
+ q = numpy.array([0.0, axis[0], axis[1], axis[2]])
+ qlen = vector_norm(q)
+ if qlen > _EPS:
+ q *= math.sin(angle/2.0) / qlen
+ q[0] = math.cos(angle/2.0)
+ return q
+
+
+def quaternion_matrix(quaternion):
+ """Return homogeneous rotation matrix from quaternion.
+
+ >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
+ >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
+ True
+ >>> M = quaternion_matrix([1, 0, 0, 0])
+ >>> numpy.allclose(M, numpy.identity(4))
+ True
+ >>> M = quaternion_matrix([0, 1, 0, 0])
+ >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
+ True
+
+ """
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
+ n = numpy.dot(q, q)
+ if n < _EPS:
+ return numpy.identity(4)
+ q *= math.sqrt(2.0 / n)
+ q = numpy.outer(q, q)
+ return numpy.array([
+ [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
+ [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
+ [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
+ [ 0.0, 0.0, 0.0, 1.0]])
+
+
+def quaternion_from_matrix(matrix, isprecise=False):
+ """Return quaternion from rotation matrix.
+
+ If isprecise is True, the input matrix is assumed to be a precise rotation
+ matrix and a faster algorithm is used.
+
+ >>> q = quaternion_from_matrix(numpy.identity(4), True)
+ >>> numpy.allclose(q, [1, 0, 0, 0])
+ True
+ >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
+ >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
+ True
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
+ >>> q = quaternion_from_matrix(R, True)
+ >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
+ True
+ >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
+ ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
+ >>> q = quaternion_from_matrix(R)
+ >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
+ True
+ >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
+ ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
+ >>> q = quaternion_from_matrix(R)
+ >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
+ True
+ >>> R = random_rotation_matrix()
+ >>> q = quaternion_from_matrix(R)
+ >>> is_same_transform(R, quaternion_matrix(q))
+ True
+ >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
+ >>> numpy.allclose(quaternion_from_matrix(R, isprecise=False),
+ ... quaternion_from_matrix(R, isprecise=True))
+ True
+
+ """
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
+ if isprecise:
+ q = numpy.empty((4, ))
+ t = numpy.trace(M)
+ if t > M[3, 3]:
+ q[0] = t
+ q[3] = M[1, 0] - M[0, 1]
+ q[2] = M[0, 2] - M[2, 0]
+ q[1] = M[2, 1] - M[1, 2]
+ else:
+ i, j, k = 1, 2, 3
+ if M[1, 1] > M[0, 0]:
+ i, j, k = 2, 3, 1
+ if M[2, 2] > M[i, i]:
+ i, j, k = 3, 1, 2
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
+ q[i] = t
+ q[j] = M[i, j] + M[j, i]
+ q[k] = M[k, i] + M[i, k]
+ q[3] = M[k, j] - M[j, k]
+ q *= 0.5 / math.sqrt(t * M[3, 3])
+ else:
+ m00 = M[0, 0]
+ m01 = M[0, 1]
+ m02 = M[0, 2]
+ m10 = M[1, 0]
+ m11 = M[1, 1]
+ m12 = M[1, 2]
+ m20 = M[2, 0]
+ m21 = M[2, 1]
+ m22 = M[2, 2]
+ # symmetric matrix K
+ K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
+ [m01+m10, m11-m00-m22, 0.0, 0.0],
+ [m02+m20, m12+m21, m22-m00-m11, 0.0],
+ [m21-m12, m02-m20, m10-m01, m00+m11+m22]])
+ K /= 3.0
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
+ w, V = numpy.linalg.eigh(K)
+ q = V[[3, 0, 1, 2], numpy.argmax(w)]
+ if q[0] < 0.0:
+ numpy.negative(q, q)
+ return q
+
+
+def quaternion_multiply(quaternion1, quaternion0):
+ """Return multiplication of two quaternions.
+
+ >>> q = quaternion_multiply([4, 1, -2, 3], [8, -5, 6, 7])
+ >>> numpy.allclose(q, [28, -44, -14, 48])
+ True
+
+ """
+ w0, x0, y0, z0 = quaternion0
+ w1, x1, y1, z1 = quaternion1
+ return numpy.array([-x1*x0 - y1*y0 - z1*z0 + w1*w0,
+ x1*w0 + y1*z0 - z1*y0 + w1*x0,
+ -x1*z0 + y1*w0 + z1*x0 + w1*y0,
+ x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64)
+
+
+def quaternion_conjugate(quaternion):
+ """Return conjugate of quaternion.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = quaternion_conjugate(q0)
+ >>> q1[0] == q0[0] and all(q1[1:] == -q0[1:])
+ True
+
+ """
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
+ numpy.negative(q[1:], q[1:])
+ return q
+
+
+def quaternion_inverse(quaternion):
+ """Return inverse of quaternion.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = quaternion_inverse(q0)
+ >>> numpy.allclose(quaternion_multiply(q0, q1), [1, 0, 0, 0])
+ True
+
+ """
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
+ numpy.negative(q[1:], q[1:])
+ return q / numpy.dot(q, q)
+
+
+def quaternion_real(quaternion):
+ """Return real part of quaternion.
+
+ >>> quaternion_real([3, 0, 1, 2])
+ 3.0
+
+ """
+ return float(quaternion[0])
+
+
+def quaternion_imag(quaternion):
+ """Return imaginary part of quaternion.
+
+ >>> quaternion_imag([3, 0, 1, 2])
+ array([ 0., 1., 2.])
+
+ """
+ return numpy.array(quaternion[1:4], dtype=numpy.float64, copy=True)
+
+
+def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
+ """Return spherical linear interpolation between two quaternions.
+
+ >>> q0 = random_quaternion()
+ >>> q1 = random_quaternion()
+ >>> q = quaternion_slerp(q0, q1, 0)
+ >>> numpy.allclose(q, q0)
+ True
+ >>> q = quaternion_slerp(q0, q1, 1, 1)
+ >>> numpy.allclose(q, q1)
+ True
+ >>> q = quaternion_slerp(q0, q1, 0.5)
+ >>> angle = math.acos(numpy.dot(q0, q))
+ >>> numpy.allclose(2, math.acos(numpy.dot(q0, q1)) / angle) or \
+ numpy.allclose(2, math.acos(-numpy.dot(q0, q1)) / angle)
+ True
+
+ """
+ q0 = unit_vector(quat0[:4])
+ q1 = unit_vector(quat1[:4])
+ if fraction == 0.0:
+ return q0
+ elif fraction == 1.0:
+ return q1
+ d = numpy.dot(q0, q1)
+ if abs(abs(d) - 1.0) < _EPS:
+ return q0
+ if shortestpath and d < 0.0:
+ # invert rotation
+ d = -d
+ numpy.negative(q1, q1)
+ angle = math.acos(d) + spin * math.pi
+ if abs(angle) < _EPS:
+ return q0
+ isin = 1.0 / math.sin(angle)
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
+ q1 *= math.sin(fraction * angle) * isin
+ q0 += q1
+ return q0
+
+
+def random_quaternion(rand=None):
+ """Return uniform random unit quaternion.
+
+ rand: array like or None
+ Three independent random variables that are uniformly distributed
+ between 0 and 1.
+
+ >>> q = random_quaternion()
+ >>> numpy.allclose(1, vector_norm(q))
+ True
+ >>> q = random_quaternion(numpy.random.random(3))
+ >>> len(q.shape), q.shape[0]==4
+ (1, True)
+
+ """
+ if rand is None:
+ rand = numpy.random.rand(3)
+ else:
+ assert len(rand) == 3
+ r1 = numpy.sqrt(1.0 - rand[0])
+ r2 = numpy.sqrt(rand[0])
+ pi2 = math.pi * 2.0
+ t1 = pi2 * rand[1]
+ t2 = pi2 * rand[2]
+ return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1,
+ numpy.cos(t1)*r1, numpy.sin(t2)*r2])
+
+
+def random_rotation_matrix(rand=None):
+ """Return uniform random rotation matrix.
+
+ rand: array like
+ Three independent random variables that are uniformly distributed
+ between 0 and 1 for each returned quaternion.
+
+ >>> R = random_rotation_matrix()
+ >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4))
+ True
+
+ """
+ return quaternion_matrix(random_quaternion(rand))
+
+
+class Arcball(object):
+ """Virtual Trackball Control.
+
+ >>> ball = Arcball()
+ >>> ball = Arcball(initial=numpy.identity(4))
+ >>> ball.place([320, 320], 320)
+ >>> ball.down([500, 250])
+ >>> ball.drag([475, 275])
+ >>> R = ball.matrix()
+ >>> numpy.allclose(numpy.sum(R), 3.90583455)
+ True
+ >>> ball = Arcball(initial=[1, 0, 0, 0])
+ >>> ball.place([320, 320], 320)
+ >>> ball.setaxes([1, 1, 0], [-1, 1, 0])
+ >>> ball.constrain = True
+ >>> ball.down([400, 200])
+ >>> ball.drag([200, 400])
+ >>> R = ball.matrix()
+ >>> numpy.allclose(numpy.sum(R), 0.2055924)
+ True
+ >>> ball.next()
+
+ """
+ def __init__(self, initial=None):
+ """Initialize virtual trackball control.
+
+ initial : quaternion or rotation matrix
+
+ """
+ self._axis = None
+ self._axes = None
+ self._radius = 1.0
+ self._center = [0.0, 0.0]
+ self._vdown = numpy.array([0.0, 0.0, 1.0])
+ self._constrain = False
+ if initial is None:
+ self._qdown = numpy.array([1.0, 0.0, 0.0, 0.0])
+ else:
+ initial = numpy.array(initial, dtype=numpy.float64)
+ if initial.shape == (4, 4):
+ self._qdown = quaternion_from_matrix(initial)
+ elif initial.shape == (4, ):
+ initial /= vector_norm(initial)
+ self._qdown = initial
+ else:
+ raise ValueError("initial not a quaternion or matrix")
+ self._qnow = self._qpre = self._qdown
+
+ def place(self, center, radius):
+ """Place Arcball, e.g. when window size changes.
+
+ center : sequence[2]
+ Window coordinates of trackball center.
+ radius : float
+ Radius of trackball in window coordinates.
+
+ """
+ self._radius = float(radius)
+ self._center[0] = center[0]
+ self._center[1] = center[1]
+
+ def setaxes(self, *axes):
+ """Set axes to constrain rotations."""
+ if axes is None:
+ self._axes = None
+ else:
+ self._axes = [unit_vector(axis) for axis in axes]
+
+ @property
+ def constrain(self):
+ """Return state of constrain to axis mode."""
+ return self._constrain
+
+ @constrain.setter
+ def constrain(self, value):
+ """Set state of constrain to axis mode."""
+ self._constrain = bool(value)
+
+ def down(self, point):
+ """Set initial cursor window coordinates and pick constrain-axis."""
+ self._vdown = arcball_map_to_sphere(point, self._center, self._radius)
+ self._qdown = self._qpre = self._qnow
+ if self._constrain and self._axes is not None:
+ self._axis = arcball_nearest_axis(self._vdown, self._axes)
+ self._vdown = arcball_constrain_to_axis(self._vdown, self._axis)
+ else:
+ self._axis = None
+
+ def drag(self, point):
+ """Update current cursor window coordinates."""
+ vnow = arcball_map_to_sphere(point, self._center, self._radius)
+ if self._axis is not None:
+ vnow = arcball_constrain_to_axis(vnow, self._axis)
+ self._qpre = self._qnow
+ t = numpy.cross(self._vdown, vnow)
+ if numpy.dot(t, t) < _EPS:
+ self._qnow = self._qdown
+ else:
+ q = [numpy.dot(self._vdown, vnow), t[0], t[1], t[2]]
+ self._qnow = quaternion_multiply(q, self._qdown)
+
+ def next(self, acceleration=0.0):
+ """Continue rotation in direction of last drag."""
+ q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
+ self._qpre, self._qnow = self._qnow, q
+
+ def matrix(self):
+ """Return homogeneous rotation matrix."""
+ return quaternion_matrix(self._qnow)
+
+
+def arcball_map_to_sphere(point, center, radius):
+ """Return unit sphere coordinates from window coordinates."""
+ v0 = (point[0] - center[0]) / radius
+ v1 = (center[1] - point[1]) / radius
+ n = v0*v0 + v1*v1
+ if n > 1.0:
+ # position outside of sphere
+ n = math.sqrt(n)
+ return numpy.array([v0/n, v1/n, 0.0])
+ else:
+ return numpy.array([v0, v1, math.sqrt(1.0 - n)])
+
+
+def arcball_constrain_to_axis(point, axis):
+ """Return sphere point perpendicular to axis."""
+ v = numpy.array(point, dtype=numpy.float64, copy=True)
+ a = numpy.array(axis, dtype=numpy.float64, copy=True)
+ v -= a * numpy.dot(a, v) # on plane
+ n = vector_norm(v)
+ if n > _EPS:
+ if v[2] < 0.0:
+ numpy.negative(v, v)
+ v /= n
+ return v
+ if a[2] == 1.0:
+ return numpy.array([1.0, 0.0, 0.0])
+ return unit_vector([-a[1], a[0], 0.0])
+
+
+def arcball_nearest_axis(point, axes):
+ """Return axis, which arc is nearest to point."""
+ point = numpy.array(point, dtype=numpy.float64, copy=False)
+ nearest = None
+ mx = -1.0
+ for axis in axes:
+ t = numpy.dot(arcball_constrain_to_axis(point, axis), point)
+ if t > mx:
+ nearest = axis
+ mx = t
+ return nearest
+
+
+# epsilon for testing whether a number is close to zero
+_EPS = numpy.finfo(float).eps * 4.0
+
+# axis sequences for Euler angles
+_NEXT_AXIS = [1, 2, 0, 1]
+
+# map axes strings to/from tuples of inner axis, parity, repetition, frame
+_AXES2TUPLE = {
+ 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
+ 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
+ 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
+ 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
+ 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
+ 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
+ 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
+ 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
+
+_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
+
+
+def vector_norm(data, axis=None, out=None):
+ """Return length, i.e. Euclidean norm, of ndarray along axis.
+
+ >>> v = numpy.random.random(3)
+ >>> n = vector_norm(v)
+ >>> numpy.allclose(n, numpy.linalg.norm(v))
+ True
+ >>> v = numpy.random.rand(6, 5, 3)
+ >>> n = vector_norm(v, axis=-1)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2)))
+ True
+ >>> n = vector_norm(v, axis=1)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
+ True
+ >>> v = numpy.random.rand(5, 4, 3)
+ >>> n = numpy.empty((5, 3))
+ >>> vector_norm(v, axis=1, out=n)
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
+ True
+ >>> vector_norm([])
+ 0.0
+ >>> vector_norm([1])
+ 1.0
+
+ """
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
+ if out is None:
+ if data.ndim == 1:
+ return math.sqrt(numpy.dot(data, data))
+ data *= data
+ out = numpy.atleast_1d(numpy.sum(data, axis=axis))
+ numpy.sqrt(out, out)
+ return out
+ else:
+ data *= data
+ numpy.sum(data, axis=axis, out=out)
+ numpy.sqrt(out, out)
+
+
+def unit_vector(data, axis=None, out=None):
+ """Return ndarray normalized by length, i.e. Euclidean norm, along axis.
+
+ >>> v0 = numpy.random.random(3)
+ >>> v1 = unit_vector(v0)
+ >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
+ True
+ >>> v0 = numpy.random.rand(5, 4, 3)
+ >>> v1 = unit_vector(v0, axis=-1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> v1 = unit_vector(v0, axis=1)
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> v1 = numpy.empty((5, 4, 3))
+ >>> unit_vector(v0, axis=1, out=v1)
+ >>> numpy.allclose(v1, v2)
+ True
+ >>> list(unit_vector([]))
+ []
+ >>> list(unit_vector([1]))
+ [1.0]
+
+ """
+ if out is None:
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
+ if data.ndim == 1:
+ data /= math.sqrt(numpy.dot(data, data))
+ return data
+ else:
+ if out is not data:
+ out[:] = numpy.array(data, copy=False)
+ data = out
+ length = numpy.atleast_1d(numpy.sum(data*data, axis))
+ numpy.sqrt(length, length)
+ if axis is not None:
+ length = numpy.expand_dims(length, axis)
+ data /= length
+ if out is None:
+ return data
+
+
+def random_vector(size):
+ """Return array of random doubles in the half-open interval [0.0, 1.0).
+
+ >>> v = random_vector(10000)
+ >>> numpy.all(v >= 0) and numpy.all(v < 1)
+ True
+ >>> v0 = random_vector(10)
+ >>> v1 = random_vector(10)
+ >>> numpy.any(v0 == v1)
+ False
+
+ """
+ return numpy.random.random(size)
+
+
+def vector_product(v0, v1, axis=0):
+ """Return vector perpendicular to vectors.
+
+ >>> v = vector_product([2, 0, 0], [0, 3, 0])
+ >>> numpy.allclose(v, [0, 0, 6])
+ True
+ >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
+ >>> v1 = [[3], [0], [0]]
+ >>> v = vector_product(v0, v1)
+ >>> numpy.allclose(v, [[0, 0, 0, 0], [0, 0, 6, 6], [0, -6, 0, -6]])
+ True
+ >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
+ >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
+ >>> v = vector_product(v0, v1, axis=1)
+ >>> numpy.allclose(v, [[0, 0, 6], [0, -6, 0], [6, 0, 0], [0, -6, 6]])
+ True
+
+ """
+ return numpy.cross(v0, v1, axis=axis)
+
+
+def angle_between_vectors(v0, v1, directed=True, axis=0):
+ """Return angle between vectors.
+
+ If directed is False, the input vectors are interpreted as undirected axes,
+ i.e. the maximum angle is pi/2.
+
+ >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3])
+ >>> numpy.allclose(a, math.pi)
+ True
+ >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3], directed=False)
+ >>> numpy.allclose(a, 0)
+ True
+ >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
+ >>> v1 = [[3], [0], [0]]
+ >>> a = angle_between_vectors(v0, v1)
+ >>> numpy.allclose(a, [0, 1.5708, 1.5708, 0.95532])
+ True
+ >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
+ >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
+ >>> a = angle_between_vectors(v0, v1, axis=1)
+ >>> numpy.allclose(a, [1.5708, 1.5708, 1.5708, 0.95532])
+ True
+
+ """
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)
+ dot = numpy.sum(v0 * v1, axis=axis)
+ dot /= vector_norm(v0, axis=axis) * vector_norm(v1, axis=axis)
+ return numpy.arccos(dot if directed else numpy.fabs(dot))
+
+
+def inverse_matrix(matrix):
+ """Return inverse of square transformation matrix.
+
+ >>> M0 = random_rotation_matrix()
+ >>> M1 = inverse_matrix(M0.T)
+ >>> numpy.allclose(M1, numpy.linalg.inv(M0.T))
+ True
+ >>> for size in range(1, 7):
+ ... M0 = numpy.random.rand(size, size)
+ ... M1 = inverse_matrix(M0)
+ ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print(size)
+
+ """
+ return numpy.linalg.inv(matrix)
+
+
+def concatenate_matrices(*matrices):
+ """Return concatenation of series of transformation matrices.
+
+ >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5
+ >>> numpy.allclose(M, concatenate_matrices(M))
+ True
+ >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T))
+ True
+
+ """
+ M = numpy.identity(4)
+ for i in matrices:
+ M = numpy.dot(M, i)
+ return M
+
+
+def is_same_transform(matrix0, matrix1):
+ """Return True if two matrices perform same transformation.
+
+ >>> is_same_transform(numpy.identity(4), numpy.identity(4))
+ True
+ >>> is_same_transform(numpy.identity(4), random_rotation_matrix())
+ False
+
+ """
+ matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True)
+ matrix0 /= matrix0[3, 3]
+ matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True)
+ matrix1 /= matrix1[3, 3]
+ return numpy.allclose(matrix0, matrix1)
+
+
+def _import_module(name, package=None, warn=True, prefix='_py_', ignore='_'):
+ """Try import all public attributes from module into global namespace.
+
+ Existing attributes with name clashes are renamed with prefix.
+ Attributes starting with underscore are ignored by default.
+
+ Return True on successful import.
+
+ """
+ import warnings
+ from importlib import import_module
+ try:
+ if not package:
+ module = import_module(name)
+ else:
+ module = import_module('.' + name, package=package)
+ except ImportError:
+ if warn:
+ #warnings.warn("failed to import module %s" % name)
+ pass
+ else:
+ for attr in dir(module):
+ if ignore and attr.startswith(ignore):
+ continue
+ if prefix:
+ if attr in globals():
+ globals()[prefix + attr] = globals()[attr]
+ elif warn:
+ warnings.warn("no Python implementation of " + attr)
+ globals()[attr] = getattr(module, attr)
+ return True
+
+
+_import_module('_transformations')
+
+if __name__ == "__main__":
+ import doctest
+ import random # used in doctests
+ numpy.set_printoptions(suppress=True, precision=5)
+ doctest.testmod()
+
diff --git a/imcui/third_party/SOLD2/notebooks/__init__.py b/imcui/third_party/SOLD2/notebooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/setup.py b/imcui/third_party/SOLD2/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f72fecdc54cf9b43a7fc55144470e83c5a862d
--- /dev/null
+++ b/imcui/third_party/SOLD2/setup.py
@@ -0,0 +1,4 @@
+from setuptools import setup
+
+
+setup(name='sold2', version="0.0", packages=['sold2'])
diff --git a/imcui/third_party/SOLD2/sold2/config/__init__.py b/imcui/third_party/SOLD2/sold2/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/sold2/config/export_line_features.yaml b/imcui/third_party/SOLD2/sold2/config/export_line_features.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f19c7b6d684b7a826d6f2909b8c9f94528fdbf94
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/export_line_features.yaml
@@ -0,0 +1,80 @@
+### [Model config]
+model_cfg:
+ ### [Model parameters]
+ model_name: "lcnn_simple"
+ model_architecture: "simple"
+ # Backbone related config
+ backbone: "lcnn"
+ backbone_cfg:
+ input_channel: 1 # Use RGB images or grayscale images.
+ depth: 4
+ num_stacks: 2
+ num_blocks: 1
+ num_classes: 5
+ # Junction decoder related config
+ junction_decoder: "superpoint_decoder"
+ junc_decoder_cfg:
+ # Heatmap decoder related config
+ heatmap_decoder: "pixel_shuffle"
+ heatmap_decoder_cfg:
+ # Descriptor decoder related config
+ descriptor_decoder: "superpoint_descriptor"
+ descriptor_decoder_cfg:
+ # Shared configurations
+ grid_size: 8
+ keep_border_valid: True
+ # Threshold of junction detection
+ detection_thresh: 0.0153846 # 1/65
+ max_num_junctions: 300
+ # Threshold of heatmap detection
+ prob_thresh: 0.5
+
+ ### [Loss parameters]
+ weighting_policy: "dynamic"
+ # [Heatmap loss]
+ w_heatmap: 0.
+ w_heatmap_class: 1
+ heatmap_loss_func: "cross_entropy"
+ heatmap_loss_cfg:
+ policy: "dynamic"
+ # [Junction loss]
+ w_junc: 0.
+ junction_loss_func: "superpoint"
+ junction_loss_cfg:
+ policy: "dynamic"
+ # [Descriptor loss]
+ w_desc: 0.
+ descriptor_loss_func: "regular_sampling"
+ descriptor_loss_cfg:
+ dist_threshold: 8
+ grid_size: 4
+ margin: 1
+ policy: "dynamic"
+
+### [Line detector config]
+line_detector_cfg:
+ detect_thresh: 0.5
+ num_samples: 64
+ sampling_method: "local_max"
+ inlier_thresh: 0.99
+ use_candidate_suppression: True
+ nms_dist_tolerance: 3.
+ use_heatmap_refinement: True
+ heatmap_refine_cfg:
+ mode: "local"
+ ratio: 0.2
+ valid_thresh: 0.001
+ num_blocks: 20
+ overlap_ratio: 0.5
+ use_junction_refinement: True
+ junction_refine_cfg:
+ num_perturbs: 9
+ perturb_interval: 0.25
+
+### [Line matcher config]
+line_matcher_cfg:
+ cross_check: True
+ num_samples: 5
+ min_dist_pts: 8
+ top_k_candidates: 10
+ grid_size: 4
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/config/holicity_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/holicity_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..72e9380dbf496dc4b4d6430d58534e0663c85f0e
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/holicity_dataset.yaml
@@ -0,0 +1,76 @@
+### General dataset parameters
+dataset_name: "holicity"
+train_splits: ["2018-01"] # 5720 images
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Ground truth source ('official' or path to the exported h5 dataset.)
+#gt_source_train: "" # Fill with your own export file
+#gt_source_test: "" # Fill with your own export file
+# Return type: (1) single (to train the detector only)
+# or (2) paired_desc (to train the detector + descriptor)
+return_type: "single"
+random_seed: 0
+
+### Descriptor training parameters
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+
+### Data preprocessing configuration
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: True
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+
+### Homography adaptation configuration
+homography_adaptation:
+ num_iter: 100
+ valid_border_margin: 3
+ min_counts: 30
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/config/merge_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/merge_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f70465b71e507cbc9f258a8bbf45f41e435ee9b0
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/merge_dataset.yaml
@@ -0,0 +1,54 @@
+dataset_name: "merge"
+datasets: ["wireframe", "holicity"]
+weights: [0.5, 0.5]
+gt_source_train: ["", ""] # Fill with your own [wireframe, holicity] exported ground-truth
+gt_source_test: ["", ""] # Fill with your own [wireframe, holicity] exported ground-truth
+train_splits: ["", "2018-01"]
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Return type: (1) single (original version) (2) paired
+return_type: "paired_desc"
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+# Random seed
+random_seed: 0
+# Date preprocessing configuration.
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ photometric:
+ enable: True
+ primitives: [
+ 'random_brightness', 'random_contrast', 'additive_speckle_noise',
+ 'additive_gaussian_noise', 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
diff --git a/imcui/third_party/SOLD2/sold2/config/project_config.py b/imcui/third_party/SOLD2/sold2/config/project_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..42ed00d1c1900e71568d1b06ff4f9d19a295232d
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/project_config.py
@@ -0,0 +1,41 @@
+"""
+Project configurations.
+"""
+import os
+
+
+class Config(object):
+ """ Datasets and experiments folders for the whole project. """
+ #####################
+ ## Dataset setting ##
+ #####################
+ DATASET_ROOT = os.getenv("DATASET_ROOT", "./datasets/") # TODO: path to your datasets folder
+ if not os.path.exists(DATASET_ROOT):
+ os.makedirs(DATASET_ROOT)
+
+ # Synthetic shape dataset
+ synthetic_dataroot = os.path.join(DATASET_ROOT, "synthetic_shapes")
+ synthetic_cache_path = os.path.join(DATASET_ROOT, "synthetic_shapes")
+ if not os.path.exists(synthetic_dataroot):
+ os.makedirs(synthetic_dataroot)
+
+ # Exported predictions dataset
+ export_dataroot = os.path.join(DATASET_ROOT, "export_datasets")
+ export_cache_path = os.path.join(DATASET_ROOT, "export_datasets")
+ if not os.path.exists(export_dataroot):
+ os.makedirs(export_dataroot)
+
+ # Wireframe dataset
+ wireframe_dataroot = os.path.join(DATASET_ROOT, "wireframe")
+ wireframe_cache_path = os.path.join(DATASET_ROOT, "wireframe")
+
+ # Holicity dataset
+ holicity_dataroot = os.path.join(DATASET_ROOT, "Holicity")
+ holicity_cache_path = os.path.join(DATASET_ROOT, "Holicity")
+
+ ########################
+ ## Experiment Setting ##
+ ########################
+ EXP_PATH = os.getenv("EXP_PATH", "./experiments/") # TODO: path to your experiments folder
+ if not os.path.exists(EXP_PATH):
+ os.makedirs(EXP_PATH)
diff --git a/imcui/third_party/SOLD2/sold2/config/synthetic_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/synthetic_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d9fa44522b6c09500100dbc56a11bc8a24d56832
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/synthetic_dataset.yaml
@@ -0,0 +1,48 @@
+### General dataset parameters
+dataset_name: "synthetic_shape"
+primitives: "all"
+add_augmentation_to_all_splits: True
+test_augmentation_seed: 200
+# Shape generation configuration
+generation:
+ split_sizes: {'train': 20000, 'val': 2000, 'test': 400}
+ random_seed: 10
+ image_size: [960, 1280]
+ min_len: 0.0985
+ min_label_len: 0.099
+ params:
+ generate_background:
+ min_kernel_size: 150
+ max_kernel_size: 500
+ min_rad_ratio: 0.02
+ max_rad_ratio: 0.031
+ draw_stripes:
+ transform_params: [0.1, 0.1]
+ draw_multiple_polygons:
+ kernel_boundaries: [50, 100]
+
+### Data preprocessing configuration.
+preprocessing:
+ resize: [400, 400]
+ blur_size: 11
+augmentation:
+ photometric:
+ enable: True
+ primitives: 'all'
+ params: {}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.8
+ max_angle: 1.57
+ allow_artifacts: true
+ translation_overflow: 0.05
+ valid_border_margin: 0
diff --git a/imcui/third_party/SOLD2/sold2/config/train_detector.yaml b/imcui/third_party/SOLD2/sold2/config/train_detector.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c53c35a6464eb1c37a9ea71c939225f793543aec
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/train_detector.yaml
@@ -0,0 +1,51 @@
+### [Model parameters]
+model_name: "lcnn_simple"
+model_architecture: "simple"
+# Backbone related config
+backbone: "lcnn"
+backbone_cfg:
+ input_channel: 1 # Use RGB images or grayscale images.
+ depth: 4
+ num_stacks: 2
+ num_blocks: 1
+ num_classes: 5
+# Junction decoder related config
+junction_decoder: "superpoint_decoder"
+junc_decoder_cfg:
+# Heatmap decoder related config
+heatmap_decoder: "pixel_shuffle"
+heatmap_decoder_cfg:
+# Shared configurations
+grid_size: 8
+keep_border_valid: True
+# Threshold of junction detection
+detection_thresh: 0.0153846 # 1/65
+# Threshold of heatmap detection
+prob_thresh: 0.5
+
+### [Loss parameters]
+weighting_policy: "dynamic"
+# [Heatmap loss]
+w_heatmap: 0.
+w_heatmap_class: 1
+heatmap_loss_func: "cross_entropy"
+heatmap_loss_cfg:
+ policy: "dynamic"
+# [Junction loss]
+w_junc: 0.
+junction_loss_func: "superpoint"
+junction_loss_cfg:
+ policy: "dynamic"
+
+### [Training parameters]
+learning_rate: 0.0005
+epochs: 200
+train:
+ batch_size: 6
+ num_workers: 8
+test:
+ batch_size: 6
+ num_workers: 8
+disp_freq: 100
+summary_freq: 200
+max_ckpt: 150
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/config/train_full_pipeline.yaml b/imcui/third_party/SOLD2/sold2/config/train_full_pipeline.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..233d898f47110c14beabbe63ee82044d506cc15a
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/train_full_pipeline.yaml
@@ -0,0 +1,62 @@
+### [Model parameters]
+model_name: "lcnn_simple"
+model_architecture: "simple"
+# Backbone related config
+backbone: "lcnn"
+backbone_cfg:
+ input_channel: 1 # Use RGB images or grayscale images.
+ depth: 4
+ num_stacks: 2
+ num_blocks: 1
+ num_classes: 5
+# Junction decoder related config
+junction_decoder: "superpoint_decoder"
+junc_decoder_cfg:
+# Heatmap decoder related config
+heatmap_decoder: "pixel_shuffle"
+heatmap_decoder_cfg:
+# Descriptor decoder related config
+descriptor_decoder: "superpoint_descriptor"
+descriptor_decoder_cfg:
+# Shared configurations
+grid_size: 8
+keep_border_valid: True
+# Threshold of junction detection
+detection_thresh: 0.0153846 # 1/65
+# Threshold of heatmap detection
+prob_thresh: 0.5
+
+### [Loss parameters]
+weighting_policy: "dynamic"
+# [Heatmap loss]
+w_heatmap: 0.
+w_heatmap_class: 1
+heatmap_loss_func: "cross_entropy"
+heatmap_loss_cfg:
+ policy: "dynamic"
+# [Junction loss]
+w_junc: 0.
+junction_loss_func: "superpoint"
+junction_loss_cfg:
+ policy: "dynamic"
+# [Descriptor loss]
+w_desc: 0.
+descriptor_loss_func: "regular_sampling"
+descriptor_loss_cfg:
+ dist_threshold: 8
+ grid_size: 4
+ margin: 1
+ policy: "dynamic"
+
+### [Training parameters]
+learning_rate: 0.0005
+epochs: 130
+train:
+ batch_size: 4
+ num_workers: 8
+test:
+ batch_size: 4
+ num_workers: 8
+disp_freq: 100
+summary_freq: 200
+max_ckpt: 130
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/config/wireframe_dataset.yaml b/imcui/third_party/SOLD2/sold2/config/wireframe_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..15abd3dbd6462dca21ac331a802b86a8ef050bff
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/config/wireframe_dataset.yaml
@@ -0,0 +1,75 @@
+### General dataset parameters
+dataset_name: "wireframe"
+add_augmentation_to_all_splits: False
+gray_scale: True
+# Ground truth source ('official' or path to the exported h5 dataset.)
+# gt_source_train: "" # Fill with your own export file
+# gt_source_test: "" # Fill with your own export file
+# Return type: (1) single (to train the detector only)
+# or (2) paired_desc (to train the detector + descriptor)
+return_type: "single"
+random_seed: 0
+
+### Descriptor training parameters
+# Number of points extracted per line
+max_num_samples: 10
+# Max number of training line points extracted in the whole image
+max_pts: 1000
+# Min distance between two points on a line (in pixels)
+min_dist_pts: 10
+# Small jittering of the sampled points during training
+jittering: 0
+
+### Data preprocessing configuration
+preprocessing:
+ resize: [512, 512]
+ blur_size: 11
+augmentation:
+ random_scaling:
+ enable: True
+ range: [0.7, 1.5]
+ photometric:
+ enable: True
+ primitives: ['random_brightness', 'random_contrast',
+ 'additive_speckle_noise', 'additive_gaussian_noise',
+ 'additive_shade', 'motion_blur' ]
+ params:
+ random_brightness: {brightness: 0.2}
+ random_contrast: {contrast: [0.3, 1.5]}
+ additive_gaussian_noise: {stddev_range: [0, 10]}
+ additive_speckle_noise: {prob_range: [0, 0.0035]}
+ additive_shade:
+ transparency_range: [-0.5, 0.5]
+ kernel_size_range: [100, 150]
+ motion_blur: {max_kernel_size: 3}
+ random_order: True
+ homographic:
+ enable: True
+ params:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ patch_ratio: 0.85
+ max_angle: 1.57
+ allow_artifacts: true
+ valid_border_margin: 3
+
+### Homography adaptation configuration
+homography_adaptation:
+ num_iter: 100
+ valid_border_margin: 3
+ min_counts: 30
+ homographies:
+ translation: true
+ rotation: true
+ scaling: true
+ perspective: true
+ scaling_amplitude: 0.2
+ perspective_amplitude_x: 0.2
+ perspective_amplitude_y: 0.2
+ allow_artifacts: true
+ patch_ratio: 0.85
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/dataset/__init__.py b/imcui/third_party/SOLD2/sold2/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/sold2/dataset/dataset_util.py b/imcui/third_party/SOLD2/sold2/dataset/dataset_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..50439ef3e2958d82719da0f6d10f4a7d98322f9a
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/dataset_util.py
@@ -0,0 +1,60 @@
+"""
+The interface of initializing different datasets.
+"""
+from .synthetic_dataset import SyntheticShapes
+from .wireframe_dataset import WireframeDataset
+from .holicity_dataset import HolicityDataset
+from .merge_dataset import MergeDataset
+
+
+def get_dataset(mode="train", dataset_cfg=None):
+ """ Initialize different dataset based on a configuration. """
+ # Check dataset config is given
+ if dataset_cfg is None:
+ raise ValueError("[Error] The dataset config is required!")
+
+ # Synthetic dataset
+ if dataset_cfg["dataset_name"] == "synthetic_shape":
+ dataset = SyntheticShapes(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ from .synthetic_dataset import synthetic_collate_fn
+ collate_fn = synthetic_collate_fn
+
+ # Wireframe dataset
+ elif dataset_cfg["dataset_name"] == "wireframe":
+ dataset = WireframeDataset(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ from .wireframe_dataset import wireframe_collate_fn
+ collate_fn = wireframe_collate_fn
+
+ # Holicity dataset
+ elif dataset_cfg["dataset_name"] == "holicity":
+ dataset = HolicityDataset(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ from .holicity_dataset import holicity_collate_fn
+ collate_fn = holicity_collate_fn
+
+ # Dataset merging several datasets in one
+ elif dataset_cfg["dataset_name"] == "merge":
+ dataset = MergeDataset(
+ mode, dataset_cfg
+ )
+
+ # Get the collate_fn
+ from .holicity_dataset import holicity_collate_fn
+ collate_fn = holicity_collate_fn
+
+ else:
+ raise ValueError(
+ "[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"])
+
+ return dataset, collate_fn
diff --git a/imcui/third_party/SOLD2/sold2/dataset/holicity_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/holicity_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4437f37bda366983052de902a41467ca01412bd
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/holicity_dataset.py
@@ -0,0 +1,797 @@
+"""
+File to process and load the Holicity dataset.
+"""
+import os
+import math
+import copy
+import PIL
+import numpy as np
+import h5py
+import cv2
+import pickle
+from skimage.io import imread
+from skimage import color
+import torch
+import torch.utils.data.dataloader as torch_loader
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from ..config.project_config import Config as cfg
+from .transforms import photometric_transforms as photoaug
+from .transforms import homographic_transforms as homoaug
+from .transforms.utils import random_scaling
+from .synthetic_util import get_line_heatmap
+from ..misc.geometry_utils import warp_points, mask_points
+from ..misc.train_utils import parse_h5_data
+
+
+def holicity_collate_fn(batch):
+ """ Customized collate_fn. """
+ batch_keys = ["image", "junction_map", "valid_mask", "heatmap",
+ "heatmap_pos", "heatmap_neg", "homography",
+ "line_points", "line_indices"]
+ list_keys = ["junctions", "line_map", "line_map_pos",
+ "line_map_neg", "file_key"]
+
+ outputs = {}
+ for data_key in batch[0].keys():
+ batch_match = sum([_ in data_key for _ in batch_keys])
+ list_match = sum([_ in data_key for _ in list_keys])
+ # print(batch_match, list_match)
+ if batch_match > 0 and list_match == 0:
+ outputs[data_key] = torch_loader.default_collate(
+ [b[data_key] for b in batch])
+ elif batch_match == 0 and list_match > 0:
+ outputs[data_key] = [b[data_key] for b in batch]
+ elif batch_match == 0 and list_match == 0:
+ continue
+ else:
+ raise ValueError(
+ "[Error] A key matches batch keys and list keys simultaneously.")
+
+ return outputs
+
+
+class HolicityDataset(Dataset):
+ def __init__(self, mode="train", config=None):
+ super(HolicityDataset, self).__init__()
+ if not mode in ["train", "test"]:
+ raise ValueError(
+ "[Error] Unknown mode for Holicity dataset. Only 'train' and 'test'.")
+ self.mode = mode
+
+ if config is None:
+ self.config = self.get_default_config()
+ else:
+ self.config = config
+ # Also get the default config
+ self.default_config = self.get_default_config()
+
+ # Get cache setting
+ self.dataset_name = self.get_dataset_name()
+ self.cache_name = self.get_cache_name()
+ self.cache_path = cfg.holicity_cache_path
+
+ # Get the ground truth source if it exists
+ self.gt_source = None
+ if "gt_source_%s"%(self.mode) in self.config:
+ self.gt_source = self.config.get("gt_source_%s"%(self.mode))
+ self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source)
+ # Check the full path exists
+ if not os.path.exists(self.gt_source):
+ raise ValueError(
+ "[Error] The specified ground truth source does not exist.")
+
+ # Get the filename dataset
+ print("[Info] Initializing Holicity dataset...")
+ self.filename_dataset, self.datapoints = self.construct_dataset()
+
+ # Get dataset length
+ self.dataset_length = len(self.datapoints)
+
+ # Print some info
+ print("[Info] Successfully initialized dataset")
+ print("\t Name: Holicity")
+ print("\t Mode: %s" %(self.mode))
+ print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode),
+ "None")))
+ print("\t Counts: %d" %(self.dataset_length))
+ print("----------------------------------------")
+
+ #######################################
+ ## Dataset construction related APIs ##
+ #######################################
+ def construct_dataset(self):
+ """ Construct the dataset (from scratch or from cache). """
+ # Check if the filename cache exists
+ # If cache exists, load from cache
+ if self.check_dataset_cache():
+ print("\t Found filename cache %s at %s"%(self.cache_name,
+ self.cache_path))
+ print("\t Load filename cache...")
+ filename_dataset, datapoints = self.get_filename_dataset_from_cache()
+ # If not, initialize dataset from scratch
+ else:
+ print("\t Can't find filename cache ...")
+ print("\t Create filename dataset from scratch...")
+ filename_dataset, datapoints = self.get_filename_dataset()
+ print("\t Create filename dataset cache...")
+ self.create_filename_dataset_cache(filename_dataset, datapoints)
+
+ return filename_dataset, datapoints
+
+ def create_filename_dataset_cache(self, filename_dataset, datapoints):
+ """ Create filename dataset cache for faster initialization. """
+ # Check cache path exists
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ data = {
+ "filename_dataset": filename_dataset,
+ "datapoints": datapoints
+ }
+ with open(cache_file_path, "wb") as f:
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+ def get_filename_dataset_from_cache(self):
+ """ Get filename dataset from cache. """
+ # Load from pkl cache
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ with open(cache_file_path, "rb") as f:
+ data = pickle.load(f)
+
+ return data["filename_dataset"], data["datapoints"]
+
+ def get_filename_dataset(self):
+ """ Get the path to the dataset. """
+ if self.mode == "train":
+ # Contains 5720 or 11872 images
+ dataset_path = [os.path.join(cfg.holicity_dataroot, p)
+ for p in self.config["train_splits"]]
+ else:
+ # Test mode - Contains 520 images
+ dataset_path = [os.path.join(cfg.holicity_dataroot, "2018-03")]
+
+ # Get paths to all image files
+ image_paths = []
+ for folder in dataset_path:
+ image_paths += [os.path.join(folder, img)
+ for img in os.listdir(folder)
+ if os.path.splitext(img)[-1] == ".jpg"]
+ image_paths = sorted(image_paths)
+
+ # Verify all the images exist
+ for idx in range(len(image_paths)):
+ image_path = image_paths[idx]
+ if not (os.path.exists(image_path)):
+ raise ValueError(
+ "[Error] The image does not exist. %s"%(image_path))
+
+ # Construct the filename dataset
+ num_pad = int(math.ceil(math.log10(len(image_paths))) + 1)
+ filename_dataset = {}
+ for idx in range(len(image_paths)):
+ # Get the file key
+ key = self.get_padded_filename(num_pad, idx)
+
+ filename_dataset[key] = {"image": image_paths[idx]}
+
+ # Get the datapoints
+ datapoints = list(sorted(filename_dataset.keys()))
+
+ return filename_dataset, datapoints
+
+ def get_dataset_name(self):
+ """ Get dataset name from dataset config / default config. """
+ dataset_name = self.config.get("dataset_name",
+ self.default_config["dataset_name"])
+ dataset_name = dataset_name + "_%s" % self.mode
+ return dataset_name
+
+ def get_cache_name(self):
+ """ Get cache name from dataset config / default config. """
+ dataset_name = self.config.get("dataset_name",
+ self.default_config["dataset_name"])
+ dataset_name = dataset_name + "_%s" % self.mode
+ # Compose cache name
+ cache_name = dataset_name + "_cache.pkl"
+ return cache_name
+
+ def check_dataset_cache(self):
+ """ Check if dataset cache exists. """
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ if os.path.exists(cache_file_path):
+ return True
+ else:
+ return False
+
+ @staticmethod
+ def get_padded_filename(num_pad, idx):
+ """ Get the padded filename using adaptive padding. """
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+ return filename
+
+ def get_default_config(self):
+ """ Get the default configuration. """
+ return {
+ "dataset_name": "holicity",
+ "train_split": "2018-01",
+ "add_augmentation_to_all_splits": False,
+ "preprocessing": {
+ "resize": [512, 512],
+ "blur_size": 11
+ },
+ "augmentation":{
+ "photometric":{
+ "enable": False
+ },
+ "homographic":{
+ "enable": False
+ },
+ },
+ }
+
+ ############################################
+ ## Pytorch and preprocessing related APIs ##
+ ############################################
+ @staticmethod
+ def get_data_from_path(data_path):
+ """ Get data from the information from filename dataset. """
+ output = {}
+
+ # Get image data
+ image_path = data_path["image"]
+ image = imread(image_path)
+ output["image"] = image
+
+ return output
+
+ @staticmethod
+ def convert_line_map(lcnn_line_map, num_junctions):
+ """ Convert the line_pos or line_neg
+ (represented by two junction indexes) to our line map. """
+ # Initialize empty line map
+ line_map = np.zeros([num_junctions, num_junctions])
+
+ # Iterate through all the lines
+ for idx in range(lcnn_line_map.shape[0]):
+ index1 = lcnn_line_map[idx, 0]
+ index2 = lcnn_line_map[idx, 1]
+
+ line_map[index1, index2] = 1
+ line_map[index2, index1] = 1
+
+ return line_map
+
+ @staticmethod
+ def junc_to_junc_map(junctions, image_size):
+ """ Convert junction points to junction maps. """
+ junctions = np.round(junctions).astype(np.int)
+ # Clip the boundary by image size
+ junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
+ junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+
+ # Create junction map
+ junc_map = np.zeros([image_size[0], image_size[1]])
+ junc_map[junctions[:, 0], junctions[:, 1]] = 1
+
+ return junc_map[..., None].astype(np.int)
+
+ def parse_transforms(self, names, all_transforms):
+ """ Parse the transform. """
+ trans = all_transforms if (names == 'all') \
+ else (names if isinstance(names, list) else [names])
+ assert set(trans) <= set(all_transforms)
+ return trans
+
+ def get_photo_transform(self):
+ """ Get list of photometric transforms (according to the config). """
+ # Get the photometric transform config
+ photo_config = self.config["augmentation"]["photometric"]
+ if not photo_config["enable"]:
+ raise ValueError(
+ "[Error] Photometric augmentation is not enabled.")
+
+ # Parse photometric transforms
+ trans_lst = self.parse_transforms(photo_config["primitives"],
+ photoaug.available_augmentations)
+ trans_config_lst = [photo_config["params"].get(p, {})
+ for p in trans_lst]
+
+ # List of photometric augmentation
+ photometric_trans_lst = [
+ getattr(photoaug, trans)(**conf) \
+ for (trans, conf) in zip(trans_lst, trans_config_lst)
+ ]
+
+ return photometric_trans_lst
+
+ def get_homo_transform(self):
+ """ Get homographic transforms (according to the config). """
+ # Get homographic transforms for image
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ if not self.config["augmentation"]["homographic"]["enable"]:
+ raise ValueError(
+ "[Error] Homographic augmentation is not enabled")
+
+ # Parse the homographic transforms
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Compute the min_label_len from config
+ try:
+ min_label_tmp = self.config["generation"]["min_label_len"]
+ except:
+ min_label_tmp = None
+
+ # float label len => fraction
+ if isinstance(min_label_tmp, float): # Skip if not provided
+ min_label_len = min_label_tmp * min(image_shape)
+ # int label len => length in pixel
+ elif isinstance(min_label_tmp, int):
+ scale_ratio = (self.config["preprocessing"]["resize"]
+ / self.config["generation"]["image_size"][0])
+ min_label_len = (self.config["generation"]["min_label_len"]
+ * scale_ratio)
+ # if none => no restriction
+ else:
+ min_label_len = 0
+
+ # Initialize the transform
+ homographic_trans = homoaug.homography_transform(
+ image_shape, homo_config, 0, min_label_len)
+
+ return homographic_trans
+
+ def get_line_points(self, junctions, line_map, H1=None, H2=None,
+ img_size=None, warp=False):
+ """ Sample evenly points along each line segments
+ and keep track of line idx. """
+ if np.sum(line_map) == 0:
+ # No segment detected in the image
+ line_indices = np.zeros(self.config["max_pts"], dtype=int)
+ line_points = np.zeros((self.config["max_pts"], 2), dtype=float)
+ return line_points, line_indices
+
+ # Extract all pairs of connected junctions
+ junc_indices = np.array(
+ [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i])
+ line_segments = np.stack([junctions[junc_indices[:, 0]],
+ junctions[junc_indices[:, 1]]], axis=1)
+ # line_segments is (num_lines, 2, 2)
+ line_lengths = np.linalg.norm(
+ line_segments[:, 0] - line_segments[:, 1], axis=1)
+
+ # Sample the points separated by at least min_dist_pts along each line
+ # The number of samples depends on the length of the line
+ num_samples = np.minimum(line_lengths // self.config["min_dist_pts"],
+ self.config["max_num_samples"])
+ line_points = []
+ line_indices = []
+ cur_line_idx = 1
+ for n in np.arange(2, self.config["max_num_samples"] + 1):
+ # Consider all lines where we can fit up to n points
+ cur_line_seg = line_segments[num_samples == n]
+ line_points_x = np.linspace(cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 0],
+ n, axis=-1).flatten()
+ line_points_y = np.linspace(cur_line_seg[:, 0, 1],
+ cur_line_seg[:, 1, 1],
+ n, axis=-1).flatten()
+ jitter = self.config.get("jittering", 0)
+ if jitter:
+ # Add a small random jittering of all points along the line
+ angles = np.arctan2(
+ cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n)
+ jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter
+ line_points_x += jitter_hyp * np.sin(angles)
+ line_points_y += jitter_hyp * np.cos(angles)
+ line_points.append(np.stack([line_points_x, line_points_y], axis=-1))
+ # Keep track of the line indices for each sampled point
+ num_cur_lines = len(cur_line_seg)
+ line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines)
+ line_indices.append(line_idx.repeat(n))
+ cur_line_idx += num_cur_lines
+ line_points = np.concatenate(line_points,
+ axis=0)[:self.config["max_pts"]]
+ line_indices = np.concatenate(line_indices,
+ axis=0)[:self.config["max_pts"]]
+
+ # Warp the points if need be, and filter unvalid ones
+ # If the other view is also warped
+ if warp and H2 is not None:
+ warp_points2 = warp_points(line_points, H2)
+ line_points = warp_points(line_points, H1)
+ mask = mask_points(line_points, img_size)
+ mask2 = mask_points(warp_points2, img_size)
+ mask = mask * mask2
+ # If the other view is not warped
+ elif warp and H2 is None:
+ line_points = warp_points(line_points, H1)
+ mask = mask_points(line_points, img_size)
+ else:
+ if H1 is not None:
+ raise ValueError("[Error] Wrong combination of homographies.")
+ # Remove points that would be outside of img_size if warped by H
+ warped_points = warp_points(line_points, H1)
+ mask = mask_points(warped_points, img_size)
+ line_points = line_points[mask]
+ line_indices = line_indices[mask]
+
+ # Pad the line points to a fixed length
+ # Index of 0 means padded line
+ line_indices = np.concatenate([line_indices, np.zeros(
+ self.config["max_pts"] - len(line_indices))], axis=0)
+ line_points = np.concatenate(
+ [line_points,
+ np.zeros((self.config["max_pts"] - len(line_points), 2),
+ dtype=float)], axis=0)
+
+ return line_points, line_indices
+
+ def export_preprocessing(self, data, numpy=False):
+ """ Preprocess the exported data. """
+ # Fetch the corresponding entries
+ image = data["image"]
+ image_size = image.shape[:2]
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ image = photoaug.normalize_image()(image)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ return {"image": to_tensor(image)}
+ else:
+ return {"image": image}
+
+ def train_preprocessing_exported(
+ self, data, numpy=False, disable_homoaug=False, desc_training=False,
+ H1=None, H1_scale=None, H2=None, scale=1., h_crop=None, w_crop=None):
+ """ Train preprocessing for the exported labels. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junctions"]
+ line_map = data["line_map"]
+ image_size = image.shape[:2]
+
+ # Define the random crop for scaling if necessary
+ if h_crop is None or w_crop is None:
+ h_crop, w_crop = 0, 0
+ if scale > 1:
+ H, W = self.config["preprocessing"]["resize"]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+ if H_scale > H:
+ h_crop = np.random.randint(H_scale - H)
+ if W_scale > W:
+ w_crop = np.random.randint(W_scale - W)
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # # In HW format
+ # junctions = (junctions * np.array(
+ # self.config['preprocessing']['resize'], np.float)
+ # / np.array(size_old, np.float))
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ image_size = image.shape[:2]
+ heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Check if we need to apply augmentations
+ # In training mode => yes.
+ # In homography adaptation mode (export mode) => No
+ if self.config["augmentation"]["photometric"]["enable"]:
+ photo_trans_lst = self.get_photo_transform()
+ ### Image transform ###
+ np.random.shuffle(photo_trans_lst)
+ image_transform = transforms.Compose(
+ photo_trans_lst + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Perform the random scaling
+ if scale != 1.:
+ image, junctions, line_map, valid_mask = random_scaling(
+ image, junctions, line_map, scale,
+ h_crop=h_crop, w_crop=w_crop)
+ else:
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ # Initialize the empty output dict
+ outputs = {}
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+
+ # Check homographic augmentation
+ warp = (self.config["augmentation"]["homographic"]["enable"]
+ and disable_homoaug == False)
+ if warp:
+ homo_trans = self.get_homo_transform()
+ # Perform homographic transform
+ if H1 is None:
+ homo_outputs = homo_trans(image, junctions, line_map,
+ valid_mask=valid_mask)
+ else:
+ homo_outputs = homo_trans(
+ image, junctions, line_map, homo=H1, scale=H1_scale,
+ valid_mask=valid_mask)
+ homography_mat = homo_outputs["homo"]
+
+ # Give the warp of the other view
+ if H1 is None:
+ H1 = homo_outputs["homo"]
+
+ # Sample points along each line segments for the descriptor
+ if desc_training:
+ line_points, line_indices = self.get_line_points(
+ junctions, line_map, H1=H1, H2=H2,
+ img_size=image_size, warp=warp)
+
+ # Record the warped results
+ if warp:
+ junctions = homo_outputs["junctions"] # Should be HW format
+ image = homo_outputs["warped_image"]
+ line_map = homo_outputs["line_map"]
+ valid_mask = homo_outputs["valid_mask"] # Same for pos and neg
+ heatmap = homo_outputs["warped_heatmap"]
+
+ # Optionally put warping information first.
+ if not numpy:
+ outputs["homography_mat"] = to_tensor(
+ homography_mat).to(torch.float32)[0, ...]
+ else:
+ outputs["homography_mat"] = homography_mat.astype(np.float32)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ if not numpy:
+ outputs.update({
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ })
+ if desc_training:
+ outputs.update({
+ "line_points": to_tensor(
+ line_points).to(torch.float32)[0],
+ "line_indices": torch.tensor(line_indices,
+ dtype=torch.int)
+ })
+ else:
+ outputs.update({
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map": line_map.astype(np.int32),
+ "heatmap": heatmap.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ })
+ if desc_training:
+ outputs.update({
+ "line_points": line_points.astype(np.float32),
+ "line_indices": line_indices.astype(int)
+ })
+
+ return outputs
+
+ def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.):
+ """ Train preprocessing for paired data for the exported labels
+ for descriptor training. """
+ outputs = {}
+
+ # Define the random crop for scaling if necessary
+ h_crop, w_crop = 0, 0
+ if scale > 1:
+ H, W = self.config["preprocessing"]["resize"]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+ if H_scale > H:
+ h_crop = np.random.randint(H_scale - H)
+ if W_scale > W:
+ w_crop = np.random.randint(W_scale - W)
+
+ # Sample ref homography first
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ image_shape = self.config["preprocessing"]["resize"]
+ ref_H, ref_scale = homoaug.sample_homography(image_shape,
+ **homo_config)
+
+ # Data for target view (All augmentation)
+ target_data = self.train_preprocessing_exported(
+ data, numpy=numpy, desc_training=True, H1=None, H2=ref_H,
+ scale=scale, h_crop=h_crop, w_crop=w_crop)
+
+ # Data for reference view (No homographical augmentation)
+ ref_data = self.train_preprocessing_exported(
+ data, numpy=numpy, desc_training=True, H1=ref_H,
+ H1_scale=ref_scale, H2=target_data['homography_mat'].numpy(),
+ scale=scale, h_crop=h_crop, w_crop=w_crop)
+
+ # Spread ref data
+ for key, val in ref_data.items():
+ outputs["ref_" + key] = val
+
+ # Spread target data
+ for key, val in target_data.items():
+ outputs["target_" + key] = val
+
+ return outputs
+
+ def test_preprocessing_exported(self, data, numpy=False):
+ """ Test preprocessing for the exported labels. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junctions"]
+ line_map = data["line_map"]
+ image_size = image.shape[:2]
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # # In HW format
+ # junctions = (junctions * np.array(
+ # self.config['preprocessing']['resize'], np.float)
+ # / np.array(size_old, np.float))
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Still need to normalize image
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ image_size = image.shape[:2]
+ heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
+
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ outputs = {
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ outputs = {
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map": line_map.astype(np.int32),
+ "heatmap": heatmap.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ return outputs
+
+ def __len__(self):
+ return self.dataset_length
+
+ def get_data_from_key(self, file_key):
+ """ Get data from file_key. """
+ # Check key exists
+ if not file_key in self.filename_dataset.keys():
+ raise ValueError(
+ "[Error] the specified key is not in the dataset.")
+
+ # Get the data paths
+ data_path = self.filename_dataset[file_key]
+ # Read in the image and npz labels
+ data = self.get_data_from_path(data_path)
+
+ # Perform transform and augmentation
+ if (self.mode == "train"
+ or self.config["add_augmentation_to_all_splits"]):
+ data = self.train_preprocessing(data, numpy=True)
+ else:
+ data = self.test_preprocessing(data, numpy=True)
+
+ # Add file key to the output
+ data["file_key"] = file_key
+
+ return data
+
+ def __getitem__(self, idx):
+ """Return data
+ file_key: str, keys used to retrieve data from the filename dataset.
+ image: torch.float, C*H*W range 0~1,
+ junctions: torch.float, N*2,
+ junction_map: torch.int32, 1*H*W range 0 or 1,
+ line_map: torch.int32, N*N range 0 or 1,
+ heatmap: torch.int32, 1*H*W range 0 or 1,
+ valid_mask: torch.int32, 1*H*W range 0 or 1
+ """
+ # Get the corresponding datapoint and contents from filename dataset
+ file_key = self.datapoints[idx]
+ data_path = self.filename_dataset[file_key]
+ # Read in the image and npz labels
+ data = self.get_data_from_path(data_path)
+
+ if self.gt_source:
+ with h5py.File(self.gt_source, "r") as f:
+ exported_label = parse_h5_data(f[file_key])
+
+ data["junctions"] = exported_label["junctions"]
+ data["line_map"] = exported_label["line_map"]
+
+ # Perform transform and augmentation
+ return_type = self.config.get("return_type", "single")
+ if self.gt_source is None:
+ # For export only
+ data = self.export_preprocessing(data)
+ elif (self.mode == "train"
+ or self.config["add_augmentation_to_all_splits"]):
+ # Perform random scaling first
+ if self.config["augmentation"]["random_scaling"]["enable"]:
+ scale_range = self.config["augmentation"]["random_scaling"]["range"]
+ # Decide the scaling
+ scale = np.random.uniform(min(scale_range), max(scale_range))
+ else:
+ scale = 1.
+ if self.mode == "train" and return_type == "paired_desc":
+ data = self.preprocessing_exported_paired_desc(data,
+ scale=scale)
+ else:
+ data = self.train_preprocessing_exported(data, scale=scale)
+ else:
+ if return_type == "paired_desc":
+ data = self.preprocessing_exported_paired_desc(data)
+ else:
+ data = self.test_preprocessing_exported(data)
+
+ # Add file key to the output
+ data["file_key"] = file_key
+
+ return data
+
diff --git a/imcui/third_party/SOLD2/sold2/dataset/merge_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/merge_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..178d3822d56639a49a99f68e392330e388fa8fc3
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/merge_dataset.py
@@ -0,0 +1,37 @@
+""" Compose multiple datasets in a single loader. """
+
+import numpy as np
+from copy import deepcopy
+from torch.utils.data import Dataset
+
+from .wireframe_dataset import WireframeDataset
+from .holicity_dataset import HolicityDataset
+
+
+class MergeDataset(Dataset):
+ def __init__(self, mode, config=None):
+ super(MergeDataset, self).__init__()
+ # Initialize the datasets
+ self._datasets = []
+ spec_config = deepcopy(config)
+ for i, d in enumerate(config['datasets']):
+ spec_config['dataset_name'] = d
+ spec_config['gt_source_train'] = config['gt_source_train'][i]
+ spec_config['gt_source_test'] = config['gt_source_test'][i]
+ if d == "wireframe":
+ self._datasets.append(WireframeDataset(mode, spec_config))
+ elif d == "holicity":
+ spec_config['train_split'] = config['train_splits'][i]
+ self._datasets.append(HolicityDataset(mode, spec_config))
+ else:
+ raise ValueError("Unknown dataset: " + d)
+
+ self._weights = config['weights']
+
+ def __getitem__(self, item):
+ dataset = self._datasets[np.random.choice(
+ range(len(self._datasets)), p=self._weights)]
+ return dataset[np.random.randint(len(dataset))]
+
+ def __len__(self):
+ return np.sum([len(d) for d in self._datasets])
diff --git a/imcui/third_party/SOLD2/sold2/dataset/synthetic_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/synthetic_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf5f11e5407e65887f4995291156f7cc361843d1
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/synthetic_dataset.py
@@ -0,0 +1,712 @@
+"""
+This file implements the synthetic shape dataset object for pytorch
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import os
+import math
+import h5py
+import pickle
+import torch
+import numpy as np
+import cv2
+from tqdm import tqdm
+from torchvision import transforms
+from torch.utils.data import Dataset
+import torch.utils.data.dataloader as torch_loader
+
+from ..config.project_config import Config as cfg
+from . import synthetic_util
+from .transforms import photometric_transforms as photoaug
+from .transforms import homographic_transforms as homoaug
+from ..misc.train_utils import parse_h5_data
+
+
+def synthetic_collate_fn(batch):
+ """ Customized collate_fn. """
+ batch_keys = ["image", "junction_map", "heatmap",
+ "valid_mask", "homography"]
+ list_keys = ["junctions", "line_map", "file_key"]
+
+ outputs = {}
+ for data_key in batch[0].keys():
+ batch_match = sum([_ in data_key for _ in batch_keys])
+ list_match = sum([_ in data_key for _ in list_keys])
+ # print(batch_match, list_match)
+ if batch_match > 0 and list_match == 0:
+ outputs[data_key] = torch_loader.default_collate([b[data_key]
+ for b in batch])
+ elif batch_match == 0 and list_match > 0:
+ outputs[data_key] = [b[data_key] for b in batch]
+ elif batch_match == 0 and list_match == 0:
+ continue
+ else:
+ raise ValueError(
+ "[Error] A key matches batch keys and list keys simultaneously.")
+
+ return outputs
+
+
+class SyntheticShapes(Dataset):
+ """ Dataset of synthetic shapes. """
+ # Initialize the dataset
+ def __init__(self, mode="train", config=None):
+ super(SyntheticShapes, self).__init__()
+ if not mode in ["train", "val", "test"]:
+ raise ValueError(
+ "[Error] Supported dataset modes are 'train', 'val', and 'test'.")
+ self.mode = mode
+
+ # Get configuration
+ if config is None:
+ self.config = self.get_default_config()
+ else:
+ self.config = config
+
+ # Set all available primitives
+ self.available_primitives = [
+ 'draw_lines',
+ 'draw_polygon',
+ 'draw_multiple_polygons',
+ 'draw_star',
+ 'draw_checkerboard_multiseg',
+ 'draw_stripes_multiseg',
+ 'draw_cube',
+ 'gaussian_noise'
+ ]
+
+ # Some cache setting
+ self.dataset_name = self.get_dataset_name()
+ self.cache_name = self.get_cache_name()
+ self.cache_path = cfg.synthetic_cache_path
+
+ # Check if export dataset exists
+ print("===============================================")
+ self.filename_dataset, self.datapoints = self.construct_dataset()
+ self.print_dataset_info()
+
+ # Initialize h5 file handle
+ self.dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name + ".h5")
+
+ # Fix the random seed for torch and numpy in testing mode
+ if ((self.mode == "val" or self.mode == "test")
+ and self.config["add_augmentation_to_all_splits"]):
+ seed = self.config.get("test_augmentation_seed", 200)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ # For CuDNN
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ ##########################################
+ ## Dataset construction related methods ##
+ ##########################################
+ def construct_dataset(self):
+ """ Dataset constructor. """
+ # Check if the filename cache exists
+ # If cache exists, load from cache
+ if self._check_dataset_cache():
+ print("[Info]: Found filename cache at ...")
+ print("\t Load filename cache...")
+ filename_dataset, datapoints = self.get_filename_dataset_from_cache()
+ print("\t Check if all file exists...")
+ # If all file exists, continue
+ if self._check_file_existence(filename_dataset):
+ print("\t All files exist!")
+ # If not, need to re-export the synthetic dataset
+ else:
+ print("\t Some files are missing. Re-export the synthetic shape dataset.")
+ self.export_synthetic_shapes()
+ print("\t Initialize filename dataset")
+ filename_dataset, datapoints = self.get_filename_dataset()
+ print("\t Create filename dataset cache...")
+ self.create_filename_dataset_cache(filename_dataset,
+ datapoints)
+
+ # If not, initialize dataset from scratch
+ else:
+ print("[Info]: Can't find filename cache ...")
+ print("\t First check export dataset exists.")
+ # If export dataset exists, then just update the filename_dataset
+ if self._check_export_dataset():
+ print("\t Synthetic dataset exists. Initialize the dataset ...")
+
+ # If export dataset does not exist, export from scratch
+ else:
+ print("\t Synthetic dataset does not exist. Export the synthetic dataset.")
+ self.export_synthetic_shapes()
+ print("\t Initialize filename dataset")
+
+ filename_dataset, datapoints = self.get_filename_dataset()
+ print("\t Create filename dataset cache...")
+ self.create_filename_dataset_cache(filename_dataset, datapoints)
+
+ return filename_dataset, datapoints
+
+ def get_cache_name(self):
+ """ Get cache name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+ # Compose cache name
+ cache_name = dataset_name + "_cache.pkl"
+
+ return cache_name
+
+ def get_dataset_name(self):
+ """Get dataset name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+
+ return dataset_name
+
+ def get_filename_dataset_from_cache(self):
+ """ Get filename dataset from cache. """
+ # Load from the pkl cache
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ with open(cache_file_path, "rb") as f:
+ data = pickle.load(f)
+
+ return data["filename_dataset"], data["datapoints"]
+
+ def get_filename_dataset(self):
+ """ Get filename dataset from scratch. """
+ # Path to the exported dataset
+ dataset_path = os.path.join(cfg.synthetic_dataroot,
+ self.dataset_name + ".h5")
+
+ filename_dataset = {}
+ datapoints = []
+ # Open the h5 dataset
+ with h5py.File(dataset_path, "r") as f:
+ # Iterate through all the primitives
+ for prim_name in f.keys():
+ filenames = sorted(f[prim_name].keys())
+ filenames_full = [os.path.join(prim_name, _)
+ for _ in filenames]
+
+ filename_dataset[prim_name] = filenames_full
+ datapoints += filenames_full
+
+ return filename_dataset, datapoints
+
+ def create_filename_dataset_cache(self, filename_dataset, datapoints):
+ """ Create filename dataset cache for faster initialization. """
+ # Check cache path exists
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ data = {
+ "filename_dataset": filename_dataset,
+ "datapoints": datapoints
+ }
+ with open(cache_file_path, "wb") as f:
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+ def export_synthetic_shapes(self):
+ """ Export synthetic shapes to disk. """
+ # Set the global random state for data generation
+ synthetic_util.set_random_state(np.random.RandomState(
+ self.config["generation"]["random_seed"]))
+
+ # Define the export path
+ dataset_path = os.path.join(cfg.synthetic_dataroot,
+ self.dataset_name + ".h5")
+
+ # Open h5py file
+ with h5py.File(dataset_path, "w", libver="latest") as f:
+ # Iterate through all types of shape
+ primitives = self.parse_drawing_primitives(
+ self.config["primitives"])
+ split_size = self.config["generation"]["split_sizes"][self.mode]
+ for prim in primitives:
+ # Create h5 group
+ group = f.create_group(prim)
+ # Export single primitive
+ self.export_single_primitive(prim, split_size, group)
+
+ f.swmr_mode = True
+
+ def export_single_primitive(self, primitive, split_size, group):
+ """ Export single primitive. """
+ # Check if the primitive is valid or not
+ if primitive not in self.available_primitives:
+ raise ValueError(
+ "[Error]: %s is not a supported primitive" % primitive)
+ # Set the random seed
+ synthetic_util.set_random_state(np.random.RandomState(
+ self.config["generation"]["random_seed"]))
+
+ # Generate shapes
+ print("\t Generating %s ..." % primitive)
+ for idx in tqdm(range(split_size), ascii=True):
+ # Generate background image
+ image = synthetic_util.generate_background(
+ self.config['generation']['image_size'],
+ **self.config['generation']['params']['generate_background'])
+
+ # Generate points
+ drawing_func = getattr(synthetic_util, primitive)
+ kwarg = self.config["generation"]["params"].get(primitive, {})
+
+ # Get min_len and min_label_len
+ min_len = self.config["generation"]["min_len"]
+ min_label_len = self.config["generation"]["min_label_len"]
+
+ # Some only take min_label_len, and gaussian noises take nothing
+ if primitive in ["draw_lines", "draw_polygon",
+ "draw_multiple_polygons", "draw_star"]:
+ data = drawing_func(image, min_len=min_len,
+ min_label_len=min_label_len, **kwarg)
+ elif primitive in ["draw_checkerboard_multiseg",
+ "draw_stripes_multiseg", "draw_cube"]:
+ data = drawing_func(image, min_label_len=min_label_len,
+ **kwarg)
+ else:
+ data = drawing_func(image, **kwarg)
+
+ # Convert the data
+ if data["points"] is not None:
+ points = np.flip(data["points"], axis=1).astype(np.float)
+ line_map = data["line_map"].astype(np.int32)
+ else:
+ points = np.zeros([0, 2]).astype(np.float)
+ line_map = np.zeros([0, 0]).astype(np.int32)
+
+ # Post-processing
+ blur_size = self.config["preprocessing"]["blur_size"]
+ image = cv2.GaussianBlur(image, (blur_size, blur_size), 0)
+
+ # Resize the image and the point location.
+ points = (points
+ * np.array(self.config['preprocessing']['resize'],
+ np.float)
+ / np.array(self.config['generation']['image_size'],
+ np.float))
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # Generate the line heatmap after post-processing
+ junctions = np.flip(np.round(points).astype(np.int32), axis=1)
+ heatmap = (synthetic_util.get_line_heatmap(
+ junctions, line_map,
+ size=image.shape) * 255.).astype(np.uint8)
+
+ # Record the data in group
+ num_pad = math.ceil(math.log10(split_size)) + 1
+ file_key_name = self.get_padded_filename(num_pad, idx)
+ file_group = group.create_group(file_key_name)
+
+ # Store data
+ file_group.create_dataset("points", data=points,
+ compression="gzip")
+ file_group.create_dataset("image", data=image,
+ compression="gzip")
+ file_group.create_dataset("line_map", data=line_map,
+ compression="gzip")
+ file_group.create_dataset("heatmap", data=heatmap,
+ compression="gzip")
+
+ def get_default_config(self):
+ """ Get default configuration of the dataset. """
+ # Initialize the default configuration
+ self.default_config = {
+ "dataset_name": "synthetic_shape",
+ "primitives": "all",
+ "add_augmentation_to_all_splits": False,
+ # Shape generation configuration
+ "generation": {
+ "split_sizes": {'train': 10000, 'val': 400, 'test': 500},
+ "random_seed": 10,
+ "image_size": [960, 1280],
+ "min_len": 0.09,
+ "min_label_len": 0.1,
+ 'params': {
+ 'generate_background': {
+ 'min_kernel_size': 150, 'max_kernel_size': 500,
+ 'min_rad_ratio': 0.02, 'max_rad_ratio': 0.031},
+ 'draw_stripes': {'transform_params': (0.1, 0.1)},
+ 'draw_multiple_polygons': {'kernel_boundaries': (50, 100)}
+ },
+ },
+ # Date preprocessing configuration.
+ "preprocessing": {
+ "resize": [240, 320],
+ "blur_size": 11
+ },
+ 'augmentation': {
+ 'photometric': {
+ 'enable': False,
+ 'primitives': 'all',
+ 'params': {},
+ 'random_order': True,
+ },
+ 'homographic': {
+ 'enable': False,
+ 'params': {},
+ 'valid_border_margin': 0,
+ },
+ }
+ }
+
+ return self.default_config
+
+ def parse_drawing_primitives(self, names):
+ """ Parse the primitives in config to list of primitive names. """
+ if names == "all":
+ p = self.available_primitives
+ else:
+ if isinstance(names, list):
+ p = names
+ else:
+ p = [names]
+
+ assert set(p) <= set(self.available_primitives)
+
+ return p
+
+ @staticmethod
+ def get_padded_filename(num_pad, idx):
+ """ Get the padded filename using adaptive padding. """
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+
+ return filename
+
+ def print_dataset_info(self):
+ """ Print dataset info. """
+ print("\t ---------Summary------------------")
+ print("\t Dataset mode: \t\t %s" % self.mode)
+ print("\t Number of primitive: \t %d" % len(self.filename_dataset.keys()))
+ print("\t Number of data: \t %d" % len(self.datapoints))
+ print("\t ----------------------------------")
+
+ #########################
+ ## Pytorch related API ##
+ #########################
+ def get_data_from_datapoint(self, datapoint, reader=None):
+ """ Get data given the datapoint
+ (keyname of the h5 dataset e.g. "draw_lines/0000.h5"). """
+ # Check if the datapoint is valid
+ if not datapoint in self.datapoints:
+ raise ValueError(
+ "[Error] The specified datapoint is not in available datapoints.")
+
+ # Get data from h5 dataset
+ if reader is None:
+ raise ValueError(
+ "[Error] The reader must be provided in __getitem__.")
+ else:
+ data = reader[datapoint]
+
+ return parse_h5_data(data)
+
+ def get_data_from_signature(self, primitive_name, index):
+ """ Get data given the primitive name and index ("draw_lines", 10) """
+ # Check the primitive name and index
+ self._check_primitive_and_index(primitive_name, index)
+
+ # Get the datapoint from filename dataset
+ datapoint = self.filename_dataset[primitive_name][index]
+
+ return self.get_data_from_datapoint(datapoint)
+
+ def parse_transforms(self, names, all_transforms):
+ trans = all_transforms if (names == 'all') \
+ else (names if isinstance(names, list) else [names])
+ assert set(trans) <= set(all_transforms)
+ return trans
+
+ def get_photo_transform(self):
+ """ Get list of photometric transforms (according to the config). """
+ # Get the photometric transform config
+ photo_config = self.config["augmentation"]["photometric"]
+ if not photo_config["enable"]:
+ raise ValueError(
+ "[Error] Photometric augmentation is not enabled.")
+
+ # Parse photometric transforms
+ trans_lst = self.parse_transforms(photo_config["primitives"],
+ photoaug.available_augmentations)
+ trans_config_lst = [photo_config["params"].get(p, {})
+ for p in trans_lst]
+
+ # List of photometric augmentation
+ photometric_trans_lst = [
+ getattr(photoaug, trans)(**conf) \
+ for (trans, conf) in zip(trans_lst, trans_config_lst)
+ ]
+
+ return photometric_trans_lst
+
+ def get_homo_transform(self):
+ """ Get homographic transforms (according to the config). """
+ # Get homographic transforms for image
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ if not self.config["augmentation"]["homographic"]["enable"]:
+ raise ValueError(
+ "[Error] Homographic augmentation is not enabled")
+
+ # Parse the homographic transforms
+ # ToDo: use the shape from the config
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Compute the min_label_len from config
+ try:
+ min_label_tmp = self.config["generation"]["min_label_len"]
+ except:
+ min_label_tmp = None
+
+ # float label len => fraction
+ if isinstance(min_label_tmp, float): # Skip if not provided
+ min_label_len = min_label_tmp * min(image_shape)
+ # int label len => length in pixel
+ elif isinstance(min_label_tmp, int):
+ scale_ratio = (self.config["preprocessing"]["resize"]
+ / self.config["generation"]["image_size"][0])
+ min_label_len = (self.config["generation"]["min_label_len"]
+ * scale_ratio)
+ # if none => no restriction
+ else:
+ min_label_len = 0
+
+ # Initialize the transform
+ homographic_trans = homoaug.homography_transform(
+ image_shape, homo_config, 0, min_label_len)
+
+ return homographic_trans
+
+ @staticmethod
+ def junc_to_junc_map(junctions, image_size):
+ """ Convert junction points to junction maps. """
+ junctions = np.round(junctions).astype(np.int)
+ # Clip the boundary by image size
+ junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
+ junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+
+ # Create junction map
+ junc_map = np.zeros([image_size[0], image_size[1]])
+ junc_map[junctions[:, 0], junctions[:, 1]] = 1
+
+ return junc_map[..., None].astype(np.int)
+
+ def train_preprocessing(self, data, disable_homoaug=False):
+ """ Training preprocessing. """
+ # Fetch corresponding entries
+ image = data["image"]
+ junctions = data["points"]
+ line_map = data["line_map"]
+ heatmap = data["heatmap"]
+ image_size = image.shape[:2]
+
+ # Resize the image before the photometric and homographic transforms
+ # Check if we need to do the resizing
+ if not(list(image.shape) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ junctions = (
+ junctions
+ * np.array(self.config['preprocessing']['resize'], np.float)
+ / np.array(size_old, np.float))
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32),
+ axis=1)
+ heatmap = synthetic_util.get_line_heatmap(junctions_xy, line_map,
+ size=image.shape)
+ heatmap = (heatmap * 255.).astype(np.uint8)
+
+ # Update image size
+ image_size = image.shape[:2]
+
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ # Check if we need to apply augmentations
+ # In training mode => yes.
+ # In homography adaptation mode (export mode) => No
+ # Check photometric augmentation
+ if self.config["augmentation"]["photometric"]["enable"]:
+ photo_trans_lst = self.get_photo_transform()
+ ### Image transform ###
+ np.random.shuffle(photo_trans_lst)
+ image_transform = transforms.Compose(
+ photo_trans_lst + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Initialize the empty output dict
+ outputs = {}
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ # Check homographic augmentation
+ if (self.config["augmentation"]["homographic"]["enable"]
+ and disable_homoaug == False):
+ homo_trans = self.get_homo_transform()
+ # Perform homographic transform
+ homo_outputs = homo_trans(image, junctions, line_map)
+
+ # Record the warped results
+ junctions = homo_outputs["junctions"] # Should be HW format
+ image = homo_outputs["warped_image"]
+ line_map = homo_outputs["line_map"]
+ heatmap = homo_outputs["warped_heatmap"]
+ valid_mask = homo_outputs["valid_mask"] # Same for pos and neg
+ homography_mat = homo_outputs["homo"]
+
+ # Optionally put warpping information first.
+ outputs["homography_mat"] = to_tensor(
+ homography_mat).to(torch.float32)[0, ...]
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ outputs.update({
+ "image": to_tensor(image),
+ "junctions": to_tensor(np.ascontiguousarray(
+ junctions).copy()).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32),
+ })
+
+ return outputs
+
+ def test_preprocessing(self, data):
+ """ Test preprocessing. """
+ # Fetch corresponding entries
+ image = data["image"]
+ points = data["points"]
+ line_map = data["line_map"]
+ heatmap = data["heatmap"]
+ image_size = image.shape[:2]
+
+ # Resize the image before the photometric and homographic transforms
+ if not (list(image.shape) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ points = (points
+ * np.array(self.config['preprocessing']['resize'],
+ np.float)
+ / np.array(size_old, np.float))
+
+ # Generate the line heatmap after post-processing
+ junctions = np.flip(np.round(points).astype(np.int32), axis=1)
+ heatmap = synthetic_util.get_line_heatmap(junctions, line_map,
+ size=image.shape)
+ heatmap = (heatmap * 255.).astype(np.uint8)
+
+ # Update image size
+ image_size = image.shape[:2]
+
+ ### image transform ###
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ ### joint transform ###
+ junction_map = self.junc_to_junc_map(points, image_size)
+ to_tensor = transforms.ToTensor()
+ image = to_tensor(image)
+ junctions = to_tensor(points)
+ junction_map = to_tensor(junction_map).to(torch.int)
+ line_map = to_tensor(line_map)
+ heatmap = to_tensor(heatmap)
+ valid_mask = to_tensor(np.ones(image_size)).to(torch.int32)
+
+ return {
+ "image": image,
+ "junctions": junctions,
+ "junction_map": junction_map,
+ "line_map": line_map,
+ "heatmap": heatmap,
+ "valid_mask": valid_mask
+ }
+
+ def __getitem__(self, index):
+ datapoint = self.datapoints[index]
+
+ # Initialize reader and use it
+ with h5py.File(self.dataset_path, "r", swmr=True) as reader:
+ data = self.get_data_from_datapoint(datapoint, reader)
+
+ # Apply different transforms in different mod.
+ if (self.mode == "train"
+ or self.config["add_augmentation_to_all_splits"]):
+ return_type = self.config.get("return_type", "single")
+ data = self.train_preprocessing(data)
+ else:
+ data = self.test_preprocessing(data)
+
+ return data
+
+ def __len__(self):
+ return len(self.datapoints)
+
+ ########################
+ ## Some other methods ##
+ ########################
+ def _check_dataset_cache(self):
+ """ Check if dataset cache exists. """
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ if os.path.exists(cache_file_path):
+ return True
+ else:
+ return False
+
+ def _check_export_dataset(self):
+ """ Check if exported dataset exists. """
+ dataset_path = os.path.join(cfg.synthetic_dataroot, self.dataset_name)
+ if os.path.exists(dataset_path) and len(os.listdir(dataset_path)) > 0:
+ return True
+ else:
+ return False
+
+ def _check_file_existence(self, filename_dataset):
+ """ Check if all exported file exists. """
+ # Path to the exported dataset
+ dataset_path = os.path.join(cfg.synthetic_dataroot,
+ self.dataset_name + ".h5")
+
+ flag = True
+ # Open the h5 dataset
+ with h5py.File(dataset_path, "r") as f:
+ # Iterate through all the primitives
+ for prim_name in f.keys():
+ if (len(filename_dataset[prim_name])
+ != len(f[prim_name].keys())):
+ flag = False
+
+ return flag
+
+ def _check_primitive_and_index(self, primitive, index):
+ """ Check if the primitve and index are valid. """
+ # Check primitives
+ if not primitive in self.available_primitives:
+ raise ValueError(
+ "[Error] The primitive is not in available primitives.")
+
+ prim_len = len(self.filename_dataset[primitive])
+ # Check the index
+ if not index < prim_len:
+ raise ValueError(
+ "[Error] The index exceeds the total file counts %d for %s"
+ % (prim_len, primitive))
diff --git a/imcui/third_party/SOLD2/sold2/dataset/synthetic_util.py b/imcui/third_party/SOLD2/sold2/dataset/synthetic_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..af009e0ce7e91391e31d7069064ae6121aa84cc0
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/synthetic_util.py
@@ -0,0 +1,1232 @@
+"""
+Code adapted from https://github.com/rpautrat/SuperPoint
+Module used to generate geometrical synthetic shapes
+"""
+import math
+import cv2 as cv
+import numpy as np
+import shapely.geometry
+from itertools import combinations
+
+random_state = np.random.RandomState(None)
+
+
+def set_random_state(state):
+ global random_state
+ random_state = state
+
+
+def get_random_color(background_color):
+ """ Output a random scalar in grayscale with a least a small contrast
+ with the background color. """
+ color = random_state.randint(256)
+ if abs(color - background_color) < 30: # not enough contrast
+ color = (color + 128) % 256
+ return color
+
+
+def get_different_color(previous_colors, min_dist=50, max_count=20):
+ """ Output a color that contrasts with the previous colors.
+ Parameters:
+ previous_colors: np.array of the previous colors
+ min_dist: the difference between the new color and
+ the previous colors must be at least min_dist
+ max_count: maximal number of iterations
+ """
+ color = random_state.randint(256)
+ count = 0
+ while np.any(np.abs(previous_colors - color) < min_dist) and count < max_count:
+ count += 1
+ color = random_state.randint(256)
+ return color
+
+
+def add_salt_and_pepper(img):
+ """ Add salt and pepper noise to an image. """
+ noise = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
+ cv.randu(noise, 0, 255)
+ black = noise < 30
+ white = noise > 225
+ img[white > 0] = 255
+ img[black > 0] = 0
+ cv.blur(img, (5, 5), img)
+ return np.empty((0, 2), dtype=np.int)
+
+
+def generate_background(size=(960, 1280), nb_blobs=100, min_rad_ratio=0.01,
+ max_rad_ratio=0.05, min_kernel_size=50,
+ max_kernel_size=300):
+ """ Generate a customized background image.
+ Parameters:
+ size: size of the image
+ nb_blobs: number of circles to draw
+ min_rad_ratio: the radius of blobs is at least min_rad_size * max(size)
+ max_rad_ratio: the radius of blobs is at most max_rad_size * max(size)
+ min_kernel_size: minimal size of the kernel
+ max_kernel_size: maximal size of the kernel
+ """
+ img = np.zeros(size, dtype=np.uint8)
+ dim = max(size)
+ cv.randu(img, 0, 255)
+ cv.threshold(img, random_state.randint(256), 255, cv.THRESH_BINARY, img)
+ background_color = int(np.mean(img))
+ blobs = np.concatenate(
+ [random_state.randint(0, size[1], size=(nb_blobs, 1)),
+ random_state.randint(0, size[0], size=(nb_blobs, 1))], axis=1)
+ for i in range(nb_blobs):
+ col = get_random_color(background_color)
+ cv.circle(img, (blobs[i][0], blobs[i][1]),
+ np.random.randint(int(dim * min_rad_ratio),
+ int(dim * max_rad_ratio)),
+ col, -1)
+ kernel_size = random_state.randint(min_kernel_size, max_kernel_size)
+ cv.blur(img, (kernel_size, kernel_size), img)
+ return img
+
+
+def generate_custom_background(size, background_color, nb_blobs=3000,
+ kernel_boundaries=(50, 100)):
+ """ Generate a customized background to fill the shapes.
+ Parameters:
+ background_color: average color of the background image
+ nb_blobs: number of circles to draw
+ kernel_boundaries: interval of the possible sizes of the kernel
+ """
+ img = np.zeros(size, dtype=np.uint8)
+ img = img + get_random_color(background_color)
+ blobs = np.concatenate(
+ [np.random.randint(0, size[1], size=(nb_blobs, 1)),
+ np.random.randint(0, size[0], size=(nb_blobs, 1))], axis=1)
+ for i in range(nb_blobs):
+ col = get_random_color(background_color)
+ cv.circle(img, (blobs[i][0], blobs[i][1]),
+ np.random.randint(20), col, -1)
+ kernel_size = np.random.randint(kernel_boundaries[0],
+ kernel_boundaries[1])
+ cv.blur(img, (kernel_size, kernel_size), img)
+ return img
+
+
+def final_blur(img, kernel_size=(5, 5)):
+ """ Gaussian blur applied to an image.
+ Parameters:
+ kernel_size: size of the kernel
+ """
+ cv.GaussianBlur(img, kernel_size, 0, img)
+
+
+def ccw(A, B, C, dim):
+ """ Check if the points are listed in counter-clockwise order. """
+ if dim == 2: # only 2 dimensions
+ return((C[:, 1] - A[:, 1]) * (B[:, 0] - A[:, 0])
+ > (B[:, 1] - A[:, 1]) * (C[:, 0] - A[:, 0]))
+ else: # dim should be equal to 3
+ return((C[:, 1, :] - A[:, 1, :])
+ * (B[:, 0, :] - A[:, 0, :])
+ > (B[:, 1, :] - A[:, 1, :])
+ * (C[:, 0, :] - A[:, 0, :]))
+
+
+def intersect(A, B, C, D, dim):
+ """ Return true if line segments AB and CD intersect """
+ return np.any((ccw(A, C, D, dim) != ccw(B, C, D, dim)) &
+ (ccw(A, B, C, dim) != ccw(A, B, D, dim)))
+
+
+def keep_points_inside(points, size):
+ """ Keep only the points whose coordinates are inside the dimensions of
+ the image of size 'size' """
+ mask = (points[:, 0] >= 0) & (points[:, 0] < size[1]) &\
+ (points[:, 1] >= 0) & (points[:, 1] < size[0])
+ return points[mask, :]
+
+
+def get_unique_junctions(segments, min_label_len):
+ """ Get unique junction points from line segments. """
+ # Get all junctions from segments
+ junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, segments)
+
+ return junc_points, line_map
+
+
+def get_line_map(points: np.ndarray, segments: np.ndarray) -> np.ndarray:
+ """ Get line map given the points and segment sets. """
+ # create empty line map
+ num_point = points.shape[0]
+ line_map = np.zeros([num_point, num_point])
+
+ # Iterate through every segment
+ for idx in range(segments.shape[0]):
+ # Get the junctions from a single segement
+ seg = segments[idx, :]
+ junction1 = seg[:2]
+ junction2 = seg[2:]
+
+ # Get index
+ idx_junction1 = np.where((points == junction1).sum(axis=1) == 2)[0]
+ idx_junction2 = np.where((points == junction2).sum(axis=1) == 2)[0]
+
+ # label the corresponding entries
+ line_map[idx_junction1, idx_junction2] = 1
+ line_map[idx_junction2, idx_junction1] = 1
+
+ return line_map
+
+
+def get_line_heatmap(junctions, line_map, size=[480, 640], thickness=1):
+ """ Get line heat map from junctions and line map. """
+ # Make sure that the thickness is 1
+ if not isinstance(thickness, int):
+ thickness = int(thickness)
+
+ # If the junction points are not int => round them and convert to int
+ if not junctions.dtype == np.int:
+ junctions = (np.round(junctions)).astype(np.int)
+
+ # Initialize empty map
+ heat_map = np.zeros(size)
+
+ if junctions.shape[0] > 0: # If empty, just return zero map
+ # Iterate through all the junctions
+ for idx in range(junctions.shape[0]):
+ # if no connectivity, just skip it
+ if line_map[idx, :].sum() == 0:
+ continue
+ # Plot the line segment
+ else:
+ # Iterate through all the connected junctions
+ for idx2 in np.where(line_map[idx, :] == 1)[0]:
+ point1 = junctions[idx, :]
+ point2 = junctions[idx2, :]
+
+ # Draw line
+ cv.line(heat_map, tuple(point1), tuple(point2), 1., thickness)
+
+ return heat_map
+
+
+def draw_lines(img, nb_lines=10, min_len=32, min_label_len=32):
+ """ Draw random lines and output the positions of the pair of junctions
+ and line associativities.
+ Parameters:
+ nb_lines: maximal number of lines
+ """
+ # Set line number and points placeholder
+ num_lines = random_state.randint(1, nb_lines)
+ segments = np.empty((0, 4), dtype=np.int)
+ points = np.empty((0, 2), dtype=np.int)
+ background_color = int(np.mean(img))
+ min_dim = min(img.shape)
+
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ # Generate lines one by one
+ for i in range(num_lines):
+ x1 = random_state.randint(img.shape[1])
+ y1 = random_state.randint(img.shape[0])
+ p1 = np.array([[x1, y1]])
+ x2 = random_state.randint(img.shape[1])
+ y2 = random_state.randint(img.shape[0])
+ p2 = np.array([[x2, y2]])
+
+ # Check the length of the line
+ line_length = np.sqrt(np.sum((p1 - p2) ** 2))
+ if line_length < min_len:
+ continue
+
+ # Check that there is no overlap
+ if intersect(segments[:, 0:2], segments[:, 2:4], p1, p2, 2):
+ continue
+
+ col = get_random_color(background_color)
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.02)
+ cv.line(img, (x1, y1), (x2, y2), col, thickness)
+
+ # Only record the segments longer than min_label_len
+ seg_len = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
+ if seg_len >= min_label_len:
+ segments = np.concatenate([segments,
+ np.array([[x1, y1, x2, y2]])], axis=0)
+ points = np.concatenate([points,
+ np.array([[x1, y1], [x2, y2]])], axis=0)
+
+ # If no line is drawn, recursively call the function
+ if points.shape[0] == 0:
+ return draw_lines(img, nb_lines, min_len, min_label_len)
+
+ # Get the line associativity map
+ line_map = get_line_map(points, segments)
+
+ return {
+ "points": points,
+ "line_map": line_map
+ }
+
+
+def check_segment_len(segments, min_len=32):
+ """ Check if one of the segments is too short (True means too short). """
+ point1_vec = segments[:, :2]
+ point2_vec = segments[:, 2:]
+ diff = point1_vec - point2_vec
+
+ dist = np.sqrt(np.sum(diff ** 2, axis=1))
+ if np.any(dist < min_len):
+ return True
+ else:
+ return False
+
+
+def draw_polygon(img, max_sides=8, min_len=32, min_label_len=64):
+ """ Draw a polygon with a random number of corners and return the position
+ of the junctions + line map.
+ Parameters:
+ max_sides: maximal number of sides + 1
+ """
+ num_corners = random_state.randint(3, max_sides)
+ min_dim = min(img.shape[0], img.shape[1])
+ rad = max(random_state.rand() * min_dim / 2, min_dim / 10)
+ # Center of a circle
+ x = random_state.randint(rad, img.shape[1] - rad)
+ y = random_state.randint(rad, img.shape[0] - rad)
+
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ # Sample num_corners points inside the circle
+ slices = np.linspace(0, 2 * math.pi, num_corners + 1)
+ angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
+ for i in range(num_corners)]
+ points = np.array(
+ [[int(x + max(random_state.rand(), 0.4) * rad * math.cos(a)),
+ int(y + max(random_state.rand(), 0.4) * rad * math.sin(a))]
+ for a in angles])
+
+ # Filter the points that are too close or that have an angle too flat
+ norms = [np.linalg.norm(points[(i-1) % num_corners, :]
+ - points[i, :]) for i in range(num_corners)]
+ mask = np.array(norms) > 0.01
+ points = points[mask, :]
+ num_corners = points.shape[0]
+ corner_angles = [angle_between_vectors(points[(i-1) % num_corners, :] -
+ points[i, :],
+ points[(i+1) % num_corners, :] -
+ points[i, :])
+ for i in range(num_corners)]
+ mask = np.array(corner_angles) < (2 * math.pi / 3)
+ points = points[mask, :]
+ num_corners = points.shape[0]
+
+ # Get junction pairs from points
+ segments = np.zeros([0, 4])
+ # Used to record all the segments no matter we are going to label it or not.
+ segments_raw = np.zeros([0, 4])
+ for idx in range(num_corners):
+ if idx == (num_corners - 1):
+ p1 = points[idx]
+ p2 = points[0]
+ else:
+ p1 = points[idx]
+ p2 = points[idx + 1]
+
+ segment = np.concatenate((p1, p2), axis=0)
+ # Only record the segments longer than min_label_len
+ seg_len = np.sqrt(np.sum((p1 - p2) ** 2))
+ if seg_len >= min_label_len:
+ segments = np.concatenate((segments, segment[None, ...]), axis=0)
+ segments_raw = np.concatenate((segments_raw, segment[None, ...]),
+ axis=0)
+
+ # If not enough corner, just regenerate one
+ if (num_corners < 3) or check_segment_len(segments_raw, min_len):
+ return draw_polygon(img, max_sides, min_len, min_label_len)
+
+ # Get junctions from segments
+ junctions_all = np.concatenate((segments[:, :2], segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+
+ # Get the line map
+ line_map = get_line_map(junc_points, segments)
+
+ corners = points.reshape((-1, 1, 2))
+ col = get_random_color(int(np.mean(img)))
+ cv.fillPoly(img, [corners], col)
+
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def overlap(center, rad, centers, rads):
+ """ Check that the circle with (center, rad)
+ doesn't overlap with the other circles. """
+ flag = False
+ for i in range(len(rads)):
+ if np.linalg.norm(center - centers[i]) < rad + rads[i]:
+ flag = True
+ break
+ return flag
+
+
+def angle_between_vectors(v1, v2):
+ """ Compute the angle (in rad) between the two vectors v1 and v2. """
+ v1_u = v1 / np.linalg.norm(v1)
+ v2_u = v2 / np.linalg.norm(v2)
+ return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
+
+
+def draw_multiple_polygons(img, max_sides=8, nb_polygons=30, min_len=32,
+ min_label_len=64, safe_margin=5, **extra):
+ """ Draw multiple polygons with a random number of corners
+ and return the junction points + line map.
+ Parameters:
+ max_sides: maximal number of sides + 1
+ nb_polygons: maximal number of polygons
+ """
+ segments = np.empty((0, 4), dtype=np.int)
+ label_segments = np.empty((0, 4), dtype=np.int)
+ centers = []
+ rads = []
+ points = np.empty((0, 2), dtype=np.int)
+ background_color = int(np.mean(img))
+
+ min_dim = min(img.shape[0], img.shape[1])
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+ if isinstance(safe_margin, float) and safe_margin <= 1.:
+ safe_margin = int(min_dim * safe_margin)
+
+ # Sequentially generate polygons
+ for i in range(nb_polygons):
+ num_corners = random_state.randint(3, max_sides)
+ min_dim = min(img.shape[0], img.shape[1])
+
+ # Also add the real radius
+ rad = max(random_state.rand() * min_dim / 2, min_dim / 9)
+ rad_real = rad - safe_margin
+
+ # Center of a circle
+ x = random_state.randint(rad, img.shape[1] - rad)
+ y = random_state.randint(rad, img.shape[0] - rad)
+
+ # Sample num_corners points inside the circle
+ slices = np.linspace(0, 2 * math.pi, num_corners + 1)
+ angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
+ for i in range(num_corners)]
+
+ # Sample outer points and inner points
+ new_points = []
+ new_points_real = []
+ for a in angles:
+ x_offset = max(random_state.rand(), 0.4)
+ y_offset = max(random_state.rand(), 0.4)
+ new_points.append([int(x + x_offset * rad * math.cos(a)),
+ int(y + y_offset * rad * math.sin(a))])
+ new_points_real.append(
+ [int(x + x_offset * rad_real * math.cos(a)),
+ int(y + y_offset * rad_real * math.sin(a))])
+ new_points = np.array(new_points)
+ new_points_real = np.array(new_points_real)
+
+ # Filter the points that are too close or that have an angle too flat
+ norms = [np.linalg.norm(new_points[(i-1) % num_corners, :]
+ - new_points[i, :])
+ for i in range(num_corners)]
+ mask = np.array(norms) > 0.01
+ new_points = new_points[mask, :]
+ new_points_real = new_points_real[mask, :]
+
+ num_corners = new_points.shape[0]
+ corner_angles = [
+ angle_between_vectors(new_points[(i-1) % num_corners, :] -
+ new_points[i, :],
+ new_points[(i+1) % num_corners, :] -
+ new_points[i, :])
+ for i in range(num_corners)]
+ mask = np.array(corner_angles) < (2 * math.pi / 3)
+ new_points = new_points[mask, :]
+ new_points_real = new_points_real[mask, :]
+ num_corners = new_points.shape[0]
+
+ # Not enough corners
+ if num_corners < 3:
+ continue
+
+ # Segments for checking overlap (outer circle)
+ new_segments = np.zeros((1, 4, num_corners))
+ new_segments[:, 0, :] = [new_points[i][0] for i in range(num_corners)]
+ new_segments[:, 1, :] = [new_points[i][1] for i in range(num_corners)]
+ new_segments[:, 2, :] = [new_points[(i+1) % num_corners][0]
+ for i in range(num_corners)]
+ new_segments[:, 3, :] = [new_points[(i+1) % num_corners][1]
+ for i in range(num_corners)]
+
+ # Segments to record (inner circle)
+ new_segments_real = np.zeros((1, 4, num_corners))
+ new_segments_real[:, 0, :] = [new_points_real[i][0]
+ for i in range(num_corners)]
+ new_segments_real[:, 1, :] = [new_points_real[i][1]
+ for i in range(num_corners)]
+ new_segments_real[:, 2, :] = [
+ new_points_real[(i + 1) % num_corners][0]
+ for i in range(num_corners)]
+ new_segments_real[:, 3, :] = [
+ new_points_real[(i + 1) % num_corners][1]
+ for i in range(num_corners)]
+
+ # Check that the polygon will not overlap with pre-existing shapes
+ if intersect(segments[:, 0:2, None], segments[:, 2:4, None],
+ new_segments[:, 0:2, :], new_segments[:, 2:4, :],
+ 3) or overlap(np.array([x, y]), rad, centers, rads):
+ continue
+
+ # Check that the the edges of the polygon is not too short
+ if check_segment_len(new_segments_real, min_len):
+ continue
+
+ # If the polygon is valid, append it to the polygon set
+ centers.append(np.array([x, y]))
+ rads.append(rad)
+ new_segments = np.reshape(np.swapaxes(new_segments, 0, 2), (-1, 4))
+ segments = np.concatenate([segments, new_segments], axis=0)
+
+ # Only record the segments longer than min_label_len
+ new_segments_real = np.reshape(np.swapaxes(new_segments_real, 0, 2),
+ (-1, 4))
+ points1 = new_segments_real[:, :2]
+ points2 = new_segments_real[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ new_label_segment = new_segments_real[seg_len >= min_label_len, :]
+ label_segments = np.concatenate([label_segments, new_label_segment],
+ axis=0)
+
+ # Color the polygon with a custom background
+ corners = new_points_real.reshape((-1, 1, 2))
+ mask = np.zeros(img.shape, np.uint8)
+ custom_background = generate_custom_background(
+ img.shape, background_color, **extra)
+
+ cv.fillPoly(mask, [corners], 255)
+ locs = np.where(mask != 0)
+ img[locs[0], locs[1]] = custom_background[locs[0], locs[1]]
+ points = np.concatenate([points, new_points], axis=0)
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_ellipses(img, nb_ellipses=20):
+ """ Draw several ellipses.
+ Parameters:
+ nb_ellipses: maximal number of ellipses
+ """
+ centers = np.empty((0, 2), dtype=np.int)
+ rads = np.empty((0, 1), dtype=np.int)
+ min_dim = min(img.shape[0], img.shape[1]) / 4
+ background_color = int(np.mean(img))
+ for i in range(nb_ellipses):
+ ax = int(max(random_state.rand() * min_dim, min_dim / 5))
+ ay = int(max(random_state.rand() * min_dim, min_dim / 5))
+ max_rad = max(ax, ay)
+ x = random_state.randint(max_rad, img.shape[1] - max_rad) # center
+ y = random_state.randint(max_rad, img.shape[0] - max_rad)
+ new_center = np.array([[x, y]])
+
+ # Check that the ellipsis will not overlap with pre-existing shapes
+ diff = centers - new_center
+ if np.any(max_rad > (np.sqrt(np.sum(diff * diff, axis=1)) - rads)):
+ continue
+ centers = np.concatenate([centers, new_center], axis=0)
+ rads = np.concatenate([rads, np.array([[max_rad]])], axis=0)
+
+ col = get_random_color(background_color)
+ angle = random_state.rand() * 90
+ cv.ellipse(img, (x, y), (ax, ay), angle, 0, 360, col, -1)
+ return np.empty((0, 2), dtype=np.int)
+
+
+def draw_star(img, nb_branches=6, min_len=32, min_label_len=64):
+ """ Draw a star and return the junction points + line map.
+ Parameters:
+ nb_branches: number of branches of the star
+ """
+ num_branches = random_state.randint(3, nb_branches)
+ min_dim = min(img.shape[0], img.shape[1])
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.025)
+ rad = max(random_state.rand() * min_dim / 2, min_dim / 5)
+ x = random_state.randint(rad, img.shape[1] - rad)
+ y = random_state.randint(rad, img.shape[0] - rad)
+ # Sample num_branches points inside the circle
+ slices = np.linspace(0, 2 * math.pi, num_branches + 1)
+ angles = [slices[i] + random_state.rand() * (slices[i+1] - slices[i])
+ for i in range(num_branches)]
+ points = np.array(
+ [[int(x + max(random_state.rand(), 0.3) * rad * math.cos(a)),
+ int(y + max(random_state.rand(), 0.3) * rad * math.sin(a))]
+ for a in angles])
+ points = np.concatenate(([[x, y]], points), axis=0)
+
+ # Generate segments and check the length
+ segments = np.array([[x, y, _[0], _[1]] for _ in points[1:, :]])
+ if check_segment_len(segments, min_len):
+ return draw_star(img, nb_branches, min_len, min_label_len)
+
+ # Only record the segments longer than min_label_len
+ points1 = segments[:, :2]
+ points2 = segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ background_color = int(np.mean(img))
+ for i in range(1, num_branches + 1):
+ col = get_random_color(background_color)
+ cv.line(img, (points[0][0], points[0][1]),
+ (points[i][0], points[i][1]),
+ col, thickness)
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_checkerboard_multiseg(img, max_rows=7, max_cols=7,
+ transform_params=(0.05, 0.15),
+ min_label_len=64, seed=None):
+ """ Draw a checkerboard and output the junctions + line segments
+ Parameters:
+ max_rows: maximal number of rows + 1
+ max_cols: maximal number of cols + 1
+ transform_params: set the range of the parameters of the transformations
+ """
+ if seed is None:
+ global random_state
+ else:
+ random_state = np.random.RandomState(seed)
+
+ background_color = int(np.mean(img))
+
+ min_dim = min(img.shape)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+ # Create the grid
+ rows = random_state.randint(3, max_rows) # number of rows
+ cols = random_state.randint(3, max_cols) # number of cols
+ s = min((img.shape[1] - 1) // cols, (img.shape[0] - 1) // rows)
+ x_coord = np.tile(range(cols + 1),
+ rows + 1).reshape(((rows + 1) * (cols + 1), 1))
+ y_coord = np.repeat(range(rows + 1),
+ cols + 1).reshape(((rows + 1) * (cols + 1), 1))
+ # points are the grid coordinates
+ points = s * np.concatenate([x_coord, y_coord], axis=1)
+
+ # Warp the grid using an affine transformation and an homography
+ alpha_affine = np.max(img.shape) * (
+ transform_params[0] + random_state.rand() * transform_params[1])
+ center_square = np.float32(img.shape) // 2
+ min_dim = min(img.shape)
+ square_size = min_dim // 3
+ pts1 = np.float32([center_square + square_size,
+ [center_square[0] + square_size,
+ center_square[1] - square_size],
+ center_square - square_size,
+ [center_square[0] - square_size,
+ center_square[1] + square_size]])
+ pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine,
+ size=pts1.shape).astype(np.float32)
+ affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3])
+ pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2,
+ size=pts1.shape).astype(np.float32)
+ perspective_transform = cv.getPerspectiveTransform(pts1, pts2)
+
+ # Apply the affine transformation
+ points = np.transpose(np.concatenate(
+ (points, np.ones(((rows + 1) * (cols + 1), 1))), axis=1))
+ warped_points = np.transpose(np.dot(affine_transform, points))
+
+ # Apply the homography
+ warped_col0 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[0, :2]), axis=1),
+ perspective_transform[0, 2])
+ warped_col1 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[1, :2]), axis=1),
+ perspective_transform[1, 2])
+ warped_col2 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[2, :2]), axis=1),
+ perspective_transform[2, 2])
+ warped_col0 = np.divide(warped_col0, warped_col2)
+ warped_col1 = np.divide(warped_col1, warped_col2)
+ warped_points = np.concatenate(
+ [warped_col0[:, None], warped_col1[:, None]], axis=1)
+ warped_points_float = warped_points.copy()
+ warped_points = warped_points.astype(int)
+
+ # Fill the rectangles
+ colors = np.zeros((rows * cols,), np.int32)
+ for i in range(rows):
+ for j in range(cols):
+ # Get a color that contrast with the neighboring cells
+ if i == 0 and j == 0:
+ col = get_random_color(background_color)
+ else:
+ neighboring_colors = []
+ if i != 0:
+ neighboring_colors.append(colors[(i - 1) * cols + j])
+ if j != 0:
+ neighboring_colors.append(colors[i * cols + j - 1])
+ col = get_different_color(np.array(neighboring_colors))
+ colors[i * cols + j] = col
+
+ # Fill the cell
+ cv.fillConvexPoly(img, np.array(
+ [(warped_points[i * (cols + 1) + j, 0],
+ warped_points[i * (cols + 1) + j, 1]),
+ (warped_points[i * (cols + 1) + j + 1, 0],
+ warped_points[i * (cols + 1) + j + 1, 1]),
+ (warped_points[(i + 1) * (cols + 1) + j + 1, 0],
+ warped_points[(i + 1) * (cols + 1) + j + 1, 1]),
+ (warped_points[(i + 1) * (cols + 1) + j, 0],
+ warped_points[(i + 1) * (cols + 1) + j, 1])]), col)
+
+ label_segments = np.empty([0, 4], dtype=np.int)
+ # Iterate through rows
+ for row_idx in range(rows + 1):
+ # Include all the combination of the junctions
+ # Iterate through all the combination of junction index in that row
+ multi_seg_lst = [
+ np.array([warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ row_idx * (cols + 1), (row_idx + 1) * (cols + 1), 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ label_segments = np.concatenate((label_segments, multi_seg), axis=0)
+
+ # Iterate through columns
+ for col_idx in range(cols + 1): # for 5 columns, we will have 5 + 1 edges
+ # Include all the combination of the junctions
+ # Iterate throuhg all the combination of junction index in that column
+ multi_seg_lst = [
+ np.array([warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ col_idx, col_idx + ((rows + 1) * (cols + 1)), cols + 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ label_segments = np.concatenate((label_segments, multi_seg), axis=0)
+
+ label_segments_filtered = np.zeros([0, 4])
+ # Define image boundary polygon (in x y manner)
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1],
+ [0, img.shape[0] - 1]])
+ for idx in range(label_segments.shape[0]):
+ # Get the line segment
+ seg_raw = label_segments[idx, :]
+ seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ label_segments_filtered = np.concatenate(
+ (label_segments_filtered, seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(seg.intersection(
+ image_poly).coords).reshape([-1, 4])
+ # If intersect with eact one point
+ except:
+ continue
+ segment = p
+ label_segments_filtered = np.concatenate(
+ (label_segments_filtered, segment), axis=0)
+
+ else:
+ continue
+
+ label_segments = np.round(label_segments_filtered).astype(np.int)
+
+ # Only record the segments longer than min_label_len
+ points1 = label_segments[:, :2]
+ points2 = label_segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = label_segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junc_points, line_map = get_unique_junctions(label_segments,
+ min_label_len)
+
+ # Draw lines on the boundaries of the board at random
+ nb_rows = random_state.randint(2, rows + 2)
+ nb_cols = random_state.randint(2, cols + 2)
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.015)
+ for _ in range(nb_rows):
+ row_idx = random_state.randint(rows + 1)
+ col_idx1 = random_state.randint(cols + 1)
+ col_idx2 = random_state.randint(cols + 1)
+ col = get_random_color(background_color)
+ cv.line(img, (warped_points[row_idx * (cols + 1) + col_idx1, 0],
+ warped_points[row_idx * (cols + 1) + col_idx1, 1]),
+ (warped_points[row_idx * (cols + 1) + col_idx2, 0],
+ warped_points[row_idx * (cols + 1) + col_idx2, 1]),
+ col, thickness)
+ for _ in range(nb_cols):
+ col_idx = random_state.randint(cols + 1)
+ row_idx1 = random_state.randint(rows + 1)
+ row_idx2 = random_state.randint(rows + 1)
+ col = get_random_color(background_color)
+ cv.line(img, (warped_points[row_idx1 * (cols + 1) + col_idx, 0],
+ warped_points[row_idx1 * (cols + 1) + col_idx, 1]),
+ (warped_points[row_idx2 * (cols + 1) + col_idx, 0],
+ warped_points[row_idx2 * (cols + 1) + col_idx, 1]),
+ col, thickness)
+
+ # Keep only the points inside the image
+ points = keep_points_inside(warped_points, img.shape[:2])
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_stripes_multiseg(img, max_nb_cols=13, min_len=0.04, min_label_len=64,
+ transform_params=(0.05, 0.15), seed=None):
+ """ Draw stripes in a distorted rectangle
+ and output the junctions points + line map.
+ Parameters:
+ max_nb_cols: maximal number of stripes to be drawn
+ min_width_ratio: the minimal width of a stripe is
+ min_width_ratio * smallest dimension of the image
+ transform_params: set the range of the parameters of the transformations
+ """
+ # Set the optional random seed (most for debugging)
+ if seed is None:
+ global random_state
+ else:
+ random_state = np.random.RandomState(seed)
+
+ background_color = int(np.mean(img))
+ # Create the grid
+ board_size = (int(img.shape[0] * (1 + random_state.rand())),
+ int(img.shape[1] * (1 + random_state.rand())))
+
+ # Number of cols
+ col = random_state.randint(5, max_nb_cols)
+ cols = np.concatenate([board_size[1] * random_state.rand(col - 1),
+ np.array([0, board_size[1] - 1])], axis=0)
+ cols = np.unique(cols.astype(int))
+
+ # Remove the indices that are too close
+ min_dim = min(img.shape)
+
+ # Convert length constrain to pixel if given float number
+ if isinstance(min_len, float) and min_len <= 1.:
+ min_len = int(min_dim * min_len)
+ if isinstance(min_label_len, float) and min_label_len <= 1.:
+ min_label_len = int(min_dim * min_label_len)
+
+ cols = cols[(np.concatenate([cols[1:],
+ np.array([board_size[1] + min_len])],
+ axis=0) - cols) >= min_len]
+ # Update the number of cols
+ col = cols.shape[0] - 1
+ cols = np.reshape(cols, (col + 1, 1))
+ cols1 = np.concatenate([cols, np.zeros((col + 1, 1), np.int32)], axis=1)
+ cols2 = np.concatenate(
+ [cols, (board_size[0] - 1) * np.ones((col + 1, 1), np.int32)], axis=1)
+ points = np.concatenate([cols1, cols2], axis=0)
+
+ # Warp the grid using an affine transformation and a homography
+ alpha_affine = np.max(img.shape) * (
+ transform_params[0] + random_state.rand() * transform_params[1])
+ center_square = np.float32(img.shape) // 2
+ square_size = min(img.shape) // 3
+ pts1 = np.float32([center_square + square_size,
+ [center_square[0]+square_size,
+ center_square[1]-square_size],
+ center_square - square_size,
+ [center_square[0]-square_size,
+ center_square[1]+square_size]])
+ pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine,
+ size=pts1.shape).astype(np.float32)
+ affine_transform = cv.getAffineTransform(pts1[:3], pts2[:3])
+ pts2 = pts1 + random_state.uniform(-alpha_affine / 2, alpha_affine / 2,
+ size=pts1.shape).astype(np.float32)
+ perspective_transform = cv.getPerspectiveTransform(pts1, pts2)
+
+ # Apply the affine transformation
+ points = np.transpose(np.concatenate((points,
+ np.ones((2 * (col + 1), 1))),
+ axis=1))
+ warped_points = np.transpose(np.dot(affine_transform, points))
+
+ # Apply the homography
+ warped_col0 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[0, :2]), axis=1),
+ perspective_transform[0, 2])
+ warped_col1 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[1, :2]), axis=1),
+ perspective_transform[1, 2])
+ warped_col2 = np.add(np.sum(np.multiply(
+ warped_points, perspective_transform[2, :2]), axis=1),
+ perspective_transform[2, 2])
+ warped_col0 = np.divide(warped_col0, warped_col2)
+ warped_col1 = np.divide(warped_col1, warped_col2)
+ warped_points = np.concatenate(
+ [warped_col0[:, None], warped_col1[:, None]], axis=1)
+ warped_points_float = warped_points.copy()
+ warped_points = warped_points.astype(int)
+
+ # Fill the rectangles and get the segments
+ color = get_random_color(background_color)
+ # segments_debug = np.zeros([0, 4])
+ for i in range(col):
+ # Fill the color
+ color = (color + 128 + random_state.randint(-30, 30)) % 256
+ cv.fillConvexPoly(img, np.array([(warped_points[i, 0],
+ warped_points[i, 1]),
+ (warped_points[i+1, 0],
+ warped_points[i+1, 1]),
+ (warped_points[i+col+2, 0],
+ warped_points[i+col+2, 1]),
+ (warped_points[i+col+1, 0],
+ warped_points[i+col+1, 1])]),
+ color)
+
+ segments = np.zeros([0, 4])
+ row = 1 # in stripes case
+ # Iterate through rows
+ for row_idx in range(row + 1):
+ # Include all the combination of the junctions
+ # Iterate through all the combination of junction index in that row
+ multi_seg_lst = [np.array(
+ [warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ row_idx * (col + 1), (row_idx + 1) * (col + 1), 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ segments = np.concatenate((segments, multi_seg), axis=0)
+
+ # Iterate through columns
+ for col_idx in range(col + 1): # for 5 columns, we will have 5 + 1 edges.
+ # Include all the combination of the junctions
+ # Iterate throuhg all the combination of junction index in that column
+ multi_seg_lst = [np.array(
+ [warped_points_float[id1, 0],
+ warped_points_float[id1, 1],
+ warped_points_float[id2, 0],
+ warped_points_float[id2, 1]])[None, ...]
+ for (id1, id2) in combinations(range(
+ col_idx, col_idx + (row * col) + 2, col + 1), 2)]
+ multi_seg = np.concatenate(multi_seg_lst, axis=0)
+ segments = np.concatenate((segments, multi_seg), axis=0)
+
+ # Select and refine the segments
+ segments_new = np.zeros([0, 4])
+ # Define image boundary polygon (in x y manner)
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [img.shape[1]-1, 0], [img.shape[1]-1, img.shape[0]-1],
+ [0, img.shape[0]-1]])
+ for idx in range(segments.shape[0]):
+ # Get the line segment
+ seg_raw = segments[idx, :]
+ seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ segments_new = np.concatenate(
+ (segments_new, seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ # If intersect at exact one point, just continue.
+ except:
+ continue
+ segment = p
+ segments_new = np.concatenate((segments_new, segment), axis=0)
+
+ else:
+ continue
+
+ segments = (np.round(segments_new)).astype(np.int)
+
+ # Only record the segments longer than min_label_len
+ points1 = segments[:, :2]
+ points2 = segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ # Draw lines on the boundaries of the stripes at random
+ nb_rows = random_state.randint(2, 5)
+ nb_cols = random_state.randint(2, col + 2)
+ thickness = random_state.randint(min_dim * 0.01, min_dim * 0.011)
+ for _ in range(nb_rows):
+ row_idx = random_state.choice([0, col + 1])
+ col_idx1 = random_state.randint(col + 1)
+ col_idx2 = random_state.randint(col + 1)
+ color = get_random_color(background_color)
+ cv.line(img, (warped_points[row_idx + col_idx1, 0],
+ warped_points[row_idx + col_idx1, 1]),
+ (warped_points[row_idx + col_idx2, 0],
+ warped_points[row_idx + col_idx2, 1]),
+ color, thickness)
+
+ for _ in range(nb_cols):
+ col_idx = random_state.randint(col + 1)
+ color = get_random_color(background_color)
+ cv.line(img, (warped_points[col_idx, 0],
+ warped_points[col_idx, 1]),
+ (warped_points[col_idx + col + 1, 0],
+ warped_points[col_idx + col + 1, 1]),
+ color, thickness)
+
+ # Keep only the points inside the image
+ # points = keep_points_inside(warped_points, img.shape[:2])
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def draw_cube(img, min_size_ratio=0.2, min_label_len=64,
+ scale_interval=(0.4, 0.6), trans_interval=(0.5, 0.2)):
+ """ Draw a 2D projection of a cube and output the visible juntions.
+ Parameters:
+ min_size_ratio: min(img.shape) * min_size_ratio is the smallest
+ achievable cube side size
+ scale_interval: the scale is between scale_interval[0] and
+ scale_interval[0]+scale_interval[1]
+ trans_interval: the translation is between img.shape*trans_interval[0]
+ and img.shape*(trans_interval[0] + trans_interval[1])
+ """
+ # Generate a cube and apply to it an affine transformation
+ # The order matters!
+ # The indices of two adjacent vertices differ only of one bit (Gray code)
+ background_color = int(np.mean(img))
+ min_dim = min(img.shape[:2])
+ min_side = min_dim * min_size_ratio
+ lx = min_side + random_state.rand() * 2 * min_dim / 3 # dims of the cube
+ ly = min_side + random_state.rand() * 2 * min_dim / 3
+ lz = min_side + random_state.rand() * 2 * min_dim / 3
+ cube = np.array([[0, 0, 0],
+ [lx, 0, 0],
+ [0, ly, 0],
+ [lx, ly, 0],
+ [0, 0, lz],
+ [lx, 0, lz],
+ [0, ly, lz],
+ [lx, ly, lz]])
+ rot_angles = random_state.rand(3) * 3 * math.pi / 10. + math.pi / 10.
+ rotation_1 = np.array([[math.cos(rot_angles[0]),
+ -math.sin(rot_angles[0]), 0],
+ [math.sin(rot_angles[0]),
+ math.cos(rot_angles[0]), 0],
+ [0, 0, 1]])
+ rotation_2 = np.array([[1, 0, 0],
+ [0, math.cos(rot_angles[1]),
+ -math.sin(rot_angles[1])],
+ [0, math.sin(rot_angles[1]),
+ math.cos(rot_angles[1])]])
+ rotation_3 = np.array([[math.cos(rot_angles[2]), 0,
+ -math.sin(rot_angles[2])],
+ [0, 1, 0],
+ [math.sin(rot_angles[2]), 0,
+ math.cos(rot_angles[2])]])
+ scaling = np.array([[scale_interval[0] +
+ random_state.rand() * scale_interval[1], 0, 0],
+ [0, scale_interval[0] +
+ random_state.rand() * scale_interval[1], 0],
+ [0, 0, scale_interval[0] +
+ random_state.rand() * scale_interval[1]]])
+ trans = np.array([img.shape[1] * trans_interval[0] +
+ random_state.randint(-img.shape[1] * trans_interval[1],
+ img.shape[1] * trans_interval[1]),
+ img.shape[0] * trans_interval[0] +
+ random_state.randint(-img.shape[0] * trans_interval[1],
+ img.shape[0] * trans_interval[1]),
+ 0])
+ cube = trans + np.transpose(
+ np.dot(scaling, np.dot(rotation_1,
+ np.dot(rotation_2, np.dot(rotation_3, np.transpose(cube))))))
+
+ # The hidden corner is 0 by construction
+ # The front one is 7
+ cube = cube[:, :2] # project on the plane z=0
+ cube = cube.astype(int)
+ points = cube[1:, :] # get rid of the hidden corner
+
+ # Get the three visible faces
+ faces = np.array([[7, 3, 1, 5], [7, 5, 4, 6], [7, 6, 2, 3]])
+
+ # Get all visible line segments
+ segments = np.zeros([0, 4])
+ # Iterate through all the faces
+ for face_idx in range(faces.shape[0]):
+ face = faces[face_idx, :]
+ # Brute-forcely expand all the segments
+ segment = np.array(
+ [np.concatenate((cube[face[0]], cube[face[1]]), axis=0),
+ np.concatenate((cube[face[1]], cube[face[2]]), axis=0),
+ np.concatenate((cube[face[2]], cube[face[3]]), axis=0),
+ np.concatenate((cube[face[3]], cube[face[0]]), axis=0)])
+ segments = np.concatenate((segments, segment), axis=0)
+
+ # Select and refine the segments
+ segments_new = np.zeros([0, 4])
+ # Define image boundary polygon (in x y manner)
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [img.shape[1] - 1, 0], [img.shape[1] - 1, img.shape[0] - 1],
+ [0, img.shape[0] - 1]])
+ for idx in range(segments.shape[0]):
+ # Get the line segment
+ seg_raw = segments[idx, :]
+ seg = shapely.geometry.LineString([seg_raw[:2], seg_raw[2:]])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ segments_new = np.concatenate(
+ (segments_new, seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ except:
+ continue
+ segment = p
+ segments_new = np.concatenate((segments_new, segment), axis=0)
+
+ else:
+ continue
+
+ segments = (np.round(segments_new)).astype(np.int)
+
+ # Only record the segments longer than min_label_len
+ points1 = segments[:, :2]
+ points2 = segments[:, 2:]
+ seg_len = np.sqrt(np.sum((points1 - points2) ** 2, axis=1))
+ label_segments = segments[seg_len >= min_label_len, :]
+
+ # Get all junctions from label segments
+ junctions_all = np.concatenate(
+ (label_segments[:, :2], label_segments[:, 2:]), axis=0)
+ if junctions_all.shape[0] == 0:
+ junc_points = None
+ line_map = None
+
+ # Get all unique junction points
+ else:
+ junc_points = np.unique(junctions_all, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junc_points, label_segments)
+
+ # Fill the faces and draw the contours
+ col_face = get_random_color(background_color)
+ for i in [0, 1, 2]:
+ cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))],
+ col_face)
+ thickness = random_state.randint(min_dim * 0.003, min_dim * 0.015)
+ for i in [0, 1, 2]:
+ for j in [0, 1, 2, 3]:
+ col_edge = (col_face + 128
+ + random_state.randint(-64, 64))\
+ % 256 # color that constrats with the face color
+ cv.line(img, (cube[faces[i][j], 0], cube[faces[i][j], 1]),
+ (cube[faces[i][(j + 1) % 4], 0],
+ cube[faces[i][(j + 1) % 4], 1]),
+ col_edge, thickness)
+
+ return {
+ "points": junc_points,
+ "line_map": line_map
+ }
+
+
+def gaussian_noise(img):
+ """ Apply random noise to the image. """
+ cv.randu(img, 0, 255)
+ return {
+ "points": None,
+ "line_map": None
+ }
diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/__init__.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9338abb169f7a86f3c6e702a031e1c0de86c339
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/transforms/homographic_transforms.py
@@ -0,0 +1,350 @@
+"""
+This file implements the homographic transforms for data augmentation.
+Code adapted from https://github.com/rpautrat/SuperPoint
+"""
+import numpy as np
+from math import pi
+
+from ..synthetic_util import get_line_map, get_line_heatmap
+import cv2
+import copy
+import shapely.geometry
+
+
+def sample_homography(
+ shape, perspective=True, scaling=True, rotation=True,
+ translation=True, n_scales=5, n_angles=25, scaling_amplitude=0.1,
+ perspective_amplitude_x=0.1, perspective_amplitude_y=0.1,
+ patch_ratio=0.5, max_angle=pi/2, allow_artifacts=False,
+ translation_overflow=0.):
+ """
+ Computes the homography transformation between a random patch in the
+ original image and a warped projection with the same image size.
+ As in `tf.contrib.image.transform`, it maps the output point
+ (warped patch) to a transformed input point (original patch).
+ The original patch, initialized with a simple half-size centered crop,
+ is iteratively projected, scaled, rotated and translated.
+
+ Arguments:
+ shape: A rank-2 `Tensor` specifying the height and width of the original image.
+ perspective: A boolean that enables the perspective and affine transformations.
+ scaling: A boolean that enables the random scaling of the patch.
+ rotation: A boolean that enables the random rotation of the patch.
+ translation: A boolean that enables the random translation of the patch.
+ n_scales: The number of tentative scales that are sampled when scaling.
+ n_angles: The number of tentatives angles that are sampled when rotating.
+ scaling_amplitude: Controls the amount of scale.
+ perspective_amplitude_x: Controls the perspective effect in x direction.
+ perspective_amplitude_y: Controls the perspective effect in y direction.
+ patch_ratio: Controls the size of the patches used to create the homography.
+ max_angle: Maximum angle used in rotations.
+ allow_artifacts: A boolean that enables artifacts when applying the homography.
+ translation_overflow: Amount of border artifacts caused by translation.
+
+ Returns:
+ homo_mat: A numpy array of shape `[1, 3, 3]` corresponding to the
+ homography transform.
+ selected_scale: The selected scaling factor.
+ """
+ # Convert shape to ndarry
+ if not isinstance(shape, np.ndarray):
+ shape = np.array(shape)
+
+ # Corners of the output image
+ pts1 = np.array([[0., 0.], [0., 1.], [1., 1.], [1., 0.]])
+ # Corners of the input patch
+ margin = (1 - patch_ratio) / 2
+ pts2 = margin + np.array([[0, 0], [0, patch_ratio],
+ [patch_ratio, patch_ratio], [patch_ratio, 0]])
+
+ # Random perspective and affine perturbations
+ if perspective:
+ if not allow_artifacts:
+ perspective_amplitude_x = min(perspective_amplitude_x, margin)
+ perspective_amplitude_y = min(perspective_amplitude_y, margin)
+
+ # normal distribution with mean=0, std=perspective_amplitude_y/2
+ perspective_displacement = np.random.normal(
+ 0., perspective_amplitude_y/2, [1])
+ h_displacement_left = np.random.normal(
+ 0., perspective_amplitude_x/2, [1])
+ h_displacement_right = np.random.normal(
+ 0., perspective_amplitude_x/2, [1])
+ pts2 += np.stack([np.concatenate([h_displacement_left,
+ perspective_displacement], 0),
+ np.concatenate([h_displacement_left,
+ -perspective_displacement], 0),
+ np.concatenate([h_displacement_right,
+ perspective_displacement], 0),
+ np.concatenate([h_displacement_right,
+ -perspective_displacement], 0)])
+
+ # Random scaling: sample several scales, check collision with borders,
+ # randomly pick a valid one
+ if scaling:
+ scales = np.concatenate(
+ [[1.], np.random.normal(1, scaling_amplitude/2, [n_scales])], 0)
+ center = np.mean(pts2, axis=0, keepdims=True)
+ scaled = (pts2 - center)[None, ...] * scales[..., None, None] + center
+ # all scales are valid except scale=1
+ if allow_artifacts:
+ valid = np.array(range(n_scales))
+ # Chech the valid scale
+ else:
+ valid = np.where(np.all((scaled >= 0.)
+ & (scaled < 1.), (1, 2)))[0]
+ # No valid scale found => recursively call
+ if valid.shape[0] == 0:
+ return sample_homography(
+ shape, perspective, scaling, rotation, translation,
+ n_scales, n_angles, scaling_amplitude,
+ perspective_amplitude_x, perspective_amplitude_y,
+ patch_ratio, max_angle, allow_artifacts, translation_overflow)
+
+ idx = valid[np.random.uniform(0., valid.shape[0], ()).astype(np.int32)]
+ pts2 = scaled[idx]
+
+ # Additionally save and return the selected scale.
+ selected_scale = scales[idx]
+
+ # Random translation
+ if translation:
+ t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0)
+ if allow_artifacts:
+ t_min += translation_overflow
+ t_max += translation_overflow
+ pts2 += (np.stack([np.random.uniform(-t_min[0], t_max[0], ()),
+ np.random.uniform(-t_min[1],
+ t_max[1], ())]))[None, ...]
+
+ # Random rotation: sample several rotations, check collision with borders,
+ # randomly pick a valid one
+ if rotation:
+ angles = np.linspace(-max_angle, max_angle, n_angles)
+ # in case no rotation is valid
+ angles = np.concatenate([[0.], angles], axis=0)
+ center = np.mean(pts2, axis=0, keepdims=True)
+ rot_mat = np.reshape(np.stack(
+ [np.cos(angles), -np.sin(angles),
+ np.sin(angles), np.cos(angles)], axis=1), [-1, 2, 2])
+ rotated = np.matmul(
+ np.tile((pts2 - center)[None, ...], [n_angles+1, 1, 1]),
+ rot_mat) + center
+ if allow_artifacts:
+ # All angles are valid, except angle=0
+ valid = np.array(range(n_angles))
+ else:
+ valid = np.where(np.all((rotated >= 0.)
+ & (rotated < 1.), axis=(1, 2)))[0]
+
+ if valid.shape[0] == 0:
+ return sample_homography(
+ shape, perspective, scaling, rotation, translation,
+ n_scales, n_angles, scaling_amplitude,
+ perspective_amplitude_x, perspective_amplitude_y,
+ patch_ratio, max_angle, allow_artifacts, translation_overflow)
+
+ idx = valid[np.random.uniform(0., valid.shape[0],
+ ()).astype(np.int32)]
+ pts2 = rotated[idx]
+
+ # Rescale to actual size
+ shape = shape[::-1].astype(np.float32) # different convention [y, x]
+ pts1 *= shape[None, ...]
+ pts2 *= shape[None, ...]
+
+ def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
+
+ def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
+
+ a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4)
+ for f in (ax, ay)], axis=0)
+ p_mat = np.transpose(np.stack([[pts2[i][j] for i in range(4)
+ for j in range(2)]], axis=0))
+ homo_vec, _, _, _ = np.linalg.lstsq(a_mat, p_mat, rcond=None)
+
+ # Compose the homography vector back to matrix
+ homo_mat = np.concatenate([
+ homo_vec[0:3, 0][None, ...], homo_vec[3:6, 0][None, ...],
+ np.concatenate((homo_vec[6], homo_vec[7], [1]),
+ axis=0)[None, ...]], axis=0)
+
+ return homo_mat, selected_scale
+
+
+def convert_to_line_segments(junctions, line_map):
+ """ Convert junctions and line map to line segments. """
+ # Copy the line map
+ line_map_tmp = copy.copy(line_map)
+
+ line_segments = np.zeros([0, 4])
+ for idx in range(junctions.shape[0]):
+ # If no connectivity, just skip it
+ if line_map_tmp[idx, :].sum() == 0:
+ continue
+ # Record the line segment
+ else:
+ for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
+ p1 = junctions[idx, :]
+ p2 = junctions[idx2, :]
+ line_segments = np.concatenate(
+ (line_segments,
+ np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
+ axis=0)
+ # Update line_map
+ line_map_tmp[idx, idx2] = 0
+ line_map_tmp[idx2, idx] = 0
+
+ return line_segments
+
+
+def compute_valid_mask(image_size, homography,
+ border_margin, valid_mask=None):
+ # Warp the mask
+ if valid_mask is None:
+ initial_mask = np.ones(image_size)
+ else:
+ initial_mask = valid_mask
+ mask = cv2.warpPerspective(
+ initial_mask, homography, (image_size[1], image_size[0]),
+ flags=cv2.INTER_NEAREST)
+
+ # Optionally perform erosion
+ if border_margin > 0:
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
+ (border_margin*2, )*2)
+ mask = cv2.erode(mask, kernel)
+
+ # Perform dilation if border_margin is negative
+ if border_margin < 0:
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
+ (abs(int(border_margin))*2, )*2)
+ mask = cv2.dilate(mask, kernel)
+
+ return mask
+
+
+def warp_line_segment(line_segments, homography, image_size):
+ """ Warp the line segments using a homography. """
+ # Separate the line segements into 2N points to apply matrix operation
+ num_segments = line_segments.shape[0]
+
+ junctions = np.concatenate(
+ (line_segments[:, :2], # The first junction of each segment.
+ line_segments[:, 2:]), # The second junction of each segment.
+ axis=0)
+ # Convert to homogeneous coordinates
+ # Flip the junctions before converting to homogeneous (xy format)
+ junctions = np.flip(junctions, axis=1)
+ junctions = np.concatenate((junctions, np.ones([2*num_segments, 1])),
+ axis=1)
+ warped_junctions = np.matmul(homography, junctions.T).T
+
+ # Convert back to segments
+ warped_junctions = warped_junctions[:, :2] / warped_junctions[:, 2:]
+ # (Convert back to hw format)
+ warped_junctions = np.flip(warped_junctions, axis=1)
+ warped_segments = np.concatenate(
+ (warped_junctions[:num_segments, :],
+ warped_junctions[num_segments:, :]),
+ axis=1
+ )
+
+ # Check the intersections with the boundary
+ warped_segments_new = np.zeros([0, 4])
+ image_poly = shapely.geometry.Polygon(
+ [[0, 0], [image_size[1]-1, 0], [image_size[1]-1, image_size[0]-1],
+ [0, image_size[0]-1]])
+ for idx in range(warped_segments.shape[0]):
+ # Get the line segment
+ seg_raw = warped_segments[idx, :] # in HW format.
+ # Convert to shapely line (flip to xy format)
+ seg = shapely.geometry.LineString([np.flip(seg_raw[:2]),
+ np.flip(seg_raw[2:])])
+
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ warped_segments_new = np.concatenate((warped_segments_new,
+ seg_raw[None, ...]), axis=0)
+
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ # If intersect at exact one point, just continue.
+ except:
+ continue
+ segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:],
+ axis=0)])[None, ...]
+ warped_segments_new = np.concatenate(
+ (warped_segments_new, segment), axis=0)
+
+ else:
+ continue
+
+ warped_segments = (np.round(warped_segments_new)).astype(np.int)
+ return warped_segments
+
+
+class homography_transform(object):
+ """ # Homography transformations. """
+ def __init__(self, image_size, homograpy_config,
+ border_margin=0, min_label_len=20):
+ self.homo_config = homograpy_config
+ self.image_size = image_size
+ self.target_size = (self.image_size[1], self.image_size[0])
+ self.border_margin = border_margin
+ if (min_label_len < 1) and isinstance(min_label_len, float):
+ raise ValueError("[Error] min_label_len should be in pixels.")
+ self.min_label_len = min_label_len
+
+ def __call__(self, input_image, junctions, line_map,
+ valid_mask=None, homo=None, scale=None):
+ # Sample one random homography or use the given one
+ if homo is None or scale is None:
+ homo, scale = sample_homography(self.image_size,
+ **self.homo_config)
+
+ # Warp the image
+ warped_image = cv2.warpPerspective(
+ input_image, homo, self.target_size, flags=cv2.INTER_LINEAR)
+
+ valid_mask = compute_valid_mask(self.image_size, homo,
+ self.border_margin, valid_mask)
+
+ # Convert junctions and line_map back to line segments
+ line_segments = convert_to_line_segments(junctions, line_map)
+
+ # Warp the segments and check the length.
+ # Adjust the min_label_length
+ warped_segments = warp_line_segment(line_segments, homo,
+ self.image_size)
+
+ # Convert back to junctions and line_map
+ junctions_new = np.concatenate((warped_segments[:, :2],
+ warped_segments[:, 2:]), axis=0)
+ if junctions_new.shape[0] == 0:
+ junctions_new = np.zeros([0, 2])
+ line_map = np.zeros([0, 0])
+ warped_heatmap = np.zeros(self.image_size)
+ else:
+ junctions_new = np.unique(junctions_new, axis=0)
+
+ # Generate line map from points and segments
+ line_map = get_line_map(junctions_new,
+ warped_segments).astype(np.int)
+ # Compute the heatmap
+ warped_heatmap = get_line_heatmap(np.flip(junctions_new, axis=1),
+ line_map, self.image_size)
+
+ return {
+ "junctions": junctions_new,
+ "warped_image": warped_image,
+ "valid_mask": valid_mask,
+ "line_map": line_map,
+ "warped_heatmap": warped_heatmap,
+ "homo": homo,
+ "scale": scale
+ }
diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa44bf0efa93a47e5f8012988058f1cbd49324f
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/transforms/photometric_transforms.py
@@ -0,0 +1,185 @@
+"""
+Common photometric transforms for data augmentation.
+"""
+import numpy as np
+from PIL import Image
+from torchvision import transforms as transforms
+import cv2
+
+
+# List all the available augmentations
+available_augmentations = [
+ 'additive_gaussian_noise',
+ 'additive_speckle_noise',
+ 'random_brightness',
+ 'random_contrast',
+ 'additive_shade',
+ 'motion_blur'
+]
+
+
+class additive_gaussian_noise(object):
+ """ Additive gaussian noise. """
+ def __init__(self, stddev_range=None):
+ # If std is not given, use the default setting
+ if stddev_range is None:
+ self.stddev_range = [5, 95]
+ else:
+ self.stddev_range = stddev_range
+
+ def __call__(self, input_image):
+ # Get the noise stddev
+ stddev = np.random.uniform(self.stddev_range[0], self.stddev_range[1])
+ noise = np.random.normal(0., stddev, size=input_image.shape)
+ noisy_image = (input_image + noise).clip(0., 255.)
+
+ return noisy_image
+
+
+class additive_speckle_noise(object):
+ """ Additive speckle noise. """
+ def __init__(self, prob_range=None):
+ # If prob range is not given, use the default setting
+ if prob_range is None:
+ self.prob_range = [0.0, 0.005]
+ else:
+ self.prob_range = prob_range
+
+ def __call__(self, input_image):
+ # Sample
+ prob = np.random.uniform(self.prob_range[0], self.prob_range[1])
+ sample = np.random.uniform(0., 1., size=input_image.shape)
+
+ # Get the mask
+ mask0 = sample <= prob
+ mask1 = sample >= (1 - prob)
+
+ # Mask the image (here we assume the image ranges from 0~255
+ noisy = input_image.copy()
+ noisy[mask0] = 0.
+ noisy[mask1] = 255.
+
+ return noisy
+
+
+class random_brightness(object):
+ """ Brightness change. """
+ def __init__(self, brightness=None):
+ # If the brightness is not given, use the default setting
+ if brightness is None:
+ self.brightness = 0.5
+ else:
+ self.brightness = brightness
+
+ # Initialize the transformer
+ self.transform = transforms.ColorJitter(brightness=self.brightness)
+
+ def __call__(self, input_image):
+ # Convert to PIL image
+ if isinstance(input_image, np.ndarray):
+ input_image = Image.fromarray(input_image.astype(np.uint8))
+
+ return np.array(self.transform(input_image))
+
+
+class random_contrast(object):
+ """ Additive contrast. """
+ def __init__(self, contrast=None):
+ # If the brightness is not given, use the default setting
+ if contrast is None:
+ self.contrast = 0.5
+ else:
+ self.contrast = contrast
+
+ # Initialize the transformer
+ self.transform = transforms.ColorJitter(contrast=self.contrast)
+
+ def __call__(self, input_image):
+ # Convert to PIL image
+ if isinstance(input_image, np.ndarray):
+ input_image = Image.fromarray(input_image.astype(np.uint8))
+
+ return np.array(self.transform(input_image))
+
+
+class additive_shade(object):
+ """ Additive shade. """
+ def __init__(self, nb_ellipses=20, transparency_range=None,
+ kernel_size_range=None):
+ self.nb_ellipses = nb_ellipses
+ if transparency_range is None:
+ self.transparency_range = [-0.5, 0.8]
+ else:
+ self.transparency_range = transparency_range
+
+ if kernel_size_range is None:
+ self.kernel_size_range = [250, 350]
+ else:
+ self.kernel_size_range = kernel_size_range
+
+ def __call__(self, input_image):
+ # ToDo: if we should convert to numpy array first.
+ min_dim = min(input_image.shape[:2]) / 4
+ mask = np.zeros(input_image.shape[:2], np.uint8)
+ for i in range(self.nb_ellipses):
+ ax = int(max(np.random.rand() * min_dim, min_dim / 5))
+ ay = int(max(np.random.rand() * min_dim, min_dim / 5))
+ max_rad = max(ax, ay)
+ x = np.random.randint(max_rad, input_image.shape[1] - max_rad)
+ y = np.random.randint(max_rad, input_image.shape[0] - max_rad)
+ angle = np.random.rand() * 90
+ cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
+
+ transparency = np.random.uniform(*self.transparency_range)
+ kernel_size = np.random.randint(*self.kernel_size_range)
+
+ # kernel_size has to be odd
+ if (kernel_size % 2) == 0:
+ kernel_size += 1
+ mask = cv2.GaussianBlur(mask.astype(np.float32),
+ (kernel_size, kernel_size), 0)
+ shaded = (input_image[..., None]
+ * (1 - transparency * mask[..., np.newaxis]/255.))
+ shaded = np.clip(shaded, 0, 255)
+
+ return np.reshape(shaded, input_image.shape)
+
+
+class motion_blur(object):
+ """ Motion blur. """
+ def __init__(self, max_kernel_size=10):
+ self.max_kernel_size = max_kernel_size
+
+ def __call__(self, input_image):
+ # Either vertical, horizontal or diagonal blur
+ mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up'])
+ ksize = np.random.randint(
+ 0, int(round((self.max_kernel_size + 1) / 2))) * 2 + 1
+ center = int((ksize - 1) / 2)
+ kernel = np.zeros((ksize, ksize))
+ if mode == 'h':
+ kernel[center, :] = 1.
+ elif mode == 'v':
+ kernel[:, center] = 1.
+ elif mode == 'diag_down':
+ kernel = np.eye(ksize)
+ elif mode == 'diag_up':
+ kernel = np.flip(np.eye(ksize), 0)
+ var = ksize * ksize / 16.
+ grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1)
+ gaussian = np.exp(-(np.square(grid - center)
+ + np.square(grid.T - center)) / (2. * var))
+ kernel *= gaussian
+ kernel /= np.sum(kernel)
+ blurred = cv2.filter2D(input_image, -1, kernel)
+
+ return np.reshape(blurred, input_image.shape)
+
+
+class normalize_image(object):
+ """ Image normalization to the range [0, 1]. """
+ def __init__(self):
+ self.normalize_value = 255
+
+ def __call__(self, input_image):
+ return (input_image / self.normalize_value).astype(np.float32)
diff --git a/imcui/third_party/SOLD2/sold2/dataset/transforms/utils.py b/imcui/third_party/SOLD2/sold2/dataset/transforms/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f1ed09e5b32e2ae2f3577e0e8e5491495e7b05b
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/transforms/utils.py
@@ -0,0 +1,121 @@
+"""
+Some useful functions for dataset pre-processing
+"""
+import cv2
+import numpy as np
+import shapely.geometry as sg
+
+from ..synthetic_util import get_line_map
+from . import homographic_transforms as homoaug
+
+
+def random_scaling(image, junctions, line_map, scale=1., h_crop=0, w_crop=0):
+ H, W = image.shape[:2]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+
+ # Nothing to do if the scale is too close to 1
+ if H_scale == H and W_scale == W:
+ return (image, junctions, line_map, np.ones([H, W], dtype=np.int))
+
+ # Zoom-in => resize and random crop
+ if scale >= 1.:
+ image_big = cv2.resize(image, (W_scale, H_scale),
+ interpolation=cv2.INTER_LINEAR)
+ # Crop the image
+ image = image_big[h_crop:h_crop+H, w_crop:w_crop+W, ...]
+ valid_mask = np.ones([H, W], dtype=np.int)
+
+ # Process junctions
+ junctions, line_map = process_junctions_and_line_map(
+ h_crop, w_crop, H, W, H_scale, W_scale,
+ junctions, line_map, "zoom-in")
+ # Zoom-out => resize and pad
+ else:
+ image_shape_raw = image.shape
+ image_small = cv2.resize(image, (W_scale, H_scale),
+ interpolation=cv2.INTER_AREA)
+ # Decide the pasting location
+ h_start = round((H - H_scale) / 2)
+ w_start = round((W - W_scale) / 2)
+ # Paste the image to the middle
+ image = np.zeros(image_shape_raw, dtype=np.float)
+ image[h_start:h_start+H_scale,
+ w_start:w_start+W_scale, ...] = image_small
+ valid_mask = np.zeros([H, W], dtype=np.int)
+ valid_mask[h_start:h_start+H_scale, w_start:w_start+W_scale] = 1
+
+ # Process the junctions
+ junctions, line_map = process_junctions_and_line_map(
+ h_start, w_start, H, W, H_scale, W_scale,
+ junctions, line_map, "zoom-out")
+
+ return image, junctions, line_map, valid_mask
+
+
+def process_junctions_and_line_map(h_start, w_start, H, W, H_scale, W_scale,
+ junctions, line_map, mode="zoom-in"):
+ if mode == "zoom-in":
+ junctions[:, 0] = junctions[:, 0] * H_scale / H
+ junctions[:, 1] = junctions[:, 1] * W_scale / W
+ line_segments = homoaug.convert_to_line_segments(junctions, line_map)
+ # Crop segments to the new boundaries
+ line_segments_new = np.zeros([0, 4])
+ image_poly = sg.Polygon(
+ [[w_start, h_start],
+ [w_start+W, h_start],
+ [w_start+W, h_start+H],
+ [w_start, h_start+H]
+ ])
+ for idx in range(line_segments.shape[0]):
+ # Get the line segment
+ seg_raw = line_segments[idx, :] # in HW format.
+ # Convert to shapely line (flip to xy format)
+ seg = sg.LineString([np.flip(seg_raw[:2]),
+ np.flip(seg_raw[2:])])
+ # The line segment is just inside the image.
+ if seg.intersection(image_poly) == seg:
+ line_segments_new = np.concatenate(
+ (line_segments_new, seg_raw[None, ...]), axis=0)
+ # Intersect with the image.
+ elif seg.intersects(image_poly):
+ # Check intersection
+ try:
+ p = np.array(
+ seg.intersection(image_poly).coords).reshape([-1, 4])
+ # If intersect at exact one point, just continue.
+ except:
+ continue
+ segment = np.concatenate([np.flip(p[0, :2]), np.flip(p[0, 2:],
+ axis=0)])[None, ...]
+ line_segments_new = np.concatenate(
+ (line_segments_new, segment), axis=0)
+ else:
+ continue
+ line_segments_new = (np.round(line_segments_new)).astype(np.int)
+ # Filter segments with 0 length
+ segment_lens = np.linalg.norm(
+ line_segments_new[:, :2] - line_segments_new[:, 2:], axis=-1)
+ seg_mask = segment_lens != 0
+ line_segments_new = line_segments_new[seg_mask, :]
+ # Convert back to junctions and line_map
+ junctions_new = np.concatenate(
+ (line_segments_new[:, :2], line_segments_new[:, 2:]), axis=0)
+ if junctions_new.shape[0] == 0:
+ junctions_new = np.zeros([0, 2])
+ line_map = np.zeros([0, 0])
+ else:
+ junctions_new = np.unique(junctions_new, axis=0)
+ # Generate line map from points and segments
+ line_map = get_line_map(junctions_new,
+ line_segments_new).astype(np.int)
+ junctions_new[:, 0] -= h_start
+ junctions_new[:, 1] -= w_start
+ junctions = junctions_new
+ elif mode == "zoom-out":
+ # Process the junctions
+ junctions[:, 0] = (junctions[:, 0] * H_scale / H) + h_start
+ junctions[:, 1] = (junctions[:, 1] * W_scale / W) + w_start
+ else:
+ raise ValueError("[Error] unknown mode...")
+
+ return junctions, line_map
diff --git a/imcui/third_party/SOLD2/sold2/dataset/wireframe_dataset.py b/imcui/third_party/SOLD2/sold2/dataset/wireframe_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed5bb910bed1b89934ddaaec3bcddf111ea0faef
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/dataset/wireframe_dataset.py
@@ -0,0 +1,1000 @@
+"""
+This file implements the wireframe dataset object for pytorch.
+Some parts of the code are adapted from https://github.com/zhou13/lcnn
+"""
+import os
+import math
+import copy
+from skimage.io import imread
+from skimage import color
+import PIL
+import numpy as np
+import h5py
+import cv2
+import pickle
+import torch
+import torch.utils.data.dataloader as torch_loader
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from ..config.project_config import Config as cfg
+from .transforms import photometric_transforms as photoaug
+from .transforms import homographic_transforms as homoaug
+from .transforms.utils import random_scaling
+from .synthetic_util import get_line_heatmap
+from ..misc.train_utils import parse_h5_data
+from ..misc.geometry_utils import warp_points, mask_points
+
+
+def wireframe_collate_fn(batch):
+ """ Customized collate_fn for wireframe dataset. """
+ batch_keys = ["image", "junction_map", "valid_mask", "heatmap",
+ "heatmap_pos", "heatmap_neg", "homography",
+ "line_points", "line_indices"]
+ list_keys = ["junctions", "line_map", "line_map_pos",
+ "line_map_neg", "file_key"]
+
+ outputs = {}
+ for data_key in batch[0].keys():
+ batch_match = sum([_ in data_key for _ in batch_keys])
+ list_match = sum([_ in data_key for _ in list_keys])
+ # print(batch_match, list_match)
+ if batch_match > 0 and list_match == 0:
+ outputs[data_key] = torch_loader.default_collate(
+ [b[data_key] for b in batch])
+ elif batch_match == 0 and list_match > 0:
+ outputs[data_key] = [b[data_key] for b in batch]
+ elif batch_match == 0 and list_match == 0:
+ continue
+ else:
+ raise ValueError(
+ "[Error] A key matches batch keys and list keys simultaneously.")
+
+ return outputs
+
+
+class WireframeDataset(Dataset):
+ def __init__(self, mode="train", config=None):
+ super(WireframeDataset, self).__init__()
+ if not mode in ["train", "test"]:
+ raise ValueError(
+ "[Error] Unknown mode for Wireframe dataset. Only 'train' and 'test'.")
+ self.mode = mode
+
+ if config is None:
+ self.config = self.get_default_config()
+ else:
+ self.config = config
+ # Also get the default config
+ self.default_config = self.get_default_config()
+
+ # Get cache setting
+ self.dataset_name = self.get_dataset_name()
+ self.cache_name = self.get_cache_name()
+ self.cache_path = cfg.wireframe_cache_path
+
+ # Get the ground truth source
+ self.gt_source = self.config.get("gt_source_%s"%(self.mode),
+ "official")
+ if not self.gt_source == "official":
+ # Convert gt_source to full path
+ self.gt_source = os.path.join(cfg.export_dataroot, self.gt_source)
+ # Check the full path exists
+ if not os.path.exists(self.gt_source):
+ raise ValueError(
+ "[Error] The specified ground truth source does not exist.")
+
+
+ # Get the filename dataset
+ print("[Info] Initializing wireframe dataset...")
+ self.filename_dataset, self.datapoints = self.construct_dataset()
+
+ # Get dataset length
+ self.dataset_length = len(self.datapoints)
+
+ # Print some info
+ print("[Info] Successfully initialized dataset")
+ print("\t Name: wireframe")
+ print("\t Mode: %s" %(self.mode))
+ print("\t Gt: %s" %(self.config.get("gt_source_%s"%(self.mode),
+ "official")))
+ print("\t Counts: %d" %(self.dataset_length))
+ print("----------------------------------------")
+
+ #######################################
+ ## Dataset construction related APIs ##
+ #######################################
+ def construct_dataset(self):
+ """ Construct the dataset (from scratch or from cache). """
+ # Check if the filename cache exists
+ # If cache exists, load from cache
+ if self._check_dataset_cache():
+ print("\t Found filename cache %s at %s"%(self.cache_name,
+ self.cache_path))
+ print("\t Load filename cache...")
+ filename_dataset, datapoints = self.get_filename_dataset_from_cache()
+ # If not, initialize dataset from scratch
+ else:
+ print("\t Can't find filename cache ...")
+ print("\t Create filename dataset from scratch...")
+ filename_dataset, datapoints = self.get_filename_dataset()
+ print("\t Create filename dataset cache...")
+ self.create_filename_dataset_cache(filename_dataset, datapoints)
+
+ return filename_dataset, datapoints
+
+ def create_filename_dataset_cache(self, filename_dataset, datapoints):
+ """ Create filename dataset cache for faster initialization. """
+ # Check cache path exists
+ if not os.path.exists(self.cache_path):
+ os.makedirs(self.cache_path)
+
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ data = {
+ "filename_dataset": filename_dataset,
+ "datapoints": datapoints
+ }
+ with open(cache_file_path, "wb") as f:
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+ def get_filename_dataset_from_cache(self):
+ """ Get filename dataset from cache. """
+ # Load from pkl cache
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ with open(cache_file_path, "rb") as f:
+ data = pickle.load(f)
+
+ return data["filename_dataset"], data["datapoints"]
+
+ def get_filename_dataset(self):
+ # Get the path to the dataset
+ if self.mode == "train":
+ dataset_path = os.path.join(cfg.wireframe_dataroot, "train")
+ elif self.mode == "test":
+ dataset_path = os.path.join(cfg.wireframe_dataroot, "valid")
+
+ # Get paths to all image files
+ image_paths = sorted([os.path.join(dataset_path, _)
+ for _ in os.listdir(dataset_path)\
+ if os.path.splitext(_)[-1] == ".png"])
+ # Get the shared prefix
+ prefix_paths = [_.split(".png")[0] for _ in image_paths]
+
+ # Get the label paths (different procedure for different split)
+ if self.mode == "train":
+ label_paths = [_ + "_label.npz" for _ in prefix_paths]
+ else:
+ label_paths = [_ + "_label.npz" for _ in prefix_paths]
+ mat_paths = [p[:-2] + "_line.mat" for p in prefix_paths]
+
+ # Verify all the images and labels exist
+ for idx in range(len(image_paths)):
+ image_path = image_paths[idx]
+ label_path = label_paths[idx]
+ if (not (os.path.exists(image_path)
+ and os.path.exists(label_path))):
+ raise ValueError(
+ "[Error] The image and label do not exist. %s"%(image_path))
+ # Further verify mat paths for test split
+ if self.mode == "test":
+ mat_path = mat_paths[idx]
+ if not os.path.exists(mat_path):
+ raise ValueError(
+ "[Error] The mat file does not exist. %s"%(mat_path))
+
+ # Construct the filename dataset
+ num_pad = int(math.ceil(math.log10(len(image_paths))) + 1)
+ filename_dataset = {}
+ for idx in range(len(image_paths)):
+ # Get the file key
+ key = self.get_padded_filename(num_pad, idx)
+
+ filename_dataset[key] = {
+ "image": image_paths[idx],
+ "label": label_paths[idx]
+ }
+
+ # Get the datapoints
+ datapoints = list(sorted(filename_dataset.keys()))
+
+ return filename_dataset, datapoints
+
+ def get_dataset_name(self):
+ """ Get dataset name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+
+ return dataset_name
+
+ def get_cache_name(self):
+ """ Get cache name from dataset config / default config. """
+ if self.config["dataset_name"] is None:
+ dataset_name = self.default_config["dataset_name"] + "_%s" % self.mode
+ else:
+ dataset_name = self.config["dataset_name"] + "_%s" % self.mode
+ # Compose cache name
+ cache_name = dataset_name + "_cache.pkl"
+
+ return cache_name
+
+ @staticmethod
+ def get_padded_filename(num_pad, idx):
+ """ Get the padded filename using adaptive padding. """
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+
+ return filename
+
+ def get_default_config(self):
+ """ Get the default configuration. """
+ return {
+ "dataset_name": "wireframe",
+ "add_augmentation_to_all_splits": False,
+ "preprocessing": {
+ "resize": [240, 320],
+ "blur_size": 11
+ },
+ "augmentation":{
+ "photometric":{
+ "enable": False
+ },
+ "homographic":{
+ "enable": False
+ },
+ },
+ }
+
+
+ ############################################
+ ## Pytorch and preprocessing related APIs ##
+ ############################################
+ # Get data from the information from filename dataset
+ @staticmethod
+ def get_data_from_path(data_path):
+ output = {}
+
+ # Get image data
+ image_path = data_path["image"]
+ image = imread(image_path)
+ output["image"] = image
+
+ # Get the npz label
+ """ Data entries in the npz file
+ jmap: [J, H, W] Junction heat map (H and W are 4x smaller)
+ joff: [J, 2, H, W] Junction offset within each pixel (Not sure about offsets)
+ lmap: [H, W] Line heat map with anti-aliasing (H and W are 4x smaller)
+ junc: [Na, 3] Junction coordinates (coordinates from 0~128 => 4x smaller.)
+ Lpos: [M, 2] Positive lines represented with junction indices
+ Lneg: [M, 2] Negative lines represented with junction indices
+ lpos: [Np, 2, 3] Positive lines represented with junction coordinates
+ lneg: [Nn, 2, 3] Negative lines represented with junction coordinates
+ """
+ label_path = data_path["label"]
+ label = np.load(label_path)
+ for key in list(label.keys()):
+ output[key] = label[key]
+
+ # If there's "line_mat" entry.
+ # TODO: How to process mat data
+ if data_path.get("line_mat") is not None:
+ raise NotImplementedError
+
+ return output
+
+ @staticmethod
+ def convert_line_map(lcnn_line_map, num_junctions):
+ """ Convert the line_pos or line_neg
+ (represented by two junction indexes) to our line map. """
+ # Initialize empty line map
+ line_map = np.zeros([num_junctions, num_junctions])
+
+ # Iterate through all the lines
+ for idx in range(lcnn_line_map.shape[0]):
+ index1 = lcnn_line_map[idx, 0]
+ index2 = lcnn_line_map[idx, 1]
+
+ line_map[index1, index2] = 1
+ line_map[index2, index1] = 1
+
+ return line_map
+
+ @staticmethod
+ def junc_to_junc_map(junctions, image_size):
+ """ Convert junction points to junction maps. """
+ junctions = np.round(junctions).astype(np.int)
+ # Clip the boundary by image size
+ junctions[:, 0] = np.clip(junctions[:, 0], 0., image_size[0]-1)
+ junctions[:, 1] = np.clip(junctions[:, 1], 0., image_size[1]-1)
+
+ # Create junction map
+ junc_map = np.zeros([image_size[0], image_size[1]])
+ junc_map[junctions[:, 0], junctions[:, 1]] = 1
+
+ return junc_map[..., None].astype(np.int)
+
+ def parse_transforms(self, names, all_transforms):
+ """ Parse the transform. """
+ trans = all_transforms if (names == 'all') \
+ else (names if isinstance(names, list) else [names])
+ assert set(trans) <= set(all_transforms)
+ return trans
+
+ def get_photo_transform(self):
+ """ Get list of photometric transforms (according to the config). """
+ # Get the photometric transform config
+ photo_config = self.config["augmentation"]["photometric"]
+ if not photo_config["enable"]:
+ raise ValueError(
+ "[Error] Photometric augmentation is not enabled.")
+
+ # Parse photometric transforms
+ trans_lst = self.parse_transforms(photo_config["primitives"],
+ photoaug.available_augmentations)
+ trans_config_lst = [photo_config["params"].get(p, {})
+ for p in trans_lst]
+
+ # List of photometric augmentation
+ photometric_trans_lst = [
+ getattr(photoaug, trans)(**conf) \
+ for (trans, conf) in zip(trans_lst, trans_config_lst)
+ ]
+
+ return photometric_trans_lst
+
+ def get_homo_transform(self):
+ """ Get homographic transforms (according to the config). """
+ # Get homographic transforms for image
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ if not self.config["augmentation"]["homographic"]["enable"]:
+ raise ValueError(
+ "[Error] Homographic augmentation is not enabled.")
+
+ # Parse the homographic transforms
+ image_shape = self.config["preprocessing"]["resize"]
+
+ # Compute the min_label_len from config
+ try:
+ min_label_tmp = self.config["generation"]["min_label_len"]
+ except:
+ min_label_tmp = None
+
+ # float label len => fraction
+ if isinstance(min_label_tmp, float): # Skip if not provided
+ min_label_len = min_label_tmp * min(image_shape)
+ # int label len => length in pixel
+ elif isinstance(min_label_tmp, int):
+ scale_ratio = (self.config["preprocessing"]["resize"]
+ / self.config["generation"]["image_size"][0])
+ min_label_len = (self.config["generation"]["min_label_len"]
+ * scale_ratio)
+ # if none => no restriction
+ else:
+ min_label_len = 0
+
+ # Initialize the transform
+ homographic_trans = homoaug.homography_transform(
+ image_shape, homo_config, 0, min_label_len)
+
+ return homographic_trans
+
+ def get_line_points(self, junctions, line_map, H1=None, H2=None,
+ img_size=None, warp=False):
+ """ Sample evenly points along each line segments
+ and keep track of line idx. """
+ if np.sum(line_map) == 0:
+ # No segment detected in the image
+ line_indices = np.zeros(self.config["max_pts"], dtype=int)
+ line_points = np.zeros((self.config["max_pts"], 2), dtype=float)
+ return line_points, line_indices
+
+ # Extract all pairs of connected junctions
+ junc_indices = np.array(
+ [[i, j] for (i, j) in zip(*np.where(line_map)) if j > i])
+ line_segments = np.stack([junctions[junc_indices[:, 0]],
+ junctions[junc_indices[:, 1]]], axis=1)
+ # line_segments is (num_lines, 2, 2)
+ line_lengths = np.linalg.norm(
+ line_segments[:, 0] - line_segments[:, 1], axis=1)
+
+ # Sample the points separated by at least min_dist_pts along each line
+ # The number of samples depends on the length of the line
+ num_samples = np.minimum(line_lengths // self.config["min_dist_pts"],
+ self.config["max_num_samples"])
+ line_points = []
+ line_indices = []
+ cur_line_idx = 1
+ for n in np.arange(2, self.config["max_num_samples"] + 1):
+ # Consider all lines where we can fit up to n points
+ cur_line_seg = line_segments[num_samples == n]
+ line_points_x = np.linspace(cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 0],
+ n, axis=-1).flatten()
+ line_points_y = np.linspace(cur_line_seg[:, 0, 1],
+ cur_line_seg[:, 1, 1],
+ n, axis=-1).flatten()
+ jitter = self.config.get("jittering", 0)
+ if jitter:
+ # Add a small random jittering of all points along the line
+ angles = np.arctan2(
+ cur_line_seg[:, 1, 0] - cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 1] - cur_line_seg[:, 0, 1]).repeat(n)
+ jitter_hyp = (np.random.rand(len(angles)) * 2 - 1) * jitter
+ line_points_x += jitter_hyp * np.sin(angles)
+ line_points_y += jitter_hyp * np.cos(angles)
+ line_points.append(np.stack([line_points_x, line_points_y], axis=-1))
+ # Keep track of the line indices for each sampled point
+ num_cur_lines = len(cur_line_seg)
+ line_idx = np.arange(cur_line_idx, cur_line_idx + num_cur_lines)
+ line_indices.append(line_idx.repeat(n))
+ cur_line_idx += num_cur_lines
+ line_points = np.concatenate(line_points,
+ axis=0)[:self.config["max_pts"]]
+ line_indices = np.concatenate(line_indices,
+ axis=0)[:self.config["max_pts"]]
+
+ # Warp the points if need be, and filter unvalid ones
+ # If the other view is also warped
+ if warp and H2 is not None:
+ warp_points2 = warp_points(line_points, H2)
+ line_points = warp_points(line_points, H1)
+ mask = mask_points(line_points, img_size)
+ mask2 = mask_points(warp_points2, img_size)
+ mask = mask * mask2
+ # If the other view is not warped
+ elif warp and H2 is None:
+ line_points = warp_points(line_points, H1)
+ mask = mask_points(line_points, img_size)
+ else:
+ if H1 is not None:
+ raise ValueError("[Error] Wrong combination of homographies.")
+ # Remove points that would be outside of img_size if warped by H
+ warped_points = warp_points(line_points, H1)
+ mask = mask_points(warped_points, img_size)
+ line_points = line_points[mask]
+ line_indices = line_indices[mask]
+
+ # Pad the line points to a fixed length
+ # Index of 0 means padded line
+ line_indices = np.concatenate([line_indices, np.zeros(
+ self.config["max_pts"] - len(line_indices))], axis=0)
+ line_points = np.concatenate(
+ [line_points,
+ np.zeros((self.config["max_pts"] - len(line_points), 2),
+ dtype=float)], axis=0)
+
+ return line_points, line_indices
+
+ def train_preprocessing(self, data, numpy=False):
+ """ Train preprocessing for GT data. """
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junc"][:, :2]
+ line_pos = data["Lpos"]
+ line_neg = data["Lneg"]
+ image_size = image.shape[:2]
+ # Convert junctions to pixel coordinates (from 128x128)
+ junctions[:, 0] *= image_size[0] / 128
+ junctions[:, 1] *= image_size[1] / 128
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # In HW format
+ junctions = (junctions * np.array(
+ self.config['preprocessing']['resize'], np.float)
+ / np.array(size_old, np.float))
+
+ # Convert to positive line map and negative line map (our format)
+ num_junctions = junctions.shape[0]
+ line_map_pos = self.convert_line_map(line_pos, num_junctions)
+ line_map_neg = self.convert_line_map(line_neg, num_junctions)
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ # Update image size
+ image_size = image.shape[:2]
+ heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size)
+ heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size)
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Check if we need to apply augmentations
+ # In training mode => yes.
+ # In homography adaptation mode (export mode) => No
+ if self.config["augmentation"]["photometric"]["enable"]:
+ photo_trans_lst = self.get_photo_transform()
+ ### Image transform ###
+ np.random.shuffle(photo_trans_lst)
+ image_transform = transforms.Compose(
+ photo_trans_lst + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Check homographic augmentation
+ if self.config["augmentation"]["homographic"]["enable"]:
+ homo_trans = self.get_homo_transform()
+ # Perform homographic transform
+ outputs_pos = homo_trans(image, junctions, line_map_pos)
+ outputs_neg = homo_trans(image, junctions, line_map_neg)
+
+ # record the warped results
+ junctions = outputs_pos["junctions"] # Should be HW format
+ image = outputs_pos["warped_image"]
+ line_map_pos = outputs_pos["line_map"]
+ line_map_neg = outputs_neg["line_map"]
+ heatmap_pos = outputs_pos["warped_heatmap"]
+ heatmap_neg = outputs_neg["warped_heatmap"]
+ valid_mask = outputs_pos["valid_mask"] # Same for pos and neg
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ return {
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map_pos": to_tensor(
+ line_map_pos).to(torch.int32)[0, ...],
+ "line_map_neg": to_tensor(
+ line_map_neg).to(torch.int32)[0, ...],
+ "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32),
+ "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ return {
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map_pos": line_map_pos.astype(np.int32),
+ "line_map_neg": line_map_neg.astype(np.int32),
+ "heatmap_pos": heatmap_pos.astype(np.int32),
+ "heatmap_neg": heatmap_neg.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ def train_preprocessing_exported(
+ self, data, numpy=False, disable_homoaug=False,
+ desc_training=False, H1=None, H1_scale=None, H2=None, scale=1.,
+ h_crop=None, w_crop=None):
+ """ Train preprocessing for the exported labels. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junctions"]
+ line_map = data["line_map"]
+ image_size = image.shape[:2]
+
+ # Define the random crop for scaling if necessary
+ if h_crop is None or w_crop is None:
+ h_crop, w_crop = 0, 0
+ if scale > 1:
+ H, W = self.config["preprocessing"]["resize"]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+ if H_scale > H:
+ h_crop = np.random.randint(H_scale - H)
+ if W_scale > W:
+ w_crop = np.random.randint(W_scale - W)
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # # In HW format
+ # junctions = (junctions * np.array(
+ # self.config['preprocessing']['resize'], np.float)
+ # / np.array(size_old, np.float))
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ image_size = image.shape[:2]
+ heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Check if we need to apply augmentations
+ # In training mode => yes.
+ # In homography adaptation mode (export mode) => No
+ if self.config["augmentation"]["photometric"]["enable"]:
+ photo_trans_lst = self.get_photo_transform()
+ ### Image transform ###
+ np.random.shuffle(photo_trans_lst)
+ image_transform = transforms.Compose(
+ photo_trans_lst + [photoaug.normalize_image()])
+ else:
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Perform the random scaling
+ if scale != 1.:
+ image, junctions, line_map, valid_mask = random_scaling(
+ image, junctions, line_map, scale,
+ h_crop=h_crop, w_crop=w_crop)
+ else:
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ # Initialize the empty output dict
+ outputs = {}
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+
+ # Check homographic augmentation
+ warp = (self.config["augmentation"]["homographic"]["enable"]
+ and disable_homoaug == False)
+ if warp:
+ homo_trans = self.get_homo_transform()
+ # Perform homographic transform
+ if H1 is None:
+ homo_outputs = homo_trans(
+ image, junctions, line_map, valid_mask=valid_mask)
+ else:
+ homo_outputs = homo_trans(
+ image, junctions, line_map, homo=H1, scale=H1_scale,
+ valid_mask=valid_mask)
+ homography_mat = homo_outputs["homo"]
+
+ # Give the warp of the other view
+ if H1 is None:
+ H1 = homo_outputs["homo"]
+
+ # Sample points along each line segments for the descriptor
+ if desc_training:
+ line_points, line_indices = self.get_line_points(
+ junctions, line_map, H1=H1, H2=H2,
+ img_size=image_size, warp=warp)
+
+ # Record the warped results
+ if warp:
+ junctions = homo_outputs["junctions"] # Should be HW format
+ image = homo_outputs["warped_image"]
+ line_map = homo_outputs["line_map"]
+ valid_mask = homo_outputs["valid_mask"] # Same for pos and neg
+ heatmap = homo_outputs["warped_heatmap"]
+
+ # Optionally put warping information first.
+ if not numpy:
+ outputs["homography_mat"] = to_tensor(
+ homography_mat).to(torch.float32)[0, ...]
+ else:
+ outputs["homography_mat"] = homography_mat.astype(np.float32)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ if not numpy:
+ outputs.update({
+ "image": to_tensor(image).to(torch.float32),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ })
+ if desc_training:
+ outputs.update({
+ "line_points": to_tensor(
+ line_points).to(torch.float32)[0],
+ "line_indices": torch.tensor(line_indices,
+ dtype=torch.int)
+ })
+ else:
+ outputs.update({
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map": line_map.astype(np.int32),
+ "heatmap": heatmap.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ })
+ if desc_training:
+ outputs.update({
+ "line_points": line_points.astype(np.float32),
+ "line_indices": line_indices.astype(int)
+ })
+
+ return outputs
+
+ def preprocessing_exported_paired_desc(self, data, numpy=False, scale=1.):
+ """ Train preprocessing for paired data for the exported labels
+ for descriptor training. """
+ outputs = {}
+
+ # Define the random crop for scaling if necessary
+ h_crop, w_crop = 0, 0
+ if scale > 1:
+ H, W = self.config["preprocessing"]["resize"]
+ H_scale, W_scale = round(H * scale), round(W * scale)
+ if H_scale > H:
+ h_crop = np.random.randint(H_scale - H)
+ if W_scale > W:
+ w_crop = np.random.randint(W_scale - W)
+
+ # Sample ref homography first
+ homo_config = self.config["augmentation"]["homographic"]["params"]
+ image_shape = self.config["preprocessing"]["resize"]
+ ref_H, ref_scale = homoaug.sample_homography(image_shape,
+ **homo_config)
+
+ # Data for target view (All augmentation)
+ target_data = self.train_preprocessing_exported(
+ data, numpy=numpy, desc_training=True, H1=None, H2=ref_H,
+ scale=scale, h_crop=h_crop, w_crop=w_crop)
+
+ # Data for reference view (No homographical augmentation)
+ ref_data = self.train_preprocessing_exported(
+ data, numpy=numpy, desc_training=True, H1=ref_H,
+ H1_scale=ref_scale, H2=target_data["homography_mat"].numpy(),
+ scale=scale, h_crop=h_crop, w_crop=w_crop)
+
+ # Spread ref data
+ for key, val in ref_data.items():
+ outputs["ref_" + key] = val
+
+ # Spread target data
+ for key, val in target_data.items():
+ outputs["target_" + key] = val
+
+ return outputs
+
+ def test_preprocessing(self, data, numpy=False):
+ """ Test preprocessing for GT data. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junc"][:, :2]
+ line_pos = data["Lpos"]
+ line_neg = data["Lneg"]
+ image_size = image.shape[:2]
+ # Convert junctions to pixel coordinates (from 128x128)
+ junctions[:, 0] *= image_size[0] / 128
+ junctions[:, 1] *= image_size[1] / 128
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # In HW format
+ junctions = (junctions * np.array(
+ self.config['preprocessing']['resize'], np.float)
+ / np.array(size_old, np.float))
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Still need to normalize image
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Convert to positive line map and negative line map (our format)
+ num_junctions = junctions.shape[0]
+ line_map_pos = self.convert_line_map(line_pos, num_junctions)
+ line_map_neg = self.convert_line_map(line_neg, num_junctions)
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ # Update image size
+ image_size = image.shape[:2]
+ heatmap_pos = get_line_heatmap(junctions_xy, line_map_pos, image_size)
+ heatmap_neg = get_line_heatmap(junctions_xy, line_map_neg, image_size)
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ return {
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map_pos": to_tensor(
+ line_map_pos).to(torch.int32)[0, ...],
+ "line_map_neg": to_tensor(
+ line_map_neg).to(torch.int32)[0, ...],
+ "heatmap_pos": to_tensor(heatmap_pos).to(torch.int32),
+ "heatmap_neg": to_tensor(heatmap_neg).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ return {
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map_pos": line_map_pos.astype(np.int32),
+ "line_map_neg": line_map_neg.astype(np.int32),
+ "heatmap_pos": heatmap_pos.astype(np.int32),
+ "heatmap_neg": heatmap_neg.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ def test_preprocessing_exported(self, data, numpy=False, scale=1.):
+ """ Test preprocessing for the exported labels. """
+ data = copy.deepcopy(data)
+ # Fetch the corresponding entries
+ image = data["image"]
+ junctions = data["junctions"]
+ line_map = data["line_map"]
+ image_size = image.shape[:2]
+
+ # Resize the image before photometric and homographical augmentations
+ if not(list(image_size) == self.config["preprocessing"]["resize"]):
+ # Resize the image and the point location.
+ size_old = list(image.shape)[:2] # Only H and W dimensions
+
+ image = cv2.resize(
+ image, tuple(self.config['preprocessing']['resize'][::-1]),
+ interpolation=cv2.INTER_LINEAR)
+ image = np.array(image, dtype=np.uint8)
+
+ # # In HW format
+ # junctions = (junctions * np.array(
+ # self.config['preprocessing']['resize'], np.float)
+ # / np.array(size_old, np.float))
+
+ # Optionally convert the image to grayscale
+ if self.config["gray_scale"]:
+ image = (color.rgb2gray(image) * 255.).astype(np.uint8)
+
+ # Still need to normalize image
+ image_transform = photoaug.normalize_image()
+ image = image_transform(image)
+
+ # Generate the line heatmap after post-processing
+ junctions_xy = np.flip(np.round(junctions).astype(np.int32), axis=1)
+ image_size = image.shape[:2]
+ heatmap = get_line_heatmap(junctions_xy, line_map, image_size)
+
+ # Declare default valid mask (all ones)
+ valid_mask = np.ones(image_size)
+
+ junction_map = self.junc_to_junc_map(junctions, image_size)
+
+ # Convert to tensor and return the results
+ to_tensor = transforms.ToTensor()
+ if not numpy:
+ outputs = {
+ "image": to_tensor(image),
+ "junctions": to_tensor(junctions).to(torch.float32)[0, ...],
+ "junction_map": to_tensor(junction_map).to(torch.int),
+ "line_map": to_tensor(line_map).to(torch.int32)[0, ...],
+ "heatmap": to_tensor(heatmap).to(torch.int32),
+ "valid_mask": to_tensor(valid_mask).to(torch.int32)
+ }
+ else:
+ outputs = {
+ "image": image,
+ "junctions": junctions.astype(np.float32),
+ "junction_map": junction_map.astype(np.int32),
+ "line_map": line_map.astype(np.int32),
+ "heatmap": heatmap.astype(np.int32),
+ "valid_mask": valid_mask.astype(np.int32)
+ }
+
+ return outputs
+
+ def __len__(self):
+ return self.dataset_length
+
+ def get_data_from_key(self, file_key):
+ """ Get data from file_key. """
+ # Check key exists
+ if not file_key in self.filename_dataset.keys():
+ raise ValueError("[Error] the specified key is not in the dataset.")
+
+ # Get the data paths
+ data_path = self.filename_dataset[file_key]
+ # Read in the image and npz labels (but haven't applied any transform)
+ data = self.get_data_from_path(data_path)
+
+ # Perform transform and augmentation
+ if self.mode == "train" or self.config["add_augmentation_to_all_splits"]:
+ data = self.train_preprocessing(data, numpy=True)
+ else:
+ data = self.test_preprocessing(data, numpy=True)
+
+ # Add file key to the output
+ data["file_key"] = file_key
+
+ return data
+
+ def __getitem__(self, idx):
+ """Return data
+ file_key: str, keys used to retrieve data from the filename dataset.
+ image: torch.float, C*H*W range 0~1,
+ junctions: torch.float, N*2,
+ junction_map: torch.int32, 1*H*W range 0 or 1,
+ line_map_pos: torch.int32, N*N range 0 or 1,
+ line_map_neg: torch.int32, N*N range 0 or 1,
+ heatmap_pos: torch.int32, 1*H*W range 0 or 1,
+ heatmap_neg: torch.int32, 1*H*W range 0 or 1,
+ valid_mask: torch.int32, 1*H*W range 0 or 1
+ """
+ # Get the corresponding datapoint and contents from filename dataset
+ file_key = self.datapoints[idx]
+ data_path = self.filename_dataset[file_key]
+ # Read in the image and npz labels (but haven't applied any transform)
+ data = self.get_data_from_path(data_path)
+
+ # Also load the exported labels if not using the official ground truth
+ if not self.gt_source == "official":
+ with h5py.File(self.gt_source, "r") as f:
+ exported_label = parse_h5_data(f[file_key])
+
+ data["junctions"] = exported_label["junctions"]
+ data["line_map"] = exported_label["line_map"]
+
+ # Perform transform and augmentation
+ return_type = self.config.get("return_type", "single")
+ if (self.mode == "train"
+ or self.config["add_augmentation_to_all_splits"]):
+ # Perform random scaling first
+ if self.config["augmentation"]["random_scaling"]["enable"]:
+ scale_range = self.config["augmentation"]["random_scaling"]["range"]
+ # Decide the scaling
+ scale = np.random.uniform(min(scale_range), max(scale_range))
+ else:
+ scale = 1.
+ if self.gt_source == "official":
+ data = self.train_preprocessing(data)
+ else:
+ if return_type == "paired_desc":
+ data = self.preprocessing_exported_paired_desc(
+ data, scale=scale)
+ else:
+ data = self.train_preprocessing_exported(data,
+ scale=scale)
+ else:
+ if self.gt_source == "official":
+ data = self.test_preprocessing(data)
+ elif return_type == "paired_desc":
+ data = self.preprocessing_exported_paired_desc(data)
+ else:
+ data = self.test_preprocessing_exported(data)
+
+ # Add file key to the output
+ data["file_key"] = file_key
+
+ return data
+
+ ########################
+ ## Some other methods ##
+ ########################
+ def _check_dataset_cache(self):
+ """ Check if dataset cache exists. """
+ cache_file_path = os.path.join(self.cache_path, self.cache_name)
+ if os.path.exists(cache_file_path):
+ return True
+ else:
+ return False
diff --git a/imcui/third_party/SOLD2/sold2/experiment.py b/imcui/third_party/SOLD2/sold2/experiment.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bf4db1c9f148b9e33c6d7d0ba973375cd770a14
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/experiment.py
@@ -0,0 +1,227 @@
+"""
+Main file to launch training and testing experiments.
+"""
+
+import yaml
+import os
+import argparse
+import numpy as np
+import torch
+
+from .config.project_config import Config as cfg
+from .train import train_net
+from .export import export_predictions, export_homograpy_adaptation
+
+
+# Pytorch configurations
+torch.cuda.empty_cache()
+torch.backends.cudnn.benchmark = True
+
+
+def load_config(config_path):
+ """ Load configurations from a given yaml file. """
+ # Check file exists
+ if not os.path.exists(config_path):
+ raise ValueError("[Error] The provided config path is not valid.")
+
+ # Load the configuration
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+
+ return config
+
+
+def update_config(path, model_cfg=None, dataset_cfg=None):
+ """ Update configuration file from the resume path. """
+ # Check we need to update or completely override.
+ model_cfg = {} if model_cfg is None else model_cfg
+ dataset_cfg = {} if dataset_cfg is None else dataset_cfg
+
+ # Load saved configs
+ with open(os.path.join(path, "model_cfg.yaml"), "r") as f:
+ model_cfg_saved = yaml.safe_load(f)
+ model_cfg.update(model_cfg_saved)
+ with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f:
+ dataset_cfg_saved = yaml.safe_load(f)
+ dataset_cfg.update(dataset_cfg_saved)
+
+ # Update the saved yaml file
+ if not model_cfg == model_cfg_saved:
+ with open(os.path.join(path, "model_cfg.yaml"), "w") as f:
+ yaml.dump(model_cfg, f)
+ if not dataset_cfg == dataset_cfg_saved:
+ with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f:
+ yaml.dump(dataset_cfg, f)
+
+ return model_cfg, dataset_cfg
+
+
+def record_config(model_cfg, dataset_cfg, output_path):
+ """ Record dataset config to the log path. """
+ # Record model config
+ with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f:
+ yaml.safe_dump(model_cfg, f)
+
+ # Record dataset config
+ with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f:
+ yaml.safe_dump(dataset_cfg, f)
+
+
+def train(args, dataset_cfg, model_cfg, output_path):
+ """ Training function. """
+ # Update model config from the resume path (only in resume mode)
+ if args.resume:
+ if os.path.realpath(output_path) != os.path.realpath(args.resume_path):
+ record_config(model_cfg, dataset_cfg, output_path)
+
+ # First time, then write the config file to the output path
+ else:
+ record_config(model_cfg, dataset_cfg, output_path)
+
+ # Launch the training
+ train_net(args, dataset_cfg, model_cfg, output_path)
+
+
+def export(args, dataset_cfg, model_cfg, output_path,
+ export_dataset_mode=None, device=torch.device("cuda")):
+ """ Export function. """
+ # Choose between normal predictions export or homography adaptation
+ if dataset_cfg.get("homography_adaptation") is not None:
+ print("[Info] Export predictions with homography adaptation.")
+ export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
+ export_dataset_mode, device)
+ else:
+ print("[Info] Export predictions normally.")
+ export_predictions(args, dataset_cfg, model_cfg, output_path,
+ export_dataset_mode)
+
+
+def main(args, dataset_cfg, model_cfg, export_dataset_mode=None,
+ device=torch.device("cuda")):
+ """ Main function. """
+ # Make the output path
+ output_path = os.path.join(cfg.EXP_PATH, args.exp_name)
+
+ if args.mode == "train":
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ print("[Info] Training mode")
+ print("\t Output path: %s" % output_path)
+ train(args, dataset_cfg, model_cfg, output_path)
+ elif args.mode == "export":
+ # Different output_path in export mode
+ output_path = os.path.join(cfg.export_dataroot, args.exp_name)
+ print("[Info] Export mode")
+ print("\t Output path: %s" % output_path)
+ export(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device=device)
+ else:
+ raise ValueError("[Error]: Unknown mode: " + args.mode)
+
+
+def set_random_seed(seed):
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+
+if __name__ == "__main__":
+ # Parse input arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--mode", type=str, default="train",
+ help="'train' or 'export'.")
+ parser.add_argument("--dataset_config", type=str, default=None,
+ help="Path to the dataset config.")
+ parser.add_argument("--model_config", type=str, default=None,
+ help="Path to the model config.")
+ parser.add_argument("--exp_name", type=str, default="exp",
+ help="Experiment name.")
+ parser.add_argument("--resume", action="store_true", default=False,
+ help="Load a previously trained model.")
+ parser.add_argument("--pretrained", action="store_true", default=False,
+ help="Start training from a pre-trained model.")
+ parser.add_argument("--resume_path", default=None,
+ help="Path from which to resume training.")
+ parser.add_argument("--pretrained_path", default=None,
+ help="Path to the pre-trained model.")
+ parser.add_argument("--checkpoint_name", default=None,
+ help="Name of the checkpoint to use.")
+ parser.add_argument("--export_dataset_mode", default=None,
+ help="'train' or 'test'.")
+ parser.add_argument("--export_batch_size", default=4, type=int,
+ help="Export batch size.")
+
+ args = parser.parse_args()
+
+ # Check if GPU is available
+ # Get the model
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ else:
+ device = torch.device("cpu")
+
+ # Check if dataset config and model config is given.
+ if (((args.dataset_config is None) or (args.model_config is None))
+ and (not args.resume) and (args.mode == "train")):
+ raise ValueError(
+ "[Error] The dataset config and model config should be given in non-resume mode")
+
+ # If resume, check if the resume path has been given
+ if args.resume and (args.resume_path is None):
+ raise ValueError(
+ "[Error] Missing resume path.")
+
+ # [Training] Load the config file.
+ if args.mode == "train" and (not args.resume):
+ # Check the pretrained checkpoint_path exists
+ if args.pretrained:
+ checkpoint_folder = args.resume_path
+ checkpoint_path = os.path.join(args.pretrained_path,
+ args.checkpoint_name)
+ if not os.path.exists(checkpoint_path):
+ raise ValueError("[Error] Missing checkpoint: "
+ + checkpoint_path)
+ dataset_cfg = load_config(args.dataset_config)
+ model_cfg = load_config(args.model_config)
+
+ # [resume Training, Test, Export] Load the config file.
+ elif (args.mode == "train" and args.resume) or (args.mode == "export"):
+ # Check checkpoint path exists
+ checkpoint_folder = args.resume_path
+ checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name)
+ if not os.path.exists(checkpoint_path):
+ raise ValueError("[Error] Missing checkpoint: " + checkpoint_path)
+
+ # Load model_cfg from checkpoint folder if not provided
+ if args.model_config is None:
+ print("[Info] No model config provided. Loading from checkpoint folder.")
+ model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml")
+ if not os.path.exists(model_cfg_path):
+ raise ValueError(
+ "[Error] Missing model config in checkpoint path.")
+ model_cfg = load_config(model_cfg_path)
+ else:
+ model_cfg = load_config(args.model_config)
+
+ # Load dataset_cfg from checkpoint folder if not provided
+ if args.dataset_config is None:
+ print("[Info] No dataset config provided. Loading from checkpoint folder.")
+ dataset_cfg_path = os.path.join(checkpoint_folder,
+ "dataset_cfg.yaml")
+ if not os.path.exists(dataset_cfg_path):
+ raise ValueError(
+ "[Error] Missing dataset config in checkpoint path.")
+ dataset_cfg = load_config(dataset_cfg_path)
+ else:
+ dataset_cfg = load_config(args.dataset_config)
+
+ # Check the --export_dataset_mode flag
+ if (args.mode == "export") and (args.export_dataset_mode is None):
+ raise ValueError("[Error] Empty --export_dataset_mode flag.")
+ else:
+ raise ValueError("[Error] Unknown mode: " + args.mode)
+
+ # Set the random seed
+ seed = dataset_cfg.get("random_seed", 0)
+ set_random_seed(seed)
+
+ main(args, dataset_cfg, model_cfg,
+ export_dataset_mode=args.export_dataset_mode, device=device)
diff --git a/imcui/third_party/SOLD2/sold2/export.py b/imcui/third_party/SOLD2/sold2/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..19683d982c6d7fd429b27868b620fd20562d1aa7
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/export.py
@@ -0,0 +1,342 @@
+import numpy as np
+import copy
+import cv2
+import h5py
+import math
+from tqdm import tqdm
+import torch
+from torch.nn.functional import pixel_shuffle, softmax
+from torch.utils.data import DataLoader
+from kornia.geometry import warp_perspective
+
+from .dataset.dataset_util import get_dataset
+from .model.model_util import get_model
+from .misc.train_utils import get_latest_checkpoint
+from .train import convert_junc_predictions
+from .dataset.transforms.homographic_transforms import sample_homography
+
+
+def restore_weights(model, state_dict):
+ """ Restore weights in compatible mode. """
+ # Try to directly load state dict
+ try:
+ model.load_state_dict(state_dict)
+ except:
+ err = model.load_state_dict(state_dict, strict=False)
+ # missing keys are those in model but not in state_dict
+ missing_keys = err.missing_keys
+ # Unexpected keys are those in state_dict but not in model
+ unexpected_keys = err.unexpected_keys
+
+ # Load mismatched keys manually
+ model_dict = model.state_dict()
+ for idx, key in enumerate(missing_keys):
+ dict_keys = [_ for _ in unexpected_keys if not "tracked" in _]
+ model_dict[key] = state_dict[dict_keys[idx]]
+ model.load_state_dict(model_dict)
+ return model
+
+
+def get_padded_filename(num_pad, idx):
+ """ Get the filename padded with 0. """
+ file_len = len("%d" % (idx))
+ filename = "0" * (num_pad - file_len) + "%d" % (idx)
+ return filename
+
+
+def export_predictions(args, dataset_cfg, model_cfg, output_path,
+ export_dataset_mode):
+ """ Export predictions. """
+ # Get the test configuration
+ test_cfg = model_cfg["test"]
+
+ # Create the dataset and dataloader based on the export_dataset_mode
+ print("\t Initializing dataset and dataloader")
+ batch_size = 4
+ export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
+ export_loader = DataLoader(export_dataset, batch_size=batch_size,
+ num_workers=test_cfg.get("num_workers", 4),
+ shuffle=False, pin_memory=False,
+ collate_fn=collate_fn)
+ print("\t Successfully intialized dataset and dataloader.")
+
+ # Initialize model and load the checkpoint
+ model = get_model(model_cfg, mode="test")
+ checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name)
+ model = restore_weights(model, checkpoint["model_state_dict"])
+ model = model.cuda()
+ model.eval()
+ print("\t Successfully initialized model")
+
+ # Start the export process
+ print("[Info] Start exporting predictions")
+ output_dataset_path = output_path + ".h5"
+ filename_idx = 0
+ with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f:
+ # Iterate through all the data in dataloader
+ for data in tqdm(export_loader, ascii=True):
+ # Fetch the data
+ junc_map = data["junction_map"]
+ heatmap = data["heatmap"]
+ valid_mask = data["valid_mask"]
+ input_images = data["image"].cuda()
+
+ # Run the forward pass
+ with torch.no_grad():
+ outputs = model(input_images)
+
+ # Convert predictions
+ junc_np = convert_junc_predictions(
+ outputs["junctions"], model_cfg["grid_size"],
+ model_cfg["detection_thresh"], 300)
+ junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1)
+ heatmap_np = softmax(outputs["heatmap"].detach(),
+ dim=1).cpu().numpy().transpose(0, 2, 3, 1)
+ heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1)
+ valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1)
+
+ # Data entries to save
+ current_batch_size = input_images.shape[0]
+ for batch_idx in range(current_batch_size):
+ output_data = {
+ "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+ "junc_gt": junc_map_np[batch_idx],
+ "junc_pred": junc_np["junc_pred"][batch_idx],
+ "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(np.float32),
+ "heatmap_gt": heatmap_gt_np[batch_idx],
+ "heatmap_pred": heatmap_np[batch_idx],
+ "valid_mask": valid_mask_np[batch_idx],
+ "junc_points": data["junctions"][batch_idx].numpy()[0].round().astype(np.int32),
+ "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32)
+ }
+
+ # Save data to h5 dataset
+ num_pad = math.ceil(math.log10(len(export_loader))) + 1
+ output_key = get_padded_filename(num_pad, filename_idx)
+ f_group = f.create_group(output_key)
+
+ # Store data
+ for key, output_data in output_data.items():
+ f_group.create_dataset(key, data=output_data,
+ compression="gzip")
+ filename_idx += 1
+
+
+def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
+ export_dataset_mode, device):
+ """ Export homography adaptation results. """
+ # Check if the export_dataset_mode is supported
+ supported_modes = ["train", "test"]
+ if not export_dataset_mode in supported_modes:
+ raise ValueError(
+ "[Error] The specified export_dataset_mode is not supported.")
+
+ # Get the test configuration
+ test_cfg = model_cfg["test"]
+
+ # Get the homography adaptation configurations
+ homography_cfg = dataset_cfg.get("homography_adaptation", None)
+ if homography_cfg is None:
+ raise ValueError(
+ "[Error] Empty homography_adaptation entry in config.")
+
+ # Create the dataset and dataloader based on the export_dataset_mode
+ print("\t Initializing dataset and dataloader")
+ batch_size = args.export_batch_size
+
+ export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
+ export_loader = DataLoader(export_dataset, batch_size=batch_size,
+ num_workers=test_cfg.get("num_workers", 4),
+ shuffle=False, pin_memory=False,
+ collate_fn=collate_fn)
+ print("\t Successfully intialized dataset and dataloader.")
+
+ # Initialize model and load the checkpoint
+ model = get_model(model_cfg, mode="test")
+ checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name,
+ device)
+ model = restore_weights(model, checkpoint["model_state_dict"])
+ model = model.to(device).eval()
+ print("\t Successfully initialized model")
+
+ # Start the export process
+ print("[Info] Start exporting predictions")
+ output_dataset_path = output_path + ".h5"
+ with h5py.File(output_dataset_path, "w", libver="latest") as f:
+ f.swmr_mode=True
+ for _, data in enumerate(tqdm(export_loader, ascii=True)):
+ input_images = data["image"].to(device)
+ file_keys = data["file_key"]
+ batch_size = input_images.shape[0]
+
+ # Run the homograpy adaptation
+ outputs = homography_adaptation(input_images, model,
+ model_cfg["grid_size"],
+ homography_cfg)
+
+ # Save the entries
+ for batch_idx in range(batch_size):
+ # Get the save key
+ save_key = file_keys[batch_idx]
+ output_data = {
+ "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+ "junc_prob_mean": outputs["junc_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+ "junc_prob_max": outputs["junc_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+ "junc_count": outputs["junc_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+ "heatmap_prob_mean": outputs["heatmap_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+ "heatmap_prob_max": outputs["heatmap_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
+ "heatmap_cout": outputs["heatmap_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx]
+ }
+
+ # Create group and write data
+ f_group = f.create_group(save_key)
+ for key, output_data in output_data.items():
+ f_group.create_dataset(key, data=output_data,
+ compression="gzip")
+
+
+def homography_adaptation(input_images, model, grid_size, homography_cfg):
+ """ The homography adaptation process.
+ Arguments:
+ input_images: The images to be evaluated.
+ model: The pytorch model in evaluation mode.
+ grid_size: Grid size of the junction decoder.
+ homography_cfg: Homography adaptation configurations.
+ """
+ # Get the device of the current model
+ device = next(model.parameters()).device
+
+ # Define some constants and placeholder
+ batch_size, _, H, W = input_images.shape
+ num_iter = homography_cfg["num_iter"]
+ junc_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
+ junc_counts = torch.zeros([batch_size, 1, H, W], device=device)
+ heatmap_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
+ heatmap_counts = torch.zeros([batch_size, 1, H, W], device=device)
+ margin = homography_cfg["valid_border_margin"]
+
+ # Keep a config with no artifacts
+ homography_cfg_no_artifacts = copy.copy(homography_cfg["homographies"])
+ homography_cfg_no_artifacts["allow_artifacts"] = False
+
+ for idx in range(num_iter):
+ if idx <= num_iter // 5:
+ # Ensure that 20% of the homographies have no artifact
+ H_mat_lst = [sample_homography(
+ [H,W], **homography_cfg_no_artifacts)[0][None]
+ for _ in range(batch_size)]
+ else:
+ H_mat_lst = [sample_homography(
+ [H,W], **homography_cfg["homographies"])[0][None]
+ for _ in range(batch_size)]
+
+ H_mats = np.concatenate(H_mat_lst, axis=0)
+ H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device)
+ H_inv_tensor = torch.inverse(H_tensor)
+
+ # Perform the homography warp
+ images_warped = warp_perspective(input_images, H_tensor, (H, W),
+ flags="bilinear")
+
+ # Warp the mask
+ masks_junc_warped = warp_perspective(
+ torch.ones([batch_size, 1, H, W], device=device),
+ H_tensor, (H, W), flags="nearest")
+ masks_heatmap_warped = warp_perspective(
+ torch.ones([batch_size, 1, H, W], device=device),
+ H_tensor, (H, W), flags="nearest")
+
+ # Run the network forward pass
+ with torch.no_grad():
+ outputs = model(images_warped)
+
+ # Unwarp and mask the junction prediction
+ junc_prob_warped = pixel_shuffle(softmax(
+ outputs["junctions"], dim=1)[:, :-1, :, :], grid_size)
+ junc_prob = warp_perspective(junc_prob_warped, H_inv_tensor,
+ (H, W), flags="bilinear")
+
+ # Create the out of boundary mask
+ out_boundary_mask = warp_perspective(
+ torch.ones([batch_size, 1, H, W], device=device),
+ H_inv_tensor, (H, W), flags="nearest")
+ out_boundary_mask = adjust_border(out_boundary_mask, device, margin)
+
+ junc_prob = junc_prob * out_boundary_mask
+ junc_count = warp_perspective(masks_junc_warped * out_boundary_mask,
+ H_inv_tensor, (H, W), flags="nearest")
+
+ # Unwarp the mask and heatmap prediction
+ # Always fetch only one channel
+ if outputs["heatmap"].shape[1] == 2:
+ # Convert to single channel directly from here
+ heatmap_prob_warped = softmax(outputs["heatmap"],
+ dim=1)[:, 1:, :, :]
+ else:
+ heatmap_prob_warped = torch.sigmoid(outputs["heatmap"])
+
+ heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped
+ heatmap_prob = warp_perspective(heatmap_prob_warped, H_inv_tensor,
+ (H, W), flags="bilinear")
+ heatmap_count = warp_perspective(masks_heatmap_warped, H_inv_tensor,
+ (H, W), flags="nearest")
+
+ # Record the results
+ junc_probs[:, idx:idx+1, :, :] = junc_prob
+ heatmap_probs[:, idx:idx+1, :, :] = heatmap_prob
+ junc_counts += junc_count
+ heatmap_counts += heatmap_count
+
+ # Perform the accumulation operation
+ if homography_cfg["min_counts"] > 0:
+ min_counts = homography_cfg["min_counts"]
+ junc_count_mask = (junc_counts < min_counts)
+ heatmap_count_mask = (heatmap_counts < min_counts)
+ junc_counts[junc_count_mask] = 0
+ heatmap_counts[heatmap_count_mask] = 0
+ else:
+ junc_count_mask = np.zeros_like(junc_counts, dtype=bool)
+ heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool)
+
+ # Compute the mean accumulation
+ junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts
+ junc_probs_mean[junc_count_mask] = 0.
+ heatmap_probs_mean = (torch.sum(heatmap_probs, dim=1, keepdim=True)
+ / heatmap_counts)
+ heatmap_probs_mean[heatmap_count_mask] = 0.
+
+ # Compute the max accumulation
+ junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0]
+ junc_probs_max[junc_count_mask] = 0.
+ heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0]
+ heatmap_probs_max[heatmap_count_mask] = 0.
+
+ return {"junc_probs_mean": junc_probs_mean,
+ "junc_probs_max": junc_probs_max,
+ "junc_counts": junc_counts,
+ "heatmap_probs_mean": heatmap_probs_mean,
+ "heatmap_probs_max": heatmap_probs_max,
+ "heatmap_counts": heatmap_counts}
+
+
+def adjust_border(input_masks, device, margin=3):
+ """ Adjust the border of the counts and valid_mask. """
+ # Convert the mask to numpy array
+ dtype = input_masks.dtype
+ input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1)
+
+ erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
+ (margin*2, margin*2))
+ batch_size = input_masks.shape[0]
+
+ output_mask_lst = []
+ # Erode all the masks
+ for i in range(batch_size):
+ output_mask = cv2.erode(input_masks[i, ...], erosion_kernel)
+
+ output_mask_lst.append(
+ torch.tensor(output_mask, dtype=dtype, device=device)[None])
+
+ # Concat back along the batch dimension.
+ output_masks = torch.cat(output_mask_lst, dim=0)
+ return output_masks.unsqueeze(dim=1)
diff --git a/imcui/third_party/SOLD2/sold2/export_line_features.py b/imcui/third_party/SOLD2/sold2/export_line_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cbde860a446d758dff254ea5320ca13bb79e6b7
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/export_line_features.py
@@ -0,0 +1,74 @@
+"""
+ Export line detections and descriptors given a list of input images.
+"""
+import os
+import argparse
+import cv2
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from .experiment import load_config
+from .model.line_matcher import LineMatcher
+
+
+def export_descriptors(images_list, ckpt_path, config, device, extension,
+ output_folder, multiscale=False):
+ # Extract the image paths
+ with open(images_list, 'r') as f:
+ image_files = f.readlines()
+ image_files = [path.strip('\n') for path in image_files]
+
+ # Initialize the line matcher
+ line_matcher = LineMatcher(
+ config["model_cfg"], ckpt_path, device, config["line_detector_cfg"],
+ config["line_matcher_cfg"], multiscale)
+ print("\t Successfully initialized model")
+
+ # Run the inference on each image and write the output on disk
+ for img_path in tqdm(image_files):
+ img = cv2.imread(img_path, 0)
+ img = torch.tensor(img[None, None] / 255., dtype=torch.float,
+ device=device)
+
+ # Run the line detection and description
+ ref_detection = line_matcher.line_detection(img)
+ ref_line_seg = ref_detection["line_segments"]
+ ref_descriptors = ref_detection["descriptor"][0].cpu().numpy()
+
+ # Write the output on disk
+ img_name = os.path.splitext(os.path.basename(img_path))[0]
+ output_file = os.path.join(output_folder, img_name + extension)
+ np.savez_compressed(output_file, line_seg=ref_line_seg,
+ descriptors=ref_descriptors)
+
+
+if __name__ == "__main__":
+ # Parse input arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_list", type=str, required=True,
+ help="List of input images in a text file.")
+ parser.add_argument("--output_folder", type=str, required=True,
+ help="Path to the output folder.")
+ parser.add_argument("--config", type=str,
+ default="config/export_line_features.yaml")
+ parser.add_argument("--checkpoint_path", type=str,
+ default="pretrained_models/sold2_wireframe.tar")
+ parser.add_argument("--multiscale", action="store_true", default=False)
+ parser.add_argument("--extension", type=str, default=None)
+ args = parser.parse_args()
+
+ # Get the device
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ else:
+ device = torch.device("cpu")
+
+ # Get the model config, extension and checkpoint path
+ config = load_config(args.config)
+ ckpt_path = os.path.abspath(args.checkpoint_path)
+ extension = 'sold2' if args.extension is None else args.extension
+ extension = "." + extension
+
+ export_descriptors(args.img_list, ckpt_path, config, device, extension,
+ args.output_folder, args.multiscale)
diff --git a/imcui/third_party/SOLD2/sold2/misc/__init__.py b/imcui/third_party/SOLD2/sold2/misc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/sold2/misc/geometry_utils.py b/imcui/third_party/SOLD2/sold2/misc/geometry_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..50f0478062cd19ebac812bff62b6c3a3d5f124c2
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/misc/geometry_utils.py
@@ -0,0 +1,81 @@
+import numpy as np
+import torch
+
+
+### Point-related utils
+
+# Warp a list of points using a homography
+def warp_points(points, homography):
+ # Convert to homogeneous and in xy format
+ new_points = np.concatenate([points[..., [1, 0]],
+ np.ones_like(points[..., :1])], axis=-1)
+ # Warp
+ new_points = (homography @ new_points.T).T
+ # Convert back to inhomogeneous and hw format
+ new_points = new_points[..., [1, 0]] / new_points[..., 2:]
+ return new_points
+
+
+# Mask out the points that are outside of img_size
+def mask_points(points, img_size):
+ mask = ((points[..., 0] >= 0)
+ & (points[..., 0] < img_size[0])
+ & (points[..., 1] >= 0)
+ & (points[..., 1] < img_size[1]))
+ return mask
+
+
+# Convert a tensor [N, 2] or batched tensor [B, N, 2] of N keypoints into
+# a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate
+def keypoints_to_grid(keypoints, img_size):
+ n_points = keypoints.size()[-2]
+ device = keypoints.device
+ grid_points = keypoints.float() * 2. / torch.tensor(
+ img_size, dtype=torch.float, device=device) - 1.
+ grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2)
+ return grid_points
+
+
+# Return a 2D matrix indicating the local neighborhood of each point
+# for a given threshold and two lists of corresponding keypoints
+def get_dist_mask(kp0, kp1, valid_mask, dist_thresh):
+ b_size, n_points, _ = kp0.size()
+ dist_mask0 = torch.norm(kp0.unsqueeze(2) - kp0.unsqueeze(1), dim=-1)
+ dist_mask1 = torch.norm(kp1.unsqueeze(2) - kp1.unsqueeze(1), dim=-1)
+ dist_mask = torch.min(dist_mask0, dist_mask1)
+ dist_mask = dist_mask <= dist_thresh
+ dist_mask = dist_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
+ b_size * n_points)
+ dist_mask = dist_mask[valid_mask, :][:, valid_mask]
+ return dist_mask
+
+
+### Line-related utils
+
+# Sample n points along lines of shape (num_lines, 2, 2)
+def sample_line_points(lines, n):
+ line_points_x = np.linspace(lines[:, 0, 0], lines[:, 1, 0], n, axis=-1)
+ line_points_y = np.linspace(lines[:, 0, 1], lines[:, 1, 1], n, axis=-1)
+ line_points = np.stack([line_points_x, line_points_y], axis=2)
+ return line_points
+
+
+# Return a mask of the valid lines that are within a valid mask of an image
+def mask_lines(lines, valid_mask):
+ h, w = valid_mask.shape
+ int_lines = np.clip(np.round(lines).astype(int), 0, [h - 1, w - 1])
+ h_valid = valid_mask[int_lines[:, 0, 0], int_lines[:, 0, 1]]
+ w_valid = valid_mask[int_lines[:, 1, 0], int_lines[:, 1, 1]]
+ valid = h_valid & w_valid
+ return valid
+
+
+# Return a 2D matrix indicating for each pair of points
+# if they are on the same line or not
+def get_common_line_mask(line_indices, valid_mask):
+ b_size, n_points = line_indices.shape
+ common_mask = line_indices[:, :, None] == line_indices[:, None, :]
+ common_mask = common_mask.repeat(1, 1, b_size).reshape(b_size * n_points,
+ b_size * n_points)
+ common_mask = common_mask[valid_mask, :][:, valid_mask]
+ return common_mask
diff --git a/imcui/third_party/SOLD2/sold2/misc/train_utils.py b/imcui/third_party/SOLD2/sold2/misc/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5ada35eea660df1f78b9f20d9bf7ed726eaee2c
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/misc/train_utils.py
@@ -0,0 +1,74 @@
+"""
+This file contains some useful functions for train / val.
+"""
+import os
+import numpy as np
+import torch
+
+
+#################
+## image utils ##
+#################
+def convert_image(input_tensor, axis):
+ """ Convert single channel images to 3-channel images. """
+ image_lst = [input_tensor for _ in range(3)]
+ outputs = np.concatenate(image_lst, axis)
+ return outputs
+
+
+######################
+## checkpoint utils ##
+######################
+def get_latest_checkpoint(checkpoint_root, checkpoint_name,
+ device=torch.device("cuda")):
+ """ Get the latest checkpoint or by filename. """
+ # Load specific checkpoint
+ if checkpoint_name is not None:
+ checkpoint = torch.load(
+ os.path.join(checkpoint_root, checkpoint_name),
+ map_location=device)
+ # Load the latest checkpoint
+ else:
+ lastest_checkpoint = sorted(os.listdir(os.path.join(
+ checkpoint_root, "*.tar")))[-1]
+ checkpoint = torch.load(os.path.join(
+ checkpoint_root, lastest_checkpoint), map_location=device)
+ return checkpoint
+
+
+def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
+ """ Remove the outdated checkpoints. """
+ # Get sorted list of checkpoints
+ checkpoint_list = sorted(
+ [_ for _ in os.listdir(os.path.join(checkpoint_root))
+ if _.endswith(".tar")])
+
+ # Get the checkpoints to be removed
+ if len(checkpoint_list) > max_ckpt:
+ remove_list = checkpoint_list[:-max_ckpt]
+ for _ in remove_list:
+ full_name = os.path.join(checkpoint_root, _)
+ os.remove(full_name)
+ print("[Debug] Remove outdated checkpoint %s" % (full_name))
+
+
+def adapt_checkpoint(state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith('module.'):
+ new_state_dict[k[7:]] = v
+ else:
+ new_state_dict[k] = v
+ return new_state_dict
+
+
+################
+## HDF5 utils ##
+################
+def parse_h5_data(h5_data):
+ """ Parse h5 dataset. """
+ output_data = {}
+ for key in h5_data.keys():
+ output_data[key] = np.array(h5_data[key])
+
+ return output_data
diff --git a/imcui/third_party/SOLD2/sold2/misc/visualize_util.py b/imcui/third_party/SOLD2/sold2/misc/visualize_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa46877f79724221b7caa423de6916acdc021f8
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/misc/visualize_util.py
@@ -0,0 +1,526 @@
+""" Organize some frequently used visualization functions. """
+import cv2
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+import copy
+import seaborn as sns
+
+
+# Plot junctions onto the image (return a separate copy)
+def plot_junctions(input_image, junctions, junc_size=3, color=None):
+ """
+ input_image: can be 0~1 float or 0~255 uint8.
+ junctions: Nx2 or 2xN np array.
+ junc_size: the size of the plotted circles.
+ """
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.:
+ image = (image * 255.).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Junction dimensions should be N*2
+ if not len(junctions.shape) == 2:
+ raise ValueError("[Error] junctions should be 2-dim array.")
+
+ # Always convert to N*2
+ if junctions.shape[-1] != 2:
+ if junctions.shape[0] == 2:
+ junctions = junctions.T
+ else:
+ raise ValueError("[Error] At least one of the two dims should be 2.")
+
+ # Round and convert junctions to int (and check the boundary)
+ H, W = image.shape[:2]
+ junctions = (np.round(junctions)).astype(np.int)
+ junctions[junctions < 0] = 0
+ junctions[junctions[:, 0] >= H, 0] = H-1 # (first dim) max bounded by H-1
+ junctions[junctions[:, 1] >= W, 1] = W-1 # (second dim) max bounded by W-1
+
+ # Iterate through all the junctions
+ num_junc = junctions.shape[0]
+ if color is None:
+ color = (0, 255., 0)
+ for idx in range(num_junc):
+ # Fetch one junction
+ junc = junctions[idx, :]
+ cv2.circle(image, tuple(np.flip(junc)), radius=junc_size,
+ color=color, thickness=3)
+
+ return image
+
+
+# Plot line segements given junctions and line adjecent map
+def plot_line_segments(input_image, junctions, line_map, junc_size=3,
+ color=(0, 255., 0), line_width=1, plot_survived_junc=True):
+ """
+ input_image: can be 0~1 float or 0~255 uint8.
+ junctions: Nx2 or 2xN np array.
+ line_map: NxN np array
+ junc_size: the size of the plotted circles.
+ color: color of the line segments (can be string "random")
+ line_width: width of the drawn segments.
+ plot_survived_junc: whether we only plot the survived junctions.
+ """
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.:
+ image = (image * 255.).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Junction dimensions should be 2
+ if not len(junctions.shape) == 2:
+ raise ValueError("[Error] junctions should be 2-dim array.")
+
+ # Always convert to N*2
+ if junctions.shape[-1] != 2:
+ if junctions.shape[0] == 2:
+ junctions = junctions.T
+ else:
+ raise ValueError("[Error] At least one of the two dims should be 2.")
+
+ # line_map dimension should be 2
+ if not len(line_map.shape) == 2:
+ raise ValueError("[Error] line_map should be 2-dim array.")
+
+ # Color should be "random" or a list or tuple with length 3
+ if color != "random":
+ if not (isinstance(color, tuple) or isinstance(color, list)):
+ raise ValueError("[Error] color should have type list or tuple.")
+ else:
+ if len(color) != 3:
+ raise ValueError("[Error] color should be a list or tuple with length 3.")
+
+ # Make a copy of the line_map
+ line_map_tmp = copy.copy(line_map)
+
+ # Parse line_map back to segment pairs
+ segments = np.zeros([0, 4])
+ for idx in range(junctions.shape[0]):
+ # if no connectivity, just skip it
+ if line_map_tmp[idx, :].sum() == 0:
+ continue
+ # record the line segment
+ else:
+ for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
+ p1 = np.flip(junctions[idx, :]) # Convert to xy format
+ p2 = np.flip(junctions[idx2, :]) # Convert to xy format
+ segments = np.concatenate((segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), axis=0)
+
+ # Update line_map
+ line_map_tmp[idx, idx2] = 0
+ line_map_tmp[idx2, idx] = 0
+
+ # Draw segment pairs
+ for idx in range(segments.shape[0]):
+ seg = np.round(segments[idx, :]).astype(np.int)
+ # Decide the color
+ if color != "random":
+ color = tuple(color)
+ else:
+ color = tuple(np.random.rand(3,))
+ cv2.line(image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width)
+
+ # Also draw the junctions
+ if not plot_survived_junc:
+ num_junc = junctions.shape[0]
+ for idx in range(num_junc):
+ # Fetch one junction
+ junc = junctions[idx, :]
+ cv2.circle(image, tuple(np.flip(junc)), radius=junc_size,
+ color=(0, 255., 0), thickness=3)
+ # Only plot the junctions which are part of a line segment
+ else:
+ for idx in range(segments.shape[0]):
+ seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format.
+ cv2.circle(image, tuple(seg[:2]), radius=junc_size,
+ color=(0, 255., 0), thickness=3)
+ cv2.circle(image, tuple(seg[2:]), radius=junc_size,
+ color=(0, 255., 0), thickness=3)
+
+ return image
+
+
+# Plot line segments given Nx4 or Nx2x2 line segments
+def plot_line_segments_from_segments(input_image, line_segments, junc_size=3,
+ color=(0, 255., 0), line_width=1):
+ # Create image copy
+ image = copy.copy(input_image)
+ # Make sure the image is converted to 255 uint8
+ if image.dtype == np.uint8:
+ pass
+ # A float type image ranging from 0~1
+ elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.:
+ image = (image * 255.).astype(np.uint8)
+ # A float type image ranging from 0.~255.
+ elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.:
+ image = image.astype(np.uint8)
+ else:
+ raise ValueError("[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8.")
+
+ # Check whether the image is single channel
+ if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
+ # Squeeze to H*W first
+ image = image.squeeze()
+
+ # Stack to channle 3
+ image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
+
+ # Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
+ H, W, _ = image.shape
+ # (1) Nx4 format
+ if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
+ # Round to int32
+ line_segments = line_segments.astype(np.int32)
+
+ # Clip H dimension
+ line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H-1)
+ line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H-1)
+
+ # Clip W dimension
+ line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W-1)
+ line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W-1)
+
+ # Convert to Nx2x2 format
+ line_segments = np.concatenate(
+ [np.expand_dims(line_segments[:, :2], axis=1),
+ np.expand_dims(line_segments[:, 2:], axis=1)],
+ axis=1
+ )
+
+ # (2) Nx2x2 format
+ elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
+ # Round to int32
+ line_segments = line_segments.astype(np.int32)
+
+ # Clip H dimension
+ line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H-1)
+ line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W-1)
+
+ else:
+ raise ValueError("[Error] line_segments should be either Nx4 or Nx2x2 in HW format.")
+
+ # Draw segment pairs (all segments should be in HW format)
+ image = image.copy()
+ for idx in range(line_segments.shape[0]):
+ seg = np.round(line_segments[idx, :, :]).astype(np.int32)
+ # Decide the color
+ if color != "random":
+ color = tuple(color)
+ else:
+ color = tuple(np.random.rand(3,))
+ cv2.line(image, tuple(np.flip(seg[0, :])),
+ tuple(np.flip(seg[1, :])),
+ color=color, thickness=line_width)
+
+ # Also draw the junctions
+ cv2.circle(image, tuple(np.flip(seg[0, :])), radius=junc_size, color=(0, 255., 0), thickness=3)
+ cv2.circle(image, tuple(np.flip(seg[1, :])), radius=junc_size, color=(0, 255., 0), thickness=3)
+
+ return image
+
+
+# Additional functions to visualize multiple images at the same time,
+# e.g. for line matching
+def plot_images(imgs, titles=None, cmaps='gray', dpi=100, size=6, pad=.5):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+ figsize = (size*n, size*3/4) if size is not None else None
+ fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
+ if n == 1:
+ ax = [ax]
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ ax[i].set_axis_off()
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ if titles:
+ ax[i].set_title(titles[i])
+ fig.tight_layout(pad=pad)
+
+
+def plot_keypoints(kpts, colors='lime', ps=4):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ axes = plt.gcf().axes
+ for a, k, c in zip(axes, kpts, colors):
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
+ alpha=a)
+ for i in range(len(kpts0))]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
+
+
+def plot_lines(lines, line_colors='orange', point_colors='cyan',
+ ps=4, lw=2, indices=(0, 1)):
+ """Plot lines and endpoints for existing images.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float pixels.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ if not isinstance(line_colors, list):
+ line_colors = [line_colors] * len(lines)
+ if not isinstance(point_colors, list):
+ point_colors = [point_colors] * len(lines)
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines and junctions
+ for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
+ for i in range(len(l)):
+ line = matplotlib.lines.Line2D((l[i, 0, 0], l[i, 1, 0]),
+ (l[i, 0, 1], l[i, 1, 1]),
+ zorder=1, c=lc, linewidth=lw)
+ a.add_line(line)
+ pts = l.reshape(-1, 2)
+ a.scatter(pts[:, 0], pts[:, 1],
+ c=pc, s=ps, linewidths=0, zorder=2)
+
+
+def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.):
+ """Plot matches for a pair of existing images, parametrized by their middle point.
+ Args:
+ kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
+ alpha=a)
+ for i in range(len(kpts0))]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+
+def plot_color_line_matches(lines, correct_matches=None,
+ lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ correct_matches: bool array of size (N,) indicating correct matches.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ n_lines = len(lines[0])
+ colors = sns.color_palette('husl', n_colors=n_lines)
+ np.random.shuffle(colors)
+ alphas = np.ones(n_lines)
+ # If correct_matches is not None, display wrong matches with a low alpha
+ if correct_matches is not None:
+ alphas[~np.array(correct_matches)] = 0.2
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l in zip(axes, lines):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=colors[i],
+ alpha=alphas[i], linewidth=lw) for i in range(n_lines)]
+
+
+def plot_color_lines(lines, correct_matches, wrong_matches,
+ lw=2, indices=(0, 1)):
+ """Plot line matches for existing images with multiple colors:
+ green for correct matches, red for wrong ones, and blue for the rest.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ correct_matches: list of bool arrays of size N with correct matches.
+ wrong_matches: list of bool arrays of size (N,) with correct matches.
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ # palette = sns.color_palette()
+ palette = sns.color_palette("hls", 8)
+ blue = palette[5] # palette[0]
+ red = palette[0] # palette[3]
+ green = palette[2] # palette[2]
+ colors = [np.array([blue] * len(l)) for l in lines]
+ for i, c in enumerate(colors):
+ c[np.array(correct_matches[i])] = green
+ c[np.array(wrong_matches[i])] = red
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l, c in zip(axes, lines, colors):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=c[i],
+ linewidth=lw) for i in range(len(l))]
+
+
+def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
+ """ Plot line matches for existing images with multiple colors and
+ highlight the actually matched subsegments.
+ Args:
+ lines: list of ndarrays of size (N, 2, 2).
+ subsegments: list of ndarrays of size (N, 2, 2).
+ lw: line width as float pixels.
+ indices: indices of the images to draw the matches on.
+ """
+ n_lines = len(lines[0])
+ colors = sns.cubehelix_palette(start=2, rot=-0.2, dark=0.3, light=.7,
+ gamma=1.3, hue=1, n_colors=n_lines)
+
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ axes = [ax[i] for i in indices]
+ fig.canvas.draw()
+
+ # Plot the lines
+ for a, l, ss in zip(axes, lines, subsegments):
+ # Transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+
+ # Draw full line
+ endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c='red',
+ alpha=0.7, linewidth=lw) for i in range(n_lines)]
+
+ # Draw matched subsegment
+ endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
+ endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
+ fig.lines += [matplotlib.lines.Line2D(
+ (endpoint0[i, 0], endpoint1[i, 0]),
+ (endpoint0[i, 1], endpoint1[i, 1]),
+ zorder=1, transform=fig.transFigure, c=colors[i],
+ alpha=1, linewidth=lw) for i in range(n_lines)]
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/model/__init__.py b/imcui/third_party/SOLD2/sold2/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/sold2/model/line_detection.py b/imcui/third_party/SOLD2/sold2/model/line_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0d1928515a8494833a8ef6509008f4299cd74c4
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/line_detection.py
@@ -0,0 +1,506 @@
+"""
+Implementation of the line segment detection module.
+"""
+import math
+import numpy as np
+import torch
+
+
+class LineSegmentDetectionModule(object):
+ """ Module extracting line segments from junctions and line heatmaps. """
+ def __init__(
+ self, detect_thresh, num_samples=64, sampling_method="local_max",
+ inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2,
+ max_local_patch_radius=3, lambda_radius=2.,
+ use_candidate_suppression=False, nms_dist_tolerance=3.,
+ use_heatmap_refinement=False, heatmap_refine_cfg=None,
+ use_junction_refinement=False, junction_refine_cfg=None):
+ """
+ Parameters:
+ detect_thresh: The probability threshold for mean activation (0. ~ 1.)
+ num_samples: Number of sampling locations along the line segments.
+ sampling_method: Sampling method on locations ("bilinear" or "local_max").
+ inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
+ heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery.
+ heatmap_high_thresh: The higher threshold for NMS in junction recovery.
+ max_local_patch_radius: The max patch to be considered in local maximum search.
+ lambda_radius: The lambda factor in linear local maximum search formulation
+ use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments.
+ nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line.
+ use_heatmap_refinement: Use heatmap refinement method or not.
+ heatmap_refine_cfg: The configs for heatmap refinement methods.
+ use_junction_refinement: Use junction refinement method or not.
+ junction_refine_cfg: The configs for junction refinement methods.
+ """
+ # Line detection parameters
+ self.detect_thresh = detect_thresh
+
+ # Line sampling parameters
+ self.num_samples = num_samples
+ self.sampling_method = sampling_method
+ self.inlier_thresh = inlier_thresh
+ self.local_patch_radius = max_local_patch_radius
+ self.lambda_radius = lambda_radius
+
+ # Detecting junctions on the boundary parameters
+ self.low_thresh = heatmap_low_thresh
+ self.high_thresh = heatmap_high_thresh
+
+ # Pre-compute the linspace sampler
+ self.sampler = np.linspace(0, 1, self.num_samples)
+ self.torch_sampler = torch.linspace(0, 1, self.num_samples)
+
+ # Long line segment suppression configuration
+ self.use_candidate_suppression = use_candidate_suppression
+ self.nms_dist_tolerance = nms_dist_tolerance
+
+ # Heatmap refinement configuration
+ self.use_heatmap_refinement = use_heatmap_refinement
+ self.heatmap_refine_cfg = heatmap_refine_cfg
+ if self.use_heatmap_refinement and self.heatmap_refine_cfg is None:
+ raise ValueError("[Error] Missing heatmap refinement config.")
+
+ # Junction refinement configuration
+ self.use_junction_refinement = use_junction_refinement
+ self.junction_refine_cfg = junction_refine_cfg
+ if self.use_junction_refinement and self.junction_refine_cfg is None:
+ raise ValueError("[Error] Missing junction refinement config.")
+
+ def convert_inputs(self, inputs, device):
+ """ Convert inputs to desired torch tensor. """
+ if isinstance(inputs, np.ndarray):
+ outputs = torch.tensor(inputs, dtype=torch.float32, device=device)
+ elif isinstance(inputs, torch.Tensor):
+ outputs = inputs.to(torch.float32).to(device)
+ else:
+ raise ValueError(
+ "[Error] Inputs must either be torch tensor or numpy ndarray.")
+
+ return outputs
+
+ def detect(self, junctions, heatmap, device=torch.device("cpu")):
+ """ Main function performing line segment detection. """
+ # Convert inputs to torch tensor
+ junctions = self.convert_inputs(junctions, device=device)
+ heatmap = self.convert_inputs(heatmap, device=device)
+
+ # Perform the heatmap refinement
+ if self.use_heatmap_refinement:
+ if self.heatmap_refine_cfg["mode"] == "global":
+ heatmap = self.refine_heatmap(
+ heatmap,
+ self.heatmap_refine_cfg["ratio"],
+ self.heatmap_refine_cfg["valid_thresh"]
+ )
+ elif self.heatmap_refine_cfg["mode"] == "local":
+ heatmap = self.refine_heatmap_local(
+ heatmap,
+ self.heatmap_refine_cfg["num_blocks"],
+ self.heatmap_refine_cfg["overlap_ratio"],
+ self.heatmap_refine_cfg["ratio"],
+ self.heatmap_refine_cfg["valid_thresh"]
+ )
+
+ # Initialize empty line map
+ num_junctions = junctions.shape[0]
+ line_map_pred = torch.zeros([num_junctions, num_junctions],
+ device=device, dtype=torch.int32)
+
+ # Stop if there are not enough junctions
+ if num_junctions < 2:
+ return line_map_pred, junctions, heatmap
+
+ # Generate the candidate map
+ candidate_map = torch.triu(torch.ones(
+ [num_junctions, num_junctions], device=device, dtype=torch.int32),
+ diagonal=1)
+
+ # Fetch the image boundary
+ if len(heatmap.shape) > 2:
+ H, W, _ = heatmap.shape
+ else:
+ H, W = heatmap.shape
+
+ # Optionally perform candidate filtering
+ if self.use_candidate_suppression:
+ candidate_map = self.candidate_suppression(junctions,
+ candidate_map)
+
+ # Fetch the candidates
+ candidate_index_map = torch.where(candidate_map)
+ candidate_index_map = torch.cat([candidate_index_map[0][..., None],
+ candidate_index_map[1][..., None]],
+ dim=-1)
+
+ # Get the corresponding start and end junctions
+ candidate_junc_start = junctions[candidate_index_map[:, 0], :]
+ candidate_junc_end = junctions[candidate_index_map[:, 1], :]
+
+ # Get the sampling locations (N x 64)
+ sampler = self.torch_sampler.to(device)[None, ...]
+ cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \
+ candidate_junc_end[:, 0:1] * (1 - sampler)
+ cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \
+ candidate_junc_end[:, 1:2] * (1 - sampler)
+
+ # Clip to image boundary
+ cand_h = torch.clamp(cand_samples_h, min=0, max=H-1)
+ cand_w = torch.clamp(cand_samples_w, min=0, max=W-1)
+
+ # Local maximum search
+ if self.sampling_method == "local_max":
+ # Compute normalized segment lengths
+ segments_length = torch.sqrt(torch.sum(
+ (candidate_junc_start.to(torch.float32) -
+ candidate_junc_end.to(torch.float32)) ** 2, dim=-1))
+ normalized_seg_length = (segments_length
+ / (((H ** 2) + (W ** 2)) ** 0.5))
+
+ # Perform local max search
+ num_cand = cand_h.shape[0]
+ group_size = 10000
+ if num_cand > group_size:
+ num_iter = math.ceil(num_cand / group_size)
+ sampled_feat_lst = []
+ for iter_idx in range(num_iter):
+ if not iter_idx == num_iter-1:
+ cand_h_ = cand_h[iter_idx * group_size:
+ (iter_idx+1) * group_size, :]
+ cand_w_ = cand_w[iter_idx * group_size:
+ (iter_idx+1) * group_size, :]
+ normalized_seg_length_ = normalized_seg_length[
+ iter_idx * group_size: (iter_idx+1) * group_size]
+ else:
+ cand_h_ = cand_h[iter_idx * group_size:, :]
+ cand_w_ = cand_w[iter_idx * group_size:, :]
+ normalized_seg_length_ = normalized_seg_length[
+ iter_idx * group_size:]
+ sampled_feat_ = self.detect_local_max(
+ heatmap, cand_h_, cand_w_, H, W,
+ normalized_seg_length_, device)
+ sampled_feat_lst.append(sampled_feat_)
+ sampled_feat = torch.cat(sampled_feat_lst, dim=0)
+ else:
+ sampled_feat = self.detect_local_max(
+ heatmap, cand_h, cand_w, H, W,
+ normalized_seg_length, device)
+ # Bilinear sampling
+ elif self.sampling_method == "bilinear":
+ # Perform bilinear sampling
+ sampled_feat = self.detect_bilinear(
+ heatmap, cand_h, cand_w, H, W, device)
+ else:
+ raise ValueError("[Error] Unknown sampling method.")
+
+ # [Simple threshold detection]
+ # detection_results is a mask over all candidates
+ detection_results = (torch.mean(sampled_feat, dim=-1)
+ > self.detect_thresh)
+
+ # [Inlier threshold detection]
+ if self.inlier_thresh > 0.:
+ inlier_ratio = torch.sum(
+ sampled_feat > self.detect_thresh,
+ dim=-1).to(torch.float32) / self.num_samples
+ detection_results_inlier = inlier_ratio >= self.inlier_thresh
+ detection_results = detection_results * detection_results_inlier
+
+ # Convert detection results back to line_map_pred
+ detected_junc_indexes = candidate_index_map[detection_results, :]
+ line_map_pred[detected_junc_indexes[:, 0],
+ detected_junc_indexes[:, 1]] = 1
+ line_map_pred[detected_junc_indexes[:, 1],
+ detected_junc_indexes[:, 0]] = 1
+
+ # Perform junction refinement
+ if self.use_junction_refinement and len(detected_junc_indexes) > 0:
+ junctions, line_map_pred = self.refine_junction_perturb(
+ junctions, line_map_pred, heatmap, H, W, device)
+
+ return line_map_pred, junctions, heatmap
+
+ def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2):
+ """ Global heatmap refinement method. """
+ # Grab the top 10% values
+ heatmap_values = heatmap[heatmap > valid_thresh]
+ sorted_values = torch.sort(heatmap_values, descending=True)[0]
+ top10_len = math.ceil(sorted_values.shape[0] * ratio)
+ max20 = torch.mean(sorted_values[:top10_len])
+ heatmap = torch.clamp(heatmap / max20, min=0., max=1.)
+ return heatmap
+
+ def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5,
+ ratio=0.2, valid_thresh=2e-3):
+ """ Local heatmap refinement method. """
+ # Get the shape of the heatmap
+ H, W = heatmap.shape
+ increase_ratio = 1 - overlap_ratio
+ h_block = round(H / (1 + (num_blocks - 1) * increase_ratio))
+ w_block = round(W / (1 + (num_blocks - 1) * increase_ratio))
+
+ count_map = torch.zeros(heatmap.shape, dtype=torch.float,
+ device=heatmap.device)
+ heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float,
+ device=heatmap.device)
+ # Iterate through each block
+ for h_idx in range(num_blocks):
+ for w_idx in range(num_blocks):
+ # Fetch the heatmap
+ h_start = round(h_idx * h_block * increase_ratio)
+ w_start = round(w_idx * w_block * increase_ratio)
+ h_end = h_start + h_block if h_idx < num_blocks - 1 else H
+ w_end = w_start + w_block if w_idx < num_blocks - 1 else W
+
+ subheatmap = heatmap[h_start:h_end, w_start:w_end]
+ if subheatmap.max() > valid_thresh:
+ subheatmap = self.refine_heatmap(
+ subheatmap, ratio, valid_thresh=valid_thresh)
+
+ # Aggregate it to the final heatmap
+ heatmap_output[h_start:h_end, w_start:w_end] += subheatmap
+ count_map[h_start:h_end, w_start:w_end] += 1
+ heatmap_output = torch.clamp(heatmap_output / count_map,
+ max=1., min=0.)
+
+ return heatmap_output
+
+ def candidate_suppression(self, junctions, candidate_map):
+ """ Suppress overlapping long lines in the candidate segments. """
+ # Define the distance tolerance
+ dist_tolerance = self.nms_dist_tolerance
+
+ # Compute distance between junction pairs
+ # (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map
+ line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1)
+ - junctions[None, ...]) ** 2, dim=-1) ** 0.5
+
+ # Fetch all the "detected lines"
+ seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1))
+ start_point_idxs = seg_indexes[0]
+ end_point_idxs = seg_indexes[1]
+ start_points = junctions[start_point_idxs, :]
+ end_points = junctions[end_point_idxs, :]
+
+ # Fetch corresponding entries
+ line_dists = line_dist_map[start_point_idxs, end_point_idxs]
+
+ # Check whether they are on the line
+ dir_vecs = ((end_points - start_points)
+ / torch.norm(end_points - start_points,
+ dim=-1)[..., None])
+ # Get the orthogonal distance
+ cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1)
+ cand_vecs_norm = torch.norm(cand_vecs, dim=-1)
+ # Check whether they are projected directly onto the segment
+ proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
+ / line_dists[..., None, None])
+ # proj is num_segs x num_junction x 1
+ proj_mask = (proj >=0) * (proj <= 1)
+ cand_angles = torch.acos(
+ torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
+ / cand_vecs_norm[..., None])
+ cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles)
+ junc_dist_mask = cand_dists <= dist_tolerance
+ junc_mask = junc_dist_mask * proj_mask
+
+ # Minus starting points
+ num_segs = start_point_idxs.shape[0]
+ junc_counts = torch.sum(junc_mask, dim=[1, 2])
+ junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
+ start_point_idxs].to(torch.int)
+ junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
+ end_point_idxs].to(torch.int)
+
+ # Get the invalid candidate mask
+ final_mask = junc_counts > 0
+ candidate_map[start_point_idxs[final_mask],
+ end_point_idxs[final_mask]] = 0
+
+ return candidate_map
+
+ def refine_junction_perturb(self, junctions, line_map_pred,
+ heatmap, H, W, device):
+ """ Refine the line endpoints in a similar way as in LSD. """
+ # Get the config
+ junction_refine_cfg = self.junction_refine_cfg
+
+ # Fetch refinement parameters
+ num_perturbs = junction_refine_cfg["num_perturbs"]
+ perturb_interval = junction_refine_cfg["perturb_interval"]
+ side_perturbs = (num_perturbs - 1) // 2
+ # Fetch the 2D perturb mat
+ perturb_vec = torch.arange(
+ start=-perturb_interval*side_perturbs,
+ end=perturb_interval*(side_perturbs+1),
+ step=perturb_interval, device=device)
+ w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid(
+ perturb_vec, perturb_vec, perturb_vec, perturb_vec)
+ perturb_tensor = torch.cat([
+ w1_grid[..., None], h1_grid[..., None],
+ w2_grid[..., None], h2_grid[..., None]], dim=-1)
+ perturb_tensor_flat = perturb_tensor.view(-1, 2, 2)
+
+ # Fetch the junctions and line_map
+ junctions = junctions.clone()
+ line_map = line_map_pred
+
+ # Fetch all the detected lines
+ detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1))
+ start_point_idxs = detected_seg_indexes[0]
+ end_point_idxs = detected_seg_indexes[1]
+ start_points = junctions[start_point_idxs, :]
+ end_points = junctions[end_point_idxs, :]
+
+ line_segments = torch.cat([start_points.unsqueeze(dim=1),
+ end_points.unsqueeze(dim=1)], dim=1)
+
+ line_segment_candidates = (line_segments.unsqueeze(dim=1)
+ + perturb_tensor_flat[None, ...])
+ # Clip the boundaries
+ line_segment_candidates[..., 0] = torch.clamp(
+ line_segment_candidates[..., 0], min=0, max=H - 1)
+ line_segment_candidates[..., 1] = torch.clamp(
+ line_segment_candidates[..., 1], min=0, max=W - 1)
+
+ # Iterate through all the segments
+ refined_segment_lst = []
+ num_segments = line_segments.shape[0]
+ for idx in range(num_segments):
+ segment = line_segment_candidates[idx, ...]
+ # Get the corresponding start and end junctions
+ candidate_junc_start = segment[:, 0, :]
+ candidate_junc_end = segment[:, 1, :]
+
+ # Get the sampling locations (N x 64)
+ sampler = self.torch_sampler.to(device)[None, ...]
+ cand_samples_h = (candidate_junc_start[:, 0:1] * sampler +
+ candidate_junc_end[:, 0:1] * (1 - sampler))
+ cand_samples_w = (candidate_junc_start[:, 1:2] * sampler +
+ candidate_junc_end[:, 1:2] * (1 - sampler))
+
+ # Clip to image boundary
+ cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
+ cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)
+
+ # Perform bilinear sampling
+ segment_feat = self.detect_bilinear(
+ heatmap, cand_h, cand_w, H, W, device)
+ segment_results = torch.mean(segment_feat, dim=-1)
+ max_idx = torch.argmax(segment_results)
+ refined_segment_lst.append(segment[max_idx, ...][None, ...])
+
+ # Concatenate back to segments
+ refined_segments = torch.cat(refined_segment_lst, dim=0)
+
+ # Convert back to junctions and line_map
+ junctions_new = torch.cat(
+ [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0)
+ junctions_new = torch.unique(junctions_new, dim=0)
+ line_map_new = self.segments_to_line_map(junctions_new,
+ refined_segments)
+
+ return junctions_new, line_map_new
+
+ def segments_to_line_map(self, junctions, segments):
+ """ Convert the list of segments to line map. """
+ # Create empty line map
+ device = junctions.device
+ num_junctions = junctions.shape[0]
+ line_map = torch.zeros([num_junctions, num_junctions], device=device)
+
+ # Iterate through every segment
+ for idx in range(segments.shape[0]):
+ # Get the junctions from a single segement
+ seg = segments[idx, ...]
+ junction1 = seg[0, :]
+ junction2 = seg[1, :]
+
+ # Get index
+ idx_junction1 = torch.where(
+ (junctions == junction1).sum(axis=1) == 2)[0]
+ idx_junction2 = torch.where(
+ (junctions == junction2).sum(axis=1) == 2)[0]
+
+ # label the corresponding entries
+ line_map[idx_junction1, idx_junction2] = 1
+ line_map[idx_junction2, idx_junction1] = 1
+
+ return line_map
+
+ def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device):
+ """ Detection by bilinear sampling. """
+ # Get the floor and ceiling locations
+ cand_h_floor = torch.floor(cand_h).to(torch.long)
+ cand_h_ceil = torch.ceil(cand_h).to(torch.long)
+ cand_w_floor = torch.floor(cand_w).to(torch.long)
+ cand_w_ceil = torch.ceil(cand_w).to(torch.long)
+
+ # Perform the bilinear sampling
+ cand_samples_feat = (
+ heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h)
+ * (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil]
+ * (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) +
+ heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor)
+ * (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil]
+ * (cand_h - cand_h_floor) * (cand_w - cand_w_floor))
+
+ return cand_samples_feat
+
+ def detect_local_max(self, heatmap, cand_h, cand_w, H, W,
+ normalized_seg_length, device):
+ """ Detection by local maximum search. """
+ # Compute the distance threshold
+ dist_thresh = (0.5 * (2 ** 0.5)
+ + self.lambda_radius * normalized_seg_length)
+ # Make it N x 64
+ dist_thresh = torch.repeat_interleave(dist_thresh[..., None],
+ self.num_samples, dim=-1)
+
+ # Compute the candidate points
+ cand_points = torch.cat([cand_h[..., None], cand_w[..., None]],
+ dim=-1)
+ cand_points_round = torch.round(cand_points) # N x 64 x 2
+
+ # Construct local patches 9x9 = 81
+ patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1),
+ int(2 * self.local_patch_radius + 1)],
+ device=device)
+ patch_center = torch.tensor(
+ [[self.local_patch_radius, self.local_patch_radius]],
+ device=device, dtype=torch.float32)
+ H_patch_points, W_patch_points = torch.where(patch_mask >= 0)
+ patch_points = torch.cat([H_patch_points[..., None],
+ W_patch_points[..., None]], dim=-1)
+ # Fetch the circle region
+ patch_center_dist = torch.sqrt(torch.sum(
+ (patch_points - patch_center) ** 2, dim=-1))
+ patch_points = (patch_points[patch_center_dist
+ <= self.local_patch_radius, :])
+ # Shift [0, 0] to the center
+ patch_points = patch_points - self.local_patch_radius
+
+ # Construct local patch mask
+ patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2)
+ + patch_points[None, None, ...])
+ patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2)
+ - patch_points_shifted) ** 2,
+ dim=-1))
+ patch_dist_mask = patch_dist < dist_thresh[..., None]
+
+ # Get all points => num_points_center x num_patch_points x 2
+ points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0,
+ max=H - 1).to(torch.long)
+ points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0,
+ max=W - 1).to(torch.long)
+ points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1)
+
+ # Sample the feature (N x 64 x 81)
+ sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]]
+ # Filtering using the valid mask
+ sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32)
+ if len(sampled_feat) == 0:
+ sampled_feat_lmax = torch.empty(0, 64)
+ else:
+ sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1)
+
+ return sampled_feat_lmax
diff --git a/imcui/third_party/SOLD2/sold2/model/line_detector.py b/imcui/third_party/SOLD2/sold2/model/line_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f3d059e130178c482e8e569171ef9e0370424c7
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/line_detector.py
@@ -0,0 +1,127 @@
+"""
+Line segment detection from raw images.
+"""
+import time
+import numpy as np
+import torch
+from torch.nn.functional import softmax
+
+from .model_util import get_model
+from .loss import get_loss_and_weights
+from .line_detection import LineSegmentDetectionModule
+from ..train import convert_junc_predictions
+from ..misc.train_utils import adapt_checkpoint
+
+
+def line_map_to_segments(junctions, line_map):
+ """ Convert a line map to a Nx2x2 list of segments. """
+ line_map_tmp = line_map.copy()
+
+ output_segments = np.zeros([0, 2, 2])
+ for idx in range(junctions.shape[0]):
+ # if no connectivity, just skip it
+ if line_map_tmp[idx, :].sum() == 0:
+ continue
+ # Record the line segment
+ else:
+ for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
+ p1 = junctions[idx, :] # HW format
+ p2 = junctions[idx2, :]
+ single_seg = np.concatenate([p1[None, ...], p2[None, ...]],
+ axis=0)
+ output_segments = np.concatenate(
+ (output_segments, single_seg[None, ...]), axis=0)
+
+ # Update line_map
+ line_map_tmp[idx, idx2] = 0
+ line_map_tmp[idx2, idx] = 0
+
+ return output_segments
+
+
+class LineDetector(object):
+ def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg,
+ junc_detect_thresh=None):
+ """ SOLD² line detector taking raw images as input.
+ Parameters:
+ model_cfg: config for CNN model
+ ckpt_path: path to the weights
+ line_detector_cfg: config file for the line detection module
+ """
+ # Get loss weights if dynamic weighting
+ _, loss_weights = get_loss_and_weights(model_cfg, device)
+ self.device = device
+
+ # Initialize the cnn backbone
+ self.model = get_model(model_cfg, loss_weights)
+ checkpoint = torch.load(ckpt_path, map_location=self.device)
+ checkpoint = adapt_checkpoint(checkpoint["model_state_dict"])
+ self.model.load_state_dict(checkpoint)
+ self.model = self.model.to(self.device)
+ self.model = self.model.eval()
+
+ self.grid_size = model_cfg["grid_size"]
+
+ if junc_detect_thresh is not None:
+ self.junc_detect_thresh = junc_detect_thresh
+ else:
+ self.junc_detect_thresh = model_cfg.get("detection_thresh", 1/65)
+ self.max_num_junctions = model_cfg.get("max_num_junctions", 300)
+
+ # Initialize the line detector
+ self.line_detector_cfg = line_detector_cfg
+ self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
+
+ def __call__(self, input_image, valid_mask=None,
+ return_heatmap=False, profile=False):
+ # Now we restrict input_image to 4D torch tensor
+ if ((not len(input_image.shape) == 4)
+ or (not isinstance(input_image, torch.Tensor))):
+ raise ValueError(
+ "[Error] the input image should be a 4D torch tensor.")
+
+ # Move the input to corresponding device
+ input_image = input_image.to(self.device)
+
+ # Forward of the CNN backbone
+ start_time = time.time()
+ with torch.no_grad():
+ net_outputs = self.model(input_image)
+
+ junc_np = convert_junc_predictions(
+ net_outputs["junctions"], self.grid_size,
+ self.junc_detect_thresh, self.max_num_junctions)
+ if valid_mask is None:
+ junctions = np.where(junc_np["junc_pred_nms"].squeeze())
+ else:
+ junctions = np.where(junc_np["junc_pred_nms"].squeeze()
+ * valid_mask)
+ junctions = np.concatenate(
+ [junctions[0][..., None], junctions[1][..., None]], axis=-1)
+
+ if net_outputs["heatmap"].shape[1] == 2:
+ # Convert to single channel directly from here
+ heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :]
+ else:
+ heatmap = torch.sigmoid(net_outputs["heatmap"])
+ heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0]
+
+ # Run the line detector.
+ line_map, junctions, heatmap = self.line_detector.detect(
+ junctions, heatmap, device=self.device)
+ heatmap = heatmap.cpu().numpy()
+ if isinstance(line_map, torch.Tensor):
+ line_map = line_map.cpu().numpy()
+ if isinstance(junctions, torch.Tensor):
+ junctions = junctions.cpu().numpy()
+ line_segments = line_map_to_segments(junctions, line_map)
+ end_time = time.time()
+
+ outputs = {"line_segments": line_segments}
+
+ if return_heatmap:
+ outputs["heatmap"] = heatmap
+ if profile:
+ outputs["time"] = end_time - start_time
+
+ return outputs
diff --git a/imcui/third_party/SOLD2/sold2/model/line_matcher.py b/imcui/third_party/SOLD2/sold2/model/line_matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc5a003573c91313e2295c75871edcb1c113662a
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/line_matcher.py
@@ -0,0 +1,279 @@
+"""
+Implements the full pipeline from raw images to line matches.
+"""
+import time
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.nn.functional import softmax
+
+from .model_util import get_model
+from .loss import get_loss_and_weights
+from .metrics import super_nms
+from .line_detection import LineSegmentDetectionModule
+from .line_matching import WunschLineMatcher
+from ..train import convert_junc_predictions
+from ..misc.train_utils import adapt_checkpoint
+from .line_detector import line_map_to_segments
+
+
+class LineMatcher(object):
+ """ Full line matcher including line detection and matching
+ with the Needleman-Wunsch algorithm. """
+ def __init__(self, model_cfg, ckpt_path, device, line_detector_cfg,
+ line_matcher_cfg, multiscale=False, scales=[1., 2.]):
+ # Get loss weights if dynamic weighting
+ _, loss_weights = get_loss_and_weights(model_cfg, device)
+ self.device = device
+
+ # Initialize the cnn backbone
+ self.model = get_model(model_cfg, loss_weights)
+ checkpoint = torch.load(ckpt_path, map_location=self.device)
+ checkpoint = adapt_checkpoint(checkpoint["model_state_dict"])
+ self.model.load_state_dict(checkpoint)
+ self.model = self.model.to(self.device)
+ self.model = self.model.eval()
+
+ self.grid_size = model_cfg["grid_size"]
+ self.junc_detect_thresh = model_cfg["detection_thresh"]
+ self.max_num_junctions = model_cfg.get("max_num_junctions", 300)
+
+ # Initialize the line detector
+ self.line_detector = LineSegmentDetectionModule(**line_detector_cfg)
+ self.multiscale = multiscale
+ self.scales = scales
+
+ # Initialize the line matcher
+ self.line_matcher = WunschLineMatcher(**line_matcher_cfg)
+
+ # Print some debug messages
+ for key, val in line_detector_cfg.items():
+ print(f"[Debug] {key}: {val}")
+ # print("[Debug] detect_thresh: %f" % (line_detector_cfg["detect_thresh"]))
+ # print("[Debug] num_samples: %d" % (line_detector_cfg["num_samples"]))
+
+
+
+ # Perform line detection and descriptor inference on a single image
+ def line_detection(self, input_image, valid_mask=None,
+ desc_only=False, profile=False):
+ # Restrict input_image to 4D torch tensor
+ if ((not len(input_image.shape) == 4)
+ or (not isinstance(input_image, torch.Tensor))):
+ raise ValueError(
+ "[Error] the input image should be a 4D torch tensor")
+
+ # Move the input to corresponding device
+ input_image = input_image.to(self.device)
+
+ # Forward of the CNN backbone
+ start_time = time.time()
+ with torch.no_grad():
+ net_outputs = self.model(input_image)
+
+ outputs = {"descriptor": net_outputs["descriptors"]}
+
+ if not desc_only:
+ junc_np = convert_junc_predictions(
+ net_outputs["junctions"], self.grid_size,
+ self.junc_detect_thresh, self.max_num_junctions)
+ if valid_mask is None:
+ junctions = np.where(junc_np["junc_pred_nms"].squeeze())
+ else:
+ junctions = np.where(
+ junc_np["junc_pred_nms"].squeeze() * valid_mask)
+ junctions = np.concatenate([junctions[0][..., None],
+ junctions[1][..., None]], axis=-1)
+
+ if net_outputs["heatmap"].shape[1] == 2:
+ # Convert to single channel directly from here
+ heatmap = softmax(
+ net_outputs["heatmap"],
+ dim=1)[:, 1:, :, :].cpu().numpy().transpose(0, 2, 3, 1)
+ else:
+ heatmap = torch.sigmoid(
+ net_outputs["heatmap"]).cpu().numpy().transpose(0, 2, 3, 1)
+ heatmap = heatmap[0, :, :, 0]
+
+ # Run the line detector.
+ line_map, junctions, heatmap = self.line_detector.detect(
+ junctions, heatmap, device=self.device)
+ if isinstance(line_map, torch.Tensor):
+ line_map = line_map.cpu().numpy()
+ if isinstance(junctions, torch.Tensor):
+ junctions = junctions.cpu().numpy()
+ outputs["heatmap"] = heatmap.cpu().numpy()
+ outputs["junctions"] = junctions
+
+ # If it's a line map with multiple detect_thresh and inlier_thresh
+ if len(line_map.shape) > 2:
+ num_detect_thresh = line_map.shape[0]
+ num_inlier_thresh = line_map.shape[1]
+ line_segments = []
+ for detect_idx in range(num_detect_thresh):
+ line_segments_inlier = []
+ for inlier_idx in range(num_inlier_thresh):
+ line_map_tmp = line_map[detect_idx, inlier_idx, :, :]
+ line_segments_tmp = line_map_to_segments(junctions, line_map_tmp)
+ line_segments_inlier.append(line_segments_tmp)
+ line_segments.append(line_segments_inlier)
+ else:
+ line_segments = line_map_to_segments(junctions, line_map)
+
+ outputs["line_segments"] = line_segments
+
+ end_time = time.time()
+
+ if profile:
+ outputs["time"] = end_time - start_time
+
+ return outputs
+
+ # Perform line detection and descriptor inference at multiple scales
+ def multiscale_line_detection(self, input_image, valid_mask=None,
+ desc_only=False, profile=False,
+ scales=[1., 2.], aggregation='mean'):
+ # Restrict input_image to 4D torch tensor
+ if ((not len(input_image.shape) == 4)
+ or (not isinstance(input_image, torch.Tensor))):
+ raise ValueError(
+ "[Error] the input image should be a 4D torch tensor")
+
+ # Move the input to corresponding device
+ input_image = input_image.to(self.device)
+ img_size = input_image.shape[2:4]
+ desc_size = tuple(np.array(img_size) // 4)
+
+ # Run the inference at multiple image scales
+ start_time = time.time()
+ junctions, heatmaps, descriptors = [], [], []
+ for s in scales:
+ # Resize the image
+ resized_img = F.interpolate(input_image, scale_factor=s,
+ mode='bilinear')
+
+ # Forward of the CNN backbone
+ with torch.no_grad():
+ net_outputs = self.model(resized_img)
+
+ descriptors.append(F.interpolate(
+ net_outputs["descriptors"], size=desc_size, mode="bilinear"))
+
+ if not desc_only:
+ junc_prob = convert_junc_predictions(
+ net_outputs["junctions"], self.grid_size)["junc_pred"]
+ junctions.append(cv2.resize(junc_prob.squeeze(),
+ (img_size[1], img_size[0]),
+ interpolation=cv2.INTER_LINEAR))
+
+ if net_outputs["heatmap"].shape[1] == 2:
+ # Convert to single channel directly from here
+ heatmap = softmax(net_outputs["heatmap"],
+ dim=1)[:, 1:, :, :]
+ else:
+ heatmap = torch.sigmoid(net_outputs["heatmap"])
+ heatmaps.append(F.interpolate(heatmap, size=img_size,
+ mode="bilinear"))
+
+ # Aggregate the results
+ if aggregation == 'mean':
+ # Aggregation through the mean activation
+ descriptors = torch.stack(descriptors, dim=0).mean(0)
+ else:
+ # Aggregation through the max activation
+ descriptors = torch.stack(descriptors, dim=0).max(0)[0]
+ outputs = {"descriptor": descriptors}
+
+ if not desc_only:
+ if aggregation == 'mean':
+ junctions = np.stack(junctions, axis=0).mean(0)[None]
+ heatmap = torch.stack(heatmaps, dim=0).mean(0)[0, 0, :, :]
+ heatmap = heatmap.cpu().numpy()
+ else:
+ junctions = np.stack(junctions, axis=0).max(0)[None]
+ heatmap = torch.stack(heatmaps, dim=0).max(0)[0][0, 0, :, :]
+ heatmap = heatmap.cpu().numpy()
+
+ # Extract junctions
+ junc_pred_nms = super_nms(
+ junctions[..., None], self.grid_size,
+ self.junc_detect_thresh, self.max_num_junctions)
+ if valid_mask is None:
+ junctions = np.where(junc_pred_nms.squeeze())
+ else:
+ junctions = np.where(junc_pred_nms.squeeze() * valid_mask)
+ junctions = np.concatenate([junctions[0][..., None],
+ junctions[1][..., None]], axis=-1)
+
+ # Run the line detector.
+ line_map, junctions, heatmap = self.line_detector.detect(
+ junctions, heatmap, device=self.device)
+ if isinstance(line_map, torch.Tensor):
+ line_map = line_map.cpu().numpy()
+ if isinstance(junctions, torch.Tensor):
+ junctions = junctions.cpu().numpy()
+ outputs["heatmap"] = heatmap.cpu().numpy()
+ outputs["junctions"] = junctions
+
+ # If it's a line map with multiple detect_thresh and inlier_thresh
+ if len(line_map.shape) > 2:
+ num_detect_thresh = line_map.shape[0]
+ num_inlier_thresh = line_map.shape[1]
+ line_segments = []
+ for detect_idx in range(num_detect_thresh):
+ line_segments_inlier = []
+ for inlier_idx in range(num_inlier_thresh):
+ line_map_tmp = line_map[detect_idx, inlier_idx, :, :]
+ line_segments_tmp = line_map_to_segments(
+ junctions, line_map_tmp)
+ line_segments_inlier.append(line_segments_tmp)
+ line_segments.append(line_segments_inlier)
+ else:
+ line_segments = line_map_to_segments(junctions, line_map)
+
+ outputs["line_segments"] = line_segments
+
+ end_time = time.time()
+
+ if profile:
+ outputs["time"] = end_time - start_time
+
+ return outputs
+
+ def __call__(self, images, valid_masks=[None, None], profile=False):
+ # Line detection and descriptor inference on both images
+ if self.multiscale:
+ forward_outputs = [
+ self.multiscale_line_detection(
+ images[0], valid_masks[0], profile=profile,
+ scales=self.scales),
+ self.multiscale_line_detection(
+ images[1], valid_masks[1], profile=profile,
+ scales=self.scales)]
+ else:
+ forward_outputs = [
+ self.line_detection(images[0], valid_masks[0],
+ profile=profile),
+ self.line_detection(images[1], valid_masks[1],
+ profile=profile)]
+ line_seg1 = forward_outputs[0]["line_segments"]
+ line_seg2 = forward_outputs[1]["line_segments"]
+ desc1 = forward_outputs[0]["descriptor"]
+ desc2 = forward_outputs[1]["descriptor"]
+
+ # Match the lines in both images
+ start_time = time.time()
+ matches = self.line_matcher.forward(line_seg1, line_seg2,
+ desc1, desc2)
+ end_time = time.time()
+
+ outputs = {"line_segments": [line_seg1, line_seg2],
+ "matches": matches}
+
+ if profile:
+ outputs["line_detection_time"] = (forward_outputs[0]["time"]
+ + forward_outputs[1]["time"])
+ outputs["line_matching_time"] = end_time - start_time
+
+ return outputs
diff --git a/imcui/third_party/SOLD2/sold2/model/line_matching.py b/imcui/third_party/SOLD2/sold2/model/line_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b71879e3104f9a8b52c1cf5e534cd124fe83b2
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/line_matching.py
@@ -0,0 +1,390 @@
+"""
+Implementation of the line matching methods.
+"""
+import numpy as np
+import cv2
+import torch
+import torch.nn.functional as F
+
+from ..misc.geometry_utils import keypoints_to_grid
+
+
+class WunschLineMatcher(object):
+ """ Class matching two sets of line segments
+ with the Needleman-Wunsch algorithm. """
+ def __init__(self, cross_check=True, num_samples=10, min_dist_pts=8,
+ top_k_candidates=10, grid_size=8, sampling="regular",
+ line_score=False):
+ self.cross_check = cross_check
+ self.num_samples = num_samples
+ self.min_dist_pts = min_dist_pts
+ self.top_k_candidates = top_k_candidates
+ self.grid_size = grid_size
+ self.line_score = line_score # True to compute saliency on a line
+ self.sampling_mode = sampling
+ if sampling not in ["regular", "d2_net", "asl_feat"]:
+ raise ValueError("Wrong sampling mode: " + sampling)
+
+ def forward(self, line_seg1, line_seg2, desc1, desc2):
+ """
+ Find the best matches between two sets of line segments
+ and their corresponding descriptors.
+ """
+ img_size1 = (desc1.shape[2] * self.grid_size,
+ desc1.shape[3] * self.grid_size)
+ img_size2 = (desc2.shape[2] * self.grid_size,
+ desc2.shape[3] * self.grid_size)
+ device = desc1.device
+
+ # Default case when an image has no lines
+ if len(line_seg1) == 0:
+ return np.empty((0), dtype=int)
+ if len(line_seg2) == 0:
+ return -np.ones(len(line_seg1), dtype=int)
+
+ # Sample points regularly along each line
+ if self.sampling_mode == "regular":
+ line_points1, valid_points1 = self.sample_line_points(line_seg1)
+ line_points2, valid_points2 = self.sample_line_points(line_seg2)
+ else:
+ line_points1, valid_points1 = self.sample_salient_points(
+ line_seg1, desc1, img_size1, self.sampling_mode)
+ line_points2, valid_points2 = self.sample_salient_points(
+ line_seg2, desc2, img_size2, self.sampling_mode)
+ line_points1 = torch.tensor(line_points1.reshape(-1, 2),
+ dtype=torch.float, device=device)
+ line_points2 = torch.tensor(line_points2.reshape(-1, 2),
+ dtype=torch.float, device=device)
+
+ # Extract the descriptors for each point
+ grid1 = keypoints_to_grid(line_points1, img_size1)
+ grid2 = keypoints_to_grid(line_points2, img_size2)
+ desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0)
+ desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0)
+
+ # Precompute the distance between line points for every pair of lines
+ # Assign a score of -1 for unvalid points
+ scores = desc1.t() @ desc2
+ scores[~valid_points1.flatten()] = -1
+ scores[:, ~valid_points2.flatten()] = -1
+ scores = scores.reshape(len(line_seg1), self.num_samples,
+ len(line_seg2), self.num_samples)
+ scores = scores.permute(0, 2, 1, 3)
+ # scores.shape = (n_lines1, n_lines2, num_samples, num_samples)
+
+ # Pre-filter the line candidates and find the best match for each line
+ matches = self.filter_and_match_lines(scores)
+
+ # [Optionally] filter matches with mutual nearest neighbor filtering
+ if self.cross_check:
+ matches2 = self.filter_and_match_lines(
+ scores.permute(1, 0, 3, 2))
+ mutual = matches2[matches] == np.arange(len(line_seg1))
+ matches[~mutual] = -1
+
+ return matches
+
+ def d2_net_saliency_score(self, desc):
+ """ Compute the D2-Net saliency score
+ on a 3D or 4D descriptor. """
+ is_3d = len(desc.shape) == 3
+ b_size = len(desc)
+ feat = F.relu(desc)
+
+ # Compute the soft local max
+ exp = torch.exp(feat)
+ if is_3d:
+ sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1,
+ padding=1)
+ else:
+ sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1,
+ padding=1)
+ soft_local_max = exp / sum_exp
+
+ # Compute the depth-wise maximum
+ depth_wise_max = torch.max(feat, dim=1)[0]
+ depth_wise_max = feat / depth_wise_max.unsqueeze(1)
+
+ # Total saliency score
+ score = torch.max(soft_local_max * depth_wise_max, dim=1)[0]
+ normalization = torch.sum(score.reshape(b_size, -1), dim=1)
+ if is_3d:
+ normalization = normalization.reshape(b_size, 1)
+ else:
+ normalization = normalization.reshape(b_size, 1, 1)
+ score = score / normalization
+ return score
+
+ def asl_feat_saliency_score(self, desc):
+ """ Compute the ASLFeat saliency score on a 3D or 4D descriptor. """
+ is_3d = len(desc.shape) == 3
+ b_size = len(desc)
+
+ # Compute the soft local peakiness
+ if is_3d:
+ local_avg = F.avg_pool1d(desc, kernel_size=3, stride=1, padding=1)
+ else:
+ local_avg = F.avg_pool2d(desc, kernel_size=3, stride=1, padding=1)
+ soft_local_score = F.softplus(desc - local_avg)
+
+ # Compute the depth-wise peakiness
+ depth_wise_mean = torch.mean(desc, dim=1).unsqueeze(1)
+ depth_wise_score = F.softplus(desc - depth_wise_mean)
+
+ # Total saliency score
+ score = torch.max(soft_local_score * depth_wise_score, dim=1)[0]
+ normalization = torch.sum(score.reshape(b_size, -1), dim=1)
+ if is_3d:
+ normalization = normalization.reshape(b_size, 1)
+ else:
+ normalization = normalization.reshape(b_size, 1, 1)
+ score = score / normalization
+ return score
+
+ def sample_salient_points(self, line_seg, desc, img_size,
+ saliency_type='d2_net'):
+ """
+ Sample the most salient points along each line segments, with a
+ minimal distance between each point. Pad the remaining points.
+ Inputs:
+ line_seg: an Nx2x2 torch.Tensor.
+ desc: a NxDxHxW torch.Tensor.
+ image_size: the original image size.
+ saliency_type: 'd2_net' or 'asl_feat'.
+ Outputs:
+ line_points: an Nxnum_samplesx2 np.array.
+ valid_points: a boolean Nxnum_samples np.array.
+ """
+ device = desc.device
+ if not self.line_score:
+ # Compute the score map
+ if saliency_type == "d2_net":
+ score = self.d2_net_saliency_score(desc)
+ else:
+ score = self.asl_feat_saliency_score(desc)
+
+ num_lines = len(line_seg)
+ line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1)
+
+ # The number of samples depends on the length of the line
+ num_samples_lst = np.clip(line_lengths // self.min_dist_pts,
+ 2, self.num_samples)
+ line_points = np.empty((num_lines, self.num_samples, 2), dtype=float)
+ valid_points = np.empty((num_lines, self.num_samples), dtype=bool)
+
+ # Sample the score on a fixed number of points of each line
+ n_samples_per_region = 4
+ for n in np.arange(2, self.num_samples + 1):
+ sample_rate = n * n_samples_per_region
+ # Consider all lines where we can fit up to n points
+ cur_mask = num_samples_lst == n
+ cur_line_seg = line_seg[cur_mask]
+ cur_num_lines = len(cur_line_seg)
+ if cur_num_lines == 0:
+ continue
+ line_points_x = np.linspace(cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 0],
+ sample_rate, axis=-1)
+ line_points_y = np.linspace(cur_line_seg[:, 0, 1],
+ cur_line_seg[:, 1, 1],
+ sample_rate, axis=-1)
+ cur_line_points = np.stack([line_points_x, line_points_y],
+ axis=-1).reshape(-1, 2)
+ # cur_line_points is of shape (n_cur_lines * sample_rate, 2)
+ cur_line_points = torch.tensor(cur_line_points, dtype=torch.float,
+ device=device)
+ grid_points = keypoints_to_grid(cur_line_points, img_size)
+
+ if self.line_score:
+ # The saliency score is high when the activation are locally
+ # maximal along the line (and not in a square neigborhood)
+ line_desc = F.grid_sample(desc, grid_points).squeeze()
+ line_desc = line_desc.reshape(-1, cur_num_lines, sample_rate)
+ line_desc = line_desc.permute(1, 0, 2)
+ if saliency_type == "d2_net":
+ scores = self.d2_net_saliency_score(line_desc)
+ else:
+ scores = self.asl_feat_saliency_score(line_desc)
+ else:
+ scores = F.grid_sample(score.unsqueeze(1),
+ grid_points).squeeze()
+
+ # Take the most salient point in n distinct regions
+ scores = scores.reshape(-1, n, n_samples_per_region)
+ best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy()
+ cur_line_points = cur_line_points.reshape(-1, n,
+ n_samples_per_region, 2)
+ cur_line_points = np.take_along_axis(
+ cur_line_points, best[..., None], axis=2)[:, :, 0]
+
+ # Pad
+ cur_valid_points = np.ones((cur_num_lines, self.num_samples),
+ dtype=bool)
+ cur_valid_points[:, n:] = False
+ cur_line_points = np.concatenate([
+ cur_line_points,
+ np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)],
+ axis=1)
+
+ line_points[cur_mask] = cur_line_points
+ valid_points[cur_mask] = cur_valid_points
+
+ return line_points, valid_points
+
+ def sample_line_points(self, line_seg):
+ """
+ Regularly sample points along each line segments, with a minimal
+ distance between each point. Pad the remaining points.
+ Inputs:
+ line_seg: an Nx2x2 torch.Tensor.
+ Outputs:
+ line_points: an Nxnum_samplesx2 np.array.
+ valid_points: a boolean Nxnum_samples np.array.
+ """
+ num_lines = len(line_seg)
+ line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1)
+
+ # Sample the points separated by at least min_dist_pts along each line
+ # The number of samples depends on the length of the line
+ num_samples_lst = np.clip(line_lengths // self.min_dist_pts,
+ 2, self.num_samples)
+ line_points = np.empty((num_lines, self.num_samples, 2), dtype=float)
+ valid_points = np.empty((num_lines, self.num_samples), dtype=bool)
+ for n in np.arange(2, self.num_samples + 1):
+ # Consider all lines where we can fit up to n points
+ cur_mask = num_samples_lst == n
+ cur_line_seg = line_seg[cur_mask]
+ line_points_x = np.linspace(cur_line_seg[:, 0, 0],
+ cur_line_seg[:, 1, 0],
+ n, axis=-1)
+ line_points_y = np.linspace(cur_line_seg[:, 0, 1],
+ cur_line_seg[:, 1, 1],
+ n, axis=-1)
+ cur_line_points = np.stack([line_points_x, line_points_y], axis=-1)
+
+ # Pad
+ cur_num_lines = len(cur_line_seg)
+ cur_valid_points = np.ones((cur_num_lines, self.num_samples),
+ dtype=bool)
+ cur_valid_points[:, n:] = False
+ cur_line_points = np.concatenate([
+ cur_line_points,
+ np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)],
+ axis=1)
+
+ line_points[cur_mask] = cur_line_points
+ valid_points[cur_mask] = cur_valid_points
+
+ return line_points, valid_points
+
+ def filter_and_match_lines(self, scores):
+ """
+ Use the scores to keep the top k best lines, compute the Needleman-
+ Wunsch algorithm on each candidate pairs, and keep the highest score.
+ Inputs:
+ scores: a (N, M, n, n) torch.Tensor containing the pairwise scores
+ of the elements to match.
+ Outputs:
+ matches: a (N) np.array containing the indices of the best match
+ """
+ # Pre-filter the pairs and keep the top k best candidate lines
+ line_scores1 = scores.max(3)[0]
+ valid_scores1 = line_scores1 != -1
+ line_scores1 = ((line_scores1 * valid_scores1).sum(2)
+ / valid_scores1.sum(2))
+ line_scores2 = scores.max(2)[0]
+ valid_scores2 = line_scores2 != -1
+ line_scores2 = ((line_scores2 * valid_scores2).sum(2)
+ / valid_scores2.sum(2))
+ line_scores = (line_scores1 + line_scores2) / 2
+ topk_lines = torch.argsort(line_scores,
+ dim=1)[:, -self.top_k_candidates:]
+ scores, topk_lines = scores.cpu().numpy(), topk_lines.cpu().numpy()
+ # topk_lines.shape = (n_lines1, top_k_candidates)
+ top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None],
+ axis=1)
+
+ # Consider the reversed line segments as well
+ top_scores = np.concatenate([top_scores, top_scores[..., ::-1]],
+ axis=1)
+
+ # Compute the line distance matrix with Needleman-Wunsch algo and
+ # retrieve the closest line neighbor
+ n_lines1, top2k, n, m = top_scores.shape
+ top_scores = top_scores.reshape(n_lines1 * top2k, n, m)
+ nw_scores = self.needleman_wunsch(top_scores)
+ nw_scores = nw_scores.reshape(n_lines1, top2k)
+ matches = np.mod(np.argmax(nw_scores, axis=1), top2k // 2)
+ matches = topk_lines[np.arange(n_lines1), matches]
+ return matches
+
+ def needleman_wunsch(self, scores):
+ """
+ Batched implementation of the Needleman-Wunsch algorithm.
+ The cost of the InDel operation is set to 0 by subtracting the gap
+ penalty to the scores.
+ Inputs:
+ scores: a (B, N, M) np.array containing the pairwise scores
+ of the elements to match.
+ """
+ b, n, m = scores.shape
+
+ # Recalibrate the scores to get a gap score of 0
+ gap = 0.1
+ nw_scores = scores - gap
+
+ # Run the dynamic programming algorithm
+ nw_grid = np.zeros((b, n + 1, m + 1), dtype=float)
+ for i in range(n):
+ for j in range(m):
+ nw_grid[:, i + 1, j + 1] = np.maximum(
+ np.maximum(nw_grid[:, i + 1, j], nw_grid[:, i, j + 1]),
+ nw_grid[:, i, j] + nw_scores[:, i, j])
+
+ return nw_grid[:, -1, -1]
+
+ def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2):
+ """
+ Compute the OPPOSITE of the NW score for pairs of line segments
+ and their corresponding descriptors.
+ """
+ num_lines = len(line_seg1)
+ assert num_lines == len(line_seg2), "The same number of lines is required in pairwise score."
+ img_size1 = (desc1.shape[2] * self.grid_size,
+ desc1.shape[3] * self.grid_size)
+ img_size2 = (desc2.shape[2] * self.grid_size,
+ desc2.shape[3] * self.grid_size)
+ device = desc1.device
+
+ # Sample points regularly along each line
+ line_points1, valid_points1 = self.sample_line_points(line_seg1)
+ line_points2, valid_points2 = self.sample_line_points(line_seg2)
+ line_points1 = torch.tensor(line_points1.reshape(-1, 2),
+ dtype=torch.float, device=device)
+ line_points2 = torch.tensor(line_points2.reshape(-1, 2),
+ dtype=torch.float, device=device)
+
+ # Extract the descriptors for each point
+ grid1 = keypoints_to_grid(line_points1, img_size1)
+ grid2 = keypoints_to_grid(line_points2, img_size2)
+ desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0)
+ desc1 = desc1.reshape(-1, num_lines, self.num_samples)
+ desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0)
+ desc2 = desc2.reshape(-1, num_lines, self.num_samples)
+
+ # Compute the distance between line points for every pair of lines
+ # Assign a score of -1 for unvalid points
+ scores = torch.einsum('dns,dnt->nst', desc1, desc2).cpu().numpy()
+ scores = scores.reshape(num_lines * self.num_samples,
+ self.num_samples)
+ scores[~valid_points1.flatten()] = -1
+ scores = scores.reshape(num_lines, self.num_samples, self.num_samples)
+ scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1)
+ scores[:, ~valid_points2.flatten()] = -1
+ scores = scores.reshape(self.num_samples, num_lines, self.num_samples)
+ scores = scores.transpose(1, 0, 2)
+ # scores.shape = (num_lines, num_samples, num_samples)
+
+ # Compute the NW score for each pair of lines
+ pairwise_scores = np.array([self.needleman_wunsch(s) for s in scores])
+ return -pairwise_scores
diff --git a/imcui/third_party/SOLD2/sold2/model/loss.py b/imcui/third_party/SOLD2/sold2/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaad3c67f3fd59db308869901f8a56623901e318
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/loss.py
@@ -0,0 +1,445 @@
+"""
+Loss function implementations.
+"""
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from kornia.geometry import warp_perspective
+
+from ..misc.geometry_utils import (keypoints_to_grid, get_dist_mask,
+ get_common_line_mask)
+
+
+def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
+ """ Get loss functions and either static or dynamic weighting. """
+ # Get the global weighting policy
+ w_policy = model_cfg.get("weighting_policy", "static")
+ if not w_policy in ["static", "dynamic"]:
+ raise ValueError("[Error] Not supported weighting policy.")
+
+ loss_func = {}
+ loss_weight = {}
+ # Get junction loss function and weight
+ w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy)
+ loss_func["junc_loss"] = junc_loss_func.to(device)
+ loss_weight["w_junc"] = w_junc
+
+ # Get heatmap loss function and weight
+ w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight(
+ model_cfg, w_policy, device)
+ loss_func["heatmap_loss"] = heatmap_loss_func.to(device)
+ loss_weight["w_heatmap"] = w_heatmap
+
+ # [Optionally] get descriptor loss function and weight
+ if model_cfg.get("descriptor_loss_func", None) is not None:
+ w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight(
+ model_cfg, w_policy)
+ loss_func["descriptor_loss"] = descriptor_loss_func.to(device)
+ loss_weight["w_desc"] = w_descriptor
+
+ return loss_func, loss_weight
+
+
+def get_junction_loss_and_weight(model_cfg, global_w_policy):
+ """ Get the junction loss function and weight. """
+ junction_loss_cfg = model_cfg.get("junction_loss_cfg", {})
+
+ # Get the junction loss weight
+ w_policy = junction_loss_cfg.get("policy", global_w_policy)
+ if w_policy == "static":
+ w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32)
+ elif w_policy == "dynamic":
+ w_junc = nn.Parameter(
+ torch.tensor(model_cfg["w_junc"], dtype=torch.float32),
+ requires_grad=True)
+ else:
+ raise ValueError(
+ "[Error] Unknown weighting policy for junction loss weight.")
+
+ # Get the junction loss function
+ junc_loss_name = model_cfg.get("junction_loss_func", "superpoint")
+ if junc_loss_name == "superpoint":
+ junc_loss_func = JunctionDetectionLoss(model_cfg["grid_size"],
+ model_cfg["keep_border_valid"])
+ else:
+ raise ValueError("[Error] Not supported junction loss function.")
+
+ return w_junc, junc_loss_func
+
+
+def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device):
+ """ Get the heatmap loss function and weight. """
+ heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {})
+
+ # Get the heatmap loss weight
+ w_policy = heatmap_loss_cfg.get("policy", global_w_policy)
+ if w_policy == "static":
+ w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32)
+ elif w_policy == "dynamic":
+ w_heatmap = nn.Parameter(
+ torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32),
+ requires_grad=True)
+ else:
+ raise ValueError(
+ "[Error] Unknown weighting policy for junction loss weight.")
+
+ # Get the corresponding heatmap loss based on the config
+ heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy")
+ if heatmap_loss_name == "cross_entropy":
+ # Get the heatmap class weight (always static)
+ heatmap_class_w = model_cfg.get("w_heatmap_class", 1.)
+ class_weight = torch.tensor(
+ np.array([1., heatmap_class_w])).to(torch.float).to(device)
+ heatmap_loss_func = HeatmapLoss(class_weight=class_weight)
+ else:
+ raise ValueError("[Error] Not supported heatmap loss function.")
+
+ return w_heatmap, heatmap_loss_func
+
+
+def get_descriptor_loss_and_weight(model_cfg, global_w_policy):
+ """ Get the descriptor loss function and weight. """
+ descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {})
+
+ # Get the descriptor loss weight
+ w_policy = descriptor_loss_cfg.get("policy", global_w_policy)
+ if w_policy == "static":
+ w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32)
+ elif w_policy == "dynamic":
+ w_descriptor = nn.Parameter(torch.tensor(model_cfg["w_desc"],
+ dtype=torch.float32), requires_grad=True)
+ else:
+ raise ValueError(
+ "[Error] Unknown weighting policy for descriptor loss weight.")
+
+ # Get the descriptor loss function
+ descriptor_loss_name = model_cfg.get("descriptor_loss_func",
+ "regular_sampling")
+ if descriptor_loss_name == "regular_sampling":
+ descriptor_loss_func = TripletDescriptorLoss(
+ descriptor_loss_cfg["grid_size"],
+ descriptor_loss_cfg["dist_threshold"],
+ descriptor_loss_cfg["margin"])
+ else:
+ raise ValueError("[Error] Not supported descriptor loss function.")
+
+ return w_descriptor, descriptor_loss_func
+
+
+def space_to_depth(input_tensor, grid_size):
+ """ PixelUnshuffle for pytorch. """
+ N, C, H, W = input_tensor.size()
+ # (N, C, H//bs, bs, W//bs, bs)
+ x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size)
+ # (N, bs, bs, C, H//bs, W//bs)
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
+ # (N, C*bs^2, H//bs, W//bs)
+ x = x.view(N, C * (grid_size ** 2), H // grid_size, W // grid_size)
+ return x
+
+
+def junction_detection_loss(junction_map, junc_predictions, valid_mask=None,
+ grid_size=8, keep_border=True):
+ """ Junction detection loss. """
+ # Convert junc_map to channel tensor
+ junc_map = space_to_depth(junction_map, grid_size)
+ map_shape = junc_map.shape[-2:]
+ batch_size = junc_map.shape[0]
+ dust_bin_label = torch.ones(
+ [batch_size, 1, map_shape[0],
+ map_shape[1]]).to(junc_map.device).to(torch.int)
+ junc_map = torch.cat([junc_map*2, dust_bin_label], dim=1)
+ labels = torch.argmax(
+ junc_map.to(torch.float) +
+ torch.distributions.Uniform(0, 0.1).sample(junc_map.shape).to(junc_map.device),
+ dim=1)
+
+ # Also convert the valid mask to channel tensor
+ valid_mask = (torch.ones(junction_map.shape) if valid_mask is None
+ else valid_mask)
+ valid_mask = space_to_depth(valid_mask, grid_size)
+
+ # Compute junction loss on the border patch or not
+ if keep_border:
+ valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int),
+ dim=1, keepdim=True) > 0
+ else:
+ valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int),
+ dim=1, keepdim=True) >= grid_size * grid_size
+
+ # Compute the classification loss
+ loss_func = nn.CrossEntropyLoss(reduction="none")
+ # The loss still need NCHW format
+ loss = loss_func(input=junc_predictions,
+ target=labels.to(torch.long))
+
+ # Weighted sum by the valid mask
+ loss_ = torch.sum(loss * torch.squeeze(valid_mask.to(torch.float),
+ dim=1), dim=[0, 1, 2])
+ loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float),
+ dim=1))
+
+ return loss_final
+
+
+def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None,
+ class_weight=None):
+ """ Heatmap prediction loss. """
+ # Compute the classification loss on each pixel
+ if class_weight is None:
+ loss_func = nn.CrossEntropyLoss(reduction="none")
+ else:
+ loss_func = nn.CrossEntropyLoss(class_weight, reduction="none")
+
+ loss = loss_func(input=heatmap_pred,
+ target=torch.squeeze(heatmap_gt.to(torch.long), dim=1))
+
+ # Weighted sum by the valid mask
+ # Sum over H and W
+ loss_spatial_sum = torch.sum(loss * torch.squeeze(
+ valid_mask.to(torch.float), dim=1), dim=[1, 2])
+ valid_spatial_sum = torch.sum(torch.squeeze(valid_mask.to(torch.float32),
+ dim=1), dim=[1, 2])
+ # Mean to single scalar over batch dimension
+ loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum)
+
+ return loss
+
+
+class JunctionDetectionLoss(nn.Module):
+ """ Junction detection loss. """
+ def __init__(self, grid_size, keep_border):
+ super(JunctionDetectionLoss, self).__init__()
+ self.grid_size = grid_size
+ self.keep_border = keep_border
+
+ def forward(self, prediction, target, valid_mask=None):
+ return junction_detection_loss(target, prediction, valid_mask,
+ self.grid_size, self.keep_border)
+
+
+class HeatmapLoss(nn.Module):
+ """ Heatmap prediction loss. """
+ def __init__(self, class_weight):
+ super(HeatmapLoss, self).__init__()
+ self.class_weight = class_weight
+
+ def forward(self, prediction, target, valid_mask=None):
+ return heatmap_loss(target, prediction, valid_mask, self.class_weight)
+
+
+class RegularizationLoss(nn.Module):
+ """ Module for regularization loss. """
+ def __init__(self):
+ super(RegularizationLoss, self).__init__()
+ self.name = "regularization_loss"
+ self.loss_init = torch.zeros([])
+
+ def forward(self, loss_weights):
+ # Place it to the same device
+ loss = self.loss_init.to(loss_weights["w_junc"].device)
+ for _, val in loss_weights.items():
+ if isinstance(val, nn.Parameter):
+ loss += val
+
+ return loss
+
+
+def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices,
+ epoch, grid_size=8, dist_threshold=8,
+ init_dist_threshold=64, margin=1):
+ """ Regular triplet loss for descriptor learning. """
+ b_size, _, Hc, Wc = desc_pred1.size()
+ img_size = (Hc * grid_size, Wc * grid_size)
+ device = desc_pred1.device
+
+ # Extract valid keypoints
+ n_points = line_indices.size()[1]
+ valid_points = line_indices.bool().flatten()
+ n_correct_points = torch.sum(valid_points).item()
+ if n_correct_points == 0:
+ return torch.tensor(0., dtype=torch.float, device=device)
+
+ # Check which keypoints are too close to be matched
+ # dist_threshold is decreased at each epoch for easier training
+ dist_threshold = max(dist_threshold,
+ 2 * init_dist_threshold // (epoch + 1))
+ dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold)
+
+ # Additionally ban negative mining along the same line
+ common_line_mask = get_common_line_mask(line_indices, valid_points)
+ dist_mask = dist_mask | common_line_mask
+
+ # Convert the keypoints to a grid suitable for interpolation
+ grid1 = keypoints_to_grid(points1, img_size)
+ grid2 = keypoints_to_grid(points2, img_size)
+
+ # Extract the descriptors
+ desc1 = F.grid_sample(desc_pred1, grid1).permute(
+ 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+ desc1 = F.normalize(desc1, dim=1)
+ desc2 = F.grid_sample(desc_pred2, grid2).permute(
+ 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+ desc2 = F.normalize(desc2, dim=1)
+ desc_dists = 2 - 2 * (desc1 @ desc2.t())
+
+ # Positive distance loss
+ pos_dist = torch.diag(desc_dists)
+
+ # Negative distance loss
+ max_dist = torch.tensor(4., dtype=torch.float, device=device)
+ desc_dists[
+ torch.arange(n_correct_points, dtype=torch.long),
+ torch.arange(n_correct_points, dtype=torch.long)] = max_dist
+ desc_dists[dist_mask] = max_dist
+ neg_dist = torch.min(torch.min(desc_dists, dim=1)[0],
+ torch.min(desc_dists, dim=0)[0])
+
+ triplet_loss = F.relu(margin + pos_dist - neg_dist)
+ return triplet_loss, grid1, grid2, valid_points
+
+
+class TripletDescriptorLoss(nn.Module):
+ """ Triplet descriptor loss. """
+ def __init__(self, grid_size, dist_threshold, margin):
+ super(TripletDescriptorLoss, self).__init__()
+ self.grid_size = grid_size
+ self.init_dist_threshold = 64
+ self.dist_threshold = dist_threshold
+ self.margin = margin
+
+ def forward(self, desc_pred1, desc_pred2, points1,
+ points2, line_indices, epoch):
+ return self.descriptor_loss(desc_pred1, desc_pred2, points1,
+ points2, line_indices, epoch)
+
+ # The descriptor loss based on regularly sampled points along the lines
+ def descriptor_loss(self, desc_pred1, desc_pred2, points1,
+ points2, line_indices, epoch):
+ return torch.mean(triplet_loss(
+ desc_pred1, desc_pred2, points1, points2, line_indices, epoch,
+ self.grid_size, self.dist_threshold, self.init_dist_threshold,
+ self.margin)[0])
+
+
+class TotalLoss(nn.Module):
+ """ Total loss summing junction, heatma, descriptor
+ and regularization losses. """
+ def __init__(self, loss_funcs, loss_weights, weighting_policy):
+ super(TotalLoss, self).__init__()
+ # Whether we need to compute the descriptor loss
+ self.compute_descriptors = "descriptor_loss" in loss_funcs.keys()
+
+ self.loss_funcs = loss_funcs
+ self.loss_weights = loss_weights
+ self.weighting_policy = weighting_policy
+
+ # Always add regularization loss (it will return zero if not used)
+ self.loss_funcs["reg_loss"] = RegularizationLoss().cuda()
+
+ def forward(self, junc_pred, junc_target, heatmap_pred,
+ heatmap_target, valid_mask=None):
+ """ Detection only loss. """
+ # Compute the junction loss
+ junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target,
+ valid_mask)
+ # Compute the heatmap loss
+ heatmap_loss = self.loss_funcs["heatmap_loss"](
+ heatmap_pred, heatmap_target, valid_mask)
+
+ # Compute the total loss.
+ if self.weighting_policy == "dynamic":
+ reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
+ total_loss = junc_loss * torch.exp(-self.loss_weights["w_junc"]) + \
+ heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) + \
+ reg_loss
+
+ return {
+ "total_loss": total_loss,
+ "junc_loss": junc_loss,
+ "heatmap_loss": heatmap_loss,
+ "reg_loss": reg_loss,
+ "w_junc": torch.exp(-self.loss_weights["w_junc"]).item(),
+ "w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(),
+ }
+
+ elif self.weighting_policy == "static":
+ total_loss = junc_loss * self.loss_weights["w_junc"] + \
+ heatmap_loss * self.loss_weights["w_heatmap"]
+
+ return {
+ "total_loss": total_loss,
+ "junc_loss": junc_loss,
+ "heatmap_loss": heatmap_loss
+ }
+
+ else:
+ raise ValueError("[Error] Unknown weighting policy.")
+
+ def forward_descriptors(self,
+ junc_map_pred1, junc_map_pred2, junc_map_target1,
+ junc_map_target2, heatmap_pred1, heatmap_pred2, heatmap_target1,
+ heatmap_target2, line_points1, line_points2, line_indices,
+ desc_pred1, desc_pred2, epoch, valid_mask1=None,
+ valid_mask2=None):
+ """ Loss for detection + description. """
+ # Compute junction loss
+ junc_loss = self.loss_funcs["junc_loss"](
+ torch.cat([junc_map_pred1, junc_map_pred2], dim=0),
+ torch.cat([junc_map_target1, junc_map_target2], dim=0),
+ torch.cat([valid_mask1, valid_mask2], dim=0)
+ )
+ # Get junction loss weight (dynamic or not)
+ if isinstance(self.loss_weights["w_junc"], nn.Parameter):
+ w_junc = torch.exp(-self.loss_weights["w_junc"])
+ else:
+ w_junc = self.loss_weights["w_junc"]
+
+ # Compute heatmap loss
+ heatmap_loss = self.loss_funcs["heatmap_loss"](
+ torch.cat([heatmap_pred1, heatmap_pred2], dim=0),
+ torch.cat([heatmap_target1, heatmap_target2], dim=0),
+ torch.cat([valid_mask1, valid_mask2], dim=0)
+ )
+ # Get heatmap loss weight (dynamic or not)
+ if isinstance(self.loss_weights["w_heatmap"], nn.Parameter):
+ w_heatmap = torch.exp(-self.loss_weights["w_heatmap"])
+ else:
+ w_heatmap = self.loss_weights["w_heatmap"]
+
+ # Compute the descriptor loss
+ descriptor_loss = self.loss_funcs["descriptor_loss"](
+ desc_pred1, desc_pred2, line_points1,
+ line_points2, line_indices, epoch)
+ # Get descriptor loss weight (dynamic or not)
+ if isinstance(self.loss_weights["w_desc"], nn.Parameter):
+ w_descriptor = torch.exp(-self.loss_weights["w_desc"])
+ else:
+ w_descriptor = self.loss_weights["w_desc"]
+
+ # Update the total loss
+ total_loss = (junc_loss * w_junc
+ + heatmap_loss * w_heatmap
+ + descriptor_loss * w_descriptor)
+ outputs = {
+ "junc_loss": junc_loss,
+ "heatmap_loss": heatmap_loss,
+ "w_junc": w_junc.item() \
+ if isinstance(w_junc, nn.Parameter) else w_junc,
+ "w_heatmap": w_heatmap.item() \
+ if isinstance(w_heatmap, nn.Parameter) else w_heatmap,
+ "descriptor_loss": descriptor_loss,
+ "w_desc": w_descriptor.item() \
+ if isinstance(w_descriptor, nn.Parameter) else w_descriptor
+ }
+
+ # Compute the regularization loss
+ reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
+ total_loss += reg_loss
+ outputs.update({
+ "reg_loss": reg_loss,
+ "total_loss": total_loss
+ })
+
+ return outputs
diff --git a/imcui/third_party/SOLD2/sold2/model/lr_scheduler.py b/imcui/third_party/SOLD2/sold2/model/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3faa4f68a67564719008a932b40c16c5e908949f
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/lr_scheduler.py
@@ -0,0 +1,22 @@
+"""
+This file implements different learning rate schedulers
+"""
+import torch
+
+
+def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer):
+ """ Get the learning rate scheduler according to the config. """
+ # If no lr_decay is specified => return None
+ if (lr_decay == False) or (lr_decay_cfg is None):
+ schduler = None
+ # Exponential decay
+ elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"):
+ schduler = torch.optim.lr_scheduler.ExponentialLR(
+ optimizer,
+ gamma=lr_decay_cfg["gamma"]
+ )
+ # Unknown policy
+ else:
+ raise ValueError("[Error] Unknow learning rate decay policy!")
+
+ return schduler
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/model/metrics.py b/imcui/third_party/SOLD2/sold2/model/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..0894a7207ee4afa344cb332c605c715b14db73a4
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/metrics.py
@@ -0,0 +1,528 @@
+"""
+This file implements the evaluation metrics.
+"""
+import torch
+import torch.nn.functional as F
+import numpy as np
+from torchvision.ops.boxes import batched_nms
+
+from ..misc.geometry_utils import keypoints_to_grid
+
+
+class Metrics(object):
+ """ Metric evaluation calculator. """
+ def __init__(self, detection_thresh, prob_thresh, grid_size,
+ junc_metric_lst=None, heatmap_metric_lst=None,
+ pr_metric_lst=None, desc_metric_lst=None):
+ # List supported metrics
+ self.supported_junc_metrics = ["junc_precision", "junc_precision_nms",
+ "junc_recall", "junc_recall_nms"]
+ self.supported_heatmap_metrics = ["heatmap_precision",
+ "heatmap_recall"]
+ self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"]
+ self.supported_desc_metrics = ["matching_score"]
+
+ # If metric_lst is None, default to use all metrics
+ if junc_metric_lst is None:
+ self.junc_metric_lst = self.supported_junc_metrics
+ else:
+ self.junc_metric_lst = junc_metric_lst
+ if heatmap_metric_lst is None:
+ self.heatmap_metric_lst = self.supported_heatmap_metrics
+ else:
+ self.heatmap_metric_lst = heatmap_metric_lst
+ if pr_metric_lst is None:
+ self.pr_metric_lst = self.supported_pr_metrics
+ else:
+ self.pr_metric_lst = pr_metric_lst
+ # For the descriptors, the default None assumes no desc metric at all
+ if desc_metric_lst is None:
+ self.desc_metric_lst = []
+ elif desc_metric_lst == 'all':
+ self.desc_metric_lst = self.supported_desc_metrics
+ else:
+ self.desc_metric_lst = desc_metric_lst
+
+ if not self._check_metrics():
+ raise ValueError(
+ "[Error] Some elements in the metric_lst are invalid.")
+
+ # Metric mapping table
+ self.metric_table = {
+ "junc_precision": junction_precision(detection_thresh),
+ "junc_precision_nms": junction_precision(detection_thresh),
+ "junc_recall": junction_recall(detection_thresh),
+ "junc_recall_nms": junction_recall(detection_thresh),
+ "heatmap_precision": heatmap_precision(prob_thresh),
+ "heatmap_recall": heatmap_recall(prob_thresh),
+ "junc_pr": junction_pr(),
+ "junc_nms_pr": junction_pr(),
+ "matching_score": matching_score(grid_size)
+ }
+
+ # Initialize the results
+ self.metric_results = {}
+ for key in self.metric_table.keys():
+ self.metric_results[key] = 0.
+
+ def evaluate(self, junc_pred, junc_pred_nms, junc_gt, heatmap_pred,
+ heatmap_gt, valid_mask, line_points1=None, line_points2=None,
+ desc_pred1=None, desc_pred2=None, valid_points=None):
+ """ Perform evaluation. """
+ for metric in self.junc_metric_lst:
+ # If nms metrics then use nms to compute it.
+ if "nms" in metric:
+ junc_pred_input = junc_pred_nms
+ # Use normal inputs instead.
+ else:
+ junc_pred_input = junc_pred
+ self.metric_results[metric] = self.metric_table[metric](
+ junc_pred_input, junc_gt, valid_mask)
+
+ for metric in self.heatmap_metric_lst:
+ self.metric_results[metric] = self.metric_table[metric](
+ heatmap_pred, heatmap_gt, valid_mask)
+
+ for metric in self.pr_metric_lst:
+ if "nms" in metric:
+ self.metric_results[metric] = self.metric_table[metric](
+ junc_pred_nms, junc_gt, valid_mask)
+ else:
+ self.metric_results[metric] = self.metric_table[metric](
+ junc_pred, junc_gt, valid_mask)
+
+ for metric in self.desc_metric_lst:
+ self.metric_results[metric] = self.metric_table[metric](
+ line_points1, line_points2, desc_pred1,
+ desc_pred2, valid_points)
+
+ def _check_metrics(self):
+ """ Check if all input metrics are valid. """
+ flag = True
+ for metric in self.junc_metric_lst:
+ if not metric in self.supported_junc_metrics:
+ flag = False
+ break
+ for metric in self.heatmap_metric_lst:
+ if not metric in self.supported_heatmap_metrics:
+ flag = False
+ break
+ for metric in self.desc_metric_lst:
+ if not metric in self.supported_desc_metrics:
+ flag = False
+ break
+
+ return flag
+
+
+class AverageMeter(object):
+ def __init__(self, junc_metric_lst=None, heatmap_metric_lst=None,
+ is_training=True, desc_metric_lst=None):
+ # List supported metrics
+ self.supported_junc_metrics = ["junc_precision", "junc_precision_nms",
+ "junc_recall", "junc_recall_nms"]
+ self.supported_heatmap_metrics = ["heatmap_precision",
+ "heatmap_recall"]
+ self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"]
+ self.supported_desc_metrics = ["matching_score"]
+ # Record loss in training mode
+ # if is_training:
+ self.supported_loss = [
+ "junc_loss", "heatmap_loss", "descriptor_loss", "total_loss"]
+
+ self.is_training = is_training
+
+ # If metric_lst is None, default to use all metrics
+ if junc_metric_lst is None:
+ self.junc_metric_lst = self.supported_junc_metrics
+ else:
+ self.junc_metric_lst = junc_metric_lst
+ if heatmap_metric_lst is None:
+ self.heatmap_metric_lst = self.supported_heatmap_metrics
+ else:
+ self.heatmap_metric_lst = heatmap_metric_lst
+ # For the descriptors, the default None assumes no desc metric at all
+ if desc_metric_lst is None:
+ self.desc_metric_lst = []
+ elif desc_metric_lst == 'all':
+ self.desc_metric_lst = self.supported_desc_metrics
+ else:
+ self.desc_metric_lst = desc_metric_lst
+
+ if not self._check_metrics():
+ raise ValueError(
+ "[Error] Some elements in the metric_lst are invalid.")
+
+ # Initialize the results
+ self.metric_results = {}
+ for key in (self.supported_junc_metrics
+ + self.supported_heatmap_metrics
+ + self.supported_loss + self.supported_desc_metrics):
+ self.metric_results[key] = 0.
+ for key in self.supported_pr_metrics:
+ zero_lst = [0 for _ in range(50)]
+ self.metric_results[key] = {
+ "tp": zero_lst,
+ "tn": zero_lst,
+ "fp": zero_lst,
+ "fn": zero_lst,
+ "precision": zero_lst,
+ "recall": zero_lst
+ }
+
+ # Initialize total count
+ self.count = 0
+
+ def update(self, metrics, loss_dict=None, num_samples=1):
+ # loss should be given in the training mode
+ if self.is_training and (loss_dict is None):
+ raise ValueError(
+ "[Error] loss info should be given in the training mode.")
+
+ # update total counts
+ self.count += num_samples
+
+ # update all the metrics
+ for met in (self.supported_junc_metrics
+ + self.supported_heatmap_metrics
+ + self.supported_desc_metrics):
+ self.metric_results[met] += (num_samples
+ * metrics.metric_results[met])
+
+ # Update all the losses
+ for loss in loss_dict.keys():
+ self.metric_results[loss] += num_samples * loss_dict[loss]
+
+ # Update all pr counts
+ for pr_met in self.supported_pr_metrics:
+ # Update all tp, tn, fp, fn, precision, and recall.
+ for key in metrics.metric_results[pr_met].keys():
+ # Update each interval
+ for idx in range(len(self.metric_results[pr_met][key])):
+ self.metric_results[pr_met][key][idx] += (
+ num_samples
+ * metrics.metric_results[pr_met][key][idx])
+
+ def average(self):
+ results = {}
+ for met in self.metric_results.keys():
+ # Skip pr curve metrics
+ if not met in self.supported_pr_metrics:
+ results[met] = self.metric_results[met] / self.count
+ # Only update precision and recall in pr metrics
+ else:
+ met_results = {
+ "tp": self.metric_results[met]["tp"],
+ "tn": self.metric_results[met]["tn"],
+ "fp": self.metric_results[met]["fp"],
+ "fn": self.metric_results[met]["fn"],
+ "precision": [],
+ "recall": []
+ }
+ for idx in range(len(self.metric_results[met]["precision"])):
+ met_results["precision"].append(
+ self.metric_results[met]["precision"][idx]
+ / self.count)
+ met_results["recall"].append(
+ self.metric_results[met]["recall"][idx] / self.count)
+
+ results[met] = met_results
+
+ return results
+
+ def _check_metrics(self):
+ """ Check if all input metrics are valid. """
+ flag = True
+ for metric in self.junc_metric_lst:
+ if not metric in self.supported_junc_metrics:
+ flag = False
+ break
+ for metric in self.heatmap_metric_lst:
+ if not metric in self.supported_heatmap_metrics:
+ flag = False
+ break
+ for metric in self.desc_metric_lst:
+ if not metric in self.supported_desc_metrics:
+ flag = False
+ break
+
+ return flag
+
+
+class junction_precision(object):
+ """ Junction precision. """
+ def __init__(self, detection_thresh):
+ self.detection_thresh = detection_thresh
+
+ # Compute the evaluation result
+ def __call__(self, junc_pred, junc_gt, valid_mask):
+ # Convert prediction to discrete detection
+ junc_pred = (junc_pred >= self.detection_thresh).astype(np.int)
+ junc_pred = junc_pred * valid_mask.squeeze()
+
+ # Deal with the corner case of the prediction
+ if np.sum(junc_pred) > 0:
+ precision = (np.sum(junc_pred * junc_gt.squeeze())
+ / np.sum(junc_pred))
+ else:
+ precision = 0
+
+ return float(precision)
+
+
+class junction_recall(object):
+ """ Junction recall. """
+ def __init__(self, detection_thresh):
+ self.detection_thresh = detection_thresh
+
+ # Compute the evaluation result
+ def __call__(self, junc_pred, junc_gt, valid_mask):
+ # Convert prediction to discrete detection
+ junc_pred = (junc_pred >= self.detection_thresh).astype(np.int)
+ junc_pred = junc_pred * valid_mask.squeeze()
+
+ # Deal with the corner case of the recall.
+ if np.sum(junc_gt):
+ recall = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_gt)
+ else:
+ recall = 0
+
+ return float(recall)
+
+
+class junction_pr(object):
+ """ Junction precision-recall info. """
+ def __init__(self, num_threshold=50):
+ self.max = 0.4
+ step = self.max / num_threshold
+ self.min = step
+ self.intervals = np.flip(np.arange(self.min, self.max + step, step))
+
+ def __call__(self, junc_pred_raw, junc_gt, valid_mask):
+ tp_lst = []
+ fp_lst = []
+ tn_lst = []
+ fn_lst = []
+ precision_lst = []
+ recall_lst = []
+
+ valid_mask = valid_mask.squeeze()
+ # Iterate through all the thresholds
+ for thresh in list(self.intervals):
+ # Convert prediction to discrete detection
+ junc_pred = (junc_pred_raw >= thresh).astype(np.int)
+ junc_pred = junc_pred * valid_mask
+
+ # Compute tp, fp, tn, fn
+ junc_gt = junc_gt.squeeze()
+ tp = np.sum(junc_pred * junc_gt)
+ tn = np.sum((junc_pred == 0).astype(np.float)
+ * (junc_gt == 0).astype(np.float) * valid_mask)
+ fp = np.sum((junc_pred == 1).astype(np.float)
+ * (junc_gt == 0).astype(np.float) * valid_mask)
+ fn = np.sum((junc_pred == 0).astype(np.float)
+ * (junc_gt == 1).astype(np.float) * valid_mask)
+
+ tp_lst.append(tp)
+ tn_lst.append(tn)
+ fp_lst.append(fp)
+ fn_lst.append(fn)
+ precision_lst.append(tp / (tp + fp))
+ recall_lst.append(tp / (tp + fn))
+
+ return {
+ "tp": np.array(tp_lst),
+ "tn": np.array(tn_lst),
+ "fp": np.array(fp_lst),
+ "fn": np.array(fn_lst),
+ "precision": np.array(precision_lst),
+ "recall": np.array(recall_lst)
+ }
+
+
+class heatmap_precision(object):
+ """ Heatmap precision. """
+ def __init__(self, prob_thresh):
+ self.prob_thresh = prob_thresh
+
+ def __call__(self, heatmap_pred, heatmap_gt, valid_mask):
+ # Assume NHWC (Handle L1 and L2 cases) NxHxWx1
+ heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh)
+ heatmap_pred = heatmap_pred * valid_mask.squeeze()
+
+ # Deal with the corner case of the prediction
+ if np.sum(heatmap_pred) > 0:
+ precision = (np.sum(heatmap_pred * heatmap_gt.squeeze())
+ / np.sum(heatmap_pred))
+ else:
+ precision = 0.
+
+ return precision
+
+
+class heatmap_recall(object):
+ """ Heatmap recall. """
+ def __init__(self, prob_thresh):
+ self.prob_thresh = prob_thresh
+
+ def __call__(self, heatmap_pred, heatmap_gt, valid_mask):
+ # Assume NHWC (Handle L1 and L2 cases) NxHxWx1
+ heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh)
+ heatmap_pred = heatmap_pred * valid_mask.squeeze()
+
+ # Deal with the corner case of the ground truth
+ if np.sum(heatmap_gt) > 0:
+ recall = (np.sum(heatmap_pred * heatmap_gt.squeeze())
+ / np.sum(heatmap_gt))
+ else:
+ recall = 0.
+
+ return recall
+
+
+class matching_score(object):
+ """ Descriptors matching score. """
+ def __init__(self, grid_size):
+ self.grid_size = grid_size
+
+ def __call__(self, points1, points2, desc_pred1,
+ desc_pred2, line_indices):
+ b_size, _, Hc, Wc = desc_pred1.size()
+ img_size = (Hc * self.grid_size, Wc * self.grid_size)
+ device = desc_pred1.device
+
+ # Extract valid keypoints
+ n_points = line_indices.size()[1]
+ valid_points = line_indices.bool().flatten()
+ n_correct_points = torch.sum(valid_points).item()
+ if n_correct_points == 0:
+ return torch.tensor(0., dtype=torch.float, device=device)
+
+ # Convert the keypoints to a grid suitable for interpolation
+ grid1 = keypoints_to_grid(points1, img_size)
+ grid2 = keypoints_to_grid(points2, img_size)
+
+ # Extract the descriptors
+ desc1 = F.grid_sample(desc_pred1, grid1).permute(
+ 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+ desc1 = F.normalize(desc1, dim=1)
+ desc2 = F.grid_sample(desc_pred2, grid2).permute(
+ 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points]
+ desc2 = F.normalize(desc2, dim=1)
+ desc_dists = 2 - 2 * (desc1 @ desc2.t())
+
+ # Compute percentage of correct matches
+ matches0 = torch.min(desc_dists, dim=1)[1]
+ matches1 = torch.min(desc_dists, dim=0)[1]
+ matching_score = (matches1[matches0]
+ == torch.arange(len(matches0)).to(device))
+ matching_score = matching_score.float().mean()
+ return matching_score
+
+
+def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0):
+ """ Non-maximum suppression adapted from SuperPoint. """
+ # Iterate through batch dimension
+ im_h = prob_predictions.shape[1]
+ im_w = prob_predictions.shape[2]
+ output_lst = []
+ for i in range(prob_predictions.shape[0]):
+ # print(i)
+ prob_pred = prob_predictions[i, ...]
+ # Filter the points using prob_thresh
+ coord = np.where(prob_pred >= prob_thresh) # HW format
+ points = np.concatenate((coord[0][..., None], coord[1][..., None]),
+ axis=1) # HW format
+
+ # Get the probability score
+ prob_score = prob_pred[points[:, 0], points[:, 1]]
+
+ # Perform super nms
+ # Modify the in_points to xy format (instead of HW format)
+ in_points = np.concatenate((coord[1][..., None], coord[0][..., None],
+ prob_score), axis=1).T
+ keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh)
+ # Remember to flip outputs back to HW format
+ keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T)
+ keep_score = keep_points_[-1, :].T
+
+ # Whether we only keep the topk value
+ if (top_k > 0) or (top_k is None):
+ k = min([keep_points.shape[0], top_k])
+ keep_points = keep_points[:k, :]
+ keep_score = keep_score[:k]
+
+ # Re-compose the probability map
+ output_map = np.zeros([im_h, im_w])
+ output_map[keep_points[:, 0].astype(np.int),
+ keep_points[:, 1].astype(np.int)] = keep_score.squeeze()
+
+ output_lst.append(output_map[None, ...])
+
+ return np.concatenate(output_lst, axis=0)
+
+
+def nms_fast(in_corners, H, W, dist_thresh):
+ """
+ Run a faster approximate Non-Max-Suppression on numpy corners shaped:
+ 3xN [x_i,y_i,conf_i]^T
+
+ Algo summary: Create a grid sized HxW. Assign each corner location a 1,
+ rest are zeros. Iterate through all the 1's and convert them to -1 or 0.
+ Suppress points by setting nearby values to 0.
+
+ Grid Value Legend:
+ -1 : Kept.
+ 0 : Empty or suppressed.
+ 1 : To be processed (converted to either kept or supressed).
+
+ NOTE: The NMS first rounds points to integers, so NMS distance might not
+ be exactly dist_thresh. It also assumes points are within image boundary.
+
+ Inputs
+ in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
+ H - Image height.
+ W - Image width.
+ dist_thresh - Distance to suppress, measured as an infinite distance.
+ Returns
+ nmsed_corners - 3xN numpy matrix with surviving corners.
+ nmsed_inds - N length numpy vector with surviving corner indices.
+ """
+ grid = np.zeros((H, W)).astype(int) # Track NMS data.
+ inds = np.zeros((H, W)).astype(int) # Store indices of points.
+ # Sort by confidence and round to nearest int.
+ inds1 = np.argsort(-in_corners[2, :])
+ corners = in_corners[:, inds1]
+ rcorners = corners[:2, :].round().astype(int) # Rounded corners.
+ # Check for edge case of 0 or 1 corners.
+ if rcorners.shape[1] == 0:
+ return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int)
+ if rcorners.shape[1] == 1:
+ out = np.vstack((rcorners, in_corners[2])).reshape(3, 1)
+ return out, np.zeros((1)).astype(int)
+ # Initialize the grid.
+ for i, rc in enumerate(rcorners.T):
+ grid[rcorners[1, i], rcorners[0, i]] = 1
+ inds[rcorners[1, i], rcorners[0, i]] = i
+ # Pad the border of the grid, so that we can NMS points near the border.
+ pad = dist_thresh
+ grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant')
+ # Iterate through points, highest to lowest conf, suppress neighborhood.
+ count = 0
+ for i, rc in enumerate(rcorners.T):
+ # Account for top and left padding.
+ pt = (rc[0] + pad, rc[1] + pad)
+ if grid[pt[1], pt[0]] == 1: # If not yet suppressed.
+ grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0
+ grid[pt[1], pt[0]] = -1
+ count += 1
+ # Get all surviving -1's and return sorted array of remaining corners.
+ keepy, keepx = np.where(grid == -1)
+ keepy, keepx = keepy - pad, keepx - pad
+ inds_keep = inds[keepy, keepx]
+ out = corners[:, inds_keep]
+ values = out[-1, :]
+ inds2 = np.argsort(-values)
+ out = out[:, inds2]
+ out_inds = inds1[inds_keep[inds2]]
+ return out, out_inds
diff --git a/imcui/third_party/SOLD2/sold2/model/model_util.py b/imcui/third_party/SOLD2/sold2/model/model_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70d80da40a72c207edfcfc1509e820846f0b731
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/model_util.py
@@ -0,0 +1,203 @@
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+from .nets.backbone import HourglassBackbone, SuperpointBackbone
+from .nets.junction_decoder import SuperpointDecoder
+from .nets.heatmap_decoder import PixelShuffleDecoder
+from .nets.descriptor_decoder import SuperpointDescriptor
+
+
+def get_model(model_cfg=None, loss_weights=None, mode="train"):
+ """ Get model based on the model configuration. """
+ # Check dataset config is given
+ if model_cfg is None:
+ raise ValueError("[Error] The model config is required!")
+
+ # List the supported options here
+ print("\n\n\t--------Initializing model----------")
+ supported_arch = ["simple"]
+ if not model_cfg["model_architecture"] in supported_arch:
+ raise ValueError(
+ "[Error] The model architecture is not in supported arch!")
+
+ if model_cfg["model_architecture"] == "simple":
+ model = SOLD2Net(model_cfg)
+ else:
+ raise ValueError(
+ "[Error] The model architecture is not in supported arch!")
+
+ # Optionally register loss weights to the model
+ if mode == "train":
+ if loss_weights is not None:
+ for param_name, param in loss_weights.items():
+ if isinstance(param, nn.Parameter):
+ print("\t [Debug] Adding %s with value %f to model"
+ % (param_name, param.item()))
+ model.register_parameter(param_name, param)
+ else:
+ raise ValueError(
+ "[Error] the loss weights can not be None in dynamic weighting mode during training.")
+
+ # Display some summary info.
+ print("\tModel architecture: %s" % model_cfg["model_architecture"])
+ print("\tBackbone: %s" % model_cfg["backbone"])
+ print("\tJunction decoder: %s" % model_cfg["junction_decoder"])
+ print("\tHeatmap decoder: %s" % model_cfg["heatmap_decoder"])
+ print("\t-------------------------------------")
+
+ return model
+
+
+class SOLD2Net(nn.Module):
+ """ Full network for SOLD². """
+ def __init__(self, model_cfg):
+ super(SOLD2Net, self).__init__()
+ self.name = model_cfg["model_name"]
+ self.cfg = model_cfg
+
+ # List supported network options
+ self.supported_backbone = ["lcnn", "superpoint"]
+ self.backbone_net, self.feat_channel = self.get_backbone()
+
+ # List supported junction decoder options
+ self.supported_junction_decoder = ["superpoint_decoder"]
+ self.junction_decoder = self.get_junction_decoder()
+
+ # List supported heatmap decoder options
+ self.supported_heatmap_decoder = ["pixel_shuffle",
+ "pixel_shuffle_single"]
+ self.heatmap_decoder = self.get_heatmap_decoder()
+
+ # List supported descriptor decoder options
+ if "descriptor_decoder" in self.cfg:
+ self.supported_descriptor_decoder = ["superpoint_descriptor"]
+ self.descriptor_decoder = self.get_descriptor_decoder()
+
+ # Initialize the model weights
+ self.apply(weight_init)
+
+ def forward(self, input_images):
+ # The backbone
+ features = self.backbone_net(input_images)
+
+ # junction decoder
+ junctions = self.junction_decoder(features)
+
+ # heatmap decoder
+ heatmaps = self.heatmap_decoder(features)
+
+ outputs = {"junctions": junctions, "heatmap": heatmaps}
+
+ # Descriptor decoder
+ if "descriptor_decoder" in self.cfg:
+ outputs["descriptors"] = self.descriptor_decoder(features)
+
+ return outputs
+
+ def get_backbone(self):
+ """ Retrieve the backbone encoder network. """
+ if not self.cfg["backbone"] in self.supported_backbone:
+ raise ValueError(
+ "[Error] The backbone selection is not supported.")
+
+ # lcnn backbone (stacked hourglass)
+ if self.cfg["backbone"] == "lcnn":
+ backbone_cfg = self.cfg["backbone_cfg"]
+ backbone = HourglassBackbone(**backbone_cfg)
+ feat_channel = 256
+
+ elif self.cfg["backbone"] == "superpoint":
+ backbone_cfg = self.cfg["backbone_cfg"]
+ backbone = SuperpointBackbone()
+ feat_channel = 128
+
+ else:
+ raise ValueError(
+ "[Error] The backbone selection is not supported.")
+
+ return backbone, feat_channel
+
+ def get_junction_decoder(self):
+ """ Get the junction decoder. """
+ if (not self.cfg["junction_decoder"]
+ in self.supported_junction_decoder):
+ raise ValueError(
+ "[Error] The junction decoder selection is not supported.")
+
+ # superpoint decoder
+ if self.cfg["junction_decoder"] == "superpoint_decoder":
+ decoder = SuperpointDecoder(self.feat_channel,
+ self.cfg["backbone"])
+ else:
+ raise ValueError(
+ "[Error] The junction decoder selection is not supported.")
+
+ return decoder
+
+ def get_heatmap_decoder(self):
+ """ Get the heatmap decoder. """
+ if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder:
+ raise ValueError(
+ "[Error] The heatmap decoder selection is not supported.")
+
+ # Pixel_shuffle decoder
+ if self.cfg["heatmap_decoder"] == "pixel_shuffle":
+ if self.cfg["backbone"] == "lcnn":
+ decoder = PixelShuffleDecoder(self.feat_channel,
+ num_upsample=2)
+ elif self.cfg["backbone"] == "superpoint":
+ decoder = PixelShuffleDecoder(self.feat_channel,
+ num_upsample=3)
+ else:
+ raise ValueError("[Error] Unknown backbone option.")
+ # Pixel_shuffle decoder with single channel output
+ elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single":
+ if self.cfg["backbone"] == "lcnn":
+ decoder = PixelShuffleDecoder(
+ self.feat_channel, num_upsample=2, output_channel=1)
+ elif self.cfg["backbone"] == "superpoint":
+ decoder = PixelShuffleDecoder(
+ self.feat_channel, num_upsample=3, output_channel=1)
+ else:
+ raise ValueError("[Error] Unknown backbone option.")
+ else:
+ raise ValueError(
+ "[Error] The heatmap decoder selection is not supported.")
+
+ return decoder
+
+ def get_descriptor_decoder(self):
+ """ Get the descriptor decoder. """
+ if (not self.cfg["descriptor_decoder"]
+ in self.supported_descriptor_decoder):
+ raise ValueError(
+ "[Error] The descriptor decoder selection is not supported.")
+
+ # SuperPoint descriptor
+ if self.cfg["descriptor_decoder"] == "superpoint_descriptor":
+ decoder = SuperpointDescriptor(self.feat_channel)
+ else:
+ raise ValueError(
+ "[Error] The descriptor decoder selection is not supported.")
+
+ return decoder
+
+
+def weight_init(m):
+ """ Weight initialization function. """
+ # Conv2D
+ if isinstance(m, nn.Conv2d):
+ init.xavier_normal_(m.weight.data)
+ if m.bias is not None:
+ init.normal_(m.bias.data)
+ # Batchnorm
+ elif isinstance(m, nn.BatchNorm2d):
+ init.normal_(m.weight.data, mean=1, std=0.02)
+ init.constant_(m.bias.data, 0)
+ # Linear
+ elif isinstance(m, nn.Linear):
+ init.xavier_normal_(m.weight.data)
+ init.normal_(m.bias.data)
+ else:
+ pass
diff --git a/imcui/third_party/SOLD2/sold2/model/nets/__init__.py b/imcui/third_party/SOLD2/sold2/model/nets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/sold2/model/nets/backbone.py b/imcui/third_party/SOLD2/sold2/model/nets/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..71f260aef108c77d54319cab7bc082c3c51112e7
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/nets/backbone.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+
+from .lcnn_hourglass import MultitaskHead, hg
+
+
+class HourglassBackbone(nn.Module):
+ """ Hourglass backbone. """
+ def __init__(self, input_channel=1, depth=4, num_stacks=2,
+ num_blocks=1, num_classes=5):
+ super(HourglassBackbone, self).__init__()
+ self.head = MultitaskHead
+ self.net = hg(**{
+ "head": self.head,
+ "depth": depth,
+ "num_stacks": num_stacks,
+ "num_blocks": num_blocks,
+ "num_classes": num_classes,
+ "input_channels": input_channel
+ })
+
+ def forward(self, input_images):
+ return self.net(input_images)[1]
+
+
+class SuperpointBackbone(nn.Module):
+ """ SuperPoint backbone. """
+ def __init__(self):
+ super(SuperpointBackbone, self).__init__()
+ self.relu = torch.nn.ReLU(inplace=True)
+ self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
+ c1, c2, c3, c4 = 64, 64, 128, 128
+ # Shared Encoder.
+ self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3,
+ stride=1, padding=1)
+ self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3,
+ stride=1, padding=1)
+ self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3,
+ stride=1, padding=1)
+ self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3,
+ stride=1, padding=1)
+ self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3,
+ stride=1, padding=1)
+ self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3,
+ stride=1, padding=1)
+ self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3,
+ stride=1, padding=1)
+ self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3,
+ stride=1, padding=1)
+
+ def forward(self, input_images):
+ # Shared Encoder.
+ x = self.relu(self.conv1a(input_images))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ return x
diff --git a/imcui/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py b/imcui/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ed4306fad764efab2c22ede9cae253c9b17d6c2
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/nets/descriptor_decoder.py
@@ -0,0 +1,19 @@
+import torch
+import torch.nn as nn
+
+
+class SuperpointDescriptor(nn.Module):
+ """ Descriptor decoder based on the SuperPoint arcihtecture. """
+ def __init__(self, input_feat_dim=128):
+ super(SuperpointDescriptor, self).__init__()
+ self.relu = torch.nn.ReLU(inplace=True)
+ self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3,
+ stride=1, padding=1)
+ self.convPb = torch.nn.Conv2d(256, 128, kernel_size=1,
+ stride=1, padding=0)
+
+ def forward(self, input_features):
+ feat = self.relu(self.convPa(input_features))
+ semi = self.convPb(feat)
+
+ return semi
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py b/imcui/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd5157ca740c8c7e25f2183b2a3c1fefa813deca
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/nets/heatmap_decoder.py
@@ -0,0 +1,59 @@
+import torch.nn as nn
+
+
+class PixelShuffleDecoder(nn.Module):
+ """ Pixel shuffle decoder. """
+ def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2):
+ super(PixelShuffleDecoder, self).__init__()
+ # Get channel parameters
+ self.channel_conf = self.get_channel_conf(num_upsample)
+
+ # Define the pixel shuffle
+ self.pixshuffle = nn.PixelShuffle(2)
+
+ # Process the feature
+ self.conv_block_lst = []
+ # The input block
+ self.conv_block_lst.append(
+ nn.Sequential(
+ nn.Conv2d(input_feat_dim, self.channel_conf[0],
+ kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(self.channel_conf[0]),
+ nn.ReLU(inplace=True)
+ ))
+
+ # Intermediate block
+ for channel in self.channel_conf[1:-1]:
+ self.conv_block_lst.append(
+ nn.Sequential(
+ nn.Conv2d(channel, channel, kernel_size=3,
+ stride=1, padding=1),
+ nn.BatchNorm2d(channel),
+ nn.ReLU(inplace=True)
+ ))
+
+ # Output block
+ self.conv_block_lst.append(
+ nn.Conv2d(self.channel_conf[-1], output_channel,
+ kernel_size=1, stride=1, padding=0)
+ )
+ self.conv_block_lst = nn.ModuleList(self.conv_block_lst)
+
+ # Get num of channels based on number of upsampling.
+ def get_channel_conf(self, num_upsample):
+ if num_upsample == 2:
+ return [256, 64, 16]
+ elif num_upsample == 3:
+ return [256, 64, 16, 4]
+
+ def forward(self, input_features):
+ # Iterate til output block
+ out = input_features
+ for block in self.conv_block_lst[:-1]:
+ out = block(out)
+ out = self.pixshuffle(out)
+
+ # Output layer
+ out = self.conv_block_lst[-1](out)
+
+ return out
diff --git a/imcui/third_party/SOLD2/sold2/model/nets/junction_decoder.py b/imcui/third_party/SOLD2/sold2/model/nets/junction_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2bb649518896501c784940028a772d688c2b3a7
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/nets/junction_decoder.py
@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+
+
+class SuperpointDecoder(nn.Module):
+ """ Junction decoder based on the SuperPoint architecture. """
+ def __init__(self, input_feat_dim=128, backbone_name="lcnn"):
+ super(SuperpointDecoder, self).__init__()
+ self.relu = torch.nn.ReLU(inplace=True)
+ # Perform strided convolution when using lcnn backbone.
+ if backbone_name == "lcnn":
+ self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3,
+ stride=2, padding=1)
+ elif backbone_name == "superpoint":
+ self.convPa = torch.nn.Conv2d(input_feat_dim, 256, kernel_size=3,
+ stride=1, padding=1)
+ else:
+ raise ValueError("[Error] Unknown backbone option.")
+
+ self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1,
+ stride=1, padding=0)
+
+ def forward(self, input_features):
+ feat = self.relu(self.convPa(input_features))
+ semi = self.convPb(feat)
+
+ return semi
\ No newline at end of file
diff --git a/imcui/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py b/imcui/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9dc78eef34e7ee146166b1b66c10070799d63f3
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/model/nets/lcnn_hourglass.py
@@ -0,0 +1,226 @@
+"""
+Hourglass network, taken from https://github.com/zhou13/lcnn
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ["HourglassNet", "hg"]
+
+
+class MultitaskHead(nn.Module):
+ def __init__(self, input_channels, num_class):
+ super(MultitaskHead, self).__init__()
+
+ m = int(input_channels / 4)
+ head_size = [[2], [1], [2]]
+ heads = []
+ for output_channels in sum(head_size, []):
+ heads.append(
+ nn.Sequential(
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(m, output_channels, kernel_size=1),
+ )
+ )
+ self.heads = nn.ModuleList(heads)
+ assert num_class == sum(sum(head_size, []))
+
+ def forward(self, x):
+ return torch.cat([head(x) for head in self.heads], dim=1)
+
+
+class Bottleneck2D(nn.Module):
+ expansion = 2
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck2D, self).__init__()
+
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+ stride=stride, padding=1)
+ self.bn3 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.bn1(x)
+ out = self.relu(out)
+ out = self.conv1(out)
+
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ out = self.bn3(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+
+class Hourglass(nn.Module):
+ def __init__(self, block, num_blocks, planes, depth):
+ super(Hourglass, self).__init__()
+ self.depth = depth
+ self.block = block
+ self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
+
+ def _make_residual(self, block, num_blocks, planes):
+ layers = []
+ for i in range(0, num_blocks):
+ layers.append(block(planes * block.expansion, planes))
+ return nn.Sequential(*layers)
+
+ def _make_hour_glass(self, block, num_blocks, planes, depth):
+ hg = []
+ for i in range(depth):
+ res = []
+ for j in range(3):
+ res.append(self._make_residual(block, num_blocks, planes))
+ if i == 0:
+ res.append(self._make_residual(block, num_blocks, planes))
+ hg.append(nn.ModuleList(res))
+ return nn.ModuleList(hg)
+
+ def _hour_glass_forward(self, n, x):
+ up1 = self.hg[n - 1][0](x)
+ low1 = F.max_pool2d(x, 2, stride=2)
+ low1 = self.hg[n - 1][1](low1)
+
+ if n > 1:
+ low2 = self._hour_glass_forward(n - 1, low1)
+ else:
+ low2 = self.hg[n - 1][3](low1)
+ low3 = self.hg[n - 1][2](low2)
+ # up2 = F.interpolate(low3, scale_factor=2)
+ up2 = F.interpolate(low3, size=up1.shape[2:])
+ out = up1 + up2
+ return out
+
+ def forward(self, x):
+ return self._hour_glass_forward(self.depth, x)
+
+
+class HourglassNet(nn.Module):
+ """Hourglass model from Newell et al ECCV 2016"""
+
+ def __init__(self, block, head, depth, num_stacks, num_blocks,
+ num_classes, input_channels):
+ super(HourglassNet, self).__init__()
+
+ self.inplanes = 64
+ self.num_feats = 128
+ self.num_stacks = num_stacks
+ self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7,
+ stride=2, padding=3)
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.layer1 = self._make_residual(block, self.inplanes, 1)
+ self.layer2 = self._make_residual(block, self.inplanes, 1)
+ self.layer3 = self._make_residual(block, self.num_feats, 1)
+ self.maxpool = nn.MaxPool2d(2, stride=2)
+
+ # build hourglass modules
+ ch = self.num_feats * block.expansion
+ # vpts = []
+ hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
+ for i in range(num_stacks):
+ hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
+ res.append(self._make_residual(block, self.num_feats, num_blocks))
+ fc.append(self._make_fc(ch, ch))
+ score.append(head(ch, num_classes))
+ # vpts.append(VptsHead(ch))
+ # vpts.append(nn.Linear(ch, 9))
+ # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
+ # score[i].bias.data[0] += 4.6
+ # score[i].bias.data[2] += 4.6
+ if i < num_stacks - 1:
+ fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
+ score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
+ self.hg = nn.ModuleList(hg)
+ self.res = nn.ModuleList(res)
+ self.fc = nn.ModuleList(fc)
+ self.score = nn.ModuleList(score)
+ # self.vpts = nn.ModuleList(vpts)
+ self.fc_ = nn.ModuleList(fc_)
+ self.score_ = nn.ModuleList(score_)
+
+ def _make_residual(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ )
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_fc(self, inplanes, outplanes):
+ bn = nn.BatchNorm2d(inplanes)
+ conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
+ return nn.Sequential(conv, bn, self.relu)
+
+ def forward(self, x):
+ out = []
+ # out_vps = []
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.maxpool(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ for i in range(self.num_stacks):
+ y = self.hg[i](x)
+ y = self.res[i](y)
+ y = self.fc[i](y)
+ score = self.score[i](y)
+ # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1))
+ # pre_vpts = pre_vpts.reshape(-1, 256)
+ # vpts = self.vpts[i](x)
+ out.append(score)
+ # out_vps.append(vpts)
+ if i < self.num_stacks - 1:
+ fc_ = self.fc_[i](y)
+ score_ = self.score_[i](score)
+ x = x + fc_ + score_
+
+ return out[::-1], y # , out_vps[::-1]
+
+
+def hg(**kwargs):
+ model = HourglassNet(
+ Bottleneck2D,
+ head=kwargs.get("head",
+ lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
+ depth=kwargs["depth"],
+ num_stacks=kwargs["num_stacks"],
+ num_blocks=kwargs["num_blocks"],
+ num_classes=kwargs["num_classes"],
+ input_channels=kwargs["input_channels"]
+ )
+ return model
diff --git a/imcui/third_party/SOLD2/sold2/postprocess/__init__.py b/imcui/third_party/SOLD2/sold2/postprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SOLD2/sold2/postprocess/convert_homography_results.py b/imcui/third_party/SOLD2/sold2/postprocess/convert_homography_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..352eebbde00f6d8a9c20517dccd7024fd0758ffd
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/postprocess/convert_homography_results.py
@@ -0,0 +1,136 @@
+"""
+Convert the aggregation results from the homography adaptation to GT labels.
+"""
+import sys
+sys.path.append("../")
+import os
+import yaml
+import argparse
+import numpy as np
+import h5py
+import torch
+from tqdm import tqdm
+
+from config.project_config import Config as cfg
+from model.line_detection import LineSegmentDetectionModule
+from model.metrics import super_nms
+from misc.train_utils import parse_h5_data
+
+
+def convert_raw_exported_predictions(input_data, grid_size=8,
+ detect_thresh=1/65, topk=300):
+ """ Convert the exported junctions and heatmaps predictions
+ to a standard format.
+ Arguments:
+ input_data: the raw data (dict) decoded from the hdf5 dataset
+ outputs: dict containing required entries including:
+ junctions_pred: Nx2 ndarray containing nms junction predictions.
+ heatmap_pred: HxW ndarray containing predicted heatmaps
+ valid_mask: HxW ndarray containing the valid mask
+ """
+ # Check the input_data is from (1) single prediction,
+ # or (2) homography adaptation.
+ # Homography adaptation raw predictions
+ if (("junc_prob_mean" in input_data.keys())
+ and ("heatmap_prob_mean" in input_data.keys())):
+ # Get the junction predictions and convert if to Nx2 format
+ junc_prob = input_data["junc_prob_mean"]
+ junc_pred_np = junc_prob[None, ...]
+ junc_pred_np_nms = super_nms(junc_pred_np, grid_size,
+ detect_thresh, topk)
+ junctions = np.where(junc_pred_np_nms.squeeze())
+ junc_points_pred = np.concatenate([junctions[0][..., None],
+ junctions[1][..., None]], axis=-1)
+
+ # Get the heatmap predictions
+ heatmap_pred = input_data["heatmap_prob_mean"].squeeze()
+ valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32)
+
+ # Single predictions
+ else:
+ # Get the junction point predictions and convert to Nx2 format
+ junc_points_pred = np.where(input_data["junc_pred_nms"])
+ junc_points_pred = np.concatenate(
+ [junc_points_pred[0][..., None],
+ junc_points_pred[1][..., None]], axis=-1)
+
+ # Get the heatmap predictions
+ heatmap_pred = input_data["heatmap_pred"]
+ valid_mask = input_data["valid_mask"]
+
+ return {
+ "junctions_pred": junc_points_pred,
+ "heatmap_pred": heatmap_pred,
+ "valid_mask": valid_mask
+ }
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input_dataset", type=str,
+ help="Name of the exported dataset.")
+ parser.add_argument("output_dataset", type=str,
+ help="Name of the output dataset.")
+ parser.add_argument("config", type=str,
+ help="Path to the model config.")
+ args = parser.parse_args()
+
+ # Define the path to the input exported dataset
+ exported_dataset_path = os.path.join(cfg.export_dataroot,
+ args.input_dataset)
+ if not os.path.exists(exported_dataset_path):
+ raise ValueError("Missing input dataset: " + exported_dataset_path)
+ exported_dataset = h5py.File(exported_dataset_path, "r")
+
+ # Define the output path for the results
+ output_dataset_path = os.path.join(cfg.export_dataroot,
+ args.output_dataset)
+
+ device = torch.device("cuda")
+ nms_device = torch.device("cuda")
+
+ # Read the config file
+ if not os.path.exists(args.config):
+ raise ValueError("Missing config file: " + args.config)
+ with open(args.config, "r") as f:
+ config = yaml.safe_load(f)
+ model_cfg = config["model_cfg"]
+ line_detector_cfg = config["line_detector_cfg"]
+
+ # Initialize the line detection module
+ line_detector = LineSegmentDetectionModule(**line_detector_cfg)
+
+ # Iterate through all the dataset keys
+ with h5py.File(output_dataset_path, "w") as output_dataset:
+ for idx, output_key in enumerate(tqdm(list(exported_dataset.keys()),
+ ascii=True)):
+ # Get the data
+ data = parse_h5_data(exported_dataset[output_key])
+
+ # Preprocess the data
+ converted_data = convert_raw_exported_predictions(
+ data, grid_size=model_cfg["grid_size"],
+ detect_thresh=model_cfg["detection_thresh"])
+ junctions_pred_raw = converted_data["junctions_pred"]
+ heatmap_pred = converted_data["heatmap_pred"]
+ valid_mask = converted_data["valid_mask"]
+
+ line_map_pred, junctions_pred, heatmap_pred = line_detector.detect(
+ junctions_pred_raw, heatmap_pred, device=device)
+ if isinstance(line_map_pred, torch.Tensor):
+ line_map_pred = line_map_pred.cpu().numpy()
+ if isinstance(junctions_pred, torch.Tensor):
+ junctions_pred = junctions_pred.cpu().numpy()
+ if isinstance(heatmap_pred, torch.Tensor):
+ heatmap_pred = heatmap_pred.cpu().numpy()
+
+ output_data = {"junctions": junctions_pred,
+ "line_map": line_map_pred}
+
+ # Record it to the h5 dataset
+ f_group = output_dataset.create_group(output_key)
+
+ # Store data
+ for key, output_data in output_data.items():
+ f_group.create_dataset(key, data=output_data,
+ compression="gzip")
diff --git a/imcui/third_party/SOLD2/sold2/train.py b/imcui/third_party/SOLD2/sold2/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2064e00e6d192f9202f011c3626d6f53c4fe6270
--- /dev/null
+++ b/imcui/third_party/SOLD2/sold2/train.py
@@ -0,0 +1,752 @@
+"""
+This file implements the training process and all the summaries
+"""
+import os
+import numpy as np
+import cv2
+import torch
+from torch.nn.functional import pixel_shuffle, softmax
+from torch.utils.data import DataLoader
+import torch.utils.data.dataloader as torch_loader
+from tensorboardX import SummaryWriter
+
+from .dataset.dataset_util import get_dataset
+from .model.model_util import get_model
+from .model.loss import TotalLoss, get_loss_and_weights
+from .model.metrics import AverageMeter, Metrics, super_nms
+from .model.lr_scheduler import get_lr_scheduler
+from .misc.train_utils import (convert_image, get_latest_checkpoint,
+ remove_old_checkpoints)
+
+
+def customized_collate_fn(batch):
+ """ Customized collate_fn. """
+ batch_keys = ["image", "junction_map", "heatmap", "valid_mask"]
+ list_keys = ["junctions", "line_map"]
+
+ outputs = {}
+ for key in batch_keys:
+ outputs[key] = torch_loader.default_collate([b[key] for b in batch])
+ for key in list_keys:
+ outputs[key] = [b[key] for b in batch]
+
+ return outputs
+
+
+def restore_weights(model, state_dict, strict=True):
+ """ Restore weights in compatible mode. """
+ # Try to directly load state dict
+ try:
+ model.load_state_dict(state_dict, strict=strict)
+ # Deal with some version compatibility issue (catch version incompatible)
+ except:
+ err = model.load_state_dict(state_dict, strict=False)
+
+ # missing keys are those in model but not in state_dict
+ missing_keys = err.missing_keys
+ # Unexpected keys are those in state_dict but not in model
+ unexpected_keys = err.unexpected_keys
+
+ # Load mismatched keys manually
+ model_dict = model.state_dict()
+ for idx, key in enumerate(missing_keys):
+ dict_keys = [_ for _ in unexpected_keys if not "tracked" in _]
+ model_dict[key] = state_dict[dict_keys[idx]]
+ model.load_state_dict(model_dict)
+
+ return model
+
+
+def train_net(args, dataset_cfg, model_cfg, output_path):
+ """ Main training function. """
+ # Add some version compatibility check
+ if model_cfg.get("weighting_policy") is None:
+ # Default to static
+ model_cfg["weighting_policy"] = "static"
+
+ # Get the train, val, test config
+ train_cfg = model_cfg["train"]
+ test_cfg = model_cfg["test"]
+
+ # Create train and test dataset
+ print("\t Initializing dataset...")
+ train_dataset, train_collate_fn = get_dataset("train", dataset_cfg)
+ test_dataset, test_collate_fn = get_dataset("test", dataset_cfg)
+
+ # Create the dataloader
+ train_loader = DataLoader(train_dataset,
+ batch_size=train_cfg["batch_size"],
+ num_workers=8,
+ shuffle=True, pin_memory=True,
+ collate_fn=train_collate_fn)
+ test_loader = DataLoader(test_dataset,
+ batch_size=test_cfg.get("batch_size", 1),
+ num_workers=test_cfg.get("num_workers", 1),
+ shuffle=False, pin_memory=False,
+ collate_fn=test_collate_fn)
+ print("\t Successfully intialized dataloaders.")
+
+
+ # Get the loss function and weight first
+ loss_funcs, loss_weights = get_loss_and_weights(model_cfg)
+
+ # If resume.
+ if args.resume:
+ # Create model and load the state dict
+ checkpoint = get_latest_checkpoint(args.resume_path,
+ args.checkpoint_name)
+ model = get_model(model_cfg, loss_weights)
+ model = restore_weights(model, checkpoint["model_state_dict"])
+ model = model.cuda()
+ optimizer = torch.optim.Adam(
+ [{"params": model.parameters(),
+ "initial_lr": model_cfg["learning_rate"]}],
+ model_cfg["learning_rate"],
+ amsgrad=True)
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ # Optionally get the learning rate scheduler
+ scheduler = get_lr_scheduler(
+ lr_decay=model_cfg.get("lr_decay", False),
+ lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
+ optimizer=optimizer)
+ # If we start to use learning rate scheduler from the middle
+ if ((scheduler is not None)
+ and (checkpoint.get("scheduler_state_dict", None) is not None)):
+ scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
+ start_epoch = checkpoint["epoch"] + 1
+ # Initialize all the components.
+ else:
+ # Create model and optimizer
+ model = get_model(model_cfg, loss_weights)
+ # Optionally get the pretrained wieghts
+ if args.pretrained:
+ print("\t [Debug] Loading pretrained weights...")
+ checkpoint = get_latest_checkpoint(args.pretrained_path,
+ args.checkpoint_name)
+ # If auto weighting restore from non-auto weighting
+ model = restore_weights(model, checkpoint["model_state_dict"],
+ strict=False)
+ print("\t [Debug] Finished loading pretrained weights!")
+
+ model = model.cuda()
+ optimizer = torch.optim.Adam(
+ [{"params": model.parameters(),
+ "initial_lr": model_cfg["learning_rate"]}],
+ model_cfg["learning_rate"],
+ amsgrad=True)
+ # Optionally get the learning rate scheduler
+ scheduler = get_lr_scheduler(
+ lr_decay=model_cfg.get("lr_decay", False),
+ lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
+ optimizer=optimizer)
+ start_epoch = 0
+
+ print("\t Successfully initialized model")
+
+ # Define the total loss
+ policy = model_cfg.get("weighting_policy", "static")
+ loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda()
+ if "descriptor_decoder" in model_cfg:
+ metric_func = Metrics(model_cfg["detection_thresh"],
+ model_cfg["prob_thresh"],
+ model_cfg["descriptor_loss_cfg"]["grid_size"],
+ desc_metric_lst='all')
+ else:
+ metric_func = Metrics(model_cfg["detection_thresh"],
+ model_cfg["prob_thresh"],
+ model_cfg["grid_size"])
+
+ # Define the summary writer
+ logdir = os.path.join(output_path, "log")
+ writer = SummaryWriter(logdir=logdir)
+
+ # Start the training loop
+ for epoch in range(start_epoch, model_cfg["epochs"]):
+ # Record the learning rate
+ current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
+ writer.add_scalar("LR/lr", current_lr, epoch)
+
+ # Train for one epochs
+ print("\n\n================== Training ====================")
+ train_single_epoch(
+ model=model,
+ model_cfg=model_cfg,
+ optimizer=optimizer,
+ loss_func=loss_func,
+ metric_func=metric_func,
+ train_loader=train_loader,
+ writer=writer,
+ epoch=epoch)
+
+ # Do the validation
+ print("\n\n================== Validation ==================")
+ validate(
+ model=model,
+ model_cfg=model_cfg,
+ loss_func=loss_func,
+ metric_func=metric_func,
+ val_loader=test_loader,
+ writer=writer,
+ epoch=epoch)
+
+ # Update the scheduler
+ if scheduler is not None:
+ scheduler.step()
+
+ # Save checkpoints
+ file_name = os.path.join(output_path,
+ "checkpoint-epoch%03d-end.tar"%(epoch))
+ print("[Info] Saving checkpoint %s ..." % file_name)
+ save_dict = {
+ "epoch": epoch,
+ "model_state_dict": model.state_dict(),
+ "optimizer_state_dict": optimizer.state_dict(),
+ "model_cfg": model_cfg}
+ if scheduler is not None:
+ save_dict.update({"scheduler_state_dict": scheduler.state_dict()})
+ torch.save(save_dict, file_name)
+
+ # Remove the outdated checkpoints
+ remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15))
+
+
+def train_single_epoch(model, model_cfg, optimizer, loss_func, metric_func,
+ train_loader, writer, epoch):
+ """ Train for one epoch. """
+ # Switch the model to training mode
+ model.train()
+
+ # Initialize the average meter
+ compute_descriptors = loss_func.compute_descriptors
+ if compute_descriptors:
+ average_meter = AverageMeter(is_training=True, desc_metric_lst='all')
+ else:
+ average_meter = AverageMeter(is_training=True)
+
+ # The training loop
+ for idx, data in enumerate(train_loader):
+ if compute_descriptors:
+ junc_map = data["ref_junction_map"].cuda()
+ junc_map2 = data["target_junction_map"].cuda()
+ heatmap = data["ref_heatmap"].cuda()
+ heatmap2 = data["target_heatmap"].cuda()
+ line_points = data["ref_line_points"].cuda()
+ line_points2 = data["target_line_points"].cuda()
+ line_indices = data["ref_line_indices"].cuda()
+ valid_mask = data["ref_valid_mask"].cuda()
+ valid_mask2 = data["target_valid_mask"].cuda()
+ input_images = data["ref_image"].cuda()
+ input_images2 = data["target_image"].cuda()
+
+ # Run the forward pass
+ outputs = model(input_images)
+ outputs2 = model(input_images2)
+
+ # Compute losses
+ losses = loss_func.forward_descriptors(
+ outputs["junctions"], outputs2["junctions"],
+ junc_map, junc_map2, outputs["heatmap"], outputs2["heatmap"],
+ heatmap, heatmap2, line_points, line_points2,
+ line_indices, outputs['descriptors'], outputs2['descriptors'],
+ epoch, valid_mask, valid_mask2)
+ else:
+ junc_map = data["junction_map"].cuda()
+ heatmap = data["heatmap"].cuda()
+ valid_mask = data["valid_mask"].cuda()
+ input_images = data["image"].cuda()
+
+ # Run the forward pass
+ outputs = model(input_images)
+
+ # Compute losses
+ losses = loss_func(
+ outputs["junctions"], junc_map,
+ outputs["heatmap"], heatmap,
+ valid_mask)
+
+ total_loss = losses["total_loss"]
+
+ # Update the model
+ optimizer.zero_grad()
+ total_loss.backward()
+ optimizer.step()
+
+ # Compute the global step
+ global_step = epoch * len(train_loader) + idx
+ ############## Measure the metric error #########################
+ # Only do this when needed
+ if (((idx % model_cfg["disp_freq"]) == 0)
+ or ((idx % model_cfg["summary_freq"]) == 0)):
+ junc_np = convert_junc_predictions(
+ outputs["junctions"], model_cfg["grid_size"],
+ model_cfg["detection_thresh"], 300)
+ junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)
+
+ # Always fetch only one channel (compatible with L1, L2, and CE)
+ if outputs["heatmap"].shape[1] == 2:
+ heatmap_np = softmax(outputs["heatmap"].detach(),
+ dim=1).cpu().numpy()
+ heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:]
+ else:
+ heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
+ heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)
+
+ heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
+ valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)
+
+ # Evaluate metric results
+ if compute_descriptors:
+ metric_func.evaluate(
+ junc_np["junc_pred"], junc_np["junc_pred_nms"],
+ junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np,
+ line_points, line_points2, outputs["descriptors"],
+ outputs2["descriptors"], line_indices)
+ else:
+ metric_func.evaluate(
+ junc_np["junc_pred"], junc_np["junc_pred_nms"],
+ junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np)
+ # Update average meter
+ junc_loss = losses["junc_loss"].item()
+ heatmap_loss = losses["heatmap_loss"].item()
+ loss_dict = {
+ "junc_loss": junc_loss,
+ "heatmap_loss": heatmap_loss,
+ "total_loss": total_loss.item()}
+ if compute_descriptors:
+ descriptor_loss = losses["descriptor_loss"].item()
+ loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()
+
+ average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0])
+
+ # Display the progress
+ if (idx % model_cfg["disp_freq"]) == 0:
+ results = metric_func.metric_results
+ average = average_meter.average()
+ # Get gpu memory usage in GB
+ gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3)
+ if compute_descriptors:
+ print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB"
+ % (epoch, model_cfg["epochs"], idx, len(train_loader),
+ total_loss.item(), average["total_loss"], junc_loss,
+ average["junc_loss"], heatmap_loss,
+ average["heatmap_loss"], descriptor_loss,
+ average["descriptor_loss"], gpu_mem_usage))
+ else:
+ print("Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB"
+ % (epoch, model_cfg["epochs"], idx, len(train_loader),
+ total_loss.item(), average["total_loss"],
+ junc_loss, average["junc_loss"], heatmap_loss,
+ average["heatmap_loss"], gpu_mem_usage))
+ print("\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+ % (results["junc_precision"], average["junc_precision"],
+ results["junc_recall"], average["junc_recall"]))
+ print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+ % (results["junc_precision_nms"],
+ average["junc_precision_nms"],
+ results["junc_recall_nms"], average["junc_recall_nms"]))
+ print("\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+ %(results["heatmap_precision"],
+ average["heatmap_precision"],
+ results["heatmap_recall"], average["heatmap_recall"]))
+ if compute_descriptors:
+ print("\t Descriptors matching score=%.4f (%.4f)"
+ %(results["matching_score"], average["matching_score"]))
+
+ # Record summaries
+ if (idx % model_cfg["summary_freq"]) == 0:
+ results = metric_func.metric_results
+ average = average_meter.average()
+ # Add the shared losses
+ scalar_summaries = {
+ "junc_loss": junc_loss,
+ "heatmap_loss": heatmap_loss,
+ "total_loss": total_loss.detach().cpu().numpy(),
+ "metrics": results,
+ "average": average}
+ # Add descriptor terms
+ if compute_descriptors:
+ scalar_summaries["descriptor_loss"] = descriptor_loss
+ scalar_summaries["w_desc"] = losses["w_desc"]
+
+ # Add weighting terms (even for static terms)
+ scalar_summaries["w_junc"] = losses["w_junc"]
+ scalar_summaries["w_heatmap"] = losses["w_heatmap"]
+ scalar_summaries["reg_loss"] = losses["reg_loss"].item()
+
+ num_images = 3
+ junc_pred_binary = (junc_np["junc_pred"][:num_images, ...]
+ > model_cfg["detection_thresh"])
+ junc_pred_nms_binary = (junc_np["junc_pred_nms"][:num_images, ...]
+ > model_cfg["detection_thresh"])
+ image_summaries = {
+ "image": input_images.cpu().numpy()[:num_images, ...],
+ "valid_mask": valid_mask_np[:num_images, ...],
+ "junc_map_pred": junc_pred_binary,
+ "junc_map_pred_nms": junc_pred_nms_binary,
+ "junc_map_gt": junc_map_np[:num_images, ...],
+ "junc_prob_map": junc_np["junc_prob"][:num_images, ...],
+ "heatmap_pred": heatmap_np[:num_images, ...],
+ "heatmap_gt": heatmap_gt_np[:num_images, ...]}
+ # Record the training summary
+ record_train_summaries(
+ writer, global_step, scalars=scalar_summaries,
+ images=image_summaries)
+
+
+def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch):
+ """ Validation. """
+ # Switch the model to eval mode
+ model.eval()
+
+ # Initialize the average meter
+ compute_descriptors = loss_func.compute_descriptors
+ if compute_descriptors:
+ average_meter = AverageMeter(is_training=True, desc_metric_lst='all')
+ else:
+ average_meter = AverageMeter(is_training=True)
+
+ # The validation loop
+ for idx, data in enumerate(val_loader):
+ if compute_descriptors:
+ junc_map = data["ref_junction_map"].cuda()
+ junc_map2 = data["target_junction_map"].cuda()
+ heatmap = data["ref_heatmap"].cuda()
+ heatmap2 = data["target_heatmap"].cuda()
+ line_points = data["ref_line_points"].cuda()
+ line_points2 = data["target_line_points"].cuda()
+ line_indices = data["ref_line_indices"].cuda()
+ valid_mask = data["ref_valid_mask"].cuda()
+ valid_mask2 = data["target_valid_mask"].cuda()
+ input_images = data["ref_image"].cuda()
+ input_images2 = data["target_image"].cuda()
+
+ # Run the forward pass
+ with torch.no_grad():
+ outputs = model(input_images)
+ outputs2 = model(input_images2)
+
+ # Compute losses
+ losses = loss_func.forward_descriptors(
+ outputs["junctions"], outputs2["junctions"],
+ junc_map, junc_map2, outputs["heatmap"],
+ outputs2["heatmap"], heatmap, heatmap2, line_points,
+ line_points2, line_indices, outputs['descriptors'],
+ outputs2['descriptors'], epoch, valid_mask, valid_mask2)
+ else:
+ junc_map = data["junction_map"].cuda()
+ heatmap = data["heatmap"].cuda()
+ valid_mask = data["valid_mask"].cuda()
+ input_images = data["image"].cuda()
+
+ # Run the forward pass
+ with torch.no_grad():
+ outputs = model(input_images)
+
+ # Compute losses
+ losses = loss_func(
+ outputs["junctions"], junc_map,
+ outputs["heatmap"], heatmap,
+ valid_mask)
+ total_loss = losses["total_loss"]
+
+ ############## Measure the metric error #########################
+ junc_np = convert_junc_predictions(
+ outputs["junctions"], model_cfg["grid_size"],
+ model_cfg["detection_thresh"], 300)
+ junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)
+ # Always fetch only one channel (compatible with L1, L2, and CE)
+ if outputs["heatmap"].shape[1] == 2:
+ heatmap_np = softmax(outputs["heatmap"].detach(),
+ dim=1).cpu().numpy().transpose(0, 2, 3, 1)
+ heatmap_np = heatmap_np[:, :, :, 1:]
+ else:
+ heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
+ heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)
+
+
+ heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
+ valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)
+
+ # Evaluate metric results
+ if compute_descriptors:
+ metric_func.evaluate(
+ junc_np["junc_pred"], junc_np["junc_pred_nms"],
+ junc_map_np, heatmap_np, heatmap_gt_np, valid_mask_np,
+ line_points, line_points2, outputs["descriptors"],
+ outputs2["descriptors"], line_indices)
+ else:
+ metric_func.evaluate(
+ junc_np["junc_pred"], junc_np["junc_pred_nms"], junc_map_np,
+ heatmap_np, heatmap_gt_np, valid_mask_np)
+ # Update average meter
+ junc_loss = losses["junc_loss"].item()
+ heatmap_loss = losses["heatmap_loss"].item()
+ loss_dict = {
+ "junc_loss": junc_loss,
+ "heatmap_loss": heatmap_loss,
+ "total_loss": total_loss.item()}
+ if compute_descriptors:
+ descriptor_loss = losses["descriptor_loss"].item()
+ loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()
+ average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0])
+
+ # Display the progress
+ if (idx % model_cfg["disp_freq"]) == 0:
+ results = metric_func.metric_results
+ average = average_meter.average()
+ if compute_descriptors:
+ print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)"
+ % (idx, len(val_loader),
+ total_loss.item(), average["total_loss"],
+ junc_loss, average["junc_loss"],
+ heatmap_loss, average["heatmap_loss"],
+ descriptor_loss, average["descriptor_loss"]))
+ else:
+ print("Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)"
+ % (idx, len(val_loader),
+ total_loss.item(), average["total_loss"],
+ junc_loss, average["junc_loss"],
+ heatmap_loss, average["heatmap_loss"]))
+ print("\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+ % (results["junc_precision"], average["junc_precision"],
+ results["junc_recall"], average["junc_recall"]))
+ print("\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+ % (results["junc_precision_nms"],
+ average["junc_precision_nms"],
+ results["junc_recall_nms"], average["junc_recall_nms"]))
+ print("\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)"
+ % (results["heatmap_precision"],
+ average["heatmap_precision"],
+ results["heatmap_recall"], average["heatmap_recall"]))
+ if compute_descriptors:
+ print("\t Descriptors matching score=%.4f (%.4f)"
+ %(results["matching_score"], average["matching_score"]))
+
+ # Record summaries
+ average = average_meter.average()
+ scalar_summaries = {"average": average}
+ # Record the training summary
+ record_test_summaries(writer, epoch, scalar_summaries)
+
+
+def convert_junc_predictions(predictions, grid_size,
+ detect_thresh=1/65, topk=300):
+ """ Convert torch predictions to numpy arrays for evaluation. """
+ # Convert to probability outputs first
+ junc_prob = softmax(predictions.detach(), dim=1).cpu()
+ junc_pred = junc_prob[:, :-1, :, :]
+
+ junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1]
+ junc_prob_np = np.sum(junc_prob_np, axis=-1)
+ junc_pred_np = pixel_shuffle(
+ junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1)
+ junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
+ junc_pred_np = junc_pred_np.squeeze(-1)
+
+ return {"junc_pred": junc_pred_np, "junc_pred_nms": junc_pred_np_nms,
+ "junc_prob": junc_prob_np}
+
+
+def record_train_summaries(writer, global_step, scalars, images):
+ """ Record training summaries. """
+ # Record the scalar summaries
+ results = scalars["metrics"]
+ average = scalars["average"]
+
+ # GPU memory part
+ # Get gpu memory usage in GB
+ gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024 ** 3)
+ writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step)
+
+ # Loss part
+ writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"],
+ global_step)
+ writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"],
+ global_step)
+ writer.add_scalar("Train_loss/total_loss", scalars["total_loss"],
+ global_step)
+ # Add regularization loss
+ if "reg_loss" in scalars.keys():
+ writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"],
+ global_step)
+ # Add descriptor loss
+ if "descriptor_loss" in scalars.keys():
+ key = "descriptor_loss"
+ writer.add_scalar("Train_loss/%s"%(key), scalars[key], global_step)
+ writer.add_scalar("Train_loss_average/%s"%(key), average[key],
+ global_step)
+
+ # Record weighting
+ for key in scalars.keys():
+ if "w_" in key:
+ writer.add_scalar("Train_weight/%s"%(key), scalars[key],
+ global_step)
+
+ # Smoothed loss
+ writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"],
+ global_step)
+ writer.add_scalar("Train_loss_average/heatmap_loss",
+ average["heatmap_loss"], global_step)
+ writer.add_scalar("Train_loss_average/total_loss", average["total_loss"],
+ global_step)
+ # Add smoothed descriptor loss
+ if "descriptor_loss" in average.keys():
+ writer.add_scalar("Train_loss_average/descriptor_loss",
+ average["descriptor_loss"], global_step)
+
+ # Metrics part
+ writer.add_scalar("Train_metrics/junc_precision",
+ results["junc_precision"], global_step)
+ writer.add_scalar("Train_metrics/junc_precision_nms",
+ results["junc_precision_nms"], global_step)
+ writer.add_scalar("Train_metrics/junc_recall",
+ results["junc_recall"], global_step)
+ writer.add_scalar("Train_metrics/junc_recall_nms",
+ results["junc_recall_nms"], global_step)
+ writer.add_scalar("Train_metrics/heatmap_precision",
+ results["heatmap_precision"], global_step)
+ writer.add_scalar("Train_metrics/heatmap_recall",
+ results["heatmap_recall"], global_step)
+ # Add descriptor metric
+ if "matching_score" in results.keys():
+ writer.add_scalar("Train_metrics/matching_score",
+ results["matching_score"], global_step)
+
+ # Average part
+ writer.add_scalar("Train_metrics_average/junc_precision",
+ average["junc_precision"], global_step)
+ writer.add_scalar("Train_metrics_average/junc_precision_nms",
+ average["junc_precision_nms"], global_step)
+ writer.add_scalar("Train_metrics_average/junc_recall",
+ average["junc_recall"], global_step)
+ writer.add_scalar("Train_metrics_average/junc_recall_nms",
+ average["junc_recall_nms"], global_step)
+ writer.add_scalar("Train_metrics_average/heatmap_precision",
+ average["heatmap_precision"], global_step)
+ writer.add_scalar("Train_metrics_average/heatmap_recall",
+ average["heatmap_recall"], global_step)
+ # Add smoothed descriptor metric
+ if "matching_score" in average.keys():
+ writer.add_scalar("Train_metrics_average/matching_score",
+ average["matching_score"], global_step)
+
+ # Record the image summary
+ # Image part
+ image_tensor = convert_image(images["image"], 1)
+ valid_masks = convert_image(images["valid_mask"], -1)
+ writer.add_images("Train/images", image_tensor, global_step,
+ dataformats="NCHW")
+ writer.add_images("Train/valid_map", valid_masks, global_step,
+ dataformats="NHWC")
+
+ # Heatmap part
+ writer.add_images("Train/heatmap_gt",
+ convert_image(images["heatmap_gt"], -1), global_step,
+ dataformats="NHWC")
+ writer.add_images("Train/heatmap_pred",
+ convert_image(images["heatmap_pred"], -1), global_step,
+ dataformats="NHWC")
+
+ # Junction prediction part
+ junc_plots = plot_junction_detection(
+ image_tensor, images["junc_map_pred"],
+ images["junc_map_pred_nms"], images["junc_map_gt"])
+ writer.add_images("Train/junc_gt", junc_plots["junc_gt_plot"] / 255.,
+ global_step, dataformats="NHWC")
+ writer.add_images("Train/junc_pred", junc_plots["junc_pred_plot"] / 255.,
+ global_step, dataformats="NHWC")
+ writer.add_images("Train/junc_pred_nms",
+ junc_plots["junc_pred_nms_plot"] / 255., global_step,
+ dataformats="NHWC")
+ writer.add_images(
+ "Train/junc_prob_map",
+ convert_image(images["junc_prob_map"][..., None], axis=-1),
+ global_step, dataformats="NHWC")
+
+
+def record_test_summaries(writer, epoch, scalars):
+ """ Record testing summaries. """
+ average = scalars["average"]
+
+ # Average loss
+ writer.add_scalar("Val_loss/junc_loss", average["junc_loss"], epoch)
+ writer.add_scalar("Val_loss/heatmap_loss", average["heatmap_loss"], epoch)
+ writer.add_scalar("Val_loss/total_loss", average["total_loss"], epoch)
+ # Add descriptor loss
+ if "descriptor_loss" in average.keys():
+ key = "descriptor_loss"
+ writer.add_scalar("Val_loss/%s"%(key), average[key], epoch)
+
+ # Average metrics
+ writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"],
+ epoch)
+ writer.add_scalar("Val_metrics/junc_precision_nms",
+ average["junc_precision_nms"], epoch)
+ writer.add_scalar("Val_metrics/junc_recall",
+ average["junc_recall"], epoch)
+ writer.add_scalar("Val_metrics/junc_recall_nms",
+ average["junc_recall_nms"], epoch)
+ writer.add_scalar("Val_metrics/heatmap_precision",
+ average["heatmap_precision"], epoch)
+ writer.add_scalar("Val_metrics/heatmap_recall",
+ average["heatmap_recall"], epoch)
+ # Add descriptor metric
+ if "matching_score" in average.keys():
+ writer.add_scalar("Val_metrics/matching_score",
+ average["matching_score"], epoch)
+
+
+def plot_junction_detection(image_tensor, junc_pred_tensor,
+ junc_pred_nms_tensor, junc_gt_tensor):
+ """ Plot the junction points on images. """
+ # Get the batch_size
+ batch_size = image_tensor.shape[0]
+
+ # Process through batch dimension
+ junc_pred_lst = []
+ junc_pred_nms_lst = []
+ junc_gt_lst = []
+ for i in range(batch_size):
+ # Convert image to 255 uint8
+ image = (image_tensor[i, :, :, :]
+ * 255.).astype(np.uint8).transpose(1,2,0)
+
+ # Plot groundtruth onto image
+ junc_gt = junc_gt_tensor[i, ...]
+ coord_gt = np.where(junc_gt.squeeze() > 0)
+ points_gt = np.concatenate((coord_gt[0][..., None],
+ coord_gt[1][..., None]),
+ axis=1)
+ plot_gt = image.copy()
+ for id in range(points_gt.shape[0]):
+ cv2.circle(plot_gt, tuple(np.flip(points_gt[id, :])), 3,
+ color=(255, 0, 0), thickness=2)
+ junc_gt_lst.append(plot_gt[None, ...])
+
+ # Plot junc_pred
+ junc_pred = junc_pred_tensor[i, ...]
+ coord_pred = np.where(junc_pred > 0)
+ points_pred = np.concatenate((coord_pred[0][..., None],
+ coord_pred[1][..., None]),
+ axis=1)
+ plot_pred = image.copy()
+ for id in range(points_pred.shape[0]):
+ cv2.circle(plot_pred, tuple(np.flip(points_pred[id, :])), 3,
+ color=(0, 255, 0), thickness=2)
+ junc_pred_lst.append(plot_pred[None, ...])
+
+ # Plot junc_pred_nms
+ junc_pred_nms = junc_pred_nms_tensor[i, ...]
+ coord_pred_nms = np.where(junc_pred_nms > 0)
+ points_pred_nms = np.concatenate((coord_pred_nms[0][..., None],
+ coord_pred_nms[1][..., None]),
+ axis=1)
+ plot_pred_nms = image.copy()
+ for id in range(points_pred_nms.shape[0]):
+ cv2.circle(plot_pred_nms, tuple(np.flip(points_pred_nms[id, :])),
+ 3, color=(0, 255, 0), thickness=2)
+ junc_pred_nms_lst.append(plot_pred_nms[None, ...])
+
+ return {"junc_gt_plot": np.concatenate(junc_gt_lst, axis=0),
+ "junc_pred_plot": np.concatenate(junc_pred_lst, axis=0),
+ "junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0)}
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/demo_superglue.py b/imcui/third_party/SuperGluePretrainedNetwork/demo_superglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..32d4ad3c7df1b7da141c4c6aa51f871a7d756aaf
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/demo_superglue.py
@@ -0,0 +1,259 @@
+#! /usr/bin/env python3
+#
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+# Daniel DeTone
+# Tomasz Malisiewicz
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+from pathlib import Path
+import argparse
+import cv2
+import matplotlib.cm as cm
+import torch
+
+from models.matching import Matching
+from models.utils import (AverageTimer, VideoStreamer,
+ make_matching_plot_fast, frame2tensor)
+
+torch.set_grad_enabled(False)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='SuperGlue demo',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '--input', type=str, default='0',
+ help='ID of a USB webcam, URL of an IP camera, '
+ 'or path to an image directory or movie file')
+ parser.add_argument(
+ '--output_dir', type=str, default=None,
+ help='Directory where to write output frames (If None, no output)')
+
+ parser.add_argument(
+ '--image_glob', type=str, nargs='+', default=['*.png', '*.jpg', '*.jpeg'],
+ help='Glob if a directory of images is specified')
+ parser.add_argument(
+ '--skip', type=int, default=1,
+ help='Images to skip if input is a movie or directory')
+ parser.add_argument(
+ '--max_length', type=int, default=1000000,
+ help='Maximum length if input is a movie or directory')
+ parser.add_argument(
+ '--resize', type=int, nargs='+', default=[640, 480],
+ help='Resize the input image before running inference. If two numbers, '
+ 'resize to the exact dimensions, if one number, resize the max '
+ 'dimension, if -1, do not resize')
+
+ parser.add_argument(
+ '--superglue', choices={'indoor', 'outdoor'}, default='indoor',
+ help='SuperGlue weights')
+ parser.add_argument(
+ '--max_keypoints', type=int, default=-1,
+ help='Maximum number of keypoints detected by Superpoint'
+ ' (\'-1\' keeps all keypoints)')
+ parser.add_argument(
+ '--keypoint_threshold', type=float, default=0.005,
+ help='SuperPoint keypoint detector confidence threshold')
+ parser.add_argument(
+ '--nms_radius', type=int, default=4,
+ help='SuperPoint Non Maximum Suppression (NMS) radius'
+ ' (Must be positive)')
+ parser.add_argument(
+ '--sinkhorn_iterations', type=int, default=20,
+ help='Number of Sinkhorn iterations performed by SuperGlue')
+ parser.add_argument(
+ '--match_threshold', type=float, default=0.2,
+ help='SuperGlue match threshold')
+
+ parser.add_argument(
+ '--show_keypoints', action='store_true',
+ help='Show the detected keypoints')
+ parser.add_argument(
+ '--no_display', action='store_true',
+ help='Do not display images to screen. Useful if running remotely')
+ parser.add_argument(
+ '--force_cpu', action='store_true',
+ help='Force pytorch to run in CPU mode.')
+
+ opt = parser.parse_args()
+ print(opt)
+
+ if len(opt.resize) == 2 and opt.resize[1] == -1:
+ opt.resize = opt.resize[0:1]
+ if len(opt.resize) == 2:
+ print('Will resize to {}x{} (WxH)'.format(
+ opt.resize[0], opt.resize[1]))
+ elif len(opt.resize) == 1 and opt.resize[0] > 0:
+ print('Will resize max dimension to {}'.format(opt.resize[0]))
+ elif len(opt.resize) == 1:
+ print('Will not resize images')
+ else:
+ raise ValueError('Cannot specify more than two integers for --resize')
+
+ device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
+ print('Running inference on device \"{}\"'.format(device))
+ config = {
+ 'superpoint': {
+ 'nms_radius': opt.nms_radius,
+ 'keypoint_threshold': opt.keypoint_threshold,
+ 'max_keypoints': opt.max_keypoints
+ },
+ 'superglue': {
+ 'weights': opt.superglue,
+ 'sinkhorn_iterations': opt.sinkhorn_iterations,
+ 'match_threshold': opt.match_threshold,
+ }
+ }
+ matching = Matching(config).eval().to(device)
+ keys = ['keypoints', 'scores', 'descriptors']
+
+ vs = VideoStreamer(opt.input, opt.resize, opt.skip,
+ opt.image_glob, opt.max_length)
+ frame, ret = vs.next_frame()
+ assert ret, 'Error when reading the first frame (try different --input?)'
+
+ frame_tensor = frame2tensor(frame, device)
+ last_data = matching.superpoint({'image': frame_tensor})
+ last_data = {k+'0': last_data[k] for k in keys}
+ last_data['image0'] = frame_tensor
+ last_frame = frame
+ last_image_id = 0
+
+ if opt.output_dir is not None:
+ print('==> Will write outputs to {}'.format(opt.output_dir))
+ Path(opt.output_dir).mkdir(exist_ok=True)
+
+ # Create a window to display the demo.
+ if not opt.no_display:
+ cv2.namedWindow('SuperGlue matches', cv2.WINDOW_NORMAL)
+ cv2.resizeWindow('SuperGlue matches', 640*2, 480)
+ else:
+ print('Skipping visualization, will not show a GUI.')
+
+ # Print the keyboard help menu.
+ print('==> Keyboard control:\n'
+ '\tn: select the current frame as the anchor\n'
+ '\te/r: increase/decrease the keypoint confidence threshold\n'
+ '\td/f: increase/decrease the match filtering threshold\n'
+ '\tk: toggle the visualization of keypoints\n'
+ '\tq: quit')
+
+ timer = AverageTimer()
+
+ while True:
+ frame, ret = vs.next_frame()
+ if not ret:
+ print('Finished demo_superglue.py')
+ break
+ timer.update('data')
+ stem0, stem1 = last_image_id, vs.i - 1
+
+ frame_tensor = frame2tensor(frame, device)
+ pred = matching({**last_data, 'image1': frame_tensor})
+ kpts0 = last_data['keypoints0'][0].cpu().numpy()
+ kpts1 = pred['keypoints1'][0].cpu().numpy()
+ matches = pred['matches0'][0].cpu().numpy()
+ confidence = pred['matching_scores0'][0].cpu().numpy()
+ timer.update('forward')
+
+ valid = matches > -1
+ mkpts0 = kpts0[valid]
+ mkpts1 = kpts1[matches[valid]]
+ color = cm.jet(confidence[valid])
+ text = [
+ 'SuperGlue',
+ 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
+ 'Matches: {}'.format(len(mkpts0))
+ ]
+ k_thresh = matching.superpoint.config['keypoint_threshold']
+ m_thresh = matching.superglue.config['match_threshold']
+ small_text = [
+ 'Keypoint Threshold: {:.4f}'.format(k_thresh),
+ 'Match Threshold: {:.2f}'.format(m_thresh),
+ 'Image Pair: {:06}:{:06}'.format(stem0, stem1),
+ ]
+ out = make_matching_plot_fast(
+ last_frame, frame, kpts0, kpts1, mkpts0, mkpts1, color, text,
+ path=None, show_keypoints=opt.show_keypoints, small_text=small_text)
+
+ if not opt.no_display:
+ cv2.imshow('SuperGlue matches', out)
+ key = chr(cv2.waitKey(1) & 0xFF)
+ if key == 'q':
+ vs.cleanup()
+ print('Exiting (via q) demo_superglue.py')
+ break
+ elif key == 'n': # set the current frame as anchor
+ last_data = {k+'0': pred[k+'1'] for k in keys}
+ last_data['image0'] = frame_tensor
+ last_frame = frame
+ last_image_id = (vs.i - 1)
+ elif key in ['e', 'r']:
+ # Increase/decrease keypoint threshold by 10% each keypress.
+ d = 0.1 * (-1 if key == 'e' else 1)
+ matching.superpoint.config['keypoint_threshold'] = min(max(
+ 0.0001, matching.superpoint.config['keypoint_threshold']*(1+d)), 1)
+ print('\nChanged the keypoint threshold to {:.4f}'.format(
+ matching.superpoint.config['keypoint_threshold']))
+ elif key in ['d', 'f']:
+ # Increase/decrease match threshold by 0.05 each keypress.
+ d = 0.05 * (-1 if key == 'd' else 1)
+ matching.superglue.config['match_threshold'] = min(max(
+ 0.05, matching.superglue.config['match_threshold']+d), .95)
+ print('\nChanged the match threshold to {:.2f}'.format(
+ matching.superglue.config['match_threshold']))
+ elif key == 'k':
+ opt.show_keypoints = not opt.show_keypoints
+
+ timer.update('viz')
+ timer.print()
+
+ if opt.output_dir is not None:
+ #stem = 'matches_{:06}_{:06}'.format(last_image_id, vs.i-1)
+ stem = 'matches_{:06}_{:06}'.format(stem0, stem1)
+ out_file = str(Path(opt.output_dir, stem + '.png'))
+ print('\nWriting image to {}'.format(out_file))
+ cv2.imwrite(out_file, out)
+
+ cv2.destroyAllWindows()
+ vs.cleanup()
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/match_pairs.py b/imcui/third_party/SuperGluePretrainedNetwork/match_pairs.py
new file mode 100644
index 0000000000000000000000000000000000000000..7079687cf69fd71d810ec80442548ad2a7b869e0
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/match_pairs.py
@@ -0,0 +1,425 @@
+#! /usr/bin/env python3
+#
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+# Daniel DeTone
+# Tomasz Malisiewicz
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+from pathlib import Path
+import argparse
+import random
+import numpy as np
+import matplotlib.cm as cm
+import torch
+
+
+from models.matching import Matching
+from models.utils import (compute_pose_error, compute_epipolar_error,
+ estimate_pose, make_matching_plot,
+ error_colormap, AverageTimer, pose_auc, read_image,
+ rotate_intrinsics, rotate_pose_inplane,
+ scale_intrinsics)
+
+torch.set_grad_enabled(False)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='Image pair matching and pose evaluation with SuperGlue',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument(
+ '--input_pairs', type=str, default='assets/scannet_sample_pairs_with_gt.txt',
+ help='Path to the list of image pairs')
+ parser.add_argument(
+ '--input_dir', type=str, default='assets/scannet_sample_images/',
+ help='Path to the directory that contains the images')
+ parser.add_argument(
+ '--output_dir', type=str, default='dump_match_pairs/',
+ help='Path to the directory in which the .npz results and optionally,'
+ 'the visualization images are written')
+
+ parser.add_argument(
+ '--max_length', type=int, default=-1,
+ help='Maximum number of pairs to evaluate')
+ parser.add_argument(
+ '--resize', type=int, nargs='+', default=[640, 480],
+ help='Resize the input image before running inference. If two numbers, '
+ 'resize to the exact dimensions, if one number, resize the max '
+ 'dimension, if -1, do not resize')
+ parser.add_argument(
+ '--resize_float', action='store_true',
+ help='Resize the image after casting uint8 to float')
+
+ parser.add_argument(
+ '--superglue', choices={'indoor', 'outdoor'}, default='indoor',
+ help='SuperGlue weights')
+ parser.add_argument(
+ '--max_keypoints', type=int, default=1024,
+ help='Maximum number of keypoints detected by Superpoint'
+ ' (\'-1\' keeps all keypoints)')
+ parser.add_argument(
+ '--keypoint_threshold', type=float, default=0.005,
+ help='SuperPoint keypoint detector confidence threshold')
+ parser.add_argument(
+ '--nms_radius', type=int, default=4,
+ help='SuperPoint Non Maximum Suppression (NMS) radius'
+ ' (Must be positive)')
+ parser.add_argument(
+ '--sinkhorn_iterations', type=int, default=20,
+ help='Number of Sinkhorn iterations performed by SuperGlue')
+ parser.add_argument(
+ '--match_threshold', type=float, default=0.2,
+ help='SuperGlue match threshold')
+
+ parser.add_argument(
+ '--viz', action='store_true',
+ help='Visualize the matches and dump the plots')
+ parser.add_argument(
+ '--eval', action='store_true',
+ help='Perform the evaluation'
+ ' (requires ground truth pose and intrinsics)')
+ parser.add_argument(
+ '--fast_viz', action='store_true',
+ help='Use faster image visualization with OpenCV instead of Matplotlib')
+ parser.add_argument(
+ '--cache', action='store_true',
+ help='Skip the pair if output .npz files are already found')
+ parser.add_argument(
+ '--show_keypoints', action='store_true',
+ help='Plot the keypoints in addition to the matches')
+ parser.add_argument(
+ '--viz_extension', type=str, default='png', choices=['png', 'pdf'],
+ help='Visualization file extension. Use pdf for highest-quality.')
+ parser.add_argument(
+ '--opencv_display', action='store_true',
+ help='Visualize via OpenCV before saving output images')
+ parser.add_argument(
+ '--shuffle', action='store_true',
+ help='Shuffle ordering of pairs before processing')
+ parser.add_argument(
+ '--force_cpu', action='store_true',
+ help='Force pytorch to run in CPU mode.')
+
+ opt = parser.parse_args()
+ print(opt)
+
+ assert not (opt.opencv_display and not opt.viz), 'Must use --viz with --opencv_display'
+ assert not (opt.opencv_display and not opt.fast_viz), 'Cannot use --opencv_display without --fast_viz'
+ assert not (opt.fast_viz and not opt.viz), 'Must use --viz with --fast_viz'
+ assert not (opt.fast_viz and opt.viz_extension == 'pdf'), 'Cannot use pdf extension with --fast_viz'
+
+ if len(opt.resize) == 2 and opt.resize[1] == -1:
+ opt.resize = opt.resize[0:1]
+ if len(opt.resize) == 2:
+ print('Will resize to {}x{} (WxH)'.format(
+ opt.resize[0], opt.resize[1]))
+ elif len(opt.resize) == 1 and opt.resize[0] > 0:
+ print('Will resize max dimension to {}'.format(opt.resize[0]))
+ elif len(opt.resize) == 1:
+ print('Will not resize images')
+ else:
+ raise ValueError('Cannot specify more than two integers for --resize')
+
+ with open(opt.input_pairs, 'r') as f:
+ pairs = [l.split() for l in f.readlines()]
+
+ if opt.max_length > -1:
+ pairs = pairs[0:np.min([len(pairs), opt.max_length])]
+
+ if opt.shuffle:
+ random.Random(0).shuffle(pairs)
+
+ if opt.eval:
+ if not all([len(p) == 38 for p in pairs]):
+ raise ValueError(
+ 'All pairs should have ground truth info for evaluation.'
+ 'File \"{}\" needs 38 valid entries per row'.format(opt.input_pairs))
+
+ # Load the SuperPoint and SuperGlue models.
+ device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
+ print('Running inference on device \"{}\"'.format(device))
+ config = {
+ 'superpoint': {
+ 'nms_radius': opt.nms_radius,
+ 'keypoint_threshold': opt.keypoint_threshold,
+ 'max_keypoints': opt.max_keypoints
+ },
+ 'superglue': {
+ 'weights': opt.superglue,
+ 'sinkhorn_iterations': opt.sinkhorn_iterations,
+ 'match_threshold': opt.match_threshold,
+ }
+ }
+ matching = Matching(config).eval().to(device)
+
+ # Create the output directories if they do not exist already.
+ input_dir = Path(opt.input_dir)
+ print('Looking for data in directory \"{}\"'.format(input_dir))
+ output_dir = Path(opt.output_dir)
+ output_dir.mkdir(exist_ok=True, parents=True)
+ print('Will write matches to directory \"{}\"'.format(output_dir))
+ if opt.eval:
+ print('Will write evaluation results',
+ 'to directory \"{}\"'.format(output_dir))
+ if opt.viz:
+ print('Will write visualization images to',
+ 'directory \"{}\"'.format(output_dir))
+
+ timer = AverageTimer(newline=True)
+ for i, pair in enumerate(pairs):
+ name0, name1 = pair[:2]
+ stem0, stem1 = Path(name0).stem, Path(name1).stem
+ matches_path = output_dir / '{}_{}_matches.npz'.format(stem0, stem1)
+ eval_path = output_dir / '{}_{}_evaluation.npz'.format(stem0, stem1)
+ viz_path = output_dir / '{}_{}_matches.{}'.format(stem0, stem1, opt.viz_extension)
+ viz_eval_path = output_dir / \
+ '{}_{}_evaluation.{}'.format(stem0, stem1, opt.viz_extension)
+
+ # Handle --cache logic.
+ do_match = True
+ do_eval = opt.eval
+ do_viz = opt.viz
+ do_viz_eval = opt.eval and opt.viz
+ if opt.cache:
+ if matches_path.exists():
+ try:
+ results = np.load(matches_path)
+ except:
+ raise IOError('Cannot load matches .npz file: %s' %
+ matches_path)
+
+ kpts0, kpts1 = results['keypoints0'], results['keypoints1']
+ matches, conf = results['matches'], results['match_confidence']
+ do_match = False
+ if opt.eval and eval_path.exists():
+ try:
+ results = np.load(eval_path)
+ except:
+ raise IOError('Cannot load eval .npz file: %s' % eval_path)
+ err_R, err_t = results['error_R'], results['error_t']
+ precision = results['precision']
+ matching_score = results['matching_score']
+ num_correct = results['num_correct']
+ epi_errs = results['epipolar_errors']
+ do_eval = False
+ if opt.viz and viz_path.exists():
+ do_viz = False
+ if opt.viz and opt.eval and viz_eval_path.exists():
+ do_viz_eval = False
+ timer.update('load_cache')
+
+ if not (do_match or do_eval or do_viz or do_viz_eval):
+ timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs)))
+ continue
+
+ # If a rotation integer is provided (e.g. from EXIF data), use it:
+ if len(pair) >= 5:
+ rot0, rot1 = int(pair[2]), int(pair[3])
+ else:
+ rot0, rot1 = 0, 0
+
+ # Load the image pair.
+ image0, inp0, scales0 = read_image(
+ input_dir / name0, device, opt.resize, rot0, opt.resize_float)
+ image1, inp1, scales1 = read_image(
+ input_dir / name1, device, opt.resize, rot1, opt.resize_float)
+ if image0 is None or image1 is None:
+ print('Problem reading image pair: {} {}'.format(
+ input_dir/name0, input_dir/name1))
+ exit(1)
+ timer.update('load_image')
+
+ if do_match:
+ # Perform the matching.
+ pred = matching({'image0': inp0, 'image1': inp1})
+ pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
+ kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
+ matches, conf = pred['matches0'], pred['matching_scores0']
+ timer.update('matcher')
+
+ # Write the matches to disk.
+ out_matches = {'keypoints0': kpts0, 'keypoints1': kpts1,
+ 'matches': matches, 'match_confidence': conf}
+ np.savez(str(matches_path), **out_matches)
+
+ # Keep the matching keypoints.
+ valid = matches > -1
+ mkpts0 = kpts0[valid]
+ mkpts1 = kpts1[matches[valid]]
+ mconf = conf[valid]
+
+ if do_eval:
+ # Estimate the pose and compute the pose error.
+ assert len(pair) == 38, 'Pair does not have ground truth info'
+ K0 = np.array(pair[4:13]).astype(float).reshape(3, 3)
+ K1 = np.array(pair[13:22]).astype(float).reshape(3, 3)
+ T_0to1 = np.array(pair[22:]).astype(float).reshape(4, 4)
+
+ # Scale the intrinsics to resized image.
+ K0 = scale_intrinsics(K0, scales0)
+ K1 = scale_intrinsics(K1, scales1)
+
+ # Update the intrinsics + extrinsics if EXIF rotation was found.
+ if rot0 != 0 or rot1 != 0:
+ cam0_T_w = np.eye(4)
+ cam1_T_w = T_0to1
+ if rot0 != 0:
+ K0 = rotate_intrinsics(K0, image0.shape, rot0)
+ cam0_T_w = rotate_pose_inplane(cam0_T_w, rot0)
+ if rot1 != 0:
+ K1 = rotate_intrinsics(K1, image1.shape, rot1)
+ cam1_T_w = rotate_pose_inplane(cam1_T_w, rot1)
+ cam1_T_cam0 = cam1_T_w @ np.linalg.inv(cam0_T_w)
+ T_0to1 = cam1_T_cam0
+
+ epi_errs = compute_epipolar_error(mkpts0, mkpts1, T_0to1, K0, K1)
+ correct = epi_errs < 5e-4
+ num_correct = np.sum(correct)
+ precision = np.mean(correct) if len(correct) > 0 else 0
+ matching_score = num_correct / len(kpts0) if len(kpts0) > 0 else 0
+
+ thresh = 1. # In pixels relative to resized image size.
+ ret = estimate_pose(mkpts0, mkpts1, K0, K1, thresh)
+ if ret is None:
+ err_t, err_R = np.inf, np.inf
+ else:
+ R, t, inliers = ret
+ err_t, err_R = compute_pose_error(T_0to1, R, t)
+
+ # Write the evaluation results to disk.
+ out_eval = {'error_t': err_t,
+ 'error_R': err_R,
+ 'precision': precision,
+ 'matching_score': matching_score,
+ 'num_correct': num_correct,
+ 'epipolar_errors': epi_errs}
+ np.savez(str(eval_path), **out_eval)
+ timer.update('eval')
+
+ if do_viz:
+ # Visualize the matches.
+ color = cm.jet(mconf)
+ text = [
+ 'SuperGlue',
+ 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
+ 'Matches: {}'.format(len(mkpts0)),
+ ]
+ if rot0 != 0 or rot1 != 0:
+ text.append('Rotation: {}:{}'.format(rot0, rot1))
+
+ # Display extra parameter info.
+ k_thresh = matching.superpoint.config['keypoint_threshold']
+ m_thresh = matching.superglue.config['match_threshold']
+ small_text = [
+ 'Keypoint Threshold: {:.4f}'.format(k_thresh),
+ 'Match Threshold: {:.2f}'.format(m_thresh),
+ 'Image Pair: {}:{}'.format(stem0, stem1),
+ ]
+
+ make_matching_plot(
+ image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
+ text, viz_path, opt.show_keypoints,
+ opt.fast_viz, opt.opencv_display, 'Matches', small_text)
+
+ timer.update('viz_match')
+
+ if do_viz_eval:
+ # Visualize the evaluation results for the image pair.
+ color = np.clip((epi_errs - 0) / (1e-3 - 0), 0, 1)
+ color = error_colormap(1 - color)
+ deg, delta = ' deg', 'Delta '
+ if not opt.fast_viz:
+ deg, delta = '°', '$\\Delta$'
+ e_t = 'FAIL' if np.isinf(err_t) else '{:.1f}{}'.format(err_t, deg)
+ e_R = 'FAIL' if np.isinf(err_R) else '{:.1f}{}'.format(err_R, deg)
+ text = [
+ 'SuperGlue',
+ '{}R: {}'.format(delta, e_R), '{}t: {}'.format(delta, e_t),
+ 'inliers: {}/{}'.format(num_correct, (matches > -1).sum()),
+ ]
+ if rot0 != 0 or rot1 != 0:
+ text.append('Rotation: {}:{}'.format(rot0, rot1))
+
+ # Display extra parameter info (only works with --fast_viz).
+ k_thresh = matching.superpoint.config['keypoint_threshold']
+ m_thresh = matching.superglue.config['match_threshold']
+ small_text = [
+ 'Keypoint Threshold: {:.4f}'.format(k_thresh),
+ 'Match Threshold: {:.2f}'.format(m_thresh),
+ 'Image Pair: {}:{}'.format(stem0, stem1),
+ ]
+
+ make_matching_plot(
+ image0, image1, kpts0, kpts1, mkpts0,
+ mkpts1, color, text, viz_eval_path,
+ opt.show_keypoints, opt.fast_viz,
+ opt.opencv_display, 'Relative Pose', small_text)
+
+ timer.update('viz_eval')
+
+ timer.print('Finished pair {:5} of {:5}'.format(i, len(pairs)))
+
+ if opt.eval:
+ # Collate the results into a final table and print to terminal.
+ pose_errors = []
+ precisions = []
+ matching_scores = []
+ for pair in pairs:
+ name0, name1 = pair[:2]
+ stem0, stem1 = Path(name0).stem, Path(name1).stem
+ eval_path = output_dir / \
+ '{}_{}_evaluation.npz'.format(stem0, stem1)
+ results = np.load(eval_path)
+ pose_error = np.maximum(results['error_t'], results['error_R'])
+ pose_errors.append(pose_error)
+ precisions.append(results['precision'])
+ matching_scores.append(results['matching_score'])
+ thresholds = [5, 10, 20]
+ aucs = pose_auc(pose_errors, thresholds)
+ aucs = [100.*yy for yy in aucs]
+ prec = 100.*np.mean(precisions)
+ ms = 100.*np.mean(matching_scores)
+ print('Evaluation Results (mean over {} pairs):'.format(len(pairs)))
+ print('AUC@5\t AUC@10\t AUC@20\t Prec\t MScore\t')
+ print('{:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t {:.2f}\t'.format(
+ aucs[0], aucs[1], aucs[2], prec, ms))
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/__init__.py b/imcui/third_party/SuperGluePretrainedNetwork/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/matching.py b/imcui/third_party/SuperGluePretrainedNetwork/models/matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d174208d146373230a8a68dd1420fc59c180633
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/models/matching.py
@@ -0,0 +1,84 @@
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+import torch
+
+from .superpoint import SuperPoint
+from .superglue import SuperGlue
+
+
+class Matching(torch.nn.Module):
+ """ Image Matching Frontend (SuperPoint + SuperGlue) """
+ def __init__(self, config={}):
+ super().__init__()
+ self.superpoint = SuperPoint(config.get('superpoint', {}))
+ self.superglue = SuperGlue(config.get('superglue', {}))
+
+ def forward(self, data):
+ """ Run SuperPoint (optionally) and SuperGlue
+ SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
+ Args:
+ data: dictionary with minimal keys: ['image0', 'image1']
+ """
+ pred = {}
+
+ # Extract SuperPoint (keypoints, scores, descriptors) if not provided
+ if 'keypoints0' not in data:
+ pred0 = self.superpoint({'image': data['image0']})
+ pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
+ if 'keypoints1' not in data:
+ pred1 = self.superpoint({'image': data['image1']})
+ pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
+
+ # Batch all features
+ # We should either have i) one image per batch, or
+ # ii) the same number of local features for all images in the batch.
+ data = {**data, **pred}
+
+ for k in data:
+ if isinstance(data[k], (list, tuple)):
+ data[k] = torch.stack(data[k])
+
+ # Perform the matching
+ pred = {**pred, **self.superglue(data)}
+
+ return pred
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/superglue.py b/imcui/third_party/SuperGluePretrainedNetwork/models/superglue.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a89b0348075bcb918eab123bc988c7102137a3d
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/models/superglue.py
@@ -0,0 +1,285 @@
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+from copy import deepcopy
+from pathlib import Path
+from typing import List, Tuple
+
+import torch
+from torch import nn
+
+
+def MLP(channels: List[int], do_bn: bool = True) -> nn.Module:
+ """ Multi-layer perceptron """
+ n = len(channels)
+ layers = []
+ for i in range(1, n):
+ layers.append(
+ nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
+ if i < (n-1):
+ if do_bn:
+ layers.append(nn.BatchNorm1d(channels[i]))
+ layers.append(nn.ReLU())
+ return nn.Sequential(*layers)
+
+
+def normalize_keypoints(kpts, image_shape):
+ """ Normalize keypoints locations based on image image_shape"""
+ _, _, height, width = image_shape
+ one = kpts.new_tensor(1)
+ size = torch.stack([one*width, one*height])[None]
+ center = size / 2
+ scaling = size.max(1, keepdim=True).values * 0.7
+ return (kpts - center[:, None, :]) / scaling[:, None, :]
+
+
+class KeypointEncoder(nn.Module):
+ """ Joint encoding of visual appearance and location using MLPs"""
+ def __init__(self, feature_dim: int, layers: List[int]) -> None:
+ super().__init__()
+ self.encoder = MLP([3] + layers + [feature_dim])
+ nn.init.constant_(self.encoder[-1].bias, 0.0)
+
+ def forward(self, kpts, scores):
+ inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
+ return self.encoder(torch.cat(inputs, dim=1))
+
+
+def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
+ dim = query.shape[1]
+ scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
+ prob = torch.nn.functional.softmax(scores, dim=-1)
+ return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
+
+
+class MultiHeadedAttention(nn.Module):
+ """ Multi-head attention to increase model expressivitiy """
+ def __init__(self, num_heads: int, d_model: int):
+ super().__init__()
+ assert d_model % num_heads == 0
+ self.dim = d_model // num_heads
+ self.num_heads = num_heads
+ self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
+ self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
+
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+ batch_dim = query.size(0)
+ query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
+ for l, x in zip(self.proj, (query, key, value))]
+ x, _ = attention(query, key, value)
+ return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
+
+
+class AttentionalPropagation(nn.Module):
+ def __init__(self, feature_dim: int, num_heads: int):
+ super().__init__()
+ self.attn = MultiHeadedAttention(num_heads, feature_dim)
+ self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
+ nn.init.constant_(self.mlp[-1].bias, 0.0)
+
+ def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
+ message = self.attn(x, source, source)
+ return self.mlp(torch.cat([x, message], dim=1))
+
+
+class AttentionalGNN(nn.Module):
+ def __init__(self, feature_dim: int, layer_names: List[str]) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList([
+ AttentionalPropagation(feature_dim, 4)
+ for _ in range(len(layer_names))])
+ self.names = layer_names
+
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
+ for layer, name in zip(self.layers, self.names):
+ if name == 'cross':
+ src0, src1 = desc1, desc0
+ else: # if name == 'self':
+ src0, src1 = desc0, desc1
+ delta0, delta1 = layer(desc0, src0), layer(desc1, src1)
+ desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
+ return desc0, desc1
+
+
+def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
+ """ Perform Sinkhorn Normalization in Log-space for stability"""
+ u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
+ for _ in range(iters):
+ u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
+ v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
+ return Z + u.unsqueeze(2) + v.unsqueeze(1)
+
+
+def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
+ """ Perform Differentiable Optimal Transport in Log-space for stability"""
+ b, m, n = scores.shape
+ one = scores.new_tensor(1)
+ ms, ns = (m*one).to(scores), (n*one).to(scores)
+
+ bins0 = alpha.expand(b, m, 1)
+ bins1 = alpha.expand(b, 1, n)
+ alpha = alpha.expand(b, 1, 1)
+
+ couplings = torch.cat([torch.cat([scores, bins0], -1),
+ torch.cat([bins1, alpha], -1)], 1)
+
+ norm = - (ms + ns).log()
+ log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
+ log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
+ log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
+
+ Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
+ Z = Z - norm # multiply probabilities by M+N
+ return Z
+
+
+def arange_like(x, dim: int):
+ return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
+
+
+class SuperGlue(nn.Module):
+ """SuperGlue feature matching middle-end
+
+ Given two sets of keypoints and locations, we determine the
+ correspondences by:
+ 1. Keypoint Encoding (normalization + visual feature and location fusion)
+ 2. Graph Neural Network with multiple self and cross-attention layers
+ 3. Final projection layer
+ 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
+ 5. Thresholding matrix based on mutual exclusivity and a match_threshold
+
+ The correspondence ids use -1 to indicate non-matching points.
+
+ Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
+ Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
+ Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
+
+ """
+ default_config = {
+ 'descriptor_dim': 256,
+ 'weights': 'indoor',
+ 'keypoint_encoder': [32, 64, 128, 256],
+ 'GNN_layers': ['self', 'cross'] * 9,
+ 'sinkhorn_iterations': 100,
+ 'match_threshold': 0.2,
+ }
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = {**self.default_config, **config}
+
+ self.kenc = KeypointEncoder(
+ self.config['descriptor_dim'], self.config['keypoint_encoder'])
+
+ self.gnn = AttentionalGNN(
+ feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers'])
+
+ self.final_proj = nn.Conv1d(
+ self.config['descriptor_dim'], self.config['descriptor_dim'],
+ kernel_size=1, bias=True)
+
+ bin_score = torch.nn.Parameter(torch.tensor(1.))
+ self.register_parameter('bin_score', bin_score)
+
+ assert self.config['weights'] in ['indoor', 'outdoor']
+ path = Path(__file__).parent
+ path = path / 'weights/superglue_{}.pth'.format(self.config['weights'])
+ self.load_state_dict(torch.load(str(path)))
+ print('Loaded SuperGlue model (\"{}\" weights)'.format(
+ self.config['weights']))
+
+ def forward(self, data):
+ """Run SuperGlue on a pair of keypoints and descriptors"""
+ desc0, desc1 = data['descriptors0'], data['descriptors1']
+ kpts0, kpts1 = data['keypoints0'], data['keypoints1']
+
+ if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints
+ shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
+ return {
+ 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int),
+ 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int),
+ 'matching_scores0': kpts0.new_zeros(shape0),
+ 'matching_scores1': kpts1.new_zeros(shape1),
+ }
+
+ # Keypoint normalization.
+ kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
+ kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
+
+ # Keypoint MLP encoder.
+ desc0 = desc0 + self.kenc(kpts0, data['scores0'])
+ desc1 = desc1 + self.kenc(kpts1, data['scores1'])
+
+ # Multi-layer Transformer network.
+ desc0, desc1 = self.gnn(desc0, desc1)
+
+ # Final MLP projection.
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
+
+ # Compute matching descriptor distance.
+ scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
+ scores = scores / self.config['descriptor_dim']**.5
+
+ # Run the optimal transport.
+ scores = log_optimal_transport(
+ scores, self.bin_score,
+ iters=self.config['sinkhorn_iterations'])
+
+ # Get the matches with score above "match_threshold".
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
+ indices0, indices1 = max0.indices, max1.indices
+ mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
+ mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
+ zero = scores.new_tensor(0)
+ mscores0 = torch.where(mutual0, max0.values.exp(), zero)
+ mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
+ valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
+ valid1 = mutual1 & valid0.gather(1, indices1)
+ indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
+ indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
+
+ return {
+ 'matches0': indices0, # use -1 for invalid match
+ 'matches1': indices1, # use -1 for invalid match
+ 'matching_scores0': mscores0,
+ 'matching_scores1': mscores1,
+ }
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/superpoint.py b/imcui/third_party/SuperGluePretrainedNetwork/models/superpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..0577e1ec47c3397e45bc9a3cf2e47f211c32877e
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/models/superpoint.py
@@ -0,0 +1,206 @@
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+from pathlib import Path
+import torch
+from torch import nn
+
+def simple_nms(scores, nms_radius: int):
+ """ Fast Non-maximum suppression to remove nearby points """
+ assert(nms_radius >= 0)
+
+ def max_pool(x):
+ return torch.nn.functional.max_pool2d(
+ x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
+
+ zeros = torch.zeros_like(scores)
+ max_mask = scores == max_pool(scores)
+ for _ in range(2):
+ supp_mask = max_pool(max_mask.float()) > 0
+ supp_scores = torch.where(supp_mask, zeros, scores)
+ new_max_mask = supp_scores == max_pool(supp_scores)
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
+ return torch.where(max_mask, scores, zeros)
+
+
+def remove_borders(keypoints, scores, border: int, height: int, width: int):
+ """ Removes keypoints too close to the border """
+ mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
+ mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
+ mask = mask_h & mask_w
+ return keypoints[mask], scores[mask]
+
+
+def top_k_keypoints(keypoints, scores, k: int):
+ if k >= len(keypoints):
+ return keypoints, scores
+ scores, indices = torch.topk(scores, k, dim=0)
+ return keypoints[indices], scores
+
+
+def sample_descriptors(keypoints, descriptors, s: int = 8):
+ """ Interpolate descriptors at keypoint locations """
+ b, c, h, w = descriptors.shape
+ keypoints = keypoints - s / 2 + 0.5
+ keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
+ ).to(keypoints)[None]
+ keypoints = keypoints*2 - 1 # normalize to (-1, 1)
+ args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
+ descriptors = torch.nn.functional.grid_sample(
+ descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
+ descriptors = torch.nn.functional.normalize(
+ descriptors.reshape(b, c, -1), p=2, dim=1)
+ return descriptors
+
+
+class SuperPoint(nn.Module):
+ """SuperPoint Convolutional Detector and Descriptor
+
+ SuperPoint: Self-Supervised Interest Point Detection and
+ Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
+ Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
+
+ """
+ default_config = {
+ 'descriptor_dim': 256,
+ 'nms_radius': 4,
+ 'keypoint_threshold': 0.005,
+ 'max_keypoints': -1,
+ 'remove_borders': 4,
+ }
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = {**self.default_config, **config}
+
+ self.relu = nn.ReLU(inplace=True)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
+
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
+
+ self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
+
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
+ self.convDb = nn.Conv2d(
+ c5, self.config['descriptor_dim'],
+ kernel_size=1, stride=1, padding=0)
+
+ path = Path(__file__).parent / 'weights/superpoint_v1.pth'
+ self.load_state_dict(torch.load(str(path)))
+
+ mk = self.config['max_keypoints']
+ if mk == 0 or mk < -1:
+ raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
+
+ print('Loaded SuperPoint model')
+
+ def forward(self, data, cfg={}):
+ """Compute keypoints, scores, descriptors for image"""
+ self.config = {
+ **self.config,
+ **cfg,
+ }
+ # Shared Encoder
+ x = self.relu(self.conv1a(data['image']))
+ x = self.relu(self.conv1b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv2a(x))
+ x = self.relu(self.conv2b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv3a(x))
+ x = self.relu(self.conv3b(x))
+ x = self.pool(x)
+ x = self.relu(self.conv4a(x))
+ x = self.relu(self.conv4b(x))
+
+ # Compute the dense keypoint scores
+ cPa = self.relu(self.convPa(x))
+ scores = self.convPb(cPa)
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
+ b, _, h, w = scores.shape
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
+ scores = simple_nms(scores, self.config['nms_radius'])
+
+ # Extract keypoints
+ keypoints = [
+ torch.nonzero(s > self.config['keypoint_threshold'])
+ for s in scores]
+ scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
+
+ # Discard keypoints near the image borders
+ keypoints, scores = list(zip(*[
+ remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
+ for k, s in zip(keypoints, scores)]))
+
+ # Keep the k keypoints with highest score
+ if self.config['max_keypoints'] >= 0:
+ keypoints, scores = list(zip(*[
+ top_k_keypoints(k, s, self.config['max_keypoints'])
+ for k, s in zip(keypoints, scores)]))
+
+ # Convert (h, w) to (x, y)
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
+
+ # Compute the dense descriptors
+ cDa = self.relu(self.convDa(x))
+ descriptors = self.convDb(cDa)
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
+
+ # Extract descriptors
+ descriptors = [sample_descriptors(k[None], d[None], 8)[0]
+ for k, d in zip(keypoints, descriptors)]
+
+ return {
+ 'keypoints': keypoints,
+ 'scores': scores,
+ 'descriptors': descriptors,
+ }
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/utils.py b/imcui/third_party/SuperGluePretrainedNetwork/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1206244aa2a004d9f653782de798bfef9e5e726b
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/models/utils.py
@@ -0,0 +1,555 @@
+# %BANNER_BEGIN%
+# ---------------------------------------------------------------------
+# %COPYRIGHT_BEGIN%
+#
+# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
+#
+# Unpublished Copyright (c) 2020
+# Magic Leap, Inc., All Rights Reserved.
+#
+# NOTICE: All information contained herein is, and remains the property
+# of COMPANY. The intellectual and technical concepts contained herein
+# are proprietary to COMPANY and may be covered by U.S. and Foreign
+# Patents, patents in process, and are protected by trade secret or
+# copyright law. Dissemination of this information or reproduction of
+# this material is strictly forbidden unless prior written permission is
+# obtained from COMPANY. Access to the source code contained herein is
+# hereby forbidden to anyone except current COMPANY employees, managers
+# or contractors who have executed Confidentiality and Non-disclosure
+# agreements explicitly covering such access.
+#
+# The copyright notice above does not evidence any actual or intended
+# publication or disclosure of this source code, which includes
+# information that is confidential and/or proprietary, and is a trade
+# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
+# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
+# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
+# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
+# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
+# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
+# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
+# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
+#
+# %COPYRIGHT_END%
+# ----------------------------------------------------------------------
+# %AUTHORS_BEGIN%
+#
+# Originating Authors: Paul-Edouard Sarlin
+# Daniel DeTone
+# Tomasz Malisiewicz
+#
+# %AUTHORS_END%
+# --------------------------------------------------------------------*/
+# %BANNER_END%
+
+from pathlib import Path
+import time
+from collections import OrderedDict
+from threading import Thread
+import numpy as np
+import cv2
+import torch
+import matplotlib.pyplot as plt
+import matplotlib
+matplotlib.use('Agg')
+
+
+class AverageTimer:
+ """ Class to help manage printing simple timing of code execution. """
+
+ def __init__(self, smoothing=0.3, newline=False):
+ self.smoothing = smoothing
+ self.newline = newline
+ self.times = OrderedDict()
+ self.will_print = OrderedDict()
+ self.reset()
+
+ def reset(self):
+ now = time.time()
+ self.start = now
+ self.last_time = now
+ for name in self.will_print:
+ self.will_print[name] = False
+
+ def update(self, name='default'):
+ now = time.time()
+ dt = now - self.last_time
+ if name in self.times:
+ dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name]
+ self.times[name] = dt
+ self.will_print[name] = True
+ self.last_time = now
+
+ def print(self, text='Timer'):
+ total = 0.
+ print('[{}]'.format(text), end=' ')
+ for key in self.times:
+ val = self.times[key]
+ if self.will_print[key]:
+ print('%s=%.3f' % (key, val), end=' ')
+ total += val
+ print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ')
+ if self.newline:
+ print(flush=True)
+ else:
+ print(end='\r', flush=True)
+ self.reset()
+
+
+class VideoStreamer:
+ """ Class to help process image streams. Four types of possible inputs:"
+ 1.) USB Webcam.
+ 2.) An IP camera
+ 3.) A directory of images (files in directory matching 'image_glob').
+ 4.) A video file, such as an .mp4 or .avi file.
+ """
+ def __init__(self, basedir, resize, skip, image_glob, max_length=1000000):
+ self._ip_grabbed = False
+ self._ip_running = False
+ self._ip_camera = False
+ self._ip_image = None
+ self._ip_index = 0
+ self.cap = []
+ self.camera = True
+ self.video_file = False
+ self.listing = []
+ self.resize = resize
+ self.interp = cv2.INTER_AREA
+ self.i = 0
+ self.skip = skip
+ self.max_length = max_length
+ if isinstance(basedir, int) or basedir.isdigit():
+ print('==> Processing USB webcam input: {}'.format(basedir))
+ self.cap = cv2.VideoCapture(int(basedir))
+ self.listing = range(0, self.max_length)
+ elif basedir.startswith(('http', 'rtsp')):
+ print('==> Processing IP camera input: {}'.format(basedir))
+ self.cap = cv2.VideoCapture(basedir)
+ self.start_ip_camera_thread()
+ self._ip_camera = True
+ self.listing = range(0, self.max_length)
+ elif Path(basedir).is_dir():
+ print('==> Processing image directory input: {}'.format(basedir))
+ self.listing = list(Path(basedir).glob(image_glob[0]))
+ for j in range(1, len(image_glob)):
+ image_path = list(Path(basedir).glob(image_glob[j]))
+ self.listing = self.listing + image_path
+ self.listing.sort()
+ self.listing = self.listing[::self.skip]
+ self.max_length = np.min([self.max_length, len(self.listing)])
+ if self.max_length == 0:
+ raise IOError('No images found (maybe bad \'image_glob\' ?)')
+ self.listing = self.listing[:self.max_length]
+ self.camera = False
+ elif Path(basedir).exists():
+ print('==> Processing video input: {}'.format(basedir))
+ self.cap = cv2.VideoCapture(basedir)
+ self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
+ num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.listing = range(0, num_frames)
+ self.listing = self.listing[::self.skip]
+ self.video_file = True
+ self.max_length = np.min([self.max_length, len(self.listing)])
+ self.listing = self.listing[:self.max_length]
+ else:
+ raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir))
+ if self.camera and not self.cap.isOpened():
+ raise IOError('Could not read camera')
+
+ def load_image(self, impath):
+ """ Read image as grayscale and resize to img_size.
+ Inputs
+ impath: Path to input image.
+ Returns
+ grayim: uint8 numpy array sized H x W.
+ """
+ grayim = cv2.imread(impath, 0)
+ if grayim is None:
+ raise Exception('Error reading image %s' % impath)
+ w, h = grayim.shape[1], grayim.shape[0]
+ w_new, h_new = process_resize(w, h, self.resize)
+ grayim = cv2.resize(
+ grayim, (w_new, h_new), interpolation=self.interp)
+ return grayim
+
+ def next_frame(self):
+ """ Return the next frame, and increment internal counter.
+ Returns
+ image: Next H x W image.
+ status: True or False depending whether image was loaded.
+ """
+
+ if self.i == self.max_length:
+ return (None, False)
+ if self.camera:
+
+ if self._ip_camera:
+ #Wait for first image, making sure we haven't exited
+ while self._ip_grabbed is False and self._ip_exited is False:
+ time.sleep(.001)
+
+ ret, image = self._ip_grabbed, self._ip_image.copy()
+ if ret is False:
+ self._ip_running = False
+ else:
+ ret, image = self.cap.read()
+ if ret is False:
+ print('VideoStreamer: Cannot get image from camera')
+ return (None, False)
+ w, h = image.shape[1], image.shape[0]
+ if self.video_file:
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i])
+
+ w_new, h_new = process_resize(w, h, self.resize)
+ image = cv2.resize(image, (w_new, h_new),
+ interpolation=self.interp)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ else:
+ image_file = str(self.listing[self.i])
+ image = self.load_image(image_file)
+ self.i = self.i + 1
+ return (image, True)
+
+ def start_ip_camera_thread(self):
+ self._ip_thread = Thread(target=self.update_ip_camera, args=())
+ self._ip_running = True
+ self._ip_thread.start()
+ self._ip_exited = False
+ return self
+
+ def update_ip_camera(self):
+ while self._ip_running:
+ ret, img = self.cap.read()
+ if ret is False:
+ self._ip_running = False
+ self._ip_exited = True
+ self._ip_grabbed = False
+ return
+
+ self._ip_image = img
+ self._ip_grabbed = ret
+ self._ip_index += 1
+ #print('IPCAMERA THREAD got frame {}'.format(self._ip_index))
+
+
+ def cleanup(self):
+ self._ip_running = False
+
+# --- PREPROCESSING ---
+
+def process_resize(w, h, resize):
+ assert(len(resize) > 0 and len(resize) <= 2)
+ if len(resize) == 1 and resize[0] > -1:
+ scale = resize[0] / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ elif len(resize) == 1 and resize[0] == -1:
+ w_new, h_new = w, h
+ else: # len(resize) == 2:
+ w_new, h_new = resize[0], resize[1]
+
+ # Issue warning if resolution is too small or too large.
+ if max(w_new, h_new) < 160:
+ print('Warning: input resolution is very small, results may vary')
+ elif max(w_new, h_new) > 2000:
+ print('Warning: input resolution is very large, results may vary')
+
+ return w_new, h_new
+
+
+def frame2tensor(frame, device):
+ return torch.from_numpy(frame/255.).float()[None, None].to(device)
+
+
+def read_image(path, device, resize, rotation, resize_float):
+ image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
+ if image is None:
+ return None, None, None
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = process_resize(w, h, resize)
+ scales = (float(w) / float(w_new), float(h) / float(h_new))
+
+ if resize_float:
+ image = cv2.resize(image.astype('float32'), (w_new, h_new))
+ else:
+ image = cv2.resize(image, (w_new, h_new)).astype('float32')
+
+ if rotation != 0:
+ image = np.rot90(image, k=rotation)
+ if rotation % 2:
+ scales = scales[::-1]
+
+ inp = frame2tensor(image, device)
+ return image, inp, scales
+
+
+# --- GEOMETRY ---
+
+
+def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+
+ f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
+ norm_thresh = thresh / f_mean
+
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf,
+ method=cv2.RANSAC)
+
+ assert E is not None
+
+ best_num_inliers = 0
+ ret = None
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(
+ _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ best_num_inliers = n
+ ret = (R, t[:, 0], mask.ravel() > 0)
+ return ret
+
+
+def rotate_intrinsics(K, image_shape, rot):
+ """image_shape is the shape of the image after rotation"""
+ assert rot <= 3
+ h, w = image_shape[:2][::-1 if (rot % 2) else 1]
+ fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
+ rot = rot % 4
+ if rot == 1:
+ return np.array([[fy, 0., cy],
+ [0., fx, w-1-cx],
+ [0., 0., 1.]], dtype=K.dtype)
+ elif rot == 2:
+ return np.array([[fx, 0., w-1-cx],
+ [0., fy, h-1-cy],
+ [0., 0., 1.]], dtype=K.dtype)
+ else: # if rot == 3:
+ return np.array([[fy, 0., h-1-cy],
+ [0., fx, cx],
+ [0., 0., 1.]], dtype=K.dtype)
+
+
+def rotate_pose_inplane(i_T_w, rot):
+ rotation_matrices = [
+ np.array([[np.cos(r), -np.sin(r), 0., 0.],
+ [np.sin(r), np.cos(r), 0., 0.],
+ [0., 0., 1., 0.],
+ [0., 0., 0., 1.]], dtype=np.float32)
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+ ]
+ return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+ scales = np.diag([1./scales[0], 1./scales[1], 1.])
+ return np.dot(scales, K)
+
+
+def to_homogeneous(points):
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
+
+
+def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1):
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+ kpts0 = to_homogeneous(kpts0)
+ kpts1 = to_homogeneous(kpts1)
+
+ t0, t1, t2 = T_0to1[:3, 3]
+ t_skew = np.array([
+ [0, -t2, t1],
+ [t2, 0, -t0],
+ [-t1, t0, 0]
+ ])
+ E = t_skew @ T_0to1[:3, :3]
+
+ Ep0 = kpts0 @ E.T # N x 3
+ p1Ep0 = np.sum(kpts1 * Ep0, -1) # N
+ Etp1 = kpts1 @ E # N x 3
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2)
+ + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2))
+ return d
+
+
+def angle_error_mat(R1, R2):
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+ cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds
+ return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t):
+ R_gt = T_0to1[:3, :3]
+ t_gt = T_0to1[:3, 3]
+ error_t = angle_error_vec(t, t_gt)
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
+ error_R = angle_error_mat(R, R_gt)
+ return error_t, error_R
+
+
+def pose_auc(errors, thresholds):
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(len(errors)) + 1) / len(errors)
+ errors = np.r_[0., errors]
+ recall = np.r_[0., recall]
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t)
+ r = np.r_[recall[:last_index], recall[last_index-1]]
+ e = np.r_[errors[:last_index], t]
+ aucs.append(np.trapz(r, x=e)/t)
+ return aucs
+
+
+# --- VISUALIZATION ---
+
+
+def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
+ n = len(imgs)
+ assert n == 2, 'number of images must be two'
+ figsize = (size*n, size*3/4) if size is not None else None
+ _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255)
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ plt.tight_layout(pad=pad)
+
+
+def plot_keypoints(kpts0, kpts1, color='w', ps=2):
+ ax = plt.gcf().axes
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4):
+ fig = plt.gcf()
+ ax = fig.axes
+ fig.canvas.draw()
+
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1))
+
+ fig.lines = [matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1,
+ transform=fig.transFigure, c=color[i], linewidth=lw)
+ for i in range(len(kpts0))]
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
+ color, text, path, show_keypoints=False,
+ fast_viz=False, opencv_display=False,
+ opencv_title='matches', small_text=[]):
+
+ if fast_viz:
+ make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
+ color, text, path, show_keypoints, 10,
+ opencv_display, opencv_title, small_text)
+ return
+
+ plot_image_pair([image0, image1])
+ if show_keypoints:
+ plot_keypoints(kpts0, kpts1, color='k', ps=4)
+ plot_keypoints(kpts0, kpts1, color='w', ps=2)
+ plot_matches(mkpts0, mkpts1, color)
+
+ fig = plt.gcf()
+ txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes,
+ fontsize=5, va='bottom', ha='left', color=txt_color)
+
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+
+
+def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0,
+ mkpts1, color, text, path=None,
+ show_keypoints=False, margin=10,
+ opencv_display=False, opencv_title='',
+ small_text=[]):
+ H0, W0 = image0.shape
+ H1, W1 = image1.shape
+ H, W = max(H0, H1), W0 + W1 + margin
+
+ out = 255*np.ones((H, W), np.uint8)
+ out[:H0, :W0] = image0
+ out[:H1, W0+margin:] = image1
+ out = np.stack([out]*3, -1)
+
+ if show_keypoints:
+ kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
+ white = (255, 255, 255)
+ black = (0, 0, 0)
+ for x, y in kpts0:
+ cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA)
+ cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA)
+ for x, y in kpts1:
+ cv2.circle(out, (x + margin + W0, y), 2, black, -1,
+ lineType=cv2.LINE_AA)
+ cv2.circle(out, (x + margin + W0, y), 1, white, -1,
+ lineType=cv2.LINE_AA)
+
+ mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
+ color = (np.array(color[:, :3])*255).astype(int)[:, ::-1]
+ for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color):
+ c = c.tolist()
+ cv2.line(out, (x0, y0), (x1 + margin + W0, y1),
+ color=c, thickness=1, lineType=cv2.LINE_AA)
+ # display line end-points as circles
+ cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA)
+ cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1,
+ lineType=cv2.LINE_AA)
+
+ # Scale factor for consistent visualization across scales.
+ sc = min(H / 640., 2.0)
+
+ # Big text.
+ Ht = int(30 * sc) # text height
+ txt_color_fg = (255, 255, 255)
+ txt_color_bg = (0, 0, 0)
+ for i, t in enumerate(text):
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0*sc, txt_color_bg, 2, cv2.LINE_AA)
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0*sc, txt_color_fg, 1, cv2.LINE_AA)
+
+ # Small text.
+ Ht = int(18 * sc) # text height
+ for i, t in enumerate(reversed(small_text)):
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
+ 0.5*sc, txt_color_bg, 2, cv2.LINE_AA)
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
+ 0.5*sc, txt_color_fg, 1, cv2.LINE_AA)
+
+ if path is not None:
+ cv2.imwrite(str(path), out)
+
+ if opencv_display:
+ cv2.imshow(opencv_title, out)
+ cv2.waitKey(1)
+
+ return out
+
+
+def error_colormap(x):
+ return np.clip(
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1)
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_indoor.pth b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_indoor.pth
new file mode 100644
index 0000000000000000000000000000000000000000..969252133f802cb03256c15a3881b8b39c1867d4
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_indoor.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e710469be25ebe1e2ccf68edcae8b2945b0617c8e7e68412251d9d47f5052b1
+size 48233807
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_outdoor.pth b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_outdoor.pth
new file mode 100644
index 0000000000000000000000000000000000000000..79db4b5340b02afca3cdd419672300bb009975af
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superglue_outdoor.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f5f5e9bb3febf07b69df633c4c3ff7a17f8af26a023aae2b9303d22339195bd
+size 48233807
diff --git a/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superpoint_v1.pth b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superpoint_v1.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7648726e3a3dfa2581e86bfa9c5a2a05cfb9bf74
--- /dev/null
+++ b/imcui/third_party/SuperGluePretrainedNetwork/models/weights/superpoint_v1.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52b6708629640ca883673b5d5c097c4ddad37d8048b33f09c8ca0d69db12c40e
+size 5206086
diff --git a/imcui/third_party/TopicFM/.github/workflows/sync.yml b/imcui/third_party/TopicFM/.github/workflows/sync.yml
new file mode 100644
index 0000000000000000000000000000000000000000..efbf881c64bdeac6916473e4391e23e87af5b69d
--- /dev/null
+++ b/imcui/third_party/TopicFM/.github/workflows/sync.yml
@@ -0,0 +1,39 @@
+name: Upstream Sync
+
+permissions:
+ contents: write
+
+on:
+ schedule:
+ - cron: "0 0 * * *" # every day
+ workflow_dispatch:
+
+jobs:
+ sync_latest_from_upstream:
+ name: Sync latest commits from upstream repo
+ runs-on: ubuntu-latest
+ if: ${{ github.event.repository.fork }}
+
+ steps:
+ # Step 1: run a standard checkout action
+ - name: Checkout target repo
+ uses: actions/checkout@v3
+
+ # Step 2: run the sync action
+ - name: Sync upstream changes
+ id: sync
+ uses: aormsby/Fork-Sync-With-Upstream-action@v3.4
+ with:
+ upstream_sync_repo: TruongKhang/TopicFM
+ upstream_sync_branch: main
+ target_sync_branch: main
+ target_repo_token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, no need to set
+
+ # Set test_mode true to run tests instead of the true action!!
+ test_mode: false
+
+ - name: Sync check
+ if: failure()
+ run: |
+ echo "::error::Due to insufficient permissions, synchronization failed (as expected). Please go to the repository homepage and manually perform [Sync fork]."
+ exit 1
diff --git a/imcui/third_party/TopicFM/configs/data/__init__.py b/imcui/third_party/TopicFM/configs/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/TopicFM/configs/data/base.py b/imcui/third_party/TopicFM/configs/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cab7e67019a6fee2657c1a28609c8aca5b2a1d8
--- /dev/null
+++ b/imcui/third_party/TopicFM/configs/data/base.py
@@ -0,0 +1,37 @@
+"""
+The data config will be the last one merged into the main config.
+Setups in data configs will override all existed setups!
+"""
+
+from yacs.config import CfgNode as CN
+_CN = CN()
+_CN.DATASET = CN()
+_CN.TRAINER = CN()
+
+# training data config
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+# validation set config
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+
+# testing data config
+_CN.DATASET.TEST_DATA_SOURCE = None
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+_CN.DATASET.TEST_IMGSIZE = None
+
+# dataset config
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+cfg = _CN
diff --git a/imcui/third_party/TopicFM/configs/data/megadepth_test_1500.py b/imcui/third_party/TopicFM/configs/data/megadepth_test_1500.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fd107fc07ecd464f793d13282939ddb26032922
--- /dev/null
+++ b/imcui/third_party/TopicFM/configs/data/megadepth_test_1500.py
@@ -0,0 +1,11 @@
+from configs.data.base import cfg
+
+TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info"
+
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
+cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
+cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 1200
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
diff --git a/imcui/third_party/TopicFM/configs/data/megadepth_trainval.py b/imcui/third_party/TopicFM/configs/data/megadepth_trainval.py
new file mode 100644
index 0000000000000000000000000000000000000000..215b5c34cc41d36aa4444a58ca0cb69afbc11952
--- /dev/null
+++ b/imcui/third_party/TopicFM/configs/data/megadepth_trainval.py
@@ -0,0 +1,22 @@
+from configs.data.base import cfg
+
+
+TRAIN_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
+cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
+cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
+
+TEST_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
+cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
+cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+# 368 scenes in total for MegaDepth
+# (with difficulty balanced (further split each scene to 3 sub-scenes))
+cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 800 # for training on 11GB mem GPUs
diff --git a/imcui/third_party/TopicFM/configs/data/scannet_test_1500.py b/imcui/third_party/TopicFM/configs/data/scannet_test_1500.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce3b0846b61c567b053d12fb636982ce02e21a5c
--- /dev/null
+++ b/imcui/third_party/TopicFM/configs/data/scannet_test_1500.py
@@ -0,0 +1,12 @@
+from configs.data.base import cfg
+
+TEST_BASE_PATH = "assets/scannet_test_1500"
+
+cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
+cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
+cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
+cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt"
+cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz"
+cfg.DATASET.TEST_IMGSIZE = (640, 480)
+
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
diff --git a/imcui/third_party/TopicFM/configs/model/indoor/model_cfg_test.py b/imcui/third_party/TopicFM/configs/model/indoor/model_cfg_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e8872d3b79de529aa375127ea5beb7e81d9d5b1
--- /dev/null
+++ b/imcui/third_party/TopicFM/configs/model/indoor/model_cfg_test.py
@@ -0,0 +1,4 @@
+from src.config.default import _CN as cfg
+
+cfg.MODEL.COARSE.N_SAMPLES = 5
+cfg.MODEL.MATCH_COARSE.THR = 0.3
diff --git a/imcui/third_party/TopicFM/configs/model/outdoor/model_cfg_test.py b/imcui/third_party/TopicFM/configs/model/outdoor/model_cfg_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..692497457c2a7b9ad823f94546e38f15732ca632
--- /dev/null
+++ b/imcui/third_party/TopicFM/configs/model/outdoor/model_cfg_test.py
@@ -0,0 +1,4 @@
+from src.config.default import _CN as cfg
+
+cfg.MODEL.COARSE.N_SAMPLES = 10
+cfg.MODEL.MATCH_COARSE.THR = 0.2
diff --git a/imcui/third_party/TopicFM/configs/model/outdoor/model_ds.py b/imcui/third_party/TopicFM/configs/model/outdoor/model_ds.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c090edbfbdcd66cea225c39af6f62da8feb50b9
--- /dev/null
+++ b/imcui/third_party/TopicFM/configs/model/outdoor/model_ds.py
@@ -0,0 +1,16 @@
+from src.config.default import _CN as cfg
+
+cfg.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+cfg.MODEL.COARSE.N_SAMPLES = 8
+
+cfg.TRAINER.CANONICAL_LR = 1e-2
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 16, 20, 24, 28]
+
+# pose estimation
+cfg.TRAINER.RANSAC_PIXEL_THR = 0.5
+
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+cfg.MODEL.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
diff --git a/imcui/third_party/TopicFM/flop_counter.py b/imcui/third_party/TopicFM/flop_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea87fa0139897434ca52b369450aa82203311181
--- /dev/null
+++ b/imcui/third_party/TopicFM/flop_counter.py
@@ -0,0 +1,55 @@
+import torch
+from fvcore.nn import FlopCountAnalysis
+from einops.einops import rearrange
+
+from src import get_model_cfg
+from src.models.backbone import FPN as topicfm_featnet
+from src.models.modules import TopicFormer
+from src.utils.dataset import read_scannet_gray
+
+from third_party.loftr.src.loftr.utils.cvpr_ds_config import default_cfg
+from third_party.loftr.src.loftr.backbone import ResNetFPN_8_2 as loftr_featnet
+from third_party.loftr.src.loftr.loftr_module import LocalFeatureTransformer
+
+
+def feat_net_flops(feat_net, config, input):
+ model = feat_net(config)
+ model.eval()
+ flops = FlopCountAnalysis(model, input)
+ feat_c, _ = model(input)
+ return feat_c, flops.total() / 1e9
+
+
+def coarse_model_flops(coarse_model, config, inputs):
+ model = coarse_model(config)
+ model.eval()
+ flops = FlopCountAnalysis(model, inputs)
+ return flops.total() / 1e9
+
+
+if __name__ == '__main__':
+ path_img0 = "assets/scannet_sample_images/scene0711_00_frame-001680.jpg"
+ path_img1 = "assets/scannet_sample_images/scene0711_00_frame-001995.jpg"
+ img0, img1 = read_scannet_gray(path_img0), read_scannet_gray(path_img1)
+ img0, img1 = img0.unsqueeze(0), img1.unsqueeze(0)
+
+ # LoFTR
+ loftr_conf = dict(default_cfg)
+ feat_c0, loftr_featnet_flops0 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img0)
+ feat_c1, loftr_featnet_flops1 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img1)
+ print("FLOPs of feature extraction in LoFTR: {} GFLOPs".format((loftr_featnet_flops0 + loftr_featnet_flops1)/2))
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+ loftr_coarse_model_flops = coarse_model_flops(LocalFeatureTransformer, loftr_conf["coarse"], (feat_c0, feat_c1))
+ print("FLOPs of coarse matching model in LoFTR: {} GFLOPs".format(loftr_coarse_model_flops))
+
+ # TopicFM
+ topicfm_conf = get_model_cfg()
+ feat_c0, topicfm_featnet_flops0 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img0)
+ feat_c1, topicfm_featnet_flops1 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img1)
+ print("FLOPs of feature extraction in TopicFM: {} GFLOPs".format((topicfm_featnet_flops0 + topicfm_featnet_flops1) / 2))
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+ topicfm_coarse_model_flops = coarse_model_flops(TopicFormer, topicfm_conf["coarse"], (feat_c0, feat_c1))
+ print("FLOPs of coarse matching model in TopicFM: {} GFLOPs".format(topicfm_coarse_model_flops))
+
diff --git a/imcui/third_party/TopicFM/src/__init__.py b/imcui/third_party/TopicFM/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..30caef94f911f99e0c12510d8181b3c1537daf1a
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/__init__.py
@@ -0,0 +1,11 @@
+from yacs.config import CfgNode
+from .config.default import _CN
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CfgNode):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+def get_model_cfg():
+ cfg = lower_config(lower_config(_CN))
+ return cfg["model"]
\ No newline at end of file
diff --git a/imcui/third_party/TopicFM/src/config/default.py b/imcui/third_party/TopicFM/src/config/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..591558b3f358cdce0e9e72e94acba702b2a4e896
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/config/default.py
@@ -0,0 +1,171 @@
+from yacs.config import CfgNode as CN
+_CN = CN()
+
+############## ↓ MODEL Pipeline ↓ ##############
+_CN.MODEL = CN()
+_CN.MODEL.BACKBONE_TYPE = 'FPN'
+_CN.MODEL.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
+_CN.MODEL.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
+_CN.MODEL.FINE_CONCAT_COARSE_FEAT = False
+
+# 1. MODEL-backbone (local feature CNN) config
+_CN.MODEL.FPN = CN()
+_CN.MODEL.FPN.INITIAL_DIM = 128
+_CN.MODEL.FPN.BLOCK_DIMS = [128, 192, 256, 384] # s1, s2, s3
+
+# 2. MODEL-coarse module config
+_CN.MODEL.COARSE = CN()
+_CN.MODEL.COARSE.D_MODEL = 256
+_CN.MODEL.COARSE.D_FFN = 256
+_CN.MODEL.COARSE.NHEAD = 8
+_CN.MODEL.COARSE.LAYER_NAMES = ['seed', 'seed', 'seed', 'seed', 'seed']
+_CN.MODEL.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
+_CN.MODEL.COARSE.TEMP_BUG_FIX = True
+_CN.MODEL.COARSE.N_TOPICS = 100
+_CN.MODEL.COARSE.N_SAMPLES = 6
+_CN.MODEL.COARSE.N_TOPIC_TRANSFORMERS = 1
+
+# 3. Coarse-Matching config
+_CN.MODEL.MATCH_COARSE = CN()
+_CN.MODEL.MATCH_COARSE.THR = 0.2
+_CN.MODEL.MATCH_COARSE.BORDER_RM = 2
+_CN.MODEL.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+_CN.MODEL.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.MODEL.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
+_CN.MODEL.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+_CN.MODEL.MATCH_COARSE.SPARSE_SPVS = True
+
+# 4. MODEL-fine module config
+_CN.MODEL.FINE = CN()
+_CN.MODEL.FINE.D_MODEL = 128
+_CN.MODEL.FINE.D_FFN = 128
+_CN.MODEL.FINE.NHEAD = 4
+_CN.MODEL.FINE.LAYER_NAMES = ['cross'] * 1
+_CN.MODEL.FINE.ATTENTION = 'linear'
+_CN.MODEL.FINE.N_TOPICS = 1
+
+# 5. MODEL Losses
+# -- # coarse-level
+_CN.MODEL.LOSS = CN()
+_CN.MODEL.LOSS.COARSE_WEIGHT = 1.0
+# _CN.MODEL.LOSS.SPARSE_SPVS = False
+# -- - -- # focal loss (coarse)
+_CN.MODEL.LOSS.FOCAL_ALPHA = 0.25
+_CN.MODEL.LOSS.POS_WEIGHT = 1.0
+_CN.MODEL.LOSS.NEG_WEIGHT = 1.0
+# _CN.MODEL.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not.
+# use `_CN.MODEL.MATCH_COARSE.MATCH_TYPE`
+
+# -- # fine-level
+_CN.MODEL.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
+_CN.MODEL.LOSS.FINE_WEIGHT = 1.0
+_CN.MODEL.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
+
+
+############## Dataset ##############
+_CN.DATASET = CN()
+# 1. data config
+# training and validating
+_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+# testing
+_CN.DATASET.TEST_DATA_SOURCE = None
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+_CN.DATASET.TEST_IMGSIZE = None
+
+# 2. dataset config
+# general options
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
+_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
+
+# MegaDepth options
+_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
+_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
+_CN.DATASET.MGDPT_DF = 8
+
+############## Trainer ##############
+_CN.TRAINER = CN()
+_CN.TRAINER.WORLD_SIZE = 1
+_CN.TRAINER.CANONICAL_BS = 64
+_CN.TRAINER.CANONICAL_LR = 6e-3
+_CN.TRAINER.SCALING = None # this will be calculated automatically
+_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
+
+# optimizer
+_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
+_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
+_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
+_CN.TRAINER.ADAMW_DECAY = 0.01
+
+# step-based warm-up
+_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
+_CN.TRAINER.WARMUP_RATIO = 0.
+_CN.TRAINER.WARMUP_STEP = 4800
+
+# learning rate scheduler
+_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
+_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
+_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
+_CN.TRAINER.MSLR_GAMMA = 0.5
+_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
+_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
+
+# plotting related
+_CN.TRAINER.ENABLE_PLOTTING = True
+_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
+_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
+_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
+
+# geometric metrics and pose solver
+_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
+_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
+_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
+_CN.TRAINER.RANSAC_CONF = 0.99999
+_CN.TRAINER.RANSAC_MAX_ITERS = 10000
+_CN.TRAINER.USE_MAGSACPP = False
+
+# data sampler for train_dataloader
+_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
+# 'scene_balance' config
+_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
+_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
+_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
+_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
+# 'random' config
+_CN.TRAINER.RDM_REPLACEMENT = True
+_CN.TRAINER.RDM_NUM_SAMPLES = None
+
+# gradient clipping
+_CN.TRAINER.GRADIENT_CLIPPING = 0.5
+
+# reproducibility
+# This seed affects the data sampling. With the same seed, the data sampling is promised
+# to be the same. When resume training from a checkpoint, it's better to use a different
+# seed, otherwise the sampled data will be exactly the same as before resuming, which will
+# cause less unique data items sampled during the entire training.
+# Use of different seed values might affect the final training result, since not all data items
+# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
+_CN.TRAINER.SEED = 66
+
+
+def get_cfg_defaults():
+ """Get a yacs CfgNode object with default values for my_project."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _CN.clone()
diff --git a/imcui/third_party/TopicFM/src/datasets/aachen.py b/imcui/third_party/TopicFM/src/datasets/aachen.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebfeee4dbfbd78770976ec027ceee8ef333a4574
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/datasets/aachen.py
@@ -0,0 +1,29 @@
+import os
+from torch.utils.data import Dataset
+
+from src.utils.dataset import read_img_gray
+
+
+class AachenDataset(Dataset):
+ def __init__(self, img_path, match_list_path, img_resize=None, down_factor=16):
+ self.img_path = img_path
+ self.img_resize = img_resize
+ self.down_factor = down_factor
+ with open(match_list_path, 'r') as f:
+ self.raw_pairs = f.readlines()
+ print("number of matching pairs: ", len(self.raw_pairs))
+
+ def __len__(self):
+ return len(self.raw_pairs)
+
+ def __getitem__(self, idx):
+ raw_pair = self.raw_pairs[idx]
+ image_name0, image_name1 = raw_pair.strip('\n').split(' ')
+ path_img0 = os.path.join(self.img_path, image_name0)
+ path_img1 = os.path.join(self.img_path, image_name1)
+ img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor)
+ img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor)
+ return {"image0": img0, "image1": img1,
+ "scale0": scale0, "scale1": scale1,
+ "pair_names": (image_name0, image_name1),
+ "dataset_name": "AachenDayNight"}
\ No newline at end of file
diff --git a/imcui/third_party/TopicFM/src/datasets/custom_dataloader.py b/imcui/third_party/TopicFM/src/datasets/custom_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..46d55d4f4d56d2c96cd42b6597834f945a5eb20d
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/datasets/custom_dataloader.py
@@ -0,0 +1,126 @@
+from tqdm import tqdm
+from os import path as osp
+from torch.utils.data import Dataset, DataLoader, ConcatDataset
+
+from src.datasets.megadepth import MegaDepthDataset
+from src.datasets.scannet import ScanNetDataset
+from src.datasets.aachen import AachenDataset
+from src.datasets.inloc import InLocDataset
+
+
+class TestDataLoader(DataLoader):
+ """
+ For distributed training, each training process is assgined
+ only a part of the training scenes to reduce memory overhead.
+ """
+
+ def __init__(self, config):
+
+ # 1. data config
+ self.test_data_source = config.DATASET.TEST_DATA_SOURCE
+ dataset_name = str(self.test_data_source).lower()
+ # testing
+ self.test_data_root = config.DATASET.TEST_DATA_ROOT
+ self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
+ self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
+ self.test_list_path = config.DATASET.TEST_LIST_PATH
+ self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
+
+ # 2. dataset config
+ # general options
+ self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
+
+ # MegaDepth options
+ if dataset_name == 'megadepth':
+ self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 800
+ self.mgdpt_img_pad = True
+ self.mgdpt_depth_pad = True
+ self.mgdpt_df = 8
+ self.coarse_scale = 0.125
+ if dataset_name == 'scannet':
+ self.img_resize = config.DATASET.TEST_IMGSIZE
+
+ if (dataset_name == 'megadepth') or (dataset_name == 'scannet'):
+ test_dataset = self._setup_dataset(
+ self.test_data_root,
+ self.test_npz_root,
+ self.test_list_path,
+ self.test_intrinsic_path,
+ mode='test',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.test_pose_root)
+ elif dataset_name == 'aachen_v1.1':
+ test_dataset = AachenDataset(self.test_data_root, self.test_list_path,
+ img_resize=config.DATASET.TEST_IMGSIZE)
+ elif dataset_name == 'inloc':
+ test_dataset = InLocDataset(self.test_data_root, self.test_list_path,
+ img_resize=config.DATASET.TEST_IMGSIZE)
+ else:
+ raise "unknown dataset"
+
+ self.test_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': 4,
+ 'pin_memory': True
+ }
+
+ # sampler = Seq(self.test_dataset, shuffle=False)
+ super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params)
+
+ def _setup_dataset(self,
+ data_root,
+ split_npz_root,
+ scene_list_path,
+ intri_path,
+ mode='train',
+ min_overlap_score=0.,
+ pose_dir=None):
+ """ Setup train / val / test set"""
+ with open(scene_list_path, 'r') as f:
+ npz_names = [name.split()[0] for name in f.readlines()]
+ local_npz_names = npz_names
+
+ return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path,
+ mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+
+ def _build_concat_dataset(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None
+ ):
+ datasets = []
+ # augment_fn = self.augment_fn if mode == 'train' else None
+ data_source = self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ for npz_name in tqdm(npz_names):
+ # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
+ npz_path = osp.join(npz_dir, npz_name)
+ if data_source == 'ScanNet':
+ datasets.append(
+ ScanNetDataset(data_root,
+ npz_path,
+ intrinsic_path,
+ mode=mode, img_resize=self.img_resize,
+ min_overlap_score=min_overlap_score,
+ pose_dir=pose_dir))
+ elif data_source == 'MegaDepth':
+ datasets.append(
+ MegaDepthDataset(data_root,
+ npz_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ coarse_scale=self.coarse_scale))
+ else:
+ raise NotImplementedError()
+ return ConcatDataset(datasets)
diff --git a/imcui/third_party/TopicFM/src/datasets/inloc.py b/imcui/third_party/TopicFM/src/datasets/inloc.py
new file mode 100644
index 0000000000000000000000000000000000000000..5421099d11b4dbbea8c09568c493d844d5c6a1b0
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/datasets/inloc.py
@@ -0,0 +1,29 @@
+import os
+from torch.utils.data import Dataset
+
+from src.utils.dataset import read_img_gray
+
+
+class InLocDataset(Dataset):
+ def __init__(self, img_path, match_list_path, img_resize=None, down_factor=16):
+ self.img_path = img_path
+ self.img_resize = img_resize
+ self.down_factor = down_factor
+ with open(match_list_path, 'r') as f:
+ self.raw_pairs = f.readlines()
+ print("number of matching pairs: ", len(self.raw_pairs))
+
+ def __len__(self):
+ return len(self.raw_pairs)
+
+ def __getitem__(self, idx):
+ raw_pair = self.raw_pairs[idx]
+ image_name0, image_name1 = raw_pair.strip('\n').split(' ')
+ path_img0 = os.path.join(self.img_path, image_name0)
+ path_img1 = os.path.join(self.img_path, image_name1)
+ img0, scale0 = read_img_gray(path_img0, resize=self.img_resize, down_factor=self.down_factor)
+ img1, scale1 = read_img_gray(path_img1, resize=self.img_resize, down_factor=self.down_factor)
+ return {"image0": img0, "image1": img1,
+ "scale0": scale0, "scale1": scale1,
+ "pair_names": (image_name0, image_name1),
+ "dataset_name": "InLoc"}
\ No newline at end of file
diff --git a/imcui/third_party/TopicFM/src/datasets/megadepth.py b/imcui/third_party/TopicFM/src/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..e92768e72e373c2a8ebeaf1158f9710fb1bfb5f1
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/datasets/megadepth.py
@@ -0,0 +1,129 @@
+import os.path as osp
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+from loguru import logger
+
+from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
+
+
+class MegaDepthDataset(Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ mode='train',
+ min_overlap_score=0.4,
+ img_resize=None,
+ df=None,
+ img_padding=False,
+ depth_padding=False,
+ augment_fn=None,
+ **kwargs):
+ """
+ Manage one scene(npz_path) of MegaDepth dataset.
+
+ Args:
+ root_dir (str): megadepth root directory that has `phoenix`.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ mode (str): options are ['train', 'val', 'test']
+ min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
+ This is useful during training with batches and testing with memory intensive algorithms.
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
+ depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.mode = mode
+ self.scene_id = npz_path.split('.')[0]
+
+ # prepare scene_info and pair_info
+ if mode == 'test' and min_overlap_score != 0:
+ logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
+ min_overlap_score = 0
+ self.scene_info = np.load(npz_path, allow_pickle=True)
+ self.pair_infos = self.scene_info['pair_infos'].copy()
+ del self.scene_info['pair_infos']
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
+
+ # parameters for image resizing, padding and depthmap padding
+ if mode == 'train':
+ assert img_resize is not None and img_padding and depth_padding
+ self.img_resize = img_resize
+ if mode == 'val':
+ self.img_resize = 864
+ self.df = df
+ self.img_padding = img_padding
+ self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
+
+ # for training LoFTR
+ self.augment_fn = augment_fn if mode == 'train' else None
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+
+ def __len__(self):
+ return len(self.pair_infos)
+
+ def __getitem__(self, idx):
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
+
+ # read grayscale image and mask. (1, h, w) and (h, w)
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
+
+ # TODO: Support augmentation & handle seeds for each worker correctly.
+ image0, mask0, scale0 = read_megadepth_gray(
+ img_name0, self.img_resize, self.df, self.img_padding, None)
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+ image1, mask1, scale1 = read_megadepth_gray(
+ img_name1, self.img_resize, self.df, self.img_padding, None)
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+
+ # read depth. shape: (h, w)
+ if self.mode in ['train', 'val']:
+ depth0 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
+ depth1 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read intrinsics of original size
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T0 = self.scene_info['poses'][idx0]
+ T1 = self.scene_info['poses'][idx1]
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
+ T_1to0 = T_0to1.inverse()
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'scale0': scale0, # [scale_w, scale_h]
+ 'scale1': scale1,
+ 'dataset_name': 'MegaDepth',
+ 'scene_id': self.scene_id,
+ 'pair_id': idx,
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
+ }
+
+ # for LoFTR training
+ if mask0 is not None: # img_padding is True
+ if self.coarse_scale:
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
+ scale_factor=self.coarse_scale,
+ mode='nearest',
+ recompute_scale_factor=False)[0].bool()
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+
+ return data
diff --git a/imcui/third_party/TopicFM/src/datasets/sampler.py b/imcui/third_party/TopicFM/src/datasets/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/datasets/sampler.py
@@ -0,0 +1,77 @@
+import torch
+from torch.utils.data import Sampler, ConcatDataset
+
+
+class RandomConcatSampler(Sampler):
+ """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
+ in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
+ However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
+
+ For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
+ Args:
+ shuffle (bool): shuffle the random sampled indices across all sub-datsets.
+ repeat (int): repeatedly use the sampled indices multiple times for training.
+ [arXiv:1902.05509, arXiv:1901.09335]
+ NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
+ NOTE: This sampler behaves differently with DistributedSampler.
+ It assume the dataset is splitted across ranks instead of replicated.
+ TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
+ ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
+ """
+ def __init__(self,
+ data_source: ConcatDataset,
+ n_samples_per_subset: int,
+ subset_replacement: bool=True,
+ shuffle: bool=True,
+ repeat: int=1,
+ seed: int=None):
+ if not isinstance(data_source, ConcatDataset):
+ raise TypeError("data_source should be torch.utils.data.ConcatDataset")
+
+ self.data_source = data_source
+ self.n_subset = len(self.data_source.datasets)
+ self.n_samples_per_subset = n_samples_per_subset
+ self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
+ self.subset_replacement = subset_replacement
+ self.repeat = repeat
+ self.shuffle = shuffle
+ self.generator = torch.manual_seed(seed)
+ assert self.repeat >= 1
+
+ def __len__(self):
+ return self.n_samples
+
+ def __iter__(self):
+ indices = []
+ # sample from each sub-dataset
+ for d_idx in range(self.n_subset):
+ low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
+ high = self.data_source.cumulative_sizes[d_idx]
+ if self.subset_replacement:
+ rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ else: # sample without replacement
+ len_subset = len(self.data_source.datasets[d_idx])
+ rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
+ if len_subset >= self.n_samples_per_subset:
+ rand_tensor = rand_tensor[:self.n_samples_per_subset]
+ else: # padding with replacement
+ rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
+ indices.append(rand_tensor)
+ indices = torch.cat(indices)
+ if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
+ rand_tensor = torch.randperm(len(indices), generator=self.generator)
+ indices = indices[rand_tensor]
+
+ # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
+ if self.repeat > 1:
+ repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
+ if self.shuffle:
+ _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
+ repeat_indices = map(_choice, repeat_indices)
+ indices = torch.cat([indices, *repeat_indices], 0)
+
+ assert indices.shape[0] == self.n_samples
+ return iter(indices.tolist())
diff --git a/imcui/third_party/TopicFM/src/datasets/scannet.py b/imcui/third_party/TopicFM/src/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb5dab7b150a3c6f54eb07b0459bbf3e9ba58fbf
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/datasets/scannet.py
@@ -0,0 +1,115 @@
+from os import path as osp
+from typing import Dict
+from unicodedata import name
+
+import numpy as np
+import torch
+import torch.utils as utils
+from numpy.linalg import inv
+from src.utils.dataset import (
+ read_scannet_gray,
+ read_scannet_depth,
+ read_scannet_pose,
+ read_scannet_intrinsic
+)
+
+
+class ScanNetDataset(utils.data.Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ intrinsic_path,
+ mode='train',
+ min_overlap_score=0.4,
+ augment_fn=None,
+ pose_dir=None,
+ **kwargs):
+ """Manage one scene of ScanNet Dataset.
+ Args:
+ root_dir (str): ScanNet root directory that contains scene folders.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ intrinsic_path (str): path to depth-camera intrinsic file.
+ mode (str): options are ['train', 'val', 'test'].
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ pose_dir (str): ScanNet root directory that contains all poses.
+ (we use a separate (optional) pose_dir since we store images and poses separately.)
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.pose_dir = pose_dir if pose_dir is not None else root_dir
+ self.mode = mode
+ self.img_resize = (640, 480) if 'img_resize' not in kwargs else kwargs['img_resize']
+
+ # prepare data_names, intrinsics and extrinsics(T)
+ with np.load(npz_path) as data:
+ self.data_names = data['name']
+ if 'score' in data.keys() and mode not in ['val' or 'test']:
+ kept_mask = data['score'] > min_overlap_score
+ self.data_names = self.data_names[kept_mask]
+ self.intrinsics = dict(np.load(intrinsic_path))
+
+ # for training LoFTR
+ self.augment_fn = augment_fn if mode == 'train' else None
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def _read_abs_pose(self, scene_name, name):
+ pth = osp.join(self.pose_dir,
+ scene_name,
+ 'pose', f'{name}.txt')
+ return read_scannet_pose(pth)
+
+ def _compute_rel_pose(self, scene_name, name0, name1):
+ pose0 = self._read_abs_pose(scene_name, name0)
+ pose1 = self._read_abs_pose(scene_name, name1)
+
+ return np.matmul(pose1, inv(pose0)) # (4, 4)
+
+ def __getitem__(self, idx):
+ data_name = self.data_names[idx]
+ scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the grayscale image which will be resized to (1, 480, 640)
+ img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
+ img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
+
+ # TODO: Support augmentation & handle seeds for each worker correctly.
+ image0 = read_scannet_gray(img_name0, resize=self.img_resize, augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+ image1 = read_scannet_gray(img_name1, resize=self.img_resize, augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+
+ # read the depthmap which is stored as (480, 640)
+ if self.mode in ['train', 'val']:
+ depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
+ depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read the intrinsic of depthmap
+ K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
+ dtype=torch.float32)
+ T_1to0 = T_0to1.inverse()
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'dataset_name': 'ScanNet',
+ 'scene_id': scene_name,
+ 'pair_id': idx,
+ 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
+ osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
+ }
+
+ return data
diff --git a/imcui/third_party/TopicFM/src/lightning_trainer/data.py b/imcui/third_party/TopicFM/src/lightning_trainer/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..8deb713b6300e0e9e8a261e2230031174b452862
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/lightning_trainer/data.py
@@ -0,0 +1,320 @@
+import os
+import math
+from collections import abc
+from loguru import logger
+from torch.utils.data.dataset import Dataset
+from tqdm import tqdm
+from os import path as osp
+from pathlib import Path
+from joblib import Parallel, delayed
+
+import pytorch_lightning as pl
+from torch import distributed as dist
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset,
+ DistributedSampler,
+ RandomSampler,
+ dataloader
+)
+
+from src.utils.augment import build_augmentor
+from src.utils.dataloader import get_local_split
+from src.utils.misc import tqdm_joblib
+from src.utils import comm
+from src.datasets.megadepth import MegaDepthDataset
+from src.datasets.scannet import ScanNetDataset
+from src.datasets.sampler import RandomConcatSampler
+
+
+class MultiSceneDataModule(pl.LightningDataModule):
+ """
+ For distributed training, each training process is assgined
+ only a part of the training scenes to reduce memory overhead.
+ """
+ def __init__(self, args, config):
+ super().__init__()
+
+ # 1. data config
+ # Train and Val should from the same data source
+ self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE
+ self.test_data_source = config.DATASET.TEST_DATA_SOURCE
+ # training and validating
+ self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
+ self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional)
+ self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
+ self.train_list_path = config.DATASET.TRAIN_LIST_PATH
+ self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
+ self.val_data_root = config.DATASET.VAL_DATA_ROOT
+ self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional)
+ self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
+ self.val_list_path = config.DATASET.VAL_LIST_PATH
+ self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
+ # testing
+ self.test_data_root = config.DATASET.TEST_DATA_ROOT
+ self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
+ self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
+ self.test_list_path = config.DATASET.TEST_LIST_PATH
+ self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
+
+ # 2. dataset config
+ # general options
+ self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
+ self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
+ self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
+
+ # MegaDepth options
+ self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
+ self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
+ self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
+ self.mgdpt_df = config.DATASET.MGDPT_DF # 8
+ self.coarse_scale = 1 / config.MODEL.RESOLUTION[0] # 0.125. for training loftr.
+
+ # 3.loader parameters
+ self.train_loader_params = {
+ 'batch_size': args.batch_size,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.val_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.test_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': True
+ }
+
+ # 4. sampler
+ self.data_sampler = config.TRAINER.DATA_SAMPLER
+ self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
+ self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
+ self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
+ self.repeat = config.TRAINER.SB_REPEAT
+
+ # (optional) RandomSampler for debugging
+
+ # misc configurations
+ self.parallel_load_data = getattr(args, 'parallel_load_data', False)
+ self.seed = config.TRAINER.SEED # 66
+
+ def setup(self, stage=None):
+ """
+ Setup train / val / test dataset. This method will be called by PL automatically.
+ Args:
+ stage (str): 'fit' in training phase, and 'test' in testing phase.
+ """
+
+ assert stage in ['fit', 'test'], "stage must be either fit or test"
+
+ try:
+ self.world_size = dist.get_world_size()
+ self.rank = dist.get_rank()
+ logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
+ except AssertionError as ae:
+ self.world_size = 1
+ self.rank = 0
+ logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
+
+ if stage == 'fit':
+ self.train_dataset = self._setup_dataset(
+ self.train_data_root,
+ self.train_npz_root,
+ self.train_list_path,
+ self.train_intrinsic_path,
+ mode='train',
+ min_overlap_score=self.min_overlap_score_train,
+ pose_dir=self.train_pose_root)
+ # setup multiple (optional) validation subsets
+ if isinstance(self.val_list_path, (list, tuple)):
+ self.val_dataset = []
+ if not isinstance(self.val_npz_root, (list, tuple)):
+ self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
+ for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
+ self.val_dataset.append(self._setup_dataset(
+ self.val_data_root,
+ npz_root,
+ npz_list,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root))
+ else:
+ self.val_dataset = self._setup_dataset(
+ self.val_data_root,
+ self.val_npz_root,
+ self.val_list_path,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root)
+ logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
+ else: # stage == 'test
+ self.test_dataset = self._setup_dataset(
+ self.test_data_root,
+ self.test_npz_root,
+ self.test_list_path,
+ self.test_intrinsic_path,
+ mode='test',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.test_pose_root)
+ logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
+
+ def _setup_dataset(self,
+ data_root,
+ split_npz_root,
+ scene_list_path,
+ intri_path,
+ mode='train',
+ min_overlap_score=0.,
+ pose_dir=None):
+ """ Setup train / val / test set"""
+ with open(scene_list_path, 'r') as f:
+ npz_names = [name.split()[0] for name in f.readlines()]
+
+ if mode == 'train':
+ local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
+ else:
+ local_npz_names = npz_names
+ logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
+
+ dataset_builder = self._build_concat_dataset_parallel \
+ if self.parallel_load_data \
+ else self._build_concat_dataset
+ return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
+ mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+
+ def _build_concat_dataset(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None
+ ):
+ datasets = []
+ augment_fn = self.augment_fn if mode == 'train' else None
+ data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ for npz_name in tqdm(npz_names,
+ desc=f'[rank:{self.rank}] loading {mode} datasets',
+ disable=int(self.rank) != 0):
+ # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
+ npz_path = osp.join(npz_dir, npz_name)
+ if data_source == 'ScanNet':
+ datasets.append(
+ ScanNetDataset(data_root,
+ npz_path,
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir))
+ elif data_source == 'MegaDepth':
+ datasets.append(
+ MegaDepthDataset(data_root,
+ npz_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale))
+ else:
+ raise NotImplementedError()
+ return ConcatDataset(datasets)
+
+ def _build_concat_dataset_parallel(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None,
+ ):
+ augment_fn = self.augment_fn if mode == 'train' else None
+ data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
+ total=len(npz_names), disable=int(self.rank) != 0)):
+ if data_source == 'ScanNet':
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ ScanNetDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir))(name)
+ for name in npz_names)
+ elif data_source == 'MegaDepth':
+ # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
+ raise NotImplementedError()
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ MegaDepthDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale))(name)
+ for name in npz_names)
+ else:
+ raise ValueError(f'Unknown dataset: {data_source}')
+ return ConcatDataset(datasets)
+
+ def train_dataloader(self):
+ """ Build training dataloader for ScanNet / MegaDepth. """
+ assert self.data_sampler in ['scene_balance']
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
+ if self.data_sampler == 'scene_balance':
+ sampler = RandomConcatSampler(self.train_dataset,
+ self.n_samples_per_subset,
+ self.subset_replacement,
+ self.shuffle, self.repeat, self.seed)
+ else:
+ sampler = None
+ dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
+ return dataloader
+
+ def val_dataloader(self):
+ """ Build validation dataloader for ScanNet / MegaDepth. """
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
+ if not isinstance(self.val_dataset, abc.Sequence):
+ sampler = DistributedSampler(self.val_dataset, shuffle=False)
+ return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
+ else:
+ dataloaders = []
+ for dataset in self.val_dataset:
+ sampler = DistributedSampler(dataset, shuffle=False)
+ dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
+ return dataloaders
+
+ def test_dataloader(self, *args, **kwargs):
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
+ sampler = DistributedSampler(self.test_dataset, shuffle=False)
+ return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
+
+
+def _build_dataset(dataset: Dataset, *args, **kwargs):
+ return dataset(*args, **kwargs)
diff --git a/imcui/third_party/TopicFM/src/lightning_trainer/trainer.py b/imcui/third_party/TopicFM/src/lightning_trainer/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..acf51f66130be66b7d3294ca5c081a2df3856d96
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/lightning_trainer/trainer.py
@@ -0,0 +1,244 @@
+
+from collections import defaultdict
+import pprint
+from loguru import logger
+from pathlib import Path
+
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from matplotlib import pyplot as plt
+
+from src.models import TopicFM
+from src.models.utils.supervision import compute_supervision_coarse, compute_supervision_fine
+from src.losses.loss import TopicFMLoss
+from src.optimizers import build_optimizer, build_scheduler
+from src.utils.metrics import (
+ compute_symmetrical_epipolar_errors,
+ compute_pose_errors,
+ aggregate_metrics
+)
+from src.utils.plotting import make_matching_figures
+from src.utils.comm import gather, all_gather
+from src.utils.misc import lower_config, flattenList
+from src.utils.profiler import PassThroughProfiler
+
+
+class PL_Trainer(pl.LightningModule):
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
+ """
+ TODO:
+ - use the new version of PL logging API.
+ """
+ super().__init__()
+ # Misc
+ self.config = config # full config
+ _config = lower_config(self.config)
+ self.model_cfg = lower_config(_config['model'])
+ self.profiler = profiler or PassThroughProfiler()
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+
+ # Matcher: TopicFM
+ self.matcher = TopicFM(config=_config['model'])
+ self.loss = TopicFMLoss(_config)
+
+ # Pretrained weights
+ if pretrained_ckpt:
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
+ self.matcher.load_state_dict(state_dict, strict=True)
+ logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
+
+ # Testing
+ self.dump_dir = dump_dir
+
+ def configure_optimizers(self):
+ # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
+ optimizer = build_optimizer(self, self.config)
+ scheduler = build_scheduler(self.config, optimizer)
+ return [optimizer], [scheduler]
+
+ def optimizer_step(
+ self, epoch, batch_idx, optimizer, optimizer_idx,
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+ # learning rate warm up
+ warmup_step = self.config.TRAINER.WARMUP_STEP
+ if self.trainer.global_step < warmup_step:
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
+ lr = base_lr + \
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
+ for pg in optimizer.param_groups:
+ pg['lr'] = lr
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+ pass
+ else:
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+
+ # update params
+ optimizer.step(closure=optimizer_closure)
+ optimizer.zero_grad()
+
+ def _trainval_inference(self, batch):
+ with self.profiler.profile("Compute coarse supervision"):
+ compute_supervision_coarse(batch, self.config)
+
+ with self.profiler.profile("TopicFM"):
+ self.matcher(batch)
+
+ with self.profiler.profile("Compute fine supervision"):
+ compute_supervision_fine(batch, self.config)
+
+ with self.profiler.profile("Compute losses"):
+ self.loss(batch)
+
+ def _compute_metrics(self, batch):
+ with self.profiler.profile("Copmute metrics"):
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
+
+ rel_pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].size(0)
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
+ 'R_errs': batch['R_errs'],
+ 't_errs': batch['t_errs'],
+ 'inliers': batch['inliers']}
+ ret_dict = {'metrics': metrics}
+ return ret_dict, rel_pair_names
+
+ def training_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ # logging
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+ # scalars
+ for k, v in batch['loss_scalars'].items():
+ self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
+
+ # figures
+ if self.config.TRAINER.ENABLE_PLOTTING:
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+ for k, v in figures.items():
+ self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
+
+ return {'loss': batch['loss']}
+
+ def training_epoch_end(self, outputs):
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+ if self.trainer.global_rank == 0:
+ self.logger.experiment.add_scalar(
+ 'train/avg_loss_on_epoch', avg_loss,
+ global_step=self.current_epoch)
+
+ def validation_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ ret_dict, _ = self._compute_metrics(batch)
+
+ val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
+ figures = {self.config.TRAINER.PLOT_MODE: []}
+ if batch_idx % val_plot_interval == 0:
+ figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
+
+ return {
+ **ret_dict,
+ 'loss_scalars': batch['loss_scalars'],
+ 'figures': figures,
+ }
+
+ def validation_epoch_end(self, outputs):
+ # handle multiple validation sets
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+ multi_val_metrics = defaultdict(list)
+
+ for valset_idx, outputs in enumerate(multi_outputs):
+ # since pl performs sanity_check at the very begining of the training
+ cur_epoch = self.trainer.current_epoch
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+ cur_epoch = -1
+
+ # 1. loss_scalars: dict of list, on cpu
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+
+ # 2. val metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+ for thr in [5, 10, 20]:
+ multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
+
+ # 3. figures
+ _figures = [o['figures'] for o in outputs]
+ figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
+
+ # tensorboard records only on rank 0
+ if self.trainer.global_rank == 0:
+ for k, v in loss_scalars.items():
+ mean_v = torch.stack(v).mean()
+ self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+
+ for k, v in val_metrics_4tb.items():
+ self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
+
+ for k, v in figures.items():
+ if self.trainer.global_rank == 0:
+ for plot_idx, fig in enumerate(v):
+ self.logger.experiment.add_figure(
+ f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
+ plt.close('all')
+
+ for thr in [5, 10, 20]:
+ # log on all ranks for ModelCheckpoint callback to work properly
+ self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
+
+ def test_step(self, batch, batch_idx):
+ with self.profiler.profile("TopicFM"):
+ self.matcher(batch)
+
+ ret_dict, rel_pair_names = self._compute_metrics(batch)
+
+ with self.profiler.profile("dump_results"):
+ if self.dump_dir is not None:
+ # dump results for further analysis
+ keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'}
+ pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].shape[0]
+ dumps = []
+ for b_id in range(bs):
+ item = {}
+ mask = batch['m_bids'] == b_id
+ item['pair_names'] = pair_names[b_id]
+ item['identifier'] = '#'.join(rel_pair_names[b_id])
+ for key in keys_to_save:
+ item[key] = batch[key][mask].cpu().numpy()
+ for key in ['R_errs', 't_errs', 'inliers']:
+ item[key] = batch[key][b_id]
+ dumps.append(item)
+ ret_dict['dumps'] = dumps
+
+ return ret_dict
+
+ def test_epoch_end(self, outputs):
+ # metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+
+ # [{key: [{...}, *#bs]}, *#batch]
+ if self.dump_dir is not None:
+ Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
+ _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
+ dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
+ logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
+
+ if self.trainer.global_rank == 0:
+ print(self.profiler.summary())
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+ logger.info('\n' + pprint.pformat(val_metrics_4tb))
+ if self.dump_dir is not None:
+ np.save(Path(self.dump_dir) / 'TopicFM_pred_eval', dumps)
diff --git a/imcui/third_party/TopicFM/src/losses/loss.py b/imcui/third_party/TopicFM/src/losses/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4be58498579c9fe649ed0ce2d42f230e59cef581
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/losses/loss.py
@@ -0,0 +1,182 @@
+from loguru import logger
+
+import torch
+import torch.nn as nn
+
+
+def sample_non_matches(pos_mask, match_ids=None, sampling_ratio=10):
+ # assert (pos_mask.shape == mask.shape) # [B, H*W, H*W]
+ if match_ids is not None:
+ HW = pos_mask.shape[1]
+ b_ids, i_ids, j_ids = match_ids
+ if len(b_ids) == 0:
+ return ~pos_mask
+
+ neg_mask = torch.zeros_like(pos_mask)
+ probs = torch.ones((HW - 1)//3, device=pos_mask.device)
+ for _ in range(sampling_ratio):
+ d = torch.multinomial(probs, len(j_ids), replacement=True)
+ sampled_j_ids = (j_ids + d*3 + 1) % HW
+ neg_mask[b_ids, i_ids, sampled_j_ids] = True
+ # neg_mask = neg_matrix == 1
+ else:
+ neg_mask = ~pos_mask
+
+ return neg_mask
+
+
+class TopicFMLoss(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config # config under the global namespace
+ self.loss_config = config['model']['loss']
+ self.match_type = self.config['model']['match_coarse']['match_type']
+
+ # coarse-level
+ self.correct_thr = self.loss_config['fine_correct_thr']
+ self.c_pos_w = self.loss_config['pos_weight']
+ self.c_neg_w = self.loss_config['neg_weight']
+ # fine-level
+ self.fine_type = self.loss_config['fine_type']
+
+ def compute_coarse_loss(self, conf, topic_mat, conf_gt, match_ids=None, weight=None):
+ """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
+ Args:
+ conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
+ conf_gt (torch.Tensor): (N, HW0, HW1)
+ weight (torch.Tensor): (N, HW0, HW1)
+ """
+ pos_mask = conf_gt == 1
+ neg_mask = sample_non_matches(pos_mask, match_ids=match_ids)
+ c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w
+ # corner case: no gt coarse-level match at all
+ if not pos_mask.any(): # assign a wrong gt
+ pos_mask[0, 0, 0] = True
+ if weight is not None:
+ weight[0, 0, 0] = 0.
+ c_pos_w = 0.
+ if not neg_mask.any():
+ neg_mask[0, 0, 0] = True
+ if weight is not None:
+ weight[0, 0, 0] = 0.
+ c_neg_w = 0.
+
+ conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
+ alpha = self.loss_config['focal_alpha']
+
+ loss = 0.0
+ if isinstance(topic_mat, torch.Tensor):
+ pos_topic = topic_mat[pos_mask]
+ loss_pos_topic = - alpha * (pos_topic + 1e-6).log()
+ neg_topic = topic_mat[neg_mask]
+ loss_neg_topic = - alpha * (1 - neg_topic + 1e-6).log()
+ if weight is not None:
+ loss_pos_topic = loss_pos_topic * weight[pos_mask]
+ loss_neg_topic = loss_neg_topic * weight[neg_mask]
+ loss = loss_pos_topic.mean() + loss_neg_topic.mean()
+
+ pos_conf = conf[pos_mask]
+ loss_pos = - alpha * pos_conf.log()
+ # handle loss weights
+ if weight is not None:
+ # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
+ # but only through manually setting corresponding regions in sim_matrix to '-inf'.
+ loss_pos = loss_pos * weight[pos_mask]
+
+ loss = loss + c_pos_w * loss_pos.mean()
+
+ return loss
+
+ def compute_fine_loss(self, expec_f, expec_f_gt):
+ if self.fine_type == 'l2_with_std':
+ return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
+ elif self.fine_type == 'l2':
+ return self._compute_fine_loss_l2(expec_f, expec_f_gt)
+ else:
+ raise NotImplementedError()
+
+ def _compute_fine_loss_l2(self, expec_f, expec_f_gt):
+ """
+ Args:
+ expec_f (torch.Tensor): [M, 2]
+ expec_f_gt (torch.Tensor): [M, 2]
+ """
+ correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+ if correct_mask.sum() == 0:
+ if self.training: # this seldomly happen when training, since we pad prediction with gt
+ logger.warning("assign a false supervision to avoid ddp deadlock")
+ correct_mask[0] = True
+ else:
+ return None
+ offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1)
+ return offset_l2.mean()
+
+ def _compute_fine_loss_l2_std(self, expec_f, expec_f_gt):
+ """
+ Args:
+ expec_f (torch.Tensor): [M, 3]
+ expec_f_gt (torch.Tensor): [M, 2]
+ """
+ # correct_mask tells you which pair to compute fine-loss
+ correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
+
+ # use std as weight that measures uncertainty
+ std = expec_f[:, 2]
+ inverse_std = 1. / torch.clamp(std, min=1e-10)
+ weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std
+
+ # corner case: no correct coarse match found
+ if not correct_mask.any():
+ if self.training: # this seldomly happen during training, since we pad prediction with gt
+ # sometimes there is not coarse-level gt at all.
+ logger.warning("assign a false supervision to avoid ddp deadlock")
+ correct_mask[0] = True
+ weight[0] = 0.
+ else:
+ return None
+
+ # l2 loss with std
+ offset_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1)
+ loss = (offset_l2 * weight[correct_mask]).mean()
+
+ return loss
+
+ @torch.no_grad()
+ def compute_c_weight(self, data):
+ """ compute element-wise weights for computing coarse-level loss. """
+ if 'mask0' in data:
+ c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
+ else:
+ c_weight = None
+ return c_weight
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): update{
+ 'loss': [1] the reduced loss across a batch,
+ 'loss_scalars' (dict): loss scalars for tensorboard_record
+ }
+ """
+ loss_scalars = {}
+ # 0. compute element-wise loss weight
+ c_weight = self.compute_c_weight(data)
+
+ # 1. coarse-level loss
+ loss_c = self.compute_coarse_loss(data['conf_matrix'], data['topic_matrix'],
+ data['conf_matrix_gt'], match_ids=(data['spv_b_ids'], data['spv_i_ids'], data['spv_j_ids']),
+ weight=c_weight)
+ loss = loss_c * self.loss_config['coarse_weight']
+ loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
+
+ # 2. fine-level loss
+ loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt'])
+ if loss_f is not None:
+ loss += loss_f * self.loss_config['fine_weight']
+ loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
+ else:
+ assert self.training is False
+ loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound
+
+ loss_scalars.update({'loss': loss.clone().detach().cpu()})
+ data.update({"loss": loss, "loss_scalars": loss_scalars})
diff --git a/imcui/third_party/TopicFM/src/models/__init__.py b/imcui/third_party/TopicFM/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9abdbdaebbf6c91a6fdc24e23d62c73003b204bf
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/__init__.py
@@ -0,0 +1 @@
+from .topic_fm import TopicFM
diff --git a/imcui/third_party/TopicFM/src/models/backbone/__init__.py b/imcui/third_party/TopicFM/src/models/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53f98db4e910b46716bed7cfc6ebbf8c8bfad399
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/backbone/__init__.py
@@ -0,0 +1,5 @@
+from .fpn import FPN
+
+
+def build_backbone(config):
+ return FPN(config['fpn'])
diff --git a/imcui/third_party/TopicFM/src/models/backbone/fpn.py b/imcui/third_party/TopicFM/src/models/backbone/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..93cc475f57317f9dbb8132cdfe0297391972f9e2
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/backbone/fpn.py
@@ -0,0 +1,109 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution without padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_planes, planes, stride=1, bn=True):
+ super().__init__()
+ self.conv = conv3x3(in_planes, planes, stride)
+ self.bn = nn.BatchNorm2d(planes) if bn is True else None
+ self.act = nn.GELU()
+
+ def forward(self, x):
+ y = self.conv(x)
+ if self.bn:
+ y = self.bn(y) #F.layer_norm(y, y.shape[1:])
+ y = self.act(y)
+ return y
+
+
+class FPN(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1/2.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = ConvBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+ self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
+
+ # 3. FPN upsample
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
+ self.layer3_outconv2 = nn.Sequential(
+ ConvBlock(block_dims[3], block_dims[2]),
+ conv3x3(block_dims[2], block_dims[2]),
+ )
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ ConvBlock(block_dims[2], block_dims[1]),
+ conv3x3(block_dims[1], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ ConvBlock(block_dims[1], block_dims[0]),
+ conv3x3(block_dims[0], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+ x4 = self.layer4(x3) # 1/16
+
+ # FPN
+ x4_out_2x = F.interpolate(x4, scale_factor=2., mode='bilinear', align_corners=True)
+ x3_out = self.layer3_outconv(x3)
+ x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ return [x3_out, x1_out]
diff --git a/imcui/third_party/TopicFM/src/models/modules/__init__.py b/imcui/third_party/TopicFM/src/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59cf36da37104dcf080e1b2c119c8123fa8d147f
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/modules/__init__.py
@@ -0,0 +1,2 @@
+from .transformer import LocalFeatureTransformer, TopicFormer
+from .fine_preprocess import FinePreprocess
diff --git a/imcui/third_party/TopicFM/src/models/modules/fine_preprocess.py b/imcui/third_party/TopicFM/src/models/modules/fine_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c8d264c1895be8f4e124fc3982d4e0d3b876af3
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/modules/fine_preprocess.py
@@ -0,0 +1,59 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange, repeat
+
+
+class FinePreprocess(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.cat_c_feat = config['fine_concat_coarse_feat']
+ self.W = self.config['fine_window_size']
+
+ d_model_c = self.config['coarse']['d_model']
+ d_model_f = self.config['fine']['d_model']
+ self.d_model_f = d_model_f
+ if self.cat_c_feat:
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
+ self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
+
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
+ W = self.W
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
+
+ data.update({'W': W})
+ if data['b_ids'].shape[0] == 0:
+ feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
+ feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
+ return feat0, feat1
+
+ # 1. unfold(crop) all local windows
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+
+ # option: use coarse-level feature as context: concat and linear
+ if self.cat_c_feat:
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
+ feat_cf_win = self.merge_feat(torch.cat([
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
+ ], -1))
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
+
+ return feat_f0_unfold, feat_f1_unfold
diff --git a/imcui/third_party/TopicFM/src/models/modules/linear_attention.py b/imcui/third_party/TopicFM/src/models/modules/linear_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6cd825033e98b7be15cc694ce28110ef84cc93
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/modules/linear_attention.py
@@ -0,0 +1,81 @@
+"""
+Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
+Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
+"""
+
+import torch
+from torch.nn import Module, Dropout
+
+
+def elu_feature_map(x):
+ return torch.nn.functional.elu(x) + 1
+
+
+class LinearAttention(Module):
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.feature_map = elu_feature_map
+ self.eps = eps
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = self.feature_map(queries)
+ K = self.feature_map(keys)
+
+ # set padded position to zero
+ if q_mask is not None:
+ Q = Q * q_mask[:, :, None, None]
+ if kv_mask is not None:
+ K = K * kv_mask[:, :, None, None]
+ values = values * kv_mask[:, :, None, None]
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+
+ return queried_values.contiguous()
+
+
+class FullAttention(Module):
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ self.dropout = Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
+ if kv_mask is not None:
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]).bool(), -1e9)
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+ if self.use_dropout:
+ A = self.dropout(A)
+
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
+
+ return queried_values.contiguous()
diff --git a/imcui/third_party/TopicFM/src/models/modules/transformer.py b/imcui/third_party/TopicFM/src/models/modules/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ff8f6554844b1e14a7094fcbad40876f766db8
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/modules/transformer.py
@@ -0,0 +1,232 @@
+from loguru import logger
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .linear_attention import LinearAttention, FullAttention
+
+
+class LoFTREncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear'):
+ super(LoFTREncoderLayer, self).__init__()
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.GELU(),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, L, C]
+ source (torch.Tensor): [N, S, C]
+ x_mask (torch.Tensor): [N, L] (optional)
+ source_mask (torch.Tensor): [N, S] (optional)
+ """
+ bs = x.shape[0]
+ query, key, value = x, source, source
+
+ # multi-head attention
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=2))
+ message = self.norm2(message)
+
+ return x + message
+
+
+class TopicFormer(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(TopicFormer, self).__init__()
+
+ self.config = config
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+ self.layer_names = config['layer_names']
+ encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
+
+ self.topic_transformers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2*config['n_topic_transformers'])]) if config['n_samples'] > 0 else None #nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)])
+ self.n_iter_topic_transformer = config['n_topic_transformers']
+
+ self.seed_tokens = nn.Parameter(torch.randn(config['n_topics'], config['d_model']))
+ self.register_parameter('seed_tokens', self.seed_tokens)
+ self.n_samples = config['n_samples']
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def sample_topic(self, prob_topics, topics, L):
+ """
+ Args:
+ topics (torch.Tensor): [N, L+S, K]
+ """
+ prob_topics0, prob_topics1 = prob_topics[:, :L], prob_topics[:, L:]
+ topics0, topics1 = topics[:, :L], topics[:, L:]
+
+ theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1) # [N, K]
+ theta1 = F.normalize(prob_topics1.sum(dim=1), p=1, dim=-1)
+ theta = F.normalize(theta0 * theta1, p=1, dim=-1)
+ if self.n_samples == 0:
+ return None
+ if self.training:
+ sampled_inds = torch.multinomial(theta, self.n_samples)
+ sampled_values = torch.gather(theta, dim=-1, index=sampled_inds)
+ else:
+ sampled_values, sampled_inds = torch.topk(theta, self.n_samples, dim=-1)
+ sampled_topics0 = torch.gather(topics0, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1))
+ sampled_topics1 = torch.gather(topics1, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1))
+ return sampled_topics0, sampled_topics1
+
+ def reduce_feat(self, feat, topick, N, C):
+ len_topic = topick.sum(dim=-1).int()
+ max_len = len_topic.max().item()
+ selected_ids = topick.bool()
+ resized_feat = torch.zeros((N, max_len, C), dtype=torch.float32, device=feat.device)
+ new_mask = torch.zeros_like(resized_feat[..., 0]).bool()
+ for i in range(N):
+ new_mask[i, :len_topic[i]] = True
+ resized_feat[new_mask, :] = feat[selected_ids, :]
+ return resized_feat, new_mask, selected_ids
+
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+
+ assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal"
+ N, L, S, C, K = feat0.shape[0], feat0.shape[1], feat1.shape[1], feat0.shape[2], self.config['n_topics']
+
+ seeds = self.seed_tokens.unsqueeze(0).repeat(N, 1, 1)
+
+ feat = torch.cat((feat0, feat1), dim=1)
+ if mask0 is not None:
+ mask = torch.cat((mask0, mask1), dim=-1)
+ else:
+ mask = None
+
+ for layer, name in zip(self.layers, self.layer_names):
+ if name == 'seed':
+ # seeds = layer(seeds, feat0, None, mask0)
+ # seeds = layer(seeds, feat1, None, mask1)
+ seeds = layer(seeds, feat, None, mask)
+ elif name == 'feat':
+ feat0 = layer(feat0, seeds, mask0, None)
+ feat1 = layer(feat1, seeds, mask1, None)
+
+ dmatrix = torch.einsum("nmd,nkd->nmk", feat, seeds)
+ prob_topics = F.softmax(dmatrix, dim=-1)
+
+ feat_topics = torch.zeros_like(dmatrix).scatter_(-1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0)
+
+ if mask is not None:
+ feat_topics = feat_topics * mask.unsqueeze(-1)
+ prob_topics = prob_topics * mask.unsqueeze(-1)
+
+ if (feat_topics.detach().sum(dim=1).sum(dim=0) > 100).sum() <= 3:
+ logger.warning("topic distribution is highly sparse!")
+ sampled_topics = self.sample_topic(prob_topics.detach(), feat_topics, L)
+ if sampled_topics is not None:
+ updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like(feat1)
+ s_topics0, s_topics1 = sampled_topics
+ for k in range(s_topics0.shape[-1]):
+ topick0, topick1 = s_topics0[..., k], s_topics1[..., k] # [N, L+S]
+ if (topick0.sum() > 0) and (topick1.sum() > 0):
+ new_feat0, new_mask0, selected_ids0 = self.reduce_feat(feat0, topick0, N, C)
+ new_feat1, new_mask1, selected_ids1 = self.reduce_feat(feat1, topick1, N, C)
+ for idt in range(self.n_iter_topic_transformer):
+ new_feat0 = self.topic_transformers[idt*2](new_feat0, new_feat0, new_mask0, new_mask0)
+ new_feat1 = self.topic_transformers[idt*2](new_feat1, new_feat1, new_mask1, new_mask1)
+ new_feat0 = self.topic_transformers[idt*2+1](new_feat0, new_feat1, new_mask0, new_mask1)
+ new_feat1 = self.topic_transformers[idt*2+1](new_feat1, new_feat0, new_mask1, new_mask0)
+ updated_feat0[selected_ids0, :] = new_feat0[new_mask0, :]
+ updated_feat1[selected_ids1, :] = new_feat1[new_mask1, :]
+
+ feat0 = (1 - s_topics0.sum(dim=-1, keepdim=True)) * feat0 + updated_feat0
+ feat1 = (1 - s_topics1.sum(dim=-1, keepdim=True)) * feat1 + updated_feat1
+
+ conf_matrix = torch.einsum("nlc,nsc->nls", feat0, feat1) / C**.5 #(C * temperature)
+ if self.training:
+ topic_matrix = torch.einsum("nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:])
+ outlier_mask = torch.einsum("nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:])
+ else:
+ topic_matrix = {"img0": feat_topics[:, :L], "img1": feat_topics[:, L:]}
+ outlier_mask = torch.ones_like(conf_matrix)
+ if mask0 is not None:
+ outlier_mask = (outlier_mask * mask0[..., None] * mask1[:, None]) #.bool()
+ conf_matrix.masked_fill_(~outlier_mask.bool(), -1e9)
+ conf_matrix = F.softmax(conf_matrix, 1) * F.softmax(conf_matrix, 2) # * topic_matrix
+
+ return feat0, feat1, conf_matrix, topic_matrix
+
+
+class LocalFeatureTransformer(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(LocalFeatureTransformer, self).__init__()
+
+ self.config = config
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+ self.layer_names = config['layer_names']
+ encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) #len(self.layer_names))])
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+
+ assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal"
+
+ feat0 = self.layers[0](feat0, feat1, mask0, mask1)
+ feat1 = self.layers[1](feat1, feat0, mask1, mask0)
+
+ return feat0, feat1
diff --git a/imcui/third_party/TopicFM/src/models/topic_fm.py b/imcui/third_party/TopicFM/src/models/topic_fm.py
new file mode 100644
index 0000000000000000000000000000000000000000..95cd22f9b66d08760382fe4cd22c4df918cc9f68
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/topic_fm.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+from einops.einops import rearrange
+
+from .backbone import build_backbone
+from .modules import LocalFeatureTransformer, FinePreprocess, TopicFormer
+from .utils.coarse_matching import CoarseMatching
+from .utils.fine_matching import FineMatching
+
+
+class TopicFM(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Misc
+ self.config = config
+
+ # Modules
+ self.backbone = build_backbone(config)
+
+ self.loftr_coarse = TopicFormer(config['coarse'])
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
+ self.fine_preprocess = FinePreprocess(config)
+ self.loftr_fine = LocalFeatureTransformer(config["fine"])
+ self.fine_matching = FineMatching()
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
+ feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
+ (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
+ else: # handle different input shapes
+ (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
+
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
+ })
+
+ # 2. coarse-level loftr module
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
+
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
+
+ feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
+ data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix}) ######
+
+ # 3. match coarse-level
+ self.coarse_matching(data)
+
+ # 4. fine-level refinement
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data)
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
+
+ # 5. match fine-level
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
+
+ def load_state_dict(self, state_dict, *args, **kwargs):
+ for k in list(state_dict.keys()):
+ if k.startswith('matcher.'):
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+ return super().load_state_dict(state_dict, *args, **kwargs)
diff --git a/imcui/third_party/TopicFM/src/models/utils/coarse_matching.py b/imcui/third_party/TopicFM/src/models/utils/coarse_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..75adbb5cc465220e759a044f96f86c08da2d7a50
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/utils/coarse_matching.py
@@ -0,0 +1,217 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange
+
+INF = 1e9
+
+def mask_border(m, b: int, v):
+ """ Mask borders with value
+ Args:
+ m (torch.Tensor): [N, H0, W0, H1, W1]
+ b (int)
+ v (m.dtype)
+ """
+ if b <= 0:
+ return
+
+ m[:, :b] = v
+ m[:, :, :b] = v
+ m[:, :, :, :b] = v
+ m[:, :, :, :, :b] = v
+ m[:, -b:] = v
+ m[:, :, -b:] = v
+ m[:, :, :, -b:] = v
+ m[:, :, :, :, -b:] = v
+
+
+def mask_border_with_padding(m, bd, v, p_m0, p_m1):
+ if bd <= 0:
+ return
+
+ m[:, :bd] = v
+ m[:, :, :bd] = v
+ m[:, :, :, :bd] = v
+ m[:, :, :, :, :bd] = v
+
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
+ m[b_idx, h0 - bd:] = v
+ m[b_idx, :, w0 - bd:] = v
+ m[b_idx, :, :, h1 - bd:] = v
+ m[b_idx, :, :, :, w1 - bd:] = v
+
+
+def compute_max_candidates(p_m0, p_m1):
+ """Compute the max candidates of all pairs within a batch
+
+ Args:
+ p_m0, p_m1 (torch.Tensor): padded masks
+ """
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
+ max_cand = torch.sum(
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+ return max_cand
+
+
+class CoarseMatching(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # general config
+ self.thr = config['thr']
+ self.border_rm = config['border_rm']
+ # -- # for trainig fine-level LoFTR
+ self.train_coarse_percent = config['train_coarse_percent']
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
+
+ # we provide 2 options for differentiable matching
+ self.match_type = config['match_type']
+ if self.match_type == 'dual_softmax':
+ self.temperature = config['dsmax_temperature']
+ elif self.match_type == 'sinkhorn':
+ try:
+ from .superglue import log_optimal_transport
+ except ImportError:
+ raise ImportError("download superglue.py first!")
+ self.log_optimal_transport = log_optimal_transport
+ self.bin_score = nn.Parameter(
+ torch.tensor(config['skh_init_bin_score'], requires_grad=True))
+ self.skh_iters = config['skh_iters']
+ self.skh_prefilter = config['skh_prefilter']
+ else:
+ raise NotImplementedError()
+
+ def forward(self, data):
+ """
+ Args:
+ data (dict)
+ Update:
+ data (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ NOTE: M' != M during training.
+ """
+ conf_matrix = data['conf_matrix']
+ # predict coarse matches from conf_matrix
+ data.update(**self.get_coarse_match(conf_matrix, data))
+
+ @torch.no_grad()
+ def get_coarse_match(self, conf_matrix, data):
+ """
+ Args:
+ conf_matrix (torch.Tensor): [N, L, S]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ axes_lengths = {
+ 'h0c': data['hw0_c'][0],
+ 'w0c': data['hw0_c'][1],
+ 'h1c': data['hw1_c'][0],
+ 'w1c': data['hw1_c'][1]
+ }
+ _device = conf_matrix.device
+ # 1. confidence thresholding
+ mask = conf_matrix > self.thr
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
+ **axes_lengths)
+ if 'mask0' not in data:
+ mask_border(mask, self.border_rm, False)
+ else:
+ mask_border_with_padding(mask, self.border_rm, False,
+ data['mask0'], data['mask1'])
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
+ **axes_lengths)
+
+ # 2. mutual nearest
+ mask = mask \
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
+
+ # 3. find all valid coarse matches
+ # this only works when at most one `True` in each row
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
+
+ # 4. Random sampling of training samples for fine-level LoFTR
+ # (optional) pad samples with gt coarse-level matches
+ if self.training:
+ # NOTE:
+ # The sampling is performed across all pairs in a batch without manually balancing
+ # #samples for fine-level increases w.r.t. batch_size
+ if 'mask0' not in data:
+ num_candidates_max = mask.size(0) * max(
+ mask.size(1), mask.size(2))
+ else:
+ num_candidates_max = compute_max_candidates(
+ data['mask0'], data['mask1'])
+ num_matches_train = int(num_candidates_max *
+ self.train_coarse_percent)
+ num_matches_pred = len(b_ids)
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
+
+ # pred_indices is to select from prediction
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
+ pred_indices = torch.arange(num_matches_pred, device=_device)
+ else:
+ pred_indices = torch.randint(
+ num_matches_pred,
+ (num_matches_train - self.train_pad_num_gt_min, ),
+ device=_device)
+
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
+ gt_pad_indices = torch.randint(
+ len(data['spv_b_ids']),
+ (max(num_matches_train - num_matches_pred,
+ self.train_pad_num_gt_min), ),
+ device=_device)
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
+
+ b_ids, i_ids, j_ids, mconf = map(
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
+ dim=0),
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+
+ # These matches select patches that feed into fine-level network
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+
+ # 4. Update with matches in original image resolution
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+ mkpts0_c = torch.stack(
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
+ dim=1) * scale0
+ mkpts1_c = torch.stack(
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
+ dim=1) * scale1
+
+ # These matches is the current prediction (for visualization)
+ coarse_matches.update({
+ 'gt_mask': mconf == 0,
+ 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
+ 'mkpts0_c': mkpts0_c[mconf != 0],
+ 'mkpts1_c': mkpts1_c[mconf != 0],
+ 'mconf': mconf[mconf != 0]
+ })
+
+ return coarse_matches
diff --git a/imcui/third_party/TopicFM/src/models/utils/fine_matching.py b/imcui/third_party/TopicFM/src/models/utils/fine_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..018f2fe475600b319998c263a97237ce135c3aaf
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/utils/fine_matching.py
@@ -0,0 +1,80 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+
+class FineMatching(nn.Module):
+ """FineMatching with s2d paradigm"""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, feat_f0, feat_f1, data):
+ """
+ Args:
+ feat0 (torch.Tensor): [M, WW, C]
+ feat1 (torch.Tensor): [M, WW, C]
+ data (dict)
+ Update:
+ data (dict):{
+ 'expec_f' (torch.Tensor): [M, 3],
+ 'mkpts0_f' (torch.Tensor): [M, 2],
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
+ """
+ M, WW, C = feat_f0.shape
+ W = int(math.sqrt(WW))
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
+
+ # corner case: if no coarse matches found
+ if M == 0:
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
+ # logger.warning('No matches found in coarse-level.')
+ data.update({
+ 'expec_f': torch.empty(0, 3, device=feat_f0.device),
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+
+ feat_f0_picked = feat_f0[:, WW//2, :]
+
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
+ softmax_temp = 1. / C**.5
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
+ feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) # [M, C]
+ heatmap = heatmap.view(-1, W, W)
+
+ # compute coordinates from heatmap
+ coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
+
+ # compute std over
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords1_normalized**2 # [M, 2]
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
+
+ # for fine-level supervision
+ data.update({'expec_f': torch.cat([coords1_normalized, std.unsqueeze(1)], -1),
+ 'descriptors0': feat_f0_picked.detach(), 'descriptors1': feat_f1_picked.detach()})
+
+ # compute absolute kpt coords
+ self.get_fine_match(coords1_normalized, data)
+
+ @torch.no_grad()
+ def get_fine_match(self, coords1_normed, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+
+ # mkpts0_f and mkpts1_f
+ # scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
+ mkpts0_f = data['mkpts0_c'] # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])]
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale1' in data else scale
+ mkpts1_f = data['mkpts1_c'] + (coords1_normed * (W // 2) * scale1)[:len(data['mconf'])]
+
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
diff --git a/imcui/third_party/TopicFM/src/models/utils/geometry.py b/imcui/third_party/TopicFM/src/models/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/utils/geometry.py
@@ -0,0 +1,54 @@
+import torch
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - ,
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ kpts0_long = kpts0.round().long()
+
+ # Sample depth, get calculable_mask on depth != 0
+ kpts0_depth = torch.stack(
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+ ) # (N, L)
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+ w_kpts0_long = w_kpts0.long()
+ w_kpts0_long[~covisible_mask, :] = 0
+
+ w_kpts0_depth = torch.stack(
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+ ) # (N, L)
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
diff --git a/imcui/third_party/TopicFM/src/models/utils/supervision.py b/imcui/third_party/TopicFM/src/models/utils/supervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f1f0478fdcbe7f8ceffbc4aff4d507cec55bbd2
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/models/utils/supervision.py
@@ -0,0 +1,151 @@
+from math import log
+from loguru import logger
+
+import torch
+from einops import repeat
+from kornia.utils import create_meshgrid
+
+from .geometry import warp_kpts
+
+############## ↓ Coarse-Level supervision ↓ ##############
+
+
+@torch.no_grad()
+def mask_pts_at_padded_regions(grid_pt, mask):
+ """For megadepth dataset, zero-padding exists in images"""
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+ grid_pt[~mask.bool()] = 0
+ return grid_pt
+
+
+@torch.no_grad()
+def spvs_coarse(data, config):
+ """
+ Update:
+ data (dict): {
+ "conf_matrix_gt": [N, hw0, hw1],
+ 'spv_b_ids': [M]
+ 'spv_i_ids': [M]
+ 'spv_j_ids': [M]
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
+ }
+
+ NOTE:
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
+ """
+ # 1. misc
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+ scale = config['MODEL']['RESOLUTION'][0]
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+
+ # 2. warp grids
+ # create kpts in meshgrid and resize them to image resolution
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
+ grid_pt0_i = scale0 * grid_pt0_c
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+ grid_pt1_i = scale1 * grid_pt1_c
+
+ # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
+ if 'mask0' in data:
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
+
+ # warp kpts bi-directionally and resize them to coarse-level resolution
+ # (no depth consistency check, since it leads to worse results experimentally)
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
+ _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
+ _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
+ w_pt0_c = w_pt0_i / scale1
+ w_pt1_c = w_pt1_i / scale0
+
+ # 3. check if mutual nearest neighbor
+ w_pt0_c_round = w_pt0_c[:, :, :].round().long()
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
+ w_pt1_c_round = w_pt1_c[:, :, :].round().long()
+ nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
+
+ # corner case: out of boundary
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
+ nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
+
+ loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
+ correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+ correct_0to1[:, 0] = False # ignore the top-left corner
+
+ # 4. construct a gt conf_matrix
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
+ j_ids = nearest_index1[b_ids, i_ids]
+
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
+ data.update({'conf_matrix_gt': conf_matrix_gt})
+
+ # 5. save coarse matches(gt) for training fine level
+ if len(b_ids) == 0:
+ logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
+ # this won't affect fine-level loss calculation
+ b_ids = torch.tensor([0], device=device)
+ i_ids = torch.tensor([0], device=device)
+ j_ids = torch.tensor([0], device=device)
+
+ data.update({
+ 'spv_b_ids': b_ids,
+ 'spv_i_ids': i_ids,
+ 'spv_j_ids': j_ids
+ })
+
+ # 6. save intermediate results (for fast fine-level computation)
+ data.update({
+ 'spv_w_pt0_i': w_pt0_i,
+ 'spv_pt1_i': grid_pt1_i
+ })
+
+
+def compute_supervision_coarse(data, config):
+ assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_coarse(data, config)
+ else:
+ raise ValueError(f'Unknown data source: {data_source}')
+
+
+############## ↓ Fine-Level supervision ↓ ##############
+
+@torch.no_grad()
+def spvs_fine(data, config):
+ """
+ Update:
+ data (dict):{
+ "expec_f_gt": [M, 2]}
+ """
+ # 1. misc
+ # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
+ w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
+ scale = config['MODEL']['RESOLUTION'][1]
+ radius = config['MODEL']['FINE_WINDOW_SIZE'] // 2
+
+ # 2. get coarse prediction
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+
+ # 3. compute gt
+ scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
+ # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
+ expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
+ data.update({"expec_f_gt": expec_f_gt})
+
+
+def compute_supervision_fine(data, config):
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_fine(data, config)
+ else:
+ raise NotImplementedError
diff --git a/imcui/third_party/TopicFM/src/optimizers/__init__.py b/imcui/third_party/TopicFM/src/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/optimizers/__init__.py
@@ -0,0 +1,42 @@
+import torch
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
+
+
+def build_optimizer(model, config):
+ name = config.TRAINER.OPTIMIZER
+ lr = config.TRAINER.TRUE_LR
+
+ if name == "adam":
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
+ elif name == "adamw":
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
+ else:
+ raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
+
+
+def build_scheduler(config, optimizer):
+ """
+ Returns:
+ scheduler (dict):{
+ 'scheduler': lr_scheduler,
+ 'interval': 'step', # or 'epoch'
+ 'monitor': 'val_f1', (optional)
+ 'frequency': x, (optional)
+ }
+ """
+ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
+ name = config.TRAINER.SCHEDULER
+
+ if name == 'MultiStepLR':
+ scheduler.update(
+ {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
+ elif name == 'CosineAnnealing':
+ scheduler.update(
+ {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
+ elif name == 'ExponentialLR':
+ scheduler.update(
+ {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
+ else:
+ raise NotImplementedError()
+
+ return scheduler
diff --git a/imcui/third_party/TopicFM/src/utils/augment.py b/imcui/third_party/TopicFM/src/utils/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c5d3e11b6fe083aaeff7555bb7ce3a4bfb755d
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/augment.py
@@ -0,0 +1,55 @@
+import albumentations as A
+
+
+class DarkAug(object):
+ """
+ Extreme dark augmentation aiming at Aachen Day-Night
+ """
+
+ def __init__(self) -> None:
+ self.augmentor = A.Compose([
+ A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
+ A.Blur(p=0.1, blur_limit=(3, 9)),
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
+ ], p=0.75)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+class MobileAug(object):
+ """
+ Random augmentations aiming at images of mobile/handhold devices.
+ """
+
+ def __init__(self):
+ self.augmentor = A.Compose([
+ A.MotionBlur(p=0.25),
+ A.ColorJitter(p=0.5),
+ A.RandomRain(p=0.1), # random occlusion
+ A.RandomSunFlare(p=0.1),
+ A.JpegCompression(p=0.25),
+ A.ISONoise(p=0.25)
+ ], p=1.0)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+def build_augmentor(method=None, **kwargs):
+ if method is not None:
+ raise NotImplementedError('Using of augmentation functions are not supported yet!')
+ if method == 'dark':
+ return DarkAug()
+ elif method == 'mobile':
+ return MobileAug()
+ elif method is None:
+ return None
+ else:
+ raise ValueError(f'Invalid augmentation method: {method}')
+
+
+if __name__ == '__main__':
+ augmentor = build_augmentor('FDA')
diff --git a/imcui/third_party/TopicFM/src/utils/comm.py b/imcui/third_party/TopicFM/src/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/comm.py
@@ -0,0 +1,265 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+[Copied from detectron2]
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group=group) == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving Tensor from all ranks
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/imcui/third_party/TopicFM/src/utils/dataloader.py b/imcui/third_party/TopicFM/src/utils/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/dataloader.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+# --- PL-DATAMODULE ---
+
+def get_local_split(items: list, world_size: int, rank: int, seed: int):
+ """ The local rank only loads a split of the dataset. """
+ n_items = len(items)
+ items_permute = np.random.RandomState(seed).permutation(items)
+ if n_items % world_size == 0:
+ padded_items = items_permute
+ else:
+ padding = np.random.RandomState(seed).choice(
+ items,
+ world_size - (n_items % world_size),
+ replace=True)
+ padded_items = np.concatenate([items_permute, padding])
+ assert len(padded_items) % world_size == 0, \
+ f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
+ n_per_rank = len(padded_items) // world_size
+ local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
+
+ return local_items
diff --git a/imcui/third_party/TopicFM/src/utils/dataset.py b/imcui/third_party/TopicFM/src/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..647bbadd821b6c90736ed45462270670b1017b0b
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/dataset.py
@@ -0,0 +1,201 @@
+import io
+from loguru import logger
+
+import cv2
+import numpy as np
+import h5py
+import torch
+from numpy.linalg import inv
+
+
+MEGADEPTH_CLIENT = SCANNET_CLIENT = None
+
+# --- DATA IO ---
+
+def load_array_from_s3(
+ path, client, cv_type,
+ use_h5py=False,
+):
+ byte_str = client.Get(path)
+ try:
+ if not use_h5py:
+ raw_array = np.fromstring(byte_str, np.uint8)
+ data = cv2.imdecode(raw_array, cv_type)
+ else:
+ f = io.BytesIO(byte_str)
+ data = np.array(h5py.File(f, 'r')['/depth'])
+ except Exception as ex:
+ print(f"==> Data loading failure: {path}")
+ raise ex
+
+ assert data is not None
+ return data
+
+
+def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
+ else cv2.IMREAD_COLOR
+ if str(path).startswith('s3://'):
+ image = load_array_from_s3(str(path), client, cv_type)
+ else:
+ image = cv2.imread(str(path), cv_type)
+
+ if augment_fn is not None:
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = augment_fn(image)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ return image # (h, w)
+
+
+def get_resized_wh(w, h, resize=None):
+ if (resize is not None) and (max(h,w) > resize): # resize the longer edge
+ scale = resize / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def get_divisible_wh(w, h, df=None):
+ if df is not None:
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def pad_bottom_right(inp, pad_size, ret_mask=False):
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+ mask = None
+ if inp.ndim == 2:
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
+ padded[:inp.shape[0], :inp.shape[1]] = inp
+ if ret_mask:
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
+ mask[:inp.shape[0], :inp.shape[1]] = True
+ elif inp.ndim == 3:
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
+ if ret_mask:
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
+ else:
+ raise NotImplementedError()
+ return padded, mask
+
+
+# --- MEGADEPTH ---
+
+def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read image
+ image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ if padding: # padding
+ pad_to = resize #max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ else:
+ mask = None
+
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ mask = torch.from_numpy(mask) if mask is not None else None
+
+ return image, mask, scale
+
+
+def read_megadepth_depth(path, pad_to=None):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
+ else:
+ depth = np.array(h5py.File(path, 'r')['depth'])
+ if pad_to is not None:
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+# --- ScanNet ---
+
+def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
+ """
+ Args:
+ resize (tuple): align image to depthmap, in (w, h).
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read and resize image
+ image = imread_gray(path, augment_fn)
+ image = cv2.resize(image, resize)
+
+ # (h, w) -> (1, h, w) and normalized
+ image = torch.from_numpy(image).float()[None] / 255
+ return image
+
+
+# ---- evaluation datasets: HLoc, Aachen, InLoc
+
+def read_img_gray(path, resize=None, down_factor=16):
+ # read and resize image
+ image = imread_gray(path, None)
+ w, h = image.shape[1], image.shape[0]
+ if (resize is not None) and (max(h, w) > resize):
+ scale = float(resize / max(h, w))
+ w_new, h_new = int(round(w * scale)), int(round(h * scale))
+ else:
+ w_new, h_new = w, h
+ w_new, h_new = get_divisible_wh(w_new, h_new, down_factor)
+ image = cv2.resize(image, (w_new, h_new))
+
+ # (h, w) -> (1, h, w) and normalized
+ image = torch.from_numpy(image).float()[None] / 255
+ scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float)
+ return image, scale
+
+
+def read_scannet_depth(path):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
+ else:
+ depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+def read_scannet_pose(path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = inv(cam2world)
+ return world2cam
+
+
+def read_scannet_intrinsic(path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return intrinsic[:-1, :-1]
diff --git a/imcui/third_party/TopicFM/src/utils/metrics.py b/imcui/third_party/TopicFM/src/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..a93c31ed1d151cd41e2449a19be2d6abc5f9d419
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/metrics.py
@@ -0,0 +1,193 @@
+import torch
+import cv2
+import numpy as np
+from collections import OrderedDict
+from loguru import logger
+from kornia.geometry.epipolar import numeric
+from kornia.geometry.conversions import convert_points_to_homogeneous
+
+
+# --- METRICS ---
+
+def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
+ # angle error between 2 vectors
+ t_gt = T_0to1[:3, 3]
+ n = np.linalg.norm(t) * np.linalg.norm(t_gt)
+ t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
+ t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
+ if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
+ t_err = 0
+
+ # angle error between 2 rotation matrices
+ R_gt = T_0to1[:3, :3]
+ cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
+ cos = np.clip(cos, -1., 1.) # handle numercial errors
+ R_err = np.rad2deg(np.abs(np.arccos(cos)))
+
+ return t_err, R_err
+
+
+def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
+ """Squared symmetric epipolar distance.
+ This can be seen as a biased estimation of the reprojection error.
+ Args:
+ pts0 (torch.Tensor): [N, 2]
+ E (torch.Tensor): [3, 3]
+ """
+ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+ pts0 = convert_points_to_homogeneous(pts0)
+ pts1 = convert_points_to_homogeneous(pts1)
+
+ Ep0 = pts0 @ E.T # [N, 3]
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
+ Etp1 = pts1 @ E # [N, 3]
+
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
+ return d
+
+
+def compute_symmetrical_epipolar_errors(data):
+ """
+ Update:
+ data (dict):{"epi_errs": [M]}
+ """
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
+
+ m_bids = data['m_bids']
+ pts0 = data['mkpts0_f']
+ pts1 = data['mkpts1_f']
+
+ epi_errs = []
+ for bs in range(Tx.size(0)):
+ mask = m_bids == bs
+ epi_errs.append(
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+ epi_errs = torch.cat(epi_errs, dim=0)
+
+ data.update({'epi_errs': epi_errs})
+
+
+def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ # normalize keypoints
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+
+ # normalize ransac threshold
+ ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
+
+ # compute pose with cv2
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
+ if E is None:
+ print("\nE is None while trying to recover pose.\n")
+ return None
+
+ # recover pose from E
+ best_num_inliers = 0
+ ret = None
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ ret = (R, t[:, 0], mask.ravel() > 0)
+ best_num_inliers = n
+
+ return ret
+
+
+def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999):
+ """
+ Update:
+ data (dict):{
+ "R_errs" List[float]: [N]
+ "t_errs" List[float]: [N]
+ "inliers" List[np.ndarray]: [N]
+ }
+ """
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr # 0.5
+ conf = config.TRAINER.RANSAC_CONF if config is not None else ransac_conf # 0.99999
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+
+ m_bids = data['m_bids'].cpu().numpy()
+ pts0 = data['mkpts0_f'].cpu().numpy()
+ pts1 = data['mkpts1_f'].cpu().numpy()
+ K0 = data['K0'].cpu().numpy()
+ K1 = data['K1'].cpu().numpy()
+ T_0to1 = data['T_0to1'].cpu().numpy()
+
+ for bs in range(K0.shape[0]):
+ mask = m_bids == bs
+ ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
+
+ if ret is None:
+ data['R_errs'].append(np.inf)
+ data['t_errs'].append(np.inf)
+ data['inliers'].append(np.array([]).astype(np.bool))
+ else:
+ R, t, inliers = ret
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
+ data['R_errs'].append(R_err)
+ data['t_errs'].append(t_err)
+ data['inliers'].append(inliers)
+
+
+# --- METRIC AGGREGATION ---
+
+def error_auc(errors, thresholds):
+ """
+ Args:
+ errors (list): [N,]
+ thresholds (list)
+ """
+ errors = [0] + sorted(list(errors))
+ recall = list(np.linspace(0, 1, len(errors)))
+
+ aucs = []
+ thresholds = [5, 10, 20]
+ for thr in thresholds:
+ last_index = np.searchsorted(errors, thr)
+ y = recall[:last_index] + [recall[last_index-1]]
+ x = errors[:last_index] + [thr]
+ aucs.append(np.trapz(y, x) / thr)
+
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+
+
+def epidist_prec(errors, thresholds, ret_dict=False):
+ precs = []
+ for thr in thresholds:
+ prec_ = []
+ for errs in errors:
+ correct_mask = errs < thr
+ prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
+ precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
+ if ret_dict:
+ return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
+ else:
+ return precs
+
+
+def aggregate_metrics(metrics, epi_err_thr=5e-4):
+ """ Aggregate metrics for the whole dataset:
+ (This method should be called once per dataset)
+ 1. AUC of the pose error (angular) at the threshold [5, 10, 20]
+ 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
+ """
+ # filter duplicates
+ unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
+ unq_ids = list(unq_ids.values())
+ logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
+
+ # pose auc
+ angular_thresholds = [5, 10, 20]
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
+ aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
+
+ # matching precision
+ dist_thresholds = [epi_err_thr]
+ precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
+
+ return {**aucs, **precs}
diff --git a/imcui/third_party/TopicFM/src/utils/misc.py b/imcui/third_party/TopicFM/src/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8db04666519753ea2df43903ab6c47ec00a9a1
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/misc.py
@@ -0,0 +1,101 @@
+import os
+import contextlib
+import joblib
+from typing import Union
+from loguru import _Logger, logger
+from itertools import chain
+
+import torch
+from yacs.config import CfgNode as CN
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+def upper_config(dict_cfg):
+ if not isinstance(dict_cfg, dict):
+ return dict_cfg
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
+
+
+def log_on(condition, message, level):
+ if condition:
+ assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
+ logger.log(level, message)
+
+
+def get_rank_zero_only_logger(logger: _Logger):
+ if rank_zero_only.rank == 0:
+ return logger
+ else:
+ for _level in logger._core.levels.keys():
+ level = _level.lower()
+ setattr(logger, level,
+ lambda x: None)
+ logger._log = lambda x: None
+ return logger
+
+
+def setup_gpus(gpus: Union[str, int]) -> int:
+ """ A temporary fix for pytorch-lighting 1.3.x """
+ gpus = str(gpus)
+ gpu_ids = []
+
+ if ',' not in gpus:
+ n_gpus = int(gpus)
+ return n_gpus if n_gpus != -1 else torch.cuda.device_count()
+ else:
+ gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
+
+ # setup environment variables
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ if visible_devices is None:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
+ else:
+ logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
+ return len(gpu_ids)
+
+
+def flattenList(x):
+ return list(chain(*x))
+
+
+@contextlib.contextmanager
+def tqdm_joblib(tqdm_object):
+ """Context manager to patch joblib to report into tqdm progress bar given as argument
+
+ Usage:
+ with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
+ Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
+
+ When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
+ ret_vals = Parallel(n_jobs=args.world_size)(
+ delayed(lambda x: _compute_cov_score(pid, *x))(param)
+ for param in tqdm(combinations(image_ids, 2),
+ desc=f'Computing cov_score of [{pid}]',
+ total=len(image_ids)*(len(image_ids)-1)/2))
+ Src: https://stackoverflow.com/a/58936697
+ """
+ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ tqdm_object.update(n=self.batch_size)
+ return super().__call__(*args, **kwargs)
+
+ old_batch_callback = joblib.parallel.BatchCompletionCallBack
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
+ try:
+ yield tqdm_object
+ finally:
+ joblib.parallel.BatchCompletionCallBack = old_batch_callback
+ tqdm_object.close()
+
diff --git a/imcui/third_party/TopicFM/src/utils/plotting.py b/imcui/third_party/TopicFM/src/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b22ef27e6152225d07ab24bb3e62718d180b59
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/plotting.py
@@ -0,0 +1,313 @@
+import bisect
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib, os, cv2
+import matplotlib.cm as cm
+from PIL import Image
+import torch.nn.functional as F
+import torch
+
+
+def _compute_conf_thresh(data):
+ dataset_name = data['dataset_name'][0].lower()
+ if dataset_name == 'scannet':
+ thr = 5e-4
+ elif dataset_name == 'megadepth':
+ thr = 1e-4
+ else:
+ raise ValueError(f'Unknown dataset: {dataset_name}')
+ return thr
+
+
+# --- VISUALIZATION --- #
+
+def make_matching_figure(
+ img0, img1, mkpts0, mkpts1, color,
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+ # draw image pair
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0) # , cmap='gray')
+ axes[1].imshow(img1) # , cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if kpts0 is not None:
+ assert kpts1 is not None
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=5)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=5)
+
+ # draw matches
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure, c=color[i], linewidth=2)
+ for i in range(len(mkpts0))]
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color[..., :3], s=4)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color[..., :3], s=4)
+
+ # put txts
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def _make_evaluation_figure(data, b_id, alpha='dynamic'):
+ b_mask = data['m_bids'] == b_id
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
+ kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
+
+ # for megadepth, we visualize matches on the resized image
+ if 'scale0' in data:
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
+
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+ ]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
+ color, text=text)
+ return figure
+
+def _make_confidence_figure(data, b_id):
+ # TODO: Implement confidence figure
+ raise NotImplementedError()
+
+
+def make_matching_figures(data, config, mode='evaluation'):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ['evaluation', 'confidence'] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ fig = _make_evaluation_figure(
+ data, b_id,
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
+ elif mode == 'confidence':
+ fig = _make_confidence_figure(data, b_id)
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+
+def dynamic_alpha(n_matches,
+ milestones=[0, 300, 1000, 2000],
+ alphas=[1.0, 0.8, 0.4, 0.2]):
+ if n_matches == 0:
+ return 1.0
+ ranges = list(zip(alphas, alphas[1:] + [None]))
+ loc = bisect.bisect_right(milestones, n_matches) - 1
+ _range = ranges[loc]
+ if _range[1] is None:
+ return _range[0]
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+
+
+def error_colormap(err, thr, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
+ return np.clip(
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
+
+
+np.random.seed(1995)
+color_map = np.arange(100)
+np.random.shuffle(color_map)
+
+
+def draw_topics(data, img0, img1, saved_folder="viz_topics", show_n_topics=8, saved_name=None):
+
+ topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
+ hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
+ hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
+ # print(hw0_i, hw1_i)
+ scale0, scale1 = hw0_i[0] // hw0_c[0], hw1_i[0] // hw1_c[0]
+ if "scale0" in data:
+ scale0 *= data["scale0"][0]
+ else:
+ scale0 = (scale0, scale0)
+ if "scale1" in data:
+ scale1 *= data["scale1"][0]
+ else:
+ scale1 = (scale1, scale1)
+
+ n_topics = topic0.shape[-1]
+ # mask0_nonzero = topic0[0].sum(dim=-1, keepdim=True) > 0
+ # mask1_nonzero = topic1[0].sum(dim=-1, keepdim=True) > 0
+ theta0 = topic0[0].sum(dim=0)
+ theta0 /= theta0.sum().float()
+ theta1 = topic1[0].sum(dim=0)
+ theta1 /= theta1.sum().float()
+ # top_topic0 = torch.argsort(theta0, descending=True)[:show_n_topics]
+ # top_topic1 = torch.argsort(theta1, descending=True)[:show_n_topics]
+ top_topics = torch.argsort(theta0*theta1, descending=True)[:show_n_topics]
+ # print(sum_topic0, sum_topic1)
+
+ topic0 = topic0[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1 #
+ # topic0[~mask0_nonzero] = -1
+ topic1 = topic1[0].argmax(dim=-1, keepdim=True) #.float() / (n_topics - 1) #* 255 + 1
+ # topic1[~mask1_nonzero] = -1
+ label_img0, label_img1 = torch.zeros_like(topic0) - 1, torch.zeros_like(topic1) - 1
+ for i, k in enumerate(top_topics):
+ label_img0[topic0 == k] = color_map[k]
+ label_img1[topic1 == k] = color_map[k]
+
+# print(hw0_c, scale0)
+# print(hw1_c, scale1)
+ # map_topic0 = F.fold(label_img0.unsqueeze(0), hw0_i, kernel_size=scale0, stride=scale0)
+ map_topic0 = label_img0.float().view(hw0_c).cpu().numpy() #map_topic0.squeeze(0).squeeze(0).cpu().numpy()
+ map_topic0 = cv2.resize(map_topic0, (int(hw0_c[1] * scale0[0]), int(hw0_c[0] * scale0[1])))
+ # map_topic1 = F.fold(label_img1.unsqueeze(0), hw1_i, kernel_size=scale1, stride=scale1)
+ map_topic1 = label_img1.float().view(hw1_c).cpu().numpy() #map_topic1.squeeze(0).squeeze(0).cpu().numpy()
+ map_topic1 = cv2.resize(map_topic1, (int(hw1_c[1] * scale1[0]), int(hw1_c[0] * scale1[1])))
+
+
+ # show image0
+ if saved_name is None:
+ return map_topic0, map_topic1
+
+ if not os.path.exists(saved_folder):
+ os.makedirs(saved_folder)
+ path_saved_img0 = os.path.join(saved_folder, "{}_0.png".format(saved_name))
+ plt.imshow(img0)
+ masked_map_topic0 = np.ma.masked_where(map_topic0 < 0, map_topic0)
+ plt.imshow(masked_map_topic0, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear')
+ # plt.show()
+ plt.axis('off')
+ plt.savefig(path_saved_img0, bbox_inches='tight', pad_inches=0, dpi=250)
+ plt.close()
+
+ path_saved_img1 = os.path.join(saved_folder, "{}_1.png".format(saved_name))
+ plt.imshow(img1)
+ masked_map_topic1 = np.ma.masked_where(map_topic1 < 0, map_topic1)
+ plt.imshow(masked_map_topic1, cmap=plt.cm.jet, vmin=0, vmax=n_topics-1, alpha=.3, interpolation='bilinear')
+ plt.axis('off')
+ plt.savefig(path_saved_img1, bbox_inches='tight', pad_inches=0, dpi=250)
+ plt.close()
+
+
+def draw_topicfm_demo(data, img0, img1, mkpts0, mkpts1, mcolor, text, show_n_topics=8,
+ topic_alpha=0.3, margin=5, path=None, opencv_display=False, opencv_title=''):
+ topic_map0, topic_map1 = draw_topics(data, img0, img1, show_n_topics=show_n_topics)
+
+ mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(topic_map1 >= 0, axis=-1)
+
+ topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.), cm.jet(topic_map1 / 99.)
+ topic_cm0 = cv2.cvtColor(topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
+ topic_cm1 = cv2.cvtColor(topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR)
+ overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
+ overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
+
+ cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
+ cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
+
+ overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (overlay1 * 255).astype(np.uint8)
+
+ h0, w0 = img0.shape[:2]
+ h1, w1 = img1.shape[:2]
+ h, w = h0 * 2 + margin * 2, w0 * 2 + margin
+ out_fig = 255 * np.ones((h, w, 3), dtype=np.uint8)
+ out_fig[:h0, :w0] = overlay0
+ if h0 >= h1:
+ start = (h0 - h1) // 2
+ out_fig[start:(start+h1), (w0+margin):(w0+margin+w1)] = overlay1
+ else:
+ start = (h1 - h0) // 2
+ out_fig[:h0, (w0+margin):(w0+margin+w1)] = overlay1[start:(start+h0)]
+
+ step_h = h0 + margin * 2
+ out_fig[step_h:step_h+h0, :w0] = (img0 * 255).astype(np.uint8)
+ if h0 >= h1:
+ start = step_h + (h0 - h1) // 2
+ out_fig[start:start+h1, (w0+margin):(w0+margin+w1)] = (img1 * 255).astype(np.uint8)
+ else:
+ start = (h1 - h0) // 2
+ out_fig[step_h:step_h+h0, (w0+margin):(w0+margin+w1)] = (img1[start:start+h0] * 255).astype(np.uint8)
+
+ # draw matching lines, this is inspried from https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
+ mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
+ mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
+
+ for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, mcolor):
+ c = c.tolist()
+ cv2.line(out_fig, (x0, y0+step_h), (x1+margin+w0, y1+step_h+(h0-h1)//2),
+ color=c, thickness=1, lineType=cv2.LINE_AA)
+ # display line end-points as circles
+ cv2.circle(out_fig, (x0, y0+step_h), 2, c, -1, lineType=cv2.LINE_AA)
+ cv2.circle(out_fig, (x1+margin+w0, y1+step_h+(h0-h1)//2), 2, c, -1, lineType=cv2.LINE_AA)
+
+ # Scale factor for consistent visualization across scales.
+ sc = min(h / 960., 2.0)
+
+ # Big text.
+ Ht = int(30 * sc) # text height
+ txt_color_fg = (255, 255, 255)
+ txt_color_bg = (0, 0, 0)
+ for i, t in enumerate(text):
+ cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA)
+ cv2.putText(out_fig, t, (int(8 * sc), Ht + step_h*i), cv2.FONT_HERSHEY_DUPLEX,
+ 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA)
+
+ if path is not None:
+ cv2.imwrite(str(path), out_fig)
+
+ if opencv_display:
+ cv2.imshow(opencv_title, out_fig)
+ cv2.waitKey(1)
+
+ return out_fig
+
+
+
+
+
+
diff --git a/imcui/third_party/TopicFM/src/utils/profiler.py b/imcui/third_party/TopicFM/src/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9
--- /dev/null
+++ b/imcui/third_party/TopicFM/src/utils/profiler.py
@@ -0,0 +1,39 @@
+import torch
+from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
+from contextlib import contextmanager
+from pytorch_lightning.utilities import rank_zero_only
+
+
+class InferenceProfiler(SimpleProfiler):
+ """
+ This profiler records duration of actions with cuda.synchronize()
+ Use this in test time.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.start = rank_zero_only(self.start)
+ self.stop = rank_zero_only(self.stop)
+ self.summary = rank_zero_only(self.summary)
+
+ @contextmanager
+ def profile(self, action_name: str) -> None:
+ try:
+ torch.cuda.synchronize()
+ self.start(action_name)
+ yield action_name
+ finally:
+ torch.cuda.synchronize()
+ self.stop(action_name)
+
+
+def build_profiler(name):
+ if name == 'inference':
+ return InferenceProfiler()
+ elif name == 'pytorch':
+ from pytorch_lightning.profiler import PyTorchProfiler
+ return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
+ elif name is None:
+ return PassThroughProfiler()
+ else:
+ raise ValueError(f'Invalid profiler: {name}')
diff --git a/imcui/third_party/TopicFM/test.py b/imcui/third_party/TopicFM/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb451cde3674b70b0d2e02f37ff1fd391004d30
--- /dev/null
+++ b/imcui/third_party/TopicFM/test.py
@@ -0,0 +1,68 @@
+import pytorch_lightning as pl
+import argparse
+import pprint
+from loguru import logger as loguru_logger
+
+from src.config.default import get_cfg_defaults
+from src.utils.profiler import build_profiler
+
+from src.lightning_trainer.data import MultiSceneDataModule
+from src.lightning_trainer.trainer import PL_Trainer
+
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
+ parser.add_argument(
+ '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
+ parser.add_argument(
+ '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--batch_size', type=int, default=1, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=2)
+ parser.add_argument(
+ '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+ # parse arguments
+ args = parse_args()
+ pprint.pprint(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+
+ # tune when testing
+ if args.thr is not None:
+ config.MODEL.MATCH_COARSE.THR = args.thr
+
+ loguru_logger.info(f"Args and config initialized!")
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
+ loguru_logger.info(f"Model-lightning initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"DataModule initialized!")
+
+ # lightning trainer
+ trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)
+
+ loguru_logger.info(f"Start testing!")
+ trainer.test(model, datamodule=data_module, verbose=False)
diff --git a/imcui/third_party/TopicFM/train.py b/imcui/third_party/TopicFM/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a552c23718b81ddcb282cedbfe3ceb45e50b3f29
--- /dev/null
+++ b/imcui/third_party/TopicFM/train.py
@@ -0,0 +1,123 @@
+import math
+import argparse
+import pprint
+from distutils.util import strtobool
+from pathlib import Path
+from loguru import logger as loguru_logger
+
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+from pytorch_lightning.plugins import DDPPlugin
+
+from src.config.default import get_cfg_defaults
+from src.utils.misc import get_rank_zero_only_logger, setup_gpus
+from src.utils.profiler import build_profiler
+from src.lightning_trainer.data import MultiSceneDataModule
+from src.lightning_trainer.trainer import PL_Trainer
+
+loguru_logger = get_rank_zero_only_logger(loguru_logger)
+
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--exp_name', type=str, default='default_exp_name')
+ parser.add_argument(
+ '--batch_size', type=int, default=4, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=4)
+ parser.add_argument(
+ '--pin_memory', type=lambda x: bool(strtobool(x)),
+ nargs='?', default=True, help='whether loading data to pinned memory or not')
+ parser.add_argument(
+ '--ckpt_path', type=str, default=None,
+ help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR')
+ parser.add_argument(
+ '--disable_ckpt', action='store_true',
+ help='disable checkpoint saving (useful for debugging).')
+ parser.add_argument(
+ '--profiler_name', type=str, default=None,
+ help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--parallel_load_data', action='store_true',
+ help='load datasets in with multiple processes.')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+
+def main():
+ # parse arguments
+ args = parse_args()
+ rank_zero_only(pprint.pprint)(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+ # TODO: Use different seeds for each dataloader workers
+ # This is needed for data augmentation
+
+ # scale lr and warmup-step automatically
+ args.gpus = _n_gpus = setup_gpus(args.gpus)
+ config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
+ config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
+ _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
+ config.TRAINER.SCALING = _scaling
+ config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
+ config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_Trainer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
+ loguru_logger.info(f"Model LightningModule initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"Model DataModule initialized!")
+
+ # TensorBoard Logger
+ logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
+ ckpt_dir = Path(logger.log_dir) / 'checkpoints'
+
+ # Callbacks
+ # TODO: update ModelCheckpoint to monitor multiple metrics
+ ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
+ save_last=True,
+ dirpath=str(ckpt_dir),
+ filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
+ lr_monitor = LearningRateMonitor(logging_interval='step')
+ callbacks = [lr_monitor]
+ if not args.disable_ckpt:
+ callbacks.append(ckpt_callback)
+
+ # Lightning Trainer
+ trainer = pl.Trainer.from_argparse_args(
+ args,
+ plugins=DDPPlugin(find_unused_parameters=False,
+ num_nodes=args.num_nodes,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
+ gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
+ callbacks=callbacks,
+ logger=logger,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
+ replace_sampler_ddp=False, # use custom sampler
+ reload_dataloaders_every_epoch=False, # avoid repeated samples!
+ weights_summary='full',
+ profiler=profiler)
+ loguru_logger.info(f"Trainer initialized!")
+ loguru_logger.info(f"Start training!")
+ trainer.fit(model, datamodule=data_module)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/imcui/third_party/TopicFM/visualization.py b/imcui/third_party/TopicFM/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..279b41cd88f61ce3414e2f3077fec642b2c8333a
--- /dev/null
+++ b/imcui/third_party/TopicFM/visualization.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+import os, glob, cv2
+import argparse
+from argparse import Namespace
+import yaml
+from tqdm import tqdm
+import torch
+from torch.utils.data import Dataset, DataLoader, SequentialSampler
+
+from src.datasets.custom_dataloader import TestDataLoader
+from src.utils.dataset import read_img_gray
+from configs.data.base import cfg as data_cfg
+import viz
+
+
+def get_model_config(method_name, dataset_name, root_dir='viz'):
+ config_file = f'{root_dir}/configs/{method_name}.yml'
+ with open(config_file, 'r') as f:
+ model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name]
+ return model_conf
+
+
+class DemoDataset(Dataset):
+ def __init__(self, dataset_dir, img_file=None, resize=0, down_factor=16):
+ self.dataset_dir = dataset_dir
+ if img_file is None:
+ self.list_img_files = glob.glob(os.path.join(dataset_dir, "*.*"))
+ self.list_img_files.sort()
+ else:
+ with open(img_file) as f:
+ self.list_img_files = [os.path.join(dataset_dir, img_file.strip()) for img_file in f.readlines()]
+ self.resize = resize
+ self.down_factor = down_factor
+
+ def __len__(self):
+ return len(self.list_img_files)
+
+ def __getitem__(self, idx):
+ img_path = self.list_img_files[idx] #os.path.join(self.dataset_dir, self.list_img_files[idx])
+ img, scale = read_img_gray(img_path, resize=self.resize, down_factor=self.down_factor)
+ return {"img": img, "id": idx, "img_path": img_path}
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Visualize matches')
+ parser.add_argument('--gpu', '-gpu', type=str, default='0')
+ parser.add_argument('--method', type=str, default=None)
+ parser.add_argument('--dataset_dir', type=str, default='data/aachen-day-night')
+ parser.add_argument('--pair_dir', type=str, default=None)
+ parser.add_argument(
+ '--dataset_name', type=str, choices=['megadepth', 'scannet', 'aachen_v1.1', 'inloc'], default='megadepth'
+ )
+ parser.add_argument('--measure_time', action="store_true")
+ parser.add_argument('--no_viz', action="store_true")
+ parser.add_argument('--compute_eval_metrics', action="store_true")
+ parser.add_argument('--run_demo', action="store_true")
+
+ args = parser.parse_args()
+
+ model_cfg = get_model_config(args.method, args.dataset_name)
+ class_name = model_cfg["class"]
+ model = viz.__dict__[class_name](model_cfg)
+ # all_args = Namespace(**vars(args), **model_cfg)
+ if not args.run_demo:
+ if args.dataset_name == 'megadepth':
+ from configs.data.megadepth_test_1500 import cfg
+
+ data_cfg.merge_from_other_cfg(cfg)
+ elif args.dataset_name == 'scannet':
+ from configs.data.scannet_test_1500 import cfg
+
+ data_cfg.merge_from_other_cfg(cfg)
+ elif args.dataset_name == 'aachen_v1.1':
+ data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "aachen_v1.1",
+ "DATASET.TEST_DATA_ROOT", os.path.join(args.dataset_dir, "images/images_upright"),
+ "DATASET.TEST_LIST_PATH", args.pair_dir,
+ "DATASET.TEST_IMGSIZE", model_cfg["imsize"]])
+ elif args.dataset_name == 'inloc':
+ data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "inloc",
+ "DATASET.TEST_DATA_ROOT", args.dataset_dir,
+ "DATASET.TEST_LIST_PATH", args.pair_dir,
+ "DATASET.TEST_IMGSIZE", model_cfg["imsize"]])
+
+ has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in ["megadepth", "scannet"]
+ dataloader = TestDataLoader(data_cfg)
+ with torch.no_grad():
+ for data_dict in tqdm(dataloader):
+ for k, v in data_dict.items():
+ if isinstance(v, torch.Tensor):
+ data_dict[k] = v.cuda() if torch.cuda.is_available() else v
+ img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT
+ model.match_and_draw(data_dict, root_dir=img_root_dir, ground_truth=has_ground_truth,
+ measure_time=args.measure_time, viz_matches=(not args.no_viz))
+
+ if args.measure_time:
+ print("Running time for each image is {} miliseconds".format(model.measure_time()))
+ if args.compute_eval_metrics and has_ground_truth:
+ model.compute_eval_metrics()
+ else:
+ demo_dataset = DemoDataset(args.dataset_dir, img_file=args.pair_dir, resize=640)
+ sampler = SequentialSampler(demo_dataset)
+ dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler)
+
+ writer = cv2.VideoWriter('topicfm_demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (640 * 2 + 5, 480 * 2 + 10))
+
+ model.run_demo(iter(dataloader), writer) #, output_dir="demo", no_display=True)
diff --git a/imcui/third_party/TopicFM/viz/__init__.py b/imcui/third_party/TopicFM/viz/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0efac33299da6fb8195ce70bcb9eb210f6cf658
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/__init__.py
@@ -0,0 +1,3 @@
+from .methods.patch2pix import VizPatch2Pix
+from .methods.loftr import VizLoFTR
+from .methods.topicfm import VizTopicFM
diff --git a/imcui/third_party/TopicFM/viz/configs/__init__.py b/imcui/third_party/TopicFM/viz/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/TopicFM/viz/configs/loftr.yml b/imcui/third_party/TopicFM/viz/configs/loftr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..776d625ac8ad5a0b4e4a4e65e2b99f62662bc3fc
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/configs/loftr.yml
@@ -0,0 +1,18 @@
+default: &default
+ class: 'VizLoFTR'
+ ckpt: 'third_party/loftr/pretrained/outdoor_ds.ckpt'
+ match_threshold: 0.2
+megadepth:
+ <<: *default
+scannet:
+ <<: *default
+hpatch:
+ <<: *default
+inloc:
+ <<: *default
+ imsize: 1024
+ match_threshold: 0.3
+aachen_v1.1:
+ <<: *default
+ imsize: 1024
+ match_threshold: 0.3
diff --git a/imcui/third_party/TopicFM/viz/configs/patch2pix.yml b/imcui/third_party/TopicFM/viz/configs/patch2pix.yml
new file mode 100644
index 0000000000000000000000000000000000000000..5e3efa7889098425aaf586bd7b88fc28feb74778
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/configs/patch2pix.yml
@@ -0,0 +1,19 @@
+default: &default
+ class: 'VizPatch2Pix'
+ ckpt: 'third_party/patch2pix/pretrained/patch2pix_pretrained.pth'
+ ksize: 2
+ imsize: 1024
+ match_threshold: 0.25
+megadepth:
+ <<: *default
+ imsize: 1200
+scannet:
+ <<: *default
+ imsize: [640, 480]
+hpatch:
+ <<: *default
+inloc:
+ <<: *default
+aachen_v1.1:
+ <<: *default
+ imsize: 1024
diff --git a/imcui/third_party/TopicFM/viz/configs/topicfm.yml b/imcui/third_party/TopicFM/viz/configs/topicfm.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7a8071a6fcd8def21dbfec5b9b2b10200f494eee
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/configs/topicfm.yml
@@ -0,0 +1,29 @@
+default: &default
+ class: 'VizTopicFM'
+ ckpt: 'pretrained/model_best.ckpt'
+ match_threshold: 0.2
+ n_sampling_topics: 4
+ show_n_topics: 4
+megadepth:
+ <<: *default
+ n_sampling_topics: 10
+ show_n_topics: 6
+scannet:
+ <<: *default
+ match_threshold: 0.3
+ n_sampling_topics: 5
+ show_n_topics: 4
+hpatch:
+ <<: *default
+inloc:
+ <<: *default
+ imsize: 1024
+ match_threshold: 0.3
+ n_sampling_topics: 8
+ show_n_topics: 4
+aachen_v1.1:
+ <<: *default
+ imsize: 1024
+ match_threshold: 0.3
+ n_sampling_topics: 6
+ show_n_topics: 6
diff --git a/imcui/third_party/TopicFM/viz/methods/__init__.py b/imcui/third_party/TopicFM/viz/methods/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/TopicFM/viz/methods/base.py b/imcui/third_party/TopicFM/viz/methods/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..377e95134f339459bff3c5a0d30b3bfbc122d978
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/methods/base.py
@@ -0,0 +1,59 @@
+import pprint
+from abc import ABCMeta, abstractmethod
+import torch
+from itertools import chain
+
+from src.utils.plotting import make_matching_figure, error_colormap
+from src.utils.metrics import aggregate_metrics
+
+
+def flatten_list(x):
+ return list(chain(*x))
+
+
+class Viz(metaclass=ABCMeta):
+ def __init__(self):
+ super().__init__()
+ self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
+ torch.set_grad_enabled(False)
+
+ # for evaluation metrics of MegaDepth and ScanNet
+ self.eval_stats = []
+ self.time_stats = []
+
+ def draw_matches(self, mkpts0, mkpts1, img0, img1, conf, path=None, **kwargs):
+ thr = 5e-4
+ # mkpts0 = pe['mkpts0_f'].cpu().numpy()
+ # mkpts1 = pe['mkpts1_f'].cpu().numpy()
+ if "conf_thr" in kwargs:
+ thr = kwargs["conf_thr"]
+ color = error_colormap(conf, thr, alpha=0.1)
+
+ text = [
+ f"{self.name}",
+ f"#Matches: {len(mkpts0)}",
+ ]
+ if 'R_errs' in kwargs:
+ text.append(f"$\\Delta$R:{kwargs['R_errs']:.2f}°, $\\Delta$t:{kwargs['t_errs']:.2f}°",)
+
+ if path:
+ make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150)
+ else:
+ return make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text)
+
+ @abstractmethod
+ def match_and_draw(self, data_dict, **kwargs):
+ pass
+
+ def compute_eval_metrics(self, epi_err_thr=5e-4):
+ # metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in self.eval_stats]
+ metrics = {k: flatten_list([_me[k] for _me in _metrics]) for k in _metrics[0]}
+
+ val_metrics_4tb = aggregate_metrics(metrics, epi_err_thr)
+ print('\n' + pprint.pformat(val_metrics_4tb))
+
+ def measure_time(self):
+ if len(self.time_stats) == 0:
+ return 0
+ return sum(self.time_stats) / len(self.time_stats)
diff --git a/imcui/third_party/TopicFM/viz/methods/loftr.py b/imcui/third_party/TopicFM/viz/methods/loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..53d0c00c1a067cee10bf1587197e4780ac8b2eda
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/methods/loftr.py
@@ -0,0 +1,85 @@
+from argparse import Namespace
+import os
+import torch
+import cv2
+
+from .base import Viz
+from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors
+
+from third_party.loftr.src.loftr import LoFTR, default_cfg
+
+
+class VizLoFTR(Viz):
+ def __init__(self, args):
+ super().__init__()
+ if type(args) == dict:
+ args = Namespace(**args)
+
+ self.match_threshold = args.match_threshold
+
+ # Load model
+ conf = dict(default_cfg)
+ conf['match_coarse']['thr'] = self.match_threshold
+ print(conf)
+ self.model = LoFTR(config=conf)
+ ckpt_dict = torch.load(args.ckpt)
+ self.model.load_state_dict(ckpt_dict['state_dict'])
+ self.model = self.model.eval().to(self.device)
+
+ # Name the method
+ # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0]
+ self.name = 'LoFTR'
+
+ print(f'Initialize {self.name}')
+
+ def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True):
+ if measure_time:
+ torch.cuda.synchronize()
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ self.model(data_dict)
+ if measure_time:
+ torch.cuda.synchronize()
+ end.record()
+ torch.cuda.synchronize()
+ self.time_stats.append(start.elapsed_time(end))
+
+ kpts0 = data_dict['mkpts0_f'].cpu().numpy()
+ kpts1 = data_dict['mkpts1_f'].cpu().numpy()
+
+ img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0]
+ img0 = cv2.imread(os.path.join(root_dir, img_name0))
+ img1 = cv2.imread(os.path.join(root_dir, img_name1))
+ if str(data_dict["dataset_name"][0]).lower() == 'scannet':
+ img0 = cv2.resize(img0, (640, 480))
+ img1 = cv2.resize(img1, (640, 480))
+
+ if viz_matches:
+ saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]])
+ folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name))
+ if not os.path.exists(folder_matches):
+ os.makedirs(folder_matches)
+ path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name))
+ if ground_truth:
+ compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match
+ compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair
+ epi_errors = data_dict['epi_errs'].cpu().numpy()
+ R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0]
+
+ self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches,
+ R_errs=R_errors, t_errs=t_errors)
+
+ rel_pair_names = list(zip(*data_dict['pair_names']))
+ bs = data_dict['image0'].size(0)
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)],
+ 'R_errs': data_dict['R_errs'],
+ 't_errs': data_dict['t_errs'],
+ 'inliers': data_dict['inliers']}
+ self.eval_stats.append({'metrics': metrics})
+ else:
+ m_conf = 1 - data_dict["mconf"].cpu().numpy()
+ self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4)
diff --git a/imcui/third_party/TopicFM/viz/methods/patch2pix.py b/imcui/third_party/TopicFM/viz/methods/patch2pix.py
new file mode 100644
index 0000000000000000000000000000000000000000..14a1d345881e2021be97dc5dde91d8bbe1cd18fa
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/methods/patch2pix.py
@@ -0,0 +1,80 @@
+from argparse import Namespace
+import os, sys
+import torch
+import cv2
+from pathlib import Path
+
+from .base import Viz
+from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors
+
+patch2pix_path = Path(__file__).parent / '../../third_party/patch2pix'
+sys.path.append(str(patch2pix_path))
+from third_party.patch2pix.utils.eval.model_helper import load_model, estimate_matches
+
+
+class VizPatch2Pix(Viz):
+ def __init__(self, args):
+ super().__init__()
+
+ if type(args) == dict:
+ args = Namespace(**args)
+ self.imsize = args.imsize
+ self.match_threshold = args.match_threshold
+ self.ksize = args.ksize
+ self.model = load_model(args.ckpt, method='patch2pix')
+ self.name = 'Patch2Pix'
+ print(f'Initialize {self.name} with image size {self.imsize}')
+
+ def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True):
+ img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0]
+ path_img0 = os.path.join(root_dir, img_name0)
+ path_img1 = os.path.join(root_dir, img_name1)
+ img0, img1 = cv2.imread(path_img0), cv2.imread(path_img1)
+ return_m_upscale = True
+ if str(data_dict["dataset_name"][0]).lower() == 'scannet':
+ # self.imsize = 640
+ img0 = cv2.resize(img0, tuple(self.imsize)) # (640, 480))
+ img1 = cv2.resize(img1, tuple(self.imsize)) # (640, 480))
+ return_m_upscale = False
+ outputs = estimate_matches(self.model, path_img0, path_img1,
+ ksize=self.ksize, io_thres=self.match_threshold,
+ eval_type='fine', imsize=self.imsize,
+ return_upscale=return_m_upscale, measure_time=measure_time)
+ if measure_time:
+ self.time_stats.append(outputs[-1])
+ matches, mconf = outputs[0], outputs[1]
+ kpts0 = matches[:, :2]
+ kpts1 = matches[:, 2:4]
+
+ if viz_matches:
+ saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]])
+ folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name))
+ if not os.path.exists(folder_matches):
+ os.makedirs(folder_matches)
+ path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name))
+
+ if ground_truth:
+ data_dict["mkpts0_f"] = torch.from_numpy(matches[:, :2]).float().to(self.device)
+ data_dict["mkpts1_f"] = torch.from_numpy(matches[:, 2:4]).float().to(self.device)
+ data_dict["m_bids"] = torch.zeros(matches.shape[0], device=self.device, dtype=torch.float32)
+ compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match
+ compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair
+ epi_errors = data_dict['epi_errs'].cpu().numpy()
+ R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0]
+
+ self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches,
+ R_errs=R_errors, t_errs=t_errors)
+
+ rel_pair_names = list(zip(*data_dict['pair_names']))
+ bs = data_dict['image0'].size(0)
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)],
+ 'R_errs': data_dict['R_errs'],
+ 't_errs': data_dict['t_errs'],
+ 'inliers': data_dict['inliers']}
+ self.eval_stats.append({'metrics': metrics})
+ else:
+ m_conf = 1 - mconf
+ self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4)
diff --git a/imcui/third_party/TopicFM/viz/methods/topicfm.py b/imcui/third_party/TopicFM/viz/methods/topicfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8b1485d5296947a38480cc031c5d7439bf163d
--- /dev/null
+++ b/imcui/third_party/TopicFM/viz/methods/topicfm.py
@@ -0,0 +1,198 @@
+from argparse import Namespace
+import os
+import torch
+import cv2
+from time import time
+from pathlib import Path
+import matplotlib.cm as cm
+import numpy as np
+
+from src.models.topic_fm import TopicFM
+from src import get_model_cfg
+from .base import Viz
+from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors
+from src.utils.plotting import draw_topics, draw_topicfm_demo, error_colormap
+
+
+class VizTopicFM(Viz):
+ def __init__(self, args):
+ super().__init__()
+ if type(args) == dict:
+ args = Namespace(**args)
+
+ self.match_threshold = args.match_threshold
+ self.n_sampling_topics = args.n_sampling_topics
+ self.show_n_topics = args.show_n_topics
+
+ # Load model
+ conf = dict(get_model_cfg())
+ conf['match_coarse']['thr'] = self.match_threshold
+ conf['coarse']['n_samples'] = self.n_sampling_topics
+ print("model config: ", conf)
+ self.model = TopicFM(config=conf)
+ ckpt_dict = torch.load(args.ckpt)
+ self.model.load_state_dict(ckpt_dict['state_dict'])
+ self.model = self.model.eval().to(self.device)
+
+ # Name the method
+ # self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0]
+ self.name = 'TopicFM'
+
+ print(f'Initialize {self.name}')
+
+ def match_and_draw(self, data_dict, root_dir=None, ground_truth=False, measure_time=False, viz_matches=True):
+ if measure_time:
+ torch.cuda.synchronize()
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ self.model(data_dict)
+ if measure_time:
+ torch.cuda.synchronize()
+ end.record()
+ torch.cuda.synchronize()
+ self.time_stats.append(start.elapsed_time(end))
+
+ kpts0 = data_dict['mkpts0_f'].cpu().numpy()
+ kpts1 = data_dict['mkpts1_f'].cpu().numpy()
+
+ img_name0, img_name1 = list(zip(*data_dict['pair_names']))[0]
+ img0 = cv2.imread(os.path.join(root_dir, img_name0))
+ img1 = cv2.imread(os.path.join(root_dir, img_name1))
+ if str(data_dict["dataset_name"][0]).lower() == 'scannet':
+ img0 = cv2.resize(img0, (640, 480))
+ img1 = cv2.resize(img1, (640, 480))
+
+ if viz_matches:
+ saved_name = "_".join([img_name0.split('/')[-1].split('.')[0], img_name1.split('/')[-1].split('.')[0]])
+ folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name))
+ if not os.path.exists(folder_matches):
+ os.makedirs(folder_matches)
+ path_to_save_matches = os.path.join(folder_matches, "{}.png".format(saved_name))
+
+ if ground_truth:
+ compute_symmetrical_epipolar_errors(data_dict) # compute epi_errs for each match
+ compute_pose_errors(data_dict) # compute R_errs, t_errs, pose_errs for each pair
+ epi_errors = data_dict['epi_errs'].cpu().numpy()
+ R_errors, t_errors = data_dict['R_errs'][0], data_dict['t_errs'][0]
+
+ self.draw_matches(kpts0, kpts1, img0, img1, epi_errors, path=path_to_save_matches,
+ R_errs=R_errors, t_errs=t_errors)
+
+ # compute evaluation metrics
+ rel_pair_names = list(zip(*data_dict['pair_names']))
+ bs = data_dict['image0'].size(0)
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [data_dict['epi_errs'][data_dict['m_bids'] == b].cpu().numpy() for b in range(bs)],
+ 'R_errs': data_dict['R_errs'],
+ 't_errs': data_dict['t_errs'],
+ 'inliers': data_dict['inliers']}
+ self.eval_stats.append({'metrics': metrics})
+ else:
+ m_conf = 1 - data_dict["mconf"].cpu().numpy()
+ self.draw_matches(kpts0, kpts1, img0, img1, m_conf, path=path_to_save_matches, conf_thr=0.4)
+ if self.show_n_topics > 0:
+ folder_topics = os.path.join(root_dir, "{}_viz_topics".format(self.name))
+ if not os.path.exists(folder_topics):
+ os.makedirs(folder_topics)
+ draw_topics(data_dict, img0, img1, saved_folder=folder_topics, show_n_topics=self.show_n_topics,
+ saved_name=saved_name)
+
+ def run_demo(self, dataloader, writer=None, output_dir=None, no_display=False, skip_frames=1):
+ data_dict = next(dataloader)
+
+ frame_id = 0
+ last_image_id = 0
+ img0 = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255
+ frame_tensor = data_dict["img"].to(self.device)
+ pair_data = {'image0': frame_tensor}
+ last_frame = cv2.resize(img0, (frame_tensor.shape[-1], frame_tensor.shape[-2]), cv2.INTER_LINEAR)
+
+ if output_dir is not None:
+ print('==> Will write outputs to {}'.format(output_dir))
+ Path(output_dir).mkdir(exist_ok=True)
+
+ # Create a window to display the demo.
+ if not no_display:
+ window_name = 'Topic-assisted Feature Matching'
+ cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
+ cv2.resizeWindow(window_name, (640 * 2, 480 * 2))
+ else:
+ print('Skipping visualization, will not show a GUI.')
+
+ # Print the keyboard help menu.
+ print('==> Keyboard control:\n'
+ '\tn: select the current frame as the reference image (left)\n'
+ '\tq: quit')
+
+ # vis_range = [kwargs["bottom_k"], kwargs["top_k"]]
+
+ while True:
+ frame_id += 1
+ if frame_id == len(dataloader):
+ print('Finished demo_loftr.py')
+ break
+ data_dict = next(dataloader)
+ if frame_id % skip_frames != 0:
+ # print("Skipping frame.")
+ continue
+
+ stem0, stem1 = last_image_id, data_dict["id"][0].item() - 1
+ frame = np.array(cv2.imread(str(data_dict["img_path"][0])), dtype=np.float32) / 255
+
+ frame_tensor = data_dict["img"].to(self.device)
+ frame = cv2.resize(frame, (frame_tensor.shape[-1], frame_tensor.shape[-2]), interpolation=cv2.INTER_LINEAR)
+ pair_data = {**pair_data, 'image1': frame_tensor}
+ self.model(pair_data)
+
+ total_n_matches = len(pair_data['mkpts0_f'])
+ mkpts0 = pair_data['mkpts0_f'].cpu().numpy() # [vis_range[0]:vis_range[1]]
+ mkpts1 = pair_data['mkpts1_f'].cpu().numpy() # [vis_range[0]:vis_range[1]]
+ mconf = pair_data['mconf'].cpu().numpy() # [vis_range[0]:vis_range[1]]
+
+ # Normalize confidence.
+ if len(mconf) > 0:
+ mconf = 1 - mconf
+
+ # alpha = 0
+ # color = cm.jet(mconf, alpha=alpha)
+ color = error_colormap(mconf, thr=0.4, alpha=0.1)
+
+ text = [
+ f'Topics',
+ '#Matches: {}'.format(total_n_matches),
+ ]
+
+ out = draw_topicfm_demo(pair_data, last_frame, frame, mkpts0, mkpts1, color, text,
+ show_n_topics=4, path=None)
+
+ if not no_display:
+ if writer is not None:
+ writer.write(out)
+ cv2.imshow('TopicFM Matches', out)
+ key = chr(cv2.waitKey(10) & 0xFF)
+ if key == 'q':
+ if writer is not None:
+ writer.release()
+ print('Exiting...')
+ break
+ elif key == 'n':
+ pair_data['image0'] = frame_tensor
+ last_frame = frame
+ last_image_id = (data_dict["id"][0].item() - 1)
+ frame_id_left = frame_id
+
+ elif output_dir is not None:
+ stem = 'matches_{:06}_{:06}'.format(stem0, stem1)
+ out_file = str(Path(output_dir, stem + '.png'))
+ print('\nWriting image to {}'.format(out_file))
+ cv2.imwrite(out_file, out)
+ else:
+ raise ValueError("output_dir is required when no display is given.")
+
+ cv2.destroyAllWindows()
+ if writer is not None:
+ writer.release()
+
diff --git a/imcui/third_party/XoFTR/configs/data/__init__.py b/imcui/third_party/XoFTR/configs/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/XoFTR/configs/data/base.py b/imcui/third_party/XoFTR/configs/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..03aab160fa4137ccc04380f94854a56fbb549074
--- /dev/null
+++ b/imcui/third_party/XoFTR/configs/data/base.py
@@ -0,0 +1,35 @@
+"""
+The data config will be the last one merged into the main config.
+Setups in data configs will override all existed setups!
+"""
+
+from yacs.config import CfgNode as CN
+_CN = CN()
+_CN.DATASET = CN()
+_CN.TRAINER = CN()
+
+# training data config
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+# validation set config
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+
+# testing data config
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+
+# dataset config
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+cfg = _CN
diff --git a/imcui/third_party/XoFTR/configs/data/megadepth_trainval_840.py b/imcui/third_party/XoFTR/configs/data/megadepth_trainval_840.py
new file mode 100644
index 0000000000000000000000000000000000000000..130212c3e7d55310cb822a37e026655aff9c346f
--- /dev/null
+++ b/imcui/third_party/XoFTR/configs/data/megadepth_trainval_840.py
@@ -0,0 +1,22 @@
+from configs.data.base import cfg
+
+
+TRAIN_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
+cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
+cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
+
+TEST_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
+cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
+cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+# 368 scenes in total for MegaDepth
+# (with difficulty balanced (further split each scene to 3 sub-scenes))
+cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 840 # for training on 32GB meme GPUs
diff --git a/imcui/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py b/imcui/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab33112a8abc8a6c89993bed240f98bd51d38e2d
--- /dev/null
+++ b/imcui/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py
@@ -0,0 +1,23 @@
+from configs.data.base import cfg
+
+
+TRAIN_BASE_PATH = "data/megadepth/index"
+cfg.DATASET.TRAIN_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train"
+cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
+cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
+
+VAL_BASE_PATH = "data/METU_VisTIR/index"
+cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
+cfg.DATASET.VAL_DATA_SOURCE = "VisTir"
+cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/METU_VisTIR"
+cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{VAL_BASE_PATH}/scene_info_val"
+cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{VAL_BASE_PATH}/val_test_list/val_list.txt"
+cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
+
+# 368 scenes in total for MegaDepth
+# (with difficulty balanced (further split each scene to 3 sub-scenes))
+cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100
+
+cfg.DATASET.MGDPT_IMG_RESIZE = 640 # for training on 11GB mem GPUs
diff --git a/imcui/third_party/XoFTR/configs/data/pretrain.py b/imcui/third_party/XoFTR/configs/data/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dba89ca2623985e51890e04a462f46a492395f9
--- /dev/null
+++ b/imcui/third_party/XoFTR/configs/data/pretrain.py
@@ -0,0 +1,8 @@
+from configs.data.base import cfg
+
+cfg.DATASET.TRAIN_DATA_SOURCE = "KAIST"
+cfg.DATASET.TRAIN_DATA_ROOT = "data/kaist-cvpr15"
+cfg.DATASET.VAL_DATA_SOURCE = "KAIST"
+cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/kaist-cvpr15"
+
+cfg.DATASET.PRETRAIN_IMG_RESIZE = 640
diff --git a/imcui/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py b/imcui/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py
new file mode 100644
index 0000000000000000000000000000000000000000..61aa1a96853671638223ecb65075ea32675e78e9
--- /dev/null
+++ b/imcui/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py
@@ -0,0 +1,17 @@
+from src.config.default import _CN as cfg
+
+cfg.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
+
+cfg.TRAINER.CANONICAL_LR = 8e-3
+cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24, 30, 36, 42]
+
+# pose estimation
+cfg.TRAINER.RANSAC_PIXEL_THR = 1.5
+
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+cfg.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3
+
+cfg.TRAINER.USE_WANDB = True # use weight and biases
diff --git a/imcui/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py b/imcui/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..252146af93224a4449dbcaf13b7274573c3aba16
--- /dev/null
+++ b/imcui/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py
@@ -0,0 +1,12 @@
+from src.config.default import _CN as cfg
+
+cfg.TRAINER.CANONICAL_LR = 4e-3
+cfg.TRAINER.WARMUP_STEP = 1250 # 2 epochs
+cfg.TRAINER.WARMUP_RATIO = 0.1
+cfg.TRAINER.MSLR_MILESTONES = [4, 6, 8, 10, 12, 14, 16, 18]
+
+cfg.TRAINER.OPTIMIZER = "adamw"
+cfg.TRAINER.ADAMW_DECAY = 0.1
+
+cfg.TRAINER.USE_WANDB = True # use weight and biases
+
diff --git a/imcui/third_party/XoFTR/environment.yaml b/imcui/third_party/XoFTR/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c6a25b9b6d1bc38d38010da55efb309901a0a1d2
--- /dev/null
+++ b/imcui/third_party/XoFTR/environment.yaml
@@ -0,0 +1,14 @@
+name: xoftr
+channels:
+ # - https://dx-mirrors.sensetime.com/anaconda/cloud/pytorch
+ - pytorch
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.8
+ - pytorch=2.0.1
+ - pytorch-cuda=11.8
+ - pip
+ - pip:
+ - -r requirements.txt
diff --git a/imcui/third_party/XoFTR/pretrain.py b/imcui/third_party/XoFTR/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..d861315931714044e6915c5b2043218e06aed913
--- /dev/null
+++ b/imcui/third_party/XoFTR/pretrain.py
@@ -0,0 +1,125 @@
+import math
+import argparse
+import pprint
+from distutils.util import strtobool
+from pathlib import Path
+from loguru import logger as loguru_logger
+from datetime import datetime
+
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+from pytorch_lightning.plugins import DDPPlugin
+
+from src.config.default import get_cfg_defaults
+from src.utils.misc import get_rank_zero_only_logger, setup_gpus
+from src.utils.profiler import build_profiler
+from src.lightning.data_pretrain import PretrainDataModule
+from src.lightning.lightning_xoftr_pretrain import PL_XoFTR_Pretrain
+
+loguru_logger = get_rank_zero_only_logger(loguru_logger)
+
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--exp_name', type=str, default='default_exp_name')
+ parser.add_argument(
+ '--batch_size', type=int, default=4, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=4)
+ parser.add_argument(
+ '--pin_memory', type=lambda x: bool(strtobool(x)),
+ nargs='?', default=True, help='whether loading data to pinned memory or not')
+ parser.add_argument(
+ '--ckpt_path', type=str, default=None,
+ help='pretrained checkpoint path')
+ parser.add_argument(
+ '--disable_ckpt', action='store_true',
+ help='disable checkpoint saving (useful for debugging).')
+ parser.add_argument(
+ '--profiler_name', type=str, default=None,
+ help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--parallel_load_data', action='store_true',
+ help='load datasets in with multiple processes.')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+
+def main():
+ # parse arguments
+ args = parse_args()
+ rank_zero_only(pprint.pprint)(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+
+ # scale lr and warmup-step automatically
+ args.gpus = _n_gpus = setup_gpus(args.gpus)
+ config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
+ config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
+ _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
+ config.TRAINER.SCALING = _scaling
+ config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
+ config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_XoFTR_Pretrain(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
+ loguru_logger.info(f"XoFTR LightningModule initialized!")
+
+ # lightning data
+ data_module = PretrainDataModule(args, config)
+ loguru_logger.info(f"XoFTR DataModule initialized!")
+
+ # TensorBoard Logger
+ logger = [TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)]
+ ckpt_dir = Path(logger[0].log_dir) / 'checkpoints'
+ if config.TRAINER.USE_WANDB:
+ logger.append(WandbLogger(name=args.exp_name + f"_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}",
+ project='XoFTR'))
+
+ # Callbacks
+ # TODO: update ModelCheckpoint to monitor multiple metrics
+ ckpt_callback = ModelCheckpoint(verbose=True, save_top_k=-1,
+ save_last=True,
+ dirpath=str(ckpt_dir),
+ filename='{epoch}')
+ lr_monitor = LearningRateMonitor(logging_interval='step')
+ callbacks = [lr_monitor]
+ if not args.disable_ckpt:
+ callbacks.append(ckpt_callback)
+
+ # Lightning Trainer
+ trainer = pl.Trainer.from_argparse_args(
+ args,
+ plugins=DDPPlugin(find_unused_parameters=True,
+ num_nodes=args.num_nodes,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
+ gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
+ callbacks=callbacks,
+ logger=logger,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
+ replace_sampler_ddp=False, # use custom sampler
+ reload_dataloaders_every_epoch=False, # avoid repeated samples!
+ weights_summary='full',
+ profiler=profiler)
+ loguru_logger.info(f"Trainer initialized!")
+ loguru_logger.info(f"Start training!")
+ trainer.fit(model, datamodule=data_module)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/imcui/third_party/XoFTR/src/__init__.py b/imcui/third_party/XoFTR/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/XoFTR/src/config/default.py b/imcui/third_party/XoFTR/src/config/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..02091605a350d9fc8c0dd7510a3937a67e599593
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/config/default.py
@@ -0,0 +1,203 @@
+from yacs.config import CfgNode as CN
+
+INFERENCE = False
+
+_CN = CN()
+
+############## ↓ XoFTR Pipeline ↓ ##############
+_CN.XOFTR = CN()
+_CN.XOFTR.RESOLUTION = (8, 2) # options: [(8, 2)]
+_CN.XOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
+_CN.XOFTR.MEDIUM_WINDOW_SIZE = 3 # window_size in fine_level, must be odd
+
+# 1. XoFTR-backbone (local feature CNN) config
+_CN.XOFTR.RESNET = CN()
+_CN.XOFTR.RESNET.INITIAL_DIM = 128
+_CN.XOFTR.RESNET.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
+
+# 2. XoFTR-coarse module config
+_CN.XOFTR.COARSE = CN()
+_CN.XOFTR.COARSE.INFERENCE = INFERENCE
+_CN.XOFTR.COARSE.D_MODEL = 256
+_CN.XOFTR.COARSE.D_FFN = 256
+_CN.XOFTR.COARSE.NHEAD = 8
+_CN.XOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
+_CN.XOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
+
+# 3. Coarse-Matching config
+_CN.XOFTR.MATCH_COARSE = CN()
+_CN.XOFTR.MATCH_COARSE.INFERENCE = INFERENCE
+_CN.XOFTR.MATCH_COARSE.D_MODEL = 256
+_CN.XOFTR.MATCH_COARSE.THR = 0.3
+_CN.XOFTR.MATCH_COARSE.BORDER_RM = 2
+_CN.XOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax']
+_CN.XOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.XOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory
+_CN.XOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+
+# 4. XoFTR-fine module config
+_CN.XOFTR.FINE = CN()
+_CN.XOFTR.FINE.DENSER = False # if true, match all features in fine-level windows
+_CN.XOFTR.FINE.INFERENCE = INFERENCE
+_CN.XOFTR.FINE.DSMAX_TEMPERATURE = 0.1
+_CN.XOFTR.FINE.THR = 0.1
+_CN.XOFTR.FINE.MLP_HIDDEN_DIM_COEF = 2 # coef for mlp hidden dim (hidden_dim = feat_dim * coef)
+_CN.XOFTR.FINE.NHEAD_FINE_LEVEL = 8
+_CN.XOFTR.FINE.NHEAD_MEDIUM_LEVEL = 7
+
+
+# 5. XoFTR Losses
+
+_CN.XOFTR.LOSS = CN()
+_CN.XOFTR.LOSS.FOCAL_ALPHA = 0.25
+_CN.XOFTR.LOSS.FOCAL_GAMMA = 2.0
+_CN.XOFTR.LOSS.POS_WEIGHT = 1.0
+_CN.XOFTR.LOSS.NEG_WEIGHT = 1.0
+
+# -- # coarse-level
+_CN.XOFTR.LOSS.COARSE_WEIGHT = 0.5
+# -- # fine-level
+_CN.XOFTR.LOSS.FINE_WEIGHT = 0.3
+# -- # sub-pixel
+_CN.XOFTR.LOSS.SUB_WEIGHT = 1 * 10**4
+
+############## Dataset ##############
+_CN.DATASET = CN()
+# 1. data config
+# training and validating
+_CN.DATASET.TRAIN_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
+_CN.DATASET.TRAIN_DATA_ROOT = None
+_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TRAIN_NPZ_ROOT = None
+_CN.DATASET.TRAIN_LIST_PATH = None
+_CN.DATASET.TRAIN_INTRINSIC_PATH = None
+_CN.DATASET.VAL_DATA_SOURCE = None
+_CN.DATASET.VAL_DATA_ROOT = None
+_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.VAL_NPZ_ROOT = None
+_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
+_CN.DATASET.VAL_INTRINSIC_PATH = None
+# testing
+_CN.DATASET.TEST_DATA_SOURCE = None
+_CN.DATASET.TEST_DATA_ROOT = None
+_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
+_CN.DATASET.TEST_NPZ_ROOT = None
+_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
+_CN.DATASET.TEST_INTRINSIC_PATH = None
+
+# 2. dataset config
+# general options
+_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
+_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
+_CN.DATASET.AUGMENTATION_TYPE = "rgb_thermal" # options: [None, 'dark', 'mobile']
+
+# MegaDepth options
+_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
+_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
+_CN.DATASET.MGDPT_DF = 8
+
+# VisTir options
+_CN.DATASET.VISTIR_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.VISTIR_IMG_PAD = False # pad img to square with size = VISTIR_IMG_RESIZE
+_CN.DATASET.VISTIR_DF = 8
+
+# Pretrain dataset options
+_CN.DATASET.PRETRAIN_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
+_CN.DATASET.PRETRAIN_IMG_PAD = True # pad img to square with size = PRETRAIN_IMG_RESIZE
+_CN.DATASET.PRETRAIN_DF = 8
+_CN.DATASET.PRETRAIN_FRAME_GAP = 2 # the gap between video frames of Kaist dataset
+
+############## Trainer ##############
+_CN.TRAINER = CN()
+_CN.TRAINER.WORLD_SIZE = 1
+_CN.TRAINER.CANONICAL_BS = 64
+_CN.TRAINER.CANONICAL_LR = 6e-3
+_CN.TRAINER.SCALING = None # this will be calculated automatically
+_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
+
+_CN.TRAINER.USE_WANDB = False # use weight and biases
+
+# optimizer
+_CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
+_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
+_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
+_CN.TRAINER.ADAMW_DECAY = 0.1
+
+# step-based warm-up
+_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
+_CN.TRAINER.WARMUP_RATIO = 0.
+_CN.TRAINER.WARMUP_STEP = 4800
+
+# learning rate scheduler
+_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
+_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
+_CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
+_CN.TRAINER.MSLR_GAMMA = 0.5
+_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
+_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
+
+# plotting related
+_CN.TRAINER.ENABLE_PLOTTING = True
+_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 128 # number of val/test paris for plotting
+_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
+_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
+
+# geometric metrics and pose solver
+_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
+_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
+_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
+_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
+_CN.TRAINER.RANSAC_CONF = 0.99999
+_CN.TRAINER.RANSAC_MAX_ITERS = 10000
+_CN.TRAINER.USE_MAGSACPP = False
+
+# data sampler for train_dataloader
+_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
+# 'scene_balance' config
+_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
+_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
+_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
+_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
+# 'random' config
+_CN.TRAINER.RDM_REPLACEMENT = True
+_CN.TRAINER.RDM_NUM_SAMPLES = None
+
+# gradient clipping
+_CN.TRAINER.GRADIENT_CLIPPING = 0.5
+
+# reproducibility
+# This seed affects the data sampling. With the same seed, the data sampling is promised
+# to be the same. When resume training from a checkpoint, it's better to use a different
+# seed, otherwise the sampled data will be exactly the same as before resuming, which will
+# cause less unique data items sampled during the entire training.
+# Use of different seed values might affect the final training result, since not all data items
+# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
+_CN.TRAINER.SEED = 66
+
+############## Pretrain ##############
+_CN.PRETRAIN = CN()
+_CN.PRETRAIN.PATCH_SIZE = 64 # patch sıze for masks
+_CN.PRETRAIN.MASK_RATIO = 0.5
+_CN.PRETRAIN.MAE_MARGINS = [0, 0.4, 0, 0] # margins not to be masked (up bottom left right)
+_CN.PRETRAIN.VAL_SEED = 42 # rng seed to crate the same masks for validation
+
+_CN.XOFTR.PRETRAIN_PATCH_SIZE = _CN.PRETRAIN.PATCH_SIZE
+
+############## Test/Inference ##############
+_CN.TEST = CN()
+_CN.TEST.IMG0_RESIZE = 640 # resize the longer side
+_CN.TEST.IMG1_RESIZE = 640 # resize the longer side
+_CN.TEST.DF = 8
+_CN.TEST.PADDING = False # pad img to square with size = IMG0_RESIZE, IMG1_RESIZE
+_CN.TEST.COARSE_SCALE = 0.125
+
+def get_cfg_defaults(inference=False):
+ """Get a yacs CfgNode object with default values for my_project."""
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ if inference:
+ _CN.XOFTR.COARSE.INFERENCE = True
+ _CN.XOFTR.MATCH_COARSE.INFERENCE = True
+ _CN.XOFTR.FINE.INFERENCE = True
+ return _CN.clone()
diff --git a/imcui/third_party/XoFTR/src/datasets/megadepth.py b/imcui/third_party/XoFTR/src/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..32eaeb7a554ed28a956c71ca9c0df5e418aebbba
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/datasets/megadepth.py
@@ -0,0 +1,143 @@
+import os.path as osp
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+from loguru import logger
+
+from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
+
+def correct_image_paths(scene_info):
+ """Changes the path format from undistorted images from D2Net to MegaDepth_v1 format"""
+ image_paths = scene_info["image_paths"]
+ for ii in range(len(image_paths)):
+ if image_paths[ii] is not None:
+ folds = image_paths[ii].split("/")
+ path = osp.join("phoenix/S6/zl548/MegaDepth_v1/", folds[1], "dense0/imgs", folds[3] )
+ image_paths[ii] = path
+ scene_info["image_paths"] = image_paths
+ return scene_info
+
+class MegaDepthDataset(Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ mode='train',
+ min_overlap_score=0.4,
+ img_resize=None,
+ df=None,
+ img_padding=False,
+ depth_padding=False,
+ augment_fn=None,
+ **kwargs):
+ """
+ Manage one scene(npz_path) of MegaDepth dataset.
+
+ Args:
+ root_dir (str): megadepth root directory that has `phoenix`.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ mode (str): options are ['train', 'val', 'test']
+ min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
+ This is useful during training with batches and testing with memory intensive algorithms.
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
+ depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.mode = mode
+ self.scene_id = npz_path.split('.')[0]
+
+ # prepare scene_info and pair_info
+ if mode == 'test' and min_overlap_score != 0:
+ logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
+ min_overlap_score = 0
+ self.scene_info = np.load(npz_path, allow_pickle=True)
+ self.scene_info = correct_image_paths(self.scene_info)
+ self.pair_infos = self.scene_info['pair_infos'].copy()
+ del self.scene_info['pair_infos']
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
+
+ # parameters for image resizing, padding and depthmap padding
+ if mode == 'train':
+ assert img_resize is not None and img_padding and depth_padding
+ self.img_resize = img_resize
+ self.df = df
+ self.img_padding = img_padding
+ self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
+
+ # for training XoFTR
+ # self.augment_fn = augment_fn if mode == 'train' else None
+ self.augment_fn = augment_fn
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+
+ def __len__(self):
+ return len(self.pair_infos)
+
+ def __getitem__(self, idx):
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
+
+ # read grayscale image and mask. (1, h, w) and (h, w)
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
+
+ if getattr(self.augment_fn, 'random_switch', False):
+ im_num = torch.randint(0, 2, (1,))
+ augment_fn_0 = lambda x: self.augment_fn(x, image_num=im_num)
+ augment_fn_1 = lambda x: self.augment_fn(x, image_num=1-im_num)
+ else:
+ augment_fn_0 = self.augment_fn
+ augment_fn_1 = self.augment_fn
+ image0, mask0, scale0 = read_megadepth_gray(
+ img_name0, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_0)
+ image1, mask1, scale1 = read_megadepth_gray(
+ img_name1, self.img_resize, self.df, self.img_padding, augment_fn=augment_fn_1)
+
+ # read depth. shape: (h, w)
+ if self.mode in ['train', 'val']:
+ depth0 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
+ depth1 = read_megadepth_depth(
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read intrinsics of original size
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T0 = self.scene_info['poses'][idx0]
+ T1 = self.scene_info['poses'][idx1]
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
+ T_1to0 = T_0to1.inverse()
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'scale0': scale0, # [scale_w, scale_h]
+ 'scale1': scale1,
+ 'dataset_name': 'MegaDepth',
+ 'scene_id': self.scene_id,
+ 'pair_id': idx,
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
+ }
+
+ # for XoFTR training
+ if mask0 is not None: # img_padding is True
+ if self.coarse_scale:
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
+ scale_factor=self.coarse_scale,
+ mode='nearest',
+ recompute_scale_factor=False)[0].bool()
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+
+ return data
diff --git a/imcui/third_party/XoFTR/src/datasets/pretrain_dataset.py b/imcui/third_party/XoFTR/src/datasets/pretrain_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ab2f3e61299ccc13f5926b0716417bf28a5fad
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/datasets/pretrain_dataset.py
@@ -0,0 +1,156 @@
+import os
+import glob
+import os.path as osp
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+from loguru import logger
+import random
+from src.utils.dataset import read_pretrain_gray
+
+class PretrainDataset(Dataset):
+ def __init__(self,
+ root_dir,
+ mode='train',
+ img_resize=None,
+ df=None,
+ img_padding=False,
+ frame_gap=2,
+ **kwargs):
+ """
+ Manage image pairs of KAIST Multispectral Pedestrian Detection Benchmark Dataset.
+
+ Args:
+ root_dir (str): KAIST Multispectral Pedestrian root directory that has `phoenix`.
+ mode (str): options are ['train', 'val']
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
+ This is useful during training with batches and testing with memory intensive algorithms.
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.mode = mode
+
+ # specify which part of the data is used for trainng and testing
+ if mode == 'train':
+ assert img_resize is not None and img_padding
+ self.start_ratio = 0.0
+ self.end_ratio = 0.9
+ elif mode == 'val':
+ assert img_resize is not None and img_padding
+ self.start_ratio = 0.9
+ self.end_ratio = 1.0
+ else:
+ raise NotImplementedError()
+
+ # parameters for image resizing, padding
+ self.img_resize = img_resize
+ self.df = df
+ self.img_padding = img_padding
+
+ # for training XoFTR
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+
+ self.pair_paths = self.generate_kaist_pairs(root_dir, frame_gap=frame_gap, second_frame_range=0)
+
+ def get_kaist_image_paths(self, root_dir):
+ vis_img_paths = []
+ lwir_img_paths = []
+ img_num_per_folder = []
+
+ # Recursively search for folders named "image"
+ for folder, subfolders, filenames in os.walk(root_dir):
+ if "visible" in subfolders and "lwir" in subfolders:
+ vis_img_folder = osp.join(folder, "visible")
+ lwir_img_folder = osp.join(folder, "lwir")
+ # Use glob to find image files (you can add more extensions if needed)
+ vis_imgs_i = glob.glob(osp.join(vis_img_folder, '*.jpg'))
+ vis_imgs_i.sort()
+ lwir_imgs_i = glob.glob(osp.join(lwir_img_folder, '*.jpg'))
+ lwir_imgs_i.sort()
+ vis_img_paths.append(vis_imgs_i)
+ lwir_img_paths.append(lwir_imgs_i)
+ img_num_per_folder.append(len(vis_imgs_i))
+ assert len(vis_imgs_i) == len(lwir_imgs_i), f"Image numbers do not match in {folder}, {len(vis_imgs_i)} != {len(lwir_imgs_i)}"
+ # Add more image file extensions as necessary
+ return vis_img_paths, lwir_img_paths, img_num_per_folder
+
+ def generate_kaist_pairs(self, root_dir, frame_gap, second_frame_range):
+ """ generate image pairs (Vis-TIR) from KAIST Pedestrian dataset
+ Args:
+ root_dir: root directory for the dataset
+ frame_gap (int): the frame gap between consecutive images
+ second_frame_range (int): the range for second image i.e. for the first ind i, second ind j element of [i-10, i+10]
+ Returns:
+ pair_paths (list)
+ """
+ vis_img_paths, lwir_img_paths, img_num_per_folder = self.get_kaist_image_paths(root_dir)
+ pair_paths = []
+ for i in range(len(img_num_per_folder)):
+ num_img = img_num_per_folder[i]
+ inds_vis = torch.arange(int(self.start_ratio * num_img),
+ int(self.end_ratio * num_img),
+ frame_gap, dtype=int)
+ if second_frame_range > 0:
+ inds_lwir = inds_vis + torch.randint(-second_frame_range, second_frame_range, (inds_vis.shape[0],))
+ inds_lwir[inds_lwirint(self.end_ratio * num_img)-1] = int(self.end_ratio * num_img)-1
+ else:
+ inds_lwir = inds_vis
+ for j, k in zip(inds_vis, inds_lwir):
+ img_name0 = os.path.relpath(vis_img_paths[i][j], root_dir)
+ img_name1 = os.path.relpath(lwir_img_paths[i][k], root_dir)
+
+ if torch.rand(1) > 0.5:
+ img_name0, img_name1 = img_name1, img_name0
+
+ pair_paths.append([img_name0, img_name1])
+
+ random.shuffle(pair_paths)
+ return pair_paths
+
+ def __len__(self):
+ return len(self.pair_paths)
+
+ def __getitem__(self, idx):
+ # read grayscale and normalized image, and mask. (1, h, w) and (h, w)
+ img_name0 = osp.join(self.root_dir, self.pair_paths[idx][0])
+ img_name1 = osp.join(self.root_dir, self.pair_paths[idx][1])
+
+ if self.mode == "train" and torch.rand(1) > 0.5:
+ img_name0, img_name1 = img_name1, img_name0
+
+ image0, image0_norm, mask0, scale0, image0_mean, image0_std = read_pretrain_gray(
+ img_name0, self.img_resize, self.df, self.img_padding, None)
+ image1, image1_norm, mask1, scale1, image1_mean, image1_std = read_pretrain_gray(
+ img_name1, self.img_resize, self.df, self.img_padding, None)
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'image1': image1,
+ 'image0_norm': image0_norm,
+ 'image1_norm': image1_norm,
+ 'scale0': scale0, # [scale_w, scale_h]
+ 'scale1': scale1,
+ "image0_mean": image0_mean,
+ "image0_std": image0_std,
+ "image1_mean": image1_mean,
+ "image1_std": image1_std,
+ 'dataset_name': 'PreTrain',
+ 'pair_id': idx,
+ 'pair_names': (self.pair_paths[idx][0], self.pair_paths[idx][1]),
+ }
+
+ # for XoFTR training
+ if mask0 is not None: # img_padding is True
+ if self.coarse_scale:
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
+ scale_factor=self.coarse_scale,
+ mode='nearest',
+ recompute_scale_factor=False)[0].bool()
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+
+ return data
diff --git a/imcui/third_party/XoFTR/src/datasets/sampler.py b/imcui/third_party/XoFTR/src/datasets/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b6f435645632a013476f9a665a0861ab7fcb61
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/datasets/sampler.py
@@ -0,0 +1,77 @@
+import torch
+from torch.utils.data import Sampler, ConcatDataset
+
+
+class RandomConcatSampler(Sampler):
+ """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
+ in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
+ However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
+
+ For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
+ Args:
+ shuffle (bool): shuffle the random sampled indices across all sub-datsets.
+ repeat (int): repeatedly use the sampled indices multiple times for training.
+ [arXiv:1902.05509, arXiv:1901.09335]
+ NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
+ NOTE: This sampler behaves differently with DistributedSampler.
+ It assume the dataset is splitted across ranks instead of replicated.
+ TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
+ ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
+ """
+ def __init__(self,
+ data_source: ConcatDataset,
+ n_samples_per_subset: int,
+ subset_replacement: bool=True,
+ shuffle: bool=True,
+ repeat: int=1,
+ seed: int=None):
+ if not isinstance(data_source, ConcatDataset):
+ raise TypeError("data_source should be torch.utils.data.ConcatDataset")
+
+ self.data_source = data_source
+ self.n_subset = len(self.data_source.datasets)
+ self.n_samples_per_subset = n_samples_per_subset
+ self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
+ self.subset_replacement = subset_replacement
+ self.repeat = repeat
+ self.shuffle = shuffle
+ self.generator = torch.manual_seed(seed)
+ assert self.repeat >= 1
+
+ def __len__(self):
+ return self.n_samples
+
+ def __iter__(self):
+ indices = []
+ # sample from each sub-dataset
+ for d_idx in range(self.n_subset):
+ low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
+ high = self.data_source.cumulative_sizes[d_idx]
+ if self.subset_replacement:
+ rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ else: # sample without replacement
+ len_subset = len(self.data_source.datasets[d_idx])
+ rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
+ if len_subset >= self.n_samples_per_subset:
+ rand_tensor = rand_tensor[:self.n_samples_per_subset]
+ else: # padding with replacement
+ rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
+ generator=self.generator, dtype=torch.int64)
+ rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
+ indices.append(rand_tensor)
+ indices = torch.cat(indices)
+ if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
+ rand_tensor = torch.randperm(len(indices), generator=self.generator)
+ indices = indices[rand_tensor]
+
+ # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
+ if self.repeat > 1:
+ repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
+ if self.shuffle:
+ _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
+ repeat_indices = map(_choice, repeat_indices)
+ indices = torch.cat([indices, *repeat_indices], 0)
+
+ assert indices.shape[0] == self.n_samples
+ return iter(indices.tolist())
diff --git a/imcui/third_party/XoFTR/src/datasets/scannet.py b/imcui/third_party/XoFTR/src/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8cfa8d5a91bf275733b980fbf77641d77b16a9b
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/datasets/scannet.py
@@ -0,0 +1,114 @@
+from os import path as osp
+from typing import Dict
+from unicodedata import name
+
+import numpy as np
+import torch
+import torch.utils as utils
+from numpy.linalg import inv
+from src.utils.dataset import (
+ read_scannet_gray,
+ read_scannet_depth,
+ read_scannet_pose,
+ read_scannet_intrinsic
+)
+
+
+class ScanNetDataset(utils.data.Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ intrinsic_path,
+ mode='train',
+ min_overlap_score=0.4,
+ augment_fn=None,
+ pose_dir=None,
+ **kwargs):
+ """Manage one scene of ScanNet Dataset.
+ Args:
+ root_dir (str): ScanNet root directory that contains scene folders.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ intrinsic_path (str): path to depth-camera intrinsic file.
+ mode (str): options are ['train', 'val', 'test'].
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
+ pose_dir (str): ScanNet root directory that contains all poses.
+ (we use a separate (optional) pose_dir since we store images and poses separately.)
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.pose_dir = pose_dir if pose_dir is not None else root_dir
+ self.mode = mode
+
+ # prepare data_names, intrinsics and extrinsics(T)
+ with np.load(npz_path) as data:
+ self.data_names = data['name']
+ if 'score' in data.keys() and mode not in ['val' or 'test']:
+ kept_mask = data['score'] > min_overlap_score
+ self.data_names = self.data_names[kept_mask]
+ self.intrinsics = dict(np.load(intrinsic_path))
+
+ # for training LoFTR
+ self.augment_fn = augment_fn if mode == 'train' else None
+
+ def __len__(self):
+ return len(self.data_names)
+
+ def _read_abs_pose(self, scene_name, name):
+ pth = osp.join(self.pose_dir,
+ scene_name,
+ 'pose', f'{name}.txt')
+ return read_scannet_pose(pth)
+
+ def _compute_rel_pose(self, scene_name, name0, name1):
+ pose0 = self._read_abs_pose(scene_name, name0)
+ pose1 = self._read_abs_pose(scene_name, name1)
+
+ return np.matmul(pose1, inv(pose0)) # (4, 4)
+
+ def __getitem__(self, idx):
+ data_name = self.data_names[idx]
+ scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+
+ # read the grayscale image which will be resized to (1, 480, 640)
+ img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
+ img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
+
+ # TODO: Support augmentation & handle seeds for each worker correctly.
+ image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+ image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
+
+ # read the depthmap which is stored as (480, 640)
+ if self.mode in ['train', 'val']:
+ depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
+ depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
+ else:
+ depth0 = depth1 = torch.tensor([])
+
+ # read the intrinsic of depthmap
+ K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
+ dtype=torch.float32)
+ T_1to0 = T_0to1.inverse()
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'depth0': depth0, # (h, w)
+ 'image1': image1,
+ 'depth1': depth1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'dataset_name': 'ScanNet',
+ 'scene_id': scene_name,
+ 'pair_id': idx,
+ 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
+ osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
+ }
+
+ return data
diff --git a/imcui/third_party/XoFTR/src/datasets/vistir.py b/imcui/third_party/XoFTR/src/datasets/vistir.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f09e87bac1e7470ab77830d66c632dd727b9a12
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/datasets/vistir.py
@@ -0,0 +1,109 @@
+import os.path as osp
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+from loguru import logger
+
+from src.utils.dataset import read_vistir_gray
+
+class VisTirDataset(Dataset):
+ def __init__(self,
+ root_dir,
+ npz_path,
+ mode='val',
+ img_resize=None,
+ df=None,
+ img_padding=False,
+ **kwargs):
+ """
+ Manage one scene(npz_path) of VisTir dataset.
+
+ Args:
+ root_dir (str): VisTIR root directory.
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
+ mode (str): options are ['val', 'test']
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
+ img_padding (bool): If set to 'True', zero-pad the image to squared size.
+ """
+ super().__init__()
+ self.root_dir = root_dir
+ self.mode = mode
+ self.scene_id = npz_path.split('.')[0]
+
+ # prepare scene_info and pair_info
+ self.scene_info = dict(np.load(npz_path, allow_pickle=True))
+ self.pair_infos = self.scene_info['pair_infos'].copy()
+ del self.scene_info['pair_infos']
+
+ # parameters for image resizing, padding
+ self.img_resize = img_resize
+ self.df = df
+ self.img_padding = img_padding
+
+ # for training XoFTR
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
+
+
+ def __len__(self):
+ return len(self.pair_infos)
+
+ def __getitem__(self, idx):
+ (idx0, idx1) = self.pair_infos[idx]
+
+
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0][0])
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1][1])
+
+ # read intrinsics of original size
+ K_0 = np.array(self.scene_info['intrinsics'][idx0][0], dtype=float).reshape(3,3)
+ K_1 = np.array(self.scene_info['intrinsics'][idx1][1], dtype=float).reshape(3,3)
+
+ # read distortion coefficients
+ dist0 = np.array(self.scene_info['distortion_coefs'][idx0][0], dtype=float)
+ dist1 = np.array(self.scene_info['distortion_coefs'][idx1][1], dtype=float)
+
+ # read grayscale undistorted image and mask. (1, h, w) and (h, w)
+ image0, mask0, scale0, K_0 = read_vistir_gray(
+ img_name0, K_0, dist0, self.img_resize, self.df, self.img_padding, augment_fn=None)
+ image1, mask1, scale1, K_1 = read_vistir_gray(
+ img_name1, K_1, dist1, self.img_resize, self.df, self.img_padding, augment_fn=None)
+
+ # to tensor
+ K_0 = torch.tensor(K_0.copy(), dtype=torch.float).reshape(3, 3)
+ K_1 = torch.tensor(K_1.copy(), dtype=torch.float).reshape(3, 3)
+
+ # read and compute relative poses
+ T0 = self.scene_info['poses'][idx0]
+ T1 = self.scene_info['poses'][idx1]
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
+ T_1to0 = T_0to1.inverse()
+
+ data = {
+ 'image0': image0, # (1, h, w)
+ 'image1': image1,
+ 'T_0to1': T_0to1, # (4, 4)
+ 'T_1to0': T_1to0,
+ 'K0': K_0, # (3, 3)
+ 'K1': K_1,
+ 'dist0': dist0,
+ 'dist1': dist1,
+ 'scale0': scale0, # [scale_w, scale_h]
+ 'scale1': scale1,
+ 'dataset_name': 'VisTir',
+ 'scene_id': self.scene_id,
+ 'pair_id': idx,
+ 'pair_names': (self.scene_info['image_paths'][idx0][0], self.scene_info['image_paths'][idx1][1]),
+ }
+
+ # for XoFTR training
+ if mask0 is not None: # img_padding is True
+ if self.coarse_scale:
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
+ scale_factor=self.coarse_scale,
+ mode='nearest',
+ recompute_scale_factor=False)[0].bool()
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
+
+ return data
diff --git a/imcui/third_party/XoFTR/src/lightning/data.py b/imcui/third_party/XoFTR/src/lightning/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2ad42b5ed304dcd72b82dc88d0314864fbb75c9
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/lightning/data.py
@@ -0,0 +1,346 @@
+import os
+import math
+from collections import abc
+from loguru import logger
+from torch.utils.data.dataset import Dataset
+from tqdm import tqdm
+from os import path as osp
+from pathlib import Path
+from joblib import Parallel, delayed
+
+import pytorch_lightning as pl
+from torch import distributed as dist
+from torch.utils.data import (
+ Dataset,
+ DataLoader,
+ ConcatDataset,
+ DistributedSampler,
+ RandomSampler,
+ dataloader
+)
+
+from src.utils.augment import build_augmentor
+from src.utils.dataloader import get_local_split
+from src.utils.misc import tqdm_joblib
+from src.utils import comm
+from src.datasets.megadepth import MegaDepthDataset
+from src.datasets.vistir import VisTirDataset
+from src.datasets.scannet import ScanNetDataset
+from src.datasets.sampler import RandomConcatSampler
+
+
+class MultiSceneDataModule(pl.LightningDataModule):
+ """
+ For distributed training, each training process is assgined
+ only a part of the training scenes to reduce memory overhead.
+ """
+ def __init__(self, args, config):
+ super().__init__()
+
+ # 1. data config
+ # Train and Val should from the same data source
+ self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE
+ self.val_data_source = config.DATASET.VAL_DATA_SOURCE
+ self.test_data_source = config.DATASET.TEST_DATA_SOURCE
+ # training and validating
+ self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
+ self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional)
+ self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT
+ self.train_list_path = config.DATASET.TRAIN_LIST_PATH
+ self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH
+ self.val_data_root = config.DATASET.VAL_DATA_ROOT
+ self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional)
+ self.val_npz_root = config.DATASET.VAL_NPZ_ROOT
+ self.val_list_path = config.DATASET.VAL_LIST_PATH
+ self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH
+ # testing
+ self.test_data_root = config.DATASET.TEST_DATA_ROOT
+ self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
+ self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
+ self.test_list_path = config.DATASET.TEST_LIST_PATH
+ self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
+
+ # 2. dataset config
+ # general options
+ self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
+ self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
+ self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
+
+ # MegaDepth options
+ self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
+ self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
+ self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
+ self.mgdpt_df = config.DATASET.MGDPT_DF # 8
+ self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr.
+
+ # VisTir options
+ self.vistir_img_resize = config.DATASET.VISTIR_IMG_RESIZE
+ self.vistir_img_pad = config.DATASET.VISTIR_IMG_PAD
+ self.vistir_df = config.DATASET.VISTIR_DF # 8
+
+ # 3.loader parameters
+ self.train_loader_params = {
+ 'batch_size': args.batch_size,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.val_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.test_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': True
+ }
+
+ # 4. sampler
+ self.data_sampler = config.TRAINER.DATA_SAMPLER
+ self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
+ self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
+ self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
+ self.repeat = config.TRAINER.SB_REPEAT
+
+ # (optional) RandomSampler for debugging
+
+ # misc configurations
+ self.parallel_load_data = getattr(args, 'parallel_load_data', False)
+ self.seed = config.TRAINER.SEED # 66
+
+ def setup(self, stage=None):
+ """
+ Setup train / val / test dataset. This method will be called by PL automatically.
+ Args:
+ stage (str): 'fit' in training phase, and 'test' in testing phase.
+ """
+
+ assert stage in ['fit', 'test'], "stage must be either fit or test"
+
+ try:
+ self.world_size = dist.get_world_size()
+ self.rank = dist.get_rank()
+ logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
+ except AssertionError as ae:
+ self.world_size = 1
+ self.rank = 0
+ logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
+
+ if stage == 'fit':
+ self.train_dataset = self._setup_dataset(
+ self.train_data_root,
+ self.train_npz_root,
+ self.train_list_path,
+ self.train_intrinsic_path,
+ mode='train',
+ min_overlap_score=self.min_overlap_score_train,
+ pose_dir=self.train_pose_root)
+ # setup multiple (optional) validation subsets
+ if isinstance(self.val_list_path, (list, tuple)):
+ self.val_dataset = []
+ if not isinstance(self.val_npz_root, (list, tuple)):
+ self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
+ for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
+ self.val_dataset.append(self._setup_dataset(
+ self.val_data_root,
+ npz_root,
+ npz_list,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root))
+ else:
+ self.val_dataset = self._setup_dataset(
+ self.val_data_root,
+ self.val_npz_root,
+ self.val_list_path,
+ self.val_intrinsic_path,
+ mode='val',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.val_pose_root)
+ logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
+ else: # stage == 'test
+ self.test_dataset = self._setup_dataset(
+ self.test_data_root,
+ self.test_npz_root,
+ self.test_list_path,
+ self.test_intrinsic_path,
+ mode='test',
+ min_overlap_score=self.min_overlap_score_test,
+ pose_dir=self.test_pose_root)
+ logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
+
+ def _setup_dataset(self,
+ data_root,
+ split_npz_root,
+ scene_list_path,
+ intri_path,
+ mode='train',
+ min_overlap_score=0.,
+ pose_dir=None):
+ """ Setup train / val / test set"""
+ with open(scene_list_path, 'r') as f:
+ npz_names = [name.split()[0] for name in f.readlines()]
+
+ if mode == 'train':
+ local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
+ else:
+ local_npz_names = npz_names
+ logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
+
+ dataset_builder = self._build_concat_dataset_parallel \
+ if self.parallel_load_data \
+ else self._build_concat_dataset
+ return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
+ mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
+
+ def _build_concat_dataset(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None
+ ):
+ datasets = []
+ augment_fn = self.augment_fn
+ if mode == 'train':
+ data_source = self.train_data_source
+ elif mode == 'val':
+ data_source = self.val_data_source
+ else:
+ data_source = self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ for npz_name in tqdm(npz_names,
+ desc=f'[rank:{self.rank}] loading {mode} datasets',
+ disable=int(self.rank) != 0):
+ # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
+ npz_path = osp.join(npz_dir, npz_name)
+ if data_source == 'ScanNet':
+ datasets.append(
+ ScanNetDataset(data_root,
+ npz_path,
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir))
+ elif data_source == 'MegaDepth':
+ datasets.append(
+ MegaDepthDataset(data_root,
+ npz_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale))
+ elif data_source == 'VisTir':
+ datasets.append(
+ VisTirDataset(data_root,
+ npz_path,
+ mode=mode,
+ img_resize=self.vistir_img_resize,
+ df=self.vistir_df,
+ img_padding=self.vistir_img_pad,
+ coarse_scale=self.coarse_scale))
+ else:
+ raise NotImplementedError()
+ return ConcatDataset(datasets)
+
+ def _build_concat_dataset_parallel(
+ self,
+ data_root,
+ npz_names,
+ npz_dir,
+ intrinsic_path,
+ mode,
+ min_overlap_score=0.,
+ pose_dir=None,
+ ):
+ augment_fn = self.augment_fn
+ if mode == 'train':
+ data_source = self.train_data_source
+ elif mode == 'val':
+ data_source = self.val_data_source
+ else:
+ data_source = self.test_data_source
+ if str(data_source).lower() == 'megadepth':
+ npz_names = [f'{n}.npz' for n in npz_names]
+ with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
+ total=len(npz_names), disable=int(self.rank) != 0)):
+ if data_source == 'ScanNet':
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ ScanNetDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ intrinsic_path,
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ augment_fn=augment_fn,
+ pose_dir=pose_dir))(name)
+ for name in npz_names)
+ elif data_source == 'MegaDepth':
+ # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
+ raise NotImplementedError()
+ datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
+ delayed(lambda x: _build_dataset(
+ MegaDepthDataset,
+ data_root,
+ osp.join(npz_dir, x),
+ mode=mode,
+ min_overlap_score=min_overlap_score,
+ img_resize=self.mgdpt_img_resize,
+ df=self.mgdpt_df,
+ img_padding=self.mgdpt_img_pad,
+ depth_padding=self.mgdpt_depth_pad,
+ augment_fn=augment_fn,
+ coarse_scale=self.coarse_scale))(name)
+ for name in npz_names)
+ else:
+ raise ValueError(f'Unknown dataset: {data_source}')
+ return ConcatDataset(datasets)
+
+ def train_dataloader(self):
+ """ Build training dataloader for ScanNet / MegaDepth. """
+ assert self.data_sampler in ['scene_balance']
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
+ if self.data_sampler == 'scene_balance':
+ sampler = RandomConcatSampler(self.train_dataset,
+ self.n_samples_per_subset,
+ self.subset_replacement,
+ self.shuffle, self.repeat, self.seed)
+ else:
+ sampler = None
+ dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
+ return dataloader
+
+ def val_dataloader(self):
+ """ Build validation dataloader for ScanNet / MegaDepth. """
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
+ if not isinstance(self.val_dataset, abc.Sequence):
+ sampler = DistributedSampler(self.val_dataset, shuffle=False)
+ return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
+ else:
+ dataloaders = []
+ for dataset in self.val_dataset:
+ sampler = DistributedSampler(dataset, shuffle=False)
+ dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
+ return dataloaders
+
+ def test_dataloader(self, *args, **kwargs):
+ logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
+ sampler = DistributedSampler(self.test_dataset, shuffle=False)
+ return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
+
+
+def _build_dataset(dataset: Dataset, *args, **kwargs):
+ return dataset(*args, **kwargs)
diff --git a/imcui/third_party/XoFTR/src/lightning/data_pretrain.py b/imcui/third_party/XoFTR/src/lightning/data_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..11eba2cc3f7918563b3e3ed2f371eb72edf422d0
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/lightning/data_pretrain.py
@@ -0,0 +1,125 @@
+from collections import abc
+from loguru import logger
+
+import pytorch_lightning as pl
+from torch import distributed as dist
+from torch.utils.data import (
+ DataLoader,
+ ConcatDataset,
+ DistributedSampler
+)
+
+from src.datasets.pretrain_dataset import PretrainDataset
+
+
+class PretrainDataModule(pl.LightningDataModule):
+ """
+ For distributed training, each training process is assgined
+ only a part of the training scenes to reduce memory overhead.
+ """
+ def __init__(self, args, config):
+ super().__init__()
+
+ # 1. data config
+ # Train and Val should from the same data source
+ self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE
+ self.val_data_source = config.DATASET.VAL_DATA_SOURCE
+ # training and validating
+ self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
+ self.val_data_root = config.DATASET.VAL_DATA_ROOT
+
+ # 2. dataset config']
+
+ # dataset options
+ self.pretrain_img_resize = config.DATASET.PRETRAIN_IMG_RESIZE # 840
+ self.pretrain_img_pad = config.DATASET.PRETRAIN_IMG_PAD # True
+ self.pretrain_df = config.DATASET.PRETRAIN_DF # 8
+ self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr.
+ self.frame_gap = config.DATASET.PRETRAIN_FRAME_GAP
+
+ # 3.loader parameters
+ self.train_loader_params = {
+ 'batch_size': args.batch_size,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+ self.val_loader_params = {
+ 'batch_size': 1,
+ 'shuffle': False,
+ 'num_workers': args.num_workers,
+ 'pin_memory': getattr(args, 'pin_memory', True)
+ }
+
+ def setup(self, stage=None):
+ """
+ Setup train / val / test dataset. This method will be called by PL automatically.
+ Args:
+ stage (str): 'fit' in training phase, and 'test' in testing phase.
+ """
+
+ assert stage in ['fit', 'test'], "stage must be either fit or test"
+
+ try:
+ self.world_size = dist.get_world_size()
+ self.rank = dist.get_rank()
+ logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
+ except AssertionError as ae:
+ self.world_size = 1
+ self.rank = 0
+ logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
+
+ if stage == 'fit':
+ self.train_dataset = self._setup_dataset(
+ self.train_data_root,
+ mode='train')
+ # setup multiple (optional) validation subsets
+ self.val_dataset = []
+ self.val_dataset.append(self._setup_dataset(
+ self.val_data_root,
+ mode='val'))
+ logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
+ else: # stage == 'test
+ raise ValueError(f"only 'fit' implemented")
+
+ def _setup_dataset(self,
+ data_root,
+ mode='train'):
+ """ Setup train / val / test set"""
+
+ dataset_builder = self._build_concat_dataset
+ return dataset_builder(data_root, mode=mode)
+
+ def _build_concat_dataset(
+ self,
+ data_root,
+ mode
+ ):
+ datasets = []
+
+ datasets.append(
+ PretrainDataset(data_root,
+ mode=mode,
+ img_resize=self.pretrain_img_resize,
+ df=self.pretrain_df,
+ img_padding=self.pretrain_img_pad,
+ coarse_scale=self.coarse_scale,
+ frame_gap=self.frame_gap))
+
+ return ConcatDataset(datasets)
+
+ def train_dataloader(self):
+ """ Build training dataloader for KAIST dataset. """
+ sampler = DistributedSampler(self.train_dataset, shuffle=True)
+ dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
+ return dataloader
+
+ def val_dataloader(self):
+ """ Build validation dataloader KAIST dataset. """
+ if not isinstance(self.val_dataset, abc.Sequence):
+ return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
+ else:
+ dataloaders = []
+ for dataset in self.val_dataset:
+ sampler = DistributedSampler(dataset, shuffle=False)
+ dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
+ return dataloaders
diff --git a/imcui/third_party/XoFTR/src/lightning/lightning_xoftr.py b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e6758330a4ac33ea2bbe794d7141756abd61e3
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr.py
@@ -0,0 +1,334 @@
+
+from collections import defaultdict
+import pprint
+from loguru import logger
+from pathlib import Path
+
+import torch
+import numpy as np
+import pytorch_lightning as pl
+from matplotlib import pyplot as plt
+plt.switch_backend('agg')
+
+from src.xoftr import XoFTR
+from src.xoftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
+from src.losses.xoftr_loss import XoFTRLoss
+from src.optimizers import build_optimizer, build_scheduler
+from src.utils.metrics import (
+ compute_symmetrical_epipolar_errors,
+ compute_pose_errors,
+ aggregate_metrics
+)
+from src.utils.plotting import make_matching_figures
+from src.utils.comm import gather, all_gather
+from src.utils.misc import lower_config, flattenList
+from src.utils.profiler import PassThroughProfiler
+
+
+class PL_XoFTR(pl.LightningModule):
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
+ """
+ TODO:
+ - use the new version of PL logging API.
+ """
+ super().__init__()
+ # Misc
+ self.config = config # full config
+ _config = lower_config(self.config)
+ self.xoftr_cfg = lower_config(_config['xoftr'])
+ self.profiler = profiler or PassThroughProfiler()
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+
+ # Matcher: XoFTR
+ self.matcher = XoFTR(config=_config['xoftr'])
+ self.loss = XoFTRLoss(_config)
+
+ # Pretrained weights
+ if pretrained_ckpt:
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
+ self.matcher.load_state_dict(state_dict, strict=False)
+ logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
+ for name, param in self.matcher.named_parameters():
+ if name in state_dict.keys():
+ print("in ckpt: ", name)
+ else:
+ print("out ckpt: ", name)
+
+ # Testing
+ self.dump_dir = dump_dir
+
+ def configure_optimizers(self):
+ # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
+ optimizer = build_optimizer(self, self.config)
+ scheduler = build_scheduler(self.config, optimizer)
+ return [optimizer], [scheduler]
+
+ def optimizer_step(
+ self, epoch, batch_idx, optimizer, optimizer_idx,
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+ # learning rate warm up
+ warmup_step = self.config.TRAINER.WARMUP_STEP
+ if self.trainer.global_step < warmup_step:
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
+ lr = base_lr + \
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
+ for pg in optimizer.param_groups:
+ pg['lr'] = lr
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+ pass
+ else:
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+
+ # update params
+ optimizer.step(closure=optimizer_closure)
+ optimizer.zero_grad()
+
+ def _trainval_inference(self, batch):
+ with self.profiler.profile("Compute coarse supervision"):
+ compute_supervision_coarse(batch, self.config)
+
+ with self.profiler.profile("XoFTR"):
+ self.matcher(batch)
+
+ with self.profiler.profile("Compute fine supervision"):
+ compute_supervision_fine(batch, self.config)
+
+ with self.profiler.profile("Compute losses"):
+ self.loss(batch)
+
+ def _compute_metrics(self, batch):
+ with self.profiler.profile("Copmute metrics"):
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
+
+ rel_pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].size(0)
+ metrics = {
+ # to filter duplicate pairs caused by DistributedSampler
+ 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
+ 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
+ 'R_errs': batch['R_errs'],
+ 't_errs': batch['t_errs'],
+ 'inliers': batch['inliers']}
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
+ metrics.update({'scene_id': batch['scene_id']})
+ ret_dict = {'metrics': metrics}
+ return ret_dict, rel_pair_names
+
+ def training_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ # logging
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+ # scalars
+ for k, v in batch['loss_scalars'].items():
+ self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
+ if self.config.TRAINER.USE_WANDB:
+ self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
+
+ # figures
+ if self.config.TRAINER.ENABLE_PLOTTING:
+ compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
+ figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
+ for k, v in figures.items():
+ self.logger[0].experiment.add_figure(f'train_match/{k}', v, self.global_step)
+
+ return {'loss': batch['loss']}
+
+ def training_epoch_end(self, outputs):
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+ if self.trainer.global_rank == 0:
+ self.logger[0].experiment.add_scalar(
+ 'train/avg_loss_on_epoch', avg_loss,
+ global_step=self.current_epoch)
+ if self.config.TRAINER.USE_WANDB:
+ self.logger[1].log_metrics(
+ {'train/avg_loss_on_epoch': avg_loss},
+ self.current_epoch)
+
+ def validation_step(self, batch, batch_idx):
+ # no loss calculation for VisTir during val
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
+ with self.profiler.profile("XoFTR"):
+ self.matcher(batch)
+ else:
+ self._trainval_inference(batch)
+
+ ret_dict, _ = self._compute_metrics(batch)
+
+ val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
+ figures = {self.config.TRAINER.PLOT_MODE: []}
+ if batch_idx % val_plot_interval == 0:
+ figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE, ret_dict=ret_dict)
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
+ return {
+ **ret_dict,
+ 'figures': figures,
+ }
+ else:
+ return {
+ **ret_dict,
+ 'loss_scalars': batch['loss_scalars'],
+ 'figures': figures,
+ }
+
+ def validation_epoch_end(self, outputs):
+ # handle multiple validation sets
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+ multi_val_metrics = defaultdict(list)
+
+ for valset_idx, outputs in enumerate(multi_outputs):
+ # since pl performs sanity_check at the very begining of the training
+ cur_epoch = self.trainer.current_epoch
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+ cur_epoch = -1
+
+ if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
+ metrics_per_scene = {}
+ for o in outputs:
+ if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
+ metrics_per_scene[o['metrics']['scene_id'][0]] = []
+ metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
+
+ aucs_per_scene = {}
+ for scene_id in metrics_per_scene.keys():
+ # 2. val metrics: dict of list, numpy
+ _metrics = metrics_per_scene[scene_id]
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+ val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+ aucs_per_scene[scene_id] = val_metrics
+
+ # average the metrics of scenes
+ # since the number of images in each scene is different
+ val_metrics_4tb = {}
+ for thr in [5, 10, 20]:
+ temp = []
+ for scene_id in metrics_per_scene.keys():
+ temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
+ val_metrics_4tb[f'auc@{thr}'] = float(np.array(temp, dtype=float).mean())
+ temp = []
+ for scene_id in metrics_per_scene.keys():
+ temp.append(aucs_per_scene[scene_id][f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'])
+ val_metrics_4tb[f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'] = float(np.array(temp, dtype=float).mean())
+ else:
+ # 1. loss_scalars: dict of list, on cpu
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+
+ # 2. val metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+
+ for thr in [5, 10, 20]:
+ multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
+
+ # 3. figures
+ _figures = [o['figures'] for o in outputs]
+ figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
+
+ # tensorboard records only on rank 0
+ if self.trainer.global_rank == 0:
+ if self.config.DATASET.VAL_DATA_SOURCE != "VisTir":
+ for k, v in loss_scalars.items():
+ mean_v = torch.stack(v).mean()
+ self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+
+ for k, v in val_metrics_4tb.items():
+ self.logger[0].experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
+ if self.config.TRAINER.USE_WANDB:
+ self.logger[1].log_metrics({f"metrics_{valset_idx}/{k}": v}, cur_epoch)
+
+ for k, v in figures.items():
+ if self.trainer.global_rank == 0:
+ for plot_idx, fig in enumerate(v):
+ self.logger[0].experiment.add_figure(
+ f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
+ plt.close('all')
+
+ for thr in [5, 10, 20]:
+ # log on all ranks for ModelCheckpoint callback to work properly
+ self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
+
+ def test_step(self, batch, batch_idx):
+ with self.profiler.profile("XoFTR"):
+ self.matcher(batch)
+
+ ret_dict, rel_pair_names = self._compute_metrics(batch)
+
+ with self.profiler.profile("dump_results"):
+ if self.dump_dir is not None:
+ # dump results for further analysis
+ keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf_f', 'epi_errs'}
+ pair_names = list(zip(*batch['pair_names']))
+ bs = batch['image0'].shape[0]
+ dumps = []
+ for b_id in range(bs):
+ item = {}
+ mask = batch['m_bids'] == b_id
+ item['pair_names'] = pair_names[b_id]
+ item['identifier'] = '#'.join(rel_pair_names[b_id])
+ if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
+ item['scene_id'] = batch['scene_id']
+ item['K0'] = batch['K0'][b_id].cpu().numpy()
+ item['K1'] = batch['K1'][b_id].cpu().numpy()
+ item['dist0'] = batch['dist0'][b_id].cpu().numpy()
+ item['dist1'] = batch['dist1'][b_id].cpu().numpy()
+ for key in keys_to_save:
+ item[key] = batch[key][mask].cpu().numpy()
+ for key in ['R_errs', 't_errs', 'inliers']:
+ item[key] = batch[key][b_id]
+ dumps.append(item)
+ ret_dict['dumps'] = dumps
+
+ return ret_dict
+
+ def test_epoch_end(self, outputs):
+
+ if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
+ # metrics: dict of list, numpy
+ metrics_per_scene = {}
+ for o in outputs:
+ if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
+ metrics_per_scene[o['metrics']['scene_id'][0]] = []
+ metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
+
+ aucs_per_scene = {}
+ for scene_id in metrics_per_scene.keys():
+ # 2. val metrics: dict of list, numpy
+ _metrics = metrics_per_scene[scene_id]
+ metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
+ val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+ aucs_per_scene[scene_id] = val_metrics
+
+ # average the metrics of scenes
+ # since the number of images in each scene is different
+ val_metrics_4tb = {}
+ for thr in [5, 10, 20]:
+ temp = []
+ for scene_id in metrics_per_scene.keys():
+ temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
+ val_metrics_4tb[f'auc@{thr}'] = np.array(temp, dtype=float).mean()
+ else:
+ # metrics: dict of list, numpy
+ _metrics = [o['metrics'] for o in outputs]
+ metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
+
+ # [{key: [{...}, *#bs]}, *#batch]
+ if self.dump_dir is not None:
+ Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
+ _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
+ dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
+ logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
+
+ if self.trainer.global_rank == 0:
+ print(self.profiler.summary())
+ val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
+ logger.info('\n' + pprint.pformat(val_metrics_4tb))
+ if self.dump_dir is not None:
+ np.save(Path(self.dump_dir) / 'XoFTR_pred_eval', dumps)
\ No newline at end of file
diff --git a/imcui/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dc57d4205390cd1439edc2f38e3b219bad3db35
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py
@@ -0,0 +1,171 @@
+
+from loguru import logger
+
+import torch
+import pytorch_lightning as pl
+from matplotlib import pyplot as plt
+plt.switch_backend('agg')
+
+from src.xoftr import XoFTR_Pretrain
+from src.losses.xoftr_loss_pretrain import XoFTRLossPretrain
+from src.optimizers import build_optimizer, build_scheduler
+from src.utils.plotting import make_mae_figures
+from src.utils.comm import all_gather
+from src.utils.misc import lower_config, flattenList
+from src.utils.profiler import PassThroughProfiler
+from src.utils.pretrain_utils import generate_random_masks, get_target
+
+
+class PL_XoFTR_Pretrain(pl.LightningModule):
+ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
+ """
+ TODO:
+ - use the new version of PL logging API.
+ """
+ super().__init__()
+ # Misc
+ self.config = config # full config
+
+ _config = lower_config(self.config)
+ self.xoftr_cfg = lower_config(_config['xoftr'])
+ self.profiler = profiler or PassThroughProfiler()
+ self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
+
+ # generator to create the same masks for validation
+ self.val_seed = self.config.PRETRAIN.VAL_SEED
+ self.val_generator = torch.Generator(device="cuda").manual_seed(self.val_seed)
+ self.mae_margins = config.PRETRAIN.MAE_MARGINS
+
+ # Matcher: XoFTR
+ self.matcher = XoFTR_Pretrain(config=_config['xoftr'])
+ self.loss = XoFTRLossPretrain(_config)
+
+ # Pretrained weights
+ if pretrained_ckpt:
+ state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
+ self.matcher.load_state_dict(state_dict, strict=False)
+ logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
+
+ # Testing
+ self.dump_dir = dump_dir
+
+ def configure_optimizers(self):
+ # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
+ optimizer = build_optimizer(self, self.config)
+ scheduler = build_scheduler(self.config, optimizer)
+ return [optimizer], [scheduler]
+
+ def optimizer_step(
+ self, epoch, batch_idx, optimizer, optimizer_idx,
+ optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
+ # learning rate warm up
+ warmup_step = self.config.TRAINER.WARMUP_STEP
+ if self.trainer.global_step < warmup_step:
+ if self.config.TRAINER.WARMUP_TYPE == 'linear':
+ base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
+ lr = base_lr + \
+ (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
+ abs(self.config.TRAINER.TRUE_LR - base_lr)
+ for pg in optimizer.param_groups:
+ pg['lr'] = lr
+ elif self.config.TRAINER.WARMUP_TYPE == 'constant':
+ pass
+ else:
+ raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
+
+ # update params
+ optimizer.step(closure=optimizer_closure)
+ optimizer.zero_grad()
+
+ def _trainval_inference(self, batch, generator=None):
+ generate_random_masks(batch,
+ patch_size=self.config.PRETRAIN.PATCH_SIZE,
+ mask_ratio=self.config.PRETRAIN.MASK_RATIO,
+ generator=generator,
+ margins=self.mae_margins)
+
+ with self.profiler.profile("XoFTR"):
+ self.matcher(batch)
+
+ with self.profiler.profile("Compute losses"):
+ # Create target pacthes to reconstruct
+ get_target(batch)
+ self.loss(batch)
+
+ def training_step(self, batch, batch_idx):
+ self._trainval_inference(batch)
+
+ # logging
+ if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
+ # scalars
+ for k, v in batch['loss_scalars'].items():
+ self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
+ if self.config.TRAINER.USE_WANDB:
+ self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
+
+ if self.config.TRAINER.ENABLE_PLOTTING:
+ figures = make_mae_figures(batch)
+ for i, figure in enumerate(figures):
+ self.logger[0].experiment.add_figure(
+ f'train_mae/node_{self.trainer.global_rank}-device_{self.device.index}-batch_{i}',
+ figure, self.global_step)
+
+ return {'loss': batch['loss']}
+
+ def training_epoch_end(self, outputs):
+ avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
+ if self.trainer.global_rank == 0:
+ self.logger[0].experiment.add_scalar(
+ 'train/avg_loss_on_epoch', avg_loss,
+ global_step=self.current_epoch)
+ if self.config.TRAINER.USE_WANDB:
+ self.logger[1].log_metrics(
+ {'train/avg_loss_on_epoch': avg_loss},
+ self.current_epoch)
+
+ def validation_step(self, batch, batch_idx):
+ self._trainval_inference(batch, self.val_generator)
+
+ val_plot_interval = max(self.trainer.num_val_batches[0] // \
+ (self.trainer.num_gpus * self.n_vals_plot), 1)
+ figures = []
+ if batch_idx % val_plot_interval == 0:
+ figures = make_mae_figures(batch)
+
+ return {
+ 'loss_scalars': batch['loss_scalars'],
+ 'figures': figures,
+ }
+
+ def validation_epoch_end(self, outputs):
+ self.val_generator.manual_seed(self.val_seed)
+ # handle multiple validation sets
+ multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
+
+ for valset_idx, outputs in enumerate(multi_outputs):
+ # since pl performs sanity_check at the very begining of the training
+ cur_epoch = self.trainer.current_epoch
+ if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
+ cur_epoch = -1
+
+ # 1. loss_scalars: dict of list, on cpu
+ _loss_scalars = [o['loss_scalars'] for o in outputs]
+ loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
+
+ _figures = [o['figures'] for o in outputs]
+ figures = [item for sublist in _figures for item in sublist]
+
+ # tensorboard records only on rank 0
+ if self.trainer.global_rank == 0:
+ for k, v in loss_scalars.items():
+ mean_v = torch.stack(v).mean()
+ self.logger[0].experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
+ if self.config.TRAINER.USE_WANDB:
+ self.logger[1].log_metrics({f'val_{valset_idx}/avg_{k}': mean_v}, cur_epoch)
+
+ for plot_idx, fig in enumerate(figures):
+ self.logger[0].experiment.add_figure(
+ f'val_mae_{valset_idx}/pair-{plot_idx}', fig, cur_epoch, close=True)
+
+ plt.close('all')
+
diff --git a/imcui/third_party/XoFTR/src/losses/xoftr_loss.py b/imcui/third_party/XoFTR/src/losses/xoftr_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7ac9c1546a09fbd909ea11e3e597e03d92d56e
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/losses/xoftr_loss.py
@@ -0,0 +1,170 @@
+from loguru import logger
+
+import torch
+import torch.nn as nn
+from kornia.geometry.conversions import convert_points_to_homogeneous
+from kornia.geometry.epipolar import numeric
+
+class XoFTRLoss(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config # config under the global namespace
+ self.loss_config = config['xoftr']['loss']
+ self.pos_w = self.loss_config['pos_weight']
+ self.neg_w = self.loss_config['neg_weight']
+
+
+ def compute_fine_matching_loss(self, data):
+ """ Point-wise Focal Loss with 0 / 1 confidence as gt.
+ Args:
+ data (dict): {
+ conf_matrix_fine (torch.Tensor): (N, W_f^2, W_f^2)
+ conf_matrix_f_gt (torch.Tensor): (N, W_f^2, W_f^2)
+ }
+ """
+ conf_matrix_fine = data['conf_matrix_fine']
+ conf_matrix_f_gt = data['conf_matrix_f_gt']
+ pos_mask, neg_mask = conf_matrix_f_gt > 0, conf_matrix_f_gt == 0
+ pos_w, neg_w = self.pos_w, self.neg_w
+
+ if not pos_mask.any(): # assign a wrong gt
+ pos_mask[0, 0, 0] = True
+ pos_w = 0.
+ if not neg_mask.any():
+ neg_mask[0, 0, 0] = True
+ neg_w = 0.
+
+ conf_matrix_fine = torch.clamp(conf_matrix_fine, 1e-6, 1-1e-6)
+ alpha = self.loss_config['focal_alpha']
+ gamma = self.loss_config['focal_gamma']
+
+ loss_pos = - alpha * torch.pow(1 - conf_matrix_fine[pos_mask], gamma) * (conf_matrix_fine[pos_mask]).log()
+ # loss_pos *= conf_matrix_f_gt[pos_mask]
+ loss_neg = - alpha * torch.pow(conf_matrix_fine[neg_mask], gamma) * (1 - conf_matrix_fine[neg_mask]).log()
+
+ return pos_w * loss_pos.mean() + neg_w * loss_neg.mean()
+
+ def _symmetric_epipolar_distance(self, pts0, pts1, E, K0, K1):
+ """Squared symmetric epipolar distance.
+ This can be seen as a biased estimation of the reprojection error.
+ Args:
+ pts0 (torch.Tensor): [N, 2]
+ E (torch.Tensor): [3, 3]
+ """
+ pts0 = (pts0 - K0[:, [0, 1], [2, 2]]) / K0[:, [0, 1], [0, 1]]
+ pts1 = (pts1 - K1[:, [0, 1], [2, 2]]) / K1[:, [0, 1], [0, 1]]
+ pts0 = convert_points_to_homogeneous(pts0)
+ pts1 = convert_points_to_homogeneous(pts1)
+
+ Ep0 = (pts0[:,None,:] @ E.transpose(-2,-1)).squeeze(1) # [N, 3]
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
+ Etp1 = (pts1[:,None,:] @ E).squeeze(1) # [N, 3]
+
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2 + 1e-9) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2 + 1e-9)) # N
+ return d
+
+ def compute_sub_pixel_loss(self, data):
+ """ symmetric epipolar distance loss.
+ Args:
+ data (dict): {
+ m_bids (torch.Tensor): (N)
+ T_0to1 (torch.Tensor): (B, 4, 4)
+ mkpts0_f_train (torch.Tensor): (N, 2)
+ mkpts1_f_train (torch.Tensor): (N, 2)
+ }
+ """
+
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
+
+ m_bids = data['m_bids']
+ pts0 = data['mkpts0_f_train']
+ pts1 = data['mkpts1_f_train']
+
+ sym_dist = self._symmetric_epipolar_distance(pts0, pts1, E_mat[m_bids], data['K0'][m_bids], data['K1'][m_bids])
+ # filter matches with high epipolar error (only train approximately correct fine-level matches)
+ loss = sym_dist[sym_dist<1e-4]
+ if len(loss) == 0:
+ return torch.zeros(1, device=loss.device, requires_grad=False)[0]
+ return loss.mean()
+
+ def compute_coarse_loss(self, data, weight=None):
+ """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
+ Args:
+ data (dict): {
+ conf_matrix_0_to_1 (torch.Tensor): (N, HW0, HW1)
+ conf_matrix_1_to_0 (torch.Tensor): (N, HW0, HW1)
+ conf_gt (torch.Tensor): (N, HW0, HW1)
+ }
+ weight (torch.Tensor): (N, HW0, HW1)
+ """
+
+ conf_matrix_0_to_1 = data["conf_matrix_0_to_1"]
+ conf_matrix_1_to_0 = data["conf_matrix_1_to_0"]
+ conf_gt = data["conf_matrix_gt"]
+
+ pos_mask = conf_gt == 1
+ c_pos_w = self.pos_w
+ # corner case: no gt coarse-level match at all
+ if not pos_mask.any(): # assign a wrong gt
+ pos_mask[0, 0, 0] = True
+ if weight is not None:
+ weight[0, 0, 0] = 0.
+ c_pos_w = 0.
+
+ conf_matrix_0_to_1 = torch.clamp(conf_matrix_0_to_1, 1e-6, 1-1e-6)
+ conf_matrix_1_to_0 = torch.clamp(conf_matrix_1_to_0, 1e-6, 1-1e-6)
+ alpha = self.loss_config['focal_alpha']
+ gamma = self.loss_config['focal_gamma']
+
+ loss_pos = - alpha * torch.pow(1 - conf_matrix_0_to_1[pos_mask], gamma) * (conf_matrix_0_to_1[pos_mask]).log()
+ loss_pos += - alpha * torch.pow(1 - conf_matrix_1_to_0[pos_mask], gamma) * (conf_matrix_1_to_0[pos_mask]).log()
+ if weight is not None:
+ loss_pos = loss_pos * weight[pos_mask]
+
+ loss_c = (c_pos_w * loss_pos.mean())
+
+ return loss_c
+
+ @torch.no_grad()
+ def compute_c_weight(self, data):
+ """ compute element-wise weights for computing coarse-level loss. """
+ if 'mask0' in data:
+ c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
+ else:
+ c_weight = None
+ return c_weight
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): update{
+ 'loss': [1] the reduced loss across a batch,
+ 'loss_scalars' (dict): loss scalars for tensorboard_record
+ }
+ """
+ loss_scalars = {}
+ # 0. compute element-wise loss weight
+ c_weight = self.compute_c_weight(data)
+
+ # 1. coarse-level loss
+ loss_c = self.compute_coarse_loss(data, weight=c_weight)
+ loss_c *= self.loss_config['coarse_weight']
+ loss = loss_c
+ loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
+
+ # 2. fine-level matching loss for windows
+ loss_f_match = self.compute_fine_matching_loss(data)
+ loss_f_match *= self.loss_config['fine_weight']
+ loss = loss + loss_f_match
+ loss_scalars.update({"loss_f": loss_f_match.clone().detach().cpu()})
+
+ # 3. sub-pixel refinement loss
+ loss_sub = self.compute_sub_pixel_loss(data)
+ loss_sub *= self.loss_config['sub_weight']
+ loss = loss + loss_sub
+ loss_scalars.update({"loss_sub": loss_sub.clone().detach().cpu()})
+
+
+ loss_scalars.update({'loss': loss.clone().detach().cpu()})
+ data.update({"loss": loss, "loss_scalars": loss_scalars})
diff --git a/imcui/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py b/imcui/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..d00e1781b0bd105414b59697418a33b27d178b4e
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class XoFTRLossPretrain(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config # config under the global namespace
+ self.W_f = config["xoftr"]['fine_window_size']
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): update{
+ 'loss': [1] the reduced loss across a batch,
+ 'loss_scalars' (dict): loss scalars for tensorboard_record
+ }
+ """
+ loss_scalars = {}
+
+ pred0, pred1 = data["pred0"], data["pred1"]
+ target0, target1 = data["target0"], data["target1"]
+ target0 = target0[[data['b_ids'], data['i_ids']]]
+ target1 = target1[[data['b_ids'], data['j_ids']]]
+
+ # get correct indices
+ pred0 = pred0[data["ids_image0"]]
+ pred1 = pred1[data["ids_image1"]]
+ target0 = target0[data["ids_image0"]]
+ target1 = target1[data["ids_image1"]]
+
+ loss0 = (pred0 - target0)**2
+ loss1 = (pred1 - target1)**2
+ loss = loss0.mean() + loss1.mean()
+
+ loss_scalars.update({'loss': loss.clone().detach().cpu()})
+ data.update({"loss": loss, "loss_scalars": loss_scalars})
diff --git a/imcui/third_party/XoFTR/src/optimizers/__init__.py b/imcui/third_party/XoFTR/src/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1db2285352586c250912bdd2c4ae5029620ab5f
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/optimizers/__init__.py
@@ -0,0 +1,42 @@
+import torch
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
+
+
+def build_optimizer(model, config):
+ name = config.TRAINER.OPTIMIZER
+ lr = config.TRAINER.TRUE_LR
+
+ if name == "adam":
+ return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
+ elif name == "adamw":
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
+ else:
+ raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
+
+
+def build_scheduler(config, optimizer):
+ """
+ Returns:
+ scheduler (dict):{
+ 'scheduler': lr_scheduler,
+ 'interval': 'step', # or 'epoch'
+ 'monitor': 'val_f1', (optional)
+ 'frequency': x, (optional)
+ }
+ """
+ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
+ name = config.TRAINER.SCHEDULER
+
+ if name == 'MultiStepLR':
+ scheduler.update(
+ {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
+ elif name == 'CosineAnnealing':
+ scheduler.update(
+ {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
+ elif name == 'ExponentialLR':
+ scheduler.update(
+ {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
+ else:
+ raise NotImplementedError()
+
+ return scheduler
diff --git a/imcui/third_party/XoFTR/src/utils/augment.py b/imcui/third_party/XoFTR/src/utils/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f3ef976b4ad93f5eb3be89707881d1780445313
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/augment.py
@@ -0,0 +1,113 @@
+import albumentations as A
+import numpy as np
+import cv2
+
+class DarkAug(object):
+ """
+ Extreme dark augmentation aiming at Aachen Day-Night
+ """
+
+ def __init__(self):
+ self.augmentor = A.Compose([
+ A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
+ A.Blur(p=0.1, blur_limit=(3, 9)),
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
+ ], p=0.75)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+
+class MobileAug(object):
+ """
+ Random augmentations aiming at images of mobile/handhold devices.
+ """
+
+ def __init__(self):
+ self.augmentor = A.Compose([
+ A.MotionBlur(p=0.25),
+ A.ColorJitter(p=0.5),
+ A.RandomRain(p=0.1), # random occlusion
+ A.RandomSunFlare(p=0.1),
+ A.JpegCompression(p=0.25),
+ A.ISONoise(p=0.25)
+ ], p=1.0)
+
+ def __call__(self, x):
+ return self.augmentor(image=x)['image']
+
+class RGBThermalAug(object):
+ """
+ Pseudo-thermal image augmentation
+ """
+
+ def __init__(self):
+ self.blur = A.Blur(p=0.7, blur_limit=(2, 4))
+ self.hsv = A.HueSaturationValue(p=0.9, val_shift_limit=(-30, +30), hue_shift_limit=(-90,+90), sat_shift_limit=(-30,+30))
+
+ # Switch images to apply augmentation
+ self.random_switch = True
+
+ # parameters for the cosine transform
+ self.w_0 = np.pi * 2 / 3
+ self.w_r = np.pi / 2
+ self.theta_r = np.pi / 2
+
+ def augment_pseudo_thermal(self, image):
+
+ # HSV augmentation
+ image = self.hsv(image=image)["image"]
+
+ # Random blur
+ image = self.blur(image=image)["image"]
+
+ # Convert the image to the gray scale
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+
+ # Normalize the image between (-0.5, 0.5)
+ image = image / 255 - 0.5 # 8 bit color
+
+ # Random phase and freq for the cosine transform
+ phase = np.pi / 2 + np.random.randn(1) * self.theta_r
+ w = self.w_0 + np.abs(np.random.randn(1)) * self.w_r
+
+ # Cosine transform
+ image = np.cos(image * w + phase)
+
+ # Min-max normalization for the transformed image
+ image = (image - image.min()) / (image.max() - image.min()) * 255
+
+ # 3 channel gray
+ image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_GRAY2RGB)
+
+ return image
+
+ def __call__(self, x, image_num):
+ if image_num==0:
+ # augmentation for RGB image can be added here
+ return x
+ elif image_num==1:
+ # pseudo-thermal augmentation
+ return self.augment_pseudo_thermal(x)
+ else:
+ raise ValueError(f'Invalid image number: {image_num}')
+
+
+def build_augmentor(method=None, **kwargs):
+
+ if method == 'dark':
+ return DarkAug()
+ elif method == 'mobile':
+ return MobileAug()
+ elif method == "rgb_thermal":
+ return RGBThermalAug()
+ elif method is None:
+ return None
+ else:
+ raise ValueError(f'Invalid augmentation method: {method}')
+
+
+if __name__ == '__main__':
+ augmentor = build_augmentor('FDA')
\ No newline at end of file
diff --git a/imcui/third_party/XoFTR/src/utils/comm.py b/imcui/third_party/XoFTR/src/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ec9517cc47e224430106d8ae9aa99a3fe49167
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/comm.py
@@ -0,0 +1,265 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+[Copied from detectron2]
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group=group) == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving Tensor from all ranks
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/imcui/third_party/XoFTR/src/utils/data_io.py b/imcui/third_party/XoFTR/src/utils/data_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..c63628c9c7f9fa64a6c8817c0ed9f514f29aa9cb
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/data_io.py
@@ -0,0 +1,144 @@
+import torch
+from torch import nn
+import numpy as np
+import cv2
+# import torchvision.transforms as transforms
+import torch.nn.functional as F
+from yacs.config import CfgNode as CN
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+def upper_config(dict_cfg):
+ if not isinstance(dict_cfg, dict):
+ return dict_cfg
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
+
+
+class DataIOWrapper(nn.Module):
+ """
+ Pre-propcess data from different sources
+ """
+
+ def __init__(self, model, config, ckpt=None):
+ super().__init__()
+
+ self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
+ torch.set_grad_enabled(False)
+ self.model = model
+ self.config = config
+ self.img0_size = config['img0_resize']
+ self.img1_size = config['img1_resize']
+ self.df = config['df']
+ self.padding = config['padding']
+ self.coarse_scale = config['coarse_scale']
+
+ if ckpt:
+ ckpt_dict = torch.load(ckpt)
+ self.model.load_state_dict(ckpt_dict['state_dict'])
+ self.model = self.model.eval().to(self.device)
+
+ def preprocess_image(self, img, device, resize=None, df=None, padding=None, cam_K=None, dist=None, gray_scale=True):
+ # xoftr takes grayscale input images
+ if gray_scale and len(img.shape) == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+
+ h, w = img.shape[:2]
+ new_K = None
+ img_undistorted = None
+ if cam_K is not None and dist is not None:
+ new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h))
+ img = cv2.undistort(img, cam_K, dist, None, new_K)
+ img_undistorted = img.copy()
+
+ if resize is not None:
+ scale = resize / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ else:
+ w_new, h_new = w, h
+
+ if df is not None:
+ w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new])
+
+ img = cv2.resize(img, (w_new, h_new))
+ scale = np.array([w/w_new, h/h_new], dtype=np.float)
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ img, mask = self.pad_bottom_right(img, pad_to, ret_mask=True)
+ mask = torch.from_numpy(mask).to(device)
+ else:
+ mask = None
+ # img = transforms.functional.to_tensor(img).unsqueeze(0).to(device)
+ if len(img.shape) == 2: # grayscale image
+ img = torch.from_numpy(img)[None][None].cuda().float() / 255.0
+ else: # Color image
+ img = torch.from_numpy(img).permute(2, 0, 1)[None].float() / 255.0
+ return img, scale, mask, new_K, img_undistorted
+
+ def from_cv_imgs(self, img0, img1, K0=None, K1=None, dist0=None, dist1=None):
+ img0_tensor, scale0, mask0, new_K0, img0_undistorted = self.preprocess_image(
+ img0, self.device, resize=self.img0_size, df=self.df, padding=self.padding, cam_K=K0, dist=dist0)
+ img1_tensor, scale1, mask1, new_K1, img1_undistorted = self.preprocess_image(
+ img1, self.device, resize=self.img1_size, df=self.df, padding=self.padding, cam_K=K1, dist=dist1)
+ mkpts0, mkpts1, mconf = self.match_images(img0_tensor, img1_tensor, mask0, mask1)
+ mkpts0 = mkpts0 * scale0
+ mkpts1 = mkpts1 * scale1
+ matches = np.concatenate([mkpts0, mkpts1], axis=1)
+ data = {'matches':matches,
+ 'mkpts0':mkpts0,
+ 'mkpts1':mkpts1,
+ 'mconf':mconf,
+ 'img0':img0,
+ 'img1':img1
+ }
+ if K0 is not None and dist0 is not None:
+ data.update({'new_K0':new_K0, 'img0_undistorted':img0_undistorted})
+ if K1 is not None and dist1 is not None:
+ data.update({'new_K1':new_K1, 'img1_undistorted':img1_undistorted})
+ return data
+
+ def from_paths(self, img0_pth, img1_pth, K0=None, K1=None, dist0=None, dist1=None, read_color=False):
+
+ imread_flag = cv2.IMREAD_COLOR if read_color else cv2.IMREAD_GRAYSCALE
+
+ img0 = cv2.imread(img0_pth, imread_flag)
+ img1 = cv2.imread(img1_pth, imread_flag)
+ return self.from_cv_imgs(img0, img1, K0=K0, K1=K1, dist0=dist0, dist1=dist1)
+
+ def match_images(self, image0, image1, mask0, mask1):
+ batch = {'image0': image0, 'image1': image1}
+ if mask0 is not None: # img_padding is True
+ if self.coarse_scale:
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
+ scale_factor=self.coarse_scale,
+ mode='nearest',
+ recompute_scale_factor=False)[0].bool()
+ batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)})
+ self.model(batch)
+ mkpts0 = batch['mkpts0_f'].cpu().numpy()
+ mkpts1 = batch['mkpts1_f'].cpu().numpy()
+ mconf = batch['mconf_f'].cpu().numpy()
+ return mkpts0, mkpts1, mconf
+
+ def pad_bottom_right(self, inp, pad_size, ret_mask=False):
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+ mask = None
+ if inp.ndim == 2:
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
+ padded[:inp.shape[0], :inp.shape[1]] = inp
+ if ret_mask:
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
+ mask[:inp.shape[0], :inp.shape[1]] = True
+ elif inp.ndim == 3:
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
+ if ret_mask:
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
+ else:
+ raise NotImplementedError()
+ return padded, mask
+
diff --git a/imcui/third_party/XoFTR/src/utils/dataloader.py b/imcui/third_party/XoFTR/src/utils/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da37b880a290c2bb3ebb028d0c8dab592acc5c1
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/dataloader.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+# --- PL-DATAMODULE ---
+
+def get_local_split(items: list, world_size: int, rank: int, seed: int):
+ """ The local rank only loads a split of the dataset. """
+ n_items = len(items)
+ items_permute = np.random.RandomState(seed).permutation(items)
+ if n_items % world_size == 0:
+ padded_items = items_permute
+ else:
+ padding = np.random.RandomState(seed).choice(
+ items,
+ world_size - (n_items % world_size),
+ replace=True)
+ padded_items = np.concatenate([items_permute, padding])
+ assert len(padded_items) % world_size == 0, \
+ f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
+ n_per_rank = len(padded_items) // world_size
+ local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
+
+ return local_items
diff --git a/imcui/third_party/XoFTR/src/utils/dataset.py b/imcui/third_party/XoFTR/src/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e20de10040a2888fd08279b0c54bd8a4c39665bb
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/dataset.py
@@ -0,0 +1,279 @@
+import io
+from loguru import logger
+
+import cv2
+import numpy as np
+import h5py
+import torch
+from numpy.linalg import inv
+
+
+try:
+ # for internel use only
+ from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
+except Exception:
+ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
+
+# --- DATA IO ---
+
+def load_array_from_s3(
+ path, client, cv_type,
+ use_h5py=False,
+):
+ byte_str = client.Get(path)
+ try:
+ if not use_h5py:
+ raw_array = np.fromstring(byte_str, np.uint8)
+ data = cv2.imdecode(raw_array, cv_type)
+ else:
+ f = io.BytesIO(byte_str)
+ data = np.array(h5py.File(f, 'r')['/depth'])
+ except Exception as ex:
+ print(f"==> Data loading failure: {path}")
+ raise ex
+
+ assert data is not None
+ return data
+
+
+def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
+ else cv2.IMREAD_COLOR
+ if str(path).startswith('s3://'):
+ image = load_array_from_s3(str(path), client, cv_type)
+ else:
+ image = cv2.imread(str(path), cv_type)
+
+ if augment_fn is not None:
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = augment_fn(image)
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ return image # (h, w)
+
+
+def get_resized_wh(w, h, resize=None):
+ if resize is not None: # resize the longer edge
+ scale = resize / max(h, w)
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def get_divisible_wh(w, h, df=None):
+ if df is not None:
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
+ else:
+ w_new, h_new = w, h
+ return w_new, h_new
+
+
+def pad_bottom_right(inp, pad_size, ret_mask=False):
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
+ mask = None
+ if inp.ndim == 2:
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
+ padded[:inp.shape[0], :inp.shape[1]] = inp
+ if ret_mask:
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
+ mask[:inp.shape[0], :inp.shape[1]] = True
+ elif inp.ndim == 3:
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
+ if ret_mask:
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
+ else:
+ raise NotImplementedError()
+ return padded, mask
+
+
+# --- MEGADEPTH ---
+
+def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read image
+ image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ else:
+ mask = None
+
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+ mask = torch.from_numpy(mask)
+
+ return image, mask, scale
+
+
+def read_megadepth_depth(path, pad_to=None):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
+ else:
+ depth = np.array(h5py.File(path, 'r')['depth'])
+ if pad_to is not None:
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+# --- ScanNet ---
+
+def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
+ """
+ Args:
+ resize (tuple): align image to depthmap, in (w, h).
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read and resize image
+ image = imread_gray(path, augment_fn)
+ image = cv2.resize(image, resize)
+
+ # (h, w) -> (1, h, w) and normalized
+ image = torch.from_numpy(image).float()[None] / 255
+ return image
+
+
+def read_scannet_depth(path):
+ if str(path).startswith('s3://'):
+ depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
+ else:
+ depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
+ depth = depth / 1000
+ depth = torch.from_numpy(depth).float() # (h, w)
+ return depth
+
+
+def read_scannet_pose(path):
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+
+ Returns:
+ pose_w2c (np.ndarray): (4, 4)
+ """
+ cam2world = np.loadtxt(path, delimiter=' ')
+ world2cam = inv(cam2world)
+ return world2cam
+
+
+def read_scannet_intrinsic(path):
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+ """
+ intrinsic = np.loadtxt(path, delimiter=' ')
+ return intrinsic[:-1, :-1]
+
+
+# --- VisTir ---
+
+def read_vistir_gray(path, cam_K, dist, resize=None, df=None, padding=False, augment_fn=None):
+ """
+ Args:
+ cam_K (3, 3): camera matrix
+ dist (8): distortion coefficients
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w)
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ """
+ # read image
+ image = imread_gray(path, augment_fn, client=None)
+
+ h, w = image.shape[:2]
+ # update camera matrix
+ new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h))
+ # undistort image
+ image = cv2.undistort(image, cam_K, dist, None, new_K)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ mask = torch.from_numpy(mask)
+ else:
+ mask = None
+
+ image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
+
+ return image, mask, scale, new_K
+
+# --- PRETRAIN ---
+
+def read_pretrain_gray(path, resize=None, df=None, padding=False, augment_fn=None):
+ """
+ Args:
+ resize (int, optional): the longer edge of resized images. None for no resize.
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
+ augment_fn (callable, optional): augments images with pre-defined visual effects
+ Returns:
+ image (torch.tensor): (1, h, w) gray scale image
+ image_norm (torch.tensor): (1, h, w) normalized image
+ mask (torch.tensor): (h, w)
+ scale (torch.tensor): [w/w_new, h/h_new]
+ image_mean (torch.tensor): (1, 1, 1, 1)
+ image_std (torch.tensor): (1, 1, 1, 1)
+ """
+ # read image
+ image = imread_gray(path, augment_fn, client=None)
+
+ # resize image
+ w, h = image.shape[1], image.shape[0]
+ w_new, h_new = get_resized_wh(w, h, resize)
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
+
+ image = cv2.resize(image, (w_new, h_new))
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
+
+ image = image.astype(np.float32) / 255
+
+ image_mean = image.mean()
+ image_std = image.std()
+ image_norm = (image - image_mean) / (image_std + 1e-6)
+
+ if padding: # padding
+ pad_to = max(h_new, w_new)
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
+ image_norm, _ = pad_bottom_right(image_norm, pad_to, ret_mask=False)
+ mask = torch.from_numpy(mask)
+ else:
+ mask = None
+
+ image_mean = torch.as_tensor(image_mean).float()[None,None,None]
+ image_std = torch.as_tensor(image_std).float()[None,None,None]
+
+ image = torch.from_numpy(image).float()[None]
+ image_norm = torch.from_numpy(image_norm).float()[None] # (h, w) -> (1, h, w) and normalized
+
+ return image, image_norm, mask, scale, image_mean, image_std
+
diff --git a/imcui/third_party/XoFTR/src/utils/metrics.py b/imcui/third_party/XoFTR/src/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..a708e17a828574686d81d0d591dfa0c7597fa387
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/metrics.py
@@ -0,0 +1,211 @@
+import torch
+import cv2
+import numpy as np
+from collections import OrderedDict
+from loguru import logger
+from kornia.geometry.epipolar import numeric
+from kornia.geometry.conversions import convert_points_to_homogeneous
+
+
+# --- METRICS ---
+
+def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
+ # angle error between 2 vectors
+ t_gt = T_0to1[:3, 3]
+ n = np.linalg.norm(t) * np.linalg.norm(t_gt)
+ t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
+ t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
+ if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
+ t_err = 0
+
+ # angle error between 2 rotation matrices
+ R_gt = T_0to1[:3, :3]
+ cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
+ cos = np.clip(cos, -1., 1.) # handle numercial errors
+ R_err = np.rad2deg(np.abs(np.arccos(cos)))
+
+ return t_err, R_err
+
+
+def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
+ """Squared symmetric epipolar distance.
+ This can be seen as a biased estimation of the reprojection error.
+ Args:
+ pts0 (torch.Tensor): [N, 2]
+ E (torch.Tensor): [3, 3]
+ """
+ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+ pts0 = convert_points_to_homogeneous(pts0)
+ pts1 = convert_points_to_homogeneous(pts1)
+
+ Ep0 = pts0 @ E.T # [N, 3]
+ p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
+ Etp1 = pts1 @ E # [N, 3]
+
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
+ return d
+
+def symmetric_epipolar_distance_numpy(pts0, pts1, E, K0, K1):
+ """Squared symmetric epipolar distance.
+ This can be seen as a biased estimation of the reprojection error.
+ Args:
+ pts0 (numpy.array): [N, 2]
+ E (numpy.array): [3, 3]
+ """
+ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+ pts0 = np.hstack((pts0, np.ones((pts0.shape[0], 1))))
+ pts1 = np.hstack((pts1, np.ones((pts1.shape[0], 1))))
+
+ Ep0 = pts0 @ E.T # [N, 3]
+ p1Ep0 = np.sum(pts1 * Ep0, -1) # [N,]
+ Etp1 = pts1 @ E # [N, 3]
+
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
+ return d
+
+def compute_symmetrical_epipolar_errors(data):
+ """
+ Update:
+ data (dict):{"epi_errs": [M]}
+ """
+ Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
+ E_mat = Tx @ data['T_0to1'][:, :3, :3]
+
+ m_bids = data['m_bids']
+ pts0 = data['mkpts0_f']
+ pts1 = data['mkpts1_f']
+
+ epi_errs = []
+ for bs in range(Tx.size(0)):
+ mask = m_bids == bs
+ epi_errs.append(
+ symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
+ epi_errs = torch.cat(epi_errs, dim=0)
+
+ data.update({'epi_errs': epi_errs})
+
+
+def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
+ if len(kpts0) < 5:
+ return None
+ # normalize keypoints
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
+
+ # normalize ransac threshold
+ ransac_thr = thresh / np.mean([K0[0, 0], K0[1, 1], K1[0, 0], K1[1, 1]])
+
+ # compute pose with cv2
+ E, mask = cv2.findEssentialMat(
+ kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
+ if E is None:
+ print("\nE is None while trying to recover pose.\n")
+ return None
+
+ # recover pose from E
+ best_num_inliers = 0
+ ret = None
+ for _E in np.split(E, len(E) / 3):
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+ if n > best_num_inliers:
+ ret = (R, t[:, 0], mask.ravel() > 0)
+ best_num_inliers = n
+
+ return ret
+
+
+def compute_pose_errors(data, config):
+ """
+ Update:
+ data (dict):{
+ "R_errs" List[float]: [N]
+ "t_errs" List[float]: [N]
+ "inliers" List[np.ndarray]: [N]
+ }
+ """
+ pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
+ conf = config.TRAINER.RANSAC_CONF # 0.99999
+ data.update({'R_errs': [], 't_errs': [], 'inliers': []})
+
+ m_bids = data['m_bids'].cpu().numpy()
+ pts0 = data['mkpts0_f'].cpu().numpy()
+ pts1 = data['mkpts1_f'].cpu().numpy()
+ K0 = data['K0'].cpu().numpy()
+ K1 = data['K1'].cpu().numpy()
+ T_0to1 = data['T_0to1'].cpu().numpy()
+
+ for bs in range(K0.shape[0]):
+ mask = m_bids == bs
+ ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
+
+ if ret is None:
+ data['R_errs'].append(np.inf)
+ data['t_errs'].append(np.inf)
+ data['inliers'].append(np.array([]).astype(np.bool))
+ else:
+ R, t, inliers = ret
+ t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
+ data['R_errs'].append(R_err)
+ data['t_errs'].append(t_err)
+ data['inliers'].append(inliers)
+
+
+# --- METRIC AGGREGATION ---
+
+def error_auc(errors, thresholds):
+ """
+ Args:
+ errors (list): [N,]
+ thresholds (list)
+ """
+ errors = [0] + sorted(list(errors))
+ recall = list(np.linspace(0, 1, len(errors)))
+
+ aucs = []
+ thresholds = [5, 10, 20]
+ for thr in thresholds:
+ last_index = np.searchsorted(errors, thr)
+ y = recall[:last_index] + [recall[last_index-1]]
+ x = errors[:last_index] + [thr]
+ aucs.append(np.trapz(y, x) / thr)
+
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
+
+
+def epidist_prec(errors, thresholds, ret_dict=False):
+ precs = []
+ for thr in thresholds:
+ prec_ = []
+ for errs in errors:
+ correct_mask = errs < thr
+ prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
+ precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
+ if ret_dict:
+ return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
+ else:
+ return precs
+
+
+def aggregate_metrics(metrics, epi_err_thr=5e-4):
+ """ Aggregate metrics for the whole dataset:
+ (This method should be called once per dataset)
+ 1. AUC of the pose error (angular) at the threshold [5, 10, 20]
+ 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
+ """
+ # filter duplicates
+ unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
+ unq_ids = list(unq_ids.values())
+ logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
+
+ # pose auc
+ angular_thresholds = [5, 10, 20]
+ pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
+ aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
+
+ # matching precision
+ dist_thresholds = [epi_err_thr]
+ precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
+
+ return {**aucs, **precs}
diff --git a/imcui/third_party/XoFTR/src/utils/misc.py b/imcui/third_party/XoFTR/src/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8db04666519753ea2df43903ab6c47ec00a9a1
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/misc.py
@@ -0,0 +1,101 @@
+import os
+import contextlib
+import joblib
+from typing import Union
+from loguru import _Logger, logger
+from itertools import chain
+
+import torch
+from yacs.config import CfgNode as CN
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+def upper_config(dict_cfg):
+ if not isinstance(dict_cfg, dict):
+ return dict_cfg
+ return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
+
+
+def log_on(condition, message, level):
+ if condition:
+ assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
+ logger.log(level, message)
+
+
+def get_rank_zero_only_logger(logger: _Logger):
+ if rank_zero_only.rank == 0:
+ return logger
+ else:
+ for _level in logger._core.levels.keys():
+ level = _level.lower()
+ setattr(logger, level,
+ lambda x: None)
+ logger._log = lambda x: None
+ return logger
+
+
+def setup_gpus(gpus: Union[str, int]) -> int:
+ """ A temporary fix for pytorch-lighting 1.3.x """
+ gpus = str(gpus)
+ gpu_ids = []
+
+ if ',' not in gpus:
+ n_gpus = int(gpus)
+ return n_gpus if n_gpus != -1 else torch.cuda.device_count()
+ else:
+ gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
+
+ # setup environment variables
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ if visible_devices is None:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
+ visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
+ logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
+ else:
+ logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
+ return len(gpu_ids)
+
+
+def flattenList(x):
+ return list(chain(*x))
+
+
+@contextlib.contextmanager
+def tqdm_joblib(tqdm_object):
+ """Context manager to patch joblib to report into tqdm progress bar given as argument
+
+ Usage:
+ with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
+ Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
+
+ When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
+ ret_vals = Parallel(n_jobs=args.world_size)(
+ delayed(lambda x: _compute_cov_score(pid, *x))(param)
+ for param in tqdm(combinations(image_ids, 2),
+ desc=f'Computing cov_score of [{pid}]',
+ total=len(image_ids)*(len(image_ids)-1)/2))
+ Src: https://stackoverflow.com/a/58936697
+ """
+ class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ tqdm_object.update(n=self.batch_size)
+ return super().__call__(*args, **kwargs)
+
+ old_batch_callback = joblib.parallel.BatchCompletionCallBack
+ joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
+ try:
+ yield tqdm_object
+ finally:
+ joblib.parallel.BatchCompletionCallBack = old_batch_callback
+ tqdm_object.close()
+
diff --git a/imcui/third_party/XoFTR/src/utils/plotting.py b/imcui/third_party/XoFTR/src/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c312223394330125b2b5572b8fae9c7a12f599
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/plotting.py
@@ -0,0 +1,227 @@
+import bisect
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+plt.switch_backend('agg')
+from einops.einops import rearrange
+import torch.nn.functional as F
+
+
+def _compute_conf_thresh(data):
+ dataset_name = data['dataset_name'][0].lower()
+ if dataset_name == 'scannet':
+ thr = 5e-4
+ elif dataset_name == 'megadepth':
+ thr = 1e-4
+ elif dataset_name == 'vistir':
+ thr = 5e-4
+ else:
+ raise ValueError(f'Unknown dataset: {dataset_name}')
+ return thr
+
+
+# --- VISUALIZATION --- #
+
+def make_matching_figure(
+ img0, img1, mkpts0, mkpts1, color,
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+ # draw image pair
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0, cmap='gray')
+ axes[1].imshow(img1, cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if kpts0 is not None:
+ assert kpts1 is not None
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
+
+ # draw matches
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure, c=color[i], linewidth=1)
+ for i in range(len(mkpts0))]
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
+
+ # put txts
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def _make_evaluation_figure(data, b_id, alpha='dynamic', ret_dict=None):
+ b_mask = data['m_bids'] == b_id
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
+ kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
+
+ # for megadepth, we visualize matches on the resized image
+ if 'scale0' in data:
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
+
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}']
+ if ret_dict is not None:
+ text += [f"t_err: {ret_dict['metrics']['t_errs'][b_id]:.2f}",
+ f"R_err: {ret_dict['metrics']['R_errs'][b_id]:.2f}"]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
+ color, text=text)
+ return figure
+
+def _make_confidence_figure(data, b_id):
+ # TODO: Implement confidence figure
+ raise NotImplementedError()
+
+
+def make_matching_figures(data, config, mode='evaluation', ret_dict=None):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_XoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ['evaluation', 'confidence'] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ fig = _make_evaluation_figure(
+ data, b_id,
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA, ret_dict=ret_dict)
+ elif mode == 'confidence':
+ fig = _make_confidence_figure(data, b_id)
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+def make_mae_figures(data):
+ """ Make mae figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_XoFTR_Pretrain.
+ Returns:
+ figures (List[plt.figure])
+ """
+
+ scale = data['hw0_i'][0] // data['hw0_f'][0]
+ W_f = data["W_f"]
+
+ pred0, pred1 = data["pred0"], data["pred1"]
+ target0, target1 = data["target0"], data["target1"]
+
+ # replace masked regions with predictions
+ target0[data['b_ids'][data["ids_image0"]], data['i_ids'][data["ids_image0"]]] = pred0[data["ids_image0"]]
+ target1[data['b_ids'][data["ids_image1"]], data['j_ids'][data["ids_image1"]]] = pred1[data["ids_image1"]]
+
+ # remove excess parts, since the 10x10 windows have overlaping regions
+ target0 = rearrange(target0, 'n l (h w) (p q c) -> n c (h p) (w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1)
+ target1 = rearrange(target1, 'n l (h w) (p q c) -> n c (h p) (w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1)
+ # target0[:,:,-scale:,:] = 0.0
+ # target0[:,:,:,-scale:] = 0.0
+ # target1[:,:,-scale:,:] = 0.0
+ # target1[:,:,:,-scale:] = 0.0
+ gap = scale //2
+ target0[:,:,-gap:,:] = 0.0
+ target0[:,:,:,-gap:] = 0.0
+ target1[:,:,-gap:,:] = 0.0
+ target1[:,:,:,-gap:] = 0.0
+ target0[:,:,:gap,:] = 0.0
+ target0[:,:,:,:gap] = 0.0
+ target1[:,:,:gap,:] = 0.0
+ target1[:,:,:,:gap] = 0.0
+ target0 = rearrange(target0, 'n c (h p) (w q) l -> n (c h p w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1)
+ target1 = rearrange(target1, 'n c (h p) (w q) l -> n (c h p w q) l', h=W_f, w=W_f, p=scale, q=scale, c=1)
+
+ # windows to image
+ kernel_size = [int(W_f*scale), int(W_f*scale)]
+ padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2
+ stride = data['hw0_i'][0] // data['hw0_c'][0]
+ target0 = F.fold(target0, output_size=data["image0"].shape[2:], kernel_size=kernel_size, stride=stride, padding=padding)
+ target1 = F.fold(target1, output_size=data["image1"].shape[2:], kernel_size=kernel_size, stride=stride, padding=padding)
+
+ # add mean and std of original image for visualization
+ if ("image0_norm" in data) and ("image1_norm" in data):
+ target0 = target0 * data["image0_std"] + data["image0_mean"]
+ target1 = target1 * data["image1_std"] + data["image1_mean"]
+ masked_image0 = data["masked_image0"] * data["image0_std"].to("cpu") + data["image0_mean"].to("cpu")
+ masked_image1 = data["masked_image1"] * data["image1_std"].to("cpu") + data["image1_mean"].to("cpu")
+ else:
+ masked_image0 = data["masked_image0"]
+ masked_image1 = data["masked_image1"]
+
+ figures = []
+ # Create a list of these tensors
+ image_groups = [[data["image0"], masked_image0, target0],
+ [data["image1"], masked_image1, target1]]
+
+ # Iterate through the batches
+ for batch_idx in range(image_groups[0][0].shape[0]): # Assuming batch dimension is the first dimension
+ fig, axs = plt.subplots(2, 3, figsize=(9, 6))
+ for i, image_tensors in enumerate(image_groups):
+ for j, img_tensor in enumerate(image_tensors):
+ img = img_tensor[batch_idx, 0, :, :].detach().cpu().numpy() # Get the image data as a NumPy array
+ axs[i,j].imshow(img, cmap='gray', vmin=0, vmax=1) # Display the image in a subplot with correct colormap
+ axs[i,j].axis('off') # Turn off axis labels
+ fig.tight_layout()
+ figures.append(fig)
+ return figures
+
+def dynamic_alpha(n_matches,
+ milestones=[0, 300, 1000, 2000],
+ alphas=[1.0, 0.8, 0.4, 0.2]):
+ if n_matches == 0:
+ return 1.0
+ ranges = list(zip(alphas, alphas[1:] + [None]))
+ loc = bisect.bisect_right(milestones, n_matches) - 1
+ _range = ranges[loc]
+ if _range[1] is None:
+ return _range[0]
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+
+
+def error_colormap(err, thr, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
+ return np.clip(
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
diff --git a/imcui/third_party/XoFTR/src/utils/pretrain_utils.py b/imcui/third_party/XoFTR/src/utils/pretrain_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad0296d2c249d6d0adb575422efaf1d01ba69fe
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/pretrain_utils.py
@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+from einops.einops import rearrange
+import torch.nn.functional as F
+
+def generate_random_masks(batch, patch_size, mask_ratio, generator=None, margins=[0,0,0,0]):
+ mae_mask0 = _gen_random_mask(batch['image0'], patch_size, mask_ratio, generator, margins=margins)
+ mae_mask1 = _gen_random_mask(batch['image1'], patch_size, mask_ratio, generator, margins=margins)
+ batch.update({"mae_mask0" : mae_mask0, "mae_mask1": mae_mask1})
+
+def _gen_random_mask(image, patch_size, mask_ratio, generator=None, margins=[0, 0, 0, 0]):
+ """ Random mask generator
+ Args:
+ image (torch.Tensor): [N, C, H, W]
+ patch_size (int)
+ mask_ratio (float)
+ generator (torch.Generator): RNG to create the same random masks for validation
+ margins [float, float, float, float]: unused part for masking (up bottom left right)
+ Returns:
+ mask (torch.Tensor): (N, L)
+ """
+ N = image.shape[0]
+ l = (image.shape[2] // patch_size)
+ L = l ** 2
+ len_keep = int(L * (1 - mask_ratio * (1 - sum(margins))))
+
+ margins = [int(margin * l) for margin in margins]
+
+ noise = torch.rand(N, l, l, device=image.device, generator=generator)
+ if margins[0] > 0 : noise[:,:margins[0],:] = 0
+ if margins[1] > 0 : noise[:,-margins[1]:,:] = 0
+ if margins[2] > 0 : noise[:,:,:margins[2]] = 0
+ if margins[3] > 0 : noise[:,:,-margins[3]:] = 0
+ noise = noise.flatten(1)
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=1)
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # generate the binary mask: 0 is keep 1 is remove
+ mask = torch.ones([N, L], device=image.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+ return mask
+
+def patchify(data):
+ """ Split images into small overlapped patches
+ Args:
+ data (dict):{
+ 'image0_norm' (torch.Tensor): [N, C, H, W] normalized image,
+ 'image1_norm' (torch.Tensor): [N, C, H, W] normalized image,
+ Returns:
+ image0 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)
+ image1 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)
+ """
+ stride = data['hw0_i'][0] // data['hw0_c'][0]
+ scale = data['hw0_i'][0] // data['hw0_f'][0]
+ W_f = data["W_f"]
+ kernel_size = [int(W_f*scale), int(W_f*scale)]
+ padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2
+
+ image0 = data["image0_norm"] if "image0_norm" in data else data["image0"]
+ image1 = data["image1_norm"] if "image1_norm" in data else data["image1"]
+
+ image0 = F.unfold(image0, kernel_size=kernel_size, stride=stride, padding=padding)
+ image0 = rearrange(image0, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
+ image0 = image0.flatten(4)
+ image0 = image0.reshape(*image0.shape[:2], W_f**2, -1)
+
+ image1 = F.unfold(image1, kernel_size=kernel_size, stride=stride, padding=padding)
+ image1 = rearrange(image1, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
+ image1 = image1.flatten(4)
+ image1 = image1.reshape(*image1.shape[:2], W_f**2, -1)
+
+ return image0, image1
+
+def get_target(data):
+ """Create target patches for mae"""
+ target0, target1 = patchify(data)
+ data.update({"target0":target0, "target1":target1})
+
+
diff --git a/imcui/third_party/XoFTR/src/utils/profiler.py b/imcui/third_party/XoFTR/src/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d21ed79fb506ef09c75483355402c48a195aaa9
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/utils/profiler.py
@@ -0,0 +1,39 @@
+import torch
+from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
+from contextlib import contextmanager
+from pytorch_lightning.utilities import rank_zero_only
+
+
+class InferenceProfiler(SimpleProfiler):
+ """
+ This profiler records duration of actions with cuda.synchronize()
+ Use this in test time.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.start = rank_zero_only(self.start)
+ self.stop = rank_zero_only(self.stop)
+ self.summary = rank_zero_only(self.summary)
+
+ @contextmanager
+ def profile(self, action_name: str) -> None:
+ try:
+ torch.cuda.synchronize()
+ self.start(action_name)
+ yield action_name
+ finally:
+ torch.cuda.synchronize()
+ self.stop(action_name)
+
+
+def build_profiler(name):
+ if name == 'inference':
+ return InferenceProfiler()
+ elif name == 'pytorch':
+ from pytorch_lightning.profiler import PyTorchProfiler
+ return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
+ elif name is None:
+ return PassThroughProfiler()
+ else:
+ raise ValueError(f'Invalid profiler: {name}')
diff --git a/imcui/third_party/XoFTR/src/xoftr/__init__.py b/imcui/third_party/XoFTR/src/xoftr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..84ad2ecd63cab23f8ae42d49024eb448a6e18a59
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/__init__.py
@@ -0,0 +1,2 @@
+from .xoftr import XoFTR
+from .xoftr_pretrain import XoFTR_Pretrain
diff --git a/imcui/third_party/XoFTR/src/xoftr/backbone/__init__.py b/imcui/third_party/XoFTR/src/xoftr/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f84fff9207161b2ef7cda5aeaae5e1e1c9180f
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/backbone/__init__.py
@@ -0,0 +1 @@
+from .resnet import ResNet_8_2
diff --git a/imcui/third_party/XoFTR/src/xoftr/backbone/resnet.py b/imcui/third_party/XoFTR/src/xoftr/backbone/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d6b24fc31ade00e2f434ee4bc35618c45832ed
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/backbone/resnet.py
@@ -0,0 +1,95 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution without padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super().__init__()
+ self.conv1 = conv3x3(in_planes, planes, stride)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ conv1x1(in_planes, planes, stride=stride),
+ nn.BatchNorm2d(planes)
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.bn1(self.conv1(y)))
+ y = self.bn2(self.conv2(y))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+class ResNet_8_2(nn.Module):
+ """
+ ResNet, output resolution are 1/8 and 1/2.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ x3_out = self.layer3_outconv(x3)
+
+ return x3_out, x2, x1
+
diff --git a/imcui/third_party/XoFTR/src/xoftr/utils/geometry.py b/imcui/third_party/XoFTR/src/xoftr/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c9a318e1838614aa29de8949c9f2255cd73fdf
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/utils/geometry.py
@@ -0,0 +1,107 @@
+import torch
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - ,
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ kpts0_long = kpts0.round().long()
+
+ # Sample depth, get calculable_mask on depth != 0
+ kpts0_depth = torch.stack(
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+ ) # (N, L)
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+ w_kpts0_long = w_kpts0.long()
+ w_kpts0_long[~covisible_mask, :] = 0
+
+ w_kpts0_depth = torch.stack(
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+ ) # (N, L)
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
+
+@torch.no_grad()
+def warp_kpts_fine(kpts0, depth0, depth1, T_0to1, K0, K1, b_ids):
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt for give batch ids
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - ,
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ b_ids (torch.Tensor): [M], selected batch ids for fine-level matching
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ kpts0_long = kpts0.round().long()
+
+ # Sample depth, get calculable_mask on depth != 0
+ kpts0_depth = torch.stack(
+ [depth0[b_ids[i], kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+ ) # (N, L)
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
+ kpts0_cam = K0[b_ids].inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[b_ids, :3, :3] @ kpts0_cam + T_0to1[b_ids, :3, [3]][...,None] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1[b_ids] @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+ w_kpts0_long = w_kpts0.long()
+ w_kpts0_long[~covisible_mask, :] = 0
+
+ w_kpts0_depth = torch.stack(
+ [depth1[b_ids[i], w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+ ) # (N, L)
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
diff --git a/imcui/third_party/XoFTR/src/xoftr/utils/position_encoding.py b/imcui/third_party/XoFTR/src/xoftr/utils/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fa07546e2f62d73ef2da22b06466adeb01c5444
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/utils/position_encoding.py
@@ -0,0 +1,36 @@
+import math
+import torch
+from torch import nn
+
+
+class PositionEncodingSine(nn.Module):
+ """
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
+ """
+
+ def __init__(self, d_model, max_shape=(256, 256)):
+ """
+ Args:
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
+ """
+ super().__init__()
+
+ pe = torch.zeros((d_model, *max_shape))
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
+
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
+
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
+
+ def forward(self, x):
+ """
+ Args:
+ x: [N, C, H, W]
+ """
+ return x + self.pe[:, :, :x.size(2), :x.size(3)]
\ No newline at end of file
diff --git a/imcui/third_party/XoFTR/src/xoftr/utils/supervision.py b/imcui/third_party/XoFTR/src/xoftr/utils/supervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..8924aaec70b5dc40d899f69d61e6b74fba15d7c7
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/utils/supervision.py
@@ -0,0 +1,290 @@
+from math import log
+from loguru import logger
+
+import torch
+import torch.nn.functional as F
+from einops import repeat
+from kornia.utils import create_meshgrid
+from einops.einops import rearrange
+from .geometry import warp_kpts, warp_kpts_fine
+from kornia.geometry.epipolar import fundamental_from_projections, normalize_transformation
+
+############## ↓ Coarse-Level supervision ↓ ##############
+
+
+@torch.no_grad()
+def mask_pts_at_padded_regions(grid_pt, mask):
+ """For megadepth dataset, zero-padding exists in images"""
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+ grid_pt[~mask.bool()] = 0
+ return grid_pt
+
+
+@torch.no_grad()
+def spvs_coarse(data, config):
+ """
+ Update:
+ data (dict): {
+ "conf_matrix_gt": [N, hw0, hw1],
+ 'spv_b_ids': [M]
+ 'spv_i_ids': [M]
+ 'spv_j_ids': [M]
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
+ }
+
+ NOTE:
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
+ """
+ # 1. misc
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+ scale = config['XOFTR']['RESOLUTION'][0]
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+
+ # 2. warp grids
+ # create kpts in meshgrid and resize them to image resolution
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
+ grid_pt0_i = scale0 * grid_pt0_c
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+ grid_pt1_i = scale1 * grid_pt1_c
+
+ # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
+ if 'mask0' in data:
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
+
+ # warp kpts bi-directionally and resize them to coarse-level resolution
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
+ valid_mask0, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
+ valid_mask1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
+ w_pt0_i[~valid_mask0] = 0
+ w_pt1_i[~valid_mask1] = 0
+ w_pt0_c = w_pt0_i / scale1
+ w_pt1_c = w_pt1_i / scale0
+
+ # 3. nearest neighbor
+ w_pt0_c_round = w_pt0_c[:, :, :].round().long()
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
+ w_pt1_c_round = w_pt1_c[:, :, :].round().long()
+ nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
+
+ # corner case: out of boundary
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
+ nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
+
+ arange_1 = torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+ arange_0 = torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+ arange_1[nearest_index1 == 0] = 0
+ arange_0[nearest_index0 == 0] = 0
+ arange_b = torch.arange(N, device=device).unsqueeze(1)
+
+ # 4. construct a gt conf_matrix
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
+ conf_matrix_gt[arange_b, arange_1, nearest_index1] = 1
+ conf_matrix_gt[arange_b, nearest_index0, arange_0] = 1
+ conf_matrix_gt[:, 0, 0] = False
+
+ b_ids, i_ids, j_ids = conf_matrix_gt.nonzero(as_tuple=True)
+
+ data.update({'conf_matrix_gt': conf_matrix_gt})
+
+ # 5. save coarse matches(gt) for training fine level
+ if len(b_ids) == 0:
+ logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
+ # this won't affect fine-level loss calculation
+ b_ids = torch.tensor([0], device=device)
+ i_ids = torch.tensor([0], device=device)
+ j_ids = torch.tensor([0], device=device)
+
+ data.update({
+ 'spv_b_ids': b_ids,
+ 'spv_i_ids': i_ids,
+ 'spv_j_ids': j_ids
+ })
+
+ # 6. save intermediate results (for fast fine-level computation)
+ data.update({
+ 'spv_w_pt0_i': w_pt0_i,
+ 'spv_pt1_i': grid_pt1_i
+ })
+
+def compute_supervision_coarse(data, config):
+ assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_coarse(data, config)
+ else:
+ raise ValueError(f'Unknown data source: {data_source}')
+
+
+############## ↓ Fine-Level supervision ↓ ##############
+
+def compute_supervision_fine(data, config):
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_fine(data, config)
+ else:
+ raise NotImplementedError
+
+@torch.no_grad()
+def create_2d_gaussian_kernel(kernel_size, sigma, device):
+ """
+ Create a 2D Gaussian kernel.
+
+ Args:
+ kernel_size (int): Size of the kernel (both width and height).
+ sigma (float): Standard deviation of the Gaussian distribution.
+
+ Returns:
+ torch.Tensor: 2D Gaussian kernel.
+ """
+ kernel = torch.arange(kernel_size, dtype=torch.float32, device=device) - (kernel_size - 1) / 2
+ kernel = torch.exp(-kernel**2 / (2 * sigma**2))
+ kernel = kernel / kernel.sum()
+
+ # Outer product to get a 2D kernel
+ kernel = torch.outer(kernel, kernel)
+
+ return kernel
+
+@torch.no_grad()
+def create_conf_prob(points, h0, w0, h1, w1, kernel_size = 5, sigma=1):
+ """
+ Place a gaussian kernel in sim matrix for warped points
+
+ Args:
+ data (dict): {
+ points: (torch.Tensor): (N, L, 2), warped rounded key points
+ h0, w0, h1, w1: (int), windows sizes
+ kernel_size: (int), kernel size for the gaussian
+ sigma: (float), sigma value for gaussian
+ }
+ """
+ B = points.shape[0]
+ impulses = torch.zeros(B, h0 * w0, h1, w1, device=points.device)
+
+ # Extract the row and column indices
+ row_indices = points[:, :, 1]
+ col_indices = points[:, :, 0]
+
+ # Set the corresponding locations in the target tensor to 1
+ impulses[torch.arange(B, device=points.device).view(B, 1, 1),
+ torch.arange(h0 * w0, device=points.device).view(1, h0 * w0, 1),
+ row_indices.unsqueeze(-1), col_indices.unsqueeze(-1)] = 1
+ # mask 0,0 point
+ impulses[:,:,0,0] = 0
+
+ # Create the Gaussian kernel
+ gaussian_kernel = create_2d_gaussian_kernel(kernel_size, sigma=sigma, device=points.device)
+ gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
+
+ # Create distributions at the points
+ conf_prob = F.conv2d(impulses.view(-1,1,h1,w1), gaussian_kernel, padding=kernel_size//2).view(-1, h0*w0, h1*w1)
+
+ return conf_prob
+
+@torch.no_grad()
+def spvs_fine(data, config):
+ """
+ Args:
+ data (dict): {
+ 'b_ids': [M]
+ 'i_ids': [M]
+ 'j_ids': [M]
+ }
+
+ Update:
+ data (dict): {
+ conf_matrix_f_gt: [N, W_f^2, W_f^2], in original image resolution
+ }
+
+ """
+ # 1. misc
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+ scale = config['XOFTR']['RESOLUTION'][1]
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+ scale_f_c = config['XOFTR']['RESOLUTION'][0] // config['XOFTR']['RESOLUTION'][1]
+ W_f = config['XOFTR']['FINE_WINDOW_SIZE']
+ # 2. get coarse prediction
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+
+ if len(b_ids) == 0:
+ data.update({"conf_matrix_f_gt": torch.zeros(1,W_f*W_f,W_f*W_f, device=device)})
+ return
+
+ # 2. warp grids
+ # create kpts in meshgrid and resize them to image resolution
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).repeat(N, 1, 1, 1)#.reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
+ grid_pt0_i = scale0[:,None,...] * grid_pt0_c
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).repeat(N, 1, 1, 1)#.reshape(1, h1*w1, 2).repeat(N, 1, 1)
+ grid_pt1_i = scale1[:,None,...] * grid_pt1_c
+
+ # unfold (crop windows) all local windows
+ stride_f = data['hw0_f'][0] // data['hw0_c'][0]
+
+ grid_pt0_i = rearrange(grid_pt0_i, 'n h w c -> n c h w')
+ grid_pt0_i = F.unfold(grid_pt0_i, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2)
+ grid_pt0_i = rearrange(grid_pt0_i, 'n (c ww) l -> n l ww c', ww=W_f**2)
+ grid_pt0_i = grid_pt0_i[b_ids, i_ids]
+
+ grid_pt1_i = rearrange(grid_pt1_i, 'n h w c -> n c h w')
+ grid_pt1_i = F.unfold(grid_pt1_i, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2)
+ grid_pt1_i = rearrange(grid_pt1_i, 'n (c ww) l -> n l ww c', ww=W_f**2)
+ grid_pt1_i = grid_pt1_i[b_ids, j_ids]
+
+ # warp kpts bi-directionally and resize them to fine-level resolution
+ # (no depth consistency check
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
+ _, w_pt0_i = warp_kpts_fine(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'], b_ids)
+ _, w_pt1_i = warp_kpts_fine(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'], b_ids)
+ w_pt0_f = w_pt0_i / scale1[b_ids]
+ w_pt1_f = w_pt1_i / scale0[b_ids]
+
+ mkpts0_c_scaled_to_f = torch.stack(
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
+ dim=1) * scale_f_c - W_f//2
+ mkpts1_c_scaled_to_f = torch.stack(
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
+ dim=1) * scale_f_c - W_f//2
+
+ w_pt0_f = w_pt0_f - mkpts1_c_scaled_to_f[:,None,:]
+ w_pt1_f = w_pt1_f - mkpts0_c_scaled_to_f[:,None,:]
+
+ # 3. check if mutual nearest neighbor
+ w_pt0_f_round = w_pt0_f[:, :, :].round().long()
+ w_pt1_f_round = w_pt1_f[:, :, :].round().long()
+ M = w_pt0_f.shape[0]
+
+ nearest_index1 = w_pt0_f_round[..., 0] + w_pt0_f_round[..., 1] * W_f
+ nearest_index0 = w_pt1_f_round[..., 0] + w_pt1_f_round[..., 1] * W_f
+
+ # corner case: out of boundary
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ nearest_index1[out_bound_mask(w_pt0_f_round, W_f, W_f)] = 0
+ nearest_index0[out_bound_mask(w_pt1_f_round, W_f, W_f)] = 0
+
+ loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
+ correct_0to1 = loop_back == torch.arange(W_f*W_f, device=device)[None].repeat(M, 1)
+ correct_0to1[:, 0] = False # ignore the top-left corner
+
+ # 4. construct a gt conf_matrix
+ conf_matrix_f_gt = torch.zeros(M, W_f*W_f, W_f*W_f, device=device)
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
+ j_ids = nearest_index1[b_ids, i_ids]
+ conf_matrix_f_gt[b_ids, i_ids, j_ids] = 1
+
+ data.update({"conf_matrix_f_gt": conf_matrix_f_gt})
+
+
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr.py b/imcui/third_party/XoFTR/src/xoftr/xoftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..78bc7c867d82a7161ce38fd46230c04c8d26b60d
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr.py
@@ -0,0 +1,94 @@
+import torch
+import torch.nn as nn
+from einops.einops import rearrange
+from .backbone import ResNet_8_2
+from .utils.position_encoding import PositionEncodingSine
+from .xoftr_module import LocalFeatureTransformer, FineProcess, CoarseMatching, FineSubMatching
+
+class XoFTR(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Misc
+ self.config = config
+
+ # Modules
+ self.backbone = ResNet_8_2(config['resnet'])
+ self.pos_encoding = PositionEncodingSine(config['coarse']['d_model'])
+ self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
+ self.fine_process = FineProcess(config)
+ self.fine_matching= FineSubMatching(config)
+
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+
+ eps = 1e-6
+
+ image0_mean = data['image0'].mean(dim=[2,3], keepdim=True)
+ image0_std = data['image0'].std(dim=[2,3], keepdim=True)
+ image0 = (data['image0'] - image0_mean) / (image0_std + eps)
+
+ image1_mean = data['image1'].mean(dim=[2,3], keepdim=True)
+ image1_std = data['image1'].std(dim=[2,3], keepdim=True)
+ image1 = (data['image1'] - image1_mean) / (image1_std + eps)
+
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
+ feats_c, feats_m, feats_f = self.backbone(torch.cat([image0, image1], dim=0))
+ (feat_c0, feat_c1) = feats_c.split(data['bs'])
+ (feat_m0, feat_m1) = feats_m.split(data['bs'])
+ (feat_f0, feat_f1) = feats_f.split(data['bs'])
+ else: # handle different input shapes
+ feat_c0, feat_m0, feat_f0 = self.backbone(image0)
+ feat_c1, feat_m1, feat_f1 = self.backbone(image1)
+
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_m': feat_m0.shape[2:], 'hw1_m': feat_m1.shape[2:],
+ 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
+ })
+
+ # save coarse features for fine matching
+ feat_c0_pre, feat_c1_pre = feat_c0.clone(), feat_c1.clone()
+
+ # 2. coarse-level loftr module
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
+ feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
+
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
+
+ # 3. match coarse-level
+ self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
+
+ # 4. fine-level matching module
+ feat_f0_unfold, feat_f1_unfold = self.fine_process(feat_f0, feat_f1,
+ feat_m0, feat_m1,
+ feat_c0, feat_c1,
+ feat_c0_pre, feat_c1_pre,
+ data)
+
+ # 5. match fine-level and sub-pixel refinement
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
+
+ def load_state_dict(self, state_dict, *args, **kwargs):
+ for k in list(state_dict.keys()):
+ if k.startswith('matcher.'):
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+ return super().load_state_dict(state_dict, *args, **kwargs)
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a99d3559e31f93291afbcf65ade17cf3616bbc9
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py
@@ -0,0 +1,4 @@
+from .transformer import LocalFeatureTransformer
+from .fine_process import FineProcess
+from .coarse_matching import CoarseMatching
+from .fine_matching import FineSubMatching
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..34f82abbbc912a8761cf2858183041ed17d7cd70
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py
@@ -0,0 +1,305 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange
+
+INF = 1e9
+
+def mask_border(m, b: int, v):
+ """ Mask borders with value
+ Args:
+ m (torch.Tensor): [N, H0, W0, H1, W1]
+ b (int)
+ v (m.dtype)
+ """
+ if b <= 0:
+ return
+
+ m[:, :b] = v
+ m[:, :, :b] = v
+ m[:, :, :, :b] = v
+ m[:, :, :, :, :b] = v
+ m[:, -b:] = v
+ m[:, :, -b:] = v
+ m[:, :, :, -b:] = v
+ m[:, :, :, :, -b:] = v
+
+
+def mask_border_with_padding(m, bd, v, p_m0, p_m1):
+ if bd <= 0:
+ return
+
+ m[:, :bd] = v
+ m[:, :, :bd] = v
+ m[:, :, :, :bd] = v
+ m[:, :, :, :, :bd] = v
+
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
+ m[b_idx, h0 - bd:] = v
+ m[b_idx, :, w0 - bd:] = v
+ m[b_idx, :, :, h1 - bd:] = v
+ m[b_idx, :, :, :, w1 - bd:] = v
+
+
+def compute_max_candidates(p_m0, p_m1):
+ """Compute the max candidates of all pairs within a batch
+
+ Args:
+ p_m0, p_m1 (torch.Tensor): padded masks
+ """
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
+ max_cand = torch.sum(
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+ return max_cand
+
+
+class CoarseMatching(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # general config
+ d_model = config['d_model']
+ self.thr = config['thr']
+ self.inference = config['inference']
+ self.border_rm = config['border_rm']
+ # -- # for trainig fine-level XoFTR
+ self.train_coarse_percent = config['train_coarse_percent']
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
+ self.final_proj = nn.Linear(d_model, d_model, bias=True)
+
+ self.temperature = config['dsmax_temperature']
+
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ data (dict)
+ mask_c0 (torch.Tensor): [N, L] (optional)
+ mask_c1 (torch.Tensor): [N, S] (optional)
+ Update:
+ data (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ NOTE: M' != M during training.
+ """
+
+ feat_c0 = self.final_proj(feat_c0)
+ feat_c1 = self.final_proj(feat_c1)
+
+ # normalize
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1])
+
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
+ feat_c1) / self.temperature
+ if mask_c0 is not None:
+ sim_matrix.masked_fill_(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -INF)
+ if self.inference:
+ # predict coarse matches from conf_matrix
+ data.update(**self.get_coarse_match_inference(sim_matrix, data))
+ else:
+ conf_matrix_0_to_1 = F.softmax(sim_matrix, 2)
+ conf_matrix_1_to_0 = F.softmax(sim_matrix, 1)
+ data.update({'conf_matrix_0_to_1': conf_matrix_0_to_1,
+ 'conf_matrix_1_to_0': conf_matrix_1_to_0
+ })
+ # predict coarse matches from conf_matrix
+ data.update(**self.get_coarse_match_training(conf_matrix_0_to_1, conf_matrix_1_to_0, data))
+
+ @torch.no_grad()
+ def get_coarse_match_training(self, conf_matrix_0_to_1, conf_matrix_1_to_0, data):
+ """
+ Args:
+ conf_matrix_0_to_1 (torch.Tensor): [N, L, S]
+ conf_matrix_1_to_0 (torch.Tensor): [N, L, S]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ axes_lengths = {
+ 'h0c': data['hw0_c'][0],
+ 'w0c': data['hw0_c'][1],
+ 'h1c': data['hw1_c'][0],
+ 'w1c': data['hw1_c'][1]
+ }
+ _device = conf_matrix_0_to_1.device
+
+ # confidence thresholding
+ # {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)}
+ mask = torch.logical_or((conf_matrix_0_to_1 > self.thr) * (conf_matrix_0_to_1 == conf_matrix_0_to_1.max(dim=2, keepdim=True)[0]),
+ (conf_matrix_1_to_0 > self.thr) * (conf_matrix_1_to_0 == conf_matrix_1_to_0.max(dim=1, keepdim=True)[0]))
+
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
+ **axes_lengths)
+ if 'mask0' not in data:
+ mask_border(mask, self.border_rm, False)
+ else:
+ mask_border_with_padding(mask, self.border_rm, False,
+ data['mask0'], data['mask1'])
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
+ **axes_lengths)
+
+ # find all valid coarse matches
+ b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True)
+
+ mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids])
+
+ # random sampling of training samples for fine-level XoFTR
+ # (optional) pad samples with gt coarse-level matches
+ if self.training:
+ # NOTE:
+ # the sampling is performed across all pairs in a batch without manually balancing
+ # samples for fine-level increases w.r.t. batch_size
+ if 'mask0' not in data:
+ num_candidates_max = mask.size(0) * max(
+ mask.size(1), mask.size(2))
+ else:
+ num_candidates_max = compute_max_candidates(
+ data['mask0'], data['mask1'])
+ num_matches_train = int(num_candidates_max *
+ self.train_coarse_percent)
+ num_matches_pred = len(b_ids)
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
+
+ # pred_indices is to select from prediction
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
+ pred_indices = torch.arange(num_matches_pred, device=_device)
+ else:
+ pred_indices = torch.randint(
+ num_matches_pred,
+ (num_matches_train - self.train_pad_num_gt_min, ),
+ device=_device)
+
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
+ gt_pad_indices = torch.randint(
+ len(data['spv_b_ids']),
+ (max(num_matches_train - num_matches_pred,
+ self.train_pad_num_gt_min), ),
+ device=_device)
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
+
+ b_ids, i_ids, j_ids, mconf = map(
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
+ dim=0),
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+
+ # these matches are selected patches that feed into fine-level network
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+
+ # update with matches in original image resolution
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+ mkpts0_c = torch.stack(
+ [i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')],
+ dim=1) * scale0
+ mkpts1_c = torch.stack(
+ [j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')],
+ dim=1) * scale1
+
+ # these matches is the current prediction (for visualization)
+ coarse_matches.update({
+ 'gt_mask': mconf == 0,
+ 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
+ 'mkpts0_c': mkpts0_c[mconf != 0],
+ 'mkpts1_c': mkpts1_c[mconf != 0],
+ 'mconf': mconf[mconf != 0]
+ })
+
+ return coarse_matches
+
+ @torch.no_grad()
+ def get_coarse_match_inference(self, sim_matrix, data):
+ """
+ Args:
+ sim_matrix (torch.Tensor): [N, L, S]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ axes_lengths = {
+ 'h0c': data['hw0_c'][0],
+ 'w0c': data['hw0_c'][1],
+ 'h1c': data['hw1_c'][0],
+ 'w1c': data['hw1_c'][1]
+ }
+
+ # softmax for 0 to 1
+ conf_matrix_ = F.softmax(sim_matrix, 2)
+
+ # confidence thresholding and nearest neighbour for 0 to 1
+ mask = (conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=2, keepdim=True)[0])
+
+ # unlike training, reuse the same conf martix to decrease the vram consumption
+ # softmax for 0 to 1
+ conf_matrix_ = F.softmax(sim_matrix, 1)
+
+ # update mask {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)}
+ mask = torch.logical_or(mask,
+ (conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=1, keepdim=True)[0]))
+
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
+ **axes_lengths)
+ if 'mask0' not in data:
+ mask_border(mask, self.border_rm, False)
+ else:
+ mask_border_with_padding(mask, self.border_rm, False,
+ data['mask0'], data['mask1'])
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
+ **axes_lengths)
+
+ # find all valid coarse matches
+ b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True)
+
+ # mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids])
+
+ # these matches are selected patches that feed into fine-level network
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+
+ # update with matches in original image resolution
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+ mkpts0_c = torch.stack(
+ [i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')],
+ dim=1) * scale0
+ mkpts1_c = torch.stack(
+ [j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')],
+ dim=1) * scale1
+
+ # these matches are the current coarse level predictions
+ coarse_matches.update({
+ 'm_bids': b_ids, # mconf == 0 => gt matches
+ 'mkpts0_c': mkpts0_c,
+ 'mkpts1_c': mkpts1_c,
+ })
+
+ return coarse_matches
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..5430f71b0bedd60a613226e781ed8ce2baf7f9f2
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py
@@ -0,0 +1,170 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class FineSubMatching(nn.Module):
+ """Fine-level and Sub-pixel matching"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.temperature = config['fine']['dsmax_temperature']
+ self.W_f = config['fine_window_size']
+ self.denser = config['fine']['denser']
+ self.inference = config['fine']['inference']
+ dim_f = config['resnet']['block_dims'][0]
+ self.fine_thr = config['fine']['thr']
+ self.fine_proj = nn.Linear(dim_f, dim_f, bias=False)
+ self.subpixel_mlp = nn.Sequential(nn.Linear(2*dim_f, 2*dim_f, bias=False),
+ nn.ReLU(),
+ nn.Linear(2*dim_f, 4, bias=False))
+
+ def forward(self, feat_f0_unfold, feat_f1_unfold, data):
+ """
+ Args:
+ feat_f0_unfold (torch.Tensor): [M, WW, C]
+ feat_f1_unfold (torch.Tensor): [M, WW, C]
+ data (dict)
+ Update:
+ data (dict):{
+ 'expec_f' (torch.Tensor): [M, 3],
+ 'mkpts0_f' (torch.Tensor): [M, 2],
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
+ """
+
+ feat_f0 = self.fine_proj(feat_f0_unfold)
+ feat_f1 = self.fine_proj(feat_f1_unfold)
+
+ M, WW, C = feat_f0.shape
+ W_f = self.W_f
+
+ # corner case: if no coarse matches found
+ if M == 0:
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
+ # logger.warning('No matches found in coarse-level.')
+ data.update({
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ 'mconf_f': torch.zeros(0, device=feat_f0_unfold.device),
+ # 'mkpts0_f_train': data['mkpts0_c'],
+ # 'mkpts1_f_train': data['mkpts1_c'],
+ # 'conf_matrix_fine': torch.zeros(1, W_f*W_f, W_f*W_f, device=feat_f0.device)
+ })
+ return
+
+ # normalize
+ feat_f0, feat_f1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_f0, feat_f1])
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_f0,
+ feat_f1) / self.temperature
+
+ conf_matrix_fine = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
+ data.update({'conf_matrix_fine': conf_matrix_fine})
+
+ # predict fine-level and sub-pixel matches from conf_matrix
+ data.update(**self.get_fine_sub_match(conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data))
+
+ def get_fine_sub_match(self, conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data):
+ """
+ Args:
+ conf_matrix_fine (torch.Tensor): [M, WW, WW]
+ feat_f0_unfold (torch.Tensor): [M, WW, C]
+ feat_f1_unfold (torch.Tensor): [M, WW, C]
+ data (dict)
+ Update:
+ data (dict):{
+ 'm_bids' (torch.Tensor): [M]
+ 'expec_f' (torch.Tensor): [M, 3],
+ 'mkpts0_f' (torch.Tensor): [M, 2],
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
+ """
+
+ with torch.no_grad():
+ W_f = self.W_f
+
+ # 1. confidence thresholding
+ mask = conf_matrix_fine > self.fine_thr
+
+ if mask.sum() == 0:
+ mask[0,0,0] = 1
+ conf_matrix_fine[0,0,0] = 1
+
+ if not self.denser:
+ # match only the highest confidence
+ mask = mask \
+ * (conf_matrix_fine == conf_matrix_fine.amax(dim=[1,2], keepdim=True))
+ else:
+ # 2. mutual nearest, match all features in fine window
+ mask = mask \
+ * (conf_matrix_fine == conf_matrix_fine.max(dim=2, keepdim=True)[0]) \
+ * (conf_matrix_fine == conf_matrix_fine.max(dim=1, keepdim=True)[0])
+
+ # 3. find all valid fine matches
+ # this only works when at most one `True` in each row
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+ mconf = conf_matrix_fine[b_ids, i_ids, j_ids]
+
+ # 4. update with matches in original image resolution
+
+ # indices from coarse matches
+ b_ids_c, i_ids_c, j_ids_c = data['b_ids'], data['i_ids'], data['j_ids']
+
+ # scale (coarse level / fine-level)
+ scale_f_c = data['hw0_f'][0] // data['hw0_c'][0]
+
+ # coarse level matches scaled to fine-level (1/2)
+ mkpts0_c_scaled_to_f = torch.stack(
+ [i_ids_c % data['hw0_c'][1], torch.div(i_ids_c, data['hw0_c'][1], rounding_mode='trunc')],
+ dim=1) * scale_f_c
+
+ mkpts1_c_scaled_to_f = torch.stack(
+ [j_ids_c % data['hw1_c'][1], torch.div(j_ids_c, data['hw1_c'][1], rounding_mode='trunc')],
+ dim=1) * scale_f_c
+
+ # updated b_ids after second thresholding
+ updated_b_ids = b_ids_c[b_ids]
+
+ # scales (image res / fine level)
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
+ scale0 = scale * data['scale0'][updated_b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][updated_b_ids] if 'scale1' in data else scale
+
+ # fine-level discrete matches on window coordiantes
+ mkpts0_f_window = torch.stack(
+ [i_ids % W_f, torch.div(i_ids, W_f, rounding_mode='trunc')],
+ dim=1)
+
+ mkpts1_f_window = torch.stack(
+ [j_ids % W_f, torch.div(j_ids, W_f, rounding_mode='trunc')],
+ dim=1)
+
+ # sub-pixel refinement
+ sub_ref = self.subpixel_mlp(torch.cat([feat_f0_unfold[b_ids, i_ids],
+ feat_f1_unfold[b_ids, j_ids]], dim=-1))
+ sub_ref0, sub_ref1 = torch.chunk(sub_ref, 2, dim=-1)
+ sub_ref0 = torch.tanh(sub_ref0) * 0.5
+ sub_ref1 = torch.tanh(sub_ref1) * 0.5
+
+ # final sub-pixel matches by (coarse-level + fine-level windowed + sub-pixel refinement)
+ mkpts0_f_train = (mkpts0_f_window + mkpts0_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref0) * scale0
+ mkpts1_f_train = (mkpts1_f_window + mkpts1_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref1) * scale1
+ mkpts0_f = mkpts0_f_train.clone().detach()
+ mkpts1_f = mkpts1_f_train.clone().detach()
+
+ # These matches is the current prediction (for visualization)
+ sub_pixel_matches = {
+ 'm_bids': b_ids_c[b_ids[mconf != 0]], # mconf == 0 => gt matches
+ 'mkpts0_f': mkpts0_f[mconf != 0],
+ 'mkpts1_f': mkpts1_f[mconf != 0],
+ 'mconf_f': mconf[mconf != 0]
+ }
+
+ # These matches are used for training
+ if not self.inference:
+ sub_pixel_matches.update({
+ 'mkpts0_f_train': mkpts0_f_train[mconf != 0],
+ 'mkpts1_f_train': mkpts1_f_train[mconf != 0],
+ })
+
+ return sub_pixel_matches
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..001ab663b1fdd767c407e8089137a142012caa17
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py
@@ -0,0 +1,321 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange
+
+class Mlp(nn.Module):
+ """Multi-Layer Perceptron (MLP)"""
+
+ def __init__(self,
+ in_dim,
+ hidden_dim=None,
+ out_dim=None,
+ act_layer=nn.GELU):
+ """
+ Args:
+ in_dim: input features dimension
+ hidden_dim: hidden features dimension
+ out_dim: output features dimension
+ act_layer: activation function
+ """
+ super().__init__()
+ out_dim = out_dim or in_dim
+ hidden_dim = hidden_dim or in_dim
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
+ self.out_dim = out_dim
+
+ def forward(self, x):
+ x_size = x.size()
+ x = x.view(-1, x_size[-1])
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = x.view(*x_size[:-1], self.out_dim)
+ return x
+
+
+class VanillaAttention(nn.Module):
+ def __init__(self,
+ dim,
+ num_heads=8,
+ proj_bias=False):
+ super().__init__()
+ """
+ Args:
+ dim: feature dimension
+ num_heads: number of attention head
+ proj_bias: bool use query, key, value bias
+ """
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.softmax_temp = self.head_dim ** -0.5
+ self.kv_proj = nn.Linear(dim, dim * 2, bias=proj_bias)
+ self.q_proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.merge = nn.Linear(dim, dim)
+
+ def forward(self, x_q, x_kv=None):
+ """
+ Args:
+ x_q (torch.Tensor): [N, L, C]
+ x_kv (torch.Tensor): [N, S, C]
+ """
+ if x_kv is None:
+ x_kv = x_q
+ bs, _, dim = x_q.shape
+ bs, _, dim = x_kv.shape
+ # [N, S, 2, H, D] => [2, N, H, S, D]
+ kv = self.kv_proj(x_kv).reshape(bs, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ # [N, L, H, D] => [N, H, L, D]
+ q = self.q_proj(x_q).reshape(bs, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+ k, v = kv[0].transpose(-2, -1).contiguous(), kv[1].contiguous() # [N, H, D, S], [N, H, S, D]
+ attn = (q @ k) * self.softmax_temp # [N, H, L, S]
+ attn = attn.softmax(dim=-1)
+ x_q = (attn @ v).transpose(1, 2).reshape(bs, -1, dim)
+ x_q = self.merge(x_q)
+ return x_q
+
+
+class CrossBidirectionalAttention(nn.Module):
+ def __init__(self, dim, num_heads, proj_bias = False):
+ super().__init__()
+ """
+ Args:
+ dim: feature dimension
+ num_heads: number of attention head
+ proj_bias: bool use query, key, value bias
+ """
+
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.softmax_temp = self.head_dim ** -0.5
+ self.qk_proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.v_proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.merge = nn.Linear(dim, dim, bias=proj_bias)
+ self.temperature = nn.Parameter(torch.tensor([0.0]), requires_grad=True)
+ # print(self.temperature)
+
+ def map_(self, func, x0, x1):
+ return func(x0), func(x1)
+
+ def forward(self, x0, x1):
+ """
+ Args:
+ x0 (torch.Tensor): [N, L, C]
+ x1 (torch.Tensor): [N, S, C]
+ """
+ bs = x0.size(0)
+
+ qk0, qk1 = self.map_(self.qk_proj, x0, x1)
+ v0, v1 = self.map_(self.v_proj, x0, x1)
+ qk0, qk1, v0, v1 = map(
+ lambda t: t.reshape(bs, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous(),
+ (qk0, qk1, v0, v1))
+
+ qk0, qk1 = qk0 * self.softmax_temp**0.5, qk1 * self.softmax_temp**0.5
+ sim = qk0 @ qk1.transpose(-2,-1).contiguous()
+ attn01 = F.softmax(sim, dim=-1)
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
+ x0 = attn01 @ v1
+ x1 = attn10 @ v0
+ x0, x1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
+ x0, x1)
+ x0, x1 = self.map_(self.merge, x0, x1)
+
+ return x0, x1
+
+
+class SwinPosEmbMLP(nn.Module):
+ def __init__(self,
+ dim):
+ super().__init__()
+ self.pos_embed = None
+ self.pos_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(),
+ nn.Linear(512, dim, bias=False))
+
+ def forward(self, x):
+ seq_length = x.shape[1]
+ if self.pos_embed is None or self.training:
+ seq_length = int(seq_length**0.5)
+ coords = torch.arange(0, seq_length, device=x.device, dtype = x.dtype)
+ grid = torch.stack(torch.meshgrid([coords, coords])).contiguous().unsqueeze(0)
+ grid -= seq_length // 2
+ grid /= (seq_length // 2)
+ self.pos_embed = self.pos_mlp(grid.flatten(2).transpose(1,2))
+ x = x + self.pos_embed
+ return x
+
+
+class WindowSelfAttention(nn.Module):
+ def __init__(self, dim, num_heads, mlp_hidden_coef, use_pre_pos_embed=False):
+ super().__init__()
+ self.mlp = Mlp(in_dim=dim*2, hidden_dim=dim*mlp_hidden_coef, out_dim=dim, act_layer=nn.GELU)
+ self.gamma = nn.Parameter(torch.ones(dim))
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.attn = VanillaAttention(dim, num_heads=num_heads)
+ self.pos_embed = SwinPosEmbMLP(dim)
+ self.pos_embed_pre = SwinPosEmbMLP(dim) if use_pre_pos_embed else nn.Identity()
+
+ def forward(self, x, x_pre):
+ ww = x.shape[1]
+ ww_pre = x_pre.shape[1]
+ x = self.pos_embed(x)
+ x_pre = self.pos_embed_pre(x_pre)
+ x = torch.cat((x, x_pre), dim=1)
+ x = x + self.gamma*self.norm1(self.mlp(torch.cat([x, self.attn(self.norm2(x))], dim=-1)))
+ x, x_pre = x.split([ww, ww_pre], dim=1)
+ return x, x_pre
+
+
+class WindowCrossAttention(nn.Module):
+ def __init__(self, dim, num_heads, mlp_hidden_coef):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.mlp = Mlp(in_dim=dim*2, hidden_dim=dim*mlp_hidden_coef, out_dim=dim, act_layer=nn.GELU)
+ self.cross_attn = CrossBidirectionalAttention(dim, num_heads=num_heads, proj_bias=False)
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x0, x1):
+ m_x0, m_x1 = self.cross_attn(self.norm1(x0), self.norm1(x1))
+ x0 = x0 + self.gamma*self.norm2(self.mlp(torch.cat([x0, m_x0], dim=-1)))
+ x1 = x1 + self.gamma*self.norm2(self.mlp(torch.cat([x1, m_x1], dim=-1)))
+ return x0, x1
+
+
+class FineProcess(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block_dims = config['resnet']['block_dims']
+ self.block_dims = block_dims
+ self.W_f = config['fine_window_size']
+ self.W_m = config['medium_window_size']
+ nhead_f = config["fine"]['nhead_fine_level']
+ nhead_m = config["fine"]['nhead_medium_level']
+ mlp_hidden_coef = config["fine"]['mlp_hidden_dim_coef']
+
+ # Networks
+ self.conv_merge = nn.Sequential(nn.Conv2d(block_dims[2]*2, block_dims[1], kernel_size=1, stride=1, padding=0, bias=False),
+ nn.Conv2d(block_dims[1], block_dims[1], kernel_size=3, stride=1, padding=1, groups=block_dims[1], bias=False),
+ nn.BatchNorm2d(block_dims[1])
+ )
+ self.out_conv_m = nn.Conv2d(block_dims[1], block_dims[1], kernel_size=1, stride=1, padding=0, bias=False)
+ self.out_conv_f = nn.Conv2d(block_dims[0], block_dims[0], kernel_size=1, stride=1, padding=0, bias=False)
+ self.self_attn_m = WindowSelfAttention(block_dims[1], num_heads=nhead_m,
+ mlp_hidden_coef=mlp_hidden_coef, use_pre_pos_embed=False)
+ self.cross_attn_m = WindowCrossAttention(block_dims[1], num_heads=nhead_m,
+ mlp_hidden_coef=mlp_hidden_coef)
+ self.self_attn_f = WindowSelfAttention(block_dims[0], num_heads=nhead_f,
+ mlp_hidden_coef=mlp_hidden_coef, use_pre_pos_embed=True)
+ self.cross_attn_f = WindowCrossAttention(block_dims[0], num_heads=nhead_f,
+ mlp_hidden_coef=mlp_hidden_coef)
+ self.down_proj_m_f = nn.Linear(block_dims[1], block_dims[0], bias=False)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def pre_process(self, feat_f0, feat_f1, feat_m0, feat_m1, feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data):
+ W_f = self.W_f
+ W_m = self.W_m
+ data.update({'W_f': W_f,
+ 'W_m': W_m})
+
+ # merge coarse features before and after loftr layer, and down proj channel dimesions
+ feat_c0 = rearrange(feat_c0, 'n (h w) c -> n c h w', h =data["hw0_c"][0], w =data["hw0_c"][1])
+ feat_c1 = rearrange(feat_c1, 'n (h w) c -> n c h w', h =data["hw1_c"][0], w =data["hw1_c"][1])
+ feat_c0 = self.conv_merge(torch.cat([feat_c0, feat_c0_pre], dim=1))
+ feat_c1 = self.conv_merge(torch.cat([feat_c1, feat_c1_pre], dim=1))
+ feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) 1 c')
+ feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) 1 c')
+
+ stride_f = data['hw0_f'][0] // data['hw0_c'][0]
+ stride_m = data['hw0_m'][0] // data['hw0_c'][0]
+
+ if feat_m0.shape[2] == feat_m1.shape[2] and feat_m0.shape[3] == feat_m1.shape[3]:
+ feat_m = self.out_conv_m(torch.cat([feat_m0, feat_m1], dim=0))
+ feat_m0, feat_m1 = torch.chunk(feat_m, 2, dim=0)
+ feat_f = self.out_conv_f(torch.cat([feat_f0, feat_f1], dim=0))
+ feat_f0, feat_f1 = torch.chunk(feat_f, 2, dim=0)
+ else:
+ feat_m0 = self.out_conv_m(feat_m0)
+ feat_m1 = self.out_conv_m(feat_m1)
+ feat_f0 = self.out_conv_f(feat_f0)
+ feat_f1 = self.out_conv_f(feat_f1)
+
+ # 1. unfold (crop windows) all local windows
+ feat_m0_unfold = F.unfold(feat_m0, kernel_size=(W_m, W_m), stride=stride_m, padding=W_m//2)
+ feat_m0_unfold = rearrange(feat_m0_unfold, 'n (c ww) l -> n l ww c', ww=W_m**2)
+ feat_m1_unfold = F.unfold(feat_m1, kernel_size=(W_m, W_m), stride=stride_m, padding=W_m//2)
+ feat_m1_unfold = rearrange(feat_m1_unfold, 'n (c ww) l -> n l ww c', ww=W_m**2)
+
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W_f**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W_f**2)
+
+ # 2. select only the predicted matches
+ feat_c0 = feat_c0[data['b_ids'], data['i_ids']] # [n, ww, cm]
+ feat_c1 = feat_c1[data['b_ids'], data['j_ids']]
+
+ feat_m0_unfold = feat_m0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cm]
+ feat_m1_unfold = feat_m1_unfold[data['b_ids'], data['j_ids']]
+
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+
+ return feat_c0, feat_c1, feat_m0_unfold, feat_m1_unfold, feat_f0_unfold, feat_f1_unfold
+
+ def forward(self, feat_f0, feat_f1, feat_m0, feat_m1, feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data):
+ """
+ Args:
+ feat_f0 (torch.Tensor): [N, C, H, W]
+ feat_f1 (torch.Tensor): [N, C, H, W]
+ feat_m0 (torch.Tensor): [N, C, H, W]
+ feat_m1 (torch.Tensor): [N, C, H, W]
+ feat_c0 (torch.Tensor): [N, L, C]
+ feat_c1 (torch.Tensor): [N, S, C]
+ feat_c0_pre (torch.Tensor): [N, C, H, W]
+ feat_c1_pre (torch.Tensor): [N, C, H, W]
+ data (dict): with keys ['hw0_c', 'hw1_c', 'hw0_m', 'hw1_m', 'hw0_f', 'hw1_f', 'b_ids', 'j_ids']
+ """
+
+ # TODO: Check for this case
+ if data['b_ids'].shape[0] == 0:
+ feat0 = torch.empty(0, self.W_f**2, self.block_dims[0], device=feat_f0.device)
+ feat1 = torch.empty(0, self.W_f**2, self.block_dims[0], device=feat_f0.device)
+ return feat0, feat1
+
+ feat_c0, feat_c1, feat_m0_unfold, feat_m1_unfold, \
+ feat_f0_unfold, feat_f1_unfold = self.pre_process(feat_f0, feat_f1, feat_m0, feat_m1,
+ feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data)
+
+ # self attention (c + m)
+ feat_m_unfold, _ = self.self_attn_m(torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0),
+ torch.cat([feat_c0, feat_c1], dim=0))
+ feat_m0_unfold, feat_m1_unfold = torch.chunk(feat_m_unfold, 2, dim=0)
+
+ # cross attention (m0 <-> m1)
+ feat_m0_unfold, feat_m1_unfold = self.cross_attn_m(feat_m0_unfold, feat_m1_unfold)
+
+ # down proj m
+ feat_m_unfold = self.down_proj_m_f(torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0))
+ feat_m0_unfold, feat_m1_unfold = torch.chunk(feat_m_unfold, 2, dim=0)
+
+ # self attention (m + f)
+ feat_f_unfold, _ = self.self_attn_f(torch.cat([feat_f0_unfold, feat_f1_unfold], dim=0),
+ torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0))
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_f_unfold, 2, dim=0)
+
+ # cross attention (f0 <-> f1)
+ feat_f0_unfold, feat_f1_unfold = self.cross_attn_f(feat_f0_unfold, feat_f1_unfold)
+
+ return feat_f0_unfold, feat_f1_unfold
+
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..61b1b8573e6454b6d340c20381ad5f945d479791
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py
@@ -0,0 +1,81 @@
+"""
+Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
+Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
+"""
+
+import torch
+from torch.nn import Module, Dropout
+
+
+def elu_feature_map(x):
+ return torch.nn.functional.elu(x) + 1
+
+
+class LinearAttention(Module):
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.feature_map = elu_feature_map
+ self.eps = eps
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = self.feature_map(queries)
+ K = self.feature_map(keys)
+
+ # set padded position to zero
+ if q_mask is not None:
+ Q = Q * q_mask[:, :, None, None]
+ if kv_mask is not None:
+ K = K * kv_mask[:, :, None, None]
+ values = values * kv_mask[:, :, None, None]
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+
+ return queried_values.contiguous()
+
+
+class FullAttention(Module):
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ self.dropout = Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
+ if kv_mask is not None:
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+ if self.use_dropout:
+ A = self.dropout(A)
+
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
+
+ return queried_values.contiguous()
\ No newline at end of file
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eef7521a2c33e26fc66ca6797e842a7a146c21f
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py
@@ -0,0 +1,101 @@
+import copy
+import torch
+import torch.nn as nn
+from .linear_attention import LinearAttention, FullAttention
+
+
+class LoFTREncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear'):
+ super(LoFTREncoderLayer, self).__init__()
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, L, C]
+ source (torch.Tensor): [N, S, C]
+ x_mask (torch.Tensor): [N, L] (optional)
+ source_mask (torch.Tensor): [N, S] (optional)
+ """
+ bs = x.size(0)
+ query, key, value = x, source, source
+
+ # multi-head attention
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=2))
+ message = self.norm2(message)
+
+ return x + message
+
+
+class LocalFeatureTransformer(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(LocalFeatureTransformer, self).__init__()
+
+ self.config = config
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+ self.layer_names = config['layer_names']
+ encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+
+ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
+
+ for layer, name in zip(self.layers, self.layer_names):
+ if name == 'self':
+ feat0 = layer(feat0, feat0, mask0, mask0)
+ feat1 = layer(feat1, feat1, mask1, mask1)
+ elif name == 'cross':
+ feat0 = layer(feat0, feat1, mask0, mask1)
+ feat1 = layer(feat1, feat0, mask1, mask0)
+ else:
+ raise KeyError
+
+ return feat0, feat1
\ No newline at end of file
diff --git a/imcui/third_party/XoFTR/src/xoftr/xoftr_pretrain.py b/imcui/third_party/XoFTR/src/xoftr/xoftr_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a4934eeb2d856d72010840507dde5b2bd1d3a94
--- /dev/null
+++ b/imcui/third_party/XoFTR/src/xoftr/xoftr_pretrain.py
@@ -0,0 +1,209 @@
+import torch
+import torch.nn as nn
+from einops.einops import rearrange
+from .backbone import ResNet_8_2
+from .utils.position_encoding import PositionEncodingSine
+from .xoftr_module import LocalFeatureTransformer, FineProcess
+
+
+class XoFTR_Pretrain(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Misc
+ self.config = config
+ self.patch_size = config["pretrain_patch_size"]
+
+ # Modules
+ self.backbone = ResNet_8_2(config['resnet'])
+ self.pos_encoding = PositionEncodingSine(config['coarse']['d_model'])
+ self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
+ self.fine_process = FineProcess(config)
+ self.mask_token_f = nn.Parameter(torch.zeros(1, config['resnet']["block_dims"][0], 1, 1))
+ self.mask_token_m = nn.Parameter(torch.zeros(1, config['resnet']["block_dims"][1], 1, 1))
+ self.mask_token_c = nn.Parameter(torch.zeros(1, config['resnet']["block_dims"][2], 1, 1))
+ self.out_proj = nn.Linear(config['resnet']["block_dims"][0], 4)
+
+ torch.nn.init.normal_(self.mask_token_f, std=.02)
+ torch.nn.init.normal_(self.mask_token_m, std=.02)
+ torch.nn.init.normal_(self.mask_token_c, std=.02)
+
+ def upsample_mae_mask(self, mae_mask, scale):
+ assert len(mae_mask.shape) == 2
+ p = int(mae_mask.shape[1] ** .5)
+ return mae_mask.reshape(-1, p, p).repeat_interleave(scale, axis=1).repeat_interleave(scale, axis=2)
+
+ def upsample_mask(self, mask, scale):
+ return mask.repeat_interleave(scale, axis=1).repeat_interleave(scale, axis=2)
+
+ def mask_layer(self, feat, mae_mask, mae_mask_scale, mask=None, mask_scale=None, mask_token=None):
+ """ Mask the feature map and replace with trainable inpu tokens if available
+ Args:
+ feat (torch.Tensor): [N, C, H, W]
+ mae_mask (torch.Tensor): (N, L) mask for masked image modeling
+ mae_mask_scale (int): the scale of layer to mae mask
+ mask (torch.Generator): mask for padded input image
+ mask_scale (int): the scale of layer to mask (mask is created on course scale)
+ mask_token (torch.Tensor): [1, C, 1, 1] learnable mae mask token
+ Returns:
+ feat (torch.Tensor): [N, C, H, W]
+ """
+ mae_mask = self.upsample_mae_mask(mae_mask, mae_mask_scale)
+ mae_mask = mae_mask.unsqueeze(1).type_as(feat)
+ if mask is not None:
+ mask = self.upsample_mask(mask, mask_scale)
+ mask = mask.unsqueeze(1).type_as(feat)
+ mae_mask = mask * mae_mask
+ feat = feat * (1. - mae_mask)
+ if mask_token is not None:
+ mask_token = mask_token.repeat(feat.shape[0], 1, feat.shape[2], feat.shape[3])
+ feat += mask_token * mae_mask
+ return feat
+
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+
+ image0 = data["image0_norm"] if "image0_norm" in data else data["image0"]
+ image1 = data["image1_norm"] if "image1_norm" in data else data["image1"]
+
+ mask0 = mask1 = None # mask fro madded images
+ if 'mask0' in data:
+ mask0, mask1 = data['mask0'], data['mask1']
+
+ # mask input images
+ image0 = self.mask_layer(image0,
+ data["mae_mask0"],
+ mae_mask_scale=self.patch_size,
+ mask=mask0,
+ mask_scale=8)
+ image1 = self.mask_layer(image1,
+ data["mae_mask1"],
+ mae_mask_scale=self.patch_size,
+ mask=mask1,
+ mask_scale=8)
+ data.update({"masked_image0":image0.clone().detach().cpu(),
+ "masked_image1":image1.clone().detach().cpu()})
+
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
+ feats_c, feats_m, feats_f = self.backbone(torch.cat([image0, image1], dim=0))
+ (feat_c0, feat_c1) = feats_c.split(data['bs'])
+ (feat_m0, feat_m1) = feats_m.split(data['bs'])
+ (feat_f0, feat_f1) = feats_f.split(data['bs'])
+ else: # handle different input shapes
+ feat_c0, feat_m0, feat_f0 = self.backbone(image0)
+ feat_c1, feat_m1, feat_f1 = self.backbone(image1)
+
+ # mask output layers of backbone and replace with trainable token
+ feat_c0 = self.mask_layer(feat_c0,
+ data["mae_mask0"],
+ mae_mask_scale=self.patch_size // 8,
+ mask=mask0,
+ mask_scale=1,
+ mask_token=self.mask_token_c)
+ feat_c1 = self.mask_layer(feat_c1,
+ data["mae_mask1"],
+ mae_mask_scale=self.patch_size // 8,
+ mask=mask1,
+ mask_scale=1,
+ mask_token=self.mask_token_c)
+ feat_m0 = self.mask_layer(feat_m0,
+ data["mae_mask0"],
+ mae_mask_scale=self.patch_size // 4,
+ mask=mask0,
+ mask_scale=2,
+ mask_token=self.mask_token_m)
+ feat_m1 = self.mask_layer(feat_m1,
+ data["mae_mask1"],
+ mae_mask_scale=self.patch_size // 4,
+ mask=mask1,
+ mask_scale=2,
+ mask_token=self.mask_token_m)
+ feat_f0 = self.mask_layer(feat_f0,
+ data["mae_mask0"],
+ mae_mask_scale=self.patch_size // 2,
+ mask=mask0,
+ mask_scale=4,
+ mask_token=self.mask_token_f)
+ feat_f1 = self.mask_layer(feat_f1,
+ data["mae_mask1"],
+ mae_mask_scale=self.patch_size // 2,
+ mask=mask1,
+ mask_scale=4,
+ mask_token=self.mask_token_f)
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_m': feat_m0.shape[2:], 'hw1_m': feat_m1.shape[2:],
+ 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
+ })
+
+ # save coarse features for fine matching module
+ feat_c0_pre, feat_c1_pre = feat_c0.clone(), feat_c1.clone()
+
+ # 2. Coarse-level loftr module
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
+ feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
+
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
+
+ # 3. Fine-level maching module as decoder
+ # generate window locations from mae mask to reconstruct
+ mae_mask_c0 = self.upsample_mae_mask( data["mae_mask0"],
+ self.patch_size // 8)
+ if mask0 is not None:
+ mae_mask_c0 = mae_mask_c0 * mask0.type_as(mae_mask_c0)
+
+ mae_mask_c1 = self.upsample_mae_mask( data["mae_mask1"],
+ self.patch_size // 8)
+ if mask1 is not None:
+ mae_mask_c1 = mae_mask_c1 * mask1.type_as(mae_mask_c1)
+
+ mae_mask_c = torch.logical_or(mae_mask_c0, mae_mask_c1)
+
+ b_ids, i_ids = mae_mask_c.flatten(1).nonzero(as_tuple=True)
+ j_ids = i_ids
+
+ # b_ids, i_ids and j_ids are masked location for both images
+ # ids_image0 and ids_image1 determines which indeces belogs to which image
+ ids_image0 = mae_mask_c0.flatten(1)[b_ids, i_ids]
+ ids_image1 = mae_mask_c1.flatten(1)[b_ids, j_ids]
+
+ data.update({'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids,
+ 'ids_image0': ids_image0==1, 'ids_image1': ids_image1==1})
+
+
+ # fine level matching module
+ feat_f0_unfold, feat_f1_unfold = self.fine_process( feat_f0, feat_f1,
+ feat_m0, feat_m1,
+ feat_c0, feat_c1,
+ feat_c0_pre, feat_c1_pre,
+ data)
+
+ # output projection 5x5 window to 10x10 window
+ pred0 = self.out_proj(feat_f0_unfold)
+ pred1 = self.out_proj(feat_f1_unfold)
+
+ data.update({"pred0":pred0, "pred1": pred1})
+
+
+ def load_state_dict(self, state_dict, *args, **kwargs):
+ for k in list(state_dict.keys()):
+ if k.startswith('matcher.'):
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+ return super().load_state_dict(state_dict, *args, **kwargs)
diff --git a/imcui/third_party/XoFTR/test.py b/imcui/third_party/XoFTR/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..133a146d92f5d3ebc59fc1fbad6551aab4b3f5b4
--- /dev/null
+++ b/imcui/third_party/XoFTR/test.py
@@ -0,0 +1,68 @@
+import pytorch_lightning as pl
+import argparse
+import pprint
+from loguru import logger as loguru_logger
+
+from src.config.default import get_cfg_defaults
+from src.utils.profiler import build_profiler
+
+from src.lightning.data import MultiSceneDataModule
+from src.lightning.lightning_xoftr import PL_XoFTR
+
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
+ parser.add_argument(
+ '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
+ parser.add_argument(
+ '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--batch_size', type=int, default=1, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=2)
+ parser.add_argument(
+ '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+ # parse arguments
+ args = parse_args()
+ pprint.pprint(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+
+ # tune when testing
+ if args.thr is not None:
+ config.XoFTR.MATCH_COARSE.THR = args.thr
+
+ loguru_logger.info(f"Args and config initialized!")
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_XoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
+ loguru_logger.info(f"XoFTR-lightning initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"DataModule initialized!")
+
+ # lightning trainer
+ trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)
+
+ loguru_logger.info(f"Start testing!")
+ trainer.test(model, datamodule=data_module, verbose=False)
diff --git a/imcui/third_party/XoFTR/test_relative_pose.py b/imcui/third_party/XoFTR/test_relative_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fb8c7ce58d46623ff9e06b784ac846355b537b9
--- /dev/null
+++ b/imcui/third_party/XoFTR/test_relative_pose.py
@@ -0,0 +1,330 @@
+from collections import defaultdict, OrderedDict
+import os
+import os.path as osp
+import numpy as np
+from tqdm import tqdm
+import argparse
+import cv2
+from pathlib import Path
+import warnings
+import json
+import time
+
+from src.utils.metrics import estimate_pose, relative_pose_error, error_auc, symmetric_epipolar_distance_numpy
+from src.utils.plotting import dynamic_alpha, error_colormap, make_matching_figure
+
+
+# Loading functions for methods
+####################################################################
+def load_xoftr(args):
+ from src.xoftr import XoFTR
+ from src.config.default import get_cfg_defaults
+ from src.utils.data_io import DataIOWrapper, lower_config
+ config = get_cfg_defaults(inference=True)
+ config = lower_config(config)
+ config["xoftr"]["match_coarse"]["thr"] = args.match_threshold
+ config["xoftr"]["fine"]["thr"] = args.fine_threshold
+ ckpt = args.ckpt
+ matcher = XoFTR(config=config["xoftr"])
+ matcher = DataIOWrapper(matcher, config=config["test"], ckpt=ckpt)
+ return matcher.from_paths
+
+####################################################################
+
+def load_vis_tir_pairs_npz(npz_root, npz_list):
+ """Load information for scene and image pairs from npz files.
+ Args:
+ npz_root: Directory path for npz files
+ npz_list: File containing the names of the npz files to be used
+ """
+ with open(npz_list, 'r') as f:
+ npz_names = [name.split()[0] for name in f.readlines()]
+ print(f"Parse {len(npz_names)} npz from {npz_list}.")
+
+ total_pairs = 0
+ scene_pairs = {}
+
+ for name in npz_names:
+ print(f"Loading {name}")
+ scene_info = np.load(f"{npz_root}/{name}", allow_pickle=True)
+ pairs = []
+
+ # Collect pairs
+ for pair_info in scene_info['pair_infos']:
+ total_pairs += 1
+ (id0, id1) = pair_info
+ im0 = scene_info['image_paths'][id0][0]
+ im1 = scene_info['image_paths'][id1][1]
+ K0 = scene_info['intrinsics'][id0][0].astype(np.float32)
+ K1 = scene_info['intrinsics'][id1][1].astype(np.float32)
+
+ dist0 = np.array(scene_info['distortion_coefs'][id0][0], dtype=float)
+ dist1 = np.array(scene_info['distortion_coefs'][id1][1], dtype=float)
+ # Compute relative pose
+ T0 = scene_info['poses'][id0]
+ T1 = scene_info['poses'][id1]
+ T_0to1 = np.matmul(T1, np.linalg.inv(T0))
+ pairs.append({'im0':im0, 'im1':im1, 'dist0':dist0, 'dist1':dist1,
+ 'K0':K0, 'K1':K1, 'T_0to1':T_0to1})
+ scene_pairs[name] = pairs
+
+ print(f"Loaded {total_pairs} pairs.")
+ return scene_pairs
+
+
+
+def save_matching_figure(path, img0, img1, mkpts0, mkpts1, inlier_mask, T_0to1, K0, K1, t_err=None, R_err=None, name=None, conf_thr = 5e-4):
+ """ Make and save matching figures
+ """
+ Tx = np.cross(np.eye(3), T_0to1[:3, 3])
+ E_mat = Tx @ T_0to1[:3, :3]
+ mkpts0_inliers = mkpts0[inlier_mask]
+ mkpts1_inliers = mkpts1[inlier_mask]
+ if inlier_mask is not None and len(inlier_mask) != 0:
+ epi_errs = symmetric_epipolar_distance_numpy(mkpts0_inliers, mkpts1_inliers, E_mat, K0, K1)
+
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+
+ # matching info
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+ text_precision =[
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(mkpts0_inliers)}']
+ else:
+ text_precision =[
+ f'No inliers after ransac']
+
+ if name is not None:
+ text=[name]
+ else:
+ text = []
+
+ if t_err is not None and R_err is not None:
+ error_text = [f"err_t: {t_err:.2f} °", f"err_R: {R_err:.2f} °"]
+ text +=error_text
+
+ text += text_precision
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, mkpts0_inliers, mkpts1_inliers,
+ color, text=text, path=path, dpi=150)
+
+
+def aggregiate_scenes(scene_pose_auc, thresholds):
+ """Averages the auc results for cloudy_cloud and cloudy_sunny scenes
+ """
+ temp_pose_auc = {}
+ for npz_name in scene_pose_auc.keys():
+ scene_name = npz_name.split("_scene")[0]
+ temp_pose_auc[scene_name] = [np.zeros(len(thresholds), dtype=np.float32), 0] # [sum, total_number]
+ for npz_name in scene_pose_auc.keys():
+ scene_name = npz_name.split("_scene")[0]
+ temp_pose_auc[scene_name][0] += scene_pose_auc[npz_name]
+ temp_pose_auc[scene_name][1] += 1
+
+ agg_pose_auc = {}
+ for scene_name in temp_pose_auc.keys():
+ agg_pose_auc[scene_name] = temp_pose_auc[scene_name][0] / temp_pose_auc[scene_name][1]
+
+ return agg_pose_auc
+
+def eval_relapose(
+ matcher,
+ data_root,
+ scene_pairs,
+ ransac_thres,
+ thresholds,
+ save_figs,
+ figures_dir=None,
+ method=None,
+ print_out=False,
+ debug=False,
+):
+ scene_pose_auc = {}
+ for scene_name in scene_pairs.keys():
+ scene_dir = osp.join(figures_dir, scene_name.split(".")[0])
+ if save_figs and not osp.exists(scene_dir):
+ os.makedirs(scene_dir)
+
+ pairs = scene_pairs[scene_name]
+ statis = defaultdict(list)
+ np.set_printoptions(precision=2)
+
+ # Eval on pairs
+ print(f"\nStart evaluation on VisTir \n")
+ for i, pair in tqdm(enumerate(pairs), smoothing=.1, total=len(pairs)):
+ if debug and i > 10:
+ break
+
+ T_0to1 = pair['T_0to1']
+ im0 = str(data_root / pair['im0'])
+ im1 = str(data_root / pair['im1'])
+ match_res = matcher(im0, im1, pair['K0'], pair['K1'], pair['dist0'], pair['dist1'])
+ matches = match_res['matches']
+ new_K0 = match_res['new_K0']
+ new_K1 = match_res['new_K1']
+ mkpts0 = match_res['mkpts0']
+ mkpts1 = match_res['mkpts1']
+
+ # Calculate pose errors
+ ret = estimate_pose(
+ mkpts0, mkpts1, new_K0, new_K1, thresh=ransac_thres
+ )
+
+ if ret is None:
+ R, t, inliers = None, None, None
+ t_err, R_err = np.inf, np.inf
+ statis['failed'].append(i)
+ statis['R_errs'].append(R_err)
+ statis['t_errs'].append(t_err)
+ statis['inliers'].append(np.array([]).astype(np.bool_))
+ else:
+ R, t, inliers = ret
+ t_err, R_err = relative_pose_error(T_0to1, R, t)
+ statis['R_errs'].append(R_err)
+ statis['t_errs'].append(t_err)
+ statis['inliers'].append(inliers.sum() / len(mkpts0))
+ if print_out:
+ print(f"#M={len(matches)} R={R_err:.3f}, t={t_err:.3f}")
+
+ if save_figs:
+ img0_name = f"{'vis' if 'visible' in pair['im0'] else 'tir'}_{osp.basename(pair['im0']).split('.')[0]}"
+ img1_name = f"{'vis' if 'visible' in pair['im1'] else 'tir'}_{osp.basename(pair['im1']).split('.')[0]}"
+ fig_path = osp.join(scene_dir, f"{img0_name}_{img1_name}.jpg")
+ save_matching_figure(path=fig_path,
+ img0=match_res['img0_undistorted'] if 'img0_undistorted' in match_res.keys() else match_res['img0'],
+ img1=match_res['img1_undistorted'] if 'img1_undistorted' in match_res.keys() else match_res['img1'],
+ mkpts0=mkpts0,
+ mkpts1=mkpts1,
+ inlier_mask=inliers,
+ T_0to1=T_0to1,
+ K0=new_K0,
+ K1=new_K1,
+ t_err=t_err,
+ R_err=R_err,
+ name=method
+ )
+
+ print(f"Scene: {scene_name} Total samples: {len(pairs)} Failed:{len(statis['failed'])}. \n")
+ pose_errors = np.max(np.stack([statis['R_errs'], statis['t_errs']]), axis=0)
+ pose_auc = error_auc(pose_errors, thresholds) # (auc@5, auc@10, auc@20)
+ scene_pose_auc[scene_name] = 100 * np.array([pose_auc[f'auc@{t}'] for t in thresholds])
+ print(f"{scene_name} {pose_auc}")
+ agg_pose_auc = aggregiate_scenes(scene_pose_auc, thresholds)
+ return scene_pose_auc, agg_pose_auc
+
+def test_relative_pose_vistir(
+ data_root_dir,
+ method="xoftr",
+ exp_name = "VisTIR",
+ ransac_thres=1.5,
+ print_out=False,
+ save_dir=None,
+ save_figs=False,
+ debug=False,
+ args=None
+
+):
+ if not osp.exists(osp.join(save_dir, method)):
+ os.makedirs(osp.join(save_dir, method))
+
+ counter = 0
+ path = osp.join(save_dir, method, f"{exp_name}"+"_{}")
+ while osp.exists(path.format(counter)):
+ counter += 1
+ exp_dir = path.format(counter)
+ os.mkdir(exp_dir)
+ results_file = osp.join(exp_dir, "results.json")
+ figures_dir = osp.join(exp_dir, "match_figures")
+ if save_figs:
+ os.mkdir(figures_dir)
+
+ # Init paths
+ npz_root = data_root_dir / 'index/scene_info_test/'
+ npz_list = data_root_dir / 'index/val_test_list/test_list.txt'
+ data_root = data_root_dir
+
+ # Load pairs
+ scene_pairs = load_vis_tir_pairs_npz(npz_root, npz_list)
+
+ # Load method
+ matcher = eval(f"load_{method}")(args)
+
+ thresholds=[5, 10, 20]
+ # Eval
+ scene_pose_auc, agg_pose_auc = eval_relapose(
+ matcher,
+ data_root,
+ scene_pairs,
+ ransac_thres=ransac_thres,
+ thresholds=thresholds,
+ save_figs=save_figs,
+ figures_dir=figures_dir,
+ method=method,
+ print_out=print_out,
+ debug=debug,
+ )
+
+ # Create result dict
+ results = OrderedDict({"method": method,
+ "exp_name": exp_name,
+ "ransac_thres": ransac_thres,
+ "auc_thresholds": thresholds})
+ results.update({key:value for key, value in vars(args).items() if key not in results})
+ results.update({key:value.tolist() for key, value in agg_pose_auc.items()})
+ results.update({key:value.tolist() for key, value in scene_pose_auc.items()})
+
+ print(f"Results: {json.dumps(results, indent=4)}")
+
+ # Save to json file
+ with open(results_file, 'w') as outfile:
+ json.dump(results, outfile, indent=4)
+
+ print(f"Results saved to {results_file}")
+
+if __name__ == '__main__':
+
+ def add_common_arguments(parser):
+ parser.add_argument('--gpu', '-gpu', type=str, default='0')
+ parser.add_argument('--exp_name', type=str, default="VisTIR")
+ parser.add_argument('--data_root_dir', type=str, default="./data/METU_VisTIR/")
+ parser.add_argument('--save_dir', type=str, default="./results_relative_pose")
+ parser.add_argument('--ransac_thres', type=float, default=1.5)
+ parser.add_argument('--print_out', action='store_true')
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--save_figs', action='store_true')
+
+ def add_xoftr_arguments(subparsers):
+ subcommand = subparsers.add_parser('xoftr')
+ subcommand.add_argument('--match_threshold', type=float, default=0.3)
+ subcommand.add_argument('--fine_threshold', type=float, default=0.1)
+ subcommand.add_argument('--ckpt', type=str, default="./weights/weights_xoftr_640.ckpt")
+ add_common_arguments(subcommand)
+
+ parser = argparse.ArgumentParser(description='Benchmark Relative Pose')
+ add_common_arguments(parser)
+
+ # Create subparsers for top-level commands
+ subparsers = parser.add_subparsers(dest="method")
+ add_xoftr_arguments(subparsers)
+
+ args = parser.parse_args()
+
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
+ tt = time.time()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ test_relative_pose_vistir(
+ Path(args.data_root_dir),
+ args.method,
+ args.exp_name,
+ ransac_thres=args.ransac_thres,
+ print_out=args.print_out,
+ save_dir = args.save_dir,
+ save_figs = args.save_figs,
+ debug=args.debug,
+ args=args
+ )
+ print(f"Elapsed time: {time.time() - tt}")
diff --git a/imcui/third_party/XoFTR/train.py b/imcui/third_party/XoFTR/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fdb26d3857f276691dce65be6ad817e36db7414
--- /dev/null
+++ b/imcui/third_party/XoFTR/train.py
@@ -0,0 +1,126 @@
+import math
+import argparse
+import pprint
+from distutils.util import strtobool
+from pathlib import Path
+from loguru import logger as loguru_logger
+from datetime import datetime
+
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+from pytorch_lightning.plugins import DDPPlugin
+
+from src.config.default import get_cfg_defaults
+from src.utils.misc import get_rank_zero_only_logger, setup_gpus
+from src.utils.profiler import build_profiler
+from src.lightning.data import MultiSceneDataModule
+from src.lightning.lightning_xoftr import PL_XoFTR
+
+loguru_logger = get_rank_zero_only_logger(loguru_logger)
+
+
+def parse_args():
+ # init a costum parser which will be added into pl.Trainer parser
+ # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ 'data_cfg_path', type=str, help='data config path')
+ parser.add_argument(
+ 'main_cfg_path', type=str, help='main config path')
+ parser.add_argument(
+ '--exp_name', type=str, default='default_exp_name')
+ parser.add_argument(
+ '--batch_size', type=int, default=4, help='batch_size per gpu')
+ parser.add_argument(
+ '--num_workers', type=int, default=4)
+ parser.add_argument(
+ '--pin_memory', type=lambda x: bool(strtobool(x)),
+ nargs='?', default=True, help='whether loading data to pinned memory or not')
+ parser.add_argument(
+ '--ckpt_path', type=str, default=None,
+ help='pretrained checkpoint path, helpful for using a pre-trained coarse-only XoFTR')
+ parser.add_argument(
+ '--disable_ckpt', action='store_true',
+ help='disable checkpoint saving (useful for debugging).')
+ parser.add_argument(
+ '--profiler_name', type=str, default=None,
+ help='options: [inference, pytorch], or leave it unset')
+ parser.add_argument(
+ '--parallel_load_data', action='store_true',
+ help='load datasets in with multiple processes.')
+
+ parser = pl.Trainer.add_argparse_args(parser)
+ return parser.parse_args()
+
+
+def main():
+ # parse arguments
+ args = parse_args()
+ rank_zero_only(pprint.pprint)(vars(args))
+
+ # init default-cfg and merge it with the main- and data-cfg
+ config = get_cfg_defaults()
+ config.merge_from_file(args.main_cfg_path)
+ config.merge_from_file(args.data_cfg_path)
+ pl.seed_everything(config.TRAINER.SEED) # reproducibility
+
+ # scale lr and warmup-step automatically
+ args.gpus = _n_gpus = setup_gpus(args.gpus)
+ config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
+ config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
+ _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
+ config.TRAINER.SCALING = _scaling
+ config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
+ config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
+
+
+ # lightning module
+ profiler = build_profiler(args.profiler_name)
+ model = PL_XoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
+ loguru_logger.info(f"XoFTR LightningModule initialized!")
+
+ # lightning data
+ data_module = MultiSceneDataModule(args, config)
+ loguru_logger.info(f"XoFTR DataModule initialized!")
+
+ # TensorBoard Logger
+ logger = [TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)]
+ ckpt_dir = Path(logger[0].log_dir) / 'checkpoints'
+ if config.TRAINER.USE_WANDB:
+ logger.append(WandbLogger(name=args.exp_name + f"_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}",
+ project='XoFTR'))
+
+ # Callbacks
+ # TODO: update ModelCheckpoint to monitor multiple metrics
+ ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=-1, mode='max',
+ save_last=True,
+ dirpath=str(ckpt_dir),
+ filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
+ lr_monitor = LearningRateMonitor(logging_interval='step')
+ callbacks = [lr_monitor]
+ if not args.disable_ckpt:
+ callbacks.append(ckpt_callback)
+
+ # Lightning Trainer
+ trainer = pl.Trainer.from_argparse_args(
+ args,
+ plugins=DDPPlugin(find_unused_parameters=True,
+ num_nodes=args.num_nodes,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
+ gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
+ callbacks=callbacks,
+ logger=logger,
+ sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
+ replace_sampler_ddp=False, # use custom sampler
+ reload_dataloaders_every_epoch=False, # avoid repeated samples!
+ weights_summary='full',
+ profiler=profiler)
+ loguru_logger.info(f"Trainer initialized!")
+ loguru_logger.info(f"Start training!")
+ trainer.fit(model, datamodule=data_module)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/imcui/third_party/d2net/extract_features.py b/imcui/third_party/d2net/extract_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..628463a7d042a90b5cadea8a317237cde86f5ae4
--- /dev/null
+++ b/imcui/third_party/d2net/extract_features.py
@@ -0,0 +1,156 @@
+import argparse
+
+import numpy as np
+
+import imageio
+
+import torch
+
+from tqdm import tqdm
+
+import scipy
+import scipy.io
+import scipy.misc
+
+from lib.model_test import D2Net
+from lib.utils import preprocess_image
+from lib.pyramid import process_multiscale
+
+# CUDA
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:0" if use_cuda else "cpu")
+
+# Argument parsing
+parser = argparse.ArgumentParser(description='Feature extraction script')
+
+parser.add_argument(
+ '--image_list_file', type=str, required=True,
+ help='path to a file containing a list of images to process'
+)
+
+parser.add_argument(
+ '--preprocessing', type=str, default='caffe',
+ help='image preprocessing (caffe or torch)'
+)
+parser.add_argument(
+ '--model_file', type=str, default='models/d2_tf.pth',
+ help='path to the full model'
+)
+
+parser.add_argument(
+ '--max_edge', type=int, default=1600,
+ help='maximum image size at network input'
+)
+parser.add_argument(
+ '--max_sum_edges', type=int, default=2800,
+ help='maximum sum of image sizes at network input'
+)
+
+parser.add_argument(
+ '--output_extension', type=str, default='.d2-net',
+ help='extension for the output'
+)
+parser.add_argument(
+ '--output_type', type=str, default='npz',
+ help='output file type (npz or mat)'
+)
+
+parser.add_argument(
+ '--multiscale', dest='multiscale', action='store_true',
+ help='extract multiscale features'
+)
+parser.set_defaults(multiscale=False)
+
+parser.add_argument(
+ '--no-relu', dest='use_relu', action='store_false',
+ help='remove ReLU after the dense feature extraction module'
+)
+parser.set_defaults(use_relu=True)
+
+args = parser.parse_args()
+
+print(args)
+
+# Creating CNN model
+model = D2Net(
+ model_file=args.model_file,
+ use_relu=args.use_relu,
+ use_cuda=use_cuda
+)
+
+# Process the file
+with open(args.image_list_file, 'r') as f:
+ lines = f.readlines()
+for line in tqdm(lines, total=len(lines)):
+ path = line.strip()
+
+ image = imageio.imread(path)
+ if len(image.shape) == 2:
+ image = image[:, :, np.newaxis]
+ image = np.repeat(image, 3, -1)
+
+ # TODO: switch to PIL.Image due to deprecation of scipy.misc.imresize.
+ resized_image = image
+ if max(resized_image.shape) > args.max_edge:
+ resized_image = scipy.misc.imresize(
+ resized_image,
+ args.max_edge / max(resized_image.shape)
+ ).astype('float')
+ if sum(resized_image.shape[: 2]) > args.max_sum_edges:
+ resized_image = scipy.misc.imresize(
+ resized_image,
+ args.max_sum_edges / sum(resized_image.shape[: 2])
+ ).astype('float')
+
+ fact_i = image.shape[0] / resized_image.shape[0]
+ fact_j = image.shape[1] / resized_image.shape[1]
+
+ input_image = preprocess_image(
+ resized_image,
+ preprocessing=args.preprocessing
+ )
+ with torch.no_grad():
+ if args.multiscale:
+ keypoints, scores, descriptors = process_multiscale(
+ torch.tensor(
+ input_image[np.newaxis, :, :, :].astype(np.float32),
+ device=device
+ ),
+ model
+ )
+ else:
+ keypoints, scores, descriptors = process_multiscale(
+ torch.tensor(
+ input_image[np.newaxis, :, :, :].astype(np.float32),
+ device=device
+ ),
+ model,
+ scales=[1]
+ )
+
+ # Input image coordinates
+ keypoints[:, 0] *= fact_i
+ keypoints[:, 1] *= fact_j
+ # i, j -> u, v
+ keypoints = keypoints[:, [1, 0, 2]]
+
+ if args.output_type == 'npz':
+ with open(path + args.output_extension, 'wb') as output_file:
+ np.savez(
+ output_file,
+ keypoints=keypoints,
+ scores=scores,
+ descriptors=descriptors
+ )
+ elif args.output_type == 'mat':
+ with open(path + args.output_extension, 'wb') as output_file:
+ scipy.io.savemat(
+ output_file,
+ {
+ 'keypoints': keypoints,
+ 'scores': scores,
+ 'descriptors': descriptors
+ }
+ )
+ else:
+ raise ValueError('Unknown output type.')
diff --git a/imcui/third_party/d2net/extract_kapture.py b/imcui/third_party/d2net/extract_kapture.py
new file mode 100644
index 0000000000000000000000000000000000000000..23198b978229c699dbe24cd3bc0400d62bcab030
--- /dev/null
+++ b/imcui/third_party/d2net/extract_kapture.py
@@ -0,0 +1,248 @@
+import argparse
+import numpy as np
+from PIL import Image
+import torch
+import math
+from tqdm import tqdm
+from os import path
+
+# Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) and more generally sensor-acquired data
+# it can be installed with
+# pip install kapture
+# for more information check out https://github.com/naver/kapture
+import kapture
+from kapture.io.records import get_image_fullpath
+from kapture.io.csv import kapture_from_dir, get_all_tar_handlers
+from kapture.io.csv import get_feature_csv_fullpath, keypoints_to_file, descriptors_to_file
+from kapture.io.features import get_keypoints_fullpath, keypoints_check_dir, image_keypoints_to_file
+from kapture.io.features import get_descriptors_fullpath, descriptors_check_dir, image_descriptors_to_file
+
+from lib.model_test import D2Net
+from lib.utils import preprocess_image
+from lib.pyramid import process_multiscale
+
+# import imageio
+
+# CUDA
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:0" if use_cuda else "cpu")
+
+# Argument parsing
+parser = argparse.ArgumentParser(description='Feature extraction script')
+
+parser.add_argument(
+ '--kapture-root', type=str, required=True,
+ help='path to kapture root directory'
+)
+
+parser.add_argument(
+ '--preprocessing', type=str, default='caffe',
+ help='image preprocessing (caffe or torch)'
+)
+parser.add_argument(
+ '--model_file', type=str, default='models/d2_tf.pth',
+ help='path to the full model'
+)
+parser.add_argument(
+ '--keypoints-type', type=str, default=None,
+ help='keypoint type_name, default is filename of model'
+)
+parser.add_argument(
+ '--descriptors-type', type=str, default=None,
+ help='descriptors type_name, default is filename of model'
+)
+
+parser.add_argument(
+ '--max_edge', type=int, default=1600,
+ help='maximum image size at network input'
+)
+parser.add_argument(
+ '--max_sum_edges', type=int, default=2800,
+ help='maximum sum of image sizes at network input'
+)
+
+parser.add_argument(
+ '--multiscale', dest='multiscale', action='store_true',
+ help='extract multiscale features'
+)
+parser.set_defaults(multiscale=False)
+
+parser.add_argument(
+ '--no-relu', dest='use_relu', action='store_false',
+ help='remove ReLU after the dense feature extraction module'
+)
+parser.set_defaults(use_relu=True)
+
+parser.add_argument("--max-keypoints", type=int, default=float("+inf"),
+ help='max number of keypoints save to disk')
+
+args = parser.parse_args()
+
+print(args)
+with get_all_tar_handlers(args.kapture_root,
+ mode={kapture.Keypoints: 'a',
+ kapture.Descriptors: 'a',
+ kapture.GlobalFeatures: 'r',
+ kapture.Matches: 'r'}) as tar_handlers:
+ kdata = kapture_from_dir(args.kapture_root,
+ skip_list=[kapture.GlobalFeatures,
+ kapture.Matches,
+ kapture.Points3d,
+ kapture.Observations],
+ tar_handlers=tar_handlers)
+ if kdata.keypoints is None:
+ kdata.keypoints = {}
+ if kdata.descriptors is None:
+ kdata.descriptors = {}
+
+ assert kdata.records_camera is not None
+ image_list = [filename for _, _, filename in kapture.flatten(kdata.records_camera)]
+ if args.keypoints_type is None:
+ args.keypoints_type = path.splitext(path.basename(args.model_file))[0]
+ print(f'keypoints_type set to {args.keypoints_type}')
+ if args.descriptors_type is None:
+ args.descriptors_type = path.splitext(path.basename(args.model_file))[0]
+ print(f'descriptors_type set to {args.descriptors_type}')
+ if args.keypoints_type in kdata.keypoints and args.descriptors_type in kdata.descriptors:
+ image_list = [name
+ for name in image_list
+ if name not in kdata.keypoints[args.keypoints_type] or
+ name not in kdata.descriptors[args.descriptors_type]]
+
+ if len(image_list) == 0:
+ print('All features were already extracted')
+ exit(0)
+ else:
+ print(f'Extracting d2net features for {len(image_list)} images')
+
+ # Creating CNN model
+ model = D2Net(
+ model_file=args.model_file,
+ use_relu=args.use_relu,
+ use_cuda=use_cuda
+ )
+
+ if args.keypoints_type not in kdata.keypoints:
+ keypoints_dtype = None
+ keypoints_dsize = None
+ else:
+ keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype
+ keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize
+ if args.descriptors_type not in kdata.descriptors:
+ descriptors_dtype = None
+ descriptors_dsize = None
+ else:
+ descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype
+ descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize
+
+ # Process the files
+ for image_name in tqdm(image_list, total=len(image_list)):
+ img_path = get_image_fullpath(args.kapture_root, image_name)
+ image = Image.open(img_path).convert('RGB')
+
+ width, height = image.size
+
+ resized_image = image
+ resized_width = width
+ resized_height = height
+
+ max_edge = args.max_edge
+ max_sum_edges = args.max_sum_edges
+ if max(resized_width, resized_height) > max_edge:
+ scale_multiplier = max_edge / max(resized_width, resized_height)
+ resized_width = math.floor(resized_width * scale_multiplier)
+ resized_height = math.floor(resized_height * scale_multiplier)
+ resized_image = image.resize((resized_width, resized_height))
+ if resized_width + resized_height > max_sum_edges:
+ scale_multiplier = max_sum_edges / (resized_width + resized_height)
+ resized_width = math.floor(resized_width * scale_multiplier)
+ resized_height = math.floor(resized_height * scale_multiplier)
+ resized_image = image.resize((resized_width, resized_height))
+
+ fact_i = width / resized_width
+ fact_j = height / resized_height
+
+ resized_image = np.array(resized_image).astype('float')
+
+ input_image = preprocess_image(
+ resized_image,
+ preprocessing=args.preprocessing
+ )
+
+ with torch.no_grad():
+ if args.multiscale:
+ keypoints, scores, descriptors = process_multiscale(
+ torch.tensor(
+ input_image[np.newaxis, :, :, :].astype(np.float32),
+ device=device
+ ),
+ model
+ )
+ else:
+ keypoints, scores, descriptors = process_multiscale(
+ torch.tensor(
+ input_image[np.newaxis, :, :, :].astype(np.float32),
+ device=device
+ ),
+ model,
+ scales=[1]
+ )
+
+ # Input image coordinates
+ keypoints[:, 0] *= fact_i
+ keypoints[:, 1] *= fact_j
+ # i, j -> u, v
+ keypoints = keypoints[:, [1, 0, 2]]
+
+ if args.max_keypoints != float("+inf"):
+ # keep the last (the highest) indexes
+ idx_keep = scores.argsort()[-min(len(keypoints), args.max_keypoints):]
+ keypoints = keypoints[idx_keep]
+ descriptors = descriptors[idx_keep]
+
+ if keypoints_dtype is None or descriptors_dtype is None:
+ keypoints_dtype = keypoints.dtype
+ descriptors_dtype = descriptors.dtype
+
+ keypoints_dsize = keypoints.shape[1]
+ descriptors_dsize = descriptors.shape[1]
+
+ kdata.keypoints[args.keypoints_type] = kapture.Keypoints('d2net', keypoints_dtype, keypoints_dsize)
+ kdata.descriptors[args.descriptors_type] = kapture.Descriptors('d2net', descriptors_dtype,
+ descriptors_dsize,
+ args.keypoints_type, 'L2')
+
+ keypoints_config_absolute_path = get_feature_csv_fullpath(kapture.Keypoints,
+ args.keypoints_type,
+ args.kapture_root)
+ descriptors_config_absolute_path = get_feature_csv_fullpath(kapture.Descriptors,
+ args.descriptors_type,
+ args.kapture_root)
+
+ keypoints_to_file(keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type])
+ descriptors_to_file(descriptors_config_absolute_path, kdata.descriptors[args.descriptors_type])
+ else:
+ assert kdata.keypoints[args.keypoints_type].dtype == keypoints.dtype
+ assert kdata.descriptors[args.descriptors_type].dtype == descriptors.dtype
+ assert kdata.keypoints[args.keypoints_type].dsize == keypoints.shape[1]
+ assert kdata.descriptors[args.descriptors_type].dsize == descriptors.shape[1]
+ assert kdata.descriptors[args.descriptors_type].keypoints_type == args.keypoints_type
+ assert kdata.descriptors[args.descriptors_type].metric_type == 'L2'
+
+ keypoints_fullpath = get_keypoints_fullpath(args.keypoints_type, args.kapture_root,
+ image_name, tar_handlers)
+ print(f"Saving {keypoints.shape[0]} keypoints to {keypoints_fullpath}")
+ image_keypoints_to_file(keypoints_fullpath, keypoints)
+ kdata.keypoints[args.keypoints_type].add(image_name)
+
+ descriptors_fullpath = get_descriptors_fullpath(args.descriptors_type, args.kapture_root,
+ image_name, tar_handlers)
+ print(f"Saving {descriptors.shape[0]} descriptors to {descriptors_fullpath}")
+ image_descriptors_to_file(descriptors_fullpath, descriptors)
+ kdata.descriptors[args.descriptors_type].add(image_name)
+
+ if not keypoints_check_dir(kdata.keypoints[args.keypoints_type], args.keypoints_type,
+ args.kapture_root, tar_handlers) or \
+ not descriptors_check_dir(kdata.descriptors[args.descriptors_type], args.descriptors_type,
+ args.kapture_root, tar_handlers):
+ print('local feature extraction ended successfully but not all files were saved')
diff --git a/imcui/third_party/d2net/megadepth_utils/preprocess_scene.py b/imcui/third_party/d2net/megadepth_utils/preprocess_scene.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc68a403795e7cddce88dfcb74b38d19ab09e133
--- /dev/null
+++ b/imcui/third_party/d2net/megadepth_utils/preprocess_scene.py
@@ -0,0 +1,242 @@
+import argparse
+
+import imagesize
+
+import numpy as np
+
+import os
+
+parser = argparse.ArgumentParser(description='MegaDepth preprocessing script')
+
+parser.add_argument(
+ '--base_path', type=str, required=True,
+ help='path to MegaDepth'
+)
+parser.add_argument(
+ '--scene_id', type=str, required=True,
+ help='scene ID'
+)
+
+parser.add_argument(
+ '--output_path', type=str, required=True,
+ help='path to the output directory'
+)
+
+args = parser.parse_args()
+
+base_path = args.base_path
+# Remove the trailing / if need be.
+if base_path[-1] in ['/', '\\']:
+ base_path = base_path[: - 1]
+scene_id = args.scene_id
+
+base_depth_path = os.path.join(
+ base_path, 'phoenix/S6/zl548/MegaDepth_v1'
+)
+base_undistorted_sfm_path = os.path.join(
+ base_path, 'Undistorted_SfM'
+)
+
+undistorted_sparse_path = os.path.join(
+ base_undistorted_sfm_path, scene_id, 'sparse-txt'
+)
+if not os.path.exists(undistorted_sparse_path):
+ exit()
+
+depths_path = os.path.join(
+ base_depth_path, scene_id, 'dense0', 'depths'
+)
+if not os.path.exists(depths_path):
+ exit()
+
+images_path = os.path.join(
+ base_undistorted_sfm_path, scene_id, 'images'
+)
+if not os.path.exists(images_path):
+ exit()
+
+# Process cameras.txt
+with open(os.path.join(undistorted_sparse_path, 'cameras.txt'), 'r') as f:
+ raw = f.readlines()[3 :] # skip the header
+
+camera_intrinsics = {}
+for camera in raw:
+ camera = camera.split(' ')
+ camera_intrinsics[int(camera[0])] = [float(elem) for elem in camera[2 :]]
+
+# Process points3D.txt
+with open(os.path.join(undistorted_sparse_path, 'points3D.txt'), 'r') as f:
+ raw = f.readlines()[3 :] # skip the header
+
+points3D = {}
+for point3D in raw:
+ point3D = point3D.split(' ')
+ points3D[int(point3D[0])] = np.array([
+ float(point3D[1]), float(point3D[2]), float(point3D[3])
+ ])
+
+# Process images.txt
+with open(os.path.join(undistorted_sparse_path, 'images.txt'), 'r') as f:
+ raw = f.readlines()[4 :] # skip the header
+
+image_id_to_idx = {}
+image_names = []
+raw_pose = []
+camera = []
+points3D_id_to_2D = []
+n_points3D = []
+for idx, (image, points) in enumerate(zip(raw[:: 2], raw[1 :: 2])):
+ image = image.split(' ')
+ points = points.split(' ')
+
+ image_id_to_idx[int(image[0])] = idx
+
+ image_name = image[-1].strip('\n')
+ image_names.append(image_name)
+
+ raw_pose.append([float(elem) for elem in image[1 : -2]])
+ camera.append(int(image[-2]))
+ current_points3D_id_to_2D = {}
+ for x, y, point3D_id in zip(points[:: 3], points[1 :: 3], points[2 :: 3]):
+ if int(point3D_id) == -1:
+ continue
+ current_points3D_id_to_2D[int(point3D_id)] = [float(x), float(y)]
+ points3D_id_to_2D.append(current_points3D_id_to_2D)
+ n_points3D.append(len(current_points3D_id_to_2D))
+n_images = len(image_names)
+
+# Image and depthmaps paths
+image_paths = []
+depth_paths = []
+for image_name in image_names:
+ image_path = os.path.join(images_path, image_name)
+
+ # Path to the depth file
+ depth_path = os.path.join(
+ depths_path, '%s.h5' % os.path.splitext(image_name)[0]
+ )
+
+ if os.path.exists(depth_path):
+ # Check if depth map or background / foreground mask
+ file_size = os.stat(depth_path).st_size
+ # Rough estimate - 75KB might work as well
+ if file_size < 100 * 1024:
+ depth_paths.append(None)
+ image_paths.append(None)
+ else:
+ depth_paths.append(depth_path[len(base_path) + 1 :])
+ image_paths.append(image_path[len(base_path) + 1 :])
+ else:
+ depth_paths.append(None)
+ image_paths.append(None)
+
+# Camera configuration
+intrinsics = []
+poses = []
+principal_axis = []
+points3D_id_to_ndepth = []
+for idx, image_name in enumerate(image_names):
+ if image_paths[idx] is None:
+ intrinsics.append(None)
+ poses.append(None)
+ principal_axis.append([0, 0, 0])
+ points3D_id_to_ndepth.append({})
+ continue
+ image_intrinsics = camera_intrinsics[camera[idx]]
+ K = np.zeros([3, 3])
+ K[0, 0] = image_intrinsics[2]
+ K[0, 2] = image_intrinsics[4]
+ K[1, 1] = image_intrinsics[3]
+ K[1, 2] = image_intrinsics[5]
+ K[2, 2] = 1
+ intrinsics.append(K)
+
+ image_pose = raw_pose[idx]
+ qvec = image_pose[: 4]
+ qvec = qvec / np.linalg.norm(qvec)
+ w, x, y, z = qvec
+ R = np.array([
+ [
+ 1 - 2 * y * y - 2 * z * z,
+ 2 * x * y - 2 * z * w,
+ 2 * x * z + 2 * y * w
+ ],
+ [
+ 2 * x * y + 2 * z * w,
+ 1 - 2 * x * x - 2 * z * z,
+ 2 * y * z - 2 * x * w
+ ],
+ [
+ 2 * x * z - 2 * y * w,
+ 2 * y * z + 2 * x * w,
+ 1 - 2 * x * x - 2 * y * y
+ ]
+ ])
+ principal_axis.append(R[2, :])
+ t = image_pose[4 : 7]
+ # World-to-Camera pose
+ current_pose = np.zeros([4, 4])
+ current_pose[: 3, : 3] = R
+ current_pose[: 3, 3] = t
+ current_pose[3, 3] = 1
+ # Camera-to-World pose
+ # pose = np.zeros([4, 4])
+ # pose[: 3, : 3] = np.transpose(R)
+ # pose[: 3, 3] = -np.matmul(np.transpose(R), t)
+ # pose[3, 3] = 1
+ poses.append(current_pose)
+
+ current_points3D_id_to_ndepth = {}
+ for point3D_id in points3D_id_to_2D[idx].keys():
+ p3d = points3D[point3D_id]
+ current_points3D_id_to_ndepth[point3D_id] = (np.dot(R[2, :], p3d) + t[2]) / (.5 * (K[0, 0] + K[1, 1]))
+ points3D_id_to_ndepth.append(current_points3D_id_to_ndepth)
+principal_axis = np.array(principal_axis)
+angles = np.rad2deg(np.arccos(
+ np.clip(
+ np.dot(principal_axis, np.transpose(principal_axis)),
+ -1, 1
+ )
+))
+
+# Compute overlap score
+overlap_matrix = np.full([n_images, n_images], -1.)
+scale_ratio_matrix = np.full([n_images, n_images], -1.)
+for idx1 in range(n_images):
+ if image_paths[idx1] is None or depth_paths[idx1] is None:
+ continue
+ for idx2 in range(idx1 + 1, n_images):
+ if image_paths[idx2] is None or depth_paths[idx2] is None:
+ continue
+ matches = (
+ points3D_id_to_2D[idx1].keys() &
+ points3D_id_to_2D[idx2].keys()
+ )
+ min_num_points3D = min(
+ len(points3D_id_to_2D[idx1]), len(points3D_id_to_2D[idx2])
+ )
+ overlap_matrix[idx1, idx2] = len(matches) / len(points3D_id_to_2D[idx1]) # min_num_points3D
+ overlap_matrix[idx2, idx1] = len(matches) / len(points3D_id_to_2D[idx2]) # min_num_points3D
+ if len(matches) == 0:
+ continue
+ points3D_id_to_ndepth1 = points3D_id_to_ndepth[idx1]
+ points3D_id_to_ndepth2 = points3D_id_to_ndepth[idx2]
+ nd1 = np.array([points3D_id_to_ndepth1[match] for match in matches])
+ nd2 = np.array([points3D_id_to_ndepth2[match] for match in matches])
+ min_scale_ratio = np.min(np.maximum(nd1 / nd2, nd2 / nd1))
+ scale_ratio_matrix[idx1, idx2] = min_scale_ratio
+ scale_ratio_matrix[idx2, idx1] = min_scale_ratio
+
+np.savez(
+ os.path.join(args.output_path, '%s.npz' % scene_id),
+ image_paths=image_paths,
+ depth_paths=depth_paths,
+ intrinsics=intrinsics,
+ poses=poses,
+ overlap_matrix=overlap_matrix,
+ scale_ratio_matrix=scale_ratio_matrix,
+ angles=angles,
+ n_points3D=n_points3D,
+ points3D_id_to_2D=points3D_id_to_2D,
+ points3D_id_to_ndepth=points3D_id_to_ndepth
+)
diff --git a/imcui/third_party/d2net/megadepth_utils/undistort_reconstructions.py b/imcui/third_party/d2net/megadepth_utils/undistort_reconstructions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6b99a72f81206e6fbefae9daa9aa683c8754051
--- /dev/null
+++ b/imcui/third_party/d2net/megadepth_utils/undistort_reconstructions.py
@@ -0,0 +1,69 @@
+import argparse
+
+import imagesize
+
+import os
+
+import subprocess
+
+parser = argparse.ArgumentParser(description='MegaDepth Undistortion')
+
+parser.add_argument(
+ '--colmap_path', type=str, required=True,
+ help='path to colmap executable'
+)
+parser.add_argument(
+ '--base_path', type=str, required=True,
+ help='path to MegaDepth'
+)
+
+args = parser.parse_args()
+
+sfm_path = os.path.join(
+ args.base_path, 'MegaDepth_v1_SfM'
+)
+base_depth_path = os.path.join(
+ args.base_path, 'phoenix/S6/zl548/MegaDepth_v1'
+)
+output_path = os.path.join(
+ args.base_path, 'Undistorted_SfM'
+)
+
+os.mkdir(output_path)
+
+for scene_name in os.listdir(base_depth_path):
+ current_output_path = os.path.join(output_path, scene_name)
+ os.mkdir(current_output_path)
+
+ image_path = os.path.join(
+ base_depth_path, scene_name, 'dense0', 'imgs'
+ )
+ if not os.path.exists(image_path):
+ continue
+
+ # Find the maximum image size in scene.
+ max_image_size = 0
+ for image_name in os.listdir(image_path):
+ max_image_size = max(
+ max_image_size,
+ max(imagesize.get(os.path.join(image_path, image_name)))
+ )
+
+ # Undistort the images and update the reconstruction.
+ subprocess.call([
+ os.path.join(args.colmap_path, 'colmap'), 'image_undistorter',
+ '--image_path', os.path.join(sfm_path, scene_name, 'images'),
+ '--input_path', os.path.join(sfm_path, scene_name, 'sparse', 'manhattan', '0'),
+ '--output_path', current_output_path,
+ '--max_image_size', str(max_image_size)
+ ])
+
+ # Transform the reconstruction to raw text format.
+ sparse_txt_path = os.path.join(current_output_path, 'sparse-txt')
+ os.mkdir(sparse_txt_path)
+ subprocess.call([
+ os.path.join(args.colmap_path, 'colmap'), 'model_converter',
+ '--input_path', os.path.join(current_output_path, 'sparse'),
+ '--output_path', sparse_txt_path,
+ '--output_type', 'TXT'
+ ])
\ No newline at end of file
diff --git a/imcui/third_party/d2net/train.py b/imcui/third_party/d2net/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5817f1712bda0779175fb18437d1f8c263f29f3b
--- /dev/null
+++ b/imcui/third_party/d2net/train.py
@@ -0,0 +1,279 @@
+import argparse
+
+import numpy as np
+
+import os
+
+import shutil
+
+import torch
+import torch.optim as optim
+
+from torch.utils.data import DataLoader
+
+from tqdm import tqdm
+
+import warnings
+
+from lib.dataset import MegaDepthDataset
+from lib.exceptions import NoGradientError
+from lib.loss import loss_function
+from lib.model import D2Net
+
+
+# CUDA
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:0" if use_cuda else "cpu")
+
+# Seed
+torch.manual_seed(1)
+if use_cuda:
+ torch.cuda.manual_seed(1)
+np.random.seed(1)
+
+# Argument parsing
+parser = argparse.ArgumentParser(description='Training script')
+
+parser.add_argument(
+ '--dataset_path', type=str, required=True,
+ help='path to the dataset'
+)
+parser.add_argument(
+ '--scene_info_path', type=str, required=True,
+ help='path to the processed scenes'
+)
+
+parser.add_argument(
+ '--preprocessing', type=str, default='caffe',
+ help='image preprocessing (caffe or torch)'
+)
+parser.add_argument(
+ '--model_file', type=str, default='models/d2_ots.pth',
+ help='path to the full model'
+)
+
+parser.add_argument(
+ '--num_epochs', type=int, default=10,
+ help='number of training epochs'
+)
+parser.add_argument(
+ '--lr', type=float, default=1e-3,
+ help='initial learning rate'
+)
+parser.add_argument(
+ '--batch_size', type=int, default=1,
+ help='batch size'
+)
+parser.add_argument(
+ '--num_workers', type=int, default=4,
+ help='number of workers for data loading'
+)
+
+parser.add_argument(
+ '--use_validation', dest='use_validation', action='store_true',
+ help='use the validation split'
+)
+parser.set_defaults(use_validation=False)
+
+parser.add_argument(
+ '--log_interval', type=int, default=250,
+ help='loss logging interval'
+)
+
+parser.add_argument(
+ '--log_file', type=str, default='log.txt',
+ help='loss logging file'
+)
+
+parser.add_argument(
+ '--plot', dest='plot', action='store_true',
+ help='plot training pairs'
+)
+parser.set_defaults(plot=False)
+
+parser.add_argument(
+ '--checkpoint_directory', type=str, default='checkpoints',
+ help='directory for training checkpoints'
+)
+parser.add_argument(
+ '--checkpoint_prefix', type=str, default='d2',
+ help='prefix for training checkpoints'
+)
+
+args = parser.parse_args()
+
+print(args)
+
+# Create the folders for plotting if need be
+if args.plot:
+ plot_path = 'train_vis'
+ if os.path.isdir(plot_path):
+ print('[Warning] Plotting directory already exists.')
+ else:
+ os.mkdir(plot_path)
+
+# Creating CNN model
+model = D2Net(
+ model_file=args.model_file,
+ use_cuda=use_cuda
+)
+
+# Optimizer
+optimizer = optim.Adam(
+ filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr
+)
+
+# Dataset
+if args.use_validation:
+ validation_dataset = MegaDepthDataset(
+ scene_list_path='megadepth_utils/valid_scenes.txt',
+ scene_info_path=args.scene_info_path,
+ base_path=args.dataset_path,
+ train=False,
+ preprocessing=args.preprocessing,
+ pairs_per_scene=25
+ )
+ validation_dataloader = DataLoader(
+ validation_dataset,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers
+ )
+
+training_dataset = MegaDepthDataset(
+ scene_list_path='megadepth_utils/train_scenes.txt',
+ scene_info_path=args.scene_info_path,
+ base_path=args.dataset_path,
+ preprocessing=args.preprocessing
+)
+training_dataloader = DataLoader(
+ training_dataset,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers
+)
+
+
+# Define epoch function
+def process_epoch(
+ epoch_idx,
+ model, loss_function, optimizer, dataloader, device,
+ log_file, args, train=True
+):
+ epoch_losses = []
+
+ torch.set_grad_enabled(train)
+
+ progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
+ for batch_idx, batch in progress_bar:
+ if train:
+ optimizer.zero_grad()
+
+ batch['train'] = train
+ batch['epoch_idx'] = epoch_idx
+ batch['batch_idx'] = batch_idx
+ batch['batch_size'] = args.batch_size
+ batch['preprocessing'] = args.preprocessing
+ batch['log_interval'] = args.log_interval
+
+ try:
+ loss = loss_function(model, batch, device, plot=args.plot)
+ except NoGradientError:
+ continue
+
+ current_loss = loss.data.cpu().numpy()[0]
+ epoch_losses.append(current_loss)
+
+ progress_bar.set_postfix(loss=('%.4f' % np.mean(epoch_losses)))
+
+ if batch_idx % args.log_interval == 0:
+ log_file.write('[%s] epoch %d - batch %d / %d - avg_loss: %f\n' % (
+ 'train' if train else 'valid',
+ epoch_idx, batch_idx, len(dataloader), np.mean(epoch_losses)
+ ))
+
+ if train:
+ loss.backward()
+ optimizer.step()
+
+ log_file.write('[%s] epoch %d - avg_loss: %f\n' % (
+ 'train' if train else 'valid',
+ epoch_idx,
+ np.mean(epoch_losses)
+ ))
+ log_file.flush()
+
+ return np.mean(epoch_losses)
+
+
+# Create the checkpoint directory
+if os.path.isdir(args.checkpoint_directory):
+ print('[Warning] Checkpoint directory already exists.')
+else:
+ os.mkdir(args.checkpoint_directory)
+
+
+# Open the log file for writing
+if os.path.exists(args.log_file):
+ print('[Warning] Log file already exists.')
+log_file = open(args.log_file, 'a+')
+
+# Initialize the history
+train_loss_history = []
+validation_loss_history = []
+if args.use_validation:
+ validation_dataset.build_dataset()
+ min_validation_loss = process_epoch(
+ 0,
+ model, loss_function, optimizer, validation_dataloader, device,
+ log_file, args,
+ train=False
+ )
+
+# Start the training
+for epoch_idx in range(1, args.num_epochs + 1):
+ # Process epoch
+ training_dataset.build_dataset()
+ train_loss_history.append(
+ process_epoch(
+ epoch_idx,
+ model, loss_function, optimizer, training_dataloader, device,
+ log_file, args
+ )
+ )
+
+ if args.use_validation:
+ validation_loss_history.append(
+ process_epoch(
+ epoch_idx,
+ model, loss_function, optimizer, validation_dataloader, device,
+ log_file, args,
+ train=False
+ )
+ )
+
+ # Save the current checkpoint
+ checkpoint_path = os.path.join(
+ args.checkpoint_directory,
+ '%s.%02d.pth' % (args.checkpoint_prefix, epoch_idx)
+ )
+ checkpoint = {
+ 'args': args,
+ 'epoch_idx': epoch_idx,
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'train_loss_history': train_loss_history,
+ 'validation_loss_history': validation_loss_history
+ }
+ torch.save(checkpoint, checkpoint_path)
+ if (
+ args.use_validation and
+ validation_loss_history[-1] < min_validation_loss
+ ):
+ min_validation_loss = validation_loss_history[-1]
+ best_checkpoint_path = os.path.join(
+ args.checkpoint_directory,
+ '%s.best.pth' % args.checkpoint_prefix
+ )
+ shutil.copy(checkpoint_path, best_checkpoint_path)
+
+# Close the log file
+log_file.close()
diff --git a/imcui/third_party/dust3r/croco/datasets/__init__.py b/imcui/third_party/dust3r/croco/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py b/imcui/third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb66a0474ce44b54c44c08887cbafdb045b11ff3
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py
@@ -0,0 +1,159 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Extracting crops for pre-training
+# --------------------------------------------------------
+
+import os
+import argparse
+from tqdm import tqdm
+from PIL import Image
+import functools
+from multiprocessing import Pool
+import math
+
+
+def arg_parser():
+ parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
+
+ parser.add_argument('--crops', type=str, required=True, help='crop file')
+ parser.add_argument('--root-dir', type=str, required=True, help='root directory')
+ parser.add_argument('--output-dir', type=str, required=True, help='output directory')
+ parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
+ parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
+ parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
+ parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
+ return parser
+
+
+def main(args):
+ listing_path = os.path.join(args.output_dir, 'listing.txt')
+
+ print(f'Loading list of crops ... ({args.nthread} threads)')
+ crops, num_crops_to_generate = load_crop_file(args.crops)
+
+ print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
+ num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
+
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
+ del crops
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
+ call = functools.partial(save_image_crops, args)
+
+ print(f"Generating cropped images to {args.output_dir} ...")
+ with open(listing_path, 'w') as listing:
+ listing.write('# pair_path\n')
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
+ for path in results:
+ listing.write(f'{path}\n')
+ print('Finished writing listing to', listing_path)
+
+
+def load_crop_file(path):
+ data = open(path).read().splitlines()
+ pairs = []
+ num_crops_to_generate = 0
+ for line in tqdm(data):
+ if line.startswith('#'):
+ continue
+ line = line.split(', ')
+ if len(line) < 8:
+ img1, img2, rotation = line
+ pairs.append((img1, img2, int(rotation), []))
+ else:
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
+ pairs[-1][-1].append((rect1, rect2))
+ num_crops_to_generate += 1
+ return pairs, num_crops_to_generate
+
+
+def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
+ jobs = []
+ powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
+
+ def get_path(idx):
+ idx_array = []
+ d = idx
+ for level in range(num_levels - 1):
+ idx_array.append(idx // powers[level])
+ idx = idx % powers[level]
+ idx_array.append(d)
+ return '/'.join(map(lambda x: hex(x)[2:], idx_array))
+
+ idx = 0
+ for pair_data in tqdm(pairs):
+ img1, img2, rotation, crops = pair_data
+ if -60 <= rotation and rotation <= 60:
+ rotation = 0 # most likely not a true rotation
+ paths = [get_path(idx + k) for k in range(len(crops))]
+ idx += len(crops)
+ jobs.append(((img1, img2), rotation, crops, paths))
+ return jobs
+
+
+def load_image(path):
+ try:
+ return Image.open(path).convert('RGB')
+ except Exception as e:
+ print('skipping', path, e)
+ raise OSError()
+
+
+def save_image_crops(args, data):
+ # load images
+ img_pair, rot, crops, paths = data
+ try:
+ img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
+ except OSError as e:
+ return []
+
+ def area(sz):
+ return sz[0] * sz[1]
+
+ tgt_size = (args.imsize, args.imsize)
+
+ def prepare_crop(img, rect, rot=0):
+ # actual crop
+ img = img.crop(rect)
+
+ # resize to desired size
+ interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
+ img = img.resize(tgt_size, resample=interp)
+
+ # rotate the image
+ rot90 = (round(rot/90) % 4) * 90
+ if rot90 == 90:
+ img = img.transpose(Image.Transpose.ROTATE_90)
+ elif rot90 == 180:
+ img = img.transpose(Image.Transpose.ROTATE_180)
+ elif rot90 == 270:
+ img = img.transpose(Image.Transpose.ROTATE_270)
+ return img
+
+ results = []
+ for (rect1, rect2), path in zip(crops, paths):
+ crop1 = prepare_crop(img1, rect1)
+ crop2 = prepare_crop(img2, rect2, rot)
+
+ fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
+ fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
+
+ assert not os.path.isfile(fullpath1), fullpath1
+ assert not os.path.isfile(fullpath2), fullpath2
+ crop1.save(fullpath1)
+ crop2.save(fullpath2)
+ results.append(path)
+
+ return results
+
+
+if __name__ == '__main__':
+ args = arg_parser().parse_args()
+ main(args)
+
diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/__init__.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbe0d399084359495250dc8184671ff498adfbf2
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py
@@ -0,0 +1,92 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+"""
+Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
+"""
+import os
+from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
+from datasets.habitat_sim.paths import SCENES_DATASET
+import argparse
+import quaternion
+import PIL.Image
+import cv2
+import json
+from tqdm import tqdm
+
+def generate_multiview_images_from_metadata(metadata_filename,
+ output_dir,
+ overload_params = dict(),
+ scene_datasets_paths=None,
+ exist_ok=False):
+ """
+ Generate images from a metadata file for reproducibility purposes.
+ """
+ # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
+ if scene_datasets_paths is not None:
+ scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
+
+ with open(metadata_filename, 'r') as f:
+ input_metadata = json.load(f)
+ metadata = dict()
+ for key, value in input_metadata.items():
+ # Optionally replace some paths
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
+ if scene_datasets_paths is not None:
+ for dataset_label, dataset_path in scene_datasets_paths.items():
+ if value.startswith(dataset_label):
+ value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
+ break
+ metadata[key] = value
+
+ # Overload some parameters
+ for key, value in overload_params.items():
+ metadata[key] = value
+
+ generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
+ generate_depth = metadata["generate_depth"]
+
+ os.makedirs(output_dir, exist_ok=exist_ok)
+
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
+
+ # Generate views
+ for idx_label, data in tqdm(metadata['multiviews'].items()):
+ positions = data["positions"]
+ orientations = data["orientations"]
+ n = len(positions)
+ for oidx in range(n):
+ observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
+ # Color image saved using PIL
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
+ img.save(filename)
+ if generate_depth:
+ # Depth image as EXR file
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
+ # Camera parameters
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
+ with open(filename, "w") as f:
+ json.dump(camera_params, f)
+ # Save metadata
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
+ json.dump(metadata, f)
+
+ generator.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--metadata_filename", required=True)
+ parser.add_argument("--output_dir", required=True)
+ args = parser.parse_args()
+
+ generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
+ output_dir=args.output_dir,
+ scene_datasets_paths=SCENES_DATASET,
+ overload_params=dict(),
+ exist_ok=True)
+
+
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..962ef849d8c31397b8622df4f2d9140175d78873
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+"""
+Script generating commandlines to generate image pairs from metadata files.
+"""
+import os
+import glob
+from tqdm import tqdm
+import argparse
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_dir", required=True)
+ parser.add_argument("--output_dir", required=True)
+ parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
+ args = parser.parse_args()
+
+ input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
+
+ for metadata_filename in tqdm(input_metadata_filenames):
+ output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
+ # Do not process the scene if the metadata file already exists
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
+ continue
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
+ print(commandline)
diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..421d49a1696474415940493296b3f2d982398850
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py
@@ -0,0 +1,177 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+import os
+from tqdm import tqdm
+import argparse
+import PIL.Image
+import numpy as np
+import json
+from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError
+from datasets.habitat_sim.paths import list_scenes_available
+import cv2
+import quaternion
+import shutil
+
+def generate_multiview_images_for_scene(scene_dataset_config_file,
+ scene,
+ navmesh,
+ output_dir,
+ views_count,
+ size,
+ exist_ok=False,
+ generate_depth=False,
+ **kwargs):
+ """
+ Generate tuples of overlapping views for a given scene.
+ generate_depth: generate depth images and camera parameters.
+ """
+ if os.path.exists(output_dir) and not exist_ok:
+ print(f"Scene {scene}: data already generated. Ignoring generation.")
+ return
+ try:
+ print(f"Scene {scene}: {size} multiview acquisitions to generate...")
+ os.makedirs(output_dir, exist_ok=exist_ok)
+
+ metadata_filename = os.path.join(output_dir, "metadata.json")
+
+ metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file,
+ scene=scene,
+ navmesh=navmesh,
+ views_count=views_count,
+ size=size,
+ generate_depth=generate_depth,
+ **kwargs)
+ metadata_template["multiviews"] = dict()
+
+ if os.path.exists(metadata_filename):
+ print("Metadata file already exists:", metadata_filename)
+ print("Loading already generated metadata file...")
+ with open(metadata_filename, "r") as f:
+ metadata = json.load(f)
+
+ for key in metadata_template.keys():
+ if key != "multiviews":
+ assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
+ else:
+ print("No temporary file found. Starting generation from scratch...")
+ metadata = metadata_template
+
+ starting_id = len(metadata["multiviews"])
+ print(f"Starting generation from index {starting_id}/{size}...")
+ if starting_id >= size:
+ print("Generation already done.")
+ return
+
+ generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file,
+ scene=scene,
+ navmesh=navmesh,
+ views_count = views_count,
+ size = size,
+ **kwargs)
+
+ for idx in tqdm(range(starting_id, size)):
+ # Generate / re-generate the observations
+ try:
+ data = generator[idx]
+ observations = data["observations"]
+ positions = data["positions"]
+ orientations = data["orientations"]
+
+ idx_label = f"{idx:08}"
+ for oidx, observation in enumerate(observations):
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
+ # Color image saved using PIL
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
+ img.save(filename)
+ if generate_depth:
+ # Depth image as EXR file
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
+ # Camera parameters
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
+ with open(filename, "w") as f:
+ json.dump(camera_params, f)
+ metadata["multiviews"][idx_label] = {"positions": positions.tolist(),
+ "orientations": orientations.tolist(),
+ "covisibility_ratios": data["covisibility_ratios"].tolist(),
+ "valid_fractions": data["valid_fractions"].tolist(),
+ "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()}
+ except RecursionError:
+ print("Recursion error: unable to sample observations for this scene. We will stop there.")
+ break
+
+ # Regularly save a temporary metadata file, in case we need to restart the generation
+ if idx % 10 == 0:
+ with open(metadata_filename, "w") as f:
+ json.dump(metadata, f)
+
+ # Save metadata
+ with open(metadata_filename, "w") as f:
+ json.dump(metadata, f)
+
+ generator.close()
+ except NoNaviguableSpaceError:
+ pass
+
+def create_commandline(scene_data, generate_depth, exist_ok=False):
+ """
+ Create a commandline string to generate a scene.
+ """
+ def my_formatting(val):
+ if val is None or val == "":
+ return '""'
+ else:
+ return val
+ commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
+ --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
+ --navmesh {my_formatting(scene_data.navmesh)}
+ --output_dir {my_formatting(scene_data.output_dir)}
+ --generate_depth {int(generate_depth)}
+ --exist_ok {int(exist_ok)}
+ """
+ commandline = " ".join(commandline.split())
+ return commandline
+
+if __name__ == "__main__":
+ os.umask(2)
+
+ parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available:
+ > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
+ """)
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true")
+ parser.add_argument("--scene", type=str, default="")
+ parser.add_argument("--scene_dataset_config_file", type=str, default="")
+ parser.add_argument("--navmesh", type=str, default="")
+
+ parser.add_argument("--generate_depth", type=int, default=1)
+ parser.add_argument("--exist_ok", type=int, default=0)
+
+ kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000)
+
+ args = parser.parse_args()
+ generate_depth=bool(args.generate_depth)
+ exist_ok = bool(args.exist_ok)
+
+ if args.list_commands:
+ # Listing scenes available...
+ scenes_data = list_scenes_available(base_output_dir=args.output_dir)
+
+ for scene_data in scenes_data:
+ print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok))
+ else:
+ if args.scene == "" or args.output_dir == "":
+ print("Missing scene or output dir argument!")
+ print(parser.format_help())
+ else:
+ generate_multiview_images_for_scene(scene=args.scene,
+ scene_dataset_config_file = args.scene_dataset_config_file,
+ navmesh = args.navmesh,
+ output_dir = args.output_dir,
+ exist_ok=exist_ok,
+ generate_depth=generate_depth,
+ **kwargs)
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..91e5f923b836a645caf5d8e4aacc425047e3c144
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py
@@ -0,0 +1,390 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+import os
+import numpy as np
+import quaternion
+import habitat_sim
+import json
+from sklearn.neighbors import NearestNeighbors
+import cv2
+
+# OpenCV to habitat camera convention transformation
+R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
+R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
+DEG2RAD = np.pi / 180
+
+def compute_camera_intrinsics(height, width, hfov):
+ f = width/2 / np.tan(hfov/2 * np.pi/180)
+ cu, cv = width/2, height/2
+ return f, cu, cv
+
+def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
+ R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
+ t_cam2world = np.asarray(camera_position)
+ return R_cam2world, t_cam2world
+
+def compute_pointmap(depthmap, hfov):
+ """ Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
+ height, width = depthmap.shape
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
+ # Cast depth map to point
+ z_cam = depthmap
+ u, v = np.meshgrid(range(width), range(height))
+ x_cam = (u - cu) / f * z_cam
+ y_cam = (v - cv) / f * z_cam
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
+ return X_cam
+
+def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
+ """Return a 3D point cloud corresponding to valid pixels of the depth map"""
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
+
+ X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
+ valid_mask = (X_cam[:,:,2] != 0.0)
+
+ X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
+ X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
+ return X_world
+
+def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False):
+ """
+ Compute 'overlapping' metrics based on a distance threshold between two point clouds.
+ """
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
+ distances, indices = nbrs.kneighbors(pointcloud1)
+ intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
+
+ data = {"intersection1": intersection1,
+ "size1": len(pointcloud1)}
+ if compute_symmetric:
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
+ distances, indices = nbrs.kneighbors(pointcloud2)
+ intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
+ data["intersection2"] = intersection2
+ data["size2"] = len(pointcloud2)
+
+ return data
+
+def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
+ """
+ Add camera parameters to the observation dictionnary produced by Habitat-Sim
+ In-place modifications.
+ """
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
+ height, width = observation['depth'].shape
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
+ K = np.asarray([[f, 0, cu],
+ [0, f, cv],
+ [0, 0, 1.0]])
+ observation["camera_intrinsics"] = K
+ observation["t_cam2world"] = t_cam2world
+ observation["R_cam2world"] = R_cam2world
+
+def look_at(eye, center, up, return_cam2world=True):
+ """
+ Return camera pose looking at a given center point.
+ Analogous of gluLookAt function, using OpenCV camera convention.
+ """
+ z = center - eye
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
+ y = -up
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
+ x = np.cross(y, z, axis=-1)
+
+ if return_cam2world:
+ R = np.stack((x, y, z), axis=-1)
+ t = eye
+ else:
+ # World to camera transformation
+ # Transposed matrix
+ R = np.stack((x, y, z), axis=-2)
+ t = - np.einsum('...ij, ...j', R, eye)
+ return R, t
+
+def look_at_for_habitat(eye, center, up, return_cam2world=True):
+ R, t = look_at(eye, center, up)
+ orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
+ return orientation, t
+
+def generate_orientation_noise(pan_range, tilt_range, roll_range):
+ return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
+ * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
+ * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
+
+
+class NoNaviguableSpaceError(RuntimeError):
+ def __init__(self, *args):
+ super().__init__(*args)
+
+class MultiviewHabitatSimGenerator:
+ def __init__(self,
+ scene,
+ navmesh,
+ scene_dataset_config_file,
+ resolution = (240, 320),
+ views_count=2,
+ hfov = 60,
+ gpu_id = 0,
+ size = 10000,
+ minimum_covisibility = 0.5,
+ transform = None):
+ self.scene = scene
+ self.navmesh = navmesh
+ self.scene_dataset_config_file = scene_dataset_config_file
+ self.resolution = resolution
+ self.views_count = views_count
+ assert(self.views_count >= 1)
+ self.hfov = hfov
+ self.gpu_id = gpu_id
+ self.size = size
+ self.transform = transform
+
+ # Noise added to camera orientation
+ self.pan_range = (-3, 3)
+ self.tilt_range = (-10, 10)
+ self.roll_range = (-5, 5)
+
+ # Height range to sample cameras
+ self.height_range = (1.2, 1.8)
+
+ # Random steps between the camera views
+ self.random_steps_count = 5
+ self.random_step_variance = 2.0
+
+ # Minimum fraction of the scene which should be valid (well defined depth)
+ self.minimum_valid_fraction = 0.7
+
+ # Distance threshold to see to select pairs
+ self.distance_threshold = 0.05
+ # Minimum IoU of a view point cloud with respect to the reference view to be kept.
+ self.minimum_covisibility = minimum_covisibility
+
+ # Maximum number of retries.
+ self.max_attempts_count = 100
+
+ self.seed = None
+ self._lazy_initialization()
+
+ def _lazy_initialization(self):
+ # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
+ if self.seed == None:
+ # Re-seed numpy generator
+ np.random.seed()
+ self.seed = np.random.randint(2**32-1)
+ sim_cfg = habitat_sim.SimulatorConfiguration()
+ sim_cfg.scene_id = self.scene
+ if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
+ sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
+ sim_cfg.random_seed = self.seed
+ sim_cfg.load_semantic_mesh = False
+ sim_cfg.gpu_device_id = self.gpu_id
+
+ depth_sensor_spec = habitat_sim.CameraSensorSpec()
+ depth_sensor_spec.uuid = "depth"
+ depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
+ depth_sensor_spec.resolution = self.resolution
+ depth_sensor_spec.hfov = self.hfov
+ depth_sensor_spec.position = [0.0, 0.0, 0]
+ depth_sensor_spec.orientation
+
+ rgb_sensor_spec = habitat_sim.CameraSensorSpec()
+ rgb_sensor_spec.uuid = "color"
+ rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
+ rgb_sensor_spec.resolution = self.resolution
+ rgb_sensor_spec.hfov = self.hfov
+ rgb_sensor_spec.position = [0.0, 0.0, 0]
+ agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
+
+ cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
+ self.sim = habitat_sim.Simulator(cfg)
+ if self.navmesh is not None and self.navmesh != "":
+ # Use pre-computed navmesh when available (usually better than those generated automatically)
+ self.sim.pathfinder.load_nav_mesh(self.navmesh)
+
+ if not self.sim.pathfinder.is_loaded:
+ # Try to compute a navmesh
+ navmesh_settings = habitat_sim.NavMeshSettings()
+ navmesh_settings.set_defaults()
+ self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
+
+ # Ensure that the navmesh is not empty
+ if not self.sim.pathfinder.is_loaded:
+ raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
+
+ self.agent = self.sim.initialize_agent(agent_id=0)
+
+ def close(self):
+ self.sim.close()
+
+ def __del__(self):
+ self.sim.close()
+
+ def __len__(self):
+ return self.size
+
+ def sample_random_viewpoint(self):
+ """ Sample a random viewpoint using the navmesh """
+ nav_point = self.sim.pathfinder.get_random_navigable_point()
+
+ # Sample a random viewpoint height
+ viewpoint_height = np.random.uniform(*self.height_range)
+ viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
+ viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
+ return viewpoint_position, viewpoint_orientation, nav_point
+
+ def sample_other_random_viewpoint(self, observed_point, nav_point):
+ """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
+ other_nav_point = nav_point
+
+ walk_directions = self.random_step_variance * np.asarray([1,0,1])
+ for i in range(self.random_steps_count):
+ temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
+ # Snapping may return nan when it fails
+ if not np.isnan(temp[0]):
+ other_nav_point = temp
+
+ other_viewpoint_height = np.random.uniform(*self.height_range)
+ other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
+
+ # Set viewing direction towards the central point
+ rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
+ rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
+ return position, rotation, other_nav_point
+
+ def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
+ # Observation
+ pixels_count = self.resolution[0] * self.resolution[1]
+ valid_fraction = len(other_pointcloud) / pixels_count
+ assert valid_fraction <= 1.0 and valid_fraction >= 0.0
+ overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
+ covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
+ is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
+ return is_valid, valid_fraction, covisibility
+
+ def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
+ # Observation
+ other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
+ return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
+
+ def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
+ agent_state = habitat_sim.AgentState()
+ agent_state.position = viewpoint_position
+ agent_state.rotation = viewpoint_orientation
+ self.agent.set_state(agent_state)
+ viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
+ _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
+ return viewpoint_observations
+
+ def __getitem__(self, useless_idx):
+ ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
+ # Extract point cloud
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
+ camera_position=ref_position, camera_rotation=ref_orientation)
+
+ pixels_count = self.resolution[0] * self.resolution[1]
+ ref_valid_fraction = len(ref_pointcloud) / pixels_count
+ assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
+ if ref_valid_fraction < self.minimum_valid_fraction:
+ # This should produce a recursion error at some point when something is very wrong.
+ return self[0]
+ # Pick an reference observed point in the point cloud
+ observed_point = np.mean(ref_pointcloud, axis=0)
+
+ # Add the first image as reference
+ viewpoints_observations = [ref_observations]
+ viewpoints_covisibility = [ref_valid_fraction]
+ viewpoints_positions = [ref_position]
+ viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
+ viewpoints_clouds = [ref_pointcloud]
+ viewpoints_valid_fractions = [ref_valid_fraction]
+
+ for _ in range(self.views_count - 1):
+ # Generate an other viewpoint using some dummy random walk
+ successful_sampling = False
+ for sampling_attempt in range(self.max_attempts_count):
+ position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
+ # Observation
+ other_viewpoint_observations = self.render_viewpoint(position, rotation)
+ other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
+
+ is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
+ if is_valid:
+ successful_sampling = True
+ break
+ if not successful_sampling:
+ print("WARNING: Maximum number of attempts reached.")
+ # Dirty hack, try using a novel original viewpoint
+ return self[0]
+ viewpoints_observations.append(other_viewpoint_observations)
+ viewpoints_covisibility.append(covisibility)
+ viewpoints_positions.append(position)
+ viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
+ viewpoints_clouds.append(other_pointcloud)
+ viewpoints_valid_fractions.append(valid_fraction)
+
+ # Estimate relations between all pairs of images
+ pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
+ for i in range(len(viewpoints_observations)):
+ pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
+ for j in range(i+1, len(viewpoints_observations)):
+ overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
+ pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
+ pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
+
+ # IoU is relative to the image 0
+ data = {"observations": viewpoints_observations,
+ "positions": np.asarray(viewpoints_positions),
+ "orientations": np.asarray(viewpoints_orientations),
+ "covisibility_ratios": np.asarray(viewpoints_covisibility),
+ "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
+ "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
+ }
+
+ if self.transform is not None:
+ data = self.transform(data)
+ return data
+
+ def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
+ """
+ Return a list of images corresponding to a spiral trajectory from a random starting point.
+ Useful to generate nice visualisations.
+ Use an even number of half turns to get a nice "C1-continuous" loop effect
+ """
+ ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
+ camera_position=ref_position, camera_rotation=ref_orientation)
+ pixels_count = self.resolution[0] * self.resolution[1]
+ if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
+ # Dirty hack: ensure that the valid part of the image is significant
+ return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
+
+ # Pick an observed point in the point cloud
+ observed_point = np.mean(ref_pointcloud, axis=0)
+ ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
+
+ images = []
+ is_valid = []
+ # Spiral trajectory, use_constant orientation
+ for i, alpha in enumerate(np.linspace(0, 1, images_count)):
+ r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
+ theta = alpha * half_turns * np.pi
+ x = r * np.cos(theta)
+ y = r * np.sin(theta)
+ z = 0.0
+ position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
+ if use_constant_orientation:
+ orientation = ref_orientation
+ else:
+ # trajectory looking at a mean point in front of the ref observation
+ orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
+ observations = self.render_viewpoint(position, orientation)
+ images.append(observations['color'][...,:3])
+ _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
+ is_valid.append(_is_valid)
+ return images, np.all(is_valid)
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..10672a01f7dd615d3b4df37781f7f6f97e753ba6
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py
@@ -0,0 +1,69 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+"""
+Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
+"""
+import os
+import glob
+from tqdm import tqdm
+import shutil
+import json
+from datasets.habitat_sim.paths import *
+import argparse
+import collections
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input_dir")
+ parser.add_argument("output_dir")
+ args = parser.parse_args()
+
+ input_dirname = args.input_dir
+ output_dirname = args.output_dir
+
+ input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
+
+ images_count = collections.defaultdict(lambda : 0)
+
+ os.makedirs(output_dirname)
+ for input_filename in tqdm(input_metadata_filenames):
+ # Ignore empty files
+ with open(input_filename, "r") as f:
+ original_metadata = json.load(f)
+ if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
+ print("No views in", input_filename)
+ continue
+
+ relpath = os.path.relpath(input_filename, input_dirname)
+ print(relpath)
+
+ # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
+ # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
+ scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
+ metadata = dict()
+ for key, value in original_metadata.items():
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
+ known_path = False
+ for dataset, dataset_path in scenes_dataset_paths.items():
+ if value.startswith(dataset_path):
+ value = os.path.join(dataset, os.path.relpath(value, dataset_path))
+ known_path = True
+ break
+ if not known_path:
+ raise KeyError("Unknown path:" + value)
+ metadata[key] = value
+
+ # Compile some general statistics while packing data
+ scene_split = metadata["scene"].split("/")
+ upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
+ images_count[upper_level] += len(metadata["multiviews"])
+
+ output_filename = os.path.join(output_dirname, relpath)
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
+ with open(output_filename, "w") as f:
+ json.dump(metadata, f)
+
+ # Print statistics
+ print("Images count:")
+ for upper_level, count in images_count.items():
+ print(f"- {upper_level}: {count}")
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/datasets/habitat_sim/paths.py b/imcui/third_party/dust3r/croco/datasets/habitat_sim/paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d63b5fa29c274ddfeae084734a35ba66d7edee8
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/habitat_sim/paths.py
@@ -0,0 +1,129 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+"""
+Paths to Habitat-Sim scenes
+"""
+
+import os
+import json
+import collections
+from tqdm import tqdm
+
+
+# Hardcoded path to the different scene datasets
+SCENES_DATASET = {
+ "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
+ "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
+ "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
+ "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
+ "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
+ "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
+ "scannet": "./data/habitat-sim/scene_datasets/scannet/"
+}
+
+SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
+
+def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
+ scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
+ navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
+ scenes_data = []
+ for idx in range(len(scenes)):
+ output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
+ # Add scene
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
+ scene = scenes[idx] + ".scene_instance.json",
+ navmesh = os.path.join(base_path, navmeshes[idx]),
+ output_dir = output_dir)
+ scenes_data.append(data)
+ return scenes_data
+
+def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
+ scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
+ navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
+ scenes_data = []
+ for idx in range(len(scenes)):
+ output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
+ scene = scenes[idx],
+ navmesh = "",
+ output_dir = output_dir)
+ scenes_data.append(data)
+ return scenes_data
+
+def list_replica_scenes(base_output_dir, base_path):
+ scenes_data = []
+ for scene_id in os.listdir(base_path):
+ scene = os.path.join(base_path, scene_id, "mesh.ply")
+ navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
+ scene_dataset_config_file = ""
+ output_dir = os.path.join(base_output_dir, scene_id)
+ # Add scene only if it does not exist already, or if exist_ok
+ data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
+ scene = scene,
+ navmesh = navmesh,
+ output_dir = output_dir)
+ scenes_data.append(data)
+ return scenes_data
+
+
+def list_scenes(base_output_dir, base_path):
+ """
+ Generic method iterating through a base_path folder to find scenes.
+ """
+ scenes_data = []
+ for root, dirs, files in os.walk(base_path, followlinks=True):
+ folder_scenes_data = []
+ for file in files:
+ name, ext = os.path.splitext(file)
+ if ext == ".glb":
+ scene = os.path.join(root, name + ".glb")
+ navmesh = os.path.join(root, name + ".navmesh")
+ if not os.path.exists(navmesh):
+ navmesh = ""
+ relpath = os.path.relpath(root, base_path)
+ output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
+ data = SceneData(scene_dataset_config_file="",
+ scene = scene,
+ navmesh = navmesh,
+ output_dir = output_dir)
+ folder_scenes_data.append(data)
+
+ # Specific check for HM3D:
+ # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
+ basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
+ if len(basis_scenes) != 0:
+ folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
+
+ scenes_data.extend(folder_scenes_data)
+ return scenes_data
+
+def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
+ scenes_data = []
+
+ # HM3D
+ for split in ("minival", "train", "val", "examples"):
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
+ base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
+
+ # Gibson
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
+ base_path=scenes_dataset_paths["gibson"])
+
+ # Habitat test scenes (just a few)
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
+ base_path=scenes_dataset_paths["habitat-test-scenes"])
+
+ # ReplicaCAD (baked lightning)
+ scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
+
+ # ScanNet
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
+ base_path=scenes_dataset_paths["scannet"])
+
+ # Replica
+ list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
+ base_path=scenes_dataset_paths["replica"])
+ return scenes_data
diff --git a/imcui/third_party/dust3r/croco/datasets/pairs_dataset.py b/imcui/third_party/dust3r/croco/datasets/pairs_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f107526b34e154d9013a9a7a0bde3d5ff6f581c
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/pairs_dataset.py
@@ -0,0 +1,109 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+import os
+from torch.utils.data import Dataset
+from PIL import Image
+
+from datasets.transforms import get_pair_transforms
+
+def load_image(impath):
+ return Image.open(impath)
+
+def load_pairs_from_cache_file(fname, root=''):
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
+ with open(fname, 'r') as fid:
+ lines = fid.read().strip().splitlines()
+ pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
+ return pairs
+
+def load_pairs_from_list_file(fname, root=''):
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
+ with open(fname, 'r') as fid:
+ lines = fid.read().strip().splitlines()
+ pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
+ return pairs
+
+
+def write_cache_file(fname, pairs, root=''):
+ if len(root)>0:
+ if not root.endswith('/'): root+='/'
+ assert os.path.isdir(root)
+ s = ''
+ for im1, im2 in pairs:
+ if len(root)>0:
+ assert im1.startswith(root), im1
+ assert im2.startswith(root), im2
+ s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
+ with open(fname, 'w') as fid:
+ fid.write(s[:-1])
+
+def parse_and_cache_all_pairs(dname, data_dir='./data/'):
+ if dname=='habitat_release':
+ dirname = os.path.join(data_dir, 'habitat_release')
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
+ cache_file = os.path.join(dirname, 'pairs.txt')
+ assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
+
+ print('Parsing pairs for dataset: '+dname)
+ pairs = []
+ for root, dirs, files in os.walk(dirname):
+ if 'val' in root: continue
+ dirs.sort()
+ pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
+ print('Found {:,} pairs'.format(len(pairs)))
+ print('Writing cache to: '+cache_file)
+ write_cache_file(cache_file, pairs, root=dirname)
+
+ else:
+ raise NotImplementedError('Unknown dataset: '+dname)
+
+def dnames_to_image_pairs(dnames, data_dir='./data/'):
+ """
+ dnames: list of datasets with image pairs, separated by +
+ """
+ all_pairs = []
+ for dname in dnames.split('+'):
+ if dname=='habitat_release':
+ dirname = os.path.join(data_dir, 'habitat_release')
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
+ cache_file = os.path.join(dirname, 'pairs.txt')
+ assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
+ pairs = load_pairs_from_cache_file(cache_file, root=dirname)
+ elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
+ dirname = os.path.join(data_dir, dname+'_crops')
+ assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
+ list_file = os.path.join(dirname, 'listing.txt')
+ assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
+ pairs = load_pairs_from_list_file(list_file, root=dirname)
+ print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
+ all_pairs += pairs
+ if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
+ return all_pairs
+
+
+class PairsDataset(Dataset):
+
+ def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
+ super().__init__()
+ self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
+ self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
+
+ def __len__(self):
+ return len(self.image_pairs)
+
+ def __getitem__(self, index):
+ im1path, im2path = self.image_pairs[index]
+ im1 = load_image(im1path)
+ im2 = load_image(im2path)
+ if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
+ return im1, im2
+
+
+if __name__=="__main__":
+ import argparse
+ parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
+ parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
+ parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
+ args = parser.parse_args()
+ parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
diff --git a/imcui/third_party/dust3r/croco/datasets/transforms.py b/imcui/third_party/dust3r/croco/datasets/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..216bac61f8254fd50e7f269ee80301f250a2d11e
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/datasets/transforms.py
@@ -0,0 +1,95 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+import torch
+import torchvision.transforms
+import torchvision.transforms.functional as F
+
+# "Pair": apply a transform on a pair
+# "Both": apply the exact same transform to both images
+
+class ComposePair(torchvision.transforms.Compose):
+ def __call__(self, img1, img2):
+ for t in self.transforms:
+ img1, img2 = t(img1, img2)
+ return img1, img2
+
+class NormalizeBoth(torchvision.transforms.Normalize):
+ def forward(self, img1, img2):
+ img1 = super().forward(img1)
+ img2 = super().forward(img2)
+ return img1, img2
+
+class ToTensorBoth(torchvision.transforms.ToTensor):
+ def __call__(self, img1, img2):
+ img1 = super().__call__(img1)
+ img2 = super().__call__(img2)
+ return img1, img2
+
+class RandomCropPair(torchvision.transforms.RandomCrop):
+ # the crop will be intentionally different for the two images with this class
+ def forward(self, img1, img2):
+ img1 = super().forward(img1)
+ img2 = super().forward(img2)
+ return img1, img2
+
+class ColorJitterPair(torchvision.transforms.ColorJitter):
+ # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
+ def __init__(self, assymetric_prob, **kwargs):
+ super().__init__(**kwargs)
+ self.assymetric_prob = assymetric_prob
+ def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness_factor is not None:
+ img = F.adjust_brightness(img, brightness_factor)
+ elif fn_id == 1 and contrast_factor is not None:
+ img = F.adjust_contrast(img, contrast_factor)
+ elif fn_id == 2 and saturation_factor is not None:
+ img = F.adjust_saturation(img, saturation_factor)
+ elif fn_id == 3 and hue_factor is not None:
+ img = F.adjust_hue(img, hue_factor)
+ return img
+
+ def forward(self, img1, img2):
+
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue
+ )
+ img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
+ if torch.rand(1) < self.assymetric_prob: # assymetric:
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
+ self.brightness, self.contrast, self.saturation, self.hue
+ )
+ img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
+ return img1, img2
+
+def get_pair_transforms(transform_str, totensor=True, normalize=True):
+ # transform_str is eg crop224+color
+ trfs = []
+ for s in transform_str.split('+'):
+ if s.startswith('crop'):
+ size = int(s[len('crop'):])
+ trfs.append(RandomCropPair(size))
+ elif s=='acolor':
+ trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
+ elif s=='': # if transform_str was ""
+ pass
+ else:
+ raise NotImplementedError('Unknown augmentation: '+s)
+
+ if totensor:
+ trfs.append( ToTensorBoth() )
+ if normalize:
+ trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
+
+ if len(trfs)==0:
+ return None
+ elif len(trfs)==1:
+ return trfs
+ else:
+ return ComposePair(trfs)
+
+
+
+
+
diff --git a/imcui/third_party/dust3r/croco/demo.py b/imcui/third_party/dust3r/croco/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b80ccc5c98c18e20d1ce782511aa824ef28f77
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/demo.py
@@ -0,0 +1,55 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+import torch
+from models.croco import CroCoNet
+from PIL import Image
+import torchvision.transforms
+from torchvision.transforms import ToTensor, Normalize, Compose
+
+def main():
+ device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu')
+
+ # load 224x224 images and transform them to tensor
+ imagenet_mean = [0.485, 0.456, 0.406]
+ imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True)
+ imagenet_std = [0.229, 0.224, 0.225]
+ imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True)
+ trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
+ image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
+ image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
+
+ # load model
+ ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
+ model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device)
+ model.eval()
+ msg = model.load_state_dict(ckpt['model'], strict=True)
+
+ # forward
+ with torch.inference_mode():
+ out, mask, target = model(image1, image2)
+
+ # the output is normalized, thus use the mean/std of the actual image to go back to RGB space
+ patchified = model.patchify(image1)
+ mean = patchified.mean(dim=-1, keepdim=True)
+ var = patchified.var(dim=-1, keepdim=True)
+ decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean)
+ # undo imagenet normalization, prepare masked image
+ decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
+ input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
+ ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
+ image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])
+ masked_input_image = ((1 - image_masks) * input_image)
+
+ # make visualization
+ visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4
+ B, C, H, W = visualization.shape
+ visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W)
+ visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))
+ fname = "demo_output.png"
+ visualization.save(fname)
+ print('Visualization save in '+fname)
+
+
+if __name__=="__main__":
+ main()
diff --git a/imcui/third_party/dust3r/croco/models/blocks.py b/imcui/third_party/dust3r/croco/models/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..18133524f0ae265b0bd8d062d7c9eeaa63858a9b
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/blocks.py
@@ -0,0 +1,241 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# Main encoder/decoder blocks
+# --------------------------------------------------------
+# References:
+# timm
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
+
+
+import torch
+import torch.nn as nn
+
+from itertools import repeat
+import collections.abc
+
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+to_2tuple = _ntuple(2)
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+class Attention(nn.Module):
+
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(self, x, xpos):
+ B, N, C = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
+ q, k, v = [qkv[:,:,i] for i in range(3)]
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, xpos):
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+class CrossAttention(nn.Module):
+
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.rope = rope
+
+ def forward(self, query, key, value, qpos, kpos):
+ B, Nq, C = query.shape
+ Nk = key.shape[1]
+ Nv = value.shape[1]
+
+ q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
+ k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
+ v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
+
+ if self.rope is not None:
+ q = self.rope(q, qpos)
+ k = self.rope(k, kpos)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class DecoderBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.norm3 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
+
+ def forward(self, x, y, xpos, ypos):
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
+ y_ = self.norm_y(y)
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
+ return x, y
+
+
+# patch embedding
+class PositionGetter(object):
+ """ return positions of patches """
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, h, w, device):
+ if not (h,w) in self.cache_positions:
+ x = torch.arange(w, device=device)
+ y = torch.arange(h, device=device)
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
+ return pos
+
+class PatchEmbed(nn.Module):
+ """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ self.position_getter = PositionGetter()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+ x = self.proj(x)
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x, pos
+
+ def _init_weights(self):
+ w = self.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
diff --git a/imcui/third_party/dust3r/croco/models/criterion.py b/imcui/third_party/dust3r/croco/models/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..11696c40865344490f23796ea45e8fbd5e654731
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/criterion.py
@@ -0,0 +1,37 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Criterion to train CroCo
+# --------------------------------------------------------
+# References:
+# MAE: https://github.com/facebookresearch/mae
+# --------------------------------------------------------
+
+import torch
+
+class MaskedMSE(torch.nn.Module):
+
+ def __init__(self, norm_pix_loss=False, masked=True):
+ """
+ norm_pix_loss: normalize each patch by their pixel mean and variance
+ masked: compute loss over the masked patches only
+ """
+ super().__init__()
+ self.norm_pix_loss = norm_pix_loss
+ self.masked = masked
+
+ def forward(self, pred, mask, target):
+
+ if self.norm_pix_loss:
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.e-6)**.5
+
+ loss = (pred - target) ** 2
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
+ if self.masked:
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
+ else:
+ loss = loss.mean() # mean loss
+ return loss
diff --git a/imcui/third_party/dust3r/croco/models/croco.py b/imcui/third_party/dust3r/croco/models/croco.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c68634152d75555b4c35c25af268394c5821fe
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/croco.py
@@ -0,0 +1,249 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# CroCo model during pretraining
+# --------------------------------------------------------
+
+
+
+import torch
+import torch.nn as nn
+torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
+from functools import partial
+
+from models.blocks import Block, DecoderBlock, PatchEmbed
+from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
+from models.masking import RandomMask
+
+
+class CroCoNet(nn.Module):
+
+ def __init__(self,
+ img_size=224, # input image size
+ patch_size=16, # patch_size
+ mask_ratio=0.9, # ratios of masked tokens
+ enc_embed_dim=768, # encoder feature dimension
+ enc_depth=12, # encoder depth
+ enc_num_heads=12, # encoder number of heads in the transformer block
+ dec_embed_dim=512, # decoder feature dimension
+ dec_depth=8, # decoder depth
+ dec_num_heads=16, # decoder number of heads in the transformer block
+ mlp_ratio=4,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
+ pos_embed='cosine', # positional embedding (either cosine or RoPE100)
+ ):
+
+ super(CroCoNet, self).__init__()
+
+ # patch embeddings (with initialization done as in MAE)
+ self._set_patch_embed(img_size, patch_size, enc_embed_dim)
+
+ # mask generations
+ self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
+
+ self.pos_embed = pos_embed
+ if pos_embed=='cosine':
+ # positional embedding of the encoder
+ enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
+ self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
+ # positional embedding of the decoder
+ dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
+ self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
+ # pos embedding in each block
+ self.rope = None # nothing for cosine
+ elif pos_embed.startswith('RoPE'): # eg RoPE100
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
+ freq = float(pos_embed[len('RoPE'):])
+ self.rope = RoPE2D(freq=freq)
+ else:
+ raise NotImplementedError('Unknown pos_embed '+pos_embed)
+
+ # transformer for the encoder
+ self.enc_depth = enc_depth
+ self.enc_embed_dim = enc_embed_dim
+ self.enc_blocks = nn.ModuleList([
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
+ for i in range(enc_depth)])
+ self.enc_norm = norm_layer(enc_embed_dim)
+
+ # masked tokens
+ self._set_mask_token(dec_embed_dim)
+
+ # decoder
+ self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
+
+ # prediction head
+ self._set_prediction_head(dec_embed_dim, patch_size)
+
+ # initializer weights
+ self.initialize_weights()
+
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
+
+ def _set_mask_generator(self, num_patches, mask_ratio):
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
+
+ def _set_mask_token(self, dec_embed_dim):
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
+
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
+ self.dec_depth = dec_depth
+ self.dec_embed_dim = dec_embed_dim
+ # transfer from encoder to decoder
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
+ # transformer for the decoder
+ self.dec_blocks = nn.ModuleList([
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
+ for i in range(dec_depth)])
+ # final norm layer
+ self.dec_norm = norm_layer(dec_embed_dim)
+
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
+
+
+ def initialize_weights(self):
+ # patch embed
+ self.patch_embed._init_weights()
+ # mask tokens
+ if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
+ # linears and layer norms
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
+ """
+ image has B x 3 x img_size x img_size
+ do_mask: whether to perform masking or not
+ return_all_blocks: if True, return the features at the end of every block
+ instead of just the features from the last block (eg for some prediction heads)
+ """
+ # embed the image into patches (x has size B x Npatches x C)
+ # and get position if each return patch (pos has size B x Npatches x 2)
+ x, pos = self.patch_embed(image)
+ # add positional embedding without cls token
+ if self.enc_pos_embed is not None:
+ x = x + self.enc_pos_embed[None,...]
+ # apply masking
+ B,N,C = x.size()
+ if do_mask:
+ masks = self.mask_generator(x)
+ x = x[~masks].view(B, -1, C)
+ posvis = pos[~masks].view(B, -1, 2)
+ else:
+ B,N,C = x.size()
+ masks = torch.zeros((B,N), dtype=bool)
+ posvis = pos
+ # now apply the transformer encoder and normalization
+ if return_all_blocks:
+ out = []
+ for blk in self.enc_blocks:
+ x = blk(x, posvis)
+ out.append(x)
+ out[-1] = self.enc_norm(out[-1])
+ return out, pos, masks
+ else:
+ for blk in self.enc_blocks:
+ x = blk(x, posvis)
+ x = self.enc_norm(x)
+ return x, pos, masks
+
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
+ """
+ return_all_blocks: if True, return the features at the end of every block
+ instead of just the features from the last block (eg for some prediction heads)
+
+ masks1 can be None => assume image1 fully visible
+ """
+ # encoder to decoder layer
+ visf1 = self.decoder_embed(feat1)
+ f2 = self.decoder_embed(feat2)
+ # append masked tokens to the sequence
+ B,Nenc,C = visf1.size()
+ if masks1 is None: # downstreams
+ f1_ = visf1
+ else: # pretraining
+ Ntotal = masks1.size(1)
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
+ f1_[~masks1] = visf1.view(B * Nenc, C)
+ # add positional embedding
+ if self.dec_pos_embed is not None:
+ f1_ = f1_ + self.dec_pos_embed
+ f2 = f2 + self.dec_pos_embed
+ # apply Transformer blocks
+ out = f1_
+ out2 = f2
+ if return_all_blocks:
+ _out, out = out, []
+ for blk in self.dec_blocks:
+ _out, out2 = blk(_out, out2, pos1, pos2)
+ out.append(_out)
+ out[-1] = self.dec_norm(out[-1])
+ else:
+ for blk in self.dec_blocks:
+ out, out2 = blk(out, out2, pos1, pos2)
+ out = self.dec_norm(out)
+ return out
+
+ def patchify(self, imgs):
+ """
+ imgs: (B, 3, H, W)
+ x: (B, L, patch_size**2 *3)
+ """
+ p = self.patch_embed.patch_size[0]
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum('nchpwq->nhwpqc', x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
+
+ return x
+
+ def unpatchify(self, x, channels=3):
+ """
+ x: (N, L, patch_size**2 *channels)
+ imgs: (N, 3, H, W)
+ """
+ patch_size = self.patch_embed.patch_size[0]
+ h = w = int(x.shape[1]**.5)
+ assert h * w == x.shape[1]
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
+ return imgs
+
+ def forward(self, img1, img2):
+ """
+ img1: tensor of size B x 3 x img_size x img_size
+ img2: tensor of size B x 3 x img_size x img_size
+
+ out will be B x N x (3*patch_size*patch_size)
+ masks are also returned as B x N just in case
+ """
+ # encoder of the masked first image
+ feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
+ # encoder of the second image
+ feat2, pos2, _ = self._encode_image(img2, do_mask=False)
+ # decoder
+ decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
+ # prediction head
+ out = self.prediction_head(decfeat)
+ # get target
+ target = self.patchify(img1)
+ return out, mask1, target
diff --git a/imcui/third_party/dust3r/croco/models/croco_downstream.py b/imcui/third_party/dust3r/croco/models/croco_downstream.py
new file mode 100644
index 0000000000000000000000000000000000000000..159dfff4d2c1461bc235e21441b57ce1e2088f76
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/croco_downstream.py
@@ -0,0 +1,122 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+# --------------------------------------------------------
+# CroCo model for downstream tasks
+# --------------------------------------------------------
+
+import torch
+
+from .croco import CroCoNet
+
+
+def croco_args_from_ckpt(ckpt):
+ if 'croco_kwargs' in ckpt: # CroCo v2 released models
+ return ckpt['croco_kwargs']
+ elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release
+ s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)"
+ assert s.startswith('CroCoNet(')
+ return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it
+ else: # CroCo v1 released models
+ return dict()
+
+class CroCoDownstreamMonocularEncoder(CroCoNet):
+
+ def __init__(self,
+ head,
+ **kwargs):
+ """ Build network for monocular downstream task, only using the encoder.
+ It takes an extra argument head, that is called with the features
+ and a dictionary img_info containing 'width' and 'height' keys
+ The head is setup with the croconet arguments in this init function
+ NOTE: It works by *calling super().__init__() but with redefined setters
+
+ """
+ super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs)
+ head.setup(self)
+ self.head = head
+
+ def _set_mask_generator(self, *args, **kwargs):
+ """ No mask generator """
+ return
+
+ def _set_mask_token(self, *args, **kwargs):
+ """ No mask token """
+ self.mask_token = None
+ return
+
+ def _set_decoder(self, *args, **kwargs):
+ """ No decoder """
+ return
+
+ def _set_prediction_head(self, *args, **kwargs):
+ """ No 'prediction head' for downstream tasks."""
+ return
+
+ def forward(self, img):
+ """
+ img if of size batch_size x 3 x h x w
+ """
+ B, C, H, W = img.size()
+ img_info = {'height': H, 'width': W}
+ need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks
+ out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers)
+ return self.head(out, img_info)
+
+
+class CroCoDownstreamBinocular(CroCoNet):
+
+ def __init__(self,
+ head,
+ **kwargs):
+ """ Build network for binocular downstream task
+ It takes an extra argument head, that is called with the features
+ and a dictionary img_info containing 'width' and 'height' keys
+ The head is setup with the croconet arguments in this init function
+ """
+ super(CroCoDownstreamBinocular, self).__init__(**kwargs)
+ head.setup(self)
+ self.head = head
+
+ def _set_mask_generator(self, *args, **kwargs):
+ """ No mask generator """
+ return
+
+ def _set_mask_token(self, *args, **kwargs):
+ """ No mask token """
+ self.mask_token = None
+ return
+
+ def _set_prediction_head(self, *args, **kwargs):
+ """ No prediction head for downstream tasks, define your own head """
+ return
+
+ def encode_image_pairs(self, img1, img2, return_all_blocks=False):
+ """ run encoder for a pair of images
+ it is actually ~5% faster to concatenate the images along the batch dimension
+ than to encode them separately
+ """
+ ## the two commented lines below is the naive version with separate encoding
+ #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks)
+ #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False)
+ ## and now the faster version
+ out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks )
+ if return_all_blocks:
+ out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out])))
+ out2 = out2[-1]
+ else:
+ out,out2 = out.chunk(2, dim=0)
+ pos,pos2 = pos.chunk(2, dim=0)
+ return out, out2, pos, pos2
+
+ def forward(self, img1, img2):
+ B, C, H, W = img1.size()
+ img_info = {'height': H, 'width': W}
+ return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks
+ out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks)
+ if return_all_blocks:
+ decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks)
+ decout = out+decout
+ else:
+ decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks)
+ return self.head(decout, img_info)
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/models/curope/__init__.py b/imcui/third_party/dust3r/croco/models/curope/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e3d48a162760260826080f6366838e83e26878
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/curope/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+from .curope2d import cuRoPE2D
diff --git a/imcui/third_party/dust3r/croco/models/curope/curope2d.py b/imcui/third_party/dust3r/croco/models/curope/curope2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49c12f8c529e9a889b5ac20c5767158f238e17d
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/curope/curope2d.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+import torch
+
+try:
+ import curope as _kernels # run `python setup.py install`
+except ModuleNotFoundError:
+ from . import curope as _kernels # run `python setup.py build_ext --inplace`
+
+
+class cuRoPE2D_func (torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, tokens, positions, base, F0=1):
+ ctx.save_for_backward(positions)
+ ctx.saved_base = base
+ ctx.saved_F0 = F0
+ # tokens = tokens.clone() # uncomment this if inplace doesn't work
+ _kernels.rope_2d( tokens, positions, base, F0 )
+ ctx.mark_dirty(tokens)
+ return tokens
+
+ @staticmethod
+ def backward(ctx, grad_res):
+ positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
+ _kernels.rope_2d( grad_res, positions, base, -F0 )
+ ctx.mark_dirty(grad_res)
+ return grad_res, None, None, None
+
+
+class cuRoPE2D(torch.nn.Module):
+ def __init__(self, freq=100.0, F0=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+
+ def forward(self, tokens, positions):
+ cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
+ return tokens
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/models/curope/setup.py b/imcui/third_party/dust3r/croco/models/curope/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..230632ed05e309200e8f93a3a852072333975009
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/curope/setup.py
@@ -0,0 +1,34 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+from setuptools import setup
+from torch import cuda
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+# compile for all possible CUDA architectures
+all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
+# alternatively, you can list cuda archs that you want, eg:
+# all_cuda_archs = [
+ # '-gencode', 'arch=compute_70,code=sm_70',
+ # '-gencode', 'arch=compute_75,code=sm_75',
+ # '-gencode', 'arch=compute_80,code=sm_80',
+ # '-gencode', 'arch=compute_86,code=sm_86'
+# ]
+
+setup(
+ name = 'curope',
+ ext_modules = [
+ CUDAExtension(
+ name='curope',
+ sources=[
+ "curope.cpp",
+ "kernels.cu",
+ ],
+ extra_compile_args = dict(
+ nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
+ cxx=['-O3'])
+ )
+ ],
+ cmdclass = {
+ 'build_ext': BuildExtension
+ })
diff --git a/imcui/third_party/dust3r/croco/models/dpt_block.py b/imcui/third_party/dust3r/croco/models/dpt_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ddfb74e2769ceca88720d4c730e00afd71c763
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/dpt_block.py
@@ -0,0 +1,450 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+# --------------------------------------------------------
+# DPT head for ViTs
+# --------------------------------------------------------
+# References:
+# https://github.com/isl-org/DPT
+# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from typing import Union, Tuple, Iterable, List, Optional, Dict
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+def make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand == True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0],
+ out_shape1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1],
+ out_shape2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2],
+ out_shape3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3],
+ out_shape4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+
+ scratch.layer_rn = nn.ModuleList([
+ scratch.layer1_rn,
+ scratch.layer2_rn,
+ scratch.layer3_rn,
+ scratch.layer4_rn,
+ ])
+
+ return scratch
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups = 1
+
+ self.conv1 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ self.conv2 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ width_ratio=1,
+ ):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+ self.width_ratio = width_ratio
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups = 1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ groups=1,
+ )
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ if self.width_ratio != 1:
+ res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
+
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ if self.width_ratio != 1:
+ # and output.shape[3] < self.width_ratio * output.shape[2]
+ #size=(image.shape[])
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
+ shape = 3 * output.shape[3]
+ else:
+ shape = int(self.width_ratio * 2 * output.shape[2])
+ output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
+ else:
+ output = nn.functional.interpolate(output, scale_factor=2,
+ mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+ return output
+
+def make_fusion_block(features, use_bn, width_ratio=1):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ width_ratio=width_ratio,
+ )
+
+class Interpolate(nn.Module):
+ """Interpolation module."""
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x,
+ scale_factor=self.scale_factor,
+ mode=self.mode,
+ align_corners=self.align_corners,
+ )
+
+ return x
+
+class DPTOutputAdapter(nn.Module):
+ """DPT output adapter.
+
+ :param num_cahnnels: Number of output channels
+ :param stride_level: tride level compared to the full-sized image.
+ E.g. 4 for 1/4th the size of the image.
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
+ Patch size for smaller inputs will be computed accordingly.
+ :param hooks: Index of intermediate layers
+ :param layer_dims: Dimension of intermediate layers
+ :param feature_dim: Feature dimension
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
+ :param use_bn: If set to True, activates batch norm
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
+ """
+
+ def __init__(self,
+ num_channels: int = 1,
+ stride_level: int = 1,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ main_tasks: Iterable[str] = ('rgb',),
+ hooks: List[int] = [2, 5, 8, 11],
+ layer_dims: List[int] = [96, 192, 384, 768],
+ feature_dim: int = 256,
+ last_dim: int = 32,
+ use_bn: bool = False,
+ dim_tokens_enc: Optional[int] = None,
+ head_type: str = 'regression',
+ output_width_ratio=1,
+ **kwargs):
+ super().__init__()
+ self.num_channels = num_channels
+ self.stride_level = stride_level
+ self.patch_size = pair(patch_size)
+ self.main_tasks = main_tasks
+ self.hooks = hooks
+ self.layer_dims = layer_dims
+ self.feature_dim = feature_dim
+ self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
+ self.head_type = head_type
+
+ # Actual patch height and width, taking into account stride of input
+ self.P_H = max(1, self.patch_size[0] // stride_level)
+ self.P_W = max(1, self.patch_size[1] // stride_level)
+
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
+
+ self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+
+ if self.head_type == 'regression':
+ # The "DPTDepthModel" head
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
+ )
+ elif self.head_type == 'semseg':
+ # The "DPTSegmentationModel" head
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
+ nn.ReLU(True),
+ nn.Dropout(0.1, False),
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ )
+ else:
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
+
+ if self.dim_tokens_enc is not None:
+ self.init(dim_tokens_enc=dim_tokens_enc)
+
+ def init(self, dim_tokens_enc=768):
+ """
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
+ Should be called when setting up MultiMAE.
+
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
+ """
+ #print(dim_tokens_enc)
+
+ # Set up activation postprocessing layers
+ if isinstance(dim_tokens_enc, int):
+ dim_tokens_enc = 4 * [dim_tokens_enc]
+
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
+
+ self.act_1_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[0],
+ out_channels=self.layer_dims[0],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=self.layer_dims[0],
+ out_channels=self.layer_dims[0],
+ kernel_size=4, stride=4, padding=0,
+ bias=True, dilation=1, groups=1,
+ )
+ )
+
+ self.act_2_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[1],
+ out_channels=self.layer_dims[1],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=self.layer_dims[1],
+ out_channels=self.layer_dims[1],
+ kernel_size=2, stride=2, padding=0,
+ bias=True, dilation=1, groups=1,
+ )
+ )
+
+ self.act_3_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[2],
+ out_channels=self.layer_dims[2],
+ kernel_size=1, stride=1, padding=0,
+ )
+ )
+
+ self.act_4_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[3],
+ out_channels=self.layer_dims[3],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=self.layer_dims[3],
+ out_channels=self.layer_dims[3],
+ kernel_size=3, stride=2, padding=1,
+ )
+ )
+
+ self.act_postprocess = nn.ModuleList([
+ self.act_1_postprocess,
+ self.act_2_postprocess,
+ self.act_3_postprocess,
+ self.act_4_postprocess
+ ])
+
+ def adapt_tokens(self, encoder_tokens):
+ # Adapt tokens
+ x = []
+ x.append(encoder_tokens[:, :])
+ x = torch.cat(x, dim=-1)
+ return x
+
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
+ #input_info: Dict):
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+ H, W = image_size
+
+ # Number of patches in height and width
+ N_H = H // (self.stride_level * self.P_H)
+ N_W = W // (self.stride_level * self.P_W)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layers = [encoder_tokens[hook] for hook in self.hooks]
+
+ # Extract only task-relevant tokens and ignore global tokens.
+ layers = [self.adapt_tokens(l) for l in layers]
+
+ # Reshape tokens to spatial representation
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+ # Project layers to chosen feature dim
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+
+ # Fuse layers using refinement stages
+ path_4 = self.scratch.refinenet4(layers[3])
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+ # Output head
+ out = self.head(path_1)
+
+ return out
diff --git a/imcui/third_party/dust3r/croco/models/head_downstream.py b/imcui/third_party/dust3r/croco/models/head_downstream.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd40c91ba244d6c3522c6efd4ed4d724b7bdc650
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/models/head_downstream.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+# --------------------------------------------------------
+# Heads for downstream tasks
+# --------------------------------------------------------
+
+"""
+A head is a module where the __init__ defines only the head hyperparameters.
+A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
+The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
+"""
+
+import torch
+import torch.nn as nn
+from .dpt_block import DPTOutputAdapter
+
+
+class PixelwiseTaskWithDPT(nn.Module):
+ """ DPT module for CroCo.
+ by default, hooks_idx will be equal to:
+ * for encoder-only: 4 equally spread layers
+ * for encoder+decoder: last encoder + 3 equally spread layers of the decoder
+ """
+
+ def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768],
+ output_width_ratio=1, num_channels=1, postprocess=None, **kwargs):
+ super(PixelwiseTaskWithDPT, self).__init__()
+ self.return_all_blocks = True # backbone needs to return all layers
+ self.postprocess = postprocess
+ self.output_width_ratio = output_width_ratio
+ self.num_channels = num_channels
+ self.hooks_idx = hooks_idx
+ self.layer_dims = layer_dims
+
+ def setup(self, croconet):
+ dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels}
+ if self.hooks_idx is None:
+ if hasattr(croconet, 'dec_blocks'): # encoder + decoder
+ step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
+ hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
+ else: # encoder only
+ step = croconet.enc_depth//4
+ hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
+ self.hooks_idx = hooks_idx
+ print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}')
+ dpt_args['hooks'] = self.hooks_idx
+ dpt_args['layer_dims'] = self.layer_dims
+ self.dpt = DPTOutputAdapter(**dpt_args)
+ dim_tokens = [croconet.enc_embed_dim if hook0:
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+
+#----------------------------------------------------------
+# RoPE2D: RoPE implementation in 2D
+#----------------------------------------------------------
+
+try:
+ from models.curope import cuRoPE2D
+ RoPE2D = cuRoPE2D
+except ImportError:
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
+
+ class RoPE2D(torch.nn.Module):
+
+ def __init__(self, freq=100.0, F0=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D,seq_len,device,dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
+ return self.cache[D,seq_len,device,dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim==2
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
+ D = tokens.size(3) // 2
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
+ # split features into two along the feature dimension, and apply rope1d on each half
+ y, x = tokens.chunk(2, dim=-1)
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
+ tokens = torch.cat((y, x), dim=-1)
+ return tokens
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/pretrain.py b/imcui/third_party/dust3r/croco/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c45e488015ef5380c71d0381ff453fdb860759e
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/pretrain.py
@@ -0,0 +1,254 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Pre-training CroCo
+# --------------------------------------------------------
+# References:
+# MAE: https://github.com/facebookresearch/mae
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+import argparse
+import datetime
+import json
+import numpy as np
+import os
+import sys
+import time
+import math
+from pathlib import Path
+from typing import Iterable
+
+import torch
+import torch.distributed as dist
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+
+import utils.misc as misc
+from utils.misc import NativeScalerWithGradNormCount as NativeScaler
+from models.croco import CroCoNet
+from models.criterion import MaskedMSE
+from datasets.pairs_dataset import PairsDataset
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('CroCo pre-training', add_help=False)
+ # model and criterion
+ parser.add_argument('--model', default='CroCoNet()', type=str, help="string containing the model to build")
+ parser.add_argument('--norm_pix_loss', default=1, choices=[0,1], help="apply per-patch mean/std normalization before applying the loss")
+ # dataset
+ parser.add_argument('--dataset', default='habitat_release', type=str, help="training set")
+ parser.add_argument('--transforms', default='crop224+acolor', type=str, help="transforms to apply") # in the paper, we also use some homography and rotation, but find later that they were not useful or even harmful
+ # training
+ parser.add_argument('--seed', default=0, type=int, help="Random seed")
+ parser.add_argument('--batch_size', default=64, type=int, help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
+ parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler")
+ parser.add_argument('--max_epoch', default=400, type=int, help="Stop training at this epoch")
+ parser.add_argument('--accum_iter', default=1, type=int, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)")
+ parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)")
+ parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
+ parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
+ parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
+ parser.add_argument('--amp', type=int, default=1, choices=[0,1], help="Use Automatic Mixed Precision for pretraining")
+ # others
+ parser.add_argument('--num_workers', default=8, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--save_freq', default=1, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth')
+ parser.add_argument('--keep_freq', default=20, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth')
+ parser.add_argument('--print_freq', default=20, type=int, help='frequence (number of iterations) to print infos while training')
+ # paths
+ parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output")
+ parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
+ return parser
+
+
+
+
+def main(args):
+ misc.init_distributed_mode(args)
+ global_rank = misc.get_rank()
+ world_size = misc.get_world_size()
+
+ print("output_dir: "+args.output_dir)
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ # auto resume
+ last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth')
+ args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
+
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(', ', ',\n'))
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # fix the seed
+ seed = args.seed + misc.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+
+ ## training dataset and loader
+ print('Building dataset for {:s} with transforms {:s}'.format(args.dataset, args.transforms))
+ dataset = PairsDataset(args.dataset, trfs=args.transforms, data_dir=args.data_dir)
+ if world_size>1:
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset, num_replicas=world_size, rank=global_rank, shuffle=True
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ else:
+ sampler_train = torch.utils.data.RandomSampler(dataset)
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=True,
+ )
+
+ ## model
+ print('Loading model: {:s}'.format(args.model))
+ model = eval(args.model)
+ print('Loading criterion: MaskedMSE(norm_pix_loss={:s})'.format(str(bool(args.norm_pix_loss))))
+ criterion = MaskedMSE(norm_pix_loss=bool(args.norm_pix_loss))
+
+ model.to(device)
+ model_without_ddp = model
+ print("Model = %s" % str(model_without_ddp))
+
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
+ if args.lr is None: # only base_lr is specified
+ args.lr = args.blr * eff_batch_size / 256
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
+ print("actual lr: %.2e" % args.lr)
+ print("accumulate grad iterations: %d" % args.accum_iter)
+ print("effective batch size: %d" % eff_batch_size)
+
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True)
+ model_without_ddp = model.module
+
+ param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) # following timm: set wd as 0 for bias and norm layers
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
+ print(optimizer)
+ loss_scaler = NativeScaler()
+
+ misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
+
+ if global_rank == 0 and args.output_dir is not None:
+ log_writer = SummaryWriter(log_dir=args.output_dir)
+ else:
+ log_writer = None
+
+ print(f"Start training until {args.max_epoch} epochs")
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.max_epoch):
+ if world_size>1:
+ data_loader_train.sampler.set_epoch(epoch)
+
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train,
+ optimizer, device, epoch, loss_scaler,
+ log_writer=log_writer,
+ args=args
+ )
+
+ if args.output_dir and epoch % args.save_freq == 0 :
+ misc.save_model(
+ args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, fname='last')
+
+ if args.output_dir and (epoch % args.keep_freq == 0 or epoch + 1 == args.max_epoch) and (epoch>0 or args.max_epoch==1):
+ misc.save_model(
+ args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch)
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,}
+
+ if args.output_dir and misc.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+
+
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
+ device: torch.device, epoch: int, loss_scaler,
+ log_writer=None,
+ args=None):
+ model.train(True)
+ metric_logger = misc.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ accum_iter = args.accum_iter
+
+ optimizer.zero_grad()
+
+ if log_writer is not None:
+ print('log_dir: {}'.format(log_writer.log_dir))
+
+ for data_iter_step, (image1, image2) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
+
+ # we use a per iteration lr scheduler
+ if data_iter_step % accum_iter == 0:
+ misc.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
+
+ image1 = image1.to(device, non_blocking=True)
+ image2 = image2.to(device, non_blocking=True)
+ with torch.cuda.amp.autocast(enabled=bool(args.amp)):
+ out, mask, target = model(image1, image2)
+ loss = criterion(out, mask, target)
+
+ loss_value = loss.item()
+
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ sys.exit(1)
+
+ loss /= accum_iter
+ loss_scaler(loss, optimizer, parameters=model.parameters(),
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
+ if (data_iter_step + 1) % accum_iter == 0:
+ optimizer.zero_grad()
+
+ torch.cuda.synchronize()
+
+ metric_logger.update(loss=loss_value)
+
+ lr = optimizer.param_groups[0]["lr"]
+ metric_logger.update(lr=lr)
+
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
+ if log_writer is not None and ((data_iter_step + 1) % (accum_iter*args.print_freq)) == 0:
+ # x-axis is based on epoch_1000x in the tensorboard, calibrating differences curves when batch size changes
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
+ log_writer.add_scalar('lr', lr, epoch_1000x)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+
+if __name__ == '__main__':
+ args = get_args_parser()
+ args = args.parse_args()
+ main(args)
diff --git a/imcui/third_party/dust3r/croco/stereoflow/augmentor.py b/imcui/third_party/dust3r/croco/stereoflow/augmentor.py
new file mode 100644
index 0000000000000000000000000000000000000000..69e6117151988d94cbc4b385e0d88e982133bf10
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/stereoflow/augmentor.py
@@ -0,0 +1,290 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+# --------------------------------------------------------
+# Data augmentation for training stereo and flow
+# --------------------------------------------------------
+
+# References
+# https://github.com/autonomousvision/unimatch/blob/master/dataloader/stereo/transforms.py
+# https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/transforms.py
+
+
+import numpy as np
+import random
+from PIL import Image
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+import torch
+from torchvision.transforms import ColorJitter
+import torchvision.transforms.functional as FF
+
+class StereoAugmentor(object):
+
+ def __init__(self, crop_size, scale_prob=0.5, scale_xonly=True, lhth=800., lminscale=0.0, lmaxscale=1.0, hminscale=-0.2, hmaxscale=0.4, scale_interp_nearest=True, rightjitterprob=0.5, v_flip_prob=0.5, color_aug_asym=True, color_choice_prob=0.5):
+ self.crop_size = crop_size
+ self.scale_prob = scale_prob
+ self.scale_xonly = scale_xonly
+ self.lhth = lhth
+ self.lminscale = lminscale
+ self.lmaxscale = lmaxscale
+ self.hminscale = hminscale
+ self.hmaxscale = hmaxscale
+ self.scale_interp_nearest = scale_interp_nearest
+ self.rightjitterprob = rightjitterprob
+ self.v_flip_prob = v_flip_prob
+ self.color_aug_asym = color_aug_asym
+ self.color_choice_prob = color_choice_prob
+
+ def _random_scale(self, img1, img2, disp):
+ ch,cw = self.crop_size
+ h,w = img1.shape[:2]
+ if self.scale_prob>0. and np.random.rand()1.:
+ scale_x = clip_scale
+ scale_y = scale_x if not self.scale_xonly else 1.0
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ disp = cv2.resize(disp, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR if not self.scale_interp_nearest else cv2.INTER_NEAREST) * scale_x
+ return img1, img2, disp
+
+ def _random_crop(self, img1, img2, disp):
+ h,w = img1.shape[:2]
+ ch,cw = self.crop_size
+ assert ch<=h and cw<=w, (img1.shape, h,w,ch,cw)
+ offset_x = np.random.randint(w - cw + 1)
+ offset_y = np.random.randint(h - ch + 1)
+ img1 = img1[offset_y:offset_y+ch,offset_x:offset_x+cw]
+ img2 = img2[offset_y:offset_y+ch,offset_x:offset_x+cw]
+ disp = disp[offset_y:offset_y+ch,offset_x:offset_x+cw]
+ return img1, img2, disp
+
+ def _random_vflip(self, img1, img2, disp):
+ # vertical flip
+ if self.v_flip_prob>0 and np.random.rand() < self.v_flip_prob:
+ img1 = np.copy(np.flipud(img1))
+ img2 = np.copy(np.flipud(img2))
+ disp = np.copy(np.flipud(disp))
+ return img1, img2, disp
+
+ def _random_rotate_shift_right(self, img2):
+ if self.rightjitterprob>0. and np.random.rand() 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
+ xx = xx[v]
+ yy = yy[v]
+ flow1 = flow1[v]
+
+ flow = np.inf * np.ones([ht1, wd1, 2], dtype=np.float32) # invalid value every where, before we fill it with the correct ones
+ flow[yy, xx] = flow1
+ return flow
+
+ def spatial_transform(self, img1, img2, flow, dname):
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # randomly sample scale
+ ht, wd = img1.shape[:2]
+ clip_min_scale = np.maximum(
+ (self.crop_size[0] + 8) / float(ht),
+ (self.crop_size[1] + 8) / float(wd))
+ min_scale, max_scale = self.min_scale, self.max_scale
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = scale
+ scale_y = scale
+ if np.random.rand() < self.stretch_prob:
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+ scale_x = np.clip(scale_x, clip_min_scale, None)
+ scale_y = np.clip(scale_y, clip_min_scale, None)
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = self._resize_flow(flow, scale_x, scale_y, factor=2.0 if dname=='Spring' else 1.0)
+ elif dname=="Spring":
+ flow = self._resize_flow(flow, 1.0, 1.0, factor=2.0)
+
+ if self.h_flip_prob>0. and np.random.rand() < self.h_flip_prob: # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+
+ if self.v_flip_prob>0. and np.random.rand() < self.v_flip_prob: # v-flip
+ img1 = img1[::-1, :]
+ img2 = img2[::-1, :]
+ flow = flow[::-1, :] * [1.0, -1.0]
+
+ # In case no cropping
+ if img1.shape[0] - self.crop_size[0] > 0:
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
+ else:
+ y0 = 0
+ if img1.shape[1] - self.crop_size[1] > 0:
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
+ else:
+ x0 = 0
+
+ img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+
+ return img1, img2, flow
+
+ def __call__(self, img1, img2, flow, dname):
+ img1, img2, flow = self.spatial_transform(img1, img2, flow, dname)
+ img1, img2 = self.color_transform(img1, img2)
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+ return img1, img2, flow
\ No newline at end of file
diff --git a/imcui/third_party/dust3r/croco/stereoflow/criterion.py b/imcui/third_party/dust3r/croco/stereoflow/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..57792ebeeee34827b317a4d32b7445837bb33f17
--- /dev/null
+++ b/imcui/third_party/dust3r/croco/stereoflow/criterion.py
@@ -0,0 +1,251 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+# --------------------------------------------------------
+# Losses, metrics per batch, metrics per dataset
+# --------------------------------------------------------
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+def _get_gtnorm(gt):
+ if gt.size(1)==1: # stereo
+ return gt
+ # flow
+ return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW
+
+############ losses without confidence
+
+class L1Loss(nn.Module):
+
+ def __init__(self, max_gtnorm=None):
+ super().__init__()
+ self.max_gtnorm = max_gtnorm
+ self.with_conf = False
+
+ def _error(self, gt, predictions):
+ return torch.abs(gt-predictions)
+
+ def forward(self, predictions, gt, inspect=False):
+ mask = torch.isfinite(gt)
+ if self.max_gtnorm is not None:
+ mask *= _get_gtnorm(gt).expand(-1,gt.size(1),-1,-1) which is a constant
+
+
+class LaplacianLossBounded(nn.Module): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b
+ def __init__(self, max_gtnorm=10000., a=0.25, b=4.):
+ super().__init__()
+ self.max_gtnorm = max_gtnorm
+ self.with_conf = True
+ self.a, self.b = a, b
+
+ def forward(self, predictions, gt, conf):
+ mask = torch.isfinite(gt)
+ mask = mask[:,0,:,:]
+ if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant
+
+class LaplacianLossBounded2(nn.Module): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b
+ def __init__(self, max_gtnorm=None, a=3.0, b=3.0):
+ super().__init__()
+ self.max_gtnorm = max_gtnorm
+ self.with_conf = True
+ self.a, self.b = a, b
+
+ def forward(self, predictions, gt, conf):
+ mask = torch.isfinite(gt)
+ mask = mask[:,0,:,:]
+ if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:] which is a constant
+
+############## metrics per batch
+
+class StereoMetrics(nn.Module):
+
+ def __init__(self, do_quantile=False):
+ super().__init__()
+ self.bad_ths = [0.5,1,2,3]
+ self.do_quantile = do_quantile
+
+ def forward(self, predictions, gt):
+ B = predictions.size(0)
+ metrics = {}
+ gtcopy = gt.clone()
+ mask = torch.isfinite(gtcopy)
+ gtcopy[~mask] = 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0
+ Npx = mask.view(B,-1).sum(dim=1)
+ L1error = (torch.abs(gtcopy-predictions)*mask).view(B,-1)
+ L2error = (torch.square(gtcopy-predictions)*mask).view(B,-1)
+ # avgerr
+ metrics['avgerr'] = torch.mean(L1error.sum(dim=1)/Npx )
+ # rmse
+ metrics['rmse'] = torch.sqrt(L2error.sum(dim=1)/Npx).mean(dim=0)
+ # err > t for t in [0.5,1,2,3]
+ for ths in self.bad_ths:
+ metrics['bad@{:.1f}'.format(ths)] = (((L1error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100
+ return metrics
+
+class FlowMetrics(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.bad_ths = [1,3,5]
+
+ def forward(self, predictions, gt):
+ B = predictions.size(0)
+ metrics = {}
+ mask = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite
+ Npx = mask.view(B,-1).sum(dim=1)
+ gtcopy = gt.clone() # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored
+ gtcopy[:,0,:,:][~mask] = 999999.0
+ gtcopy[:,1,:,:][~mask] = 999999.0
+ L1error = (torch.abs(gtcopy-predictions).sum(dim=1)*mask).view(B,-1)
+ L2error = (torch.sqrt(torch.sum(torch.square(gtcopy-predictions),dim=1))*mask).view(B,-1)
+ metrics['L1err'] = torch.mean(L1error.sum(dim=1)/Npx )
+ metrics['EPE'] = torch.mean(L2error.sum(dim=1)/Npx )
+ for ths in self.bad_ths:
+ metrics['bad@{:.1f}'.format(ths)] = (((L2error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100
+ return metrics
+
+############## metrics per dataset
+## we update the average and maintain the number of pixels while adding data batch per batch
+## at the beggining, call reset()
+## after each batch, call add_batch(...)
+## at the end: call get_results()
+
+class StereoDatasetMetrics(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.bad_ths = [0.5,1,2,3]
+
+ def reset(self):
+ self.agg_N = 0 # number of pixels so far
+ self.agg_L1err = torch.tensor(0.0) # L1 error so far
+ self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels
+ self._metrics = None
+
+ def add_batch(self, predictions, gt):
+ assert predictions.size(1)==1, predictions.size()
+ assert gt.size(1)==1, gt.size()
+ if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ...
+ L1err = torch.minimum( torch.minimum( torch.minimum(
+ torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1),
+ torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)),
+ torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)),
+ torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1))
+ valid = torch.isfinite(L1err)
+ else:
+ valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite
+ L1err = torch.sum(torch.abs(gt-predictions),dim=1)
+ N = valid.sum()
+ Nnew = self.agg_N + N
+ self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew
+ self.agg_N = Nnew
+ for i,th in enumerate(self.bad_ths):
+ self.agg_Nbad[i] += (L1err[valid]>th).sum().cpu()
+
+ def _compute_metrics(self):
+ if self._metrics is not None: return
+ out = {}
+ out['L1err'] = self.agg_L1err.item()
+ for i,th in enumerate(self.bad_ths):
+ out['bad@{:.1f}'.format(th)] = (float(self.agg_Nbad[i]) / self.agg_N).item() * 100.0
+ self._metrics = out
+
+ def get_results(self):
+ self._compute_metrics() # to avoid recompute them multiple times
+ return self._metrics
+
+class FlowDatasetMetrics(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.bad_ths = [0.5,1,3,5]
+ self.speed_ths = [(0,10),(10,40),(40,torch.inf)]
+
+ def reset(self):
+ self.agg_N = 0 # number of pixels so far
+ self.agg_L1err = torch.tensor(0.0) # L1 error so far
+ self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far
+ self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels
+ self.agg_EPEspeed = [torch.tensor(0.0) for _ in self.speed_ths] # EPE per speed bin so far
+ self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far
+ self._metrics = None
+ self.pairname_results = {}
+
+ def add_batch(self, predictions, gt):
+ assert predictions.size(1)==2, predictions.size()
+ assert gt.size(1)==2, gt.size()
+ if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ...
+ L1err = torch.minimum( torch.minimum( torch.minimum(
+ torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1),
+ torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)),
+ torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)),
+ torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1))
+ L2err = torch.minimum( torch.minimum( torch.minimum(
+ torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]-predictions),dim=1)),
+ torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]-predictions),dim=1))),
+ torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]-predictions),dim=1))),
+ torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]-predictions),dim=1)))
+ valid = torch.isfinite(L1err)
+ gtspeed = (torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]),dim=1)) +\
+ torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]),dim=1)) ) / 4.0 # let's just average them
+ else:
+ valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite
+ L1err = torch.sum(torch.abs(gt-predictions),dim=1)
+ L2err = torch.sqrt(torch.sum(torch.square(gt-predictions),dim=1))
+ gtspeed = torch.sqrt(torch.sum(torch.square(gt),dim=1))
+ N = valid.sum()
+ Nnew = self.agg_N + N
+ self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew
+ self.agg_L2err = float(self.agg_N)/Nnew * self.agg_L2err + L2err[valid].mean().cpu() * float(N)/Nnew
+ self.agg_N = Nnew
+ for i,th in enumerate(self.bad_ths):
+ self.agg_Nbad[i] += (L2err[valid]>th).sum().cpu()
+ for i,(th1,th2) in enumerate(self.speed_ths):
+ vv = (gtspeed[valid]>=th1) * (gtspeed[valid] don't use batch_size>1 at test time)
+ self._prepare_data()
+ self._load_or_build_cache()
+
+ def prepare_data(self):
+ """
+ to be defined for each dataset
+ """
+ raise NotImplementedError
+
+ def __len__(self):
+ return len(self.pairnames) # each pairname is typically of the form (str, int1, int2)
+
+ def __getitem__(self, index):
+ pairname = self.pairnames[index]
+
+ # get filenames
+ img1name = self.pairname_to_img1name(pairname)
+ img2name = self.pairname_to_img2name(pairname)
+ flowname = self.pairname_to_flowname(pairname) if self.pairname_to_flowname is not None else None
+
+ # load images and disparities
+ img1 = _read_img(img1name)
+ img2 = _read_img(img2name)
+ flow = self.load_flow(flowname) if flowname is not None else None
+
+ # apply augmentations
+ if self.augmentor is not None:
+ img1, img2, flow = self.augmentor(img1, img2, flow, self.name)
+
+ if self.totensor:
+ img1 = img_to_tensor(img1)
+ img2 = img_to_tensor(img2)
+ if flow is not None:
+ flow = flow_to_tensor(flow)
+ else:
+ flow = torch.tensor([]) # to allow dataloader batching with default collate_gn
+ pairname = str(pairname) # transform potential tuple to str to be able to batch it
+
+ return img1, img2, flow, pairname
+
+ def __rmul__(self, v):
+ self.rmul *= v
+ self.pairnames = v * self.pairnames
+ return self
+
+ def __str__(self):
+ return f'{self.__class__.__name__}_{self.split}'
+
+ def __repr__(self):
+ s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})'
+ if self.rmul==1:
+ s+=f'\n\tnum pairs: {len(self.pairnames)}'
+ else:
+ s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})'
+ return s
+
+ def _set_root(self):
+ self.root = dataset_to_root[self.name]
+ assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}"
+
+ def _load_or_build_cache(self):
+ cache_file = osp.join(cache_dir, self.name+'.pkl')
+ if osp.isfile(cache_file):
+ with open(cache_file, 'rb') as fid:
+ self.pairnames = pickle.load(fid)[self.split]
+ else:
+ tosave = self._build_cache()
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(cache_file, 'wb') as fid:
+ pickle.dump(tosave, fid)
+ self.pairnames = tosave[self.split]
+
+class TartanAirDataset(FlowDataset):
+
+ def _prepare_data(self):
+ self.name = "TartanAir"
+ self._set_root()
+ assert self.split in ['train']
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[1]))
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[2]))
+ self.pairname_to_flowname = lambda pairname: osp.join(self.root, pairname[0], 'flow/{:06d}_{:06d}_flow.npy'.format(pairname[1],pairname[2]))
+ self.pairname_to_str = lambda pairname: os.path.join(pairname[0][pairname[0].find('/')+1:], '{:06d}_{:06d}'.format(pairname[1], pairname[2]))
+ self.load_flow = _read_numpy_flow
+
+ def _build_cache(self):
+ seqs = sorted(os.listdir(self.root))
+ pairs = [(osp.join(s,s,difficulty,Pxxx),int(a[:6]),int(a[:6])+1) for s in seqs for difficulty in ['Easy','Hard'] for Pxxx in sorted(os.listdir(osp.join(self.root,s,s,difficulty))) for a in sorted(os.listdir(osp.join(self.root,s,s,difficulty,Pxxx,'image_left/')))[:-1]]
+ assert len(pairs)==306268, "incorrect parsing of pairs in TartanAir"
+ tosave = {'train': pairs}
+ return tosave
+
+class FlyingChairsDataset(FlowDataset):
+
+ def _prepare_data(self):
+ self.name = "FlyingChairs"
+ self._set_root()
+ assert self.split in ['train','val']
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, 'data', pairname+'_img1.ppm')
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, 'data', pairname+'_img2.ppm')
+ self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'data', pairname+'_flow.flo')
+ self.pairname_to_str = lambda pairname: pairname
+ self.load_flow = _read_flo_file
+
+ def _build_cache(self):
+ split_file = osp.join(self.root, 'chairs_split.txt')
+ split_list = np.loadtxt(split_file, dtype=np.int32)
+ trainpairs = ['{:05d}'.format(i) for i in np.where(split_list==1)[0]+1]
+ valpairs = ['{:05d}'.format(i) for i in np.where(split_list==2)[0]+1]
+ assert len(trainpairs)==22232 and len(valpairs)==640, "incorrect parsing of pairs in MPI-Sintel"
+ tosave = {'train': trainpairs, 'val': valpairs}
+ return tosave
+
+class FlyingThingsDataset(FlowDataset):
+
+ def _prepare_data(self):
+ self.name = "FlyingThings"
+ self._set_root()
+ assert self.split in [f'{set_}_{pass_}pass{camstr}' for set_ in ['train','test','test1024'] for camstr in ['','_rightcam'] for pass_ in ['clean','final','all']]
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[1]))
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[2]))
+ self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'optical_flow', pairname[0], 'OpticalFlowInto{f:s}_{i:04d}_{c:s}.pfm'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' ))
+ self.pairname_to_str = lambda pairname: os.path.join(pairname[3]+'pass', pairname[0], 'Into{f:s}_{i:04d}_{c:s}'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' ))
+ self.load_flow = _read_pfm_flow
+
+ def _build_cache(self):
+ tosave = {}
+ # train and test splits for the different passes
+ for set_ in ['train', 'test']:
+ sroot = osp.join(self.root, 'optical_flow', set_.upper())
+ fname_to_i = lambda f: int(f[len('OpticalFlowIntoFuture_'):-len('_L.pfm')])
+ pp = [(osp.join(set_.upper(), d, s, 'into_future/left'),fname_to_i(fname)) for d in sorted(os.listdir(sroot)) for s in sorted(os.listdir(osp.join(sroot,d))) for fname in sorted(os.listdir(osp.join(sroot,d, s, 'into_future/left')))[:-1]]
+ pairs = [(a,i,i+1) for a,i in pp]
+ pairs += [(a.replace('into_future','into_past'),i+1,i) for a,i in pp]
+ assert len(pairs)=={'train': 40302, 'test': 7866}[set_], "incorrect parsing of pairs Flying Things"
+ for cam in ['left','right']:
+ camstr = '' if cam=='left' else f'_{cam}cam'
+ for pass_ in ['final', 'clean']:
+ tosave[f'{set_}_{pass_}pass{camstr}'] = [(a.replace('left',cam),i,j,pass_) for a,i,j in pairs]
+ tosave[f'{set_}_allpass{camstr}'] = tosave[f'{set_}_cleanpass{camstr}'] + tosave[f'{set_}_finalpass{camstr}']
+ # test1024: this is the same split as unimatch 'validation' split
+ # see https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/datasets.py#L229
+ test1024_nsamples = 1024
+ alltest_nsamples = len(tosave['test_cleanpass']) # 7866
+ stride = alltest_nsamples // test1024_nsamples
+ remove = alltest_nsamples % test1024_nsamples
+ for cam in ['left','right']:
+ camstr = '' if cam=='left' else f'_{cam}cam'
+ for pass_ in ['final','clean']:
+ tosave[f'test1024_{pass_}pass{camstr}'] = sorted(tosave[f'test_{pass_}pass{camstr}'])[:-remove][::stride] # warning, it was not sorted before
+ assert len(tosave['test1024_cleanpass'])==1024, "incorrect parsing of pairs in Flying Things"
+ tosave[f'test1024_allpass{camstr}'] = tosave[f'test1024_cleanpass{camstr}'] + tosave[f'test1024_finalpass{camstr}']
+ return tosave
+
+
+class MPISintelDataset(FlowDataset):
+
+ def _prepare_data(self):
+ self.name = "MPISintel"
+ self._set_root()
+ assert self.split in [s+'_'+p for s in ['train','test','subval','subtrain'] for p in ['cleanpass','finalpass','allpass']]
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1]))
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1]+1))
+ self.pairname_to_flowname = lambda pairname: None if pairname[0].startswith('test/') else osp.join(self.root, pairname[0].replace('/clean/','/flow/').replace('/final/','/flow/'), 'frame_{:04d}.flo'.format(pairname[1]))
+ self.pairname_to_str = lambda pairname: osp.join(pairname[0], 'frame_{:04d}'.format(pairname[1]))
+ self.load_flow = _read_flo_file
+
+ def _build_cache(self):
+ trainseqs = sorted(os.listdir(self.root+'training/clean'))
+ trainpairs = [ (osp.join('training/clean', s),i) for s in trainseqs for i in range(1, len(os.listdir(self.root+'training/clean/'+s)))]
+ subvalseqs = ['temple_2','temple_3']
+ subtrainseqs = [s for s in trainseqs if s not in subvalseqs]
+ subvalpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subvalseqs)]
+ subtrainpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subtrainseqs)]
+ testseqs = sorted(os.listdir(self.root+'test/clean'))
+ testpairs = [ (osp.join('test/clean', s),i) for s in testseqs for i in range(1, len(os.listdir(self.root+'test/clean/'+s)))]
+ assert len(trainpairs)==1041 and len(testpairs)==552 and len(subvalpairs)==98 and len(subtrainpairs)==943, "incorrect parsing of pairs in MPI-Sintel"
+ tosave = {}
+ tosave['train_cleanpass'] = trainpairs
+ tosave['test_cleanpass'] = testpairs
+ tosave['subval_cleanpass'] = subvalpairs
+ tosave['subtrain_cleanpass'] = subtrainpairs
+ for t in ['train','test','subval','subtrain']:
+ tosave[t+'_finalpass'] = [(p.replace('/clean/','/final/'),i) for p,i in tosave[t+'_cleanpass']]
+ tosave[t+'_allpass'] = tosave[t+'_cleanpass'] + tosave[t+'_finalpass']
+ return tosave
+
+ def submission_save_pairname(self, pairname, prediction, outdir, _time):
+ assert prediction.shape[2]==2
+ outfile = os.path.join(outdir, 'submission', self.pairname_to_str(pairname)+'.flo')
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
+ writeFlowFile(prediction, outfile)
+
+ def finalize_submission(self, outdir):
+ assert self.split == 'test_allpass'
+ bundle_exe = "/nfs/data/ffs-3d/datasets/StereoFlow/MPI-Sintel/bundler/linux-x64/bundler" # eg